diff --git a/src/common.nim b/src/common.nim index 4887018..7e098d8 100644 --- a/src/common.nim +++ b/src/common.nim @@ -36,6 +36,8 @@ type maxCausalHistory*: int resendInterval*: times.Duration maxResendAttempts*: int + syncMessageInterval*: times.Duration + bufferSweepInterval*: times.Duration ReliabilityManager* = ref object lamportTimestamp*: int64 @@ -65,6 +67,8 @@ const DefaultBloomFilterWindow* = initDuration(hours = 1) DefaultMaxMessageHistory* = 1000 DefaultMaxCausalHistory* = 10 - DefaultResendInterval* = initDuration(seconds = 30) + DefaultResendInterval* = initDuration(seconds = 60) DefaultMaxResendAttempts* = 5 + DefaultSyncMessageInterval* = initDuration(seconds = 30) + DefaultBufferSweepInterval* = initDuration(seconds = 60) MaxMessageSize* = 1024 * 1024 # 1 MB \ No newline at end of file diff --git a/src/reliability.nim b/src/reliability.nim index 587133b..fbd64ce 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -3,6 +3,7 @@ import chronos, results import ./common import ./utils import ./protobuf +import std/[tables, sets] proc defaultConfig*(): ReliabilityConfig = ## Creates a default configuration for the ReliabilityManager. @@ -16,7 +17,9 @@ proc defaultConfig*(): ReliabilityConfig = maxMessageHistory: DefaultMaxMessageHistory, maxCausalHistory: DefaultMaxCausalHistory, resendInterval: DefaultResendInterval, - maxResendAttempts: DefaultMaxResendAttempts + maxResendAttempts: DefaultMaxResendAttempts, + syncMessageInterval: DefaultSyncMessageInterval, + bufferSweepInterval: DefaultBufferSweepInterval ) proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defaultConfig()): Result[ReliabilityManager, ReliabilityError] = @@ -128,12 +131,102 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: )) # Add to causal history and bloom filter - rm.addToBloomAndHistory(msg) + rm.bloomFilter.add(msg.messageId) + rm.addToHistory(msg.messageId) return serializeMessage(msg) except: return err(reInternalError) +proc processIncomingBuffer(rm: ReliabilityManager) = + withLock rm.lock: + if rm.incomingBuffer.len == 0: + return + + # Create dependency map + var dependencies = initTable[MessageID, seq[MessageID]]() + var readyToProcess: seq[MessageID] = @[] + + # Build dependency graph and find initially ready messages + for msg in rm.incomingBuffer: + var hasMissingDeps = false + for depId in msg.causalHistory: + if not rm.bloomFilter.contains(depId): + hasMissingDeps = true + if depId notin dependencies: + dependencies[depId] = @[] + dependencies[depId].add(msg.messageId) + + if not hasMissingDeps: + readyToProcess.add(msg.messageId) + + # Process ready messages and their dependents + var newIncomingBuffer: seq[Message] = @[] + var processed = initHashSet[MessageID]() + + while readyToProcess.len > 0: + let msgId = readyToProcess.pop() + if msgId in processed: + continue + + # Process this message + for msg in rm.incomingBuffer: + if msg.messageId == msgId: + rm.addToHistory(msg.messageId) + if rm.onMessageReady != nil: + rm.onMessageReady(msg.messageId) + processed.incl(msgId) + + # Add any dependent messages that might now be ready + if msgId in dependencies: + for dependentId in dependencies[msgId]: + readyToProcess.add(dependentId) + break + + # Update incomingBuffer with remaining messages + for msg in rm.incomingBuffer: + if msg.messageId notin processed: + newIncomingBuffer.add(msg) + + rm.incomingBuffer = newIncomingBuffer + # withLock rm.lock: + # var processedAny = true + # while processedAny: + # processedAny = false + # var newIncomingBuffer: seq[Message] = @[] + + # for msg in rm.incomingBuffer: + # var allDependenciesMet = true + # for depId in msg.causalHistory: + # if not rm.bloomFilter.contains(depId): + # allDependenciesMet = false + # break + + # # Check if dependency is still in incoming buffer + # for bufferedMsg in rm.incomingBuffer: + # if bufferedMsg.messageId == depId: + # allDependenciesMet = false + # break + + # if not allDependenciesMet: + # break + + # if allDependenciesMet: + # # Process message + # rm.addToHistory(msg.messageId) + # if rm.onMessageReady != nil: + # rm.onMessageReady(msg.messageId) + # processedAny = true + # else: + # # Keep in buffer + # newIncomingBuffer.add(msg) + + # rm.incomingBuffer = newIncomingBuffer + + # # Exit if no messages were processed in this pass + # if not processedAny: + # break + proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] = ## Unwraps a received message and processes its reliability metadata. ## @@ -142,41 +235,52 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[ ## ## Returns: ## A Result containing either a tuple with the processed message and missing dependencies, or an error. - withLock rm.lock: - try: - let msgResult = deserializeMessage(message) - if not msgResult.isOk: - return err(msgResult.error) - - let msg = msgResult.get - if rm.bloomFilter.contains(msg.messageId): - return ok((msg.content, @[])) + try: + let msgResult = deserializeMessage(message) + if not msgResult.isOk: + return err(msgResult.error) + + let msg = msgResult.get + if rm.bloomFilter.contains(msg.messageId): + return ok((msg.content, @[])) - # Update Lamport timestamp - rm.updateLamportTimestamp(msg.lamportTimestamp) + rm.bloomFilter.add(msg.messageId) - # Review ACK status for outgoing messages - rm.reviewAckStatus(msg) + # Update Lamport timestamp + rm.updateLamportTimestamp(msg.lamportTimestamp) - var missingDeps: seq[MessageID] = @[] - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - missingDeps.add(depId) + # Review ACK status for outgoing messages + rm.reviewAckStatus(msg) - if missingDeps.len == 0: + var missingDeps: seq[MessageID] = @[] + for depId in msg.causalHistory: + if not rm.bloomFilter.contains(depId): + missingDeps.add(depId) + + if missingDeps.len == 0: + # Check if any dependencies are still in incoming buffer + var depsInBuffer = false + for bufferedMsg in rm.incomingBuffer: + if bufferedMsg.messageId in msg.causalHistory: + depsInBuffer = true + break + if depsInBuffer: + rm.incomingBuffer.add(msg) + else: # All dependencies met, add to history - rm.addToBloomAndHistory(msg) + rm.addToHistory(msg.messageId) + rm.processIncomingBuffer() if rm.onMessageReady != nil: rm.onMessageReady(msg.messageId) - else: - # Buffer message and request missing dependencies - rm.incomingBuffer.add(msg) - if rm.onMissingDependencies != nil: - rm.onMissingDependencies(msg.messageId, missingDeps) + else: + # Buffer message and request missing dependencies + rm.incomingBuffer.add(msg) + if rm.onMissingDependencies != nil: + rm.onMissingDependencies(msg.messageId, missingDeps) - return ok((msg.content, missingDeps)) - except: - return err(reInternalError) + return ok((msg.content, missingDeps)) + except: + return err(reInternalError) proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. @@ -186,39 +290,17 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R ## ## Returns: ## A Result indicating success or an error. - withLock rm.lock: - try: - var processedMessages: seq[Message] = @[] - var newIncomingBuffer: seq[Message] = @[] - - # Add all messageIds to both bloom filter and causal history - for msgId in messageIds: - if not rm.bloomFilter.contains(msgId): - rm.bloomFilter.add(msgId) - rm.messageHistory.add(msgId) - - for msg in rm.incomingBuffer: - var allDependenciesMet = true - for depId in msg.causalHistory: - if depId notin messageIds and not rm.bloomFilter.contains(depId): - allDependenciesMet = false - break - - if allDependenciesMet: - processedMessages.add(msg) - rm.addToBloomAndHistory(msg) - else: - newIncomingBuffer.add(msg) - - rm.incomingBuffer = newIncomingBuffer - - for msg in processedMessages: - if rm.onMessageReady != nil: - rm.onMessageReady(msg.messageId) - - return ok() - except: - return err(reInternalError) + try: + # Add all messageIds to bloom filter + for msgId in messageIds: + if not rm.bloomFilter.contains(msgId): + rm.bloomFilter.add(msgId) + # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + rm.processIncomingBuffer() + + return ok() + except: + return err(reInternalError) proc setCallbacks*(rm: ReliabilityManager, onMessageReady: proc(messageId: MessageID) {.gcsafe.}, @@ -262,8 +344,6 @@ proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = ## Periodically sweeps the buffer to clean up and check unacknowledged messages. - ## - ## This is an internal function and should not be called directly. while true: {.gcsafe.}: try: @@ -271,7 +351,7 @@ proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledErr rm.cleanBloomFilter() except Exception as e: logError("Error in periodic buffer sweep: " & e.msg) - await sleepAsync(chronos.seconds(5)) + await sleepAsync(chronos.seconds(rm.config.bufferSweepInterval.inSeconds)) proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = ## Periodically notifies to send a sync message to maintain connectivity. @@ -282,7 +362,7 @@ proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledErr rm.onPeriodicSync() except Exception as e: logError("Error in periodic sync: " & e.msg) - await sleepAsync(chronos.seconds(30)) + await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) proc startPeriodicTasks*(rm: ReliabilityManager) = ## Starts the periodic tasks for buffer sweeping and sync message sending. diff --git a/src/utils.nim b/src/utils.nim index e5ad83e..e5b3ccf 100644 --- a/src/utils.nim +++ b/src/utils.nim @@ -75,11 +75,10 @@ proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} = except Exception as e: logError("Failed to clean ReliabilityManager bloom filter: " & e.msg) -proc addToBloomAndHistory*(rm: ReliabilityManager, msg: Message) = - rm.messageHistory.add(msg.messageId) +proc addToHistory*(rm: ReliabilityManager, msgId: MessageID) = + rm.messageHistory.add(msgId) if rm.messageHistory.len > rm.config.maxMessageHistory: rm.messageHistory.delete(0) - rm.bloomFilter.add(msg.messageId) proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) = rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index ca7f220..f592b0b 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -1,10 +1,11 @@ -import unittest, results, chronos, chronicles +import unittest, results, chronos import ../src/reliability import ../src/common import ../src/protobuf import ../src/utils -suite "ReliabilityManager": +# Core functionality tests +suite "Core Operations": var rm: ReliabilityManager setup: @@ -40,7 +41,54 @@ suite "ReliabilityManager": unwrapped == msg missingDeps.len == 0 - test "marking dependencies": + test "message ordering": + # Create messages with different timestamps + let msg1 = Message( + messageId: "msg1", + lamportTimestamp: 1, + causalHistory: @[], + channelId: "testChannel", + content: @[byte(1)], + bloomFilter: @[] + ) + + let msg2 = Message( + messageId: "msg2", + lamportTimestamp: 5, + causalHistory: @[], + channelId: "testChannel", + content: @[byte(2)], + bloomFilter: @[] + ) + + let serialized1 = serializeMessage(msg1) + let serialized2 = serializeMessage(msg2) + check: + serialized1.isOk() + serialized2.isOk() + + # Process out of order + discard rm.unwrapReceivedMessage(serialized2.get()) + let timestamp1 = rm.lamportTimestamp + discard rm.unwrapReceivedMessage(serialized1.get()) + let timestamp2 = rm.lamportTimestamp + + check timestamp2 > timestamp1 + +# Reliability mechanism tests +suite "Reliability Mechanisms": + var rm: ReliabilityManager + + setup: + let rmResult = newReliabilityManager("testChannel") + check rmResult.isOk() + rm = rmResult.get() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "dependency detection and resolution": var messageReadyCount = 0 var messageSentCount = 0 var missingDepsCount = 0 @@ -51,49 +99,70 @@ suite "ReliabilityManager": proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1 ) - # Create dependency IDs that aren't in bloom filter yet + # Create dependency chain: msg3 -> msg2 -> msg1 let id1 = "msg1" let id2 = "msg2" + let id3 = "msg3" + + # Create messages with dependencies + let msg2 = Message( + messageId: id2, + lamportTimestamp: 2, + causalHistory: @[id1], # msg2 depends on msg1 + channelId: "testChannel", + content: @[byte(2)], + bloomFilter: @[] + ) - # Create message depending on these IDs let msg3 = Message( - messageId: "msg3", - lamportTimestamp: 1, - causalHistory: @[id1, id2], # Depends on messages we haven't seen + messageId: id3, + lamportTimestamp: 3, + causalHistory: @[id1, id2], # msg3 depends on both msg1 and msg2 channelId: "testChannel", content: @[byte(3)], bloomFilter: @[] ) - let serializedMsg3 = serializeMessage(msg3) - check serializedMsg3.isOk() + let serialized2 = serializeMessage(msg2) + let serialized3 = serializeMessage(msg3) + check: + serialized2.isOk() + serialized3.isOk() - # Process message - should identify missing dependencies - let unwrapResult = rm.unwrapReceivedMessage(serializedMsg3.get()) - check unwrapResult.isOk() - let (_, missingDeps) = unwrapResult.get() + # First try processing msg3 (which depends on msg2 which depends on msg1) + let unwrapResult3 = rm.unwrapReceivedMessage(serialized3.get()) + check unwrapResult3.isOk() + let (_, missingDeps3) = unwrapResult3.get() check: - missingDepsCount == 1 - missingDeps.len == 2 - id1 in missingDeps - id2 in missingDeps + missingDepsCount == 1 # Should trigger missing deps callback + missingDeps3.len == 2 # Should be missing both msg1 and msg2 + id1 in missingDeps3 + id2 in missingDeps3 - # Mark dependencies as met - let markResult = rm.markDependenciesMet(missingDeps) - check markResult.isOk() + # Then try processing msg2 (which only depends on msg1) + let unwrapResult2 = rm.unwrapReceivedMessage(serialized2.get()) + check unwrapResult2.isOk() + let (_, missingDeps2) = unwrapResult2.get() + + check: + missingDepsCount == 2 # Should have triggered another missing deps callback + missingDeps2.len == 1 # Should only be missing msg1 + id1 in missingDeps2 + messageReadyCount == 0 # No messages should be ready yet - # Process message again - should now be ready - let reprocessResult = rm.unwrapReceivedMessage(serializedMsg3.get()) - check reprocessResult.isOk() - let (_, remainingDeps) = reprocessResult.get() + # Mark first dependency (msg1) as met + let markResult1 = rm.markDependenciesMet(@[id1]) + check markResult1.isOk() + + let incomingBuffer = rm.getIncomingBuffer() check: - remainingDeps.len == 0 - messageReadyCount == 1 - missingDepsCount == 1 + incomingBuffer.len == 0 + messageReadyCount == 2 # Both msg2 and msg3 should be ready + missingDepsCount == 2 # Should still be 2 from the initial missing deps - test "callbacks work correctly": + test "acknowledgment via causal history": var messageReadyCount = 0 var messageSentCount = 0 var missingDepsCount = 0 @@ -104,7 +173,7 @@ suite "ReliabilityManager": proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1 ) - # First send our own message + # Send our message let msg1 = @[byte(1)] let id1 = "msg1" let wrap1 = rm.wrapOutgoingMessage(msg1, id1) @@ -131,7 +200,7 @@ suite "ReliabilityManager": messageReadyCount == 1 # For msg2 which we "received" messageSentCount == 1 # For msg1 which was acknowledged via causal history - test "bloom filter acknowledgment": + test "acknowledgment via bloom filter": var messageSentCount = 0 rm.setCallbacks( @@ -140,7 +209,7 @@ suite "ReliabilityManager": proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard ) - # First send our own message + # Send our message let msg1 = @[byte(1)] let id1 = "msg1" let wrap1 = rm.wrapOutgoingMessage(msg1, id1) @@ -174,23 +243,20 @@ suite "ReliabilityManager": check messageSentCount == 1 # Our message should be acknowledged via bloom filter - test "periodic sync callback": - var syncCallCount = 0 - rm.setCallbacks( - proc(messageId: MessageID) {.gcsafe.} = discard, - proc(messageId: MessageID) {.gcsafe.} = discard, - proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard, - proc() {.gcsafe.} = syncCallCount += 1 - ) +# Periodic task & Buffer management tests +suite "Periodic Tasks & Buffer Management": + var rm: ReliabilityManager - rm.startPeriodicTasks() - # Sleep briefly to allow periodic tasks to run - waitFor sleepAsync(chronos.seconds(1)) - rm.cleanup() - - check syncCallCount > 0 + setup: + let rmResult = newReliabilityManager("testChannel") + check rmResult.isOk() + rm = rmResult.get() - test "buffer management": + teardown: + if not rm.isNil: + rm.cleanup() + + test "outgoing buffer management": var messageSentCount = 0 rm.setCallbacks( @@ -199,7 +265,7 @@ suite "ReliabilityManager": proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard ) - # Add multiple messages to outgoing buffer + # Add multiple messages for i in 0..5: let msg = @[byte(i)] let id = "msg" & $i @@ -230,21 +296,136 @@ suite "ReliabilityManager": finalBuffer.len == 3 # Should have removed acknowledged messages messageSentCount == 3 # Should have triggered sent callback for acknowledged messages - test "handles empty message": - let msg: seq[byte] = @[] - let msgId = "test-empty-msg" - let wrappedResult = rm.wrapOutgoingMessage(msg, msgId) - check: - not wrappedResult.isOk() - wrappedResult.error == reInvalidArgument + test "periodic buffer sweep": + var messageSentCount = 0 + + rm.setCallbacks( + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard + ) - test "handles message too large": - let msg = newSeq[byte](MaxMessageSize + 1) - let msgId = "test-large-msg" - let wrappedResult = rm.wrapOutgoingMessage(msg, msgId) + # Add message to buffer + let msg1 = @[byte(1)] + let id1 = "msg1" + let wrap1 = rm.wrapOutgoingMessage(msg1, id1) + check wrap1.isOk() + + let initialBuffer = rm.getOutgoingBuffer() + check initialBuffer[0].resendAttempts == 0 + + rm.startPeriodicTasks() + waitFor sleepAsync(chronos.seconds(6)) + + let finalBuffer = rm.getOutgoingBuffer() check: - not wrappedResult.isOk() - wrappedResult.error == reMessageTooLarge + finalBuffer.len == 1 + finalBuffer[0].resendAttempts > 0 + + test "periodic sync": + var syncCallCount = 0 + rm.setCallbacks( + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard, + proc() {.gcsafe.} = syncCallCount += 1 + ) + + rm.startPeriodicTasks() + waitFor sleepAsync(chronos.seconds(1)) + rm.cleanup() + + check syncCallCount > 0 + +# Special cases handling +suite "Special Cases Handling": + var rm: ReliabilityManager + + setup: + let rmResult = newReliabilityManager("testChannel") + check rmResult.isOk() + rm = rmResult.get() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "message history limits": + # Add messages up to max history size + for i in 0..rm.config.maxMessageHistory + 5: + let msg = @[byte(i)] + let id = "msg" & $i + let wrap = rm.wrapOutgoingMessage(msg, id) + check wrap.isOk() + + let history = rm.getMessageHistory() + check: + history.len <= rm.config.maxMessageHistory + history[^1] == "msg" & $(rm.config.maxMessageHistory + 5) + + test "invalid bloom filter handling": + let msgInvalid = Message( + messageId: "invalid-bf", + lamportTimestamp: 1, + causalHistory: @[], + channelId: "testChannel", + content: @[byte(1)], + bloomFilter: @[1.byte, 2.byte, 3.byte] # Invalid filter data + ) + + let serializedInvalid = serializeMessage(msgInvalid) + check serializedInvalid.isOk() + + # Should handle invalid bloom filter gracefully + let result = rm.unwrapReceivedMessage(serializedInvalid.get()) + check: + result.isOk() + result.get()[1].len == 0 # No missing dependencies + + test "duplicate message handling": + var messageReadyCount = 0 + rm.setCallbacks( + proc(messageId: MessageID) {.gcsafe.} = messageReadyCount += 1, + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard + ) + + # Create and process a message + let msg = Message( + messageId: "dup-msg", + lamportTimestamp: 1, + causalHistory: @[], + channelId: "testChannel", + content: @[byte(1)], + bloomFilter: @[] + ) + + let serialized = serializeMessage(msg) + check serialized.isOk() + + # Process same message twice + let result1 = rm.unwrapReceivedMessage(serialized.get()) + check result1.isOk() + let result2 = rm.unwrapReceivedMessage(serialized.get()) + check: + result2.isOk() + result2.get()[1].len == 0 # No missing deps on second process + messageReadyCount == 1 # Message should only be processed once + + test "error handling": + # Empty message + let emptyMsg: seq[byte] = @[] + let emptyResult = rm.wrapOutgoingMessage(emptyMsg, "empty") + check: + not emptyResult.isOk() + emptyResult.error == reInvalidArgument + + # Oversized message + let largeMsg = newSeq[byte](MaxMessageSize + 1) + let largeResult = rm.wrapOutgoingMessage(largeMsg, "large") + check: + not largeResult.isOk() + largeResult.error == reMessageTooLarge suite "cleanup": test "cleanup works correctly":