From 97e2f681b94bb69b3c02ae377bb33972671c4fd0 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:04:24 +0400 Subject: [PATCH] feat: review ack status --- src/common.nim | 1 + src/protobuf.nim | 71 ++++++++++++++--- src/reliability.nim | 81 ++++++++++++++++--- src/utils.nim | 6 ++ tests/test_reliability.nim | 154 ++++++++++++++++++++++++++++--------- 5 files changed, 256 insertions(+), 57 deletions(-) diff --git a/src/common.nim b/src/common.nim index 9ebac22..4887018 100644 --- a/src/common.nim +++ b/src/common.nim @@ -10,6 +10,7 @@ type causalHistory*: seq[MessageID] channelId*: string content*: seq[byte] + bloomFilter*: seq[byte] UnacknowledgedMessage* = object message*: Message diff --git a/src/protobuf.nim b/src/protobuf.nim index b674511..ef07a21 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -2,10 +2,7 @@ import ./protobufutil import ./common import libp2p/protobuf/minprotobuf import std/options - -proc toString(bytes: seq[byte]): string = - result = newString(bytes.len) - copyMem(result[0].addr, bytes[0].unsafeAddr, bytes.len) +import "../nim-bloom/src/bloom" proc toBytes(s: string): seq[byte] = result = newSeq[byte](s.len) @@ -22,6 +19,7 @@ proc encode*(msg: Message): ProtoBuffer = pb.write(4, msg.channelId) pb.write(5, msg.content) + pb.write(6, msg.bloomFilter) pb.finish() pb @@ -39,11 +37,10 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] = msg.lamportTimestamp = int64(timestamp) # Decode causal history - var histories: seq[seq[byte]] - for histBytes in histories: - let hist = histBytes.toString - if hist notin msg.causalHistory: # Avoid duplicate entries - msg.causalHistory.add(hist) + var causalHistory: seq[string] + let histResult = pb.getRepeatedField(3, causalHistory) + if histResult.isOk: + msg.causalHistory = causalHistory if not ?pb.getField(4, msg.channelId): return err(ProtobufError.missingRequiredField("channelId")) @@ -51,6 +48,9 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] = if not ?pb.getField(5, msg.content): return err(ProtobufError.missingRequiredField("content")) + if not ?pb.getField(6, msg.bloomFilter): + msg.bloomFilter = @[] # Empty if not present + ok(msg) proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] = @@ -67,5 +67,58 @@ proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] = ok(msgResult.get) else: err(reSerializationError) + except: + err(reDeserializationError) + +proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] = + try: + var pb = initProtoBuffer() + + # Convert intArray to bytes + var bytes = newSeq[byte](filter.intArray.len * sizeof(int)) + for i, val in filter.intArray: + let start = i * sizeof(int) + copyMem(addr bytes[start], unsafeAddr val, sizeof(int)) + + pb.write(1, bytes) + pb.write(2, uint64(filter.capacity)) + pb.write(3, uint64(filter.errorRate * 1_000_000)) + pb.write(4, uint64(filter.kHashes)) + pb.write(5, uint64(filter.mBits)) + + pb.finish() + ok(pb.buffer) + except: + err(reSerializationError) + +proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] = + if data.len == 0: + return err(reDeserializationError) + + try: + let pb = initProtoBuffer(data) + var bytes: seq[byte] + var cap, errRate, kHashes, mBits: uint64 + + if not pb.getField(1, bytes).get() or + not pb.getField(2, cap).get() or + not pb.getField(3, errRate).get() or + not pb.getField(4, kHashes).get() or + not pb.getField(5, mBits).get(): + return err(reDeserializationError) + + # Convert bytes back to intArray + var intArray = newSeq[int](bytes.len div sizeof(int)) + for i in 0 ..< intArray.len: + let start = i * sizeof(int) + copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int)) + + ok(BloomFilter( + intArray: intArray, + capacity: int(cap), + errorRate: float(errRate) / 1_000_000, + kHashes: int(kHashes), + mBits: int(mBits) + )) except: err(reDeserializationError) \ No newline at end of file diff --git a/src/reliability.nim b/src/reliability.nim index 79c586e..587133b 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -52,6 +52,39 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau except: return err(reOutOfMemory) +proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = + var i = 0 + while i < rm.outgoingBuffer.len: + var acknowledged = false + let outMsg = rm.outgoingBuffer[i] + + # Check if message is in causal history + for msgID in msg.causalHistory: + if outMsg.message.messageId == msgID: + acknowledged = true + break + + # Check bloom filter if not already acknowledged + if not acknowledged and msg.bloomFilter.len > 0: + let bfResult = deserializeBloomFilter(msg.bloomFilter) + if bfResult.isOk: + var rbf = RollingBloomFilter( + filter: bfResult.get(), + window: rm.bloomFilter.window, + messages: @[] + ) + if rbf.contains(outMsg.message.messageId): + acknowledged = true + else: + logError("Failed to deserialize bloom filter") + + if acknowledged: + if rm.onMessageSent != nil: + rm.onMessageSent(outMsg.message.messageId) + rm.outgoingBuffer.delete(i) + else: + inc i + proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: MessageID): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## @@ -68,16 +101,35 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: withLock rm.lock: try: rm.updateLamportTimestamp(getTime().toUnix) + + # Serialize current bloom filter + var bloomBytes: seq[byte] + let bfResult = serializeBloomFilter(rm.bloomFilter.filter) + if bfResult.isErr: + logError("Failed to serialize bloom filter") + bloomBytes = @[] + else: + bloomBytes = bfResult.get() + let msg = Message( messageId: messageId, lamportTimestamp: rm.lamportTimestamp, causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), channelId: rm.channelId, - content: message + content: message, + bloomFilter: bloomBytes ) - rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)) - # rm.messageHistory.add(messageId) - # rm.bloomFilter.add(messageId) + + # Add to outgoing buffer + rm.outgoingBuffer.add(UnacknowledgedMessage( + message: msg, + sendTime: getTime(), + resendAttempts: 0 + )) + + # Add to causal history and bloom filter + rm.addToBloomAndHistory(msg) + return serializeMessage(msg) except: return err(reInternalError) @@ -100,21 +152,24 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[ if rm.bloomFilter.contains(msg.messageId): return ok((msg.content, @[])) - rm.bloomFilter.add(msg.messageId) + # Update Lamport timestamp rm.updateLamportTimestamp(msg.lamportTimestamp) + # Review ACK status for outgoing messages + rm.reviewAckStatus(msg) + var missingDeps: seq[MessageID] = @[] for depId in msg.causalHistory: if not rm.bloomFilter.contains(depId): missingDeps.add(depId) if missingDeps.len == 0: - rm.messageHistory.add(msg.messageId) - if rm.messageHistory.len > rm.config.maxMessageHistory: - rm.messageHistory.delete(0) + # All dependencies met, add to history + rm.addToBloomAndHistory(msg) if rm.onMessageReady != nil: rm.onMessageReady(msg.messageId) else: + # Buffer message and request missing dependencies rm.incomingBuffer.add(msg) if rm.onMissingDependencies != nil: rm.onMissingDependencies(msg.messageId, missingDeps) @@ -136,6 +191,12 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R var processedMessages: seq[Message] = @[] var newIncomingBuffer: seq[Message] = @[] + # Add all messageIds to both bloom filter and causal history + for msgId in messageIds: + if not rm.bloomFilter.contains(msgId): + rm.bloomFilter.add(msgId) + rm.messageHistory.add(msgId) + for msg in rm.incomingBuffer: var allDependenciesMet = true for depId in msg.causalHistory: @@ -145,15 +206,13 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R if allDependenciesMet: processedMessages.add(msg) + rm.addToBloomAndHistory(msg) else: newIncomingBuffer.add(msg) rm.incomingBuffer = newIncomingBuffer for msg in processedMessages: - rm.messageHistory.add(msg.messageId) - if rm.messageHistory.len > rm.config.maxMessageHistory: - rm.messageHistory.delete(0) if rm.onMessageReady != nil: rm.onMessageReady(msg.messageId) diff --git a/src/utils.nim b/src/utils.nim index 693d918..20edc48 100644 --- a/src/utils.nim +++ b/src/utils.nim @@ -75,6 +75,12 @@ proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} = except Exception as e: logError("Failed to clean ReliabilityManager bloom filter: " & e.msg) +proc addToBloomAndHistory*(rm: ReliabilityManager, msg: Message) = + rm.messageHistory.add(msg.messageId) + if rm.messageHistory.len > rm.config.maxMessageHistory: + rm.messageHistory.delete(0) + rm.bloomFilter.add(msg.messageId) + proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) = rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 48e02d7..645d37c 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -1,6 +1,8 @@ -import unittest, results, chronos, chronicles +import unittest, results, chronos import ../src/reliability import ../src/common +import ../src/protobuf +import ../src/utils suite "ReliabilityManager": var rm: ReliabilityManager @@ -41,38 +43,58 @@ suite "ReliabilityManager": unwrapped == msg missingDeps.len == 0 - test "markDependenciesMet": - # First message - let msg1 = @[byte(1)] + test "marking dependencies": + var messageReadyCount = 0 + var messageSentCount = 0 + var missingDepsCount = 0 + + rm.setCallbacks( + proc(messageId: MessageID) {.gcsafe.} = messageReadyCount += 1, + proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1 + ) + + # We'll create dependency IDs that aren't in the bloom filter yet let id1 = "msg1" - let wrap1 = rm.wrapOutgoingMessage(msg1, id1) - check wrap1.isOk() - let wrapped1 = wrap1.get() - - # Second message - let msg2 = @[byte(2)] let id2 = "msg2" - let wrap2 = rm.wrapOutgoingMessage(msg2, id2) - check wrap2.isOk() - let wrapped2 = wrap2.get() - # Third message - let msg3 = @[byte(3)] - let id3 = "msg3" - let wrap3 = rm.wrapOutgoingMessage(msg3, id3) - check wrap3.isOk() - let wrapped3 = wrap3.get() + # Create a message that depends on these IDs + let msg3 = Message( + messageId: "msg3", + lamportTimestamp: 1, + causalHistory: @[id1, id2], # Depends on messages we haven't seen + channelId: "testChannel", + content: @[byte(3)], + bloomFilter: @[] + ) - # Check dependencies - var unwrap3 = rm.unwrapReceivedMessage(wrapped3) - check unwrap3.isOk() - var (_, missing3) = unwrap3.get() + let serializedMsg3 = serializeMessage(msg3) + check serializedMsg3.isOk() - # Mark dependencies as met - let markResult = rm.markDependenciesMet(@[id1, id2]) + # Process the message - should identify missing dependencies + let unwrapResult = rm.unwrapReceivedMessage(serializedMsg3.get()) + check unwrapResult.isOk() + let (_, missingDeps) = unwrapResult.get() + + # Verify missing dependencies were identified + check missingDepsCount == 1 + check missingDeps.len == 2 + check id1 in missingDeps + check id2 in missingDeps + + # Now mark dependencies as met + let markResult = rm.markDependenciesMet(missingDeps) check markResult.isOk() - check missing3.len == 0 + # Process the message again - should now be ready + let reprocessResult = rm.unwrapReceivedMessage(serializedMsg3.get()) + check reprocessResult.isOk() + let (_, remainingDeps) = reprocessResult.get() + + # Verify message is now processed + check remainingDeps.len == 0 + check messageReadyCount == 1 # msg3 should now be ready + check missingDepsCount == 1 # Only the first attempt should report missing deps test "callbacks work correctly": var messageReadyCount = 0 @@ -85,18 +107,76 @@ suite "ReliabilityManager": proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1 ) - let msg1Result = rm.wrapOutgoingMessage(@[byte(1)], "msg1") - let msg2Result = rm.wrapOutgoingMessage(@[byte(2)], "msg2") - check msg1Result.isOk() and msg2Result.isOk() - let msg1 = msg1Result.get() - let msg2 = msg2Result.get() - discard rm.unwrapReceivedMessage(msg1) - discard rm.unwrapReceivedMessage(msg2) + # First send our own message + let msg1 = @[byte(1)] + let id1 = "msg1" + let wrap1 = rm.wrapOutgoingMessage(msg1, id1) + check wrap1.isOk() - check: - messageReadyCount == 2 - messageSentCount == 0 # This would be triggered by checkUnacknowledgedMessages - missingDepsCount == 0 + # Create a message that has our message in causal history + let msg2 = Message( + messageId: "msg2", + lamportTimestamp: rm.lamportTimestamp + 1, + causalHistory: @[id1], # Include our message in causal history + channelId: "testChannel", + content: @[byte(2)], + bloomFilter: @[] # Test with an empty bloom filter + ) + + let serializedMsg2 = serializeMessage(msg2) + check serializedMsg2.isOk() + + # Process the "received" message - should trigger callbacks + let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get()) + check unwrapResult.isOk() + + check messageReadyCount == 1 # For msg2 which we "received" + check messageSentCount == 1 # For msg1 which was acknowledged via causal history + + test "bloom filter acknowledgment": + var messageSentCount = 0 + + rm.setCallbacks( + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard + ) + + # First send our own message + let msg1 = @[byte(1)] + let id1 = "msg1" + let wrap1 = rm.wrapOutgoingMessage(msg1, id1) + check wrap1.isOk() + + # Create a message simulating another party's message + # with bloom filter containing our message + var otherPartyBloomFilter = newRollingBloomFilter( + DefaultBloomFilterCapacity, + DefaultBloomFilterErrorRate, + DefaultBloomFilterWindow + ) + otherPartyBloomFilter.add(id1) # Add our message to their bloom filter + + let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter) + check bfResult.isOk() + + let msg2 = Message( + messageId: "msg2", + lamportTimestamp: rm.lamportTimestamp + 1, + causalHistory: @[], # Empty causal history as we're using bloom filter + channelId: "testChannel", + content: @[byte(2)], + bloomFilter: bfResult.get() + ) + + let serializedMsg2 = serializeMessage(msg2) + check serializedMsg2.isOk() + + # Process the "received" message - should trigger acknowledgment + let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get()) + check unwrapResult.isOk() + + check messageSentCount == 1 # Our message should be acknowledged via bloom filter test "periodic sync callback works": var syncCallCount = 0