diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..82f3c3e --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ + +.DS_Store +tests/test_reliability diff --git a/nim-bloom/src/bloom.nim b/nim-bloom/src/bloom.nim index 333ea7a..d5f8f71 100644 --- a/nim-bloom/src/bloom.nim +++ b/nim-bloom/src/bloom.nim @@ -7,7 +7,7 @@ import private/probabilities {.compile: "murmur3.c".} type - BloomFilterError = object of CatchableError + BloomFilterError* = object of CatchableError MurmurHashes = array[0..1, int] BloomFilter* = object capacity*: int diff --git a/reliability.nimble b/reliability.nimble index 7f51edb..6ba5913 100644 --- a/reliability.nimble +++ b/reliability.nimble @@ -8,6 +8,7 @@ srcDir = "src" # Dependencies requires "nim >= 2.0.8" requires "chronicles" +requires "libp2p" # Tasks task test, "Run the test suite": diff --git a/src/common.nim b/src/common.nim index ca776d5..9ebac22 100644 --- a/src/common.nim +++ b/src/common.nim @@ -1,16 +1,15 @@ -import std/[times, json, locks] +import std/[times, locks] import "../nim-bloom/src/bloom" type MessageID* = string Message* = object - senderId*: string messageId*: MessageID lamportTimestamp*: int64 causalHistory*: seq[MessageID] channelId*: string - content*: string + content*: seq[byte] UnacknowledgedMessage* = object message*: Message @@ -21,18 +20,20 @@ type id*: MessageID timestamp*: Time + PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} + RollingBloomFilter* = object filter*: BloomFilter - window*: Duration + window*: times.Duration messages*: seq[TimestampedMessageID] ReliabilityConfig* = object bloomFilterCapacity*: int bloomFilterErrorRate*: float - bloomFilterWindow*: Duration + bloomFilterWindow*: times.Duration maxMessageHistory*: int maxCausalHistory*: int - resendInterval*: Duration + resendInterval*: times.Duration maxResendAttempts*: int ReliabilityManager* = ref object @@ -44,26 +45,19 @@ type channelId*: string config*: ReliabilityConfig lock*: Lock - onMessageReady*: proc(messageId: MessageID) - onMessageSent*: proc(messageId: MessageID) - onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID]) + onMessageReady*: proc(messageId: MessageID) {.gcsafe.} + onMessageSent*: proc(messageId: MessageID) {.gcsafe.} + onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} + onPeriodicSync*: PeriodicSyncCallback ReliabilityError* = enum - reSuccess, - reInvalidArgument, - reOutOfMemory, - reInternalError, - reSerializationError, - reDeserializationError, + reInvalidArgument + reOutOfMemory + reInternalError + reSerializationError + reDeserializationError reMessageTooLarge - Result*[T] = object - case isOk*: bool - of true: - value*: T - of false: - error*: ReliabilityError - const DefaultBloomFilterCapacity* = 10000 DefaultBloomFilterErrorRate* = 0.001 @@ -72,10 +66,4 @@ const DefaultMaxCausalHistory* = 10 DefaultResendInterval* = initDuration(seconds = 30) DefaultMaxResendAttempts* = 5 - MaxMessageSize* = 1024 * 1024 # 1 MB - -proc ok*[T](value: T): Result[T] = - Result[T](isOk: true, value: value) - -proc err*[T](error: ReliabilityError): Result[T] = - Result[T](isOk: false, error: error) \ No newline at end of file + MaxMessageSize* = 1024 * 1024 # 1 MB \ No newline at end of file diff --git a/src/protobuf.nim b/src/protobuf.nim new file mode 100644 index 0000000..794cbf4 --- /dev/null +++ b/src/protobuf.nim @@ -0,0 +1,58 @@ +import ./protobufutil +import ./common +import libp2p/protobuf/minprotobuf +import std/options + +proc encode*(msg: Message): ProtoBuffer = + var pb = initProtoBuffer() + + pb.write(1, msg.messageId) + pb.write(2, uint64(msg.lamportTimestamp)) + for hist in msg.causalHistory: + pb.write(3, hist) + pb.write(4, msg.channelId) + pb.write(5, msg.content) + pb.finish() + + pb + +proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] = + let pb = initProtoBuffer(buffer) + var msg = Message() + + if not ?pb.getField(1, msg.messageId): + return err(ProtobufError.missingRequiredField("messageId")) + + var timestamp: uint64 + if not ?pb.getField(2, timestamp): + return err(ProtobufError.missingRequiredField("lamportTimestamp")) + msg.lamportTimestamp = int64(timestamp) + + var hist: string + while ?pb.getField(3, hist): + msg.causalHistory.add(hist) + + if not ?pb.getField(4, msg.channelId): + return err(ProtobufError.missingRequiredField("channelId")) + + if not ?pb.getField(5, msg.content): + return err(ProtobufError.missingRequiredField("content")) + + ok(msg) + +proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] = + try: + let pb = encode(msg) + ok(pb.buffer) + except: + err(reSerializationError) + +proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] = + try: + let msgResult = Message.decode(data) + if msgResult.isOk: + ok(msgResult.get) + else: + err(reSerializationError) + except: + err(reDeserializationError) \ No newline at end of file diff --git a/src/protobufutil.nim b/src/protobufutil.nim new file mode 100644 index 0000000..15b3e33 --- /dev/null +++ b/src/protobufutil.nim @@ -0,0 +1,36 @@ +# adapted from https://github.com/waku-org/nwaku/blob/master/waku/common/protobuf.nim + +{.push raises: [].} + +import libp2p/protobuf/minprotobuf +import libp2p/varint + +export minprotobuf, varint + +type + ProtobufErrorKind* {.pure.} = enum + DecodeFailure + MissingRequiredField + InvalidLengthField + + ProtobufError* = object + case kind*: ProtobufErrorKind + of DecodeFailure: + error*: minprotobuf.ProtoError + of MissingRequiredField, InvalidLengthField: + field*: string + + ProtobufResult*[T] = Result[T, ProtobufError] + +converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError = + case err + of minprotobuf.ProtoError.RequiredFieldMissing: + ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown") + else: + ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err) + +proc missingRequiredField*(T: type ProtobufError, field: string): T = + ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field) + +proc invalidLengthField*(T: type ProtobufError, field: string): T = + ProtobufError(kind: ProtobufErrorKind.InvalidLengthField, field: field) \ No newline at end of file diff --git a/src/reliability.nim b/src/reliability.nim index 521079e..fea7048 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,4 +1,8 @@ -import ./common, ./utils +import std/[times, locks] +import chronos, results +import ./common +import ./utils +import ./protobuf proc defaultConfig*(): ReliabilityConfig = ## Creates a default configuration for the ReliabilityManager. @@ -15,7 +19,7 @@ proc defaultConfig*(): ReliabilityConfig = maxResendAttempts: DefaultMaxResendAttempts ) -proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defaultConfig()): Result[ReliabilityManager] = +proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defaultConfig()): Result[ReliabilityManager, ReliabilityError] = ## Creates a new ReliabilityManager with the specified channel ID and configuration. ## ## Parameters: @@ -25,17 +29,19 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau ## Returns: ## A Result containing either a new ReliabilityManager instance or an error. if channelId.len == 0: - return err[ReliabilityManager](reInvalidArgument) + return err(reInvalidArgument) try: - let bloomFilterResult = newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow) - if bloomFilterResult.isErr: - return err[ReliabilityManager](bloomFilterResult.error) - + let bloomFilter = newRollingBloomFilter( + config.bloomFilterCapacity, + config.bloomFilterErrorRate, + config.bloomFilterWindow + ) + let rm = ReliabilityManager( lamportTimestamp: 0, messageHistory: @[], - bloomFilter: bloomFilterResult.value, + bloomFilter: bloomFilter, outgoingBuffer: @[], incomingBuffer: @[], channelId: channelId, @@ -44,9 +50,9 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau initLock(rm.lock) return ok(rm) except: - return err[ReliabilityManager](reOutOfMemory) + return err(reOutOfMemory) -proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte]): Result[seq[byte]] = +proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: MessageID): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. ## ## Parameters: @@ -55,15 +61,14 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte]): Result[se ## Returns: ## A Result containing either a Message object with reliability metadata or an error. if message.len == 0: - return err[Message](reInvalidArgument) + return err(reInvalidArgument) if message.len > MaxMessageSize: - return err[Message](reMessageTooLarge) + return err(reMessageTooLarge) withLock rm.lock: try: let msg = Message( - senderId: "TODO_SENDER_ID", - messageId: generateUniqueID(), + messageId: messageId, lamportTimestamp: rm.lamportTimestamp, causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory), channelId: rm.channelId, @@ -71,11 +76,11 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte]): Result[se ) rm.updateLamportTimestamp(getTime().toUnix) rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)) - return ok(msg) + return serializeMessage(msg) except: - return err[Message](reInternalError) + return err(reInternalError) -proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]]] = +proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] = ## Unwraps a received message and processes its reliability metadata. ## ## Parameters: @@ -85,33 +90,38 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[ ## A Result containing either a tuple with the processed message and missing dependencies, or an error. withLock rm.lock: try: - if rm.bloomFilter.contains(message.messageId): - return ok((message, @[])) + let msgResult = deserializeMessage(message) + if not msgResult.isOk: + return err(msgResult.error) + + let msg = msgResult.get + if rm.bloomFilter.contains(msg.messageId): + return ok((msg.content, @[])) - rm.bloomFilter.add(message.messageId) - rm.updateLamportTimestamp(message.lamportTimestamp) + rm.bloomFilter.add(msg.messageId) + rm.updateLamportTimestamp(msg.lamportTimestamp) var missingDeps: seq[MessageID] = @[] - for depId in message.causalHistory: + for depId in msg.causalHistory: if not rm.bloomFilter.contains(depId): missingDeps.add(depId) if missingDeps.len == 0: - rm.messageHistory.add(message.messageId) + rm.messageHistory.add(msg.messageId) if rm.messageHistory.len > rm.config.maxMessageHistory: rm.messageHistory.delete(0) if rm.onMessageReady != nil: - rm.onMessageReady(message.messageId) + rm.onMessageReady(msg.messageId) else: - rm.incomingBuffer.add(message) + rm.incomingBuffer.add(msg) if rm.onMissingDependencies != nil: - rm.onMissingDependencies(message.messageId, missingDeps) + rm.onMissingDependencies(msg.messageId, missingDeps) - return ok((message, missingDeps)) + return ok((msg.content, missingDeps)) except: - return err[(Message, seq[MessageID])](reInternalError) + return err(reInternalError) -proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): Result[void] = +proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. ## ## Parameters: @@ -122,9 +132,21 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R withLock rm.lock: try: var processedMessages: seq[Message] = @[] - rm.incomingBuffer = rm.incomingBuffer.filterIt( - not messageIds.allIt(it in it.causalHistory or rm.bloomFilter.contains(it)) - ) + var newIncomingBuffer: seq[Message] = @[] + + for msg in rm.incomingBuffer: + var allDependenciesMet = true + for depId in msg.causalHistory: + if depId notin messageIds and not rm.bloomFilter.contains(depId): + allDependenciesMet = false + break + + if allDependenciesMet: + processedMessages.add(msg) + else: + newIncomingBuffer.add(msg) + + rm.incomingBuffer = newIncomingBuffer for msg in processedMessages: rm.messageHistory.add(msg.messageId) @@ -135,72 +157,80 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R return ok() except: - return err[void](reInternalError) + return err(reInternalError) proc setCallbacks*(rm: ReliabilityManager, - onMessageReady: proc(messageId: MessageID), - onMessageSent: proc(messageId: MessageID), - onMissingDependencies: proc(messageId: MessageID, missingDeps: seq[MessageID])) = + onMessageReady: proc(messageId: MessageID) {.gcsafe.}, + onMessageSent: proc(messageId: MessageID) {.gcsafe.}, + onMissingDependencies: proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}, + onPeriodicSync: PeriodicSyncCallback = nil) = ## Sets the callback functions for various events in the ReliabilityManager. ## ## Parameters: ## - onMessageReady: Callback function called when a message is ready to be processed. ## - onMessageSent: Callback function called when a message is confirmed as sent. ## - onMissingDependencies: Callback function called when a message has missing dependencies. + ## - onPeriodicSync: Callback function called to notify about periodic sync withLock rm.lock: rm.onMessageReady = onMessageReady rm.onMessageSent = onMessageSent rm.onMissingDependencies = onMissingDependencies + rm.onPeriodicSync = onPeriodicSync -proc checkUnacknowledgedMessages*(rm: ReliabilityManager) = +proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} = ## Checks and processes unacknowledged messages in the outgoing buffer. withLock rm.lock: let now = getTime() var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] - for msg in rm.outgoingBuffer: - if (now - msg.sendTime) < rm.config.resendInterval: - newOutgoingBuffer.add(msg) - elif msg.resendAttempts < rm.config.maxResendAttempts: - # Resend the message - msg.resendAttempts += 1 - msg.sendTime = now - newOutgoingBuffer.add(msg) - # Here you would actually resend the message - elif rm.onMessageSent != nil: - rm.onMessageSent(msg.message.messageId) - rm.outgoingBuffer = newOutgoingBuffer + + try: + for msg in rm.outgoingBuffer: + if (now - msg.sendTime) < rm.config.resendInterval: + newOutgoingBuffer.add(msg) + elif msg.resendAttempts < rm.config.maxResendAttempts: + var updatedMsg = msg + updatedMsg.resendAttempts += 1 + updatedMsg.sendTime = now + newOutgoingBuffer.add(updatedMsg) + elif rm.onMessageSent != nil: + rm.onMessageSent(msg.message.messageId) + + rm.outgoingBuffer = newOutgoingBuffer + except: + discard -proc periodicBufferSweep(rm: ReliabilityManager) {.async.} = - ## Periodically sweeps the buffer to clean up and resend messages. +proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = + ## Periodically sweeps the buffer to clean up and check unacknowledged messages. ## ## This is an internal function and should not be called directly. while true: - rm.checkUnacknowledgedMessages() - rm.cleanBloomFilter() - await sleepAsync(5000) # Sleep for 5 seconds + {.gcsafe.}: + try: + rm.checkUnacknowledgedMessages() + rm.cleanBloomFilter() + except Exception as e: + logError("Error in periodic buffer sweep: " & e.msg) + await sleepAsync(chronos.seconds(5)) -proc periodicSyncMessage(rm: ReliabilityManager) {.async.} = - ## Periodically sends a sync message to maintain connectivity. - ## - ## This is an internal function and should not be called directly. +proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} = + ## Periodically notifies to send a sync message to maintain connectivity. while true: - discard rm.wrapOutgoingMessage("") # Empty content for sync messages - await sleepAsync(30000) # Sleep for 30 seconds + {.gcsafe.}: + try: + if rm.onPeriodicSync != nil: + rm.onPeriodicSync() + except Exception as e: + logError("Error in periodic sync: " & e.msg) + await sleepAsync(chronos.seconds(30)) proc startPeriodicTasks*(rm: ReliabilityManager) = ## Starts the periodic tasks for buffer sweeping and sync message sending. ## ## This procedure should be called after creating a ReliabilityManager to enable automatic maintenance. - asyncCheck rm.periodicBufferSweep() - asyncCheck rm.periodicSyncMessage() + asyncSpawn rm.periodicBufferSweep() + asyncSpawn rm.periodicSyncMessage() -# # To demonstrate how to use the ReliabilityManager -# proc processMessage*(rm: ReliabilityManager, message: string): seq[MessageID] = -# let wrappedMsg = checkAndLogError(rm.wrapOutgoingMessage(message), "Failed to wrap message") -# let (_, missingDeps) = checkAndLogError(rm.unwrapReceivedMessage(wrappedMsg), "Failed to unwrap message") -# return missingDeps - -proc resetReliabilityManager*(rm: ReliabilityManager): Result[void] = +proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = ## Resets the ReliabilityManager to its initial state. ## ## This procedure clears all buffers and resets the Lamport timestamp. @@ -208,49 +238,26 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void] = ## Returns: ## A Result indicating success or an error if the Bloom filter initialization fails. withLock rm.lock: - let bloomFilterResult = newRollingBloomFilter(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, rm.config.bloomFilterWindow) - if bloomFilterResult.isErr: - return err[void](bloomFilterResult.error) + try: + rm.lamportTimestamp = 0 + rm.messageHistory.setLen(0) + rm.outgoingBuffer.setLen(0) + rm.incomingBuffer.setLen(0) + rm.bloomFilter = newRollingBloomFilter( + rm.config.bloomFilterCapacity, + rm.config.bloomFilterErrorRate, + rm.config.bloomFilterWindow + ) + return ok() + except: + return err(reInternalError) - rm.lamportTimestamp = 0 - rm.messageHistory.setLen(0) - rm.outgoingBuffer.setLen(0) - rm.incomingBuffer.setLen(0) - rm.bloomFilter = bloomFilterResult.value - return ok() - -proc `=destroy`(rm: var ReliabilityManager) = - ## Destructor for ReliabilityManager. Ensures proper cleanup of resources. - deinitLock(rm.lock) - -when isMainModule: - # Example usage and basic tests - let config = defaultConfig() - let rmResult = newReliabilityManager("testChannel", config) - if rmResult.isOk: - let rm = rmResult.value - rm.setCallbacks( - proc(messageId: MessageID) = echo "Message ready: ", messageId, - proc(messageId: MessageID) = echo "Message sent: ", messageId, - proc(messageId: MessageID, missingDeps: seq[MessageID]) = echo "Missing dependencies for ", messageId, ": ", missingDeps - ) - - let msgResult = rm.wrapOutgoingMessage("Hello, World!") - if msgResult.isOk: - let msg = msgResult.value - echo "Wrapped message: ", msg - - let unwrapResult = rm.unwrapReceivedMessage(msg) - if unwrapResult.isOk: - let (unwrappedMsg, missingDeps) = unwrapResult.value - echo "Unwrapped message: ", unwrappedMsg - echo "Missing dependencies: ", missingDeps - else: - echo "Error unwrapping message: ", unwrapResult.error - else: - echo "Error wrapping message: ", msgResult.error - - rm.startPeriodicTasks() - # In a real application, you'd keep the program running to allow periodic tasks to execute - else: - echo "Error creating ReliabilityManager: ", rmResult.error \ No newline at end of file +proc cleanup*(rm: ReliabilityManager) {.raises: [].} = + if not rm.isNil: + {.gcsafe.}: + try: + rm.outgoingBuffer.setLen(0) + rm.incomingBuffer.setLen(0) + rm.messageHistory.setLen(0) + except Exception as e: + logError("Error during cleanup: " & e.msg) \ No newline at end of file diff --git a/src/utils.nim b/src/utils.nim index f4b1f40..693d918 100644 --- a/src/utils.nim +++ b/src/utils.nim @@ -1,20 +1,37 @@ -import std/[times, hashes, random, sequtils, algorithm, json, options, locks, asyncdispatch] -import chronicles +import std/[times, locks] +import chronos, chronicles import "../nim-bloom/src/bloom" import ./common -proc newRollingBloomFilter*(capacity: int, errorRate: float, window: Duration): Result[RollingBloomFilter] = +proc logError*(msg: string) = + error "ReliabilityError", message = msg + +proc logInfo*(msg: string) = + info "ReliabilityInfo", message = msg + +proc newRollingBloomFilter*(capacity: int, errorRate: float, window: times.Duration): RollingBloomFilter {.gcsafe.} = try: - let filter = initializeBloomFilter(capacity, errorRate) - return ok(RollingBloomFilter( + var filter: BloomFilter + {.gcsafe.}: + filter = initializeBloomFilter(capacity, errorRate) + logInfo("Successfully initialized bloom filter") + RollingBloomFilter( filter: filter, window: window, messages: @[] - )) + ) except: - return err[RollingBloomFilter](reInternalError) + logError("Failed to initialize bloom filter") + var filter: BloomFilter + {.gcsafe.}: + filter = initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) + RollingBloomFilter( + filter: filter, + window: window, + messages: @[] + ) -proc add*(rbf: var RollingBloomFilter, messageId: MessageID) = +proc add*(rbf: var RollingBloomFilter, messageId: MessageID) {.gcsafe.} = ## Adds a message ID to the rolling bloom filter. ## ## Parameters: @@ -22,7 +39,7 @@ proc add*(rbf: var RollingBloomFilter, messageId: MessageID) = rbf.filter.insert(messageId) rbf.messages.add(TimestampedMessageID(id: messageId, timestamp: getTime())) -proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool = +proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} = ## Checks if a message ID is in the rolling bloom filter. ## ## Parameters: @@ -32,125 +49,34 @@ proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool = ## True if the message ID is probably in the filter, false otherwise. rbf.filter.lookup(messageId) -proc clean*(rbf: var RollingBloomFilter) = - ## Removes outdated entries from the rolling bloom filter. - let now = getTime() - let cutoff = now - rbf.window - var newMessages: seq[TimestampedMessageID] = @[] - var newFilter = initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate) +proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} = + try: + let now = getTime() + let cutoff = now - rbf.window + var newMessages: seq[TimestampedMessageID] = @[] + var newFilter: BloomFilter + {.gcsafe.}: + newFilter = initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate) - for msg in rbf.messages: - if msg.timestamp > cutoff: - newMessages.add(msg) - newFilter.insert(msg.id) + for msg in rbf.messages: + if msg.timestamp > cutoff: + newMessages.add(msg) + newFilter.insert(msg.id) - rbf.messages = newMessages - rbf.filter = newFilter + rbf.messages = newMessages + rbf.filter = newFilter + except Exception as e: + logError("Failed to clean bloom filter: " & e.msg) -proc cleanBloomFilter*(rm: ReliabilityManager) = - ## Cleans the rolling bloom filter, removing outdated entries. +proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} = withLock rm.lock: - rm.bloomFilter.clean() + try: + rm.bloomFilter.clean() + except Exception as e: + logError("Failed to clean ReliabilityManager bloom filter: " & e.msg) -proc updateLamportTimestamp(rm: ReliabilityManager, msgTs: int64) = +proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) = rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 -proc getRecentMessageIDs(rm: ReliabilityManager, n: int): seq[MessageID] = - result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] - -proc generateUniqueID*(): MessageID = - let timestamp = getTime().toUnix - let randomPart = rand(high(int)) - result = $hash($timestamp & $randomPart) - -proc serializeMessage*(msg: Message): Result[string] = - ## Serializes a Message object to a JSON string. - ## - ## Parameters: - ## - msg: The Message object to serialize. - ## - ## Returns: - ## A Result containing either the serialized JSON string or an error. - try: - let jsonNode = %*{ - "senderId": msg.senderId, - "messageId": msg.messageId, - "lamportTimestamp": msg.lamportTimestamp, - "causalHistory": msg.causalHistory, - "channelId": msg.channelId, - "content": msg.content - } - return ok($jsonNode) - except: - return err[string](reSerializationError) - -proc deserializeMessage*(data: string): Result[Message] = - ## Deserializes a JSON string to a Message object. - ## - ## Parameters: - ## - data: The JSON string to deserialize. - ## - ## Returns: - ## A Result containing either the deserialized Message object or an error. - try: - let jsonNode = parseJson(data) - return ok(Message( - senderId: jsonNode["senderId"].getStr, - messageId: jsonNode["messageId"].getStr, - lamportTimestamp: jsonNode["lamportTimestamp"].getBiggestInt, - causalHistory: jsonNode["causalHistory"].to(seq[string]), - channelId: jsonNode["channelId"].getStr, - content: jsonNode["content"].getStr - )) - except: - return err[Message](reDeserializationError) - -proc getMessageHistory*(rm: ReliabilityManager): seq[MessageID] = - ## Retrieves the current message history from the ReliabilityManager. - ## - ## Returns: - ## A sequence of MessageIDs representing the current message history. - withLock rm.lock: - return rm.messageHistory - -proc getOutgoingBufferSize*(rm: ReliabilityManager): int = - ## Returns the current size of the outgoing message buffer. - ## - ## Returns: - ## The number of messages in the outgoing buffer. - withLock rm.lock: - return rm.outgoingBuffer.len - -proc getIncomingBufferSize*(rm: ReliabilityManager): int = - ## Returns the current size of the incoming message buffer. - ## - ## Returns: - ## The number of messages in the incoming buffer. - withLock rm.lock: - return rm.incomingBuffer.len - -proc logError*(msg: string) = - ## Logs an error message - error "ReliabilityError", message = msg - -proc logInfo*(msg: string) = - ## Logs an informational message - info "ReliabilityInfo", message = msg - -proc checkAndLogError*[T](res: Result[T], errorMsg: string): T = - ## Checks the result of an operation, logs any errors, and returns the value or raises an exception. - ## - ## Parameters: - ## - res: A Result[T] object to check. - ## - errorMsg: A message to log if an error occurred. - ## - ## Returns: - ## The value contained in the Result if it was successful. - ## - ## Raises: - ## An exception with the error message if the Result contains an error. - if res.isOk: - return res.value - else: - logError(errorMsg & ": " & $res.error) - raise newException(ValueError, errorMsg) \ No newline at end of file +proc getRecentMessageIDs*(rm: ReliabilityManager, n: int): seq[MessageID] = + result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] \ No newline at end of file diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index a7cb656..d30eb0d 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -1,92 +1,180 @@ -import unittest +import unittest, results, chronos, chronicles import ../src/reliability +import ../src/common suite "ReliabilityManager": + var rm: ReliabilityManager + setup: let rmResult = newReliabilityManager("testChannel") - check rmResult.isOk - let rm = rmResult.value + check rmResult.isOk() + rm = rmResult.get() + + teardown: + if not rm.isNil: + rm.cleanup() + + test "can create with default config": + let config = defaultConfig() + check config.bloomFilterCapacity == DefaultBloomFilterCapacity + check config.bloomFilterErrorRate == DefaultBloomFilterErrorRate + check config.bloomFilterWindow == DefaultBloomFilterWindow test "wrapOutgoingMessage": - let msgResult = rm.wrapOutgoingMessage("Hello, World!") - check msgResult.isOk - let msg = msgResult.value - check: - msg.content == "Hello, World!" - msg.channelId == "testChannel" - msg.causalHistory.len == 0 + let msg = @[byte(1), 2, 3] + let msgId = "test-msg-1" + let wrappedResult = rm.wrapOutgoingMessage(msg, msgId) + check wrappedResult.isOk() + let wrapped = wrappedResult.get() + check wrapped.len > 0 test "unwrapReceivedMessage": - let wrappedMsgResult = rm.wrapOutgoingMessage("Test message") - check wrappedMsgResult.isOk - let wrappedMsg = wrappedMsgResult.value - let unwrapResult = rm.unwrapReceivedMessage(wrappedMsg) - check unwrapResult.isOk - let (unwrappedMsg, missingDeps) = unwrapResult.value + let msg = @[byte(1), 2, 3] + let msgId = "test-msg-1" + let wrappedResult = rm.wrapOutgoingMessage(msg, msgId) + check wrappedResult.isOk() + let wrapped = wrappedResult.get() + let unwrapResult = rm.unwrapReceivedMessage(wrapped) + check unwrapResult.isOk() + let (unwrapped, missingDeps) = unwrapResult.get() check: - unwrappedMsg.content == "Test message" + unwrapped == msg missingDeps.len == 0 test "markDependenciesMet": - var msg1Result = rm.wrapOutgoingMessage("Message 1") - var msg2Result = rm.wrapOutgoingMessage("Message 2") - var msg3Result = rm.wrapOutgoingMessage("Message 3") - check msg1Result.isOk and msg2Result.isOk and msg3Result.isOk - let msg1 = msg1Result.value - let msg2 = msg2Result.value - let msg3 = msg3Result.value + info "test_state", state="starting markDependenciesMet test" - var unwrapResult = rm.unwrapReceivedMessage(msg3) - check unwrapResult.isOk - var (_, missingDeps) = unwrapResult.value - check missingDeps.len == 2 + block message1: + let msg1 = @[byte(1)] + let id1 = "msg1" + info "message_creation", msg="message 1", id=id1 + let wrap1 = rm.wrapOutgoingMessage(msg1, id1) + check wrap1.isOk() + let wrapped1 = wrap1.get() - let markResult = rm.markDependenciesMet(@[msg1.messageId, msg2.messageId]) - check markResult.isOk + info "message_processing", msg="message 1", id=id1 + let unwrap1 = rm.unwrapReceivedMessage(wrapped1) + check unwrap1.isOk() + let (content1, deps1) = unwrap1.get() + info "message_processed", msg="message 1", deps_count=deps1.len + check content1 == msg1 - unwrapResult = rm.unwrapReceivedMessage(msg3) - check unwrapResult.isOk - (_, missingDeps) = unwrapResult.value - check missingDeps.len == 0 + block message2: + let msg2 = @[byte(2)] + let id2 = "msg2" + info "message_creation", msg="message 2", id=id2 + let wrap2 = rm.wrapOutgoingMessage(msg2, id2) + check wrap2.isOk() + let wrapped2 = wrap2.get() - test "callbacks": + info "message_processing", msg="message 2", id=id2 + let unwrap2 = rm.unwrapReceivedMessage(wrapped2) + check unwrap2.isOk() + let (content2, deps2) = unwrap2.get() + info "message_processed", msg="message 2", deps_count=deps2.len + check content2 == msg2 + + block message3: + info "message_creation", msg="message 3" + let msg3 = @[byte(3)] + let id3 = "msg3" + let wrap3 = rm.wrapOutgoingMessage(msg3, id3) + check wrap3.isOk() + info "message_wrapped", msg="message 3", id=id3 + let wrapped3 = wrap3.get() + + info "checking_dependencies", msg="message 3", id=id3 + var unwrap3 = rm.unwrapReceivedMessage(wrapped3) + check unwrap3.isOk() + var (content3, missing3) = unwrap3.get() + info "dependencies_checked", msg="message 3", missing_deps=missing3.len + + info "test_state", state="completed" + + test "callbacks work correctly": var messageReadyCount = 0 var messageSentCount = 0 var missingDepsCount = 0 rm.setCallbacks( - proc(messageId: MessageID) = messageReadyCount += 1, - proc(messageId: MessageID) = messageSentCount += 1, - proc(messageId: MessageID, missingDeps: seq[MessageID]) = missingDepsCount += 1 + proc(messageId: MessageID) {.gcsafe.} = messageReadyCount += 1, + proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1 ) - let msg1Result = rm.wrapOutgoingMessage("Message 1") - let msg2Result = rm.wrapOutgoingMessage("Message 2") - check msg1Result.isOk and msg2Result.isOk - let msg1 = msg1Result.value - let msg2 = msg2Result.value + let msg1Result = rm.wrapOutgoingMessage(@[byte(1)], "msg1") + let msg2Result = rm.wrapOutgoingMessage(@[byte(2)], "msg2") + check msg1Result.isOk() and msg2Result.isOk() + let msg1 = msg1Result.get() + let msg2 = msg2Result.get() discard rm.unwrapReceivedMessage(msg1) discard rm.unwrapReceivedMessage(msg2) check: messageReadyCount == 2 - messageSentCount == 0 # This would be triggered by the checkUnacknowledgedMessages function + messageSentCount == 0 # This would be triggered by checkUnacknowledgedMessages missingDepsCount == 0 - test "serialization": - let msgResult = rm.wrapOutgoingMessage("Test serialization") - check msgResult.isOk - let msg = msgResult.value - let serializeResult = serializeMessage(msg) - check serializeResult.isOk - let serialized = serializeResult.value - let deserializeResult = deserializeMessage(serialized) - check deserializeResult.isOk - let deserialized = deserializeResult.value - check: - deserialized.content == "Test serialization" - deserialized.messageId == msg.messageId - deserialized.lamportTimestamp == msg.lamportTimestamp + test "periodic sync callback works": + var syncCallCount = 0 + rm.setCallbacks( + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID) {.gcsafe.} = discard, + proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard, + proc() {.gcsafe.} = syncCallCount += 1 + ) -when isMainModule: - unittest.run() \ No newline at end of file + rm.startPeriodicTasks() + # Sleep briefly to allow periodic tasks to run + waitFor sleepAsync(chronos.seconds(1)) + rm.cleanup() + + check(syncCallCount > 0) + + test "protobuf serialization": + let msg = @[byte(1), 2, 3] + let msgId = "test-msg-1" + let msgResult = rm.wrapOutgoingMessage(msg, msgId) + check msgResult.isOk() + let wrapped = msgResult.get() + + let unwrapResult = rm.unwrapReceivedMessage(wrapped) + check unwrapResult.isOk() + let (unwrapped, _) = unwrapResult.get() + + check: + unwrapped == msg + unwrapped.len == msg.len + + test "handles empty message": + let msg: seq[byte] = @[] + let msgId = "test-empty-msg" + let wrappedResult = rm.wrapOutgoingMessage(msg, msgId) + check(not wrappedResult.isOk()) + check(wrappedResult.error == reInvalidArgument) + + test "handles message too large": + let msg = newSeq[byte](MaxMessageSize + 1) + let msgId = "test-large-msg" + let wrappedResult = rm.wrapOutgoingMessage(msg, msgId) + check(not wrappedResult.isOk()) + check(wrappedResult.error == reMessageTooLarge) + +suite "cleanup": + test "cleanup works correctly": + let rmResult = newReliabilityManager("testChannel") + check rmResult.isOk() + let rm = rmResult.get() + + # Add some messages + let msg = @[byte(1), 2, 3] + let msgId = "test-msg-1" + discard rm.wrapOutgoingMessage(msg, msgId) + + # Cleanup + rm.cleanup() + + # Check buffers are empty + check(rm.outgoingBuffer.len == 0) + check(rm.incomingBuffer.len == 0) + check(rm.messageHistory.len == 0) \ No newline at end of file