missing dependencies and simplified bloom filter

This commit is contained in:
shash256 2024-10-18 13:11:02 +04:00
parent 7ba27717f1
commit 297516995c
2 changed files with 93 additions and 41 deletions

View File

@ -1,7 +1,13 @@
import std/[times, sets, hashes, random, sequtils, algorithm] import std/[times, hashes, random, sequtils, algorithm]
import nimsha2 import nimsha2
import chronicles import chronicles
const
BloomFilterSize = 10000
BloomFilterHashCount = 7
MaxMessageHistory = 100
MaxCausalHistory = 10
type type
MessageID* = string MessageID* = string
@ -12,20 +18,17 @@ type
causalHistory*: seq[MessageID] causalHistory*: seq[MessageID]
channelId*: string channelId*: string
content*: string content*: string
bloomFilter*: seq[byte] bloomFilter*: RollingBloomFilter
UnacknowledgedMessage* = object UnacknowledgedMessage* = object
message*: Message message*: Message
sendTime*: Time sendTime*: Time
resendAttempts*: int resendAttempts*: int
TimestampedMessageID* = object
id*: MessageID
timestamp*: Time
RollingBloomFilter* = object RollingBloomFilter* = object
# TODO: Implement a proper Bloom filter # TODO: Implement a proper Bloom filter
data: HashSet[MessageID] data: array[BloomFilterSize, bool]
hashCount: int
ReliabilityManager* = ref object ReliabilityManager* = ref object
lamportTimestamp: int64 lamportTimestamp: int64
@ -36,16 +39,30 @@ type
channelId: string channelId: string
onMessageReady*: proc(messageId: MessageID) onMessageReady*: proc(messageId: MessageID)
onMessageSent*: proc(messageId: MessageID) onMessageSent*: proc(messageId: MessageID)
onPeriodicSync*: proc() onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID])
proc hash(filter: RollingBloomFilter): Hash =
var h: Hash = 0
for idx, val in filter.data:
h = h !& hash(idx) !& hash(val)
result = !$h
proc newRollingBloomFilter(): RollingBloomFilter = proc newRollingBloomFilter(): RollingBloomFilter =
result.data = initHashSet[MessageID]() result.hashCount = BloomFilterHashCount
proc add(filter: var RollingBloomFilter, item: MessageID) = proc add(filter: var RollingBloomFilter, item: MessageID) =
filter.data.incl(item) let itemHash = hash(item)
for i in 0 ..< filter.hashCount:
let idx = (itemHash + i * i) mod BloomFilterSize
filter.data[idx] = true
proc contains(filter: RollingBloomFilter, item: MessageID): bool = proc contains(filter: RollingBloomFilter, item: MessageID): bool =
item in filter.data let itemHash = hash(item)
for i in 0 ..< filter.hashCount:
let idx = (itemHash + i * i) mod BloomFilterSize
if not filter.data[idx]:
return false
return true
proc newReliabilityManager*(channelId: string): ReliabilityManager = proc newReliabilityManager*(channelId: string): ReliabilityManager =
result = ReliabilityManager( result = ReliabilityManager(
@ -61,7 +78,7 @@ proc generateUniqueID(): MessageID =
$secureHash($getTime().toUnix & $rand(high(int))) $secureHash($getTime().toUnix & $rand(high(int)))
proc updateLamportTimestamp(rm: ReliabilityManager, msgTs: int64) = proc updateLamportTimestamp(rm: ReliabilityManager, msgTs: int64) =
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp + 1) rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
proc getRecentMessageIDs(rm: ReliabilityManager, n: int): seq[MessageID] = proc getRecentMessageIDs(rm: ReliabilityManager, n: int): seq[MessageID] =
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1] result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
@ -72,10 +89,10 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: string): Message =
senderId: "TODO_SENDER_ID", senderId: "TODO_SENDER_ID",
messageId: generateUniqueID(), messageId: generateUniqueID(),
lamportTimestamp: rm.lamportTimestamp, lamportTimestamp: rm.lamportTimestamp,
causalHistory: rm.getRecentMessageIDs(10), causalHistory: rm.getRecentMessageIDs(MaxCausalHistory),
channelId: rm.channelId, channelId: rm.channelId,
content: message, content: message,
bloomFilter: @[] # TODO: Implement proper Bloom filter serialization bloomFilter: rm.bloomFilter
) )
rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)) rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0))
msg msg
@ -94,10 +111,14 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: Message): tuple[mes
if missingDeps.len == 0: if missingDeps.len == 0:
rm.messageHistory.add(message.messageId) rm.messageHistory.add(message.messageId)
if rm.messageHistory.len > MaxMessageHistory:
rm.messageHistory.delete(0)
if rm.onMessageReady != nil: if rm.onMessageReady != nil:
rm.onMessageReady(message.messageId) rm.onMessageReady(message.messageId)
else: else:
rm.incomingBuffer.add(message) rm.incomingBuffer.add(message)
if rm.onMissingDependencies != nil:
rm.onMissingDependencies(message.messageId, missingDeps)
(message, missingDeps) (message, missingDeps)
@ -109,27 +130,28 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]) =
for msg in processedMessages: for msg in processedMessages:
rm.messageHistory.add(msg.messageId) rm.messageHistory.add(msg.messageId)
if rm.messageHistory.len > MaxMessageHistory:
rm.messageHistory.delete(0)
if rm.onMessageReady != nil: if rm.onMessageReady != nil:
rm.onMessageReady(msg.messageId) rm.onMessageReady(msg.messageId)
proc checkUnacknowledgedMessages(rm: ReliabilityManager) = proc checkUnacknowledgedMessages(rm: ReliabilityManager) =
let now = getTime() let now = getTime()
rm.outgoingBuffer = rm.outgoingBuffer.filterIt((now - it.sendTime).inSeconds < 60) var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[]
for msg in rm.outgoingBuffer: for msg in rm.outgoingBuffer:
if rm.onMessageSent != nil: if (now - msg.sendTime).inSeconds < 60:
newOutgoingBuffer.add(msg)
elif rm.onMessageSent != nil:
rm.onMessageSent(msg.message.messageId) rm.onMessageSent(msg.message.messageId)
rm.outgoingBuffer = newOutgoingBuffer
proc periodicSync(rm: ReliabilityManager) =
if rm.onPeriodicSync != nil:
rm.onPeriodicSync()
proc setCallbacks*(rm: ReliabilityManager, proc setCallbacks*(rm: ReliabilityManager,
onMessageReady: proc(messageId: MessageID), onMessageReady: proc(messageId: MessageID),
onMessageSent: proc(messageId: MessageID), onMessageSent: proc(messageId: MessageID),
onPeriodicSync: proc()) = onMissingDependencies: proc(messageId: MessageID, missingDeps: seq[MessageID])) =
rm.onMessageReady = onMessageReady rm.onMessageReady = onMessageReady
rm.onMessageSent = onMessageSent rm.onMessageSent = onMessageSent
rm.onPeriodicSync = onPeriodicSync rm.onMissingDependencies = onMissingDependencies
# Logging # Logging
proc logInfo(msg: string) = proc logInfo(msg: string) =
@ -150,8 +172,7 @@ type
causalHistoryLen: cint causalHistoryLen: cint
channelId: cstring channelId: cstring
content: cstring content: cstring
bloomFilter: ptr UncheckedArray[byte] bloomFilter: pointer
bloomFilterLen: cint
CUnwrapResult {.bycopy.} = object CUnwrapResult {.bycopy.} = object
message: CMessage message: CMessage
@ -180,9 +201,7 @@ proc wrap_outgoing_message(rmPtr: pointer, message: cstring): CMessage {.exportc
result.causalHistory[i] = id.cstring result.causalHistory[i] = id.cstring
result.channelId = wrappedMsg.channelId.cstring result.channelId = wrappedMsg.channelId.cstring
result.content = wrappedMsg.content.cstring result.content = wrappedMsg.content.cstring
result.bloomFilter = cast[ptr UncheckedArray[byte]](alloc0(wrappedMsg.bloomFilter.len)) result.bloomFilter = cast[pointer](addr wrappedMsg.bloomFilter)
result.bloomFilterLen = wrappedMsg.bloomFilter.len.cint
copyMem(result.bloomFilter, addr wrappedMsg.bloomFilter[0], wrappedMsg.bloomFilter.len)
proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exportc, cdecl.} = proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exportc, cdecl.} =
let rm = cast[ReliabilityManager](rmPtr) let rm = cast[ReliabilityManager](rmPtr)
@ -193,11 +212,10 @@ proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exp
causalHistory: newSeq[string](msg.causalHistoryLen), causalHistory: newSeq[string](msg.causalHistoryLen),
channelId: $msg.channelId, channelId: $msg.channelId,
content: $msg.content, content: $msg.content,
bloomFilter: newSeq[byte](msg.bloomFilterLen) bloomFilter: cast[RollingBloomFilter](msg.bloomFilter)[]
) )
for i in 0 ..< msg.causalHistoryLen: for i in 0 ..< msg.causalHistoryLen:
nimMsg.causalHistory[i] = $msg.causalHistory[i] nimMsg.causalHistory[i] = $msg.causalHistory[i]
copyMem(addr nimMsg.bloomFilter[0], msg.bloomFilter, msg.bloomFilterLen)
let (unwrappedMsg, missingDeps) = rm.unwrapReceivedMessage(nimMsg) let (unwrappedMsg, missingDeps) = rm.unwrapReceivedMessage(nimMsg)
@ -209,12 +227,10 @@ proc unwrap_received_message(rmPtr: pointer, msg: CMessage): CUnwrapResult {.exp
causalHistoryLen: unwrappedMsg.causalHistory.len.cint, causalHistoryLen: unwrappedMsg.causalHistory.len.cint,
channelId: unwrappedMsg.channelId.cstring, channelId: unwrappedMsg.channelId.cstring,
content: unwrappedMsg.content.cstring, content: unwrappedMsg.content.cstring,
bloomFilter: cast[ptr UncheckedArray[byte]](alloc0(unwrappedMsg.bloomFilter.len)), bloomFilter: cast[pointer](addr unwrappedMsg.bloomFilter)
bloomFilterLen: unwrappedMsg.bloomFilter.len.cint
) )
for i, id in unwrappedMsg.causalHistory: for i, id in unwrappedMsg.causalHistory:
result.message.causalHistory[i] = id.cstring result.message.causalHistory[i] = id.cstring
copyMem(result.message.bloomFilter, addr unwrappedMsg.bloomFilter[0], unwrappedMsg.bloomFilter.len)
result.missingDeps = cast[ptr UncheckedArray[cstring]](alloc0(missingDeps.len * sizeof(cstring))) result.missingDeps = cast[ptr UncheckedArray[cstring]](alloc0(missingDeps.len * sizeof(cstring)))
result.missingDepsLen = missingDeps.len.cint result.missingDepsLen = missingDeps.len.cint
@ -231,12 +247,17 @@ proc mark_dependencies_met(rmPtr: pointer, messageIds: ptr UncheckedArray[cstrin
proc set_callbacks(rmPtr: pointer, proc set_callbacks(rmPtr: pointer,
onMessageReady: proc(messageId: cstring) {.cdecl.}, onMessageReady: proc(messageId: cstring) {.cdecl.},
onMessageSent: proc(messageId: cstring) {.cdecl.}, onMessageSent: proc(messageId: cstring) {.cdecl.},
onPeriodicSync: proc() {.cdecl.}) {.exportc, cdecl.} = onMissingDependencies: proc(messageId: cstring, missingDeps: ptr UncheckedArray[cstring], missingDepsLen: cint) {.cdecl.}) {.exportc, cdecl.} =
let rm = cast[ReliabilityManager](rmPtr) let rm = cast[ReliabilityManager](rmPtr)
rm.setCallbacks( rm.setCallbacks(
proc(messageId: MessageID) = onMessageReady(messageId.cstring), proc(messageId: MessageID) = onMessageReady(messageId.cstring),
proc(messageId: MessageID) = onMessageSent(messageId.cstring), proc(messageId: MessageID) = onMessageSent(messageId.cstring),
onPeriodicSync proc(messageId: MessageID, missingDeps: seq[MessageID]) =
var cMissingDeps = cast[ptr UncheckedArray[cstring]](alloc0(missingDeps.len * sizeof(cstring)))
for i, dep in missingDeps:
cMissingDeps[i] = dep.cstring
onMissingDependencies(messageId.cstring, cMissingDeps, missingDeps.len.cint)
dealloc(cMissingDeps)
) )
{.pop.} {.pop.}

View File

@ -34,21 +34,23 @@ suite "ReliabilityManager":
test "callbacks": test "callbacks":
var messageReadyCount = 0 var messageReadyCount = 0
var messageSentCount = 0 var messageSentCount = 0
var periodicSyncCount = 0 var missingDepsCount = 0
rm.setCallbacks( rm.setCallbacks(
proc(messageId: MessageID) = messageReadyCount += 1, proc(messageId: MessageID) = messageReadyCount += 1,
proc(messageId: MessageID) = messageSentCount += 1, proc(messageId: MessageID) = messageSentCount += 1,
proc() = periodicSyncCount += 1 proc(messageId: MessageID, missingDeps: seq[MessageID]) = missingDepsCount += 1
) )
let msg = rm.wrapOutgoingMessage("Test callback") let msg1 = rm.wrapOutgoingMessage("Message 1")
discard rm.unwrapReceivedMessage(msg) let msg2 = rm.wrapOutgoingMessage("Message 2")
discard rm.unwrapReceivedMessage(msg1)
discard rm.unwrapReceivedMessage(msg2)
check: check:
messageReadyCount == 1 messageReadyCount == 2
messageSentCount == 0 # This would be triggered by the checkUnacknowledgedMessages function messageSentCount == 0 # This would be triggered by the checkUnacknowledgedMessages function
periodicSyncCount == 0 # This would be triggered by the periodicSync function missingDepsCount == 0
test "lamport timestamps": test "lamport timestamps":
let msg1 = rm.wrapOutgoingMessage("Message 1") let msg1 = rm.wrapOutgoingMessage("Message 1")
@ -76,4 +78,33 @@ suite "ReliabilityManager":
check missingDeps1.len == 0 check missingDeps1.len == 0
let (_, missingDeps2) = rm.unwrapReceivedMessage(msg1) let (_, missingDeps2) = rm.unwrapReceivedMessage(msg1)
check missingDeps2.len == 0 # The message should be in the bloom filter and not processed again check missingDeps2.len == 0 # The message should be in the bloom filter and not processed again
test "message history limit":
for i in 1..MaxMessageHistory + 10:
let msg = rm.wrapOutgoingMessage($i)
discard rm.unwrapReceivedMessage(msg)
check rm.messageHistory.len <= MaxMessageHistory
test "missing dependencies callback":
var missingDepsReceived: seq[MessageID] = @[]
rm.setCallbacks(
proc(messageId: MessageID) = discard,
proc(messageId: MessageID) = discard,
proc(messageId: MessageID, missingDeps: seq[MessageID]) = missingDepsReceived = missingDeps
)
let msg1 = rm.wrapOutgoingMessage("Message 1")
let msg2 = rm.wrapOutgoingMessage("Message 2")
let msg3 = Message(
messageId: generateUniqueID(),
lamportTimestamp: msg2.lamportTimestamp + 1,
causalHistory: @[msg1.messageId, msg2.messageId],
content: "Message 3"
)
discard rm.unwrapReceivedMessage(msg3)
check missingDepsReceived.len == 2
check missingDepsReceived.contains(msg1.messageId)
check missingDepsReceived.contains(msg2.messageId)