diff --git a/app/app.go b/app/app.go index 59d1095bc..f62913f71 100644 --- a/app/app.go +++ b/app/app.go @@ -75,26 +75,30 @@ func RunWarp(ctx context.Context, opts WarpOptions) error { switch { case opts.Psiphon != nil: // run primary warp on a random tcp port and run psiphon on bind address - warpErr = runWarpWithPsiphon(opts.Bind, endpoints, opts.Psiphon.Country, opts.LogLevel == "debug", ctx) + warpErr = runWarpWithPsiphon(ctx, opts.Bind, endpoints, opts.Psiphon.Country, opts.LogLevel == "debug") case opts.Gool: // run warp in warp - warpErr = runWarpInWarp(opts.Bind, endpoints, opts.LogLevel == "debug", ctx) + warpErr = runWarpInWarp(ctx, opts.Bind, endpoints, opts.LogLevel == "debug") default: // just run primary warp on bindAddress - _, _, warpErr = runWarp(opts.Bind, endpoints, "./primary/wgcf-profile.ini", opts.LogLevel == "debug", true, ctx) + _, _, warpErr = runWarp(ctx, opts.Bind, endpoints, "./primary/wgcf-profile.ini", opts.LogLevel == "debug", true, true) } return warpErr } -func runWarp(bind netip.AddrPort, endpoints []string, confPath string, verbose, startProxy bool, ctx context.Context) (*wiresocks.VirtualTun, int, error) { +func runWarp(ctx context.Context, bind netip.AddrPort, endpoints []string, confPath string, verbose, startProxy bool, trick bool) (*wiresocks.VirtualTun, int, error) { conf, err := wiresocks.ParseConfig(confPath, endpoints[0]) if err != nil { log.Println(err) return nil, 0, err } - tnet, err := wiresocks.StartWireguard(conf.Device, verbose, ctx) + if trick { + conf.Device.Trick = trick + } + + tnet, err := wiresocks.StartWireguard(ctx, conf.Device, verbose) if err != nil { log.Println(err) return nil, 0, err @@ -107,7 +111,7 @@ func runWarp(bind netip.AddrPort, endpoints []string, confPath string, verbose, return tnet, conf.Device.MTU, nil } -func runWarpWithPsiphon(bind netip.AddrPort, endpoints []string, country string, verbose bool, ctx context.Context) error { +func runWarpWithPsiphon(ctx context.Context, bind netip.AddrPort, endpoints []string, country string, verbose bool) error { // make a random bind address for warp warpBindAddress, err := findFreePort("tcp") if err != nil { @@ -115,7 +119,7 @@ func runWarpWithPsiphon(bind netip.AddrPort, endpoints []string, country string, return err } - _, _, err = runWarp(warpBindAddress, endpoints, "./primary/wgcf-profile.ini", verbose, true, ctx) + _, _, err = runWarp(ctx, warpBindAddress, endpoints, "./primary/wgcf-profile.ini", verbose, true, true) if err != nil { return err } @@ -132,9 +136,9 @@ func runWarpWithPsiphon(bind netip.AddrPort, endpoints []string, country string, return nil } -func runWarpInWarp(bind netip.AddrPort, endpoints []string, verbose bool, ctx context.Context) error { +func runWarpInWarp(ctx context.Context, bind netip.AddrPort, endpoints []string, verbose bool) error { // run secondary warp - vTUN, mtu, err := runWarp(netip.AddrPort{}, endpoints, "./secondary/wgcf-profile.ini", verbose, false, ctx) + vTUN, mtu, err := runWarp(ctx, netip.AddrPort{}, endpoints, "./secondary/wgcf-profile.ini", verbose, false, true) if err != nil { return err } @@ -146,13 +150,6 @@ func runWarpInWarp(bind netip.AddrPort, endpoints []string, verbose bool, ctx co return err } addr := endpoints[1] - if addr == "" { - warpEndpoint, err := warp.RandomWarpEndpoint() - if err != nil { - return err - } - addr = warpEndpoint.String() - } err = wiresocks.NewVtunUDPForwarder(virtualEndpointBindAddress.String(), addr, vTUN, mtu+100, ctx) if err != nil { log.Println(err) @@ -160,7 +157,7 @@ func runWarpInWarp(bind netip.AddrPort, endpoints []string, verbose bool, ctx co } // run primary warp - _, _, err = runWarp(bind, []string{virtualEndpointBindAddress.String()}, "./primary/wgcf-profile.ini", verbose, true, ctx) + _, _, err = runWarp(ctx, bind, []string{virtualEndpointBindAddress.String()}, "./primary/wgcf-profile.ini", verbose, true, false) if err != nil { return err } diff --git a/device/device.go b/device/device.go index d3568ed1c..526ed2da5 100644 --- a/device/device.go +++ b/device/device.go @@ -86,6 +86,8 @@ type Device struct { mtu atomic.Int32 } + trick bool + ipcMutex sync.RWMutex closed chan struct{} log *Logger @@ -281,8 +283,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { +func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, trick bool) *Device { device := new(Device) + device.trick = trick device.state.state.Store(uint32(deviceStateDown)) device.closed = make(chan struct{}) device.log = logger diff --git a/device/device_test.go b/device/device_test.go index 8b170a266..8aa310b03 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -166,7 +166,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError } - p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)), false) if err := p.dev.IpcSet(cfg[i]); err != nil { tb.Errorf("failed to configure device %d: %v", i, err) p.dev.Close() diff --git a/device/noise_test.go b/device/noise_test.go index d6aae0851..83c3420af 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -39,7 +39,7 @@ func randDevice(t *testing.T) *Device { } tun := tuntest.NewChannelTUN() logger := NewLogger(LogLevelError, "") - device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) + device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, false) device.SetPrivateKey(sk) return device } diff --git a/device/peer.go b/device/peer.go index f0cee9c9d..623104622 100644 --- a/device/peer.go +++ b/device/peer.go @@ -53,6 +53,8 @@ type Peer struct { inbound *autodrainingInboundQueue // sequential ordering of tun writing } + trick bool + cookieGenerator CookieGenerator trieEntries list.List persistentKeepaliveInterval atomic.Uint32 @@ -78,6 +80,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) + peer.trick = true peer.cookieGenerator.Init(pk) peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) diff --git a/device/send.go b/device/send.go index 26c32b22b..b1ebe0e63 100644 --- a/device/send.go +++ b/device/send.go @@ -117,7 +117,9 @@ func (peer *Peer) sendRandomPackets() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { // Send some random packets on every keepalive - peer.sendRandomPackets() + if peer.trick { + peer.sendRandomPackets() + } elem := peer.device.NewOutboundElement() elemsContainer := peer.device.GetOutboundElementsContainer() @@ -153,7 +155,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } // send some random packets on handshake - peer.sendRandomPackets() + if peer.trick { + peer.sendRandomPackets() + } peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() diff --git a/wiresocks/config.go b/wiresocks/config.go index e83ef6097..1a3a27f9e 100644 --- a/wiresocks/config.go +++ b/wiresocks/config.go @@ -27,6 +27,7 @@ type DeviceConfig struct { DNS []netip.Addr MTU int ListenPort *int + Trick bool } type Configuration struct { diff --git a/wiresocks/wiresocks.go b/wiresocks/wiresocks.go index ba7301b99..1f920a8f0 100644 --- a/wiresocks/wiresocks.go +++ b/wiresocks/wiresocks.go @@ -60,7 +60,7 @@ func createIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) { } // StartWireguard creates a tun interface on netstack given a configuration -func StartWireguard(conf *DeviceConfig, verbose bool, ctx context.Context) (*VirtualTun, error) { +func StartWireguard(ctx context.Context, conf *DeviceConfig, verbose bool) (*VirtualTun, error) { setting, err := createIPCRequest(conf) if err != nil { return nil, err @@ -76,7 +76,7 @@ func StartWireguard(conf *DeviceConfig, verbose bool, ctx context.Context) (*Vir logLevel = device.LogLevelSilent } - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(logLevel, "")) + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(logLevel, ""), conf.Trick) err = dev.IpcSet(setting.ipcRequest) if err != nil { return nil, err