From 8355396af66c0397862e1fbd4dd4a09eae6c23ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 22 Nov 2024 16:36:35 +0800 Subject: [PATCH] Copy UDP GSO support from tailscale --- stack_system.go | 2 +- tun.go | 6 +- tun_darwin.go | 2 +- tun_linux.go | 154 ++++-- tun_linux_flags.go | 16 +- tun_linux_gvisor.go | 2 +- tun_linux_offload.go | 768 ----------------------------- tun_linux_offload_errors.go | 5 - tun_offload.go | 229 +++++++++ tun_offload_errors.go | 10 + tun_offload_linux.go | 937 ++++++++++++++++++++++++++++++++++++ 11 files changed, 1312 insertions(+), 819 deletions(-) delete mode 100644 tun_linux_offload.go delete mode 100644 tun_linux_offload_errors.go create mode 100644 tun_offload.go create mode 100644 tun_offload_errors.go create mode 100644 tun_offload_linux.go diff --git a/stack_system.go b/stack_system.go index dde0664..e7b68dd 100644 --- a/stack_system.go +++ b/stack_system.go @@ -244,7 +244,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) + _, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) if err != nil { s.logger.Trace(E.Cause(err, "batch write packet")) } diff --git a/tun.go b/tun.go index 6c7020b..3fe654f 100644 --- a/tun.go +++ b/tun.go @@ -1,7 +1,6 @@ package tun import ( - "github.com/sagernet/sing/common/control" "io" "net" "net/netip" @@ -9,6 +8,7 @@ import ( "strconv" "strings" + "github.com/sagernet/sing/common/control" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -39,7 +39,9 @@ type LinuxTUN interface { N.FrontHeadroom BatchSize() int BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) - BatchWrite(buffers [][]byte, offset int) error + BatchWrite(buffers [][]byte, offset int) (n int, err error) + DisableUDPGRO() + DisableTCPGRO() TXChecksumOffload() bool } diff --git a/tun_darwin.go b/tun_darwin.go index e3f39e5..b7f8aa7 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -3,13 +3,13 @@ package tun import ( "errors" "fmt" - "github.com/sagernet/sing-tun/internal/gtcpip/header" "net" "net/netip" "os" "syscall" "unsafe" + "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" diff --git a/tun_linux.go b/tun_linux.go index 09c9262..841bbea 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -2,6 +2,7 @@ package tun import ( "errors" + "fmt" "math/rand" "net" "net/netip" @@ -35,13 +36,15 @@ type NativeTun struct { interfaceCallback *list.Element[DefaultInterfaceUpdateCallback] options Options ruleIndex6 []int - gsoEnabled bool - gsoBuffer []byte + readAccess sync.Mutex + writeAccess sync.Mutex + vnetHdr bool + writeBuffer []byte gsoToWrite []int - gsoReadAccess sync.Mutex - tcpGROAccess sync.Mutex - tcp4GROTable *tcpGROTable - tcp6GROTable *tcpGROTable + tcpGROTable *tcpGROTable + udpGroAccess sync.Mutex + udpGROTable *udpGROTable + gro groDisablementFlags txChecksumOffload bool } @@ -81,20 +84,23 @@ func New(options Options) (Tun, error) { } func (t *NativeTun) FrontHeadroom() int { - if t.gsoEnabled { + if t.vnetHdr { return virtioNetHdrLen } return 0 } func (t *NativeTun) Read(p []byte) (n int, err error) { - if t.gsoEnabled { - n, err = t.tunFile.Read(t.gsoBuffer) + if t.vnetHdr { + n, err = t.tunFile.Read(t.writeBuffer) if err != nil { + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } return } var sizes [1]int - n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0) + n, err = handleVirtioRead(t.writeBuffer[:n], [][]byte{p}, sizes[:], 0) if err != nil { return } @@ -108,9 +114,50 @@ func (t *NativeTun) Read(p []byte) (n int, err error) { } } +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + + options, err := hdr.toGSOOptions() + if err != nil { + return 0, err + } + + // Don't trust HdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // CsumStart, which is synonymous for IP header length. + if options.GSOType == GSOUDPL4 { + options.HdrLen = options.CsumStart + 8 + } else if options.GSOType != GSONone { + if len(in) <= int(options.CsumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + options.HdrLen = options.CsumStart + tcpHLen + } + + return GSOSplit(in, options, bufs, sizes, offset) +} + func (t *NativeTun) Write(p []byte) (n int, err error) { - if t.gsoEnabled { - err = t.BatchWrite([][]byte{p}, virtioNetHdrLen) + if t.vnetHdr { + buffer := buf.Get(virtioNetHdrLen + len(p)) + copy(buffer[virtioNetHdrLen:], p) + _, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen) + buf.Put(buffer) if err != nil { return } @@ -121,7 +168,7 @@ func (t *NativeTun) Write(p []byte) (n int, err error) { } func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { - if t.gsoEnabled { + if t.vnetHdr { n := buf.LenMulti(buffers) buffer := buf.NewSize(virtioNetHdrLen + n) buffer.Truncate(virtioNetHdrLen) @@ -135,7 +182,7 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { } func (t *NativeTun) BatchSize() int { - if !t.gsoEnabled { + if !t.vnetHdr { return 1 } /* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605 @@ -147,36 +194,67 @@ func (t *NativeTun) BatchSize() int { return idealBatchSize } +// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface +// for cases where it should be called. +func (t *NativeTun) DisableUDPGRO() { + t.writeAccess.Lock() + t.gro.disableUDPGRO() + t.writeAccess.Unlock() +} + +// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface +// for cases where it should be called. +func (t *NativeTun) DisableTCPGRO() { + t.writeAccess.Lock() + t.gro.disableTCPGRO() + t.writeAccess.Unlock() +} + func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) { - t.gsoReadAccess.Lock() - defer t.gsoReadAccess.Unlock() - n, err = t.tunFile.Read(t.gsoBuffer) + t.readAccess.Lock() + defer t.readAccess.Unlock() + n, err = t.tunFile.Read(t.writeBuffer) if err != nil { return } - return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset) + return handleVirtioRead(t.writeBuffer[:n], buffers, readN, offset) } -func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error { - t.tcpGROAccess.Lock() +func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) (int, error) { + t.writeAccess.Lock() defer func() { - t.tcp4GROTable.reset() - t.tcp6GROTable.reset() - t.tcpGROAccess.Unlock() + t.tcpGROTable.reset() + t.udpGROTable.reset() + t.writeAccess.Unlock() }() + var ( + errs error + total int + ) t.gsoToWrite = t.gsoToWrite[:0] - err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite) - if err != nil { - return err + if t.vnetHdr { + err := handleGRO(buffers, offset, t.tcpGROTable, t.udpGROTable, t.gro, &t.gsoToWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen + } else { + for i := range buffers { + t.gsoToWrite = append(t.gsoToWrite, i) + } } - offset -= virtioNetHdrLen - for _, bufferIndex := range t.gsoToWrite { - _, err = t.tunFile.Write(buffers[bufferIndex][offset:]) + for _, toWrite := range t.gsoToWrite { + n, err := t.tunFile.Write(buffers[toWrite][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } if err != nil { - return err + errs = errors.Join(errs, err) + } else { + total += n } } - return nil + return total, errs } var controlPath string @@ -262,10 +340,14 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { if err != nil { return err } - t.gsoEnabled = true - t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize)) - t.tcp4GROTable = newTCPGROTable() - t.tcp6GROTable = newTCPGROTable() + t.vnetHdr = true + t.writeBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize)) + t.tcpGROTable = newTCPGROTable() + t.udpGROTable = newUDPGROTable() + err = setUDPOffload(t.tunFd) + if err != nil { + t.gro.disableUDPGRO() + } } var rxChecksumOffload bool @@ -280,7 +362,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { if err != nil { return err } - if err == nil && !txChecksumOffload { + if !txChecksumOffload { err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM) if err != nil { return err diff --git a/tun_linux_flags.go b/tun_linux_flags.go index 1b84baa..53fff08 100644 --- a/tun_linux_flags.go +++ b/tun_linux_flags.go @@ -12,6 +12,12 @@ import ( "golang.org/x/sys/unix" ) +const ( + // TODO: support TSO with ECN bits + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 +) + func checkVNETHDREnabled(fd int, name string) (bool, error) { ifr, err := unix.NewIfreq(name) if err != nil { @@ -25,17 +31,17 @@ func checkVNETHDREnabled(fd int, name string) (bool, error) { } func setTCPOffload(fd int) error { - const ( - // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 - ) - err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads) + err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload") } return nil } +func setUDPOffload(fd int) error { + return unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) +} + type ifreqData struct { ifrName [unix.IFNAMSIZ]byte ifrData uintptr diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index 1edeab1..064f740 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -10,7 +10,7 @@ import ( var _ GVisorTun = (*NativeTun)(nil) func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { - if t.gsoEnabled { + if t.vnetHdr { return fdbased.New(&fdbased.Options{ FDs: []int{t.tunFd}, MTU: t.options.MTU, diff --git a/tun_linux_offload.go b/tun_linux_offload.go deleted file mode 100644 index 488d3ff..0000000 --- a/tun_linux_offload.go +++ /dev/null @@ -1,768 +0,0 @@ -//go:build linux - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "unsafe" - - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - E "github.com/sagernet/sing/common/exceptions" - - "golang.org/x/sys/unix" -) - -const ( - gsoMaxSize = 65536 - tcpFlagsOffset = 13 - idealBatchSize = 128 -) - -const ( - tcpFlagFIN uint8 = 0x01 - tcpFlagPSH uint8 = 0x08 - tcpFlagACK uint8 = 0x10 -) - -// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The -// kernel symbol is virtio_net_hdr. -type virtioNetHdr struct { - flags uint8 - gsoType uint8 - hdrLen uint16 - gsoSize uint16 - csumStart uint16 - csumOffset uint16 -} - -func (v *virtioNetHdr) decode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) - return nil -} - -func (v *virtioNetHdr) encode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) - return nil -} - -const ( - // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the - // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). - virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) -) - -// flowKey represents the key for a flow. -type flowKey struct { - srcAddr, dstAddr [16]byte - srcPort, dstPort uint16 - rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. -} - -// tcpGROTable holds flow and coalescing information for the purposes of GRO. -type tcpGROTable struct { - itemsByFlow map[flowKey][]tcpGROItem - itemsPool [][]tcpGROItem -} - -func newTCPGROTable() *tcpGROTable { - t := &tcpGROTable{ - itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize), - itemsPool: make([][]tcpGROItem, idealBatchSize), - } - for i := range t.itemsPool { - t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize) - } - return t -} - -func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { - key := flowKey{} - addrSize := dstAddr - srcAddr - copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) - copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) - key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) - key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) - key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) - return key -} - -// lookupOrInsert looks up a flow for the provided packet and metadata, -// returning the packets found for the flow, or inserting a new one if none -// is found. -func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - items, ok := t.itemsByFlow[key] - if ok { - return items, ok - } - // TODO: insert() performs another map lookup. This could be rearranged to avoid. - t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) - return nil, false -} - -// insert an item in the table for the provided packet and packet metadata. -func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - item := tcpGROItem{ - key: key, - bufsIndex: uint16(bufsIndex), - gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), - iphLen: uint8(tcphOffset), - tcphLen: uint8(tcphLen), - sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), - pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, - } - items, ok := t.itemsByFlow[key] - if !ok { - items = t.newItems() - } - items = append(items, item) - t.itemsByFlow[key] = items -} - -func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { - items, _ := t.itemsByFlow[item.key] - items[i] = item -} - -func (t *tcpGROTable) deleteAt(key flowKey, i int) { - items, _ := t.itemsByFlow[key] - items = append(items[:i], items[i+1:]...) - t.itemsByFlow[key] = items -} - -// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime -// of a GRO evaluation across a vector of packets. -type tcpGROItem struct { - key flowKey - sentSeq uint32 // the sequence number - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item - gsoSize uint16 // payload size - iphLen uint8 // ip header len - tcphLen uint8 // tcp header len - pshSet bool // psh flag is set -} - -func (t *tcpGROTable) newItems() []tcpGROItem { - var items []tcpGROItem - items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] - return items -} - -func (t *tcpGROTable) reset() { - for k, items := range t.itemsByFlow { - items = items[:0] - t.itemsPool = append(t.itemsPool, items) - delete(t.itemsByFlow, k) - } -} - -// canCoalesce represents the outcome of checking if two TCP packets are -// candidates for coalescing. -type canCoalesce int - -const ( - coalescePrepend canCoalesce = -1 - coalesceUnavailable canCoalesce = 0 - coalesceAppend canCoalesce = 1 -) - -// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet -// described by item. This function makes considerations that match the kernel's -// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. -func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] - if tcphLen != item.tcphLen { - // cannot coalesce with unequal tcp options len - return coalesceUnavailable - } - if tcphLen > 20 { - if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { - // cannot coalesce with unequal tcp options - return coalesceUnavailable - } - } - if pkt[0]>>4 == 6 { - if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { - // cannot coalesce with unequal Traffic class values - return coalesceUnavailable - } - if pkt[7] != pktTarget[7] { - // cannot coalesce with unequal Hop limit values - return coalesceUnavailable - } - } else { - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable - } - if pkt[8] != pktTarget[8] { - // cannot coalesce with unequal TTL values - return coalesceUnavailable - } - } - // seq adjacency - lhsLen := item.gsoSize - lhsLen += item.numMerged * item.gsoSize - if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective - if item.pshSet { - // We cannot append to a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { - // A smaller than gsoSize packet has been appended previously. - // Nothing can come after a smaller packet on the end. - return coalesceUnavailable - } - if gsoSize > item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - return coalesceAppend - } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective - if pshSet { - // We cannot prepend with a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if gsoSize < item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - if gsoSize > item.gsoSize && item.numMerged > 0 { - // There's at least one previous merge, and we're larger than all - // previous. This would put multiple smaller packets on the end. - return coalesceUnavailable - } - return coalescePrepend - } - return coalesceUnavailable -} - -func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { - srcAddrAt := ipv4SrcAddrOffset - addrSize := 4 - if isV6 { - srcAddrAt = ipv6SrcAddrOffset - addrSize = 16 - } - tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0 -} - -// coalesceResult represents the result of attempting to coalesce two TCP -// packets. -type coalesceResult int - -const ( - coalesceInsufficientCap coalesceResult = iota - coalescePSHEnding - coalesceItemInvalidCSum - coalescePktInvalidCSum - coalesceSuccess -) - -// coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, returning the outcome. This function may swap bufs elements in the -// event of a prepend as item's bufs index is already being tracked for writing -// to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - var pktHead []byte // the packet that will end up at the front - headersLen := item.iphLen + item.tcphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - // Copy data - if mode == coalescePrepend { - pktHead = pkt - if cap(pkt)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if pshSet { - return coalescePSHEnding - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - item.sentSeq = seq - extendBy := coalescedLen - len(pktHead) - bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) - copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) - // Flip the slice headers in bufs as part of prepend. The index of item - // is already being tracked for writing. - bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] - } else { - pktHead = bufs[item.bufsIndex][bufsOffset:] - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - if pshSet { - // We are appending a segment with PSH set. - item.pshSet = pshSet - pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH - } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) - } - - if gsoSize > item.gsoSize { - item.gsoSize = gsoSize - } - - item.numMerged++ - return coalesceSuccess -} - -const ( - ipv4FlagMoreFragments uint8 = 0x20 -) - -const ( - ipv4SrcAddrOffset = 12 - ipv6SrcAddrOffset = 8 - maxUint16 = 1<<16 - 1 -) - -type tcpGROResult int - -const ( - tcpGROResultNoop tcpGROResult = iota - tcpGROResultTableInsert - tcpGROResultCoalesced -) - -// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It returns a tcpGROResultNoop when no -// action was taken, tcpGROResultTableInsert when the evaluated packet was -// inserted into table, and tcpGROResultCoalesced when the evaluated packet was -// coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { - pkt := bufs[pktI][offset:] - if len(pkt) > maxUint16 { - // A valid IPv4 or IPv6 packet will never exceed this. - return tcpGROResultNoop - } - iphLen := int((pkt[0] & 0x0F) * 4) - if isV6 { - iphLen = 40 - ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) - if ipv6HPayloadLen != len(pkt)-iphLen { - return tcpGROResultNoop - } - } else { - totalLen := int(binary.BigEndian.Uint16(pkt[2:])) - if totalLen != len(pkt) { - return tcpGROResultNoop - } - } - if len(pkt) < iphLen { - return tcpGROResultNoop - } - tcphLen := int((pkt[iphLen+12] >> 4) * 4) - if tcphLen < 20 || tcphLen > 60 { - return tcpGROResultNoop - } - if len(pkt) < iphLen+tcphLen { - return tcpGROResultNoop - } - if !isV6 { - if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { - // no GRO support for fragmented segments for now - return tcpGROResultNoop - } - } - tcpFlags := pkt[iphLen+tcpFlagsOffset] - var pshSet bool - // not a candidate if any non-ACK flags (except PSH+ACK) are set - if tcpFlags != tcpFlagACK { - if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return tcpGROResultNoop - } - pshSet = true - } - gsoSize := uint16(len(pkt) - tcphLen - iphLen) - // not a candidate if payload len is 0 - if gsoSize < 1 { - return tcpGROResultNoop - } - seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) - srcAddrOffset := ipv4SrcAddrOffset - addrLen := 4 - if isV6 { - srcAddrOffset = ipv6SrcAddrOffset - addrLen = 16 - } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - if !existing { - return tcpGROResultNoop - } - for i := len(items) - 1; i >= 0; i-- { - // In the best case of packets arriving in order iterating in reverse is - // more efficient if there are multiple items for a given flow. This - // also enables a natural table.deleteAt() in the - // coalesceItemInvalidCSum case without the need for index tracking. - // This algorithm makes a best effort to coalesce in the event of - // unordered packets, where pkt may land anywhere in items from a - // sequence number perspective, however once an item is inserted into - // the table it is never compared across other items later. - item := items[i] - can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) - if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) - switch result { - case coalesceSuccess: - table.updateAt(item, i) - return tcpGROResultCoalesced - case coalesceItemInvalidCSum: - // delete the item with an invalid csum - table.deleteAt(item.key, i) - case coalescePktInvalidCSum: - // no point in inserting an item that we can't coalesce - return tcpGROResultNoop - default: - } - } - } - // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return tcpGROResultTableInsert -} - -func isTCP4NoIPOptions(b []byte) bool { - if len(b) < 40 { - return false - } - if b[0]>>4 != 4 { - return false - } - if b[0]&0x0F != 5 { - return false - } - if b[9] != unix.IPPROTO_TCP { - return false - } - return true -} - -func isTCP6NoEH(b []byte) bool { - if len(b) < 60 { - return false - } - if b[0]>>4 != 6 { - return false - } - if b[6] != unix.IPPROTO_TCP { - return false - } - return true -} - -// applyCoalesceAccounting updates bufs to account for coalescing based on the -// metadata found in table. -func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { - for _, items := range table.itemsByFlow { - for _, item := range items { - if item.numMerged > 0 { - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(item.iphLen + item.tcphLen), - gsoSize: item.gsoSize, - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - pkt := bufs[item.bufsIndex][offset:] - - // Recalculate the total len (IPv4) or payload len (IPv6). - // Recalculate the (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pkt[10], pkt[11] = 0, 0 - binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field - } - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - - // Calculate the pseudo header checksum and place it at the TCP - // checksum offset. Downstream checksum offloading will combine - // this with computation of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) - binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum)) - } else { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - } - } - } - return nil -} - -// handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be -// empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { - for i := range bufs { - if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { - return errors.New("invalid offset") - } - var result tcpGROResult - switch { - case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - result = tcpGRO(bufs, offset, i, tcp4Table, false) - case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - result = tcpGRO(bufs, offset, i, tcp6Table, true) - } - switch result { - case tcpGROResultNoop: - hdr := virtioNetHdr{} - err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - fallthrough - case tcpGROResultTableInsert: - *toWrite = append(*toWrite, i) - } - } - err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) - err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) - return E.Errors(err4, err6) -} - -// tcpTSO splits packets from in into outBuffs, writing the size of each -// element into sizes. It returns the number of buffers populated, and/or an -// error. -func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { - iphLen := int(hdr.csumStart) - srcAddrOffset := ipv6SrcAddrOffset - addrLen := 16 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - in[10], in[11] = 0, 0 // clear ipv4 header checksum - srcAddrOffset = ipv4SrcAddrOffset - addrLen = 4 - } - tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) - in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum - firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) - nextSegmentDataAt := int(hdr.hdrLen) - i := 0 - for ; nextSegmentDataAt < len(in); i++ { - if i == len(outBuffs) { - return i - 1, ErrTooManySegments - } - nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) - if nextSegmentEnd > len(in) { - nextSegmentEnd = len(in) - } - segmentDataLen := nextSegmentEnd - nextSegmentDataAt - totalLen := int(hdr.hdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBuffs[i][outOffset:] - - copy(out, in[:iphLen]) - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - // For IPv4 we are responsible for incrementing the ID field, - // updating the total len field, and recalculating the header - // checksum. - if i > 0 { - id := binary.BigEndian.Uint16(out[4:]) - id += uint16(i) - binary.BigEndian.PutUint16(out[4:], id) - } - binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksumFold(out[:iphLen], 0) - binary.BigEndian.PutUint16(out[10:], ipv4CSum) - } else { - // For IPv6 we are responsible for updating the payload length field. - binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) - } - - // TCP header - copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags - } - - // payload - copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - - // TCP checksum - tcpHLen := int(hdr.hdrLen - hdr.csumStart) - tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) - - nextSegmentDataAt += int(hdr.gsoSize) - } - return i, nil -} - -func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { - cSumAt := cSumStart + cSumOffset - // The initial value at the checksum offset should be summed with the - // checksum we compute. This is typically the pseudo-header checksum. - initial := binary.BigEndian.Uint16(in[cSumAt:]) - in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial))) - return nil -} - -// handleVirtioRead splits in into bufs, leaving offset bytes at the front of -// each buffer. It mutates sizes to reflect the size of each element of bufs, -// and returns the number of packets read. -func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { - var hdr virtioNetHdr - err := hdr.decode(in) - if err != nil { - return 0, err - } - in = in[virtioNetHdrLen:] - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { - if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { - // This means CHECKSUM_PARTIAL in skb context. We are responsible - // for computing the checksum starting at hdr.csumStart and placing - // at hdr.csumOffset. - err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) - if err != nil { - return 0, err - } - } - if len(in) > len(bufs[0][offset:]) { - return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) - } - n := copy(bufs[0][offset:], in) - sizes[0] = n - return 1, nil - } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { - return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) - } - - ipVersion := in[0] >> 4 - switch ipVersion { - case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - default: - return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) - } - - if len(in) <= int(hdr.csumStart+12) { - return 0, errors.New("packet is too short") - } - // Don't trust hdr.hdrLen from the kernel as it can be equal to the length - // of the entire first packet when the kernel is handling it as part of a - // FORWARD path. Instead, parse the TCP header length and add it onto - // csumStart, which is synonymous for IP header length. - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) - if tcpHLen < 20 || tcpHLen > 60 { - // A TCP header must be between 20 and 60 bytes in length. - return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) - } - hdr.hdrLen = hdr.csumStart + tcpHLen - - if len(in) < int(hdr.hdrLen) { - return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) - } - - if hdr.hdrLen < hdr.csumStart { - return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) - } - cSumAt := int(hdr.csumStart + hdr.csumOffset) - if cSumAt+1 >= len(in) { - return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) - } - - return tcpTSO(in, hdr, bufs, sizes, offset) -} - -func checksumNoFold(b []byte, initial uint64) uint64 { - return uint64(checksum.Checksum(b, uint16(initial))) -} - -func checksumFold(b []byte, initial uint64) uint16 { - ac := checksumNoFold(b, initial) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - return uint16(ac) -} - -func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { - sum := checksumNoFold(srcAddr, 0) - sum = checksumNoFold(dstAddr, sum) - sum = checksumNoFold([]byte{0, protocol}, sum) - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - return checksumNoFold(tmp, sum) -} diff --git a/tun_linux_offload_errors.go b/tun_linux_offload_errors.go deleted file mode 100644 index 8e5db90..0000000 --- a/tun_linux_offload_errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package tun - -import E "github.com/sagernet/sing/common/exceptions" - -var ErrTooManySegments = E.New("too many segments") diff --git a/tun_offload.go b/tun_offload.go new file mode 100644 index 0000000..a0eee82 --- /dev/null +++ b/tun_offload.go @@ -0,0 +1,229 @@ +package tun + +import ( + "encoding/binary" + "fmt" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" +) + +const ( + gsoMaxSize = 65536 + idealBatchSize = 128 +) + +// GSOType represents the type of segmentation offload. +type GSOType int + +const ( + GSONone GSOType = iota + GSOTCPv4 + GSOTCPv6 + GSOUDPL4 +) + +func (g GSOType) String() string { + switch g { + case GSONone: + return "GSONone" + case GSOTCPv4: + return "GSOTCPv4" + case GSOTCPv6: + return "GSOTCPv6" + case GSOUDPL4: + return "GSOUDPL4" + default: + return "unknown" + } +} + +// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO +// specification. It is a common representation of GSO metadata that can be +// applied to support packet GSO across tun.Device implementations. +type GSOOptions struct { + // GSOType represents the type of segmentation offload. + GSOType GSOType + // HdrLen is the sum of the layer 3 and 4 header lengths. This field may be + // zero when GSOType == GSONone. + HdrLen uint16 + // CsumStart is the head byte index of the packet data to be checksummed, + // i.e. the start of the TCP or UDP header. + CsumStart uint16 + // CsumOffset is the offset from CsumStart where the 2-byte checksum value + // should be placed. + CsumOffset uint16 + // GSOSize is the size of each segment exclusive of HdrLen. The tail segment + // may be smaller than this value. + GSOSize uint16 + // NeedsCsum may be set where GSOType == GSONone. When set, the checksum + // at CsumStart + CsumOffset must be a partial checksum, i.e. the + // pseudo-header sum. + NeedsCsum bool +} + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +const ( + // defined here in order to avoid importation of any platform-specific pkgs + ipProtoTCP = 6 + ipProtoUDP = 17 +) + +// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing +// the size of each element into sizes. It returns the number of buffers +// populated, and/or an error. Callers may pass an 'in' slice that overlaps with +// the first element of outBuffers, i.e. &in[0] may be equal to +// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the +// value of options.NeedsCsum. Length of each outBufs element must be greater +// than or equal to the length of 'in', otherwise output may be silently +// truncated. +func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { + cSumAt := int(options.CsumStart) + int(options.CsumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + if len(in) < int(options.HdrLen) { + return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen) + } + + // Handle the conditions where we are copying a single element to outBuffs. + payloadLen := len(in) - int(options.HdrLen) + if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { + if len(in) > len(outBufs[0][outOffset:]) { + return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + } + if options.NeedsCsum { + // The initial value at the checksum offset should be summed with + // the checksum we compute. This is typically the pseudo-header sum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum.Checksum(in[options.CsumStart:], initial)) + } + sizes[0] = copy(outBufs[0][outOffset:], in) + return 1, nil + } + + if options.HdrLen < options.CsumStart { + return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 20 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20) + } + case 6: + if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 40 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + iphLen := int(options.CsumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if ipVersion == 4 { + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(options.CsumStart + options.CsumOffset) + var firstTCPSeqNum uint32 + var protocol uint8 + if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 { + protocol = ipProtoTCP + if len(in) < int(options.CsumStart)+20 { + return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)", + len(in), options.CsumStart, 20) + } + firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:]) + } else { + protocol = ipProtoUDP + } + nextSegmentDataAt := int(options.HdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBufs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(options.HdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBufs[i][outOffset:] + + copy(out, in[:iphLen]) + if ipVersion == 4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + out[10], out[11] = 0, 0 // clear ipv4 header checksum + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum.Checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen]) + + if protocol == ipProtoTCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i)) + binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[options.CsumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart)) + } + + // payload + copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + transportHeaderLen := int(options.HdrLen - options.CsumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(protocol), in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum = ^checksum.Checksum(out[options.CsumStart:totalLen], transportCSum) + binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum) + + nextSegmentDataAt += int(options.GSOSize) + } + return i, nil +} diff --git a/tun_offload_errors.go b/tun_offload_errors.go new file mode 100644 index 0000000..2c49fc7 --- /dev/null +++ b/tun_offload_errors.go @@ -0,0 +1,10 @@ +package tun + +import ( + "errors" +) + +// ErrTooManySegments is returned by Device.Read() when segmentation +// overflows the length of supplied buffers. This error should not cause +// reads to cease. +var ErrTooManySegments = errors.New("too many segments") diff --git a/tun_offload_linux.go b/tun_offload_linux.go new file mode 100644 index 0000000..7e44a55 --- /dev/null +++ b/tun_offload_linux.go @@ -0,0 +1,937 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "unsafe" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" + + "golang.org/x/sys/unix" +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) { + var gsoType GSOType + switch v.gsoType { + case unix.VIRTIO_NET_HDR_GSO_NONE: + gsoType = GSONone + case unix.VIRTIO_NET_HDR_GSO_TCPV4: + gsoType = GSOTCPv4 + case unix.VIRTIO_NET_HDR_GSO_TCPV6: + gsoType = GSOTCPv6 + case unix.VIRTIO_NET_HDR_GSO_UDP_L4: + gsoType = GSOUDPL4 + default: + return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType) + } + return GSOOptions{ + GSOType: gsoType, + HdrLen: v.hdrLen, + CsumStart: v.csumStart, + CsumOffset: v.csumOffset, + GSOSize: v.gsoSize, + NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0, + }, nil +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, idealBatchSize), + itemsPool: make([][]tcpGROItem, idealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, idealBatchSize), + itemsPool: make([][]udpGROItem, idealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, idealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(proto), pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum.Checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +type groDisablementFlags int + +const ( + tcpGRODisabled groDisablementFlags = 1 << iota + udpGRODisabled +) + +func (g *groDisablementFlags) disableTCPGRO() { + *g |= tcpGRODisabled +} + +func (g *groDisablementFlags) canTCPGRO() bool { + return (*g)&tcpGRODisabled == 0 +} + +func (g *groDisablementFlags) disableUDPGRO() { + *g |= udpGRODisabled +} + +func (g *groDisablementFlags) canUDPGRO() bool { + return (*g)&udpGRODisabled == 0 +} + +func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO +// are supported/enabled. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], gro) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +}