diff --git a/reliability.nimble b/reliability.nimble new file mode 100644 index 0000000..8bc19c8 --- /dev/null +++ b/reliability.nimble @@ -0,0 +1,16 @@ +# Package +version = "0.1.0" +author = "Waku Team" +description = "E2E Reliability Protocol API" +license = "MIT" +srcDir = "src" + +# Dependencies +requires "nim >= 2.0.8" +requires "chronicles" +requires "libp2p" + +# Tasks +task test, "Run the test suite": + exec "nim c -r tests/test_bloom.nim" + exec "nim c -r tests/test_reliability.nim" diff --git a/src/message.nim b/src/message.nim index 2d950d4..f9c68c0 100644 --- a/src/message.nim +++ b/src/message.nim @@ -1,24 +1,25 @@ -import std/times +import std/[times, options, sets] type - MessageID* = string + SdsMessageID* = seq[byte] + SdsChannelID* = seq[byte] - Message* = object - messageId*: MessageID + SdsMessage* = object + messageId*: SdsMessageID lamportTimestamp*: int64 - causalHistory*: seq[MessageID] - channelId*: string + causalHistory*: seq[SdsMessageID] + channelId*: Option[SdsChannelID] content*: seq[byte] bloomFilter*: seq[byte] UnacknowledgedMessage* = object - message*: Message + message*: SdsMessage sendTime*: Time resendAttempts*: int - TimestampedMessageID* = object - id*: MessageID - timestamp*: Time + IncomingMessage* = object + message*: SdsMessage + missingDeps*: HashSet[SdsMessageID] const DefaultMaxMessageHistory* = 1000 diff --git a/src/private/probabilities.nim b/src/private/probabilities.nim index f7afb0e..1588aef 100644 --- a/src/private/probabilities.nim +++ b/src/private/probabilities.nim @@ -9,9 +9,7 @@ type TErrorForK = seq[float] TAllErrorRates* = array[0 .. 12, TErrorForK] -var kErrors* {.threadvar.}: TAllErrorRates - -kErrors = [ +const kErrors*: TAllErrorRates = [ @[1.0], @[ 1.0, 1.0, 0.3930000000, 0.2830000000, 0.2210000000, 0.1810000000, 0.1540000000, diff --git a/src/protobuf.nim b/src/protobuf.nim index 4230aa9..4689da2 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -1,30 +1,28 @@ import libp2p/protobuf/minprotobuf import std/options +import endians import ../src/[message, protobufutil, bloom, reliability_utils] -proc toBytes(s: string): seq[byte] = - result = newSeq[byte](s.len) - copyMem(result[0].addr, s[0].unsafeAddr, s.len) - -proc encode*(msg: Message): ProtoBuffer = +proc encode*(msg: SdsMessage): ProtoBuffer = var pb = initProtoBuffer() pb.write(1, msg.messageId) pb.write(2, uint64(msg.lamportTimestamp)) for hist in msg.causalHistory: - pb.write(3, hist.toBytes) # Convert string to bytes for proper length handling + pb.write(3, hist) - pb.write(4, msg.channelId) + if msg.channelId.isSome(): + pb.write(4, msg.channelId.get()) pb.write(5, msg.content) pb.write(6, msg.bloomFilter) pb.finish() pb -proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] = +proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) - var msg = Message() + var msg = SdsMessage() if not ?pb.getField(1, msg.messageId): return err(ProtobufError.missingRequiredField("messageId")) @@ -34,14 +32,16 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] = return err(ProtobufError.missingRequiredField("lamportTimestamp")) msg.lamportTimestamp = int64(timestamp) - # Decode causal history - var causalHistory: seq[string] + var causalHistory: seq[seq[byte]] let histResult = pb.getRepeatedField(3, causalHistory) if histResult.isOk: msg.causalHistory = causalHistory - if not ?pb.getField(4, msg.channelId): - return err(ProtobufError.missingRequiredField("channelId")) + var channelId: seq[byte] + if ?pb.getField(4, channelId): + msg.channelId = some(channelId) + else: + msg.channelId = none[SdsChannelID]() if not ?pb.getField(5, msg.content): return err(ProtobufError.missingRequiredField("content")) @@ -51,63 +51,59 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] = ok(msg) -proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] = - try: - let pb = encode(msg) - ok(pb.buffer) - except: - err(reSerializationError) +proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = + let pb = encode(msg) + ok(pb.buffer) -proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] = - try: - let msgResult = Message.decode(data) - if msgResult.isOk: - ok(msgResult.get) - else: - err(reSerializationError) - except: - err(reDeserializationError) +proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] = + let msg = SdsMessage.decode(data).valueOr: + return err(ReliabilityError.reDeserializationError) + ok(msg) proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] = - try: - var pb = initProtoBuffer() + var pb = initProtoBuffer() - # Convert intArray to bytes + # Convert intArray to bytes + try: var bytes = newSeq[byte](filter.intArray.len * sizeof(int)) for i, val in filter.intArray: + var leVal: int + littleEndian64(addr leVal, unsafeAddr val) let start = i * sizeof(int) - copyMem(addr bytes[start], unsafeAddr val, sizeof(int)) + copyMem(addr bytes[start], addr leVal, sizeof(int)) pb.write(1, bytes) pb.write(2, uint64(filter.capacity)) pb.write(3, uint64(filter.errorRate * 1_000_000)) pb.write(4, uint64(filter.kHashes)) pb.write(5, uint64(filter.mBits)) - - pb.finish() - ok(pb.buffer) except: - err(reSerializationError) + return err(ReliabilityError.reSerializationError) + + pb.finish() + ok(pb.buffer) proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] = if data.len == 0: - return err(reDeserializationError) + return err(ReliabilityError.reDeserializationError) + + let pb = initProtoBuffer(data) + var bytes: seq[byte] + var cap, errRate, kHashes, mBits: uint64 try: - let pb = initProtoBuffer(data) - var bytes: seq[byte] - var cap, errRate, kHashes, mBits: uint64 - if not pb.getField(1, bytes).get() or not pb.getField(2, cap).get() or not pb.getField(3, errRate).get() or not pb.getField(4, kHashes).get() or not pb.getField(5, mBits).get(): - return err(reDeserializationError) + return err(ReliabilityError.reDeserializationError) # Convert bytes back to intArray var intArray = newSeq[int](bytes.len div sizeof(int)) for i in 0 ..< intArray.len: + var leVal: int let start = i * sizeof(int) - copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int)) + copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int)) + littleEndian64(addr intArray[i], addr leVal) ok( BloomFilter( @@ -119,4 +115,4 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr ) ) except: - err(reDeserializationError) + return err(ReliabilityError.reDeserializationError) diff --git a/src/reliability.nim b/src/reliability.nim index 1262c7d..0f8282d 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,9 +1,9 @@ -import std/[times, locks, tables, sets] -import chronos, results -import ../src/[message, protobuf, reliability_utils, rolling_bloom_filter] +import std/[times, locks, tables, sets, options] +import chronos, results, chronicles +import ./[message, protobuf, reliability_utils, rolling_bloom_filter] proc newReliabilityManager*( - channelId: string, config: ReliabilityConfig = defaultConfig() + channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new ReliabilityManager with the specified channel ID and configuration. ## @@ -13,94 +13,113 @@ proc newReliabilityManager*( ## ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. - if channelId.len == 0: - return err(reInvalidArgument) + if not channelId.isSome(): + return err(ReliabilityError.reInvalidArgument) try: - let bloomFilter = newRollingBloomFilter( - config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow - ) + let bloomFilter = + newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate) let rm = ReliabilityManager( lamportTimestamp: 0, messageHistory: @[], bloomFilter: bloomFilter, outgoingBuffer: @[], - incomingBuffer: @[], + incomingBuffer: initTable[SdsMessageID, IncomingMessage](), channelId: channelId, config: config, ) initLock(rm.lock) return ok(rm) - except: - return err(reOutOfMemory) + except Exception: + error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reOutOfMemory) -proc reviewAckStatus(rm: ReliabilityManager, msg: Message) = - var i = 0 - while i < rm.outgoingBuffer.len: - var acknowledged = false - let outMsg = rm.outgoingBuffer[i] +proc isAcknowledged*( + msg: UnacknowledgedMessage, + causalHistory: seq[SdsMessageID], + rbf: Option[RollingBloomFilter], +): bool = + if msg.message.messageId in causalHistory: + return true - # Check if message is in causal history - for msgID in msg.causalHistory: - if outMsg.message.messageId == msgID: - acknowledged = true - break + if rbf.isSome(): + return rbf.get().contains(msg.message.messageId) - # Check bloom filter if not already acknowledged - if not acknowledged and msg.bloomFilter.len > 0: - let bfResult = deserializeBloomFilter(msg.bloomFilter) - if bfResult.isOk: - var rbf = RollingBloomFilter( - filter: bfResult.get(), window: rm.bloomFilter.window, messages: @[] + false + +proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = + # Parse bloom filter + var rbf: Option[RollingBloomFilter] + if msg.bloomFilter.len > 0: + let bfResult = deserializeBloomFilter(msg.bloomFilter) + if bfResult.isOk(): + rbf = some( + RollingBloomFilter( + filter: bfResult.get(), + capacity: bfResult.get().capacity, + minCapacity: ( + bfResult.get().capacity.float * (100 - CapacityFlexPercent).float / 100.0 + ).int, + maxCapacity: ( + bfResult.get().capacity.float * (100 + CapacityFlexPercent).float / 100.0 + ).int, + messages: @[], ) - if rbf.contains(outMsg.message.messageId): - acknowledged = true - else: - logError("Failed to deserialize bloom filter") - - if acknowledged: - if rm.onMessageSent != nil: - rm.onMessageSent(outMsg.message.messageId) - rm.outgoingBuffer.delete(i) + ) else: - inc i + error "Failed to deserialize bloom filter", error = bfResult.error + rbf = none[RollingBloomFilter]() + else: + rbf = none[RollingBloomFilter]() + + # Keep track of indices to delete + var toDelete: seq[int] = @[] + var i = 0 + + while i < rm.outgoingBuffer.len: + let outMsg = rm.outgoingBuffer[i] + if outMsg.isAcknowledged(msg.causalHistory, rbf): + if not rm.onMessageSent.isNil(): + rm.onMessageSent(outMsg.message.messageId) + toDelete.add(i) + inc i + + for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices + rm.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( - rm: ReliabilityManager, message: seq[byte], messageId: MessageID + rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## ## Parameters: ## - message: The content of the message to be sent. + ## - messageId: Unique identifier for the message ## ## Returns: - ## A Result containing either a Message object with reliability metadata or an error. + ## A Result containing either wrapped message bytes or an error. if message.len == 0: - return err(reInvalidArgument) + return err(ReliabilityError.reInvalidArgument) if message.len > MaxMessageSize: - return err(reMessageTooLarge) + return err(ReliabilityError.reMessageTooLarge) withLock rm.lock: try: rm.updateLamportTimestamp(getTime().toUnix) - # Serialize current bloom filter - var bloomBytes: seq[byte] let bfResult = serializeBloomFilter(rm.bloomFilter.filter) if bfResult.isErr: - logError("Failed to serialize bloom filter") - bloomBytes = @[] - else: - bloomBytes = bfResult.get() + error "Failed to serialize bloom filter" + return err(ReliabilityError.reSerializationError) - let msg = Message( + let msg = SdsMessage( messageId: messageId, lamportTimestamp: rm.lamportTimestamp, - causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), + causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory), channelId: rm.channelId, content: message, - bloomFilter: bloomBytes, + bloomFilter: bfResult.get(), ) # Add to outgoing buffer @@ -113,80 +132,61 @@ proc wrapOutgoingMessage*( rm.addToHistory(msg.messageId) return serializeMessage(msg) - except: - return err(reInternalError) + except Exception: + error "Failed to wrap message", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reSerializationError) -proc processIncomingBuffer(rm: ReliabilityManager) = +proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} = withLock rm.lock: if rm.incomingBuffer.len == 0: return - # Create dependency map - var dependencies = initTable[MessageID, seq[MessageID]]() - var readyToProcess: seq[MessageID] = @[] + var processed = initHashSet[SdsMessageID]() + var readyToProcess = newSeq[SdsMessageID]() - # Build dependency graph and find initially ready messages - for msg in rm.incomingBuffer: - var hasMissingDeps = false - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - hasMissingDeps = true - if depId notin dependencies: - dependencies[depId] = @[] - dependencies[depId].add(msg.messageId) - - if not hasMissingDeps: - readyToProcess.add(msg.messageId) - - # Process ready messages and their dependents - var newIncomingBuffer: seq[Message] = @[] - var processed = initHashSet[MessageID]() + # Find initially ready messages + for msgId, entry in rm.incomingBuffer: + if entry.missingDeps.len == 0: + readyToProcess.add(msgId) while readyToProcess.len > 0: let msgId = readyToProcess.pop() if msgId in processed: continue - # Process this message - for msg in rm.incomingBuffer: - if msg.messageId == msgId: - rm.addToHistory(msg.messageId) - if rm.onMessageReady != nil: - rm.onMessageReady(msg.messageId) - processed.incl(msgId) + if msgId in rm.incomingBuffer: + rm.addToHistory(msgId) + if not rm.onMessageReady.isNil(): + rm.onMessageReady(msgId) + processed.incl(msgId) - # Add any dependent messages that might now be ready - if msgId in dependencies: - for dependentId in dependencies[msgId]: - readyToProcess.add(dependentId) - break + # Update dependencies for remaining messages + for remainingId, entry in rm.incomingBuffer: + if remainingId notin processed: + if msgId in entry.missingDeps: + rm.incomingBuffer[remainingId].missingDeps.excl(msgId) + if rm.incomingBuffer[remainingId].missingDeps.len == 0: + readyToProcess.add(remainingId) - # Update incomingBuffer with remaining messages - for msg in rm.incomingBuffer: - if msg.messageId notin processed: - newIncomingBuffer.add(msg) - - rm.incomingBuffer = newIncomingBuffer + # Remove processed messages + for msgId in processed: + rm.incomingBuffer.del(msgId) proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] -): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] {. - gcsafe -.} = +): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] = ## Unwraps a received message and processes its reliability metadata. ## ## Parameters: - ## - message: The received Message object. + ## - message: The received message bytes ## ## Returns: - ## A Result containing either a tuple with the processed message and missing dependencies, or an error. + ## A Result containing either tuple of (processed message, missing dependencies) or an error. try: - let msgResult = deserializeMessage(message) - if not msgResult.isOk: - return err(msgResult.error) + let msg = deserializeMessage(message).valueOr: + return err(ReliabilityError.reDeserializationError) - let msg = msgResult.get - if rm.bloomFilter.contains(msg.messageId): + if msg.messageId in rm.messageHistory: return ok((msg.content, @[])) rm.bloomFilter.add(msg.messageId) @@ -197,38 +197,42 @@ proc unwrapReceivedMessage*( # Review ACK status for outgoing messages rm.reviewAckStatus(msg) - var missingDeps: seq[MessageID] = @[] - for depId in msg.causalHistory: - if not rm.bloomFilter.contains(depId): - missingDeps.add(depId) + var missingDeps = rm.checkDependencies(msg.causalHistory) if missingDeps.len == 0: # Check if any dependencies are still in incoming buffer var depsInBuffer = false - for bufferedMsg in rm.incomingBuffer: - if bufferedMsg.messageId in msg.causalHistory: + for msgId, entry in rm.incomingBuffer.pairs(): + if msgId in msg.causalHistory: depsInBuffer = true break + if depsInBuffer: - rm.incomingBuffer.add(msg) + rm.incomingBuffer[msg.messageId] = IncomingMessage( + message: msg, + missingDeps: initHashSet[SdsMessageID]() + ) else: # All dependencies met, add to history rm.addToHistory(msg.messageId) rm.processIncomingBuffer() - if rm.onMessageReady != nil: + if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId) else: - # Buffer message and request missing dependencies - rm.incomingBuffer.add(msg) - if rm.onMissingDependencies != nil: + rm.incomingBuffer[msg.messageId] = IncomingMessage( + message: msg, + missingDeps: missingDeps.toHashSet() + ) + if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps) return ok((msg.content, missingDeps)) - except: - return err(reInternalError) + except Exception: + error "Failed to unwrap message", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reDeserializationError) proc markDependenciesMet*( - rm: ReliabilityManager, messageIds: seq[MessageID] + rm: ReliabilityManager, messageIds: seq[SdsMessageID] ): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. ## @@ -243,11 +247,17 @@ proc markDependenciesMet*( if not rm.bloomFilter.contains(msgId): rm.bloomFilter.add(msgId) # rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application? - rm.processIncomingBuffer() + # Update any pending messages that depend on this one + for pendingId, entry in rm.incomingBuffer: + if msgId in entry.missingDeps: + rm.incomingBuffer[pendingId].missingDeps.excl(msgId) + + rm.processIncomingBuffer() return ok() - except: - return err(reInternalError) + except Exception: + error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) proc setCallbacks*( rm: ReliabilityManager, @@ -269,53 +279,52 @@ proc setCallbacks*( rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = +proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} = ## Checks and processes unacknowledged messages in the outgoing buffer. withLock rm.lock: let now = getTime() var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] - try: - for unackMsg in rm.outgoingBuffer: - let elapsed = now - unackMsg.sendTime - if elapsed > rm.config.resendInterval: - # Time to attempt resend - if unackMsg.resendAttempts < rm.config.maxResendAttempts: - var updatedMsg = unackMsg - updatedMsg.resendAttempts += 1 - updatedMsg.sendTime = now - newOutgoingBuffer.add(updatedMsg) - else: - if rm.onMessageSent != nil: - rm.onMessageSent(unackMsg.message.messageId) + for unackMsg in rm.outgoingBuffer: + let elapsed = now - unackMsg.sendTime + if elapsed > rm.config.resendInterval: + # Time to attempt resend + if unackMsg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = unackMsg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) else: - newOutgoingBuffer.add(unackMsg) + if not rm.onMessageSent.isNil(): + rm.onMessageSent(unackMsg.message.messageId) + else: + newOutgoingBuffer.add(unackMsg) - rm.outgoingBuffer = newOutgoingBuffer - except Exception as e: - logError("Error in checking unacknowledged messages: " & e.msg) + rm.outgoingBuffer = newOutgoingBuffer -proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = +proc periodicBufferSweep( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: - {.gcsafe.}: - try: - rm.checkUnacknowledgedMessages() - rm.cleanBloomFilter() - except Exception as e: - logError("Error in periodic buffer sweep: " & e.msg) + try: + rm.checkUnacknowledgedMessages() + rm.cleanBloomFilter() + except Exception: + error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg() await sleepAsync(chronos.milliseconds(rm.config.bufferSweepInterval.inMilliseconds)) -proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = +proc periodicSyncMessage( + rm: ReliabilityManager +) {.async: (raises: [CancelledError]), gcsafe.} = ## Periodically notifies to send a sync message to maintain connectivity. while true: - {.gcsafe.}: - try: - if rm.onPeriodicSync != nil: - rm.onPeriodicSync() - except Exception as e: - logError("Error in periodic sync: " & e.msg) + try: + if not rm.onPeriodicSync.isNil(): + rm.onPeriodicSync() + except Exception: + error "Error in periodic sync", msg = getCurrentExceptionMsg() await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds)) proc startPeriodicTasks*(rm: ReliabilityManager) = @@ -329,19 +338,16 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE ## Resets the ReliabilityManager to its initial state. ## ## This procedure clears all buffers and resets the Lamport timestamp. - ## - ## Returns: - ## A Result indicating success or an error if the Bloom filter initialization fails. withLock rm.lock: try: rm.lamportTimestamp = 0 rm.messageHistory.setLen(0) rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) + rm.incomingBuffer.clear() rm.bloomFilter = newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, - rm.config.bloomFilterWindow, + rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate ) return ok() - except: - return err(reInternalError) + except Exception: + error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg() + return err(ReliabilityError.reInternalError) diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index 367e965..d9bf316 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -1,49 +1,43 @@ -import std/[times, locks] +import std/[times, locks, options] +import chronicles import ./[rolling_bloom_filter, message] type - MessageReadyCallback* = proc(messageId: MessageID) {.gcsafe.} + MessageReadyCallback* = proc(messageId: SdsMessageID) {.gcsafe.} - MessageSentCallback* = proc(messageId: MessageID) {.gcsafe.} + MessageSentCallback* = proc(messageId: SdsMessageID) {.gcsafe.} MissingDependenciesCallback* = - proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} + proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.} PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} - AppCallbacks* = ref object - messageReadyCb*: MessageReadyCallback - messageSentCb*: MessageSentCallback - missingDependenciesCb*: MissingDependenciesCallback - periodicSyncCb*: PeriodicSyncCallback - ReliabilityConfig* = object bloomFilterCapacity*: int bloomFilterErrorRate*: float - bloomFilterWindow*: times.Duration maxMessageHistory*: int maxCausalHistory*: int - resendInterval*: times.Duration + resendInterval*: Duration maxResendAttempts*: int - syncMessageInterval*: times.Duration - bufferSweepInterval*: times.Duration + syncMessageInterval*: Duration + bufferSweepInterval*: Duration ReliabilityManager* = ref object lamportTimestamp*: int64 - messageHistory*: seq[MessageID] + messageHistory*: seq[SdsMessageID] bloomFilter*: RollingBloomFilter outgoingBuffer*: seq[UnacknowledgedMessage] - incomingBuffer*: seq[Message] - channelId*: string + incomingBuffer*: Table[SdsMessageID, IncomingMessage] + channelId*: Option[SdsChannelID] config*: ReliabilityConfig lock*: Lock - onMessageReady*: proc(messageId: MessageID) {.gcsafe.} - onMessageSent*: proc(messageId: MessageID) {.gcsafe.} + onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.} + onMessageSent*: proc(messageId: SdsMessageID) {.gcsafe.} onMissingDependencies*: - proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} - onPeriodicSync*: proc() + proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.} + onPeriodicSync*: PeriodicSyncCallback - ReliabilityError* = enum + ReliabilityError* {.pure.} = enum reInvalidArgument reOutOfMemory reInternalError @@ -59,7 +53,6 @@ proc defaultConfig*(): ReliabilityConfig = ReliabilityConfig( bloomFilterCapacity: DefaultBloomFilterCapacity, bloomFilterErrorRate: DefaultBloomFilterErrorRate, - bloomFilterWindow: DefaultBloomFilterWindow, maxMessageHistory: DefaultMaxMessageHistory, maxCausalHistory: DefaultMaxCausalHistory, resendInterval: DefaultResendInterval, @@ -69,23 +62,23 @@ proc defaultConfig*(): ReliabilityConfig = ) proc cleanup*(rm: ReliabilityManager) {.raises: [].} = - if not rm.isNil: - {.gcsafe.}: - try: + if not rm.isNil(): + try: + withLock rm.lock: rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) + rm.incomingBuffer.clear() rm.messageHistory.setLen(0) - except Exception as e: - logError("Error during cleanup: " & e.msg) + except Exception: + error "Error during cleanup", error = getCurrentExceptionMsg() proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} = withLock rm.lock: try: rm.bloomFilter.clean() - except Exception as e: - logError("Failed to clean ReliabilityManager bloom filter: " & e.msg) + except Exception: + error "Failed to clean bloom filter", error = getCurrentExceptionMsg() -proc addToHistory*(rm: ReliabilityManager, msgId: MessageID) {.gcsafe, raises: [].} = +proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID) {.gcsafe, raises: [].} = rm.messageHistory.add(msgId) if rm.messageHistory.len > rm.config.maxMessageHistory: rm.messageHistory.delete(0) @@ -95,10 +88,19 @@ proc updateLamportTimestamp*( ) {.gcsafe, raises: [].} = rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 -proc getRecentMessageIDs*(rm: ReliabilityManager, n: int): seq[MessageID] = +proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] = result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] -proc getMessageHistory*(rm: ReliabilityManager): seq[MessageID] = +proc checkDependencies*( + rm: ReliabilityManager, deps: seq[SdsMessageID] +): seq[SdsMessageID] = + var missingDeps: seq[SdsMessageID] = @[] + for depId in deps: + if depId notin rm.messageHistory: + missingDeps.add(depId) + return missingDeps + +proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] = withLock rm.lock: result = rm.messageHistory @@ -106,6 +108,8 @@ proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] = withLock rm.lock: result = rm.outgoingBuffer -proc getIncomingBuffer*(rm: ReliabilityManager): seq[Message] = +proc getIncomingBuffer*( + rm: ReliabilityManager +): Table[SdsMessageID, message.IncomingMessage] = withLock rm.lock: result = rm.incomingBuffer diff --git a/src/rolling_bloom_filter.nim b/src/rolling_bloom_filter.nim index c0282be..190ab8a 100644 --- a/src/rolling_bloom_filter.nim +++ b/src/rolling_bloom_filter.nim @@ -1,64 +1,113 @@ -import std/times import chronos import chronicles import ./[bloom, message] type RollingBloomFilter* = object filter*: BloomFilter - window*: times.Duration - messages*: seq[TimestampedMessageID] + capacity*: int + minCapacity*: int + maxCapacity*: int + messages*: seq[SdsMessageID] const DefaultBloomFilterCapacity* = 10000 DefaultBloomFilterErrorRate* = 0.001 - DefaultBloomFilterWindow* = initDuration(hours = 1) - -proc logError*(msg: string) = - error "ReliabilityError", message = msg - -proc logInfo*(msg: string) = - info "ReliabilityInfo", message = msg + CapacityFlexPercent* = 20 proc newRollingBloomFilter*( - capacity: int, errorRate: float, window: times.Duration + capacity: int = DefaultBloomFilterCapacity, + errorRate: float = DefaultBloomFilterErrorRate, ): RollingBloomFilter {.gcsafe.} = - try: - var filterResult: Result[BloomFilter, string] - {.gcsafe.}: - filterResult = initializeBloomFilter(capacity, errorRate) + let targetCapacity = if capacity <= 0: DefaultBloomFilterCapacity else: capacity + let targetError = + if errorRate <= 0.0 or errorRate >= 1.0: DefaultBloomFilterErrorRate else: errorRate + + let filterResult = initializeBloomFilter(targetCapacity, targetError) + if filterResult.isErr: + error "Failed to initialize bloom filter", error = filterResult.error + # Try with default values if custom values failed + if capacity != DefaultBloomFilterCapacity or errorRate != DefaultBloomFilterErrorRate: + let defaultResult = + initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) + if defaultResult.isErr: + error "Failed to initialize bloom filter with default parameters", + error = defaultResult.error + + let minCapacity = ( + DefaultBloomFilterCapacity.float * (100 - CapacityFlexPercent).float / 100.0 + ).int + let maxCapacity = ( + DefaultBloomFilterCapacity.float * (100 + CapacityFlexPercent).float / 100.0 + ).int + + info "Successfully initialized bloom filter with default parameters", + capacity = DefaultBloomFilterCapacity, + minCapacity = minCapacity, + maxCapacity = maxCapacity - if filterResult.isOk: - logInfo("Successfully initialized bloom filter") return RollingBloomFilter( - filter: filterResult.get(), # Extract the BloomFilter from Result - window: window, + filter: defaultResult.get(), + capacity: DefaultBloomFilterCapacity, + minCapacity: minCapacity, + maxCapacity: maxCapacity, messages: @[], ) else: - logError("Failed to initialize bloom filter: " & filterResult.error) - # Fall through to default case below - except: - logError("Failed to initialize bloom filter") + error "Could not create bloom filter", error = filterResult.error - # Default fallback case - let defaultResult = - initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) - if defaultResult.isOk: - return - RollingBloomFilter(filter: defaultResult.get(), window: window, messages: @[]) - else: - # If even default initialization fails, raise an exception - logError("Failed to initialize bloom filter with default parameters") + let minCapacity = + (targetCapacity.float * (100 - CapacityFlexPercent).float / 100.0).int + let maxCapacity = + (targetCapacity.float * (100 + CapacityFlexPercent).float / 100.0).int -proc add*(rbf: var RollingBloomFilter, messageId: MessageID) {.gcsafe.} = + info "Successfully initialized bloom filter", + capacity = targetCapacity, minCapacity = minCapacity, maxCapacity = maxCapacity + + return RollingBloomFilter( + filter: filterResult.get(), + capacity: targetCapacity, + minCapacity: minCapacity, + maxCapacity: maxCapacity, + messages: @[], + ) + +proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} = + try: + if rbf.messages.len <= rbf.maxCapacity: + return # Don't clean unless we exceed max capacity + + # Initialize new filter + var newFilter = initializeBloomFilter(rbf.maxCapacity, rbf.filter.errorRate).valueOr: + error "Failed to create new bloom filter", error = $error + return + + # Keep most recent messages up to minCapacity + let keepCount = rbf.minCapacity + let startIdx = max(0, rbf.messages.len - keepCount) + var newMessages: seq[SdsMessageID] = @[] + + for i in startIdx ..< rbf.messages.len: + newMessages.add(rbf.messages[i]) + newFilter.insert(cast[string](rbf.messages[i])) + + rbf.messages = newMessages + rbf.filter = newFilter + except Exception: + error "Failed to clean bloom filter", error = getCurrentExceptionMsg() + +proc add*(rbf: var RollingBloomFilter, messageId: SdsMessageID) {.gcsafe.} = ## Adds a message ID to the rolling bloom filter. ## ## Parameters: ## - messageId: The ID of the message to add. - rbf.filter.insert(messageId) - rbf.messages.add(TimestampedMessageID(id: messageId, timestamp: getTime())) + rbf.filter.insert(cast[string](messageId)) + rbf.messages.add(messageId) -proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} = + # Clean if we exceed max capacity + if rbf.messages.len > rbf.maxCapacity: + rbf.clean() + +proc contains*(rbf: RollingBloomFilter, messageId: SdsMessageID): bool = ## Checks if a message ID is in the rolling bloom filter. ## ## Parameters: @@ -66,29 +115,4 @@ proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} = ## ## Returns: ## True if the message ID is probably in the filter, false otherwise. - rbf.filter.lookup(messageId) - -proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} = - try: - let now = getTime() - let cutoff = now - rbf.window - var newMessages: seq[TimestampedMessageID] = @[] - - # Initialize new filter - let newFilterResult = - initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate) - if newFilterResult.isErr: - logError("Failed to create new bloom filter: " & newFilterResult.error) - return - - var newFilter = newFilterResult.get() - - for msg in rbf.messages: - if msg.timestamp > cutoff: - newMessages.add(msg) - newFilter.insert(msg.id) - - rbf.messages = newMessages - rbf.filter = newFilter - except Exception as e: - logError("Failed to clean bloom filter: " & e.msg) + rbf.filter.lookup(cast[string](messageId))