diff --git a/src/raft/consensus_state_machine.nim b/src/raft/consensus_state_machine.nim index e80ee36..0bdf250 100644 --- a/src/raft/consensus_state_machine.nim +++ b/src/raft/consensus_state_machine.nim @@ -15,9 +15,6 @@ import state import std/[times] import std/random - -randomize() - type RaftRpcMessageType* = enum VoteRequest = 0, @@ -30,11 +27,13 @@ type Accepted = 1 DebugLogLevel* = enum - Critical = 0, - Error = 1, - Warning = 2, - Debug = 3, - Info = 4, + None = 0 + Critical = 1, + Error = 2, + Warning = 3, + Debug = 4, + Info = 5, + All = 6, DebugLogEntry* = object level*: DebugLogLevel @@ -113,6 +112,7 @@ type timeNow: times.DateTime startTime: times.DateTime electionTimeout: times.Duration + randomGenerator: Rand state*: RaftStateMachineState @@ -126,7 +126,7 @@ func candidate*(sm: var RaftStateMachine): var CandidateState = return sm.state.candidate func addDebugLogEntry(sm: var RaftStateMachine, level: DebugLogLevel, msg: string) = - sm.output.debugLogs.add(DebugLogEntry(time: sm.timeNow, level: level, msg: msg, nodeId: sm.myId)) + sm.output.debugLogs.add(DebugLogEntry(time: sm.timeNow, state: sm.state.state, level: level, msg: msg, nodeId: sm.myId)) func debug*(sm: var RaftStateMachine, log: string) = sm.addDebugLogEntry(DebugLogLevel.Debug, log) @@ -143,13 +143,11 @@ func info*(sm: var RaftStateMachine, log: string) = func critical*(sm: var RaftStateMachine, log: string) = sm.addDebugLogEntry(DebugLogLevel.Critical, log) - - -proc resetElectionTimeout*(sm: var RaftStateMachine) = +func resetElectionTimeout*(sm: var RaftStateMachine) = # TODO actually pick random time - sm.randomizedElectionTime = sm.electionTimeout + times.initDuration(milliseconds = 100 + rand(200)) + sm.randomizedElectionTime = sm.electionTimeout + times.initDuration(milliseconds = 100 + sm.randomGenerator.rand(200)) -proc initRaftStateMachine*(id: RaftnodeId, currentTerm: RaftNodeTerm, log: RaftLog, commitIndex: RaftLogIndex, config: RaftConfig, now: times.DateTime): RaftStateMachine = +func initRaftStateMachine*(id: RaftnodeId, currentTerm: RaftNodeTerm, log: RaftLog, commitIndex: RaftLogIndex, config: RaftConfig, now: times.DateTime, randomGenerator: Rand): RaftStateMachine = var sm = RaftStateMachine() sm.term = currentTerm sm.log = log @@ -162,6 +160,7 @@ proc initRaftStateMachine*(id: RaftnodeId, currentTerm: RaftNodeTerm, log: RaftL sm.myId = id sm.electionTimeout = times.initDuration(milliseconds = 100) sm.heartbeatTime = times.initDuration(milliseconds = 50) + sm.randomGenerator = randomGenerator sm.resetElectionTimeout() return sm @@ -280,8 +279,8 @@ func becomeCandidate*(sm: var RaftStateMachine) = let request = RaftRpcVoteRequest(currentTerm: sm.term, lastLogIndex: sm.log.lastIndex, lastLogTerm: sm.log.lastTerm, force: false) sm.sendTo(nodeId, request) - sm.debug "Elecation won" & $(sm.candidate.votes) & $sm.myId if sm.candidate.votes.tallyVote == RaftElectionResult.Won: + sm.debug "Elecation won" & $(sm.candidate.votes) & $sm.myId sm.becomeLeader() return @@ -422,7 +421,7 @@ func requestVoteReply*(sm: var RaftStateMachine, fromId: RaftNodeId, request: Ra of RaftElectionResult.Unknown: return of RaftElectionResult.Won: - sm.debug "Win election" + sm.debug "Elecation won" & $(sm.candidate.votes) & $sm.myId sm.becomeLeader() of RaftElectionResult.Lost: sm.debug "Lost election" diff --git a/src/raft/state.nim b/src/raft/state.nim index 125f0c6..af8c16c 100644 --- a/src/raft/state.nim +++ b/src/raft/state.nim @@ -10,7 +10,7 @@ type rnsLeader = 2 # Leader state RaftStateMachineState* = object - case state: RaftNodeState + case state*: RaftNodeState of rnsFollower: follower: FollowerState of rnsCandidate: candidate: CandidateState of rnsLeader: leader: LeaderState diff --git a/tests/test_consensus_state_machine.nim b/tests/test_consensus_state_machine.nim index 0a7739b..4b9b9f4 100644 --- a/tests/test_consensus_state_machine.nim +++ b/tests/test_consensus_state_machine.nim @@ -14,7 +14,7 @@ import ../src/raft/log import ../src/raft/tracker import ../src/raft/state import std/sets -import std/[times, sequtils] +import std/[times, sequtils, random] import uuids import tables import std/algorithm @@ -50,7 +50,7 @@ proc createCluster(ids: seq[RaftnodeId], now: times.DateTime) : TestCluster = for i in 0..