From 5df71ad3eaf68172cef39a2e1838ddd871b03b5d Mon Sep 17 00:00:00 2001 From: Akhil <111925100+shash256@users.noreply.github.com> Date: Mon, 13 Jan 2025 13:49:28 +0400 Subject: [PATCH 1/5] feat: add bloom filter (#3) --- .gitignore | 6 ++ src/bloom.nim | 123 +++++++++++++++++++++++++++++ src/private/probabilities.nim | 98 +++++++++++++++++++++++ tests/test_bloom.nim | 142 ++++++++++++++++++++++++++++++++++ 4 files changed, 369 insertions(+) create mode 100644 .gitignore create mode 100644 src/bloom.nim create mode 100644 src/private/probabilities.nim create mode 100644 tests/test_bloom.nim diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cfc9510 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +nimcache +nimcache/* +tests/bloom +nim-bloom/bloom +.DS_Store +src/.DS_Store \ No newline at end of file diff --git a/src/bloom.nim b/src/bloom.nim new file mode 100644 index 0000000..92b0712 --- /dev/null +++ b/src/bloom.nim @@ -0,0 +1,123 @@ +from math import ceil, ln, pow, round +import hashes +import strutils +import results +import private/probabilities + +type + BloomFilter* = object + capacity*: int + errorRate*: float + kHashes*: int + mBits*: int + intArray: seq[int] + +{.push overflowChecks: off.} # Turn off overflow checks for hashing operations + +proc hashN(item: string, n: int, maxValue: int): int = + ## Get the nth hash using Nim's built-in hash function using + ## the double hashing technique from Kirsch and Mitzenmacher, 2008: + ## http://www.eecs.harvard.edu/~kirsch/pubs/bbbf/rsa.pdf + let + hashA = abs(hash(item)) mod maxValue # Use abs to handle negative hashes + hashB = abs(hash(item & " b")) mod maxValue # string concatenation + abs((hashA + n * hashB)) mod maxValue + # # Use bit rotation for second hash instead of string concatenation if speed if preferred over FP-rate + # # Rotate left by 21 bits (lower the rotation, higher the speed but higher the FP-rate too) + # hashB = abs( + # ((h shl 21) or (h shr (sizeof(int) * 8 - 21))) + # ) mod maxValue + # abs((hashA + n.int64 * hashB)) mod maxValue + +{.pop.} + +proc getMOverNBitsForK*(k: int, targetError: float, + probabilityTable = kErrors): Result[int, string] = + ## Returns the optimal number of m/n bits for a given k. + if k notin 0..12: + return err("K must be <= 12 if forceNBitsPerElem is not also specified.") + + for mOverN in 2..probabilityTable[k].high: + if probabilityTable[k][mOverN] < targetError: + return ok(mOverN) + + err("Specified value of k and error rate not achievable using less than 4 bytes / element.") + +proc initializeBloomFilter*(capacity: int, errorRate: float, k = 0, + forceNBitsPerElem = 0): Result[BloomFilter, string] = + ## Initializes a Bloom filter with specified parameters. + ## + ## Parameters: + ## - capacity: Expected number of elements to be inserted + ## - errorRate: Desired false positive rate (e.g., 0.01 for 1%) + ## - k: Optional number of hash functions. If 0, calculated optimally + ## See http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html for + ## useful tables on k and m/n (n bits per element) combinations. + ## - forceNBitsPerElem: Optional override for bits per element + var + kHashes: int + nBitsPerElem: int + + if k < 1: # Calculate optimal k and use that + let bitsPerElem = ceil(-1.0 * (ln(errorRate) / (pow(ln(2.float), 2)))) + kHashes = round(ln(2.float) * bitsPerElem).int + nBitsPerElem = round(bitsPerElem).int + else: # Use specified k if possible + if forceNBitsPerElem < 1: # Use lookup table + let mOverNRes = getMOverNBitsForK(k = k, targetError = errorRate) + if mOverNRes.isErr: + return err(mOverNRes.error) + nBitsPerElem = mOverNRes.value + else: + nBitsPerElem = forceNBitsPerElem + kHashes = k + + let + mBits = capacity * nBitsPerElem + mInts = 1 + mBits div (sizeof(int) * 8) + + ok(BloomFilter( + capacity: capacity, + errorRate: errorRate, + kHashes: kHashes, + mBits: mBits, + intArray: newSeq[int](mInts) + )) + +proc `$`*(bf: BloomFilter): string = + ## Prints the configuration of the Bloom filter. + "Bloom filter with $1 capacity, $2 error rate, $3 hash functions, and requiring $4 bits of memory." % + [$bf.capacity, + formatFloat(bf.errorRate, format = ffScientific, precision = 1), + $bf.kHashes, + $(bf.mBits div bf.capacity)] + +proc computeHashes(bf: BloomFilter, item: string): seq[int] = + var hashes = newSeq[int](bf.kHashes) + for i in 0.. 12 + let errorCase = getMOverNBitsForK(k = 13, targetError = 0.01) + check errorCase.isErr + check errorCase.error == "K must be <= 12 if forceNBitsPerElem is not also specified." + + # Test error case for unachievable error rate + let errorCase2 = getMOverNBitsForK(k = 2, targetError = 0.00001) + check errorCase2.isErr + check errorCase2.error == "Specified value of k and error rate not achievable using less than 4 bytes / element." + + # Test success cases + let case1 = getMOverNBitsForK(k = 2, targetError = 0.1) + check case1.isOk + check case1.value == 6 + + let case2 = getMOverNBitsForK(k = 7, targetError = 0.01) + check case2.isOk + check case2.value == 10 + + let case3 = getMOverNBitsForK(k = 7, targetError = 0.001) + check case3.isOk + check case3.value == 16 + + let bf2Result = initializeBloomFilter(10000, 0.001, k = 4, forceNBitsPerElem = 20) + check bf2Result.isOk + let bf2 = bf2Result.get + check bf2.kHashes == 4 + check bf2.mBits == 200000 + + test "string representation": + let bf3Result = initializeBloomFilter(1000, 0.01, k = 4) + check bf3Result.isOk + let bf3 = bf3Result.get + let str = $bf3 + check str.contains("1000") # Capacity + check str.contains("4 hash") # Hash functions + check str.contains("1.0e-02") # Error rate in scientific notation + +suite "bloom filter special cases": + test "different patterns of strings": + const testSize = 10_000 + let patterns = @[ + "shortstr", + repeat("a", 1000), # Very long string + "special@#$%^&*()", # Special characters + "unicode→★∑≈", # Unicode characters + repeat("pattern", 10) # Repeating pattern + ] + + let bfResult = initializeBloomFilter(testSize, 0.01) + check bfResult.isOk + var bf = bfResult.get + var inserted = newSeq[string](testSize) + + # Test pattern handling + for pattern in patterns: + bf.insert(pattern) + assert bf.lookup(pattern), "failed lookup pattern: " & pattern + + # Test general insertion and lookup + for i in 0.. Date: Tue, 11 Feb 2025 13:23:19 +0530 Subject: [PATCH 2/5] feat: add rolling bloom filter, reliability utils and protobuf (#4) --- .gitignore | 5 +- reliability.nimble | 16 ++++ src/bloom.nim | 65 +++++++------ src/message.nim | 27 ++++++ src/private/probabilities.nim | 175 +++++++++++++++++----------------- src/protobuf.nim | 114 ++++++++++++++++++++++ src/protobufutil.nim | 32 +++++++ src/reliability_utils.nim | 97 +++++++++++++++++++ src/rolling_bloom_filter.nim | 118 +++++++++++++++++++++++ tests/test_bloom.nim | 64 +++++++------ 10 files changed, 566 insertions(+), 147 deletions(-) create mode 100644 reliability.nimble create mode 100644 src/message.nim create mode 100644 src/protobuf.nim create mode 100644 src/protobufutil.nim create mode 100644 src/reliability_utils.nim create mode 100644 src/rolling_bloom_filter.nim diff --git a/.gitignore b/.gitignore index cfc9510..1431936 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ nimcache nimcache/* -tests/bloom +tests/test_bloom nim-bloom/bloom .DS_Store -src/.DS_Store \ No newline at end of file +src/.DS_Store +nph \ No newline at end of file diff --git a/reliability.nimble b/reliability.nimble new file mode 100644 index 0000000..8bc19c8 --- /dev/null +++ b/reliability.nimble @@ -0,0 +1,16 @@ +# Package +version = "0.1.0" +author = "Waku Team" +description = "E2E Reliability Protocol API" +license = "MIT" +srcDir = "src" + +# Dependencies +requires "nim >= 2.0.8" +requires "chronicles" +requires "libp2p" + +# Tasks +task test, "Run the test suite": + exec "nim c -r tests/test_bloom.nim" + exec "nim c -r tests/test_reliability.nim" diff --git a/src/bloom.nim b/src/bloom.nim index 92b0712..ea3b703 100644 --- a/src/bloom.nim +++ b/src/bloom.nim @@ -4,22 +4,21 @@ import strutils import results import private/probabilities -type - BloomFilter* = object - capacity*: int - errorRate*: float - kHashes*: int - mBits*: int - intArray: seq[int] +type BloomFilter* = object + capacity*: int + errorRate*: float + kHashes*: int + mBits*: int + intArray*: seq[int] -{.push overflowChecks: off.} # Turn off overflow checks for hashing operations +{.push overflowChecks: off.} # Turn off overflow checks for hashing operations proc hashN(item: string, n: int, maxValue: int): int = ## Get the nth hash using Nim's built-in hash function using ## the double hashing technique from Kirsch and Mitzenmacher, 2008: ## http://www.eecs.harvard.edu/~kirsch/pubs/bbbf/rsa.pdf let - hashA = abs(hash(item)) mod maxValue # Use abs to handle negative hashes + hashA = abs(hash(item)) mod maxValue # Use abs to handle negative hashes hashB = abs(hash(item & " b")) mod maxValue # string concatenation abs((hashA + n * hashB)) mod maxValue # # Use bit rotation for second hash instead of string concatenation if speed if preferred over FP-rate @@ -31,20 +30,24 @@ proc hashN(item: string, n: int, maxValue: int): int = {.pop.} -proc getMOverNBitsForK*(k: int, targetError: float, - probabilityTable = kErrors): Result[int, string] = +proc getMOverNBitsForK*( + k: int, targetError: float, probabilityTable = kErrors +): Result[int, string] = ## Returns the optimal number of m/n bits for a given k. - if k notin 0..12: + if k notin 0 .. 12: return err("K must be <= 12 if forceNBitsPerElem is not also specified.") - for mOverN in 2..probabilityTable[k].high: + for mOverN in 2 .. probabilityTable[k].high: if probabilityTable[k][mOverN] < targetError: return ok(mOverN) - err("Specified value of k and error rate not achievable using less than 4 bytes / element.") + err( + "Specified value of k and error rate not achievable using less than 4 bytes / element." + ) -proc initializeBloomFilter*(capacity: int, errorRate: float, k = 0, - forceNBitsPerElem = 0): Result[BloomFilter, string] = +proc initializeBloomFilter*( + capacity: int, errorRate: float, k = 0, forceNBitsPerElem = 0 +): Result[BloomFilter, string] = ## Initializes a Bloom filter with specified parameters. ## ## Parameters: @@ -76,25 +79,29 @@ proc initializeBloomFilter*(capacity: int, errorRate: float, k = 0, mBits = capacity * nBitsPerElem mInts = 1 + mBits div (sizeof(int) * 8) - ok(BloomFilter( - capacity: capacity, - errorRate: errorRate, - kHashes: kHashes, - mBits: mBits, - intArray: newSeq[int](mInts) - )) + ok( + BloomFilter( + capacity: capacity, + errorRate: errorRate, + kHashes: kHashes, + mBits: mBits, + intArray: newSeq[int](mInts), + ) + ) proc `$`*(bf: BloomFilter): string = ## Prints the configuration of the Bloom filter. "Bloom filter with $1 capacity, $2 error rate, $3 hash functions, and requiring $4 bits of memory." % - [$bf.capacity, - formatFloat(bf.errorRate, format = ffScientific, precision = 1), - $bf.kHashes, - $(bf.mBits div bf.capacity)] + [ + $bf.capacity, + formatFloat(bf.errorRate, format = ffScientific, precision = 1), + $bf.kHashes, + $(bf.mBits div bf.capacity), + ] proc computeHashes(bf: BloomFilter, item: string): seq[int] = var hashes = newSeq[int](bf.kHashes) - for i in 0.. rm.config.maxMessageHistory: + rm.messageHistory.delete(0) + +proc updateLamportTimestamp*( + rm: ReliabilityManager, msgTs: int64 +) {.gcsafe, raises: [].} = + rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 + +proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] = + result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] + +proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] = + withLock rm.lock: + result = rm.messageHistory + +proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] = + withLock rm.lock: + result = rm.outgoingBuffer + +proc getIncomingBuffer*(rm: ReliabilityManager): seq[SdsMessage] = + withLock rm.lock: + result = rm.incomingBuffer diff --git a/src/rolling_bloom_filter.nim b/src/rolling_bloom_filter.nim new file mode 100644 index 0000000..190ab8a --- /dev/null +++ b/src/rolling_bloom_filter.nim @@ -0,0 +1,118 @@ +import chronos +import chronicles +import ./[bloom, message] + +type RollingBloomFilter* = object + filter*: BloomFilter + capacity*: int + minCapacity*: int + maxCapacity*: int + messages*: seq[SdsMessageID] + +const + DefaultBloomFilterCapacity* = 10000 + DefaultBloomFilterErrorRate* = 0.001 + CapacityFlexPercent* = 20 + +proc newRollingBloomFilter*( + capacity: int = DefaultBloomFilterCapacity, + errorRate: float = DefaultBloomFilterErrorRate, +): RollingBloomFilter {.gcsafe.} = + let targetCapacity = if capacity <= 0: DefaultBloomFilterCapacity else: capacity + let targetError = + if errorRate <= 0.0 or errorRate >= 1.0: DefaultBloomFilterErrorRate else: errorRate + + let filterResult = initializeBloomFilter(targetCapacity, targetError) + if filterResult.isErr: + error "Failed to initialize bloom filter", error = filterResult.error + # Try with default values if custom values failed + if capacity != DefaultBloomFilterCapacity or errorRate != DefaultBloomFilterErrorRate: + let defaultResult = + initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) + if defaultResult.isErr: + error "Failed to initialize bloom filter with default parameters", + error = defaultResult.error + + let minCapacity = ( + DefaultBloomFilterCapacity.float * (100 - CapacityFlexPercent).float / 100.0 + ).int + let maxCapacity = ( + DefaultBloomFilterCapacity.float * (100 + CapacityFlexPercent).float / 100.0 + ).int + + info "Successfully initialized bloom filter with default parameters", + capacity = DefaultBloomFilterCapacity, + minCapacity = minCapacity, + maxCapacity = maxCapacity + + return RollingBloomFilter( + filter: defaultResult.get(), + capacity: DefaultBloomFilterCapacity, + minCapacity: minCapacity, + maxCapacity: maxCapacity, + messages: @[], + ) + else: + error "Could not create bloom filter", error = filterResult.error + + let minCapacity = + (targetCapacity.float * (100 - CapacityFlexPercent).float / 100.0).int + let maxCapacity = + (targetCapacity.float * (100 + CapacityFlexPercent).float / 100.0).int + + info "Successfully initialized bloom filter", + capacity = targetCapacity, minCapacity = minCapacity, maxCapacity = maxCapacity + + return RollingBloomFilter( + filter: filterResult.get(), + capacity: targetCapacity, + minCapacity: minCapacity, + maxCapacity: maxCapacity, + messages: @[], + ) + +proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} = + try: + if rbf.messages.len <= rbf.maxCapacity: + return # Don't clean unless we exceed max capacity + + # Initialize new filter + var newFilter = initializeBloomFilter(rbf.maxCapacity, rbf.filter.errorRate).valueOr: + error "Failed to create new bloom filter", error = $error + return + + # Keep most recent messages up to minCapacity + let keepCount = rbf.minCapacity + let startIdx = max(0, rbf.messages.len - keepCount) + var newMessages: seq[SdsMessageID] = @[] + + for i in startIdx ..< rbf.messages.len: + newMessages.add(rbf.messages[i]) + newFilter.insert(cast[string](rbf.messages[i])) + + rbf.messages = newMessages + rbf.filter = newFilter + except Exception: + error "Failed to clean bloom filter", error = getCurrentExceptionMsg() + +proc add*(rbf: var RollingBloomFilter, messageId: SdsMessageID) {.gcsafe.} = + ## Adds a message ID to the rolling bloom filter. + ## + ## Parameters: + ## - messageId: The ID of the message to add. + rbf.filter.insert(cast[string](messageId)) + rbf.messages.add(messageId) + + # Clean if we exceed max capacity + if rbf.messages.len > rbf.maxCapacity: + rbf.clean() + +proc contains*(rbf: RollingBloomFilter, messageId: SdsMessageID): bool = + ## Checks if a message ID is in the rolling bloom filter. + ## + ## Parameters: + ## - messageId: The ID of the message to check. + ## + ## Returns: + ## True if the message ID is probably in the filter, false otherwise. + rbf.filter.lookup(cast[string](messageId)) diff --git a/tests/test_bloom.nim b/tests/test_bloom.nim index 7da555c..540735d 100644 --- a/tests/test_bloom.nim +++ b/tests/test_bloom.nim @@ -1,6 +1,7 @@ import unittest, results, strutils import ../src/bloom from random import rand, randomize +import ../src/[message, protobuf, protobufutil, reliability_utils, rolling_bloom_filter] suite "bloom filter": setup: @@ -13,9 +14,9 @@ suite "bloom filter": sampleChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" testElements = newSeq[string](nElementsToTest) - for i in 0.. 12 let errorCase = getMOverNBitsForK(k = 13, targetError = 0.01) check errorCase.isErr - check errorCase.error == "K must be <= 12 if forceNBitsPerElem is not also specified." + check errorCase.error == + "K must be <= 12 if forceNBitsPerElem is not also specified." # Test error case for unachievable error rate let errorCase2 = getMOverNBitsForK(k = 2, targetError = 0.00001) check errorCase2.isErr - check errorCase2.error == "Specified value of k and error rate not achievable using less than 4 bytes / element." + check errorCase2.error == + "Specified value of k and error rate not achievable using less than 4 bytes / element." # Test success cases let case1 = getMOverNBitsForK(k = 2, targetError = 0.1) @@ -93,50 +96,51 @@ suite "bloom filter": check bf3Result.isOk let bf3 = bf3Result.get let str = $bf3 - check str.contains("1000") # Capacity - check str.contains("4 hash") # Hash functions - check str.contains("1.0e-02") # Error rate in scientific notation + check str.contains("1000") # Capacity + check str.contains("4 hash") # Hash functions + check str.contains("1.0e-02") # Error rate in scientific notation suite "bloom filter special cases": test "different patterns of strings": const testSize = 10_000 - let patterns = @[ - "shortstr", - repeat("a", 1000), # Very long string - "special@#$%^&*()", # Special characters - "unicode→★∑≈", # Unicode characters - repeat("pattern", 10) # Repeating pattern - ] - + let patterns = + @[ + "shortstr", + repeat("a", 1000), # Very long string + "special@#$%^&*()", # Special characters + "unicode→★∑≈", # Unicode characters + repeat("pattern", 10), # Repeating pattern + ] + let bfResult = initializeBloomFilter(testSize, 0.01) check bfResult.isOk var bf = bfResult.get var inserted = newSeq[string](testSize) - + # Test pattern handling for pattern in patterns: bf.insert(pattern) assert bf.lookup(pattern), "failed lookup pattern: " & pattern - + # Test general insertion and lookup - for i in 0.. Date: Mon, 17 Feb 2025 14:47:01 +0530 Subject: [PATCH 3/5] feat: add reliability.nim --- src/reliability.nim | 346 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 346 insertions(+) create mode 100644 src/reliability.nim diff --git a/src/reliability.nim b/src/reliability.nim new file mode 100644 index 0000000..0cc490e --- /dev/null +++ b/src/reliability.nim @@ -0,0 +1,346 @@ +import std/[times, locks, tables, sets] +import chronos, results +import ../src/[message, protobuf, reliability_utils, rolling_bloom_filter] + +proc newReliabilityManager*( + channelId: string, config: ReliabilityConfig = defaultConfig() +): Result[ReliabilityManager, ReliabilityError] = + ## Creates a new ReliabilityManager with the specified channel ID and configuration. + ## + ## Parameters: + ## - channelId: A unique identifier for the communication channel. + ## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used. + ## + ## Returns: + ## A Result containing either a new ReliabilityManager instance or an error. + if channelId.len == 0: + return err(reInvalidArgument) + + try: + let bloomFilter = newRollingBloomFilter( + config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow + ) + + let rm = ReliabilityManager( + lamportTimestamp: 0, + messageHistory: @[], + bloomFilter: bloomFilter, + outgoingBuffer: @[], + incomingBuffer: @[], + channelId: channelId, + config: config, + ) + initLock(rm.lock) + return ok(rm) + except: + return err(reOutOfMemory) + +proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = + var i = 0 + while i < rm.outgoingBuffer.len: + var acknowledged = false + let outMsg = rm.outgoingBuffer[i] + + # Check if message is in causal history + for msgID in msg.causalHistory: + if outMsg.message.messageId == msgID: + acknowledged = true + break + + # Check bloom filter if not already acknowledged + if not acknowledged and msg.bloomFilter.len > 0: + let bfResult = deserializeBloomFilter(msg.bloomFilter) + if bfResult.isOk: + var rbf = RollingBloomFilter( + filter: bfResult.get(), window: rm.bloomFilter.window, messages: @[] + ) + if rbf.contains(outMsg.message.messageId): + acknowledged = true + else: + logError("Failed to deserialize bloom filter") + + if acknowledged: + if rm.onMessageSent != nil: + rm.onMessageSent(outMsg.message.messageId) + rm.outgoingBuffer.delete(i) + else: + inc i + +proc wrapOutgoingMessage*( + rm: ReliabilityManager, message: seq[byte], messageId: MessageID +): Result[seq[byte], ReliabilityError] = + ## Wraps an outgoing message with reliability metadata. + ## + ## Parameters: + ## - message: The content of the message to be sent. + ## + ## Returns: + ## A Result containing either a Message object with reliability metadata or an error. + if message.len == 0: + return err(reInvalidArgument) + if message.len > MaxMessageSize: + return err(reMessageTooLarge) + + withLock rm.lock: + try: + rm.updateLamportTimestamp(getTime().toUnix) + + # Serialize current bloom filter + var bloomBytes: seq[byte] + let bfResult = serializeBloomFilter(rm.bloomFilter.filter) + if bfResult.isErr: + logError("Failed to serialize bloom filter") + bloomBytes = @[] + else: + bloomBytes = bfResult.get() + + let msg = Message( + messageId: messageId, + lamportTimestamp: rm.lamportTimestamp, + causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), + channelId: rm.channelId, + content: message, + bloomFilter: bloomBytes, + ) + + # Add to outgoing buffer + rm.outgoingBuffer.add( + UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0) + ) + + # Add to causal history and bloom filter + rm.bloomFilter.add(msg.messageId) + rm.addToHistory(msg.messageId) + + return serializeMessage(msg) + except: + return err(reInternalError) + +proc processIncomingBuffer(rm: ReliabilityManager) = + withLock rm.lock: + if rm.incomingBuffer.len == 0: + return + + # Create dependency map + var dependencies = initTable[MessageID, seq[MessageID]]() + var readyToProcess: seq[MessageID] = @[] + + # Build dependency graph and find initially ready messages + for msg in rm.incomingBuffer: + var hasMissingDeps = false + for depId in msg.causalHistory: + if not rm.bloomFilter.contains(depId): + hasMissingDeps = true + if depId notin dependencies: + dependencies[depId] = @[] + dependencies[depId].add(msg.messageId) + + if not hasMissingDeps: + readyToProcess.add(msg.messageId) + + # Process ready messages and their dependents + var newIncomingBuffer: seq[Message] = @[] + var processed = initHashSet[MessageID]() + + while readyToProcess.len > 0: + let msgId = readyToProcess.pop() + if msgId in processed: + continue + + # Process this message + for msg in rm.incomingBuffer: + if msg.messageId == msgId: + rm.addToHistory(msg.messageId) + if rm.onMessageReady != nil: + rm.onMessageReady(msg.messageId) + processed.incl(msgId) + + # Add any dependent messages that might now be ready + if msgId in dependencies: + for dependentId in dependencies[msgId]: + readyToProcess.add(dependentId) + break + + # Update incomingBuffer with remaining messages + for msg in rm.incomingBuffer: + if msg.messageId notin processed: + newIncomingBuffer.add(msg) + + rm.incomingBuffer = newIncomingBuffer + +proc unwrapReceivedMessage*( + rm: ReliabilityManager, message: seq[byte] +): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] = + ## Unwraps a received message and processes its reliability metadata. + ## + ## Parameters: + ## - message: The received Message object. + ## + ## Returns: + ## A Result containing either a tuple with the processed message and missing dependencies, or an error. + try: + let msgResult = deserializeMessage(message) + if not msgResult.isOk: + return err(msgResult.error) + + let msg = msgResult.get + if rm.bloomFilter.contains(msg.messageId): + return ok((msg.content, @[])) + + rm.bloomFilter.add(msg.messageId) + + # Update Lamport timestamp + rm.updateLamportTimestamp(msg.lamportTimestamp) + + # Review ACK status for outgoing messages + rm.reviewAckStatus(msg) + + var missingDeps: seq[MessageID] = @[] + for depId in msg.causalHistory: + if not rm.bloomFilter.contains(depId): + missingDeps.add(depId) + + if missingDeps.len == 0: + # Check if any dependencies are still in incoming buffer + var depsInBuffer = false + for bufferedMsg in rm.incomingBuffer: + if bufferedMsg.messageId in msg.causalHistory: + depsInBuffer = true + break + if depsInBuffer: + rm.incomingBuffer.add(msg) + else: + # All dependencies met, add to history + rm.addToHistory(msg.messageId) + rm.processIncomingBuffer() + if rm.onMessageReady != nil: + rm.onMessageReady(msg.messageId) + else: + # Buffer message and request missing dependencies + rm.incomingBuffer.add(msg) + if rm.onMissingDependencies != nil: + rm.onMissingDependencies(msg.messageId, missingDeps) + + return ok((msg.content, missingDeps)) + except: + return err(reInternalError) + +proc markDependenciesMet*( + rm: ReliabilityManager, messageIds: seq[MessageID] +): Result[void, ReliabilityError] = + ## Marks the specified message dependencies as met. + ## + ## Parameters: + ## - messageIds: A sequence of message IDs to mark as met. + ## + ## Returns: + ## A Result indicating success or an error. + try: + # Add all messageIds to bloom filter + for msgId in messageIds: + if not rm.bloomFilter.contains(msgId): + rm.bloomFilter.add(msgId) + # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + rm.processIncomingBuffer() + + return ok() + except: + return err(reInternalError) + +proc setCallbacks*( + rm: ReliabilityManager, + onMessageReady: proc(messageId: MessageID) {.gcsafe.}, + onMessageSent: proc(messageId: MessageID) {.gcsafe.}, + onMissingDependencies: + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}, + onPeriodicSync: PeriodicSyncCallback = nil, +) = + ## Sets the callback functions for various events in the ReliabilityManager. + ## + ## Parameters: + ## - onMessageReady: Callback function called when a message is ready to be processed. + ## - onMessageSent: Callback function called when a message is confirmed as sent. + ## - onMissingDependencies: Callback function called when a message has missing dependencies. + ## - onPeriodicSync: Callback function called to notify about periodic sync + withLock rm.lock: + rm.onMessageReady = onMessageReady + rm.onMessageSent = onMessageSent + rm.onMissingDependencies = onMissingDependencies + rm.onPeriodicSync = onPeriodicSync + +proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = + ## Checks and processes unacknowledged messages in the outgoing buffer. + withLock rm.lock: + let now = getTime() + var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] + + try: + for unackMsg in rm.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + # Time to attempt resend + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) + else: + if rm.onMessageSent != nil: + rm.onMessageSent(unackMsg.message.messageId) + else: + newOutgoingBuffer.add(unackMsg) + + rm.outgoingBuffer = newOutgoingBuffer + except Exception as e: + logError("Error in checking unacknowledged messages: " & e.msg) + +proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = + ## Periodically sweeps the buffer to clean up and check unacknowledged messages. + while true: + {.gcsafe.}: + try: + rm.checkUnacknowledgedMessages() + rm.cleanBloomFilter() + except Exception as e: + logError("Error in periodic buffer sweep: " & e.msg) + + await sleepAsync(chronos.milliseconds(rm.config.bufferSweepInterval.inMilliseconds)) + +proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = + ## Periodically notifies to send a sync message to maintain connectivity. + while true: + {.gcsafe.}: + try: + if rm.onPeriodicSync != nil: + rm.onPeriodicSync() + except Exception as e: + logError("Error in periodic sync: " & e.msg) + await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) + +proc startPeriodicTasks*(rm: ReliabilityManager) = + ## Starts the periodic tasks for buffer sweeping and sync message sending. + ## + ## This procedure should be called after creating a ReliabilityManager to enable automatic maintenance. + asyncSpawn rm.periodicBufferSweep() + asyncSpawn rm.periodicSyncMessage() + +proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = + ## Resets the ReliabilityManager to its initial state. + ## + ## This procedure clears all buffers and resets the Lamport timestamp. + ## + ## Returns: + ## A Result indicating success or an error if the Bloom filter initialization fails. + withLock rm.lock: + try: + rm.lamportTimestamp = 0 + rm.messageHistory.setLen(0) + rm.outgoingBuffer.setLen(0) + rm.incomingBuffer.setLen(0) + rm.bloomFilter = newRollingBloomFilter( + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, + rm.config.bloomFilterWindow, + ) + return ok() + except: + return err(reInternalError) From 3e25aec7ce27263886db58f31ec9824d33bdb0d8 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Mon, 17 Feb 2025 16:16:08 +0530 Subject: [PATCH 4/5] chore: updates from prev suggestions --- src/reliability.nim | 203 +++++++++++++++++++++----------------------- 1 file changed, 98 insertions(+), 105 deletions(-) diff --git a/src/reliability.nim b/src/reliability.nim index 0cc490e..ebb533d 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,9 +1,9 @@ -import std/[times, locks, tables, sets] -import chronos, results -import ../src/[message, protobuf, reliability_utils, rolling_bloom_filter] +import std/[times, locks, tables, sets, sequtils] +import chronos, results, chronicles +import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: string, config: ReliabilityConfig = defaultConfig() + channelId: SdsChannelID, config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new ReliabilityManager with the specified channel ID and configuration. ## @@ -14,12 +14,11 @@ proc newReliabilityManager*( ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. if channelId.len == 0: - return err(reInvalidArgument) + return err(ReliabilityError.reInvalidArgument) try: - let bloomFilter = newRollingBloomFilter( - config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow - ) + let bloomFilter = + newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate) let rm = ReliabilityManager( lamportTimestamp: 0, @@ -32,10 +31,11 @@ proc newReliabilityManager*( ) initLock(rm.lock) return ok(rm) - except: - return err(reOutOfMemory) + except Exception: + error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reOutOfMemory) -proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = +proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = var i = 0 while i < rm.outgoingBuffer.len: var acknowledged = false @@ -50,57 +50,62 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = # Check bloom filter if not already acknowledged if not acknowledged and msg.bloomFilter.len > 0: let bfResult = deserializeBloomFilter(msg.bloomFilter) - if bfResult.isOk: + if bfResult.isOk(): var rbf = RollingBloomFilter( - filter: bfResult.get(), window: rm.bloomFilter.window, messages: @[] + filter: bfResult.get(), + capacity: bfResult.get().capacity, + minCapacity: ( + bfResult.get().capacity.float * (100 - CapacityFlexPercent).float / 100.0 + ).int, + maxCapacity: ( + bfResult.get().capacity.float * (100 + CapacityFlexPercent).float / 100.0 + ).int, + messages: @[], ) if rbf.contains(outMsg.message.messageId): acknowledged = true else: - logError("Failed to deserialize bloom filter") + error "Failed to deserialize bloom filter", error = bfResult.error if acknowledged: - if rm.onMessageSent != nil: + if not rm.onMessageSent.isNil(): rm.onMessageSent(outMsg.message.messageId) rm.outgoingBuffer.delete(i) else: inc i proc wrapOutgoingMessage*( - rm: ReliabilityManager, message: seq[byte], messageId: MessageID + rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## ## Parameters: ## - message: The content of the message to be sent. + ## - messageId: Unique identifier for the message ## ## Returns: - ## A Result containing either a Message object with reliability metadata or an error. + ## A Result containing either wrapped message bytes or an error. if message.len == 0: - return err(reInvalidArgument) + return err(ReliabilityError.reInvalidArgument) if message.len > MaxMessageSize: - return err(reMessageTooLarge) + return err(ReliabilityError.reMessageTooLarge) withLock rm.lock: try: rm.updateLamportTimestamp(getTime().toUnix) - # Serialize current bloom filter - var bloomBytes: seq[byte] let bfResult = serializeBloomFilter(rm.bloomFilter.filter) if bfResult.isErr: - logError("Failed to serialize bloom filter") - bloomBytes = @[] - else: - bloomBytes = bfResult.get() + error "Failed to serialize bloom filter" + return err(ReliabilityError.reSerializationError) - let msg = Message( + let msg = SdsMessage( messageId: messageId, lamportTimestamp: rm.lamportTimestamp, - causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), + causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory), channelId: rm.channelId, content: message, - bloomFilter: bloomBytes, + bloomFilter: bfResult.get(), ) # Add to outgoing buffer @@ -113,17 +118,19 @@ proc wrapOutgoingMessage*( rm.addToHistory(msg.messageId) return serializeMessage(msg) - except: - return err(reInternalError) + except Exception: + error "Failed to wrap message", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reSerializationError) -proc processIncomingBuffer(rm: ReliabilityManager) = +proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = withLock rm.lock: if rm.incomingBuffer.len == 0: return # Create dependency map - var dependencies = initTable[MessageID, seq[MessageID]]() - var readyToProcess: seq[MessageID] = @[] + var dependencies = initTable[SdsMessageID, seq[SdsMessageID]]() + var readyToProcess: seq[SdsMessageID] = @[] + var processed = initHashSet[SdsMessageID]() # Build dependency graph and find initially ready messages for msg in rm.incomingBuffer: @@ -138,10 +145,6 @@ proc processIncomingBuffer(rm: ReliabilityManager) = if not hasMissingDeps: readyToProcess.add(msg.messageId) - # Process ready messages and their dependents - var newIncomingBuffer: seq[Message] = @[] - var processed = initHashSet[MessageID]() - while readyToProcess.len > 0: let msgId = readyToProcess.pop() if msgId in processed: @@ -151,39 +154,31 @@ proc processIncomingBuffer(rm: ReliabilityManager) = for msg in rm.incomingBuffer: if msg.messageId == msgId: rm.addToHistory(msg.messageId) - if rm.onMessageReady != nil: + if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) processed.incl(msgId) # Add any dependent messages that might now be ready if msgId in dependencies: - for dependentId in dependencies[msgId]: - readyToProcess.add(dependentId) + readyToProcess.add(dependencies[msgId]) break - # Update incomingBuffer with remaining messages - for msg in rm.incomingBuffer: - if msg.messageId notin processed: - newIncomingBuffer.add(msg) - - rm.incomingBuffer = newIncomingBuffer + rm.incomingBuffer = rm.incomingBuffer.filterIt(it.messageId notin processed) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] -): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] = +): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] = ## Unwraps a received message and processes its reliability metadata. ## ## Parameters: - ## - message: The received Message object. + ## - message: The received message bytes ## ## Returns: - ## A Result containing either a tuple with the processed message and missing dependencies, or an error. + ## A Result containing either tuple of (processed message, missing dependencies) or an error. try: - let msgResult = deserializeMessage(message) - if not msgResult.isOk: - return err(msgResult.error) + let msg = deserializeMessage(message).valueOr: + return err(ReliabilityError.reDeserializationError) - let msg = msgResult.get if rm.bloomFilter.contains(msg.messageId): return ok((msg.content, @[])) @@ -195,7 +190,7 @@ proc unwrapReceivedMessage*( # Review ACK status for outgoing messages rm.reviewAckStatus(msg) - var missingDeps: seq[MessageID] = @[] + var missingDeps: seq[SdsMessageID] = @[] for depId in msg.causalHistory: if not rm.bloomFilter.contains(depId): missingDeps.add(depId) @@ -207,26 +202,27 @@ proc unwrapReceivedMessage*( if bufferedMsg.messageId in msg.causalHistory: depsInBuffer = true break + if depsInBuffer: rm.incomingBuffer.add(msg) else: # All dependencies met, add to history rm.addToHistory(msg.messageId) rm.processIncomingBuffer() - if rm.onMessageReady != nil: + if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) else: - # Buffer message and request missing dependencies rm.incomingBuffer.add(msg) - if rm.onMissingDependencies != nil: + if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps) return ok((msg.content, missingDeps)) - except: - return err(reInternalError) + except Exception: + error "Failed to unwrap message", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reDeserializationError) proc markDependenciesMet*( - rm: ReliabilityManager, messageIds: seq[MessageID] + rm: ReliabilityManager, messageIds: seq[SdsMessageID] ): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. ## @@ -241,18 +237,19 @@ proc markDependenciesMet*( if not rm.bloomFilter.contains(msgId): rm.bloomFilter.add(msgId) # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? - rm.processIncomingBuffer() + rm.processIncomingBuffer() return ok() - except: - return err(reInternalError) + except Exception: + error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) proc setCallbacks*( rm: ReliabilityManager, - onMessageReady: proc(messageId: MessageID) {.gcsafe.}, - onMessageSent: proc(messageId: MessageID) {.gcsafe.}, + onMessageReady: proc(messageId: SdsMessageID) {.gcsafe.}, + onMessageSent: proc(messageId: SdsMessageID) {.gcsafe.}, onMissingDependencies: - proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}, + proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.}, onPeriodicSync: PeriodicSyncCallback = nil, ) = ## Sets the callback functions for various events in the ReliabilityManager. @@ -268,53 +265,52 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = +proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} = ## Checks and processes unacknowledged messages in the outgoing buffer. withLock rm.lock: let now = getTime() var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] - try: - for unackMsg in rm.outgoingBuffer: - let elapsed = now - unackMsg.sendTime - if elapsed > rm.config.resendInterval: - # Time to attempt resend - if unackMsg.resendAttempts < rm.config.maxResendAttempts: - var updatedMsg = unackMsg - updatedMsg.resendAttempts += 1 - updatedMsg.sendTime = now - newOutgoingBuffer.add(updatedMsg) - else: - if rm.onMessageSent != nil: - rm.onMessageSent(unackMsg.message.messageId) + for unackMsg in rm.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + # Time to attempt resend + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) else: - newOutgoingBuffer.add(unackMsg) + if not rm.onMessageSent.isNil(): + rm.onMessageSent(unackMsg.message.messageId) + else: + newOutgoingBuffer.add(unackMsg) - rm.outgoingBuffer = newOutgoingBuffer - except Exception as e: - logError("Error in checking unacknowledged messages: " & e.msg) + rm.outgoingBuffer = newOutgoingBuffer -proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = +proc periodicBufferSweep( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: - {.gcsafe.}: - try: - rm.checkUnacknowledgedMessages() - rm.cleanBloomFilter() - except Exception as e: - logError("Error in periodic buffer sweep: " & e.msg) + try: + rm.checkUnacknowledgedMessages() + rm.cleanBloomFilter() + except Exception: + error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg() await sleepAsync(chronos.milliseconds(rm.config.bufferSweepInterval.inMilliseconds)) -proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = +proc periodicSyncMessage( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = ## Periodically notifies to send a sync message to maintain connectivity. while true: - {.gcsafe.}: - try: - if rm.onPeriodicSync != nil: - rm.onPeriodicSync() - except Exception as e: - logError("Error in periodic sync: " & e.msg) + try: + if not rm.onPeriodicSync.isNil(): + rm.onPeriodicSync() + except Exception: + error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) proc startPeriodicTasks*(rm: ReliabilityManager) = @@ -328,9 +324,6 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE ## Resets the ReliabilityManager to its initial state. ## ## This procedure clears all buffers and resets the Lamport timestamp. - ## - ## Returns: - ## A Result indicating success or an error if the Bloom filter initialization fails. withLock rm.lock: try: rm.lamportTimestamp = 0 @@ -338,9 +331,9 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE rm.outgoingBuffer.setLen(0) rm.incomingBuffer.setLen(0) rm.bloomFilter = newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, - rm.config.bloomFilterWindow, + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate ) return ok() - except: - return err(reInternalError) + except Exception: + error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) From 1d1c7e683486d901e1311f469f0903ce9c48d506 Mon Sep 17 00:00:00 2001 From: shash256 <111925100+shash256@users.noreply.github.com> Date: Mon, 10 Mar 2025 16:07:00 +0530 Subject: [PATCH 5/5] chore: address comments from review 1 --- src/message.nim | 8 ++- src/protobuf.nim | 10 ++- src/reliability.nim | 145 +++++++++++++++++++++----------------- src/reliability_utils.nim | 21 ++++-- tests/test_bloom.nim | 1 - 5 files changed, 109 insertions(+), 76 deletions(-) diff --git a/src/message.nim b/src/message.nim index 83d1f3a..f9c68c0 100644 --- a/src/message.nim +++ b/src/message.nim @@ -1,4 +1,4 @@ -import std/times +import std/[times, options, sets] type SdsMessageID* = seq[byte] @@ -8,7 +8,7 @@ type messageId*: SdsMessageID lamportTimestamp*: int64 causalHistory*: seq[SdsMessageID] - channelId*: SdsChannelID + channelId*: Option[SdsChannelID] content*: seq[byte] bloomFilter*: seq[byte] @@ -17,6 +17,10 @@ type sendTime*: Time resendAttempts*: int + IncomingMessage* = object + message*: SdsMessage + missingDeps*: HashSet[SdsMessageID] + const DefaultMaxMessageHistory* = 1000 DefaultMaxCausalHistory* = 10 diff --git a/src/protobuf.nim b/src/protobuf.nim index 5229182..4689da2 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -12,7 +12,8 @@ proc encode*(msg: SdsMessage): ProtoBuffer = for hist in msg.causalHistory: pb.write(3, hist) - pb.write(4, msg.channelId) + if msg.channelId.isSome(): + pb.write(4, msg.channelId.get()) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) pb.finish() @@ -36,8 +37,11 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if histResult.isOk: msg.causalHistory = causalHistory - if not ?pb.getField(4, msg.channelId): - return err(ProtobufError.missingRequiredField("channelId")) + var channelId: seq[byte] + if ?pb.getField(4, channelId): + msg.channelId = some(channelId) + else: + msg.channelId = none[SdsChannelID]() if not ?pb.getField(5, msg.content): return err(ProtobufError.missingRequiredField("content")) diff --git a/src/reliability.nim b/src/reliability.nim index ebb533d..a164d7c 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,9 +1,9 @@ -import std/[times, locks, tables, sets, sequtils] +import std/[times, locks, tables, sets, options] import chronos, results, chronicles import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: SdsChannelID, config: ReliabilityConfig = defaultConfig() + channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new ReliabilityManager with the specified channel ID and configuration. ## @@ -13,7 +13,7 @@ proc newReliabilityManager*( ## ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. - if channelId.len == 0: + if not channelId.isSome(): return err(ReliabilityError.reInvalidArgument) try: @@ -25,7 +25,7 @@ proc newReliabilityManager*( messageHistory: @[], bloomFilter: bloomFilter, outgoingBuffer: @[], - incomingBuffer: @[], + incomingBuffer: initTable[SdsMessageID, IncomingMessage](), channelId: channelId, config: config, ) @@ -35,23 +35,27 @@ proc newReliabilityManager*( error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() return err(ReliabilityError.reOutOfMemory) +proc isAcknowledged*( + msg: UnacknowledgedMessage, + causalHistory: seq[SdsMessageID], + rbf: Option[RollingBloomFilter], +): bool = + if msg.message.messageId in causalHistory: + return true + + if rbf.isSome(): + return rbf.get().contains(msg.message.messageId) + + false + proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = - var i = 0 - while i < rm.outgoingBuffer.len: - var acknowledged = false - let outMsg = rm.outgoingBuffer[i] - - # Check if message is in causal history - for msgID in msg.causalHistory: - if outMsg.message.messageId == msgID: - acknowledged = true - break - - # Check bloom filter if not already acknowledged - if not acknowledged and msg.bloomFilter.len > 0: - let bfResult = deserializeBloomFilter(msg.bloomFilter) - if bfResult.isOk(): - var rbf = RollingBloomFilter( + # Parse bloom filter + var rbf: Option[RollingBloomFilter] + if msg.bloomFilter.len > 0: + let bfResult = deserializeBloomFilter(msg.bloomFilter) + if bfResult.isOk(): + rbf = some( + RollingBloomFilter( filter: bfResult.get(), capacity: bfResult.get().capacity, minCapacity: ( @@ -62,17 +66,27 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = ).int, messages: @[], ) - if rbf.contains(outMsg.message.messageId): - acknowledged = true - else: - error "Failed to deserialize bloom filter", error = bfResult.error + ) + else: + error "Failed to deserialize bloom filter", error = bfResult.error + rbf = none[RollingBloomFilter]() + else: + rbf = none[RollingBloomFilter]() - if acknowledged: + # Keep track of indices to delete + var toDelete: seq[int] = @[] + var i = 0 + + while i < rm.outgoingBuffer.len: + let outMsg = rm.outgoingBuffer[i] + if outMsg.isAcknowledged(msg.causalHistory, rbf): if not rm.onMessageSent.isNil(): rm.onMessageSent(outMsg.message.messageId) - rm.outgoingBuffer.delete(i) - else: - inc i + toDelete.add(i) + inc i + + for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices + rm.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID @@ -127,43 +141,36 @@ proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = if rm.incomingBuffer.len == 0: return - # Create dependency map - var dependencies = initTable[SdsMessageID, seq[SdsMessageID]]() - var readyToProcess: seq[SdsMessageID] = @[] var processed = initHashSet[SdsMessageID]() + var readyToProcess = newSeq[SdsMessageID]() - # Build dependency graph and find initially ready messages - for msg in rm.incomingBuffer: - var hasMissingDeps = false - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - hasMissingDeps = true - if depId notin dependencies: - dependencies[depId] = @[] - dependencies[depId].add(msg.messageId) - - if not hasMissingDeps: - readyToProcess.add(msg.messageId) + # Find initially ready messages + for msgId, entry in rm.incomingBuffer: + if entry.missingDeps.len == 0: + readyToProcess.add(msgId) while readyToProcess.len > 0: let msgId = readyToProcess.pop() if msgId in processed: continue - # Process this message - for msg in rm.incomingBuffer: - if msg.messageId == msgId: - rm.addToHistory(msg.messageId) - if not rm.onMessageReady.isNil(): - rm.onMessageReady(msg.messageId) - processed.incl(msgId) + if msgId in rm.incomingBuffer: + rm.addToHistory(msgId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msgId) + processed.incl(msgId) - # Add any dependent messages that might now be ready - if msgId in dependencies: - readyToProcess.add(dependencies[msgId]) - break + # Update dependencies for remaining messages + for remainingId, entry in rm.incomingBuffer: + if remainingId notin processed: + if msgId in entry.missingDeps: + rm.incomingBuffer[remainingId].missingDeps.excl(msgId) + if rm.incomingBuffer[remainingId].missingDeps.len == 0: + readyToProcess.add(remainingId) - rm.incomingBuffer = rm.incomingBuffer.filterIt(it.messageId notin processed) + # Remove processed messages + for msgId in processed: + rm.incomingBuffer.del(msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] @@ -179,7 +186,7 @@ proc unwrapReceivedMessage*( let msg = deserializeMessage(message).valueOr: return err(ReliabilityError.reDeserializationError) - if rm.bloomFilter.contains(msg.messageId): + if msg.messageId in rm.messageHistory: return ok((msg.content, @[])) rm.bloomFilter.add(msg.messageId) @@ -190,21 +197,21 @@ proc unwrapReceivedMessage*( # Review ACK status for outgoing messages rm.reviewAckStatus(msg) - var missingDeps: seq[SdsMessageID] = @[] - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - missingDeps.add(depId) + var missingDeps = rm.checkDependencies(msg.causalHistory) if missingDeps.len == 0: # Check if any dependencies are still in incoming buffer var depsInBuffer = false - for bufferedMsg in rm.incomingBuffer: - if bufferedMsg.messageId in msg.causalHistory: + for msgId, entry in rm.incomingBuffer.pairs(): + if msgId in msg.causalHistory: depsInBuffer = true break if depsInBuffer: - rm.incomingBuffer.add(msg) + rm.incomingBuffer[msg.messageId] = IncomingMessage( + message: msg, + missingDeps: initHashSet[SdsMessageID]() + ) else: # All dependencies met, add to history rm.addToHistory(msg.messageId) @@ -212,7 +219,10 @@ proc unwrapReceivedMessage*( if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) else: - rm.incomingBuffer.add(msg) + rm.incomingBuffer[msg.messageId] = IncomingMessage( + message: msg, + missingDeps: missingDeps.toHashSet() + ) if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps) @@ -238,6 +248,11 @@ proc markDependenciesMet*( rm.bloomFilter.add(msgId) # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? + # Update any pending messages that depend on this one + for pendingId, entry in rm.incomingBuffer: + if msgId in entry.missingDeps: + rm.incomingBuffer[pendingId].missingDeps.excl(msgId) + rm.processIncomingBuffer() return ok() except Exception: @@ -329,7 +344,7 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE rm.lamportTimestamp = 0 rm.messageHistory.setLen(0) rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) + rm.incomingBuffer.clear() rm.bloomFilter = newRollingBloomFilter( rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate ) diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index a8d376f..ac05dab 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -1,4 +1,4 @@ -import std/[times, locks] +import std/[times, locks, options] import chronicles import ./[rolling_bloom_filter, message] @@ -20,8 +20,8 @@ type messageHistory*: seq[SdsMessageID] bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] - incomingBuffer*: seq[SdsMessage] - channelId*: SdsChannelID + incomingBuffer*: Table[SdsMessageID, IncomingMessage] + channelId*: Option[SdsChannelID] config*: ReliabilityConfig lock*: Lock onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.} @@ -59,7 +59,7 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} = try: withLock rm.lock: rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) + rm.incomingBuffer.clear() rm.messageHistory.setLen(0) except Exception: error "Error during cleanup", error = getCurrentExceptionMsg() @@ -84,6 +84,15 @@ proc updateLamportTimestamp*( proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] = result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] +proc checkDependencies*( + rm: ReliabilityManager, deps: seq[SdsMessageID] +): seq[SdsMessageID] = + var missingDeps: seq[SdsMessageID] = @[] + for depId in deps: + if depId notin rm.messageHistory: + missingDeps.add(depId) + return missingDeps + proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] = withLock rm.lock: result = rm.messageHistory @@ -92,6 +101,8 @@ proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] = withLock rm.lock: result = rm.outgoingBuffer -proc getIncomingBuffer*(rm: ReliabilityManager): seq[SdsMessage] = +proc getIncomingBuffer*( + rm: ReliabilityManager +): Table[SdsMessageID, message.IncomingMessage] = withLock rm.lock: result = rm.incomingBuffer diff --git a/tests/test_bloom.nim b/tests/test_bloom.nim index 540735d..ad88bba 100644 --- a/tests/test_bloom.nim +++ b/tests/test_bloom.nim @@ -1,7 +1,6 @@ import unittest, results, strutils import ../src/bloom from random import rand, randomize -import ../src/[message, protobuf, protobufutil, reliability_utils, rolling_bloom_filter] suite "bloom filter": setup: