Skip to content

Commit

Permalink
Implement read waiter for UDP
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 7, 2023
1 parent 13614c0 commit f08f673
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 44 deletions.
6 changes: 4 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M.
if err != nil {
return nil, err
}
return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
Expand All @@ -97,7 +98,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
if err != nil {
return nil, err
}
return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
}

func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
Expand Down
84 changes: 45 additions & 39 deletions client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,24 @@ func (c *clientConn) Upstream() any {
return c.Conn
}

var _ N.NetPacketConn = (*clientPacketConn)(nil)

type clientPacketConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
N.AbstractConn
conn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
readWaitOptions N.ReadWaitOptions
}

func (c *clientPacketConn) NeedHandshake() bool {
return !c.requestWritten
}

func (c *clientPacketConn) readResponse() error {
response, err := ReadStreamResponse(c.ExtendedConn)
response, err := ReadStreamResponse(c.conn)
if err != nil {
return err
}
Expand All @@ -125,14 +129,14 @@ func (c *clientPacketConn) Read(b []byte) (n int, err error) {
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(b) < int(length) {
return 0, io.ErrShortBuffer
}
return io.ReadFull(c.ExtendedConn, b[:length])
return io.ReadFull(c.conn, b[:length])
}

func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
Expand All @@ -156,7 +160,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
common.Error(buffer.Write(payload)),
)
}
_, err = c.ExtendedConn.Write(buffer.Bytes())
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
Expand All @@ -174,11 +178,11 @@ func (c *clientPacketConn) Write(b []byte) (n int, err error) {
return c.writeRequest(b)
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
err = binary.Write(c.conn, binary.BigEndian, uint16(len(b)))
if err != nil {
return
}
return c.ExtendedConn.Write(b)
return c.conn.Write(b)
}

func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
Expand All @@ -190,11 +194,11 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
_, err = buffer.ReadFullFrom(c.conn, int(length))
return
}

Expand All @@ -211,7 +215,7 @@ func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
}
bLen := buffer.Len()
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
return c.ExtendedConn.WriteBuffer(buffer)
return c.conn.WriteBuffer(buffer)
}

func (c *clientPacketConn) FrontHeadroom() int {
Expand All @@ -227,14 +231,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(p) < int(length) {
return 0, nil, io.ErrShortBuffer
}
n, err = io.ReadFull(c.ExtendedConn, p[:length])
n, err = io.ReadFull(c.conn, p[:length])
return
}

Expand All @@ -248,11 +252,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return c.writeRequest(p)
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
if err != nil {
return
}
return c.ExtendedConn.Write(p)
return c.conn.Write(p)
}

func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
Expand All @@ -265,7 +269,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
}

func (c *clientPacketConn) LocalAddr() net.Addr {
return c.ExtendedConn.LocalAddr()
return c.conn.LocalAddr()
}

func (c *clientPacketConn) RemoteAddr() net.Addr {
Expand All @@ -277,25 +281,27 @@ func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
}

func (c *clientPacketConn) Upstream() any {
return c.ExtendedConn
return c.conn
}

var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)

type clientPacketAddrConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
N.AbstractConn
conn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
readWaitOptions N.ReadWaitOptions
}

func (c *clientPacketAddrConn) NeedHandshake() bool {
return !c.requestWritten
}

func (c *clientPacketAddrConn) readResponse() error {
response, err := ReadStreamResponse(c.ExtendedConn)
response, err := ReadStreamResponse(c.conn)
if err != nil {
return err
}
Expand All @@ -313,7 +319,7 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
}
c.responseRead = true
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
Expand All @@ -323,14 +329,14 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
addr = destination.UDPAddr()
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(p) < int(length) {
return 0, nil, io.ErrShortBuffer
}
n, err = io.ReadFull(c.ExtendedConn, p[:length])
n, err = io.ReadFull(c.conn, p[:length])
return
}

Expand Down Expand Up @@ -360,7 +366,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
common.Error(buffer.Write(payload)),
)
}
_, err = c.ExtendedConn.Write(buffer.Bytes())
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
Expand All @@ -378,15 +384,15 @@ func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
return c.writeRequest(p, M.SocksaddrFromNet(addr))
}
}
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr))
if err != nil {
return
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
if err != nil {
return
}
return c.ExtendedConn.Write(p)
return c.conn.Write(p)
}

func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
Expand All @@ -397,16 +403,16 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
_, err = buffer.ReadFullFrom(c.conn, int(length))
return
}

Expand All @@ -428,11 +434,11 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
return err
}
common.Must(binary.Write(header, binary.BigEndian, uint16(bLen)))
return c.ExtendedConn.WriteBuffer(buffer)
return c.conn.WriteBuffer(buffer)
}

func (c *clientPacketAddrConn) LocalAddr() net.Addr {
return c.ExtendedConn.LocalAddr()
return c.conn.LocalAddr()
}

func (c *clientPacketAddrConn) FrontHeadroom() int {
Expand All @@ -444,5 +450,5 @@ func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
}

func (c *clientPacketAddrConn) Upstream() any {
return c.ExtendedConn
return c.conn
}
73 changes: 73 additions & 0 deletions client_conn_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package mux

import (
"encoding/binary"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var _ N.PacketReadWaiter = (*clientPacketConn)(nil)

func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}

var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil)

func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.18

require (
github.com/hashicorp/yamux v0.1.1
github.com/sagernet/sing v0.2.18
github.com/sagernet/sing v0.2.19-0.20231207032540-dbccc28f8194
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
golang.org/x/net v0.19.0
golang.org/x/sys v0.15.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.18 h1:2Ce4dl0pkWft+4914NGXPb8OiQpgA8UHQ9xFOmgvKuY=
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.2.19-0.20231207032540-dbccc28f8194 h1:lphv+waf4VhMIPkOiTewsHaCrBC7Jyrkt/uOKgjLnso=
github.com/sagernet/sing v0.2.19-0.20231207032540-dbccc28f8194/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
Expand Down

0 comments on commit f08f673

Please sign in to comment.