Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Node) Timeout for aggregator msg handling #2292

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 112 additions & 42 deletions node/pkg/aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/rs/zerolog/log"
)

const maxLeaderMsgReceiveTimeout = 100 * time.Millisecond

nick-bisonai marked this conversation as resolved.
Show resolved Hide resolved
func NewAggregator(h host.Host, ps *pubsub.PubSub, topicString string, config Config, signHelper *helper.Signer, latestLocalAggregates *LatestLocalAggregates) (*Aggregator, error) {
if h == nil || ps == nil || topicString == "" {
return nil, errorSentinel.ErrAggregatorInvalidInitValue
Expand Down Expand Up @@ -165,37 +167,72 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag
return nil
}

if prices, ok := n.roundPrices.prices[priceDataMessage.RoundID]; ok {
n.roundPrices.prices[priceDataMessage.RoundID] = append(prices, priceDataMessage.PriceData)
n.roundPrices.senders[priceDataMessage.RoundID] = append(n.roundPrices.senders[priceDataMessage.RoundID], msg.SentFrom)
} else {
n.roundPrices.prices[priceDataMessage.RoundID] = []int64{priceDataMessage.PriceData}
n.roundPrices.senders[priceDataMessage.RoundID] = []string{msg.SentFrom}
}
n.storeRoundPriceData(priceDataMessage.RoundID, priceDataMessage.PriceData, msg.SentFrom)

if len(n.roundPrices.prices[priceDataMessage.RoundID]) == n.Raft.SubscribersCount()+1 {
n.roundPrices.locked[priceDataMessage.RoundID] = true
// if all messsages received for the round
return n.processCollectedPrices(ctx, priceDataMessage.RoundID, priceDataMessage.Timestamp)
} else if len(n.roundPrices.prices[priceDataMessage.RoundID]) == 1 {
// if it's first message for the round
go n.startPriceCollectionTimeout(ctx, priceDataMessage.RoundID, priceDataMessage.Timestamp)
}

return nil
}

func (n *Aggregator) storeRoundPriceData(roundID int32, priceData int64, sender string) {
if prices, ok := n.roundPrices.prices[roundID]; ok {
n.roundPrices.prices[roundID] = append(prices, priceData)
n.roundPrices.senders[roundID] = append(n.roundPrices.senders[roundID], sender)
} else {
n.roundPrices.prices[roundID] = []int64{priceData}
n.roundPrices.senders[roundID] = []string{sender}
}
}

if n.Raft.GetRole() == raft.Leader {
prices := n.roundPrices.prices[priceDataMessage.RoundID]
log.Debug().Str("Player", "Aggregator").Int("peerCount", n.Raft.SubscribersCount()).Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices")
func (n *Aggregator) startPriceCollectionTimeout(ctx context.Context, roundID int32, timestamp time.Time) {
timer := time.NewTimer(maxLeaderMsgReceiveTimeout)
defer timer.Stop()

filteredCollectedPrices := FilterNegative(prices)
if len(filteredCollectedPrices) == 0 {
log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", priceDataMessage.RoundID).Msg("no prices collected")
return nil
}
select {
case <-timer.C:
n.roundPrices.mu.Lock()
defer n.roundPrices.mu.Unlock()

median, err := calculator.GetInt64Med(filteredCollectedPrices)
if !n.roundPrices.locked[roundID] && len(n.roundPrices.prices[roundID]) >= (n.Raft.SubscribersCount()+1)/2 {
log.Debug().Str("Player", "Aggregator").Int32("roundId", roundID).Msg("timeout reached, processing available prices")
err := n.processCollectedPrices(ctx, roundID, timestamp)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to get median")
return err
log.Error().Err(err).Int32("roundId", roundID).Msg("failed to process collected prices")
}

return n.PublishPriceFixMessage(ctx, priceDataMessage.RoundID, median, priceDataMessage.Timestamp)
}
case <-ctx.Done():
return
}
return nil
}
nick-bisonai marked this conversation as resolved.
Show resolved Hide resolved

func (n *Aggregator) processCollectedPrices(ctx context.Context, roundID int32, timestamp time.Time) error {
n.roundPrices.locked[roundID] = true
if n.Raft.GetRole() != raft.Leader {
return nil
}

prices := n.roundPrices.prices[roundID]
log.Debug().Str("Player", "Aggregator").Int("peerCount", n.Raft.SubscribersCount()).Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", roundID).Msg("collected prices")

filteredCollectedPrices := FilterNegative(prices)
if len(filteredCollectedPrices) == 0 {
log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", roundID).Msg("no prices collected")
return nil
}

median, err := calculator.GetInt64Med(filteredCollectedPrices)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to get median")
return err
}

