Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 5, 2024
1 parent f86d3aa commit c35b14a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 80 deletions.
1 change: 0 additions & 1 deletion stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ type StackOptions struct {
Context context.Context
Tun Tun
TunOptions Options
EndpointIndependentNat bool
UDPTimeout time.Duration
Handler Handler
Logger logger.Logger
Expand Down
59 changes: 15 additions & 44 deletions stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
Expand All @@ -31,15 +28,14 @@ const WithGVisor = true
const defaultNIC tcpip.NICID = 1

type GVisor struct {
ctx context.Context
tun GVisorTun
endpointIndependentNat bool
udpTimeout time.Duration
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
ctx context.Context
tun GVisorTun
udpTimeout time.Duration
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
}

type GVisorTun interface {
Expand All @@ -56,13 +52,12 @@ func NewGVisor(
}

gStack := &GVisor{
ctx: options.Context,
tun: gTun,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
ctx: options.Context,
tun: gTun,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
return gStack, nil
}
Expand Down Expand Up @@ -95,31 +90,7 @@ func (t *GVisor) Start() error {
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
if !t.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := t.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
gWriteUnreachable(t.stack, r.Packet(), err)
r.Packet().DecRef()
return
}
var wq waiter.Queue
endpoint, err := r.CreateEndpoint(&wq)
if err != nil {
return
}
go func() {
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), t.udpTimeout)
t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil)
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
}

ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
return nil
Expand Down
38 changes: 4 additions & 34 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,19 @@ package tun
import (
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
gHdr "github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

type Mixed struct {
*System
endpointIndependentNat bool
stack *stack.Stack
endpoint *channel.Endpoint
stack *stack.Stack
endpoint *channel.Endpoint
}

func NewMixed(
Expand All @@ -34,8 +28,7 @@ func NewMixed(
return nil, err
}
return &Mixed{
System: system.(*System),
endpointIndependentNat: options.EndpointIndependentNat,
System: system.(*System),
}, nil
}

Expand All @@ -49,30 +42,7 @@ func (m *Mixed) Start() error {
if err != nil {
return err
}
if !m.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := m.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
gWriteUnreachable(m.stack, r.Packet(), err)
r.Packet().DecRef()
return
}
var wq waiter.Queue
endpoint, err := r.CreateEndpoint(&wq)
if err != nil {
return
}
go func() {
ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), m.udpTimeout)
m.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil)
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
}
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
m.stack = ipStack
m.endpoint = endpoint
go m.tunLoop()
Expand Down
2 changes: 1 addition & 1 deletion stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
newPacket.Write(buffer.Bytes())
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.SetTotalLength(uint16(newPacket.Len()))
ipHdr.SetSourceAddress(ipHdr.SourceAddress())
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
ipHdr.SetSourceAddr(destination.Addr)
udpHdr := header.UDP(ipHdr.Payload())
udpHdr.SetDestinationPort(udpHdr.SourcePort())
Expand Down

0 comments on commit c35b14a

Please sign in to comment.