Skip to content

Commit

Permalink
Add tcp-brutal support
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 2, 2023
1 parent 1739640 commit a36b958
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 54 deletions.
60 changes: 60 additions & 0 deletions brutal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package mux

import (
"encoding/binary"
"io"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
)

const (
BrutalExchangeDomain = "_BrutalBwExchange"
BrutalMinSpeedBPS = 65536
)

func WriteBrutalRequest(writer io.Writer, receiveBPS uint64) error {
return binary.Write(writer, binary.BigEndian, receiveBPS)
}

func ReadBrutalRequest(reader io.Reader) (uint64, error) {
var receiveBPS uint64
err := binary.Read(reader, binary.BigEndian, &receiveBPS)
return receiveBPS, err
}

func WriteBrutalResponse(writer io.Writer, receiveBPS uint64, ok bool, message string) error {
buffer := buf.New()
defer buffer.Release()
common.Must(binary.Write(buffer, binary.BigEndian, ok))
if ok {
common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS))
}
if !ok {
err := rw.WriteVString(buffer, message)
if err != nil {
return err
}
}
return common.Error(writer.Write(buffer.Bytes()))
}

func ReadBrutalResponse(reader io.Reader) (uint64, error) {
var ok bool
err := binary.Read(reader, binary.BigEndian, &ok)
if err != nil {
return 0, err
}
if ok {
var receiveBPS uint64
err = binary.Read(reader, binary.BigEndian, &receiveBPS)
return receiveBPS, err
}
message, err := rw.ReadVString(reader)
if err != nil {
return 0, err
}
return 0, E.New("remote error: ", message)
}
53 changes: 53 additions & 0 deletions brutal_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package mux

import (
"net"
"os"
"reflect"
"syscall"
"unsafe"
_ "unsafe"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"

"golang.org/x/sys/unix"
)

const (
TCP_BRUTAL_PARAMS = 23301
)

type TCPBrutalParams struct {
Rate uint64
CwndGain uint32
}

//go:linkname setsockopt syscall.setsockopt
func setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error)

func SetBrutalOptions(conn net.Conn, sendBPS uint64) error {
syscallConn, loaded := common.Cast[syscall.Conn](conn)
if !loaded {
return E.New("cannot convert from ", reflect.TypeOf(conn), " to syscall.Conn")
}
return control.Conn(syscallConn, func(fd uintptr) error {
err := unix.SetsockoptString(int(fd), unix.IPPROTO_TCP, unix.TCP_CONGESTION, "brutal")
if err != nil {
return E.Extend(
os.NewSyscallError("setsockopt IPPROTO_TCP TCP_CONGESTION brutal", err),
"please make sure you have installed the tcp-brutal kernel module",
)
}
params := TCPBrutalParams{
Rate: sendBPS,
CwndGain: 20, // hysteria2 default
}
err = setsockopt(int(fd), unix.IPPROTO_TCP, TCP_BRUTAL_PARAMS, unsafe.Pointer(&params), unsafe.Sizeof(params))
if err != nil {
return os.NewSyscallError("setsockopt IPPROTO_TCP TCP_BRUTAL_PARAMS", err)
}
return nil
})
}
13 changes: 13 additions & 0 deletions brutal_stub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build !linux

package mux

import (
"net"

E "github.com/sagernet/sing/common/exceptions"
)

func SetBrutalOptions(conn net.Conn, sendBPS uint64) error {
return E.New("TCP Brutal is only supported on Linux")
}
53 changes: 53 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,51 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)

type Client struct {
dialer N.Dialer
logger logger.Logger
protocol byte
maxConnections int
minStreams int
maxStreams int
padding bool
access sync.Mutex
connections list.List[abstractSession]
brutal BrutalOptions
}

type Options struct {
Dialer N.Dialer
Logger logger.Logger
Protocol string
MaxConnections int
MinStreams int
MaxStreams int
Padding bool
Brutal BrutalOptions
}

type BrutalOptions struct {
Enabled bool
SendBPS uint64
ReceiveBPS uint64
}

func NewClient(options Options) (*Client, error) {
client := &Client{
dialer: options.Dialer,
logger: options.Logger,
maxConnections: options.MaxConnections,
minStreams: options.MinStreams,
maxStreams: options.MaxStreams,
padding: options.Padding,
brutal: options.Brutal,
}
if client.dialer == nil {
client.dialer = N.SystemDialer
Expand Down Expand Up @@ -125,6 +138,12 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) {
sessions = append(sessions, element.Value)
element = element.Next()
}
if c.brutal.Enabled {
if len(sessions) > 0 {
return sessions[0], nil
}
return c.offerNew(ctx)
}
session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams)
if session == nil {
return c.offerNew(ctx)
Expand Down Expand Up @@ -169,10 +188,44 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
conn.Close()
return nil, err
}
if c.brutal.Enabled {
err = c.brutalExchange(conn, session)
if err != nil {
conn.Close()
session.Close()
return nil, E.Cause(err, "brutal exchange")
}
}
c.connections.PushBack(session)
return session, nil
}

