From c977124a7e991dd8874210c421abffaeba780d82 Mon Sep 17 00:00:00 2001 From: darshankabariya Date: Mon, 4 May 2026 03:51:05 +0530 Subject: [PATCH] feat: feat: persistence interface for SDS state --- sds.nim | 67 +++++++++++++---- sds.nimble | 1 + sds/sds_utils.nim | 47 +++++++++++- sds/types.nim | 2 + sds/types/persistence.nim | 117 ++++++++++++++++++++++++++++++ sds/types/reliability_manager.nim | 9 ++- tests/in_memory_persistence.nim | 113 +++++++++++++++++++++++++++++ tests/test_persistence.nim | 111 ++++++++++++++++++++++++++++ 8 files changed, 450 insertions(+), 17 deletions(-) create mode 100644 sds/types/persistence.nim create mode 100644 tests/in_memory_persistence.nim create mode 100644 tests/test_persistence.nim diff --git a/sds.nim b/sds.nim index 202b489..30392e1 100644 --- a/sds.nim +++ b/sds.nim @@ -7,10 +7,13 @@ export types, protobuf, sds_utils, rolling_bloom_filter proc newReliabilityManager*( config: ReliabilityConfig = defaultConfig(), participantId: SdsParticipantID = "".SdsParticipantID, + persistence: Persistence = noOpPersistence(), ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new multi-channel ReliabilityManager. + ## `persistence` defaults to a no-op backend; supply a real one to durably + ## store SDS state across restarts. try: - let rm = ReliabilityManager.new(config, participantId) + let rm = ReliabilityManager.new(config, participantId, persistence) return ok(rm) except Exception: error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() @@ -53,7 +56,7 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = return let channel = rm.channels[msg.channelId] - var toDelete: seq[int] = @[] + var toDelete: seq[(int, SdsMessageID)] = @[] var i = 0 while i < channel.outgoingBuffer.len: @@ -61,11 +64,13 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = if outMsg.isAcknowledged(msg.causalHistory, rbf): if not rm.onMessageSent.isNil(): rm.onMessageSent(outMsg.message.messageId, outMsg.message.channelId) - toDelete.add(i) + toDelete.add((i, outMsg.message.messageId)) inc i - for i in countdown(toDelete.high, 0): - channel.outgoingBuffer.delete(toDelete[i]) + for k in countdown(toDelete.high, 0): + let (idx, ackedId) = toDelete[k] + channel.outgoingBuffer.delete(idx) + rm.persistence.removeOutgoing(msg.channelId, ackedId) proc wrapOutgoingMessage*( rm: ReliabilityManager, @@ -108,6 +113,7 @@ proc wrapOutgoingMessage*( expiredKeys.add(eligible[i][0]) for key in expiredKeys: channel.outgoingRepairBuffer.del(key) + rm.persistence.removeOutgoingRepair(channelId, key) let msg = SdsMessage.init( messageId = messageId, @@ -120,9 +126,10 @@ proc wrapOutgoingMessage*( repairRequest = repairReqs, ) - channel.outgoingBuffer.add( + let unackMsg = UnacknowledgedMessage.init(message = msg, sendTime = getTime(), resendAttempts = 0) - ) + channel.outgoingBuffer.add(unackMsg) + rm.persistence.saveOutgoing(channelId, unackMsg) channel.bloomFilter.add(msg.messageId) # The full SdsMessage carries senderId and content, so a single @@ -168,11 +175,15 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc if remainingId notin processed: if msgId in entry.missingDeps: channel.incomingBuffer[remainingId].missingDeps.excl(msgId) + rm.persistence.saveIncoming( + channelId, channel.incomingBuffer[remainingId] + ) if channel.incomingBuffer[remainingId].missingDeps.len == 0: readyToProcess.add(remainingId) for msgId in processed: channel.incomingBuffer.del(msgId) + rm.persistence.removeIncoming(channelId, msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] @@ -193,7 +204,9 @@ proc unwrapReceivedMessage*( # SDS-R: opportunistic repair-buffer cleanup — applies to duplicates too, # so rebroadcasts cancel redundant responses on peers that already have the message. channel.outgoingRepairBuffer.del(msg.messageId) + rm.persistence.removeOutgoingRepair(channelId, msg.messageId) channel.incomingRepairBuffer.del(msg.messageId) + rm.persistence.removeIncomingRepair(channelId, msg.messageId) if msg.messageId in channel.messageHistory: return ok((msg.content, @[], channelId)) @@ -211,6 +224,7 @@ proc unwrapReceivedMessage*( for repairEntry in msg.repairRequest: # Remove from our own outgoing repair buffer (someone else is also requesting) channel.outgoingRepairBuffer.del(repairEntry.messageId) + rm.persistence.removeOutgoingRepair(channelId, repairEntry.messageId) if repairEntry.messageId in channel.messageHistory and rm.participantId.len > 0 and repairEntry.senderId.len > 0: if isInResponseGroup( @@ -223,11 +237,13 @@ proc unwrapReceivedMessage*( rm.participantId, repairEntry.senderId, repairEntry.messageId, rm.config.repairTMax ) - channel.incomingRepairBuffer[repairEntry.messageId] = IncomingRepairEntry( + let inEntry = IncomingRepairEntry( inHistEntry: repairEntry, cachedMessage: serialized.get(), minTimeRepairResp: now + tResp, ) + channel.incomingRepairBuffer[repairEntry.messageId] = inEntry + rm.persistence.saveIncomingRepair(channelId, repairEntry.messageId, inEntry) var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) @@ -238,23 +254,30 @@ proc unwrapReceivedMessage*( depsInBuffer = true break if depsInBuffer: - channel.incomingBuffer[msg.messageId] = + let entry = IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]()) + channel.incomingBuffer[msg.messageId] = entry + rm.persistence.saveIncoming(channelId, entry) else: rm.addToHistory(msg, channelId) # Unblock any buffered messages that were waiting on this one. + var unblocked: seq[SdsMessageID] = @[] for pendingId, entry in channel.incomingBuffer: if msg.messageId in entry.missingDeps: channel.incomingBuffer[pendingId].missingDeps.excl(msg.messageId) + unblocked.add(pendingId) + for pendingId in unblocked: + rm.persistence.saveIncoming(channelId, channel.incomingBuffer[pendingId]) rm.processIncomingBuffer(channelId) if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId, channelId) else: - channel.incomingBuffer[msg.messageId] = - IncomingMessage.init( - message = msg, - missingDeps = missingDeps.getMessageIds().toHashSet(), - ) + let entry = IncomingMessage.init( + message = msg, + missingDeps = missingDeps.getMessageIds().toHashSet(), + ) + channel.incomingBuffer[msg.messageId] = entry + rm.persistence.saveIncoming(channelId, entry) if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps, channelId) @@ -266,10 +289,12 @@ proc unwrapReceivedMessage*( rm.participantId, dep.messageId, rm.config.repairTMin, rm.config.repairTMax ) - channel.outgoingRepairBuffer[dep.messageId] = OutgoingRepairEntry( + let outEntry = OutgoingRepairEntry( outHistEntry: dep, minTimeRepairReq: now + tReq, ) + channel.outgoingRepairBuffer[dep.messageId] = outEntry + rm.persistence.saveOutgoingRepair(channelId, dep.messageId, outEntry) return ok((msg.content, missingDeps, channelId)) except Exception: @@ -290,13 +315,19 @@ proc markDependenciesMet*( if not channel.bloomFilter.contains(msgId): channel.bloomFilter.add(msgId) + var unblocked: seq[SdsMessageID] = @[] for pendingId, entry in channel.incomingBuffer: if msgId in entry.missingDeps: channel.incomingBuffer[pendingId].missingDeps.excl(msgId) + unblocked.add(pendingId) + for pendingId in unblocked: + rm.persistence.saveIncoming(channelId, channel.incomingBuffer[pendingId]) # SDS-R: clear from repair buffers (dependency now met) channel.outgoingRepairBuffer.del(msgId) + rm.persistence.removeOutgoingRepair(channelId, msgId) channel.incomingRepairBuffer.del(msgId) + rm.persistence.removeIncomingRepair(channelId, msgId) rm.processIncomingBuffer(channelId) return ok() @@ -343,9 +374,11 @@ proc checkUnacknowledgedMessages( updatedMsg.resendAttempts += 1 updatedMsg.sendTime = now newOutgoingBuffer.add(updatedMsg) + rm.persistence.saveOutgoing(channelId, updatedMsg) else: if not rm.onMessageSent.isNil(): rm.onMessageSent(unackMsg.message.messageId, channelId) + rm.persistence.removeOutgoing(channelId, unackMsg.message.messageId) else: newOutgoingBuffer.add(unackMsg) @@ -400,6 +433,7 @@ proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} = for msgId in toRebroadcast: let entry = channel.incomingRepairBuffer[msgId] channel.incomingRepairBuffer.del(msgId) + rm.persistence.removeIncomingRepair(channelId, msgId) if not rm.onRepairReady.isNil(): rm.onRepairReady(entry.cachedMessage, channelId) @@ -411,6 +445,7 @@ proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} = toRemove.add(msgId) for msgId in toRemove: channel.outgoingRepairBuffer.del(msgId) + rm.persistence.removeOutgoingRepair(channelId, msgId) except Exception: error "Error in repair sweep for channel", channelId = channelId, msg = getCurrentExceptionMsg() @@ -436,6 +471,8 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE withLock rm.lock: try: for channelId, channel in rm.channels: + rm.dropChannelFromPersistence(channelId, channel) + rm.persistence.saveLamport(channelId, 0) channel.lamportTimestamp = 0 channel.messageHistory.clear() channel.outgoingBuffer.setLen(0) diff --git a/sds.nimble b/sds.nimble index fb5f43d..a3c75b3 100644 --- a/sds.nimble +++ b/sds.nimble @@ -67,6 +67,7 @@ proc getMyCpu(): string = task test, "Run the test suite": exec "nim c -r tests/test_bloom.nim" exec "nim c -r tests/test_reliability.nim" + exec "nim c -r tests/test_persistence.nim" task libsdsDynamicWindows, "Generate bindings": let outLibNameAndExt = "libsds.dll" diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index eefae43..bf1fc8c 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -14,7 +14,28 @@ export proc defaultConfig*(): ReliabilityConfig = return ReliabilityConfig.init() +proc dropChannelFromPersistence*( + rm: ReliabilityManager, channelId: SdsChannelID, channel: ChannelContext +) {.gcsafe, raises: [].} = + ## Fires per-entry remove calls for every buffered entry in the channel. + ## Called by cleanup / removeChannel / resetReliabilityManager before they + ## wipe in-memory state, so on-disk state stays consistent. + for unack in channel.outgoingBuffer: + rm.persistence.removeOutgoing(channelId, unack.message.messageId) + for msgId in channel.incomingBuffer.keys: + rm.persistence.removeIncoming(channelId, msgId) + for msgId in channel.messageHistory.keys: + rm.persistence.removeLogEntry(channelId, msgId) + for msgId in channel.outgoingRepairBuffer.keys: + rm.persistence.removeOutgoingRepair(channelId, msgId) + for msgId in channel.incomingRepairBuffer.keys: + rm.persistence.removeIncomingRepair(channelId, msgId) + proc cleanup*(rm: ReliabilityManager) {.raises: [].} = + ## Releases in-memory state. Does NOT wipe persistence — the manager may be + ## reconstructed against the same backend after cleanup, so disk state must + ## survive. For deliberate disk wipe, use `removeChannel` or + ## `resetReliabilityManager`. if not rm.isNil(): try: withLock rm.lock: @@ -50,12 +71,14 @@ proc addToHistory*( if channelId in rm.channels: let channel = rm.channels[channelId] channel.messageHistory[msg.messageId] = msg + rm.persistence.appendLogEntry(channelId, msg) while channel.messageHistory.len > rm.config.maxMessageHistory: var firstKey: SdsMessageID for k in channel.messageHistory.keys: firstKey = k break channel.messageHistory.del(firstKey) + rm.persistence.removeLogEntry(channelId, firstKey) except Exception: error "Failed to add to history", channelId = channelId, msgId = msg.messageId, error = getCurrentExceptionMsg() @@ -67,6 +90,7 @@ proc updateLamportTimestamp*( if channelId in rm.channels: let channel = rm.channels[channelId] channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1 + rm.persistence.saveLamport(channelId, channel.lamportTimestamp) except Exception: error "Failed to update lamport timestamp", channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() @@ -150,6 +174,8 @@ proc getRecentHistoryEntries*( var entry = HistoryEntry(messageId: msgId) if not rm.onRetrievalHint.isNil(): entry.retrievalHint = rm.onRetrievalHint(msgId) + if entry.retrievalHint.len > 0: + rm.persistence.setRetrievalHint(msgId, entry.retrievalHint) entry.senderId = channel.messageHistory[msgId].senderId entries.add(entry) return entries @@ -226,11 +252,29 @@ proc getIncomingBuffer*( proc getOrCreateChannel*( rm: ReliabilityManager, channelId: SdsChannelID ): ChannelContext = + ## Returns the channel context, creating and bootstrapping it from the + ## persistence backend if it does not yet exist in memory. The bloom filter + ## is rebuilt deterministically from the loaded message history rather than + ## persisted directly. Caller is expected to hold rm.lock. try: if channelId notin rm.channels: - rm.channels[channelId] = ChannelContext.new( + let channel = ChannelContext.new( RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) ) + let snapshot = rm.persistence.loadAllForChannel(channelId) + channel.lamportTimestamp = snapshot.lamportTimestamp + for msg in snapshot.messageHistory: + channel.messageHistory[msg.messageId] = msg + channel.bloomFilter.add(msg.messageId) + for unack in snapshot.outgoingBuffer: + channel.outgoingBuffer.add(unack) + for incoming in snapshot.incomingBuffer: + channel.incomingBuffer[incoming.message.messageId] = incoming + for (msgId, entry) in snapshot.outgoingRepairBuffer: + channel.outgoingRepairBuffer[msgId] = entry + for (msgId, entry) in snapshot.incomingRepairBuffer: + channel.incomingRepairBuffer[msgId] = entry + rm.channels[channelId] = channel return rm.channels[channelId] except Exception: error "Failed to get or create channel", @@ -256,6 +300,7 @@ proc removeChannel*( try: if channelId in rm.channels: let channel = rm.channels[channelId] + rm.dropChannelFromPersistence(channelId, channel) channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() channel.messageHistory.clear() diff --git a/sds/types.nim b/sds/types.nim index 637ec54..3402993 100644 --- a/sds/types.nim +++ b/sds/types.nim @@ -11,6 +11,7 @@ import sds/types/app_callbacks import sds/types/reliability_config import sds/types/repair_entry import sds/types/channel_context +import sds/types/persistence import sds/types/reliability_manager import sds/types/protobuf_error @@ -28,5 +29,6 @@ export reliability_config, repair_entry, channel_context, + persistence, reliability_manager, protobuf_error diff --git a/sds/types/persistence.nim b/sds/types/persistence.nim new file mode 100644 index 0000000..865b5f0 --- /dev/null +++ b/sds/types/persistence.nim @@ -0,0 +1,117 @@ +import ./sds_message_id +import ./sds_message +import ./unacknowledged_message +import ./incoming_message +import ./repair_entry +export + sds_message_id, sds_message, unacknowledged_message, incoming_message, repair_entry + +## SDS state persistence interface (issue #64). +## +## Defines WHAT operations a persistence backend must provide. The actual +## storage technology (SQLite, encrypted file, in-memory) is supplied by the +## caller — nim-sds knows nothing about it. Every state-mutating proc in the +## protocol calls into one of these procs immediately after the in-memory +## change, so on-disk state stays in lockstep with in-memory state. +## +## Bloom filter is intentionally not persisted: it is rebuilt from the local +## history log on bootstrap. Async timers are likewise recomputed from the +## absolute timestamps stored in the repair buffer entries. + +type + ChannelSnapshot* = object + ## Returned by `loadAllForChannel` on bootstrap. Carries the entire + ## per-channel state needed to repopulate a `ChannelContext`. The bloom + ## filter is NOT in the snapshot — callers rebuild it from `messageHistory`. + lamportTimestamp*: int64 + messageHistory*: seq[SdsMessage] + ## Delivered messages, oldest first (insertion order preserved for + ## causal-history tail access and FIFO eviction). + outgoingBuffer*: seq[UnacknowledgedMessage] + incomingBuffer*: seq[IncomingMessage] + outgoingRepairBuffer*: seq[(SdsMessageID, OutgoingRepairEntry)] + incomingRepairBuffer*: seq[(SdsMessageID, IncomingRepairEntry)] + + Persistence* = object + ## Pluggable persistence contract. The caller supplies an instance of this + ## type at `newReliabilityManager` construction time. Each proc field is + ## invoked by nim-sds at the corresponding state-mutation point. + + # Per-channel lamport clock + saveLamport*: + proc(channelId: SdsChannelID, lamport: int64) {.gcsafe, raises: [].} + + # Local log (delivered messages) + appendLogEntry*: + proc(channelId: SdsChannelID, msg: SdsMessage) {.gcsafe, raises: [].} + removeLogEntry*: + proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} + setRetrievalHint*: + proc(msgId: SdsMessageID, hint: seq[byte]) {.gcsafe, raises: [].} + + # Outgoing unacknowledged buffer + saveOutgoing*: + proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) {.gcsafe, raises: [].} + removeOutgoing*: + proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} + + # Incoming dependency-waiting buffer + saveIncoming*: + proc(channelId: SdsChannelID, msg: IncomingMessage) {.gcsafe, raises: [].} + removeIncoming*: + proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} + + # SDS-R outgoing repair buffer + saveOutgoingRepair*: proc( + channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry + ) {.gcsafe, raises: [].} + removeOutgoingRepair*: + proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} + + # SDS-R incoming repair buffer + saveIncomingRepair*: proc( + channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry + ) {.gcsafe, raises: [].} + removeIncomingRepair*: + proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} + + # Bootstrap on `addChannel` / `getOrCreateChannel` + loadAllForChannel*: + proc(channelId: SdsChannelID): ChannelSnapshot {.gcsafe, raises: [].} + +proc noOpPersistence*(): Persistence = + ## Default backend that discards every write and returns an empty snapshot. + ## Used so existing callers (and tests) that don't care about durability + ## keep working without supplying a real backend. + Persistence( + saveLamport: proc(channelId: SdsChannelID, lamport: int64) = + discard, + appendLogEntry: proc(channelId: SdsChannelID, msg: SdsMessage) = + discard, + removeLogEntry: proc(channelId: SdsChannelID, msgId: SdsMessageID) = + discard, + setRetrievalHint: proc(msgId: SdsMessageID, hint: seq[byte]) = + discard, + saveOutgoing: proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) = + discard, + removeOutgoing: proc(channelId: SdsChannelID, msgId: SdsMessageID) = + discard, + saveIncoming: proc(channelId: SdsChannelID, msg: IncomingMessage) = + discard, + removeIncoming: proc(channelId: SdsChannelID, msgId: SdsMessageID) = + discard, + saveOutgoingRepair: proc( + channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry + ) = + discard, + removeOutgoingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) = + discard, + saveIncomingRepair: proc( + channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry + ) = + discard, + removeIncomingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) = + discard, + loadAllForChannel: proc(channelId: SdsChannelID): ChannelSnapshot = + ChannelSnapshot(), + ) diff --git a/sds/types/reliability_manager.nim b/sds/types/reliability_manager.nim index d28ee5d..2c12596 100644 --- a/sds/types/reliability_manager.nim +++ b/sds/types/reliability_manager.nim @@ -4,12 +4,17 @@ import ./history_entry import ./callbacks import ./reliability_config import ./channel_context -export sds_message_id, history_entry, callbacks, reliability_config, channel_context +import ./persistence +export + sds_message_id, history_entry, callbacks, reliability_config, channel_context, + persistence type ReliabilityManager* = ref object channels*: Table[SdsChannelID, ChannelContext] config*: ReliabilityConfig participantId*: SdsParticipantID + persistence*: Persistence + ## Pluggable durability backend; defaults to a no-op when not supplied. lock*: Lock onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} @@ -24,11 +29,13 @@ proc new*( T: type ReliabilityManager, config: ReliabilityConfig, participantId: SdsParticipantID = "".SdsParticipantID, + persistence: Persistence = noOpPersistence(), ): T = let rm = T( channels: initTable[SdsChannelID, ChannelContext](), config: config, participantId: participantId, + persistence: persistence, ) rm.lock.initLock() return rm diff --git a/tests/in_memory_persistence.nim b/tests/in_memory_persistence.nim new file mode 100644 index 0000000..3f44909 --- /dev/null +++ b/tests/in_memory_persistence.nim @@ -0,0 +1,113 @@ +import std/tables +import sds + +## Test-only Persistence backend backed by Nim tables. Lets tests verify the +## full write → restart → read-back loop without depending on SQLite (or any +## real storage technology). Exposes the underlying store so tests can assert +## on what got saved. + +type InMemoryStore* = ref object + lamports*: Table[SdsChannelID, int64] + log*: Table[SdsChannelID, OrderedTable[SdsMessageID, SdsMessage]] + hints*: Table[SdsMessageID, seq[byte]] + outgoing*: Table[SdsChannelID, OrderedTable[SdsMessageID, UnacknowledgedMessage]] + incoming*: Table[SdsChannelID, OrderedTable[SdsMessageID, IncomingMessage]] + outgoingRepair*: Table[SdsChannelID, OrderedTable[SdsMessageID, OutgoingRepairEntry]] + incomingRepair*: Table[SdsChannelID, OrderedTable[SdsMessageID, IncomingRepairEntry]] + +proc newInMemoryStore*(): InMemoryStore = + InMemoryStore() + +proc newInMemoryPersistence*(store: InMemoryStore): Persistence = + Persistence( + saveLamport: proc(channelId: SdsChannelID, lamport: int64) {.gcsafe, raises: [].} = + store.lamports[channelId] = lamport, + + appendLogEntry: proc(channelId: SdsChannelID, msg: SdsMessage) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId notin store.log: + store.log[channelId] = initOrderedTable[SdsMessageID, SdsMessage]() + store.log[channelId][msg.messageId] = msg, + + removeLogEntry: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId in store.log: + store.log[channelId].del(msgId), + + setRetrievalHint: proc(msgId: SdsMessageID, hint: seq[byte]) {.gcsafe, raises: [].} = + store.hints[msgId] = hint, + + saveOutgoing: proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId notin store.outgoing: + store.outgoing[channelId] = + initOrderedTable[SdsMessageID, UnacknowledgedMessage]() + store.outgoing[channelId][msg.message.messageId] = msg, + + removeOutgoing: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId in store.outgoing: + store.outgoing[channelId].del(msgId), + + saveIncoming: proc(channelId: SdsChannelID, msg: IncomingMessage) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId notin store.incoming: + store.incoming[channelId] = + initOrderedTable[SdsMessageID, IncomingMessage]() + store.incoming[channelId][msg.message.messageId] = msg, + + removeIncoming: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId in store.incoming: + store.incoming[channelId].del(msgId), + + saveOutgoingRepair: proc( + channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry + ) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId notin store.outgoingRepair: + store.outgoingRepair[channelId] = + initOrderedTable[SdsMessageID, OutgoingRepairEntry]() + store.outgoingRepair[channelId][msgId] = entry, + + removeOutgoingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId in store.outgoingRepair: + store.outgoingRepair[channelId].del(msgId), + + saveIncomingRepair: proc( + channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry + ) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId notin store.incomingRepair: + store.incomingRepair[channelId] = + initOrderedTable[SdsMessageID, IncomingRepairEntry]() + store.incomingRepair[channelId][msgId] = entry, + + removeIncomingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} = + {.cast(raises: []).}: + if channelId in store.incomingRepair: + store.incomingRepair[channelId].del(msgId), + + loadAllForChannel: proc(channelId: SdsChannelID): ChannelSnapshot {.gcsafe, raises: [].} = + {.cast(raises: []).}: + var snap = ChannelSnapshot() + if channelId in store.lamports: + snap.lamportTimestamp = store.lamports[channelId] + if channelId in store.log: + for msg in store.log[channelId].values: + snap.messageHistory.add(msg) + if channelId in store.outgoing: + for unack in store.outgoing[channelId].values: + snap.outgoingBuffer.add(unack) + if channelId in store.incoming: + for incoming in store.incoming[channelId].values: + snap.incomingBuffer.add(incoming) + if channelId in store.outgoingRepair: + for msgId, entry in store.outgoingRepair[channelId]: + snap.outgoingRepairBuffer.add((msgId, entry)) + if channelId in store.incomingRepair: + for msgId, entry in store.incomingRepair[channelId]: + snap.incomingRepairBuffer.add((msgId, entry)) + snap, + ) diff --git a/tests/test_persistence.nim b/tests/test_persistence.nim new file mode 100644 index 0000000..020996e --- /dev/null +++ b/tests/test_persistence.nim @@ -0,0 +1,111 @@ +import unittest, results, std/tables +import sds +import ./in_memory_persistence + +converter toParticipantID(s: string): SdsParticipantID = s.SdsParticipantID + +const testChannel = "testChannel" + +suite "Persistence: write → restart → read-back": + test "outgoing buffer survives restart": + let store = newInMemoryStore() + let p1 = newInMemoryPersistence(store) + let rm1 = newReliabilityManager(persistence = p1).get() + check rm1.ensureChannel(testChannel).isOk() + let wrapped = rm1.wrapOutgoingMessage(@[1.byte, 2, 3], "msg-1", testChannel) + check wrapped.isOk() + check store.outgoing[testChannel].len == 1 + check "msg-1" in store.outgoing[testChannel] + rm1.cleanup() + + # Simulate restart: fresh manager, same backend + let p2 = newInMemoryPersistence(store) + let rm2 = newReliabilityManager(persistence = p2).get() + check rm2.ensureChannel(testChannel).isOk() + let buf = rm2.getOutgoingBuffer(testChannel) + check buf.len == 1 + check buf[0].message.messageId == "msg-1" + rm2.cleanup() + + test "lamport clock survives restart": + let store = newInMemoryStore() + let p1 = newInMemoryPersistence(store) + let rm1 = newReliabilityManager(persistence = p1).get() + check rm1.ensureChannel(testChannel).isOk() + rm1.updateLamportTimestamp(42, testChannel) + check store.lamports[testChannel] == 43 # max(42, 0) + 1 + rm1.cleanup() + + let p2 = newInMemoryPersistence(store) + let rm2 = newReliabilityManager(persistence = p2).get() + check rm2.ensureChannel(testChannel).isOk() + check rm2.channels[testChannel].lamportTimestamp == 43 + + test "delivered messages survive restart and rebuild bloom": + let store = newInMemoryStore() + let p1 = newInMemoryPersistence(store) + let rm1 = newReliabilityManager(persistence = p1).get() + check rm1.ensureChannel(testChannel).isOk() + let msg = SdsMessage.init( + messageId = "delivered-1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[9.byte, 9], + bloomFilter = @[], + senderId = "alice", + ) + rm1.addToHistory(msg, testChannel) + check store.log[testChannel].len == 1 + rm1.cleanup() + + let p2 = newInMemoryPersistence(store) + let rm2 = newReliabilityManager(persistence = p2).get() + check rm2.ensureChannel(testChannel).isOk() + let ch = rm2.channels[testChannel] + check ch.messageHistory.len == 1 + check "delivered-1" in ch.messageHistory + # Bloom filter rebuilt from log on bootstrap + check ch.bloomFilter.contains("delivered-1") + + test "ack removes outgoing entry from persistence": + let store = newInMemoryStore() + let p = newInMemoryPersistence(store) + let rm = newReliabilityManager(persistence = p).get() + check rm.ensureChannel(testChannel).isOk() + discard rm.wrapOutgoingMessage(@[1.byte], "msg-x", testChannel) + check "msg-x" in store.outgoing[testChannel] + + # Synthesize an incoming message that ACKs msg-x via causal history + let ackMsg = SdsMessage.init( + messageId = "ack-bearer", + lamportTimestamp = 5, + causalHistory = @[HistoryEntry.init("msg-x", @[])], + channelId = testChannel, + content = @[], + bloomFilter = @[], + senderId = "bob", + ) + let serialized = serializeMessage(ackMsg).get() + discard rm.unwrapReceivedMessage(serialized) + check "msg-x" notin store.outgoing[testChannel] + rm.cleanup() + + test "removeChannel fires per-entry removes": + let store = newInMemoryStore() + let p = newInMemoryPersistence(store) + let rm = newReliabilityManager(persistence = p).get() + check rm.ensureChannel(testChannel).isOk() + discard rm.wrapOutgoingMessage(@[1.byte], "msg-r", testChannel) + check store.outgoing[testChannel].len == 1 + check rm.removeChannel(testChannel).isOk() + check store.outgoing[testChannel].len == 0 + rm.cleanup() + + test "noOpPersistence keeps existing manager working": + let rm = newReliabilityManager().get() # default no-op + check rm.ensureChannel(testChannel).isOk() + let wrapped = rm.wrapOutgoingMessage(@[1.byte], "msg-n", testChannel) + check wrapped.isOk() + check rm.getOutgoingBuffer(testChannel).len == 1 + rm.cleanup()