diff --git a/raft/consensus_module.nim b/raft/consensus_module.nim index 2e56a1b..bbe4009 100644 --- a/raft/consensus_module.nim +++ b/raft/consensus_module.nim @@ -22,31 +22,38 @@ proc raftNodeQuorumMin[SmCommandType, SmStateType](node: RaftNode[SmCommandType, if cnt >= (node.peers.len div 2 + node.peers.len mod 2): result = true +proc raftNodeCheckCommitIndex*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], msg: RaftMessage[SmCommandType, SmStateType]) = + withRLock(node.raftStateMutex): + if msg.commitIndex > node.commitIndex: + let newcommitIndex = min(msg.commitIndex, raftNodeLogIndexGet(node)) + + while node.commitIndex < newcommitIndex: + node.commitIndex.inc + raftNodeApplyLogEntry(node, raftNodeLogEntryGet(node, node.commitIndex)) + proc raftNodeHandleHeartBeat*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], msg: RaftMessage[SmCommandType, SmStateType]): RaftMessageResponse[SmCommandType, SmStateType] = debug "Received heart-beat", node_id=node.id, sender_id=msg.sender_id, node_current_term=node.currentTerm, sender_term=msg.senderTerm result = RaftMessageResponse[SmCommandType, SmStateType](op: rmoAppendLogEntry, senderId: node.id, receiverId: msg.senderId, msgId: msg.msgId, success: false) withRLock(node.raftStateMutex): - if node.state == rnsStopped: - return - if msg.senderTerm >= node.currentTerm: raftNodeCancelTimers(node) if node.state == rnsCandidate: raftNodeAbortElection(node) + result.success = true node.currentTerm = msg.senderTerm node.votedFor = DefaultUUID node.currentLeaderId = msg.senderId + + raftNodeCheckCommitIndex(node, msg) + raftNodeScheduleElectionTimeout(node) proc raftNodeHandleRequestVote*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], msg: RaftMessage[SmCommandType, SmStateType]): RaftMessageResponse[SmCommandType, SmStateType] = result = RaftMessageResponse[SmCommandType, SmStateType](op: rmoRequestVote, msgId: msg.msgId, senderId: node.id, receiverId: msg.senderId, granted: false) withRLock(node.raftStateMutex): - if node.state == rnsStopped: - return - if msg.senderTerm > node.currentTerm and node.votedFor == DefaultUUID: if msg.lastLogTerm > raftNodeLogEntryGet(node, raftNodeLogIndexGet(node)).term or (msg.lastLogTerm == raftNodeLogEntryGet(node, raftNodeLogIndexGet(node)).term and msg.lastLogIndex >= raftNodeLogIndexGet(node)): @@ -85,7 +92,6 @@ proc raftNodeStartElection*[SmCommandType, SmStateType](node: RaftNode[SmCommand ) ) - withRLock(node.raftStateMutex): # Wait for votes or voting timeout let all = allFutures(node.votesFuts) await all or raftTimerCreate(node.votingTimeout, proc()=discard) @@ -114,9 +120,6 @@ proc raftNodeHandleAppendEntries*[SmCommandType, SmStateType](node: RaftNode[SmC RaftMessageResponse[SmCommandType, SmStateType] = result = RaftMessageResponse[SmCommandType, SmStateType](op: rmoAppendLogEntry, senderId: node.id, receiverId: msg.senderId, msgId: msg.msgId, success: false) withRLock(node.raftStateMutex): - if node.state == rnsStopped: - return - if msg.senderTerm >= node.currentTerm: raftNodeCancelTimers(node) if node.state == rnsCandidate: @@ -146,33 +149,48 @@ proc raftNodeHandleAppendEntries*[SmCommandType, SmStateType](node: RaftNode[SmC for entry in msg.logEntries.get: raftNodeLogAppend(node, entry) - if msg.commitIndex > node.commitIndex: - node.commitIndex = min(msg.commitIndex, raftNodeLogIndexGet(node)) - raftNodeApplyLogEntry(node, raftNodeLogEntryGet(node, node.commitIndex)) + raftNodeCheckCommitIndex(node, msg) result.success = true -proc raftNodeReplicateSmCommand*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], cmd: SmCommandType) = +proc raftNodeReplicateSmCommand*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], cmd: SmCommandType): Future[bool] {.async.} = mixin RaftLogEntry, raftTimerCreate + result = false + withRLock(node.raftStateMutex): var logEntry: RaftLogEntry[SmCommandType](term: node.currentTerm, data: cmd, entryType: etData) - raftNodeLogAppend(node, logEntry) for peer in node.peers: var msg: RaftMessage[SmCommandType, SmStateType] = RaftMessage[SmCommandType, SmStateType]( op: rmoAppendLogEntry, msgId: genUUID(), senderId: node.id, receiverId: peer.id, - senderTerm: node.currentTerm, prevLogIndex: raftNodeLogIndexGet(node), - prevLogTerm: raftNodeLogEntryGet(node, raftNodeLogIndexGet(node)).term, + senderTerm: node.currentTerm, prevLogIndex: raftNodeLogIndexGet(node) - 1, + prevLogTerm: raftNodeLogEntryGet(node, raftNodeLogIndexGet(node) - 1).term, commitIndex: node.commitIndex, entries: @[logEntry] ) node.replicateFuts.add(node.msgSendCallback(msg)) - node.commitIndex.inc - raftNodeApplyLogEntry(node, raftNodeLogEntryGet(node, node.commitIndex)) # Apply to state machine - + let allReplicateFuts = allFutures(node.replicateFuts) + await allReplicateFuts or raftTimerCreate(node.appendEntriesTimeout, proc()=discard) + if not allReplicateFuts.finished: + debug "Raft Node Replication timeout", node_id=node.id + + var replicateCnt = 0 + for fut in node.replicateFuts: + if fut.finished and not fut.cancelled: + let resp = RaftMessageResponse[SmCommandType, SmStateType](fut.read) + if resp.success: + replicateCnt.inc + info "Raft Node Replication success", node_id=node.id, sender_id=resp.senderId + else: + info "Raft Node Replication failed", node_id=node.id, sender_id=resp.senderId + + if replicateCnt >= (node.peers.len div 2 + node.peers.len mod 2): + node.commitIndex = raftNodeLogIndexGet(node) + raftNodeApplyLogEntry(node, raftNodeLogEntryGet(node, node.commitIndex)) # Apply to state machine + result = true \ No newline at end of file diff --git a/raft/log_ops.nim b/raft/log_ops.nim index 6dfc6e0..32cc2d9 100644 --- a/raft/log_ops.nim +++ b/raft/log_ops.nim @@ -16,7 +16,7 @@ proc raftNodeLogIndexGet*[SmCommandType, SmStateType](node: RaftNode[SmCommandTy proc raftNodeLogEntryGet*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], logIndex: RaftLogIndex): RaftNodeLogEntry[SmCommandType] = if logIndex > 0: - result = node.log.logData[logIndex] + result = node.log.logData[logIndex - 1] proc raftNodeLogAppend*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], logEntry: RaftNodeLogEntry[SmCommandType]) = node.log.logData.add(logEntry) diff --git a/raft/protocol.nim b/raft/protocol.nim index 6d0c438..a37f1a5 100644 --- a/raft/protocol.nim +++ b/raft/protocol.nim @@ -46,7 +46,8 @@ type RaftNodeClientResponseError* = enum rncreSuccess = 0, rncreFail = 1, - rncreNotLeader = 2 + rncreNotLeader = 2, + rncreStopped = 3 RaftNodeClientRequest*[SmCommandType] = ref object op*: RaftNodeClientRequestOps diff --git a/raft/raft_api.nim b/raft/raft_api.nim index 2369b28..29516ff 100644 --- a/raft/raft_api.nim +++ b/raft/raft_api.nim @@ -39,7 +39,7 @@ proc new*[SmCommandType, SmStateType](T: type RaftNode[SmCommandType, SmStateTyp peers: RaftNodePeers for peerId in peersIds: - peers.add(RaftNodePeer(id: peerId, nextIndex: 0, matchIndex: 0, hasVoted: false, canVote: true)) + peers.add(RaftNodePeer(id: peerId, nextIndex: 1, matchIndex: 0, hasVoted: false, canVote: true)) result = T( id: id, state: rnsFollower, currentTerm: 0, peers: peers, commitIndex: 0, lastApplied: 0, @@ -95,17 +95,26 @@ proc raftNodeMessageDeliver*[SmCommandType, SmStateType](node: RaftNode[SmComman # Process Raft Node Client Requests proc raftNodeServeClientRequest*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], req: RaftNodeClientRequest[SmCommandType]): Future[RaftNodeClientResponse[SmStateType]] {.async, gcsafe.} = - case req.op - of rncroExecSmCommand: - # TODO: implemenmt command handling - discard - of rncroRequestSmState: - if raftNodeIsLeader(node): + + withRLock(node.raftStateMutex): + if not raftNodeIsLeader(node): + return RaftNodeClientResponse(nodeId: node.id, error: rncreNotLeader, currentLeaderId: node.currentLeaderId) + + case req.op + of rncroExecSmCommand: + + let resFut = await raftNodeReplicateSmCommand(node, req.smCommand) + + if resFut.read: + return RaftNodeClientResponse(nodeId: node.id, error: rncreSuccess, state: raftNodeStateGet(node)) + else: + return RaftNodeClientResponse(nodeId: node.id, error: rncreFail, state: raftNodeStateGet(node)) + + of rncroRequestSmState: return RaftNodeClientResponse(nodeId: node.id, error: rncreSuccess, state: raftNodeStateGet(node)) + else: - return RaftNodeClientResponse(nodeId: node.id, error: rncreNotLeader, currentLeaderId: node.currentLeaderId) - else: - raiseAssert "Unknown client request operation." + raiseAssert "Unknown client request operation." # Abstract State Machine Ops func raftNodeSmStateGet*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType]): SmStateType = @@ -164,7 +173,6 @@ proc raftNodeStop*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmS # Abort election if in election if node.state == rnsCandidate: raftNodeAbortElection(node)s - node.state = rnsStopped # Cancel pending timers (if any) raftNodeCancelTimers(node) diff --git a/raft/types.nim b/raft/types.nim index 702af62..420bee4 100644 --- a/raft/types.nim +++ b/raft/types.nim @@ -30,8 +30,7 @@ type rnsUnknown = 0, rnsFollower = 1, rnsCandidate = 2 - rnsLeader = 3, - rnsStopped = 4 + rnsLeader = 3 RaftNodeId* = UUID # uuid4 uniquely identifying every Raft Node RaftNodeTerm* = int # Raft Node Term Type