diff --git a/src/reliability.nim b/src/reliability.nim index 38be506..a39fac3 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -14,8 +14,7 @@ proc newReliabilityManager*( ## A Result containing either a new ReliabilityManager instance or an error. try: let rm = ReliabilityManager( - channels: initTable[SdsChannelID, ChannelContext](), - config: config, + channels: initTable[SdsChannelID, ChannelContext](), config: config ) initLock(rm.lock) return ok(rm) @@ -36,7 +35,7 @@ proc isAcknowledged*( false -proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage, channelId: SdsChannelID) {.gcsafe.} = +proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = # Parse bloom filter var rbf: Option[RollingBloomFilter] if msg.bloomFilter.len > 0: @@ -61,10 +60,10 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage, channelId: SdsChan else: rbf = none[RollingBloomFilter]() - if channelId notin rm.channels: + if msg.channelId notin rm.channels: return - let channel = rm.channels[channelId] + let channel = rm.channels[msg.channelId] # Keep track of indices to delete var toDelete: seq[int] = @[] var i = 0 @@ -73,7 +72,7 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage, channelId: SdsChan let outMsg = channel.outgoingBuffer[i] if outMsg.isAcknowledged(msg.causalHistory, rbf): if not rm.onMessageSent.isNil(): - rm.onMessageSent(outMsg.message.messageId, channelId) + rm.onMessageSent(outMsg.message.messageId, outMsg.message.channelId) toDelete.add(i) inc i @@ -81,13 +80,17 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage, channelId: SdsChan channel.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( - rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID, channelId: SdsChannelID + rm: ReliabilityManager, + message: seq[byte], + messageId: SdsMessageID, + channelId: SdsChannelID, ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## ## Parameters: ## - message: The content of the message to be sent. ## - messageId: Unique identifier for the message + ## - channelId: Identifier for the channel this message belongs to. ## ## Returns: ## A Result containing either wrapped message bytes or an error. @@ -125,48 +128,64 @@ proc wrapOutgoingMessage*( return serializeMessage(msg) except Exception: - error "Failed to wrap message", channelId = channelId, msg = getCurrentExceptionMsg() + error "Failed to wrap message", + channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reSerializationError) proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe.} = - if channelId notin rm.channels: - return + withLock rm.lock: + if channelId notin rm.channels: + error "Channel does not exist", channelId = channelId + return - let channel = rm.channels[channelId] - if channel.incomingBuffer.len == 0: - return + let channel = rm.channels[channelId] + if channel.incomingBuffer.len == 0: + return - var processed = initHashSet[SdsMessageID]() - var readyToProcess = newSeq[SdsMessageID]() + var processed = initHashSet[SdsMessageID]() + var readyToProcess = newSeq[SdsMessageID]() - for msgId, entry in channel.incomingBuffer: - if entry.missingDeps.len == 0: - readyToProcess.add(msgId) + # Find initially ready messages + for msgId, entry in channel.incomingBuffer: + if entry.missingDeps.len == 0: + readyToProcess.add(msgId) - while readyToProcess.len > 0: - let msgId = readyToProcess.pop() - if msgId in processed: - continue + while readyToProcess.len > 0: + let msgId = readyToProcess.pop() + if msgId in processed: + continue - if msgId in channel.incomingBuffer: - rm.addToHistory(msgId, channelId) - if not rm.onMessageReady.isNil(): - rm.onMessageReady(msgId, channelId) - processed.incl(msgId) + if msgId in channel.incomingBuffer: + rm.addToHistory(msgId, channelId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msgId, channelId) + processed.incl(msgId) - for remainingId, entry in channel.incomingBuffer: - if remainingId notin processed: - if msgId in entry.missingDeps: - channel.incomingBuffer[remainingId].missingDeps.excl(msgId) - if channel.incomingBuffer[remainingId].missingDeps.len == 0: - readyToProcess.add(remainingId) + # Update dependencies for remaining messages + for remainingId, entry in channel.incomingBuffer: + if remainingId notin processed: + if msgId in entry.missingDeps: + channel.incomingBuffer[remainingId].missingDeps.excl(msgId) + if channel.incomingBuffer[remainingId].missingDeps.len == 0: + readyToProcess.add(remainingId) - for msgId in processed: - channel.incomingBuffer.del(msgId) + # Remove processed messages + for msgId in processed: + channel.incomingBuffer.del(msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] -): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID], ReliabilityError] = +): Result[ + tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID], + ReliabilityError, +] = + ## Unwraps a received message and processes its reliability metadata. + ## + ## Parameters: + ## - message: The received message bytes + ## + ## Returns: + ## A Result containing either tuple of (processed message, missing dependencies, channel ID) or an error. try: let channelId = extractChannelId(message).valueOr: return err(ReliabilityError.reDeserializationError) @@ -174,42 +193,42 @@ proc unwrapReceivedMessage*( let msg = deserializeMessage(message).valueOr: return err(ReliabilityError.reDeserializationError) - withLock rm.lock: - let channel = rm.getOrCreateChannel(channelId) + let channel = rm.getOrCreateChannel(channelId) - if msg.messageId in channel.messageHistory: - return ok((msg.content, @[], channelId)) + if msg.messageId in channel.messageHistory: + return ok((msg.content, @[], channelId)) - channel.bloomFilter.add(msg.messageId) + channel.bloomFilter.add(msg.messageId) - rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) + rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) + # Review ACK status for outgoing messages + rm.reviewAckStatus(msg) - rm.reviewAckStatus(msg, channelId) + var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) - var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) - - if missingDeps.len == 0: - var depsInBuffer = false - for msgId, entry in channel.incomingBuffer.pairs(): - if msgId in msg.causalHistory: - depsInBuffer = true - break - - if depsInBuffer: - channel.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]()) - else: - rm.addToHistory(msg.messageId, channelId) - rm.processIncomingBuffer(channelId) - if not rm.onMessageReady.isNil(): - rm.onMessageReady(msg.messageId, channelId) - else: + if missingDeps.len == 0: + var depsInBuffer = false + for msgId, entry in channel.incomingBuffer.pairs(): + if msgId in msg.causalHistory: + depsInBuffer = true + break + # Check if any dependencies are still in incoming buffer + if depsInBuffer: channel.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet()) - if not rm.onMissingDependencies.isNil(): - rm.onMissingDependencies(msg.messageId, missingDeps, channelId) + IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]()) + else: + # All dependencies met, add to history + rm.addToHistory(msg.messageId, channelId) + rm.processIncomingBuffer(channelId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msg.messageId, channelId) + else: + channel.incomingBuffer[msg.messageId] = + IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet()) + if not rm.onMissingDependencies.isNil(): + rm.onMissingDependencies(msg.messageId, missingDeps, channelId) - return ok((msg.content, missingDeps, channelId)) + return ok((msg.content, missingDeps, channelId)) except Exception: error "Failed to unwrap message", msg = getCurrentExceptionMsg() return err(ReliabilityError.reDeserializationError) @@ -217,25 +236,33 @@ proc unwrapReceivedMessage*( proc markDependenciesMet*( rm: ReliabilityManager, messageIds: seq[SdsMessageID], channelId: SdsChannelID ): Result[void, ReliabilityError] = + ## Marks the specified message dependencies as met. + ## + ## Parameters: + ## - messageIds: A sequence of message IDs to mark as met. + ## - channelId: Identifier for the channel. + ## + ## Returns: + ## A Result indicating success or an error. try: - withLock rm.lock: - if channelId notin rm.channels: - return err(ReliabilityError.reInvalidArgument) + if channelId notin rm.channels: + return err(ReliabilityError.reInvalidArgument) - let channel = rm.channels[channelId] + let channel = rm.channels[channelId] - for msgId in messageIds: - if not channel.bloomFilter.contains(msgId): - channel.bloomFilter.add(msgId) + for msgId in messageIds: + if not channel.bloomFilter.contains(msgId): + channel.bloomFilter.add(msgId) - for pendingId, entry in channel.incomingBuffer: - if msgId in entry.missingDeps: - channel.incomingBuffer[pendingId].missingDeps.excl(msgId) + for pendingId, entry in channel.incomingBuffer: + if msgId in entry.missingDeps: + channel.incomingBuffer[pendingId].missingDeps.excl(msgId) - rm.processIncomingBuffer(channelId) - return ok() + rm.processIncomingBuffer(channelId) + return ok() except Exception: - error "Failed to mark dependencies as met", channelId = channelId, msg = getCurrentExceptionMsg() + error "Failed to mark dependencies as met", + channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reInternalError) proc setCallbacks*( @@ -258,29 +285,34 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe.} = - if channelId notin rm.channels: - return +proc checkUnacknowledgedMessages( + rm: ReliabilityManager, channelId: SdsChannelID +) {.gcsafe.} = + ## Checks and processes unacknowledged messages in the outgoing buffer. + withLock rm.lock: + if channelId notin rm.channels: + error "Channel does not exist", channelId = channelId + return - let channel = rm.channels[channelId] - let now = getTime() - var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] + let channel = rm.channels[channelId] + let now = getTime() + var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] - for unackMsg in channel.outgoingBuffer: - let elapsed = now - unackMsg.sendTime - if elapsed > rm.config.resendInterval: - if unackMsg.resendAttempts < rm.config.maxResendAttempts: - var updatedMsg = unackMsg - updatedMsg.resendAttempts += 1 - updatedMsg.sendTime = now - newOutgoingBuffer.add(updatedMsg) + for unackMsg in channel.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) + else: + if not rm.onMessageSent.isNil(): + rm.onMessageSent(unackMsg.message.messageId, channelId) else: - if not rm.onMessageSent.isNil(): - rm.onMessageSent(unackMsg.message.messageId, channelId) - else: - newOutgoingBuffer.add(unackMsg) + newOutgoingBuffer.add(unackMsg) - channel.outgoingBuffer = newOutgoingBuffer + channel.outgoingBuffer = newOutgoingBuffer proc periodicBufferSweep( rm: ReliabilityManager @@ -288,13 +320,13 @@ proc periodicBufferSweep( ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: try: - withLock rm.lock: - for channelId, channel in rm.channels: - try: - rm.checkUnacknowledgedMessages(channelId) - rm.cleanBloomFilter(channelId) - except Exception: - error "Error in buffer sweep for channel", channelId = channelId, msg = getCurrentExceptionMsg() + for channelId, channel in rm.channels: + try: + rm.checkUnacknowledgedMessages(channelId) + rm.cleanBloomFilter(channelId) + except Exception: + error "Error in buffer sweep for channel", + channelId = channelId, msg = getCurrentExceptionMsg() except Exception: error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg() diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index c9001ea..28248da 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -3,12 +3,15 @@ import chronicles, results import ./[rolling_bloom_filter, message] type - MessageReadyCallback* = proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + MessageReadyCallback* = + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - MessageSentCallback* = proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + MessageSentCallback* = + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - MissingDependenciesCallback* = - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} + MissingDependenciesCallback* = proc( + messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID + ) {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} @@ -41,8 +44,9 @@ type lock*: Lock onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - onMissingDependencies*: - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} + onMissingDependencies*: proc( + messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID + ) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback ReliabilityError* {.pure.} = enum @@ -81,15 +85,20 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() -proc cleanBloomFilter*(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe, raises: [].} = +proc cleanBloomFilter*( + rm: ReliabilityManager, channelId: SdsChannelID +) {.gcsafe, raises: [].} = withLock rm.lock: try: if channelId in rm.channels: rm.channels[channelId].bloomFilter.clean() except Exception: - error "Failed to clean bloom filter", error = getCurrentExceptionMsg(), channelId = channelId + error "Failed to clean bloom filter", + error = getCurrentExceptionMsg(), channelId = channelId -proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsChannelID) {.gcsafe, raises: [].} = +proc addToHistory*( + rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsChannelID +) {.gcsafe, raises: [].} = try: if channelId in rm.channels: let channel = rm.channels[channelId] @@ -97,7 +106,8 @@ proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsCh if channel.messageHistory.len > rm.config.maxMessageHistory: channel.messageHistory.delete(0) except Exception: - error "Failed to add to history", channelId = channelId, msgId = msgId, error = getCurrentExceptionMsg() + error "Failed to add to history", + channelId = channelId, msgId = msgId, error = getCurrentExceptionMsg() proc updateLamportTimestamp*( rm: ReliabilityManager, msgTs: int64, channelId: SdsChannelID @@ -107,9 +117,12 @@ proc updateLamportTimestamp*( let channel = rm.channels[channelId] channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1 except Exception: - error "Failed to update lamport timestamp", channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() + error "Failed to update lamport timestamp", + channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() -proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int, channelId: SdsChannelID): seq[SdsMessageID] = +proc getRecentSdsMessageIDs*( + rm: ReliabilityManager, n: int, channelId: SdsChannelID +): seq[SdsMessageID] = try: if channelId in rm.channels: let channel = rm.channels[channelId] @@ -117,7 +130,8 @@ proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int, channelId: SdsChann else: result = @[] except Exception: - error "Failed to get recent message IDs", channelId = channelId, n = n, error = getCurrentExceptionMsg() + error "Failed to get recent message IDs", + channelId = channelId, n = n, error = getCurrentExceptionMsg() result = @[] proc checkDependencies*( @@ -133,11 +147,14 @@ proc checkDependencies*( else: missingDeps = deps except Exception: - error "Failed to check dependencies", channelId = channelId, error = getCurrentExceptionMsg() + error "Failed to check dependencies", + channelId = channelId, error = getCurrentExceptionMsg() missingDeps = deps return missingDeps -proc getMessageHistory*(rm: ReliabilityManager, channelId: SdsChannelID): seq[SdsMessageID] = +proc getMessageHistory*( + rm: ReliabilityManager, channelId: SdsChannelID +): seq[SdsMessageID] = withLock rm.lock: try: if channelId in rm.channels: @@ -145,10 +162,13 @@ proc getMessageHistory*(rm: ReliabilityManager, channelId: SdsChannelID): seq[Sd else: result = @[] except Exception: - error "Failed to get message history", channelId = channelId, error = getCurrentExceptionMsg() + error "Failed to get message history", + channelId = channelId, error = getCurrentExceptionMsg() result = @[] -proc getOutgoingBuffer*(rm: ReliabilityManager, channelId: SdsChannelID): seq[UnacknowledgedMessage] = +proc getOutgoingBuffer*( + rm: ReliabilityManager, channelId: SdsChannelID +): seq[UnacknowledgedMessage] = withLock rm.lock: try: if channelId in rm.channels: @@ -156,7 +176,8 @@ proc getOutgoingBuffer*(rm: ReliabilityManager, channelId: SdsChannelID): seq[Un else: result = @[] except Exception: - error "Failed to get outgoing buffer", channelId = channelId, error = getCurrentExceptionMsg() + error "Failed to get outgoing buffer", + channelId = channelId, error = getCurrentExceptionMsg() result = @[] proc getIncomingBuffer*( @@ -169,34 +190,45 @@ proc getIncomingBuffer*( else: result = initTable[SdsMessageID, message.IncomingMessage]() except Exception: - error "Failed to get incoming buffer", channelId = channelId, error = getCurrentExceptionMsg() + error "Failed to get incoming buffer", + channelId = channelId, error = getCurrentExceptionMsg() result = initTable[SdsMessageID, message.IncomingMessage]() -proc getOrCreateChannel*(rm: ReliabilityManager, channelId: SdsChannelID): ChannelContext = +proc getOrCreateChannel*( + rm: ReliabilityManager, channelId: SdsChannelID +): ChannelContext = try: if channelId notin rm.channels: rm.channels[channelId] = ChannelContext( lamportTimestamp: 0, messageHistory: @[], - bloomFilter: newRollingBloomFilter(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate), + bloomFilter: newRollingBloomFilter( + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate + ), outgoingBuffer: @[], - incomingBuffer: initTable[SdsMessageID, IncomingMessage]() + incomingBuffer: initTable[SdsMessageID, IncomingMessage](), ) result = rm.channels[channelId] except Exception: - error "Failed to get or create channel", channelId = channelId, error = getCurrentExceptionMsg() + error "Failed to get or create channel", + channelId = channelId, error = getCurrentExceptionMsg() raise -proc ensureChannel*(rm: ReliabilityManager, channelId: SdsChannelID): Result[void, ReliabilityError] = +proc ensureChannel*( + rm: ReliabilityManager, channelId: SdsChannelID +): Result[void, ReliabilityError] = withLock rm.lock: try: discard rm.getOrCreateChannel(channelId) return ok() except Exception: - error "Failed to ensure channel", channelId = channelId, msg = getCurrentExceptionMsg() + error "Failed to ensure channel", + channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reInternalError) -proc removeChannel*(rm: ReliabilityManager, channelId: SdsChannelID): Result[void, ReliabilityError] = +proc removeChannel*( + rm: ReliabilityManager, channelId: SdsChannelID +): Result[void, ReliabilityError] = withLock rm.lock: try: if channelId in rm.channels: @@ -207,5 +239,6 @@ proc removeChannel*(rm: ReliabilityManager, channelId: SdsChannelID): Result[voi rm.channels.del(channelId) return ok() except Exception: - error "Failed to remove channel", channelId = channelId, msg = getCurrentExceptionMsg() + error "Failed to remove channel", + channelId = channelId, msg = getCurrentExceptionMsg() return err(ReliabilityError.reInternalError)