feat: persistence interface for SDS state (#66)

This commit is contained in:
Darshan 2026-05-08 03:14:12 +05:30 committed by GitHub
parent 9d08f5995b
commit 881d8cb359
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 705 additions and 17 deletions

66
sds.nim
View File

@ -7,10 +7,13 @@ export types, protobuf, sds_utils, rolling_bloom_filter
proc newReliabilityManager*(
config: ReliabilityConfig = defaultConfig(),
participantId: SdsParticipantID = "".SdsParticipantID,
persistence: Persistence = noOpPersistence(),
): Result[ReliabilityManager, ReliabilityError] =
## Creates a new multi-channel ReliabilityManager.
## `persistence` defaults to a no-op backend; supply a real one to durably
## store SDS state across restarts.
try:
let rm = ReliabilityManager.new(config, participantId)
let rm = ReliabilityManager.new(config, participantId, persistence)
return ok(rm)
except Exception:
error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg()
@ -53,7 +56,7 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
return
let channel = rm.channels[msg.channelId]
var toDelete: seq[int] = @[]
var toDelete: seq[(int, SdsMessageID)] = @[]
var i = 0
while i < channel.outgoingBuffer.len:
@ -61,11 +64,13 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
if outMsg.isAcknowledged(msg.causalHistory, rbf):
if not rm.onMessageSent.isNil():
rm.onMessageSent(outMsg.message.messageId, outMsg.message.channelId)
toDelete.add(i)
toDelete.add((i, outMsg.message.messageId))
inc i
for i in countdown(toDelete.high, 0):
channel.outgoingBuffer.delete(toDelete[i])
for k in countdown(toDelete.high, 0):
let (idx, ackedId) = toDelete[k]
channel.outgoingBuffer.delete(idx)
rm.persistence.removeOutgoing(msg.channelId, ackedId)
proc wrapOutgoingMessage*(
rm: ReliabilityManager,
@ -108,6 +113,7 @@ proc wrapOutgoingMessage*(
expiredKeys.add(eligible[i][0])
for key in expiredKeys:
channel.outgoingRepairBuffer.del(key)
rm.persistence.removeOutgoingRepair(channelId, key)
let msg = SdsMessage.init(
messageId = messageId,
@ -120,9 +126,10 @@ proc wrapOutgoingMessage*(
repairRequest = repairReqs,
)
channel.outgoingBuffer.add(
let unackMsg =
UnacknowledgedMessage.init(message = msg, sendTime = getTime(), resendAttempts = 0)
)
channel.outgoingBuffer.add(unackMsg)
rm.persistence.saveOutgoing(channelId, unackMsg)
channel.bloomFilter.add(msg.messageId)
# The full SdsMessage carries senderId and content, so a single
@ -168,11 +175,15 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
if remainingId notin processed:
if msgId in entry.missingDeps:
channel.incomingBuffer[remainingId].missingDeps.excl(msgId)
rm.persistence.saveIncoming(
channelId, channel.incomingBuffer[remainingId]
)
if channel.incomingBuffer[remainingId].missingDeps.len == 0:
readyToProcess.add(remainingId)
for msgId in processed:
channel.incomingBuffer.del(msgId)
rm.persistence.removeIncoming(channelId, msgId)
proc unwrapReceivedMessage*(
rm: ReliabilityManager, message: seq[byte]
@ -193,7 +204,9 @@ proc unwrapReceivedMessage*(
# SDS-R: opportunistic repair-buffer cleanup — applies to duplicates too,
# so rebroadcasts cancel redundant responses on peers that already have the message.
channel.outgoingRepairBuffer.del(msg.messageId)
rm.persistence.removeOutgoingRepair(channelId, msg.messageId)
channel.incomingRepairBuffer.del(msg.messageId)
rm.persistence.removeIncomingRepair(channelId, msg.messageId)
if msg.messageId in channel.messageHistory:
return ok((msg.content, @[], channelId))
@ -211,6 +224,7 @@ proc unwrapReceivedMessage*(
for repairEntry in msg.repairRequest:
# Remove from our own outgoing repair buffer (someone else is also requesting)
channel.outgoingRepairBuffer.del(repairEntry.messageId)
rm.persistence.removeOutgoingRepair(channelId, repairEntry.messageId)
if repairEntry.messageId in channel.messageHistory and
rm.participantId.len > 0 and repairEntry.senderId.len > 0:
if isInResponseGroup(
@ -223,11 +237,13 @@ proc unwrapReceivedMessage*(
rm.participantId, repairEntry.senderId,
repairEntry.messageId, rm.config.repairTMax
)
channel.incomingRepairBuffer[repairEntry.messageId] = IncomingRepairEntry(
let inEntry = IncomingRepairEntry(
inHistEntry: repairEntry,
cachedMessage: serialized.get(),
minTimeRepairResp: now + tResp,
)
channel.incomingRepairBuffer[repairEntry.messageId] = inEntry
rm.persistence.saveIncomingRepair(channelId, repairEntry.messageId, inEntry)
var missingDeps = rm.checkDependencies(msg.causalHistory, channelId)
@ -238,23 +254,30 @@ proc unwrapReceivedMessage*(
depsInBuffer = true
break
if depsInBuffer:
channel.incomingBuffer[msg.messageId] =
let entry =
IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]())
channel.incomingBuffer[msg.messageId] = entry
rm.persistence.saveIncoming(channelId, entry)
else:
rm.addToHistory(msg, channelId)
# Unblock any buffered messages that were waiting on this one.
var unblocked: seq[SdsMessageID] = @[]
for pendingId, entry in channel.incomingBuffer:
if msg.messageId in entry.missingDeps:
channel.incomingBuffer[pendingId].missingDeps.excl(msg.messageId)
unblocked.add(pendingId)
for pendingId in unblocked:
rm.persistence.saveIncoming(channelId, channel.incomingBuffer[pendingId])
rm.processIncomingBuffer(channelId)
if not rm.onMessageReady.isNil():
rm.onMessageReady(msg.messageId, channelId)
else:
channel.incomingBuffer[msg.messageId] =
IncomingMessage.init(
message = msg,
missingDeps = missingDeps.getMessageIds().toHashSet(),
)
let entry = IncomingMessage.init(
message = msg,
missingDeps = missingDeps.getMessageIds().toHashSet(),
)
channel.incomingBuffer[msg.messageId] = entry
rm.persistence.saveIncoming(channelId, entry)
if not rm.onMissingDependencies.isNil():
rm.onMissingDependencies(msg.messageId, missingDeps, channelId)
@ -266,10 +289,12 @@ proc unwrapReceivedMessage*(
rm.participantId, dep.messageId,
rm.config.repairTMin, rm.config.repairTMax
)
channel.outgoingRepairBuffer[dep.messageId] = OutgoingRepairEntry(
let outEntry = OutgoingRepairEntry(
outHistEntry: dep,
minTimeRepairReq: now + tReq,
)
channel.outgoingRepairBuffer[dep.messageId] = outEntry
rm.persistence.saveOutgoingRepair(channelId, dep.messageId, outEntry)
return ok((msg.content, missingDeps, channelId))
except Exception:
@ -290,13 +315,19 @@ proc markDependenciesMet*(
if not channel.bloomFilter.contains(msgId):
channel.bloomFilter.add(msgId)
var unblocked: seq[SdsMessageID] = @[]
for pendingId, entry in channel.incomingBuffer:
if msgId in entry.missingDeps:
channel.incomingBuffer[pendingId].missingDeps.excl(msgId)
unblocked.add(pendingId)
for pendingId in unblocked:
rm.persistence.saveIncoming(channelId, channel.incomingBuffer[pendingId])
# SDS-R: clear from repair buffers (dependency now met)
channel.outgoingRepairBuffer.del(msgId)
rm.persistence.removeOutgoingRepair(channelId, msgId)
channel.incomingRepairBuffer.del(msgId)
rm.persistence.removeIncomingRepair(channelId, msgId)
rm.processIncomingBuffer(channelId)
return ok()
@ -343,9 +374,11 @@ proc checkUnacknowledgedMessages(
updatedMsg.resendAttempts += 1
updatedMsg.sendTime = now
newOutgoingBuffer.add(updatedMsg)
rm.persistence.saveOutgoing(channelId, updatedMsg)
else:
if not rm.onMessageSent.isNil():
rm.onMessageSent(unackMsg.message.messageId, channelId)
rm.persistence.removeOutgoing(channelId, unackMsg.message.messageId)
else:
newOutgoingBuffer.add(unackMsg)
@ -400,6 +433,7 @@ proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
for msgId in toRebroadcast:
let entry = channel.incomingRepairBuffer[msgId]
channel.incomingRepairBuffer.del(msgId)
rm.persistence.removeIncomingRepair(channelId, msgId)
if not rm.onRepairReady.isNil():
rm.onRepairReady(entry.cachedMessage, channelId)
@ -411,6 +445,7 @@ proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
toRemove.add(msgId)
for msgId in toRemove:
channel.outgoingRepairBuffer.del(msgId)
rm.persistence.removeOutgoingRepair(channelId, msgId)
except Exception:
error "Error in repair sweep for channel",
channelId = channelId, msg = getCurrentExceptionMsg()
@ -436,6 +471,7 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE
withLock rm.lock:
try:
for channelId, channel in rm.channels:
rm.dropChannelFromPersistence(channelId)
channel.lamportTimestamp = 0
channel.messageHistory.clear()
channel.outgoingBuffer.setLen(0)

View File

@ -67,6 +67,7 @@ proc getMyCpu(): string =
task test, "Run the test suite":
exec "nim c -r tests/test_bloom.nim"
exec "nim c -r tests/test_reliability.nim"
exec "nim c -r tests/test_persistence.nim"
task libsdsDynamicWindows, "Generate bindings":
let outLibNameAndExt = "libsds.dll"

View File

@ -14,7 +14,19 @@ export
proc defaultConfig*(): ReliabilityConfig =
return ReliabilityConfig.init()
proc dropChannelFromPersistence*(
rm: ReliabilityManager, channelId: SdsChannelID
) {.gcsafe, raises: [].} =
## Wipes all persisted state for a channel via a single backend call.
## Called by removeChannel / resetReliabilityManager before they clear
## in-memory state. Backend executes the wipe in one transaction.
rm.persistence.dropChannel(channelId)
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
## Releases in-memory state. Does NOT wipe persistence — the manager may be
## reconstructed against the same backend after cleanup, so disk state must
## survive. For deliberate disk wipe, use `removeChannel` or
## `resetReliabilityManager`.
if not rm.isNil():
try:
withLock rm.lock:
@ -50,12 +62,14 @@ proc addToHistory*(
if channelId in rm.channels:
let channel = rm.channels[channelId]
channel.messageHistory[msg.messageId] = msg
rm.persistence.appendLogEntry(channelId, msg)
while channel.messageHistory.len > rm.config.maxMessageHistory:
var firstKey: SdsMessageID
for k in channel.messageHistory.keys:
firstKey = k
break
channel.messageHistory.del(firstKey)
rm.persistence.removeLogEntry(channelId, firstKey)
except Exception:
error "Failed to add to history",
channelId = channelId, msgId = msg.messageId, error = getCurrentExceptionMsg()
@ -67,6 +81,7 @@ proc updateLamportTimestamp*(
if channelId in rm.channels:
let channel = rm.channels[channelId]
channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1
rm.persistence.saveLamport(channelId, channel.lamportTimestamp)
except Exception:
error "Failed to update lamport timestamp",
channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg()
@ -150,6 +165,8 @@ proc getRecentHistoryEntries*(
var entry = HistoryEntry(messageId: msgId)
if not rm.onRetrievalHint.isNil():
entry.retrievalHint = rm.onRetrievalHint(msgId)
if entry.retrievalHint.len > 0:
rm.persistence.setRetrievalHint(msgId, entry.retrievalHint)
entry.senderId = channel.messageHistory[msgId].senderId
entries.add(entry)
return entries
@ -226,11 +243,29 @@ proc getIncomingBuffer*(
proc getOrCreateChannel*(
rm: ReliabilityManager, channelId: SdsChannelID
): ChannelContext =
## Returns the channel context, creating and bootstrapping it from the
## persistence backend if it does not yet exist in memory. The bloom filter
## is rebuilt deterministically from the loaded message history rather than
## persisted directly. Caller is expected to hold rm.lock.
try:
if channelId notin rm.channels:
rm.channels[channelId] = ChannelContext.new(
let channel = ChannelContext.new(
RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate)
)
let snapshot = rm.persistence.loadAllForChannel(channelId)
channel.lamportTimestamp = snapshot.lamportTimestamp
for msg in snapshot.messageHistory:
channel.messageHistory[msg.messageId] = msg
channel.bloomFilter.add(msg.messageId)
for unack in snapshot.outgoingBuffer:
channel.outgoingBuffer.add(unack)
for incoming in snapshot.incomingBuffer:
channel.incomingBuffer[incoming.message.messageId] = incoming
for (msgId, entry) in snapshot.outgoingRepairBuffer:
channel.outgoingRepairBuffer[msgId] = entry
for (msgId, entry) in snapshot.incomingRepairBuffer:
channel.incomingRepairBuffer[msgId] = entry
rm.channels[channelId] = channel
return rm.channels[channelId]
except Exception:
error "Failed to get or create channel",
@ -256,6 +291,7 @@ proc removeChannel*(
try:
if channelId in rm.channels:
let channel = rm.channels[channelId]
rm.dropChannelFromPersistence(channelId)
channel.outgoingBuffer.setLen(0)
channel.incomingBuffer.clear()
channel.messageHistory.clear()

View File

@ -11,6 +11,7 @@ import sds/types/app_callbacks
import sds/types/reliability_config
import sds/types/repair_entry
import sds/types/channel_context
import sds/types/persistence
import sds/types/reliability_manager
import sds/types/protobuf_error
@ -28,5 +29,6 @@ export
reliability_config,
repair_entry,
channel_context,
persistence,
reliability_manager,
protobuf_error

126
sds/types/persistence.nim Normal file
View File

@ -0,0 +1,126 @@
import ./sds_message_id
import ./sds_message
import ./unacknowledged_message
import ./incoming_message
import ./repair_entry
export
sds_message_id, sds_message, unacknowledged_message, incoming_message, repair_entry
## SDS state persistence interface (issue #64).
##
## Defines WHAT operations a persistence backend must provide. The actual
## storage technology (SQLite, encrypted file, in-memory) is supplied by the
## caller — nim-sds knows nothing about it. Every state-mutating proc in the
## protocol calls into one of these procs immediately after the in-memory
## change, so on-disk state stays in lockstep with in-memory state.
##
## Bloom filter is intentionally not persisted: it is rebuilt from the local
## history log on bootstrap. Async timers are likewise recomputed from the
## absolute timestamps stored in the repair buffer entries.
type
ChannelSnapshot* = object
## Returned by `loadAllForChannel` on bootstrap. Carries the entire
## per-channel state needed to repopulate a `ChannelContext`. The bloom
## filter is NOT in the snapshot — callers rebuild it from `messageHistory`.
lamportTimestamp*: int64
messageHistory*: seq[SdsMessage]
## MUST be ordered oldest-first. FIFO eviction relies on insertion order;
## skipping ORDER BY corrupts the log across restarts.
outgoingBuffer*: seq[UnacknowledgedMessage]
incomingBuffer*: seq[IncomingMessage]
outgoingRepairBuffer*: seq[(SdsMessageID, OutgoingRepairEntry)]
incomingRepairBuffer*: seq[(SdsMessageID, IncomingRepairEntry)]
Persistence* = object
## Pluggable persistence contract. The caller supplies an instance of this
## type at `newReliabilityManager` construction time. Each proc field is
## invoked by nim-sds at the corresponding state-mutation point.
# Per-channel lamport clock
saveLamport*:
proc(channelId: SdsChannelID, lamport: int64) {.gcsafe, raises: [].}
# Local log (delivered messages)
appendLogEntry*:
proc(channelId: SdsChannelID, msg: SdsMessage) {.gcsafe, raises: [].}
removeLogEntry*:
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
setRetrievalHint*:
proc(msgId: SdsMessageID, hint: seq[byte]) {.gcsafe, raises: [].}
# Outgoing unacknowledged buffer
saveOutgoing*:
proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) {.gcsafe, raises: [].}
removeOutgoing*:
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
# Incoming dependency-waiting buffer
saveIncoming*:
proc(channelId: SdsChannelID, msg: IncomingMessage) {.gcsafe, raises: [].}
removeIncoming*:
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
# SDS-R outgoing repair buffer
saveOutgoingRepair*: proc(
channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry
) {.gcsafe, raises: [].}
removeOutgoingRepair*:
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
# SDS-R incoming repair buffer
saveIncomingRepair*: proc(
channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry
) {.gcsafe, raises: [].}
removeIncomingRepair*:
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
# Wipe all persisted state for a channel in one transactional call.
# Called by removeChannel / resetReliabilityManager. Backends should
# implement this atomically (e.g. one BEGIN/COMMIT) — a per-row loop on
# the nim-sds side would mean N fsyncs per drop.
dropChannel*:
proc(channelId: SdsChannelID) {.gcsafe, raises: [].}
# Bootstrap on `addChannel` / `getOrCreateChannel`.
loadAllForChannel*:
proc(channelId: SdsChannelID): ChannelSnapshot {.gcsafe, raises: [].}
proc noOpPersistence*(): Persistence =
## Default backend that discards every write and returns an empty snapshot.
## Used so existing callers (and tests) that don't care about durability
## keep working without supplying a real backend.
Persistence(
saveLamport: proc(channelId: SdsChannelID, lamport: int64) =
discard,
appendLogEntry: proc(channelId: SdsChannelID, msg: SdsMessage) =
discard,
removeLogEntry: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
discard,
setRetrievalHint: proc(msgId: SdsMessageID, hint: seq[byte]) =
discard,
saveOutgoing: proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) =
discard,
removeOutgoing: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
discard,
saveIncoming: proc(channelId: SdsChannelID, msg: IncomingMessage) =
discard,
removeIncoming: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
discard,
saveOutgoingRepair: proc(
channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry
) =
discard,
removeOutgoingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
discard,
saveIncomingRepair: proc(
channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry
) =
discard,
removeIncomingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
discard,
dropChannel: proc(channelId: SdsChannelID) =
discard,
loadAllForChannel: proc(channelId: SdsChannelID): ChannelSnapshot =
ChannelSnapshot(),
)

View File

@ -1,4 +1,5 @@
import std/times
import chronicles
const
DefaultMaxMessageHistory* = 1000
@ -49,6 +50,13 @@ proc init*(
maxRepairRequests: int = DefaultMaxRepairRequests,
repairSweepInterval: Duration = DefaultRepairSweepInterval,
): T =
# Bloom is rebuilt by replaying messageHistory on restart and is also the
# outgoing summary peers see. A bloom smaller than the log causes continuous
# clean() churn and incomplete summaries to peers, with no compensating gain.
if maxMessageHistory > bloomFilterCapacity:
warn "maxMessageHistory > bloomFilterCapacity will cause continuous bloom rebuilds and incomplete summaries to peers; reduce maxMessageHistory or increase bloomFilterCapacity unless you have a specific reason",
maxMessageHistory = maxMessageHistory,
bloomFilterCapacity = bloomFilterCapacity
return T(
bloomFilterCapacity: bloomFilterCapacity,
bloomFilterErrorRate: bloomFilterErrorRate,

View File

@ -4,12 +4,17 @@ import ./history_entry
import ./callbacks
import ./reliability_config
import ./channel_context
export sds_message_id, history_entry, callbacks, reliability_config, channel_context
import ./persistence
export
sds_message_id, history_entry, callbacks, reliability_config, channel_context,
persistence
type ReliabilityManager* = ref object
channels*: Table[SdsChannelID, ChannelContext]
config*: ReliabilityConfig
participantId*: SdsParticipantID
persistence*: Persistence
## Pluggable durability backend; defaults to a no-op when not supplied.
lock*: Lock
onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
@ -24,11 +29,13 @@ proc new*(
T: type ReliabilityManager,
config: ReliabilityConfig,
participantId: SdsParticipantID = "".SdsParticipantID,
persistence: Persistence = noOpPersistence(),
): T =
let rm = T(
channels: initTable[SdsChannelID, ChannelContext](),
config: config,
participantId: participantId,
persistence: persistence,
)
rm.lock.initLock()
return rm

View File

@ -0,0 +1,127 @@
import std/tables
import sds
## Test-only Persistence backend backed by Nim tables. Lets tests verify the
## full write → restart → read-back loop without depending on SQLite (or any
## real storage technology). Exposes the underlying store so tests can assert
## on what got saved.
type InMemoryStore* = ref object
lamports*: Table[SdsChannelID, int64]
log*: Table[SdsChannelID, OrderedTable[SdsMessageID, SdsMessage]]
hints*: Table[SdsMessageID, seq[byte]]
outgoing*: Table[SdsChannelID, OrderedTable[SdsMessageID, UnacknowledgedMessage]]
incoming*: Table[SdsChannelID, OrderedTable[SdsMessageID, IncomingMessage]]
outgoingRepair*: Table[SdsChannelID, OrderedTable[SdsMessageID, OutgoingRepairEntry]]
incomingRepair*: Table[SdsChannelID, OrderedTable[SdsMessageID, IncomingRepairEntry]]
dropChannelCalls*: Table[SdsChannelID, int]
## Per-channel counter; lets tests assert dropChannel is invoked exactly
## once per logical drop (not N times — see PR #66 review).
proc newInMemoryStore*(): InMemoryStore =
InMemoryStore()
proc newInMemoryPersistence*(store: InMemoryStore): Persistence =
Persistence(
saveLamport: proc(channelId: SdsChannelID, lamport: int64) {.gcsafe, raises: [].} =
store.lamports[channelId] = lamport,
appendLogEntry: proc(channelId: SdsChannelID, msg: SdsMessage) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId notin store.log:
store.log[channelId] = initOrderedTable[SdsMessageID, SdsMessage]()
store.log[channelId][msg.messageId] = msg,
removeLogEntry: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId in store.log:
store.log[channelId].del(msgId),
setRetrievalHint: proc(msgId: SdsMessageID, hint: seq[byte]) {.gcsafe, raises: [].} =
store.hints[msgId] = hint,
saveOutgoing: proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId notin store.outgoing:
store.outgoing[channelId] =
initOrderedTable[SdsMessageID, UnacknowledgedMessage]()
store.outgoing[channelId][msg.message.messageId] = msg,
removeOutgoing: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId in store.outgoing:
store.outgoing[channelId].del(msgId),
saveIncoming: proc(channelId: SdsChannelID, msg: IncomingMessage) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId notin store.incoming:
store.incoming[channelId] =
initOrderedTable[SdsMessageID, IncomingMessage]()
store.incoming[channelId][msg.message.messageId] = msg,
removeIncoming: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId in store.incoming:
store.incoming[channelId].del(msgId),
saveOutgoingRepair: proc(
channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry
) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId notin store.outgoingRepair:
store.outgoingRepair[channelId] =
initOrderedTable[SdsMessageID, OutgoingRepairEntry]()
store.outgoingRepair[channelId][msgId] = entry,
removeOutgoingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId in store.outgoingRepair:
store.outgoingRepair[channelId].del(msgId),
saveIncomingRepair: proc(
channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry
) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId notin store.incomingRepair:
store.incomingRepair[channelId] =
initOrderedTable[SdsMessageID, IncomingRepairEntry]()
store.incomingRepair[channelId][msgId] = entry,
removeIncomingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
if channelId in store.incomingRepair:
store.incomingRepair[channelId].del(msgId),
dropChannel: proc(channelId: SdsChannelID) {.gcsafe, raises: [].} =
{.cast(raises: []).}:
store.lamports.del(channelId)
store.log.del(channelId)
store.outgoing.del(channelId)
store.incoming.del(channelId)
store.outgoingRepair.del(channelId)
store.incomingRepair.del(channelId)
store.dropChannelCalls[channelId] =
store.dropChannelCalls.getOrDefault(channelId) + 1,
loadAllForChannel: proc(channelId: SdsChannelID): ChannelSnapshot {.gcsafe, raises: [].} =
{.cast(raises: []).}:
var snap = ChannelSnapshot()
if channelId in store.lamports:
snap.lamportTimestamp = store.lamports[channelId]
if channelId in store.log:
for msg in store.log[channelId].values:
snap.messageHistory.add(msg)
if channelId in store.outgoing:
for unack in store.outgoing[channelId].values:
snap.outgoingBuffer.add(unack)
if channelId in store.incoming:
for incoming in store.incoming[channelId].values:
snap.incomingBuffer.add(incoming)
if channelId in store.outgoingRepair:
for msgId, entry in store.outgoingRepair[channelId]:
snap.outgoingRepairBuffer.add((msgId, entry))
if channelId in store.incomingRepair:
for msgId, entry in store.incomingRepair[channelId]:
snap.incomingRepairBuffer.add((msgId, entry))
snap,
)

345
tests/test_persistence.nim Normal file
View File

@ -0,0 +1,345 @@
import unittest, results, std/[tables, sets, times]
import sds
import ./in_memory_persistence
converter toParticipantID(s: string): SdsParticipantID = s.SdsParticipantID
const testChannel = "testChannel"
suite "Persistence: write → restart → read-back":
test "outgoing buffer survives restart":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
let wrapped = rm1.wrapOutgoingMessage(@[1.byte, 2, 3], "msg-1", testChannel)
check wrapped.isOk()
check store.outgoing[testChannel].len == 1
check "msg-1" in store.outgoing[testChannel]
rm1.cleanup()
# Simulate restart: fresh manager, same backend
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
let buf = rm2.getOutgoingBuffer(testChannel)
check buf.len == 1
check buf[0].message.messageId == "msg-1"
rm2.cleanup()
test "lamport clock survives restart":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
rm1.updateLamportTimestamp(42, testChannel)
check store.lamports[testChannel] == 43 # max(42, 0) + 1
rm1.cleanup()
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
check rm2.channels[testChannel].lamportTimestamp == 43
test "delivered messages survive restart and rebuild bloom":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
let msg = SdsMessage.init(
messageId = "delivered-1",
lamportTimestamp = 1,
causalHistory = @[],
channelId = testChannel,
content = @[9.byte, 9],
bloomFilter = @[],
senderId = "alice",
)
rm1.addToHistory(msg, testChannel)
check store.log[testChannel].len == 1
rm1.cleanup()
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
let ch = rm2.channels[testChannel]
check ch.messageHistory.len == 1
check "delivered-1" in ch.messageHistory
# Bloom filter rebuilt from log on bootstrap
check ch.bloomFilter.contains("delivered-1")
test "ack removes outgoing entry from persistence":
let store = newInMemoryStore()
let p = newInMemoryPersistence(store)
let rm = newReliabilityManager(persistence = p).get()
check rm.ensureChannel(testChannel).isOk()
discard rm.wrapOutgoingMessage(@[1.byte], "msg-x", testChannel)
check "msg-x" in store.outgoing[testChannel]
# Synthesize an incoming message that ACKs msg-x via causal history
let ackMsg = SdsMessage.init(
messageId = "ack-bearer",
lamportTimestamp = 5,
causalHistory = @[HistoryEntry.init("msg-x", @[])],
channelId = testChannel,
content = @[],
bloomFilter = @[],
senderId = "bob",
)
let serialized = serializeMessage(ackMsg).get()
discard rm.unwrapReceivedMessage(serialized)
check "msg-x" notin store.outgoing[testChannel]
rm.cleanup()
test "removeChannel issues exactly one dropChannel call and wipes all state":
# Regression for PR #66 review: removal must be a single transactional
# drop, not N per-row removes — otherwise SQLite eats N fsyncs per drop.
let store = newInMemoryStore()
let p = newInMemoryPersistence(store)
let rm = newReliabilityManager(persistence = p).get()
check rm.ensureChannel(testChannel).isOk()
discard rm.wrapOutgoingMessage(@[1.byte], "msg-r", testChannel)
check store.outgoing[testChannel].len == 1
check store.lamports[testChannel] > 0
check rm.removeChannel(testChannel).isOk()
check store.dropChannelCalls.getOrDefault(testChannel) == 1
check testChannel notin store.outgoing
check testChannel notin store.lamports
check testChannel notin store.log
check testChannel notin store.incoming
check testChannel notin store.outgoingRepair
check testChannel notin store.incomingRepair
rm.cleanup()
test "noOpPersistence keeps existing manager working":
let rm = newReliabilityManager().get() # default no-op
check rm.ensureChannel(testChannel).isOk()
let wrapped = rm.wrapOutgoingMessage(@[1.byte], "msg-n", testChannel)
check wrapped.isOk()
check rm.getOutgoingBuffer(testChannel).len == 1
rm.cleanup()
test "continue operating after restart: lamport stays monotonic":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
discard rm1.wrapOutgoingMessage(@[1.byte], "m1", testChannel)
let lamportAfterSession1 = store.lamports[testChannel]
check lamportAfterSession1 > 0
rm1.cleanup()
# Restart and send another message — lamport must not regress.
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
check rm2.channels[testChannel].lamportTimestamp == lamportAfterSession1
discard rm2.wrapOutgoingMessage(@[2.byte], "m2", testChannel)
check store.lamports[testChannel] > lamportAfterSession1
check rm2.getOutgoingBuffer(testChannel).len == 2
rm2.cleanup()
test "multiple restart cycles preserve state":
let store = newInMemoryStore()
for i in 1 .. 3:
let p = newInMemoryPersistence(store)
let rm = newReliabilityManager(persistence = p).get()
check rm.ensureChannel(testChannel).isOk()
discard rm.wrapOutgoingMessage(@[byte(i)], "m" & $i, testChannel)
rm.cleanup()
# Final session: all three messages must be in the buffer.
let pFinal = newInMemoryPersistence(store)
let rmFinal = newReliabilityManager(persistence = pFinal).get()
check rmFinal.ensureChannel(testChannel).isOk()
let buf = rmFinal.getOutgoingBuffer(testChannel)
check buf.len == 3
var ids = newSeq[string]()
for unack in buf:
ids.add(unack.message.messageId.string)
check "m1" in ids
check "m2" in ids
check "m3" in ids
rmFinal.cleanup()
test "incoming dep-waiting buffer survives restart with missingDeps intact":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
# Receive a message whose causal-history references an unknown predecessor.
let depMsg = SdsMessage.init(
messageId = "msg-with-deps",
lamportTimestamp = 10,
causalHistory = @[HistoryEntry.init("missing-dep", @[])],
channelId = testChannel,
content = @[7.byte],
bloomFilter = @[],
senderId = "carol",
)
let serialized = serializeMessage(depMsg).get()
discard rm1.unwrapReceivedMessage(serialized)
check "msg-with-deps" in store.incoming[testChannel]
rm1.cleanup()
# Restart — buffered message and its missing-deps set must be back.
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
let inbuf = rm2.getIncomingBuffer(testChannel)
check "msg-with-deps" in inbuf
check "missing-dep" in inbuf["msg-with-deps"].missingDeps
rm2.cleanup()
test "removeChannel + recreate does not inherit stale lamport":
# Regression: dropChannel must wipe the lamport row; otherwise a recreate
# of the same channelId after restart picks up the old timestamp.
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
discard rm1.wrapOutgoingMessage(@[1.byte], "m-old", testChannel)
check store.lamports[testChannel] > 0
check rm1.removeChannel(testChannel).isOk()
check testChannel notin store.lamports
rm1.cleanup()
# Recreate the same channelId after a restart — must start fresh.
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
check rm2.channels[testChannel].lamportTimestamp == 0
check rm2.getOutgoingBuffer(testChannel).len == 0
rm2.cleanup()
test "SDS-R outgoing repair buffer survives restart with absolute t_req_at":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(
participantId = "alice", persistence = p1
).get()
check rm1.ensureChannel(testChannel).isOk()
# Receive a message that references an unknown dep — triggers SDS-R repair.
let depMsg = SdsMessage.init(
messageId = "msg-needs-repair",
lamportTimestamp = 5,
causalHistory = @[HistoryEntry.init("missing-dep", @[])],
channelId = testChannel,
content = @[1.byte],
bloomFilter = @[],
senderId = "bob",
)
discard rm1.unwrapReceivedMessage(serializeMessage(depMsg).get())
check "missing-dep" in store.outgoingRepair[testChannel]
let originalTReqAt = store.outgoingRepair[testChannel]["missing-dep"].minTimeRepairReq
check originalTReqAt.toUnix > 0
rm1.cleanup()
# Restart — repair entry must be back with the SAME absolute time, not "now".
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(
participantId = "alice", persistence = p2
).get()
check rm2.ensureChannel(testChannel).isOk()
let buf = rm2.channels[testChannel].outgoingRepairBuffer
check "missing-dep" in buf
check buf["missing-dep"].minTimeRepairReq == originalTReqAt
rm2.cleanup()
test "FIFO eviction state survives restart":
let store = newInMemoryStore()
var smallCfg = defaultConfig()
smallCfg.maxMessageHistory = 3
smallCfg.bloomFilterCapacity = 3
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(config = smallCfg, persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
# Add 5 delivered messages — first 2 should be evicted by FIFO.
for i in 1 .. 5:
let m = SdsMessage.init(
messageId = "m" & $i,
lamportTimestamp = int64(i),
causalHistory = @[],
channelId = testChannel,
content = @[byte(i)],
bloomFilter = @[],
senderId = "alice",
)
rm1.addToHistory(m, testChannel)
check store.log[testChannel].len == 3
check "m1" notin store.log[testChannel]
check "m2" notin store.log[testChannel]
rm1.cleanup()
# Restart — evicted entries must NOT come back; survivors keep order.
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(config = smallCfg, persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
let history = rm2.channels[testChannel].messageHistory
check history.len == 3
check "m1" notin history
check "m2" notin history
check "m3" in history
check "m5" in history
# FIFO continues correctly after restart: adding m6 evicts m3, not a stale entry.
let m6 = SdsMessage.init(
messageId = "m6", lamportTimestamp = 6, causalHistory = @[],
channelId = testChannel, content = @[6.byte],
bloomFilter = @[], senderId = "alice",
)
rm2.addToHistory(m6, testChannel)
check "m3" notin store.log[testChannel]
check "m6" in store.log[testChannel]
rm2.cleanup()
test "dep-clear cascade resumes correctly across a restart":
let store = newInMemoryStore()
let p1 = newInMemoryPersistence(store)
let rm1 = newReliabilityManager(persistence = p1).get()
check rm1.ensureChannel(testChannel).isOk()
# Receive c (deps on b), then b (deps on a). Both must buffer.
let msgC = SdsMessage.init(
messageId = "c", lamportTimestamp = 30,
causalHistory = @[HistoryEntry.init("b", @[])],
channelId = testChannel, content = @[3.byte],
bloomFilter = @[], senderId = "carol",
)
let msgB = SdsMessage.init(
messageId = "b", lamportTimestamp = 20,
causalHistory = @[HistoryEntry.init("a", @[])],
channelId = testChannel, content = @[2.byte],
bloomFilter = @[], senderId = "bob",
)
discard rm1.unwrapReceivedMessage(serializeMessage(msgC).get())
discard rm1.unwrapReceivedMessage(serializeMessage(msgB).get())
check "c" in store.incoming[testChannel]
check "b" in store.incoming[testChannel]
rm1.cleanup()
# Restart — both still buffered, with intact missingDeps.
let p2 = newInMemoryPersistence(store)
let rm2 = newReliabilityManager(persistence = p2).get()
check rm2.ensureChannel(testChannel).isOk()
let inbuf = rm2.getIncomingBuffer(testChannel)
check "c" in inbuf
check "b" in inbuf
# Now receive a (root) — should cascade-deliver a, b, c.
let msgA = SdsMessage.init(
messageId = "a", lamportTimestamp = 10, causalHistory = @[],
channelId = testChannel, content = @[1.byte],
bloomFilter = @[], senderId = "alice",
)
discard rm2.unwrapReceivedMessage(serializeMessage(msgA).get())
let history = rm2.channels[testChannel].messageHistory
check "a" in history
check "b" in history
check "c" in history
# Buffer should be drained.
check rm2.getIncomingBuffer(testChannel).len == 0
rm2.cleanup()