diff --git a/library/config_json.nim b/library/config_json.nim new file mode 100644 index 0000000..95c529e --- /dev/null +++ b/library/config_json.nim @@ -0,0 +1,82 @@ +## JSON parser for ReliabilityConfig — used by the FFI constructor +## SdsNewReliabilityManagerWithConfig. +## +## Schema: a JSON object where every field is optional. Missing fields fall +## back to the Default* constants in sds/types/reliability_config.nim. +## Duration fields use the suffix "Ms" and are integer milliseconds. +## +## Empty input ("" or NULL on the C side) returns the default config. + +import std/[json, times] +import results +import sds/types/reliability_config + +proc getJsonInt(node: JsonNode, key: string, default: int): int = + if node.hasKey(key) and node[key].kind == JInt: + return node[key].getInt() + return default + +proc getJsonFloat(node: JsonNode, key: string, default: float): float = + if not node.hasKey(key): + return default + case node[key].kind + of JFloat: node[key].getFloat() + of JInt: node[key].getInt().float + else: default + +proc getJsonDurationMs( + node: JsonNode, key: string, default: Duration +): Duration = + if node.hasKey(key) and node[key].kind == JInt: + return initDuration(milliseconds = node[key].getInt()) + return default + +proc parseReliabilityConfig*( + jsonStr: string +): Result[ReliabilityConfig, string] = + ## Parses a JSON string into a ReliabilityConfig. Empty input returns the + ## default config. Unknown keys are ignored. Type-mismatched values fall + ## back to defaults rather than failing. + if jsonStr.len == 0: + return ok(ReliabilityConfig.init()) + + var node: JsonNode + try: + node = parseJson(jsonStr) + except JsonParsingError, ValueError, Exception: + return err("invalid JSON: " & getCurrentExceptionMsg()) + + if node.isNil or node.kind != JObject: + return err("config must be a JSON object") + + ok( + ReliabilityConfig.init( + bloomFilterCapacity = + getJsonInt(node, "bloomFilterCapacity", DefaultBloomFilterCapacity), + bloomFilterErrorRate = + getJsonFloat(node, "bloomFilterErrorRate", DefaultBloomFilterErrorRate), + maxMessageHistory = + getJsonInt(node, "maxMessageHistory", DefaultMaxMessageHistory), + maxCausalHistory = + getJsonInt(node, "maxCausalHistory", DefaultMaxCausalHistory), + resendInterval = + getJsonDurationMs(node, "resendIntervalMs", DefaultResendInterval), + maxResendAttempts = + getJsonInt(node, "maxResendAttempts", DefaultMaxResendAttempts), + syncMessageInterval = getJsonDurationMs( + node, "syncMessageIntervalMs", DefaultSyncMessageInterval + ), + bufferSweepInterval = getJsonDurationMs( + node, "bufferSweepIntervalMs", DefaultBufferSweepInterval + ), + repairTMin = getJsonDurationMs(node, "repairTMinMs", DefaultRepairTMin), + repairTMax = getJsonDurationMs(node, "repairTMaxMs", DefaultRepairTMax), + numResponseGroups = + getJsonInt(node, "numResponseGroups", DefaultNumResponseGroups), + maxRepairRequests = + getJsonInt(node, "maxRepairRequests", DefaultMaxRepairRequests), + repairSweepInterval = getJsonDurationMs( + node, "repairSweepIntervalMs", DefaultRepairSweepInterval + ), + ) + ) diff --git a/library/events/json_repair_ready_event.nim b/library/events/json_repair_ready_event.nim new file mode 100644 index 0000000..d34c954 --- /dev/null +++ b/library/events/json_repair_ready_event.nim @@ -0,0 +1,20 @@ +import std/[json, base64] +import ./json_base_event, sds/[message] + +type JsonRepairReadyEvent* = ref object of JsonEvent + channelId*: SdsChannelID + message*: seq[byte] + +proc new*( + T: type JsonRepairReadyEvent, message: seq[byte], channelId: SdsChannelID +): T = + return JsonRepairReadyEvent( + eventType: "repair_ready", message: message, channelId: channelId + ) + +method `$`*(jsonRepairReady: JsonRepairReadyEvent): string = + var node = newJObject() + node["eventType"] = %*jsonRepairReady.eventType + node["channelId"] = %*jsonRepairReady.channelId + node["message"] = %*encode(jsonRepairReady.message) + $node diff --git a/library/libsds.h b/library/libsds.h index 0d9840e..553d469 100644 --- a/library/libsds.h +++ b/library/libsds.h @@ -28,6 +28,37 @@ typedef void (*SdsRetrievalHintProvider) (const char* messageId, char** hint, si void* SdsNewReliabilityManager(SdsCallBack callback, void* userData); +// Construct a Reliability Manager with an explicit participant ID and a +// JSON-encoded ReliabilityConfig. +// +// participantId: stable, non-empty identifier for SDS-R. Pass NULL or "" to +// disable SDS-R (the manager will not request or answer +// repairs). It MUST be set-once at construction; do not change +// it across the lifetime of the manager. +// configJson: JSON object with optional fields for ReliabilityConfig. +// Pass NULL or "" to use the full default config. Missing +// fields fall back to per-field defaults. Duration fields use +// the suffix "Ms" (integer milliseconds). +// +// Recognised JSON keys: +// bloomFilterCapacity (int, default 10000) +// bloomFilterErrorRate (float, default 0.001) +// maxMessageHistory (int, default 1000) +// maxCausalHistory (int, default 10) +// resendIntervalMs (int, default 60000) +// maxResendAttempts (int, default 5) +// syncMessageIntervalMs (int, default 30000) +// bufferSweepIntervalMs (int, default 60000) +// repairTMinMs (int, default 30000) +// repairTMaxMs (int, default 300000) +// numResponseGroups (int, default 1) +// maxRepairRequests (int, default 3) +// repairSweepIntervalMs (int, default 5000) +void* SdsNewReliabilityManagerWithConfig(const char* participantId, + const char* configJson, + SdsCallBack callback, + void* userData); + void SdsSetEventCallback(void* ctx, SdsCallBack callback, void* userData); void SdsSetRetrievalHintProvider(void* ctx, SdsRetrievalHintProvider callback, void* userData); @@ -59,6 +90,13 @@ int SdsMarkDependenciesMet(void* ctx, int SdsStartPeriodicTasks(void* ctx, SdsCallBack callback, void* userData); +// Removes a channel and frees its per-channel state (buffers, bloom filter, +// message cache, SDS-R repair entries). Safe to call on a channel that does +// not exist; returns RET_OK in that case. +int SdsRemoveChannel(void* ctx, + const char* channelId, + SdsCallBack callback, + void* userData); #ifdef __cplusplus diff --git a/library/libsds.nim b/library/libsds.nim index 4ae285f..fc90003 100644 --- a/library/libsds.nim +++ b/library/libsds.nim @@ -16,7 +16,7 @@ import sds, ./events/[ json_message_ready_event, json_message_sent_event, json_missing_dependencies_event, - json_periodic_sync_event, + json_periodic_sync_event, json_repair_ready_event, ] ################################################################################ @@ -114,6 +114,11 @@ proc onPeriodicSync(ctx: ptr SdsContext): PeriodicSyncCallback = callEventCallback(ctx, "onPeriodicSync"): $JsonPeriodicSyncEvent.new() +proc onRepairReady(ctx: ptr SdsContext): RepairReadyCallback = + return proc(message: seq[byte], channelId: SdsChannelID) {.gcsafe.} = + callEventCallback(ctx, "onRepairReady"): + $JsonRepairReadyEvent.new(message, channelId) + proc onRetrievalHint(ctx: ptr SdsContext): RetrievalHintProvider = return proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} = if isNil(ctx.retrievalHintProvider): @@ -173,12 +178,17 @@ proc initializeLibrary() {.exported.} = ################################################################################ ### Exported procs -proc SdsNewReliabilityManager( - callback: SdsCallBack, userData: pointer -): pointer {.dynlib, exportc, cdecl.} = +proc createManager( + participantId: cstring, + configJson: cstring, + callback: SdsCallBack, + userData: pointer, +): pointer = + ## Shared implementation for SdsNewReliabilityManager and + ## SdsNewReliabilityManagerWithConfig. Either argument may be NULL or empty + ## to indicate "use defaults". initializeLibrary() - ## Creates a new instance of the Reliability Manager. if isNil(callback): echo "error: missing callback in NewReliabilityManager" return nil @@ -196,13 +206,17 @@ proc SdsNewReliabilityManager( missingDependenciesCb: onMissingDependencies(ctx), periodicSyncCb: onPeriodicSync(ctx), retrievalHintProvider: onRetrievalHint(ctx), + repairReadyCb: onRepairReady(ctx), ) + let pId: cstring = if participantId.isNil: cstring"" else: participantId + let cfg: cstring = if configJson.isNil: cstring"" else: configJson + let retCode = handleRequest( ctx, RequestType.LIFECYCLE, SdsLifecycleRequest.createShared( - SdsLifecycleMsgType.CREATE_RELIABILITY_MANAGER, nil, appCallbacks + SdsLifecycleMsgType.CREATE_RELIABILITY_MANAGER, "", appCallbacks, pId, cfg ), callback, userData, @@ -213,6 +227,26 @@ proc SdsNewReliabilityManager( return ctx +proc SdsNewReliabilityManager( + callback: SdsCallBack, userData: pointer +): pointer {.dynlib, exportc, cdecl.} = + ## Back-compat shim. Constructs a manager with empty participantId (SDS-R + ## disabled) and the default ReliabilityConfig. New code should use + ## SdsNewReliabilityManagerWithConfig. + return createManager(nil, nil, callback, userData) + +proc SdsNewReliabilityManagerWithConfig( + participantId: cstring, + configJson: cstring, + callback: SdsCallBack, + userData: pointer, +): pointer {.dynlib, exportc, cdecl.} = + ## Creates a new instance of the Reliability Manager with an explicit + ## participantId (required for SDS-R) and a JSON-encoded ReliabilityConfig. + ## Either argument may be NULL or empty to fall back to defaults; missing + ## fields inside the JSON also fall back to per-field defaults. + return createManager(participantId, configJson, callback, userData) + proc SdsSetEventCallback( ctx: ptr SdsContext, callback: SdsCallBack, userData: pointer ) {.dynlib, exportc.} = @@ -378,5 +412,37 @@ proc SdsStartPeriodicTasks( userData, ) +proc SdsRemoveChannel( + ctx: ptr SdsContext, + channelId: cstring, + callback: SdsCallBack, + userData: pointer, +): cint {.dynlib, exportc.} = + ## Removes a channel and its associated state from the Reliability Manager. + ## Use this when a user leaves a channel to release per-channel buffers, + ## bloom filter, message cache, and SDS-R repair state. + initializeLibrary() + checkLibsdsParams(ctx, callback, userData) + + if channelId == nil: + let msg = "libsds error: " & "channel ID pointer is NULL" + callback(RET_ERR, unsafeAddr msg[0], cast[csize_t](len(msg)), userData) + return RET_ERR + + if $channelId == "": + let msg = "libsds error: " & "channel ID is empty string" + callback(RET_ERR, unsafeAddr msg[0], cast[csize_t](len(msg)), userData) + return RET_ERR + + handleRequest( + ctx, + RequestType.LIFECYCLE, + SdsLifecycleRequest.createShared( + SdsLifecycleMsgType.REMOVE_CHANNEL, channelId + ), + callback, + userData, + ) + ### End of exported procs ################################################################################ diff --git a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim index 8d2e9bc..363d442 100644 --- a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim +++ b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim @@ -2,45 +2,61 @@ import std/json import chronos, chronicles, results import library/alloc +import library/config_json import sds type SdsLifecycleMsgType* = enum CREATE_RELIABILITY_MANAGER RESET_RELIABILITY_MANAGER START_PERIODIC_TASKS + REMOVE_CHANNEL type SdsLifecycleRequest* = object operation: SdsLifecycleMsgType channelId: cstring appCallbacks: AppCallbacks + participantId: cstring + configJson: cstring proc createShared*( T: type SdsLifecycleRequest, op: SdsLifecycleMsgType, channelId: cstring = "", appCallbacks: AppCallbacks = nil, + participantId: cstring = "", + configJson: cstring = "", ): ptr type T = var ret = createShared(T) ret[].operation = op ret[].appCallbacks = appCallbacks ret[].channelId = channelId.alloc() + ret[].participantId = participantId.alloc() + ret[].configJson = configJson.alloc() return ret proc destroyShared(self: ptr SdsLifecycleRequest) = deallocShared(self[].channelId) + deallocShared(self[].participantId) + deallocShared(self[].configJson) deallocShared(self) proc createReliabilityManager( - appCallbacks: AppCallbacks = nil + participantId: string, + configJson: string, + appCallbacks: AppCallbacks = nil, ): Future[Result[ReliabilityManager, string]] {.async.} = - let rm = newReliabilityManager().valueOr: + let config = parseReliabilityConfig(configJson).valueOr: + error "Failed to parse reliability config", error = error + return err("Failed to parse reliability config: " & error) + + let rm = newReliabilityManager(config, participantId).valueOr: error "Failed creating reliability manager", error = error return err("Failed creating reliability manager: " & $error) rm.setCallbacks( appCallbacks.messageReadyCb, appCallbacks.messageSentCb, appCallbacks.missingDependenciesCb, appCallbacks.periodicSyncCb, - appCallbacks.retrievalHintProvider, + appCallbacks.retrievalHintProvider, appCallbacks.repairReadyCb, ) return ok(rm) @@ -53,7 +69,11 @@ proc process*( case self.operation of CREATE_RELIABILITY_MANAGER: - rm[] = (await createReliabilityManager(self.appCallbacks)).valueOr: + rm[] = ( + await createReliabilityManager( + $self.participantId, $self.configJson, self.appCallbacks + ) + ).valueOr: error "CREATE_RELIABILITY_MANAGER failed", error = error return err("error processing CREATE_RELIABILITY_MANAGER request: " & $error) of RESET_RELIABILITY_MANAGER: @@ -62,5 +82,9 @@ proc process*( return err("error processing RESET_RELIABILITY_MANAGER request: " & $error) of START_PERIODIC_TASKS: rm[].startPeriodicTasks() + of REMOVE_CHANNEL: + removeChannel(rm[], $self.channelId).isOkOr: + error "REMOVE_CHANNEL failed", error = error + return err("error processing REMOVE_CHANNEL request: " & $error) return ok("") diff --git a/sds.nim b/sds.nim index 58d1893..c4a1482 100644 --- a/sds.nim +++ b/sds.nim @@ -5,11 +5,12 @@ import sds/[types, protobuf, sds_utils, rolling_bloom_filter] export types, protobuf, sds_utils, rolling_bloom_filter proc newReliabilityManager*( - config: ReliabilityConfig = defaultConfig() + config: ReliabilityConfig = defaultConfig(), + participantId: SdsParticipantID = "", ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new multi-channel ReliabilityManager. try: - let rm = ReliabilityManager.new(config) + let rm = ReliabilityManager.new(config, participantId) return ok(rm) except Exception: error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() @@ -88,6 +89,17 @@ proc wrapOutgoingMessage*( error "Failed to serialize bloom filter", channelId = channelId return err(ReliabilityError.reSerializationError) + # SDS-R: collect eligible expired repair requests to attach + var repairReqs: seq[HistoryEntry] = @[] + let now = getTime() + var expiredKeys: seq[SdsMessageID] = @[] + for msgId, repairEntry in channel.outgoingRepairBuffer: + if now >= repairEntry.tReq and repairReqs.len < rm.config.maxRepairRequests: + repairReqs.add(repairEntry.entry) + expiredKeys.add(msgId) + for key in expiredKeys: + channel.outgoingRepairBuffer.del(key) + let msg = SdsMessage.init( messageId = messageId, lamportTimestamp = channel.lamportTimestamp, @@ -95,6 +107,8 @@ proc wrapOutgoingMessage*( channelId = channelId, content = message, bloomFilter = bfResult.get(), + senderId = rm.participantId, + repairRequest = repairReqs, ) channel.outgoingBuffer.add( @@ -104,7 +118,16 @@ proc wrapOutgoingMessage*( channel.bloomFilter.add(msg.messageId) rm.addToHistory(msg.messageId, channelId) - return serializeMessage(msg) + # SDS-R: record sender for future causal-history entries + if rm.participantId.len > 0: + channel.messageSenders[msg.messageId] = rm.participantId + + let serialized = serializeMessage(msg) + if serialized.isOk(): + # SDS-R: cache serialized bytes so we can serve our own message on repair + if channel.messageCache.len < rm.config.maxMessageHistory: + channel.messageCache[msg.messageId] = serialized.get() + return serialized except Exception: error "Failed to wrap message", channelId = channelId, msg = getCurrentExceptionMsg() @@ -164,6 +187,11 @@ proc unwrapReceivedMessage*( let channel = rm.getOrCreateChannel(channelId) + # SDS-R: opportunistic repair-buffer cleanup — applies to duplicates too, + # so rebroadcasts cancel redundant responses on peers that already have the message. + channel.outgoingRepairBuffer.del(msg.messageId) + channel.incomingRepairBuffer.del(msg.messageId) + if msg.messageId in channel.messageHistory: return ok((msg.content, @[], channelId)) @@ -172,6 +200,36 @@ proc unwrapReceivedMessage*( rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) rm.reviewAckStatus(msg) + # SDS-R: cache the raw message for potential repair responses + if channel.messageCache.len < rm.config.maxMessageHistory: + channel.messageCache[msg.messageId] = message + + # SDS-R: record sender so our future causal-history entries carry it + if msg.senderId.len > 0: + channel.messageSenders[msg.messageId] = msg.senderId + + # SDS-R: process incoming repair requests from this message + let now = getTime() + for repairEntry in msg.repairRequest: + # Remove from our own outgoing repair buffer (someone else is also requesting) + channel.outgoingRepairBuffer.del(repairEntry.messageId) + # Check if we can respond to this repair request + if repairEntry.messageId in channel.messageCache and + rm.participantId.len > 0 and repairEntry.senderId.len > 0: + if isInResponseGroup( + rm.participantId, repairEntry.senderId, + repairEntry.messageId, rm.config.numResponseGroups + ): + let tResp = computeTResp( + rm.participantId, repairEntry.senderId, + repairEntry.messageId, rm.config.repairTMax + ) + channel.incomingRepairBuffer[repairEntry.messageId] = IncomingRepairEntry( + entry: repairEntry, + cachedMessage: channel.messageCache[repairEntry.messageId], + tResp: now + tResp, + ) + var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) if missingDeps.len == 0: @@ -185,6 +243,10 @@ proc unwrapReceivedMessage*( IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]()) else: rm.addToHistory(msg.messageId, channelId) + # Unblock any buffered messages that were waiting on this one. + for pendingId, entry in channel.incomingBuffer: + if msg.messageId in entry.missingDeps: + channel.incomingBuffer[pendingId].missingDeps.excl(msg.messageId) rm.processIncomingBuffer(channelId) if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId, channelId) @@ -197,6 +259,19 @@ proc unwrapReceivedMessage*( if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps, channelId) + # SDS-R: add missing deps to outgoing repair buffer + if rm.participantId.len > 0: + for dep in missingDeps: + if dep.messageId notin channel.outgoingRepairBuffer: + let tReq = computeTReq( + rm.participantId, dep.messageId, + rm.config.repairTMin, rm.config.repairTMax + ) + channel.outgoingRepairBuffer[dep.messageId] = OutgoingRepairEntry( + entry: dep, + tReq: now + tReq, + ) + return ok((msg.content, missingDeps, channelId)) except Exception: error "Failed to unwrap message", msg = getCurrentExceptionMsg() @@ -220,6 +295,10 @@ proc markDependenciesMet*( if msgId in entry.missingDeps: channel.incomingBuffer[pendingId].missingDeps.excl(msgId) + # SDS-R: clear from repair buffers (dependency now met) + channel.outgoingRepairBuffer.del(msgId) + channel.incomingRepairBuffer.del(msgId) + rm.processIncomingBuffer(channelId) return ok() except Exception: @@ -234,6 +313,7 @@ proc setCallbacks*( onMissingDependencies: MissingDependenciesCallback, onPeriodicSync: PeriodicSyncCallback = nil, onRetrievalHint: RetrievalHintProvider = nil, + onRepairReady: RepairReadyCallback = nil, ) = ## Sets the callback functions for various events in the ReliabilityManager. withLock rm.lock: @@ -242,6 +322,7 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync rm.onRetrievalHint = onRetrievalHint + rm.onRepairReady = onRepairReady proc checkUnacknowledgedMessages( rm: ReliabilityManager, channelId: SdsChannelID @@ -299,10 +380,54 @@ proc periodicSyncMessage( error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) +proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} = + ## SDS-R: Runs a single pass of the repair sweep. + ## - Incoming: fires onRepairReady for expired T_resp entries and removes them + ## - Outgoing: drops entries past T_max window + ## Exposed so it can be driven directly in tests; also invoked by periodicRepairSweep. + try: + let now = getTime() + for channelId, channel in rm.channels: + try: + # Check incoming repair buffer for expired T_resp (time to rebroadcast) + var toRebroadcast: seq[SdsMessageID] = @[] + for msgId, entry in channel.incomingRepairBuffer: + if now >= entry.tResp: + toRebroadcast.add(msgId) + + for msgId in toRebroadcast: + let entry = channel.incomingRepairBuffer[msgId] + channel.incomingRepairBuffer.del(msgId) + if not rm.onRepairReady.isNil(): + rm.onRepairReady(entry.cachedMessage, channelId) + + # Drop expired outgoing repair entries past T_max + var toRemove: seq[SdsMessageID] = @[] + let tMaxDuration = rm.config.repairTMax + for msgId, entry in channel.outgoingRepairBuffer: + if now - entry.tReq > tMaxDuration: + toRemove.add(msgId) + for msgId in toRemove: + channel.outgoingRepairBuffer.del(msgId) + except Exception: + error "Error in repair sweep for channel", + channelId = channelId, msg = getCurrentExceptionMsg() + except Exception: + error "Error in repair sweep", msg = getCurrentExceptionMsg() + +proc periodicRepairSweep( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = + ## SDS-R: Periodically checks repair buffers for expired entries. + while true: + rm.runRepairSweep() + await sleepAsync(chronos.milliseconds(rm.config.repairSweepInterval.inMilliseconds)) + proc startPeriodicTasks*(rm: ReliabilityManager) = ## Starts the periodic tasks for buffer sweeping and sync message sending. asyncSpawn rm.periodicBufferSweep() asyncSpawn rm.periodicSyncMessage() + asyncSpawn rm.periodicRepairSweep() proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = ## Resets the ReliabilityManager to its initial state. @@ -313,6 +438,10 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE channel.messageHistory.setLen(0) channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() + channel.messageCache.clear() + channel.messageSenders.clear() channel.bloomFilter = RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) rm.channels.clear() diff --git a/sds/message.nim b/sds/message.nim index ddf5e5f..410fc43 100644 --- a/sds/message.nim +++ b/sds/message.nim @@ -3,6 +3,7 @@ import ./types/history_entry import ./types/sds_message import ./types/unacknowledged_message import ./types/incoming_message +import ./types/repair_entry import ./types/reliability_config export @@ -11,4 +12,5 @@ export sds_message, unacknowledged_message, incoming_message, + repair_entry, reliability_config diff --git a/sds/protobuf.nim b/sds/protobuf.nim index ba1b7ff..24a95d7 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -5,6 +5,24 @@ import ./protobufutil import ./bloom import ./sds_utils +proc encodeHistoryEntry*(entry: HistoryEntry): ProtoBuffer = + var entryPb = initProtoBuffer() + entryPb.write(1, entry.messageId) + if entry.retrievalHint.len > 0: + entryPb.write(2, entry.retrievalHint) + if entry.senderId.len > 0: + entryPb.write(3, entry.senderId) + entryPb.finish() + entryPb + +proc decodeHistoryEntry*(entryPb: ProtoBuffer): ProtobufResult[HistoryEntry] = + var entry = HistoryEntry.init("") + if not ?entryPb.getField(1, entry.messageId): + return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) + discard entryPb.getField(2, entry.retrievalHint) + discard entryPb.getField(3, entry.senderId) + ok(entry) + proc encode*(msg: SdsMessage): ProtoBuffer = var pb = initProtoBuffer() @@ -12,16 +30,20 @@ proc encode*(msg: SdsMessage): ProtoBuffer = pb.write(2, uint64(msg.lamportTimestamp)) for entry in msg.causalHistory: - var entryPb = initProtoBuffer() - entryPb.write(1, entry.messageId) - if entry.retrievalHint.len > 0: - entryPb.write(2, entry.retrievalHint) - entryPb.finish() + let entryPb = encodeHistoryEntry(entry) pb.write(3, entryPb.buffer) pb.write(4, msg.channelId) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) + + if msg.senderId.len > 0: + pb.write(7, msg.senderId) + + for entry in msg.repairRequest: + let entryPb = encodeHistoryEntry(entry) + pb.write(13, entryPb.buffer) + pb.finish() return pb @@ -44,11 +66,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = # New format: repeated HistoryEntry for histBuffer in historyBuffers: let entryPb = initProtoBuffer(histBuffer) - var entry = HistoryEntry.init("") - if not ?entryPb.getField(1, entry.messageId): - return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) - # retrievalHint is optional - discard entryPb.getField(2, entry.retrievalHint) + let entry = ?decodeHistoryEntry(entryPb) msg.causalHistory.add(entry) else: # Try old format: repeated string @@ -66,6 +84,17 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if not ?pb.getField(6, msg.bloomFilter): msg.bloomFilter = @[] # Empty if not present + # SDS-R: decode senderId (field 7, optional) + discard pb.getField(7, msg.senderId) + + # SDS-R: decode repair request (field 13, optional) + var repairBuffers: seq[seq[byte]] + if pb.getRepeatedField(13, repairBuffers).isOk(): + for repairBuffer in repairBuffers: + let entryPb = initProtoBuffer(repairBuffer) + let entry = ?decodeHistoryEntry(entryPb) + msg.repairRequest.add(entry) + return ok(msg) proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index f1a68ca..7e4eba6 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -1,15 +1,15 @@ -import std/[locks, tables, sequtils] +import std/[times, locks, tables, sequtils, hashes] import chronicles, results import ./rolling_bloom_filter import ./types/[ sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, - reliability_error, callbacks, app_callbacks, reliability_config, channel_context, - reliability_manager, + reliability_error, callbacks, app_callbacks, reliability_config, repair_entry, + channel_context, reliability_manager, ] export sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, - reliability_error, callbacks, app_callbacks, reliability_config, channel_context, - reliability_manager + reliability_error, callbacks, app_callbacks, reliability_config, repair_entry, + channel_context, reliability_manager proc defaultConfig*(): ReliabilityConfig = return ReliabilityConfig.init() @@ -22,6 +22,10 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() channel.messageHistory.setLen(0) + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() + channel.messageCache.clear() + channel.messageSenders.clear() rm.channels.clear() except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() @@ -70,21 +74,76 @@ proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] = proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] = return causalHistory.mapIt(it.messageId) +## SDS-R: Repair computation functions + +proc computeTReq*( + participantId: SdsParticipantID, + messageId: SdsMessageID, + tMin: Duration, + tMax: Duration, +): Duration = + ## Computes the repair request backoff duration per SDS-R spec: + ## T_req = hash(participant_id, message_id) % (T_max - T_min) + T_min + let h = abs(hash(participantId & messageId)) + let rangeMs = tMax.inMilliseconds - tMin.inMilliseconds + if rangeMs <= 0: + return tMin + let offsetMs = h mod rangeMs + initDuration(milliseconds = tMin.inMilliseconds + offsetMs) + +proc computeTResp*( + participantId: SdsParticipantID, + senderId: SdsParticipantID, + messageId: SdsMessageID, + tMax: Duration, +): Duration = + ## Computes the repair response backoff duration per SDS-R spec: + ## distance = hash(participant_id) XOR hash(sender_id) + ## T_resp = distance * hash(message_id) % T_max + ## Original sender has distance=0, so T_resp=0 (responds immediately). + let distance = abs(hash(participantId) xor hash(senderId)) + let msgHash = abs(hash(messageId)) + let tMaxMs = tMax.inMilliseconds + if tMaxMs <= 0 or distance == 0: + return initDuration(milliseconds = 0) + # Use uint64 to avoid overflow on multiplication + let d = uint64(distance mod tMaxMs) + let m = uint64(msgHash mod tMaxMs) + let offsetMs = int64((d * m) mod uint64(tMaxMs)) + initDuration(milliseconds = offsetMs) + +proc isInResponseGroup*( + participantId: SdsParticipantID, + senderId: SdsParticipantID, + messageId: SdsMessageID, + numResponseGroups: int, +): bool = + ## Determines if this participant is in the response group for a given message per SDS-R spec: + ## hash(participant_id, message_id) % num_groups == hash(sender_id, message_id) % num_groups + if numResponseGroups <= 1: + return true # All participants in the same group + let myGroup = abs(hash(participantId & messageId)) mod numResponseGroups + let senderGroup = abs(hash(senderId & messageId)) mod numResponseGroups + myGroup == senderGroup + proc getRecentHistoryEntries*( rm: ReliabilityManager, n: int, channelId: SdsChannelID ): seq[HistoryEntry] = + ## Get recent history entries for sending in causal history. + ## Populates retrieval hints and senderId (SDS-R) for each entry. try: if channelId in rm.channels: let channel = rm.channels[channelId] let recentMessageIds = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] - if rm.onRetrievalHint.isNil(): - return toCausalHistory(recentMessageIds) - else: - var entries: seq[HistoryEntry] = @[] - for msgId in recentMessageIds: - let hint = rm.onRetrievalHint(msgId) - entries.add(newHistoryEntry(msgId, hint)) - return entries + var entries: seq[HistoryEntry] = @[] + for msgId in recentMessageIds: + var entry = HistoryEntry(messageId: msgId) + if not rm.onRetrievalHint.isNil(): + entry.retrievalHint = rm.onRetrievalHint(msgId) + if msgId in channel.messageSenders: + entry.senderId = channel.messageSenders[msgId] + entries.add(entry) + return entries else: return @[] except Exception: @@ -188,6 +247,10 @@ proc removeChannel*( channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() channel.messageHistory.setLen(0) + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() + channel.messageCache.clear() + channel.messageSenders.clear() rm.channels.del(channelId) return ok() except Exception: diff --git a/sds/types.nim b/sds/types.nim index f37518a..637ec54 100644 --- a/sds/types.nim +++ b/sds/types.nim @@ -9,6 +9,7 @@ import sds/types/reliability_error import sds/types/callbacks import sds/types/app_callbacks import sds/types/reliability_config +import sds/types/repair_entry import sds/types/channel_context import sds/types/reliability_manager import sds/types/protobuf_error @@ -25,6 +26,7 @@ export callbacks, app_callbacks, reliability_config, + repair_entry, channel_context, reliability_manager, protobuf_error diff --git a/sds/types/app_callbacks.nim b/sds/types/app_callbacks.nim index 985a97f..84873a6 100644 --- a/sds/types/app_callbacks.nim +++ b/sds/types/app_callbacks.nim @@ -7,6 +7,7 @@ type AppCallbacks* = ref object missingDependenciesCb*: MissingDependenciesCallback periodicSyncCb*: PeriodicSyncCallback retrievalHintProvider*: RetrievalHintProvider + repairReadyCb*: RepairReadyCallback proc new*( T: type AppCallbacks, @@ -15,6 +16,7 @@ proc new*( missingDependenciesCb: MissingDependenciesCallback = nil, periodicSyncCb: PeriodicSyncCallback = nil, retrievalHintProvider: RetrievalHintProvider = nil, + repairReadyCb: RepairReadyCallback = nil, ): T = return T( messageReadyCb: messageReadyCb, @@ -22,4 +24,5 @@ proc new*( missingDependenciesCb: missingDependenciesCb, periodicSyncCb: periodicSyncCb, retrievalHintProvider: retrievalHintProvider, + repairReadyCb: repairReadyCb, ) diff --git a/sds/types/callbacks.nim b/sds/types/callbacks.nim index f1fc4b3..1894f14 100644 --- a/sds/types/callbacks.nim +++ b/sds/types/callbacks.nim @@ -16,3 +16,5 @@ type RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} + + RepairReadyCallback* = proc(message: seq[byte], channelId: SdsChannelID) {.gcsafe.} diff --git a/sds/types/channel_context.nim b/sds/types/channel_context.nim index 0346d18..2f61584 100644 --- a/sds/types/channel_context.nim +++ b/sds/types/channel_context.nim @@ -3,7 +3,10 @@ import ./sds_message_id import ./rolling_bloom_filter import ./unacknowledged_message import ./incoming_message -export sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message +import ./repair_entry +export + sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message, + repair_entry type ChannelContext* = ref object lamportTimestamp*: int64 @@ -11,6 +14,13 @@ type ChannelContext* = ref object bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] incomingBuffer*: Table[SdsMessageID, IncomingMessage] + ## SDS-R buffers + outgoingRepairBuffer*: Table[SdsMessageID, OutgoingRepairEntry] + incomingRepairBuffer*: Table[SdsMessageID, IncomingRepairEntry] + messageCache*: Table[SdsMessageID, seq[byte]] + ## Cached serialized messages for repair responses + messageSenders*: Table[SdsMessageID, SdsParticipantID] + ## SDS-R: msgId -> original sender, used to populate causal-history senderId proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = return T( @@ -19,4 +29,8 @@ proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = bloomFilter: bloomFilter, outgoingBuffer: @[], incomingBuffer: initTable[SdsMessageID, IncomingMessage](), + outgoingRepairBuffer: initTable[SdsMessageID, OutgoingRepairEntry](), + incomingRepairBuffer: initTable[SdsMessageID, IncomingRepairEntry](), + messageCache: initTable[SdsMessageID, seq[byte]](), + messageSenders: initTable[SdsMessageID, SdsParticipantID](), ) diff --git a/sds/types/history_entry.nim b/sds/types/history_entry.nim index 2435e6f..b55fc20 100644 --- a/sds/types/history_entry.nim +++ b/sds/types/history_entry.nim @@ -3,6 +3,12 @@ import ./sds_message_id type HistoryEntry* = object messageId*: SdsMessageID retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash) + senderId*: string ## Original message sender's participant ID (SDS-R) -proc init*(T: type HistoryEntry, messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): T = - return T(messageId: messageId, retrievalHint: retrievalHint) +proc init*( + T: type HistoryEntry, + messageId: SdsMessageID, + retrievalHint: seq[byte] = @[], + senderId: string = "", +): T = + return T(messageId: messageId, retrievalHint: retrievalHint, senderId: senderId) diff --git a/sds/types/reliability_config.nim b/sds/types/reliability_config.nim index f4e4e78..7cd20f2 100644 --- a/sds/types/reliability_config.nim +++ b/sds/types/reliability_config.nim @@ -7,6 +7,11 @@ const DefaultMaxResendAttempts* = 5 DefaultSyncMessageInterval* = initDuration(seconds = 30) DefaultBufferSweepInterval* = initDuration(seconds = 60) + DefaultRepairTMin* = initDuration(seconds = 30) + DefaultRepairTMax* = initDuration(seconds = 300) + DefaultNumResponseGroups* = 1 + DefaultMaxRepairRequests* = 3 + DefaultRepairSweepInterval* = initDuration(seconds = 5) MaxMessageSize* = 1024 * 1024 # 1 MB import ./rolling_bloom_filter @@ -21,6 +26,12 @@ type ReliabilityConfig* {.requiresInit.} = object maxResendAttempts*: int syncMessageInterval*: Duration bufferSweepInterval*: Duration + ## SDS-R config + repairTMin*: Duration + repairTMax*: Duration + numResponseGroups*: int + maxRepairRequests*: int + repairSweepInterval*: Duration proc init*( T: type ReliabilityConfig, @@ -32,6 +43,11 @@ proc init*( maxResendAttempts: int = DefaultMaxResendAttempts, syncMessageInterval: Duration = DefaultSyncMessageInterval, bufferSweepInterval: Duration = DefaultBufferSweepInterval, + repairTMin: Duration = DefaultRepairTMin, + repairTMax: Duration = DefaultRepairTMax, + numResponseGroups: int = DefaultNumResponseGroups, + maxRepairRequests: int = DefaultMaxRepairRequests, + repairSweepInterval: Duration = DefaultRepairSweepInterval, ): T = return T( bloomFilterCapacity: bloomFilterCapacity, @@ -42,4 +58,9 @@ proc init*( maxResendAttempts: maxResendAttempts, syncMessageInterval: syncMessageInterval, bufferSweepInterval: bufferSweepInterval, + repairTMin: repairTMin, + repairTMax: repairTMax, + numResponseGroups: numResponseGroups, + maxRepairRequests: maxRepairRequests, + repairSweepInterval: repairSweepInterval, ) diff --git a/sds/types/reliability_manager.nim b/sds/types/reliability_manager.nim index 9bfc244..5545859 100644 --- a/sds/types/reliability_manager.nim +++ b/sds/types/reliability_manager.nim @@ -9,6 +9,7 @@ export sds_message_id, history_entry, callbacks, reliability_config, channel_con type ReliabilityManager* = ref object channels*: Table[SdsChannelID, ChannelContext] config*: ReliabilityConfig + participantId*: SdsParticipantID lock*: Lock onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} @@ -17,11 +18,17 @@ type ReliabilityManager* = ref object ) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback onRetrievalHint*: RetrievalHintProvider + onRepairReady*: RepairReadyCallback -proc new*(T: type ReliabilityManager, config: ReliabilityConfig): T = +proc new*( + T: type ReliabilityManager, + config: ReliabilityConfig, + participantId: SdsParticipantID = "", +): T = let rm = T( channels: initTable[SdsChannelID, ChannelContext](), config: config, + participantId: participantId, ) rm.lock.initLock() return rm diff --git a/sds/types/repair_entry.nim b/sds/types/repair_entry.nim new file mode 100644 index 0000000..01f0fd5 --- /dev/null +++ b/sds/types/repair_entry.nim @@ -0,0 +1,28 @@ +import std/times +import ./history_entry +export history_entry + +type + OutgoingRepairEntry* = object + ## Entry in the outgoing repair request buffer (SDS-R). + ## Tracks a missing message we want to request repair for. + entry*: HistoryEntry ## The missing history entry + tReq*: Time ## Timestamp after which we will include this in a repair request + + IncomingRepairEntry* = object + ## Entry in the incoming repair request buffer (SDS-R). + ## Tracks a repair request from a remote peer that we might respond to. + entry*: HistoryEntry ## The requested history entry + cachedMessage*: seq[byte] ## Full serialized SDS message for rebroadcast + tResp*: Time ## Timestamp after which we will rebroadcast + +proc init*(T: type OutgoingRepairEntry, entry: HistoryEntry, tReq: Time): T = + return T(entry: entry, tReq: tReq) + +proc init*( + T: type IncomingRepairEntry, + entry: HistoryEntry, + cachedMessage: seq[byte], + tResp: Time, +): T = + return T(entry: entry, cachedMessage: cachedMessage, tResp: tResp) diff --git a/sds/types/sds_message.nim b/sds/types/sds_message.nim index 12f7add..6ab7a4f 100644 --- a/sds/types/sds_message.nim +++ b/sds/types/sds_message.nim @@ -2,13 +2,16 @@ import ./sds_message_id import ./history_entry export sds_message_id, history_entry -type SdsMessage* {.requiresInit.} = object +type SdsMessage* = object messageId*: SdsMessageID lamportTimestamp*: int64 causalHistory*: seq[HistoryEntry] channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] + senderId*: SdsParticipantID ## SDS-R: original sender's participant ID + repairRequest*: seq[HistoryEntry] + ## Capped list of missing entries requesting repair (SDS-R) proc init*( T: type SdsMessage, @@ -18,6 +21,8 @@ proc init*( channelId: SdsChannelID, content: seq[byte], bloomFilter: seq[byte], + senderId: SdsParticipantID = "", + repairRequest: seq[HistoryEntry] = @[], ): T = return T( messageId: messageId, @@ -26,4 +31,6 @@ proc init*( channelId: channelId, content: content, bloomFilter: bloomFilter, + senderId: senderId, + repairRequest: repairRequest, ) diff --git a/sds/types/sds_message_id.nim b/sds/types/sds_message_id.nim index 3e8b7c7..dfeb025 100644 --- a/sds/types/sds_message_id.nim +++ b/sds/types/sds_message_id.nim @@ -1,3 +1,4 @@ type SdsMessageID* = string SdsChannelID* = string + SdsParticipantID* = string diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 7100606..7f738c5 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -741,3 +741,893 @@ suite "Multi-Channel ReliabilityManager Tests": # Dependencies in channel1 should not affect channel2 check rm.channels[channel1].bloomFilter.contains("dep1") check not rm.channels[channel2].bloomFilter.contains("dep1") + +# SDS-R Repair tests +suite "SDS-R: Computation Functions": + test "computeTReq returns duration in [tMin, tMax)": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d = computeTReq("participant1", "msg1", tMin, tMax) + check: + d.inMilliseconds >= tMin.inMilliseconds + d.inMilliseconds < tMax.inMilliseconds + + test "computeTReq is deterministic for same inputs": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d1 = computeTReq("p1", "m1", tMin, tMax) + let d2 = computeTReq("p1", "m1", tMin, tMax) + check d1 == d2 + + test "computeTReq varies with different participants": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d1 = computeTReq("participant-A", "msg1", tMin, tMax) + let d2 = computeTReq("participant-B", "msg1", tMin, tMax) + # Different participants should generally get different backoff (not guaranteed but highly likely) + # Just check both are in valid range + check: + d1.inMilliseconds >= tMin.inMilliseconds + d2.inMilliseconds >= tMin.inMilliseconds + + test "computeTResp original sender has zero distance": + let d = computeTResp("sender1", "sender1", "msg1", initDuration(seconds = 300)) + check d.inMilliseconds == 0 + + test "computeTResp non-sender has positive backoff": + let d = computeTResp("other-node", "sender1", "msg1", initDuration(seconds = 300)) + check d.inMilliseconds >= 0 + + test "isInResponseGroup all in same group when numGroups=1": + check isInResponseGroup("p1", "sender1", "msg1", 1) == true + check isInResponseGroup("p2", "sender1", "msg1", 1) == true + + test "isInResponseGroup sender always in own group": + # Original sender must always be in their own response group + for groups in 1 .. 10: + check isInResponseGroup("sender1", "sender1", "msg1", groups) == true + +suite "SDS-R: Repair Buffer Management": + var rm: ReliabilityManager + + setup: + let rmResult = newReliabilityManager( + participantId = "test-participant" + ) + check rmResult.isOk() + rm = rmResult.get() + check rm.ensureChannel(testChannel).isOk() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "missing deps added to outgoing repair buffer": + var missingDepsCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + missingDepsCount += 1, + ) + + # Create a message with a missing dependency + let msg = SdsMessage( + messageId: "msg2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + + let serialized = serializeMessage(msg).get() + let result = rm.unwrapReceivedMessage(serialized) + check result.isOk() + + # msg1 should be in the outgoing repair buffer + let channel = rm.channels[testChannel] + check: + missingDepsCount == 1 + "msg1" in channel.outgoingRepairBuffer + + test "receiving message clears it from repair buffers": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + # First, create the missing dep scenario + let msg2 = SdsMessage( + messageId: "msg2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg2).get()) + check "msg1" in rm.channels[testChannel].outgoingRepairBuffer + + # Now receive msg1 — should clear from repair buffer + let msg1 = SdsMessage( + messageId: "msg1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: testChannel, + content: @[byte(1)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg1).get()) + check "msg1" notin rm.channels[testChannel].outgoingRepairBuffer + + test "markDependenciesMet clears repair buffers": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg2 = SdsMessage( + messageId: "msg2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg2).get()) + check "msg1" in rm.channels[testChannel].outgoingRepairBuffer + + # Mark as met via store retrieval + check rm.markDependenciesMet(@["msg1"], testChannel).isOk() + check "msg1" notin rm.channels[testChannel].outgoingRepairBuffer + + test "expired repair requests attached to outgoing messages": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + # Manually add an expired repair entry + let channel = rm.channels[testChannel] + channel.outgoingRepairBuffer["missing-msg"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "missing-msg", senderId: "orig-sender"), + tReq: getTime() - initDuration(seconds = 10), # Already expired + ) + + # Send a message — should pick up the expired repair request + let wrapped = rm.wrapOutgoingMessage(@[byte(1)], "new-msg", testChannel) + check wrapped.isOk() + + let unwrapped = deserializeMessage(wrapped.get()).get() + check: + unwrapped.repairRequest.len == 1 + unwrapped.repairRequest[0].messageId == "missing-msg" + # Should be removed from buffer after attaching + "missing-msg" notin channel.outgoingRepairBuffer + + test "incoming repair request adds to incoming repair buffer when eligible": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + let channel = rm.channels[testChannel] + + # First, cache a message so we can respond to a repair request for it + let cachedMsg = SdsMessage( + messageId: "cached-msg", + lamportTimestamp: 1, + causalHistory: @[], + channelId: testChannel, + content: @[byte(99)], + bloomFilter: @[], + ) + let cachedBytes = serializeMessage(cachedMsg).get() + channel.messageCache["cached-msg"] = cachedBytes + + # Receive a message with a repair request for "cached-msg" + let msgWithRepair = SdsMessage( + messageId: "requester-msg", + lamportTimestamp: 5, + causalHistory: @[], + channelId: testChannel, + content: @[byte(3)], + bloomFilter: @[], + repairRequest: @[HistoryEntry( + messageId: "cached-msg", + senderId: "test-participant", # Same as our participantId so we're in response group + )], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msgWithRepair).get()) + + # We should have added it to the incoming repair buffer (we have the message and are in response group) + check "cached-msg" in channel.incomingRepairBuffer + +suite "SDS-R: Protobuf Roundtrip": + test "senderId in HistoryEntry roundtrips through protobuf": + let msg = SdsMessage( + messageId: "msg1", + lamportTimestamp: 100, + causalHistory: @[ + HistoryEntry(messageId: "dep1", retrievalHint: @[byte(1), 2], senderId: "sender-A"), + HistoryEntry(messageId: "dep2", senderId: "sender-B"), + ], + channelId: "ch1", + content: @[byte(42)], + bloomFilter: @[], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.causalHistory.len == 2 + decoded.causalHistory[0].messageId == "dep1" + decoded.causalHistory[0].senderId == "sender-A" + decoded.causalHistory[0].retrievalHint == @[byte(1), 2] + decoded.causalHistory[1].messageId == "dep2" + decoded.causalHistory[1].senderId == "sender-B" + + test "repairRequest field roundtrips through protobuf": + let msg = SdsMessage( + messageId: "msg1", + lamportTimestamp: 100, + causalHistory: @[], + channelId: "ch1", + content: @[byte(42)], + bloomFilter: @[], + repairRequest: @[ + HistoryEntry(messageId: "missing1", senderId: "sender-X"), + HistoryEntry(messageId: "missing2", senderId: "sender-Y", retrievalHint: @[byte(5)]), + ], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.repairRequest.len == 2 + decoded.repairRequest[0].messageId == "missing1" + decoded.repairRequest[0].senderId == "sender-X" + decoded.repairRequest[1].messageId == "missing2" + decoded.repairRequest[1].senderId == "sender-Y" + decoded.repairRequest[1].retrievalHint == @[byte(5)] + + test "backward compat: message without repairRequest decodes fine": + let msg = SdsMessage( + messageId: "msg1", + lamportTimestamp: 100, + causalHistory: @[HistoryEntry(messageId: "dep1")], + channelId: "ch1", + content: @[byte(42)], + bloomFilter: @[], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.repairRequest.len == 0 + decoded.causalHistory[0].senderId == "" + + test "SdsMessage.senderId roundtrips through protobuf": + let msg = SdsMessage( + messageId: "m1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: "ch1", + content: @[byte(1)], + bloomFilter: @[], + senderId: "alice", + ) + let decoded = deserializeMessage(serializeMessage(msg).get()).get() + check decoded.senderId == "alice" + +# --------------------------------------------------------------------------- +# SDS-R Phase 2 tests: edge cases, lifecycle, sweep, and multi-participant flows +# --------------------------------------------------------------------------- + +suite "SDS-R: Edge Cases and Defensive Branches": + test "computeTReq returns tMin when range is degenerate": + let tMin = initDuration(seconds = 30) + # tMax == tMin + let d1 = computeTReq("p", "m", tMin, tMin) + check d1 == tMin + # tMax < tMin (rangeMs < 0) + let d2 = computeTReq("p", "m", tMin, initDuration(seconds = 10)) + check d2 == tMin + + test "computeTResp returns 0 when tMax is 0": + let d = computeTResp("p", "other", "m", initDuration(milliseconds = 0)) + check d.inMilliseconds == 0 + + test "computeTResp always stays within [0, tMax)": + # Adversarial sweep — result must never wrap negative nor exceed tMax + let tMax = initDuration(seconds = 300) + for i in 0 ..< 500: + let d = computeTResp( + "participant-" & $i, "sender-" & $(i * 13), "msg-" & $(i * 31), tMax + ) + check: + d.inMilliseconds >= 0 + d.inMilliseconds < tMax.inMilliseconds + + test "isInResponseGroup returns true for non-positive numGroups": + check isInResponseGroup("p", "sender", "m", 0) == true + check isInResponseGroup("p", "sender", "m", -1) == true + + test "computeTReq bounds across many random inputs": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + for i in 0 ..< 200: + let d = computeTReq("p-" & $i, "m-" & $i, tMin, tMax) + check: + d.inMilliseconds >= tMin.inMilliseconds + d.inMilliseconds < tMax.inMilliseconds + + test "response group distribution is roughly uniform": + # With numGroups=10, ~10% of random participants should share sender's group. + const numGroups = 10 + const totalParticipants = 1000 + let senderId = "alice" + let msgId = "msg-xyz" + var sameGroup = 0 + for i in 0 ..< totalParticipants: + if isInResponseGroup("participant-" & $i, senderId, msgId, numGroups): + sameGroup += 1 + # Expected ~100 (1/N), allow [50, 200] band for hash quirks + check: + sameGroup >= 50 + sameGroup <= 200 + + test "computeTResp monotonicity: self always fastest": + # The original sender (distance=0) must always be first to respond. + let tMax = initDuration(seconds = 300) + let selfD = computeTResp("alice", "alice", "msg-xyz", tMax) + check selfD.inMilliseconds == 0 + for i in 0 ..< 50: + let other = computeTResp("other-" & $i, "alice", "msg-xyz", tMax) + check other.inMilliseconds >= selfD.inMilliseconds + +suite "SDS-R: Lifecycle and State": + test "empty participantId disables outgoing repair creation": + let rm = newReliabilityManager().get() # empty participantId + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage( + messageId: "m2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "m1-missing", senderId: "alice")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check rm.channels[testChannel].outgoingRepairBuffer.len == 0 + + test "empty senderId in incoming repair request is ignored": + let rm = newReliabilityManager(participantId = "bob").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + channel.messageCache["m-wanted"] = @[byte(99), 99, 99] + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage( + messageId: "req-msg", + lamportTimestamp: 5, + causalHistory: @[], + channelId: testChannel, + content: @[byte(1)], + bloomFilter: @[], + repairRequest: @[HistoryEntry(messageId: "m-wanted", senderId: "")], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check "m-wanted" notin channel.incomingRepairBuffer + + test "wrapOutgoingMessage caches bytes and records sender": + # Proves Bug 1 is fixed — the original sender can serve her own message. + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + discard rm.wrapOutgoingMessage(@[byte(1), 2, 3], "m1", testChannel) + let channel = rm.channels[testChannel] + check: + "m1" in channel.messageCache + channel.messageCache["m1"].len > 0 + "m1" in channel.messageSenders + channel.messageSenders["m1"] == "alice" + + test "getRecentHistoryEntries carries senderId for own messages": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + discard rm.wrapOutgoingMessage(@[byte(1)], "m1", testChannel) + discard rm.wrapOutgoingMessage(@[byte(2)], "m2", testChannel) + let entries = rm.getRecentHistoryEntries(10, testChannel) + check: + entries.len == 2 + entries[0].senderId == "alice" + entries[1].senderId == "alice" + + test "resetReliabilityManager clears all SDS-R state": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + + channel.outgoingRepairBuffer["a"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "a", senderId: "x"), + tReq: getTime(), + ) + channel.incomingRepairBuffer["b"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "b", senderId: "y"), + cachedMessage: @[byte(1)], + tResp: getTime(), + ) + channel.messageCache["c"] = @[byte(2)] + channel.messageSenders["c"] = "someone" + + check rm.resetReliabilityManager().isOk() + check rm.ensureChannel(testChannel).isOk() + let ch2 = rm.channels[testChannel] + check: + ch2.outgoingRepairBuffer.len == 0 + ch2.incomingRepairBuffer.len == 0 + ch2.messageCache.len == 0 + ch2.messageSenders.len == 0 + + test "SDS-R state is isolated per channel": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel("ch-A").isOk() + check rm.ensureChannel("ch-B").isOk() + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage( + messageId: "m2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "m1-missing", senderId: "bob")], + channelId: "ch-A", + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check: + rm.channels["ch-A"].outgoingRepairBuffer.len == 1 + rm.channels["ch-B"].outgoingRepairBuffer.len == 0 + + test "duplicate message arrival cancels pending incoming repair entry": + # Covers the dedup-before-cleanup fix: a rebroadcast arriving at a peer who + # already has the message must clear that peer's incomingRepairBuffer entry. + let rm = newReliabilityManager(participantId = "carol").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + # Carol already has M1 in history and has a pending incomingRepairBuffer entry + channel.messageHistory.add("m1") + channel.incomingRepairBuffer["m1"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "m1", senderId: "alice"), + cachedMessage: @[byte(1)], + tResp: getTime() + initDuration(seconds = 10), + ) + + # A rebroadcast of M1 arrives + let msg = SdsMessage( + messageId: "m1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: testChannel, + content: @[byte(1)], + bloomFilter: @[], + senderId: "alice", + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check "m1" notin channel.incomingRepairBuffer + +suite "SDS-R: Repair Sweep": + var rm: ReliabilityManager + + setup: + rm = newReliabilityManager(participantId = "bob").get() + check rm.ensureChannel(testChannel).isOk() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "runRepairSweep fires onRepairReady for expired tResp": + var fireCount = 0 + var firstBytes: seq[byte] = @[] + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + fireCount += 1 + if fireCount == 1: + firstBytes = bytes, + ) + + let channel = rm.channels[testChannel] + channel.incomingRepairBuffer["m-ready"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "m-ready", senderId: "alice"), + cachedMessage: @[byte(1), 2, 3], + tResp: getTime() - initDuration(seconds = 1), # expired + ) + channel.incomingRepairBuffer["m-not-ready"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "m-not-ready", senderId: "alice"), + cachedMessage: @[byte(9), 9, 9], + tResp: getTime() + initDuration(minutes = 10), # far future + ) + + rm.runRepairSweep() + + check: + fireCount == 1 + firstBytes == @[byte(1), 2, 3] + "m-ready" notin channel.incomingRepairBuffer + "m-not-ready" in channel.incomingRepairBuffer + + test "runRepairSweep drops outgoing entries past T_max window": + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let channel = rm.channels[testChannel] + let tMax = rm.config.repairTMax + channel.outgoingRepairBuffer["m-stale"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "m-stale", senderId: "alice"), + tReq: getTime() - (tMax + tMax), # now - 2*T_max, past drop window + ) + channel.outgoingRepairBuffer["m-fresh"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "m-fresh", senderId: "alice"), + tReq: getTime(), + ) + + rm.runRepairSweep() + + check: + "m-stale" notin channel.outgoingRepairBuffer + "m-fresh" in channel.outgoingRepairBuffer + + test "runRepairSweep no-op when buffers are empty": + var fireCount = 0 + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + fireCount += 1, + ) + rm.runRepairSweep() + check fireCount == 0 + +# --- Multi-participant in-process bus for integration tests --------------- + +type + TestBus = ref object + peers: OrderedTable[SdsParticipantID, ReliabilityManager] + delivered: Table[SdsParticipantID, seq[SdsMessageID]] + # Log of raw message-ids placed on the wire, tagged with the source peer. + wireLog: seq[tuple[senderId: SdsParticipantID, messageId: SdsMessageID]] + +proc newTestBus(): TestBus = + TestBus( + peers: initOrderedTable[SdsParticipantID, ReliabilityManager](), + delivered: initTable[SdsParticipantID, seq[SdsMessageID]](), + wireLog: @[], + ) + +proc recordWire(bus: TestBus, senderId: SdsParticipantID, bytes: seq[byte]) {.gcsafe.} = + let decoded = deserializeMessage(bytes) + if decoded.isOk(): + bus.wireLog.add((senderId, decoded.get().messageId)) + +proc deliverExcept( + bus: TestBus, + senderId: SdsParticipantID, + bytes: seq[byte], + exclude: seq[SdsParticipantID], +) {.gcsafe.} = + for pid, peer in bus.peers: + if pid == senderId or pid in exclude: + continue + discard peer.unwrapReceivedMessage(bytes) + +proc addPeer( + bus: TestBus, + participantId: SdsParticipantID, + config: ReliabilityConfig = defaultConfig(), +): ReliabilityManager = + let rm = newReliabilityManager(config, participantId).get() + doAssert rm.ensureChannel(testChannel).isOk() + bus.peers[participantId] = rm + bus.delivered[participantId] = @[] + + let pid = participantId + let busRef = bus + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + busRef.delivered[pid].add(msgId), + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + busRef.recordWire(pid, bytes) + busRef.deliverExcept(pid, bytes, @[]), + ) + rm + +proc broadcast( + bus: TestBus, + senderId: SdsParticipantID, + content: seq[byte], + messageId: SdsMessageID, + dropAt: seq[SdsParticipantID] = @[], +) = + let rm = bus.peers[senderId] + let wrapped = rm.wrapOutgoingMessage(content, messageId, testChannel) + doAssert wrapped.isOk() + bus.recordWire(senderId, wrapped.get()) + bus.deliverExcept(senderId, wrapped.get(), dropAt) + +proc forceOutgoingExpired( + rm: ReliabilityManager, messageId: SdsMessageID +) = + ## Push a specific outgoingRepairBuffer entry's tReq into the past so the + ## next wrapOutgoingMessage will pick it up. + let channel = rm.channels[testChannel] + if messageId in channel.outgoingRepairBuffer: + channel.outgoingRepairBuffer[messageId].tReq = + getTime() - initDuration(seconds = 1) + +proc forceIncomingExpired( + rm: ReliabilityManager, messageId: SdsMessageID +) = + ## Push an incomingRepairBuffer entry's tResp into the past so runRepairSweep fires it. + let channel = rm.channels[testChannel] + if messageId in channel.incomingRepairBuffer: + channel.incomingRepairBuffer[messageId].tResp = + getTime() - initDuration(seconds = 1) + +suite "SDS-R: Multi-Participant Integration": + + test "basic single-gap repair (Alice -> Bob misses -> Carol's message triggers repair)": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1, but Bob is offline for this one. + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + # Carol now has M1; Bob does not. + check "m1" in carol.channels[testChannel].messageHistory + check "m1" notin bob.channels[testChannel].messageHistory + + # Carol sends M2 with causal history referencing M1. + bus.broadcast("carol", @[byte(2)], "m2") + # Bob detects M1 missing and populates his outgoingRepairBuffer. + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + # Bob should have buffered M2. + check "m2" in bob.channels[testChannel].incomingBuffer + check "m2" notin bus.delivered["bob"] + + # Force Bob's T_req so the next wrap attaches the repair request. + bob.forceOutgoingExpired("m1") + + # Bob sends M3 — it must carry repair_request=[M1, sender=alice]. + bus.broadcast("bob", @[byte(3)], "m3") + + # Alice received M3, saw the repair_request, cached-bypass and response-group + # checks pass, so she has an incomingRepairBuffer entry for M1 with tResp=0. + check "m1" in alice.channels[testChannel].incomingRepairBuffer + + # Force alice's tResp to past just to be safe (it's already 0 for self), + # then run her sweep. She rebroadcasts M1. + alice.forceIncomingExpired("m1") + alice.runRepairSweep() + + # Bob now has M1 and M2 delivered. + check: + "m1" in bus.delivered["bob"] + "m2" in bus.delivered["bob"] + + test "response cancellation: only one rebroadcast on the wire": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1, Bob offline. + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + # Carol sends M2; Bob sees M1 missing. + bus.broadcast("carol", @[byte(2)], "m2") + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + + # Bob requests repair. + bob.forceOutgoingExpired("m1") + bus.broadcast("bob", @[byte(3)], "m3") + + # Both Alice and Carol now have an incomingRepairBuffer entry for M1. + check: + "m1" in alice.channels[testChannel].incomingRepairBuffer + "m1" in carol.channels[testChannel].incomingRepairBuffer + + # Alice fires first (T_resp=0 for self). Her rebroadcast should cancel Carol's + # pending entry when Carol receives the rebroadcast. + alice.forceIncomingExpired("m1") + alice.runRepairSweep() + + # Carol's pending response must have been cleared by the dedup-path cleanup. + check "m1" notin carol.channels[testChannel].incomingRepairBuffer + + # Even if we now force-run Carol's sweep, nothing should fire. + let wireCountBefore = bus.wireLog.len + carol.runRepairSweep() + check bus.wireLog.len == wireCountBefore + + # Bob received exactly one rebroadcast of M1. + var m1RebroadcastCount = 0 + for entry in bus.wireLog: + if entry.messageId == "m1" and entry.senderId != "alice": + discard # only the original Alice->all broadcast had senderId="alice" + if entry.messageId == "m1": + m1RebroadcastCount += 1 + # Two "m1" entries total on wire: (1) Alice's original broadcast, (2) Alice's rebroadcast. + check m1RebroadcastCount == 2 + + test "cancellation on incoming repair request: peer drops its own pending request": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1 — drop at both Bob and Carol, so both miss it. + bus.broadcast( + "alice", @[byte(1)], "m1", + dropAt = @["bob".SdsParticipantID, "carol".SdsParticipantID], + ) + # Alice sends M2 referencing M1 — both Bob and Carol see M1 missing. + bus.broadcast("alice", @[byte(2)], "m2") + check: + "m1" in bob.channels[testChannel].outgoingRepairBuffer + "m1" in carol.channels[testChannel].outgoingRepairBuffer + + # Bob's T_req fires first. He sends a repair request for M1. + bob.forceOutgoingExpired("m1") + bus.broadcast("bob", @[byte(3)], "m3") + + # Carol, on receiving Bob's repair request, must have dropped her own + # pending outgoingRepairBuffer entry for M1 (cancellation). + check "m1" notin carol.channels[testChannel].outgoingRepairBuffer + + test "response group filtering: only group members respond": + # With numGroups=10, roughly 1/10 of receivers will be in the group. + # Construct a sender+message where a specific non-sender is NOT in the group. + var cfg = defaultConfig() + cfg.numResponseGroups = 10 + + # Pick a msgId where carol is not in the group and bob is + # We probe deterministically because computeTReq/isInResponseGroup are pure. + var chosenMsg = "" + for i in 0 ..< 1000: + let candidate = "probe-" & $i + let bobIn = isInResponseGroup("bob", "alice", candidate, 10) + let carolIn = isInResponseGroup("carol", "alice", candidate, 10) + if bobIn and not carolIn: + chosenMsg = candidate + break + check chosenMsg.len > 0 + + let bus = newTestBus() + discard bus.addPeer("alice", cfg) + let bob = bus.addPeer("bob", cfg) + let carol = bus.addPeer("carol", cfg) + + # Both Bob and Carol receive the original M1 (so both have it in messageCache). + bus.broadcast("alice", @[byte(1)], chosenMsg) + + # Now Dave arrives: build a fake requester message manually so its repair_request + # names Alice as senderId for chosenMsg. + # We inject directly by calling unwrapReceivedMessage on bob/carol. + let dave = bus.addPeer("dave", cfg) + # Dave has no messages, but we can hand-craft a repair request he would send. + let reqMsg = SdsMessage( + messageId: "req-from-dave", + lamportTimestamp: 10, + causalHistory: @[], + channelId: testChannel, + content: @[byte(9)], + bloomFilter: @[], + senderId: "dave", + repairRequest: @[HistoryEntry(messageId: chosenMsg, senderId: "alice")], + ) + let bytes = serializeMessage(reqMsg).get() + discard bob.unwrapReceivedMessage(bytes) + discard carol.unwrapReceivedMessage(bytes) + + check: + chosenMsg in bob.channels[testChannel].incomingRepairBuffer + chosenMsg notin carol.channels[testChannel].incomingRepairBuffer + + test "multi-gap batch repair: many missing deps split across requests": + let bus = newTestBus() + discard bus.addPeer("alice") + let bob = bus.addPeer("bob") + + # Alice sends 5 messages while Bob is offline. + let drops = @["bob".SdsParticipantID] + bus.broadcast("alice", @[byte(1)], "m1", dropAt = drops) + bus.broadcast("alice", @[byte(2)], "m2", dropAt = drops) + bus.broadcast("alice", @[byte(3)], "m3", dropAt = drops) + bus.broadcast("alice", @[byte(4)], "m4", dropAt = drops) + bus.broadcast("alice", @[byte(5)], "m5", dropAt = drops) + + # Bob comes online and receives M6 which depends on m1..m5. + bus.broadcast("alice", @[byte(6)], "m6") + + # Bob should have 5 outgoing repair entries. + let channel = bob.channels[testChannel] + check channel.outgoingRepairBuffer.len == 5 + + # Force all to expired and wrap one message — only maxRepairRequests + # (default 3) should attach to a single outgoing message. + for id in ["m1", "m2", "m3", "m4", "m5"]: + bob.forceOutgoingExpired(id) + + let wrapped = bob.wrapOutgoingMessage(@[byte(99)], "bob-msg-1", testChannel).get() + let decoded = deserializeMessage(wrapped).get() + check decoded.repairRequest.len <= bob.config.maxRepairRequests + + # The attached entries should be removed from the outgoing buffer. + check channel.outgoingRepairBuffer.len == 5 - decoded.repairRequest.len + + test "markDependenciesMet externally clears pending repair entry": + let bus = newTestBus() + discard bus.addPeer("alice") + let bob = bus.addPeer("bob") + + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + bus.broadcast("alice", @[byte(2)], "m2") + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + + # Simulate Bob fetching M1 via an out-of-band store query. + check bob.markDependenciesMet(@["m1"], testChannel).isOk() + check "m1" notin bob.channels[testChannel].outgoingRepairBuffer