diff --git a/const.go b/const.go index e4b2bc2..607c25d 100644 --- a/const.go +++ b/const.go @@ -3,71 +3,7 @@ package yamux import ( "encoding/binary" "fmt" -) - -type Error struct { - msg string - timeout, temporary bool -} - -func (ye *Error) Error() string { - return ye.msg -} - -func (ye *Error) Timeout() bool { - return ye.timeout -} - -func (ye *Error) Temporary() bool { - return ye.temporary -} - -var ( - // ErrInvalidVersion means we received a frame with an - // invalid version - ErrInvalidVersion = &Error{msg: "invalid protocol version"} - - // ErrInvalidMsgType means we received a frame with an - // invalid message type - ErrInvalidMsgType = &Error{msg: "invalid msg type"} - - // ErrSessionShutdown is used if there is a shutdown during - // an operation - ErrSessionShutdown = &Error{msg: "session shutdown"} - - // ErrStreamsExhausted is returned if we have no more - // stream ids to issue - ErrStreamsExhausted = &Error{msg: "streams exhausted"} - - // ErrDuplicateStream is used if a duplicate stream is - // opened inbound - ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} - - // ErrReceiveWindowExceeded indicates the window was exceeded - ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} - - // ErrTimeout is used when we reach an IO deadline - ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} - - // ErrStreamClosed is returned when using a closed stream - ErrStreamClosed = &Error{msg: "stream closed"} - - // ErrUnexpectedFlag is set when we get an unexpected flag - ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - - // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} - - // ErrStreamReset is sent if a stream is reset. This can happen - // if the backlog is exceeded, or if there was a remote GoAway. - ErrStreamReset = &Error{msg: "stream reset"} - - // ErrConnectionWriteTimeout indicates that we hit the "safety valve" - // timeout writing to the underlying stream connection. - ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} - - // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close - ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} + "time" ) const ( @@ -117,6 +53,8 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 + // goAwayWaitTime is the time we'll wait to send a goaway frame on close + goAwayWaitTime = 5 * time.Second ) const ( diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..d5f0c59 --- /dev/null +++ b/errors.go @@ -0,0 +1,97 @@ +package yamux + +import "fmt" + +type Error struct { + msg string + timeout, temporary bool +} + +func (ye *Error) Error() string { + return ye.msg +} + +func (ye *Error) Timeout() bool { + return ye.timeout +} + +func (ye *Error) Temporary() bool { + return ye.temporary +} + +type ErrorGoAway struct { + Remote bool + ErrorCode uint32 +} + +func (e *ErrorGoAway) Error() string { + if e.Remote { + return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode) + } + return fmt.Sprintf("sent go away, code: %d", e.ErrorCode) +} + +func (e *ErrorGoAway) Timeout() bool { + return false +} + +func (e *ErrorGoAway) Temporary() bool { + return false +} + +func (e *ErrorGoAway) Is(target error) bool { + // to maintain compatibility with errors returned by previous versions + if e.Remote { + return target == ErrRemoteGoAway + } else { + return target == ErrSessionShutdown + } +} + +var ( + // ErrInvalidVersion means we received a frame with an + // invalid version + ErrInvalidVersion = &Error{msg: "invalid protocol version"} + + // ErrInvalidMsgType means we received a frame with an + // invalid message type + ErrInvalidMsgType = &Error{msg: "invalid msg type"} + + // ErrSessionShutdown is used if there is a shutdown during + // an operation + ErrSessionShutdown = &Error{msg: "session shutdown"} + + // ErrStreamsExhausted is returned if we have no more + // stream ids to issue + ErrStreamsExhausted = &Error{msg: "streams exhausted"} + + // ErrDuplicateStream is used if a duplicate stream is + // opened inbound + ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} + + // ErrReceiveWindowExceeded indicates the window was exceeded + ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} + + // ErrTimeout is used when we reach an IO deadline + ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} + + // ErrStreamClosed is returned when using a closed stream + ErrStreamClosed = &Error{msg: "stream closed"} + + // ErrUnexpectedFlag is set when we get an unexpected flag + ErrUnexpectedFlag = &Error{msg: "unexpected flag"} + + // ErrRemoteGoAway is used when we get a go away from the other side + ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} + + // ErrStreamReset is sent if a stream is reset. This can happen + // if the backlog is exceeded, or if there was a remote GoAway. + ErrStreamReset = &Error{msg: "stream reset"} + + // ErrConnectionWriteTimeout indicates that we hit the "safety valve" + // timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} + + // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close + ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} +) diff --git a/session.go b/session.go index c4cd1bd..973fa98 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package yamux import ( "bufio" "context" + "errors" "fmt" "io" "log" @@ -41,17 +42,17 @@ func (n nullMemoryManagerImpl) Done() {} var nullMemoryManager = &nullMemoryManagerImpl{} +type CloseWriter interface { + CloseWrite() error +} + // Session is used to wrap a reliable ordered connection and to // multiplex it into multiple streams. type Session struct { rtt int64 // to be accessed atomically, in nanoseconds - // remoteGoAway indicates the remote side does - // not want futher connections. Must be first for alignment. - remoteGoAway int32 - // localGoAway indicates that we should stop - // accepting futher connections. Must be first for alignment. + // accepting futher streams. Must be first for alignment. localGoAway int32 // nextStreamID is the next stream we should @@ -203,9 +204,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { if s.IsClosed() { return nil, s.shutdownErr } - if atomic.LoadInt32(&s.remoteGoAway) == 1 { - return nil, ErrRemoteGoAway - } // Block if we have too many inflight SYNs select { @@ -284,23 +282,55 @@ func (s *Session) AcceptStream() (*Stream, error) { } // Close is used to close the session and all streams. -// Attempts to send a GoAway before closing the connection. +// Sends a GoAway before closing the connection. func (s *Session) Close() error { + return s.CloseWithError(goAwayNormal) +} + +// CloseWithError closes the session sending errCode in a goaway frame +func (s *Session) CloseWithError(errCode uint32) error { + return s.closeWithError(errCode, true) +} + +func (s *Session) closeWithError(errCode uint32, sendGoAway bool) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() - if s.shutdown { return nil } s.shutdown = true if s.shutdownErr == nil { - s.shutdownErr = ErrSessionShutdown + s.shutdownErr = &ErrorGoAway{Remote: !sendGoAway, ErrorCode: errCode} } close(s.shutdownCh) - s.conn.Close() s.stopKeepalive() - <-s.recvDoneCh + + // wait for write loop + _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked <-s.sendDoneCh + // send the goaway frame + if sendGoAway { + buf := pool.Get(headerSize) + hdr := s.goAway(errCode) + copy(buf, hdr[:]) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + if _, err = s.conn.Write(buf); err != nil { + sendGoAway = false + } + } else { + sendGoAway = false + } + } + if w, ok := s.conn.(CloseWriter); ok && sendGoAway { + if err := w.CloseWrite(); err != nil { + s.conn.Close() + } + } else { + s.conn.Close() + } + s.conn.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + // wait for read loop + <-s.recvDoneCh s.streamLock.Lock() defer s.streamLock.Unlock() @@ -320,11 +350,11 @@ func (s *Session) exitErr(err error) { s.shutdownErr = err } s.shutdownLock.Unlock() - s.Close() + s.closeWithError(0, false) } // GoAway can be used to prevent accepting further -// connections. It does not close the underlying conn. +// streams. It does not close the underlying conn. func (s *Session) GoAway() error { return s.sendMsg(s.goAway(goAwayNormal), nil, nil) } @@ -631,7 +661,6 @@ func (s *Session) sendLoop() (err error) { _, err := writer.Write(buf) pool.Put(buf) - if err != nil { if os.IsTimeout(err) { err = ErrConnectionWriteTimeout @@ -666,12 +695,39 @@ func (s *Session) recvLoop() (err error) { err = fmt.Errorf("panic in yamux receive loop: %s", rerr) } }() - defer close(s.recvDoneCh) + + gracefulCloseErr := errors.New("close gracefully") + defer func() { + close(s.recvDoneCh) + errGoAway := &ErrorGoAway{} + if errors.As(err, &errGoAway) { + return + } + if err != gracefulCloseErr { + s.conn.Close() + return + } + if err := s.conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + s.conn.Close() + return + } + buf := make([]byte, 1<<16) + for { + _, err := s.conn.Read(buf) + if err != nil { + s.conn.Close() + return + } + } + }() var hdr header for { // fmt.Printf("ReadFull from %#v\n", s.reader) // Read the header if _, err := io.ReadFull(s.reader, hdr[:]); err != nil { + if s.IsClosed() && os.IsTimeout(err) { + return gracefulCloseErr + } if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) } @@ -781,18 +837,15 @@ func (s *Session) handleGoAway(hdr header) error { code := hdr.Length() switch code { case goAwayNormal: - atomic.SwapInt32(&s.remoteGoAway, 1) + // Non error termination. Don't log. case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") - return fmt.Errorf("yamux protocol error") case goAwayInternalErr: s.logger.Printf("[ERR] yamux: received internal error go away") - return fmt.Errorf("remote yamux internal error") default: - s.logger.Printf("[ERR] yamux: received unexpected go away") - return fmt.Errorf("unexpected go away received") + // application error code, let the application log } - return nil + return &ErrorGoAway{ErrorCode: code, Remote: true} } // incomingStream is used to create a new incoming stream diff --git a/session_test.go b/session_test.go index 974b6d5..e698966 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package yamux import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -39,6 +40,8 @@ type pipeConn struct { writeDeadline pipeDeadline writeBlocker chan struct{} closeCh chan struct{} + closeOnce sync.Once + closeErr error } func (p *pipeConn) SetDeadline(t time.Time) error { @@ -65,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) { } func (p *pipeConn) Close() error { - p.writeDeadline.set(time.Time{}) - err := p.Conn.Close() - close(p.closeCh) - return err + p.closeOnce.Do(func() { + close(p.closeCh) + p.writeDeadline.set(time.Time{}) + p.closeErr = p.Conn.Close() + }) + return p.closeErr } func (p *pipeConn) BlockWrites() { @@ -627,29 +632,65 @@ func TestSendData_Large(t *testing.T) { } } +func testTCPConns(t *testing.T) (net.Conn, net.Conn) { + ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + if err != nil { + t.Fatal(err) + } + serverConnCh := make(chan net.Conn, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + serverConnCh <- conn + }() + + clientConn, err := net.DialTCP("tcp", nil, ln.Addr().(*net.TCPAddr)) + if err != nil { + ln.Close() + t.Fatal(err) + return nil, nil + } + + return clientConn, <-serverConnCh + +} + func TestGoAway(t *testing.T) { // This test is noisy. conf := testConf() conf.LogOutput = io.Discard - client, server := testClientServerConfig(conf) + clientConn, serverConn := testTCPConns(t) + client, _ := Client(clientConn, testConf(), nil) + server, _ := Server(serverConn, testConf(), nil) defer client.Close() defer server.Close() - if err := server.GoAway(); err != nil { + if err := server.CloseWithError(42); err != nil { t.Fatalf("err: %v", err) } for i := 0; i < 100; i++ { s, err := client.Open(context.Background()) - switch err { - case nil: - s.Close() - case ErrRemoteGoAway: + if err != nil { + if !errors.Is(err, ErrRemoteGoAway) { + t.Fatal("expected error to be ErrRemoteGoAway, got", err) + } + errExpected := &ErrorGoAway{Remote: true, ErrorCode: 42} + errGot, ok := err.(*ErrorGoAway) + if !ok { + t.Fatalf("expected type *ErrorGoAway, got %T", err) + } + if *errGot != *errExpected { + t.Fatalf("invalid error, expected %v, got %v", errExpected, errGot) + } return - default: - t.Fatalf("err: %v", err) + } else { + s.Close() } + time.Sleep(20 * time.Millisecond) } t.Fatalf("expected GoAway error") }