Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gracefully close connection fixes: #448 #487

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io/ioutil"
"math/rand"
"net"
"reflect"
"strconv"
"sync"
"time"
Expand Down Expand Up @@ -219,7 +220,7 @@ var validReceivedCloseCodes = map[int]bool{
CloseTLSHandshake: false,
}

func isValidReceivedCloseCode(code int) bool {
func isValidCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
}

Expand Down Expand Up @@ -325,10 +326,53 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol
}

// Close closes the underlying network connection without sending or waiting
// for a close message.
func (c *Conn) Close() error {
return c.conn.Close()
// Close sends close frame and waits for one in response
// it expects two args. `closeCode int` and `closeMessage string` in order
// it uses variadic args to maintain backwards compatibility
func (c *Conn) Close(args ...interface{}) error {
sanketplus marked this conversation as resolved.
Show resolved Hide resolved
closeCode := CloseNoStatusReceived
message := ""
ok := false
if len(args) == 2 {
closeCode, ok = args[0].(int)
if !ok {
closeCode = CloseNoStatusReceived
}
message, ok = args[1].(string)
if !ok {
message = ""
}
}
err := c.Shutdown(closeCode, message)
if err != nil {
return err
}
c.conn.Close()
return nil
}

// Shutdown sends a close frame and waits for one in response
func (c *Conn) Shutdown(closeCode int, closeMessage string) error {
if !isValidCloseCode(closeCode) {
// we do not shutdown connection
return errors.New("invalid close code received")
}
if !utf8.ValidString(closeMessage) {
return errors.New("invalid utf8 payload for shutdown message")
}
message := FormatCloseMessage(closeCode, closeMessage)
err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
if err != nil {
return err
}
timeStart := time.Now()
c.conn.SetReadDeadline(time.Now().Add(time.Minute))
sanketplus marked this conversation as resolved.
Show resolved Hide resolved
for _, _, err := c.ReadMessage(); reflect.TypeOf(err) != reflect.TypeOf(&CloseError{}) ; {
sanketplus marked this conversation as resolved.
Show resolved Hide resolved
if timeStart.Sub(time.Now()) > time.Minute {
break
}
}
return nil
}

// LocalAddr returns the local network address.
Expand Down Expand Up @@ -496,6 +540,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {

var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
Expand Down Expand Up @@ -902,7 +947,7 @@ func (c *Conn) advanceFrame() (int, error) {
closeText := ""
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
if !isValidCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
}
closeText = string(payload[2:])
Expand Down