Skip to content

Commit

Permalink
feat: optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Aug 1, 2024
1 parent 738a413 commit c309fcb
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 218 deletions.
241 changes: 47 additions & 194 deletions node/pkg/aggregator/aggregator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package aggregator

import (
"bytes"
"context"
"encoding/json"
"sync"
Expand Down Expand Up @@ -32,18 +31,15 @@ func NewAggregator(h host.Host, ps *pubsub.PubSub, topicString string, config Co
aggregateInterval := time.Duration(config.AggregateInterval) * time.Millisecond

aggregator := Aggregator{
Config: config,
Raft: raft.NewRaftNode(h, ps, topic, 100, aggregateInterval),
CollectedPrices: map[int32][]int64{},
CollectedProofs: map[int32][][]byte{},
CollectedAgreements: map[int32][]bool{},
PreparedLocalAggregates: map[int32]int64{},
PreparedGlobalAggregates: map[int32]GlobalAggregate{},
SyncedTimes: map[int32]time.Time{},
AggregatorMutex: sync.Mutex{},
RoundID: 1,
Signer: signHelper,
LatestLocalAggregates: latestLocalAggregates,
Config: config,
Raft: raft.NewRaftNode(h, ps, topic, 100, aggregateInterval),

roundPrices: &RoundPrices{prices: map[int32][]int64{}},
roundProofs: &RoundProofs{proofs: map[int32][][]byte{}},

RoundID: 1,
Signer: signHelper,
LatestLocalAggregates: latestLocalAggregates,
}
aggregator.Raft.LeaderJob = aggregator.LeaderJob
aggregator.Raft.HandleCustomMessage = aggregator.HandleCustomMessage
Expand All @@ -65,15 +61,11 @@ func (n *Aggregator) Run(ctx context.Context) {
func (n *Aggregator) LeaderJob() error {
n.RoundID++
n.Raft.IncreaseTerm()
return n.PublishSyncMessage(n.RoundID, time.Now())
return n.PublishTriggerMessage(n.RoundID, time.Now())
}

func (n *Aggregator) HandleCustomMessage(ctx context.Context, message raft.Message) error {
switch message.Type {
case RoundSync:
return n.HandleRoundSyncMessage(ctx, message)
case SyncReply:
return n.HandleSyncReplyMessage(ctx, message)
case Trigger:
return n.HandleTriggerMessage(ctx, message)
case PriceData:
Expand All @@ -85,28 +77,26 @@ func (n *Aggregator) HandleCustomMessage(ctx context.Context, message raft.Messa
}
}

func (n *Aggregator) HandleRoundSyncMessage(ctx context.Context, msg raft.Message) error {
var roundSyncMessage RoundSyncMessage
err := json.Unmarshal(msg.Data, &roundSyncMessage)
func (n *Aggregator) HandleTriggerMessage(ctx context.Context, msg raft.Message) error {
var triggerMessage TriggerMessage
err := json.Unmarshal(msg.Data, &triggerMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal round sync message")
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal trigger message")
return err
}

if msg.SentFrom != n.Raft.GetLeader() {
log.Warn().Str("Player", "Aggregator").Msg("round sync message sent from non-leader")
return errorSentinel.ErrAggregatorNonLeaderRaftMessage
}

if roundSyncMessage.LeaderID == "" || roundSyncMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid round sync message")
if triggerMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid trigger message")
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

if n.Raft.GetRole() != raft.Leader {
n.RoundID = roundSyncMessage.RoundID
if msg.SentFrom != n.Raft.GetLeader() {
log.Warn().Str("Player", "Aggregator").Msg("trigger message sent from non-leader")
return errorSentinel.ErrAggregatorNonLeaderRaftMessage
}

defer n.cleanUpRoundData(triggerMessage.RoundID - 10)

var value int64
localAggregateRaw, ok := n.LatestLocalAggregates.Load(n.ID)
if !ok {
Expand All @@ -120,82 +110,7 @@ func (n *Aggregator) HandleRoundSyncMessage(ctx context.Context, msg raft.Messag
value = localAggregate.Value
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()
// run cleanup to prevent memory leak
// removes data 10 rounds ago, approximately 4 seconds old data
n.cleanUpRoundData(roundSyncMessage.RoundID - 10)

n.PreparedLocalAggregates[roundSyncMessage.RoundID] = value
n.SyncedTimes[roundSyncMessage.RoundID] = roundSyncMessage.Timestamp
return n.PublishSyncReplyMessage(roundSyncMessage.RoundID, true)
}

func (n *Aggregator) HandleSyncReplyMessage(ctx context.Context, msg raft.Message) error {
if n.Raft.GetRole() != raft.Leader {
log.Debug().Str("Player", "Aggregator").Msg("received sync reply message while not leader")
return nil
}

var syncReplyMessage SyncReplyMessage
err := json.Unmarshal(msg.Data, &syncReplyMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal sync reply message")
return err
}

if syncReplyMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid sync reply message")
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()

if _, ok := n.CollectedAgreements[syncReplyMessage.RoundID]; !ok {
n.CollectedAgreements[syncReplyMessage.RoundID] = []bool{}
}

n.CollectedAgreements[syncReplyMessage.RoundID] = append(n.CollectedAgreements[syncReplyMessage.RoundID], syncReplyMessage.Agreed)
if len(n.CollectedAgreements[syncReplyMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
defer delete(n.CollectedAgreements, syncReplyMessage.RoundID)
agreeCount := 0
for _, agreed := range n.CollectedAgreements[syncReplyMessage.RoundID] {
if agreed {
agreeCount++
}
}
requiredAgreements := int(float64(n.Raft.SubscribersCount()) * AGREEMENT_QUORUM)
if agreeCount >= requiredAgreements {
return n.PublishTriggerMessage(syncReplyMessage.RoundID)
} else {
log.Warn().Str("Player", "Aggregator").Int("agreeCount", agreeCount).Int("requiredAgreements", requiredAgreements).Msg("not enough agreements, resigning as leader")
n.Raft.ResignLeader()
return nil
}
}
return nil
}

func (n *Aggregator) HandleTriggerMessage(ctx context.Context, msg raft.Message) error {
var triggerMessage TriggerMessage
err := json.Unmarshal(msg.Data, &triggerMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to unmarshal trigger message")
return err
}

if triggerMessage.RoundID == 0 {
log.Error().Str("Player", "Aggregator").Msg("invalid trigger message")
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

if msg.SentFrom != n.Raft.GetLeader() {
log.Warn().Str("Player", "Aggregator").Msg("trigger message sent from non-leader")
return errorSentinel.ErrAggregatorNonLeaderRaftMessage
}
defer delete(n.PreparedLocalAggregates, triggerMessage.RoundID)
return n.PublishPriceDataMessage(triggerMessage.RoundID, n.PreparedLocalAggregates[triggerMessage.RoundID])
return n.PublishPriceDataMessage(triggerMessage.RoundID, value, triggerMessage.Timestamp)
}

func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Message) error {
Expand All @@ -211,19 +126,13 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()
if _, ok := n.CollectedPrices[priceDataMessage.RoundID]; !ok {
n.CollectedPrices[priceDataMessage.RoundID] = []int64{}
}

n.CollectedPrices[priceDataMessage.RoundID] = append(n.CollectedPrices[priceDataMessage.RoundID], priceDataMessage.PriceData)
if len(n.CollectedPrices[priceDataMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Any("collected prices", n.CollectedPrices[priceDataMessage.RoundID]).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices")
defer delete(n.CollectedPrices, priceDataMessage.RoundID)
defer delete(n.SyncedTimes, priceDataMessage.RoundID)
filteredCollectedPrices := FilterNegative(n.CollectedPrices[priceDataMessage.RoundID])
n.roundPrices.push(priceDataMessage.RoundID, priceDataMessage.PriceData)
if n.roundPrices.len(priceDataMessage.RoundID) >= n.Raft.SubscribersCount()+1 {
prices := n.roundPrices.snapshot(priceDataMessage.RoundID)
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Any("collected prices", prices).Int32("roundId", priceDataMessage.RoundID).Msg("collected prices")
defer n.roundPrices.delete(priceDataMessage.RoundID)

filteredCollectedPrices := FilterNegative(prices)
if len(filteredCollectedPrices) == 0 {
log.Warn().Str("Player", "Aggregator").Str("Name", n.Name).Int32("roundId", priceDataMessage.RoundID).Msg("no prices collected")
return nil
Expand All @@ -235,19 +144,13 @@ func (n *Aggregator) HandlePriceDataMessage(ctx context.Context, msg raft.Messag
return err
}
log.Debug().Str("Player", "Aggregator").Str("Name", n.Name).Any("filtered collected prices", filteredCollectedPrices).Int32("roundId", priceDataMessage.RoundID).Int64("global_aggregate", median).Msg("global aggregated")
n.PreparedGlobalAggregates[priceDataMessage.RoundID] = GlobalAggregate{
ConfigID: n.ID,
Value: median,
Round: priceDataMessage.RoundID,
Timestamp: n.SyncedTimes[priceDataMessage.RoundID],
}

proof, err := n.Signer.MakeGlobalAggregateProof(median, n.SyncedTimes[priceDataMessage.RoundID], n.Name)
proof, err := n.Signer.MakeGlobalAggregateProof(median, priceDataMessage.Timestamp, n.Name)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to make global aggregate proof")
return err
}
return n.PublishProofMessage(priceDataMessage.RoundID, proof)
return n.PublishProofMessage(priceDataMessage.RoundID, median, proof, priceDataMessage.Timestamp)
}
return nil
}
Expand All @@ -265,22 +168,15 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e
return errorSentinel.ErrAggregatorInvalidRaftMessage
}

n.AggregatorMutex.Lock()
defer n.AggregatorMutex.Unlock()
if _, ok := n.CollectedProofs[proofMessage.RoundID]; !ok {
n.CollectedProofs[proofMessage.RoundID] = [][]byte{}
}

n.CollectedProofs[proofMessage.RoundID] = append(n.CollectedProofs[proofMessage.RoundID], proofMessage.Proof)
if len(n.CollectedProofs[proofMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
defer delete(n.CollectedProofs, proofMessage.RoundID)
defer delete(n.PreparedGlobalAggregates, proofMessage.RoundID)
n.roundProofs.push(proofMessage.RoundID, proofMessage.Proof)
if n.roundProofs.len(proofMessage.RoundID) >= n.Raft.SubscribersCount()+1 {
defer n.roundProofs.delete(proofMessage.RoundID)
globalAggregate := GlobalAggregate{
ConfigID: n.ID,
Value: n.PreparedGlobalAggregates[proofMessage.RoundID].Value,
Value: proofMessage.Value,
Round: proofMessage.RoundID,
Timestamp: n.PreparedGlobalAggregates[proofMessage.RoundID].Timestamp}
concatProof := bytes.Join(n.CollectedProofs[proofMessage.RoundID], nil)
Timestamp: proofMessage.Timestamp}
concatProof := n.roundProofs.concat(proofMessage.RoundID)
proof := Proof{ConfigID: n.ID, Round: proofMessage.RoundID, Proof: concatProof}

err := PublishGlobalAggregateAndProof(ctx, globalAggregate, proof)
Expand All @@ -292,55 +188,13 @@ func (n *Aggregator) HandleProofMessage(ctx context.Context, msg raft.Message) e
return nil
}

func (n *Aggregator) PublishSyncMessage(roundId int32, timestamp time.Time) error {
roundMessage := RoundSyncMessage{
func (n *Aggregator) PublishTriggerMessage(roundId int32, timestamp time.Time) error {
triggerMessage := TriggerMessage{
LeaderID: n.Raft.GetHostId(),
RoundID: roundId,
Timestamp: timestamp,
}

marshalledRoundMessage, err := json.Marshal(roundMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to marshal round message")
return err
}

message := raft.Message{
Type: RoundSync,
SentFrom: n.Raft.GetHostId(),
Data: json.RawMessage(marshalledRoundMessage),
}

return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishSyncReplyMessage(roundId int32, agreed bool) error {
syncReplyMessage := SyncReplyMessage{
RoundID: roundId,
Agreed: agreed,
}

marshalledSyncReplyMessage, err := json.Marshal(syncReplyMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to marshal sync reply message")
return err
}

message := raft.Message{
Type: SyncReply,
SentFrom: n.Raft.GetHostId(),
Data: json.RawMessage(marshalledSyncReplyMessage),
}

return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishTriggerMessage(roundId int32) error {
triggerMessage := TriggerMessage{
LeaderID: n.Raft.GetHostId(),
RoundID: roundId,
}

marshalledTriggerMessage, err := json.Marshal(triggerMessage)
if err != nil {
log.Error().Str("Player", "Aggregator").Err(err).Msg("failed to marshal trigger message")
Expand All @@ -356,10 +210,11 @@ func (n *Aggregator) PublishTriggerMessage(roundId int32) error {
return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishPriceDataMessage(roundId int32, value int64) error {
func (n *Aggregator) PublishPriceDataMessage(roundId int32, value int64, timestamp time.Time) error {
priceDataMessage := PriceDataMessage{
RoundID: roundId,
PriceData: value,
Timestamp: timestamp,
}

marshalledPriceDataMessage, err := json.Marshal(priceDataMessage)
Expand All @@ -377,10 +232,12 @@ func (n *Aggregator) PublishPriceDataMessage(roundId int32, value int64) error {
return n.Raft.PublishMessage(message)
}

func (n *Aggregator) PublishProofMessage(roundId int32, proof []byte) error {
func (n *Aggregator) PublishProofMessage(roundId int32, value int64, proof []byte, timestamp time.Time) error {
proofMessage := ProofMessage{
RoundID: roundId,
Proof: proof,
RoundID: roundId,
Value: value,
Proof: proof,
Timestamp: timestamp,
}

marshalledProofMessage, err := json.Marshal(proofMessage)
Expand All @@ -399,10 +256,6 @@ func (n *Aggregator) PublishProofMessage(roundId int32, proof []byte) error {
}

func (n *Aggregator) cleanUpRoundData(roundId int32) {
delete(n.CollectedPrices, roundId)
delete(n.CollectedProofs, roundId)
delete(n.CollectedAgreements, roundId)
delete(n.PreparedLocalAggregates, roundId)
delete(n.PreparedGlobalAggregates, roundId)
delete(n.SyncedTimes, roundId)
n.roundPrices.delete(roundId)
n.roundProofs.delete(roundId)
}
Loading

0 comments on commit c309fcb

Please sign in to comment.