feat: review ack status

This commit is contained in:
shash256 2024-12-12 16:04:24 +04:00
parent d92d2f733a
commit 97e2f681b9
5 changed files with 256 additions and 57 deletions

View File

@ -10,6 +10,7 @@ type
causalHistory*: seq[MessageID]
channelId*: string
content*: seq[byte]
bloomFilter*: seq[byte]
UnacknowledgedMessage* = object
message*: Message

View File

@ -2,10 +2,7 @@ import ./protobufutil
import ./common
import libp2p/protobuf/minprotobuf
import std/options
proc toString(bytes: seq[byte]): string =
result = newString(bytes.len)
copyMem(result[0].addr, bytes[0].unsafeAddr, bytes.len)
import "../nim-bloom/src/bloom"
proc toBytes(s: string): seq[byte] =
result = newSeq[byte](s.len)
@ -22,6 +19,7 @@ proc encode*(msg: Message): ProtoBuffer =
pb.write(4, msg.channelId)
pb.write(5, msg.content)
pb.write(6, msg.bloomFilter)
pb.finish()
pb
@ -39,11 +37,10 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
msg.lamportTimestamp = int64(timestamp)
# Decode causal history
var histories: seq[seq[byte]]
for histBytes in histories:
let hist = histBytes.toString
if hist notin msg.causalHistory: # Avoid duplicate entries
msg.causalHistory.add(hist)
var causalHistory: seq[string]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = causalHistory
if not ?pb.getField(4, msg.channelId):
return err(ProtobufError.missingRequiredField("channelId"))
@ -51,6 +48,9 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
if not ?pb.getField(5, msg.content):
return err(ProtobufError.missingRequiredField("content"))
if not ?pb.getField(6, msg.bloomFilter):
msg.bloomFilter = @[] # Empty if not present
ok(msg)
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
@ -67,5 +67,58 @@ proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
ok(msgResult.get)
else:
err(reSerializationError)
except:
err(reDeserializationError)
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
try:
var pb = initProtoBuffer()
# Convert intArray to bytes
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
for i, val in filter.intArray:
let start = i * sizeof(int)
copyMem(addr bytes[start], unsafeAddr val, sizeof(int))
pb.write(1, bytes)
pb.write(2, uint64(filter.capacity))
pb.write(3, uint64(filter.errorRate * 1_000_000))
pb.write(4, uint64(filter.kHashes))
pb.write(5, uint64(filter.mBits))
pb.finish()
ok(pb.buffer)
except:
err(reSerializationError)
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
if data.len == 0:
return err(reDeserializationError)
try:
let pb = initProtoBuffer(data)
var bytes: seq[byte]
var cap, errRate, kHashes, mBits: uint64
if not pb.getField(1, bytes).get() or
not pb.getField(2, cap).get() or
not pb.getField(3, errRate).get() or
not pb.getField(4, kHashes).get() or
not pb.getField(5, mBits).get():
return err(reDeserializationError)
# Convert bytes back to intArray
var intArray = newSeq[int](bytes.len div sizeof(int))
for i in 0 ..< intArray.len:
let start = i * sizeof(int)
copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int))
ok(BloomFilter(
intArray: intArray,
capacity: int(cap),
errorRate: float(errRate) / 1_000_000,
kHashes: int(kHashes),
mBits: int(mBits)
))
except:
err(reDeserializationError)

View File

