Skip to content

Commit

Permalink
Fix GSO support
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 11, 2023
1 parent 6122891 commit 1cd52bf
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 66 deletions.
8 changes: 0 additions & 8 deletions stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ type StackOptions struct {
InterfaceFinder control.InterfaceFinder
}

func (o *StackOptions) BufferSize() uint32 {
if o.TunOptions.GSO {
return o.TunOptions.GSOMaxSize
} else {
return o.TunOptions.MTU
}
}

func NewStack(
stack string,
options StackOptions,
Expand Down
79 changes: 56 additions & 23 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (m *Mixed) Start() error {
if err != nil {
return err
}
endpoint := channel.New(1024, m.mtu, "")
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := newGVisorStack(endpoint)
if err != nil {
return err
Expand Down Expand Up @@ -95,8 +95,16 @@ func (m *Mixed) tunLoop() {
m.wintunLoop(winTun)
return
}

if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN {
batchSize := batchTUN.BatchSize()
if batchSize > 1 {
m.batchLoop(batchTUN, batchSize)
return
}
}
frontHeadroom := m.tun.FrontHeadroom()
packetBuffer := make([]byte, m.bufferSize+frontHeadroom+PacketOffset)
packetBuffer := make([]byte, m.mtu+frontHeadroom+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer[frontHeadroom:])
if err != nil {
Expand All @@ -110,17 +118,7 @@ func (m *Mixed) tunLoop() {
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(rawPacket, packet)
case 6:
err = m.processIPv6(rawPacket, packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
}
m.processPacket(rawPacket, packet)
}
}

Expand All @@ -134,18 +132,53 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet, packet)
case 6:
err = m.processIPv6(packet, packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
m.processPacket(packet, packet)
release()
}
}

func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := m.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset)
}
packetSizes := make([]int, batchSize)
for {
n, err := linuxTUN.BatchRead(packetBuffers, packetSizes)
if err != nil {
m.logger.Trace(err)
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "batch read packet"))
}
release()
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetBuffer := packetBuffers[i][:packetSizes[i]]
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
m.processPacket(rawPacket, packet)
}
}
}

func (m *Mixed) processPacket(rawPacket []byte, packet []byte) {
var err error
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(rawPacket, packet)
case 6:
err = m.processIPv6(rawPacket, packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
}
}

Expand Down
104 changes: 70 additions & 34 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ type System struct {
ctx context.Context
tun Tun
tunName string
mtu uint32
bufferSize int
mtu int
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
Expand Down Expand Up @@ -57,8 +56,7 @@ func NewSystem(options StackOptions) (Stack, error) {
ctx: options.Context,
tun: options.Tun,
tunName: options.TunOptions.Name,
mtu: options.TunOptions.MTU,
bufferSize: int(options.BufferSize()),
mtu: int(options.TunOptions.MTU),
udpTimeout: options.UDPTimeout,
handler: options.Handler,
logger: options.Logger,
Expand Down Expand Up @@ -147,8 +145,15 @@ func (s *System) tunLoop() {
s.wintunLoop(winTun)
return
}
if batchTUN, isBatchTUN := s.tun.(BatchTUN); isBatchTUN {
batchSize := batchTUN.BatchSize()
if batchSize > 1 {
s.batchLoop(batchTUN, batchSize)
return
}
}
frontHeadroom := s.tun.FrontHeadroom()
packetBuffer := make([]byte, s.bufferSize+frontHeadroom+PacketOffset)
packetBuffer := make([]byte, s.mtu+frontHeadroom+PacketOffset)
for {
n, err := s.tun.Read(packetBuffer[frontHeadroom:])
if err != nil {
Expand All @@ -162,17 +167,7 @@ func (s *System) tunLoop() {
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(rawPacket, packet)
case 6:
err = s.processIPv6(rawPacket, packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
}
s.processPacket(rawPacket, packet)
}
}

Expand All @@ -186,18 +181,53 @@ func (s *System) wintunLoop(winTun WinTun) {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(packet, packet)
case 6:
err = s.processIPv6(packet, packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
s.processPacket(packet, packet)
release()
}
}

func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := s.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset)
}
packetSizes := make([]int, batchSize)
for {
n, err := linuxTUN.BatchRead(packetBuffers, packetSizes)
if err != nil {
s.logger.Trace(err)
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "batch read packet"))
}
release()
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetBuffer := packetBuffers[i][:packetSizes[i]]
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
s.processPacket(rawPacket, packet)
}
}
}

func (s *System) processPacket(rawPacket []byte, packet []byte) {
var err error
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(rawPacket, packet)
case 6:
err = s.processIPv6(rawPacket, packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
}
}

Expand Down Expand Up @@ -354,7 +384,7 @@ func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet,
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom(), headerCopy, source}
return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source}
})
return nil
}
Expand All @@ -380,7 +410,7 @@ func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet,
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom(), headerCopy, source}
return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source}
})
return nil
}
Expand Down Expand Up @@ -421,8 +451,7 @@ type systemUDPPacketWriter4 struct {
func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.WriteZeroN(w.frontHeadroom)
newPacket.Advance(w.frontHeadroom)
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes())
Expand All @@ -435,7 +464,11 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
newPacket.Advance(-w.frontHeadroom)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}

Expand All @@ -449,8 +482,7 @@ type systemUDPPacketWriter6 struct {
func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.WriteZeroN(w.frontHeadroom)
newPacket.Advance(w.frontHeadroom)
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes())
Expand All @@ -463,6 +495,10 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
udpHdr.ResetChecksum(ipHdr.PseudoSum())
newPacket.Advance(-w.frontHeadroom)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
5 changes: 5 additions & 0 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ type WinTun interface {
ReadPacket() ([]byte, func(), error)
}

type BatchTUN interface {
BatchSize() int
BatchRead(buffers [][]byte, readN []int) (n int, err error)
}

type Options struct {
Name string
Inet4Address []netip.Prefix
Expand Down
25 changes: 25 additions & 0 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"golang.org/x/sys/unix"
)

var _ BatchTUN = (*NativeTun)(nil)

type NativeTun struct {
tunFd int
tunFile *os.File
Expand Down Expand Up @@ -119,6 +121,29 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
return t.tunFile.Write(p)
}

func (t *NativeTun) BatchSize() int {
if !t.gsoEnabled {
return 1
}
return idealBatchSize
}

func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) {
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0)
if err != nil {
return
}
return
} else {
return 0, os.ErrInvalid
}
}

var controlPath string

func init() {
Expand Down
2 changes: 1 addition & 1 deletion tun_linux_offload.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

const (
tcpFlagsOffset = 13
idealBatchSize = 1
idealBatchSize = 128
)

const (
Expand Down

0 comments on commit 1cd52bf

Please sign in to comment.