Skip to content

Commit

Permalink
feat: timeout for aggregator msg handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Sep 29, 2024
1 parent ce21f43 commit 214e0f0
Showing 1 changed file with 112 additions and 42 deletions.
154 changes: 112 additions & 42 deletions node/pkg/aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/rs/zerolog/log"
)

const maxLeaderMsgReceiveTimeout = 100 * time.Millisecond

func NewAggregator(h host.Host, ps *pubsub.PubSub, topicString string, config Config, signHelper *helper.Signer, latestLocalAggregates *LatestLocalAggregates) (*Aggregator, error) {
if h == nil || ps == nil || topicString == "" {
return nil, errorSentinel.ErrAggregatorInvalidInitValue
Expand Down Expand Up @@ -165,37 +167,72 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag
return nil
}

if prices, ok := n.roundPrices.prices[priceDataMessage.RoundID]; ok {
n.roundPrices.prices[priceDataMessage.RoundID] = append(prices, priceDataMessage.PriceData)
n.roundPrices.senders[priceDataMessage.RoundID] = append(n.roundPrices.senders[priceDataMessage.RoundID], msg.SentFrom)
} else {
n.roundPrices.prices[priceDataMessage.RoundID] = []int64{priceDataMessage.PriceData}
n.roundPrices.senders[priceDataMessage.RoundID] = []string{msg.SentFrom}
}
n.storeRoundPriceData(priceDataMessage.RoundID, priceDataMessage.PriceData, msg.SentFrom)

if len(n.roundPrices.prices[priceDataMessage.RoundID]) == n.Raft.SubscribersCount()+1 {
n.roundPrices.locked[priceDataMessage.RoundID] = true
// if all messsages received for the round
return n.processCollectedPrices(ctx, priceDataMessage.RoundID, priceDataMessage.Timestamp)
} else if len(n.roundPrices.prices[priceDataMessage.RoundID]) == 1 {
// if it's first message for the round
go n.startPriceCollectionTimeout(ctx, priceDataMessage.RoundID, priceDataMessage.Timestamp)
}

return nil
}

func (n *Aggregator) storeRoundPriceData(roundID int32, priceData int64, sender string) {
if prices, ok := n.roundPrices.prices[roundID]; ok {
n.roundPrices.prices[roundID] = append(prices, priceData)
n.roundPrices.senders[roundID] = append(n.roundPrices.senders[roundID], sender)
} else {
n.roundPrices.prices[roundID] = []int64{priceData}
n.roundPrices.senders[roundID] = []string{sender}
}
}

if n.Raft.GetRole() == raft.Leader {
prices := n.roundPrices.prices[priceDataMessage.RoundID]
log.Debug().Str("Player", "Aggregator").Int("peerCount", n.Raft.SubscribersCount()).Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices")
func (n *Aggregator) startPriceCollectionTimeout(ctx context.Context, roundID int32, timestamp time.Time) {
timer := time.NewTimer(maxLeaderMsgReceiveTimeout)
defer timer.Stop()

filteredCollectedPrices := FilterNegative(prices)
if len(filteredCollectedPrices) == 0 {
log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", priceDataMessage.RoundID).Msg("no prices collected")
return nil
}
select {
case <-timer.C:
n.roundPrices.mu.Lock()
defer n.roundPrices.mu.Unlock()

median, err := calculator.GetInt64Med(filteredCollectedPrices)
if !n.roundPrices.locked[roundID] && len(n.roundPrices.prices[roundID]) >= (n.Raft.SubscribersCount()+1)/2 {
log.Debug().Str("Player", "Aggregator").Int32("roundId", roundID).Msg("timeout reached, processing available prices")
err := n.processCollectedPrices(ctx, roundID, timestamp)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to get median")
return err
log.Error().Err(err).Int32("roundId", roundID).Msg("failed to process collected prices")
}

return n.PublishPriceFixMessage(ctx, priceDataMessage.RoundID, median, priceDataMessage.Timestamp)
}
case <-ctx.Done():
return
}
return nil
}

func (n *Aggregator) processCollectedPrices(ctx context.Context, roundID int32, timestamp time.Time) error {
n.roundPrices.locked[roundID] = true
if n.Raft.GetRole() != raft.Leader {
return nil
}

prices := n.roundPrices.prices[roundID]
log.Debug().Str("Player", "Aggregator").Int("peerCount", n.Raft.SubscribersCount()).Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", roundID).Msg("collected prices")

filteredCollectedPrices := FilterNegative(prices)
if len(filteredCollectedPrices) == 0 {
log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", roundID).Msg("no prices collected")
return nil
}

median, err := calculator.GetInt64Med(filteredCollectedPrices)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to get median")
return err
}

