From 214e0f05c1fb91e8cdb9d368c98b6c15e252e37b Mon Sep 17 00:00:00 2001 From: nick Date: Sun, 29 Sep 2024 14:29:28 +0900 Subject: [PATCH] feat: timeout for aggregator msg handling --- node/pkg/aggregator/aggregator.go | 154 ++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 42 deletions(-) diff --git a/node/pkg/aggregator/aggregator.go b/node/pkg/aggregator/aggregator.go index e33430310..42c81721e 100644 --- a/node/pkg/aggregator/aggregator.go +++ b/node/pkg/aggregator/aggregator.go @@ -16,6 +16,8 @@ import ( "github.com/rs/zerolog/log" ) +const maxLeaderMsgReceiveTimeout = 100 * time.Millisecond + func NewAggregator(h host.Host, ps *pubsub.PubSub, topicString string, config Config, signHelper *helper.Signer, latestLocalAggregates *LatestLocalAggregates) (*Aggregator, error) { if h == nil || ps == nil || topicString == "" { return nil, errorSentinel.ErrAggregatorInvalidInitValue @@ -165,37 +167,72 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag return nil } - if prices, ok := n.roundPrices.prices[priceDataMessage.RoundID]; ok { - n.roundPrices.prices[priceDataMessage.RoundID] = append(prices, priceDataMessage.PriceData) - n.roundPrices.senders[priceDataMessage.RoundID] = append(n.roundPrices.senders[priceDataMessage.RoundID], msg.SentFrom) - } else { - n.roundPrices.prices[priceDataMessage.RoundID] = []int64{priceDataMessage.PriceData} - n.roundPrices.senders[priceDataMessage.RoundID] = []string{msg.SentFrom} - } + n.storeRoundPriceData(priceDataMessage.RoundID, priceDataMessage.PriceData, msg.SentFrom) if len(n.roundPrices.prices[priceDataMessage.RoundID]) == n.Raft.SubscribersCount()+1 { - n.roundPrices.locked[priceDataMessage.RoundID] = true + // if all messsages received for the round + return n.processCollectedPrices(ctx, priceDataMessage.RoundID, priceDataMessage.Timestamp) + } else if len(n.roundPrices.prices[priceDataMessage.RoundID]) == 1 { + // if it's first message for the round + go n.startPriceCollectionTimeout(ctx, priceDataMessage.RoundID, priceDataMessage.Timestamp) + } + + return nil +} + +func (n *Aggregator) storeRoundPriceData(roundID int32, priceData int64, sender string) { + if prices, ok := n.roundPrices.prices[roundID]; ok { + n.roundPrices.prices[roundID] = append(prices, priceData) + n.roundPrices.senders[roundID] = append(n.roundPrices.senders[roundID], sender) + } else { + n.roundPrices.prices[roundID] = []int64{priceData} + n.roundPrices.senders[roundID] = []string{sender} + } +} - if n.Raft.GetRole() == raft.Leader { - prices := n.roundPrices.prices[priceDataMessage.RoundID] - log.Debug().Str("Player", "Aggregator").Int("peerCount", n.Raft.SubscribersCount()).Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices") +func (n *Aggregator) startPriceCollectionTimeout(ctx context.Context, roundID int32, timestamp time.Time) { + timer := time.NewTimer(maxLeaderMsgReceiveTimeout) + defer timer.Stop() - filteredCollectedPrices := FilterNegative(prices) - if len(filteredCollectedPrices) == 0 { - log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", priceDataMessage.RoundID).Msg("no prices collected") - return nil - } + select { + case <-timer.C: + n.roundPrices.mu.Lock() + defer n.roundPrices.mu.Unlock() - median, err := calculator.GetInt64Med(filteredCollectedPrices) + if !n.roundPrices.locked[roundID] && len(n.roundPrices.prices[roundID]) >= (n.Raft.SubscribersCount()+1)/2 { + log.Debug().Str("Player", "Aggregator").Int32("roundId", roundID).Msg("timeout reached, processing available prices") + err := n.processCollectedPrices(ctx, roundID, timestamp) if err != nil { - log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to get median") - return err + log.Error().Err(err).Int32("roundId", roundID).Msg("failed to process collected prices") } - - return n.PublishPriceFixMessage(ctx, priceDataMessage.RoundID, median, priceDataMessage.Timestamp) } + case <-ctx.Done(): + return } - return nil +} + +func (n *Aggregator) processCollectedPrices(ctx context.Context, roundID int32, timestamp time.Time) error { + n.roundPrices.locked[roundID] = true + if n.Raft.GetRole() != raft.Leader { + return nil + } + + prices := n.roundPrices.prices[roundID] + log.Debug().Str("Player", "Aggregator").Int("peerCount", n.Raft.SubscribersCount()).Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", roundID).Msg("collected prices") + + filteredCollectedPrices := FilterNegative(prices) + if len(filteredCollectedPrices) == 0 { + log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", roundID).Msg("no prices collected") + return nil + } + + median, err := calculator.GetInt64Med(filteredCollectedPrices) + if err != nil { + log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to get median") + return err + } + + return n.PublishPriceFixMessage(ctx, roundID, median, timestamp) } func (n *Aggregator) HandlePriceFixMessage(ctx context.Context, msg raft.Message) error { @@ -257,37 +294,70 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e return nil } - if proofs, ok := n.roundProofs.proofs[proofMessage.RoundID]; ok { - n.roundProofs.proofs[proofMessage.RoundID] = append(proofs, proofMessage.Proof) - n.roundProofs.senders[proofMessage.RoundID] = append(n.roundProofs.senders[proofMessage.RoundID], msg.SentFrom) - } else { - n.roundProofs.proofs[proofMessage.RoundID] = [][]byte{proofMessage.Proof} - n.roundProofs.senders[proofMessage.RoundID] = []string{msg.SentFrom} - } + n.storeRoundProofData(proofMessage.RoundID, proofMessage.Proof, msg.SentFrom) if len(n.roundProofs.proofs[proofMessage.RoundID]) == n.Raft.SubscribersCount()+1 { - n.roundProofs.locked[proofMessage.RoundID] = true + return n.processCollectedProofs(ctx, proofMessage) + } else if len(n.roundProofs.proofs[proofMessage.RoundID]) == 1 { + go n.startProofCollectionTimeout(ctx, proofMessage) + } - log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Int("peerCount", n.Raft.SubscribersCount()).Int32("roundId", proofMessage.RoundID).Any("collected proofs", n.roundProofs.proofs[proofMessage.RoundID]).Msg("collected proofs") + return nil +} + +func (n *Aggregator) storeRoundProofData(roundID int32, proofData []byte, sender string) { + if proofs, ok := n.roundProofs.proofs[roundID]; ok { + n.roundProofs.proofs[roundID] = append(proofs, proofData) + n.roundProofs.senders[roundID] = append(n.roundProofs.senders[roundID], sender) + } else { + n.roundProofs.proofs[roundID] = [][]byte{proofData} + n.roundProofs.senders[roundID] = []string{sender} + } +} - globalAggregate := GlobalAggregate{ - ConfigID: n.ID, - Value: proofMessage.Value, - Round: proofMessage.RoundID, - Timestamp: proofMessage.Timestamp} +func (n *Aggregator) startProofCollectionTimeout(ctx context.Context, proofMessage ProofMessage) { + timer := time.NewTimer(maxLeaderMsgReceiveTimeout) + defer timer.Stop() - concatProof := bytes.Join(n.roundProofs.proofs[proofMessage.RoundID], nil) - proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof} + select { + case <-timer.C: + n.roundProofs.mu.Lock() + defer n.roundProofs.mu.Unlock() - err := PublishGlobalAggregateAndProof(ctx, n.Name, globalAggregate, proof) - if err != nil { - log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to publish global aggregate and proof") + if !n.roundProofs.locked[proofMessage.RoundID] && len(n.roundProofs.proofs[proofMessage.RoundID]) >= (n.Raft.SubscribersCount()+1)/2 { + log.Debug().Str("Player", "Aggregator").Int32("roundId", proofMessage.RoundID).Msg("timeout reached, processing available proofs") + err := n.processCollectedProofs(ctx, proofMessage) + if err != nil { + log.Error().Err(err).Int32("roundId", proofMessage.RoundID).Msg("failed to process collected proofs") + } } + case <-ctx.Done(): + log.Debug().Str("Player", "Aggregator").Int32("roundId", proofMessage.RoundID).Msg("context canceled, stopping timeout") + return + } +} + +func (n *Aggregator) processCollectedProofs(ctx context.Context, proofMessage ProofMessage) error { + n.roundProofs.locked[proofMessage.RoundID] = true + log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Int("peerCount", n.Raft.SubscribersCount()).Int32("roundId", proofMessage.RoundID).Any("collected proofs", n.roundProofs.proofs[proofMessage.RoundID]).Msg("collected proofs") + + globalAggregate := GlobalAggregate{ + ConfigID: n.ID, + Value: proofMessage.Value, + Round: proofMessage.RoundID, + Timestamp: proofMessage.Timestamp} + concatProof := bytes.Join(n.roundProofs.proofs[proofMessage.RoundID], nil) + proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof} + + err := PublishGlobalAggregateAndProof(ctx, n.Name, globalAggregate, proof) + if err != nil { + log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to publish global aggregate and proof") + return err } + return nil } - func (n *Aggregator) PublishTriggerMessage(ctx context.Context, roundId int32, timestamp time.Time) error { triggerMessage := TriggerMessage{ LeaderID: n.Raft.GetHostId(),