diff --git a/brutal.go b/brutal.go new file mode 100644 index 0000000..d1bcc65 --- /dev/null +++ b/brutal.go @@ -0,0 +1,60 @@ +package mux + +import ( + "encoding/binary" + "io" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" +) + +const ( + BrutalExchangeDomain = "_BrutalBwExchange" + BrutalMinSpeedBPS = 65536 +) + +func WriteBrutalRequest(writer io.Writer, receiveBPS uint64) error { + return binary.Write(writer, binary.BigEndian, receiveBPS) +} + +func ReadBrutalRequest(reader io.Reader) (uint64, error) { + var receiveBPS uint64 + err := binary.Read(reader, binary.BigEndian, &receiveBPS) + return receiveBPS, err +} + +func WriteBrutalResponse(writer io.Writer, receiveBPS uint64, ok bool, message string) error { + buffer := buf.New() + defer buffer.Release() + common.Must(binary.Write(buffer, binary.BigEndian, ok)) + if ok { + common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS)) + } + if !ok { + err := rw.WriteVString(buffer, message) + if err != nil { + return err + } + } + return common.Error(writer.Write(buffer.Bytes())) +} + +func ReadBrutalResponse(reader io.Reader) (uint64, error) { + var ok bool + err := binary.Read(reader, binary.BigEndian, &ok) + if err != nil { + return 0, err + } + if ok { + var receiveBPS uint64 + err = binary.Read(reader, binary.BigEndian, &receiveBPS) + return receiveBPS, err + } + message, err := rw.ReadVString(reader) + if err != nil { + return 0, err + } + return 0, E.New("remote error: ", message) +} diff --git a/brutal_linux.go b/brutal_linux.go new file mode 100644 index 0000000..d68096d --- /dev/null +++ b/brutal_linux.go @@ -0,0 +1,53 @@ +package mux + +import ( + "net" + "os" + "reflect" + "syscall" + "unsafe" + _ "unsafe" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/unix" +) + +const ( + TCP_BRUTAL_PARAMS = 23301 +) + +type TCPBrutalParams struct { + Rate uint64 + CwndGain uint32 +} + +//go:linkname setsockopt syscall.setsockopt +func setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error) + +func SetBrutalOptions(conn net.Conn, sendBPS uint64) error { + syscallConn, loaded := common.Cast[syscall.Conn](conn) + if !loaded { + return E.New("cannot convert from ", reflect.TypeOf(conn), " to syscall.Conn") + } + return control.Conn(syscallConn, func(fd uintptr) error { + err := unix.SetsockoptString(int(fd), unix.IPPROTO_TCP, unix.TCP_CONGESTION, "brutal") + if err != nil { + return E.Extend( + os.NewSyscallError("setsockopt IPPROTO_TCP TCP_CONGESTION brutal", err), + "please make sure you have installed the tcp-brutal kernel module", + ) + } + params := TCPBrutalParams{ + Rate: sendBPS, + CwndGain: 20, // hysteria2 default + } + err = setsockopt(int(fd), unix.IPPROTO_TCP, TCP_BRUTAL_PARAMS, unsafe.Pointer(¶ms), unsafe.Sizeof(params)) + if err != nil { + return os.NewSyscallError("setsockopt IPPROTO_TCP TCP_BRUTAL_PARAMS", err) + } + return nil + }) +} diff --git a/brutal_stub.go b/brutal_stub.go new file mode 100644 index 0000000..3960802 --- /dev/null +++ b/brutal_stub.go @@ -0,0 +1,13 @@ +//go:build !linux + +package mux + +import ( + "net" + + E "github.com/sagernet/sing/common/exceptions" +) + +func SetBrutalOptions(conn net.Conn, sendBPS uint64) error { + return E.New("TCP Brutal is only supported on Linux") +} diff --git a/client.go b/client.go index 7bba001..cf60ef7 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" @@ -15,6 +16,7 @@ import ( type Client struct { dialer N.Dialer + logger logger.Logger protocol byte maxConnections int minStreams int @@ -22,24 +24,35 @@ type Client struct { padding bool access sync.Mutex connections list.List[abstractSession] + brutal BrutalOptions } type Options struct { Dialer N.Dialer + Logger logger.Logger Protocol string MaxConnections int MinStreams int MaxStreams int Padding bool + Brutal BrutalOptions +} + +type BrutalOptions struct { + Enabled bool + SendBPS uint64 + ReceiveBPS uint64 } func NewClient(options Options) (*Client, error) { client := &Client{ dialer: options.Dialer, + logger: options.Logger, maxConnections: options.MaxConnections, minStreams: options.MinStreams, maxStreams: options.MaxStreams, padding: options.Padding, + brutal: options.Brutal, } if client.dialer == nil { client.dialer = N.SystemDialer @@ -125,6 +138,12 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) { sessions = append(sessions, element.Value) element = element.Next() } + if c.brutal.Enabled { + if len(sessions) > 0 { + return sessions[0], nil + } + return c.offerNew(ctx) + } session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) if session == nil { return c.offerNew(ctx) @@ -169,10 +188,44 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { conn.Close() return nil, err } + if c.brutal.Enabled { + err = c.brutalExchange(conn, session) + if err != nil { + conn.Close() + session.Close() + return nil, E.Cause(err, "brutal exchange") + } + } c.connections.PushBack(session) return session, nil } +func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error { + stream, err := session.Open() + if err != nil { + return err + } + conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}} + err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS) + if err != nil { + return err + } + serverReceiveBPS, err := ReadBrutalResponse(conn) + if err != nil { + return err + } + conn.Close() + sendBPS := c.brutal.SendBPS + if serverReceiveBPS < sendBPS { + sendBPS = serverReceiveBPS + } + clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS) + if clientBrutalErr != nil { + c.logger.Debug(E.Cause(err, "failed to enable TCP Brutal at client")) + } + return nil +} + func (c *Client) Reset() { c.access.Lock() defer c.access.Unlock() diff --git a/server.go b/server.go index a805254..5bad622 100644 --- a/server.go +++ b/server.go @@ -12,19 +12,46 @@ import ( "github.com/sagernet/sing/common/task" ) -type ServerHandler interface { +type ServiceHandler interface { N.TCPConnectionHandler N.UDPConnectionHandler - E.Handler } -func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, conn net.Conn, metadata M.Metadata) error { +type Service struct { + newStreamContext func(context.Context, net.Conn) context.Context + logger logger.ContextLogger + handler ServiceHandler + padding bool + brutal BrutalOptions +} + +type ServiceOptions struct { + NewStreamContext func(context.Context, net.Conn) context.Context + Logger logger.ContextLogger + Handler ServiceHandler + Padding bool + Brutal BrutalOptions +} + +func NewService(options ServiceOptions) (*Service, error) { + return &Service{ + newStreamContext: options.NewStreamContext, + logger: options.Logger, + handler: options.Handler, + padding: options.Padding, + brutal: options.Brutal, + }, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { request, err := ReadRequest(conn) if err != nil { return err } if request.Padding { conn = newPaddingConn(conn) + } else if s.padding { + return E.New("non-padded connection rejected") } session, err := newServerSession(conn, request.Protocol) if err != nil { @@ -38,7 +65,13 @@ func HandleConnection(ctx context.Context, handler ServerHandler, logger logger. if err != nil { return err } - go newConnection(ctx, handler, logger, stream, metadata) + streamCtx := s.newStreamContext(ctx, stream) + go func() { + hErr := s.newConnection(streamCtx, stream, metadata) + if hErr != nil { + s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection")) + } + }() } }) group.Cleanup(func() { @@ -47,34 +80,60 @@ func HandleConnection(ctx context.Context, handler ServerHandler, logger logger. return group.Run(ctx) } -func newConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, stream net.Conn, metadata M.Metadata) { +func (s *Service) newConnection(ctx context.Context, stream net.Conn, metadata M.Metadata) error { stream = &wrapStream{stream} request, err := ReadStreamRequest(stream) if err != nil { - logger.ErrorContext(ctx, err) - return + return E.Cause(err, "read multiplex stream request") + } + if request.Destination.Fqdn == BrutalExchangeDomain { + defer stream.Close() + var clientReceiveBPS uint64 + clientReceiveBPS, err = ReadBrutalRequest(stream) + if err != nil { + return E.Cause(err, "read brutal request") + } + if !s.brutal.Enabled { + err = WriteBrutalResponse(stream, 0, false, "brutal is not enabled by the server") + if err != nil { + return E.Cause(err, "write brutal response") + } + return nil + } + sendBPS := s.brutal.SendBPS + if clientReceiveBPS < sendBPS { + sendBPS = clientReceiveBPS + } + err = SetBrutalOptions(stream, sendBPS) + if err != nil { + err = WriteBrutalResponse(stream, 0, false, E.Cause(err, "enable TCP Brutal").Error()) + if err != nil { + return E.Cause(err, "write brutal response") + } + return nil + } + err = WriteBrutalResponse(stream, s.brutal.ReceiveBPS, true, "") + if err != nil { + return E.Cause(err, "write brutal response") + } + return nil } metadata.Destination = request.Destination if request.Network == N.NetworkTCP { - logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) - hErr := handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata) + s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) + s.handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata) stream.Close() - if hErr != nil { - handler.NewError(ctx, hErr) - } } else { var packetConn N.PacketConn if !request.PacketAddr { - logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) + s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} } else { - logger.InfoContext(ctx, "inbound multiplex packet connection") + s.logger.InfoContext(ctx, "inbound multiplex packet connection") packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} } - hErr := handler.NewPacketConnection(ctx, packetConn, metadata) + s.handler.NewPacketConnection(ctx, packetConn, metadata) stream.Close() - if hErr != nil { - handler.NewError(ctx, hErr) - } } + return nil } diff --git a/server_default.go b/server_default.go deleted file mode 100644 index f10247e..0000000 --- a/server_default.go +++ /dev/null @@ -1,36 +0,0 @@ -package mux - -import ( - "context" - "net" - - "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -func HandleConnectionDefault(ctx context.Context, conn net.Conn) error { - return HandleConnection(ctx, (*defaultServerHandler)(nil), logger.NOP(), conn, M.Metadata{}) -} - -type defaultServerHandler struct{} - -func (h *defaultServerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - remoteConn, err := N.SystemDialer.DialContext(ctx, N.NetworkTCP, metadata.Destination) - if err != nil { - return err - } - return bufio.CopyConn(ctx, conn, remoteConn) -} - -func (h *defaultServerHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { - remoteConn, err := N.SystemDialer.ListenPacket(ctx, metadata.Destination) - if err != nil { - return err - } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(remoteConn)) -} - -func (h *defaultServerHandler) NewError(ctx context.Context, err error) { -}