mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-05-18 07:59:54 +00:00
feat: feat: persistence interface for SDS state
This commit is contained in:
parent
9d08f5995b
commit
c977124a7e
67
sds.nim
67
sds.nim
@ -7,10 +7,13 @@ export types, protobuf, sds_utils, rolling_bloom_filter
|
||||
proc newReliabilityManager*(
|
||||
config: ReliabilityConfig = defaultConfig(),
|
||||
participantId: SdsParticipantID = "".SdsParticipantID,
|
||||
persistence: Persistence = noOpPersistence(),
|
||||
): Result[ReliabilityManager, ReliabilityError] =
|
||||
## Creates a new multi-channel ReliabilityManager.
|
||||
## `persistence` defaults to a no-op backend; supply a real one to durably
|
||||
## store SDS state across restarts.
|
||||
try:
|
||||
let rm = ReliabilityManager.new(config, participantId)
|
||||
let rm = ReliabilityManager.new(config, participantId, persistence)
|
||||
return ok(rm)
|
||||
except Exception:
|
||||
error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg()
|
||||
@ -53,7 +56,7 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
|
||||
return
|
||||
|
||||
let channel = rm.channels[msg.channelId]
|
||||
var toDelete: seq[int] = @[]
|
||||
var toDelete: seq[(int, SdsMessageID)] = @[]
|
||||
var i = 0
|
||||
|
||||
while i < channel.outgoingBuffer.len:
|
||||
@ -61,11 +64,13 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
|
||||
if outMsg.isAcknowledged(msg.causalHistory, rbf):
|
||||
if not rm.onMessageSent.isNil():
|
||||
rm.onMessageSent(outMsg.message.messageId, outMsg.message.channelId)
|
||||
toDelete.add(i)
|
||||
toDelete.add((i, outMsg.message.messageId))
|
||||
inc i
|
||||
|
||||
for i in countdown(toDelete.high, 0):
|
||||
channel.outgoingBuffer.delete(toDelete[i])
|
||||
for k in countdown(toDelete.high, 0):
|
||||
let (idx, ackedId) = toDelete[k]
|
||||
channel.outgoingBuffer.delete(idx)
|
||||
rm.persistence.removeOutgoing(msg.channelId, ackedId)
|
||||
|
||||
proc wrapOutgoingMessage*(
|
||||
rm: ReliabilityManager,
|
||||
@ -108,6 +113,7 @@ proc wrapOutgoingMessage*(
|
||||
expiredKeys.add(eligible[i][0])
|
||||
for key in expiredKeys:
|
||||
channel.outgoingRepairBuffer.del(key)
|
||||
rm.persistence.removeOutgoingRepair(channelId, key)
|
||||
|
||||
let msg = SdsMessage.init(
|
||||
messageId = messageId,
|
||||
@ -120,9 +126,10 @@ proc wrapOutgoingMessage*(
|
||||
repairRequest = repairReqs,
|
||||
)
|
||||
|
||||
channel.outgoingBuffer.add(
|
||||
let unackMsg =
|
||||
UnacknowledgedMessage.init(message = msg, sendTime = getTime(), resendAttempts = 0)
|
||||
)
|
||||
channel.outgoingBuffer.add(unackMsg)
|
||||
rm.persistence.saveOutgoing(channelId, unackMsg)
|
||||
|
||||
channel.bloomFilter.add(msg.messageId)
|
||||
# The full SdsMessage carries senderId and content, so a single
|
||||
@ -168,11 +175,15 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
|
||||
if remainingId notin processed:
|
||||
if msgId in entry.missingDeps:
|
||||
channel.incomingBuffer[remainingId].missingDeps.excl(msgId)
|
||||
rm.persistence.saveIncoming(
|
||||
channelId, channel.incomingBuffer[remainingId]
|
||||
)
|
||||
if channel.incomingBuffer[remainingId].missingDeps.len == 0:
|
||||
readyToProcess.add(remainingId)
|
||||
|
||||
for msgId in processed:
|
||||
channel.incomingBuffer.del(msgId)
|
||||
rm.persistence.removeIncoming(channelId, msgId)
|
||||
|
||||
proc unwrapReceivedMessage*(
|
||||
rm: ReliabilityManager, message: seq[byte]
|
||||
@ -193,7 +204,9 @@ proc unwrapReceivedMessage*(
|
||||
# SDS-R: opportunistic repair-buffer cleanup — applies to duplicates too,
|
||||
# so rebroadcasts cancel redundant responses on peers that already have the message.
|
||||
channel.outgoingRepairBuffer.del(msg.messageId)
|
||||
rm.persistence.removeOutgoingRepair(channelId, msg.messageId)
|
||||
channel.incomingRepairBuffer.del(msg.messageId)
|
||||
rm.persistence.removeIncomingRepair(channelId, msg.messageId)
|
||||
|
||||
if msg.messageId in channel.messageHistory:
|
||||
return ok((msg.content, @[], channelId))
|
||||
@ -211,6 +224,7 @@ proc unwrapReceivedMessage*(
|
||||
for repairEntry in msg.repairRequest:
|
||||
# Remove from our own outgoing repair buffer (someone else is also requesting)
|
||||
channel.outgoingRepairBuffer.del(repairEntry.messageId)
|
||||
rm.persistence.removeOutgoingRepair(channelId, repairEntry.messageId)
|
||||
if repairEntry.messageId in channel.messageHistory and
|
||||
rm.participantId.len > 0 and repairEntry.senderId.len > 0:
|
||||
if isInResponseGroup(
|
||||
@ -223,11 +237,13 @@ proc unwrapReceivedMessage*(
|
||||
rm.participantId, repairEntry.senderId,
|
||||
repairEntry.messageId, rm.config.repairTMax
|
||||
)
|
||||
channel.incomingRepairBuffer[repairEntry.messageId] = IncomingRepairEntry(
|
||||
let inEntry = IncomingRepairEntry(
|
||||
inHistEntry: repairEntry,
|
||||
cachedMessage: serialized.get(),
|
||||
minTimeRepairResp: now + tResp,
|
||||
)
|
||||
channel.incomingRepairBuffer[repairEntry.messageId] = inEntry
|
||||
rm.persistence.saveIncomingRepair(channelId, repairEntry.messageId, inEntry)
|
||||
|
||||
var missingDeps = rm.checkDependencies(msg.causalHistory, channelId)
|
||||
|
||||
@ -238,23 +254,30 @@ proc unwrapReceivedMessage*(
|
||||
depsInBuffer = true
|
||||
break
|
||||
if depsInBuffer:
|
||||
channel.incomingBuffer[msg.messageId] =
|
||||
let entry =
|
||||
IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]())
|
||||
channel.incomingBuffer[msg.messageId] = entry
|
||||
rm.persistence.saveIncoming(channelId, entry)
|
||||
else:
|
||||
rm.addToHistory(msg, channelId)
|
||||
# Unblock any buffered messages that were waiting on this one.
|
||||
var unblocked: seq[SdsMessageID] = @[]
|
||||
for pendingId, entry in channel.incomingBuffer:
|
||||
if msg.messageId in entry.missingDeps:
|
||||
channel.incomingBuffer[pendingId].missingDeps.excl(msg.messageId)
|
||||
unblocked.add(pendingId)
|
||||
for pendingId in unblocked:
|
||||
rm.persistence.saveIncoming(channelId, channel.incomingBuffer[pendingId])
|
||||
rm.processIncomingBuffer(channelId)
|
||||
if not rm.onMessageReady.isNil():
|
||||
rm.onMessageReady(msg.messageId, channelId)
|
||||
else:
|
||||
channel.incomingBuffer[msg.messageId] =
|
||||
IncomingMessage.init(
|
||||
message = msg,
|
||||
missingDeps = missingDeps.getMessageIds().toHashSet(),
|
||||
)
|
||||
let entry = IncomingMessage.init(
|
||||
message = msg,
|
||||
missingDeps = missingDeps.getMessageIds().toHashSet(),
|
||||
)
|
||||
channel.incomingBuffer[msg.messageId] = entry
|
||||
rm.persistence.saveIncoming(channelId, entry)
|
||||
if not rm.onMissingDependencies.isNil():
|
||||
rm.onMissingDependencies(msg.messageId, missingDeps, channelId)
|
||||
|
||||
@ -266,10 +289,12 @@ proc unwrapReceivedMessage*(
|
||||
rm.participantId, dep.messageId,
|
||||
rm.config.repairTMin, rm.config.repairTMax
|
||||
)
|
||||
channel.outgoingRepairBuffer[dep.messageId] = OutgoingRepairEntry(
|
||||
let outEntry = OutgoingRepairEntry(
|
||||
outHistEntry: dep,
|
||||
minTimeRepairReq: now + tReq,
|
||||
)
|
||||
channel.outgoingRepairBuffer[dep.messageId] = outEntry
|
||||
rm.persistence.saveOutgoingRepair(channelId, dep.messageId, outEntry)
|
||||
|
||||
return ok((msg.content, missingDeps, channelId))
|
||||
except Exception:
|
||||
@ -290,13 +315,19 @@ proc markDependenciesMet*(
|
||||
if not channel.bloomFilter.contains(msgId):
|
||||
channel.bloomFilter.add(msgId)
|
||||
|
||||
var unblocked: seq[SdsMessageID] = @[]
|
||||
for pendingId, entry in channel.incomingBuffer:
|
||||
if msgId in entry.missingDeps:
|
||||
channel.incomingBuffer[pendingId].missingDeps.excl(msgId)
|
||||
unblocked.add(pendingId)
|
||||
for pendingId in unblocked:
|
||||
rm.persistence.saveIncoming(channelId, channel.incomingBuffer[pendingId])
|
||||
|
||||
# SDS-R: clear from repair buffers (dependency now met)
|
||||
channel.outgoingRepairBuffer.del(msgId)
|
||||
rm.persistence.removeOutgoingRepair(channelId, msgId)
|
||||
channel.incomingRepairBuffer.del(msgId)
|
||||
rm.persistence.removeIncomingRepair(channelId, msgId)
|
||||
|
||||
rm.processIncomingBuffer(channelId)
|
||||
return ok()
|
||||
@ -343,9 +374,11 @@ proc checkUnacknowledgedMessages(
|
||||
updatedMsg.resendAttempts += 1
|
||||
updatedMsg.sendTime = now
|
||||
newOutgoingBuffer.add(updatedMsg)
|
||||
rm.persistence.saveOutgoing(channelId, updatedMsg)
|
||||
else:
|
||||
if not rm.onMessageSent.isNil():
|
||||
rm.onMessageSent(unackMsg.message.messageId, channelId)
|
||||
rm.persistence.removeOutgoing(channelId, unackMsg.message.messageId)
|
||||
else:
|
||||
newOutgoingBuffer.add(unackMsg)
|
||||
|
||||
@ -400,6 +433,7 @@ proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
||||
for msgId in toRebroadcast:
|
||||
let entry = channel.incomingRepairBuffer[msgId]
|
||||
channel.incomingRepairBuffer.del(msgId)
|
||||
rm.persistence.removeIncomingRepair(channelId, msgId)
|
||||
if not rm.onRepairReady.isNil():
|
||||
rm.onRepairReady(entry.cachedMessage, channelId)
|
||||
|
||||
@ -411,6 +445,7 @@ proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
||||
toRemove.add(msgId)
|
||||
for msgId in toRemove:
|
||||
channel.outgoingRepairBuffer.del(msgId)
|
||||
rm.persistence.removeOutgoingRepair(channelId, msgId)
|
||||
except Exception:
|
||||
error "Error in repair sweep for channel",
|
||||
channelId = channelId, msg = getCurrentExceptionMsg()
|
||||
@ -436,6 +471,8 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE
|
||||
withLock rm.lock:
|
||||
try:
|
||||
for channelId, channel in rm.channels:
|
||||
rm.dropChannelFromPersistence(channelId, channel)
|
||||
rm.persistence.saveLamport(channelId, 0)
|
||||
channel.lamportTimestamp = 0
|
||||
channel.messageHistory.clear()
|
||||
channel.outgoingBuffer.setLen(0)
|
||||
|
||||
@ -67,6 +67,7 @@ proc getMyCpu(): string =
|
||||
task test, "Run the test suite":
|
||||
exec "nim c -r tests/test_bloom.nim"
|
||||
exec "nim c -r tests/test_reliability.nim"
|
||||
exec "nim c -r tests/test_persistence.nim"
|
||||
|
||||
task libsdsDynamicWindows, "Generate bindings":
|
||||
let outLibNameAndExt = "libsds.dll"
|
||||
|
||||
@ -14,7 +14,28 @@ export
|
||||
proc defaultConfig*(): ReliabilityConfig =
|
||||
return ReliabilityConfig.init()
|
||||
|
||||
proc dropChannelFromPersistence*(
|
||||
rm: ReliabilityManager, channelId: SdsChannelID, channel: ChannelContext
|
||||
) {.gcsafe, raises: [].} =
|
||||
## Fires per-entry remove calls for every buffered entry in the channel.
|
||||
## Called by cleanup / removeChannel / resetReliabilityManager before they
|
||||
## wipe in-memory state, so on-disk state stays consistent.
|
||||
for unack in channel.outgoingBuffer:
|
||||
rm.persistence.removeOutgoing(channelId, unack.message.messageId)
|
||||
for msgId in channel.incomingBuffer.keys:
|
||||
rm.persistence.removeIncoming(channelId, msgId)
|
||||
for msgId in channel.messageHistory.keys:
|
||||
rm.persistence.removeLogEntry(channelId, msgId)
|
||||
for msgId in channel.outgoingRepairBuffer.keys:
|
||||
rm.persistence.removeOutgoingRepair(channelId, msgId)
|
||||
for msgId in channel.incomingRepairBuffer.keys:
|
||||
rm.persistence.removeIncomingRepair(channelId, msgId)
|
||||
|
||||
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
|
||||
## Releases in-memory state. Does NOT wipe persistence — the manager may be
|
||||
## reconstructed against the same backend after cleanup, so disk state must
|
||||
## survive. For deliberate disk wipe, use `removeChannel` or
|
||||
## `resetReliabilityManager`.
|
||||
if not rm.isNil():
|
||||
try:
|
||||
withLock rm.lock:
|
||||
@ -50,12 +71,14 @@ proc addToHistory*(
|
||||
if channelId in rm.channels:
|
||||
let channel = rm.channels[channelId]
|
||||
channel.messageHistory[msg.messageId] = msg
|
||||
rm.persistence.appendLogEntry(channelId, msg)
|
||||
while channel.messageHistory.len > rm.config.maxMessageHistory:
|
||||
var firstKey: SdsMessageID
|
||||
for k in channel.messageHistory.keys:
|
||||
firstKey = k
|
||||
break
|
||||
channel.messageHistory.del(firstKey)
|
||||
rm.persistence.removeLogEntry(channelId, firstKey)
|
||||
except Exception:
|
||||
error "Failed to add to history",
|
||||
channelId = channelId, msgId = msg.messageId, error = getCurrentExceptionMsg()
|
||||
@ -67,6 +90,7 @@ proc updateLamportTimestamp*(
|
||||
if channelId in rm.channels:
|
||||
let channel = rm.channels[channelId]
|
||||
channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1
|
||||
rm.persistence.saveLamport(channelId, channel.lamportTimestamp)
|
||||
except Exception:
|
||||
error "Failed to update lamport timestamp",
|
||||
channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg()
|
||||
@ -150,6 +174,8 @@ proc getRecentHistoryEntries*(
|
||||
var entry = HistoryEntry(messageId: msgId)
|
||||
if not rm.onRetrievalHint.isNil():
|
||||
entry.retrievalHint = rm.onRetrievalHint(msgId)
|
||||
if entry.retrievalHint.len > 0:
|
||||
rm.persistence.setRetrievalHint(msgId, entry.retrievalHint)
|
||||
entry.senderId = channel.messageHistory[msgId].senderId
|
||||
entries.add(entry)
|
||||
return entries
|
||||
@ -226,11 +252,29 @@ proc getIncomingBuffer*(
|
||||
proc getOrCreateChannel*(
|
||||
rm: ReliabilityManager, channelId: SdsChannelID
|
||||
): ChannelContext =
|
||||
## Returns the channel context, creating and bootstrapping it from the
|
||||
## persistence backend if it does not yet exist in memory. The bloom filter
|
||||
## is rebuilt deterministically from the loaded message history rather than
|
||||
## persisted directly. Caller is expected to hold rm.lock.
|
||||
try:
|
||||
if channelId notin rm.channels:
|
||||
rm.channels[channelId] = ChannelContext.new(
|
||||
let channel = ChannelContext.new(
|
||||
RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate)
|
||||
)
|
||||
let snapshot = rm.persistence.loadAllForChannel(channelId)
|
||||
channel.lamportTimestamp = snapshot.lamportTimestamp
|
||||
for msg in snapshot.messageHistory:
|
||||
channel.messageHistory[msg.messageId] = msg
|
||||
channel.bloomFilter.add(msg.messageId)
|
||||
for unack in snapshot.outgoingBuffer:
|
||||
channel.outgoingBuffer.add(unack)
|
||||
for incoming in snapshot.incomingBuffer:
|
||||
channel.incomingBuffer[incoming.message.messageId] = incoming
|
||||
for (msgId, entry) in snapshot.outgoingRepairBuffer:
|
||||
channel.outgoingRepairBuffer[msgId] = entry
|
||||
for (msgId, entry) in snapshot.incomingRepairBuffer:
|
||||
channel.incomingRepairBuffer[msgId] = entry
|
||||
rm.channels[channelId] = channel
|
||||
return rm.channels[channelId]
|
||||
except Exception:
|
||||
error "Failed to get or create channel",
|
||||
@ -256,6 +300,7 @@ proc removeChannel*(
|
||||
try:
|
||||
if channelId in rm.channels:
|
||||
let channel = rm.channels[channelId]
|
||||
rm.dropChannelFromPersistence(channelId, channel)
|
||||
channel.outgoingBuffer.setLen(0)
|
||||
channel.incomingBuffer.clear()
|
||||
channel.messageHistory.clear()
|
||||
|
||||
@ -11,6 +11,7 @@ import sds/types/app_callbacks
|
||||
import sds/types/reliability_config
|
||||
import sds/types/repair_entry
|
||||
import sds/types/channel_context
|
||||
import sds/types/persistence
|
||||
import sds/types/reliability_manager
|
||||
import sds/types/protobuf_error
|
||||
|
||||
@ -28,5 +29,6 @@ export
|
||||
reliability_config,
|
||||
repair_entry,
|
||||
channel_context,
|
||||
persistence,
|
||||
reliability_manager,
|
||||
protobuf_error
|
||||
|
||||
117
sds/types/persistence.nim
Normal file
117
sds/types/persistence.nim
Normal file
@ -0,0 +1,117 @@
|
||||
import ./sds_message_id
|
||||
import ./sds_message
|
||||
import ./unacknowledged_message
|
||||
import ./incoming_message
|
||||
import ./repair_entry
|
||||
export
|
||||
sds_message_id, sds_message, unacknowledged_message, incoming_message, repair_entry
|
||||
|
||||
## SDS state persistence interface (issue #64).
|
||||
##
|
||||
## Defines WHAT operations a persistence backend must provide. The actual
|
||||
## storage technology (SQLite, encrypted file, in-memory) is supplied by the
|
||||
## caller — nim-sds knows nothing about it. Every state-mutating proc in the
|
||||
## protocol calls into one of these procs immediately after the in-memory
|
||||
## change, so on-disk state stays in lockstep with in-memory state.
|
||||
##
|
||||
## Bloom filter is intentionally not persisted: it is rebuilt from the local
|
||||
## history log on bootstrap. Async timers are likewise recomputed from the
|
||||
## absolute timestamps stored in the repair buffer entries.
|
||||
|
||||
type
|
||||
ChannelSnapshot* = object
|
||||
## Returned by `loadAllForChannel` on bootstrap. Carries the entire
|
||||
## per-channel state needed to repopulate a `ChannelContext`. The bloom
|
||||
## filter is NOT in the snapshot — callers rebuild it from `messageHistory`.
|
||||
lamportTimestamp*: int64
|
||||
messageHistory*: seq[SdsMessage]
|
||||
## Delivered messages, oldest first (insertion order preserved for
|
||||
## causal-history tail access and FIFO eviction).
|
||||
outgoingBuffer*: seq[UnacknowledgedMessage]
|
||||
incomingBuffer*: seq[IncomingMessage]
|
||||
outgoingRepairBuffer*: seq[(SdsMessageID, OutgoingRepairEntry)]
|
||||
incomingRepairBuffer*: seq[(SdsMessageID, IncomingRepairEntry)]
|
||||
|
||||
Persistence* = object
|
||||
## Pluggable persistence contract. The caller supplies an instance of this
|
||||
## type at `newReliabilityManager` construction time. Each proc field is
|
||||
## invoked by nim-sds at the corresponding state-mutation point.
|
||||
|
||||
# Per-channel lamport clock
|
||||
saveLamport*:
|
||||
proc(channelId: SdsChannelID, lamport: int64) {.gcsafe, raises: [].}
|
||||
|
||||
# Local log (delivered messages)
|
||||
appendLogEntry*:
|
||||
proc(channelId: SdsChannelID, msg: SdsMessage) {.gcsafe, raises: [].}
|
||||
removeLogEntry*:
|
||||
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
|
||||
setRetrievalHint*:
|
||||
proc(msgId: SdsMessageID, hint: seq[byte]) {.gcsafe, raises: [].}
|
||||
|
||||
# Outgoing unacknowledged buffer
|
||||
saveOutgoing*:
|
||||
proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) {.gcsafe, raises: [].}
|
||||
removeOutgoing*:
|
||||
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
|
||||
|
||||
# Incoming dependency-waiting buffer
|
||||
saveIncoming*:
|
||||
proc(channelId: SdsChannelID, msg: IncomingMessage) {.gcsafe, raises: [].}
|
||||
removeIncoming*:
|
||||
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
|
||||
|
||||
# SDS-R outgoing repair buffer
|
||||
saveOutgoingRepair*: proc(
|
||||
channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry
|
||||
) {.gcsafe, raises: [].}
|
||||
removeOutgoingRepair*:
|
||||
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
|
||||
|
||||
# SDS-R incoming repair buffer
|
||||
saveIncomingRepair*: proc(
|
||||
channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry
|
||||
) {.gcsafe, raises: [].}
|
||||
removeIncomingRepair*:
|
||||
proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].}
|
||||
|
||||
# Bootstrap on `addChannel` / `getOrCreateChannel`
|
||||
loadAllForChannel*:
|
||||
proc(channelId: SdsChannelID): ChannelSnapshot {.gcsafe, raises: [].}
|
||||
|
||||
proc noOpPersistence*(): Persistence =
|
||||
## Default backend that discards every write and returns an empty snapshot.
|
||||
## Used so existing callers (and tests) that don't care about durability
|
||||
## keep working without supplying a real backend.
|
||||
Persistence(
|
||||
saveLamport: proc(channelId: SdsChannelID, lamport: int64) =
|
||||
discard,
|
||||
appendLogEntry: proc(channelId: SdsChannelID, msg: SdsMessage) =
|
||||
discard,
|
||||
removeLogEntry: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
|
||||
discard,
|
||||
setRetrievalHint: proc(msgId: SdsMessageID, hint: seq[byte]) =
|
||||
discard,
|
||||
saveOutgoing: proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) =
|
||||
discard,
|
||||
removeOutgoing: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
|
||||
discard,
|
||||
saveIncoming: proc(channelId: SdsChannelID, msg: IncomingMessage) =
|
||||
discard,
|
||||
removeIncoming: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
|
||||
discard,
|
||||
saveOutgoingRepair: proc(
|
||||
channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry
|
||||
) =
|
||||
discard,
|
||||
removeOutgoingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
|
||||
discard,
|
||||
saveIncomingRepair: proc(
|
||||
channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry
|
||||
) =
|
||||
discard,
|
||||
removeIncomingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) =
|
||||
discard,
|
||||
loadAllForChannel: proc(channelId: SdsChannelID): ChannelSnapshot =
|
||||
ChannelSnapshot(),
|
||||
)
|
||||
@ -4,12 +4,17 @@ import ./history_entry
|
||||
import ./callbacks
|
||||
import ./reliability_config
|
||||
import ./channel_context
|
||||
export sds_message_id, history_entry, callbacks, reliability_config, channel_context
|
||||
import ./persistence
|
||||
export
|
||||
sds_message_id, history_entry, callbacks, reliability_config, channel_context,
|
||||
persistence
|
||||
|
||||
type ReliabilityManager* = ref object
|
||||
channels*: Table[SdsChannelID, ChannelContext]
|
||||
config*: ReliabilityConfig
|
||||
participantId*: SdsParticipantID
|
||||
persistence*: Persistence
|
||||
## Pluggable durability backend; defaults to a no-op when not supplied.
|
||||
lock*: Lock
|
||||
onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
|
||||
onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
|
||||
@ -24,11 +29,13 @@ proc new*(
|
||||
T: type ReliabilityManager,
|
||||
config: ReliabilityConfig,
|
||||
participantId: SdsParticipantID = "".SdsParticipantID,
|
||||
persistence: Persistence = noOpPersistence(),
|
||||
): T =
|
||||
let rm = T(
|
||||
channels: initTable[SdsChannelID, ChannelContext](),
|
||||
config: config,
|
||||
participantId: participantId,
|
||||
persistence: persistence,
|
||||
)
|
||||
rm.lock.initLock()
|
||||
return rm
|
||||
|
||||
113
tests/in_memory_persistence.nim
Normal file
113
tests/in_memory_persistence.nim
Normal file
@ -0,0 +1,113 @@
|
||||
import std/tables
|
||||
import sds
|
||||
|
||||
## Test-only Persistence backend backed by Nim tables. Lets tests verify the
|
||||
## full write → restart → read-back loop without depending on SQLite (or any
|
||||
## real storage technology). Exposes the underlying store so tests can assert
|
||||
## on what got saved.
|
||||
|
||||
type InMemoryStore* = ref object
|
||||
lamports*: Table[SdsChannelID, int64]
|
||||
log*: Table[SdsChannelID, OrderedTable[SdsMessageID, SdsMessage]]
|
||||
hints*: Table[SdsMessageID, seq[byte]]
|
||||
outgoing*: Table[SdsChannelID, OrderedTable[SdsMessageID, UnacknowledgedMessage]]
|
||||
incoming*: Table[SdsChannelID, OrderedTable[SdsMessageID, IncomingMessage]]
|
||||
outgoingRepair*: Table[SdsChannelID, OrderedTable[SdsMessageID, OutgoingRepairEntry]]
|
||||
incomingRepair*: Table[SdsChannelID, OrderedTable[SdsMessageID, IncomingRepairEntry]]
|
||||
|
||||
proc newInMemoryStore*(): InMemoryStore =
|
||||
InMemoryStore()
|
||||
|
||||
proc newInMemoryPersistence*(store: InMemoryStore): Persistence =
|
||||
Persistence(
|
||||
saveLamport: proc(channelId: SdsChannelID, lamport: int64) {.gcsafe, raises: [].} =
|
||||
store.lamports[channelId] = lamport,
|
||||
|
||||
appendLogEntry: proc(channelId: SdsChannelID, msg: SdsMessage) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId notin store.log:
|
||||
store.log[channelId] = initOrderedTable[SdsMessageID, SdsMessage]()
|
||||
store.log[channelId][msg.messageId] = msg,
|
||||
|
||||
removeLogEntry: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId in store.log:
|
||||
store.log[channelId].del(msgId),
|
||||
|
||||
setRetrievalHint: proc(msgId: SdsMessageID, hint: seq[byte]) {.gcsafe, raises: [].} =
|
||||
store.hints[msgId] = hint,
|
||||
|
||||
saveOutgoing: proc(channelId: SdsChannelID, msg: UnacknowledgedMessage) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId notin store.outgoing:
|
||||
store.outgoing[channelId] =
|
||||
initOrderedTable[SdsMessageID, UnacknowledgedMessage]()
|
||||
store.outgoing[channelId][msg.message.messageId] = msg,
|
||||
|
||||
removeOutgoing: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId in store.outgoing:
|
||||
store.outgoing[channelId].del(msgId),
|
||||
|
||||
saveIncoming: proc(channelId: SdsChannelID, msg: IncomingMessage) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId notin store.incoming:
|
||||
store.incoming[channelId] =
|
||||
initOrderedTable[SdsMessageID, IncomingMessage]()
|
||||
store.incoming[channelId][msg.message.messageId] = msg,
|
||||
|
||||
removeIncoming: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId in store.incoming:
|
||||
store.incoming[channelId].del(msgId),
|
||||
|
||||
saveOutgoingRepair: proc(
|
||||
channelId: SdsChannelID, msgId: SdsMessageID, entry: OutgoingRepairEntry
|
||||
) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId notin store.outgoingRepair:
|
||||
store.outgoingRepair[channelId] =
|
||||
initOrderedTable[SdsMessageID, OutgoingRepairEntry]()
|
||||
store.outgoingRepair[channelId][msgId] = entry,
|
||||
|
||||
removeOutgoingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId in store.outgoingRepair:
|
||||
store.outgoingRepair[channelId].del(msgId),
|
||||
|
||||
saveIncomingRepair: proc(
|
||||
channelId: SdsChannelID, msgId: SdsMessageID, entry: IncomingRepairEntry
|
||||
) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId notin store.incomingRepair:
|
||||
store.incomingRepair[channelId] =
|
||||
initOrderedTable[SdsMessageID, IncomingRepairEntry]()
|
||||
store.incomingRepair[channelId][msgId] = entry,
|
||||
|
||||
removeIncomingRepair: proc(channelId: SdsChannelID, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
if channelId in store.incomingRepair:
|
||||
store.incomingRepair[channelId].del(msgId),
|
||||
|
||||
loadAllForChannel: proc(channelId: SdsChannelID): ChannelSnapshot {.gcsafe, raises: [].} =
|
||||
{.cast(raises: []).}:
|
||||
var snap = ChannelSnapshot()
|
||||
if channelId in store.lamports:
|
||||
snap.lamportTimestamp = store.lamports[channelId]
|
||||
if channelId in store.log:
|
||||
for msg in store.log[channelId].values:
|
||||
snap.messageHistory.add(msg)
|
||||
if channelId in store.outgoing:
|
||||
for unack in store.outgoing[channelId].values:
|
||||
snap.outgoingBuffer.add(unack)
|
||||
if channelId in store.incoming:
|
||||
for incoming in store.incoming[channelId].values:
|
||||
snap.incomingBuffer.add(incoming)
|
||||
if channelId in store.outgoingRepair:
|
||||
for msgId, entry in store.outgoingRepair[channelId]:
|
||||
snap.outgoingRepairBuffer.add((msgId, entry))
|
||||
if channelId in store.incomingRepair:
|
||||
for msgId, entry in store.incomingRepair[channelId]:
|
||||
snap.incomingRepairBuffer.add((msgId, entry))
|
||||
snap,
|
||||
)
|
||||
111
tests/test_persistence.nim
Normal file
111
tests/test_persistence.nim
Normal file
@ -0,0 +1,111 @@
|
||||
import unittest, results, std/tables
|
||||
import sds
|
||||
import ./in_memory_persistence
|
||||
|
||||
converter toParticipantID(s: string): SdsParticipantID = s.SdsParticipantID
|
||||
|
||||
const testChannel = "testChannel"
|
||||
|
||||
suite "Persistence: write → restart → read-back":
|
||||
test "outgoing buffer survives restart":
|
||||
let store = newInMemoryStore()
|
||||
let p1 = newInMemoryPersistence(store)
|
||||
let rm1 = newReliabilityManager(persistence = p1).get()
|
||||
check rm1.ensureChannel(testChannel).isOk()
|
||||
let wrapped = rm1.wrapOutgoingMessage(@[1.byte, 2, 3], "msg-1", testChannel)
|
||||
check wrapped.isOk()
|
||||
check store.outgoing[testChannel].len == 1
|
||||
check "msg-1" in store.outgoing[testChannel]
|
||||
rm1.cleanup()
|
||||
|
||||
# Simulate restart: fresh manager, same backend
|
||||
let p2 = newInMemoryPersistence(store)
|
||||
let rm2 = newReliabilityManager(persistence = p2).get()
|
||||
check rm2.ensureChannel(testChannel).isOk()
|
||||
let buf = rm2.getOutgoingBuffer(testChannel)
|
||||
check buf.len == 1
|
||||
check buf[0].message.messageId == "msg-1"
|
||||
rm2.cleanup()
|
||||
|
||||
test "lamport clock survives restart":
|
||||
let store = newInMemoryStore()
|
||||
let p1 = newInMemoryPersistence(store)
|
||||
let rm1 = newReliabilityManager(persistence = p1).get()
|
||||
check rm1.ensureChannel(testChannel).isOk()
|
||||
rm1.updateLamportTimestamp(42, testChannel)
|
||||
check store.lamports[testChannel] == 43 # max(42, 0) + 1
|
||||
rm1.cleanup()
|
||||
|
||||
let p2 = newInMemoryPersistence(store)
|
||||
let rm2 = newReliabilityManager(persistence = p2).get()
|
||||
check rm2.ensureChannel(testChannel).isOk()
|
||||
check rm2.channels[testChannel].lamportTimestamp == 43
|
||||
|
||||
test "delivered messages survive restart and rebuild bloom":
|
||||
let store = newInMemoryStore()
|
||||
let p1 = newInMemoryPersistence(store)
|
||||
let rm1 = newReliabilityManager(persistence = p1).get()
|
||||
check rm1.ensureChannel(testChannel).isOk()
|
||||
let msg = SdsMessage.init(
|
||||
messageId = "delivered-1",
|
||||
lamportTimestamp = 1,
|
||||
causalHistory = @[],
|
||||
channelId = testChannel,
|
||||
content = @[9.byte, 9],
|
||||
bloomFilter = @[],
|
||||
senderId = "alice",
|
||||
)
|
||||
rm1.addToHistory(msg, testChannel)
|
||||
check store.log[testChannel].len == 1
|
||||
rm1.cleanup()
|
||||
|
||||
let p2 = newInMemoryPersistence(store)
|
||||
let rm2 = newReliabilityManager(persistence = p2).get()
|
||||
check rm2.ensureChannel(testChannel).isOk()
|
||||
let ch = rm2.channels[testChannel]
|
||||
check ch.messageHistory.len == 1
|
||||
check "delivered-1" in ch.messageHistory
|
||||
# Bloom filter rebuilt from log on bootstrap
|
||||
check ch.bloomFilter.contains("delivered-1")
|
||||
|
||||
test "ack removes outgoing entry from persistence":
|
||||
let store = newInMemoryStore()
|
||||
let p = newInMemoryPersistence(store)
|
||||
let rm = newReliabilityManager(persistence = p).get()
|
||||
check rm.ensureChannel(testChannel).isOk()
|
||||
discard rm.wrapOutgoingMessage(@[1.byte], "msg-x", testChannel)
|
||||
check "msg-x" in store.outgoing[testChannel]
|
||||
|
||||
# Synthesize an incoming message that ACKs msg-x via causal history
|
||||
let ackMsg = SdsMessage.init(
|
||||
messageId = "ack-bearer",
|
||||
lamportTimestamp = 5,
|
||||
causalHistory = @[HistoryEntry.init("msg-x", @[])],
|
||||
channelId = testChannel,
|
||||
content = @[],
|
||||
bloomFilter = @[],
|
||||
senderId = "bob",
|
||||
)
|
||||
let serialized = serializeMessage(ackMsg).get()
|
||||
discard rm.unwrapReceivedMessage(serialized)
|
||||
check "msg-x" notin store.outgoing[testChannel]
|
||||
rm.cleanup()
|
||||
|
||||
test "removeChannel fires per-entry removes":
|
||||
let store = newInMemoryStore()
|
||||
let p = newInMemoryPersistence(store)
|
||||
let rm = newReliabilityManager(persistence = p).get()
|
||||
check rm.ensureChannel(testChannel).isOk()
|
||||
discard rm.wrapOutgoingMessage(@[1.byte], "msg-r", testChannel)
|
||||
check store.outgoing[testChannel].len == 1
|
||||
check rm.removeChannel(testChannel).isOk()
|
||||
check store.outgoing[testChannel].len == 0
|
||||
rm.cleanup()
|
||||
|
||||
test "noOpPersistence keeps existing manager working":
|
||||
let rm = newReliabilityManager().get() # default no-op
|
||||
check rm.ensureChannel(testChannel).isOk()
|
||||
let wrapped = rm.wrapOutgoingMessage(@[1.byte], "msg-n", testChannel)
|
||||
check wrapped.isOk()
|
||||
check rm.getOutgoingBuffer(testChannel).len == 1
|
||||
rm.cleanup()
|
||||
Loading…
x
Reference in New Issue
Block a user