diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 8fd52a6..7b68a86 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -591,54 +591,82 @@ suite "Multi-Channel ReliabilityManager Tests": msgId1 notin history2 msgId2 notin history1 - # test "multi-channel callbacks": - # var readyMessages: seq[(SdsMessageID, SdsChannelID)] = @[] - # var sentMessages: seq[(SdsMessageID, SdsChannelID)] = @[] - # var missingDeps: seq[(SdsMessageID, seq[SdsMessageID], SdsChannelID)] = @[] + test "multi-channel callbacks": + var readyMessageCount = 0 + var sentMessageCount = 0 + var missingDepsCount = 0 - # rm.setCallbacks( - # proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = - # readyMessages.add((messageId, channelId)), - # proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = - # sentMessages.add((messageId, channelId)), - # proc(messageId: SdsMessageID, deps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = - # missingDeps.add((messageId, deps, channelId)) - # ) + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + readyMessageCount += 1, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + sentMessageCount += 1, + proc(messageId: SdsMessageID, deps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + missingDepsCount += 1 + ) - # let channel1 = "callback-channel-1" - # let channel2 = "callback-channel-2" + let channel1 = "callback-channel-1" + let channel2 = "callback-channel-2" - # # Create messages in different channels - # let msg1 = @[byte(1)] - # let msgId1 = "callback-msg1" - # let wrapped1 = rm.wrapOutgoingMessage(msg1, msgId1, channel1) - # check wrapped1.isOk() + # Send messages from both channels + let msg1 = @[byte(1)] + let msgId1 = "callback-msg1" + let wrapped1 = rm.wrapOutgoingMessage(msg1, msgId1, channel1) + check wrapped1.isOk() - # let msg2 = @[byte(2)] - # let msgId2 = "callback-msg2" - # let wrapped2 = rm.wrapOutgoingMessage(msg2, msgId2, channel2) - # check wrapped2.isOk() + let msg2 = @[byte(2)] + let msgId2 = "callback-msg2" + let wrapped2 = rm.wrapOutgoingMessage(msg2, msgId2, channel2) + check wrapped2.isOk() - # # Process messages - should trigger callbacks with correct channel IDs - # discard rm.unwrapReceivedMessage(wrapped1.get()) - # discard rm.unwrapReceivedMessage(wrapped2.get()) + # Create acknowledgment messages that include our message IDs in causal history + # to trigger sent callbacks + let ackMsg1 = SdsMessage( + messageId: "ack1", + lamportTimestamp: rm.channels[channel1].lamportTimestamp + 1, + causalHistory: @[msgId1], # Acknowledge msg1 + channelId: channel1, + content: @[byte(100)], + bloomFilter: @[], + ) - # check: - # readyMessages.len == 2 - # (msgId1, channel1) in readyMessages - # (msgId2, channel2) in readyMessages + let ackMsg2 = SdsMessage( + messageId: "ack2", + lamportTimestamp: rm.channels[channel2].lamportTimestamp + 1, + causalHistory: @[msgId2], # Acknowledge msg2 + channelId: channel2, + content: @[byte(101)], + bloomFilter: @[], + ) - # test "channel-specific dependency management": - # let channel1 = "dep-channel-1" - # let depIds = @["dep1", "dep2", "dep3"] + let serializedAck1 = serializeMessage(ackMsg1) + let serializedAck2 = serializeMessage(ackMsg2) + check: + serializedAck1.isOk() + serializedAck2.isOk() - # # Mark dependencies as met for specific channel - # check rm.markDependenciesMet(depIds, channel1).isOk() + # Process acknowledgment messages - should trigger callbacks + discard rm.unwrapReceivedMessage(serializedAck1.get()) + discard rm.unwrapReceivedMessage(serializedAck2.get()) - # # Dependencies should only affect the specified channel - # let channel2 = "dep-channel-2" - # check rm.ensureChannel(channel2).isOk() + check: + readyMessageCount == 2 # Both ack messages should trigger ready callbacks + sentMessageCount == 2 # Both original messages should be marked as sent + missingDepsCount == 0 # No missing dependencies - # # Dependencies in channel1 should not affect channel2 - # check rm.channels[channel1].bloomFilter.contains("dep1") - # check not rm.channels[channel2].bloomFilter.contains("dep1") + test "channel-specific dependency management": + let channel1 = "dep-channel-1" + let channel2 = "dep-channel-2" + let depIds = @["dep1", "dep2", "dep3"] + + # Ensure both channels exist first + check rm.ensureChannel(channel1).isOk() + check rm.ensureChannel(channel2).isOk() + + # Mark dependencies as met for specific channel + check rm.markDependenciesMet(depIds, channel1).isOk() + + # Dependencies should only affect the specified channel + # Dependencies in channel1 should not affect channel2 + check rm.channels[channel1].bloomFilter.contains("dep1") + check not rm.channels[channel2].bloomFilter.contains("dep1")