diff --git a/raft/consensus_module.nim b/raft/consensus_module.nim index 9d6cbd6..a4cdb5f 100644 --- a/raft/consensus_module.nim +++ b/raft/consensus_module.nim @@ -27,6 +27,9 @@ proc RaftNodeHandleHeartBeat*[SmCommandType, SmStateType](node: RaftNode[SmComma debug "Received heart-beat", node_id=node.id, sender_id=msg.sender_id, node_current_term=node.currentTerm, sender_term=msg.senderTerm result = RaftMessageResponse[SmCommandType, SmStateType](op: rmoAppendLogEntry, senderId: node.id, receiverId: msg.senderId, msgId: msg.msgId, success: false) withRLock(node.raftStateMutex): + if node.state == rnsStopped: + return + if msg.senderTerm >= node.currentTerm: RaftNodeCancelAllTimers(node) if node.state == rnsCandidate: @@ -39,12 +42,16 @@ proc RaftNodeHandleHeartBeat*[SmCommandType, SmStateType](node: RaftNode[SmComma proc RaftNodeHandleRequestVote*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType], msg: RaftMessage[SmCommandType, SmStateType]): RaftMessageResponse[SmCommandType, SmStateType] = + result = RaftMessageResponse[SmCommandType, SmStateType](op: rmoRequestVote, msgId: msg.msgId, senderId: node.id, receiverId: msg.senderId, granted: false) withRLock(node.raftStateMutex): - result = RaftMessageResponse[SmCommandType, SmStateType](op: rmoRequestVote, msgId: msg.msgId, senderId: node.id, receiverId: msg.senderId, granted: false) - if node.state != rnsCandidate and node.state != rnsStopped and msg.senderTerm > node.currentTerm and node.votedFor == DefaultUUID: - if msg.lastLogTerm >= RaftNodeLogEntryGet(node, RaftNodeLogIndexGet(node)).term or + if node.state == rnsStopped: + return + + if msg.senderTerm > node.currentTerm and node.votedFor == DefaultUUID: + if msg.lastLogTerm > RaftNodeLogEntryGet(node, RaftNodeLogIndexGet(node)).term or (msg.lastLogTerm == RaftNodeLogEntryGet(node, RaftNodeLogIndexGet(node)).term and msg.lastLogIndex >= RaftNodeLogIndexGet(node)): - asyncSpawn cancelAndWait(node.electionTimeoutTimer) + if node.electionTimeoutTimer != nil: + asyncSpawn cancelAndWait(node.electionTimeoutTimer) node.votedFor = msg.senderId result.granted = true RaftNodeScheduleElectionTimeout(node) @@ -52,16 +59,16 @@ proc RaftNodeHandleRequestVote*[SmCommandType, SmStateType](node: RaftNode[SmCom proc RaftNodeAbortElection*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType]) = withRLock(node.raftStateMutex): node.state = rnsFollower - # for fut in node.votesFuts: - # waitFor cancelAndWait(fut) + for fut in node.votesFuts: + waitFor cancelAndWait(fut) proc RaftNodeStartElection*[SmCommandType, SmStateType](node: RaftNode[SmCommandType, SmStateType]) {.async.} = - while node.votesFuts.len > 0: - discard node.votesFuts.pop mixin RaftNodeScheduleElectionTimeout RaftNodeScheduleElectionTimeout(node) withRLock(node.raftStateMutex): + while node.votesFuts.len > 0: + discard node.votesFuts.pop node.currentTerm.inc node.state = rnsCandidate node.votedFor = node.id @@ -77,22 +84,23 @@ proc RaftNodeStartElection*[SmCommandType, SmStateType](node: RaftNode[SmCommand ) ) - # Process votes (if any) - for voteFut in node.votesFuts: - try: - await voteFut or sleepAsync(milliseconds(node.votingTimeout)) - if not voteFut.finished: - await cancelAndWait(voteFut) - else: - if not voteFut.cancelled: - let respVote = RaftMessageResponse[SmCommandType, SmStateType](voteFut.read) - debug "Received vote", node_id=node.id, sender_id=respVote.senderId, granted=respVote.granted + withRLock(node.raftStateMutex): + # Process votes (if any) + for voteFut in node.votesFuts: + try: + await voteFut or sleepAsync(milliseconds(node.votingTimeout)) + if not voteFut.finished: + await cancelAndWait(voteFut) + else: + if not voteFut.cancelled: + let respVote = RaftMessageResponse[SmCommandType, SmStateType](voteFut.read) + debug "Received vote", node_id=node.id, sender_id=respVote.senderId, granted=respVote.granted - for p in node.peers: - if p.id == respVote.senderId: - p.hasVoted = respVote.granted - except Exception as e: - discard + for p in node.peers: + if p.id == respVote.senderId: + p.hasVoted = respVote.granted + except Exception as e: + discard withRLock(node.raftStateMutex): if node.state == rnsCandidate: diff --git a/raft/consensus_state_machine.nim b/raft/consensus_state_machine.nim new file mode 100644 index 0000000..9e17029 --- /dev/null +++ b/raft/consensus_state_machine.nim @@ -0,0 +1,58 @@ +# nim-raft +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import std/tables +import std/rlocks + +type + # Define callback to use with Terminals + ConsensusFSMCallbackType*[NodeType] = proc(node: NodeType) {.gcsafe.} + # Define Non-Terminals as a (unique) tuples of the internal state and a sequence of callbacks + NonTerminalSymbol*[NodeType, NodeStates] = (NodeStates, seq[ConsensusFSMCallbackType[NodeType]]) + # Define loose conditions computed from our NodeType + Condition*[NodeType] = proc(node: NodeType): bool + # Define Terminals as a tuple of a Event and (Hash) Table of sequences of (loose) conditions and their respective values computed from NodeType (Truth Table) + TerminalSymbol*[NodeType, EventType] = (Table[EventType, (seq[Condition[NodeType]], seq[bool])]) + # Define State Transition Rules LUT of the form ( NonTerminal -> Terminal ) -> NonTerminal ) + StateTransitionsRulesLUT*[NodeType, EventType, NodeStates] = Table[ + (NonTerminalSymbol[NodeType, NodeStates], TerminalSymbol[NodeType, EventType]), + NonTerminalSymbol[NodeType, NodeStates]] + + # FSM type definition + ConsensusFSM*[NodeType, EventType, NodeStates] = ref object + mtx: RLock + state: NonTerminalSymbol[NodeType, NodeStates] + stateTransitionsLUT: StateTransitionsRulesLUT[NodeType, EventType, NodeStates] + +# FSM type constructor +proc new*[NodeType, EventType, NodeStates](T: type ConsensusFSM[NodeType, EventType, NodeStates], + lut: StateTransitionsRulesLUT[NodeType, EventType, NodeStates], + startSymbol: NonTerminalSymbol[NodeType, NodeStates] + ): T = + result = new(ConsensusFSM[NodeType, EventType, NodeStates]) + initRLock(result.mtx) + result.state = startSymbol + result.stateTransitionsLUT = lut + +proc computeFSMInputRobustLogic[NodeType, EventType](node: NodeType, event: EventType, rawInput: TerminalSymbol[NodeType, EventType]): + TerminalSymbol[NodeType, EventType] = + var + robustLogicEventTerminal = rawInput[event] + for f, v in robustLogicEventTerminal: + v = f(node) + rawInput[event] = robustLogicEventTerminal + result = rawInput + +proc consensusFSMAdvance[NodeType, EventType, NodeStates](fsm: ConsensusFSM[NodeType, EventType, NodeStates], node: NodeType, event: EventType, + rawInput: TerminalSymbol[NodeType, EventType]): NonTerminalSymbol[NodeType, NodeStates] = + withRLock(): + var + input = computeFSMInputRobustLogic(node, event, rawInput) + fsm.state = fsm.stateTransitionsLUT[fsm.state, input] + result = fsm.state \ No newline at end of file