Skip to content

Commit

Permalink
fix: fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Aug 2, 2024
1 parent 25d8884 commit cd2002b
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 26 deletions.
3 changes: 3 additions & 0 deletions node/pkg/dal/api/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ func handleSubscribe(h *Hub, client *ThreadSafeClient, msg Subscription, ctx con
h.clients[client] = subscriptions

defer func(subscribed []string) {
if len(valid) == 0 {
return
}
if err := stats.InsertWebsocketSubscriptions(ctx, id, valid); err != nil {
log.Error().Err(err).Msg("failed to insert websocket subscription log")
}
Expand Down
61 changes: 35 additions & 26 deletions node/pkg/dal/api/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewHub(configs map[string]types.Config) *Hub {
register: make(chan *ThreadSafeClient),
unregister: make(chan *ThreadSafeClient),
broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData),
connPerIP: make(map[string][]*websocket.Conn),
connPerIP: make(map[string][]*ThreadSafeClient),
}
}

Expand All @@ -46,36 +46,43 @@ func (h *Hub) Start(ctx context.Context, collector *collector.Collector) {
func (h *Hub) handleClientRegistration() {
for {
select {
case client := <-c.register:
c.addClient(client)
case client := <-c.unregister:
c.removeClient(client)
case client := <-h.register:
h.addClient(client)
case client := <-h.unregister:
h.removeClient(client)
}
}
}

func (c *Hub) addClient(client *ThreadSafeClient) {
c.mu.Lock() // Use write lock for both checking and insertion
defer c.mu.Unlock()
if _, ok := c.clients[client]; ok {
func (h *Hub) addClient(client *ThreadSafeClient) {
h.mu.Lock() // Use write lock for both checking and insertion
defer h.mu.Unlock()
if _, ok := h.clients[client]; ok {
return
}
c.clients[client] = make(map[string]bool)

if _, ok := h.connPerIP[conn.IP()]; !ok {
h.connPerIP[conn.IP()] = make([]*websocket.Conn, 0)
h.clients[client] = make(map[string]bool)

ip := client.Conn.IP()
if _, ok := h.connPerIP[ip]; !ok {
h.connPerIP[ip] = make([]*ThreadSafeClient, 0)
}

h.connPerIP[conn.IP()] = append(h.connPerIP[conn.IP()], conn)

if len(h.connPerIP) > MAX_CONNECTIONS {
oldConn := h.connPerIP[conn.IP()][0]
h.connPerIP[ip] = append(h.connPerIP[ip], client)
if len(h.connPerIP[ip]) > MAX_CONNECTIONS {
oldConn := h.connPerIP[ip][0]
if subs, ok := h.clients[oldConn]; ok {
for k := range subs {
delete(h.clients[oldConn], k)
}
}
subscriptions, ok := h.clients[oldConn]
if !ok {
return
}
delete(h.clients, oldConn)
for symbol := range subscriptions {
delete(subscriptions, symbol)
}
h.connPerIP[ip] = h.connPerIP[ip][1:]
oldConn.WriteControl(
websocket.CloseMessage,
Expand All @@ -86,20 +93,22 @@ func (c *Hub) addClient(client *ThreadSafeClient) {
}
}

func (c *Hub) removeClient(client *ThreadSafeClient) {
c.mu.Lock() // Use write lock for both checking and removal
defer c.mu.Unlock()
subscriptions, ok := c.clients[client]
func (h *Hub) removeClient(client *ThreadSafeClient) {
h.mu.Lock() // Use write lock for both checking and removal
defer h.mu.Unlock()
subscriptions, ok := h.clients[client]
if !ok {
return
}
delete(c.clients, client)
delete(h.clients, client)
for symbol := range subscriptions {
delete(subscriptions, symbol)
}

for i, c := range h.connPerIP[conn.IP()] {
if c == conn {
ip := client.Conn.IP()

for i, entry := range h.connPerIP[ip] {
if entry == client {
h.connPerIP[ip] = append(h.connPerIP[ip][:i], h.connPerIP[ip][i+1:]...)
if len(h.connPerIP) == 0 {
delete(h.connPerIP, ip)
Expand All @@ -124,12 +133,12 @@ func (c *Hub) removeClient(client *ThreadSafeClient) {

func (h *Hub) initializeBroadcastChannels(collector *collector.Collector) {
for configId, stream := range collector.OutgoingStream {
symbol := c.configIdToSymbol(configId)
symbol := h.configIdToSymbol(configId)
if symbol == "" {
continue
}

c.broadcast[symbol] = stream
h.broadcast[symbol] = stream
}
}

Expand Down
109 changes: 109 additions & 0 deletions node/pkg/dal/tests/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package test

import (
"context"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -222,6 +223,14 @@ func TestApiWebsocket(t *testing.T) {

t.Run("test fail for 10+ dial", func(t *testing.T) {
conns := []*wss.WebsocketHelper{}
defer func() {
for _, conn := range conns {
err := conn.Close()
if err != nil {
t.Logf("error closing websocket: %v", err)
}
}
}()
for i := 0; i < 11; i++ {
conn, err := wss.NewWebsocketHelper(ctx, wss.WithEndpoint("ws://localhost:8090/ws"), wss.WithRequestHeaders(headers))
if err != nil {
Expand All @@ -238,5 +247,105 @@ func TestApiWebsocket(t *testing.T) {
time.Sleep(1 * time.Second)

assert.Error(t, conns[0].IsAlive(ctx), "expected to fail due to too many connections")

})

t.Run("test multiconnection and multi subscriptions", func(t *testing.T) {
var wg sync.WaitGroup
headers := map[string]string{"X-API-Key": testItems.ApiKey}
connCount := 10
subscriptions := []string{"submission@test-aggregate"}

// Create a channel to collect all results
resultsChan := make(chan common.OutgoingSubmissionData, connCount*len(subscriptions))

for i := 0; i < connCount; i++ {
wg.Add(1)
go func(clientID int) {
defer wg.Done()
conn, err := wss.NewWebsocketHelper(ctx, wss.WithEndpoint("ws://localhost:8090/ws"), wss.WithRequestHeaders(headers))
if err != nil {
t.Errorf("error creating websocket helper for client %d: %v", clientID, err)
return
}

err = conn.Dial(ctx)
if err != nil {
t.Errorf("error dialing websocket for client %d: %v", clientID, err)
return
}

defer func() {
err = conn.Close()
if err != nil {
t.Errorf("error closing websocket for client %d: %v", clientID, err)
}
}()

err = conn.Write(ctx, api.Subscription{
Method: "SUBSCRIBE",
Params: subscriptions,
})
if err != nil {
t.Errorf("error subscribing for client %d: %v", clientID, err)
return
}

// Receive messages
ch := make(chan any)
go conn.Read(ctx, ch)

// Read messages from the channel and store the results
for j := 0; j < len(subscriptions); j++ {
select {
case sample := <-ch:
result, err := wsfcommon.MessageToStruct[common.OutgoingSubmissionData](sample.(map[string]any))
if err != nil {
t.Errorf("error converting sample to struct for client %d: %v", clientID, err)
return
}
resultsChan <- result
case <-time.After(10 * time.Second): // Timeout if no message is received
t.Errorf("timeout waiting for message for client %d", clientID)
return
}
}
}(i)
}

// Simulate data publication
expectedData, err := generateSampleSubmissionData(
testItems.TmpConfig.ID,
int64(15),
time.Now(),
1,
"test-aggregate",
)
if err != nil {
t.Fatalf("error generating expected data: %v", err)
}
expected, err := testItems.Collector.IncomingDataToOutgoingData(ctx, *expectedData)
if err != nil {
t.Fatalf("error converting sample submission data to outgoing data: %v", err)
}

// Publish data
err = testPublishData(ctx, *expectedData)
if err != nil {
t.Fatalf("error publishing sample submission data: %v", err)
}

// Wait for all goroutines to finish
wg.Wait()
close(resultsChan)

// Verify results
for result := range resultsChan {
if result.Symbol == expected.Symbol {
assert.Equal(t, *expected, result)
} else {
t.Errorf("unexpected data received: %v", result)
}
}
})
}

0 comments on commit cd2002b

Please sign in to comment.