Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow exchanging error codes on session termination #119

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 3 additions & 65 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,71 +3,7 @@ package yamux
import (
"encoding/binary"
"fmt"
)

type Error struct {
msg string
timeout, temporary bool
}

func (ye *Error) Error() string {
return ye.msg
}

func (ye *Error) Timeout() bool {
return ye.timeout
}

func (ye *Error) Temporary() bool {
return ye.temporary
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
ErrInvalidVersion = &Error{msg: "invalid protocol version"}

// ErrInvalidMsgType means we received a frame with an
// invalid message type
ErrInvalidMsgType = &Error{msg: "invalid msg type"}

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
ErrStreamsExhausted = &Error{msg: "streams exhausted"}

// ErrDuplicateStream is used if a duplicate stream is
// opened inbound
ErrDuplicateStream = &Error{msg: "duplicate stream initiated"}

// ErrReceiveWindowExceeded indicates the window was exceeded
ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"}

// ErrTimeout is used when we reach an IO deadline
ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true}

// ErrStreamClosed is returned when using a closed stream
ErrStreamClosed = &Error{msg: "stream closed"}

// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
ErrStreamReset = &Error{msg: "stream reset"}

// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
// timeout writing to the underlying stream connection.
ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true}

// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true}
"time"
)

const (
Expand Down Expand Up @@ -117,6 +53,8 @@ const (
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
// goAwayWaitTime is the time we'll wait to send a goaway frame on close
goAwayWaitTime = 5 * time.Second
)

const (
Expand Down
97 changes: 97 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package yamux

import "fmt"

type Error struct {
msg string
timeout, temporary bool
}

func (ye *Error) Error() string {
return ye.msg
}

func (ye *Error) Timeout() bool {
return ye.timeout
}

func (ye *Error) Temporary() bool {
return ye.temporary

Check warning on line 19 in errors.go

View check run for this annotation

Codecov / codecov/patch

errors.go#L18-L19

Added lines #L18 - L19 were not covered by tests
}

type ErrorGoAway struct {
Remote bool
ErrorCode uint32
}

func (e *ErrorGoAway) Error() string {
if e.Remote {
return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode)
}
return fmt.Sprintf("sent go away, code: %d", e.ErrorCode)

Check warning on line 31 in errors.go

View check run for this annotation

Codecov / codecov/patch

errors.go#L31

Added line #L31 was not covered by tests
}

func (e *ErrorGoAway) Timeout() bool {
return false

Check warning on line 35 in errors.go

View check run for this annotation

Codecov / codecov/patch

errors.go#L34-L35

Added lines #L34 - L35 were not covered by tests
}

func (e *ErrorGoAway) Temporary() bool {
return false

Check warning on line 39 in errors.go

View check run for this annotation

Codecov / codecov/patch

errors.go#L38-L39

Added lines #L38 - L39 were not covered by tests
}

func (e *ErrorGoAway) Is(target error) bool {
// to maintain compatibility with errors returned by previous versions
if e.Remote {
return target == ErrRemoteGoAway
} else {
return target == ErrSessionShutdown

Check warning on line 47 in errors.go

View check run for this annotation

Codecov / codecov/patch

errors.go#L47

Added line #L47 was not covered by tests
}
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
ErrInvalidVersion = &Error{msg: "invalid protocol version"}

// ErrInvalidMsgType means we received a frame with an
// invalid message type
ErrInvalidMsgType = &Error{msg: "invalid msg type"}

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
ErrStreamsExhausted = &Error{msg: "streams exhausted"}

// ErrDuplicateStream is used if a duplicate stream is
// opened inbound
ErrDuplicateStream = &Error{msg: "duplicate stream initiated"}

// ErrReceiveWindowExceeded indicates the window was exceeded
ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"}

// ErrTimeout is used when we reach an IO deadline
ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true}

// ErrStreamClosed is returned when using a closed stream
ErrStreamClosed = &Error{msg: "stream closed"}

// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
ErrStreamReset = &Error{msg: "stream reset"}

// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
// timeout writing to the underlying stream connection.
ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true}

// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true}
)
99 changes: 76 additions & 23 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -41,17 +42,17 @@

var nullMemoryManager = &nullMemoryManagerImpl{}

type CloseWriter interface {
CloseWrite() error
}

