From 248ddaf555771d32357aca563c096d61a56160c6 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:11:28 +0530 Subject: [PATCH] feat: add updates to core api --- src/message.nim | 4 +- src/protobuf.nim | 21 ++- src/reliability.nim | 283 ++++++++++++++++++-------------------- src/reliability_utils.nim | 158 ++++++++++++++++----- 4 files changed, 277 insertions(+), 189 deletions(-) diff --git a/src/message.nim b/src/message.nim index 4f6640c..f23ad13 100644 --- a/src/message.nim +++ b/src/message.nim @@ -1,4 +1,4 @@ -import std/[times, options, sets] +import std/[times, sets] type SdsMessageID* = string @@ -8,7 +8,7 @@ type messageId*: SdsMessageID lamportTimestamp*: int64 causalHistory*: seq[SdsMessageID] - channelId*: Option[SdsChannelID] + channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] diff --git a/src/protobuf.nim b/src/protobuf.nim index 6a147c8..1f6d600 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -12,8 +12,7 @@ proc encode*(msg: SdsMessage): ProtoBuffer = for hist in msg.causalHistory: pb.write(3, hist) - if msg.channelId.isSome(): - pb.write(4, msg.channelId.get()) + pb.write(4, msg.channelId) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) pb.finish() @@ -37,11 +36,8 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if histResult.isOk: msg.causalHistory = causalHistory - var channelId: SdsChannelID - if ?pb.getField(4, channelId): - msg.channelId = some(channelId) - else: - msg.channelId = none[SdsChannelID]() + if not ?pb.getField(4, msg.channelId): + return err(ProtobufError.missingRequiredField("channelId")) if not ?pb.getField(5, msg.content): return err(ProtobufError.missingRequiredField("content")) @@ -51,6 +47,17 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = ok(msg) +proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = + ## For extraction of channel ID without full message deserialization + try: + let pb = initProtoBuffer(data) + var channelId: SdsChannelID + if not pb.getField(4, channelId).get(): + return err(ReliabilityError.reDeserializationError) + ok(channelId) + except: + err(ReliabilityError.reDeserializationError) + proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = let pb = encode(msg) ok(pb.buffer) diff --git a/src/reliability.nim b/src/reliability.nim index 97a39a9..38be506 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -3,30 +3,18 @@ import chronos, results, chronicles import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig() + config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = - ## Creates a new ReliabilityManager with the specified channel ID and configuration. + ## Creates a new multi-channel ReliabilityManager. ## ## Parameters: - ## - channelId: A unique identifier for the communication channel. ## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used. ## ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. - if not channelId.isSome(): - return err(ReliabilityError.reInvalidArgument) - try: - let bloomFilter = - newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate) - let rm = ReliabilityManager( - lamportTimestamp: 0, - messageHistory: @[], - bloomFilter: bloomFilter, - outgoingBuffer: @[], - incomingBuffer: initTable[SdsMessageID, IncomingMessage](), - channelId: channelId, + channels: initTable[SdsChannelID, ChannelContext](), config: config, ) initLock(rm.lock) @@ -48,7 +36,7 @@ proc isAcknowledged*( false -proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = +proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage, channelId: SdsChannelID) {.gcsafe.} = # Parse bloom filter var rbf: Option[RollingBloomFilter] if msg.bloomFilter.len > 0: @@ -73,23 +61,27 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = else: rbf = none[RollingBloomFilter]() + if channelId notin rm.channels: + return + + let channel = rm.channels[channelId] # Keep track of indices to delete var toDelete: seq[int] = @[] var i = 0 - while i < rm.outgoingBuffer.len: - let outMsg = rm.outgoingBuffer[i] + while i < channel.outgoingBuffer.len: + let outMsg = channel.outgoingBuffer[i] if outMsg.isAcknowledged(msg.causalHistory, rbf): if not rm.onMessageSent.isNil(): - rm.onMessageSent(outMsg.message.messageId) + rm.onMessageSent(outMsg.message.messageId, channelId) toDelete.add(i) inc i for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices - rm.outgoingBuffer.delete(toDelete[i]) + channel.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( - rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID + rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID, channelId: SdsChannelID ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## @@ -106,153 +98,144 @@ proc wrapOutgoingMessage*( withLock rm.lock: try: - rm.updateLamportTimestamp(getTime().toUnix) + let channel = rm.getOrCreateChannel(channelId) + rm.updateLamportTimestamp(getTime().toUnix, channelId) - let bfResult = serializeBloomFilter(rm.bloomFilter.filter) + let bfResult = serializeBloomFilter(channel.bloomFilter.filter) if bfResult.isErr: - error "Failed to serialize bloom filter" + error "Failed to serialize bloom filter", channelId = channelId return err(ReliabilityError.reSerializationError) let msg = SdsMessage( messageId: messageId, - lamportTimestamp: rm.lamportTimestamp, - causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory), - channelId: rm.channelId, + lamportTimestamp: channel.lamportTimestamp, + causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory, channelId), + channelId: channelId, content: message, bloomFilter: bfResult.get(), ) - # Add to outgoing buffer - rm.outgoingBuffer.add( + channel.outgoingBuffer.add( UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0) ) # Add to causal history and bloom filter - rm.bloomFilter.add(msg.messageId) - rm.addToHistory(msg.messageId) + channel.bloomFilter.add(msg.messageId) + rm.addToHistory(msg.messageId, channelId) return serializeMessage(msg) except Exception: - error "Failed to wrap message", msg = getCurrentExceptionMsg() + error "Failed to wrap message", channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reSerializationError) -proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = - withLock rm.lock: - if rm.incomingBuffer.len == 0: - return +proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe.} = + if channelId notin rm.channels: + return - var processed = initHashSet[SdsMessageID]() - var readyToProcess = newSeq[SdsMessageID]() + let channel = rm.channels[channelId] + if channel.incomingBuffer.len == 0: + return - # Find initially ready messages - for msgId, entry in rm.incomingBuffer: - if entry.missingDeps.len == 0: - readyToProcess.add(msgId) + var processed = initHashSet[SdsMessageID]() + var readyToProcess = newSeq[SdsMessageID]() - while readyToProcess.len > 0: - let msgId = readyToProcess.pop() - if msgId in processed: - continue + for msgId, entry in channel.incomingBuffer: + if entry.missingDeps.len == 0: + readyToProcess.add(msgId) - if msgId in rm.incomingBuffer: - rm.addToHistory(msgId) - if not rm.onMessageReady.isNil(): - rm.onMessageReady(msgId) - processed.incl(msgId) + while readyToProcess.len > 0: + let msgId = readyToProcess.pop() + if msgId in processed: + continue - # Update dependencies for remaining messages - for remainingId, entry in rm.incomingBuffer: - if remainingId notin processed: - if msgId in entry.missingDeps: - rm.incomingBuffer[remainingId].missingDeps.excl(msgId) - if rm.incomingBuffer[remainingId].missingDeps.len == 0: - readyToProcess.add(remainingId) + if msgId in channel.incomingBuffer: + rm.addToHistory(msgId, channelId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msgId, channelId) + processed.incl(msgId) - # Remove processed messages - for msgId in processed: - rm.incomingBuffer.del(msgId) + for remainingId, entry in channel.incomingBuffer: + if remainingId notin processed: + if msgId in entry.missingDeps: + channel.incomingBuffer[remainingId].missingDeps.excl(msgId) + if channel.incomingBuffer[remainingId].missingDeps.len == 0: + readyToProcess.add(remainingId) + + for msgId in processed: + channel.incomingBuffer.del(msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] -): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] = - ## Unwraps a received message and processes its reliability metadata. - ## - ## Parameters: - ## - message: The received message bytes - ## - ## Returns: - ## A Result containing either tuple of (processed message, missing dependencies) or an error. +): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID], ReliabilityError] = try: + let channelId = extractChannelId(message).valueOr: + return err(ReliabilityError.reDeserializationError) + let msg = deserializeMessage(message).valueOr: return err(ReliabilityError.reDeserializationError) - if msg.messageId in rm.messageHistory: - return ok((msg.content, @[])) + withLock rm.lock: + let channel = rm.getOrCreateChannel(channelId) - rm.bloomFilter.add(msg.messageId) + if msg.messageId in channel.messageHistory: + return ok((msg.content, @[], channelId)) - # Update Lamport timestamp - rm.updateLamportTimestamp(msg.lamportTimestamp) + channel.bloomFilter.add(msg.messageId) - # Review ACK status for outgoing messages - rm.reviewAckStatus(msg) + rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) - var missingDeps = rm.checkDependencies(msg.causalHistory) + rm.reviewAckStatus(msg, channelId) - if missingDeps.len == 0: - # Check if any dependencies are still in incoming buffer - var depsInBuffer = false - for msgId, entry in rm.incomingBuffer.pairs(): - if msgId in msg.causalHistory: - depsInBuffer = true - break + var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) - if depsInBuffer: - rm.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]()) + if missingDeps.len == 0: + var depsInBuffer = false + for msgId, entry in channel.incomingBuffer.pairs(): + if msgId in msg.causalHistory: + depsInBuffer = true + break + + if depsInBuffer: + channel.incomingBuffer[msg.messageId] = + IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]()) + else: + rm.addToHistory(msg.messageId, channelId) + rm.processIncomingBuffer(channelId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msg.messageId, channelId) else: - # All dependencies met, add to history - rm.addToHistory(msg.messageId) - rm.processIncomingBuffer() - if not rm.onMessageReady.isNil(): - rm.onMessageReady(msg.messageId) - else: - rm.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet()) - if not rm.onMissingDependencies.isNil(): - rm.onMissingDependencies(msg.messageId, missingDeps) + channel.incomingBuffer[msg.messageId] = + IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet()) + if not rm.onMissingDependencies.isNil(): + rm.onMissingDependencies(msg.messageId, missingDeps, channelId) - return ok((msg.content, missingDeps)) + return ok((msg.content, missingDeps, channelId)) except Exception: error "Failed to unwrap message", msg = getCurrentExceptionMsg() return err(ReliabilityError.reDeserializationError) proc markDependenciesMet*( - rm: ReliabilityManager, messageIds: seq[SdsMessageID] + rm: ReliabilityManager, messageIds: seq[SdsMessageID], channelId: SdsChannelID ): Result[void, ReliabilityError] = - ## Marks the specified message dependencies as met. - ## - ## Parameters: - ## - messageIds: A sequence of message IDs to mark as met. - ## - ## Returns: - ## A Result indicating success or an error. try: - # Add all messageIds to bloom filter - for msgId in messageIds: - if not rm.bloomFilter.contains(msgId): - rm.bloomFilter.add(msgId) - # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + withLock rm.lock: + if channelId notin rm.channels: + return err(ReliabilityError.reInvalidArgument) - # Update any pending messages that depend on this one - for pendingId, entry in rm.incomingBuffer: - if msgId in entry.missingDeps: - rm.incomingBuffer[pendingId].missingDeps.excl(msgId) + let channel = rm.channels[channelId] - rm.processIncomingBuffer() - return ok() + for msgId in messageIds: + if not channel.bloomFilter.contains(msgId): + channel.bloomFilter.add(msgId) + + for pendingId, entry in channel.incomingBuffer: + if msgId in entry.missingDeps: + channel.incomingBuffer[pendingId].missingDeps.excl(msgId) + + rm.processIncomingBuffer(channelId) + return ok() except Exception: - error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg() + error "Failed to mark dependencies as met", channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reInternalError) proc setCallbacks*( @@ -275,28 +258,29 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} = - ## Checks and processes unacknowledged messages in the outgoing buffer. - withLock rm.lock: - let now = getTime() - var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] +proc checkUnacknowledgedMessages(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe.} = + if channelId notin rm.channels: + return - for unackMsg in rm.outgoingBuffer: - let elapsed = now - unackMsg.sendTime - if elapsed > rm.config.resendInterval: - # Time to attempt resend - if unackMsg.resendAttempts < rm.config.maxResendAttempts: - var updatedMsg = unackMsg - updatedMsg.resendAttempts += 1 - updatedMsg.sendTime = now - newOutgoingBuffer.add(updatedMsg) - else: - if not rm.onMessageSent.isNil(): - rm.onMessageSent(unackMsg.message.messageId) + let channel = rm.channels[channelId] + let now = getTime() + var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] + + for unackMsg in channel.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) else: - newOutgoingBuffer.add(unackMsg) + if not rm.onMessageSent.isNil(): + rm.onMessageSent(unackMsg.message.messageId, channelId) + else: + newOutgoingBuffer.add(unackMsg) - rm.outgoingBuffer = newOutgoingBuffer + channel.outgoingBuffer = newOutgoingBuffer proc periodicBufferSweep( rm: ReliabilityManager @@ -304,8 +288,13 @@ proc periodicBufferSweep( ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: try: - rm.checkUnacknowledgedMessages() - rm.cleanBloomFilter() + withLock rm.lock: + for channelId, channel in rm.channels: + try: + rm.checkUnacknowledgedMessages(channelId) + rm.cleanBloomFilter(channelId) + except Exception: + error "Error in buffer sweep for channel", channelId = channelId, msg = getCurrentExceptionMsg() except Exception: error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg() @@ -336,13 +325,15 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE ## This procedure clears all buffers and resets the Lamport timestamp. withLock rm.lock: try: - rm.lamportTimestamp = 0 - rm.messageHistory.setLen(0) - rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.clear() - rm.bloomFilter = newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate - ) + for channelId, channel in rm.channels: + channel.lamportTimestamp = 0 + channel.messageHistory.setLen(0) + channel.outgoingBuffer.setLen(0) + channel.incomingBuffer.clear() + channel.bloomFilter = newRollingBloomFilter( + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate + ) + rm.channels.clear() return ok() except Exception: error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg() diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index 3cc23fa..c9001ea 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -1,14 +1,14 @@ -import std/[times, locks, options] -import chronicles +import std/[times, locks, tables] +import chronicles, results import ./[rolling_bloom_filter, message] type - MessageReadyCallback* = proc(messageId: SdsMessageID) {.gcsafe.} + MessageReadyCallback* = proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - MessageSentCallback* = proc(messageId: SdsMessageID) {.gcsafe.} + MessageSentCallback* = proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} MissingDependenciesCallback* = - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.} + proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} @@ -28,19 +28,21 @@ type syncMessageInterval*: Duration bufferSweepInterval*: Duration - ReliabilityManager* = ref object + ChannelContext* = ref object lamportTimestamp*: int64 messageHistory*: seq[SdsMessageID] bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] incomingBuffer*: Table[SdsMessageID, IncomingMessage] - channelId*: Option[SdsChannelID] + + ReliabilityManager* = ref object + channels*: Table[SdsChannelID, ChannelContext] config*: ReliabilityConfig lock*: Lock - onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.} - onMessageSent*: proc(messageId: SdsMessageID) {.gcsafe.} + onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMissingDependencies*: - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.} + proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback ReliabilityError* {.pure.} = enum @@ -71,51 +73,139 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = if not rm.isNil(): try: withLock rm.lock: - rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.clear() - rm.messageHistory.setLen(0) + for channelId, channel in rm.channels: + channel.outgoingBuffer.setLen(0) + channel.incomingBuffer.clear() + channel.messageHistory.setLen(0) + rm.channels.clear() except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() -proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} = +proc cleanBloomFilter*(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe, raises: [].} = withLock rm.lock: try: - rm.bloomFilter.clean() + if channelId in rm.channels: + rm.channels[channelId].bloomFilter.clean() except Exception: - error "Failed to clean bloom filter", error = getCurrentExceptionMsg() + error "Failed to clean bloom filter", error = getCurrentExceptionMsg(), channelId = channelId -proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID) {.gcsafe, raises: [].} = - rm.messageHistory.add(msgId) - if rm.messageHistory.len > rm.config.maxMessageHistory: - rm.messageHistory.delete(0) +proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsChannelID) {.gcsafe, raises: [].} = + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + channel.messageHistory.add(msgId) + if channel.messageHistory.len > rm.config.maxMessageHistory: + channel.messageHistory.delete(0) + except Exception: + error "Failed to add to history", channelId = channelId, msgId = msgId, error = getCurrentExceptionMsg() proc updateLamportTimestamp*( - rm: ReliabilityManager, msgTs: int64 + rm: ReliabilityManager, msgTs: int64, channelId: SdsChannelID ) {.gcsafe, raises: [].} = - rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1 + except Exception: + error "Failed to update lamport timestamp", channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() -proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] = - result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] +proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int, channelId: SdsChannelID): seq[SdsMessageID] = + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + result = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] + else: + result = @[] + except Exception: + error "Failed to get recent message IDs", channelId = channelId, n = n, error = getCurrentExceptionMsg() + result = @[] proc checkDependencies*( - rm: ReliabilityManager, deps: seq[SdsMessageID] + rm: ReliabilityManager, deps: seq[SdsMessageID], channelId: SdsChannelID ): seq[SdsMessageID] = var missingDeps: seq[SdsMessageID] = @[] - for depId in deps: - if depId notin rm.messageHistory: - missingDeps.add(depId) + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + for depId in deps: + if depId notin channel.messageHistory: + missingDeps.add(depId) + else: + missingDeps = deps + except Exception: + error "Failed to check dependencies", channelId = channelId, error = getCurrentExceptionMsg() + missingDeps = deps return missingDeps -proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] = +proc getMessageHistory*(rm: ReliabilityManager, channelId: SdsChannelID): seq[SdsMessageID] = withLock rm.lock: - result = rm.messageHistory + try: + if channelId in rm.channels: + result = rm.channels[channelId].messageHistory + else: + result = @[] + except Exception: + error "Failed to get message history", channelId = channelId, error = getCurrentExceptionMsg() + result = @[] -proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] = +proc getOutgoingBuffer*(rm: ReliabilityManager, channelId: SdsChannelID): seq[UnacknowledgedMessage] = withLock rm.lock: - result = rm.outgoingBuffer + try: + if channelId in rm.channels: + result = rm.channels[channelId].outgoingBuffer + else: + result = @[] + except Exception: + error "Failed to get outgoing buffer", channelId = channelId, error = getCurrentExceptionMsg() + result = @[] proc getIncomingBuffer*( - rm: ReliabilityManager + rm: ReliabilityManager, channelId: SdsChannelID ): Table[SdsMessageID, message.IncomingMessage] = withLock rm.lock: - result = rm.incomingBuffer + try: + if channelId in rm.channels: + result = rm.channels[channelId].incomingBuffer + else: + result = initTable[SdsMessageID, message.IncomingMessage]() + except Exception: + error "Failed to get incoming buffer", channelId = channelId, error = getCurrentExceptionMsg() + result = initTable[SdsMessageID, message.IncomingMessage]() + +proc getOrCreateChannel*(rm: ReliabilityManager, channelId: SdsChannelID): ChannelContext = + try: + if channelId notin rm.channels: + rm.channels[channelId] = ChannelContext( + lamportTimestamp: 0, + messageHistory: @[], + bloomFilter: newRollingBloomFilter(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate), + outgoingBuffer: @[], + incomingBuffer: initTable[SdsMessageID, IncomingMessage]() + ) + result = rm.channels[channelId] + except Exception: + error "Failed to get or create channel", channelId = channelId, error = getCurrentExceptionMsg() + raise + +proc ensureChannel*(rm: ReliabilityManager, channelId: SdsChannelID): Result[void, ReliabilityError] = + withLock rm.lock: + try: + discard rm.getOrCreateChannel(channelId) + return ok() + except Exception: + error "Failed to ensure channel", channelId = channelId, msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) + +proc removeChannel*(rm: ReliabilityManager, channelId: SdsChannelID): Result[void, ReliabilityError] = + withLock rm.lock: + try: + if channelId in rm.channels: + let channel = rm.channels[channelId] + channel.outgoingBuffer.setLen(0) + channel.incomingBuffer.clear() + channel.messageHistory.setLen(0) + rm.channels.del(channelId) + return ok() + except Exception: + error "Failed to remove channel", channelId = channelId, msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError)