diff --git a/raft/consensus_module.nim b/raft/consensus_module.nim index 21c93fb..fd188ab 100644 --- a/raft/consensus_module.nim +++ b/raft/consensus_module.nim @@ -72,11 +72,22 @@ proc raftNodeAbortElection*[SmCommandType, SmStateType](node: RaftNode[SmCommand proc raftNodeStartElection*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType]) {.async.} = mixin raftNodeScheduleElectionTimeout, raftTimerCreate - raftNodeScheduleElectionTimeout(node) withRLock(node.raftStateMutex): + if node.state == rnsLeader and node.hrtBtSuccess: + raftNodeScheduleElectionTimeout(node) + return + + if node.state == rnsLeader and not node.hrtBtSuccess: + node.state = rnsFollower + node.currentLeaderId = DefaultUUID + node.votedFor = DefaultUUID + raftNodeScheduleElectionTimeout(node) + return + while node.votesFuts.len > 0: discard node.votesFuts.pop + node.currentTerm.inc node.state = rnsCandidate node.votedFor = node.id @@ -94,7 +105,7 @@ proc raftNodeStartElection*[SmCommandType, SmStateType](node: RaftNode[SmCommand # Wait for votes or voting timeout let all = allFutures(node.votesFuts) - await all or raftTimerCreate(node.votingTimeout, proc()=discard) + await all or raftTimerCreate(node.votingRespTimeout, proc()=discard) if not all.finished: debug "Raft Node Voting timeout", node_id=node.id @@ -114,6 +125,7 @@ proc raftNodeStartElection*[SmCommandType, SmStateType](node: RaftNode[SmCommand await cancelAndWait(node.electionTimeoutTimer) debug "Raft Node transition to leader", node_id=node.id node.state = rnsLeader # Transition to leader state and send Heart-Beat to establish this node as the cluster leader + raftNodeScheduleElectionTimeout(node) asyncSpawn raftNodeSendHeartBeat(node) proc raftNodeHandleAppendEntries*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], msg: RaftMessage[SmCommandType, SmStateType]): @@ -177,7 +189,7 @@ proc raftNodeReplicateSmCommand*[SmCommandType, SmStateType](node: RaftNode[SmCo node.replicateFuts.add(node.msgSendCallback(msg)) let allReplicateFuts = allFutures(node.replicateFuts) - await allReplicateFuts or raftTimerCreate(node.appendEntriesTimeout, proc()=discard) + await allReplicateFuts or raftTimerCreate(node.appendEntriesRespTimeout, proc()=discard) if not allReplicateFuts.finished: debug "Raft Node Replication timeout", node_id=node.id diff --git a/raft/raft_api.nim b/raft/raft_api.nim index 7737e2b..c8675c9 100644 --- a/raft/raft_api.nim +++ b/raft/raft_api.nim @@ -32,8 +32,9 @@ proc new*[SmCommandType, SmStateType](T: type RaftNode[SmCommandType, SmStateTyp msgSendCallback: RaftMessageSendCallback; electionTimeout: int=150; heartBeatTimeout: int=150; - appendEntriesTimeout: int=30; - votingTimeout: int=20 + appendEntriesRespTimeout: int=20; + votingRespTimeout: int=20; + heartBeatRespTimeout: int=10 ): T = var peers: RaftNodePeers @@ -44,8 +45,8 @@ proc new*[SmCommandType, SmStateType](T: type RaftNode[SmCommandType, SmStateTyp result = T( id: id, state: rnsFollower, currentTerm: 0, peers: peers, commitIndex: 0, lastApplied: 0, msgSendCallback: msgSendCallback, votedFor: DefaultUUID, currentLeaderId: DefaultUUID, - electionTimeout: electionTimeout, heartBeatTimeout: heartBeatTimeout, appendEntriesTimeout: appendEntriesTimeout, - votingTimeout: votingTimeout + electionTimeout: electionTimeout, heartBeatTimeout: heartBeatTimeout, appendEntriesRespTimeout: appendEntriesRespTimeout, + heartBeatRespTimeout: heartBeatRespTimeout, votingRespTimeout: votingRespTimeout, hrtBtSuccess: false ) raftNodeSmInit(result.stateMachine) @@ -146,13 +147,27 @@ proc raftNodeSendHeartBeat*[SmCommandType, SmStateType](node: RaftNode[SmCommand debug "Raft Node sending Heart-Beat to peers", node_id=node.id withRLock(node.raftStateMutex): + var hrtBtFuts: seq[Future[RaftMessageResponseBase[SmCommandType, SmStateType]]] + for raftPeer in node.peers: let msgHrtBt = RaftMessage[SmCommandType, SmStateType]( op: rmoAppendLogEntry, senderId: node.id, receiverId: raftPeer.id, senderTerm: raftNodeTermGet(node), commitIndex: node.commitIndex, prevLogIndex: raftNodeLogIndexGet(node) - 1, prevLogTerm: if raftNodeLogIndexGet(node) > 0: raftNodeLogEntryGet(node, raftNodeLogIndexGet(node) - 1).term else: 0 ) - discard node.msgSendCallback(msgHrtBt) + hrtBtFuts.add(node.msgSendCallback(msgHrtBt)) + let allHrtBtFuts = allFutures(hrtBtFuts) + await allHrtBtFuts or raftTimerCreate(node.heartBeatRespTimeout, proc()=discard) + + var successCnt = 0 + for fut in hrtBtFuts: + if fut.finished: + let resp = RaftMessageResponse[SmCommandType, SmStateType](fut.read) + if resp.success: + successCnt.inc + + if successCnt >= (node.peers.len div 2 + node.peers.len mod 2): + node.hrtBtSuccess = true raftNodeScheduleHeartBeat(node) @@ -161,6 +176,7 @@ proc raftNodeScheduleElectionTimeout*[SmCommandType, SmStateType](node: RaftNode node.electionTimeoutTimer = raftTimerCreate(node.electionTimeout + rand(node.electionTimeout), proc = asyncSpawn raftNodeStartElection(node) ) + node.hrtBtSuccess = false # Raft Node Control proc raftNodeCancelTimers*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType]) = diff --git a/raft/types.nim b/raft/types.nim index 420bee4..295ca2c 100644 --- a/raft/types.nim +++ b/raft/types.nim @@ -127,8 +127,9 @@ type electionTimeout*: int heartBeatTimeout*: int - appendEntriesTimeout*: int - votingTimeout*: int + appendEntriesRespTimeout*: int + votingRespTimeout*: int + heartBeatRespTimeout*: int heartBeatTimer*: Future[void] electionTimeoutTimer*: Future[void] @@ -139,6 +140,7 @@ type # Misc msgSendCallback*: RaftMessageSendCallback[SmCommandType, SmStateType] persistentStorage: RaftNodePersistentStorage[SmCommandType, SmStateType] + hrtBtSuccess*: bool # Persistent state id*: RaftNodeId # This Raft Node ID diff --git a/tests/basic_cluster.nim b/tests/basic_cluster.nim index 15d15c8..23da196 100644 --- a/tests/basic_cluster.nim +++ b/tests/basic_cluster.nim @@ -46,7 +46,8 @@ proc basicRaftClusterClientRequest*(cluster: BasicRaftCluster, req: RaftNodeClie of rncroExecSmCommand: discard -proc basicRaftClusterInit*(nodesIds: seq[RaftNodeId], electionTimeout=150, heartBeatTimeout=150, appendEntriesTimeout=20, votingTimeout=20): BasicRaftCluster = +proc basicRaftClusterInit*(nodesIds: seq[RaftNodeId], electionTimeout: int=150, heartBeatTimeout: int=150, appendEntriesRespTimeout: int=20, votingRespTimeout: int=20, + heartBeatRespTimeout: int=10): BasicRaftCluster = new(result) for nodeId in nodesIds: var @@ -55,5 +56,5 @@ proc basicRaftClusterInit*(nodesIds: seq[RaftNodeId], electionTimeout=150, heart peersIds.del(peersIds.find(nodeId)) result.nodes[nodeId] = BasicRaftNode.new(nodeId, peersIds, basicRaftClusterRaftMessageSendCallbackCreate[SmCommand, SmState](result), - electionTimeout, heartBeatTimeout, appendEntriesTimeout, votingTimeout) + electionTimeout, heartBeatTimeout, appendEntriesRespTimeout, votingRespTimeout, heartBeatRespTimeout) diff --git a/tests/test_basic_cluster_election.nim b/tests/test_basic_cluster_election.nim index e5c7b01..7d858fd 100644 --- a/tests/test_basic_cluster_election.nim +++ b/tests/test_basic_cluster_election.nim @@ -20,7 +20,7 @@ proc basicClusterElectionMain*() = test "Basic Raft Cluster Init (5 nodes)": for i in 0..4: nodesIds[i] = genUUID() - cluster = basicRaftClusterInit(nodesIds, 150, 150, 20, 20) + cluster = basicRaftClusterInit(nodesIds, 150, 150, 20, 20, 10) check cluster != nil test "Start Basic Raft Cluster and wait it to converge for a 2 seconds interval (Elect a Leader)":