Fix lastTerm access

This commit is contained in:
Marto 2024-02-29 19:26:11 +02:00
parent 5950ed20eb
commit 5bba48662c
4 changed files with 117 additions and 48 deletions

View File

@ -59,8 +59,8 @@ type
lastNewIndex: RaftLogIndex lastNewIndex: RaftLogIndex
RaftRpcAppendReply* = object RaftRpcAppendReply* = object
commitIndex: RaftLogIndex commitIndex*: RaftLogIndex
term: RaftNodeTerm term*: RaftNodeTerm
case result: RaftRpcCode: case result: RaftRpcCode:
of Accepted: accepted: RaftRpcAppendReplyAccepted of Accepted: accepted: RaftRpcAppendReplyAccepted
of Rejected: rejected: RaftRpcAppendReplyRejected of Rejected: rejected: RaftRpcAppendReplyRejected
@ -331,8 +331,15 @@ func becomeCandidate*(sm: var RaftStateMachine) =
func heartbeat(sm: var RaftStateMachine, follower: var RaftFollowerProgressTracker) = func heartbeat(sm: var RaftStateMachine, follower: var RaftFollowerProgressTracker) =
sm.info "heartbeat " & $follower.nextIndex sm.info "heartbeat " & $follower.nextIndex
# TODO: we should just send empty array instead adding new empty entries on each heartbeat var previousTerm = 0
sm.addEntry(Empty()) if sm.log.lastIndex > 1:
previousTerm = sm.log.termForIndex(follower.nextIndex - 1).get()
let request = RaftRpcAppendRequest(
previousTerm: previousTerm,
previousLogIndex: follower.nextIndex - 1,
commitIndex: sm.commitIndex,
entries: @[])
sm.sendTo(follower.id, request)
func tickLeader*(sm: var RaftStateMachine, now: times.DateTime) = func tickLeader*(sm: var RaftStateMachine, now: times.DateTime) =
sm.timeNow = now sm.timeNow = now
@ -371,13 +378,14 @@ func commit(sm: var RaftStateMachine) =
return return
var newIndex = sm.commitIndex var newIndex = sm.commitIndex
var nextIndex = sm.commitIndex + 1 var nextIndex = sm.commitIndex + 1
while nextIndex < sm.log.lastIndex: while nextIndex < sm.log.nextIndex:
var replicationCnt = 1 var replicationCnt = 1
for p in sm.leader.tracker.progress: for p in sm.leader.tracker.progress:
if p.matchIndex > newIndex: if p.matchIndex > newIndex:
replicationCnt += 1 replicationCnt += 1
sm.debug "replication count" & $replicationCnt sm.debug "replication count: " & $replicationCnt & " for log index: " & $nextIndex
if replicationCnt >= (sm.leader.tracker.progress.len div 2 + 1): if replicationCnt >= (sm.leader.tracker.progress.len div 2 + 1):
sm.debug "Commit index: " & $nextIndex
sm.commitIndex = nextIndex; sm.commitIndex = nextIndex;
nextIndex += 1 nextIndex += 1
newIndex += 1 newIndex += 1
@ -414,7 +422,7 @@ func appendEntryReply*(sm: var RaftStateMachine, fromId: RaftNodeId, reply: Raft
case reply.result: case reply.result:
of RaftRpcCode.Accepted: of RaftRpcCode.Accepted:
let lastIndex = reply.accepted.lastNewIndex let lastIndex = reply.accepted.lastNewIndex
sm.debug "Accpeted" & $fromId & " " & $lastIndex sm.debug "Accpeted message from" & $fromId & " last log index: " & $lastIndex
follower.get().accepted(lastIndex) follower.get().accepted(lastIndex)
# TODO: add leader stepping down logic here # TODO: add leader stepping down logic here
if not sm.state.isLeader: if not sm.state.isLeader:
@ -432,6 +440,7 @@ func appendEntryReply*(sm: var RaftStateMachine, fromId: RaftNodeId, reply: Raft
func advanceCommitIdx(sm: var RaftStateMachine, leaderIdx: RaftLogIndex) = func advanceCommitIdx(sm: var RaftStateMachine, leaderIdx: RaftLogIndex) =
let newIdx = min(leaderIdx, sm.log.lastIndex) let newIdx = min(leaderIdx, sm.log.lastIndex)
if newIdx > sm.commitIndex: if newIdx > sm.commitIndex:
sm.debug "Commit index is changed. Old:" & $sm.commitIndex & " New:" & $newIdx
sm.commitIndex = newIdx sm.commitIndex = newIdx
# TODO: signal the output for the update # TODO: signal the output for the update

View File

@ -97,7 +97,7 @@ func matchTerm*(rf: RaftLog, index: RaftLogIndex, term: RaftNodeTerm): (bool, Ra
func termForIndex*(rf: RaftLog, index: RaftLogIndex): Option[RaftNodeTerm] = func termForIndex*(rf: RaftLog, index: RaftLogIndex): Option[RaftNodeTerm] =
# TODO: snapshot support # TODO: snapshot support
assert rf.logEntries.len > index assert rf.logEntries.len > index - rf.firstIndex, $rf.logEntries.len & " " & $index & "" & $rf
if rf.logEntries.len > 0 and index >= rf.firstIndex: if rf.logEntries.len > 0 and index >= rf.firstIndex:
return some(rf.logEntries[index].term) return some(rf.logEntries[index - rf.firstIndex].term)
return none(RaftNodeTerm) return none(RaftNodeTerm)

View File

@ -5,6 +5,7 @@ import ../src/raft/tracker
import ../src/raft/state import ../src/raft/state
import std/[times, sequtils, random] import std/[times, sequtils, random]
import std/sugar
import std/sets import std/sets
import std/json import std/json
import std/jsonutils import std/jsonutils
@ -32,6 +33,7 @@ type
SignedLogEntry = object SignedLogEntry = object
hash: Hash hash: Hash
logIndex: RaftLogIndex
signature: SignedShare signature: SignedShare
BLSTestNode* = ref object BLSTestNode* = ref object
@ -46,11 +48,23 @@ type
BLSTestCluster* = object BLSTestCluster* = object
nodes*: Table[RaftnodeId, BLSTestNode] nodes*: Table[RaftnodeId, BLSTestNode]
delayer*: MessageDelayer
SecretShare = object SecretShare = object
secret: SecretKey secret: SecretKey
id: ID id: ID
DelayedMessage* = object
msg: SignedRpcMessage
executeAt: times.DateTime
MessageDelayer* = object
messages: seq[DelayedMessage]
randomGenerator: Rand
meanDelay: float
stdDelay: float
minDelayMs: int
SignedShare = object SignedShare = object
sign: Signature sign: Signature
pubkey: PublicKey pubkey: PublicKey
@ -72,6 +86,26 @@ var test_ids_1 = @[
RaftnodeId(parseUUID("a8409b39-f17b-4682-aaef-a19cc9f356fb")), RaftnodeId(parseUUID("a8409b39-f17b-4682-aaef-a19cc9f356fb")),
] ]
proc initDelayer(mean: float, std: float, minInMs: int, generator: Rand): MessageDelayer =
var delayer = MessageDelayer()
delayer.meanDelay = mean
delayer.stdDelay = std
delayer.minDelayMs = minInMs
delayer.randomGenerator = generator
return delayer
proc getMessages(delayer: var MessageDelayer, now: times.DateTime): seq[SignedRpcMessage] =
result = delayer.messages.filter(m => m.executeAt <= now).map(m => m.msg)
delayer.messages = delayer.messages.filter(m => m.executeAt > now)
return result
proc add(delayer: var MessageDelayer, message: SignedRpcMessage, now: times.DateTime) =
let rndDelay = delayer.randomGenerator.gauss(delayer.meanDelay, delayer.stdDelay)
let at = now + times.initDuration(milliseconds = delayer.minDelayMs + rndDelay.int)
delayer.messages.add(DelayedMessage(msg: message, executeAt: at))
proc signs(shares: openArray[SignedShare]): seq[Signature] = proc signs(shares: openArray[SignedShare]): seq[Signature] =
shares.mapIt(it.sign) shares.mapIt(it.sign)
@ -110,6 +144,16 @@ func `$`*(de: DebugLogEntry): string =
return "[" & $de.level & "][" & de.time.format("HH:mm:ss:fff") & "][" & (($de.nodeId)[0..7]) & "...][" & $de.state & "]: " & de.msg return "[" & $de.level & "][" & de.time.format("HH:mm:ss:fff") & "][" & (($de.nodeId)[0..7]) & "...][" & $de.state & "]: " & de.msg
proc sign(node: BLSTestNode, msg: Message): SignedShare =
var pk: PublicKey
discard pk.publicFromSecret(node.keyShare.secret)
echo "Produce signature from node: " & $node.stm.myId & " with public key: " & $pk.toHex & "over msg " & $msg.toJson
return SignedShare(
sign: node.keyShare.secret.sign(msg.toBytes),
pubkey: pk,
id: node.keyShare.id,
)
proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] = proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] =
var output = node.stm.poll() var output = node.stm.poll()
var debugLogs = output.debugLogs var debugLogs = output.debugLogs
@ -122,7 +166,9 @@ proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] =
raftMsg: msg, raftMsg: msg,
signEntries: node.signEntries signEntries: node.signEntries
)) ))
node.signEntries = @[] let commitIndex = msg.appendReply.commitIndex
# remove the signature of all entries that are already commited
node.signEntries = node.signEntries.filter(x => x.logIndex > commitIndex)
else: else:
msgs.add(SignedRpcMessage( msgs.add(SignedRpcMessage(
raftMsg: msg, raftMsg: msg,
@ -135,6 +181,7 @@ proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] =
var orgMsg = commitedMsg.command.toMessage var orgMsg = commitedMsg.command.toMessage
var shares = node.messageSignatures[orgMsg.fieldInt] var shares = node.messageSignatures[orgMsg.fieldInt]
echo "Try to recover message" & $orgMsg.toBytes echo "Try to recover message" & $orgMsg.toBytes
echo "Shares: " & $shares.signs
var recoveredSignature = recover(shares.signs, shares.ids).expect("valid shares") var recoveredSignature = recover(shares.signs, shares.ids).expect("valid shares")
if not node.clusterPublicKey.verify(orgMsg.toBytes, recoveredSignature): if not node.clusterPublicKey.verify(orgMsg.toBytes, recoveredSignature):
node.us.lastCommitedMsg = orgMsg node.us.lastCommitedMsg = orgMsg
@ -148,20 +195,18 @@ proc pollMessages(node: BLSTestNode): seq[SignedRpcMessage] =
echo $msg echo $msg
return msgs return msgs
proc acceptMessage(node: BLSTestNode, msg: SignedRpcMessage, now: times.DateTime) = proc acceptMessage(node: var BLSTestNode, msg: SignedRpcMessage, now: times.DateTime) =
if msg.raftMsg.kind == RaftRpcMessageType.AppendReply and node.stm.state.isFollower: if msg.raftMsg.kind == RaftRpcMessageType.AppendRequest and node.stm.state.isFollower:
var pk: PublicKey var pk: PublicKey
discard pk.publicFromSecret(node.keyShare.secret) discard pk.publicFromSecret(node.keyShare.secret)
for entry in msg.raftMsg.appendRequest.entries: for entry in msg.raftMsg.appendRequest.entries:
if entry.kind == rletEmpty:
continue
var orgMsg = entry.command.toMessage var orgMsg = entry.command.toMessage
echo "Sign message" & $orgMsg.toBytes
var share = SignedLogEntry( var share = SignedLogEntry(
hash: orgMsg.fieldInt, hash: orgMsg.fieldInt,
signature: SignedShare( logIndex: msg.raftMsg.appendRequest.previousLogIndex + 1,
sign: node.keyShare.secret.sign(orgMsg.toBytes), signature: node.sign(orgMsg)
pubkey: pk,
id: node.keyShare.id,
)
) )
node.signEntries.add(share) node.signEntries.add(share)
node.stm.advance(msg.raftMsg, now) node.stm.advance(msg.raftMsg, now)
@ -192,19 +237,21 @@ proc generateSecretShares(sk: SecretKey, k: int, n: int): seq[SecretShare] =
let secret = genSecretShare(originPts, id) let secret = genSecretShare(originPts, id)
result.add(SecretShare(secret: secret, id: id)) result.add(SecretShare(secret: secret, id: id))
proc createBLSCluster(ids: seq[RaftnodeId], now: times.DateTime) : BLSTestCluster = proc createBLSCluster(ids: seq[RaftnodeId], now: times.DateTime, k: int, n: int, delayer: MessageDelayer) : BLSTestCluster =
var sk: SecretKey var sk: SecretKey
discard sk.fromHex("1b500388741efd98239a9b3a689721a89a92e8b209aabb10fb7dc3f844976dc2") discard sk.fromHex("1b500388741efd98239a9b3a689721a89a92e8b209aabb10fb7dc3f844976dc2")
var pk: PublicKey var pk: PublicKey
discard pk.publicFromSecret(sk) discard pk.publicFromSecret(sk)
var blsShares = generateSecretShares(sk, 2, 3) var blsShares = generateSecretShares(sk, k, n)
var config = createConfigFromIds(ids) var config = createConfigFromIds(ids)
var cluster = BLSTestCluster() var cluster = BLSTestCluster()
cluster.delayer = delayer
cluster.nodes = initTable[RaftnodeId, BLSTestNode]() cluster.nodes = initTable[RaftnodeId, BLSTestNode]()
for i in 0..<config.currentSet.len: for i in 0..<config.currentSet.len:
let id = config.currentSet[i] let id = config.currentSet[i]
var log = initRaftLog(1) var log = initRaftLog(1)
@ -223,6 +270,11 @@ proc advance*(tc: var BLSTestCluster, now: times.DateTime, logLevel: DebugLogLev
node.tick(now) node.tick(now)
var msgs = node.pollMessages() var msgs = node.pollMessages()
for msg in msgs: for msg in msgs:
tc.delayer.add(msg, now)
var msgs = tc.delayer.getMessages(now)
for msg in msgs:
echo "eloooooooooooooooooooooooooooooooo" & $ msg
tc.nodes[msg.raftMsg.receiver].acceptMessage(msg, now) tc.nodes[msg.raftMsg.receiver].acceptMessage(msg, now)
func getLeader*(tc: BLSTestCluster): Option[BLSTestNode] = func getLeader*(tc: BLSTestCluster): Option[BLSTestNode] =
@ -239,11 +291,7 @@ proc submitMessage(tc: var BLSTestCluster, msg: Message): bool =
var pk: PublicKey var pk: PublicKey
discard pk.publicFromSecret(leader.get.keyShare.secret) discard pk.publicFromSecret(leader.get.keyShare.secret)
echo "Leader Sign message" & $msg.toBytes echo "Leader Sign message" & $msg.toBytes
var share = SignedShare( var share = leader.get().sign(msg)
sign: leader.get.keyShare.secret.sign(msg.toBytes),
pubkey: pk,
id: leader.get.keyShare.id,
)
if not leader.get.messageSignatures.hasKey(msg.fieldInt): if not leader.get.messageSignatures.hasKey(msg.fieldInt):
leader.get.messageSignatures[msg.fieldInt] = @[] leader.get.messageSignatures[msg.fieldInt] = @[]
leader.get.messageSignatures[msg.fieldInt].add(share) leader.get.messageSignatures[msg.fieldInt].add(share)
@ -252,38 +300,45 @@ proc submitMessage(tc: var BLSTestCluster, msg: Message): bool =
proc blsconsensusMain*() = proc blsconsensusMain*() =
suite "BLS consensus tests": suite "BLS consensus tests":
# test "create single node cluster": test "create single node cluster":
# var timeNow = dateTime(2017, mMar, 01, 00, 00, 00, 00, utc()) var timeNow = dateTime(2017, mMar, 01, 00, 00, 00, 00, utc())
# var cluster = createBLSCluster(test_ids_1, timeNow) var delayer = initDelayer(3, 3, 1, initRand(42))
var cluster = createBLSCluster(test_ids_1, timeNow, 1, 1, delayer)
# timeNow += 300.milliseconds timeNow += 300.milliseconds
# cluster.advance(timeNow) cluster.advance(timeNow)
# echo cluster.getLeader().get().stm.state echo cluster.getLeader().get().stm.state
# discard cluster.submitMessage(Message(fieldInt: 1)) discard cluster.submitMessage(Message(fieldInt: 1))
# discard cluster.submitMessage(Message(fieldInt: 2)) discard cluster.submitMessage(Message(fieldInt: 2))
# for i in 0..<305: for i in 0..<305:
# timeNow += 5.milliseconds timeNow += 5.milliseconds
# cluster.advance(timeNow) cluster.advance(timeNow)
# echo "Helloo" & $cluster.getLeader().get.us.lastCommitedMsg echo "Helloo" & $cluster.getLeader().get.us.lastCommitedMsg
test "create 3 node cluster": test "create 3 node cluster":
var timeNow = dateTime(2017, mMar, 01, 00, 00, 00, 00, utc()) var timeNow = dateTime(2017, mMar, 01, 00, 00, 00, 00, utc())
var cluster = createBLSCluster(test_ids_3, timeNow) var delayer = initDelayer(3, 3, 1, initRand(42))
var cluster = createBLSCluster(test_ids_3, timeNow, 2, 3, delayer)
#timeNow += 300.milliseconds # skip time until first election
timeNow += 200.milliseconds
cluster.advance(timeNow) cluster.advance(timeNow)
var added = false var added = false
for i in 0..<305: var commited = false
for i in 0..<50:
cluster.advance(timeNow) cluster.advance(timeNow)
if cluster.getLeader().isSome() and not added: if cluster.getLeader().isSome() and not added:
discard cluster.submitMessage(Message(fieldInt: 1)) discard cluster.submitMessage(Message(fieldInt: 42))
added = true added = true
echo "Add to the entry log" echo "Add to the entry log"
timeNow += 5.milliseconds timeNow += 5.milliseconds
if cluster.getLeader().isSome():
#echo $cluster.nodes echo cluster.getLeader().get.us.lastCommitedMsg
echo "Last state" & $cluster.getLeader().get.us.lastCommitedMsg if cluster.getLeader().get.us.lastCommitedMsg.fieldInt == 42:
commited = true
#break
check commited == true
if isMainModule: if isMainModule:

View File

@ -231,8 +231,11 @@ proc consensusstatemachineMain*() =
timeNow += 500.milliseconds timeNow += 500.milliseconds
sm.tick(timeNow) sm.tick(timeNow)
output = sm.poll() output = sm.poll()
echo output
check output.logEntries.len == 0 check output.logEntries.len == 0
check output.committed.len == 0 # When the node became a leader it will produce empty message in the log
# and because we have single node cluster the node will commit that entry immediately
check output.committed.len == 1
check output.messages.len == 0 check output.messages.len == 0
check sm.state.isLeader check sm.state.isLeader
check sm.term == 1 check sm.term == 1
@ -247,7 +250,9 @@ proc consensusstatemachineMain*() =
sm.tick(timeNow) sm.tick(timeNow)
var output = sm.poll() var output = sm.poll()
check output.logEntries.len == 0 check output.logEntries.len == 0
check output.committed.len == 0 # When the node became a leader it will produce empty message in the log
# and because we have single node cluster the node will commit that entry immediately
check output.committed.len == 1
check output.messages.len == 0 check output.messages.len == 0
check sm.state.isLeader check sm.state.isLeader
sm.addEntry(Empty()) sm.addEntry(Empty())