Skip to content

Commit

Permalink
feat: add phase, remove logs
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Mar 2, 2024
1 parent a24f85f commit 62c6b91
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 31 deletions.
9 changes: 0 additions & 9 deletions node/pkg/aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,11 @@ func (a *App) startAggregator(ctx context.Context, aggregator *AggregatorNode) e
return errors.New("aggregator already running")
}

latestGlobalAggregate, err := db.QueryRow[globalAggregate](ctx, SelectLatestGlobalAggregateQuery, map[string]interface{}{"name": aggregator.Name})
if err != nil {
log.Error().Err(err).Msg("failed to get latest global aggregate")
return err
}
nodeCtx, cancel := context.WithCancel(ctx)
aggregator.nodeCtx = nodeCtx
aggregator.nodeCancel = cancel
aggregator.isRunning = true

if latestGlobalAggregate.Round > 0 {
aggregator.RoundID = latestGlobalAggregate.Round
}

aggregator.Run(ctx)
return nil
}
Expand Down
110 changes: 99 additions & 11 deletions node/pkg/aggregator/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aggregator
import (
"context"
"encoding/json"
"os"
"sync"

"time"
Expand Down Expand Up @@ -41,6 +42,7 @@ func NewNode(h host.Host, ps *pubsub.PubSub, topicString string) (*AggregatorNod
}

func (n *AggregatorNode) Run(ctx context.Context) {
n.loadLatestRoundId(ctx)

Check failure on line 45 in node/pkg/aggregator/node.go

View workflow job for this annotation

GitHub Actions / core-build

Error return value of `n.loadLatestRoundId` is not checked (errcheck)
n.Raft.Run(ctx, n)
}

Expand All @@ -64,6 +66,7 @@ func (n *AggregatorNode) SetLeaderJobTicker(d *time.Duration) error {
func (n *AggregatorNode) LeaderJob() error {
// leader continously sends roundId in regular basis and triggers all other nodes to run its job
n.RoundID++
n.Raft.Term++
roundMessage := RoundSyncMessage{
LeaderID: n.Raft.Host.ID().String(),
RoundID: n.RoundID,
Expand All @@ -87,30 +90,95 @@ func (n *AggregatorNode) HandleCustomMessage(message raft.Message) error {
switch message.Type {
case RoundSync:
return n.HandleRoundSyncMessage(message)
// every node runs its job when leader sends roundSync message
case PriceData:
return n.HandlePriceDataMessage(message)
case RoundReply:
return n.HandleRoundReplyMessage(message)
case TriggerAggregate:
return n.HandleTriggerAggregateMessage(message)
}
return nil
}

/*
TODO: adding another phase to agree on roundId
1. leader sends roundSync message
2. followers check if the leader's roundId is greater than its own roundId
3. if it is, follower will send signal to leader to update roundId
*/
func (n *AggregatorNode) HandleRoundSyncMessage(msg raft.Message) error {
var roundSyncMessage RoundSyncMessage
err := json.Unmarshal(msg.Data, &roundSyncMessage)
if err != nil {
return err
}
n.RoundID = roundSyncMessage.RoundID

if n.RoundID < roundSyncMessage.RoundID {
n.RoundID = roundSyncMessage.RoundID
}

roundReplyMessage := RoundReplyMessage{
RoundId: n.RoundID,
}

marshalledRoundReplyMessage, err := json.Marshal(roundReplyMessage)
if err != nil {
return err
}

message := raft.Message{
Type: RoundReply,
SentFrom: n.Raft.Host.ID().String(),
Data: json.RawMessage(marshalledRoundReplyMessage),
}

return n.Raft.PublishMessage(message)

}

func (n *AggregatorNode) HandleRoundReplyMessage(msg raft.Message) error {
if n.Raft.GetRole() != raft.Leader {
return nil
}

n.RoundSyncReplies++
var roundReplyMessage RoundReplyMessage
err := json.Unmarshal(msg.Data, &roundReplyMessage)
if err != nil {
return err
}

if roundReplyMessage.RoundId > n.RoundID {
n.RoundID = roundReplyMessage.RoundId
}

if n.RoundSyncReplies > n.Raft.SubscribersCount() {
triggerAggregateMessage := TriggerAggregateMessage{
RoundID: n.RoundID,
}

marshalledTriggerAggregateMessage, err := json.Marshal(triggerAggregateMessage)
if err != nil {
return err
}

message := raft.Message{
Type: TriggerAggregate,
SentFrom: n.Raft.Host.ID().String(),
Data: json.RawMessage(marshalledTriggerAggregateMessage),
}
n.RoundSyncReplies = 0
return n.Raft.PublishMessage(message)
}
return nil
}

func (n *AggregatorNode) HandleTriggerAggregateMessage(msg raft.Message) error {
var triggerAggregateMessage TriggerAggregateMessage
err := json.Unmarshal(msg.Data, &triggerAggregateMessage)
if err != nil {
return err
}

if triggerAggregateMessage.RoundID != n.RoundID {
n.RoundID = triggerAggregateMessage.RoundID
}

// pull latest local aggregate and send to peers
// latestAggregate := utils.RandomNumberGenerator()
var updateValue int64
value, updateTime, err := n.getLatestLocalAggregate(n.nodeCtx)
if err != nil {
Expand Down Expand Up @@ -157,7 +225,7 @@ func (n *AggregatorNode) HandlePriceDataMessage(msg raft.Message) error {
}

n.CollectedPrices[priceDataMessage.RoundID] = append(n.CollectedPrices[priceDataMessage.RoundID], priceDataMessage.PriceData)
if len(n.CollectedPrices[priceDataMessage.RoundID]) >= len(n.Raft.Ps.ListPeers(n.Raft.Topic.String()))+1 {
if len(n.CollectedPrices[priceDataMessage.RoundID]) >= n.Raft.SubscribersCount()+1 {
filteredCollectedPrices := FilterNegative(n.CollectedPrices[priceDataMessage.RoundID])

// handle aggregation here once all the data have been collected
Expand All @@ -174,6 +242,9 @@ func (n *AggregatorNode) HandlePriceDataMessage(msg raft.Message) error {
}

func (n *AggregatorNode) getLatestLocalAggregate(ctx context.Context) (int64, time.Time, error) {
if os.Getenv("TEST") == "true" {
return int64(utils.RandomNumberGenerator()), time.Now(), nil
}
redisAggregate, err := GetLatestLocalAggregateFromRdb(ctx, n.Name)
if err != nil {
pgsqlAggregate, err := GetLatestLocalAggregateFromPgs(ctx, n.Name)
Expand All @@ -186,13 +257,30 @@ func (n *AggregatorNode) getLatestLocalAggregate(ctx context.Context) (int64, ti
}

func (n *AggregatorNode) insertGlobalAggregate(ctx context.Context, name string, value int64, round int64) error {
if os.Getenv("TEST") == "true" {
return nil
}
_, err := db.QueryRow[globalAggregate](ctx, InsertGlobalAggregateQuery, map[string]any{"name": name, "value": value, "round": round})
if err != nil {
return err
}
return nil
}

func (n *AggregatorNode) loadLatestRoundId(ctx context.Context) error {
if os.Getenv("TEST") == "true" {
n.RoundID = 0
return nil
}
latestGlobalAggregate, err := db.QueryRow[globalAggregate](ctx, SelectLatestGlobalAggregateQuery, map[string]any{"name": n.Name})
if err != nil {
return err
}
n.RoundID = latestGlobalAggregate.Round
return nil

}

func (n *AggregatorNode) executeDeviation() error {
// signals for deviation job which triggers immediate aggregation and sends submission request to submitter
return nil
Expand Down
24 changes: 18 additions & 6 deletions node/pkg/aggregator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import (
)

const (
RoundSync raft.MessageType = "roundSync"
PriceData raft.MessageType = "priceData"
SelectActiveAggregatorsQuery = `SELECT * FROM aggregators WHERE active = true`
SelectLatestLocalAggregateQuery = `SELECT * FROM local_aggregates WHERE name = @name ORDER BY timestamp DESC LIMIT 1`
InsertGlobalAggregateQuery = `INSERT INTO global_aggregates (name, value, round) VALUES (@name, @value, @round) RETURNING *`
SelectLatestGlobalAggregateQuery = `SELECT * FROM global_aggregates WHERE name = @name ORDER BY round DESC LIMIT 1`
RoundSync raft.MessageType = "roundSync"
RoundReply raft.MessageType = "roundReply"
TriggerAggregate raft.MessageType = "triggerAggregate"
PriceData raft.MessageType = "priceData"

SelectActiveAggregatorsQuery = `SELECT * FROM aggregators WHERE active = true`
SelectLatestLocalAggregateQuery = `SELECT * FROM local_aggregates WHERE name = @name ORDER BY timestamp DESC LIMIT 1`
InsertGlobalAggregateQuery = `INSERT INTO global_aggregates (name, value, round) VALUES (@name, @value, @round) RETURNING *`
SelectLatestGlobalAggregateQuery = `SELECT * FROM global_aggregates WHERE name = @name ORDER BY round DESC LIMIT 1`
)

type redisLocalAggregate struct {
Expand Down Expand Up @@ -59,6 +62,7 @@ type AggregatorNode struct {

LastLocalAggregateTime time.Time
RoundID int64
RoundSyncReplies int

nodeCtx context.Context
nodeCancel context.CancelFunc
Expand All @@ -70,6 +74,14 @@ type RoundSyncMessage struct {
RoundID int64 `json:"roundID"`
}

type RoundReplyMessage struct {
RoundId int64 `json:"roundId"`
}

type TriggerAggregateMessage struct {
RoundID int64 `json:"roundID"`
}

type PriceDataMessage struct {
RoundID int64 `json:"roundID"`
PriceData int64 `json:"priceData"`
Expand Down
6 changes: 3 additions & 3 deletions node/pkg/libp2p/libp2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func DiscoverPeers(ctx context.Context, h host.Host, topicName string, bootstrap
anyConnected := false
var wg sync.WaitGroup
for !anyConnected {
log.Debug().Msg("Searching for peers...")
// log.Debug().Msg("Searching for peers...")
peerChan, err := routingDiscovery.FindPeers(ctx, topicName)
if err != nil {
return err
Expand All @@ -132,9 +132,9 @@ func DiscoverPeers(ctx context.Context, h host.Host, topicName string, bootstrap
defer wg.Done()
err := h.Connect(ctx, p)
if err != nil {
log.Trace().Msg("Failed connecting to " + p.ID.String())
// log.Trace().Msg("Failed connecting to " + p.ID.String())
} else {
log.Trace().Str("connectedTo", p.ID.String()).Msg("Connected to peer")
// log.Trace().Str("connectedTo", p.ID.String()).Msg("Connected to peer")
anyConnected = true
}
}(p)
Expand Down
8 changes: 6 additions & 2 deletions node/pkg/raft/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,19 @@ func (r *Raft) handleHeartbeat(node Node, msg Message) error {
return fmt.Errorf("leader id mismatch")
}

r.StopHeartbeatTicker(node)
r.startElectionTimer()

currentRole := r.GetRole()
currentTerm := r.GetCurrentTerm()
currentLeader := r.GetLeader()

// log.Debug().Str("current role", string(currentRole)).Str("current leader", currentLeader).Int("current term", currentTerm).Msg("received heartbeat")

// If the current role is Candidate or the current role is Leader and the current term is less than the heartbeat term, update the role to Follower
shouldUpdateRoleToFollower := (currentRole == Candidate) || (currentRole == Leader && currentTerm < heartbeatMessage.Term)
shouldUpdateRoleToFollower := (currentRole == Candidate) || (currentRole == Leader && currentTerm <= heartbeatMessage.Term)
if shouldUpdateRoleToFollower {
log.Debug().Msg("updating role to follower\n")
r.StopHeartbeatTicker(node)
r.UpdateRole(Follower)
}

Expand Down Expand Up @@ -303,6 +306,7 @@ func (r *Raft) becomeLeader(ctx context.Context, node Node) {
r.Resign = make(chan interface{})
r.ElectionTimer.Stop()
r.UpdateRole(Leader)
r.UpdateLeader(r.GetHostId())
r.HeartbeatTicker = time.NewTicker(r.HeartbeatTimeout)

var leaderJobTicker <-chan time.Time
Expand Down
6 changes: 6 additions & 0 deletions node/pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ func RandomNumberGenerator() int {
}

func FindMedian(nums []int) int {
if len(nums) == 0 {
return 0
}
sort.Ints(nums)
n := len(nums)
if n%2 == 0 {
Expand All @@ -25,6 +28,9 @@ func FindMedian(nums []int) int {
}

func FindMedianInt64(nums []int64) int64 {
if len(nums) == 0 {
return 0
}
sort.Slice(nums, func(i, j int) bool { return nums[i] < nums[j] })
n := len(nums)
if n%2 == 0 {
Expand Down

0 comments on commit 62c6b91

Please sign in to comment.