diff --git a/src/reliability.nim b/src/reliability.nim index 9805a22..87d238a 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -1,7 +1,13 @@ -import std/[times, sets, hashes, random, sequtils, algorithm] +import std/[times, hashes, random, sequtils, algorithm] import nimsha2 import chronicles +const + BloomFilterSize = 10000 + BloomFilterHashCount = 7 + MaxMessageHistory = 100 + MaxCausalHistory = 10 + type MessageID* = string @@ -12,20 +18,17 @@ type causalHistory*: seq[MessageID] channelId*: string content*: string - bloomFilter*: seq[byte] + bloomFilter*: RollingBloomFilter UnacknowledgedMessage* = object message*: Message sendTime*: Time resendAttempts*: int - TimestampedMessageID* = object - id*: MessageID - timestamp*: Time - RollingBloomFilter* = object # TODO: Implement a proper Bloom filter - data: HashSet[MessageID] + data: array[BloomFilterSize, bool] + hashCount: int ReliabilityManager* = ref object lamportTimestamp: int64 @@ -36,16 +39,30 @@ type channelId: string onMessageReady*: proc(messageId: MessageID) onMessageSent*: proc(messageId: MessageID) - onPeriodicSync*: proc() + onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID]) + +proc hash(filter: RollingBloomFilter): Hash = + var h: Hash = 0 + for idx, val in filter.data: + h = h !& hash(idx) !& hash(val) + result = !$h proc newRollingBloomFilter(): RollingBloomFilter = - result.data = initHashSet[MessageID]() + result.hashCount = BloomFilterHashCount proc add(filter: var RollingBloomFilter, item: MessageID) = - filter.data.incl(item) + let itemHash = hash(item) + for i in 0 ..< filter.hashCount: + let idx = (itemHash + i * i) mod BloomFilterSize + filter.data[idx] = true proc contains(filter: RollingBloomFilter, item: MessageID): bool = - item in filter.data + let itemHash = hash(item) + for i in 0 ..< filter.hashCount: + let idx = (itemHash + i * i) mod BloomFilterSize + if not filter.data[idx]: + return false + return true proc newReliabilityManager*(channelId: string): ReliabilityManager = result = ReliabilityManager( @@ -61,7 +78,7 @@ proc generateUniqueID(): MessageID = $secureHash($getTime().toUnix & $rand(high(int))) proc updateLamportTimestamp(rm: ReliabilityManager, msgTs: int64) = - rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp + 1) + rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1 proc getRecentMessageIDs(rm: ReliabilityManager, n: int): seq[MessageID] = result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] @@ -72,10 +89,10 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: string): Message = senderId: "TODO_SENDER_ID", messageId: generateUniqueID(), lamportTimestamp: rm.lamportTimestamp, - causalHistory: rm.getRecentMessageIDs(10), + causalHistory: rm.getRecentMessageIDs(MaxCausalHistory), channelId: rm.channelId, content: message, - bloomFilter: @[] # TODO: Implement proper Bloom filter serialization + bloomFilter: rm.bloomFilter ) rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)) msg @@ -94,10 +111,14 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: Message): tuple[mes if missingDeps.len == 0: rm.messageHistory.add(message.messageId) + if rm.messageHistory.len > MaxMessageHistory: + rm.messageHistory.delete(0) if rm.onMessageReady != nil: rm.onMessageReady(message.messageId) else: rm.incomingBuffer.add(message) + if rm.onMissingDependencies != nil: + rm.onMissingDependencies(message.messageId, missingDeps) (message, missingDeps) @@ -109,27 +130,28 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]) = for msg in processedMessages: rm.messageHistory.add(msg.messageId) + if rm.messageHistory.len > MaxMessageHistory: + rm.messageHistory.delete(0) if rm.onMessageReady != nil: rm.onMessageReady(msg.messageId) proc checkUnacknowledgedMessages(rm: ReliabilityManager) = let now = getTime() - rm.outgoingBuffer = rm.outgoingBuffer.filterIt((now - it.sendTime).inSeconds < 60) + var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[] for msg in rm.outgoingBuffer: - if rm.onMessageSent != nil: + if (now - msg.sendTime).inSeconds < 60: + newOutgoingBuffer.add(msg) + elif rm.onMessageSent != nil: rm.onMessageSent(msg.message.messageId) - -proc periodicSync(rm: ReliabilityManager) = - if rm.onPeriodicSync != nil: - rm.onPeriodicSync() + rm.outgoingBuffer = newOutgoingBuffer proc setCallbacks*(rm: ReliabilityManager, onMessageReady: proc(messageId: MessageID), onMessageSent: proc(messageId: MessageID), - onPeriodicSync: proc()) = + onMissingDependencies: proc(messageId: MessageID, missingDeps: seq[MessageID])) = rm.onMessageReady = onMessageReady rm.onMessageSent = onMessageSent - rm.onPeriodicSync = onPeriodicSync + rm.onMissingDependencies = onMissingDependencies # Logging proc logInfo(msg: string) = @@ -150,8 +172,7 @@ type causalHistoryLen: cint channelId: cstring content: cstring - bloomFilter: ptr UncheckedArray[byte] - bloomFilterLen: cint + bloomFilter: pointer CUnwrapResult {.bycopy.} = object message: CMessage @@ -180,9 +201,7 @@ proc wrap_outgoing_message(rmPtr: pointer, message: cstring): CMessage {.exportc result.causalHistory[i] = id.cstring result.channelId = wrappedMsg.channelId.cstring result.content = wrappedMsg.content.cstring - result.bloomFilter = cast[ptr UncheckedArray[byte]](alloc0(wrappedMsg.bloomFilter.len)) - result.bloomFilterLen = wrappedMsg.bloomFilter.len.cint - copyMem(result.bloomFilter, addr wrappedMsg.bloomFilter[0], wrappedMsg.bloomFilter.len) + result.bloomFilter = cast[pointer](addr wrappedMsg.bloomFilter) proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exportc, cdecl.} = let rm = cast[ReliabilityManager](rmPtr) @@ -193,11 +212,10 @@ proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exp causalHistory: newSeq[string](msg.causalHistoryLen), channelId: $msg.channelId, content: $msg.content, - bloomFilter: newSeq[byte](msg.bloomFilterLen) + bloomFilter: cast[RollingBloomFilter](msg.bloomFilter)[] ) for i in 0 ..< msg.causalHistoryLen: nimMsg.causalHistory[i] = $msg.causalHistory[i] - copyMem(addr nimMsg.bloomFilter[0], msg.bloomFilter, msg.bloomFilterLen) let (unwrappedMsg, missingDeps) = rm.unwrapReceivedMessage(nimMsg) @@ -209,12 +227,10 @@ proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exp causalHistoryLen: unwrappedMsg.causalHistory.len.cint, channelId: unwrappedMsg.channelId.cstring, content: unwrappedMsg.content.cstring, - bloomFilter: cast[ptr UncheckedArray[byte]](alloc0(unwrappedMsg.bloomFilter.len)), - bloomFilterLen: unwrappedMsg.bloomFilter.len.cint + bloomFilter: cast[pointer](addr unwrappedMsg.bloomFilter) ) for i, id in unwrappedMsg.causalHistory: result.message.causalHistory[i] = id.cstring - copyMem(result.message.bloomFilter, addr unwrappedMsg.bloomFilter[0], unwrappedMsg.bloomFilter.len) result.missingDeps = cast[ptr UncheckedArray[cstring]](alloc0(missingDeps.len * sizeof(cstring))) result.missingDepsLen = missingDeps.len.cint @@ -231,12 +247,17 @@ proc mark_dependencies_met(rmPtr: pointer, messageIds: ptr UncheckedArray[cstrin proc set_callbacks(rmPtr: pointer, onMessageReady: proc(messageId: cstring) {.cdecl.}, onMessageSent: proc(messageId: cstring) {.cdecl.}, - onPeriodicSync: proc() {.cdecl.}) {.exportc, cdecl.} = + onMissingDependencies: proc(messageId: cstring, missingDeps: ptr UncheckedArray[cstring], missingDepsLen: cint) {.cdecl.}) {.exportc, cdecl.} = let rm = cast[ReliabilityManager](rmPtr) rm.setCallbacks( proc(messageId: MessageID) = onMessageReady(messageId.cstring), proc(messageId: MessageID) = onMessageSent(messageId.cstring), - onPeriodicSync + proc(messageId: MessageID, missingDeps: seq[MessageID]) = + var cMissingDeps = cast[ptr UncheckedArray[cstring]](alloc0(missingDeps.len * sizeof(cstring))) + for i, dep in missingDeps: + cMissingDeps[i] = dep.cstring + onMissingDependencies(messageId.cstring, cMissingDeps, missingDeps.len.cint) + dealloc(cMissingDeps) ) {.pop.} diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 241ae43..3ddb36f 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -34,21 +34,23 @@ suite "ReliabilityManager": test "callbacks": var messageReadyCount = 0 var messageSentCount = 0 - var periodicSyncCount = 0 + var missingDepsCount = 0 rm.setCallbacks( proc(messageId: MessageID) = messageReadyCount += 1, proc(messageId: MessageID) = messageSentCount += 1, - proc() = periodicSyncCount += 1 + proc(messageId: MessageID, missingDeps: seq[MessageID]) = missingDepsCount += 1 ) - let msg = rm.wrapOutgoingMessage("Test callback") - discard rm.unwrapReceivedMessage(msg) + let msg1 = rm.wrapOutgoingMessage("Message 1") + let msg2 = rm.wrapOutgoingMessage("Message 2") + discard rm.unwrapReceivedMessage(msg1) + discard rm.unwrapReceivedMessage(msg2) check: - messageReadyCount == 1 + messageReadyCount == 2 messageSentCount == 0 # This would be triggered by the checkUnacknowledgedMessages function - periodicSyncCount == 0 # This would be triggered by the periodicSync function + missingDepsCount == 0 test "lamport timestamps": let msg1 = rm.wrapOutgoingMessage("Message 1") @@ -76,4 +78,33 @@ suite "ReliabilityManager": check missingDeps1.len == 0 let (_, missingDeps2) = rm.unwrapReceivedMessage(msg1) - check missingDeps2.len == 0 # The message should be in the bloom filter and not processed again \ No newline at end of file + check missingDeps2.len == 0 # The message should be in the bloom filter and not processed again + + test "message history limit": + for i in 1..MaxMessageHistory + 10: + let msg = rm.wrapOutgoingMessage($i) + discard rm.unwrapReceivedMessage(msg) + + check rm.messageHistory.len <= MaxMessageHistory + + test "missing dependencies callback": + var missingDepsReceived: seq[MessageID] = @[] + rm.setCallbacks( + proc(messageId: MessageID) = discard, + proc(messageId: MessageID) = discard, + proc(messageId: MessageID, missingDeps: seq[MessageID]) = missingDepsReceived = missingDeps + ) + + let msg1 = rm.wrapOutgoingMessage("Message 1") + let msg2 = rm.wrapOutgoingMessage("Message 2") + let msg3 = Message( + messageId: generateUniqueID(), + lamportTimestamp: msg2.lamportTimestamp + 1, + causalHistory: @[msg1.messageId, msg2.messageId], + content: "Message 3" + ) + + discard rm.unwrapReceivedMessage(msg3) + check missingDepsReceived.len == 2 + check missingDepsReceived.contains(msg1.messageId) + check missingDepsReceived.contains(msg2.messageId) \ No newline at end of file