From 3b5b396d06f727c63ecb8e0e2c1cfaeb5beae3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 23 Oct 2024 13:41:50 +0800 Subject: [PATCH] Minor fixes --- go.mod | 2 +- go.sum | 4 ++-- stack.go | 1 - stack_gvisor.go | 59 +++++++++++++------------------------------------ stack_mixed.go | 38 ++++--------------------------- stack_system.go | 2 +- 6 files changed, 23 insertions(+), 83 deletions(-) diff --git a/go.mod b/go.mod index 8876e33..5b5eeb2 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.5.0-rc.4.0.20241021153852-cf58af1a4627 + github.com/sagernet/sing v0.5.0-rc.4.0.20241023053048-94f058276959 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index 2e9bf29..ea4a1f7 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.5.0-rc.4.0.20241021153852-cf58af1a4627 h1:wWRmqHPHfyWRPUIGsjAmYshvXF+pC/csl9pAmo/vGpo= -github.com/sagernet/sing v0.5.0-rc.4.0.20241021153852-cf58af1a4627/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.5.0-rc.4.0.20241023053048-94f058276959 h1:8BzTt5cU8h6HK4CcRq1UQHKsgUi942GjO0by/ntFZIs= +github.com/sagernet/sing v0.5.0-rc.4.0.20241023053048-94f058276959/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/stack.go b/stack.go index 5ba18e4..f664aab 100644 --- a/stack.go +++ b/stack.go @@ -23,7 +23,6 @@ type StackOptions struct { Context context.Context Tun Tun TunOptions Options - EndpointIndependentNat bool UDPTimeout time.Duration Handler Handler Logger logger.Logger diff --git a/stack_gvisor.go b/stack_gvisor.go index 6c5c27f..60af865 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -17,9 +17,6 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/gvisor/pkg/waiter" - "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -31,15 +28,14 @@ const WithGVisor = true const defaultNIC tcpip.NICID = 1 type GVisor struct { - ctx context.Context - tun GVisorTun - endpointIndependentNat bool - udpTimeout time.Duration - broadcastAddr netip.Addr - handler Handler - logger logger.Logger - stack *stack.Stack - endpoint stack.LinkEndpoint + ctx context.Context + tun GVisorTun + udpTimeout time.Duration + broadcastAddr netip.Addr + handler Handler + logger logger.Logger + stack *stack.Stack + endpoint stack.LinkEndpoint } type GVisorTun interface { @@ -56,13 +52,12 @@ func NewGVisor( } gStack := &GVisor{ - ctx: options.Context, - tun: gTun, - endpointIndependentNat: options.EndpointIndependentNat, - udpTimeout: options.UDPTimeout, - broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), - handler: options.Handler, - logger: options.Logger, + ctx: options.Context, + tun: gTun, + udpTimeout: options.UDPTimeout, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), + handler: options.Handler, + logger: options.Logger, } return gStack, nil } @@ -95,31 +90,7 @@ func (t *GVisor) Start() error { go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) - if !t.endpointIndependentNat { - udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { - source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := t.handler.PrepareConnection(N.NetworkUDP, source, destination) - if pErr != nil { - gWriteUnreachable(t.stack, r.Packet(), err) - r.Packet().DecRef() - return - } - var wq waiter.Queue - endpoint, err := r.CreateEndpoint(&wq) - if err != nil { - return - } - go func() { - ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), t.udpTimeout) - t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) - }() - }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - } else { - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) - } - + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint return nil diff --git a/stack_mixed.go b/stack_mixed.go index 8388cb9..3b7314e 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -5,25 +5,19 @@ package tun import ( "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" - "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" gHdr "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/link/channel" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/gvisor/pkg/waiter" "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) type Mixed struct { *System - endpointIndependentNat bool - stack *stack.Stack - endpoint *channel.Endpoint + stack *stack.Stack + endpoint *channel.Endpoint } func NewMixed( @@ -34,8 +28,7 @@ func NewMixed( return nil, err } return &Mixed{ - System: system.(*System), - endpointIndependentNat: options.EndpointIndependentNat, + System: system.(*System), }, nil } @@ -49,30 +42,7 @@ func (m *Mixed) Start() error { if err != nil { return err } - if !m.endpointIndependentNat { - udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { - source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := m.handler.PrepareConnection(N.NetworkUDP, source, destination) - if pErr != nil { - gWriteUnreachable(m.stack, r.Packet(), err) - r.Packet().DecRef() - return - } - var wq waiter.Queue - endpoint, err := r.CreateEndpoint(&wq) - if err != nil { - return - } - go func() { - ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), m.udpTimeout) - m.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) - }() - }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - } else { - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) - } + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) m.stack = ipStack m.endpoint = endpoint go m.tunLoop() diff --git a/stack_system.go b/stack_system.go index 08f2ba3..d7cd02e 100644 --- a/stack_system.go +++ b/stack_system.go @@ -731,7 +731,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S newPacket.Write(buffer.Bytes()) ipHdr := header.IPv4(newPacket.Bytes()) ipHdr.SetTotalLength(uint16(newPacket.Len())) - ipHdr.SetSourceAddress(ipHdr.SourceAddress()) + ipHdr.SetDestinationAddress(ipHdr.SourceAddress()) ipHdr.SetSourceAddr(destination.Addr) udpHdr := header.UDP(ipHdr.Payload()) udpHdr.SetDestinationPort(udpHdr.SourcePort())