From f0cd0c362415cf4f6b8440fc9a58cf703a1d6357 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Mon, 17 Feb 2025 14:47:01 +0530 Subject: [PATCH 1/3] feat: add reliability.nim --- src/reliability.nim | 346 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 346 insertions(+) create mode 100644 src/reliability.nim diff --git a/src/reliability.nim b/src/reliability.nim new file mode 100644 index 0000000..0cc490e --- /dev/null +++ b/src/reliability.nim @@ -0,0 +1,346 @@ +import std/[times, locks, tables, sets] +import chronos, results +import ../src/[message, protobuf, reliability_utils, rolling_bloom_filter] + +proc newReliabilityManager*( + channelId: string, config: ReliabilityConfig = defaultConfig() +): Result[ReliabilityManager, ReliabilityError] = + ## Creates a new ReliabilityManager with the specified channel ID and configuration. + ## + ## Parameters: + ## - channelId: A unique identifier for the communication channel. + ## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used. + ## + ## Returns: + ## A Result containing either a new ReliabilityManager instance or an error. + if channelId.len == 0: + return err(reInvalidArgument) + + try: + let bloomFilter = newRollingBloomFilter( + config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow + ) + + let rm = ReliabilityManager( + lamportTimestamp: 0, + messageHistory: @[], + bloomFilter: bloomFilter, + outgoingBuffer: @[], + incomingBuffer: @[], + channelId: channelId, + config: config, + ) + initLock(rm.lock) + return ok(rm) + except: + return err(reOutOfMemory) + +proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = + var i = 0 + while i < rm.outgoingBuffer.len: + var acknowledged = false + let outMsg = rm.outgoingBuffer[i] + + # Check if message is in causal history + for msgID in msg.causalHistory: + if outMsg.message.messageId == msgID: + acknowledged = true + break + + # Check bloom filter if not already acknowledged + if not acknowledged and msg.bloomFilter.len > 0: + let bfResult = deserializeBloomFilter(msg.bloomFilter) + if bfResult.isOk: + var rbf = RollingBloomFilter( + filter: bfResult.get(), window: rm.bloomFilter.window, messages: @[] + ) + if rbf.contains(outMsg.message.messageId): + acknowledged = true + else: + logError("Failed to deserialize bloom filter") + + if acknowledged: + if rm.onMessageSent != nil: + rm.onMessageSent(outMsg.message.messageId) + rm.outgoingBuffer.delete(i) + else: + inc i + +proc wrapOutgoingMessage*( + rm: ReliabilityManager, message: seq[byte], messageId: MessageID +): Result[seq[byte], ReliabilityError] = + ## Wraps an outgoing message with reliability metadata. + ## + ## Parameters: + ## - message: The content of the message to be sent. + ## + ## Returns: + ## A Result containing either a Message object with reliability metadata or an error. + if message.len == 0: + return err(reInvalidArgument) + if message.len > MaxMessageSize: + return err(reMessageTooLarge) + + withLock rm.lock: + try: + rm.updateLamportTimestamp(getTime().toUnix) + + # Serialize current bloom filter + var bloomBytes: seq[byte] + let bfResult = serializeBloomFilter(rm.bloomFilter.filter) + if bfResult.isErr: + logError("Failed to serialize bloom filter") + bloomBytes = @[] + else: + bloomBytes = bfResult.get() + + let msg = Message( + messageId: messageId, + lamportTimestamp: rm.lamportTimestamp, + causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), + channelId: rm.channelId, + content: message, + bloomFilter: bloomBytes, + ) + + # Add to outgoing buffer + rm.outgoingBuffer.add( + UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0) + ) + + # Add to causal history and bloom filter + rm.bloomFilter.add(msg.messageId) + rm.addToHistory(msg.messageId) + + return serializeMessage(msg) + except: + return err(reInternalError) + +proc processIncomingBuffer(rm: ReliabilityManager) = + withLock rm.lock: + if rm.incomingBuffer.len == 0: + return + + # Create dependency map + var dependencies = initTable[MessageID, seq[MessageID]]() + var readyToProcess: seq[MessageID] = @[] + + # Build dependency graph and find initially ready messages + for msg in rm.incomingBuffer: + var hasMissingDeps = false + for depId in msg.causalHistory: + if not rm.bloomFilter.contains(depId): + hasMissingDeps = true + if depId notin dependencies: + dependencies[depId] = @[] + dependencies[depId].add(msg.messageId) + + if not hasMissingDeps: + readyToProcess.add(msg.messageId) + + # Process ready messages and their dependents + var newIncomingBuffer: seq[Message] = @[] + var processed = initHashSet[MessageID]() + + while readyToProcess.len > 0: + let msgId = readyToProcess.pop() + if msgId in processed: + continue + + # Process this message + for msg in rm.incomingBuffer: + if msg.messageId == msgId: + rm.addToHistory(msg.messageId) + if rm.onMessageReady != nil: + rm.onMessageReady(msg.messageId) + processed.incl(msgId) + + # Add any dependent messages that might now be ready + if msgId in dependencies: + for dependentId in dependencies[msgId]: + readyToProcess.add(dependentId) + break + + # Update incomingBuffer with remaining messages + for msg in rm.incomingBuffer: + if msg.messageId notin processed: + newIncomingBuffer.add(msg) + + rm.incomingBuffer = newIncomingBuffer + +proc unwrapReceivedMessage*( + rm: ReliabilityManager, message: seq[byte] +): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] = + ## Unwraps a received message and processes its reliability metadata. + ## + ## Parameters: + ## - message: The received Message object. + ## + ## Returns: + ## A Result containing either a tuple with the processed message and missing dependencies, or an error. + try: + let msgResult = deserializeMessage(message) + if not msgResult.isOk: + return err(msgResult.error) + + let msg = msgResult.get + if rm.bloomFilter.contains(msg.messageId): + return ok((msg.content, @[])) + + rm.bloomFilter.add(msg.messageId) + + # Update Lamport timestamp + rm.updateLamportTimestamp(msg.lamportTimestamp) + + # Review ACK status for outgoing messages + rm.reviewAckStatus(msg) + + var missingDeps: seq[MessageID] = @[] + for depId in msg.causalHistory: + if not rm.bloomFilter.contains(depId): + missingDeps.add(depId) + + if missingDeps.len == 0: + # Check if any dependencies are still in incoming buffer + var depsInBuffer = false + for bufferedMsg in rm.incomingBuffer: + if bufferedMsg.messageId in msg.causalHistory: + depsInBuffer = true + break + if depsInBuffer: + rm.incomingBuffer.add(msg) + else: + # All dependencies met, add to history + rm.addToHistory(msg.messageId) + rm.processIncomingBuffer() + if rm.onMessageReady != nil: + rm.onMessageReady(msg.messageId) + else: + # Buffer message and request missing dependencies + rm.incomingBuffer.add(msg) + if rm.onMissingDependencies != nil: + rm.onMissingDependencies(msg.messageId, missingDeps) + + return ok((msg.content, missingDeps)) + except: + return err(reInternalError) + +proc markDependenciesMet*( + rm: ReliabilityManager, messageIds: seq[MessageID] +): Result[void, ReliabilityError] = + ## Marks the specified message dependencies as met. + ## + ## Parameters: + ## - messageIds: A sequence of message IDs to mark as met. + ## + ## Returns: + ## A Result indicating success or an error. + try: + # Add all messageIds to bloom filter + for msgId in messageIds: + if not rm.bloomFilter.contains(msgId): + rm.bloomFilter.add(msgId) + # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + rm.processIncomingBuffer() + + return ok() + except: + return err(reInternalError) + +proc setCallbacks*( + rm: ReliabilityManager, + onMessageReady: proc(messageId: MessageID) {.gcsafe.}, + onMessageSent: proc(messageId: MessageID) {.gcsafe.}, + onMissingDependencies: + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}, + onPeriodicSync: PeriodicSyncCallback = nil, +) = + ## Sets the callback functions for various events in the ReliabilityManager. + ## + ## Parameters: + ## - onMessageReady: Callback function called when a message is ready to be processed. + ## - onMessageSent: Callback function called when a message is confirmed as sent. + ## - onMissingDependencies: Callback function called when a message has missing dependencies. + ## - onPeriodicSync: Callback function called to notify about periodic sync + withLock rm.lock: + rm.onMessageReady = onMessageReady + rm.onMessageSent = onMessageSent + rm.onMissingDependencies = onMissingDependencies + rm.onPeriodicSync = onPeriodicSync + +proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = + ## Checks and processes unacknowledged messages in the outgoing buffer. + withLock rm.lock: + let now = getTime() + var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] + + try: + for unackMsg in rm.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + # Time to attempt resend + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) + else: + if rm.onMessageSent != nil: + rm.onMessageSent(unackMsg.message.messageId) + else: + newOutgoingBuffer.add(unackMsg) + + rm.outgoingBuffer = newOutgoingBuffer + except Exception as e: + logError("Error in checking unacknowledged messages: " & e.msg) + +proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = + ## Periodically sweeps the buffer to clean up and check unacknowledged messages. + while true: + {.gcsafe.}: + try: + rm.checkUnacknowledgedMessages() + rm.cleanBloomFilter() + except Exception as e: + logError("Error in periodic buffer sweep: " & e.msg) + + await sleepAsync(chronos.milliseconds(rm.config.bufferSweepInterval.inMilliseconds)) + +proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = + ## Periodically notifies to send a sync message to maintain connectivity. + while true: + {.gcsafe.}: + try: + if rm.onPeriodicSync != nil: + rm.onPeriodicSync() + except Exception as e: + logError("Error in periodic sync: " & e.msg) + await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) + +proc startPeriodicTasks*(rm: ReliabilityManager) = + ## Starts the periodic tasks for buffer sweeping and sync message sending. + ## + ## This procedure should be called after creating a ReliabilityManager to enable automatic maintenance. + asyncSpawn rm.periodicBufferSweep() + asyncSpawn rm.periodicSyncMessage() + +proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = + ## Resets the ReliabilityManager to its initial state. + ## + ## This procedure clears all buffers and resets the Lamport timestamp. + ## + ## Returns: + ## A Result indicating success or an error if the Bloom filter initialization fails. + withLock rm.lock: + try: + rm.lamportTimestamp = 0 + rm.messageHistory.setLen(0) + rm.outgoingBuffer.setLen(0) + rm.incomingBuffer.setLen(0) + rm.bloomFilter = newRollingBloomFilter( + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, + rm.config.bloomFilterWindow, + ) + return ok() + except: + return err(reInternalError) From 3e25aec7ce27263886db58f31ec9824d33bdb0d8 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Mon, 17 Feb 2025 16:16:08 +0530 Subject: [PATCH 2/3] chore: updates from prev suggestions --- src/reliability.nim | 203 +++++++++++++++++++++----------------------- 1 file changed, 98 insertions(+), 105 deletions(-) diff --git a/src/reliability.nim b/src/reliability.nim index 0cc490e..ebb533d 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,9 +1,9 @@ -import std/[times, locks, tables, sets] -import chronos, results -import ../src/[message, protobuf, reliability_utils, rolling_bloom_filter] +import std/[times, locks, tables, sets, sequtils] +import chronos, results, chronicles +import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: string, config: ReliabilityConfig = defaultConfig() + channelId: SdsChannelID, config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new ReliabilityManager with the specified channel ID and configuration. ## @@ -14,12 +14,11 @@ proc newReliabilityManager*( ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. if channelId.len == 0: - return err(reInvalidArgument) + return err(ReliabilityError.reInvalidArgument) try: - let bloomFilter = newRollingBloomFilter( - config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow - ) + let bloomFilter = + newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate) let rm = ReliabilityManager( lamportTimestamp: 0, @@ -32,10 +31,11 @@ proc newReliabilityManager*( ) initLock(rm.lock) return ok(rm) - except: - return err(reOutOfMemory) + except Exception: + error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reOutOfMemory) -proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = +proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = var i = 0 while i < rm.outgoingBuffer.len: var acknowledged = false @@ -50,57 +50,62 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = # Check bloom filter if not already acknowledged if not acknowledged and msg.bloomFilter.len > 0: let bfResult = deserializeBloomFilter(msg.bloomFilter) - if bfResult.isOk: + if bfResult.isOk(): var rbf = RollingBloomFilter( - filter: bfResult.get(), window: rm.bloomFilter.window, messages: @[] + filter: bfResult.get(), + capacity: bfResult.get().capacity, + minCapacity: ( + bfResult.get().capacity.float * (100 - CapacityFlexPercent).float / 100.0 + ).int, + maxCapacity: ( + bfResult.get().capacity.float * (100 + CapacityFlexPercent).float / 100.0 + ).int, + messages: @[], ) if rbf.contains(outMsg.message.messageId): acknowledged = true else: - logError("Failed to deserialize bloom filter") + error "Failed to deserialize bloom filter", error = bfResult.error if acknowledged: - if rm.onMessageSent != nil: + if not rm.onMessageSent.isNil(): rm.onMessageSent(outMsg.message.messageId) rm.outgoingBuffer.delete(i) else: inc i proc wrapOutgoingMessage*( - rm: ReliabilityManager, message: seq[byte], messageId: MessageID + rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## ## Parameters: ## - message: The content of the message to be sent. + ## - messageId: Unique identifier for the message ## ## Returns: - ## A Result containing either a Message object with reliability metadata or an error. + ## A Result containing either wrapped message bytes or an error. if message.len == 0: - return err(reInvalidArgument) + return err(ReliabilityError.reInvalidArgument) if message.len > MaxMessageSize: - return err(reMessageTooLarge) + return err(ReliabilityError.reMessageTooLarge) withLock rm.lock: try: rm.updateLamportTimestamp(getTime().toUnix) - # Serialize current bloom filter - var bloomBytes: seq[byte] let bfResult = serializeBloomFilter(rm.bloomFilter.filter) if bfResult.isErr: - logError("Failed to serialize bloom filter") - bloomBytes = @[] - else: - bloomBytes = bfResult.get() + error "Failed to serialize bloom filter" + return err(ReliabilityError.reSerializationError) - let msg = Message( + let msg = SdsMessage( messageId: messageId, lamportTimestamp: rm.lamportTimestamp, - causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), + causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory), channelId: rm.channelId, content: message, - bloomFilter: bloomBytes, + bloomFilter: bfResult.get(), ) # Add to outgoing buffer @@ -113,17 +118,19 @@ proc wrapOutgoingMessage*( rm.addToHistory(msg.messageId) return serializeMessage(msg) - except: - return err(reInternalError) + except Exception: + error "Failed to wrap message", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reSerializationError) -proc processIncomingBuffer(rm: ReliabilityManager) = +proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = withLock rm.lock: if rm.incomingBuffer.len == 0: return # Create dependency map - var dependencies = initTable[MessageID, seq[MessageID]]() - var readyToProcess: seq[MessageID] = @[] + var dependencies = initTable[SdsMessageID, seq[SdsMessageID]]() + var readyToProcess: seq[SdsMessageID] = @[] + var processed = initHashSet[SdsMessageID]() # Build dependency graph and find initially ready messages for msg in rm.incomingBuffer: @@ -138,10 +145,6 @@ proc processIncomingBuffer(rm: ReliabilityManager) = if not hasMissingDeps: readyToProcess.add(msg.messageId) - # Process ready messages and their dependents - var newIncomingBuffer: seq[Message] = @[] - var processed = initHashSet[MessageID]() - while readyToProcess.len > 0: let msgId = readyToProcess.pop() if msgId in processed: @@ -151,39 +154,31 @@ proc processIncomingBuffer(rm: ReliabilityManager) = for msg in rm.incomingBuffer: if msg.messageId == msgId: rm.addToHistory(msg.messageId) - if rm.onMessageReady != nil: + if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) processed.incl(msgId) # Add any dependent messages that might now be ready if msgId in dependencies: - for dependentId in dependencies[msgId]: - readyToProcess.add(dependentId) + readyToProcess.add(dependencies[msgId]) break - # Update incomingBuffer with remaining messages - for msg in rm.incomingBuffer: - if msg.messageId notin processed: - newIncomingBuffer.add(msg) - - rm.incomingBuffer = newIncomingBuffer + rm.incomingBuffer = rm.incomingBuffer.filterIt(it.messageId notin processed) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] -): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] = +): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] = ## Unwraps a received message and processes its reliability metadata. ## ## Parameters: - ## - message: The received Message object. + ## - message: The received message bytes ## ## Returns: - ## A Result containing either a tuple with the processed message and missing dependencies, or an error. + ## A Result containing either tuple of (processed message, missing dependencies) or an error. try: - let msgResult = deserializeMessage(message) - if not msgResult.isOk: - return err(msgResult.error) + let msg = deserializeMessage(message).valueOr: + return err(ReliabilityError.reDeserializationError) - let msg = msgResult.get if rm.bloomFilter.contains(msg.messageId): return ok((msg.content, @[])) @@ -195,7 +190,7 @@ proc unwrapReceivedMessage*( # Review ACK status for outgoing messages rm.reviewAckStatus(msg) - var missingDeps: seq[MessageID] = @[] + var missingDeps: seq[SdsMessageID] = @[] for depId in msg.causalHistory: if not rm.bloomFilter.contains(depId): missingDeps.add(depId) @@ -207,26 +202,27 @@ proc unwrapReceivedMessage*( if bufferedMsg.messageId in msg.causalHistory: depsInBuffer = true break + if depsInBuffer: rm.incomingBuffer.add(msg) else: # All dependencies met, add to history rm.addToHistory(msg.messageId) rm.processIncomingBuffer() - if rm.onMessageReady != nil: + if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) else: - # Buffer message and request missing dependencies rm.incomingBuffer.add(msg) - if rm.onMissingDependencies != nil: + if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps) return ok((msg.content, missingDeps)) - except: - return err(reInternalError) + except Exception: + error "Failed to unwrap message", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reDeserializationError) proc markDependenciesMet*( - rm: ReliabilityManager, messageIds: seq[MessageID] + rm: ReliabilityManager, messageIds: seq[SdsMessageID] ): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. ## @@ -241,18 +237,19 @@ proc markDependenciesMet*( if not rm.bloomFilter.contains(msgId): rm.bloomFilter.add(msgId) # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? - rm.processIncomingBuffer() + rm.processIncomingBuffer() return ok() - except: - return err(reInternalError) + except Exception: + error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) proc setCallbacks*( rm: ReliabilityManager, - onMessageReady: proc(messageId: MessageID) {.gcsafe.}, - onMessageSent: proc(messageId: MessageID) {.gcsafe.}, + onMessageReady: proc(messageId: SdsMessageID) {.gcsafe.}, + onMessageSent: proc(messageId: SdsMessageID) {.gcsafe.}, onMissingDependencies: - proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}, + proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.}, onPeriodicSync: PeriodicSyncCallback = nil, ) = ## Sets the callback functions for various events in the ReliabilityManager. @@ -268,53 +265,52 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = +proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} = ## Checks and processes unacknowledged messages in the outgoing buffer. withLock rm.lock: let now = getTime() var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] - try: - for unackMsg in rm.outgoingBuffer: - let elapsed = now - unackMsg.sendTime - if elapsed > rm.config.resendInterval: - # Time to attempt resend - if unackMsg.resendAttempts < rm.config.maxResendAttempts: - var updatedMsg = unackMsg - updatedMsg.resendAttempts += 1 - updatedMsg.sendTime = now - newOutgoingBuffer.add(updatedMsg) - else: - if rm.onMessageSent != nil: - rm.onMessageSent(unackMsg.message.messageId) + for unackMsg in rm.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + # Time to attempt resend + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) else: - newOutgoingBuffer.add(unackMsg) + if not rm.onMessageSent.isNil(): + rm.onMessageSent(unackMsg.message.messageId) + else: + newOutgoingBuffer.add(unackMsg) - rm.outgoingBuffer = newOutgoingBuffer - except Exception as e: - logError("Error in checking unacknowledged messages: " & e.msg) + rm.outgoingBuffer = newOutgoingBuffer -proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = +proc periodicBufferSweep( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: - {.gcsafe.}: - try: - rm.checkUnacknowledgedMessages() - rm.cleanBloomFilter() - except Exception as e: - logError("Error in periodic buffer sweep: " & e.msg) + try: + rm.checkUnacknowledgedMessages() + rm.cleanBloomFilter() + except Exception: + error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg() await sleepAsync(chronos.milliseconds(rm.config.bufferSweepInterval.inMilliseconds)) -proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = +proc periodicSyncMessage( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = ## Periodically notifies to send a sync message to maintain connectivity. while true: - {.gcsafe.}: - try: - if rm.onPeriodicSync != nil: - rm.onPeriodicSync() - except Exception as e: - logError("Error in periodic sync: " & e.msg) + try: + if not rm.onPeriodicSync.isNil(): + rm.onPeriodicSync() + except Exception: + error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) proc startPeriodicTasks*(rm: ReliabilityManager) = @@ -328,9 +324,6 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE ## Resets the ReliabilityManager to its initial state. ## ## This procedure clears all buffers and resets the Lamport timestamp. - ## - ## Returns: - ## A Result indicating success or an error if the Bloom filter initialization fails. withLock rm.lock: try: rm.lamportTimestamp = 0 @@ -338,9 +331,9 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE rm.outgoingBuffer.setLen(0) rm.incomingBuffer.setLen(0) rm.bloomFilter = newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, - rm.config.bloomFilterWindow, + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate ) return ok() - except: - return err(reInternalError) + except Exception: + error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) From 1d1c7e683486d901e1311f469f0903ce9c48d506 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Mon, 10 Mar 2025 16:07:00 +0530 Subject: [PATCH 3/3] chore: address comments from review 1 --- src/message.nim | 8 ++- src/protobuf.nim | 10 ++- src/reliability.nim | 145 +++++++++++++++++++++----------------- src/reliability_utils.nim | 21 ++++-- tests/test_bloom.nim | 1 - 5 files changed, 109 insertions(+), 76 deletions(-) diff --git a/src/message.nim b/src/message.nim index 83d1f3a..f9c68c0 100644 --- a/src/message.nim +++ b/src/message.nim @@ -1,4 +1,4 @@ -import std/times +import std/[times, options, sets] type SdsMessageID* = seq[byte] @@ -8,7 +8,7 @@ type messageId*: SdsMessageID lamportTimestamp*: int64 causalHistory*: seq[SdsMessageID] - channelId*: SdsChannelID + channelId*: Option[SdsChannelID] content*: seq[byte] bloomFilter*: seq[byte] @@ -17,6 +17,10 @@ type sendTime*: Time resendAttempts*: int + IncomingMessage* = object + message*: SdsMessage + missingDeps*: HashSet[SdsMessageID] + const DefaultMaxMessageHistory* = 1000 DefaultMaxCausalHistory* = 10 diff --git a/src/protobuf.nim b/src/protobuf.nim index 5229182..4689da2 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -12,7 +12,8 @@ proc encode*(msg: SdsMessage): ProtoBuffer = for hist in msg.causalHistory: pb.write(3, hist) - pb.write(4, msg.channelId) + if msg.channelId.isSome(): + pb.write(4, msg.channelId.get()) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) pb.finish() @@ -36,8 +37,11 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if histResult.isOk: msg.causalHistory = causalHistory - if not ?pb.getField(4, msg.channelId): - return err(ProtobufError.missingRequiredField("channelId")) + var channelId: seq[byte] + if ?pb.getField(4, channelId): + msg.channelId = some(channelId) + else: + msg.channelId = none[SdsChannelID]() if not ?pb.getField(5, msg.content): return err(ProtobufError.missingRequiredField("content")) diff --git a/src/reliability.nim b/src/reliability.nim index ebb533d..a164d7c 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,9 +1,9 @@ -import std/[times, locks, tables, sets, sequtils] +import std/[times, locks, tables, sets, options] import chronos, results, chronicles import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: SdsChannelID, config: ReliabilityConfig = defaultConfig() + channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new ReliabilityManager with the specified channel ID and configuration. ## @@ -13,7 +13,7 @@ proc newReliabilityManager*( ## ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. - if channelId.len == 0: + if not channelId.isSome(): return err(ReliabilityError.reInvalidArgument) try: @@ -25,7 +25,7 @@ proc newReliabilityManager*( messageHistory: @[], bloomFilter: bloomFilter, outgoingBuffer: @[], - incomingBuffer: @[], + incomingBuffer: initTable[SdsMessageID, IncomingMessage](), channelId: channelId, config: config, ) @@ -35,23 +35,27 @@ proc newReliabilityManager*( error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() return err(ReliabilityError.reOutOfMemory) +proc isAcknowledged*( + msg: UnacknowledgedMessage, + causalHistory: seq[SdsMessageID], + rbf: Option[RollingBloomFilter], +): bool = + if msg.message.messageId in causalHistory: + return true + + if rbf.isSome(): + return rbf.get().contains(msg.message.messageId) + + false + proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = - var i = 0 - while i < rm.outgoingBuffer.len: - var acknowledged = false - let outMsg = rm.outgoingBuffer[i] - - # Check if message is in causal history - for msgID in msg.causalHistory: - if outMsg.message.messageId == msgID: - acknowledged = true - break - - # Check bloom filter if not already acknowledged - if not acknowledged and msg.bloomFilter.len > 0: - let bfResult = deserializeBloomFilter(msg.bloomFilter) - if bfResult.isOk(): - var rbf = RollingBloomFilter( + # Parse bloom filter + var rbf: Option[RollingBloomFilter] + if msg.bloomFilter.len > 0: + let bfResult = deserializeBloomFilter(msg.bloomFilter) + if bfResult.isOk(): + rbf = some( + RollingBloomFilter( filter: bfResult.get(), capacity: bfResult.get().capacity, minCapacity: ( @@ -62,17 +66,27 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = ).int, messages: @[], ) - if rbf.contains(outMsg.message.messageId): - acknowledged = true - else: - error "Failed to deserialize bloom filter", error = bfResult.error + ) + else: + error "Failed to deserialize bloom filter", error = bfResult.error + rbf = none[RollingBloomFilter]() + else: + rbf = none[RollingBloomFilter]() - if acknowledged: + # Keep track of indices to delete + var toDelete: seq[int] = @[] + var i = 0 + + while i < rm.outgoingBuffer.len: + let outMsg = rm.outgoingBuffer[i] + if outMsg.isAcknowledged(msg.causalHistory, rbf): if not rm.onMessageSent.isNil(): rm.onMessageSent(outMsg.message.messageId) - rm.outgoingBuffer.delete(i) - else: - inc i + toDelete.add(i) + inc i + + for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices + rm.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID @@ -127,43 +141,36 @@ proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = if rm.incomingBuffer.len == 0: return - # Create dependency map - var dependencies = initTable[SdsMessageID, seq[SdsMessageID]]() - var readyToProcess: seq[SdsMessageID] = @[] var processed = initHashSet[SdsMessageID]() + var readyToProcess = newSeq[SdsMessageID]() - # Build dependency graph and find initially ready messages - for msg in rm.incomingBuffer: - var hasMissingDeps = false - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - hasMissingDeps = true - if depId notin dependencies: - dependencies[depId] = @[] - dependencies[depId].add(msg.messageId) - - if not hasMissingDeps: - readyToProcess.add(msg.messageId) + # Find initially ready messages + for msgId, entry in rm.incomingBuffer: + if entry.missingDeps.len == 0: + readyToProcess.add(msgId) while readyToProcess.len > 0: let msgId = readyToProcess.pop() if msgId in processed: continue - # Process this message - for msg in rm.incomingBuffer: - if msg.messageId == msgId: - rm.addToHistory(msg.messageId) - if not rm.onMessageReady.isNil(): - rm.onMessageReady(msg.messageId) - processed.incl(msgId) + if msgId in rm.incomingBuffer: + rm.addToHistory(msgId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msgId) + processed.incl(msgId) - # Add any dependent messages that might now be ready - if msgId in dependencies: - readyToProcess.add(dependencies[msgId]) - break + # Update dependencies for remaining messages + for remainingId, entry in rm.incomingBuffer: + if remainingId notin processed: + if msgId in entry.missingDeps: + rm.incomingBuffer[remainingId].missingDeps.excl(msgId) + if rm.incomingBuffer[remainingId].missingDeps.len == 0: + readyToProcess.add(remainingId) - rm.incomingBuffer = rm.incomingBuffer.filterIt(it.messageId notin processed) + # Remove processed messages + for msgId in processed: + rm.incomingBuffer.del(msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] @@ -179,7 +186,7 @@ proc unwrapReceivedMessage*( let msg = deserializeMessage(message).valueOr: return err(ReliabilityError.reDeserializationError) - if rm.bloomFilter.contains(msg.messageId): + if msg.messageId in rm.messageHistory: return ok((msg.content, @[])) rm.bloomFilter.add(msg.messageId) @@ -190,21 +197,21 @@ proc unwrapReceivedMessage*( # Review ACK status for outgoing messages rm.reviewAckStatus(msg) - var missingDeps: seq[SdsMessageID] = @[] - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - missingDeps.add(depId) + var missingDeps = rm.checkDependencies(msg.causalHistory) if missingDeps.len == 0: # Check if any dependencies are still in incoming buffer var depsInBuffer = false - for bufferedMsg in rm.incomingBuffer: - if bufferedMsg.messageId in msg.causalHistory: + for msgId, entry in rm.incomingBuffer.pairs(): + if msgId in msg.causalHistory: depsInBuffer = true break if depsInBuffer: - rm.incomingBuffer.add(msg) + rm.incomingBuffer[msg.messageId] = IncomingMessage( + message: msg, + missingDeps: initHashSet[SdsMessageID]() + ) else: # All dependencies met, add to history rm.addToHistory(msg.messageId) @@ -212,7 +219,10 @@ proc unwrapReceivedMessage*( if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) else: - rm.incomingBuffer.add(msg) + rm.incomingBuffer[msg.messageId] = IncomingMessage( + message: msg, + missingDeps: missingDeps.toHashSet() + ) if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps) @@ -238,6 +248,11 @@ proc markDependenciesMet*( rm.bloomFilter.add(msgId) # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + # Update any pending messages that depend on this one + for pendingId, entry in rm.incomingBuffer: + if msgId in entry.missingDeps: + rm.incomingBuffer[pendingId].missingDeps.excl(msgId) + rm.processIncomingBuffer() return ok() except Exception: @@ -329,7 +344,7 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE rm.lamportTimestamp = 0 rm.messageHistory.setLen(0) rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) + rm.incomingBuffer.clear() rm.bloomFilter = newRollingBloomFilter( rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate ) diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index a8d376f..ac05dab 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -1,4 +1,4 @@ -import std/[times, locks] +import std/[times, locks, options] import chronicles import ./[rolling_bloom_filter, message] @@ -20,8 +20,8 @@ type messageHistory*: seq[SdsMessageID] bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] - incomingBuffer*: seq[SdsMessage] - channelId*: SdsChannelID + incomingBuffer*: Table[SdsMessageID, IncomingMessage] + channelId*: Option[SdsChannelID] config*: ReliabilityConfig lock*: Lock onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.} @@ -59,7 +59,7 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = try: withLock rm.lock: rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) + rm.incomingBuffer.clear() rm.messageHistory.setLen(0) except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() @@ -84,6 +84,15 @@ proc updateLamportTimestamp*( proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] = result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] +proc checkDependencies*( + rm: ReliabilityManager, deps: seq[SdsMessageID] +): seq[SdsMessageID] = + var missingDeps: seq[SdsMessageID] = @[] + for depId in deps: + if depId notin rm.messageHistory: + missingDeps.add(depId) + return missingDeps + proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] = withLock rm.lock: result = rm.messageHistory @@ -92,6 +101,8 @@ proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] = withLock rm.lock: result = rm.outgoingBuffer -proc getIncomingBuffer*(rm: ReliabilityManager): seq[SdsMessage] = +proc getIncomingBuffer*( + rm: ReliabilityManager +): Table[SdsMessageID, message.IncomingMessage] = withLock rm.lock: result = rm.incomingBuffer diff --git a/tests/test_bloom.nim b/tests/test_bloom.nim index 540735d..ad88bba 100644 --- a/tests/test_bloom.nim +++ b/tests/test_bloom.nim @@ -1,7 +1,6 @@ import unittest, results, strutils import ../src/bloom from random import rand, randomize -import ../src/[message, protobuf, protobufutil, reliability_utils, rolling_bloom_filter] suite "bloom filter": setup: