diff --git a/sds.nim b/sds.nim index d9c4b82..c4a1482 100644 --- a/sds.nim +++ b/sds.nim @@ -107,6 +107,7 @@ proc wrapOutgoingMessage*( channelId = channelId, content = message, bloomFilter = bfResult.get(), + senderId = rm.participantId, repairRequest = repairReqs, ) @@ -117,7 +118,16 @@ proc wrapOutgoingMessage*( channel.bloomFilter.add(msg.messageId) rm.addToHistory(msg.messageId, channelId) - return serializeMessage(msg) + # SDS-R: record sender for future causal-history entries + if rm.participantId.len > 0: + channel.messageSenders[msg.messageId] = rm.participantId + + let serialized = serializeMessage(msg) + if serialized.isOk(): + # SDS-R: cache serialized bytes so we can serve our own message on repair + if channel.messageCache.len < rm.config.maxMessageHistory: + channel.messageCache[msg.messageId] = serialized.get() + return serialized except Exception: error "Failed to wrap message", channelId = channelId, msg = getCurrentExceptionMsg() @@ -177,6 +187,11 @@ proc unwrapReceivedMessage*( let channel = rm.getOrCreateChannel(channelId) + # SDS-R: opportunistic repair-buffer cleanup — applies to duplicates too, + # so rebroadcasts cancel redundant responses on peers that already have the message. + channel.outgoingRepairBuffer.del(msg.messageId) + channel.incomingRepairBuffer.del(msg.messageId) + if msg.messageId in channel.messageHistory: return ok((msg.content, @[], channelId)) @@ -185,14 +200,14 @@ proc unwrapReceivedMessage*( rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) rm.reviewAckStatus(msg) - # SDS-R: remove this message from repair buffers (dependency now met) - channel.outgoingRepairBuffer.del(msg.messageId) - channel.incomingRepairBuffer.del(msg.messageId) - # SDS-R: cache the raw message for potential repair responses if channel.messageCache.len < rm.config.maxMessageHistory: channel.messageCache[msg.messageId] = message + # SDS-R: record sender so our future causal-history entries carry it + if msg.senderId.len > 0: + channel.messageSenders[msg.messageId] = msg.senderId + # SDS-R: process incoming repair requests from this message let now = getTime() for repairEntry in msg.repairRequest: @@ -228,6 +243,10 @@ proc unwrapReceivedMessage*( IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]()) else: rm.addToHistory(msg.messageId, channelId) + # Unblock any buffered messages that were waiting on this one. + for pendingId, entry in channel.incomingBuffer: + if msg.messageId in entry.missingDeps: + channel.incomingBuffer[pendingId].missingDeps.excl(msg.messageId) rm.processIncomingBuffer(channelId) if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId, channelId) @@ -361,43 +380,47 @@ proc periodicSyncMessage( error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) +proc runRepairSweep*(rm: ReliabilityManager) {.gcsafe, raises: [].} = + ## SDS-R: Runs a single pass of the repair sweep. + ## - Incoming: fires onRepairReady for expired T_resp entries and removes them + ## - Outgoing: drops entries past T_max window + ## Exposed so it can be driven directly in tests; also invoked by periodicRepairSweep. + try: + let now = getTime() + for channelId, channel in rm.channels: + try: + # Check incoming repair buffer for expired T_resp (time to rebroadcast) + var toRebroadcast: seq[SdsMessageID] = @[] + for msgId, entry in channel.incomingRepairBuffer: + if now >= entry.tResp: + toRebroadcast.add(msgId) + + for msgId in toRebroadcast: + let entry = channel.incomingRepairBuffer[msgId] + channel.incomingRepairBuffer.del(msgId) + if not rm.onRepairReady.isNil(): + rm.onRepairReady(entry.cachedMessage, channelId) + + # Drop expired outgoing repair entries past T_max + var toRemove: seq[SdsMessageID] = @[] + let tMaxDuration = rm.config.repairTMax + for msgId, entry in channel.outgoingRepairBuffer: + if now - entry.tReq > tMaxDuration: + toRemove.add(msgId) + for msgId in toRemove: + channel.outgoingRepairBuffer.del(msgId) + except Exception: + error "Error in repair sweep for channel", + channelId = channelId, msg = getCurrentExceptionMsg() + except Exception: + error "Error in repair sweep", msg = getCurrentExceptionMsg() + proc periodicRepairSweep( rm: ReliabilityManager ) {.async: (raises: [CancelledError]), gcsafe.} = ## SDS-R: Periodically checks repair buffers for expired entries. - ## - Incoming: fires onRepairReady for expired T_resp entries - ## - Outgoing: drops entries past T_max while true: - try: - let now = getTime() - for channelId, channel in rm.channels: - try: - # Check incoming repair buffer for expired T_resp (time to rebroadcast) - var toRebroadcast: seq[SdsMessageID] = @[] - for msgId, entry in channel.incomingRepairBuffer: - if now >= entry.tResp: - toRebroadcast.add(msgId) - - for msgId in toRebroadcast: - let entry = channel.incomingRepairBuffer[msgId] - channel.incomingRepairBuffer.del(msgId) - if not rm.onRepairReady.isNil(): - rm.onRepairReady(entry.cachedMessage, channelId) - - # Drop expired outgoing repair entries past T_max - var toRemove: seq[SdsMessageID] = @[] - let tMaxDuration = rm.config.repairTMax - for msgId, entry in channel.outgoingRepairBuffer: - if now - entry.tReq > tMaxDuration: - toRemove.add(msgId) - for msgId in toRemove: - channel.outgoingRepairBuffer.del(msgId) - except Exception: - error "Error in repair sweep for channel", - channelId = channelId, msg = getCurrentExceptionMsg() - except Exception: - error "Error in periodic repair sweep", msg = getCurrentExceptionMsg() - + rm.runRepairSweep() await sleepAsync(chronos.milliseconds(rm.config.repairSweepInterval.inMilliseconds)) proc startPeriodicTasks*(rm: ReliabilityManager) = @@ -418,6 +441,7 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE channel.outgoingRepairBuffer.clear() channel.incomingRepairBuffer.clear() channel.messageCache.clear() + channel.messageSenders.clear() channel.bloomFilter = RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) rm.channels.clear() diff --git a/sds/protobuf.nim b/sds/protobuf.nim index 63830c7..24a95d7 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -37,6 +37,9 @@ proc encode*(msg: SdsMessage): ProtoBuffer = pb.write(5, msg.content) pb.write(6, msg.bloomFilter) + if msg.senderId.len > 0: + pb.write(7, msg.senderId) + for entry in msg.repairRequest: let entryPb = encodeHistoryEntry(entry) pb.write(13, entryPb.buffer) @@ -81,6 +84,9 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if not ?pb.getField(6, msg.bloomFilter): msg.bloomFilter = @[] # Empty if not present + # SDS-R: decode senderId (field 7, optional) + discard pb.getField(7, msg.senderId) + # SDS-R: decode repair request (field 13, optional) var repairBuffers: seq[seq[byte]] if pb.getRepeatedField(13, repairBuffers).isOk(): diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index f979b3e..7e4eba6 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -25,6 +25,7 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = channel.outgoingRepairBuffer.clear() channel.incomingRepairBuffer.clear() channel.messageCache.clear() + channel.messageSenders.clear() rm.channels.clear() except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() @@ -128,18 +129,21 @@ proc isInResponseGroup*( proc getRecentHistoryEntries*( rm: ReliabilityManager, n: int, channelId: SdsChannelID ): seq[HistoryEntry] = + ## Get recent history entries for sending in causal history. + ## Populates retrieval hints and senderId (SDS-R) for each entry. try: if channelId in rm.channels: let channel = rm.channels[channelId] let recentMessageIds = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] - if rm.onRetrievalHint.isNil(): - return toCausalHistory(recentMessageIds) - else: - var entries: seq[HistoryEntry] = @[] - for msgId in recentMessageIds: - let hint = rm.onRetrievalHint(msgId) - entries.add(newHistoryEntry(msgId, hint)) - return entries + var entries: seq[HistoryEntry] = @[] + for msgId in recentMessageIds: + var entry = HistoryEntry(messageId: msgId) + if not rm.onRetrievalHint.isNil(): + entry.retrievalHint = rm.onRetrievalHint(msgId) + if msgId in channel.messageSenders: + entry.senderId = channel.messageSenders[msgId] + entries.add(entry) + return entries else: return @[] except Exception: @@ -246,6 +250,7 @@ proc removeChannel*( channel.outgoingRepairBuffer.clear() channel.incomingRepairBuffer.clear() channel.messageCache.clear() + channel.messageSenders.clear() rm.channels.del(channelId) return ok() except Exception: diff --git a/sds/types/channel_context.nim b/sds/types/channel_context.nim index cec11dc..2f61584 100644 --- a/sds/types/channel_context.nim +++ b/sds/types/channel_context.nim @@ -19,6 +19,8 @@ type ChannelContext* = ref object incomingRepairBuffer*: Table[SdsMessageID, IncomingRepairEntry] messageCache*: Table[SdsMessageID, seq[byte]] ## Cached serialized messages for repair responses + messageSenders*: Table[SdsMessageID, SdsParticipantID] + ## SDS-R: msgId -> original sender, used to populate causal-history senderId proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = return T( @@ -30,4 +32,5 @@ proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = outgoingRepairBuffer: initTable[SdsMessageID, OutgoingRepairEntry](), incomingRepairBuffer: initTable[SdsMessageID, IncomingRepairEntry](), messageCache: initTable[SdsMessageID, seq[byte]](), + messageSenders: initTable[SdsMessageID, SdsParticipantID](), ) diff --git a/sds/types/sds_message.nim b/sds/types/sds_message.nim index b50380c..6ab7a4f 100644 --- a/sds/types/sds_message.nim +++ b/sds/types/sds_message.nim @@ -9,6 +9,7 @@ type SdsMessage* = object channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] + senderId*: SdsParticipantID ## SDS-R: original sender's participant ID repairRequest*: seq[HistoryEntry] ## Capped list of missing entries requesting repair (SDS-R) @@ -20,6 +21,7 @@ proc init*( channelId: SdsChannelID, content: seq[byte], bloomFilter: seq[byte], + senderId: SdsParticipantID = "", repairRequest: seq[HistoryEntry] = @[], ): T = return T( @@ -29,5 +31,6 @@ proc init*( channelId: channelId, content: content, bloomFilter: bloomFilter, + senderId: senderId, repairRequest: repairRequest, ) diff --git a/tests/test_bloom b/tests/test_bloom index cf17bfa..f4776c4 100755 Binary files a/tests/test_bloom and b/tests/test_bloom differ diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index aa0eb06..7f738c5 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -1015,3 +1015,619 @@ suite "SDS-R: Protobuf Roundtrip": check: decoded.repairRequest.len == 0 decoded.causalHistory[0].senderId == "" + + test "SdsMessage.senderId roundtrips through protobuf": + let msg = SdsMessage( + messageId: "m1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: "ch1", + content: @[byte(1)], + bloomFilter: @[], + senderId: "alice", + ) + let decoded = deserializeMessage(serializeMessage(msg).get()).get() + check decoded.senderId == "alice" + +# --------------------------------------------------------------------------- +# SDS-R Phase 2 tests: edge cases, lifecycle, sweep, and multi-participant flows +# --------------------------------------------------------------------------- + +suite "SDS-R: Edge Cases and Defensive Branches": + test "computeTReq returns tMin when range is degenerate": + let tMin = initDuration(seconds = 30) + # tMax == tMin + let d1 = computeTReq("p", "m", tMin, tMin) + check d1 == tMin + # tMax < tMin (rangeMs < 0) + let d2 = computeTReq("p", "m", tMin, initDuration(seconds = 10)) + check d2 == tMin + + test "computeTResp returns 0 when tMax is 0": + let d = computeTResp("p", "other", "m", initDuration(milliseconds = 0)) + check d.inMilliseconds == 0 + + test "computeTResp always stays within [0, tMax)": + # Adversarial sweep — result must never wrap negative nor exceed tMax + let tMax = initDuration(seconds = 300) + for i in 0 ..< 500: + let d = computeTResp( + "participant-" & $i, "sender-" & $(i * 13), "msg-" & $(i * 31), tMax + ) + check: + d.inMilliseconds >= 0 + d.inMilliseconds < tMax.inMilliseconds + + test "isInResponseGroup returns true for non-positive numGroups": + check isInResponseGroup("p", "sender", "m", 0) == true + check isInResponseGroup("p", "sender", "m", -1) == true + + test "computeTReq bounds across many random inputs": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + for i in 0 ..< 200: + let d = computeTReq("p-" & $i, "m-" & $i, tMin, tMax) + check: + d.inMilliseconds >= tMin.inMilliseconds + d.inMilliseconds < tMax.inMilliseconds + + test "response group distribution is roughly uniform": + # With numGroups=10, ~10% of random participants should share sender's group. + const numGroups = 10 + const totalParticipants = 1000 + let senderId = "alice" + let msgId = "msg-xyz" + var sameGroup = 0 + for i in 0 ..< totalParticipants: + if isInResponseGroup("participant-" & $i, senderId, msgId, numGroups): + sameGroup += 1 + # Expected ~100 (1/N), allow [50, 200] band for hash quirks + check: + sameGroup >= 50 + sameGroup <= 200 + + test "computeTResp monotonicity: self always fastest": + # The original sender (distance=0) must always be first to respond. + let tMax = initDuration(seconds = 300) + let selfD = computeTResp("alice", "alice", "msg-xyz", tMax) + check selfD.inMilliseconds == 0 + for i in 0 ..< 50: + let other = computeTResp("other-" & $i, "alice", "msg-xyz", tMax) + check other.inMilliseconds >= selfD.inMilliseconds + +suite "SDS-R: Lifecycle and State": + test "empty participantId disables outgoing repair creation": + let rm = newReliabilityManager().get() # empty participantId + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage( + messageId: "m2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "m1-missing", senderId: "alice")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check rm.channels[testChannel].outgoingRepairBuffer.len == 0 + + test "empty senderId in incoming repair request is ignored": + let rm = newReliabilityManager(participantId = "bob").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + channel.messageCache["m-wanted"] = @[byte(99), 99, 99] + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage( + messageId: "req-msg", + lamportTimestamp: 5, + causalHistory: @[], + channelId: testChannel, + content: @[byte(1)], + bloomFilter: @[], + repairRequest: @[HistoryEntry(messageId: "m-wanted", senderId: "")], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check "m-wanted" notin channel.incomingRepairBuffer + + test "wrapOutgoingMessage caches bytes and records sender": + # Proves Bug 1 is fixed — the original sender can serve her own message. + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + discard rm.wrapOutgoingMessage(@[byte(1), 2, 3], "m1", testChannel) + let channel = rm.channels[testChannel] + check: + "m1" in channel.messageCache + channel.messageCache["m1"].len > 0 + "m1" in channel.messageSenders + channel.messageSenders["m1"] == "alice" + + test "getRecentHistoryEntries carries senderId for own messages": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + + discard rm.wrapOutgoingMessage(@[byte(1)], "m1", testChannel) + discard rm.wrapOutgoingMessage(@[byte(2)], "m2", testChannel) + let entries = rm.getRecentHistoryEntries(10, testChannel) + check: + entries.len == 2 + entries[0].senderId == "alice" + entries[1].senderId == "alice" + + test "resetReliabilityManager clears all SDS-R state": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + + channel.outgoingRepairBuffer["a"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "a", senderId: "x"), + tReq: getTime(), + ) + channel.incomingRepairBuffer["b"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "b", senderId: "y"), + cachedMessage: @[byte(1)], + tResp: getTime(), + ) + channel.messageCache["c"] = @[byte(2)] + channel.messageSenders["c"] = "someone" + + check rm.resetReliabilityManager().isOk() + check rm.ensureChannel(testChannel).isOk() + let ch2 = rm.channels[testChannel] + check: + ch2.outgoingRepairBuffer.len == 0 + ch2.incomingRepairBuffer.len == 0 + ch2.messageCache.len == 0 + ch2.messageSenders.len == 0 + + test "SDS-R state is isolated per channel": + let rm = newReliabilityManager(participantId = "alice").get() + defer: rm.cleanup() + check rm.ensureChannel("ch-A").isOk() + check rm.ensureChannel("ch-B").isOk() + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg = SdsMessage( + messageId: "m2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "m1-missing", senderId: "bob")], + channelId: "ch-A", + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check: + rm.channels["ch-A"].outgoingRepairBuffer.len == 1 + rm.channels["ch-B"].outgoingRepairBuffer.len == 0 + + test "duplicate message arrival cancels pending incoming repair entry": + # Covers the dedup-before-cleanup fix: a rebroadcast arriving at a peer who + # already has the message must clear that peer's incomingRepairBuffer entry. + let rm = newReliabilityManager(participantId = "carol").get() + defer: rm.cleanup() + check rm.ensureChannel(testChannel).isOk() + let channel = rm.channels[testChannel] + + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + # Carol already has M1 in history and has a pending incomingRepairBuffer entry + channel.messageHistory.add("m1") + channel.incomingRepairBuffer["m1"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "m1", senderId: "alice"), + cachedMessage: @[byte(1)], + tResp: getTime() + initDuration(seconds = 10), + ) + + # A rebroadcast of M1 arrives + let msg = SdsMessage( + messageId: "m1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: testChannel, + content: @[byte(1)], + bloomFilter: @[], + senderId: "alice", + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg).get()) + check "m1" notin channel.incomingRepairBuffer + +suite "SDS-R: Repair Sweep": + var rm: ReliabilityManager + + setup: + rm = newReliabilityManager(participantId = "bob").get() + check rm.ensureChannel(testChannel).isOk() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "runRepairSweep fires onRepairReady for expired tResp": + var fireCount = 0 + var firstBytes: seq[byte] = @[] + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + fireCount += 1 + if fireCount == 1: + firstBytes = bytes, + ) + + let channel = rm.channels[testChannel] + channel.incomingRepairBuffer["m-ready"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "m-ready", senderId: "alice"), + cachedMessage: @[byte(1), 2, 3], + tResp: getTime() - initDuration(seconds = 1), # expired + ) + channel.incomingRepairBuffer["m-not-ready"] = IncomingRepairEntry( + entry: HistoryEntry(messageId: "m-not-ready", senderId: "alice"), + cachedMessage: @[byte(9), 9, 9], + tResp: getTime() + initDuration(minutes = 10), # far future + ) + + rm.runRepairSweep() + + check: + fireCount == 1 + firstBytes == @[byte(1), 2, 3] + "m-ready" notin channel.incomingRepairBuffer + "m-not-ready" in channel.incomingRepairBuffer + + test "runRepairSweep drops outgoing entries past T_max window": + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + ) + + let channel = rm.channels[testChannel] + let tMax = rm.config.repairTMax + channel.outgoingRepairBuffer["m-stale"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "m-stale", senderId: "alice"), + tReq: getTime() - (tMax + tMax), # now - 2*T_max, past drop window + ) + channel.outgoingRepairBuffer["m-fresh"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "m-fresh", senderId: "alice"), + tReq: getTime(), + ) + + rm.runRepairSweep() + + check: + "m-stale" notin channel.outgoingRepairBuffer + "m-fresh" in channel.outgoingRepairBuffer + + test "runRepairSweep no-op when buffers are empty": + var fireCount = 0 + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + fireCount += 1, + ) + rm.runRepairSweep() + check fireCount == 0 + +# --- Multi-participant in-process bus for integration tests --------------- + +type + TestBus = ref object + peers: OrderedTable[SdsParticipantID, ReliabilityManager] + delivered: Table[SdsParticipantID, seq[SdsMessageID]] + # Log of raw message-ids placed on the wire, tagged with the source peer. + wireLog: seq[tuple[senderId: SdsParticipantID, messageId: SdsMessageID]] + +proc newTestBus(): TestBus = + TestBus( + peers: initOrderedTable[SdsParticipantID, ReliabilityManager](), + delivered: initTable[SdsParticipantID, seq[SdsMessageID]](), + wireLog: @[], + ) + +proc recordWire(bus: TestBus, senderId: SdsParticipantID, bytes: seq[byte]) {.gcsafe.} = + let decoded = deserializeMessage(bytes) + if decoded.isOk(): + bus.wireLog.add((senderId, decoded.get().messageId)) + +proc deliverExcept( + bus: TestBus, + senderId: SdsParticipantID, + bytes: seq[byte], + exclude: seq[SdsParticipantID], +) {.gcsafe.} = + for pid, peer in bus.peers: + if pid == senderId or pid in exclude: + continue + discard peer.unwrapReceivedMessage(bytes) + +proc addPeer( + bus: TestBus, + participantId: SdsParticipantID, + config: ReliabilityConfig = defaultConfig(), +): ReliabilityManager = + let rm = newReliabilityManager(config, participantId).get() + doAssert rm.ensureChannel(testChannel).isOk() + bus.peers[participantId] = rm + bus.delivered[participantId] = @[] + + let pid = participantId + let busRef = bus + rm.setCallbacks( + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + busRef.delivered[pid].add(msgId), + proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, + proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, + onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + {.cast(gcsafe).}: + busRef.recordWire(pid, bytes) + busRef.deliverExcept(pid, bytes, @[]), + ) + rm + +proc broadcast( + bus: TestBus, + senderId: SdsParticipantID, + content: seq[byte], + messageId: SdsMessageID, + dropAt: seq[SdsParticipantID] = @[], +) = + let rm = bus.peers[senderId] + let wrapped = rm.wrapOutgoingMessage(content, messageId, testChannel) + doAssert wrapped.isOk() + bus.recordWire(senderId, wrapped.get()) + bus.deliverExcept(senderId, wrapped.get(), dropAt) + +proc forceOutgoingExpired( + rm: ReliabilityManager, messageId: SdsMessageID +) = + ## Push a specific outgoingRepairBuffer entry's tReq into the past so the + ## next wrapOutgoingMessage will pick it up. + let channel = rm.channels[testChannel] + if messageId in channel.outgoingRepairBuffer: + channel.outgoingRepairBuffer[messageId].tReq = + getTime() - initDuration(seconds = 1) + +proc forceIncomingExpired( + rm: ReliabilityManager, messageId: SdsMessageID +) = + ## Push an incomingRepairBuffer entry's tResp into the past so runRepairSweep fires it. + let channel = rm.channels[testChannel] + if messageId in channel.incomingRepairBuffer: + channel.incomingRepairBuffer[messageId].tResp = + getTime() - initDuration(seconds = 1) + +suite "SDS-R: Multi-Participant Integration": + + test "basic single-gap repair (Alice -> Bob misses -> Carol's message triggers repair)": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1, but Bob is offline for this one. + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + # Carol now has M1; Bob does not. + check "m1" in carol.channels[testChannel].messageHistory + check "m1" notin bob.channels[testChannel].messageHistory + + # Carol sends M2 with causal history referencing M1. + bus.broadcast("carol", @[byte(2)], "m2") + # Bob detects M1 missing and populates his outgoingRepairBuffer. + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + # Bob should have buffered M2. + check "m2" in bob.channels[testChannel].incomingBuffer + check "m2" notin bus.delivered["bob"] + + # Force Bob's T_req so the next wrap attaches the repair request. + bob.forceOutgoingExpired("m1") + + # Bob sends M3 — it must carry repair_request=[M1, sender=alice]. + bus.broadcast("bob", @[byte(3)], "m3") + + # Alice received M3, saw the repair_request, cached-bypass and response-group + # checks pass, so she has an incomingRepairBuffer entry for M1 with tResp=0. + check "m1" in alice.channels[testChannel].incomingRepairBuffer + + # Force alice's tResp to past just to be safe (it's already 0 for self), + # then run her sweep. She rebroadcasts M1. + alice.forceIncomingExpired("m1") + alice.runRepairSweep() + + # Bob now has M1 and M2 delivered. + check: + "m1" in bus.delivered["bob"] + "m2" in bus.delivered["bob"] + + test "response cancellation: only one rebroadcast on the wire": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1, Bob offline. + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + # Carol sends M2; Bob sees M1 missing. + bus.broadcast("carol", @[byte(2)], "m2") + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + + # Bob requests repair. + bob.forceOutgoingExpired("m1") + bus.broadcast("bob", @[byte(3)], "m3") + + # Both Alice and Carol now have an incomingRepairBuffer entry for M1. + check: + "m1" in alice.channels[testChannel].incomingRepairBuffer + "m1" in carol.channels[testChannel].incomingRepairBuffer + + # Alice fires first (T_resp=0 for self). Her rebroadcast should cancel Carol's + # pending entry when Carol receives the rebroadcast. + alice.forceIncomingExpired("m1") + alice.runRepairSweep() + + # Carol's pending response must have been cleared by the dedup-path cleanup. + check "m1" notin carol.channels[testChannel].incomingRepairBuffer + + # Even if we now force-run Carol's sweep, nothing should fire. + let wireCountBefore = bus.wireLog.len + carol.runRepairSweep() + check bus.wireLog.len == wireCountBefore + + # Bob received exactly one rebroadcast of M1. + var m1RebroadcastCount = 0 + for entry in bus.wireLog: + if entry.messageId == "m1" and entry.senderId != "alice": + discard # only the original Alice->all broadcast had senderId="alice" + if entry.messageId == "m1": + m1RebroadcastCount += 1 + # Two "m1" entries total on wire: (1) Alice's original broadcast, (2) Alice's rebroadcast. + check m1RebroadcastCount == 2 + + test "cancellation on incoming repair request: peer drops its own pending request": + let bus = newTestBus() + let alice = bus.addPeer("alice") + let bob = bus.addPeer("bob") + let carol = bus.addPeer("carol") + + # Alice sends M1 — drop at both Bob and Carol, so both miss it. + bus.broadcast( + "alice", @[byte(1)], "m1", + dropAt = @["bob".SdsParticipantID, "carol".SdsParticipantID], + ) + # Alice sends M2 referencing M1 — both Bob and Carol see M1 missing. + bus.broadcast("alice", @[byte(2)], "m2") + check: + "m1" in bob.channels[testChannel].outgoingRepairBuffer + "m1" in carol.channels[testChannel].outgoingRepairBuffer + + # Bob's T_req fires first. He sends a repair request for M1. + bob.forceOutgoingExpired("m1") + bus.broadcast("bob", @[byte(3)], "m3") + + # Carol, on receiving Bob's repair request, must have dropped her own + # pending outgoingRepairBuffer entry for M1 (cancellation). + check "m1" notin carol.channels[testChannel].outgoingRepairBuffer + + test "response group filtering: only group members respond": + # With numGroups=10, roughly 1/10 of receivers will be in the group. + # Construct a sender+message where a specific non-sender is NOT in the group. + var cfg = defaultConfig() + cfg.numResponseGroups = 10 + + # Pick a msgId where carol is not in the group and bob is + # We probe deterministically because computeTReq/isInResponseGroup are pure. + var chosenMsg = "" + for i in 0 ..< 1000: + let candidate = "probe-" & $i + let bobIn = isInResponseGroup("bob", "alice", candidate, 10) + let carolIn = isInResponseGroup("carol", "alice", candidate, 10) + if bobIn and not carolIn: + chosenMsg = candidate + break + check chosenMsg.len > 0 + + let bus = newTestBus() + discard bus.addPeer("alice", cfg) + let bob = bus.addPeer("bob", cfg) + let carol = bus.addPeer("carol", cfg) + + # Both Bob and Carol receive the original M1 (so both have it in messageCache). + bus.broadcast("alice", @[byte(1)], chosenMsg) + + # Now Dave arrives: build a fake requester message manually so its repair_request + # names Alice as senderId for chosenMsg. + # We inject directly by calling unwrapReceivedMessage on bob/carol. + let dave = bus.addPeer("dave", cfg) + # Dave has no messages, but we can hand-craft a repair request he would send. + let reqMsg = SdsMessage( + messageId: "req-from-dave", + lamportTimestamp: 10, + causalHistory: @[], + channelId: testChannel, + content: @[byte(9)], + bloomFilter: @[], + senderId: "dave", + repairRequest: @[HistoryEntry(messageId: chosenMsg, senderId: "alice")], + ) + let bytes = serializeMessage(reqMsg).get() + discard bob.unwrapReceivedMessage(bytes) + discard carol.unwrapReceivedMessage(bytes) + + check: + chosenMsg in bob.channels[testChannel].incomingRepairBuffer + chosenMsg notin carol.channels[testChannel].incomingRepairBuffer + + test "multi-gap batch repair: many missing deps split across requests": + let bus = newTestBus() + discard bus.addPeer("alice") + let bob = bus.addPeer("bob") + + # Alice sends 5 messages while Bob is offline. + let drops = @["bob".SdsParticipantID] + bus.broadcast("alice", @[byte(1)], "m1", dropAt = drops) + bus.broadcast("alice", @[byte(2)], "m2", dropAt = drops) + bus.broadcast("alice", @[byte(3)], "m3", dropAt = drops) + bus.broadcast("alice", @[byte(4)], "m4", dropAt = drops) + bus.broadcast("alice", @[byte(5)], "m5", dropAt = drops) + + # Bob comes online and receives M6 which depends on m1..m5. + bus.broadcast("alice", @[byte(6)], "m6") + + # Bob should have 5 outgoing repair entries. + let channel = bob.channels[testChannel] + check channel.outgoingRepairBuffer.len == 5 + + # Force all to expired and wrap one message — only maxRepairRequests + # (default 3) should attach to a single outgoing message. + for id in ["m1", "m2", "m3", "m4", "m5"]: + bob.forceOutgoingExpired(id) + + let wrapped = bob.wrapOutgoingMessage(@[byte(99)], "bob-msg-1", testChannel).get() + let decoded = deserializeMessage(wrapped).get() + check decoded.repairRequest.len <= bob.config.maxRepairRequests + + # The attached entries should be removed from the outgoing buffer. + check channel.outgoingRepairBuffer.len == 5 - decoded.repairRequest.len + + test "markDependenciesMet externally clears pending repair entry": + let bus = newTestBus() + discard bus.addPeer("alice") + let bob = bus.addPeer("bob") + + bus.broadcast("alice", @[byte(1)], "m1", dropAt = @["bob".SdsParticipantID]) + bus.broadcast("alice", @[byte(2)], "m2") + check "m1" in bob.channels[testChannel].outgoingRepairBuffer + + # Simulate Bob fetching M1 via an out-of-band store query. + check bob.markDependenciesMet(@["m1"], testChannel).isOk() + check "m1" notin bob.channels[testChannel].outgoingRepairBuffer