Skip to content

Commit

Permalink
fix: fix err, update test
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Jul 29, 2024
1 parent 3739c9c commit 67c6911
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 53 deletions.
20 changes: 12 additions & 8 deletions node/pkg/dal/api/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,27 @@ func (h *Hub) handleClientRegistration() {
func (h *Hub) addClient(conn *websocket.Conn) {
h.mu.Lock()
defer h.mu.Unlock()
ip := conn.IP()
if _, ok := h.clients[conn]; ok {
return
}
h.clients[conn] = make(map[string]bool)
if _, ok := h.connPerIP[conn.IP()]; !ok {
h.connPerIP[conn.IP()] = make([]*websocket.Conn, 0)
if _, ok := h.connPerIP[ip]; !ok {
h.connPerIP[ip] = make([]*websocket.Conn, 0)
}

h.connPerIP[conn.IP()] = append(h.connPerIP[conn.IP()], conn)
h.connPerIP[ip] = append(h.connPerIP[ip], conn)

if len(h.connPerIP) > MAX_CONNECTIONS {
oldConn := h.connPerIP[conn.IP()][0]
if len(h.connPerIP[ip]) > MAX_CONNECTIONS {
log.Info().Msg("removing old connection")
oldConn := h.connPerIP[ip][0]
if subs, ok := h.clients[oldConn]; ok {
for k := range subs {
delete(h.clients[oldConn], k)
}
}
delete(h.clients, oldConn)
h.connPerIP[ip] = h.connPerIP[ip][1:]
oldConn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "too many connections"),
Expand All @@ -87,17 +90,18 @@ func (h *Hub) addClient(conn *websocket.Conn) {
func (h *Hub) removeClient(conn *websocket.Conn) {
h.mu.Lock()
defer h.mu.Unlock()
ip := conn.IP()
if _, ok := h.clients[conn]; ok {
for symbol := range h.clients[conn] {
delete(h.clients[conn], symbol)
}
delete(h.clients, conn)
}
for i, c := range h.connPerIP[conn.IP()] {
for i, c := range h.connPerIP[ip] {
if c == conn {
h.connPerIP[conn.IP()] = append(h.connPerIP[conn.IP()][:i], h.connPerIP[conn.IP()][i+1:]...)
h.connPerIP[ip] = append(h.connPerIP[ip][:i], h.connPerIP[ip][i+1:]...)
if len(h.connPerIP) == 0 {
delete(h.connPerIP, conn.IP())
delete(h.connPerIP, ip)
}
}
}
Expand Down
111 changes: 66 additions & 45 deletions node/pkg/dal/tests/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,58 +163,79 @@ func TestApiWebsocket(t *testing.T) {

go testItems.App.Listen(":8090")

conn, err := wss.NewWebsocketHelper(ctx, wss.WithEndpoint("ws://localhost:8090/ws"), wss.WithRequestHeaders(headers))
if err != nil {
t.Fatalf("error creating websocket helper: %v", err)
}
t.Run("test subscription", func(t *testing.T) {

err = conn.Dial(ctx)
if err != nil {
t.Fatalf("error dialing websocket: %v", err)
}
conn, err := wss.NewWebsocketHelper(ctx, wss.WithEndpoint("ws://localhost:8090/ws"), wss.WithRequestHeaders(headers))
if err != nil {
t.Fatalf("error creating websocket helper: %v", err)
}

err = conn.Write(ctx, api.Subscription{
Method: "SUBSCRIBE",
Params: []string{"submission@test-aggregate"},
})
if err != nil {
t.Fatalf("error subscribing to websocket: %v", err)
}
err = conn.Dial(ctx)
if err != nil {
t.Fatalf("error dialing websocket: %v", err)
}

sampleSubmissionData, err := generateSampleSubmissionData(
testItems.TmpConfig.ID,
int64(15),
time.Now(),
1,
"test-aggregate",
)
if err != nil {
t.Fatalf("error generating sample submission data: %v", err)
}
err = conn.Write(ctx, api.Subscription{
Method: "SUBSCRIBE",
Params: []string{"submission@test-aggregate"},
})
if err != nil {
t.Fatalf("error subscribing to websocket: %v", err)
}

err = testPublishData(ctx, *sampleSubmissionData)
if err != nil {
t.Fatalf("error publishing sample submission data: %v", err)
}
sampleSubmissionData, err := generateSampleSubmissionData(
testItems.TmpConfig.ID,
int64(15),
time.Now(),
1,
"test-aggregate",
)
if err != nil {
t.Fatalf("error generating sample submission data: %v", err)
}

ch := make(chan any)
go conn.Read(ctx, ch)
err = testPublishData(ctx, *sampleSubmissionData)
if err != nil {
t.Fatalf("error publishing sample submission data: %v", err)
}

expected, err := testItems.Collector.IncomingDataToOutgoingData(ctx, *sampleSubmissionData)
if err != nil {
t.Fatalf("error converting sample submission data to outgoing data: %v", err)
}
ch := make(chan any)
go conn.Read(ctx, ch)

sample := <-ch
expected, err := testItems.Collector.IncomingDataToOutgoingData(ctx, *sampleSubmissionData)
if err != nil {
t.Fatalf("error converting sample submission data to outgoing data: %v", err)
}

result, err := wsfcommon.MessageToStruct[common.OutgoingSubmissionData](sample.(map[string]any))
if err != nil {
t.Fatalf("error converting sample to struct: %v", err)
}
assert.Equal(t, *expected, result)
sample := <-ch

err = conn.Close()
if err != nil {
t.Fatalf("error closing websocket: %v", err)
}
result, err := wsfcommon.MessageToStruct[common.OutgoingSubmissionData](sample.(map[string]any))
if err != nil {
t.Fatalf("error converting sample to struct: %v", err)
}
assert.Equal(t, *expected, result)
err = conn.Close()
if err != nil {
t.Fatalf("error closing websocket: %v", err)
}
})

t.Run("test fail for 10+ dial", func(t *testing.T) {
conns := []*wss.WebsocketHelper{}
for i := 0; i < 11; i++ {
conn, err := wss.NewWebsocketHelper(ctx, wss.WithEndpoint("ws://localhost:8090/ws"), wss.WithRequestHeaders(headers))
if err != nil {
t.Fatalf("error creating websocket helper: %v", err)
}

err = conn.Dial(ctx)
if err != nil {
t.Fatalf("error dialing websocket: %v", err)
}

conns = append(conns, conn)
}

assert.Error(t, conns[0].IsAlive(ctx), "expected to fail due to too many connections")
})
}
14 changes: 14 additions & 0 deletions node/pkg/wss/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ func (ws *WebsocketHelper) Close() error {
return nil
}

func (ws *WebsocketHelper) IsAlive(ctx context.Context) error {
if ws.Conn == nil {
return fmt.Errorf("websocket is not running")
}
ctx = ws.Conn.CloseRead(ctx)

err := ws.Conn.Ping(ctx)
if err != nil {
log.Error().Err(err).Str("endpoint", ws.Endpoint).Msg("error pinging websocket")
return err
}
return nil
}

func defaultReader(ctx context.Context, conn *websocket.Conn) (map[string]interface{}, error) {
var data map[string]interface{}
err := wsjson.Read(ctx, conn, &data)
Expand Down

0 comments on commit 67c6911

Please sign in to comment.