mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-01-02 14:13:07 +00:00
fix: add reliability fixes
This commit is contained in:
commit
40a87a7064
16
reliability.nimble
Normal file
16
reliability.nimble
Normal file
@ -0,0 +1,16 @@
|
||||
# Package
|
||||
version = "0.1.0"
|
||||
author = "Waku Team"
|
||||
description = "E2E Reliability Protocol API"
|
||||
license = "MIT"
|
||||
srcDir = "src"
|
||||
|
||||
# Dependencies
|
||||
requires "nim >= 2.0.8"
|
||||
requires "chronicles"
|
||||
requires "libp2p"
|
||||
|
||||
# Tasks
|
||||
task test, "Run the test suite":
|
||||
exec "nim c -r tests/test_bloom.nim"
|
||||
exec "nim c -r tests/test_reliability.nim"
|
||||
@ -1,24 +1,25 @@
|
||||
import std/times
|
||||
import std/[times, options, sets]
|
||||
|
||||
type
|
||||
MessageID* = string
|
||||
SdsMessageID* = seq[byte]
|
||||
SdsChannelID* = seq[byte]
|
||||
|
||||
Message* = object
|
||||
messageId*: MessageID
|
||||
SdsMessage* = object
|
||||
messageId*: SdsMessageID
|
||||
lamportTimestamp*: int64
|
||||
causalHistory*: seq[MessageID]
|
||||
channelId*: string
|
||||
causalHistory*: seq[SdsMessageID]
|
||||
channelId*: Option[SdsChannelID]
|
||||
content*: seq[byte]
|
||||
bloomFilter*: seq[byte]
|
||||
|
||||
UnacknowledgedMessage* = object
|
||||
message*: Message
|
||||
message*: SdsMessage
|
||||
sendTime*: Time
|
||||
resendAttempts*: int
|
||||
|
||||
TimestampedMessageID* = object
|
||||
id*: MessageID
|
||||
timestamp*: Time
|
||||
IncomingMessage* = object
|
||||
message*: SdsMessage
|
||||
missingDeps*: HashSet[SdsMessageID]
|
||||
|
||||
const
|
||||
DefaultMaxMessageHistory* = 1000
|
||||
|
||||
@ -9,9 +9,7 @@ type
|
||||
TErrorForK = seq[float]
|
||||
TAllErrorRates* = array[0 .. 12, TErrorForK]
|
||||
|
||||
var kErrors* {.threadvar.}: TAllErrorRates
|
||||
|
||||
kErrors = [
|
||||
const kErrors*: TAllErrorRates = [
|
||||
@[1.0],
|
||||
@[
|
||||
1.0, 1.0, 0.3930000000, 0.2830000000, 0.2210000000, 0.1810000000, 0.1540000000,
|
||||
|
||||
@ -1,30 +1,28 @@
|
||||
import libp2p/protobuf/minprotobuf
|
||||
import std/options
|
||||
import endians
|
||||
import ../src/[message, protobufutil, bloom, reliability_utils]
|
||||
|
||||
proc toBytes(s: string): seq[byte] =
|
||||
result = newSeq[byte](s.len)
|
||||
copyMem(result[0].addr, s[0].unsafeAddr, s.len)
|
||||
|
||||
proc encode*(msg: Message): ProtoBuffer =
|
||||
proc encode*(msg: SdsMessage): ProtoBuffer =
|
||||
var pb = initProtoBuffer()
|
||||
|
||||
pb.write(1, msg.messageId)
|
||||
pb.write(2, uint64(msg.lamportTimestamp))
|
||||
|
||||
for hist in msg.causalHistory:
|
||||
pb.write(3, hist.toBytes) # Convert string to bytes for proper length handling
|
||||
pb.write(3, hist)
|
||||
|
||||
pb.write(4, msg.channelId)
|
||||
if msg.channelId.isSome():
|
||||
pb.write(4, msg.channelId.get())
|
||||
pb.write(5, msg.content)
|
||||
pb.write(6, msg.bloomFilter)
|
||||
pb.finish()
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||
proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var msg = Message()
|
||||
var msg = SdsMessage()
|
||||
|
||||
if not ?pb.getField(1, msg.messageId):
|
||||
return err(ProtobufError.missingRequiredField("messageId"))
|
||||
@ -34,14 +32,16 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
|
||||
msg.lamportTimestamp = int64(timestamp)
|
||||
|
||||
# Decode causal history
|
||||
var causalHistory: seq[string]
|
||||
var causalHistory: seq[seq[byte]]
|
||||
let histResult = pb.getRepeatedField(3, causalHistory)
|
||||
if histResult.isOk:
|
||||
msg.causalHistory = causalHistory
|
||||
|
||||
if not ?pb.getField(4, msg.channelId):
|
||||
return err(ProtobufError.missingRequiredField("channelId"))
|
||||
var channelId: seq[byte]
|
||||
if ?pb.getField(4, channelId):
|
||||
msg.channelId = some(channelId)
|
||||
else:
|
||||
msg.channelId = none[SdsChannelID]()
|
||||
|
||||
if not ?pb.getField(5, msg.content):
|
||||
return err(ProtobufError.missingRequiredField("content"))
|
||||
@ -51,63 +51,59 @@ proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||
|
||||
ok(msg)
|
||||
|
||||
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
|
||||
try:
|
||||
let pb = encode(msg)
|
||||
ok(pb.buffer)
|
||||
except:
|
||||
err(reSerializationError)
|
||||
proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] =
|
||||
let pb = encode(msg)
|
||||
ok(pb.buffer)
|
||||
|
||||
proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
|
||||
try:
|
||||
let msgResult = Message.decode(data)
|
||||
if msgResult.isOk:
|
||||
ok(msgResult.get)
|
||||
else:
|
||||
err(reSerializationError)
|
||||
except:
|
||||
err(reDeserializationError)
|
||||
proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] =
|
||||
let msg = SdsMessage.decode(data).valueOr:
|
||||
return err(ReliabilityError.reDeserializationError)
|
||||
ok(msg)
|
||||
|
||||
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
|
||||
try:
|
||||
var pb = initProtoBuffer()
|
||||
var pb = initProtoBuffer()
|
||||
|
||||
# Convert intArray to bytes
|
||||
# Convert intArray to bytes
|
||||
try:
|
||||
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
|
||||
for i, val in filter.intArray:
|
||||
var leVal: int
|
||||
littleEndian64(addr leVal, unsafeAddr val)
|
||||
let start = i * sizeof(int)
|
||||
copyMem(addr bytes[start], unsafeAddr val, sizeof(int))
|
||||
copyMem(addr bytes[start], addr leVal, sizeof(int))
|
||||
|
||||
pb.write(1, bytes)
|
||||
pb.write(2, uint64(filter.capacity))
|
||||
pb.write(3, uint64(filter.errorRate * 1_000_000))
|
||||
pb.write(4, uint64(filter.kHashes))
|
||||
pb.write(5, uint64(filter.mBits))
|
||||
|
||||
pb.finish()
|
||||
ok(pb.buffer)
|
||||
except:
|
||||
err(reSerializationError)
|
||||
return err(ReliabilityError.reSerializationError)
|
||||
|
||||
pb.finish()
|
||||
ok(pb.buffer)
|
||||
|
||||
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
|
||||
if data.len == 0:
|
||||
return err(reDeserializationError)
|
||||
return err(ReliabilityError.reDeserializationError)
|
||||
|
||||
let pb = initProtoBuffer(data)
|
||||
var bytes: seq[byte]
|
||||
var cap, errRate, kHashes, mBits: uint64
|
||||
|
||||
try:
|
||||
let pb = initProtoBuffer(data)
|
||||
var bytes: seq[byte]
|
||||
var cap, errRate, kHashes, mBits: uint64
|
||||
|
||||
if not pb.getField(1, bytes).get() or not pb.getField(2, cap).get() or
|
||||
not pb.getField(3, errRate).get() or not pb.getField(4, kHashes).get() or
|
||||
not pb.getField(5, mBits).get():
|
||||
return err(reDeserializationError)
|
||||
return err(ReliabilityError.reDeserializationError)
|
||||
|
||||
# Convert bytes back to intArray
|
||||
var intArray = newSeq[int](bytes.len div sizeof(int))
|
||||
for i in 0 ..< intArray.len:
|
||||
var leVal: int
|
||||
let start = i * sizeof(int)
|
||||
copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int))
|
||||
copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int))
|
||||
littleEndian64(addr intArray[i], addr leVal)
|
||||
|
||||
ok(
|
||||
BloomFilter(
|
||||
@ -119,4 +115,4 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr
|
||||
)
|
||||
)
|
||||
except:
|
||||
err(reDeserializationError)
|
||||
return err(ReliabilityError.reDeserializationError)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import std/[times, locks, tables, sets]
|
||||
import chronos, results
|
||||
import ../src/[message, protobuf, reliability_utils, rolling_bloom_filter]
|
||||
import std/[times, locks, tables, sets, options]
|
||||
import chronos, results, chronicles
|
||||
import ./[message, protobuf, reliability_utils, rolling_bloom_filter]
|
||||
|
||||
proc newReliabilityManager*(
|
||||
channelId: string, config: ReliabilityConfig = defaultConfig()
|
||||
channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig()
|
||||
): Result[ReliabilityManager, ReliabilityError] =
|
||||
## Creates a new ReliabilityManager with the specified channel ID and configuration.
|
||||
##
|
||||
@ -13,94 +13,113 @@ proc newReliabilityManager*(
|
||||
##
|
||||
## Returns:
|
||||
## A Result containing either a new ReliabilityManager instance or an error.
|
||||
if channelId.len == 0:
|
||||
return err(reInvalidArgument)
|
||||
if not channelId.isSome():
|
||||
return err(ReliabilityError.reInvalidArgument)
|
||||
|
||||
try:
|
||||
let bloomFilter = newRollingBloomFilter(
|
||||
config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow
|
||||
)
|
||||
let bloomFilter =
|
||||
newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate)
|
||||
|
||||
let rm = ReliabilityManager(
|
||||
lamportTimestamp: 0,
|
||||
messageHistory: @[],
|
||||
bloomFilter: bloomFilter,
|
||||
outgoingBuffer: @[],
|
||||
incomingBuffer: @[],
|
||||
incomingBuffer: initTable[SdsMessageID, IncomingMessage](),
|
||||
channelId: channelId,
|
||||
config: config,
|
||||
)
|
||||
initLock(rm.lock)
|
||||
return ok(rm)
|
||||
except:
|
||||
return err(reOutOfMemory)
|
||||
except Exception:
|
||||
error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg()
|
||||
return err(ReliabilityError.reOutOfMemory)
|
||||
|
||||
proc reviewAckStatus(rm: ReliabilityManager, msg: Message) =
|
||||
var i = 0
|
||||
while i < rm.outgoingBuffer.len:
|
||||
var acknowledged = false
|
||||
let outMsg = rm.outgoingBuffer[i]
|
||||
proc isAcknowledged*(
|
||||
msg: UnacknowledgedMessage,
|
||||
causalHistory: seq[SdsMessageID],
|
||||
rbf: Option[RollingBloomFilter],
|
||||
): bool =
|
||||
if msg.message.messageId in causalHistory:
|
||||
return true
|
||||
|
||||
# Check if message is in causal history
|
||||
for msgID in msg.causalHistory:
|
||||
if outMsg.message.messageId == msgID:
|
||||
acknowledged = true
|
||||
break
|
||||
if rbf.isSome():
|
||||
return rbf.get().contains(msg.message.messageId)
|
||||
|
||||
# Check bloom filter if not already acknowledged
|
||||
if not acknowledged and msg.bloomFilter.len > 0:
|
||||
let bfResult = deserializeBloomFilter(msg.bloomFilter)
|
||||
if bfResult.isOk:
|
||||
var rbf = RollingBloomFilter(
|
||||
filter: bfResult.get(), window: rm.bloomFilter.window, messages: @[]
|
||||
false
|
||||
|
||||
proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
|
||||
# Parse bloom filter
|
||||
var rbf: Option[RollingBloomFilter]
|
||||
if msg.bloomFilter.len > 0:
|
||||
let bfResult = deserializeBloomFilter(msg.bloomFilter)
|
||||
if bfResult.isOk():
|
||||
rbf = some(
|
||||
RollingBloomFilter(
|
||||
filter: bfResult.get(),
|
||||
capacity: bfResult.get().capacity,
|
||||
minCapacity: (
|
||||
bfResult.get().capacity.float * (100 - CapacityFlexPercent).float / 100.0
|
||||
).int,
|
||||
maxCapacity: (
|
||||
bfResult.get().capacity.float * (100 + CapacityFlexPercent).float / 100.0
|
||||
).int,
|
||||
messages: @[],
|
||||
)
|
||||
if rbf.contains(outMsg.message.messageId):
|
||||
acknowledged = true
|
||||
else:
|
||||
logError("Failed to deserialize bloom filter")
|
||||
|
||||
if acknowledged:
|
||||
if rm.onMessageSent != nil:
|
||||
rm.onMessageSent(outMsg.message.messageId)
|
||||
rm.outgoingBuffer.delete(i)
|
||||
)
|
||||
else:
|
||||
inc i
|
||||
error "Failed to deserialize bloom filter", error = bfResult.error
|
||||
rbf = none[RollingBloomFilter]()
|
||||
else:
|
||||
rbf = none[RollingBloomFilter]()
|
||||
|
||||
# Keep track of indices to delete
|
||||
var toDelete: seq[int] = @[]
|
||||
var i = 0
|
||||
|
||||
while i < rm.outgoingBuffer.len:
|
||||
let outMsg = rm.outgoingBuffer[i]
|
||||
if outMsg.isAcknowledged(msg.causalHistory, rbf):
|
||||
if not rm.onMessageSent.isNil():
|
||||
rm.onMessageSent(outMsg.message.messageId)
|
||||
toDelete.add(i)
|
||||
inc i
|
||||
|
||||
for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices
|
||||
rm.outgoingBuffer.delete(toDelete[i])
|
||||
|
||||
proc wrapOutgoingMessage*(
|
||||
rm: ReliabilityManager, message: seq[byte], messageId: MessageID
|
||||
rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID
|
||||
): Result[seq[byte], ReliabilityError] =
|
||||
## Wraps an outgoing message with reliability metadata.
|
||||
##
|
||||
## Parameters:
|
||||
## - message: The content of the message to be sent.
|
||||
## - messageId: Unique identifier for the message
|
||||
##
|
||||
## Returns:
|
||||
## A Result containing either a Message object with reliability metadata or an error.
|
||||
## A Result containing either wrapped message bytes or an error.
|
||||
if message.len == 0:
|
||||
return err(reInvalidArgument)
|
||||
return err(ReliabilityError.reInvalidArgument)
|
||||
if message.len > MaxMessageSize:
|
||||
return err(reMessageTooLarge)
|
||||
return err(ReliabilityError.reMessageTooLarge)
|
||||
|
||||
withLock rm.lock:
|
||||
try:
|
||||
rm.updateLamportTimestamp(getTime().toUnix)
|
||||
|
||||
# Serialize current bloom filter
|
||||
var bloomBytes: seq[byte]
|
||||
let bfResult = serializeBloomFilter(rm.bloomFilter.filter)
|
||||
if bfResult.isErr:
|
||||
logError("Failed to serialize bloom filter")
|
||||
bloomBytes = @[]
|
||||
else:
|
||||
bloomBytes = bfResult.get()
|
||||
error "Failed to serialize bloom filter"
|
||||
return err(ReliabilityError.reSerializationError)
|
||||
|
||||
let msg = Message(
|
||||
let msg = SdsMessage(
|
||||
messageId: messageId,
|
||||
lamportTimestamp: rm.lamportTimestamp,
|
||||
causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory),
|
||||
causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory),
|
||||
channelId: rm.channelId,
|
||||
content: message,
|
||||
bloomFilter: bloomBytes,
|
||||
bloomFilter: bfResult.get(),
|
||||
)
|
||||
|
||||
# Add to outgoing buffer
|
||||
@ -113,80 +132,61 @@ proc wrapOutgoingMessage*(
|
||||
rm.addToHistory(msg.messageId)
|
||||
|
||||
return serializeMessage(msg)
|
||||
except:
|
||||
return err(reInternalError)
|
||||
except Exception:
|
||||
error "Failed to wrap message", msg = getCurrentExceptionMsg()
|
||||
return err(ReliabilityError.reSerializationError)
|
||||
|
||||
proc processIncomingBuffer(rm: ReliabilityManager) =
|
||||
proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} =
|
||||
withLock rm.lock:
|
||||
if rm.incomingBuffer.len == 0:
|
||||
return
|
||||
|
||||
# Create dependency map
|
||||
var dependencies = initTable[MessageID, seq[MessageID]]()
|
||||
var readyToProcess: seq[MessageID] = @[]
|
||||
var processed = initHashSet[SdsMessageID]()
|
||||
var readyToProcess = newSeq[SdsMessageID]()
|
||||
|
||||
# Build dependency graph and find initially ready messages
|
||||
for msg in rm.incomingBuffer:
|
||||
var hasMissingDeps = false
|
||||
for depId in msg.causalHistory:
|
||||
if not rm.bloomFilter.contains(depId):
|
||||
hasMissingDeps = true
|
||||
if depId notin dependencies:
|
||||
dependencies[depId] = @[]
|
||||
dependencies[depId].add(msg.messageId)
|
||||
|
||||
if not hasMissingDeps:
|
||||
readyToProcess.add(msg.messageId)
|
||||
|
||||
# Process ready messages and their dependents
|
||||
var newIncomingBuffer: seq[Message] = @[]
|
||||
var processed = initHashSet[MessageID]()
|
||||
# Find initially ready messages
|
||||
for msgId, entry in rm.incomingBuffer:
|
||||
if entry.missingDeps.len == 0:
|
||||
readyToProcess.add(msgId)
|
||||
|
||||
while readyToProcess.len > 0:
|
||||
let msgId = readyToProcess.pop()
|
||||
if msgId in processed:
|
||||
continue
|
||||
|
||||
# Process this message
|
||||
for msg in rm.incomingBuffer:
|
||||
if msg.messageId == msgId:
|
||||
rm.addToHistory(msg.messageId)
|
||||
if rm.onMessageReady != nil:
|
||||
rm.onMessageReady(msg.messageId)
|
||||
processed.incl(msgId)
|
||||
if msgId in rm.incomingBuffer:
|
||||
rm.addToHistory(msgId)
|
||||
if not rm.onMessageReady.isNil():
|
||||
rm.onMessageReady(msgId)
|
||||
processed.incl(msgId)
|
||||
|
||||
# Add any dependent messages that might now be ready
|
||||
if msgId in dependencies:
|
||||
for dependentId in dependencies[msgId]:
|
||||
readyToProcess.add(dependentId)
|
||||
break
|
||||
# Update dependencies for remaining messages
|
||||
for remainingId, entry in rm.incomingBuffer:
|
||||
if remainingId notin processed:
|
||||
if msgId in entry.missingDeps:
|
||||
rm.incomingBuffer[remainingId].missingDeps.excl(msgId)
|
||||
if rm.incomingBuffer[remainingId].missingDeps.len == 0:
|
||||
readyToProcess.add(remainingId)
|
||||
|
||||
# Update incomingBuffer with remaining messages
|
||||
for msg in rm.incomingBuffer:
|
||||
if msg.messageId notin processed:
|
||||
newIncomingBuffer.add(msg)
|
||||
|
||||
rm.incomingBuffer = newIncomingBuffer
|
||||
# Remove processed messages
|
||||
for msgId in processed:
|
||||
rm.incomingBuffer.del(msgId)
|
||||
|
||||
proc unwrapReceivedMessage*(
|
||||
rm: ReliabilityManager, message: seq[byte]
|
||||
): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] {.
|
||||
gcsafe
|
||||
.} =
|
||||
): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] =
|
||||
## Unwraps a received message and processes its reliability metadata.
|
||||
##
|
||||
## Parameters:
|
||||
## - message: The received Message object.
|
||||
## - message: The received message bytes
|
||||
##
|
||||
## Returns:
|
||||
## A Result containing either a tuple with the processed message and missing dependencies, or an error.
|
||||
## A Result containing either tuple of (processed message, missing dependencies) or an error.
|
||||
try:
|
||||
let msgResult = deserializeMessage(message)
|
||||
if not msgResult.isOk:
|
||||
return err(msgResult.error)
|
||||
let msg = deserializeMessage(message).valueOr:
|
||||
return err(ReliabilityError.reDeserializationError)
|
||||
|
||||
let msg = msgResult.get
|
||||
if rm.bloomFilter.contains(msg.messageId):
|
||||
if msg.messageId in rm.messageHistory:
|
||||
return ok((msg.content, @[]))
|
||||
|
||||
rm.bloomFilter.add(msg.messageId)
|
||||
@ -197,38 +197,42 @@ proc unwrapReceivedMessage*(
|
||||
# Review ACK status for outgoing messages
|
||||
rm.reviewAckStatus(msg)
|
||||
|
||||
var missingDeps: seq[MessageID] = @[]
|
||||
for depId in msg.causalHistory:
|
||||
if not rm.bloomFilter.contains(depId):
|
||||
missingDeps.add(depId)
|
||||
var missingDeps = rm.checkDependencies(msg.causalHistory)
|
||||
|
||||
if missingDeps.len == 0:
|
||||
# Check if any dependencies are still in incoming buffer
|
||||
var depsInBuffer = false
|
||||
for bufferedMsg in rm.incomingBuffer:
|
||||
if bufferedMsg.messageId in msg.causalHistory:
|
||||
for msgId, entry in rm.incomingBuffer.pairs():
|
||||
if msgId in msg.causalHistory:
|
||||
depsInBuffer = true
|
||||
break
|
||||
|
||||
if depsInBuffer:
|
||||
rm.incomingBuffer.add(msg)
|
||||
rm.incomingBuffer[msg.messageId] = IncomingMessage(
|
||||
message: msg,
|
||||
missingDeps: initHashSet[SdsMessageID]()
|
||||
)
|
||||
else:
|
||||
# All dependencies met, add to history
|
||||
rm.addToHistory(msg.messageId)
|
||||
rm.processIncomingBuffer()
|
||||
if rm.onMessageReady != nil:
|
||||
if not rm.onMessageReady.isNil():
|
||||
rm.onMessageReady(msg.messageId)
|
||||
else:
|
||||
# Buffer message and request missing dependencies
|
||||
rm.incomingBuffer.add(msg)
|
||||
if rm.onMissingDependencies != nil:
|
||||
rm.incomingBuffer[msg.messageId] = IncomingMessage(
|
||||
message: msg,
|
||||
missingDeps: missingDeps.toHashSet()
|
||||
)
|
||||
if not rm.onMissingDependencies.isNil():
|
||||
rm.onMissingDependencies(msg.messageId, missingDeps)
|
||||
|
||||
return ok((msg.content, missingDeps))
|
||||
except:
|
||||
return err(reInternalError)
|
||||
except Exception:
|
||||
error "Failed to unwrap message", msg = getCurrentExceptionMsg()
|
||||
return err(ReliabilityError.reDeserializationError)
|
||||
|
||||
proc markDependenciesMet*(
|
||||
rm: ReliabilityManager, messageIds: seq[MessageID]
|
||||
rm: ReliabilityManager, messageIds: seq[SdsMessageID]
|
||||
): Result[void, ReliabilityError] =
|
||||
## Marks the specified message dependencies as met.
|
||||
##
|
||||
@ -243,11 +247,17 @@ proc markDependenciesMet*(
|
||||
if not rm.bloomFilter.contains(msgId):
|
||||
rm.bloomFilter.add(msgId)
|
||||
# rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application?
|
||||
rm.processIncomingBuffer()
|
||||
|
||||
# Update any pending messages that depend on this one
|
||||
for pendingId, entry in rm.incomingBuffer:
|
||||
if msgId in entry.missingDeps:
|
||||
rm.incomingBuffer[pendingId].missingDeps.excl(msgId)
|
||||
|
||||
rm.processIncomingBuffer()
|
||||
return ok()
|
||||
except:
|
||||
return err(reInternalError)
|
||||
except Exception:
|
||||
error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg()
|
||||
return err(ReliabilityError.reInternalError)
|
||||
|
||||
proc setCallbacks*(
|
||||
rm: ReliabilityManager,
|
||||
@ -269,53 +279,52 @@ proc setCallbacks*(
|
||||
rm.onMissingDependencies = onMissingDependencies
|
||||
rm.onPeriodicSync = onPeriodicSync
|
||||
|
||||
proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} =
|
||||
proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} =
|
||||
## Checks and processes unacknowledged messages in the outgoing buffer.
|
||||
withLock rm.lock:
|
||||
let now = getTime()
|
||||
var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[]
|
||||
|
||||
try:
|
||||
for unackMsg in rm.outgoingBuffer:
|
||||
let elapsed = now - unackMsg.sendTime
|
||||
if elapsed > rm.config.resendInterval:
|
||||
# Time to attempt resend
|
||||
if unackMsg.resendAttempts < rm.config.maxResendAttempts:
|
||||
var updatedMsg = unackMsg
|
||||
updatedMsg.resendAttempts += 1
|
||||
updatedMsg.sendTime = now
|
||||
newOutgoingBuffer.add(updatedMsg)
|
||||
else:
|
||||
if rm.onMessageSent != nil:
|
||||
rm.onMessageSent(unackMsg.message.messageId)
|
||||
for unackMsg in rm.outgoingBuffer:
|
||||
let elapsed = now - unackMsg.sendTime
|
||||
if elapsed > rm.config.resendInterval:
|
||||
# Time to attempt resend
|
||||
if unackMsg.resendAttempts < rm.config.maxResendAttempts:
|
||||
var updatedMsg = unackMsg
|
||||
updatedMsg.resendAttempts += 1
|
||||
updatedMsg.sendTime = now
|
||||
newOutgoingBuffer.add(updatedMsg)
|
||||
else:
|
||||
newOutgoingBuffer.add(unackMsg)
|
||||
if not rm.onMessageSent.isNil():
|
||||
rm.onMessageSent(unackMsg.message.messageId)
|
||||
else:
|
||||
newOutgoingBuffer.add(unackMsg)
|
||||
|
||||
rm.outgoingBuffer = newOutgoingBuffer
|
||||
except Exception as e:
|
||||
logError("Error in checking unacknowledged messages: " & e.msg)
|
||||
rm.outgoingBuffer = newOutgoingBuffer
|
||||
|
||||
proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} =
|
||||
proc periodicBufferSweep(
|
||||
rm: ReliabilityManager
|
||||
) {.async: (raises: [CancelledError]), gcsafe.} =
|
||||
## Periodically sweeps the buffer to clean up and check unacknowledged messages.
|
||||
while true:
|
||||
{.gcsafe.}:
|
||||
try:
|
||||
rm.checkUnacknowledgedMessages()
|
||||
rm.cleanBloomFilter()
|
||||
except Exception as e:
|
||||
logError("Error in periodic buffer sweep: " & e.msg)
|
||||
try:
|
||||
rm.checkUnacknowledgedMessages()
|
||||
rm.cleanBloomFilter()
|
||||
except Exception:
|
||||
error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg()
|
||||
|
||||
await sleepAsync(chronos.milliseconds(rm.config.bufferSweepInterval.inMilliseconds))
|
||||
|
||||
proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} =
|
||||
proc periodicSyncMessage(
|
||||
rm: ReliabilityManager
|
||||
) {.async: (raises: [CancelledError]), gcsafe.} =
|
||||
## Periodically notifies to send a sync message to maintain connectivity.
|
||||
while true:
|
||||
{.gcsafe.}:
|
||||
try:
|
||||
if rm.onPeriodicSync != nil:
|
||||
rm.onPeriodicSync()
|
||||
except Exception as e:
|
||||
logError("Error in periodic sync: " & e.msg)
|
||||
try:
|
||||
if not rm.onPeriodicSync.isNil():
|
||||
rm.onPeriodicSync()
|
||||
except Exception:
|
||||
error "Error in periodic sync", msg = getCurrentExceptionMsg()
|
||||
await sleepAsync(chronos.seconds(rm.config.syncMessageInterval.inSeconds))
|
||||
|
||||
proc startPeriodicTasks*(rm: ReliabilityManager) =
|
||||
@ -329,19 +338,16 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE
|
||||
## Resets the ReliabilityManager to its initial state.
|
||||
##
|
||||
## This procedure clears all buffers and resets the Lamport timestamp.
|
||||
##
|
||||
## Returns:
|
||||
## A Result indicating success or an error if the Bloom filter initialization fails.
|
||||
withLock rm.lock:
|
||||
try:
|
||||
rm.lamportTimestamp = 0
|
||||
rm.messageHistory.setLen(0)
|
||||
rm.outgoingBuffer.setLen(0)
|
||||
rm.incomingBuffer.setLen(0)
|
||||
rm.incomingBuffer.clear()
|
||||
rm.bloomFilter = newRollingBloomFilter(
|
||||
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate,
|
||||
rm.config.bloomFilterWindow,
|
||||
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate
|
||||
)
|
||||
return ok()
|
||||
except:
|
||||
return err(reInternalError)
|
||||
except Exception:
|
||||
error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg()
|
||||
return err(ReliabilityError.reInternalError)
|
||||
|
||||
@ -1,49 +1,43 @@
|
||||
import std/[times, locks]
|
||||
import std/[times, locks, options]
|
||||
import chronicles
|
||||
import ./[rolling_bloom_filter, message]
|
||||
|
||||
type
|
||||
MessageReadyCallback* = proc(messageId: MessageID) {.gcsafe.}
|
||||
MessageReadyCallback* = proc(messageId: SdsMessageID) {.gcsafe.}
|
||||
|
||||
MessageSentCallback* = proc(messageId: MessageID) {.gcsafe.}
|
||||
MessageSentCallback* = proc(messageId: SdsMessageID) {.gcsafe.}
|
||||
|
||||
MissingDependenciesCallback* =
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}
|
||||
proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.}
|
||||
|
||||
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
|
||||
|
||||
AppCallbacks* = ref object
|
||||
messageReadyCb*: MessageReadyCallback
|
||||
messageSentCb*: MessageSentCallback
|
||||
missingDependenciesCb*: MissingDependenciesCallback
|
||||
periodicSyncCb*: PeriodicSyncCallback
|
||||
|
||||
ReliabilityConfig* = object
|
||||
bloomFilterCapacity*: int
|
||||
bloomFilterErrorRate*: float
|
||||
bloomFilterWindow*: times.Duration
|
||||
maxMessageHistory*: int
|
||||
maxCausalHistory*: int
|
||||
resendInterval*: times.Duration
|
||||
resendInterval*: Duration
|
||||
maxResendAttempts*: int
|
||||
syncMessageInterval*: times.Duration
|
||||
bufferSweepInterval*: times.Duration
|
||||
syncMessageInterval*: Duration
|
||||
bufferSweepInterval*: Duration
|
||||
|
||||
ReliabilityManager* = ref object
|
||||
lamportTimestamp*: int64
|
||||
messageHistory*: seq[MessageID]
|
||||
messageHistory*: seq[SdsMessageID]
|
||||
bloomFilter*: RollingBloomFilter
|
||||
outgoingBuffer*: seq[UnacknowledgedMessage]
|
||||
incomingBuffer*: seq[Message]
|
||||
channelId*: string
|
||||
incomingBuffer*: Table[SdsMessageID, IncomingMessage]
|
||||
channelId*: Option[SdsChannelID]
|
||||
config*: ReliabilityConfig
|
||||
lock*: Lock
|
||||
onMessageReady*: proc(messageId: MessageID) {.gcsafe.}
|
||||
onMessageSent*: proc(messageId: MessageID) {.gcsafe.}
|
||||
onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.}
|
||||
onMessageSent*: proc(messageId: SdsMessageID) {.gcsafe.}
|
||||
onMissingDependencies*:
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}
|
||||
onPeriodicSync*: proc()
|
||||
proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.}
|
||||
onPeriodicSync*: PeriodicSyncCallback
|
||||
|
||||
ReliabilityError* = enum
|
||||
ReliabilityError* {.pure.} = enum
|
||||
reInvalidArgument
|
||||
reOutOfMemory
|
||||
reInternalError
|
||||
@ -59,7 +53,6 @@ proc defaultConfig*(): ReliabilityConfig =
|
||||
ReliabilityConfig(
|
||||
bloomFilterCapacity: DefaultBloomFilterCapacity,
|
||||
bloomFilterErrorRate: DefaultBloomFilterErrorRate,
|
||||
bloomFilterWindow: DefaultBloomFilterWindow,
|
||||
maxMessageHistory: DefaultMaxMessageHistory,
|
||||
maxCausalHistory: DefaultMaxCausalHistory,
|
||||
resendInterval: DefaultResendInterval,
|
||||
@ -69,23 +62,23 @@ proc defaultConfig*(): ReliabilityConfig =
|
||||
)
|
||||
|
||||
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
|
||||
if not rm.isNil:
|
||||
{.gcsafe.}:
|
||||
try:
|
||||
if not rm.isNil():
|
||||
try:
|
||||
withLock rm.lock:
|
||||
rm.outgoingBuffer.setLen(0)
|
||||
rm.incomingBuffer.setLen(0)
|
||||
rm.incomingBuffer.clear()
|
||||
rm.messageHistory.setLen(0)
|
||||
except Exception as e:
|
||||
logError("Error during cleanup: " & e.msg)
|
||||
except Exception:
|
||||
error "Error during cleanup", error = getCurrentExceptionMsg()
|
||||
|
||||
proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
||||
withLock rm.lock:
|
||||
try:
|
||||
rm.bloomFilter.clean()
|
||||
except Exception as e:
|
||||
logError("Failed to clean ReliabilityManager bloom filter: " & e.msg)
|
||||
except Exception:
|
||||
error "Failed to clean bloom filter", error = getCurrentExceptionMsg()
|
||||
|
||||
proc addToHistory*(rm: ReliabilityManager, msgId: MessageID) {.gcsafe, raises: [].} =
|
||||
proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
||||
rm.messageHistory.add(msgId)
|
||||
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
||||
rm.messageHistory.delete(0)
|
||||
@ -95,10 +88,19 @@ proc updateLamportTimestamp*(
|
||||
) {.gcsafe, raises: [].} =
|
||||
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
|
||||
|
||||
proc getRecentMessageIDs*(rm: ReliabilityManager, n: int): seq[MessageID] =
|
||||
proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] =
|
||||
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
|
||||
|
||||
proc getMessageHistory*(rm: ReliabilityManager): seq[MessageID] =
|
||||
proc checkDependencies*(
|
||||
rm: ReliabilityManager, deps: seq[SdsMessageID]
|
||||
): seq[SdsMessageID] =
|
||||
var missingDeps: seq[SdsMessageID] = @[]
|
||||
for depId in deps:
|
||||
if depId notin rm.messageHistory:
|
||||
missingDeps.add(depId)
|
||||
return missingDeps
|
||||
|
||||
proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] =
|
||||
withLock rm.lock:
|
||||
result = rm.messageHistory
|
||||
|
||||
@ -106,6 +108,8 @@ proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] =
|
||||
withLock rm.lock:
|
||||
result = rm.outgoingBuffer
|
||||
|
||||
proc getIncomingBuffer*(rm: ReliabilityManager): seq[Message] =
|
||||
proc getIncomingBuffer*(
|
||||
rm: ReliabilityManager
|
||||
): Table[SdsMessageID, message.IncomingMessage] =
|
||||
withLock rm.lock:
|
||||
result = rm.incomingBuffer
|
||||
|
||||
@ -1,64 +1,113 @@
|
||||
import std/times
|
||||
import chronos
|
||||
import chronicles
|
||||
import ./[bloom, message]
|
||||
|
||||
type RollingBloomFilter* = object
|
||||
filter*: BloomFilter
|
||||
window*: times.Duration
|
||||
messages*: seq[TimestampedMessageID]
|
||||
capacity*: int
|
||||
minCapacity*: int
|
||||
maxCapacity*: int
|
||||
messages*: seq[SdsMessageID]
|
||||
|
||||
const
|
||||
DefaultBloomFilterCapacity* = 10000
|
||||
DefaultBloomFilterErrorRate* = 0.001
|
||||
DefaultBloomFilterWindow* = initDuration(hours = 1)
|
||||
|
||||
proc logError*(msg: string) =
|
||||
error "ReliabilityError", message = msg
|
||||
|
||||
proc logInfo*(msg: string) =
|
||||
info "ReliabilityInfo", message = msg
|
||||
CapacityFlexPercent* = 20
|
||||
|
||||
proc newRollingBloomFilter*(
|
||||
capacity: int, errorRate: float, window: times.Duration
|
||||
capacity: int = DefaultBloomFilterCapacity,
|
||||
errorRate: float = DefaultBloomFilterErrorRate,
|
||||
): RollingBloomFilter {.gcsafe.} =
|
||||
try:
|
||||
var filterResult: Result[BloomFilter, string]
|
||||
{.gcsafe.}:
|
||||
filterResult = initializeBloomFilter(capacity, errorRate)
|
||||
let targetCapacity = if capacity <= 0: DefaultBloomFilterCapacity else: capacity
|
||||
let targetError =
|
||||
if errorRate <= 0.0 or errorRate >= 1.0: DefaultBloomFilterErrorRate else: errorRate
|
||||
|
||||
let filterResult = initializeBloomFilter(targetCapacity, targetError)
|
||||
if filterResult.isErr:
|
||||
error "Failed to initialize bloom filter", error = filterResult.error
|
||||
# Try with default values if custom values failed
|
||||
if capacity != DefaultBloomFilterCapacity or errorRate != DefaultBloomFilterErrorRate:
|
||||
let defaultResult =
|
||||
initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
|
||||
if defaultResult.isErr:
|
||||
error "Failed to initialize bloom filter with default parameters",
|
||||
error = defaultResult.error
|
||||
|
||||
let minCapacity = (
|
||||
DefaultBloomFilterCapacity.float * (100 - CapacityFlexPercent).float / 100.0
|
||||
).int
|
||||
let maxCapacity = (
|
||||
DefaultBloomFilterCapacity.float * (100 + CapacityFlexPercent).float / 100.0
|
||||
).int
|
||||
|
||||
info "Successfully initialized bloom filter with default parameters",
|
||||
capacity = DefaultBloomFilterCapacity,
|
||||
minCapacity = minCapacity,
|
||||
maxCapacity = maxCapacity
|
||||
|
||||
if filterResult.isOk:
|
||||
logInfo("Successfully initialized bloom filter")
|
||||
return RollingBloomFilter(
|
||||
filter: filterResult.get(), # Extract the BloomFilter from Result
|
||||
window: window,
|
||||
filter: defaultResult.get(),
|
||||
capacity: DefaultBloomFilterCapacity,
|
||||
minCapacity: minCapacity,
|
||||
maxCapacity: maxCapacity,
|
||||
messages: @[],
|
||||
)
|
||||
else:
|
||||
logError("Failed to initialize bloom filter: " & filterResult.error)
|
||||
# Fall through to default case below
|
||||
except:
|
||||
logError("Failed to initialize bloom filter")
|
||||
error "Could not create bloom filter", error = filterResult.error
|
||||
|
||||
# Default fallback case
|
||||
let defaultResult =
|
||||
initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
|
||||
if defaultResult.isOk:
|
||||
return
|
||||
RollingBloomFilter(filter: defaultResult.get(), window: window, messages: @[])
|
||||
else:
|
||||
# If even default initialization fails, raise an exception
|
||||
logError("Failed to initialize bloom filter with default parameters")
|
||||
let minCapacity =
|
||||
(targetCapacity.float * (100 - CapacityFlexPercent).float / 100.0).int
|
||||
let maxCapacity =
|
||||
(targetCapacity.float * (100 + CapacityFlexPercent).float / 100.0).int
|
||||
|
||||
proc add*(rbf: var RollingBloomFilter, messageId: MessageID) {.gcsafe.} =
|
||||
info "Successfully initialized bloom filter",
|
||||
capacity = targetCapacity, minCapacity = minCapacity, maxCapacity = maxCapacity
|
||||
|
||||
return RollingBloomFilter(
|
||||
filter: filterResult.get(),
|
||||
capacity: targetCapacity,
|
||||
minCapacity: minCapacity,
|
||||
maxCapacity: maxCapacity,
|
||||
messages: @[],
|
||||
)
|
||||
|
||||
proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
|
||||
try:
|
||||
if rbf.messages.len <= rbf.maxCapacity:
|
||||
return # Don't clean unless we exceed max capacity
|
||||
|
||||
# Initialize new filter
|
||||
var newFilter = initializeBloomFilter(rbf.maxCapacity, rbf.filter.errorRate).valueOr:
|
||||
error "Failed to create new bloom filter", error = $error
|
||||
return
|
||||
|
||||
# Keep most recent messages up to minCapacity
|
||||
let keepCount = rbf.minCapacity
|
||||
let startIdx = max(0, rbf.messages.len - keepCount)
|
||||
var newMessages: seq[SdsMessageID] = @[]
|
||||
|
||||
for i in startIdx ..< rbf.messages.len:
|
||||
newMessages.add(rbf.messages[i])
|
||||
newFilter.insert(cast[string](rbf.messages[i]))
|
||||
|
||||
rbf.messages = newMessages
|
||||
rbf.filter = newFilter
|
||||
except Exception:
|
||||
error "Failed to clean bloom filter", error = getCurrentExceptionMsg()
|
||||
|
||||
proc add*(rbf: var RollingBloomFilter, messageId: SdsMessageID) {.gcsafe.} =
|
||||
## Adds a message ID to the rolling bloom filter.
|
||||
##
|
||||
## Parameters:
|
||||
## - messageId: The ID of the message to add.
|
||||
rbf.filter.insert(messageId)
|
||||
rbf.messages.add(TimestampedMessageID(id: messageId, timestamp: getTime()))
|
||||
rbf.filter.insert(cast[string](messageId))
|
||||
rbf.messages.add(messageId)
|
||||
|
||||
proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} =
|
||||
# Clean if we exceed max capacity
|
||||
if rbf.messages.len > rbf.maxCapacity:
|
||||
rbf.clean()
|
||||
|
||||
proc contains*(rbf: RollingBloomFilter, messageId: SdsMessageID): bool =
|
||||
## Checks if a message ID is in the rolling bloom filter.
|
||||
##
|
||||
## Parameters:
|
||||
@ -66,29 +115,4 @@ proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} =
|
||||
##
|
||||
## Returns:
|
||||
## True if the message ID is probably in the filter, false otherwise.
|
||||
rbf.filter.lookup(messageId)
|
||||
|
||||
proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
|
||||
try:
|
||||
let now = getTime()
|
||||
let cutoff = now - rbf.window
|
||||
var newMessages: seq[TimestampedMessageID] = @[]
|
||||
|
||||
# Initialize new filter
|
||||
let newFilterResult =
|
||||
initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate)
|
||||
if newFilterResult.isErr:
|
||||
logError("Failed to create new bloom filter: " & newFilterResult.error)
|
||||
return
|
||||
|
||||
var newFilter = newFilterResult.get()
|
||||
|
||||
for msg in rbf.messages:
|
||||
if msg.timestamp > cutoff:
|
||||
newMessages.add(msg)
|
||||
newFilter.insert(msg.id)
|
||||
|
||||
rbf.messages = newMessages
|
||||
rbf.filter = newFilter
|
||||
except Exception as e:
|
||||
logError("Failed to clean bloom filter: " & e.msg)
|
||||
rbf.filter.lookup(cast[string](messageId))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user