Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Jul 31, 2024
1 parent 85b145c commit 20257d3
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 64 deletions.
48 changes: 48 additions & 0 deletions node/pkg/dal/api/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package api

import (
"sync"
"time"

"github.com/gofiber/contrib/websocket"
"github.com/rs/zerolog/log"
)

type ThreadSafeClient struct {
Conn *websocket.Conn
mu sync.Mutex
}

func NewThreadSafeClient(conn *websocket.Conn) *ThreadSafeClient {
return &ThreadSafeClient{
Conn: conn,
}
}

func (c *ThreadSafeClient) WriteJSON(data any) error {
c.mu.Lock()
defer c.mu.Unlock()
if err := c.Conn.WriteJSON(data); err != nil {
log.Error().Err(err).Msg("failed to write json msg")
return err
}
return nil
}

func (c *ThreadSafeClient) ReadJSON(data any) error {
c.mu.Lock()
defer c.mu.Unlock()
if err := c.Conn.ReadJSON(&data); err != nil {
log.Error().Err(err).Msg("failed to read json msg")
return err
}
return nil
}

func (c *ThreadSafeClient) Close() error {
return c.Conn.Close()
}

func (c *ThreadSafeClient) WriteControl(messageType int, data []byte, deadline time.Time) error {
return c.Conn.WriteControl(messageType, data, deadline)
}
34 changes: 20 additions & 14 deletions node/pkg/dal/api/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ func HandleWebsocket(conn *websocket.Conn) {
return
}

threadSafeClient := NewThreadSafeClient(conn)

closeHandler := conn.CloseHandler()
conn.SetCloseHandler(func(code int, text string) error {
h.unregister <- conn
h.unregister <- threadSafeClient
return closeHandler(code, text)
})

Expand All @@ -33,7 +35,7 @@ func HandleWebsocket(conn *websocket.Conn) {
return
}

h.register <- conn
h.register <- threadSafeClient
apiKey := conn.Headers("X-Api-Key")

