diff --git a/sds.nim b/sds.nim index 479a3e5..58d1893 100644 --- a/sds.nim +++ b/sds.nim @@ -1,24 +1,15 @@ import std/[times, locks, tables, sets, options] import chronos, results, chronicles -import sds/[message, protobuf, sds_utils, rolling_bloom_filter] +import sds/[types, protobuf, sds_utils, rolling_bloom_filter] -export message, protobuf, sds_utils, rolling_bloom_filter +export types, protobuf, sds_utils, rolling_bloom_filter proc newReliabilityManager*( config: ReliabilityConfig = defaultConfig() ): Result[ReliabilityManager, ReliabilityError] = ## Creates a new multi-channel ReliabilityManager. - ## - ## Parameters: - ## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used. - ## - ## Returns: - ## A Result containing either a new ReliabilityManager instance or an error. try: - let rm = ReliabilityManager( - channels: initTable[SdsChannelID, ChannelContext](), config: config - ) - initLock(rm.lock) + let rm = ReliabilityManager.new(config) return ok(rm) except Exception: error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg() @@ -35,25 +26,20 @@ proc isAcknowledged*( if rbf.isSome(): return rbf.get().contains(msg.message.messageId) - false + return false proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = - # Parse bloom filter var rbf: Option[RollingBloomFilter] if msg.bloomFilter.len > 0: let bfResult = deserializeBloomFilter(msg.bloomFilter) if bfResult.isOk(): + let bf = bfResult.get() rbf = some( - RollingBloomFilter( - filter: bfResult.get(), - capacity: bfResult.get().capacity, - minCapacity: ( - bfResult.get().capacity.float * (100 - CapacityFlexPercent).float / 100.0 - ).int, - maxCapacity: ( - bfResult.get().capacity.float * (100 + CapacityFlexPercent).float / 100.0 - ).int, - messages: @[], + RollingBloomFilter.init( + filter = bf, + capacity = bf.capacity, + minCapacity = (bf.capacity.float * (100 - CapacityFlexPercent).float / 100.0).int, + maxCapacity = (bf.capacity.float * (100 + CapacityFlexPercent).float / 100.0).int, ) ) else: @@ -66,7 +52,6 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = return let channel = rm.channels[msg.channelId] - # Keep track of indices to delete var toDelete: seq[int] = @[] var i = 0 @@ -78,7 +63,7 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} = toDelete.add(i) inc i - for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices + for i in countdown(toDelete.high, 0): channel.outgoingBuffer.delete(toDelete[i]) proc wrapOutgoingMessage*( @@ -88,14 +73,6 @@ proc wrapOutgoingMessage*( channelId: SdsChannelID, ): Result[seq[byte], ReliabilityError] = ## Wraps an outgoing message with reliability metadata. - ## - ## Parameters: - ## - message: The content of the message to be sent. - ## - messageId: Unique identifier for the message - ## - channelId: Identifier for the channel this message belongs to. - ## - ## Returns: - ## A Result containing either wrapped message bytes or an error. if message.len == 0: return err(ReliabilityError.reInvalidArgument) if message.len > MaxMessageSize: @@ -111,20 +88,19 @@ proc wrapOutgoingMessage*( error "Failed to serialize bloom filter", channelId = channelId return err(ReliabilityError.reSerializationError) - let msg = SdsMessage( - messageId: messageId, - lamportTimestamp: channel.lamportTimestamp, - causalHistory: rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId), - channelId: channelId, - content: message, - bloomFilter: bfResult.get(), + let msg = SdsMessage.init( + messageId = messageId, + lamportTimestamp = channel.lamportTimestamp, + causalHistory = rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId), + channelId = channelId, + content = message, + bloomFilter = bfResult.get(), ) channel.outgoingBuffer.add( - UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0) + UnacknowledgedMessage.init(message = msg, sendTime = getTime(), resendAttempts = 0) ) - # Add to causal history and bloom filter channel.bloomFilter.add(msg.messageId) rm.addToHistory(msg.messageId, channelId) @@ -147,7 +123,6 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc var processed = initHashSet[SdsMessageID]() var readyToProcess = newSeq[SdsMessageID]() - # Find initially ready messages for msgId, entry in channel.incomingBuffer: if entry.missingDeps.len == 0: readyToProcess.add(msgId) @@ -163,7 +138,6 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc rm.onMessageReady(msgId, channelId) processed.incl(msgId) - # Update dependencies for remaining messages for remainingId, entry in channel.incomingBuffer: if remainingId notin processed: if msgId in entry.missingDeps: @@ -171,7 +145,6 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc if channel.incomingBuffer[remainingId].missingDeps.len == 0: readyToProcess.add(remainingId) - # Remove processed messages for msgId in processed: channel.incomingBuffer.del(msgId) @@ -182,12 +155,6 @@ proc unwrapReceivedMessage*( ReliabilityError, ] = ## Unwraps a received message and processes its reliability metadata. - ## - ## Parameters: - ## - message: The received message bytes - ## - ## Returns: - ## A Result containing either tuple of (processed message, missing dependencies, channel ID) or an error. try: let channelId = extractChannelId(message).valueOr: return err(ReliabilityError.reDeserializationError) @@ -203,7 +170,6 @@ proc unwrapReceivedMessage*( channel.bloomFilter.add(msg.messageId) rm.updateLamportTimestamp(msg.lamportTimestamp, channelId) - # Review ACK status for outgoing messages rm.reviewAckStatus(msg) var missingDeps = rm.checkDependencies(msg.causalHistory, channelId) @@ -214,19 +180,20 @@ proc unwrapReceivedMessage*( if msgId in msg.causalHistory.getMessageIds(): depsInBuffer = true break - # Check if any dependencies are still in incoming buffer if depsInBuffer: channel.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]()) + IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]()) else: - # All dependencies met, add to history rm.addToHistory(msg.messageId, channelId) rm.processIncomingBuffer(channelId) if not rm.onMessageReady.isNil(): rm.onMessageReady(msg.messageId, channelId) else: channel.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: missingDeps.getMessageIds().toHashSet()) + IncomingMessage.init( + message = msg, + missingDeps = missingDeps.getMessageIds().toHashSet(), + ) if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps, channelId) @@ -239,13 +206,6 @@ proc markDependenciesMet*( rm: ReliabilityManager, messageIds: seq[SdsMessageID], channelId: SdsChannelID ): Result[void, ReliabilityError] = ## Marks the specified message dependencies as met. - ## - ## Parameters: - ## - messageIds: A sequence of message IDs to mark as met. - ## - channelId: Identifier for the channel. - ## - ## Returns: - ## A Result indicating success or an error. try: if channelId notin rm.channels: return err(ReliabilityError.reInvalidArgument) @@ -273,16 +233,9 @@ proc setCallbacks*( onMessageSent: MessageSentCallback, onMissingDependencies: MissingDependenciesCallback, onPeriodicSync: PeriodicSyncCallback = nil, - onRetrievalHint: RetrievalHintProvider = nil + onRetrievalHint: RetrievalHintProvider = nil, ) = ## Sets the callback functions for various events in the ReliabilityManager. - ## - ## Parameters: - ## - onMessageReady: Callback function called when a message is ready to be processed. - ## - onMessageSent: Callback function called when a message is confirmed as sent. - ## - onMissingDependencies: Callback function called when a message has missing dependencies. - ## - onPeriodicSync: Callback function called to notify about periodic sync - ## - onRetrievalHint: Callback function called to get a retrieval hint for a message ID. withLock rm.lock: rm.onMessageReady = onMessageReady rm.onMessageSent = onMessageSent @@ -293,7 +246,6 @@ proc setCallbacks*( proc checkUnacknowledgedMessages( rm: ReliabilityManager, channelId: SdsChannelID ) {.gcsafe.} = - ## Checks and processes unacknowledged messages in the outgoing buffer. withLock rm.lock: if channelId notin rm.channels: error "Channel does not exist", channelId = channelId @@ -322,7 +274,6 @@ proc checkUnacknowledgedMessages( proc periodicBufferSweep( rm: ReliabilityManager ) {.async: (raises: [CancelledError]), gcsafe.} = - ## Periodically sweeps the buffer to clean up and check unacknowledged messages. while true: try: for channelId, channel in rm.channels: @@ -340,7 +291,6 @@ proc periodicBufferSweep( proc periodicSyncMessage( rm: ReliabilityManager ) {.async: (raises: [CancelledError]), gcsafe.} = - ## Periodically notifies to send a sync message to maintain connectivity. while true: try: if not rm.onPeriodicSync.isNil(): @@ -351,15 +301,11 @@ proc periodicSyncMessage( proc startPeriodicTasks*(rm: ReliabilityManager) = ## Starts the periodic tasks for buffer sweeping and sync message sending. - ## - ## This procedure should be called after creating a ReliabilityManager to enable automatic maintenance. asyncSpawn rm.periodicBufferSweep() asyncSpawn rm.periodicSyncMessage() proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] = ## Resets the ReliabilityManager to its initial state. - ## - ## This procedure clears all buffers and resets the Lamport timestamp. withLock rm.lock: try: for channelId, channel in rm.channels: @@ -367,9 +313,8 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE channel.messageHistory.setLen(0) channel.outgoingBuffer.setLen(0) channel.incomingBuffer.clear() - channel.bloomFilter = newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate - ) + channel.bloomFilter = + RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) rm.channels.clear() return ok() except Exception: diff --git a/sds/bloom.nim b/sds/bloom.nim index 7d6f498..20854a2 100644 --- a/sds/bloom.nim +++ b/sds/bloom.nim @@ -3,13 +3,8 @@ import hashes import strutils import results import private/probabilities - -type BloomFilter* = object - capacity*: int - errorRate*: float - kHashes*: int - mBits*: int - intArray*: seq[int] +import ./types/bloom_filter +export bloom_filter {.push overflowChecks: off.} # Turn off overflow checks for hashing operations @@ -20,13 +15,7 @@ proc hashN(item: string, n: int, maxValue: int): int = let hashA = abs(hash(item)) mod maxValue # Use abs to handle negative hashes hashB = abs(hash(item & " b")) mod maxValue # string concatenation - abs((hashA + n * hashB)) mod maxValue - # # Use bit rotation for second hash instead of string concatenation if speed if preferred over FP-rate - # # Rotate left by 21 bits (lower the rotation, higher the speed but higher the FP-rate too) - # hashB = abs( - # ((h shl 21) or (h shr (sizeof(int) * 8 - 21))) - # ) mod maxValue - # abs((hashA + n.int64 * hashB)) mod maxValue + return abs((hashA + n * hashB)) mod maxValue {.pop.} @@ -41,7 +30,7 @@ proc getMOverNBitsForK*( if probabilityTable[k][mOverN] < targetError: return ok(mOverN) - err( + return err( "Specified value of k and error rate not achievable using less than 4 bytes / element." ) @@ -79,31 +68,31 @@ proc initializeBloomFilter*( mBits = capacity * nBitsPerElem mInts = 1 + mBits div (sizeof(int) * 8) - ok( - BloomFilter( - capacity: capacity, - errorRate: errorRate, - kHashes: kHashes, - mBits: mBits, - intArray: newSeq[int](mInts), + return ok( + BloomFilter.init( + capacity = capacity, + errorRate = errorRate, + kHashes = kHashes, + mBits = mBits, + intArray = newSeq[int](mInts), ) ) proc `$`*(bf: BloomFilter): string = ## Prints the configuration of the Bloom filter. - "Bloom filter with $1 capacity, $2 error rate, $3 hash functions, and requiring $4 bits of memory." % - [ - $bf.capacity, - formatFloat(bf.errorRate, format = ffScientific, precision = 1), - $bf.kHashes, - $(bf.mBits div bf.capacity), - ] + return "Bloom filter with $1 capacity, $2 error rate, $3 hash functions, and requiring $4 bits of memory." % + [ + $bf.capacity, + formatFloat(bf.errorRate, format = ffScientific, precision = 1), + $bf.kHashes, + $(bf.mBits div bf.capacity), + ] proc computeHashes(bf: BloomFilter, item: string): seq[int] = var hashes = newSeq[int](bf.kHashes) for i in 0 ..< bf.kHashes: hashes[i] = hashN(item, i, bf.mBits) - hashes + return hashes proc insert*(bf: var BloomFilter, item: string) = ## Insert an item (string) into the Bloom filter. @@ -127,4 +116,4 @@ proc lookup*(bf: BloomFilter, item: string): bool = currentInt = bf.intArray[intAddress] if currentInt != (currentInt or (1 shl bitOffset)): return false - true + return true diff --git a/sds/message.nim b/sds/message.nim index 030a023..ddf5e5f 100644 --- a/sds/message.nim +++ b/sds/message.nim @@ -1,35 +1,14 @@ -import std/[times, sets] +import ./types/sds_message_id +import ./types/history_entry +import ./types/sds_message +import ./types/unacknowledged_message +import ./types/incoming_message +import ./types/reliability_config -type - SdsMessageID* = string - SdsChannelID* = string - - HistoryEntry* = object - messageId*: SdsMessageID - retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash) - - SdsMessage* = object - messageId*: SdsMessageID - lamportTimestamp*: int64 - causalHistory*: seq[HistoryEntry] - channelId*: SdsChannelID - content*: seq[byte] - bloomFilter*: seq[byte] - - UnacknowledgedMessage* = object - message*: SdsMessage - sendTime*: Time - resendAttempts*: int - - IncomingMessage* = object - message*: SdsMessage - missingDeps*: HashSet[SdsMessageID] - -const - DefaultMaxMessageHistory* = 1000 - DefaultMaxCausalHistory* = 10 - DefaultResendInterval* = initDuration(seconds = 60) - DefaultMaxResendAttempts* = 5 - DefaultSyncMessageInterval* = initDuration(seconds = 30) - DefaultBufferSweepInterval* = initDuration(seconds = 60) - MaxMessageSize* = 1024 * 1024 # 1 MB +export + sds_message_id, + history_entry, + sds_message, + unacknowledged_message, + incoming_message, + reliability_config diff --git a/sds/protobuf.nim b/sds/protobuf.nim index 8eb69ac..ba1b7ff 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -1,6 +1,9 @@ import libp2p/protobuf/minprotobuf import endians -import sds/[message, protobufutil, bloom, sds_utils] +import ./types/[sds_message_id, history_entry, sds_message, reliability_error] +import ./protobufutil +import ./bloom +import ./sds_utils proc encode*(msg: SdsMessage): ProtoBuffer = var pb = initProtoBuffer() @@ -21,11 +24,11 @@ proc encode*(msg: SdsMessage): ProtoBuffer = pb.write(6, msg.bloomFilter) pb.finish() - pb + return pb proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) - var msg = SdsMessage() + var msg = SdsMessage.init("", 0, @[], "", @[], @[]) if not ?pb.getField(1, msg.messageId): return err(ProtobufError.missingRequiredField("messageId")) @@ -41,7 +44,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = # New format: repeated HistoryEntry for histBuffer in historyBuffers: let entryPb = initProtoBuffer(histBuffer) - var entry: HistoryEntry + var entry = HistoryEntry.init("") if not ?entryPb.getField(1, entry.messageId): return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) # retrievalHint is optional @@ -63,7 +66,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = if not ?pb.getField(6, msg.bloomFilter): msg.bloomFilter = @[] # Empty if not present - ok(msg) + return ok(msg) proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = ## For extraction of channel ID without full message deserialization @@ -74,23 +77,22 @@ proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] return err(ReliabilityError.reDeserializationError) if not fieldOk: return err(ReliabilityError.reDeserializationError) - ok(channelId) + return ok(channelId) except: - err(ReliabilityError.reDeserializationError) + return err(ReliabilityError.reDeserializationError) proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = let pb = encode(msg) - ok(pb.buffer) + return ok(pb.buffer) proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] = let msg = SdsMessage.decode(data).valueOr: return err(ReliabilityError.reDeserializationError) - ok(msg) + return ok(msg) proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] = var pb = initProtoBuffer() - # Convert intArray to bytes try: var bytes = newSeq[byte](filter.intArray.len * sizeof(int)) for i, val in filter.intArray: @@ -108,7 +110,7 @@ proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityEr return err(ReliabilityError.reSerializationError) pb.finish() - ok(pb.buffer) + return ok(pb.buffer) proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] = if data.len == 0: @@ -134,7 +136,6 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr if not field1_Ok or not field2_Ok or not field3_Ok or not field4_Ok or not field5_Ok: return err(ReliabilityError.reDeserializationError) - # Convert bytes back to intArray var intArray = newSeq[int](bytes.len div sizeof(int)) for i in 0 ..< intArray.len: var leVal: int @@ -142,13 +143,13 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int)) littleEndian64(addr intArray[i], addr leVal) - ok( - BloomFilter( - intArray: intArray, - capacity: int(cap), - errorRate: float(errRate) / 1_000_000, - kHashes: int(kHashes), - mBits: int(mBits), + return ok( + BloomFilter.init( + capacity = int(cap), + errorRate = float(errRate) / 1_000_000, + kHashes = int(kHashes), + mBits = int(mBits), + intArray = intArray, ) ) except: diff --git a/sds/protobufutil.nim b/sds/protobufutil.nim index d7c928c..3153017 100644 --- a/sds/protobufutil.nim +++ b/sds/protobufutil.nim @@ -4,29 +4,16 @@ import libp2p/protobuf/minprotobuf import libp2p/varint +import ./types/protobuf_error -export minprotobuf, varint - -type - ProtobufErrorKind* {.pure.} = enum - DecodeFailure - MissingRequiredField - - ProtobufError* = object - case kind*: ProtobufErrorKind - of DecodeFailure: - error*: minprotobuf.ProtoError - of MissingRequiredField: - field*: string - - ProtobufResult*[T] = Result[T, ProtobufError] +export minprotobuf, varint, protobuf_error converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError = case err of minprotobuf.ProtoError.RequiredFieldMissing: - ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown") + return ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown") else: - ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err) + return ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err) proc missingRequiredField*(T: type ProtobufError, field: string): T = - ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field) + return ProtobufError.init(field) diff --git a/sds/rolling_bloom_filter.nim b/sds/rolling_bloom_filter.nim index 190ab8a..1fe2c79 100644 --- a/sds/rolling_bloom_filter.nim +++ b/sds/rolling_bloom_filter.nim @@ -1,23 +1,14 @@ import chronos import chronicles -import ./[bloom, message] +import ./bloom +import ./types/rolling_bloom_filter +export rolling_bloom_filter -type RollingBloomFilter* = object - filter*: BloomFilter - capacity*: int - minCapacity*: int - maxCapacity*: int - messages*: seq[SdsMessageID] - -const - DefaultBloomFilterCapacity* = 10000 - DefaultBloomFilterErrorRate* = 0.001 - CapacityFlexPercent* = 20 - -proc newRollingBloomFilter*( +proc init*( + T: type RollingBloomFilter, capacity: int = DefaultBloomFilterCapacity, errorRate: float = DefaultBloomFilterErrorRate, -): RollingBloomFilter {.gcsafe.} = +): T {.gcsafe.} = let targetCapacity = if capacity <= 0: DefaultBloomFilterCapacity else: capacity let targetError = if errorRate <= 0.0 or errorRate >= 1.0: DefaultBloomFilterErrorRate else: errorRate @@ -25,7 +16,6 @@ proc newRollingBloomFilter*( let filterResult = initializeBloomFilter(targetCapacity, targetError) if filterResult.isErr: error "Failed to initialize bloom filter", error = filterResult.error - # Try with default values if custom values failed if capacity != DefaultBloomFilterCapacity or errorRate != DefaultBloomFilterErrorRate: let defaultResult = initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) @@ -45,12 +35,11 @@ proc newRollingBloomFilter*( minCapacity = minCapacity, maxCapacity = maxCapacity - return RollingBloomFilter( - filter: defaultResult.get(), - capacity: DefaultBloomFilterCapacity, - minCapacity: minCapacity, - maxCapacity: maxCapacity, - messages: @[], + return RollingBloomFilter.init( + filter = defaultResult.get(), + capacity = DefaultBloomFilterCapacity, + minCapacity = minCapacity, + maxCapacity = maxCapacity, ) else: error "Could not create bloom filter", error = filterResult.error @@ -63,12 +52,11 @@ proc newRollingBloomFilter*( info "Successfully initialized bloom filter", capacity = targetCapacity, minCapacity = minCapacity, maxCapacity = maxCapacity - return RollingBloomFilter( - filter: filterResult.get(), - capacity: targetCapacity, - minCapacity: minCapacity, - maxCapacity: maxCapacity, - messages: @[], + return RollingBloomFilter.init( + filter = filterResult.get(), + capacity = targetCapacity, + minCapacity = minCapacity, + maxCapacity = maxCapacity, ) proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} = @@ -97,22 +85,12 @@ proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} = proc add*(rbf: var RollingBloomFilter, messageId: SdsMessageID) {.gcsafe.} = ## Adds a message ID to the rolling bloom filter. - ## - ## Parameters: - ## - messageId: The ID of the message to add. rbf.filter.insert(cast[string](messageId)) rbf.messages.add(messageId) - # Clean if we exceed max capacity if rbf.messages.len > rbf.maxCapacity: rbf.clean() proc contains*(rbf: RollingBloomFilter, messageId: SdsMessageID): bool = ## Checks if a message ID is in the rolling bloom filter. - ## - ## Parameters: - ## - messageId: The ID of the message to check. - ## - ## Returns: - ## True if the message ID is probably in the filter, false otherwise. - rbf.filter.lookup(cast[string](messageId)) + return rbf.filter.lookup(cast[string](messageId)) diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index e8b0dd6..f1a68ca 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -1,81 +1,18 @@ -import std/[times, locks, tables, sequtils] +import std/[locks, tables, sequtils] import chronicles, results -import ./[rolling_bloom_filter, message] - -type - MessageReadyCallback* = - proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - - MessageSentCallback* = - proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - - MissingDependenciesCallback* = proc( - messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID - ) {.gcsafe.} - - RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} - - PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} - - AppCallbacks* = ref object - messageReadyCb*: MessageReadyCallback - messageSentCb*: MessageSentCallback - missingDependenciesCb*: MissingDependenciesCallback - periodicSyncCb*: PeriodicSyncCallback - retrievalHintProvider*: RetrievalHintProvider - - ReliabilityConfig* = object - bloomFilterCapacity*: int - bloomFilterErrorRate*: float - maxMessageHistory*: int - maxCausalHistory*: int - resendInterval*: Duration - maxResendAttempts*: int - syncMessageInterval*: Duration - bufferSweepInterval*: Duration - - ChannelContext* = ref object - lamportTimestamp*: int64 - messageHistory*: seq[SdsMessageID] - bloomFilter*: RollingBloomFilter - outgoingBuffer*: seq[UnacknowledgedMessage] - incomingBuffer*: Table[SdsMessageID, IncomingMessage] - - ReliabilityManager* = ref object - channels*: Table[SdsChannelID, ChannelContext] - config*: ReliabilityConfig - lock*: Lock - onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} - onMissingDependencies*: proc( - messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID - ) {.gcsafe.} - onPeriodicSync*: PeriodicSyncCallback - onRetrievalHint*: RetrievalHintProvider - - ReliabilityError* {.pure.} = enum - reInvalidArgument - reOutOfMemory - reInternalError - reSerializationError - reDeserializationError - reMessageTooLarge +import ./rolling_bloom_filter +import ./types/[ + sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, + reliability_error, callbacks, app_callbacks, reliability_config, channel_context, + reliability_manager, +] +export + sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message, + reliability_error, callbacks, app_callbacks, reliability_config, channel_context, + reliability_manager proc defaultConfig*(): ReliabilityConfig = - ## Creates a default configuration for the ReliabilityManager. - ## - ## Returns: - ## A ReliabilityConfig object with default values. - ReliabilityConfig( - bloomFilterCapacity: DefaultBloomFilterCapacity, - bloomFilterErrorRate: DefaultBloomFilterErrorRate, - maxMessageHistory: DefaultMaxMessageHistory, - maxCausalHistory: DefaultMaxCausalHistory, - resendInterval: DefaultResendInterval, - maxResendAttempts: DefaultMaxResendAttempts, - syncMessageInterval: DefaultSyncMessageInterval, - bufferSweepInterval: DefaultBufferSweepInterval, - ) + return ReliabilityConfig.init() proc cleanup*(rm: ReliabilityManager) {.raises: [].} = if not rm.isNil(): @@ -124,24 +61,18 @@ proc updateLamportTimestamp*( error "Failed to update lamport timestamp", channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() -# Helper functions for HistoryEntry proc newHistoryEntry*(messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): HistoryEntry = - ## Creates a new HistoryEntry with optional retrieval hint - HistoryEntry(messageId: messageId, retrievalHint: retrievalHint) + return HistoryEntry.init(messageId, retrievalHint) proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] = - ## Converts a sequence of message IDs to HistoryEntry sequence (for backward compatibility) return messageIds.mapIt(newHistoryEntry(it)) proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] = - ## Extracts message IDs from HistoryEntry sequence return causalHistory.mapIt(it.messageId) proc getRecentHistoryEntries*( rm: ReliabilityManager, n: int, channelId: SdsChannelID ): seq[HistoryEntry] = - ## Get recent history entries for sending in causal history. - ## Populates retrieval hints for our own messages using the provider callback. try: if channelId in rm.channels: let channel = rm.channels[channelId] @@ -164,7 +95,6 @@ proc getRecentHistoryEntries*( proc checkDependencies*( rm: ReliabilityManager, deps: seq[HistoryEntry], channelId: SdsChannelID ): seq[HistoryEntry] = - ## Check which dependencies are missing from our message history. var missingDeps: seq[HistoryEntry] = @[] try: if channelId in rm.channels: @@ -173,7 +103,6 @@ proc checkDependencies*( if dep.messageId notin channel.messageHistory: missingDeps.add(dep) else: - # Channel doesn't exist, all deps are missing missingDeps = deps except Exception: error "Failed to check dependencies", @@ -187,13 +116,13 @@ proc getMessageHistory*( withLock rm.lock: try: if channelId in rm.channels: - result = rm.channels[channelId].messageHistory + return rm.channels[channelId].messageHistory else: - result = @[] + return @[] except Exception: error "Failed to get message history", channelId = channelId, error = getCurrentExceptionMsg() - result = @[] + return @[] proc getOutgoingBuffer*( rm: ReliabilityManager, channelId: SdsChannelID @@ -201,43 +130,37 @@ proc getOutgoingBuffer*( withLock rm.lock: try: if channelId in rm.channels: - result = rm.channels[channelId].outgoingBuffer + return rm.channels[channelId].outgoingBuffer else: - result = @[] + return @[] except Exception: error "Failed to get outgoing buffer", channelId = channelId, error = getCurrentExceptionMsg() - result = @[] + return @[] proc getIncomingBuffer*( rm: ReliabilityManager, channelId: SdsChannelID -): Table[SdsMessageID, message.IncomingMessage] = +): Table[SdsMessageID, IncomingMessage] = withLock rm.lock: try: if channelId in rm.channels: - result = rm.channels[channelId].incomingBuffer + return rm.channels[channelId].incomingBuffer else: - result = initTable[SdsMessageID, message.IncomingMessage]() + return initTable[SdsMessageID, IncomingMessage]() except Exception: error "Failed to get incoming buffer", channelId = channelId, error = getCurrentExceptionMsg() - result = initTable[SdsMessageID, message.IncomingMessage]() + return initTable[SdsMessageID, IncomingMessage]() proc getOrCreateChannel*( rm: ReliabilityManager, channelId: SdsChannelID ): ChannelContext = try: if channelId notin rm.channels: - rm.channels[channelId] = ChannelContext( - lamportTimestamp: 0, - messageHistory: @[], - bloomFilter: newRollingBloomFilter( - rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate - ), - outgoingBuffer: @[], - incomingBuffer: initTable[SdsMessageID, IncomingMessage](), + rm.channels[channelId] = ChannelContext.new( + RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate) ) - result = rm.channels[channelId] + return rm.channels[channelId] except Exception: error "Failed to get or create channel", channelId = channelId, error = getCurrentExceptionMsg() @@ -270,4 +193,4 @@ proc removeChannel*( except Exception: error "Failed to remove channel", channelId = channelId, msg = getCurrentExceptionMsg() - return err(ReliabilityError.reInternalError) \ No newline at end of file + return err(ReliabilityError.reInternalError) diff --git a/sds/types.nim b/sds/types.nim new file mode 100644 index 0000000..f37518a --- /dev/null +++ b/sds/types.nim @@ -0,0 +1,30 @@ +import sds/types/sds_message_id +import sds/types/history_entry +import sds/types/sds_message +import sds/types/unacknowledged_message +import sds/types/incoming_message +import sds/types/bloom_filter +import sds/types/rolling_bloom_filter +import sds/types/reliability_error +import sds/types/callbacks +import sds/types/app_callbacks +import sds/types/reliability_config +import sds/types/channel_context +import sds/types/reliability_manager +import sds/types/protobuf_error + +export + sds_message_id, + history_entry, + sds_message, + unacknowledged_message, + incoming_message, + bloom_filter, + rolling_bloom_filter, + reliability_error, + callbacks, + app_callbacks, + reliability_config, + channel_context, + reliability_manager, + protobuf_error diff --git a/sds/types/app_callbacks.nim b/sds/types/app_callbacks.nim new file mode 100644 index 0000000..985a97f --- /dev/null +++ b/sds/types/app_callbacks.nim @@ -0,0 +1,25 @@ +import ./callbacks +export callbacks + +type AppCallbacks* = ref object + messageReadyCb*: MessageReadyCallback + messageSentCb*: MessageSentCallback + missingDependenciesCb*: MissingDependenciesCallback + periodicSyncCb*: PeriodicSyncCallback + retrievalHintProvider*: RetrievalHintProvider + +proc new*( + T: type AppCallbacks, + messageReadyCb: MessageReadyCallback = nil, + messageSentCb: MessageSentCallback = nil, + missingDependenciesCb: MissingDependenciesCallback = nil, + periodicSyncCb: PeriodicSyncCallback = nil, + retrievalHintProvider: RetrievalHintProvider = nil, +): T = + return T( + messageReadyCb: messageReadyCb, + messageSentCb: messageSentCb, + missingDependenciesCb: missingDependenciesCb, + periodicSyncCb: periodicSyncCb, + retrievalHintProvider: retrievalHintProvider, + ) diff --git a/sds/types/bloom_filter.nim b/sds/types/bloom_filter.nim new file mode 100644 index 0000000..8ca5be0 --- /dev/null +++ b/sds/types/bloom_filter.nim @@ -0,0 +1,22 @@ +type BloomFilter* {.requiresInit.} = object + capacity*: int + errorRate*: float + kHashes*: int + mBits*: int + intArray*: seq[int] + +proc init*( + T: type BloomFilter, + capacity: int, + errorRate: float, + kHashes: int, + mBits: int, + intArray: seq[int], +): T = + return T( + capacity: capacity, + errorRate: errorRate, + kHashes: kHashes, + mBits: mBits, + intArray: intArray, + ) diff --git a/sds/types/callbacks.nim b/sds/types/callbacks.nim new file mode 100644 index 0000000..f1fc4b3 --- /dev/null +++ b/sds/types/callbacks.nim @@ -0,0 +1,18 @@ +import ./sds_message_id +import ./history_entry +export sds_message_id, history_entry + +type + MessageReadyCallback* = + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + + MessageSentCallback* = + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + + MissingDependenciesCallback* = proc( + messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID + ) {.gcsafe.} + + RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} + + PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} diff --git a/sds/types/channel_context.nim b/sds/types/channel_context.nim new file mode 100644 index 0000000..0346d18 --- /dev/null +++ b/sds/types/channel_context.nim @@ -0,0 +1,22 @@ +import std/tables +import ./sds_message_id +import ./rolling_bloom_filter +import ./unacknowledged_message +import ./incoming_message +export sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message + +type ChannelContext* = ref object + lamportTimestamp*: int64 + messageHistory*: seq[SdsMessageID] + bloomFilter*: RollingBloomFilter + outgoingBuffer*: seq[UnacknowledgedMessage] + incomingBuffer*: Table[SdsMessageID, IncomingMessage] + +proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T = + return T( + lamportTimestamp: 0, + messageHistory: @[], + bloomFilter: bloomFilter, + outgoingBuffer: @[], + incomingBuffer: initTable[SdsMessageID, IncomingMessage](), + ) diff --git a/sds/types/history_entry.nim b/sds/types/history_entry.nim new file mode 100644 index 0000000..2435e6f --- /dev/null +++ b/sds/types/history_entry.nim @@ -0,0 +1,8 @@ +import ./sds_message_id + +type HistoryEntry* = object + messageId*: SdsMessageID + retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash) + +proc init*(T: type HistoryEntry, messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): T = + return T(messageId: messageId, retrievalHint: retrievalHint) diff --git a/sds/types/incoming_message.nim b/sds/types/incoming_message.nim new file mode 100644 index 0000000..47206b4 --- /dev/null +++ b/sds/types/incoming_message.nim @@ -0,0 +1,13 @@ +import std/sets +import ./sds_message_id +import ./sds_message +export sds_message_id, sds_message + +type IncomingMessage* {.requiresInit.} = object + message*: SdsMessage + missingDeps*: HashSet[SdsMessageID] + +proc init*( + T: type IncomingMessage, message: SdsMessage, missingDeps: HashSet[SdsMessageID] +): T = + return T(message: message, missingDeps: missingDeps) diff --git a/sds/types/protobuf_error.nim b/sds/types/protobuf_error.nim new file mode 100644 index 0000000..aff41df --- /dev/null +++ b/sds/types/protobuf_error.nim @@ -0,0 +1,22 @@ +import results +import libp2p/protobuf/minprotobuf + +type + ProtobufErrorKind* {.pure.} = enum + DecodeFailure + MissingRequiredField + + ProtobufError* = object + case kind*: ProtobufErrorKind + of DecodeFailure: + error*: minprotobuf.ProtoError + of MissingRequiredField: + field*: string + + ProtobufResult*[T] = Result[T, ProtobufError] + +proc init*(T: type ProtobufError, error: minprotobuf.ProtoError): T = + return T(kind: ProtobufErrorKind.DecodeFailure, error: error) + +proc init*(T: type ProtobufError, field: string): T = + return T(kind: ProtobufErrorKind.MissingRequiredField, field: field) diff --git a/sds/types/reliability_config.nim b/sds/types/reliability_config.nim new file mode 100644 index 0000000..f4e4e78 --- /dev/null +++ b/sds/types/reliability_config.nim @@ -0,0 +1,45 @@ +import std/times + +const + DefaultMaxMessageHistory* = 1000 + DefaultMaxCausalHistory* = 10 + DefaultResendInterval* = initDuration(seconds = 60) + DefaultMaxResendAttempts* = 5 + DefaultSyncMessageInterval* = initDuration(seconds = 30) + DefaultBufferSweepInterval* = initDuration(seconds = 60) + MaxMessageSize* = 1024 * 1024 # 1 MB + +import ./rolling_bloom_filter +export rolling_bloom_filter + +type ReliabilityConfig* {.requiresInit.} = object + bloomFilterCapacity*: int + bloomFilterErrorRate*: float + maxMessageHistory*: int + maxCausalHistory*: int + resendInterval*: Duration + maxResendAttempts*: int + syncMessageInterval*: Duration + bufferSweepInterval*: Duration + +proc init*( + T: type ReliabilityConfig, + bloomFilterCapacity: int = DefaultBloomFilterCapacity, + bloomFilterErrorRate: float = DefaultBloomFilterErrorRate, + maxMessageHistory: int = DefaultMaxMessageHistory, + maxCausalHistory: int = DefaultMaxCausalHistory, + resendInterval: Duration = DefaultResendInterval, + maxResendAttempts: int = DefaultMaxResendAttempts, + syncMessageInterval: Duration = DefaultSyncMessageInterval, + bufferSweepInterval: Duration = DefaultBufferSweepInterval, +): T = + return T( + bloomFilterCapacity: bloomFilterCapacity, + bloomFilterErrorRate: bloomFilterErrorRate, + maxMessageHistory: maxMessageHistory, + maxCausalHistory: maxCausalHistory, + resendInterval: resendInterval, + maxResendAttempts: maxResendAttempts, + syncMessageInterval: syncMessageInterval, + bufferSweepInterval: bufferSweepInterval, + ) diff --git a/sds/types/reliability_error.nim b/sds/types/reliability_error.nim new file mode 100644 index 0000000..43af2f7 --- /dev/null +++ b/sds/types/reliability_error.nim @@ -0,0 +1,7 @@ +type ReliabilityError* {.pure.} = enum + reInvalidArgument + reOutOfMemory + reInternalError + reSerializationError + reDeserializationError + reMessageTooLarge diff --git a/sds/types/reliability_manager.nim b/sds/types/reliability_manager.nim new file mode 100644 index 0000000..9bfc244 --- /dev/null +++ b/sds/types/reliability_manager.nim @@ -0,0 +1,27 @@ +import std/[tables, locks] +import ./sds_message_id +import ./history_entry +import ./callbacks +import ./reliability_config +import ./channel_context +export sds_message_id, history_entry, callbacks, reliability_config, channel_context + +type ReliabilityManager* = ref object + channels*: Table[SdsChannelID, ChannelContext] + config*: ReliabilityConfig + lock*: Lock + onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} + onMissingDependencies*: proc( + messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID + ) {.gcsafe.} + onPeriodicSync*: PeriodicSyncCallback + onRetrievalHint*: RetrievalHintProvider + +proc new*(T: type ReliabilityManager, config: ReliabilityConfig): T = + let rm = T( + channels: initTable[SdsChannelID, ChannelContext](), + config: config, + ) + rm.lock.initLock() + return rm diff --git a/sds/types/rolling_bloom_filter.nim b/sds/types/rolling_bloom_filter.nim new file mode 100644 index 0000000..1348384 --- /dev/null +++ b/sds/types/rolling_bloom_filter.nim @@ -0,0 +1,31 @@ +import ./bloom_filter +import ./sds_message_id +export bloom_filter, sds_message_id + +const + DefaultBloomFilterCapacity* = 10000 + DefaultBloomFilterErrorRate* = 0.001 + CapacityFlexPercent* = 20 + +type RollingBloomFilter* {.requiresInit.} = object + filter*: BloomFilter + capacity*: int + minCapacity*: int + maxCapacity*: int + messages*: seq[SdsMessageID] + +proc init*( + T: type RollingBloomFilter, + filter: BloomFilter, + capacity: int, + minCapacity: int, + maxCapacity: int, + messages: seq[SdsMessageID] = @[], +): T = + return T( + filter: filter, + capacity: capacity, + minCapacity: minCapacity, + maxCapacity: maxCapacity, + messages: messages, + ) diff --git a/sds/types/sds_message.nim b/sds/types/sds_message.nim new file mode 100644 index 0000000..12f7add --- /dev/null +++ b/sds/types/sds_message.nim @@ -0,0 +1,29 @@ +import ./sds_message_id +import ./history_entry +export sds_message_id, history_entry + +type SdsMessage* {.requiresInit.} = object + messageId*: SdsMessageID + lamportTimestamp*: int64 + causalHistory*: seq[HistoryEntry] + channelId*: SdsChannelID + content*: seq[byte] + bloomFilter*: seq[byte] + +proc init*( + T: type SdsMessage, + messageId: SdsMessageID, + lamportTimestamp: int64, + causalHistory: seq[HistoryEntry], + channelId: SdsChannelID, + content: seq[byte], + bloomFilter: seq[byte], +): T = + return T( + messageId: messageId, + lamportTimestamp: lamportTimestamp, + causalHistory: causalHistory, + channelId: channelId, + content: content, + bloomFilter: bloomFilter, + ) diff --git a/sds/types/sds_message_id.nim b/sds/types/sds_message_id.nim new file mode 100644 index 0000000..3e8b7c7 --- /dev/null +++ b/sds/types/sds_message_id.nim @@ -0,0 +1,3 @@ +type + SdsMessageID* = string + SdsChannelID* = string diff --git a/sds/types/unacknowledged_message.nim b/sds/types/unacknowledged_message.nim new file mode 100644 index 0000000..ff4a4d3 --- /dev/null +++ b/sds/types/unacknowledged_message.nim @@ -0,0 +1,13 @@ +import std/times +import ./sds_message +export sds_message + +type UnacknowledgedMessage* = object + message*: SdsMessage + sendTime*: Time + resendAttempts*: int + +proc init*( + T: type UnacknowledgedMessage, message: SdsMessage, sendTime: Time, resendAttempts: int +): T = + return T(message: message, sendTime: sendTime, resendAttempts: resendAttempts) diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 7770aef..7100606 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -228,7 +228,7 @@ suite "Reliability Mechanisms": # Create a message with bloom filter containing our message var otherPartyBloomFilter = - newRollingBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) + RollingBloomFilter.init(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate) otherPartyBloomFilter.add(id1) let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter)