From 2e00fb7c642d042e853c3acb16f8348a2a62a971 Mon Sep 17 00:00:00 2001 From: darshankabariya Date: Mon, 6 Apr 2026 12:54:22 +0530 Subject: [PATCH] feat: initial implementation of SDS-Repair --- library/libsds.nim | 8 +- .../requests/sds_lifecycle_request.nim | 2 +- sds.nim | 109 ++++++- sds/message.nim | 2 + sds/protobuf.nim | 43 ++- sds/sds_utils.nim | 68 ++++- sds/types.nim | 2 + sds/types/app_callbacks.nim | 3 + sds/types/callbacks.nim | 2 + sds/types/channel_context.nim | 13 +- sds/types/history_entry.nim | 10 +- sds/types/reliability_config.nim | 21 ++ sds/types/reliability_manager.nim | 9 +- sds/types/repair_entry.nim | 28 ++ sds/types/sds_message.nim | 6 +- sds/types/sds_message_id.nim | 1 + tests/test_reliability.nim | 274 ++++++++++++++++++ 17 files changed, 577 insertions(+), 24 deletions(-) create mode 100644 sds/types/repair_entry.nim diff --git a/library/libsds.nim b/library/libsds.nim index 4ae285f..af05857 100644 --- a/library/libsds.nim +++ b/library/libsds.nim @@ -16,7 +16,7 @@ import sds, ./events/[ json_message_ready_event, json_message_sent_event, json_missing_dependencies_event, - json_periodic_sync_event, + json_periodic_sync_event, json_repair_ready_event, ] ################################################################################ @@ -114,6 +114,11 @@ proc onPeriodicSync(ctx: ptr SdsContext): PeriodicSyncCallback = callEventCallback(ctx, "onPeriodicSync"): $JsonPeriodicSyncEvent.new() +proc onRepairReady(ctx: ptr SdsContext): RepairReadyCallback = + return proc(message: seq[byte], channelId: SdsChannelID) {.gcsafe.} = + callEventCallback(ctx, "onRepairReady"): + $JsonRepairReadyEvent.new(message, channelId) + proc onRetrievalHint(ctx: ptr SdsContext): RetrievalHintProvider = return proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} = if isNil(ctx.retrievalHintProvider): @@ -196,6 +201,7 @@ proc SdsNewReliabilityManager( missingDependenciesCb: onMissingDependencies(ctx), periodicSyncCb: onPeriodicSync(ctx), retrievalHintProvider: onRetrievalHint(ctx), + repairReadyCb: onRepairReady(ctx), ) let retCode = handleRequest( diff --git a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim index 8d2e9bc..a0f3adb 100644 --- a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim +++ b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim @@ -40,7 +40,7 @@ proc createReliabilityManager( rm.setCallbacks( appCallbacks.messageReadyCb, appCallbacks.messageSentCb, appCallbacks.missingDependenciesCb, appCallbacks.periodicSyncCb, - appCallbacks.retrievalHintProvider, + appCallbacks.retrievalHintProvider, appCallbacks.repairReadyCb, ) return ok(rm) diff --git a/sds.nim b/sds.nim index 58d1893..d9c4b82 100644 --- a/sds.nim +++ b/sds.nim @@ -5,11 +5,12 @@ import sds/[types, protobuf, sds_utils, rolling_bloom_filter] export types, protobuf, sds_utils, rolling_bloom_filter proc newReliabilityManager*( - config: ReliabilityConfig = defaultConfig() + config: ReliabilityConfig = defaultConfig(), + participantId: SdsParticipantID = "", ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new multi-channel ReliabilityManager. try: - let rm = ReliabilityManager.new(config) + let rm = ReliabilityManager.new(config, participantId) return ok(rm) except Exception: error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() @@ -88,6 +89,17 @@ proc wrapOutgoingMessage*( error "Failed to serialize bloom filter", channelId = channelId return err(ReliabilityError.reSerializationError) + # SDS-R: collect eligible expired repair requests to attach + var repairReqs: seq[HistoryEntry] = @[] + let now = getTime() + var expiredKeys: seq[SdsMessageID] = @[] + for msgId, repairEntry in channel.outgoingRepairBuffer: + if now >= repairEntry.tReq and repairReqs.len < rm.config.maxRepairRequests: + repairReqs.add(repairEntry.entry) + expiredKeys.add(msgId) + for key in expiredKeys: + channel.outgoingRepairBuffer.del(key) + let msg = SdsMessage.init( messageId = messageId, lamportTimestamp = channel.lamportTimestamp, @@ -95,6 +107,7 @@ proc wrapOutgoingMessage*( channelId = channelId, content = message, bloomFilter = bfResult.get(), + repairRequest = repairReqs, ) channel.outgoingBuffer.add( @@ -172,6 +185,36 @@ proc unwrapReceivedMessage*( rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) rm.reviewAckStatus(msg) + # SDS-R: remove this message from repair buffers (dependency now met) + channel.outgoingRepairBuffer.del(msg.messageId) + channel.incomingRepairBuffer.del(msg.messageId) + + # SDS-R: cache the raw message for potential repair responses + if channel.messageCache.len < rm.config.maxMessageHistory: + channel.messageCache[msg.messageId] = message + + # SDS-R: process incoming repair requests from this message + let now = getTime() + for repairEntry in msg.repairRequest: + # Remove from our own outgoing repair buffer (someone else is also requesting) + channel.outgoingRepairBuffer.del(repairEntry.messageId) + # Check if we can respond to this repair request + if repairEntry.messageId in channel.messageCache and + rm.participantId.len > 0 and repairEntry.senderId.len > 0: + if isInResponseGroup( + rm.participantId, repairEntry.senderId, + repairEntry.messageId, rm.config.numResponseGroups + ): + let tResp = computeTResp( + rm.participantId, repairEntry.senderId, + repairEntry.messageId, rm.config.repairTMax + ) + channel.incomingRepairBuffer[repairEntry.messageId] = IncomingRepairEntry( + entry: repairEntry, + cachedMessage: channel.messageCache[repairEntry.messageId], + tResp: now + tResp, + ) + var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) if missingDeps.len == 0: @@ -197,6 +240,19 @@ proc unwrapReceivedMessage*( if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps, channelId) + # SDS-R: add missing deps to outgoing repair buffer + if rm.participantId.len > 0: + for dep in missingDeps: + if dep.messageId notin channel.outgoingRepairBuffer: + let tReq = computeTReq( + rm.participantId, dep.messageId, + rm.config.repairTMin, rm.config.repairTMax + ) + channel.outgoingRepairBuffer[dep.messageId] = OutgoingRepairEntry( + entry: dep, + tReq: now + tReq, + ) + return ok((msg.content, missingDeps, channelId)) except Exception: error "Failed to unwrap message", msg = getCurrentExceptionMsg() @@ -220,6 +276,10 @@ proc markDependenciesMet*( if msgId in entry.missingDeps: channel.incomingBuffer[pendingId].missingDeps.excl(msgId) + # SDS-R: clear from repair buffers (dependency now met) + channel.outgoingRepairBuffer.del(msgId) + channel.incomingRepairBuffer.del(msgId) + rm.processIncomingBuffer(channelId) return ok() except Exception: @@ -234,6 +294,7 @@ proc setCallbacks*( onMissingDependencies: MissingDependenciesCallback, onPeriodicSync: PeriodicSyncCallback = nil, onRetrievalHint: RetrievalHintProvider = nil, + onRepairReady: RepairReadyCallback = nil, ) = ## Sets the callback functions for various events in the ReliabilityManager. withLock rm.lock: @@ -242,6 +303,7 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync rm.onRetrievalHint = onRetrievalHint + rm.onRepairReady = onRepairReady proc checkUnacknowledgedMessages( rm: ReliabilityManager, channelId: SdsChannelID @@ -299,10 +361,50 @@ proc periodicSyncMessage( error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) +proc periodicRepairSweep( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = + ## SDS-R: Periodically checks repair buffers for expired entries. + ## - Incoming: fires onRepairReady for expired T_resp entries + ## - Outgoing: drops entries past T_max + while true: + try: + let now = getTime() + for channelId, channel in rm.channels: + try: + # Check incoming repair buffer for expired T_resp (time to rebroadcast) + var toRebroadcast: seq[SdsMessageID] = @[] + for msgId, entry in channel.incomingRepairBuffer: + if now >= entry.tResp: + toRebroadcast.add(msgId) + + for msgId in toRebroadcast: + let entry = channel.incomingRepairBuffer[msgId] + channel.incomingRepairBuffer.del(msgId) + if not rm.onRepairReady.isNil(): + rm.onRepairReady(entry.cachedMessage, channelId) + + # Drop expired outgoing repair entries past T_max + var toRemove: seq[SdsMessageID] = @[] + let tMaxDuration = rm.config.repairTMax + for msgId, entry in channel.outgoingRepairBuffer: + if now - entry.tReq > tMaxDuration: + toRemove.add(msgId) + for msgId in toRemove: + channel.outgoingRepairBuffer.del(msgId) + except Exception: + error "Error in repair sweep for channel", + channelId = channelId, msg = getCurrentExceptionMsg() + except Exception: + error "Error in periodic repair sweep", msg = getCurrentExceptionMsg() + + await sleepAsync(chronos.milliseconds(rm.config.repairSweepInterval.inMilliseconds)) + proc startPeriodicTasks*(rm: ReliabilityManager) = ## Starts the periodic tasks for buffer sweeping and sync message sending. asyncSpawn rm.periodicBufferSweep() asyncSpawn rm.periodicSyncMessage() + asyncSpawn rm.periodicRepairSweep() proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = ## Resets the ReliabilityManager to its initial state. @@ -313,6 +415,9 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE channel.messageHistory.setLen(0) channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() + channel.messageCache.clear() channel.bloomFilter = RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) rm.channels.clear() diff --git a/sds/message.nim b/sds/message.nim index ddf5e5f..410fc43 100644 --- a/sds/message.nim +++ b/sds/message.nim @@ -3,6 +3,7 @@ import ./types/history_entry import ./types/sds_message import ./types/unacknowledged_message import ./types/incoming_message +import ./types/repair_entry import ./types/reliability_config export @@ -11,4 +12,5 @@ export sds_message, unacknowledged_message, incoming_message, + repair_entry, reliability_config diff --git a/sds/protobuf.nim b/sds/protobuf.nim index ba1b7ff..63830c7 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -5,6 +5,24 @@ import ./protobufutil import ./bloom import ./sds_utils +proc encodeHistoryEntry*(entry: HistoryEntry): ProtoBuffer = + var entryPb = initProtoBuffer() + entryPb.write(1, entry.messageId) + if entry.retrievalHint.len > 0: + entryPb.write(2, entry.retrievalHint) + if entry.senderId.len > 0: + entryPb.write(3, entry.senderId) + entryPb.finish() + entryPb + +proc decodeHistoryEntry*(entryPb: ProtoBuffer): ProtobufResult[HistoryEntry] = + var entry = HistoryEntry.init("") + if not ?entryPb.getField(1, entry.messageId): + return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) + discard entryPb.getField(2, entry.retrievalHint) + discard entryPb.getField(3, entry.senderId) + ok(entry) + proc encode*(msg: SdsMessage): ProtoBuffer = var pb = initProtoBuffer() @@ -12,16 +30,17 @@ proc encode*(msg: SdsMessage): ProtoBuffer = pb.write(2, uint64(msg.lamportTimestamp)) for entry in msg.causalHistory: - var entryPb = initProtoBuffer() - entryPb.write(1, entry.messageId) - if entry.retrievalHint.len > 0: - entryPb.write(2, entry.retrievalHint) - entryPb.finish() + let entryPb = encodeHistoryEntry(entry) pb.write(3, entryPb.buffer) pb.write(4, msg.channelId) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) + + for entry in msg.repairRequest: + let entryPb = encodeHistoryEntry(entry) + pb.write(13, entryPb.buffer) + pb.finish() return pb @@ -44,11 +63,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = # New format: repeated HistoryEntry for histBuffer in historyBuffers: let entryPb = initProtoBuffer(histBuffer) - var entry = HistoryEntry.init("") - if not ?entryPb.getField(1, entry.messageId): - return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) - # retrievalHint is optional - discard entryPb.getField(2, entry.retrievalHint) + let entry = ?decodeHistoryEntry(entryPb) msg.causalHistory.add(entry) else: # Try old format: repeated string @@ -66,6 +81,14 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if not ?pb.getField(6, msg.bloomFilter): msg.bloomFilter = @[] # Empty if not present + # SDS-R: decode repair request (field 13, optional) + var repairBuffers: seq[seq[byte]] + if pb.getRepeatedField(13, repairBuffers).isOk(): + for repairBuffer in repairBuffers: + let entryPb = initProtoBuffer(repairBuffer) + let entry = ?decodeHistoryEntry(entryPb) + msg.repairRequest.add(entry) + return ok(msg) proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index f1a68ca..f979b3e 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -1,15 +1,15 @@ -import std/[locks, tables, sequtils] +import std/[times, locks, tables, sequtils, hashes] import chronicles, results import ./rolling_bloom_filter import ./types/[ sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, - reliability_error, callbacks, app_callbacks, reliability_config, channel_context, - reliability_manager, + reliability_error, callbacks, app_callbacks, reliability_config, repair_entry, + channel_context, reliability_manager, ] export sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, - reliability_error, callbacks, app_callbacks, reliability_config, channel_context, - reliability_manager + reliability_error, callbacks, app_callbacks, reliability_config, repair_entry, + channel_context, reliability_manager proc defaultConfig*(): ReliabilityConfig = return ReliabilityConfig.init() @@ -22,6 +22,9 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() channel.messageHistory.setLen(0) + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() + channel.messageCache.clear() rm.channels.clear() except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() @@ -70,6 +73,58 @@ proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] = proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] = return causalHistory.mapIt(it.messageId) +## SDS-R: Repair computation functions + +proc computeTReq*( + participantId: SdsParticipantID, + messageId: SdsMessageID, + tMin: Duration, + tMax: Duration, +): Duration = + ## Computes the repair request backoff duration per SDS-R spec: + ## T_req = hash(participant_id, message_id) % (T_max - T_min) + T_min + let h = abs(hash(participantId & messageId)) + let rangeMs = tMax.inMilliseconds - tMin.inMilliseconds + if rangeMs <= 0: + return tMin + let offsetMs = h mod rangeMs + initDuration(milliseconds = tMin.inMilliseconds + offsetMs) + +proc computeTResp*( + participantId: SdsParticipantID, + senderId: SdsParticipantID, + messageId: SdsMessageID, + tMax: Duration, +): Duration = + ## Computes the repair response backoff duration per SDS-R spec: + ## distance = hash(participant_id) XOR hash(sender_id) + ## T_resp = distance * hash(message_id) % T_max + ## Original sender has distance=0, so T_resp=0 (responds immediately). + let distance = abs(hash(participantId) xor hash(senderId)) + let msgHash = abs(hash(messageId)) + let tMaxMs = tMax.inMilliseconds + if tMaxMs <= 0 or distance == 0: + return initDuration(milliseconds = 0) + # Use uint64 to avoid overflow on multiplication + let d = uint64(distance mod tMaxMs) + let m = uint64(msgHash mod tMaxMs) + let offsetMs = int64((d * m) mod uint64(tMaxMs)) + initDuration(milliseconds = offsetMs) + +proc isInResponseGroup*( + participantId: SdsParticipantID, + senderId: SdsParticipantID, + messageId: SdsMessageID, + numResponseGroups: int, +): bool = + ## Determines if this participant is in the response group for a given message per SDS-R spec: + ## hash(participant_id, message_id) % num_groups == hash(sender_id, message_id) % num_groups + if numResponseGroups <= 1: + return true # All participants in the same group + let myGroup = abs(hash(participantId & messageId)) mod numResponseGroups + let senderGroup = abs(hash(senderId & messageId)) mod numResponseGroups + myGroup == senderGroup + proc getRecentHistoryEntries*( rm: ReliabilityManager, n: int, channelId: SdsChannelID ): seq[HistoryEntry] = @@ -188,6 +243,9 @@ proc removeChannel*( channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() channel.messageHistory.setLen(0) + channel.outgoingRepairBuffer.clear() + channel.incomingRepairBuffer.clear() + channel.messageCache.clear() rm.channels.del(channelId) return ok() except Exception: diff --git a/sds/types.nim b/sds/types.nim index f37518a..637ec54 100644 --- a/sds/types.nim +++ b/sds/types.nim @@ -9,6 +9,7 @@ import sds/types/reliability_error import sds/types/callbacks import sds/types/app_callbacks import sds/types/reliability_config +import sds/types/repair_entry import sds/types/channel_context import sds/types/reliability_manager import sds/types/protobuf_error @@ -25,6 +26,7 @@ export callbacks, app_callbacks, reliability_config, + repair_entry, channel_context, reliability_manager, protobuf_error diff --git a/sds/types/app_callbacks.nim b/sds/types/app_callbacks.nim index 985a97f..84873a6 100644 --- a/sds/types/app_callbacks.nim +++ b/sds/types/app_callbacks.nim @@ -7,6 +7,7 @@ type AppCallbacks* = ref object missingDependenciesCb*: MissingDependenciesCallback periodicSyncCb*: PeriodicSyncCallback retrievalHintProvider*: RetrievalHintProvider + repairReadyCb*: RepairReadyCallback proc new*( T: type AppCallbacks, @@ -15,6 +16,7 @@ proc new*( missingDependenciesCb: MissingDependenciesCallback = nil, periodicSyncCb: PeriodicSyncCallback = nil, retrievalHintProvider: RetrievalHintProvider = nil, + repairReadyCb: RepairReadyCallback = nil, ): T = return T( messageReadyCb: messageReadyCb, @@ -22,4 +24,5 @@ proc new*( missingDependenciesCb: missingDependenciesCb, periodicSyncCb: periodicSyncCb, retrievalHintProvider: retrievalHintProvider, + repairReadyCb: repairReadyCb, ) diff --git a/sds/types/callbacks.nim b/sds/types/callbacks.nim index f1fc4b3..1894f14 100644 --- a/sds/types/callbacks.nim +++ b/sds/types/callbacks.nim @@ -16,3 +16,5 @@ type RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} + + RepairReadyCallback* = proc(message: seq[byte], channelId: SdsChannelID) {.gcsafe.} diff --git a/sds/types/channel_context.nim b/sds/types/channel_context.nim index 0346d18..cec11dc 100644 --- a/sds/types/channel_context.nim +++ b/sds/types/channel_context.nim @@ -3,7 +3,10 @@ import ./sds_message_id import ./rolling_bloom_filter import ./unacknowledged_message import ./incoming_message -export sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message +import ./repair_entry +export + sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message, + repair_entry type ChannelContext* = ref object lamportTimestamp*: int64 @@ -11,6 +14,11 @@ type ChannelContext* = ref object bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] incomingBuffer*: Table[SdsMessageID, IncomingMessage] + ## SDS-R buffers + outgoingRepairBuffer*: Table[SdsMessageID, OutgoingRepairEntry] + incomingRepairBuffer*: Table[SdsMessageID, IncomingRepairEntry] + messageCache*: Table[SdsMessageID, seq[byte]] + ## Cached serialized messages for repair responses proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = return T( @@ -19,4 +27,7 @@ proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = bloomFilter: bloomFilter, outgoingBuffer: @[], incomingBuffer: initTable[SdsMessageID, IncomingMessage](), + outgoingRepairBuffer: initTable[SdsMessageID, OutgoingRepairEntry](), + incomingRepairBuffer: initTable[SdsMessageID, IncomingRepairEntry](), + messageCache: initTable[SdsMessageID, seq[byte]](), ) diff --git a/sds/types/history_entry.nim b/sds/types/history_entry.nim index 2435e6f..b55fc20 100644 --- a/sds/types/history_entry.nim +++ b/sds/types/history_entry.nim @@ -3,6 +3,12 @@ import ./sds_message_id type HistoryEntry* = object messageId*: SdsMessageID retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash) + senderId*: string ## Original message sender's participant ID (SDS-R) -proc init*(T: type HistoryEntry, messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): T = - return T(messageId: messageId, retrievalHint: retrievalHint) +proc init*( + T: type HistoryEntry, + messageId: SdsMessageID, + retrievalHint: seq[byte] = @[], + senderId: string = "", +): T = + return T(messageId: messageId, retrievalHint: retrievalHint, senderId: senderId) diff --git a/sds/types/reliability_config.nim b/sds/types/reliability_config.nim index f4e4e78..7cd20f2 100644 --- a/sds/types/reliability_config.nim +++ b/sds/types/reliability_config.nim @@ -7,6 +7,11 @@ const DefaultMaxResendAttempts* = 5 DefaultSyncMessageInterval* = initDuration(seconds = 30) DefaultBufferSweepInterval* = initDuration(seconds = 60) + DefaultRepairTMin* = initDuration(seconds = 30) + DefaultRepairTMax* = initDuration(seconds = 300) + DefaultNumResponseGroups* = 1 + DefaultMaxRepairRequests* = 3 + DefaultRepairSweepInterval* = initDuration(seconds = 5) MaxMessageSize* = 1024 * 1024 # 1 MB import ./rolling_bloom_filter @@ -21,6 +26,12 @@ type ReliabilityConfig* {.requiresInit.} = object maxResendAttempts*: int syncMessageInterval*: Duration bufferSweepInterval*: Duration + ## SDS-R config + repairTMin*: Duration + repairTMax*: Duration + numResponseGroups*: int + maxRepairRequests*: int + repairSweepInterval*: Duration proc init*( T: type ReliabilityConfig, @@ -32,6 +43,11 @@ proc init*( maxResendAttempts: int = DefaultMaxResendAttempts, syncMessageInterval: Duration = DefaultSyncMessageInterval, bufferSweepInterval: Duration = DefaultBufferSweepInterval, + repairTMin: Duration = DefaultRepairTMin, + repairTMax: Duration = DefaultRepairTMax, + numResponseGroups: int = DefaultNumResponseGroups, + maxRepairRequests: int = DefaultMaxRepairRequests, + repairSweepInterval: Duration = DefaultRepairSweepInterval, ): T = return T( bloomFilterCapacity: bloomFilterCapacity, @@ -42,4 +58,9 @@ proc init*( maxResendAttempts: maxResendAttempts, syncMessageInterval: syncMessageInterval, bufferSweepInterval: bufferSweepInterval, + repairTMin: repairTMin, + repairTMax: repairTMax, + numResponseGroups: numResponseGroups, + maxRepairRequests: maxRepairRequests, + repairSweepInterval: repairSweepInterval, ) diff --git a/sds/types/reliability_manager.nim b/sds/types/reliability_manager.nim index 9bfc244..5545859 100644 --- a/sds/types/reliability_manager.nim +++ b/sds/types/reliability_manager.nim @@ -9,6 +9,7 @@ export sds_message_id, history_entry, callbacks, reliability_config, channel_con type ReliabilityManager* = ref object channels*: Table[SdsChannelID, ChannelContext] config*: ReliabilityConfig + participantId*: SdsParticipantID lock*: Lock onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} @@ -17,11 +18,17 @@ type ReliabilityManager* = ref object ) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback onRetrievalHint*: RetrievalHintProvider + onRepairReady*: RepairReadyCallback -proc new*(T: type ReliabilityManager, config: ReliabilityConfig): T = +proc new*( + T: type ReliabilityManager, + config: ReliabilityConfig, + participantId: SdsParticipantID = "", +): T = let rm = T( channels: initTable[SdsChannelID, ChannelContext](), config: config, + participantId: participantId, ) rm.lock.initLock() return rm diff --git a/sds/types/repair_entry.nim b/sds/types/repair_entry.nim new file mode 100644 index 0000000..01f0fd5 --- /dev/null +++ b/sds/types/repair_entry.nim @@ -0,0 +1,28 @@ +import std/times +import ./history_entry +export history_entry + +type + OutgoingRepairEntry* = object + ## Entry in the outgoing repair request buffer (SDS-R). + ## Tracks a missing message we want to request repair for. + entry*: HistoryEntry ## The missing history entry + tReq*: Time ## Timestamp after which we will include this in a repair request + + IncomingRepairEntry* = object + ## Entry in the incoming repair request buffer (SDS-R). + ## Tracks a repair request from a remote peer that we might respond to. + entry*: HistoryEntry ## The requested history entry + cachedMessage*: seq[byte] ## Full serialized SDS message for rebroadcast + tResp*: Time ## Timestamp after which we will rebroadcast + +proc init*(T: type OutgoingRepairEntry, entry: HistoryEntry, tReq: Time): T = + return T(entry: entry, tReq: tReq) + +proc init*( + T: type IncomingRepairEntry, + entry: HistoryEntry, + cachedMessage: seq[byte], + tResp: Time, +): T = + return T(entry: entry, cachedMessage: cachedMessage, tResp: tResp) diff --git a/sds/types/sds_message.nim b/sds/types/sds_message.nim index 12f7add..b50380c 100644 --- a/sds/types/sds_message.nim +++ b/sds/types/sds_message.nim @@ -2,13 +2,15 @@ import ./sds_message_id import ./history_entry export sds_message_id, history_entry -type SdsMessage* {.requiresInit.} = object +type SdsMessage* = object messageId*: SdsMessageID lamportTimestamp*: int64 causalHistory*: seq[HistoryEntry] channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] + repairRequest*: seq[HistoryEntry] + ## Capped list of missing entries requesting repair (SDS-R) proc init*( T: type SdsMessage, @@ -18,6 +20,7 @@ proc init*( channelId: SdsChannelID, content: seq[byte], bloomFilter: seq[byte], + repairRequest: seq[HistoryEntry] = @[], ): T = return T( messageId: messageId, @@ -26,4 +29,5 @@ proc init*( channelId: channelId, content: content, bloomFilter: bloomFilter, + repairRequest: repairRequest, ) diff --git a/sds/types/sds_message_id.nim b/sds/types/sds_message_id.nim index 3e8b7c7..dfeb025 100644 --- a/sds/types/sds_message_id.nim +++ b/sds/types/sds_message_id.nim @@ -1,3 +1,4 @@ type SdsMessageID* = string SdsChannelID* = string + SdsParticipantID* = string diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 7100606..aa0eb06 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -741,3 +741,277 @@ suite "Multi-Channel ReliabilityManager Tests": # Dependencies in channel1 should not affect channel2 check rm.channels[channel1].bloomFilter.contains("dep1") check not rm.channels[channel2].bloomFilter.contains("dep1") + +# SDS-R Repair tests +suite "SDS-R: Computation Functions": + test "computeTReq returns duration in [tMin, tMax)": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d = computeTReq("participant1", "msg1", tMin, tMax) + check: + d.inMilliseconds >= tMin.inMilliseconds + d.inMilliseconds < tMax.inMilliseconds + + test "computeTReq is deterministic for same inputs": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d1 = computeTReq("p1", "m1", tMin, tMax) + let d2 = computeTReq("p1", "m1", tMin, tMax) + check d1 == d2 + + test "computeTReq varies with different participants": + let tMin = initDuration(seconds = 30) + let tMax = initDuration(seconds = 300) + let d1 = computeTReq("participant-A", "msg1", tMin, tMax) + let d2 = computeTReq("participant-B", "msg1", tMin, tMax) + # Different participants should generally get different backoff (not guaranteed but highly likely) + # Just check both are in valid range + check: + d1.inMilliseconds >= tMin.inMilliseconds + d2.inMilliseconds >= tMin.inMilliseconds + + test "computeTResp original sender has zero distance": + let d = computeTResp("sender1", "sender1", "msg1", initDuration(seconds = 300)) + check d.inMilliseconds == 0 + + test "computeTResp non-sender has positive backoff": + let d = computeTResp("other-node", "sender1", "msg1", initDuration(seconds = 300)) + check d.inMilliseconds >= 0 + + test "isInResponseGroup all in same group when numGroups=1": + check isInResponseGroup("p1", "sender1", "msg1", 1) == true + check isInResponseGroup("p2", "sender1", "msg1", 1) == true + + test "isInResponseGroup sender always in own group": + # Original sender must always be in their own response group + for groups in 1 .. 10: + check isInResponseGroup("sender1", "sender1", "msg1", groups) == true + +suite "SDS-R: Repair Buffer Management": + var rm: ReliabilityManager + + setup: + let rmResult = newReliabilityManager( + participantId = "test-participant" + ) + check rmResult.isOk() + rm = rmResult.get() + check rm.ensureChannel(testChannel).isOk() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "missing deps added to outgoing repair buffer": + var missingDepsCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + missingDepsCount += 1, + ) + + # Create a message with a missing dependency + let msg = SdsMessage( + messageId: "msg2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + + let serialized = serializeMessage(msg).get() + let result = rm.unwrapReceivedMessage(serialized) + check result.isOk() + + # msg1 should be in the outgoing repair buffer + let channel = rm.channels[testChannel] + check: + missingDepsCount == 1 + "msg1" in channel.outgoingRepairBuffer + + test "receiving message clears it from repair buffers": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + # First, create the missing dep scenario + let msg2 = SdsMessage( + messageId: "msg2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg2).get()) + check "msg1" in rm.channels[testChannel].outgoingRepairBuffer + + # Now receive msg1 — should clear from repair buffer + let msg1 = SdsMessage( + messageId: "msg1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: testChannel, + content: @[byte(1)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg1).get()) + check "msg1" notin rm.channels[testChannel].outgoingRepairBuffer + + test "markDependenciesMet clears repair buffers": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + let msg2 = SdsMessage( + messageId: "msg2", + lamportTimestamp: 2, + causalHistory: @[HistoryEntry(messageId: "msg1", senderId: "sender-A")], + channelId: testChannel, + content: @[byte(2)], + bloomFilter: @[], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msg2).get()) + check "msg1" in rm.channels[testChannel].outgoingRepairBuffer + + # Mark as met via store retrieval + check rm.markDependenciesMet(@["msg1"], testChannel).isOk() + check "msg1" notin rm.channels[testChannel].outgoingRepairBuffer + + test "expired repair requests attached to outgoing messages": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + # Manually add an expired repair entry + let channel = rm.channels[testChannel] + channel.outgoingRepairBuffer["missing-msg"] = OutgoingRepairEntry( + entry: HistoryEntry(messageId: "missing-msg", senderId: "orig-sender"), + tReq: getTime() - initDuration(seconds = 10), # Already expired + ) + + # Send a message — should pick up the expired repair request + let wrapped = rm.wrapOutgoingMessage(@[byte(1)], "new-msg", testChannel) + check wrapped.isOk() + + let unwrapped = deserializeMessage(wrapped.get()).get() + check: + unwrapped.repairRequest.len == 1 + unwrapped.repairRequest[0].messageId == "missing-msg" + # Should be removed from buffer after attaching + "missing-msg" notin channel.outgoingRepairBuffer + + test "incoming repair request adds to incoming repair buffer when eligible": + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, + ) + + let channel = rm.channels[testChannel] + + # First, cache a message so we can respond to a repair request for it + let cachedMsg = SdsMessage( + messageId: "cached-msg", + lamportTimestamp: 1, + causalHistory: @[], + channelId: testChannel, + content: @[byte(99)], + bloomFilter: @[], + ) + let cachedBytes = serializeMessage(cachedMsg).get() + channel.messageCache["cached-msg"] = cachedBytes + + # Receive a message with a repair request for "cached-msg" + let msgWithRepair = SdsMessage( + messageId: "requester-msg", + lamportTimestamp: 5, + causalHistory: @[], + channelId: testChannel, + content: @[byte(3)], + bloomFilter: @[], + repairRequest: @[HistoryEntry( + messageId: "cached-msg", + senderId: "test-participant", # Same as our participantId so we're in response group + )], + ) + discard rm.unwrapReceivedMessage(serializeMessage(msgWithRepair).get()) + + # We should have added it to the incoming repair buffer (we have the message and are in response group) + check "cached-msg" in channel.incomingRepairBuffer + +suite "SDS-R: Protobuf Roundtrip": + test "senderId in HistoryEntry roundtrips through protobuf": + let msg = SdsMessage( + messageId: "msg1", + lamportTimestamp: 100, + causalHistory: @[ + HistoryEntry(messageId: "dep1", retrievalHint: @[byte(1), 2], senderId: "sender-A"), + HistoryEntry(messageId: "dep2", senderId: "sender-B"), + ], + channelId: "ch1", + content: @[byte(42)], + bloomFilter: @[], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.causalHistory.len == 2 + decoded.causalHistory[0].messageId == "dep1" + decoded.causalHistory[0].senderId == "sender-A" + decoded.causalHistory[0].retrievalHint == @[byte(1), 2] + decoded.causalHistory[1].messageId == "dep2" + decoded.causalHistory[1].senderId == "sender-B" + + test "repairRequest field roundtrips through protobuf": + let msg = SdsMessage( + messageId: "msg1", + lamportTimestamp: 100, + causalHistory: @[], + channelId: "ch1", + content: @[byte(42)], + bloomFilter: @[], + repairRequest: @[ + HistoryEntry(messageId: "missing1", senderId: "sender-X"), + HistoryEntry(messageId: "missing2", senderId: "sender-Y", retrievalHint: @[byte(5)]), + ], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.repairRequest.len == 2 + decoded.repairRequest[0].messageId == "missing1" + decoded.repairRequest[0].senderId == "sender-X" + decoded.repairRequest[1].messageId == "missing2" + decoded.repairRequest[1].senderId == "sender-Y" + decoded.repairRequest[1].retrievalHint == @[byte(5)] + + test "backward compat: message without repairRequest decodes fine": + let msg = SdsMessage( + messageId: "msg1", + lamportTimestamp: 100, + causalHistory: @[HistoryEntry(messageId: "dep1")], + channelId: "ch1", + content: @[byte(42)], + bloomFilter: @[], + ) + + let serialized = serializeMessage(msg).get() + let decoded = deserializeMessage(serialized).get() + + check: + decoded.repairRequest.len == 0 + decoded.causalHistory[0].senderId == ""