Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(OraklNode) Optimize aggregator #1988

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 47 additions & 195 deletions node/pkg/aggregator/aggregator.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package aggregator

import (
"bytes"
"context"
"encoding/json"
"sync"

"time"

Expand All @@ -31,18 +29,15 @@ func NewAggregator(h host.Host, ps *pubsub.PubSub, topicString string, config Co
aggregateInterval := time.Duration(config.AggregateInterval) * time.Millisecond

aggregator := Aggregator{
Config: config,
Raft: raft.NewRaftNode(h, ps, topic, 100, aggregateInterval),
CollectedPrices: map[int32][]int64{},
CollectedProofs: map[int32][][]byte{},
CollectedAgreements: map[int32][]bool{},
PreparedLocalAggregates: map[int32]int64{},
PreparedGlobalAggregates: map[int32]GlobalAggregate{},
SyncedTimes: map[int32]time.Time{},
AggregatorMutex: sync.Mutex{},
RoundID: 1,
Signer: signHelper,
LatestLocalAggregates: latestLocalAggregates,
Config: config,
Raft: raft.NewRaftNode(h, ps, topic, 100, aggregateInterval),

roundPrices: &RoundPrices{prices: map[int32][]int64{}},
roundProofs: &RoundProofs{proofs: map[int32][][]byte{}},

RoundID: 1,
Signer: signHelper,
LatestLocalAggregates: latestLocalAggregates,
}
aggregator.Raft.LeaderJob = aggregator.LeaderJob
aggregator.Raft.HandleCustomMessage = aggregator.HandleCustomMessage
Expand All @@ -64,15 +59,11 @@ func (n *Aggregator) Run(ctx context.Context) {
func (n *Aggregator) LeaderJob() error {
n.RoundID++
n.Raft.IncreaseTerm()
return n.PublishSyncMessage(n.RoundID, time.Now())
return n.PublishTriggerMessage(n.RoundID, time.Now())
}

func (n *Aggregator) HandleCustomMessage(ctx context.Context, message raft.Message) error {
switch message.Type {
case RoundSync:
return n.HandleRoundSyncMessage(ctx, message)
case SyncReply:
return n.HandleSyncReplyMessage(ctx, message)
case Trigger:
return n.HandleTriggerMessage(ctx, message)
case PriceData:
Expand All @@ -84,28 +75,26 @@ func (n *Aggregator) HandleCustomMessage(ctx context.Context, message raft.Messa
}
}

func (n *Aggregator) HandleRoundSyncMessage(ctx context.Context, msg raft.Message) error {
var roundSyncMessage RoundSyncMessage
err := json.Unmarshal(msg.Data, &roundSyncMessage)
func (n *Aggregator) HandleTriggerMessage(ctx context.Context, msg raft.Message) error {
var triggerMessage TriggerMessage
err := json.Unmarshal(msg.Data, &triggerMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal round sync message")
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal trigger message")
return err
}

if msg.SentFrom != n.Raft.GetLeader() {
log.Warn().Str("Player", "Aggregator").Msg("round sync message sent from non-leader")
return errorSentinel.ErrAggregatorNonLeaderRaftMessage
}

if roundSyncMessage.LeaderID == "" || roundSyncMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid round sync message")
if triggerMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid trigger message")
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

if n.Raft.GetRole() != raft.Leader {
n.RoundID = roundSyncMessage.RoundID
if msg.SentFrom != n.Raft.GetLeader() {
log.Warn().Str("Player", "Aggregator").Msg("trigger message sent from non-leader")
return errorSentinel.ErrAggregatorNonLeaderRaftMessage
}

defer n.cleanUpRoundData(triggerMessage.RoundID - 10)

var value int64
localAggregate, ok := n.LatestLocalAggregates.Load(n.ID)
if !ok {
Expand All @@ -118,82 +107,7 @@ func (n *Aggregator) HandleRoundSyncMessage(ctx context.Context, msg raft.Messag
value = localAggregate.Value
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()
// run cleanup to prevent memory leak
// removes data 10 rounds ago, approximately 4 seconds old data
n.cleanUpRoundData(roundSyncMessage.RoundID - 10)

n.PreparedLocalAggregates[roundSyncMessage.RoundID] = value
n.SyncedTimes[roundSyncMessage.RoundID] = roundSyncMessage.Timestamp
return n.PublishSyncReplyMessage(roundSyncMessage.RoundID, true)
}

func (n *Aggregator) HandleSyncReplyMessage(ctx context.Context, msg raft.Message) error {
if n.Raft.GetRole() != raft.Leader {
log.Debug().Str("Player", "Aggregator").Msg("received sync reply message while not leader")
return nil
}

var syncReplyMessage SyncReplyMessage
err := json.Unmarshal(msg.Data, &syncReplyMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal sync reply message")
return err
}

if syncReplyMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid sync reply message")
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()

if _, ok := n.CollectedAgreements[syncReplyMessage.RoundID]; !ok {
n.CollectedAgreements[syncReplyMessage.RoundID] = []bool{}
}

n.CollectedAgreements[syncReplyMessage.RoundID] = append(n.CollectedAgreements[syncReplyMessage.RoundID], syncReplyMessage.Agreed)
if len(n.CollectedAgreements[syncReplyMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
defer delete(n.CollectedAgreements, syncReplyMessage.RoundID)
agreeCount := 0
for _, agreed := range n.CollectedAgreements[syncReplyMessage.RoundID] {
if agreed {
agreeCount++
}
}
requiredAgreements := int(float64(n.Raft.SubscribersCount()) * AGREEMENT_QUORUM)
if agreeCount >= requiredAgreements {
return n.PublishTriggerMessage(syncReplyMessage.RoundID)
} else {
log.Warn().Str("Player", "Aggregator").Int("agreeCount", agreeCount).Int("requiredAgreements", requiredAgreements).Msg("not enough agreements, resigning as leader")
n.Raft.ResignLeader()
return nil
}
}
return nil
}

func (n *Aggregator) HandleTriggerMessage(ctx context.Context, msg raft.Message) error {
var triggerMessage TriggerMessage
err := json.Unmarshal(msg.Data, &triggerMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal trigger message")
return err
}

if triggerMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid trigger message")
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

if msg.SentFrom != n.Raft.GetLeader() {
log.Warn().Str("Player", "Aggregator").Msg("trigger message sent from non-leader")
return errorSentinel.ErrAggregatorNonLeaderRaftMessage
}
defer delete(n.PreparedLocalAggregates, triggerMessage.RoundID)
return n.PublishPriceDataMessage(triggerMessage.RoundID, n.PreparedLocalAggregates[triggerMessage.RoundID])
return n.PublishPriceDataMessage(triggerMessage.RoundID, value, triggerMessage.Timestamp)
}

func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Message) error {
Expand All @@ -209,19 +123,13 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()
if _, ok := n.CollectedPrices[priceDataMessage.RoundID]; !ok {
n.CollectedPrices[priceDataMessage.RoundID] = []int64{}
}

n.CollectedPrices[priceDataMessage.RoundID] = append(n.CollectedPrices[priceDataMessage.RoundID], priceDataMessage.PriceData)
if len(n.CollectedPrices[priceDataMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Any("collected prices", n.CollectedPrices[priceDataMessage.RoundID]).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices")
defer delete(n.CollectedPrices, priceDataMessage.RoundID)
defer delete(n.SyncedTimes, priceDataMessage.RoundID)
filteredCollectedPrices := FilterNegative(n.CollectedPrices[priceDataMessage.RoundID])
n.roundPrices.push(priceDataMessage.RoundID, priceDataMessage.PriceData)
if n.roundPrices.len(priceDataMessage.RoundID) >= n.Raft.SubscribersCount()+1 {
prices := n.roundPrices.snapshot(priceDataMessage.RoundID)
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices")
defer n.roundPrices.delete(priceDataMessage.RoundID)

filteredCollectedPrices := FilterNegative(prices)
if len(filteredCollectedPrices) == 0 {
log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", priceDataMessage.RoundID).Msg("no prices collected")
return nil
Expand All @@ -233,19 +141,13 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag
return err
}
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Any("filtered collected prices", filteredCollectedPrices).Int32("roundId", priceDataMessage.RoundID).Int64("global_aggregate", median).Msg("global aggregated")
n.PreparedGlobalAggregates[priceDataMessage.RoundID] = GlobalAggregate{
ConfigID: n.ID,
Value: median,
Round: priceDataMessage.RoundID,
Timestamp: n.SyncedTimes[priceDataMessage.RoundID],
}

proof, err := n.Signer.MakeGlobalAggregateProof(median, n.SyncedTimes[priceDataMessage.RoundID], n.Name)
proof, err := n.Signer.MakeGlobalAggregateProof(median, priceDataMessage.Timestamp, n.Name)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to make global aggregate proof")
return err
}
return n.PublishProofMessage(priceDataMessage.RoundID, proof)
return n.PublishProofMessage(priceDataMessage.RoundID, median, proof, priceDataMessage.Timestamp)
}
return nil
}
Expand All @@ -263,22 +165,15 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()
if _, ok := n.CollectedProofs[proofMessage.RoundID]; !ok {
n.CollectedProofs[proofMessage.RoundID] = [][]byte{}
}

n.CollectedProofs[proofMessage.RoundID] = append(n.CollectedProofs[proofMessage.RoundID], proofMessage.Proof)
if len(n.CollectedProofs[proofMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
defer delete(n.CollectedProofs, proofMessage.RoundID)
defer delete(n.PreparedGlobalAggregates, proofMessage.RoundID)
n.roundProofs.push(proofMessage.RoundID, proofMessage.Proof)
if n.roundProofs.len(proofMessage.RoundID) >= n.Raft.SubscribersCount()+1 {
defer n.roundProofs.delete(proofMessage.RoundID)
globalAggregate := GlobalAggregate{
ConfigID: n.ID,
Value: n.PreparedGlobalAggregates[proofMessage.RoundID].Value,
Value: proofMessage.Value,
Round: proofMessage.RoundID,
Timestamp: n.PreparedGlobalAggregates[proofMessage.RoundID].Timestamp}
concatProof := bytes.Join(n.CollectedProofs[proofMessage.RoundID], nil)
Timestamp: proofMessage.Timestamp}
concatProof := n.roundProofs.concat(proofMessage.RoundID)
proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof}

err := PublishGlobalAggregateAndProof(ctx, globalAggregate, proof)
Expand All @@ -290,55 +185,13 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e
return nil
}

func (n *Aggregator) PublishSyncMessage(roundId int32, timestamp time.Time) error {
roundMessage := RoundSyncMessage{
func (n *Aggregator) PublishTriggerMessage(roundId int32, timestamp time.Time) error {
triggerMessage := TriggerMessage{
LeaderID: n.Raft.GetHostId(),
RoundID: roundId,
Timestamp: timestamp,
}

marshalledRoundMessage, err := json.Marshal(roundMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to marshal round message")
return err
}

message := raft.Message{
Type: RoundSync,
SentFrom: n.Raft.GetHostId(),
Data: json.RawMessage(marshalledRoundMessage),
}

return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishSyncReplyMessage(roundId int32, agreed bool) error {
syncReplyMessage := SyncReplyMessage{
RoundID: roundId,
Agreed: agreed,
}

marshalledSyncReplyMessage, err := json.Marshal(syncReplyMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to marshal sync reply message")
return err
}

message := raft.Message{
Type: SyncReply,
SentFrom: n.Raft.GetHostId(),
Data: json.RawMessage(marshalledSyncReplyMessage),
}

return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishTriggerMessage(roundId int32) error {
triggerMessage := TriggerMessage{
LeaderID: n.Raft.GetHostId(),
RoundID: roundId,
}

marshalledTriggerMessage, err := json.Marshal(triggerMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to marshal trigger message")
Expand All @@ -354,10 +207,11 @@ func (n *Aggregator) PublishTriggerMessage(roundId int32) error {
return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishPriceDataMessage(roundId int32, value int64) error {
func (n *Aggregator) PublishPriceDataMessage(roundId int32, value int64, timestamp time.Time) error {
priceDataMessage := PriceDataMessage{
RoundID: roundId,
PriceData: value,
Timestamp: timestamp,
}

marshalledPriceDataMessage, err := json.Marshal(priceDataMessage)
Expand All @@ -375,10 +229,12 @@ func (n *Aggregator) PublishPriceDataMessage(roundId int32, value int64) error {
return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishProofMessage(roundId int32, proof []byte) error {
func (n *Aggregator) PublishProofMessage(roundId int32, value int64, proof []byte, timestamp time.Time) error {
proofMessage := ProofMessage{
RoundID: roundId,
Proof: proof,
RoundID: roundId,
Value: value,
Proof: proof,
Timestamp: timestamp,
}

marshalledProofMessage, err := json.Marshal(proofMessage)
Expand All @@ -397,10 +253,6 @@ func (n *Aggregator) PublishProofMessage(roundId int32, proof []byte) error {
}

func (n *Aggregator) cleanUpRoundData(roundId int32) {
delete(n.CollectedPrices, roundId)
delete(n.CollectedProofs, roundId)
delete(n.CollectedAgreements, roundId)
delete(n.PreparedLocalAggregates, roundId)
delete(n.PreparedGlobalAggregates, roundId)
delete(n.SyncedTimes, roundId)
n.roundPrices.delete(roundId)
n.roundProofs.delete(roundId)
}
Loading