Skip to content

Commit

Permalink
fix ping/ping on gorilla/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
scottfeldman committed Dec 12, 2024
1 parent 46675ed commit 0354559
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 60 deletions.
28 changes: 23 additions & 5 deletions pkg/device/ws-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,31 @@
package device

import (
"net/http"
"net/url"
"time"

"github.com/gorilla/websocket"
)

func wsDial(wsURL *url.URL, user, passwd string) {

var hdr = http.Header{}

// If valid user, set the basic auth header for the request
if user != "" {
req, err := http.NewRequest("GET", wsURL.String(), nil)
if err != nil {
LogError("Dialing", "url", wsURL, "err", err)
return
}
req.SetBasicAuth(user, passwd)
hdr = req.Header
}

for {
// Connect to the server
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
// Connect to the server with custom headers
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), hdr)
if err == nil {
// Service the client websocket
wsClient(conn)
Expand Down Expand Up @@ -43,6 +58,9 @@ func wsClient(conn *websocket.Conn) {
Path: "/announce",
}

link.setPongHandler()
link.startPing()

pkt.Marshal(&ann)

// Send announcement
Expand All @@ -53,8 +71,8 @@ func wsClient(conn *websocket.Conn) {
return
}

// Receive welcome within 1 sec
pkt, err = link.receiveTimeout(time.Second)
// Receive welcome
pkt, err = link.receive()
if err != nil {
LogError("Receiving", "err", err)
return
Expand All @@ -75,7 +93,7 @@ func wsClient(conn *websocket.Conn) {
// Route incoming packets down to the destination device
LogInfo("Receiving packets")
for {
pkt, err := link.receivePoll()
pkt, err := link.receive()
if err != nil {
LogError("Receiving packet", "err", err)
break
Expand Down
5 changes: 4 additions & 1 deletion pkg/device/ws-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ func wsServer(conn *websocket.Conn) {

var link = &wsLink{conn: conn}

link.setPongHandler()
link.startPing()

// First receive should be an /announce packet
pkt, err := link.receive()
if err != nil {
Expand Down Expand Up @@ -72,7 +75,7 @@ func wsServer(conn *websocket.Conn) {

// Route incoming packets up to the destination device
for {
pkt, err := link.receivePoll()
pkt, err := link.receive()
if err != nil {
LogError("Receiving packet", "err", err)
break
Expand Down
88 changes: 34 additions & 54 deletions pkg/device/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
package device

import (
"encoding/json"
"fmt"
"html/template"
"sync"
"time"

"github.com/gorilla/websocket"
)

type wsLink struct {
conn *websocket.Conn
lastRecv time.Time
lastSend time.Time
conn *websocket.Conn
sync.Mutex
}

type announcement struct {
Expand All @@ -28,66 +27,47 @@ type announcement struct {
}

func (l *wsLink) Send(pkt *Packet) error {
data, err := json.Marshal(pkt)
if err != nil {
return fmt.Errorf("Marshal error: %w", err)
}
if err := l.conn.WriteMessage(websocket.TextMessage, data); err != nil {
return fmt.Errorf("Send error: %w", err)
}
l.lastSend = time.Now()
return nil
l.Lock()
defer l.Unlock()
return l.conn.WriteJSON(pkt)
}

func (l *wsLink) Close() {
l.conn.Close()
}

func (l *wsLink) receive() (*Packet, error) {
_, data, err := l.conn.ReadMessage()
if err != nil {
return nil, err
}

l.lastRecv = time.Now()

var pkt Packet
if err := json.Unmarshal(data, &pkt); err != nil {
LogError("Unmarshal Error", "data", string(data))
return nil, fmt.Errorf("Unmarshalling error: %w", err)
}
return &pkt, nil
}
var wsPingPeriod = 5 * time.Second

func (l *wsLink) receiveTimeout(timeout time.Duration) (*Packet, error) {
l.conn.SetReadDeadline(time.Now().Add(timeout))
pkt, err := l.receive()
l.conn.SetReadDeadline(time.Time{})
return pkt, err
func (l *wsLink) setPongHandler() {
l.conn.SetReadDeadline(time.Now().Add(wsPingPeriod + time.Second))
l.conn.SetPongHandler(func(appData string) error {
l.conn.SetReadDeadline(time.Now().Add(wsPingPeriod + time.Second))
LogInfo("Pong received, read deadline extended")
return nil
})
}

var pingDuration = 4 * time.Second
var pingTimeout = 2*pingDuration + time.Second

func (l *wsLink) receivePoll() (*Packet, error) {
for {
if time.Since(l.lastSend) >= pingDuration {
if err := l.Send(&Packet{Path: "/ping"}); err != nil {
return nil, err
}
}
pkt, err := l.receiveTimeout(time.Second)
if err == nil {
if pkt.Path == "/ping" {
continue
func (l *wsLink) startPing() {
go func() {
for {
time.Sleep(wsPingPeriod)
l.Lock()
if err := l.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
LogInfo("Ping error:", "err", err)
l.Unlock()
return
}
return pkt, nil
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
return nil, err
}
if time.Since(l.lastRecv) > pingTimeout {
return nil, err
l.Unlock()
LogInfo("Ping sent")
}
}()
}

func (l *wsLink) receive() (*Packet, error) {
var pkt Packet
if err := l.conn.ReadJSON(&pkt); err != nil {
LogError("ReadJSON Error", "err", err)
return nil, fmt.Errorf("ReadJSON error: %w", err)
}
return &pkt, nil
}

0 comments on commit 0354559

Please sign in to comment.