From 28003f9598bbda0155d9c179ee65aca343b097ef Mon Sep 17 00:00:00 2001 From: nick Date: Thu, 1 Aug 2024 23:25:58 +0900 Subject: [PATCH] feat: check data freshness before updates, use mutex --- node/pkg/dal/api/controller.go | 11 +++-- node/pkg/dal/api/hub.go | 42 +++++++++------- node/pkg/dal/api/types.go | 3 +- node/pkg/dal/collector/collector.go | 74 ++++++++++++++++++++--------- 4 files changed, 86 insertions(+), 44 deletions(-) diff --git a/node/pkg/dal/api/controller.go b/node/pkg/dal/api/controller.go index b8e1a0903..72fe92f45 100644 --- a/node/pkg/dal/api/controller.go +++ b/node/pkg/dal/api/controller.go @@ -63,11 +63,12 @@ func HandleWebsocket(conn *websocket.Conn) { } if msg.Method == "SUBSCRIBE" { - val, ok := h.clients.Load(threadSafeClient) + h.mu.RLock() + subscriptions, ok := h.clients[threadSafeClient] if !ok { - val = make(map[string]bool) + subscriptions = map[string]bool{} } - subscriptions := val.(map[string]bool) + h.mu.RUnlock() valid := []string{} for _, param := range msg.Params { @@ -78,7 +79,9 @@ func HandleWebsocket(conn *websocket.Conn) { subscriptions[symbol] = true valid = append(valid, param) } - h.clients.Store(threadSafeClient, subscriptions) + h.mu.Lock() + h.clients[threadSafeClient] = subscriptions + h.mu.Unlock() err = stats.InsertWebsocketSubscriptions(*ctx, id, valid) if err != nil { log.Error().Err(err).Msg("failed to insert websocket subscription log") diff --git a/node/pkg/dal/api/hub.go b/node/pkg/dal/api/hub.go index 6b201a977..df5818135 100644 --- a/node/pkg/dal/api/hub.go +++ b/node/pkg/dal/api/hub.go @@ -54,18 +54,30 @@ func (c *Hub) handleClientRegistration() { } func (c *Hub) addClient(client *ThreadSafeClient) { - c.clients.LoadOrStore(client, make(map[string]bool)) + c.mu.RLock() + _, ok := c.clients[client] + c.mu.RUnlock() + if ok { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + c.clients[client] = make(map[string]bool) } func (c *Hub) removeClient(client *ThreadSafeClient) { - raw, ok := c.clients.LoadAndDelete(client) - if ok { - subscriptions, typeOk := raw.(map[string]bool) - if typeOk { - for symbol := range subscriptions { - delete(subscriptions, symbol) - } - } + c.mu.RLock() + subscriptions, ok := c.clients[client] + c.mu.RUnlock() + if !ok { + return + } + + c.mu.Lock() + delete(c.clients, client) + for symbol := range subscriptions { + delete(subscriptions, symbol) } err := client.WriteControl( @@ -107,13 +119,10 @@ func (c *Hub) broadcastDataForSymbol(symbol string) { // pass by pointer to reduce memory copy time func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) { var wg sync.WaitGroup - c.clients.Range(func(threadSafeClient, symbols any) bool { - client, ok := threadSafeClient.(*ThreadSafeClient) - if !ok { - return true - } - subscriptions := symbols.(map[string]bool) + c.mu.RLock() + defer c.mu.RUnlock() + for client, subscriptions := range c.clients { if subscriptions[*symbol] { wg.Add(1) go func(entry *ThreadSafeClient) { @@ -124,7 +133,6 @@ func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol } }(client) } - return true - }) + } wg.Wait() } diff --git a/node/pkg/dal/api/types.go b/node/pkg/dal/api/types.go index d74268812..4b7febd2a 100644 --- a/node/pkg/dal/api/types.go +++ b/node/pkg/dal/api/types.go @@ -14,10 +14,11 @@ type Subscription struct { type Hub struct { configs map[string]types.Config - clients sync.Map // map[*ThreadSafeClient]map[string]bool + clients map[*ThreadSafeClient]map[string]bool register chan *ThreadSafeClient unregister chan *ThreadSafeClient broadcast map[string]chan dalcommon.OutgoingSubmissionData + mu sync.RWMutex } type BulkResponse struct { diff --git a/node/pkg/dal/collector/collector.go b/node/pkg/dal/collector/collector.go index a151abaa8..63019963f 100644 --- a/node/pkg/dal/collector/collector.go +++ b/node/pkg/dal/collector/collector.go @@ -6,6 +6,7 @@ import ( "os" "strconv" "sync" + "time" "bisonai.com/orakl/node/pkg/aggregator" "bisonai.com/orakl/node/pkg/chain/websocketchainreader" @@ -26,12 +27,13 @@ const ( ) type Collector struct { - IncomingStream map[int32]chan aggregator.SubmissionData - OutgoingStream map[int32]chan dalcommon.OutgoingSubmissionData - Symbols map[int32]string - FeedHashes map[int32][]byte - CachedWhitelist []klaytncommon.Address - LatestData sync.Map + IncomingStream map[int32]chan aggregator.SubmissionData + OutgoingStream map[int32]chan dalcommon.OutgoingSubmissionData + Symbols map[int32]string + FeedHashes map[int32][]byte + LatestTimestamps map[int32]time.Time + LatestData map[string]*dalcommon.OutgoingSubmissionData + CachedWhitelist []klaytncommon.Address IsRunning bool CancelFunc context.CancelFunc @@ -68,7 +70,8 @@ func NewCollector(ctx context.Context, configs []types.Config) (*Collector, erro OutgoingStream: make(map[int32]chan dalcommon.OutgoingSubmissionData, len(configs)), Symbols: make(map[int32]string, len(configs)), FeedHashes: make(map[int32][]byte, len(configs)), - LatestData: sync.Map{}, + LatestTimestamps: make(map[int32]time.Time), + LatestData: make(map[string]*dalcommon.OutgoingSubmissionData), chainReader: chainReader, CachedWhitelist: initialWhitelist, submissionProxyContractAddr: submissionProxyContractAddr, @@ -100,27 +103,23 @@ func (c *Collector) Start(ctx context.Context) { } func (c *Collector) GetLatestData(symbol string) (*dalcommon.OutgoingSubmissionData, error) { - result, ok := c.LatestData.Load(symbol) + c.mu.RLock() + defer c.mu.RUnlock() + result, ok := c.LatestData[symbol] if !ok { return nil, errors.New("symbol not found") } - - data, ok := result.(*dalcommon.OutgoingSubmissionData) - if !ok { - return nil, errors.New("symbol not converted") - } - - return data, nil + return result, nil } func (c *Collector) GetAllLatestData() []dalcommon.OutgoingSubmissionData { + c.mu.RLock() + defer c.mu.RUnlock() result := make([]dalcommon.OutgoingSubmissionData, 0, len(c.Symbols)) - c.LatestData.Range(func(key, value interface{}) bool { - if data, ok := value.(*dalcommon.OutgoingSubmissionData); ok { - result = append(result, *data) - } - return true - }) + for _, value := range c.LatestData { + result = append(result, *value) + } + return result } @@ -156,14 +155,45 @@ func (c *Collector) receiveEach(ctx context.Context, configId int32) error { } } +func (c *Collector) compareAndSwapLatestTimestamp(data aggregator.SubmissionData) bool { + c.mu.RLock() + old, ok := c.LatestTimestamps[data.GlobalAggregate.ConfigID] + c.mu.RUnlock() + if !ok { + c.mu.Lock() + c.LatestTimestamps[data.GlobalAggregate.ConfigID] = data.GlobalAggregate.Timestamp + c.mu.Unlock() + return true + } + + if old.After(data.GlobalAggregate.Timestamp) { + return false + } + + c.mu.Lock() + c.LatestTimestamps[data.GlobalAggregate.ConfigID] = data.GlobalAggregate.Timestamp + c.mu.Unlock() + return true +} + func (c *Collector) processIncomingData(ctx context.Context, data aggregator.SubmissionData) { + valid := c.compareAndSwapLatestTimestamp(data) + if !valid { + log.Debug().Msg("old data recieved") + return + } + result, err := c.IncomingDataToOutgoingData(ctx, data) if err != nil { log.Error().Err(err).Str("Player", "DalCollector").Msg("failed to convert incoming data to outgoing data") return } - defer c.LatestData.Store(result.Symbol, result) + defer func(data *dalcommon.OutgoingSubmissionData) { + c.mu.Lock() + defer c.mu.Unlock() + c.LatestData[data.Symbol] = data + }(result) c.OutgoingStream[data.GlobalAggregate.ConfigID] <- *result }