From 273d2b4fbbd97361d947e18eceb5bc1dda99c7ab Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 22 Aug 2024 23:52:21 +0530 Subject: [PATCH 1/3] allow exchanging error codes on session termination --- const.go | 68 ++-------------------------------- errors.go | 97 +++++++++++++++++++++++++++++++++++++++++++++++++ session.go | 62 ++++++++++++++++++++----------- session_test.go | 52 ++++++++++++++++++++++---- 4 files changed, 185 insertions(+), 94 deletions(-) create mode 100644 errors.go diff --git a/const.go b/const.go index e4b2bc2..607c25d 100644 --- a/const.go +++ b/const.go @@ -3,71 +3,7 @@ package yamux import ( "encoding/binary" "fmt" -) - -type Error struct { - msg string - timeout, temporary bool -} - -func (ye *Error) Error() string { - return ye.msg -} - -func (ye *Error) Timeout() bool { - return ye.timeout -} - -func (ye *Error) Temporary() bool { - return ye.temporary -} - -var ( - // ErrInvalidVersion means we received a frame with an - // invalid version - ErrInvalidVersion = &Error{msg: "invalid protocol version"} - - // ErrInvalidMsgType means we received a frame with an - // invalid message type - ErrInvalidMsgType = &Error{msg: "invalid msg type"} - - // ErrSessionShutdown is used if there is a shutdown during - // an operation - ErrSessionShutdown = &Error{msg: "session shutdown"} - - // ErrStreamsExhausted is returned if we have no more - // stream ids to issue - ErrStreamsExhausted = &Error{msg: "streams exhausted"} - - // ErrDuplicateStream is used if a duplicate stream is - // opened inbound - ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} - - // ErrReceiveWindowExceeded indicates the window was exceeded - ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} - - // ErrTimeout is used when we reach an IO deadline - ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} - - // ErrStreamClosed is returned when using a closed stream - ErrStreamClosed = &Error{msg: "stream closed"} - - // ErrUnexpectedFlag is set when we get an unexpected flag - ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - - // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} - - // ErrStreamReset is sent if a stream is reset. This can happen - // if the backlog is exceeded, or if there was a remote GoAway. - ErrStreamReset = &Error{msg: "stream reset"} - - // ErrConnectionWriteTimeout indicates that we hit the "safety valve" - // timeout writing to the underlying stream connection. - ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} - - // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close - ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} + "time" ) const ( @@ -117,6 +53,8 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 + // goAwayWaitTime is the time we'll wait to send a goaway frame on close + goAwayWaitTime = 5 * time.Second ) const ( diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..d5f0c59 --- /dev/null +++ b/errors.go @@ -0,0 +1,97 @@ +package yamux + +import "fmt" + +type Error struct { + msg string + timeout, temporary bool +} + +func (ye *Error) Error() string { + return ye.msg +} + +func (ye *Error) Timeout() bool { + return ye.timeout +} + +func (ye *Error) Temporary() bool { + return ye.temporary +} + +type ErrorGoAway struct { + Remote bool + ErrorCode uint32 +} + +func (e *ErrorGoAway) Error() string { + if e.Remote { + return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode) + } + return fmt.Sprintf("sent go away, code: %d", e.ErrorCode) +} + +func (e *ErrorGoAway) Timeout() bool { + return false +} + +func (e *ErrorGoAway) Temporary() bool { + return false +} + +func (e *ErrorGoAway) Is(target error) bool { + // to maintain compatibility with errors returned by previous versions + if e.Remote { + return target == ErrRemoteGoAway + } else { + return target == ErrSessionShutdown + } +} + +var ( + // ErrInvalidVersion means we received a frame with an + // invalid version + ErrInvalidVersion = &Error{msg: "invalid protocol version"} + + // ErrInvalidMsgType means we received a frame with an + // invalid message type + ErrInvalidMsgType = &Error{msg: "invalid msg type"} + + // ErrSessionShutdown is used if there is a shutdown during + // an operation + ErrSessionShutdown = &Error{msg: "session shutdown"} + + // ErrStreamsExhausted is returned if we have no more + // stream ids to issue + ErrStreamsExhausted = &Error{msg: "streams exhausted"} + + // ErrDuplicateStream is used if a duplicate stream is + // opened inbound + ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} + + // ErrReceiveWindowExceeded indicates the window was exceeded + ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} + + // ErrTimeout is used when we reach an IO deadline + ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} + + // ErrStreamClosed is returned when using a closed stream + ErrStreamClosed = &Error{msg: "stream closed"} + + // ErrUnexpectedFlag is set when we get an unexpected flag + ErrUnexpectedFlag = &Error{msg: "unexpected flag"} + + // ErrRemoteGoAway is used when we get a go away from the other side + ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} + + // ErrStreamReset is sent if a stream is reset. This can happen + // if the backlog is exceeded, or if there was a remote GoAway. + ErrStreamReset = &Error{msg: "stream reset"} + + // ErrConnectionWriteTimeout indicates that we hit the "safety valve" + // timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} + + // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close + ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} +) diff --git a/session.go b/session.go index c4cd1bd..0f9a2e3 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package yamux import ( "bufio" "context" + "errors" "fmt" "io" "log" @@ -46,12 +47,8 @@ var nullMemoryManager = &nullMemoryManagerImpl{} type Session struct { rtt int64 // to be accessed atomically, in nanoseconds - // remoteGoAway indicates the remote side does - // not want futher connections. Must be first for alignment. - remoteGoAway int32 - // localGoAway indicates that we should stop - // accepting futher connections. Must be first for alignment. + // accepting futher streams. Must be first for alignment. localGoAway int32 // nextStreamID is the next stream we should @@ -203,9 +200,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { if s.IsClosed() { return nil, s.shutdownErr } - if atomic.LoadInt32(&s.remoteGoAway) == 1 { - return nil, ErrRemoteGoAway - } // Block if we have too many inflight SYNs select { @@ -284,24 +278,47 @@ func (s *Session) AcceptStream() (*Stream, error) { } // Close is used to close the session and all streams. -// Attempts to send a GoAway before closing the connection. +// Sends a GoAway before closing the connection. func (s *Session) Close() error { + return s.CloseWithError(goAwayNormal) +} + +// CloseWithError closes the session sending errCode in a goaway frame +func (s *Session) CloseWithError(errCode uint32) error { + return s.closeWithError(errCode, true) +} + +func (s *Session) closeWithError(errCode uint32, sendGoAway bool) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() - if s.shutdown { return nil } s.shutdown = true if s.shutdownErr == nil { - s.shutdownErr = ErrSessionShutdown + s.shutdownErr = &ErrorGoAway{Remote: !sendGoAway, ErrorCode: errCode} } close(s.shutdownCh) - s.conn.Close() s.stopKeepalive() - <-s.recvDoneCh + + // wait for write loop + _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked <-s.sendDoneCh + // send the goaway frame + if sendGoAway { + buf := pool.Get(headerSize) + hdr := s.goAway(errCode) + copy(buf, hdr[:]) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + _, _ = s.conn.Write(buf) // Ignore the error. We are going to close the connection anyway + } + } + s.conn.Close() + + // wait for read loop + <-s.recvDoneCh + s.streamLock.Lock() defer s.streamLock.Unlock() for id, stream := range s.streams { @@ -320,11 +337,11 @@ func (s *Session) exitErr(err error) { s.shutdownErr = err } s.shutdownLock.Unlock() - s.Close() + s.closeWithError(0, false) } // GoAway can be used to prevent accepting further -// connections. It does not close the underlying conn. +// streams. It does not close the underlying conn. func (s *Session) GoAway() error { return s.sendMsg(s.goAway(goAwayNormal), nil, nil) } @@ -516,6 +533,12 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { + if !s.IsClosed() && (errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "reset") || strings.Contains(err.Error(), "broken pipe")) { + // if remote has closed the connection, wait for recv loop to exit + // unfortunately it is impossible to close the connection such that FIN is sent and not RST + <-s.recvDoneCh + return + } s.exitErr(err) } } @@ -781,18 +804,15 @@ func (s *Session) handleGoAway(hdr header) error { code := hdr.Length() switch code { case goAwayNormal: - atomic.SwapInt32(&s.remoteGoAway, 1) + // Non error termination. Don't log. case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") - return fmt.Errorf("yamux protocol error") case goAwayInternalErr: s.logger.Printf("[ERR] yamux: received internal error go away") - return fmt.Errorf("remote yamux internal error") default: - s.logger.Printf("[ERR] yamux: received unexpected go away") - return fmt.Errorf("unexpected go away received") + // application error code, let the application log } - return nil + return &ErrorGoAway{ErrorCode: code, Remote: true} } // incomingStream is used to create a new incoming stream diff --git a/session_test.go b/session_test.go index 974b6d5..9ef7a3d 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package yamux import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -627,28 +628,63 @@ func TestSendData_Large(t *testing.T) { } } +func testTCPConns(t *testing.T) (*Session, *Session) { + ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + if err != nil { + t.Fatal(err) + } + serverConnCh := make(chan net.Conn, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + serverConnCh <- conn + }() + + clientConn, err := net.DialTCP("tcp", nil, ln.Addr().(*net.TCPAddr)) + if err != nil { + ln.Close() + t.Fatal(err) + return nil, nil + } + + client, _ := Client(clientConn, testConf(), nil) + server, _ := Server(<-serverConnCh, testConf(), nil) + return client, server + +} + func TestGoAway(t *testing.T) { // This test is noisy. conf := testConf() conf.LogOutput = io.Discard - client, server := testClientServerConfig(conf) + client, server := testTCPConns(t) defer client.Close() defer server.Close() - if err := server.GoAway(); err != nil { + if err := server.CloseWithError(42); err != nil { t.Fatalf("err: %v", err) } for i := 0; i < 100; i++ { s, err := client.Open(context.Background()) - switch err { - case nil: - s.Close() - case ErrRemoteGoAway: + if err != nil { + if !errors.Is(err, ErrRemoteGoAway) { + t.Fatal("expected error to be ErrRemoteGoAway, got", err) + } + errExpected := &ErrorGoAway{Remote: true, ErrorCode: 42} + errGot, ok := err.(*ErrorGoAway) + if !ok { + t.Fatalf("expected type *ErrorGoAway, got %T", err) + } + if *errGot != *errExpected { + t.Fatalf("invalid error, expected %v, got %v", errExpected, errGot) + } return - default: - t.Fatalf("err: %v", err) + } else { + s.Close() } } t.Fatalf("expected GoAway error") From 276b8919b02f1e2d8f3eabd5eae7bd5121a944ea Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 23 Aug 2024 16:03:37 +0530 Subject: [PATCH 2/3] debug test --- session_test.go | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/session_test.go b/session_test.go index 9ef7a3d..46d0c7c 100644 --- a/session_test.go +++ b/session_test.go @@ -628,7 +628,7 @@ func TestSendData_Large(t *testing.T) { } } -func testTCPConns(t *testing.T) (*Session, *Session) { +func testTCPConns(t *testing.T) (net.Conn, net.Conn) { ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) if err != nil { t.Fatal(err) @@ -649,9 +649,7 @@ func testTCPConns(t *testing.T) (*Session, *Session) { return nil, nil } - client, _ := Client(clientConn, testConf(), nil) - server, _ := Server(<-serverConnCh, testConf(), nil) - return client, server + return clientConn, <-serverConnCh } @@ -660,7 +658,9 @@ func TestGoAway(t *testing.T) { conf := testConf() conf.LogOutput = io.Discard - client, server := testTCPConns(t) + clientConn, serverConn := testTCPConns(t) + client, _ := Client(clientConn, testConf(), nil) + server, _ := Server(serverConn, testConf(), nil) defer client.Close() defer server.Close() @@ -1821,3 +1821,28 @@ func TestMaxIncomingStreams(t *testing.T) { _, err = str.Read([]byte{0}) require.NoError(t, err) } + +func TestRSTBehavior(t *testing.T) { + client, server := testTCPConns(t) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer server.Close() + server.Write([]byte("hello")) + time.Sleep(20 * time.Second) + buf := make([]byte, 10) + n, err := server.Read(buf) + if err != nil { + t.Error(err) + } else { + t.Log(string(buf[:n])) + } + + }() + client.Write([]byte("world")) + time.Sleep(10 * time.Second) + // close client without reading server msg. This ensures that the TCP stack sends an RST + client.Close() + wg.Wait() +} From 1ed15bb8621ce7d49d811c0a024e7fdf6fdde1ce Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 23 Aug 2024 20:34:19 +0530 Subject: [PATCH 3/3] do a graceful close --- session.go | 57 ++++++++++++++++++++++++++++++++++++++----------- session_test.go | 38 ++++++++------------------------- 2 files changed, 54 insertions(+), 41 deletions(-) diff --git a/session.go b/session.go index 0f9a2e3..973fa98 100644 --- a/session.go +++ b/session.go @@ -42,6 +42,10 @@ func (n nullMemoryManagerImpl) Done() {} var nullMemoryManager = &nullMemoryManagerImpl{} +type CloseWriter interface { + CloseWrite() error +} + // Session is used to wrap a reliable ordered connection and to // multiplex it into multiple streams. type Session struct { @@ -304,18 +308,27 @@ func (s *Session) closeWithError(errCode uint32, sendGoAway bool) error { // wait for write loop _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked <-s.sendDoneCh - // send the goaway frame if sendGoAway { buf := pool.Get(headerSize) hdr := s.goAway(errCode) copy(buf, hdr[:]) if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { - _, _ = s.conn.Write(buf) // Ignore the error. We are going to close the connection anyway + if _, err = s.conn.Write(buf); err != nil { + sendGoAway = false + } + } else { + sendGoAway = false } } - s.conn.Close() - + if w, ok := s.conn.(CloseWriter); ok && sendGoAway { + if err := w.CloseWrite(); err != nil { + s.conn.Close() + } + } else { + s.conn.Close() + } + s.conn.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // wait for read loop <-s.recvDoneCh @@ -533,12 +546,6 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - if !s.IsClosed() && (errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "reset") || strings.Contains(err.Error(), "broken pipe")) { - // if remote has closed the connection, wait for recv loop to exit - // unfortunately it is impossible to close the connection such that FIN is sent and not RST - <-s.recvDoneCh - return - } s.exitErr(err) } } @@ -654,7 +661,6 @@ func (s *Session) sendLoop() (err error) { _, err := writer.Write(buf) pool.Put(buf) - if err != nil { if os.IsTimeout(err) { err = ErrConnectionWriteTimeout @@ -689,12 +695,39 @@ func (s *Session) recvLoop() (err error) { err = fmt.Errorf("panic in yamux receive loop: %s", rerr) } }() - defer close(s.recvDoneCh) + + gracefulCloseErr := errors.New("close gracefully") + defer func() { + close(s.recvDoneCh) + errGoAway := &ErrorGoAway{} + if errors.As(err, &errGoAway) { + return + } + if err != gracefulCloseErr { + s.conn.Close() + return + } + if err := s.conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + s.conn.Close() + return + } + buf := make([]byte, 1<<16) + for { + _, err := s.conn.Read(buf) + if err != nil { + s.conn.Close() + return + } + } + }() var hdr header for { // fmt.Printf("ReadFull from %#v\n", s.reader) // Read the header if _, err := io.ReadFull(s.reader, hdr[:]); err != nil { + if s.IsClosed() && os.IsTimeout(err) { + return gracefulCloseErr + } if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) } diff --git a/session_test.go b/session_test.go index 46d0c7c..e698966 100644 --- a/session_test.go +++ b/session_test.go @@ -40,6 +40,8 @@ type pipeConn struct { writeDeadline pipeDeadline writeBlocker chan struct{} closeCh chan struct{} + closeOnce sync.Once + closeErr error } func (p *pipeConn) SetDeadline(t time.Time) error { @@ -66,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) { } func (p *pipeConn) Close() error { - p.writeDeadline.set(time.Time{}) - err := p.Conn.Close() - close(p.closeCh) - return err + p.closeOnce.Do(func() { + close(p.closeCh) + p.writeDeadline.set(time.Time{}) + p.closeErr = p.Conn.Close() + }) + return p.closeErr } func (p *pipeConn) BlockWrites() { @@ -686,6 +690,7 @@ func TestGoAway(t *testing.T) { } else { s.Close() } + time.Sleep(20 * time.Millisecond) } t.Fatalf("expected GoAway error") } @@ -1821,28 +1826,3 @@ func TestMaxIncomingStreams(t *testing.T) { _, err = str.Read([]byte{0}) require.NoError(t, err) } - -func TestRSTBehavior(t *testing.T) { - client, server := testTCPConns(t) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - defer server.Close() - server.Write([]byte("hello")) - time.Sleep(20 * time.Second) - buf := make([]byte, 10) - n, err := server.Read(buf) - if err != nil { - t.Error(err) - } else { - t.Log(string(buf[:n])) - } - - }() - client.Write([]byte("world")) - time.Sleep(10 * time.Second) - // close client without reading server msg. This ensures that the TCP stack sends an RST - client.Close() - wg.Wait() -}