// Session is used to wrap a reliable ordered connection and to
// multiplex it into multiple streams.
type Session struct {
rtt int64 // to be accessed atomically, in nanoseconds

// remoteGoAway indicates the remote side does
// not want futher connections. Must be first for alignment.
remoteGoAway int32

// localGoAway indicates that we should stop
// accepting futher connections. Must be first for alignment.
// accepting futher streams. Must be first for alignment.
localGoAway int32

// nextStreamID is the next stream we should
Expand Down Expand Up @@ -203,9 +204,6 @@
if s.IsClosed() {
return nil, s.shutdownErr
}
if atomic.LoadInt32(&s.remoteGoAway) == 1 {
return nil, ErrRemoteGoAway
}

// Block if we have too many inflight SYNs
select {
Expand Down Expand Up @@ -284,23 +282,55 @@
}

// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection.
// Sends a GoAway before closing the connection.
func (s *Session) Close() error {
return s.CloseWithError(goAwayNormal)
}

// CloseWithError closes the session sending errCode in a goaway frame
func (s *Session) CloseWithError(errCode uint32) error {
return s.closeWithError(errCode, true)
}

func (s *Session) closeWithError(errCode uint32, sendGoAway bool) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

if s.shutdown {
return nil
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = ErrSessionShutdown
s.shutdownErr = &ErrorGoAway{Remote: !sendGoAway, ErrorCode: errCode}
}
close(s.shutdownCh)
s.conn.Close()
s.stopKeepalive()
<-s.recvDoneCh

// wait for write loop
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
<-s.sendDoneCh
// send the goaway frame
if sendGoAway {
buf := pool.Get(headerSize)
hdr := s.goAway(errCode)
copy(buf, hdr[:])
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
if _, err = s.conn.Write(buf); err != nil {
sendGoAway = false
}
} else {
sendGoAway = false
}
}
if w, ok := s.conn.(CloseWriter); ok && sendGoAway {
if err := w.CloseWrite(); err != nil {
s.conn.Close()

Check warning on line 326 in session.go

View check run for this annotation

Codecov / codecov/patch

session.go#L326

Added line #L326 was not covered by tests
}
} else {
s.conn.Close()
}
s.conn.SetReadDeadline(time.Now().Add(-1 * time.Hour))
// wait for read loop
<-s.recvDoneCh

s.streamLock.Lock()
defer s.streamLock.Unlock()
Expand All @@ -320,11 +350,11 @@
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.Close()
s.closeWithError(0, false)
}

// GoAway can be used to prevent accepting further
// connections. It does not close the underlying conn.
// streams. It does not close the underlying conn.
func (s *Session) GoAway() error {
return s.sendMsg(s.goAway(goAwayNormal), nil, nil)
}
Expand Down Expand Up @@ -631,7 +661,6 @@

_, err := writer.Write(buf)
pool.Put(buf)

if err != nil {
if os.IsTimeout(err) {
err = ErrConnectionWriteTimeout
Expand Down Expand Up @@ -666,12 +695,39 @@
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
}
}()
defer close(s.recvDoneCh)

gracefulCloseErr := errors.New("close gracefully")
defer func() {
close(s.recvDoneCh)
errGoAway := &ErrorGoAway{}
if errors.As(err, &errGoAway) {
return
}
if err != gracefulCloseErr {
s.conn.Close()
return
}
if err := s.conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
s.conn.Close()
return

Check warning on line 712 in session.go

View check run for this annotation

Codecov / codecov/patch

session.go#L711-L712

Added lines #L711 - L712 were not covered by tests
}
buf := make([]byte, 1<<16)
for {
_, err := s.conn.Read(buf)
if err != nil {
s.conn.Close()
return
}
}
}()
var hdr header
for {
// fmt.Printf("ReadFull from %#v\n", s.reader)
// Read the header
if _, err := io.ReadFull(s.reader, hdr[:]); err != nil {
if s.IsClosed() && os.IsTimeout(err) {
return gracefulCloseErr
}
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
}
Expand Down Expand Up @@ -781,18 +837,15 @@
code := hdr.Length()
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
// Non error termination. Don't log.
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
return fmt.Errorf("yamux protocol error")
case goAwayInternalErr:
s.logger.Printf("[ERR] yamux: received internal error go away")
return fmt.Errorf("remote yamux internal error")
default:
s.logger.Printf("[ERR] yamux: received unexpected go away")
return fmt.Errorf("unexpected go away received")
// application error code, let the application log
}
return nil
return &ErrorGoAway{ErrorCode: code, Remote: true}
}

// incomingStream is used to create a new incoming stream
Expand Down
Loading