diff --git a/src/raft/consensus_state_machine.nim b/src/raft/consensus_state_machine.nim index 2156bee..d88ed1c 100644 --- a/src/raft/consensus_state_machine.nim +++ b/src/raft/consensus_state_machine.nim @@ -59,8 +59,8 @@ type lastNewIndex: RaftLogIndex RaftRpcAppendReply* = object - commitIndex: RaftLogIndex - term: RaftNodeTerm + commitIndex*: RaftLogIndex + term*: RaftNodeTerm case result: RaftRpcCode: of Accepted: accepted: RaftRpcAppendReplyAccepted of Rejected: rejected: RaftRpcAppendReplyRejected @@ -224,7 +224,7 @@ func sendToImpl*(sm: var RaftStateMachine, id: RaftNodeId, request: RaftInstallS func sendTo[MsgType](sm: var RaftStateMachine, id: RaftNodeId, request: MsgType) = - sm.debug "Send to" & $id & $request + sm.debug "Send to " & $id & $request if sm.state.isLeader: var follower = sm.findFollowerProggressById(id) if follower.isSome: @@ -331,8 +331,15 @@ func becomeCandidate*(sm: var RaftStateMachine) = func heartbeat(sm: var RaftStateMachine, follower: var RaftFollowerProgressTracker) = sm.info "heartbeat " & $follower.nextIndex - # TODO: we should just send empty array instead adding new empty entries on each heartbeat - sm.addEntry(Empty()) + var previousTerm = 0 + if sm.log.lastIndex > 1: + previousTerm = sm.log.termForIndex(follower.nextIndex - 1).get() + let request = RaftRpcAppendRequest( + previousTerm: previousTerm, + previousLogIndex: follower.nextIndex - 1, + commitIndex: sm.commitIndex, + entries: @[]) + sm.sendTo(follower.id, request) func tickLeader*(sm: var RaftStateMachine, now: times.DateTime) = sm.timeNow = now @@ -371,13 +378,14 @@ func commit(sm: var RaftStateMachine) = return var newIndex = sm.commitIndex var nextIndex = sm.commitIndex + 1 - while nextIndex < sm.log.lastIndex: + while nextIndex < sm.log.nextIndex: var replicationCnt = 1 for p in sm.leader.tracker.progress: if p.matchIndex > newIndex: replicationCnt += 1 - sm.debug "replication count" & $replicationCnt + sm.debug "replication count: " & $replicationCnt & " for log index: " & $nextIndex if replicationCnt >= (sm.leader.tracker.progress.len div 2 + 1): + sm.debug "Commit index: " & $nextIndex sm.commitIndex = nextIndex; nextIndex += 1 newIndex += 1 @@ -414,7 +422,7 @@ func appendEntryReply*(sm: var RaftStateMachine, fromId: RaftNodeId, reply: Raft case reply.result: of RaftRpcCode.Accepted: let lastIndex = reply.accepted.lastNewIndex - sm.debug "Accpeted" & $fromId & " " & $lastIndex + sm.debug "Accpeted message from" & $fromId & " last log index: " & $lastIndex follower.get().accepted(lastIndex) # TODO: add leader stepping down logic here if not sm.state.isLeader: @@ -432,6 +440,7 @@ func appendEntryReply*(sm: var RaftStateMachine, fromId: RaftNodeId, reply: Raft func advanceCommitIdx(sm: var RaftStateMachine, leaderIdx: RaftLogIndex) = let newIdx = min(leaderIdx, sm.log.lastIndex) if newIdx > sm.commitIndex: + sm.debug "Commit index is changed. Old:" & $sm.commitIndex & " New:" & $newIdx sm.commitIndex = newIdx # TODO: signal the output for the update diff --git a/src/raft/log.nim b/src/raft/log.nim index 58db552..55d6830 100644 --- a/src/raft/log.nim +++ b/src/raft/log.nim @@ -97,7 +97,7 @@ func matchTerm*(rf: RaftLog, index: RaftLogIndex, term: RaftNodeTerm): (bool, Ra func termForIndex*(rf: RaftLog, index: RaftLogIndex): Option[RaftNodeTerm] = # TODO: snapshot support - assert rf.logEntries.len > index + assert rf.logEntries.len > index - rf.firstIndex, $rf.logEntries.len & " " & $index & "" & $rf if rf.logEntries.len > 0 and index >= rf.firstIndex: - return some(rf.logEntries[index].term) + return some(rf.logEntries[index - rf.firstIndex].term) return none(RaftNodeTerm) diff --git a/tests/test_bls_cluester.nim b/tests/test_bls_cluester.nim index 5421a58..7f34729 100644 --- a/tests/test_bls_cluester.nim +++ b/tests/test_bls_cluester.nim @@ -5,6 +5,7 @@ import ../src/raft/tracker import ../src/raft/state import std/[times, sequtils, random] +import std/sugar import std/sets import std/json import std/jsonutils @@ -32,6 +33,7 @@ type SignedLogEntry = object hash: Hash + logIndex: RaftLogIndex signature: SignedShare BLSTestNode* = ref object @@ -46,11 +48,23 @@ type BLSTestCluster* = object nodes*: Table[RaftnodeId, BLSTestNode] + delayer*: MessageDelayer SecretShare = object secret: SecretKey id: ID + DelayedMessage* = object + msg: SignedRpcMessage + executeAt: times.DateTime + + MessageDelayer* = object + messages: seq[DelayedMessage] + randomGenerator: Rand + meanDelay: float + stdDelay: float + minDelayMs: int + SignedShare = object sign: Signature pubkey: PublicKey @@ -72,6 +86,26 @@ var test_ids_1 = @[ RaftnodeId(parseUUID("a8409b39-f17b-4682-aaef-a19cc9f356fb")), ] + +proc initDelayer(mean: float, std: float, minInMs: int, generator: Rand): MessageDelayer = + var delayer = MessageDelayer() + delayer.meanDelay = mean + delayer.stdDelay = std + delayer.minDelayMs = minInMs + delayer.randomGenerator = generator + return delayer + +proc getMessages(delayer: var MessageDelayer, now: times.DateTime): seq[SignedRpcMessage] = + result = delayer.messages.filter(m => m.executeAt <= now).map(m => m.msg) + delayer.messages = delayer.messages.filter(m => m.executeAt > now) + return result + +proc add(delayer: var MessageDelayer, message: SignedRpcMessage, now: times.DateTime) = + let rndDelay = delayer.randomGenerator.gauss(delayer.meanDelay, delayer.stdDelay) + let at = now + times.initDuration(milliseconds = delayer.minDelayMs + rndDelay.int) + delayer.messages.add(DelayedMessage(msg: message, executeAt: at)) + + proc signs(shares: openArray[SignedShare]): seq[Signature] = shares.mapIt(it.sign) @@ -110,6 +144,16 @@ func `$`*(de: DebugLogEntry): string = return "[" & $de.level & "][" & de.time.format("HH:mm:ss:fff") & "][" & (($de.nodeId)[0..7]) & "...][" & $de.state & "]: " & de.msg +proc sign(node: BLSTestNode, msg: Message): SignedShare = + var pk: PublicKey + discard pk.publicFromSecret(node.keyShare.secret) + echo "Produce signature from node: " & $node.stm.myId & " with public key: " & $pk.toHex & "over msg " & $msg.toJson + return SignedShare( + sign: node.keyShare.secret.sign(msg.toBytes), + pubkey: pk, + id: node.keyShare.id, + ) + proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] = var output = node.stm.poll() var debugLogs = output.debugLogs @@ -122,7 +166,9 @@ proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] = raftMsg: msg, signEntries: node.signEntries )) - node.signEntries = @[] + let commitIndex = msg.appendReply.commitIndex + # remove the signature of all entries that are already commited + node.signEntries = node.signEntries.filter(x => x.logIndex > commitIndex) else: msgs.add(SignedRpcMessage( raftMsg: msg, @@ -135,6 +181,7 @@ proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] = var orgMsg = commitedMsg.command.toMessage var shares = node.messageSignatures[orgMsg.fieldInt] echo "Try to recover message" & $orgMsg.toBytes + echo "Shares: " & $shares.signs var recoveredSignature = recover(shares.signs, shares.ids).expect("valid shares") if not node.clusterPublicKey.verify(orgMsg.toBytes, recoveredSignature): node.us.lastCommitedMsg = orgMsg @@ -148,20 +195,18 @@ proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] = echo $msg return msgs -proc acceptMessage(node: BLSTestNode, msg: SignedRpcMessage, now: times.DateTime) = - if msg.raftMsg.kind == RaftRpcMessageType.AppendReply and node.stm.state.isFollower: +proc acceptMessage(node: var BLSTestNode, msg: SignedRpcMessage, now: times.DateTime) = + if msg.raftMsg.kind == RaftRpcMessageType.AppendRequest and node.stm.state.isFollower: var pk: PublicKey discard pk.publicFromSecret(node.keyShare.secret) for entry in msg.raftMsg.appendRequest.entries: + if entry.kind == rletEmpty: + continue var orgMsg = entry.command.toMessage - echo "Sign message" & $orgMsg.toBytes var share = SignedLogEntry( hash: orgMsg.fieldInt, - signature: SignedShare( - sign: node.keyShare.secret.sign(orgMsg.toBytes), - pubkey: pk, - id: node.keyShare.id, - ) + logIndex: msg.raftMsg.appendRequest.previousLogIndex + 1, + signature: node.sign(orgMsg) ) node.signEntries.add(share) node.stm.advance(msg.raftMsg, now) @@ -192,18 +237,20 @@ proc generateSecretShares(sk: SecretKey, k: int, n: int): seq[SecretShare] = let secret = genSecretShare(originPts, id) result.add(SecretShare(secret: secret, id: id)) -proc createBLSCluster(ids: seq[RaftnodeId], now: times.DateTime) : BLSTestCluster = +proc createBLSCluster(ids: seq[RaftnodeId], now: times.DateTime, k: int, n: int, delayer: MessageDelayer) : BLSTestCluster = var sk: SecretKey discard sk.fromHex("1b500388741efd98239a9b3a689721a89a92e8b209aabb10fb7dc3f844976dc2") var pk: PublicKey discard pk.publicFromSecret(sk) - var blsShares = generateSecretShares(sk, 2, 3) + var blsShares = generateSecretShares(sk, k, n) var config = createConfigFromIds(ids) var cluster = BLSTestCluster() + cluster.delayer = delayer cluster.nodes = initTable[RaftnodeId, BLSTestNode]() + for i in 0..