@ -52,6 +52,39 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau
except:
return err(reOutOfMemory)
proc reviewAckStatus(rm: ReliabilityManager, msg: Message) =
var i = 0
while i < rm.outgoingBuffer.len:
var acknowledged = false
let outMsg = rm.outgoingBuffer[i]
# Check if message is in causal history
for msgID in msg.causalHistory:
if outMsg.message.messageId == msgID:
acknowledged = true
break
# Check bloom filter if not already acknowledged
if not acknowledged and msg.bloomFilter.len > 0:
let bfResult = deserializeBloomFilter(msg.bloomFilter)
if bfResult.isOk:
var rbf = RollingBloomFilter(
filter: bfResult.get(),
window: rm.bloomFilter.window,
messages: @[]
)
if rbf.contains(outMsg.message.messageId):
acknowledged = true
else:
logError("Failed to deserialize bloom filter")
if acknowledged:
if rm.onMessageSent != nil:
rm.onMessageSent(outMsg.message.messageId)
rm.outgoingBuffer.delete(i)
else:
inc i
proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: MessageID): Result[seq[byte], ReliabilityError] =
## Wraps an outgoing message with reliability metadata.
##
@ -68,16 +101,35 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId:
withLock rm.lock:
try:
rm.updateLamportTimestamp(getTime().toUnix)
# Serialize current bloom filter
var bloomBytes: seq[byte]
let bfResult = serializeBloomFilter(rm.bloomFilter.filter)
if bfResult.isErr:
logError("Failed to serialize bloom filter")
bloomBytes = @[]
else:
bloomBytes = bfResult.get()
let msg = Message(
messageId: messageId,
lamportTimestamp: rm.lamportTimestamp,
causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory),
channelId: rm.channelId,
content: message
content: message,
bloomFilter: bloomBytes
)
rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0))
# rm.messageHistory.add(messageId)
# rm.bloomFilter.add(messageId)
# Add to outgoing buffer
rm.outgoingBuffer.add(UnacknowledgedMessage(
message: msg,
sendTime: getTime(),
resendAttempts: 0
))
# Add to causal history and bloom filter
rm.addToBloomAndHistory(msg)
return serializeMessage(msg)
except:
return err(reInternalError)
@ -100,21 +152,24 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[
if rm.bloomFilter.contains(msg.messageId):
return ok((msg.content, @[]))
rm.bloomFilter.add(msg.messageId)
# Update Lamport timestamp
rm.updateLamportTimestamp(msg.lamportTimestamp)
# Review ACK status for outgoing messages
rm.reviewAckStatus(msg)
var missingDeps: seq[MessageID] = @[]
for depId in msg.causalHistory:
if not rm.bloomFilter.contains(depId):
missingDeps.add(depId)
if missingDeps.len == 0:
rm.messageHistory.add(msg.messageId)
if rm.messageHistory.len > rm.config.maxMessageHistory:
rm.messageHistory.delete(0)
# All dependencies met, add to history
rm.addToBloomAndHistory(msg)
if rm.onMessageReady != nil:
rm.onMessageReady(msg.messageId)
else:
# Buffer message and request missing dependencies
rm.incomingBuffer.add(msg)
if rm.onMissingDependencies != nil:
rm.onMissingDependencies(msg.messageId, missingDeps)
@ -136,6 +191,12 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R
var processedMessages: seq[Message] = @[]
var newIncomingBuffer: seq[Message] = @[]
# Add all messageIds to both bloom filter and causal history
for msgId in messageIds:
if not rm.bloomFilter.contains(msgId):
rm.bloomFilter.add(msgId)
rm.messageHistory.add(msgId)
for msg in rm.incomingBuffer:
var allDependenciesMet = true
for depId in msg.causalHistory:
@ -145,15 +206,13 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R
if allDependenciesMet:
processedMessages.add(msg)
rm.addToBloomAndHistory(msg)
else:
newIncomingBuffer.add(msg)
rm.incomingBuffer = newIncomingBuffer
for msg in processedMessages:
rm.messageHistory.add(msg.messageId)
if rm.messageHistory.len > rm.config.maxMessageHistory:
rm.messageHistory.delete(0)
if rm.onMessageReady != nil:
rm.onMessageReady(msg.messageId)

View File

@ -75,6 +75,12 @@ proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
except Exception as e:
logError("Failed to clean ReliabilityManager bloom filter: " & e.msg)
proc addToBloomAndHistory*(rm: ReliabilityManager, msg: Message) =
rm.messageHistory.add(msg.messageId)
if rm.messageHistory.len > rm.config.maxMessageHistory:
rm.messageHistory.delete(0)
rm.bloomFilter.add(msg.messageId)
proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) =
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1

View File

