diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index d682558fb..e155e3c59 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -213,7 +213,6 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf buffer.Truncate(readN) } else { buffer.Release() - buffer = nil } if w.readErr == syscall.EAGAIN { return false diff --git a/common/uot/conn.go b/common/uot/conn.go index 81382a478..3306cc6eb 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -13,11 +13,17 @@ import ( N "github.com/sagernet/sing/common/network" ) +var ( + _ N.NetPacketConn = (*Conn)(nil) + _ N.PacketReadWaiter = (*Conn)(nil) +) + type Conn struct { net.Conn isConnect bool destination M.Socksaddr writer N.VectorisedWriter + newBuffer func() *buf.Buffer } func NewConn(conn net.Conn, request Request) *Conn { @@ -135,6 +141,33 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WriteVectorised([]*buf.Buffer{header, buffer}) } +func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + c.newBuffer = newBuffer +} + +func (c *Conn) WaitReadPacket() (destination M.Socksaddr, err error) { + if c.isConnect { + destination = c.destination + } else { + destination, err = AddrParser.ReadAddrPort(c.Conn) + if err != nil { + return + } + } + var length uint16 + err = binary.Read(c.Conn, binary.BigEndian, &length) + if err != nil { + return + } + buffer := c.newBuffer() + _, err = buffer.ReadFullFrom(c.Conn, int(length)) + if err != nil { + buffer.Release() + return M.Socksaddr{}, E.Cause(err, "UoT read") + } + return +} + func (c *Conn) NeedAdditionalReadDeadline() bool { return true }