diff --git a/library/events/json_repair_ready_event.nim b/library/events/json_repair_ready_event.nim new file mode 100644 index 0000000..d34c954 --- /dev/null +++ b/library/events/json_repair_ready_event.nim @@ -0,0 +1,20 @@ +import std/[json, base64] +import ./json_base_event, sds/[message] + +type JsonRepairReadyEvent* = ref object of JsonEvent + channelId*: SdsChannelID + message*: seq[byte] + +proc new*( + T: type JsonRepairReadyEvent, message: seq[byte], channelId: SdsChannelID +): T = + return JsonRepairReadyEvent( + eventType: "repair_ready", message: message, channelId: channelId + ) + +method `$`*(jsonRepairReady: JsonRepairReadyEvent): string = + var node = newJObject() + node["eventType"] = %*jsonRepairReady.eventType + node["channelId"] = %*jsonRepairReady.channelId + node["message"] = %*encode(jsonRepairReady.message) + $node diff --git a/library/libsds.nim b/library/libsds.nim index 4ae285f..af05857 100644 --- a/library/libsds.nim +++ b/library/libsds.nim @@ -16,7 +16,7 @@ import sds, ./events/[ json_message_ready_event, json_message_sent_event, json_missing_dependencies_event, - json_periodic_sync_event, + json_periodic_sync_event, json_repair_ready_event, ] ################################################################################ @@ -114,6 +114,11 @@ proc onPeriodicSync(ctx: ptr SdsContext): PeriodicSyncCallback = callEventCallback(ctx, "onPeriodicSync"): $JsonPeriodicSyncEvent.new() +proc onRepairReady(ctx: ptr SdsContext): RepairReadyCallback = + return proc(message: seq[byte], channelId: SdsChannelID) {.gcsafe.} = + callEventCallback(ctx, "onRepairReady"): + $JsonRepairReadyEvent.new(message, channelId) + proc onRetrievalHint(ctx: ptr SdsContext): RetrievalHintProvider = return proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} = if isNil(ctx.retrievalHintProvider): @@ -196,6 +201,7 @@ proc SdsNewReliabilityManager( missingDependenciesCb: onMissingDependencies(ctx), periodicSyncCb: onPeriodicSync(ctx), retrievalHintProvider: onRetrievalHint(ctx), + repairReadyCb: onRepairReady(ctx), ) let retCode = handleRequest( diff --git a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim index 8d2e9bc..a0f3adb 100644 --- a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim +++ b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim @@ -40,7 +40,7 @@ proc createReliabilityManager( rm.setCallbacks( appCallbacks.messageReadyCb, appCallbacks.messageSentCb, appCallbacks.missingDependenciesCb, appCallbacks.periodicSyncCb, - appCallbacks.retrievalHintProvider, + appCallbacks.retrievalHintProvider, appCallbacks.repairReadyCb, ) return ok(rm) diff --git a/sds.nim b/sds.nim index 58d1893..202b489 100644 --- a/sds.nim +++ b/sds.nim @@ -1,15 +1,16 @@ -import std/[times, locks, tables, sets, options] +import std/[algorithm, times, locks, tables, sets, options] import chronos, results, chronicles import sds/[types, protobuf, sds_utils, rolling_bloom_filter] export types, protobuf, sds_utils, rolling_bloom_filter proc newReliabilityManager*( - config: ReliabilityConfig = defaultConfig() + config: ReliabilityConfig = defaultConfig(), + participantId: SdsParticipantID = "".SdsParticipantID, ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new multi-channel ReliabilityManager. try: - let rm = ReliabilityManager.new(config) + let rm = ReliabilityManager.new(config, participantId) return ok(rm) except Exception: error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() @@ -88,6 +89,26 @@ proc wrapOutgoingMessage*( error "Failed to serialize bloom filter", channelId = channelId return err(ReliabilityError.reSerializationError) + # SDS-R: collect eligible expired repair requests to attach. Per + # spec (sds-r-send-message, RECOMMENDED), prioritise the entries with + # the smallest minTimeRepairReq — they are the most overdue and the + # ones the network most needs us to ask about. + var repairReqs: seq[HistoryEntry] = @[] + let now = getTime() + var expiredKeys: seq[SdsMessageID] = @[] + var eligible: seq[(SdsMessageID, OutgoingRepairEntry)] = @[] + for msgId, repairEntry in channel.outgoingRepairBuffer: + if now >= repairEntry.minTimeRepairReq: + eligible.add((msgId, repairEntry)) + eligible.sort do(a, b: (SdsMessageID, OutgoingRepairEntry)) -> int: + cmp(a[1].minTimeRepairReq, b[1].minTimeRepairReq) + let take = min(eligible.len, rm.config.maxRepairRequests) + for i in 0 ..< take: + repairReqs.add(eligible[i][1].outHistEntry) + expiredKeys.add(eligible[i][0]) + for key in expiredKeys: + channel.outgoingRepairBuffer.del(key) + let msg = SdsMessage.init( messageId = messageId, lamportTimestamp = channel.lamportTimestamp, @@ -95,6 +116,8 @@ proc wrapOutgoingMessage*( channelId = channelId, content = message, bloomFilter = bfResult.get(), + senderId = rm.participantId, + repairRequest = repairReqs, ) channel.outgoingBuffer.add( @@ -102,7 +125,10 @@ proc wrapOutgoingMessage*( ) channel.bloomFilter.add(msg.messageId) - rm.addToHistory(msg.messageId, channelId) + # The full SdsMessage carries senderId and content, so a single + # addToHistory replaces the old triple-write to messageHistory, + # messageCache, and messageSenders. + rm.addToHistory(msg, channelId) return serializeMessage(msg) except Exception: @@ -133,7 +159,7 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc continue if msgId in channel.incomingBuffer: - rm.addToHistory(msgId, channelId) + rm.addToHistory(channel.incomingBuffer[msgId].message, channelId) if not rm.onMessageReady.isNil(): rm.onMessageReady(msgId, channelId) processed.incl(msgId) @@ -164,6 +190,11 @@ proc unwrapReceivedMessage*( let channel = rm.getOrCreateChannel(channelId) + # SDS-R: opportunistic repair-buffer cleanup — applies to duplicates too, + # so rebroadcasts cancel redundant responses on peers that already have the message. + channel.outgoingRepairBuffer.del(msg.messageId) + channel.incomingRepairBuffer.del(msg.messageId) + if msg.messageId in channel.messageHistory: return ok((msg.content, @[], channelId)) @@ -172,6 +203,32 @@ proc unwrapReceivedMessage*( rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) rm.reviewAckStatus(msg) + # SDS-R: process incoming repair requests from this message. We can only + # answer for messages we have actually delivered (i.e. that live in + # messageHistory) — buffered-but-undelivered messages are not in a state + # to confidently rebroadcast. + let now = getTime() + for repairEntry in msg.repairRequest: + # Remove from our own outgoing repair buffer (someone else is also requesting) + channel.outgoingRepairBuffer.del(repairEntry.messageId) + if repairEntry.messageId in channel.messageHistory and + rm.participantId.len > 0 and repairEntry.senderId.len > 0: + if isInResponseGroup( + rm.participantId, repairEntry.senderId, + repairEntry.messageId, rm.config.numResponseGroups + ): + let serialized = serializeMessage(channel.messageHistory[repairEntry.messageId]) + if serialized.isOk(): + let tResp = computeTResp( + rm.participantId, repairEntry.senderId, + repairEntry.messageId, rm.config.repairTMax + ) + channel.incomingRepairBuffer[repairEntry.messageId] = IncomingRepairEntry( + inHistEntry: repairEntry, + cachedMessage: serialized.get(), + minTimeRepairResp: now + tResp, + ) + var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) if missingDeps.len == 0: @@ -184,7 +241,11 @@ proc unwrapReceivedMessage*( channel.incomingBuffer[msg.messageId] = IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]()) else: - rm.addToHistory(msg.messageId, channelId) + rm.addToHistory(msg, channelId) + # Unblock any buffered messages that were waiting on this one. + for pendingId, entry in channel.incomingBuffer: + if msg.messageId in entry.missingDeps: + channel.incomingBuffer[pendingId].missingDeps.excl(msg.messageId) rm.processIncomingBuffer(channelId) if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId, channelId) @@ -197,6 +258,19 @@ proc unwrapReceivedMessage*( if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps, channelId) + # SDS-R: add missing deps to outgoing repair buffer + if rm.participantId.len > 0: + for dep in missingDeps: + if dep.messageId notin channel.outgoingRepairBuffer: + let tReq = computeTReq( + rm.participantId, dep.messageId, + rm.config.repairTMin, rm.config.repairTMax + ) + channel.outgoingRepairBuffer[dep.messageId] = OutgoingRepairEntry( + outHistEntry: dep, + minTimeRepairReq: now + tReq, + ) + return ok((msg.content, missingDeps, channelId)) except Exception: error "Failed to unwrap message", msg = getCurrentExceptionMsg() @@ -220,6 +294,10 @@ proc markDependenciesMet*( if msgId in entry.missingDeps: channel.incomingBuffer[pendingId].missingDeps.excl(msgId) + # SDS-R: clear from repair buffers (dependency now met) + channel.outgoingRepairBuffer.del(msgId) + channel.incomingRepairBuffer.del(msgId) + rm.processIncomingBuffer(channelId) return ok() except Exception: @@ -234,6 +312,7 @@ proc setCallbacks*( onMissingDependencies: MissingDependenciesCallback, onPeriodicSync: PeriodicSyncCallback = nil, onRetrievalHint: RetrievalHintProvider = nil, + onRepairReady: RepairReadyCallback = nil, ) = ## Sets the callback functions for various events in the ReliabilityManager. withLock rm.lock: @@ -242,6 +321,7 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync rm.onRetrievalHint = onRetrievalHint + rm.onRepairReady = onRepairReady proc checkUnacknowledgedMessages( rm: ReliabilityManager, channelId: SdsChannelID @@ -299,10 +379,57 @@ proc periodicSyncMessage( error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) +proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} = + ## SDS-R: Runs a single pass of the repair sweep. + ## - Incoming: fires onRepairReady for expired T_resp entries and removes them + ## - Outgoing: drops entries past T_max window + ## Exposed so it can be driven directly in tests; also invoked by periodicRepairSweep. + ## Acquires rm.lock so the repair buffers cannot be observed mid-mutation by + ## a concurrent wrapOutgoingMessage / unwrapReceivedMessage on another thread. + withLock rm.lock: + try: + let now = getTime() + for channelId, channel in rm.channels: + try: + # Check incoming repair buffer for expired T_resp (time to rebroadcast) + var toRebroadcast: seq[SdsMessageID] = @[] + for msgId, entry in channel.incomingRepairBuffer: + if now >= entry.minTimeRepairResp: + toRebroadcast.add(msgId) + + for msgId in toRebroadcast: + let entry = channel.incomingRepairBuffer[msgId] + channel.incomingRepairBuffer.del(msgId) + if not rm.onRepairReady.isNil(): + rm.onRepairReady(entry.cachedMessage, channelId) + + # Drop expired outgoing repair entries past T_max + var toRemove: seq[SdsMessageID] = @[] + let tMaxDuration = rm.config.repairTMax + for msgId, entry in channel.outgoingRepairBuffer: + if now - entry.minTimeRepairReq > tMaxDuration: + toRemove.add(msgId) + for msgId in toRemove: + channel.outgoingRepairBuffer.del(msgId) + except Exception: + error "Error in repair sweep for channel", + channelId = channelId, msg = getCurrentExceptionMsg() + except Exception: + error "Error in repair sweep", msg = getCurrentExceptionMsg() + +proc periodicRepairSweep( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = + ## SDS-R: Periodically checks repair buffers for expired entries. + while true: + rm.runRepairSweep() + await sleepAsync(chronos.milliseconds(rm.config.repairSweepInterval.inMilliseconds)) + proc startPeriodicTasks*(rm: ReliabilityManager) = ## Starts the periodic tasks for buffer sweeping and sync message sending. asyncSpawn rm.periodicBufferSweep() asyncSpawn rm.periodicSyncMessage() + asyncSpawn rm.periodicRepairSweep() proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = ## Resets the ReliabilityManager to its initial state. @@ -310,9 +437,11 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE try: for channelId, channel in rm.channels: channel.lamportTimestamp = 0 - channel.messageHistory.setLen(0) + channel.messageHistory.clear() channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() channel.bloomFilter = RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) rm.channels.clear() diff --git a/sds/message.nim b/sds/message.nim index ddf5e5f..410fc43 100644 --- a/sds/message.nim +++ b/sds/message.nim @@ -3,6 +3,7 @@ import ./types/history_entry import ./types/sds_message import ./types/unacknowledged_message import ./types/incoming_message +import ./types/repair_entry import ./types/reliability_config export @@ -11,4 +12,5 @@ export sds_message, unacknowledged_message, incoming_message, + repair_entry, reliability_config diff --git a/sds/protobuf.nim b/sds/protobuf.nim index ba1b7ff..916bf18 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -5,6 +5,26 @@ import ./protobufutil import ./bloom import ./sds_utils +proc encodeHistoryEntry*(entry: HistoryEntry): ProtoBuffer = + var entryPb = initProtoBuffer() + entryPb.write(1, entry.messageId) + if entry.retrievalHint.len > 0: + entryPb.write(2, entry.retrievalHint) + if entry.senderId.len > 0: + entryPb.write(3, entry.senderId.string) + entryPb.finish() + entryPb + +proc decodeHistoryEntry*(entryPb: ProtoBuffer): ProtobufResult[HistoryEntry] = + var entry = HistoryEntry.init("") + if not ?entryPb.getField(1, entry.messageId): + return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) + discard entryPb.getField(2, entry.retrievalHint) + var senderIdStr: string + if entryPb.getField(3, senderIdStr).valueOr(false): + entry.senderId = senderIdStr.SdsParticipantID + ok(entry) + proc encode*(msg: SdsMessage): ProtoBuffer = var pb = initProtoBuffer() @@ -12,16 +32,20 @@ proc encode*(msg: SdsMessage): ProtoBuffer = pb.write(2, uint64(msg.lamportTimestamp)) for entry in msg.causalHistory: - var entryPb = initProtoBuffer() - entryPb.write(1, entry.messageId) - if entry.retrievalHint.len > 0: - entryPb.write(2, entry.retrievalHint) - entryPb.finish() + let entryPb = encodeHistoryEntry(entry) pb.write(3, entryPb.buffer) pb.write(4, msg.channelId) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) + + if msg.senderId.len > 0: + pb.write(7, msg.senderId.string) + + for entry in msg.repairRequest: + let entryPb = encodeHistoryEntry(entry) + pb.write(13, entryPb.buffer) + pb.finish() return pb @@ -44,11 +68,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = # New format: repeated HistoryEntry for histBuffer in historyBuffers: let entryPb = initProtoBuffer(histBuffer) - var entry = HistoryEntry.init("") - if not ?entryPb.getField(1, entry.messageId): - return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) - # retrievalHint is optional - discard entryPb.getField(2, entry.retrievalHint) + let entry = ?decodeHistoryEntry(entryPb) msg.causalHistory.add(entry) else: # Try old format: repeated string @@ -66,6 +86,19 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if not ?pb.getField(6, msg.bloomFilter): msg.bloomFilter = @[] # Empty if not present + # SDS-R: decode senderId (field 7, optional) + var msgSenderIdStr: string + if pb.getField(7, msgSenderIdStr).valueOr(false): + msg.senderId = msgSenderIdStr.SdsParticipantID + + # SDS-R: decode repair request (field 13, optional) + var repairBuffers: seq[seq[byte]] + if pb.getRepeatedField(13, repairBuffers).isOk(): + for repairBuffer in repairBuffers: + let entryPb = initProtoBuffer(repairBuffer) + let entry = ?decodeHistoryEntry(entryPb) + msg.repairRequest.add(entry) + return ok(msg) proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index f1a68ca..eefae43 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -1,15 +1,15 @@ -import std/[locks, tables, sequtils] +import std/[times, locks, tables, sequtils, hashes] import chronicles, results import ./rolling_bloom_filter import ./types/[ sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, - reliability_error, callbacks, app_callbacks, reliability_config, channel_context, - reliability_manager, + reliability_error, callbacks, app_callbacks, reliability_config, repair_entry, + channel_context, reliability_manager, ] export sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, - reliability_error, callbacks, app_callbacks, reliability_config, channel_context, - reliability_manager + reliability_error, callbacks, app_callbacks, reliability_config, repair_entry, + channel_context, reliability_manager proc defaultConfig*(): ReliabilityConfig = return ReliabilityConfig.init() @@ -21,7 +21,9 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = for channelId, channel in rm.channels: channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() - channel.messageHistory.setLen(0) + channel.messageHistory.clear() + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() rm.channels.clear() except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() @@ -38,17 +40,25 @@ proc cleanBloomFilter*( error = getCurrentExceptionMsg(), channelId = channelId proc addToHistory*( - rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsChannelID + rm: ReliabilityManager, msg: SdsMessage, channelId: SdsChannelID ) {.gcsafe, raises: [].} = + ## Inserts a delivered message into the channel's history map and evicts the + ## eldest entries when the bound is exceeded. The full SdsMessage is kept so + ## senderId is available for downstream causal-history population and the + ## bytes can be re-serialized on demand to answer SDS-R repair requests. try: if channelId in rm.channels: let channel = rm.channels[channelId] - channel.messageHistory.add(msgId) - if channel.messageHistory.len > rm.config.maxMessageHistory: - channel.messageHistory.delete(0) + channel.messageHistory[msg.messageId] = msg + while channel.messageHistory.len > rm.config.maxMessageHistory: + var firstKey: SdsMessageID + for k in channel.messageHistory.keys: + firstKey = k + break + channel.messageHistory.del(firstKey) except Exception: error "Failed to add to history", - channelId = channelId, msgId = msgId, error = getCurrentExceptionMsg() + channelId = channelId, msgId = msg.messageId, error = getCurrentExceptionMsg() proc updateLamportTimestamp*( rm: ReliabilityManager, msgTs: int64, channelId: SdsChannelID @@ -70,21 +80,79 @@ proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] = proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] = return causalHistory.mapIt(it.messageId) +## SDS-R: Repair computation functions + +proc computeTReq*( + participantId: SdsParticipantID, + messageId: SdsMessageID, + tMin: Duration, + tMax: Duration, +): Duration = + ## Computes the repair request backoff duration per SDS-R spec: + ## T_req = hash(participant_id, message_id) % (T_max - T_min) + T_min + let h = abs(hash(participantId.string & messageId)) + let rangeMs = tMax.inMilliseconds - tMin.inMilliseconds + if rangeMs <= 0: + return tMin + let offsetMs = h mod rangeMs + initDuration(milliseconds = tMin.inMilliseconds + offsetMs) + +proc computeTResp*( + participantId: SdsParticipantID, + senderId: SdsParticipantID, + messageId: SdsMessageID, + tMax: Duration, +): Duration = + ## Computes the repair response backoff duration per SDS-R spec: + ## distance = hash(participant_id) XOR hash(sender_id) + ## T_resp = distance * hash(message_id) % T_max + ## Original sender has distance=0, so T_resp=0 (responds immediately). + let distance = abs(hash(participantId) xor hash(senderId)) + let msgHash = abs(hash(messageId)) + let tMaxMs = tMax.inMilliseconds + if tMaxMs <= 0 or distance == 0: + return initDuration(milliseconds = 0) + # Use uint64 to avoid overflow on multiplication + let d = uint64(distance mod tMaxMs) + let m = uint64(msgHash mod tMaxMs) + let offsetMs = int64((d * m) mod uint64(tMaxMs)) + initDuration(milliseconds = offsetMs) + +proc isInResponseGroup*( + participantId: SdsParticipantID, + senderId: SdsParticipantID, + messageId: SdsMessageID, + numResponseGroups: int, +): bool = + ## Determines if this participant is in the response group for a given message per SDS-R spec: + ## hash(participant_id, message_id) % num_groups == hash(sender_id, message_id) % num_groups + if numResponseGroups <= 1: + return true # All participants in the same group + let myGroup = abs(hash(participantId.string & messageId)) mod numResponseGroups + let senderGroup = abs(hash(senderId.string & messageId)) mod numResponseGroups + myGroup == senderGroup + proc getRecentHistoryEntries*( rm: ReliabilityManager, n: int, channelId: SdsChannelID ): seq[HistoryEntry] = + ## Get recent history entries for sending in causal history. + ## Populates retrieval hints and senderId (SDS-R) for each entry. try: if channelId in rm.channels: let channel = rm.channels[channelId] - let recentMessageIds = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] - if rm.onRetrievalHint.isNil(): - return toCausalHistory(recentMessageIds) - else: - var entries: seq[HistoryEntry] = @[] - for msgId in recentMessageIds: - let hint = rm.onRetrievalHint(msgId) - entries.add(newHistoryEntry(msgId, hint)) - return entries + var orderedIds: seq[SdsMessageID] = @[] + for msgId in channel.messageHistory.keys: + orderedIds.add(msgId) + let recentMessageIds = + orderedIds[max(0, orderedIds.len - n) .. ^1] + var entries: seq[HistoryEntry] = @[] + for msgId in recentMessageIds: + var entry = HistoryEntry(messageId: msgId) + if not rm.onRetrievalHint.isNil(): + entry.retrievalHint = rm.onRetrievalHint(msgId) + entry.senderId = channel.messageHistory[msgId].senderId + entries.add(entry) + return entries else: return @[] except Exception: @@ -116,7 +184,10 @@ proc getMessageHistory*( withLock rm.lock: try: if channelId in rm.channels: - return rm.channels[channelId].messageHistory + var ids: seq[SdsMessageID] = @[] + for msgId in rm.channels[channelId].messageHistory.keys: + ids.add(msgId) + return ids else: return @[] except Exception: @@ -187,7 +258,9 @@ proc removeChannel*( let channel = rm.channels[channelId] channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() - channel.messageHistory.setLen(0) + channel.messageHistory.clear() + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() rm.channels.del(channelId) return ok() except Exception: diff --git a/sds/types.nim b/sds/types.nim index f37518a..637ec54 100644 --- a/sds/types.nim +++ b/sds/types.nim @@ -9,6 +9,7 @@ import sds/types/reliability_error import sds/types/callbacks import sds/types/app_callbacks import sds/types/reliability_config +import sds/types/repair_entry import sds/types/channel_context import sds/types/reliability_manager import sds/types/protobuf_error @@ -25,6 +26,7 @@ export callbacks, app_callbacks, reliability_config, + repair_entry, channel_context, reliability_manager, protobuf_error diff --git a/sds/types/app_callbacks.nim b/sds/types/app_callbacks.nim index 985a97f..84873a6 100644 --- a/sds/types/app_callbacks.nim +++ b/sds/types/app_callbacks.nim @@ -7,6 +7,7 @@ type AppCallbacks* = ref object missingDependenciesCb*: MissingDependenciesCallback periodicSyncCb*: PeriodicSyncCallback retrievalHintProvider*: RetrievalHintProvider + repairReadyCb*: RepairReadyCallback proc new*( T: type AppCallbacks, @@ -15,6 +16,7 @@ proc new*( missingDependenciesCb: MissingDependenciesCallback = nil, periodicSyncCb: PeriodicSyncCallback = nil, retrievalHintProvider: RetrievalHintProvider = nil, + repairReadyCb: RepairReadyCallback = nil, ): T = return T( messageReadyCb: messageReadyCb, @@ -22,4 +24,5 @@ proc new*( missingDependenciesCb: missingDependenciesCb, periodicSyncCb: periodicSyncCb, retrievalHintProvider: retrievalHintProvider, + repairReadyCb: repairReadyCb, ) diff --git a/sds/types/callbacks.nim b/sds/types/callbacks.nim index f1fc4b3..1894f14 100644 --- a/sds/types/callbacks.nim +++ b/sds/types/callbacks.nim @@ -16,3 +16,5 @@ type RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} + + RepairReadyCallback* = proc(message: seq[byte], channelId: SdsChannelID) {.gcsafe.} diff --git a/sds/types/channel_context.nim b/sds/types/channel_context.nim index 0346d18..3f7bee6 100644 --- a/sds/types/channel_context.nim +++ b/sds/types/channel_context.nim @@ -1,22 +1,36 @@ import std/tables import ./sds_message_id +import ./sds_message import ./rolling_bloom_filter import ./unacknowledged_message import ./incoming_message -export sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message +import ./repair_entry +export + sds_message_id, sds_message, rolling_bloom_filter, unacknowledged_message, + incoming_message, repair_entry type ChannelContext* = ref object lamportTimestamp*: int64 - messageHistory*: seq[SdsMessageID] + messageHistory*: OrderedTable[SdsMessageID, SdsMessage] + ## Single source of truth for delivered messages. Holds the deserialized + ## SdsMessage (which carries senderId, lamportTimestamp, content, etc.) so + ## causal history, sender lookup, and SDS-R repair responses can all be + ## answered from one place. OrderedTable preserves insertion order for + ## causal-history tail access and FIFO eviction at maxMessageHistory. bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] incomingBuffer*: Table[SdsMessageID, IncomingMessage] + ## SDS-R buffers + outgoingRepairBuffer*: Table[SdsMessageID, OutgoingRepairEntry] + incomingRepairBuffer*: Table[SdsMessageID, IncomingRepairEntry] proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = return T( lamportTimestamp: 0, - messageHistory: @[], + messageHistory: initOrderedTable[SdsMessageID, SdsMessage](), bloomFilter: bloomFilter, outgoingBuffer: @[], incomingBuffer: initTable[SdsMessageID, IncomingMessage](), + outgoingRepairBuffer: initTable[SdsMessageID, OutgoingRepairEntry](), + incomingRepairBuffer: initTable[SdsMessageID, IncomingRepairEntry](), ) diff --git a/sds/types/history_entry.nim b/sds/types/history_entry.nim index 2435e6f..d06afac 100644 --- a/sds/types/history_entry.nim +++ b/sds/types/history_entry.nim @@ -1,8 +1,15 @@ import ./sds_message_id +export sds_message_id type HistoryEntry* = object messageId*: SdsMessageID retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash) + senderId*: SdsParticipantID ## Original message sender's participant ID (SDS-R) -proc init*(T: type HistoryEntry, messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): T = - return T(messageId: messageId, retrievalHint: retrievalHint) +proc init*( + T: type HistoryEntry, + messageId: SdsMessageID, + retrievalHint: seq[byte] = @[], + senderId: SdsParticipantID = "".SdsParticipantID, +): T = + return T(messageId: messageId, retrievalHint: retrievalHint, senderId: senderId) diff --git a/sds/types/reliability_config.nim b/sds/types/reliability_config.nim index f4e4e78..7cd20f2 100644 --- a/sds/types/reliability_config.nim +++ b/sds/types/reliability_config.nim @@ -7,6 +7,11 @@ const DefaultMaxResendAttempts* = 5 DefaultSyncMessageInterval* = initDuration(seconds = 30) DefaultBufferSweepInterval* = initDuration(seconds = 60) + DefaultRepairTMin* = initDuration(seconds = 30) + DefaultRepairTMax* = initDuration(seconds = 300) + DefaultNumResponseGroups* = 1 + DefaultMaxRepairRequests* = 3 + DefaultRepairSweepInterval* = initDuration(seconds = 5) MaxMessageSize* = 1024 * 1024 # 1 MB import ./rolling_bloom_filter @@ -21,6 +26,12 @@ type ReliabilityConfig* {.requiresInit.} = object maxResendAttempts*: int syncMessageInterval*: Duration bufferSweepInterval*: Duration + ## SDS-R config + repairTMin*: Duration + repairTMax*: Duration + numResponseGroups*: int + maxRepairRequests*: int + repairSweepInterval*: Duration proc init*( T: type ReliabilityConfig, @@ -32,6 +43,11 @@ proc init*( maxResendAttempts: int = DefaultMaxResendAttempts, syncMessageInterval: Duration = DefaultSyncMessageInterval, bufferSweepInterval: Duration = DefaultBufferSweepInterval, + repairTMin: Duration = DefaultRepairTMin, + repairTMax: Duration = DefaultRepairTMax, + numResponseGroups: int = DefaultNumResponseGroups, + maxRepairRequests: int = DefaultMaxRepairRequests, + repairSweepInterval: Duration = DefaultRepairSweepInterval, ): T = return T( bloomFilterCapacity: bloomFilterCapacity, @@ -42,4 +58,9 @@ proc init*( maxResendAttempts: maxResendAttempts, syncMessageInterval: syncMessageInterval, bufferSweepInterval: bufferSweepInterval, + repairTMin: repairTMin, + repairTMax: repairTMax, + numResponseGroups: numResponseGroups, + maxRepairRequests: maxRepairRequests, + repairSweepInterval: repairSweepInterval, ) diff --git a/sds/types/reliability_manager.nim b/sds/types/reliability_manager.nim index 9bfc244..d28ee5d 100644 --- a/sds/types/reliability_manager.nim +++ b/sds/types/reliability_manager.nim @@ -9,6 +9,7 @@ export sds_message_id, history_entry, callbacks, reliability_config, channel_con type ReliabilityManager* = ref object channels*: Table[SdsChannelID, ChannelContext] config*: ReliabilityConfig + participantId*: SdsParticipantID lock*: Lock onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} @@ -17,11 +18,17 @@ type ReliabilityManager* = ref object ) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback onRetrievalHint*: RetrievalHintProvider + onRepairReady*: RepairReadyCallback -proc new*(T: type ReliabilityManager, config: ReliabilityConfig): T = +proc new*( + T: type ReliabilityManager, + config: ReliabilityConfig, + participantId: SdsParticipantID = "".SdsParticipantID, +): T = let rm = T( channels: initTable[SdsChannelID, ChannelContext](), config: config, + participantId: participantId, ) rm.lock.initLock() return rm diff --git a/sds/types/repair_entry.nim b/sds/types/repair_entry.nim new file mode 100644 index 0000000..04503b9 --- /dev/null +++ b/sds/types/repair_entry.nim @@ -0,0 +1,36 @@ +import std/times +import ./history_entry +export history_entry + +type + OutgoingRepairEntry* {.requiresInit.} = object + ## Entry in the outgoing repair request buffer (SDS-R). + ## Tracks a missing message we want to request repair for. + outHistEntry*: HistoryEntry ## The missing history entry + minTimeRepairReq*: Time + ## Earliest time at which we will include this in a repair request (T_REQ in spec) + + IncomingRepairEntry* {.requiresInit.} = object + ## Entry in the incoming repair request buffer (SDS-R). + ## Tracks a repair request from a remote peer that we might respond to. + inHistEntry*: HistoryEntry ## The requested history entry + cachedMessage*: seq[byte] ## Full serialized SDS message for rebroadcast + minTimeRepairResp*: Time + ## Earliest time at which we will rebroadcast (T_RESP in spec) + +proc init*( + T: type OutgoingRepairEntry, outHistEntry: HistoryEntry, minTimeRepairReq: Time +): T = + return T(outHistEntry: outHistEntry, minTimeRepairReq: minTimeRepairReq) + +proc init*( + T: type IncomingRepairEntry, + inHistEntry: HistoryEntry, + cachedMessage: seq[byte], + minTimeRepairResp: Time, +): T = + return T( + inHistEntry: inHistEntry, + cachedMessage: cachedMessage, + minTimeRepairResp: minTimeRepairResp, + ) diff --git a/sds/types/sds_message.nim b/sds/types/sds_message.nim index 12f7add..82197cd 100644 --- a/sds/types/sds_message.nim +++ b/sds/types/sds_message.nim @@ -9,6 +9,9 @@ type SdsMessage* {.requiresInit.} = object channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] + senderId*: SdsParticipantID ## SDS-R: original sender's participant ID + repairRequest*: seq[HistoryEntry] + ## Capped list of missing entries requesting repair (SDS-R) proc init*( T: type SdsMessage, @@ -18,6 +21,8 @@ proc init*( channelId: SdsChannelID, content: seq[byte], bloomFilter: seq[byte], + senderId: SdsParticipantID = "".SdsParticipantID, + repairRequest: seq[HistoryEntry] = @[], ): T = return T( messageId: messageId, @@ -26,4 +31,6 @@ proc init*( channelId: channelId, content: content, bloomFilter: bloomFilter, + senderId: senderId, + repairRequest: repairRequest, ) diff --git a/sds/types/sds_message_id.nim b/sds/types/sds_message_id.nim index 3e8b7c7..05f1ab4 100644 --- a/sds/types/sds_message_id.nim +++ b/sds/types/sds_message_id.nim @@ -1,3 +1,11 @@ +import std/hashes + type SdsMessageID* = string SdsChannelID* = string + SdsParticipantID* = distinct string + +proc `==`*(a, b: SdsParticipantID): bool {.borrow.} +proc `$`*(p: SdsParticipantID): string {.borrow.} +proc len*(p: SdsParticipantID): int {.borrow.} +proc hash*(p: SdsParticipantID): Hash {.borrow.} diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 7100606..290ac9c 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -1,8 +1,22 @@ import unittest, results, chronos, std/[times, options, tables] import sds +# Test-only convenience: implicit string → SdsParticipantID so test fixtures +# can use string literals. Production code retains the distinct-type safety. +converter toParticipantID(s: string): SdsParticipantID = s.SdsParticipantID + const testChannel = "testChannel" +proc seedBloom( + rm: ReliabilityManager, channel: SdsChannelID, n: int, prefix = "noise" +) = + ## Pre-populate a channel's bloom filter with n unrelated ids so the test + ## exercises the manager against a realistic, non-empty filter rather than + ## the implicit empty one a fresh ReliabilityManager would produce. + let ch = rm.channels[channel] + for i in 0 ..< n: + ch.bloomFilter.add(prefix & $i) + # Core functionality tests suite "Core Operations": var rm: ReliabilityManager @@ -41,24 +55,49 @@ suite "Core Operations": missingDeps.len == 0 channelId == testChannel + test "basic message wrapping and unwrapping (non-empty bloom)": + rm.seedBloom(testChannel, 50) + + let msg = @[byte(1), 2, 3] + let msgId = "test-msg-1" + + let wrappedResult = rm.wrapOutgoingMessage(msg, msgId, testChannel) + check wrappedResult.isOk() + let wrapped = wrappedResult.get() + check wrapped.len > 0 + + # The outgoing message must carry the populated bloom snapshot, not an + # empty one — this is the path that was never exercised before. + let decoded = deserializeMessage(wrapped) + check decoded.isOk() + check decoded.get().bloomFilter.len > 0 + + let unwrapResult = rm.unwrapReceivedMessage(wrapped) + check unwrapResult.isOk() + let (unwrapped, missingDeps, channelId) = unwrapResult.get() + check: + unwrapped == msg + missingDeps.len == 0 + channelId == testChannel + test "message ordering": # Create messages with different timestamps - let msg1 = SdsMessage( - messageId: "msg1", - lamportTimestamp: 1, - causalHistory: @[], - channelId: testChannel, - content: @[byte(1)], - bloomFilter: @[], + let msg1 = SdsMessage.init( + messageId = "msg1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[], ) - let msg2 = SdsMessage( - messageId: "msg2", - lamportTimestamp: 5, - causalHistory: @[], - channelId: testChannel, - content: @[byte(2)], - bloomFilter: @[], + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = 5, + causalHistory = @[], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], ) let serialized1 = serializeMessage(msg1) @@ -109,22 +148,22 @@ suite "Reliability Mechanisms": let id3 = "msg3" # Create messages with dependencies - let msg2 = SdsMessage( - messageId: id2, - lamportTimestamp: 2, - causalHistory: toCausalHistory(@[id1]), # msg2 depends on msg1 - channelId: testChannel, - content: @[byte(2)], - bloomFilter: @[], + let msg2 = SdsMessage.init( + messageId = id2, + lamportTimestamp = 2, + causalHistory = toCausalHistory(@[id1]), # msg2 depends on msg1 + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], ) - let msg3 = SdsMessage( - messageId: id3, - lamportTimestamp: 3, - causalHistory: toCausalHistory(@[id1, id2]), # msg3 depends on both msg1 and msg2 - channelId: testChannel, - content: @[byte(3)], - bloomFilter: @[], + let msg3 = SdsMessage.init( + messageId = id3, + lamportTimestamp = 3, + causalHistory = toCausalHistory(@[id1, id2]), # msg3 depends on both msg1 and msg2 + channelId = testChannel, + content = @[byte(3)], + bloomFilter = @[], ) let serialized2 = serializeMessage(msg2) @@ -166,6 +205,52 @@ suite "Reliability Mechanisms": messageReadyCount == 2 # Both msg2 and msg3 should be ready missingDepsCount == 2 # Should still be 2 from the initial missing deps + test "dependency detection and resolution (non-empty bloom)": + # A populated bloom filter must not short-circuit the dependency check. + # Dependency resolution reads messageHistory, not the bloom — but a future + # "optimisation" could regress this. Seed the bloom with the dep id so a + # bloom-based shortcut would mistakenly mark the dep as satisfied. + var missingDepsCount = 0 + var messageReadyCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + messageReadyCount += 1, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + missingDepsCount += 1, + ) + + let id1 = "msg1" + let id2 = "msg2" + + rm.seedBloom(testChannel, 30) + # Crucially, also seed the bloom with id1 itself — the dep we will be + # missing from messageHistory. The manager must still report it missing. + rm.channels[testChannel].bloomFilter.add(id1) + + let msg2 = SdsMessage.init( + messageId = id2, + lamportTimestamp = 2, + causalHistory = toCausalHistory(@[id1]), + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + ) + let serialized2 = serializeMessage(msg2) + check serialized2.isOk() + + let unwrapResult = rm.unwrapReceivedMessage(serialized2.get()) + check unwrapResult.isOk() + let (_, missingDeps, _) = unwrapResult.get() + + check: + missingDepsCount == 1 + missingDeps.len == 1 + id1 in missingDeps.getMessageIds() + messageReadyCount == 0 + test "acknowledgment via causal history": var messageReadyCount = 0 var messageSentCount = 0 @@ -187,13 +272,13 @@ suite "Reliability Mechanisms": check wrap1.isOk() # Create a message that has our message in causal history - let msg2 = SdsMessage( - messageId: "msg2", - lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1, - causalHistory: toCausalHistory(@[id1]), # Include our message in causal history - channelId: testChannel, - content: @[byte(2)], - bloomFilter: @[] # Test with an empty bloom filter + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = rm.channels[testChannel].lamportTimestamp + 1, + causalHistory = toCausalHistory(@[id1]), # Include our message in causal history + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[] # Test with an empty bloom filter , ) @@ -208,6 +293,47 @@ suite "Reliability Mechanisms": messageReadyCount == 1 # For msg2 which we "received" messageSentCount == 1 # For msg1 which was acknowledged via causal history + test "acknowledgment via causal history (non-empty bloom)": + # The causal-history ack path must not be perturbed by the local channel + # bloom carrying unrelated ids, and the empty bloom on the incoming + # message must not spuriously ack any of them. + var messageReadyCount = 0 + var messageSentCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + messageReadyCount += 1, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + messageSentCount += 1, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + discard, + ) + + rm.seedBloom(testChannel, 50) + + let msg1 = @[byte(1)] + let id1 = "msg1" + let wrap1 = rm.wrapOutgoingMessage(msg1, id1, testChannel) + check wrap1.isOk() + + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = rm.channels[testChannel].lamportTimestamp + 1, + causalHistory = toCausalHistory(@[id1]), + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + ) + let serializedMsg2 = serializeMessage(msg2) + check serializedMsg2.isOk() + + let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get()) + check unwrapResult.isOk() + + check: + messageReadyCount == 1 + messageSentCount == 1 # exactly id1; no spurious acks for the seeded ids + test "acknowledgment via bloom filter": var messageSentCount = 0 @@ -234,13 +360,13 @@ suite "Reliability Mechanisms": let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter) check bfResult.isOk() - let msg2 = SdsMessage( - messageId: "msg2", - lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1, - causalHistory: @[], # Empty causal history as we're using bloom filter - channelId: testChannel, - content: @[byte(2)], - bloomFilter: bfResult.get(), + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = rm.channels[testChannel].lamportTimestamp + 1, + causalHistory = @[], # Empty causal history as we're using bloom filter + channelId = testChannel, + content = @[byte(2)], + bloomFilter = bfResult.get(), ) let serializedMsg2 = serializeMessage(msg2) @@ -251,6 +377,90 @@ suite "Reliability Mechanisms": check messageSentCount == 1 # Our message should be acknowledged via bloom filter + test "acknowledgment via bloom filter (non-empty bloom)": + # The peer's bloom contains both our outgoing id and a pile of unrelated + # ids. The manager must still ack our message exactly once, and unrelated + # ids in the peer's bloom must not produce spurious sent callbacks. + var messageSentCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + messageSentCount += 1, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + discard, + ) + + let msg1 = @[byte(1)] + let id1 = "msg1" + let wrap1 = rm.wrapOutgoingMessage(msg1, id1, testChannel) + check wrap1.isOk() + + var otherPartyBloomFilter = + RollingBloomFilter.init(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) + for i in 0 ..< 100: + otherPartyBloomFilter.add("peer-noise-" & $i) + otherPartyBloomFilter.add(id1) + + let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter) + check bfResult.isOk() + + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = rm.channels[testChannel].lamportTimestamp + 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = bfResult.get(), + ) + let serializedMsg2 = serializeMessage(msg2) + check serializedMsg2.isOk() + + let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get()) + check unwrapResult.isOk() + + check messageSentCount == 1 + + test "outgoing message bloom snapshot reflects channel state": + # Until now nothing asserts that wrapOutgoingMessage actually attaches + # the current bloom snapshot — every other test runs against an empty + # filter where the field is empty either way. + rm.seedBloom(testChannel, 40, prefix = "delivered-") + + # Plus a real delivery so we exercise the bloom-on-delivery path too. + let incoming = SdsMessage.init( + messageId = "incoming-1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(9)], + bloomFilter = @[], + ) + let serIncoming = serializeMessage(incoming) + check serIncoming.isOk() + discard rm.unwrapReceivedMessage(serIncoming.get()) + + let outId = "outgoing-1" + let wrapped = rm.wrapOutgoingMessage(@[byte(1)], outId, testChannel) + check wrapped.isOk() + + let decoded = deserializeMessage(wrapped.get()) + check decoded.isOk() + let attachedFilter = deserializeBloomFilter(decoded.get().bloomFilter) + check attachedFilter.isOk() + + var snapshot = RollingBloomFilter.init( + filter = attachedFilter.get(), + capacity = DefaultBloomFilterCapacity, + minCapacity = 0, + maxCapacity = DefaultBloomFilterCapacity, + ) + check: + snapshot.contains("delivered-0") + snapshot.contains("delivered-39") + snapshot.contains("incoming-1") + test "retrieval hints": var messageReadyCount = 0 var messageSentCount = 0 @@ -287,13 +497,13 @@ suite "Reliability Mechanisms": check unwrappedMsg2.causalHistory[0].retrievalHint == cast[seq[byte]]("hint:" & id1) # Create a message with a missing dependency (no retrieval hint) - let msg3 = SdsMessage( - messageId: "msg3", - lamportTimestamp: 3, - causalHistory: toCausalHistory(@["missing-dep"]), - channelId: testChannel, - content: @[byte(3)], - bloomFilter: @[], + let msg3 = SdsMessage.init( + messageId = "msg3", + lamportTimestamp = 3, + causalHistory = toCausalHistory(@["missing-dep"]), + channelId = testChannel, + content = @[byte(3)], + bloomFilter = @[], ) let serialized3 = serializeMessage(msg3).get() let unwrapResult3 = rm.unwrapReceivedMessage(serialized3) @@ -305,13 +515,13 @@ suite "Reliability Mechanisms": check missingDeps3[0].retrievalHint.len == 0 # Test with a message that HAS a retrieval hint from remote - let msg4 = SdsMessage( - messageId: "msg4", - lamportTimestamp: 4, - causalHistory: @[newHistoryEntry("another-missing", cast[seq[byte]]("remote-hint"))], - channelId: testChannel, - content: @[byte(4)], - bloomFilter: @[], + let msg4 = SdsMessage.init( + messageId = "msg4", + lamportTimestamp = 4, + causalHistory = @[newHistoryEntry("another-missing", cast[seq[byte]]("remote-hint"))], + channelId = testChannel, + content = @[byte(4)], + bloomFilter = @[], ) let serialized4 = serializeMessage(msg4).get() let unwrapResult4 = rm.unwrapReceivedMessage(serialized4) @@ -359,13 +569,13 @@ suite "Periodic Tasks & Buffer Management": check outBuffer.len == 6 # Create message that acknowledges some messages - let ackMsg = SdsMessage( - messageId: "ack1", - lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1, - causalHistory: toCausalHistory(@["msg0", "msg2", "msg4"]), - channelId: testChannel, - content: @[byte(100)], - bloomFilter: @[], + let ackMsg = SdsMessage.init( + messageId = "ack1", + lamportTimestamp = rm.channels[testChannel].lamportTimestamp + 1, + causalHistory = toCausalHistory(@["msg0", "msg2", "msg4"]), + channelId = testChannel, + content = @[byte(100)], + bloomFilter = @[], ) let serializedAck = serializeMessage(ackMsg) @@ -488,13 +698,13 @@ suite "Special Cases Handling": history[^1] == "msg" & $(rm.config.maxMessageHistory + 5) test "invalid bloom filter handling": - let msgInvalid = SdsMessage( - messageId: "invalid-bf", - lamportTimestamp: 1, - causalHistory: toCausalHistory(@[]), - channelId: testChannel, - content: @[byte(1)], - bloomFilter: @[1.byte, 2.byte, 3.byte] # Invalid filter data + let msgInvalid = SdsMessage.init( + messageId = "invalid-bf", + lamportTimestamp = 1, + causalHistory = toCausalHistory(@[]), + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[1.byte, 2.byte, 3.byte] # Invalid filter data , ) @@ -519,13 +729,13 @@ suite "Special Cases Handling": ) # Create and process a message - let msg = SdsMessage( - messageId: "dup-msg", - lamportTimestamp: 1, - causalHistory: toCausalHistory(@[]), - channelId: testChannel, - content: @[byte(1)], - bloomFilter: @[], + let msg = SdsMessage.init( + messageId = "dup-msg", + lamportTimestamp = 1, + causalHistory = toCausalHistory(@[]), + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[], ) let serialized = serializeMessage(msg) @@ -662,6 +872,34 @@ suite "Multi-Channel ReliabilityManager Tests": msgId1 notin history2 msgId2 notin history1 + test "channel isolation (non-empty bloom)": + # With both channels carrying populated blooms, ids on one channel must + # not appear in the other's filter. An empty-bloom test cannot observe + # this — there is nothing to bleed across. + let channel1 = "iso-bloom-1" + let channel2 = "iso-bloom-2" + check rm.ensureChannel(channel1).isOk() + check rm.ensureChannel(channel2).isOk() + + rm.seedBloom(channel1, 25, prefix = "ch1-") + rm.seedBloom(channel2, 25, prefix = "ch2-") + + let wrap1 = rm.wrapOutgoingMessage(@[byte(1)], "iso-msg-1", channel1) + let wrap2 = rm.wrapOutgoingMessage(@[byte(2)], "iso-msg-2", channel2) + check wrap1.isOk() and wrap2.isOk() + + let bf1 = rm.channels[channel1].bloomFilter + let bf2 = rm.channels[channel2].bloomFilter + check: + bf1.contains("ch1-0") + bf1.contains("iso-msg-1") + not bf1.contains("ch2-0") + not bf1.contains("iso-msg-2") + bf2.contains("ch2-0") + bf2.contains("iso-msg-2") + not bf2.contains("ch1-0") + not bf2.contains("iso-msg-1") + test "multi-channel callbacks": var readyMessageCount = 0 var sentMessageCount = 0 @@ -692,22 +930,22 @@ suite "Multi-Channel ReliabilityManager Tests": # Create acknowledgment messages that include our message IDs in causal history # to trigger sent callbacks - let ackMsg1 = SdsMessage( - messageId: "ack1", - lamportTimestamp: rm.channels[channel1].lamportTimestamp + 1, - causalHistory: toCausalHistory(@[msgId1]), # Acknowledge msg1 - channelId: channel1, - content: @[byte(100)], - bloomFilter: @[], + let ackMsg1 = SdsMessage.init( + messageId = "ack1", + lamportTimestamp = rm.channels[channel1].lamportTimestamp + 1, + causalHistory = toCausalHistory(@[msgId1]), # Acknowledge msg1 + channelId = channel1, + content = @[byte(100)], + bloomFilter = @[], ) - let ackMsg2 = SdsMessage( - messageId: "ack2", - lamportTimestamp: rm.channels[channel2].lamportTimestamp + 1, - causalHistory: toCausalHistory(@[msgId2]), # Acknowledge msg2 - channelId: channel2, - content: @[byte(101)], - bloomFilter: @[], + let ackMsg2 = SdsMessage.init( + messageId = "ack2", + lamportTimestamp = rm.channels[channel2].lamportTimestamp + 1, + causalHistory = toCausalHistory(@[msgId2]), # Acknowledge msg2 + channelId = channel2, + content = @[byte(101)], + bloomFilter = @[], ) let serializedAck1 = serializeMessage(ackMsg1) @@ -741,3 +979,951 @@ suite "Multi-Channel ReliabilityManager Tests": # Dependencies in channel1 should not affect channel2 check rm.channels[channel1].bloomFilter.contains("dep1") check not rm.channels[channel2].bloomFilter.contains("dep1") + +# SDS-R Repair tests +suite "SDS-R: Computation Functions": + test "computeTReq returns duration in [tMin, tMax)": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d = computeTReq("participant1", "msg1", tMin, tMax) + check: + d.inMilliseconds >= tMin.inMilliseconds + d.inMilliseconds < tMax.inMilliseconds + + test "computeTReq is deterministic for same inputs": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d1 = computeTReq("p1", "m1", tMin, tMax) + let d2 = computeTReq("p1", "m1", tMin, tMax) + check d1 == d2 + + test "computeTReq varies with different participants": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d1 = computeTReq("participant-A", "msg1", tMin, tMax) + let d2 = computeTReq("participant-B", "msg1", tMin, tMax) + # Different participants should generally get different backoff (not guaranteed but highly likely) + # Just check both are in valid range + check: + d1.inMilliseconds >= tMin.inMilliseconds + d2.inMilliseconds >= tMin.inMilliseconds + + test "computeTResp original sender has zero distance": + let d = computeTResp("sender1", "sender1", "msg1", initDuration(seconds = 300)) + check d.inMilliseconds == 0 + + test "computeTResp non-sender has positive backoff": + let d = computeTResp("other-node", "sender1", "msg1", initDuration(seconds = 300)) + check d.inMilliseconds >= 0 + + test "isInResponseGroup all in same group when numGroups =1": + check isInResponseGroup("p1", "sender1", "msg1", 1) == true + check isInResponseGroup("p2", "sender1", "msg1", 1) == true + + test "isInResponseGroup sender always in own group": + # Original sender must always be in their own response group + for groups in 1 .. 10: + check isInResponseGroup("sender1", "sender1", "msg1", groups) == true + +suite "SDS-R: Repair Buffer Management": + var rm: ReliabilityManager + + setup: + let rmResult = newReliabilityManager( + participantId = "test-participant" + ) + check rmResult.isOk() + rm = rmResult.get() + check rm.ensureChannel(testChannel).isOk() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "missing deps added to outgoing repair buffer": + var missingDepsCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + missingDepsCount += 1, + ) + + # Create a message with a missing dependency + let msg = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = 2, + causalHistory = @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + ) + + let serialized = serializeMessage(msg).get() + let result = rm.unwrapReceivedMessage(serialized) + check result.isOk() + + # msg1 should be in the outgoing repair buffer + let channel = rm.channels[testChannel] + check: + missingDepsCount == 1 + "msg1" in channel.outgoingRepairBuffer + + test "receiving message clears it from repair buffers": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + # First, create the missing dep scenario + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = 2, + causalHistory = @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg2).get()) + check "msg1" in rm.channels[testChannel].outgoingRepairBuffer + + # Now receive msg1 — should clear from repair buffer + let msg1 = SdsMessage.init( + messageId = "msg1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg1).get()) + check "msg1" notin rm.channels[testChannel].outgoingRepairBuffer + + test "markDependenciesMet clears repair buffers": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg2 = SdsMessage.init( + messageId = "msg2", + lamportTimestamp = 2, + causalHistory = @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg2).get()) + check "msg1" in rm.channels[testChannel].outgoingRepairBuffer + + # Mark as met via store retrieval + check rm.markDependenciesMet(@["msg1"], testChannel).isOk() + check "msg1" notin rm.channels[testChannel].outgoingRepairBuffer + + test "expired repair requests attached to outgoing messages": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + # Manually add an expired repair entry + let channel = rm.channels[testChannel] + channel.outgoingRepairBuffer["missing-msg"] = OutgoingRepairEntry( + outHistEntry: HistoryEntry(messageId: "missing-msg", senderId: "orig-sender"), + minTimeRepairReq: getTime() - initDuration(seconds = 10), # Already expired + ) + + # Send a message — should pick up the expired repair request + let wrapped = rm.wrapOutgoingMessage(@[byte(1)], "new-msg", testChannel) + check wrapped.isOk() + + let unwrapped = deserializeMessage(wrapped.get()).get() + check: + unwrapped.repairRequest.len == 1 + unwrapped.repairRequest[0].messageId == "missing-msg" + # Should be removed from buffer after attaching + "missing-msg" notin channel.outgoingRepairBuffer + + test "expired repair requests attach the most-overdue first when capped": + # Per spec (sds-r-send-message, RECOMMENDED): when more entries are + # eligible than maxRepairRequests, attach the ones with the smallest + # minTimeRepairReq — i.e. the most overdue. + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + let channel = rm.channels[testChannel] + let now = getTime() + + # Five eligible entries with strictly ordered minTimeRepairReq (most-overdue first). + # All are expired; the cap is the default 3, so two should be left behind. + let expected = ["oldest", "second", "third", "fourth", "newest"] + for i, id in expected: + channel.outgoingRepairBuffer[id] = OutgoingRepairEntry( + outHistEntry: HistoryEntry(messageId: id, senderId: "sender"), + minTimeRepairReq: now - initDuration(seconds = 50 - i * 10), + ) + + let wrapped = rm.wrapOutgoingMessage(@[byte(1)], "outbound", testChannel) + check wrapped.isOk() + + let attached = deserializeMessage(wrapped.get()).get().repairRequest + check: + attached.len == rm.config.maxRepairRequests + attached[0].messageId == "oldest" + attached[1].messageId == "second" + attached[2].messageId == "third" + # Two least-overdue remain in the buffer for next time. + "fourth" in channel.outgoingRepairBuffer + "newest" in channel.outgoingRepairBuffer + "oldest" notin channel.outgoingRepairBuffer + "second" notin channel.outgoingRepairBuffer + "third" notin channel.outgoingRepairBuffer + + test "incoming repair request adds to incoming repair buffer when eligible": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + let channel = rm.channels[testChannel] + + # First, seed delivered history so we can respond to a repair request for it + let cachedMsg = SdsMessage.init( + messageId = "cached-msg", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(99)], + bloomFilter = @[], + ) + channel.messageHistory["cached-msg"] = cachedMsg + + # Receive a message with a repair request for "cached-msg" + let msgWithRepair = SdsMessage.init( + messageId = "requester-msg", + lamportTimestamp = 5, + causalHistory = @[], + channelId = testChannel, + content = @[byte(3)], + bloomFilter = @[], + repairRequest = @[HistoryEntry( + messageId: "cached-msg", + senderId: "test-participant", # Same as our participantId so we're in response group + )], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msgWithRepair).get()) + + # We should have added it to the incoming repair buffer (we have the message and are in response group) + check "cached-msg" in channel.incomingRepairBuffer + +suite "SDS-R: Protobuf Roundtrip": + test "senderId in HistoryEntry roundtrips through protobuf": + let msg = SdsMessage.init( + messageId = "msg1", + lamportTimestamp = 100, + causalHistory = @[ + HistoryEntry(messageId: "dep1", retrievalHint: @[byte(1), 2], senderId: "sender-A"), + HistoryEntry(messageId: "dep2", senderId: "sender-B"), + ], + channelId = "ch1", + content = @[byte(42)], + bloomFilter = @[], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.causalHistory.len == 2 + decoded.causalHistory[0].messageId == "dep1" + decoded.causalHistory[0].senderId == "sender-A" + decoded.causalHistory[0].retrievalHint == @[byte(1), 2] + decoded.causalHistory[1].messageId == "dep2" + decoded.causalHistory[1].senderId == "sender-B" + + test "repairRequest field roundtrips through protobuf": + let msg = SdsMessage.init( + messageId = "msg1", + lamportTimestamp = 100, + causalHistory = @[], + channelId = "ch1", + content = @[byte(42)], + bloomFilter = @[], + repairRequest = @[ + HistoryEntry(messageId: "missing1", senderId: "sender-X"), + HistoryEntry(messageId: "missing2", senderId: "sender-Y", retrievalHint: @[byte(5)]), + ], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.repairRequest.len == 2 + decoded.repairRequest[0].messageId == "missing1" + decoded.repairRequest[0].senderId == "sender-X" + decoded.repairRequest[1].messageId == "missing2" + decoded.repairRequest[1].senderId == "sender-Y" + decoded.repairRequest[1].retrievalHint == @[byte(5)] + + test "backward compat: message without repairRequest decodes fine": + let msg = SdsMessage.init( + messageId = "msg1", + lamportTimestamp = 100, + causalHistory = @[HistoryEntry(messageId: "dep1")], + channelId = "ch1", + content = @[byte(42)], + bloomFilter = @[], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.repairRequest.len == 0 + decoded.causalHistory[0].senderId == "" + + test "SdsMessage.senderId roundtrips through protobuf": + let msg = SdsMessage.init( + messageId = "m1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = "ch1", + content = @[byte(1)], + bloomFilter = @[], + senderId = "alice", + ) + let decoded = deserializeMessage(serializeMessage(msg).get()).get() + check decoded.senderId == "alice" + +# --------------------------------------------------------------------------- +# SDS-R Phase 2 tests: edge cases, lifecycle, sweep, and multi-participant flows +# --------------------------------------------------------------------------- + +suite "SDS-R: Edge Cases and Defensive Branches": + test "computeTReq returns tMin when range is degenerate": + let tMin = initDuration(seconds = 30) + # tMax == tMin + let d1 = computeTReq("p", "m", tMin, tMin) + check d1 == tMin + # tMax < tMin (rangeMs < 0) + let d2 = computeTReq("p", "m", tMin, initDuration(seconds = 10)) + check d2 == tMin + + test "computeTResp returns 0 when tMax is 0": + let d = computeTResp("p", "other", "m", initDuration(milliseconds = 0)) + check d.inMilliseconds == 0 + + test "computeTResp always stays within [0, tMax)": + # Adversarial sweep — result must never wrap negative nor exceed tMax + let tMax = initDuration(seconds = 300) + for i in 0 ..< 500: + let d = computeTResp( + "participant-" & $i, "sender-" & $(i * 13), "msg-" & $(i * 31), tMax + ) + check: + d.inMilliseconds >= 0 + d.inMilliseconds < tMax.inMilliseconds + + test "isInResponseGroup returns true for non-positive numGroups": + check isInResponseGroup("p", "sender", "m", 0) == true + check isInResponseGroup("p", "sender", "m", -1) == true + + test "computeTReq bounds across many random inputs": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + for i in 0 ..< 200: + let d = computeTReq("p-" & $i, "m-" & $i, tMin, tMax) + check: + d.inMilliseconds >= tMin.inMilliseconds + d.inMilliseconds < tMax.inMilliseconds + + test "response group distribution is roughly uniform": + # With numGroups =10, ~10% of random participants should share sender's group. + const numGroups = 10 + const totalParticipants = 1000 + let senderId = "alice" + let msgId = "msg-xyz" + var sameGroup = 0 + for i in 0 ..< totalParticipants: + if isInResponseGroup("participant-" & $i, senderId, msgId, numGroups): + sameGroup += 1 + # Expected ~100 (1/N), allow [50, 200] band for hash quirks + check: + sameGroup >= 50 + sameGroup <= 200 + + test "computeTResp monotonicity: self always fastest": + # The original sender (distance =0) must always be first to respond. + let tMax = initDuration(seconds = 300) + let selfD = computeTResp("alice", "alice", "msg-xyz", tMax) + check selfD.inMilliseconds == 0 + for i in 0 ..< 50: + let other = computeTResp("other-" & $i, "alice", "msg-xyz", tMax) + check other.inMilliseconds >= selfD.inMilliseconds + +suite "SDS-R: Lifecycle and State": + test "empty participantId disables outgoing repair creation": + let rm = newReliabilityManager().get() # empty participantId + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage.init( + messageId = "m2", + lamportTimestamp = 2, + causalHistory = @[HistoryEntry(messageId: "m1-missing", senderId: "alice")], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check rm.channels[testChannel].outgoingRepairBuffer.len == 0 + + test "empty senderId in incoming repair request is ignored": + let rm = newReliabilityManager(participantId = "bob").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + channel.messageHistory["m-wanted"] = SdsMessage.init( + messageId = "m-wanted", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(99), 99, 99], + bloomFilter = @[], + ) + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage.init( + messageId = "req-msg", + lamportTimestamp = 5, + causalHistory = @[], + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[], + repairRequest = @[HistoryEntry(messageId: "m-wanted", senderId: "")], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check "m-wanted" notin channel.incomingRepairBuffer + + test "wrapOutgoingMessage records the message in history with our senderId": + # Proves Bug 1 is fixed — the original sender can serve her own message. + # In the consolidated history model, the SdsMessage itself carries senderId + # and can be re-serialized on demand for repair, so a single membership + # check + senderId read covers both halves of the original assertion. + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + discard rm.wrapOutgoingMessage(@[byte(1), 2, 3], "m1", testChannel) + let channel = rm.channels[testChannel] + check: + "m1" in channel.messageHistory + channel.messageHistory["m1"].senderId == "alice" + channel.messageHistory["m1"].content == @[byte(1), 2, 3] + + test "getRecentHistoryEntries carries senderId for own messages": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + discard rm.wrapOutgoingMessage(@[byte(1)], "m1", testChannel) + discard rm.wrapOutgoingMessage(@[byte(2)], "m2", testChannel) + let entries = rm.getRecentHistoryEntries(10, testChannel) + check: + entries.len == 2 + entries[0].senderId == "alice" + entries[1].senderId == "alice" + + test "resetReliabilityManager clears all SDS-R state": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + + channel.outgoingRepairBuffer["a"] = OutgoingRepairEntry( + outHistEntry: HistoryEntry(messageId: "a", senderId: "x"), + minTimeRepairReq: getTime(), + ) + channel.incomingRepairBuffer["b"] = IncomingRepairEntry( + inHistEntry: HistoryEntry(messageId: "b", senderId: "y"), + cachedMessage: @[byte(1)], + minTimeRepairResp: getTime(), + ) + channel.messageHistory["c"] = SdsMessage.init( + messageId = "c", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(2)], + bloomFilter = @[], + senderId = "someone", + ) + + check rm.resetReliabilityManager().isOk() + check rm.ensureChannel(testChannel).isOk() + let ch2 = rm.channels[testChannel] + check: + ch2.outgoingRepairBuffer.len == 0 + ch2.incomingRepairBuffer.len == 0 + ch2.messageHistory.len == 0 + + test "SDS-R state is isolated per channel": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel("ch-A").isOk() + check rm.ensureChannel("ch-B").isOk() + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage.init( + messageId = "m2", + lamportTimestamp = 2, + causalHistory = @[HistoryEntry(messageId: "m1-missing", senderId: "bob")], + channelId = "ch-A", + content = @[byte(2)], + bloomFilter = @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check: + rm.channels["ch-A"].outgoingRepairBuffer.len == 1 + rm.channels["ch-B"].outgoingRepairBuffer.len == 0 + + test "duplicate message arrival cancels pending incoming repair entry": + # Covers the dedup-before-cleanup fix: a rebroadcast arriving at a peer who + # already has the message must clear that peer's incomingRepairBuffer entry. + let rm = newReliabilityManager(participantId = "carol").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + # Carol already has M1 in history and has a pending incomingRepairBuffer entry + channel.messageHistory["m1"] = SdsMessage.init( + messageId = "m1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[], + ) + channel.incomingRepairBuffer["m1"] = IncomingRepairEntry( + inHistEntry: HistoryEntry(messageId: "m1", senderId: "alice"), + cachedMessage: @[byte(1)], + minTimeRepairResp: getTime() + initDuration(seconds = 10), + ) + + # A rebroadcast of M1 arrives + let msg = SdsMessage.init( + messageId = "m1", + lamportTimestamp = 1, + causalHistory = @[], + channelId = testChannel, + content = @[byte(1)], + bloomFilter = @[], + senderId = "alice", + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check "m1" notin channel.incomingRepairBuffer + +suite "SDS-R: Repair Sweep": + var rm: ReliabilityManager + + setup: + rm = newReliabilityManager(participantId = "bob").get() + check rm.ensureChannel(testChannel).isOk() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "runRepairSweep fires onRepairReady for expired tResp": + var fireCount = 0 + var firstBytes: seq[byte] = @[] + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + fireCount += 1 + if fireCount == 1: + firstBytes = bytes, + ) + + let channel = rm.channels[testChannel] + channel.incomingRepairBuffer["m-ready"] = IncomingRepairEntry( + inHistEntry: HistoryEntry(messageId: "m-ready", senderId: "alice"), + cachedMessage: @[byte(1), 2, 3], + minTimeRepairResp: getTime() - initDuration(seconds = 1), # expired + ) + channel.incomingRepairBuffer["m-not-ready"] = IncomingRepairEntry( + inHistEntry: HistoryEntry(messageId: "m-not-ready", senderId: "alice"), + cachedMessage: @[byte(9), 9, 9], + minTimeRepairResp: getTime() + initDuration(minutes = 10), # far future + ) + + rm.runRepairSweep() + + check: + fireCount == 1 + firstBytes == @[byte(1), 2, 3] + "m-ready" notin channel.incomingRepairBuffer + "m-not-ready" in channel.incomingRepairBuffer + + test "runRepairSweep drops outgoing entries past T_max window": + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let channel = rm.channels[testChannel] + let tMax = rm.config.repairTMax + channel.outgoingRepairBuffer["m-stale"] = OutgoingRepairEntry( + outHistEntry: HistoryEntry(messageId: "m-stale", senderId: "alice"), + minTimeRepairReq: getTime() - (tMax + tMax), # now - 2*T_max, past drop window + ) + channel.outgoingRepairBuffer["m-fresh"] = OutgoingRepairEntry( + outHistEntry: HistoryEntry(messageId: "m-fresh", senderId: "alice"), + minTimeRepairReq: getTime(), + ) + + rm.runRepairSweep() + + check: + "m-stale" notin channel.outgoingRepairBuffer + "m-fresh" in channel.outgoingRepairBuffer + + test "runRepairSweep no-op when buffers are empty": + var fireCount = 0 + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + fireCount += 1, + ) + rm.runRepairSweep() + check fireCount == 0 + +# --- Multi-participant in-process bus for integration tests --------------- + +type + TestBus = ref object + peers: OrderedTable[SdsParticipantID, ReliabilityManager] + delivered: Table[SdsParticipantID, seq[SdsMessageID]] + # Log of raw message-ids placed on the wire, tagged with the source peer. + wireLog: seq[tuple[senderId: SdsParticipantID, messageId: SdsMessageID]] + +proc newTestBus(): TestBus = + TestBus( + peers: initOrderedTable[SdsParticipantID, ReliabilityManager](), + delivered: initTable[SdsParticipantID, seq[SdsMessageID]](), + wireLog: @[], + ) + +proc recordWire(bus: TestBus, senderId: SdsParticipantID, bytes: seq[byte]) {.gcsafe.} = + let decoded = deserializeMessage(bytes) + if decoded.isOk(): + bus.wireLog.add((senderId, decoded.get().messageId)) + +proc deliverExcept( + bus: TestBus, + senderId: SdsParticipantID, + bytes: seq[byte], + exclude: seq[SdsParticipantID], +) {.gcsafe.} = + for pid, peer in bus.peers: + if pid == senderId or pid in exclude: + continue + discard peer.unwrapReceivedMessage(bytes) + +proc addPeer( + bus: TestBus, + participantId: SdsParticipantID, + config: ReliabilityConfig = defaultConfig(), +): ReliabilityManager = + let rm = newReliabilityManager(config, participantId).get() + doAssert rm.ensureChannel(testChannel).isOk() + bus.peers[participantId] = rm + bus.delivered[participantId] = @[] + + let pid = participantId + let busRef = bus + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + busRef.delivered[pid].add(msgId), + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + busRef.recordWire(pid, bytes) + busRef.deliverExcept(pid, bytes, @[]), + ) + rm + +proc broadcast( + bus: TestBus, + senderId: SdsParticipantID, + content: seq[byte], + messageId: SdsMessageID, + dropAt: seq[SdsParticipantID] = @[], +) = + let rm = bus.peers[senderId] + let wrapped = rm.wrapOutgoingMessage(content, messageId, testChannel) + doAssert wrapped.isOk() + bus.recordWire(senderId, wrapped.get()) + bus.deliverExcept(senderId, wrapped.get(), dropAt) + +proc forceOutgoingExpired( + rm: ReliabilityManager, messageId: SdsMessageID +) = + ## Push a specific outgoingRepairBuffer entry's minTimeRepairReq into the past so the + ## next wrapOutgoingMessage will pick it up. + let channel = rm.channels[testChannel] + if messageId in channel.outgoingRepairBuffer: + channel.outgoingRepairBuffer[messageId].minTimeRepairReq = + getTime() - initDuration(seconds = 1) + +proc forceIncomingExpired( + rm: ReliabilityManager, messageId: SdsMessageID +) = + ## Push an incomingRepairBuffer entry's minTimeRepairResp into the past so runRepairSweep fires it. + let channel = rm.channels[testChannel] + if messageId in channel.incomingRepairBuffer: + channel.incomingRepairBuffer[messageId].minTimeRepairResp = + getTime() - initDuration(seconds = 1) + +suite "SDS-R: Multi-Participant Integration": + + test "basic single-gap repair (Alice -> Bob misses -> Carol's message triggers repair)": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1, but Bob is offline for this one. + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + # Carol now has M1; Bob does not. + check "m1" in carol.channels[testChannel].messageHistory + check "m1" notin bob.channels[testChannel].messageHistory + + # Carol sends M2 with causal history referencing M1. + bus.broadcast("carol", @[byte(2)], "m2") + # Bob detects M1 missing and populates his outgoingRepairBuffer. + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + # Bob should have buffered M2. + check "m2" in bob.channels[testChannel].incomingBuffer + check "m2" notin bus.delivered["bob"] + + # Force Bob's T_req so the next wrap attaches the repair request. + bob.forceOutgoingExpired("m1") + + # Bob sends M3 — it must carry repair_request =[M1, sender =alice]. + bus.broadcast("bob", @[byte(3)], "m3") + + # Alice received M3, saw the repair_request, cached-bypass and response-group + # checks pass, so she has an incomingRepairBuffer entry for M1 with tResp =0. + check "m1" in alice.channels[testChannel].incomingRepairBuffer + + # Force alice's tResp to past just to be safe (it's already 0 for self), + # then run her sweep. She rebroadcasts M1. + alice.forceIncomingExpired("m1") + alice.runRepairSweep() + + # Bob now has M1 and M2 delivered. + check: + "m1" in bus.delivered["bob"] + "m2" in bus.delivered["bob"] + + test "response cancellation: only one rebroadcast on the wire": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1, Bob offline. + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + # Carol sends M2; Bob sees M1 missing. + bus.broadcast("carol", @[byte(2)], "m2") + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + + # Bob requests repair. + bob.forceOutgoingExpired("m1") + bus.broadcast("bob", @[byte(3)], "m3") + + # Both Alice and Carol now have an incomingRepairBuffer entry for M1. + check: + "m1" in alice.channels[testChannel].incomingRepairBuffer + "m1" in carol.channels[testChannel].incomingRepairBuffer + + # Alice fires first (T_resp =0 for self). Her rebroadcast should cancel Carol's + # pending entry when Carol receives the rebroadcast. + alice.forceIncomingExpired("m1") + alice.runRepairSweep() + + # Carol's pending response must have been cleared by the dedup-path cleanup. + check "m1" notin carol.channels[testChannel].incomingRepairBuffer + + # Even if we now force-run Carol's sweep, nothing should fire. + let wireCountBefore = bus.wireLog.len + carol.runRepairSweep() + check bus.wireLog.len == wireCountBefore + + # Bob received exactly one rebroadcast of M1. + var m1RebroadcastCount = 0 + for entry in bus.wireLog: + if entry.messageId == "m1" and entry.senderId != "alice": + discard # only the original Alice->all broadcast had senderId ="alice" + if entry.messageId == "m1": + m1RebroadcastCount += 1 + # Two "m1" entries total on wire: (1) Alice's original broadcast, (2) Alice's rebroadcast. + check m1RebroadcastCount == 2 + + test "cancellation on incoming repair request: peer drops its own pending request": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1 — drop at both Bob and Carol, so both miss it. + bus.broadcast( + "alice", @[byte(1)], "m1", + dropAt = @["bob".SdsParticipantID, "carol".SdsParticipantID], + ) + # Alice sends M2 referencing M1 — both Bob and Carol see M1 missing. + bus.broadcast("alice", @[byte(2)], "m2") + check: + "m1" in bob.channels[testChannel].outgoingRepairBuffer + "m1" in carol.channels[testChannel].outgoingRepairBuffer + + # Bob's T_req fires first. He sends a repair request for M1. + bob.forceOutgoingExpired("m1") + bus.broadcast("bob", @[byte(3)], "m3") + + # Carol, on receiving Bob's repair request, must have dropped her own + # pending outgoingRepairBuffer entry for M1 (cancellation). + check "m1" notin carol.channels[testChannel].outgoingRepairBuffer + + test "response group filtering: only group members respond": + # With numGroups =10, roughly 1/10 of receivers will be in the group. + # Construct a sender+message where a specific non-sender is NOT in the group. + var cfg = defaultConfig() + cfg.numResponseGroups = 10 + + # Pick a msgId where carol is not in the group and bob is + # We probe deterministically because computeTReq/isInResponseGroup are pure. + var chosenMsg = "" + for i in 0 ..< 1000: + let candidate = "probe-" & $i + let bobIn = isInResponseGroup("bob", "alice", candidate, 10) + let carolIn = isInResponseGroup("carol", "alice", candidate, 10) + if bobIn and not carolIn: + chosenMsg = candidate + break + check chosenMsg.len > 0 + + let bus = newTestBus() + discard bus.addPeer("alice", cfg) + let bob = bus.addPeer("bob", cfg) + let carol = bus.addPeer("carol", cfg) + + # Both Bob and Carol receive the original M1 (so both have it in messageHistory). + bus.broadcast("alice", @[byte(1)], chosenMsg) + + # Now Dave arrives: build a fake requester message manually so its repair_request + # names Alice as senderId for chosenMsg. + # We inject directly by calling unwrapReceivedMessage on bob/carol. + let dave = bus.addPeer("dave", cfg) + # Dave has no messages, but we can hand-craft a repair request he would send. + let reqMsg = SdsMessage.init( + messageId = "req-from-dave", + lamportTimestamp = 10, + causalHistory = @[], + channelId = testChannel, + content = @[byte(9)], + bloomFilter = @[], + senderId = "dave", + repairRequest = @[HistoryEntry(messageId: chosenMsg, senderId: "alice")], + ) + let bytes = serializeMessage(reqMsg).get() + discard bob.unwrapReceivedMessage(bytes) + discard carol.unwrapReceivedMessage(bytes) + + check: + chosenMsg in bob.channels[testChannel].incomingRepairBuffer + chosenMsg notin carol.channels[testChannel].incomingRepairBuffer + + test "multi-gap batch repair: many missing deps split across requests": + let bus = newTestBus() + discard bus.addPeer("alice") + let bob = bus.addPeer("bob") + + # Alice sends 5 messages while Bob is offline. + let drops = @["bob".SdsParticipantID] + bus.broadcast("alice", @[byte(1)], "m1", dropAt = drops) + bus.broadcast("alice", @[byte(2)], "m2", dropAt = drops) + bus.broadcast("alice", @[byte(3)], "m3", dropAt = drops) + bus.broadcast("alice", @[byte(4)], "m4", dropAt = drops) + bus.broadcast("alice", @[byte(5)], "m5", dropAt = drops) + + # Bob comes online and receives M6 which depends on m1..m5. + bus.broadcast("alice", @[byte(6)], "m6") + + # Bob should have 5 outgoing repair entries. + let channel = bob.channels[testChannel] + check channel.outgoingRepairBuffer.len == 5 + + # Force all to expired and wrap one message — only maxRepairRequests + # (default 3) should attach to a single outgoing message. + for id in ["m1", "m2", "m3", "m4", "m5"]: + bob.forceOutgoingExpired(id) + + let wrapped = bob.wrapOutgoingMessage(@[byte(99)], "bob-msg-1", testChannel).get() + let decoded = deserializeMessage(wrapped).get() + check decoded.repairRequest.len <= bob.config.maxRepairRequests + + # The attached entries should be removed from the outgoing buffer. + check channel.outgoingRepairBuffer.len == 5 - decoded.repairRequest.len + + test "markDependenciesMet externally clears pending repair entry": + let bus = newTestBus() + discard bus.addPeer("alice") + let bob = bus.addPeer("bob") + + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + bus.broadcast("alice", @[byte(2)], "m2") + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + + # Simulate Bob fetching M1 via an out-of-band store query. + check bob.markDependenciesMet(@["m1"], testChannel).isOk() + check "m1" notin bob.channels[testChannel].outgoingRepairBuffer