Skip to content

Commit

Permalink
Improve read waiter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed May 12, 2023
1 parent ab3e469 commit 9be7806
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 32 deletions.
77 changes: 47 additions & 30 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"io"
"net/netip"
"os"
"syscall"

"github.com/sagernet/sing/common/buf"
Expand All @@ -25,24 +26,21 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
bufferSize = buf.BufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
buffer *buf.Buffer
readBuffer *buf.Buffer
notFirstTime bool
)
newBuffer := func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
}
var notFirstTime bool
})
defer source.InitializeReadWaiter(nil)
for {
err = source.WaitReadBuffer(newBuffer)
err = source.WaitReadBuffer()
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
err = nil
return
Expand All @@ -56,9 +54,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
if err != nil {
if buffer != nil {
buffer.Release()
}
buffer.Release()
return
}
n += int64(dataLen)
Expand All @@ -83,25 +79,22 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
notFirstTime bool
)
newBuffer := func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
}
var destination M.Socksaddr
var notFirstTime bool
})
defer source.InitializeReadWaiter(nil)
for {
destination, err = source.WaitReadPacket(newBuffer)
destination, err = source.WaitReadPacket()
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
Expand All @@ -113,8 +106,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
if err != nil {
buffer.Release()
return
} else {
buffer = nil
}
n += int64(dataLen)
for _, counter := range readCounters {
Expand All @@ -127,6 +118,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
}
}

var _ N.ReadWaiter = (*syscallReadWaiter)(nil)

type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
Expand All @@ -143,8 +136,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false
}

func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
if w.readFunc == nil {
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
Expand All @@ -164,16 +160,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
return true
}
}
}

func (w *syscallReadWaiter) WaitReadBuffer() error {
if w.readFunc == nil {
return os.ErrInvalid
}
err := w.rawConn.Read(w.readFunc)
if err != nil {
return err
}
if w.readErr != nil {
if w.readErr == io.EOF {
return io.EOF
}
return E.Cause(w.readErr, "raw read")
}
return nil
}

var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)

type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
Expand All @@ -191,8 +198,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false
}

func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
if w.readFunc == nil {
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
Expand Down Expand Up @@ -221,6 +232,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
return true
}
}
}

func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
if w.readFunc == nil {
return M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
Expand Down
6 changes: 4 additions & 2 deletions common/network/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ import (
)

type ReadWaiter interface {
WaitReadBuffer(newBuffer func() *buf.Buffer) error
InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadBuffer() error
}

type ReadWaitCreator interface {
CreateReadWaiter() (ReadWaiter, bool)
}

type PacketReadWaiter interface {
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadPacket() (destination M.Socksaddr, err error)
}

type PacketReadWaitCreator interface {
Expand Down

0 comments on commit 9be7806

Please sign in to comment.