id, err := stats.InsertWebsocketConnection(*ctx, apiKey)
Expand All @@ -44,7 +46,7 @@ func HandleWebsocket(conn *websocket.Conn) {
log.Info().Int32("id", id).Msg("inserted websocket connection")

defer func() {
h.unregister <- conn
h.unregister <- threadSafeClient
err = stats.UpdateWebsocketConnection(*ctx, id)
if err != nil {
log.Error().Err(err).Msg("failed to update websocket connection")
Expand All @@ -55,28 +57,32 @@ func HandleWebsocket(conn *websocket.Conn) {

for {
var msg Subscription
if err = conn.ReadJSON(&msg); err != nil {
if err = threadSafeClient.ReadJSON(&msg); err != nil {
log.Error().Err(err).Msg("failed to read message")
return
continue
}

if msg.Method == "SUBSCRIBE" {
h.mu.Lock()
if h.clients[conn] == nil {
h.clients[conn] = make(map[string]bool)
val, ok := h.clients.Load(threadSafeClient)
if !ok {
val = make(map[string]bool)
}
subscriptions := val.(map[string]bool)
valid := []string{}

for _, param := range msg.Params {
symbol := strings.TrimPrefix(param, "submission@")
if _, ok := h.configs[symbol]; !ok {
continue
}
h.clients[conn][symbol] = true
err = stats.InsertWebsocketSubscription(*ctx, id, param)
if err != nil {
log.Error().Err(err).Msg("failed to insert websocket subscription")
}
subscriptions[symbol] = true
valid = append(valid, param)
}
h.clients.Store(threadSafeClient, subscriptions)
err = stats.InsertWebsocketSubscriptions(*ctx, id, valid)
if err != nil {
log.Error().Err(err).Msg("failed to insert websocket subscription log")
}
h.mu.Unlock()
}
}
}
Expand Down
90 changes: 46 additions & 44 deletions node/pkg/dal/api/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"sync"
"time"

"bisonai.com/orakl/node/pkg/common/types"
Expand All @@ -25,9 +26,8 @@ func NewHub(configs map[string]types.Config) *Hub {
return &Hub{
configs: configs,

clients: make(map[*websocket.Conn]map[string]bool),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
register: make(chan *ThreadSafeClient),
unregister: make(chan *ThreadSafeClient),
broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData),
}
}
Expand All @@ -45,50 +45,41 @@ func (c *Hub) Start(ctx context.Context, collector *collector.Collector) {
func (c *Hub) handleClientRegistration() {
for {
select {
case conn := <-c.register:
c.addClient(conn)
case conn := <-c.unregister:
c.removeClient(conn)
case client := <-c.register:
c.addClient(client)
case client := <-c.unregister:
c.removeClient(client)
}
}
}

func (c *Hub) addClient(conn *websocket.Conn) {
c.mu.RLock()
if _, ok := c.clients[conn]; ok {
c.mu.RUnlock()
return
}
c.mu.RUnlock()

c.mu.Lock()
defer c.mu.Unlock()

c.clients[conn] = make(map[string]bool)
func (c *Hub) addClient(client *ThreadSafeClient) {
c.clients.LoadOrStore(client, make(map[string]bool))
}

func (c *Hub) removeClient(conn *websocket.Conn) {
c.mu.RLock()
if _, ok := c.clients[conn]; !ok {
c.mu.RUnlock()
return
}
c.mu.RUnlock()

c.mu.Lock()
defer c.mu.Unlock()

for symbol := range c.clients[conn] {
delete(c.clients[conn], symbol)
func (c *Hub) removeClient(client *ThreadSafeClient) {
raw, ok := c.clients.LoadAndDelete(client)
if ok {
subscriptions, typeOk := raw.(map[string]bool)
if typeOk {
for symbol := range subscriptions {
delete(subscriptions, symbol)
}
}
}
delete(c.clients, conn)

_ = conn.WriteControl(
err := client.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
time.Now().Add(time.Second),
)
conn.Close()
if err != nil {
log.Warn().Err(err).Msg("failed to write close message")
}
err = client.Close()
if err != nil {
log.Warn().Err(err).Msg("failed to close connection")
}
}

func (c *Hub) initializeBroadcastChannels(collector *collector.Collector) {
Expand All @@ -115,14 +106,25 @@ func (c *Hub) broadcastDataForSymbol(symbol string) {

// pass by pointer to reduce memory copy time
func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) {
c.mu.Lock()
defer c.mu.Unlock()
for conn := range c.clients {
if _, ok := c.clients[conn][*symbol]; ok {
if err := conn.WriteJSON(*data); err != nil {
log.Error().Err(err).Msg("failed to write message")
go func() { c.unregister <- conn }()
}
var wg sync.WaitGroup
c.clients.Range(func(threadSafeClient, symbols any) bool {
client, ok := threadSafeClient.(*ThreadSafeClient)
if !ok {
return true
}
}

subscriptions := symbols.(map[string]bool)
if subscriptions[*symbol] {
wg.Add(1)
go func(entry *ThreadSafeClient) {
defer wg.Done()
if err := entry.WriteJSON(*data); err != nil {
log.Error().Err(err).Msg("failed to write message")
c.unregister <- entry
}
}(client)
}
return true
})
wg.Wait()
}
9 changes: 3 additions & 6 deletions node/pkg/dal/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"bisonai.com/orakl/node/pkg/common/types"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"github.com/gofiber/contrib/websocket"
)

type Subscription struct {
Expand All @@ -15,12 +14,10 @@ type Subscription struct {

type Hub struct {
configs map[string]types.Config
clients map[*websocket.Conn]map[string]bool
register chan *websocket.Conn
unregister chan *websocket.Conn
clients sync.Map // map[*ThreadSafeClient]map[string]bool
register chan *ThreadSafeClient
unregister chan *ThreadSafeClient
broadcast map[string]chan dalcommon.OutgoingSubmissionData

mu sync.RWMutex
}

type BulkResponse struct {
Expand Down
9 changes: 9 additions & 0 deletions node/pkg/dal/utils/stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ func InsertWebsocketSubscription(ctx context.Context, connectionId int32, topic
})
}

func InsertWebsocketSubscriptions(ctx context.Context, connectionId int32, topics []string) error {
entries := [][]any{}
for _, topic := range topics {
entries = append(entries, []any{connectionId, topic})
}

return db.BulkInsert(ctx, "websocket_subscriptions", []string{"connection_id", "topic"}, entries)
}

func StatsMiddleware(c *fiber.Ctx) error {
start := time.Now()
if err := c.Next(); err != nil {
Expand Down

0 comments on commit 20257d3

Please sign in to comment.