From 99121098cc59d6667d955f136c6aee777b22d7a1 Mon Sep 17 00:00:00 2001 From: Akhil <111925100+shash256@users.noreply.github.com> Date: Thu, 10 Jul 2025 22:26:34 +0530 Subject: [PATCH] feat: Implementing Support for Multiple Channels in Single Reliability Manager (#13) --- src/message.nim | 4 +- src/protobuf.nim | 21 ++-- src/reliability.nim | 191 +++++++++++++++++++++---------------- src/reliability_utils.nim | 195 +++++++++++++++++++++++++++++++------- 4 files changed, 282 insertions(+), 129 deletions(-) diff --git a/src/message.nim b/src/message.nim index 4f6640c..f23ad13 100644 --- a/src/message.nim +++ b/src/message.nim @@ -1,4 +1,4 @@ -import std/[times, options, sets] +import std/[times, sets] type SdsMessageID* = string @@ -8,7 +8,7 @@ type messageId*: SdsMessageID lamportTimestamp*: int64 causalHistory*: seq[SdsMessageID] - channelId*: Option[SdsChannelID] + channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] diff --git a/src/protobuf.nim b/src/protobuf.nim index 6a147c8..1f6d600 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -12,8 +12,7 @@ proc encode*(msg: SdsMessage): ProtoBuffer = for hist in msg.causalHistory: pb.write(3, hist) - if msg.channelId.isSome(): - pb.write(4, msg.channelId.get()) + pb.write(4, msg.channelId) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) pb.finish() @@ -37,11 +36,8 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if histResult.isOk: msg.causalHistory = causalHistory - var channelId: SdsChannelID - if ?pb.getField(4, channelId): - msg.channelId = some(channelId) - else: - msg.channelId = none[SdsChannelID]() + if not ?pb.getField(4, msg.channelId): + return err(ProtobufError.missingRequiredField("channelId")) if not ?pb.getField(5, msg.content): return err(ProtobufError.missingRequiredField("content")) @@ -51,6 +47,17 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = ok(msg) +proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = + ## For extraction of channel ID without full message deserialization + try: + let pb = initProtoBuffer(data) + var channelId: SdsChannelID + if not pb.getField(4, channelId).get(): + return err(ReliabilityError.reDeserializationError) + ok(channelId) + except: + err(ReliabilityError.reDeserializationError) + proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = let pb = encode(msg) ok(pb.buffer) diff --git a/src/reliability.nim b/src/reliability.nim index 97a39a9..a39fac3 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -3,31 +3,18 @@ import chronos, results, chronicles import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig() + config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = - ## Creates a new ReliabilityManager with the specified channel ID and configuration. + ## Creates a new multi-channel ReliabilityManager. ## ## Parameters: - ## - channelId: A unique identifier for the communication channel. ## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used. ## ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. - if not channelId.isSome(): - return err(ReliabilityError.reInvalidArgument) - try: - let bloomFilter = - newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate) - let rm = ReliabilityManager( - lamportTimestamp: 0, - messageHistory: @[], - bloomFilter: bloomFilter, - outgoingBuffer: @[], - incomingBuffer: initTable[SdsMessageID, IncomingMessage](), - channelId: channelId, - config: config, + channels: initTable[SdsChannelID, ChannelContext](), config: config ) initLock(rm.lock) return ok(rm) @@ -73,29 +60,37 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = else: rbf = none[RollingBloomFilter]() + if msg.channelId notin rm.channels: + return + + let channel = rm.channels[msg.channelId] # Keep track of indices to delete var toDelete: seq[int] = @[] var i = 0 - while i < rm.outgoingBuffer.len: - let outMsg = rm.outgoingBuffer[i] + while i < channel.outgoingBuffer.len: + let outMsg = channel.outgoingBuffer[i] if outMsg.isAcknowledged(msg.causalHistory, rbf): if not rm.onMessageSent.isNil(): - rm.onMessageSent(outMsg.message.messageId) + rm.onMessageSent(outMsg.message.messageId, outMsg.message.channelId) toDelete.add(i) inc i for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices - rm.outgoingBuffer.delete(toDelete[i]) + channel.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( - rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID + rm: ReliabilityManager, + message: seq[byte], + messageId: SdsMessageID, + channelId: SdsChannelID, ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## ## Parameters: ## - message: The content of the message to be sent. ## - messageId: Unique identifier for the message + ## - channelId: Identifier for the channel this message belongs to. ## ## Returns: ## A Result containing either wrapped message bytes or an error. @@ -106,46 +101,52 @@ proc wrapOutgoingMessage*( withLock rm.lock: try: - rm.updateLamportTimestamp(getTime().toUnix) + let channel = rm.getOrCreateChannel(channelId) + rm.updateLamportTimestamp(getTime().toUnix, channelId) - let bfResult = serializeBloomFilter(rm.bloomFilter.filter) + let bfResult = serializeBloomFilter(channel.bloomFilter.filter) if bfResult.isErr: - error "Failed to serialize bloom filter" + error "Failed to serialize bloom filter", channelId = channelId return err(ReliabilityError.reSerializationError) let msg = SdsMessage( messageId: messageId, - lamportTimestamp: rm.lamportTimestamp, - causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory), - channelId: rm.channelId, + lamportTimestamp: channel.lamportTimestamp, + causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory, channelId), + channelId: channelId, content: message, bloomFilter: bfResult.get(), ) - # Add to outgoing buffer - rm.outgoingBuffer.add( + channel.outgoingBuffer.add( UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0) ) # Add to causal history and bloom filter - rm.bloomFilter.add(msg.messageId) - rm.addToHistory(msg.messageId) + channel.bloomFilter.add(msg.messageId) + rm.addToHistory(msg.messageId, channelId) return serializeMessage(msg) except Exception: - error "Failed to wrap message", msg = getCurrentExceptionMsg() + error "Failed to wrap message", + channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reSerializationError) -proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = +proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe.} = withLock rm.lock: - if rm.incomingBuffer.len == 0: + if channelId notin rm.channels: + error "Channel does not exist", channelId = channelId + return + + let channel = rm.channels[channelId] + if channel.incomingBuffer.len == 0: return var processed = initHashSet[SdsMessageID]() var readyToProcess = newSeq[SdsMessageID]() # Find initially ready messages - for msgId, entry in rm.incomingBuffer: + for msgId, entry in channel.incomingBuffer: if entry.missingDeps.len == 0: readyToProcess.add(msgId) @@ -154,105 +155,114 @@ proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = if msgId in processed: continue - if msgId in rm.incomingBuffer: - rm.addToHistory(msgId) + if msgId in channel.incomingBuffer: + rm.addToHistory(msgId, channelId) if not rm.onMessageReady.isNil(): - rm.onMessageReady(msgId) + rm.onMessageReady(msgId, channelId) processed.incl(msgId) # Update dependencies for remaining messages - for remainingId, entry in rm.incomingBuffer: + for remainingId, entry in channel.incomingBuffer: if remainingId notin processed: if msgId in entry.missingDeps: - rm.incomingBuffer[remainingId].missingDeps.excl(msgId) - if rm.incomingBuffer[remainingId].missingDeps.len == 0: + channel.incomingBuffer[remainingId].missingDeps.excl(msgId) + if channel.incomingBuffer[remainingId].missingDeps.len == 0: readyToProcess.add(remainingId) # Remove processed messages for msgId in processed: - rm.incomingBuffer.del(msgId) + channel.incomingBuffer.del(msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] -): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] = +): Result[ + tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID], + ReliabilityError, +] = ## Unwraps a received message and processes its reliability metadata. ## ## Parameters: ## - message: The received message bytes ## ## Returns: - ## A Result containing either tuple of (processed message, missing dependencies) or an error. + ## A Result containing either tuple of (processed message, missing dependencies, channel ID) or an error. try: + let channelId = extractChannelId(message).valueOr: + return err(ReliabilityError.reDeserializationError) + let msg = deserializeMessage(message).valueOr: return err(ReliabilityError.reDeserializationError) - if msg.messageId in rm.messageHistory: - return ok((msg.content, @[])) + let channel = rm.getOrCreateChannel(channelId) - rm.bloomFilter.add(msg.messageId) + if msg.messageId in channel.messageHistory: + return ok((msg.content, @[], channelId)) - # Update Lamport timestamp - rm.updateLamportTimestamp(msg.lamportTimestamp) + channel.bloomFilter.add(msg.messageId) + rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) # Review ACK status for outgoing messages rm.reviewAckStatus(msg) - var missingDeps = rm.checkDependencies(msg.causalHistory) + var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) if missingDeps.len == 0: - # Check if any dependencies are still in incoming buffer var depsInBuffer = false - for msgId, entry in rm.incomingBuffer.pairs(): + for msgId, entry in channel.incomingBuffer.pairs(): if msgId in msg.causalHistory: depsInBuffer = true break - + # Check if any dependencies are still in incoming buffer if depsInBuffer: - rm.incomingBuffer[msg.messageId] = + channel.incomingBuffer[msg.messageId] = IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]()) else: # All dependencies met, add to history - rm.addToHistory(msg.messageId) - rm.processIncomingBuffer() + rm.addToHistory(msg.messageId, channelId) + rm.processIncomingBuffer(channelId) if not rm.onMessageReady.isNil(): - rm.onMessageReady(msg.messageId) + rm.onMessageReady(msg.messageId, channelId) else: - rm.incomingBuffer[msg.messageId] = + channel.incomingBuffer[msg.messageId] = IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet()) if not rm.onMissingDependencies.isNil(): - rm.onMissingDependencies(msg.messageId, missingDeps) + rm.onMissingDependencies(msg.messageId, missingDeps, channelId) - return ok((msg.content, missingDeps)) + return ok((msg.content, missingDeps, channelId)) except Exception: error "Failed to unwrap message", msg = getCurrentExceptionMsg() return err(ReliabilityError.reDeserializationError) proc markDependenciesMet*( - rm: ReliabilityManager, messageIds: seq[SdsMessageID] + rm: ReliabilityManager, messageIds: seq[SdsMessageID], channelId: SdsChannelID ): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. ## ## Parameters: ## - messageIds: A sequence of message IDs to mark as met. + ## - channelId: Identifier for the channel. ## ## Returns: ## A Result indicating success or an error. try: - # Add all messageIds to bloom filter + if channelId notin rm.channels: + return err(ReliabilityError.reInvalidArgument) + + let channel = rm.channels[channelId] + for msgId in messageIds: - if not rm.bloomFilter.contains(msgId): - rm.bloomFilter.add(msgId) - # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + if not channel.bloomFilter.contains(msgId): + channel.bloomFilter.add(msgId) - # Update any pending messages that depend on this one - for pendingId, entry in rm.incomingBuffer: + for pendingId, entry in channel.incomingBuffer: if msgId in entry.missingDeps: - rm.incomingBuffer[pendingId].missingDeps.excl(msgId) + channel.incomingBuffer[pendingId].missingDeps.excl(msgId) - rm.processIncomingBuffer() + rm.processIncomingBuffer(channelId) return ok() except Exception: - error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg() + error "Failed to mark dependencies as met", + channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reInternalError) proc setCallbacks*( @@ -275,16 +285,22 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} = +proc checkUnacknowledgedMessages( + rm: ReliabilityManager, channelId: SdsChannelID +) {.gcsafe.} = ## Checks and processes unacknowledged messages in the outgoing buffer. withLock rm.lock: + if channelId notin rm.channels: + error "Channel does not exist", channelId = channelId + return + + let channel = rm.channels[channelId] let now = getTime() var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] - for unackMsg in rm.outgoingBuffer: + for unackMsg in channel.outgoingBuffer: let elapsed = now - unackMsg.sendTime if elapsed > rm.config.resendInterval: - # Time to attempt resend if unackMsg.resendAttempts < rm.config.maxResendAttempts: var updatedMsg = unackMsg updatedMsg.resendAttempts += 1 @@ -292,11 +308,11 @@ proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} = newOutgoingBuffer.add(updatedMsg) else: if not rm.onMessageSent.isNil(): - rm.onMessageSent(unackMsg.message.messageId) + rm.onMessageSent(unackMsg.message.messageId, channelId) else: newOutgoingBuffer.add(unackMsg) - rm.outgoingBuffer = newOutgoingBuffer + channel.outgoingBuffer = newOutgoingBuffer proc periodicBufferSweep( rm: ReliabilityManager @@ -304,8 +320,13 @@ proc periodicBufferSweep( ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: try: - rm.checkUnacknowledgedMessages() - rm.cleanBloomFilter() + for channelId, channel in rm.channels: + try: + rm.checkUnacknowledgedMessages(channelId) + rm.cleanBloomFilter(channelId) + except Exception: + error "Error in buffer sweep for channel", + channelId = channelId, msg = getCurrentExceptionMsg() except Exception: error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg() @@ -336,13 +357,15 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE ## This procedure clears all buffers and resets the Lamport timestamp. withLock rm.lock: try: - rm.lamportTimestamp = 0 - rm.messageHistory.setLen(0) - rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.clear() - rm.bloomFilter = newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate - ) + for channelId, channel in rm.channels: + channel.lamportTimestamp = 0 + channel.messageHistory.setLen(0) + channel.outgoingBuffer.setLen(0) + channel.incomingBuffer.clear() + channel.bloomFilter = newRollingBloomFilter( + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate + ) + rm.channels.clear() return ok() except Exception: error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg() diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index 3cc23fa..28248da 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -1,14 +1,17 @@ -import std/[times, locks, options] -import chronicles +import std/[times, locks, tables] +import chronicles, results import ./[rolling_bloom_filter, message] type - MessageReadyCallback* = proc(messageId: SdsMessageID) {.gcsafe.} + MessageReadyCallback* = + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - MessageSentCallback* = proc(messageId: SdsMessageID) {.gcsafe.} + MessageSentCallback* = + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - MissingDependenciesCallback* = - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.} + MissingDependenciesCallback* = proc( + messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID + ) {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} @@ -28,19 +31,22 @@ type syncMessageInterval*: Duration bufferSweepInterval*: Duration - ReliabilityManager* = ref object + ChannelContext* = ref object lamportTimestamp*: int64 messageHistory*: seq[SdsMessageID] bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] incomingBuffer*: Table[SdsMessageID, IncomingMessage] - channelId*: Option[SdsChannelID] + + ReliabilityManager* = ref object + channels*: Table[SdsChannelID, ChannelContext] config*: ReliabilityConfig lock*: Lock - onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.} - onMessageSent*: proc(messageId: SdsMessageID) {.gcsafe.} - onMissingDependencies*: - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.} + onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + onMissingDependencies*: proc( + messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID + ) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback ReliabilityError* {.pure.} = enum @@ -71,51 +77,168 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = if not rm.isNil(): try: withLock rm.lock: - rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.clear() - rm.messageHistory.setLen(0) + for channelId, channel in rm.channels: + channel.outgoingBuffer.setLen(0) + channel.incomingBuffer.clear() + channel.messageHistory.setLen(0) + rm.channels.clear() except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() -proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} = +proc cleanBloomFilter*( + rm: ReliabilityManager, channelId: SdsChannelID +) {.gcsafe, raises: [].} = withLock rm.lock: try: - rm.bloomFilter.clean() + if channelId in rm.channels: + rm.channels[channelId].bloomFilter.clean() except Exception: - error "Failed to clean bloom filter", error = getCurrentExceptionMsg() + error "Failed to clean bloom filter", + error = getCurrentExceptionMsg(), channelId = channelId -proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID) {.gcsafe, raises: [].} = - rm.messageHistory.add(msgId) - if rm.messageHistory.len > rm.config.maxMessageHistory: - rm.messageHistory.delete(0) +proc addToHistory*( + rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsChannelID +) {.gcsafe, raises: [].} = + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + channel.messageHistory.add(msgId) + if channel.messageHistory.len > rm.config.maxMessageHistory: + channel.messageHistory.delete(0) + except Exception: + error "Failed to add to history", + channelId = channelId, msgId = msgId, error = getCurrentExceptionMsg() proc updateLamportTimestamp*( - rm: ReliabilityManager, msgTs: int64 + rm: ReliabilityManager, msgTs: int64, channelId: SdsChannelID ) {.gcsafe, raises: [].} = - rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1 + except Exception: + error "Failed to update lamport timestamp", + channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() -proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] = - result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] +proc getRecentSdsMessageIDs*( + rm: ReliabilityManager, n: int, channelId: SdsChannelID +): seq[SdsMessageID] = + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + result = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] + else: + result = @[] + except Exception: + error "Failed to get recent message IDs", + channelId = channelId, n = n, error = getCurrentExceptionMsg() + result = @[] proc checkDependencies*( - rm: ReliabilityManager, deps: seq[SdsMessageID] + rm: ReliabilityManager, deps: seq[SdsMessageID], channelId: SdsChannelID ): seq[SdsMessageID] = var missingDeps: seq[SdsMessageID] = @[] - for depId in deps: - if depId notin rm.messageHistory: - missingDeps.add(depId) + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + for depId in deps: + if depId notin channel.messageHistory: + missingDeps.add(depId) + else: + missingDeps = deps + except Exception: + error "Failed to check dependencies", + channelId = channelId, error = getCurrentExceptionMsg() + missingDeps = deps return missingDeps -proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] = +proc getMessageHistory*( + rm: ReliabilityManager, channelId: SdsChannelID +): seq[SdsMessageID] = withLock rm.lock: - result = rm.messageHistory + try: + if channelId in rm.channels: + result = rm.channels[channelId].messageHistory + else: + result = @[] + except Exception: + error "Failed to get message history", + channelId = channelId, error = getCurrentExceptionMsg() + result = @[] -proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] = +proc getOutgoingBuffer*( + rm: ReliabilityManager, channelId: SdsChannelID +): seq[UnacknowledgedMessage] = withLock rm.lock: - result = rm.outgoingBuffer + try: + if channelId in rm.channels: + result = rm.channels[channelId].outgoingBuffer + else: + result = @[] + except Exception: + error "Failed to get outgoing buffer", + channelId = channelId, error = getCurrentExceptionMsg() + result = @[] proc getIncomingBuffer*( - rm: ReliabilityManager + rm: ReliabilityManager, channelId: SdsChannelID ): Table[SdsMessageID, message.IncomingMessage] = withLock rm.lock: - result = rm.incomingBuffer + try: + if channelId in rm.channels: + result = rm.channels[channelId].incomingBuffer + else: + result = initTable[SdsMessageID, message.IncomingMessage]() + except Exception: + error "Failed to get incoming buffer", + channelId = channelId, error = getCurrentExceptionMsg() + result = initTable[SdsMessageID, message.IncomingMessage]() + +proc getOrCreateChannel*( + rm: ReliabilityManager, channelId: SdsChannelID +): ChannelContext = + try: + if channelId notin rm.channels: + rm.channels[channelId] = ChannelContext( + lamportTimestamp: 0, + messageHistory: @[], + bloomFilter: newRollingBloomFilter( + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate + ), + outgoingBuffer: @[], + incomingBuffer: initTable[SdsMessageID, IncomingMessage](), + ) + result = rm.channels[channelId] + except Exception: + error "Failed to get or create channel", + channelId = channelId, error = getCurrentExceptionMsg() + raise + +proc ensureChannel*( + rm: ReliabilityManager, channelId: SdsChannelID +): Result[void, ReliabilityError] = + withLock rm.lock: + try: + discard rm.getOrCreateChannel(channelId) + return ok() + except Exception: + error "Failed to ensure channel", + channelId = channelId, msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) + +proc removeChannel*( + rm: ReliabilityManager, channelId: SdsChannelID +): Result[void, ReliabilityError] = + withLock rm.lock: + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + channel.outgoingBuffer.setLen(0) + channel.incomingBuffer.clear() + channel.messageHistory.setLen(0) + rm.channels.del(channelId) + return ok() + except Exception: + error "Failed to remove channel", + channelId = channelId, msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError)