From c0552def9a7f04af33b3705cd4efc5bcd454dff6 Mon Sep 17 00:00:00 2001 From: nick Date: Thu, 1 Aug 2024 00:27:26 +0900 Subject: [PATCH] wip --- node/pkg/dal/api/client.go | 48 ++++++++++++++++ node/pkg/dal/api/controller.go | 34 +++++++----- node/pkg/dal/api/hub.go | 92 +++++++++++++++---------------- node/pkg/dal/api/types.go | 9 +-- node/pkg/dal/utils/stats/stats.go | 9 +++ 5 files changed, 126 insertions(+), 66 deletions(-) create mode 100644 node/pkg/dal/api/client.go diff --git a/node/pkg/dal/api/client.go b/node/pkg/dal/api/client.go new file mode 100644 index 000000000..1b0a5814e --- /dev/null +++ b/node/pkg/dal/api/client.go @@ -0,0 +1,48 @@ +package api + +import ( + "sync" + "time" + + "github.com/gofiber/contrib/websocket" + "github.com/rs/zerolog/log" +) + +type ThreadSafeClient struct { + Conn *websocket.Conn + mu sync.Mutex +} + +func NewThreadSafeClient(conn *websocket.Conn) *ThreadSafeClient { + return &ThreadSafeClient{ + Conn: conn, + } +} + +func (c *ThreadSafeClient) WriteJSON(data any) error { + c.mu.Lock() + defer c.mu.Unlock() + if err := c.Conn.WriteJSON(data); err != nil { + log.Error().Err(err).Msg("failed to write json msg") + return err + } + return nil +} + +func (c *ThreadSafeClient) ReadJSON(data any) error { + c.mu.Lock() + defer c.mu.Unlock() + if err := c.Conn.ReadJSON(&data); err != nil { + log.Error().Err(err).Msg("failed to read json msg") + return err + } + return nil +} + +func (c *ThreadSafeClient) Close() error { + return c.Conn.Close() +} + +func (c *ThreadSafeClient) WriteControl(messageType int, data []byte, deadline time.Time) error { + return c.Conn.WriteControl(messageType, data, deadline) +} diff --git a/node/pkg/dal/api/controller.go b/node/pkg/dal/api/controller.go index 764ba3ba3..fc32b16aa 100644 --- a/node/pkg/dal/api/controller.go +++ b/node/pkg/dal/api/controller.go @@ -21,9 +21,11 @@ func HandleWebsocket(conn *websocket.Conn) { return } + threadSafeClient := NewThreadSafeClient(conn) + closeHandler := conn.CloseHandler() conn.SetCloseHandler(func(code int, text string) error { - h.unregister <- conn + h.unregister <- threadSafeClient return closeHandler(code, text) }) @@ -33,7 +35,7 @@ func HandleWebsocket(conn *websocket.Conn) { return } - h.register <- conn + h.register <- threadSafeClient apiKey := conn.Headers("X-Api-Key") id, err := stats.InsertWebsocketConnection(*ctx, apiKey) @@ -44,7 +46,7 @@ func HandleWebsocket(conn *websocket.Conn) { log.Info().Int32("id", id).Msg("inserted websocket connection") defer func() { - h.unregister <- conn + h.unregister <- threadSafeClient err = stats.UpdateWebsocketConnection(*ctx, id) if err != nil { log.Error().Err(err).Msg("failed to update websocket connection") @@ -55,28 +57,32 @@ func HandleWebsocket(conn *websocket.Conn) { for { var msg Subscription - if err = conn.ReadJSON(&msg); err != nil { + if err = threadSafeClient.ReadJSON(&msg); err != nil { log.Error().Err(err).Msg("failed to read message") - return + continue } if msg.Method == "SUBSCRIBE" { - h.mu.Lock() - if h.clients[conn] == nil { - h.clients[conn] = make(map[string]bool) + val, ok := h.clients.Load(threadSafeClient) + if !ok { + val = make(map[string]bool) } + subscriptions := val.(map[string]bool) + valid := []string{} + for _, param := range msg.Params { symbol := strings.TrimPrefix(param, "submission@") if _, ok := h.configs[symbol]; !ok { continue } - h.clients[conn][symbol] = true - err = stats.InsertWebsocketSubscription(*ctx, id, param) - if err != nil { - log.Error().Err(err).Msg("failed to insert websocket subscription") - } + subscriptions[symbol] = true + valid = append(valid, param) + } + h.clients.Store(threadSafeClient, subscriptions) + err = stats.InsertWebsocketSubscriptions(*ctx, id, valid) + if err != nil { + log.Error().Err(err).Msg("failed to insert websocket subscription log") } - h.mu.Unlock() } } } diff --git a/node/pkg/dal/api/hub.go b/node/pkg/dal/api/hub.go index a166f8f03..6b201a977 100644 --- a/node/pkg/dal/api/hub.go +++ b/node/pkg/dal/api/hub.go @@ -2,6 +2,7 @@ package api import ( "context" + "sync" "time" "bisonai.com/orakl/node/pkg/common/types" @@ -25,9 +26,8 @@ func NewHub(configs map[string]types.Config) *Hub { return &Hub{ configs: configs, - clients: make(map[*websocket.Conn]map[string]bool), - register: make(chan *websocket.Conn), - unregister: make(chan *websocket.Conn), + register: make(chan *ThreadSafeClient), + unregister: make(chan *ThreadSafeClient), broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData), } } @@ -45,50 +45,41 @@ func (c *Hub) Start(ctx context.Context, collector *collector.Collector) { func (c *Hub) handleClientRegistration() { for { select { - case conn := <-c.register: - c.addClient(conn) - case conn := <-c.unregister: - c.removeClient(conn) + case client := <-c.register: + c.addClient(client) + case client := <-c.unregister: + c.removeClient(client) } } } -func (c *Hub) addClient(conn *websocket.Conn) { - c.mu.RLock() - if _, ok := c.clients[conn]; ok { - c.mu.RUnlock() - return - } - c.mu.RUnlock() - - c.mu.Lock() - defer c.mu.Unlock() - - c.clients[conn] = make(map[string]bool) +func (c *Hub) addClient(client *ThreadSafeClient) { + c.clients.LoadOrStore(client, make(map[string]bool)) } -func (c *Hub) removeClient(conn *websocket.Conn) { - c.mu.RLock() - if _, ok := c.clients[conn]; !ok { - c.mu.RUnlock() - return - } - c.mu.RUnlock() - - c.mu.Lock() - defer c.mu.Unlock() - - for symbol := range c.clients[conn] { - delete(c.clients[conn], symbol) +func (c *Hub) removeClient(client *ThreadSafeClient) { + raw, ok := c.clients.LoadAndDelete(client) + if ok { + subscriptions, typeOk := raw.(map[string]bool) + if typeOk { + for symbol := range subscriptions { + delete(subscriptions, symbol) + } + } } - delete(c.clients, conn) - _ = conn.WriteControl( + err := client.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second), ) - conn.Close() + if err != nil { + log.Warn().Err(err).Msg("failed to write close message") + } + err = client.Close() + if err != nil { + log.Warn().Err(err).Msg("failed to close connection") + } } func (c *Hub) initializeBroadcastChannels(collector *collector.Collector) { @@ -115,16 +106,25 @@ func (c *Hub) broadcastDataForSymbol(symbol string) { // pass by pointer to reduce memory copy time func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) { - c.mu.Lock() - defer c.mu.Unlock() - for conn := range c.clients { - if _, ok := c.clients[conn][*symbol]; ok { - if err := conn.WriteJSON(*data); err != nil { - log.Error().Err(err).Msg("failed to write message") - go func(conn *websocket.Conn) { - c.unregister <- conn - }(conn) - } + var wg sync.WaitGroup + c.clients.Range(func(threadSafeClient, symbols any) bool { + client, ok := threadSafeClient.(*ThreadSafeClient) + if !ok { + return true } - } + + subscriptions := symbols.(map[string]bool) + if subscriptions[*symbol] { + wg.Add(1) + go func(entry *ThreadSafeClient) { + defer wg.Done() + if err := entry.WriteJSON(*data); err != nil { + log.Error().Err(err).Msg("failed to write message") + c.unregister <- entry + } + }(client) + } + return true + }) + wg.Wait() } diff --git a/node/pkg/dal/api/types.go b/node/pkg/dal/api/types.go index eb7128576..d74268812 100644 --- a/node/pkg/dal/api/types.go +++ b/node/pkg/dal/api/types.go @@ -5,7 +5,6 @@ import ( "bisonai.com/orakl/node/pkg/common/types" dalcommon "bisonai.com/orakl/node/pkg/dal/common" - "github.com/gofiber/contrib/websocket" ) type Subscription struct { @@ -15,12 +14,10 @@ type Subscription struct { type Hub struct { configs map[string]types.Config - clients map[*websocket.Conn]map[string]bool - register chan *websocket.Conn - unregister chan *websocket.Conn + clients sync.Map // map[*ThreadSafeClient]map[string]bool + register chan *ThreadSafeClient + unregister chan *ThreadSafeClient broadcast map[string]chan dalcommon.OutgoingSubmissionData - - mu sync.RWMutex } type BulkResponse struct { diff --git a/node/pkg/dal/utils/stats/stats.go b/node/pkg/dal/utils/stats/stats.go index 0a915b7c4..a859f41d9 100644 --- a/node/pkg/dal/utils/stats/stats.go +++ b/node/pkg/dal/utils/stats/stats.go @@ -71,6 +71,15 @@ func InsertWebsocketSubscription(ctx context.Context, connectionId int32, topic }) } +func InsertWebsocketSubscriptions(ctx context.Context, connectionId int32, topics []string) error { + entries := [][]any{} + for _, topic := range topics { + entries = append(entries, []any{connectionId, topic}) + } + + return db.BulkInsert(ctx, "websocket_subscriptions", []string{"connection_id", "topic"}, entries) +} + func StatsMiddleware(c *fiber.Ctx) error { start := time.Now() if err := c.Next(); err != nil {