return n.PublishPriceFixMessage(ctx, roundID, median, timestamp)
}

func (n *Aggregator) HandlePriceFixMessage(ctx context.Context, msg raft.Message) error {
Expand Down Expand Up @@ -257,37 +294,70 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e
return nil
}

if proofs, ok := n.roundProofs.proofs[proofMessage.RoundID]; ok {
n.roundProofs.proofs[proofMessage.RoundID] = append(proofs, proofMessage.Proof)
n.roundProofs.senders[proofMessage.RoundID] = append(n.roundProofs.senders[proofMessage.RoundID], msg.SentFrom)
} else {
n.roundProofs.proofs[proofMessage.RoundID] = [][]byte{proofMessage.Proof}
n.roundProofs.senders[proofMessage.RoundID] = []string{msg.SentFrom}
}
n.storeRoundProofData(proofMessage.RoundID, proofMessage.Proof, msg.SentFrom)

if len(n.roundProofs.proofs[proofMessage.RoundID]) == n.Raft.SubscribersCount()+1 {
n.roundProofs.locked[proofMessage.RoundID] = true
return n.processCollectedProofs(ctx, proofMessage)
} else if len(n.roundProofs.proofs[proofMessage.RoundID]) == 1 {
go n.startProofCollectionTimeout(ctx, proofMessage)
}

log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Int("peerCount", n.Raft.SubscribersCount()).Int32("roundId", proofMessage.RoundID).Any("collected proofs", n.roundProofs.proofs[proofMessage.RoundID]).Msg("collected proofs")
return nil
}

func (n *Aggregator) storeRoundProofData(roundID int32, proofData []byte, sender string) {
if proofs, ok := n.roundProofs.proofs[roundID]; ok {
n.roundProofs.proofs[roundID] = append(proofs, proofData)
n.roundProofs.senders[roundID] = append(n.roundProofs.senders[roundID], sender)
} else {
n.roundProofs.proofs[roundID] = [][]byte{proofData}
n.roundProofs.senders[roundID] = []string{sender}
}
}

globalAggregate := GlobalAggregate{
ConfigID: n.ID,
Value: proofMessage.Value,
Round: proofMessage.RoundID,
Timestamp: proofMessage.Timestamp}
func (n *Aggregator) startProofCollectionTimeout(ctx context.Context, proofMessage ProofMessage) {
timer := time.NewTimer(maxLeaderMsgReceiveTimeout)
defer timer.Stop()

concatProof := bytes.Join(n.roundProofs.proofs[proofMessage.RoundID], nil)
proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof}
select {
case <-timer.C:
n.roundProofs.mu.Lock()
defer n.roundProofs.mu.Unlock()

err := PublishGlobalAggregateAndProof(ctx, n.Name, globalAggregate, proof)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to publish global aggregate and proof")
if !n.roundProofs.locked[proofMessage.RoundID] && len(n.roundProofs.proofs[proofMessage.RoundID]) >= (n.Raft.SubscribersCount()+1)/2 {
log.Debug().Str("Player", "Aggregator").Int32("roundId", proofMessage.RoundID).Msg("timeout reached, processing available proofs")
err := n.processCollectedProofs(ctx, proofMessage)
if err != nil {
log.Error().Err(err).Int32("roundId", proofMessage.RoundID).Msg("failed to process collected proofs")
}
}
case <-ctx.Done():
log.Debug().Str("Player", "Aggregator").Int32("roundId", proofMessage.RoundID).Msg("context canceled, stopping timeout")
return
}
}

func (n *Aggregator) processCollectedProofs(ctx context.Context, proofMessage ProofMessage) error {
n.roundProofs.locked[proofMessage.RoundID] = true
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Int("peerCount", n.Raft.SubscribersCount()).Int32("roundId", proofMessage.RoundID).Any("collected proofs", n.roundProofs.proofs[proofMessage.RoundID]).Msg("collected proofs")

globalAggregate := GlobalAggregate{
ConfigID: n.ID,
Value: proofMessage.Value,
Round: proofMessage.RoundID,
Timestamp: proofMessage.Timestamp}

concatProof := bytes.Join(n.roundProofs.proofs[proofMessage.RoundID], nil)
proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof}

err := PublishGlobalAggregateAndProof(ctx, n.Name, globalAggregate, proof)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to publish global aggregate and proof")
return err
}

return nil
}

func (n *Aggregator) PublishTriggerMessage(ctx context.Context, roundId int32, timestamp time.Time) error {
triggerMessage := TriggerMessage{
LeaderID: n.Raft.GetHostId(),
Expand Down

0 comments on commit 214e0f0

Please sign in to comment.