diff --git a/stack_gvisor.go b/stack_gvisor.go index 65bb7bd..8f7dbff 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -70,8 +70,8 @@ func (t *GVisor) Start() error { if err != nil { return err } - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, false, t.handler).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, false, t.handler, t.udpTimeout).HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint return nil diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go index 59c993b..8536e25 100644 --- a/stack_gvisor_lazy.go +++ b/stack_gvisor_lazy.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/waiter" @@ -17,6 +18,7 @@ import ( ) type gLazyConn struct { + ethernet bool tcpConn *gonet.TCPConn parentCtx context.Context stack *stack.Stack @@ -35,9 +37,13 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error { c.handshakeDone = true }() var ( - wq waiter.Queue - endpoint tcpip.Endpoint + wq waiter.Queue + endpoint tcpip.Endpoint + linkAddress tcpip.LinkAddress ) + if c.ethernet { + linkAddress = header.Ethernet(c.request.Packet().LinkHeader().Slice()).DestinationAddress() + } handshakeCtx, cancel := context.WithCancel(ctx) go func() { select { @@ -54,6 +60,9 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error { c.request.Complete(true) return gErr } + if c.ethernet { + endpoint.SetOwner(&EthernetOwner{linkAddress}) + } c.request.Complete(false) endpoint.SocketOptions().SetKeepAlive(true) endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second))) @@ -63,6 +72,18 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error { return nil } +type EthernetOwner struct { + Destination tcpip.LinkAddress +} + +func (o *EthernetOwner) KUID() uint32 { + return 0 +} + +func (o *EthernetOwner) KGID() uint32 { + return 0 +} + func (c *gLazyConn) HandshakeFailure(err error) error { if c.handshakeDone { return os.ErrInvalid diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index 33cf40e..986a455 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -14,15 +14,17 @@ import ( type TCPForwarder struct { ctx context.Context stack *stack.Stack + ethernet bool handler Handler forwarder *tcp.Forwarder } -func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder { +func NewTCPForwarder(ctx context.Context, stack *stack.Stack, ethernet bool, handler Handler) *TCPForwarder { forwarder := &TCPForwarder{ - ctx: ctx, - stack: stack, - handler: handler, + ctx: ctx, + stack: stack, + ethernet: ethernet, + handler: handler, } forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward) return forwarder @@ -41,6 +43,7 @@ func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) { return } conn := &gLazyConn{ + ethernet: f.ethernet, parentCtx: f.ctx, stack: f.stack, request: r, diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 3027798..6288a05 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -25,17 +25,19 @@ import ( ) type UDPForwarder struct { - ctx context.Context - stack *stack.Stack - handler Handler - udpNat *udpnat.Service + ctx context.Context + stack *stack.Stack + ethernet bool + handler Handler + udpNat *udpnat.Service } -func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder { +func NewUDPForwarder(ctx context.Context, stack *stack.Stack, ethernet bool, handler Handler, timeout time.Duration) *UDPForwarder { forwarder := &UDPForwarder{ - ctx: ctx, - stack: stack, - handler: handler, + ctx: ctx, + stack: stack, + ethernet: ethernet, + handler: handler, } forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true) return forwarder @@ -77,16 +79,23 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M sourcePort: source.Port, sourceNetwork: sourceNetwork, } + if f.ethernet { + ethHdr := header.Ethernet(userData.(*stack.PacketBuffer).LinkHeader().Slice()) + writer.linkSource = ethHdr.SourceAddress() + writer.linkDestination = ethHdr.DestinationAddress() + } return true, f.ctx, writer, nil } type UDPBackWriter struct { - access sync.Mutex - stack *stack.Stack - packet *stack.PacketBuffer - source tcpip.Address - sourcePort uint16 - sourceNetwork tcpip.NetworkProtocolNumber + access sync.Mutex + stack *stack.Stack + packet *stack.PacketBuffer + source tcpip.Address + sourcePort uint16 + sourceNetwork tcpip.NetworkProtocolNumber + linkSource tcpip.LinkAddress + linkDestination tcpip.LinkAddress } func (w *UDPBackWriter) HandshakeSuccess() error { @@ -139,7 +148,14 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock Payload: buffer.MakeWithData(packetBuffer.Bytes()), }) defer packet.DecRef() - + if w.linkSource != "" { + packet.LinkHeader().Consume(header.EthernetMinimumSize) + header.Ethernet(packet.LinkHeader().Slice()).Encode(&header.EthernetFields{ + SrcAddr: w.linkDestination, + DstAddr: w.linkSource, + Type: w.sourceNetwork, + }) + } packet.TransportProtocolNumber = header.UDPProtocolNumber udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize)) pLen := uint16(packet.Size()) diff --git a/stack_mixed.go b/stack_mixed.go index 9293fb8..c18fba9 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -42,7 +42,7 @@ func (m *Mixed) Start() error { if err != nil { return err } - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, false, m.handler, m.udpTimeout).HandlePacket) m.stack = ipStack m.endpoint = endpoint go m.tunLoop()