mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-01-02 14:13:07 +00:00
feat: review ack status
This commit is contained in:
parent
d92d2f733a
commit
97e2f681b9
@ -10,6 +10,7 @@ type
|
||||
causalHistory*: seq[MessageID]
|
||||
channelId*: string
|
||||
content*: seq[byte]
|
||||
bloomFilter*: seq[byte]
|
||||
|
||||
UnacknowledgedMessage* = object
|
||||
message*: Message
|
||||
|
||||
@ -2,10 +2,7 @@ import ./protobufutil
|
||||
import ./common
|
||||
import libp2p/protobuf/minprotobuf
|
||||
import std/options
|
||||
|
||||
proc toString(bytes: seq[byte]): string =
|
||||
result = newString(bytes.len)
|
||||
copyMem(result[0].addr, bytes[0].unsafeAddr, bytes.len)
|
||||
import "../nim-bloom/src/bloom"
|
||||
|
||||
proc toBytes(s: string): seq[byte] =
|
||||
result = newSeq[byte](s.len)
|
||||
@ -22,6 +19,7 @@ proc encode*(msg: Message): ProtoBuffer =
|
||||
|
||||
pb.write(4, msg.channelId)
|
||||
pb.write(5, msg.content)
|
||||
pb.write(6, msg.bloomFilter)
|
||||
pb.finish()
|
||||
|
||||
pb
|
||||
@ -39,11 +37,10 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||
msg.lamportTimestamp = int64(timestamp)
|
||||
|
||||
# Decode causal history
|
||||
var histories: seq[seq[byte]]
|
||||
for histBytes in histories:
|
||||
let hist = histBytes.toString
|
||||
if hist notin msg.causalHistory: # Avoid duplicate entries
|
||||
msg.causalHistory.add(hist)
|
||||
var causalHistory: seq[string]
|
||||
let histResult = pb.getRepeatedField(3, causalHistory)
|
||||
if histResult.isOk:
|
||||
msg.causalHistory = causalHistory
|
||||
|
||||
if not ?pb.getField(4, msg.channelId):
|
||||
return err(ProtobufError.missingRequiredField("channelId"))
|
||||
@ -51,6 +48,9 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||
if not ?pb.getField(5, msg.content):
|
||||
return err(ProtobufError.missingRequiredField("content"))
|
||||
|
||||
if not ?pb.getField(6, msg.bloomFilter):
|
||||
msg.bloomFilter = @[] # Empty if not present
|
||||
|
||||
ok(msg)
|
||||
|
||||
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
|
||||
@ -67,5 +67,58 @@ proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
|
||||
ok(msgResult.get)
|
||||
else:
|
||||
err(reSerializationError)
|
||||
except:
|
||||
err(reDeserializationError)
|
||||
|
||||
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
|
||||
try:
|
||||
var pb = initProtoBuffer()
|
||||
|
||||
# Convert intArray to bytes
|
||||
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
|
||||
for i, val in filter.intArray:
|
||||
let start = i * sizeof(int)
|
||||
copyMem(addr bytes[start], unsafeAddr val, sizeof(int))
|
||||
|
||||
pb.write(1, bytes)
|
||||
pb.write(2, uint64(filter.capacity))
|
||||
pb.write(3, uint64(filter.errorRate * 1_000_000))
|
||||
pb.write(4, uint64(filter.kHashes))
|
||||
pb.write(5, uint64(filter.mBits))
|
||||
|
||||
pb.finish()
|
||||
ok(pb.buffer)
|
||||
except:
|
||||
err(reSerializationError)
|
||||
|
||||
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
|
||||
if data.len == 0:
|
||||
return err(reDeserializationError)
|
||||
|
||||
try:
|
||||
let pb = initProtoBuffer(data)
|
||||
var bytes: seq[byte]
|
||||
var cap, errRate, kHashes, mBits: uint64
|
||||
|
||||
if not pb.getField(1, bytes).get() or
|
||||
not pb.getField(2, cap).get() or
|
||||
not pb.getField(3, errRate).get() or
|
||||
not pb.getField(4, kHashes).get() or
|
||||
not pb.getField(5, mBits).get():
|
||||
return err(reDeserializationError)
|
||||
|
||||
# Convert bytes back to intArray
|
||||
var intArray = newSeq[int](bytes.len div sizeof(int))
|
||||
for i in 0 ..< intArray.len:
|
||||
let start = i * sizeof(int)
|
||||
copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int))
|
||||
|
||||
ok(BloomFilter(
|
||||
intArray: intArray,
|
||||
capacity: int(cap),
|
||||
errorRate: float(errRate) / 1_000_000,
|
||||
kHashes: int(kHashes),
|
||||
mBits: int(mBits)
|
||||
))
|
||||
except:
|
||||
err(reDeserializationError)
|
||||
@ -52,6 +52,39 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau
|
||||
except:
|
||||
return err(reOutOfMemory)
|
||||
|
||||
proc reviewAckStatus(rm: ReliabilityManager, msg: Message) =
|
||||
var i = 0
|
||||
while i < rm.outgoingBuffer.len:
|
||||
var acknowledged = false
|
||||
let outMsg = rm.outgoingBuffer[i]
|
||||
|
||||
# Check if message is in causal history
|
||||
for msgID in msg.causalHistory:
|
||||
if outMsg.message.messageId == msgID:
|
||||
acknowledged = true
|
||||
break
|
||||
|
||||
# Check bloom filter if not already acknowledged
|
||||
if not acknowledged and msg.bloomFilter.len > 0:
|
||||
let bfResult = deserializeBloomFilter(msg.bloomFilter)
|
||||
if bfResult.isOk:
|
||||
var rbf = RollingBloomFilter(
|
||||
filter: bfResult.get(),
|
||||
window: rm.bloomFilter.window,
|
||||
messages: @[]
|
||||
)
|
||||
if rbf.contains(outMsg.message.messageId):
|
||||
acknowledged = true
|
||||
else:
|
||||
logError("Failed to deserialize bloom filter")
|
||||
|
||||
if acknowledged:
|
||||
if rm.onMessageSent != nil:
|
||||
rm.onMessageSent(outMsg.message.messageId)
|
||||
rm.outgoingBuffer.delete(i)
|
||||
else:
|
||||
inc i
|
||||
|
||||
proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: MessageID): Result[seq[byte], ReliabilityError] =
|
||||
## Wraps an outgoing message with reliability metadata.
|
||||
##
|
||||
@ -68,16 +101,35 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId:
|
||||
withLock rm.lock:
|
||||
try:
|
||||
rm.updateLamportTimestamp(getTime().toUnix)
|
||||
|
||||
# Serialize current bloom filter
|
||||
var bloomBytes: seq[byte]
|
||||
let bfResult = serializeBloomFilter(rm.bloomFilter.filter)
|
||||
if bfResult.isErr:
|
||||
logError("Failed to serialize bloom filter")
|
||||
bloomBytes = @[]
|
||||
else:
|
||||
bloomBytes = bfResult.get()
|
||||
|
||||
let msg = Message(
|
||||
messageId: messageId,
|
||||
lamportTimestamp: rm.lamportTimestamp,
|
||||
causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory),
|
||||
channelId: rm.channelId,
|
||||
content: message
|
||||
content: message,
|
||||
bloomFilter: bloomBytes
|
||||
)
|
||||
rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0))
|
||||
# rm.messageHistory.add(messageId)
|
||||
# rm.bloomFilter.add(messageId)
|
||||
|
||||
# Add to outgoing buffer
|
||||
rm.outgoingBuffer.add(UnacknowledgedMessage(
|
||||
message: msg,
|
||||
sendTime: getTime(),
|
||||
resendAttempts: 0
|
||||
))
|
||||
|
||||
# Add to causal history and bloom filter
|
||||
rm.addToBloomAndHistory(msg)
|
||||
|
||||
return serializeMessage(msg)
|
||||
except:
|
||||
return err(reInternalError)
|
||||
@ -100,21 +152,24 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[
|
||||
if rm.bloomFilter.contains(msg.messageId):
|
||||
return ok((msg.content, @[]))
|
||||
|
||||
rm.bloomFilter.add(msg.messageId)
|
||||
# Update Lamport timestamp
|
||||
rm.updateLamportTimestamp(msg.lamportTimestamp)
|
||||
|
||||
# Review ACK status for outgoing messages
|
||||
rm.reviewAckStatus(msg)
|
||||
|
||||
var missingDeps: seq[MessageID] = @[]
|
||||
for depId in msg.causalHistory:
|
||||
if not rm.bloomFilter.contains(depId):
|
||||
missingDeps.add(depId)
|
||||
|
||||
if missingDeps.len == 0:
|
||||
rm.messageHistory.add(msg.messageId)
|
||||
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
||||
rm.messageHistory.delete(0)
|
||||
# All dependencies met, add to history
|
||||
rm.addToBloomAndHistory(msg)
|
||||
if rm.onMessageReady != nil:
|
||||
rm.onMessageReady(msg.messageId)
|
||||
else:
|
||||
# Buffer message and request missing dependencies
|
||||
rm.incomingBuffer.add(msg)
|
||||
if rm.onMissingDependencies != nil:
|
||||
rm.onMissingDependencies(msg.messageId, missingDeps)
|
||||
@ -136,6 +191,12 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R
|
||||
var processedMessages: seq[Message] = @[]
|
||||
var newIncomingBuffer: seq[Message] = @[]
|
||||
|
||||
# Add all messageIds to both bloom filter and causal history
|
||||
for msgId in messageIds:
|
||||
if not rm.bloomFilter.contains(msgId):
|
||||
rm.bloomFilter.add(msgId)
|
||||
rm.messageHistory.add(msgId)
|
||||
|
||||
for msg in rm.incomingBuffer:
|
||||
var allDependenciesMet = true
|
||||
for depId in msg.causalHistory:
|
||||
@ -145,15 +206,13 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R
|
||||
|
||||
if allDependenciesMet:
|
||||
processedMessages.add(msg)
|
||||
rm.addToBloomAndHistory(msg)
|
||||
else:
|
||||
newIncomingBuffer.add(msg)
|
||||
|
||||
rm.incomingBuffer = newIncomingBuffer
|
||||
|
||||
for msg in processedMessages:
|
||||
rm.messageHistory.add(msg.messageId)
|
||||
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
||||
rm.messageHistory.delete(0)
|
||||
if rm.onMessageReady != nil:
|
||||
rm.onMessageReady(msg.messageId)
|
||||
|
||||
|
||||
@ -75,6 +75,12 @@ proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
||||
except Exception as e:
|
||||
logError("Failed to clean ReliabilityManager bloom filter: " & e.msg)
|
||||
|
||||
proc addToBloomAndHistory*(rm: ReliabilityManager, msg: Message) =
|
||||
rm.messageHistory.add(msg.messageId)
|
||||
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
||||
rm.messageHistory.delete(0)
|
||||
rm.bloomFilter.add(msg.messageId)
|
||||
|
||||
proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) =
|
||||
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import unittest, results, chronos, chronicles
|
||||
import unittest, results, chronos
|
||||
import ../src/reliability
|
||||
import ../src/common
|
||||
import ../src/protobuf
|
||||
import ../src/utils
|
||||
|
||||
suite "ReliabilityManager":
|
||||
var rm: ReliabilityManager
|
||||
@ -41,38 +43,58 @@ suite "ReliabilityManager":
|
||||
unwrapped == msg
|
||||
missingDeps.len == 0
|
||||
|
||||
test "markDependenciesMet":
|
||||
# First message
|
||||
let msg1 = @[byte(1)]
|
||||
test "marking dependencies":
|
||||
var messageReadyCount = 0
|
||||
var messageSentCount = 0
|
||||
var missingDepsCount = 0
|
||||
|
||||
rm.setCallbacks(
|
||||
proc(messageId: MessageID) {.gcsafe.} = messageReadyCount += 1,
|
||||
proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1,
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1
|
||||
)
|
||||
|
||||
# We'll create dependency IDs that aren't in the bloom filter yet
|
||||
let id1 = "msg1"
|
||||
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
|
||||
check wrap1.isOk()
|
||||
let wrapped1 = wrap1.get()
|
||||
|
||||
# Second message
|
||||
let msg2 = @[byte(2)]
|
||||
let id2 = "msg2"
|
||||
let wrap2 = rm.wrapOutgoingMessage(msg2, id2)
|
||||
check wrap2.isOk()
|
||||
let wrapped2 = wrap2.get()
|
||||
|
||||
# Third message
|
||||
let msg3 = @[byte(3)]
|
||||
let id3 = "msg3"
|
||||
let wrap3 = rm.wrapOutgoingMessage(msg3, id3)
|
||||
check wrap3.isOk()
|
||||
let wrapped3 = wrap3.get()
|
||||
# Create a message that depends on these IDs
|
||||
let msg3 = Message(
|
||||
messageId: "msg3",
|
||||
lamportTimestamp: 1,
|
||||
causalHistory: @[id1, id2], # Depends on messages we haven't seen
|
||||
channelId: "testChannel",
|
||||
content: @[byte(3)],
|
||||
bloomFilter: @[]
|
||||
)
|
||||
|
||||
# Check dependencies
|
||||
var unwrap3 = rm.unwrapReceivedMessage(wrapped3)
|
||||
check unwrap3.isOk()
|
||||
var (_, missing3) = unwrap3.get()
|
||||
let serializedMsg3 = serializeMessage(msg3)
|
||||
check serializedMsg3.isOk()
|
||||
|
||||
# Mark dependencies as met
|
||||
let markResult = rm.markDependenciesMet(@[id1, id2])
|
||||
# Process the message - should identify missing dependencies
|
||||
let unwrapResult = rm.unwrapReceivedMessage(serializedMsg3.get())
|
||||
check unwrapResult.isOk()
|
||||
let (_, missingDeps) = unwrapResult.get()
|
||||
|
||||
# Verify missing dependencies were identified
|
||||
check missingDepsCount == 1
|
||||
check missingDeps.len == 2
|
||||
check id1 in missingDeps
|
||||
check id2 in missingDeps
|
||||
|
||||
# Now mark dependencies as met
|
||||
let markResult = rm.markDependenciesMet(missingDeps)
|
||||
check markResult.isOk()
|
||||
|
||||
check missing3.len == 0
|
||||
# Process the message again - should now be ready
|
||||
let reprocessResult = rm.unwrapReceivedMessage(serializedMsg3.get())
|
||||
check reprocessResult.isOk()
|
||||
let (_, remainingDeps) = reprocessResult.get()
|
||||
|
||||
# Verify message is now processed
|
||||
check remainingDeps.len == 0
|
||||
check messageReadyCount == 1 # msg3 should now be ready
|
||||
check missingDepsCount == 1 # Only the first attempt should report missing deps
|
||||
|
||||
test "callbacks work correctly":
|
||||
var messageReadyCount = 0
|
||||
@ -85,18 +107,76 @@ suite "ReliabilityManager":
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1
|
||||
)
|
||||
|
||||
let msg1Result = rm.wrapOutgoingMessage(@[byte(1)], "msg1")
|
||||
let msg2Result = rm.wrapOutgoingMessage(@[byte(2)], "msg2")
|
||||
check msg1Result.isOk() and msg2Result.isOk()
|
||||
let msg1 = msg1Result.get()
|
||||
let msg2 = msg2Result.get()
|
||||
discard rm.unwrapReceivedMessage(msg1)
|
||||
discard rm.unwrapReceivedMessage(msg2)
|
||||
# First send our own message
|
||||
let msg1 = @[byte(1)]
|
||||
let id1 = "msg1"
|
||||
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
|
||||
check wrap1.isOk()
|
||||
|
||||
check:
|
||||
messageReadyCount == 2
|
||||
messageSentCount == 0 # This would be triggered by checkUnacknowledgedMessages
|
||||
missingDepsCount == 0
|
||||
# Create a message that has our message in causal history
|
||||
let msg2 = Message(
|
||||
messageId: "msg2",
|
||||
lamportTimestamp: rm.lamportTimestamp + 1,
|
||||
causalHistory: @[id1], # Include our message in causal history
|
||||
channelId: "testChannel",
|
||||
content: @[byte(2)],
|
||||
bloomFilter: @[] # Test with an empty bloom filter
|
||||
)
|
||||
|
||||
let serializedMsg2 = serializeMessage(msg2)
|
||||
check serializedMsg2.isOk()
|
||||
|
||||
# Process the "received" message - should trigger callbacks
|
||||
let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get())
|
||||
check unwrapResult.isOk()
|
||||
|
||||
check messageReadyCount == 1 # For msg2 which we "received"
|
||||
check messageSentCount == 1 # For msg1 which was acknowledged via causal history
|
||||
|
||||
test "bloom filter acknowledgment":
|
||||
var messageSentCount = 0
|
||||
|
||||
rm.setCallbacks(
|
||||
proc(messageId: MessageID) {.gcsafe.} = discard,
|
||||
proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1,
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard
|
||||
)
|
||||
|
||||
# First send our own message
|
||||
let msg1 = @[byte(1)]
|
||||
let id1 = "msg1"
|
||||
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
|
||||
check wrap1.isOk()
|
||||
|
||||
# Create a message simulating another party's message
|
||||
# with bloom filter containing our message
|
||||
var otherPartyBloomFilter = newRollingBloomFilter(
|
||||
DefaultBloomFilterCapacity,
|
||||
DefaultBloomFilterErrorRate,
|
||||
DefaultBloomFilterWindow
|
||||
)
|
||||
otherPartyBloomFilter.add(id1) # Add our message to their bloom filter
|
||||
|
||||
let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter)
|
||||
check bfResult.isOk()
|
||||
|
||||
let msg2 = Message(
|
||||
messageId: "msg2",
|
||||
lamportTimestamp: rm.lamportTimestamp + 1,
|
||||
causalHistory: @[], # Empty causal history as we're using bloom filter
|
||||
channelId: "testChannel",
|
||||
content: @[byte(2)],
|
||||
bloomFilter: bfResult.get()
|
||||
)
|
||||
|
||||
let serializedMsg2 = serializeMessage(msg2)
|
||||
check serializedMsg2.isOk()
|
||||
|
||||
# Process the "received" message - should trigger acknowledgment
|
||||
let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get())
|
||||
check unwrapResult.isOk()
|
||||
|
||||
check messageSentCount == 1 # Our message should be acknowledged via bloom filter
|
||||
|
||||
test "periodic sync callback works":
|
||||
var syncCallCount = 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user