return n.PublishPriceFixMessage(ctx, roundID, median, timestamp)
nick-bisonai marked this conversation as resolved.
Show resolved Hide resolved
}

func (n *Aggregator) HandlePriceFixMessage(ctx context.Context, msg raft.Message) error {
Expand Down Expand Up @@ -257,37 +294,70 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e
return nil
}

if proofs, ok := n.roundProofs.proofs[proofMessage.RoundID]; ok {
n.roundProofs.proofs[proofMessage.RoundID] = append(proofs, proofMessage.Proof)
n.roundProofs.senders[proofMessage.RoundID] = append(n.roundProofs.senders[proofMessage.RoundID], msg.SentFrom)
} else {
n.roundProofs.proofs[proofMessage.RoundID] = [][]byte{proofMessage.Proof}
n.roundProofs.senders[proofMessage.RoundID] = []string{msg.SentFrom}
}
n.storeRoundProofData(proofMessage.RoundID, proofMessage.Proof, msg.SentFrom)

if len(n.roundProofs.proofs[proofMessage.RoundID]) == n.Raft.SubscribersCount()+1 {
n.roundProofs.locked[proofMessage.RoundID] = true
return n.processCollectedProofs(ctx, proofMessage)
} else if len(n.roundProofs.proofs[proofMessage.RoundID]) == 1 {
go n.startProofCollectionTimeout(ctx, proofMessage)
}

log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Int("peerCount", n.Raft.SubscribersCount()).Int32("roundId", proofMessage.RoundID).Any("collected proofs", n.roundProofs.proofs[proofMessage.RoundID]).Msg("collected proofs")
return nil
}

func (n *Aggregator) storeRoundProofData(roundID int32, proofData []byte, sender string) {
if proofs, ok := n.roundProofs.proofs[roundID]; ok {
n.roundProofs.proofs[roundID] = append(proofs, proofData)
n.roundProofs.senders[roundID] = append(n.roundProofs.senders[roundID], sender)
} else {
n.roundProofs.proofs[roundID] = [][]byte{proofData}
n.roundProofs.senders[roundID] = []string{sender}
}
}

globalAggregate := GlobalAggregate{
ConfigID: n.ID,
Value: proofMessage.Value,
Round: proofMessage.RoundID,
Timestamp: proofMessage.Timestamp}
func (n *Aggregator) startProofCollectionTimeout(ctx context.Context, proofMessage ProofMessage) {
timer := time.NewTimer(maxLeaderMsgReceiveTimeout)
defer timer.Stop()

concatProof := bytes.Join(n.roundProofs.proofs[proofMessage.RoundID], nil)
proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof}
select {
case <-timer.C:
n.roundProofs.mu.Lock()
defer n.roundProofs.mu.Unlock()

err := PublishGlobalAggregateAndProof(ctx, n.Name, globalAggregate, proof)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to publish global aggregate and proof")
if !n.roundProofs.locked[proofMessage.RoundID] && len(n.roundProofs.proofs[proofMessage.RoundID]) >= (n.Raft.SubscribersCount()+1)/2 {
log.Debug().Str("Player", "Aggregator").Int32("roundId", proofMessage.RoundID).Msg("timeout reached, processing available proofs")
err := n.processCollectedProofs(ctx, proofMessage)
if err != nil {
log.Error().Err(err).Int32("roundId", proofMessage.RoundID).Msg("failed to process collected proofs")
}
}
case <-ctx.Done():
log.Debug().Str("Player", "Aggregator").Int32("roundId", proofMessage.RoundID).Msg("context canceled, stopping timeout")
return
}
}

func (n *Aggregator) processCollectedProofs(ctx context.Context, proofMessage ProofMessage) error {
n.roundProofs.locked[proofMessage.RoundID] = true
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Int("peerCount", n.Raft.SubscribersCount()).Int32("roundId", proofMessage.RoundID).Any("collected proofs", n.roundProofs.proofs[proofMessage.RoundID]).Msg("collected proofs")

globalAggregate := GlobalAggregate{
ConfigID: n.ID,
Value: proofMessage.Value,
Round: proofMessage.RoundID,
Timestamp: proofMessage.Timestamp}

concatProof := bytes.Join(n.roundProofs.proofs[proofMessage.RoundID], nil)
proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof}

err := PublishGlobalAggregateAndProof(ctx, n.Name, globalAggregate, proof)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to publish global aggregate and proof")
return err
}

return nil
}

func (n *Aggregator) PublishTriggerMessage(ctx context.Context, roundId int32, timestamp time.Time) error {
triggerMessage := TriggerMessage{
LeaderID: n.Raft.GetHostId(),
Expand Down