@ -1,6 +1,8 @@
import unittest, results, chronos, chronicles
import unittest, results, chronos
import ../src/reliability
import ../src/common
import ../src/protobuf
import ../src/utils
suite "ReliabilityManager":
var rm: ReliabilityManager
@ -41,38 +43,58 @@ suite "ReliabilityManager":
unwrapped == msg
missingDeps.len == 0
test "markDependenciesMet":
# First message
let msg1 = @[byte(1)]
test "marking dependencies":
var messageReadyCount = 0
var messageSentCount = 0
var missingDepsCount = 0
rm.setCallbacks(
proc(messageId: MessageID) {.gcsafe.} = messageReadyCount += 1,
proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1,
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1
)
# We'll create dependency IDs that aren't in the bloom filter yet
let id1 = "msg1"
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
check wrap1.isOk()
let wrapped1 = wrap1.get()
# Second message
let msg2 = @[byte(2)]
let id2 = "msg2"
let wrap2 = rm.wrapOutgoingMessage(msg2, id2)
check wrap2.isOk()
let wrapped2 = wrap2.get()
# Third message
let msg3 = @[byte(3)]
let id3 = "msg3"
let wrap3 = rm.wrapOutgoingMessage(msg3, id3)
check wrap3.isOk()
let wrapped3 = wrap3.get()
# Create a message that depends on these IDs
let msg3 = Message(
messageId: "msg3",
lamportTimestamp: 1,
causalHistory: @[id1, id2], # Depends on messages we haven't seen
channelId: "testChannel",
content: @[byte(3)],
bloomFilter: @[]
)
# Check dependencies
var unwrap3 = rm.unwrapReceivedMessage(wrapped3)
check unwrap3.isOk()
var (_, missing3) = unwrap3.get()
let serializedMsg3 = serializeMessage(msg3)
check serializedMsg3.isOk()
# Mark dependencies as met
let markResult = rm.markDependenciesMet(@[id1, id2])
# Process the message - should identify missing dependencies
let unwrapResult = rm.unwrapReceivedMessage(serializedMsg3.get())
check unwrapResult.isOk()
let (_, missingDeps) = unwrapResult.get()
# Verify missing dependencies were identified
check missingDepsCount == 1
check missingDeps.len == 2
check id1 in missingDeps
check id2 in missingDeps
# Now mark dependencies as met
let markResult = rm.markDependenciesMet(missingDeps)
check markResult.isOk()
check missing3.len == 0
# Process the message again - should now be ready
let reprocessResult = rm.unwrapReceivedMessage(serializedMsg3.get())
check reprocessResult.isOk()
let (_, remainingDeps) = reprocessResult.get()
# Verify message is now processed
check remainingDeps.len == 0
check messageReadyCount == 1 # msg3 should now be ready
check missingDepsCount == 1 # Only the first attempt should report missing deps
test "callbacks work correctly":
var messageReadyCount = 0
@ -85,18 +107,76 @@ suite "ReliabilityManager":
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1
)
let msg1Result = rm.wrapOutgoingMessage(@[byte(1)], "msg1")
let msg2Result = rm.wrapOutgoingMessage(@[byte(2)], "msg2")
check msg1Result.isOk() and msg2Result.isOk()
let msg1 = msg1Result.get()
let msg2 = msg2Result.get()
discard rm.unwrapReceivedMessage(msg1)
discard rm.unwrapReceivedMessage(msg2)
# First send our own message
let msg1 = @[byte(1)]
let id1 = "msg1"
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
check wrap1.isOk()
check:
messageReadyCount == 2
messageSentCount == 0 # This would be triggered by checkUnacknowledgedMessages
missingDepsCount == 0
# Create a message that has our message in causal history
let msg2 = Message(
messageId: "msg2",
lamportTimestamp: rm.lamportTimestamp + 1,
causalHistory: @[id1], # Include our message in causal history
channelId: "testChannel",
content: @[byte(2)],
bloomFilter: @[] # Test with an empty bloom filter
)
let serializedMsg2 = serializeMessage(msg2)
check serializedMsg2.isOk()
# Process the "received" message - should trigger callbacks
let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get())
check unwrapResult.isOk()
check messageReadyCount == 1 # For msg2 which we "received"
check messageSentCount == 1 # For msg1 which was acknowledged via causal history
test "bloom filter acknowledgment":
var messageSentCount = 0
rm.setCallbacks(
proc(messageId: MessageID) {.gcsafe.} = discard,
proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1,
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard
)
# First send our own message
let msg1 = @[byte(1)]
let id1 = "msg1"
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
check wrap1.isOk()
# Create a message simulating another party's message
# with bloom filter containing our message
var otherPartyBloomFilter = newRollingBloomFilter(
DefaultBloomFilterCapacity,
DefaultBloomFilterErrorRate,
DefaultBloomFilterWindow
)
otherPartyBloomFilter.add(id1) # Add our message to their bloom filter
let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter)
check bfResult.isOk()
let msg2 = Message(
messageId: "msg2",
lamportTimestamp: rm.lamportTimestamp + 1,
causalHistory: @[], # Empty causal history as we're using bloom filter
channelId: "testChannel",
content: @[byte(2)],
bloomFilter: bfResult.get()
)
let serializedMsg2 = serializeMessage(msg2)
check serializedMsg2.isOk()
# Process the "received" message - should trigger acknowledgment
let unwrapResult = rm.unwrapReceivedMessage(serializedMsg2.get())
check unwrapResult.isOk()
check messageSentCount == 1 # Our message should be acknowledged via bloom filter
test "periodic sync callback works":
var syncCallCount = 0