func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error {
stream, err := session.Open()
if err != nil {
return err
}
conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}}
err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS)
if err != nil {
return err
}
serverReceiveBPS, err := ReadBrutalResponse(conn)
if err != nil {
return err
}
conn.Close()
sendBPS := c.brutal.SendBPS
if serverReceiveBPS < sendBPS {
sendBPS = serverReceiveBPS
}
clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS)
if clientBrutalErr != nil {
c.logger.Debug(E.Cause(err, "failed to enable TCP Brutal at client"))
}
return nil
}

func (c *Client) Reset() {
c.access.Lock()
defer c.access.Unlock()
Expand Down
95 changes: 77 additions & 18 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,46 @@ import (
"github.com/sagernet/sing/common/task"
)

type ServerHandler interface {
type ServiceHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
E.Handler
}

func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, conn net.Conn, metadata M.Metadata) error {
type Service struct {
newStreamContext func(context.Context, net.Conn) context.Context
logger logger.ContextLogger
handler ServiceHandler
padding bool
brutal BrutalOptions
}

type ServiceOptions struct {
NewStreamContext func(context.Context, net.Conn) context.Context
Logger logger.ContextLogger
Handler ServiceHandler
Padding bool
Brutal BrutalOptions
}

func NewService(options ServiceOptions) (*Service, error) {
return &Service{
newStreamContext: options.NewStreamContext,
logger: options.Logger,
handler: options.Handler,
padding: options.Padding,
brutal: options.Brutal,
}, nil
}

func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
request, err := ReadRequest(conn)
if err != nil {
return err
}
if request.Padding {
conn = newPaddingConn(conn)
} else if s.padding {
return E.New("non-padded connection rejected")
}
session, err := newServerSession(conn, request.Protocol)
if err != nil {
Expand All @@ -38,7 +65,13 @@ func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.
if err != nil {
return err
}
go newConnection(ctx, handler, logger, stream, metadata)
streamCtx := s.newStreamContext(ctx, stream)
go func() {
hErr := s.newConnection(streamCtx, stream, metadata)
if hErr != nil {
s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection"))
}
}()
}
})
group.Cleanup(func() {
Expand All @@ -47,34 +80,60 @@ func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.
return group.Run(ctx)
}

func newConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, stream net.Conn, metadata M.Metadata) {
func (s *Service) newConnection(ctx context.Context, stream net.Conn, metadata M.Metadata) error {
stream = &wrapStream{stream}
request, err := ReadStreamRequest(stream)
if err != nil {
logger.ErrorContext(ctx, err)
return
return E.Cause(err, "read multiplex stream request")
}
if request.Destination.Fqdn == BrutalExchangeDomain {
defer stream.Close()
var clientReceiveBPS uint64
clientReceiveBPS, err = ReadBrutalRequest(stream)
if err != nil {
return E.Cause(err, "read brutal request")
}
if !s.brutal.Enabled {
err = WriteBrutalResponse(stream, 0, false, "brutal is not enabled by the server")
if err != nil {
return E.Cause(err, "write brutal response")
}
return nil
}
sendBPS := s.brutal.SendBPS
if clientReceiveBPS < sendBPS {
sendBPS = clientReceiveBPS
}
err = SetBrutalOptions(stream, sendBPS)
if err != nil {
err = WriteBrutalResponse(stream, 0, false, E.Cause(err, "enable TCP Brutal").Error())
if err != nil {
return E.Cause(err, "write brutal response")
}
return nil
}
err = WriteBrutalResponse(stream, s.brutal.ReceiveBPS, true, "")
if err != nil {
return E.Cause(err, "write brutal response")
}
return nil
}
metadata.Destination = request.Destination
if request.Network == N.NetworkTCP {
logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
hErr := handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
s.handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
stream.Close()
if hErr != nil {
handler.NewError(ctx, hErr)
}
} else {
var packetConn N.PacketConn
if !request.PacketAddr {
logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
} else {
logger.InfoContext(ctx, "inbound multiplex packet connection")
s.logger.InfoContext(ctx, "inbound multiplex packet connection")
packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
}
hErr := handler.NewPacketConnection(ctx, packetConn, metadata)
s.handler.NewPacketConnection(ctx, packetConn, metadata)
stream.Close()
if hErr != nil {
handler.NewError(ctx, hErr)
}
}
return nil
}
Loading

0 comments on commit a36b958

Please sign in to comment.