Skip to content

Commit

Permalink
raft: improve mutex, concurrent msg read
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Aug 3, 2024
1 parent 730819a commit a59d240
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 113 deletions.
60 changes: 0 additions & 60 deletions node/pkg/raft/accessors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,72 +6,12 @@ func (r *Raft) IncreaseTerm() {
r.Term++
}

func (r *Raft) UpdateTerm(newTerm int) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.Term = newTerm
}

func (r *Raft) GetCurrentTerm() int {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.Term
}

func (r *Raft) IncreaseVote() {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.VotesReceived++
}

func (r *Raft) UpdateVoteReceived(votes int) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.VotesReceived = votes
}

func (r *Raft) GetVoteReceived() int {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.VotesReceived
}

func (r *Raft) UpdateRole(role RoleType) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.Role = role
}

func (r *Raft) GetRole() RoleType {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.Role
}

func (r *Raft) GetVotedFor() string {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.VotedFor
}

func (r *Raft) GetLeader() string {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.LeaderID
}

func (r *Raft) UpdateLeader(leader string) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.LeaderID = leader
}

func (r *Raft) UpdateVotedFor(votedFor string) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.VotedFor = votedFor
}

func (r *Raft) SubscribersCount() int {
return len(r.Ps.ListPeers(r.Topic.String()))
}
Expand Down
122 changes: 69 additions & 53 deletions node/pkg/raft/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,10 @@ func NewRaftNode(
func (r *Raft) Run(ctx context.Context) {
go r.subscribe(ctx)
r.startElectionTimer()

for {
select {
case msg := <-r.MessageBuffer:
err := r.handleMessage(ctx, msg)
if err != nil {
log.Error().Err(err).Msg("failed to handle message")
}

go r.handleMessage(ctx, msg)
case <-r.ElectionTimer.C:
r.startElection()
case <-ctx.Done():
Expand Down Expand Up @@ -84,11 +79,13 @@ func (r *Raft) subscribe(ctx context.Context) {
log.Error().Err(err).Msg("failed to get message from topic")
continue
}

msg, err := r.unmarshalMessage(rawMsg.Data)
if err != nil {
log.Error().Err(err).Msg("failed to unmarshal message")
continue
}

r.MessageBuffer <- msg
}
}
Expand All @@ -113,7 +110,6 @@ func (r *Raft) handleHeartbeat(msg Message) error {
if msg.SentFrom == r.GetHostId() {
return nil
}

var heartbeatMessage HeartbeatMessage
err := json.Unmarshal(msg.Data, &heartbeatMessage)
if err != nil {
Expand All @@ -125,9 +121,12 @@ func (r *Raft) handleHeartbeat(msg Message) error {
return errorSentinel.ErrRaftLeaderIdMismatch
}

currentRole := r.GetRole()
currentTerm := r.GetCurrentTerm()
currentLeader := r.GetLeader()
r.Mutex.Lock()
defer r.Mutex.Unlock()

currentRole := r.Role
currentTerm := r.Term
currentLeader := r.LeaderID

if currentTerm > heartbeatMessage.Term && currentRole != Leader {
r.startElectionTimer()
Expand All @@ -141,21 +140,22 @@ func (r *Raft) handleHeartbeat(msg Message) error {
if currentRole == Leader {
r.ResignLeader()
} else if currentRole == Candidate {
r.UpdateRole(Follower)
r.Role = Follower
}

r.startElectionTimer()
r.UpdateTerm(heartbeatMessage.Term)
r.Term = heartbeatMessage.Term

if currentLeader != heartbeatMessage.LeaderID {
r.UpdateLeader(heartbeatMessage.LeaderID)
r.LeaderID = heartbeatMessage.LeaderID
}

return nil
}

func (r *Raft) handleRequestVote(msg Message) error {
if r.GetRole() == Leader {
r.Mutex.Lock()
defer r.Mutex.Unlock()
if r.Role == Leader {
return nil
}

Expand All @@ -166,32 +166,34 @@ func (r *Raft) handleRequestVote(msg Message) error {
return err
}

currentTerm := r.GetCurrentTerm()
currentTerm := r.Term

if RequestVoteMessage.Term > currentTerm {
r.UpdateTerm(RequestVoteMessage.Term)
r.Term = RequestVoteMessage.Term
}

if RequestVoteMessage.Term < currentTerm {
return r.sendReplyVote(msg.SentFrom, false)
}

if r.GetRole() == Candidate && RequestVoteMessage.Term == currentTerm && msg.SentFrom != r.GetHostId() {
r.UpdateRole(Follower)
if r.Role == Candidate && RequestVoteMessage.Term == currentTerm && msg.SentFrom != r.GetHostId() {
r.Role = Follower
return r.sendReplyVote(msg.SentFrom, false)
}

voteGranted := false
if r.GetVotedFor() == "" || r.GetVotedFor() == msg.SentFrom {
if r.VotedFor == "" || r.VotedFor == msg.SentFrom {
voteGranted = true
r.UpdateVotedFor(msg.SentFrom)
r.VotedFor = msg.SentFrom
}
log.Debug().Bool("vote granted", voteGranted).Msg("voted")
return r.sendReplyVote(msg.SentFrom, voteGranted)
}

func (r *Raft) handleReplyVote(ctx context.Context, msg Message) error {
if r.GetRole() != Candidate {
r.Mutex.Lock()
defer r.Mutex.Unlock()
if r.Role != Candidate {
return nil
}

Expand All @@ -205,11 +207,11 @@ func (r *Raft) handleReplyVote(ctx context.Context, msg Message) error {
return nil
}

if replyVoteMessage.VoteGranted && replyVoteMessage.LeaderID == r.GetHostId() && r.GetRole() == Candidate {
r.IncreaseVote()
log.Debug().Int("vote received", r.GetVoteReceived()).Msg("vote received")
if replyVoteMessage.VoteGranted && replyVoteMessage.LeaderID == r.GetHostId() && r.Role == Candidate {
r.VotesReceived++
log.Debug().Int("vote received", r.VotesReceived).Msg("vote received")
log.Debug().Int("subscribers count", r.SubscribersCount()).Msg("subscribers count")
if r.GetVoteReceived() >= (r.SubscribersCount()+1)/2 {
if r.VotesReceived >= (r.SubscribersCount()+1)/2 {
r.becomeLeader(ctx)
}
}
Expand All @@ -227,11 +229,12 @@ func (r *Raft) PublishMessage(msg Message) error {
}

func (r *Raft) sendHeartbeat() error {

r.Mutex.Lock()
heartbeatMessage := HeartbeatMessage{
LeaderID: r.GetHostId(),
Term: r.GetCurrentTerm(),
Term: r.Term,
}
r.Mutex.Unlock()
marshalledHeartbeatMsg, err := json.Marshal(heartbeatMessage)
if err != nil {
log.Error().Err(err).Msg("failed to marshal heartbeat message")
Expand Down Expand Up @@ -274,7 +277,7 @@ func (r *Raft) sendReplyVote(to string, voteGranted bool) error {

func (r *Raft) sendRequestVote() error {
requestVoteMessage := RequestVoteMessage{
Term: r.GetCurrentTerm(),
Term: r.Term,
}
marshalledRequestVoteMsg, err := json.Marshal(requestVoteMessage)
if err != nil {
Expand All @@ -299,32 +302,37 @@ func (r *Raft) ResignLeader() {
if r.Resign != nil {
close(r.Resign)
r.Resign = nil

r.UpdateRole(Follower)
r.UpdateLeader("")
r.Role = Follower
r.LeaderID = ""
r.startElectionTimer()
}
}

func (r *Raft) becomeLeader(ctx context.Context) {

log.Debug().Msg("becoming leader")

func (r *Raft) setLeaderState() {
r.Resign = make(chan interface{})
r.ElectionTimer.Stop()
r.UpdateRole(Leader)
r.UpdateLeader(r.GetHostId())
r.Role = Leader
r.LeaderID = r.GetHostId()
r.HeartbeatTicker = time.NewTicker(r.HeartbeatTimeout)
r.LeaderJobTicker = time.NewTicker(r.LeaderJobTimeout)
}

func (r *Raft) becomeLeader(ctx context.Context) {
r.setLeaderState()
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Msgf("recovered from panic in leader job: %v", r)
}
}()

for {
select {
case <-r.Resign:
log.Debug().Msg("resigning as leader")
r.Mutex.Lock()
r.HeartbeatTicker.Stop()
r.LeaderJobTicker.Stop()

r.Mutex.Unlock()
return

case <-r.HeartbeatTicker.C:
Expand All @@ -335,11 +343,6 @@ func (r *Raft) becomeLeader(ctx context.Context) {

case <-r.LeaderJobTicker.C:
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Msgf("recovered from panic in leader job: %v", r)
}
}()
err := r.LeaderJob()
if err != nil {
log.Error().Err(err).Msg("failed to execute leader job")
Expand All @@ -348,8 +351,10 @@ func (r *Raft) becomeLeader(ctx context.Context) {

case <-ctx.Done():
log.Debug().Msg("context cancelled")
r.Mutex.Lock()
r.HeartbeatTicker.Stop()
r.LeaderJobTicker.Stop()
r.Mutex.Unlock()
return
}
}
Expand All @@ -359,23 +364,34 @@ func (r *Raft) becomeLeader(ctx context.Context) {
func (r *Raft) getRandomElectionTimeout() time.Duration {
minTimeout := int(r.HeartbeatTimeout) * 3
maxTimeout := int(r.HeartbeatTimeout) * 6
return time.Duration(minTimeout + rand.Intn(maxTimeout-minTimeout))
duration := time.Duration(minTimeout + rand.Intn(maxTimeout-minTimeout))
return duration
}

func (r *Raft) startElectionTimer() {
if r.ElectionTimer != nil {
r.ElectionTimer.Stop()
if !r.ElectionTimer.Stop() {
select {
case <-r.ElectionTimer.C:
log.Debug().Msg("Old timer channel drained")
default:
log.Debug().Msg("Old timer channel already empty")
}
}
r.ElectionTimer.Reset(r.getRandomElectionTimeout())
} else {
r.ElectionTimer = time.NewTimer(r.getRandomElectionTimeout())
}
r.ElectionTimer = time.NewTimer(r.getRandomElectionTimeout())
}

func (r *Raft) startElection() {
r.IncreaseTerm()
r.UpdateVoteReceived(0)
log.Debug().Msg("start election")

r.UpdateRole(Candidate)
r.UpdateVotedFor(r.GetHostId())
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.Term++
r.VotesReceived = 0
r.Role = Candidate
r.VotedFor = r.GetHostId()

r.startElectionTimer()

Expand Down

0 comments on commit a59d240

Please sign in to comment.