Skip to content

Commit

Permalink
Add support for ethernet
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jan 3, 2025
1 parent aa9d9c6 commit 7b60951
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 24 deletions.
4 changes: 2 additions & 2 deletions stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, false, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, false, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
return nil
Expand Down
25 changes: 23 additions & 2 deletions stack_gvisor_lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import (

"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common"
)

type gLazyConn struct {
ethernet bool
tcpConn *gonet.TCPConn
parentCtx context.Context
stack *stack.Stack
Expand All @@ -35,9 +37,13 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
c.handshakeDone = true
}()
var (
wq waiter.Queue
endpoint tcpip.Endpoint
wq waiter.Queue
endpoint tcpip.Endpoint
linkAddress tcpip.LinkAddress
)
if c.ethernet {
linkAddress = header.Ethernet(c.request.Packet().LinkHeader().Slice()).DestinationAddress()
}
handshakeCtx, cancel := context.WithCancel(ctx)
go func() {
select {
Expand All @@ -54,6 +60,9 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
c.request.Complete(true)
return gErr
}
if c.ethernet {
endpoint.SetOwner(&EthernetOwner{linkAddress})
}
c.request.Complete(false)
endpoint.SocketOptions().SetKeepAlive(true)
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second)))
Expand All @@ -63,6 +72,18 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
return nil
}

type EthernetOwner struct {
Destination tcpip.LinkAddress
}

func (o *EthernetOwner) KUID() uint32 {
return 0
}

func (o *EthernetOwner) KGID() uint32 {
return 0
}

func (c *gLazyConn) HandshakeFailure(err error) error {
if c.handshakeDone {
return os.ErrInvalid
Expand Down
11 changes: 7 additions & 4 deletions stack_gvisor_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@ import (
type TCPForwarder struct {
ctx context.Context
stack *stack.Stack
ethernet bool
handler Handler
forwarder *tcp.Forwarder
}

func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, ethernet bool, handler Handler) *TCPForwarder {
forwarder := &TCPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
ctx: ctx,
stack: stack,
ethernet: ethernet,
handler: handler,
}
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder
Expand All @@ -41,6 +43,7 @@ func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
return
}
conn := &gLazyConn{
ethernet: f.ethernet,
parentCtx: f.ctx,
stack: f.stack,
request: r,
Expand Down
46 changes: 31 additions & 15 deletions stack_gvisor_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ import (
)

type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
udpNat *udpnat.Service
ctx context.Context
stack *stack.Stack
ethernet bool
handler Handler
udpNat *udpnat.Service
}

func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, ethernet bool, handler Handler, timeout time.Duration) *UDPForwarder {
forwarder := &UDPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
ctx: ctx,
stack: stack,
ethernet: ethernet,
handler: handler,
}
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true)
return forwarder
Expand Down Expand Up @@ -77,16 +79,23 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M
sourcePort: source.Port,
sourceNetwork: sourceNetwork,
}
if f.ethernet {
ethHdr := header.Ethernet(userData.(*stack.PacketBuffer).LinkHeader().Slice())
writer.linkSource = ethHdr.SourceAddress()
writer.linkDestination = ethHdr.DestinationAddress()
}
return true, f.ctx, writer, nil
}

type UDPBackWriter struct {
access sync.Mutex
stack *stack.Stack
packet *stack.PacketBuffer
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
access sync.Mutex
stack *stack.Stack
packet *stack.PacketBuffer
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
linkSource tcpip.LinkAddress
linkDestination tcpip.LinkAddress
}

func (w *UDPBackWriter) HandshakeSuccess() error {
Expand Down Expand Up @@ -139,7 +148,14 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
})
defer packet.DecRef()

if w.linkSource != "" {
packet.LinkHeader().Consume(header.EthernetMinimumSize)
header.Ethernet(packet.LinkHeader().Slice()).Encode(&header.EthernetFields{
SrcAddr: w.linkDestination,
DstAddr: w.linkSource,
Type: w.sourceNetwork,
})
}
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
Expand Down
2 changes: 1 addition & 1 deletion stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (m *Mixed) Start() error {
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, false, m.handler, m.udpTimeout).HandlePacket)
m.stack = ipStack
m.endpoint = endpoint
go m.tunLoop()
Expand Down

0 comments on commit 7b60951

Please sign in to comment.