generic refactor to make the code more aligned to logos-delivery style (#62)

* generic refactor to make the code more aligned to logos-delivery style
* use explicit return statement
This commit is contained in:
Ivan FB 2026-04-24 09:50:18 +02:00 committed by GitHub
parent 6f49a9742a
commit 8ee857c908
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 445 additions and 328 deletions

109
sds.nim
View File

@ -1,24 +1,15 @@
import std/[times, locks, tables, sets, options]
import chronos, results, chronicles
import sds/[message, protobuf, sds_utils, rolling_bloom_filter]
import sds/[types, protobuf, sds_utils, rolling_bloom_filter]
export message, protobuf, sds_utils, rolling_bloom_filter
export types, protobuf, sds_utils, rolling_bloom_filter
proc newReliabilityManager*(
config: ReliabilityConfig = defaultConfig()
): Result[ReliabilityManager, ReliabilityError] =
## Creates a new multi-channel ReliabilityManager.
##
## Parameters:
## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used.
##
## Returns:
## A Result containing either a new ReliabilityManager instance or an error.
try:
let rm = ReliabilityManager(
channels: initTable[SdsChannelID, ChannelContext](), config: config
)
initLock(rm.lock)
let rm = ReliabilityManager.new(config)
return ok(rm)
except Exception:
error "Failed to create ReliabilityManager", msg = getCurrentExceptionMsg()
@ -35,25 +26,20 @@ proc isAcknowledged*(
if rbf.isSome():
return rbf.get().contains(msg.message.messageId)
false
return false
proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
# Parse bloom filter
var rbf: Option[RollingBloomFilter]
if msg.bloomFilter.len > 0:
let bfResult = deserializeBloomFilter(msg.bloomFilter)
if bfResult.isOk():
let bf = bfResult.get()
rbf = some(
RollingBloomFilter(
filter: bfResult.get(),
capacity: bfResult.get().capacity,
minCapacity: (
bfResult.get().capacity.float * (100 - CapacityFlexPercent).float / 100.0
).int,
maxCapacity: (
bfResult.get().capacity.float * (100 + CapacityFlexPercent).float / 100.0
).int,
messages: @[],
RollingBloomFilter.init(
filter = bf,
capacity = bf.capacity,
minCapacity = (bf.capacity.float * (100 - CapacityFlexPercent).float / 100.0).int,
maxCapacity = (bf.capacity.float * (100 + CapacityFlexPercent).float / 100.0).int,
)
)
else:
@ -66,7 +52,6 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
return
let channel = rm.channels[msg.channelId]
# Keep track of indices to delete
var toDelete: seq[int] = @[]
var i = 0
@ -78,7 +63,7 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
toDelete.add(i)
inc i
for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices
for i in countdown(toDelete.high, 0):
channel.outgoingBuffer.delete(toDelete[i])
proc wrapOutgoingMessage*(
@ -88,14 +73,6 @@ proc wrapOutgoingMessage*(
channelId: SdsChannelID,
): Result[seq[byte], ReliabilityError] =
## Wraps an outgoing message with reliability metadata.
##
## Parameters:
## - message: The content of the message to be sent.
## - messageId: Unique identifier for the message
## - channelId: Identifier for the channel this message belongs to.
##
## Returns:
## A Result containing either wrapped message bytes or an error.
if message.len == 0:
return err(ReliabilityError.reInvalidArgument)
if message.len > MaxMessageSize:
@ -111,20 +88,19 @@ proc wrapOutgoingMessage*(
error "Failed to serialize bloom filter", channelId = channelId
return err(ReliabilityError.reSerializationError)
let msg = SdsMessage(
messageId: messageId,
lamportTimestamp: channel.lamportTimestamp,
causalHistory: rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId),
channelId: channelId,
content: message,
bloomFilter: bfResult.get(),
let msg = SdsMessage.init(
messageId = messageId,
lamportTimestamp = channel.lamportTimestamp,
causalHistory = rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId),
channelId = channelId,
content = message,
bloomFilter = bfResult.get(),
)
channel.outgoingBuffer.add(
UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)
UnacknowledgedMessage.init(message = msg, sendTime = getTime(), resendAttempts = 0)
)
# Add to causal history and bloom filter
channel.bloomFilter.add(msg.messageId)
rm.addToHistory(msg.messageId, channelId)
@ -147,7 +123,6 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
var processed = initHashSet[SdsMessageID]()
var readyToProcess = newSeq[SdsMessageID]()
# Find initially ready messages
for msgId, entry in channel.incomingBuffer:
if entry.missingDeps.len == 0:
readyToProcess.add(msgId)
@ -163,7 +138,6 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
rm.onMessageReady(msgId, channelId)
processed.incl(msgId)
# Update dependencies for remaining messages
for remainingId, entry in channel.incomingBuffer:
if remainingId notin processed:
if msgId in entry.missingDeps:
@ -171,7 +145,6 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
if channel.incomingBuffer[remainingId].missingDeps.len == 0:
readyToProcess.add(remainingId)
# Remove processed messages
for msgId in processed:
channel.incomingBuffer.del(msgId)
@ -182,12 +155,6 @@ proc unwrapReceivedMessage*(
ReliabilityError,
] =
## Unwraps a received message and processes its reliability metadata.
##
## Parameters:
## - message: The received message bytes
##
## Returns:
## A Result containing either tuple of (processed message, missing dependencies, channel ID) or an error.
try:
let channelId = extractChannelId(message).valueOr:
return err(ReliabilityError.reDeserializationError)
@ -203,7 +170,6 @@ proc unwrapReceivedMessage*(
channel.bloomFilter.add(msg.messageId)
rm.updateLamportTimestamp(msg.lamportTimestamp, channelId)
# Review ACK status for outgoing messages
rm.reviewAckStatus(msg)
var missingDeps = rm.checkDependencies(msg.causalHistory, channelId)
@ -214,19 +180,20 @@ proc unwrapReceivedMessage*(
if msgId in msg.causalHistory.getMessageIds():
depsInBuffer = true
break
# Check if any dependencies are still in incoming buffer
if depsInBuffer:
channel.incomingBuffer[msg.messageId] =
IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]())
IncomingMessage.init(message = msg, missingDeps = initHashSet[SdsMessageID]())
else:
# All dependencies met, add to history
rm.addToHistory(msg.messageId, channelId)
rm.processIncomingBuffer(channelId)
if not rm.onMessageReady.isNil():
rm.onMessageReady(msg.messageId, channelId)
else:
channel.incomingBuffer[msg.messageId] =
IncomingMessage(message: msg, missingDeps: missingDeps.getMessageIds().toHashSet())
IncomingMessage.init(
message = msg,
missingDeps = missingDeps.getMessageIds().toHashSet(),
)
if not rm.onMissingDependencies.isNil():
rm.onMissingDependencies(msg.messageId, missingDeps, channelId)
@ -239,13 +206,6 @@ proc markDependenciesMet*(
rm: ReliabilityManager, messageIds: seq[SdsMessageID], channelId: SdsChannelID
): Result[void, ReliabilityError] =
## Marks the specified message dependencies as met.
##
## Parameters:
## - messageIds: A sequence of message IDs to mark as met.
## - channelId: Identifier for the channel.
##
## Returns:
## A Result indicating success or an error.
try:
if channelId notin rm.channels:
return err(ReliabilityError.reInvalidArgument)
@ -273,16 +233,9 @@ proc setCallbacks*(
onMessageSent: MessageSentCallback,
onMissingDependencies: MissingDependenciesCallback,
onPeriodicSync: PeriodicSyncCallback = nil,
onRetrievalHint: RetrievalHintProvider = nil
onRetrievalHint: RetrievalHintProvider = nil,
) =
## Sets the callback functions for various events in the ReliabilityManager.
##
## Parameters:
## - onMessageReady: Callback function called when a message is ready to be processed.
## - onMessageSent: Callback function called when a message is confirmed as sent.
## - onMissingDependencies: Callback function called when a message has missing dependencies.
## - onPeriodicSync: Callback function called to notify about periodic sync
## - onRetrievalHint: Callback function called to get a retrieval hint for a message ID.
withLock rm.lock:
rm.onMessageReady = onMessageReady
rm.onMessageSent = onMessageSent
@ -293,7 +246,6 @@ proc setCallbacks*(
proc checkUnacknowledgedMessages(
rm: ReliabilityManager, channelId: SdsChannelID
) {.gcsafe.} =
## Checks and processes unacknowledged messages in the outgoing buffer.
withLock rm.lock:
if channelId notin rm.channels:
error "Channel does not exist", channelId = channelId
@ -322,7 +274,6 @@ proc checkUnacknowledgedMessages(
proc periodicBufferSweep(
rm: ReliabilityManager
) {.async: (raises: [CancelledError]), gcsafe.} =
## Periodically sweeps the buffer to clean up and check unacknowledged messages.
while true:
try:
for channelId, channel in rm.channels:
@ -340,7 +291,6 @@ proc periodicBufferSweep(
proc periodicSyncMessage(
rm: ReliabilityManager
) {.async: (raises: [CancelledError]), gcsafe.} =
## Periodically notifies to send a sync message to maintain connectivity.
while true:
try:
if not rm.onPeriodicSync.isNil():
@ -351,15 +301,11 @@ proc periodicSyncMessage(
proc startPeriodicTasks*(rm: ReliabilityManager) =
## Starts the periodic tasks for buffer sweeping and sync message sending.
##
## This procedure should be called after creating a ReliabilityManager to enable automatic maintenance.
asyncSpawn rm.periodicBufferSweep()
asyncSpawn rm.periodicSyncMessage()
proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] =
## Resets the ReliabilityManager to its initial state.
##
## This procedure clears all buffers and resets the Lamport timestamp.
withLock rm.lock:
try:
for channelId, channel in rm.channels:
@ -367,9 +313,8 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE
channel.messageHistory.setLen(0)
channel.outgoingBuffer.setLen(0)
channel.incomingBuffer.clear()
channel.bloomFilter = newRollingBloomFilter(
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate
)
channel.bloomFilter =
RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate)
rm.channels.clear()
return ok()
except Exception:

View File

@ -3,13 +3,8 @@ import hashes
import strutils
import results
import private/probabilities
type BloomFilter* = object
capacity*: int
errorRate*: float
kHashes*: int
mBits*: int
intArray*: seq[int]
import ./types/bloom_filter
export bloom_filter
{.push overflowChecks: off.} # Turn off overflow checks for hashing operations
@ -20,13 +15,7 @@ proc hashN(item: string, n: int, maxValue: int): int =
let
hashA = abs(hash(item)) mod maxValue # Use abs to handle negative hashes
hashB = abs(hash(item & " b")) mod maxValue # string concatenation
abs((hashA + n * hashB)) mod maxValue
# # Use bit rotation for second hash instead of string concatenation if speed if preferred over FP-rate
# # Rotate left by 21 bits (lower the rotation, higher the speed but higher the FP-rate too)
# hashB = abs(
# ((h shl 21) or (h shr (sizeof(int) * 8 - 21)))
# ) mod maxValue
# abs((hashA + n.int64 * hashB)) mod maxValue
return abs((hashA + n * hashB)) mod maxValue
{.pop.}
@ -41,7 +30,7 @@ proc getMOverNBitsForK*(
if probabilityTable[k][mOverN] < targetError:
return ok(mOverN)
err(
return err(
"Specified value of k and error rate not achievable using less than 4 bytes / element."
)
@ -79,31 +68,31 @@ proc initializeBloomFilter*(
mBits = capacity * nBitsPerElem
mInts = 1 + mBits div (sizeof(int) * 8)
ok(
BloomFilter(
capacity: capacity,
errorRate: errorRate,
kHashes: kHashes,
mBits: mBits,
intArray: newSeq[int](mInts),
return ok(
BloomFilter.init(
capacity = capacity,
errorRate = errorRate,
kHashes = kHashes,
mBits = mBits,
intArray = newSeq[int](mInts),
)
)
proc `$`*(bf: BloomFilter): string =
## Prints the configuration of the Bloom filter.
"Bloom filter with $1 capacity, $2 error rate, $3 hash functions, and requiring $4 bits of memory." %
[
$bf.capacity,
formatFloat(bf.errorRate, format = ffScientific, precision = 1),
$bf.kHashes,
$(bf.mBits div bf.capacity),
]
return "Bloom filter with $1 capacity, $2 error rate, $3 hash functions, and requiring $4 bits of memory." %
[
$bf.capacity,
formatFloat(bf.errorRate, format = ffScientific, precision = 1),
$bf.kHashes,
$(bf.mBits div bf.capacity),
]
proc computeHashes(bf: BloomFilter, item: string): seq[int] =
var hashes = newSeq[int](bf.kHashes)
for i in 0 ..< bf.kHashes:
hashes[i] = hashN(item, i, bf.mBits)
hashes
return hashes
proc insert*(bf: var BloomFilter, item: string) =
## Insert an item (string) into the Bloom filter.
@ -127,4 +116,4 @@ proc lookup*(bf: BloomFilter, item: string): bool =
currentInt = bf.intArray[intAddress]
if currentInt != (currentInt or (1 shl bitOffset)):
return false
true
return true

View File

@ -1,35 +1,14 @@
import std/[times, sets]
import ./types/sds_message_id
import ./types/history_entry
import ./types/sds_message
import ./types/unacknowledged_message
import ./types/incoming_message
import ./types/reliability_config
type
SdsMessageID* = string
SdsChannelID* = string
HistoryEntry* = object
messageId*: SdsMessageID
retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash)
SdsMessage* = object
messageId*: SdsMessageID
lamportTimestamp*: int64
causalHistory*: seq[HistoryEntry]
channelId*: SdsChannelID
content*: seq[byte]
bloomFilter*: seq[byte]
UnacknowledgedMessage* = object
message*: SdsMessage
sendTime*: Time
resendAttempts*: int
IncomingMessage* = object
message*: SdsMessage
missingDeps*: HashSet[SdsMessageID]
const
DefaultMaxMessageHistory* = 1000
DefaultMaxCausalHistory* = 10
DefaultResendInterval* = initDuration(seconds = 60)
DefaultMaxResendAttempts* = 5
DefaultSyncMessageInterval* = initDuration(seconds = 30)
DefaultBufferSweepInterval* = initDuration(seconds = 60)
MaxMessageSize* = 1024 * 1024 # 1 MB
export
sds_message_id,
history_entry,
sds_message,
unacknowledged_message,
incoming_message,
reliability_config

View File

@ -1,6 +1,9 @@
import libp2p/protobuf/minprotobuf
import endians
import sds/[message, protobufutil, bloom, sds_utils]
import ./types/[sds_message_id, history_entry, sds_message, reliability_error]
import ./protobufutil
import ./bloom
import ./sds_utils
proc encode*(msg: SdsMessage): ProtoBuffer =
var pb = initProtoBuffer()
@ -21,11 +24,11 @@ proc encode*(msg: SdsMessage): ProtoBuffer =
pb.write(6, msg.bloomFilter)
pb.finish()
pb
return pb
proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var msg = SdsMessage()
var msg = SdsMessage.init("", 0, @[], "", @[], @[])
if not ?pb.getField(1, msg.messageId):
return err(ProtobufError.missingRequiredField("messageId"))
@ -41,7 +44,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
# New format: repeated HistoryEntry
for histBuffer in historyBuffers:
let entryPb = initProtoBuffer(histBuffer)
var entry: HistoryEntry
var entry = HistoryEntry.init("")
if not ?entryPb.getField(1, entry.messageId):
return err(ProtobufError.missingRequiredField("HistoryEntry.messageId"))
# retrievalHint is optional
@ -63,7 +66,7 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
if not ?pb.getField(6, msg.bloomFilter):
msg.bloomFilter = @[] # Empty if not present
ok(msg)
return ok(msg)
proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] =
## For extraction of channel ID without full message deserialization
@ -74,23 +77,22 @@ proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError]
return err(ReliabilityError.reDeserializationError)
if not fieldOk:
return err(ReliabilityError.reDeserializationError)
ok(channelId)
return ok(channelId)
except:
err(ReliabilityError.reDeserializationError)
return err(ReliabilityError.reDeserializationError)
proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] =
let pb = encode(msg)
ok(pb.buffer)
return ok(pb.buffer)
proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] =
let msg = SdsMessage.decode(data).valueOr:
return err(ReliabilityError.reDeserializationError)
ok(msg)
return ok(msg)
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
var pb = initProtoBuffer()
# Convert intArray to bytes
try:
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
for i, val in filter.intArray:
@ -108,7 +110,7 @@ proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityEr
return err(ReliabilityError.reSerializationError)
pb.finish()
ok(pb.buffer)
return ok(pb.buffer)
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
if data.len == 0:
@ -134,7 +136,6 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr
if not field1_Ok or not field2_Ok or not field3_Ok or not field4_Ok or not field5_Ok:
return err(ReliabilityError.reDeserializationError)
# Convert bytes back to intArray
var intArray = newSeq[int](bytes.len div sizeof(int))
for i in 0 ..< intArray.len:
var leVal: int
@ -142,13 +143,13 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr
copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int))
littleEndian64(addr intArray[i], addr leVal)
ok(
BloomFilter(
intArray: intArray,
capacity: int(cap),
errorRate: float(errRate) / 1_000_000,
kHashes: int(kHashes),
mBits: int(mBits),
return ok(
BloomFilter.init(
capacity = int(cap),
errorRate = float(errRate) / 1_000_000,
kHashes = int(kHashes),
mBits = int(mBits),
intArray = intArray,
)
)
except:

View File

@ -4,29 +4,16 @@
import libp2p/protobuf/minprotobuf
import libp2p/varint
import ./types/protobuf_error
export minprotobuf, varint
type
ProtobufErrorKind* {.pure.} = enum
DecodeFailure
MissingRequiredField
ProtobufError* = object
case kind*: ProtobufErrorKind
of DecodeFailure:
error*: minprotobuf.ProtoError
of MissingRequiredField:
field*: string
ProtobufResult*[T] = Result[T, ProtobufError]
export minprotobuf, varint, protobuf_error
converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
case err
of minprotobuf.ProtoError.RequiredFieldMissing:
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
return ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
else:
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
return ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
proc missingRequiredField*(T: type ProtobufError, field: string): T =
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)
return ProtobufError.init(field)

View File

@ -1,23 +1,14 @@
import chronos
import chronicles
import ./[bloom, message]
import ./bloom
import ./types/rolling_bloom_filter
export rolling_bloom_filter
type RollingBloomFilter* = object
filter*: BloomFilter
capacity*: int
minCapacity*: int
maxCapacity*: int
messages*: seq[SdsMessageID]
const
DefaultBloomFilterCapacity* = 10000
DefaultBloomFilterErrorRate* = 0.001
CapacityFlexPercent* = 20
proc newRollingBloomFilter*(
proc init*(
T: type RollingBloomFilter,
capacity: int = DefaultBloomFilterCapacity,
errorRate: float = DefaultBloomFilterErrorRate,
): RollingBloomFilter {.gcsafe.} =
): T {.gcsafe.} =
let targetCapacity = if capacity <= 0: DefaultBloomFilterCapacity else: capacity
let targetError =
if errorRate <= 0.0 or errorRate >= 1.0: DefaultBloomFilterErrorRate else: errorRate
@ -25,7 +16,6 @@ proc newRollingBloomFilter*(
let filterResult = initializeBloomFilter(targetCapacity, targetError)
if filterResult.isErr:
error "Failed to initialize bloom filter", error = filterResult.error
# Try with default values if custom values failed
if capacity != DefaultBloomFilterCapacity or errorRate != DefaultBloomFilterErrorRate:
let defaultResult =
initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
@ -45,12 +35,11 @@ proc newRollingBloomFilter*(
minCapacity = minCapacity,
maxCapacity = maxCapacity
return RollingBloomFilter(
filter: defaultResult.get(),
capacity: DefaultBloomFilterCapacity,
minCapacity: minCapacity,
maxCapacity: maxCapacity,
messages: @[],
return RollingBloomFilter.init(
filter = defaultResult.get(),
capacity = DefaultBloomFilterCapacity,
minCapacity = minCapacity,
maxCapacity = maxCapacity,
)
else:
error "Could not create bloom filter", error = filterResult.error
@ -63,12 +52,11 @@ proc newRollingBloomFilter*(
info "Successfully initialized bloom filter",
capacity = targetCapacity, minCapacity = minCapacity, maxCapacity = maxCapacity
return RollingBloomFilter(
filter: filterResult.get(),
capacity: targetCapacity,
minCapacity: minCapacity,
maxCapacity: maxCapacity,
messages: @[],
return RollingBloomFilter.init(
filter = filterResult.get(),
capacity = targetCapacity,
minCapacity = minCapacity,
maxCapacity = maxCapacity,
)
proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
@ -97,22 +85,12 @@ proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
proc add*(rbf: var RollingBloomFilter, messageId: SdsMessageID) {.gcsafe.} =
## Adds a message ID to the rolling bloom filter.
##
## Parameters:
## - messageId: The ID of the message to add.
rbf.filter.insert(cast[string](messageId))
rbf.messages.add(messageId)
# Clean if we exceed max capacity
if rbf.messages.len > rbf.maxCapacity:
rbf.clean()
proc contains*(rbf: RollingBloomFilter, messageId: SdsMessageID): bool =
## Checks if a message ID is in the rolling bloom filter.
##
## Parameters:
## - messageId: The ID of the message to check.
##
## Returns:
## True if the message ID is probably in the filter, false otherwise.
rbf.filter.lookup(cast[string](messageId))
return rbf.filter.lookup(cast[string](messageId))

View File

@ -1,81 +1,18 @@
import std/[times, locks, tables, sequtils]
import std/[locks, tables, sequtils]
import chronicles, results
import ./[rolling_bloom_filter, message]
type
MessageReadyCallback* =
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
MessageSentCallback* =
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
MissingDependenciesCallback* = proc(
messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID
) {.gcsafe.}
RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.}
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
AppCallbacks* = ref object
messageReadyCb*: MessageReadyCallback
messageSentCb*: MessageSentCallback
missingDependenciesCb*: MissingDependenciesCallback
periodicSyncCb*: PeriodicSyncCallback
retrievalHintProvider*: RetrievalHintProvider
ReliabilityConfig* = object
bloomFilterCapacity*: int
bloomFilterErrorRate*: float
maxMessageHistory*: int
maxCausalHistory*: int
resendInterval*: Duration
maxResendAttempts*: int
syncMessageInterval*: Duration
bufferSweepInterval*: Duration
ChannelContext* = ref object
lamportTimestamp*: int64
messageHistory*: seq[SdsMessageID]
bloomFilter*: RollingBloomFilter
outgoingBuffer*: seq[UnacknowledgedMessage]
incomingBuffer*: Table[SdsMessageID, IncomingMessage]
ReliabilityManager* = ref object
channels*: Table[SdsChannelID, ChannelContext]
config*: ReliabilityConfig
lock*: Lock
onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
onMissingDependencies*: proc(
messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID
) {.gcsafe.}
onPeriodicSync*: PeriodicSyncCallback
onRetrievalHint*: RetrievalHintProvider
ReliabilityError* {.pure.} = enum
reInvalidArgument
reOutOfMemory
reInternalError
reSerializationError
reDeserializationError
reMessageTooLarge
import ./rolling_bloom_filter
import ./types/[
sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message,
reliability_error, callbacks, app_callbacks, reliability_config, channel_context,
reliability_manager,
]
export
sds_message_id, history_entry, sds_message, unacknowledged_message, incoming_message,
reliability_error, callbacks, app_callbacks, reliability_config, channel_context,
reliability_manager
proc defaultConfig*(): ReliabilityConfig =
## Creates a default configuration for the ReliabilityManager.
##
## Returns:
## A ReliabilityConfig object with default values.
ReliabilityConfig(
bloomFilterCapacity: DefaultBloomFilterCapacity,
bloomFilterErrorRate: DefaultBloomFilterErrorRate,
maxMessageHistory: DefaultMaxMessageHistory,
maxCausalHistory: DefaultMaxCausalHistory,
resendInterval: DefaultResendInterval,
maxResendAttempts: DefaultMaxResendAttempts,
syncMessageInterval: DefaultSyncMessageInterval,
bufferSweepInterval: DefaultBufferSweepInterval,
)
return ReliabilityConfig.init()
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
if not rm.isNil():
@ -124,24 +61,18 @@ proc updateLamportTimestamp*(
error "Failed to update lamport timestamp",
channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg()
# Helper functions for HistoryEntry
proc newHistoryEntry*(messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): HistoryEntry =
## Creates a new HistoryEntry with optional retrieval hint
HistoryEntry(messageId: messageId, retrievalHint: retrievalHint)
return HistoryEntry.init(messageId, retrievalHint)
proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] =
## Converts a sequence of message IDs to HistoryEntry sequence (for backward compatibility)
return messageIds.mapIt(newHistoryEntry(it))
proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] =
## Extracts message IDs from HistoryEntry sequence
return causalHistory.mapIt(it.messageId)
proc getRecentHistoryEntries*(
rm: ReliabilityManager, n: int, channelId: SdsChannelID
): seq[HistoryEntry] =
## Get recent history entries for sending in causal history.
## Populates retrieval hints for our own messages using the provider callback.
try:
if channelId in rm.channels:
let channel = rm.channels[channelId]
@ -164,7 +95,6 @@ proc getRecentHistoryEntries*(
proc checkDependencies*(
rm: ReliabilityManager, deps: seq[HistoryEntry], channelId: SdsChannelID
): seq[HistoryEntry] =
## Check which dependencies are missing from our message history.
var missingDeps: seq[HistoryEntry] = @[]
try:
if channelId in rm.channels:
@ -173,7 +103,6 @@ proc checkDependencies*(
if dep.messageId notin channel.messageHistory:
missingDeps.add(dep)
else:
# Channel doesn't exist, all deps are missing
missingDeps = deps
except Exception:
error "Failed to check dependencies",
@ -187,13 +116,13 @@ proc getMessageHistory*(
withLock rm.lock:
try:
if channelId in rm.channels:
result = rm.channels[channelId].messageHistory
return rm.channels[channelId].messageHistory
else:
result = @[]
return @[]
except Exception:
error "Failed to get message history",
channelId = channelId, error = getCurrentExceptionMsg()
result = @[]
return @[]
proc getOutgoingBuffer*(
rm: ReliabilityManager, channelId: SdsChannelID
@ -201,43 +130,37 @@ proc getOutgoingBuffer*(
withLock rm.lock:
try:
if channelId in rm.channels:
result = rm.channels[channelId].outgoingBuffer
return rm.channels[channelId].outgoingBuffer
else:
result = @[]
return @[]
except Exception:
error "Failed to get outgoing buffer",
channelId = channelId, error = getCurrentExceptionMsg()
result = @[]
return @[]
proc getIncomingBuffer*(
rm: ReliabilityManager, channelId: SdsChannelID
): Table[SdsMessageID, message.IncomingMessage] =
): Table[SdsMessageID, IncomingMessage] =
withLock rm.lock:
try:
if channelId in rm.channels:
result = rm.channels[channelId].incomingBuffer
return rm.channels[channelId].incomingBuffer
else:
result = initTable[SdsMessageID, message.IncomingMessage]()
return initTable[SdsMessageID, IncomingMessage]()
except Exception:
error "Failed to get incoming buffer",
channelId = channelId, error = getCurrentExceptionMsg()
result = initTable[SdsMessageID, message.IncomingMessage]()
return initTable[SdsMessageID, IncomingMessage]()
proc getOrCreateChannel*(
rm: ReliabilityManager, channelId: SdsChannelID
): ChannelContext =
try:
if channelId notin rm.channels:
rm.channels[channelId] = ChannelContext(
lamportTimestamp: 0,
messageHistory: @[],
bloomFilter: newRollingBloomFilter(
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate
),
outgoingBuffer: @[],
incomingBuffer: initTable[SdsMessageID, IncomingMessage](),
rm.channels[channelId] = ChannelContext.new(
RollingBloomFilter.init(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate)
)
result = rm.channels[channelId]
return rm.channels[channelId]
except Exception:
error "Failed to get or create channel",
channelId = channelId, error = getCurrentExceptionMsg()
@ -270,4 +193,4 @@ proc removeChannel*(
except Exception:
error "Failed to remove channel",
channelId = channelId, msg = getCurrentExceptionMsg()
return err(ReliabilityError.reInternalError)
return err(ReliabilityError.reInternalError)

30
sds/types.nim Normal file
View File

@ -0,0 +1,30 @@
import sds/types/sds_message_id
import sds/types/history_entry
import sds/types/sds_message
import sds/types/unacknowledged_message
import sds/types/incoming_message
import sds/types/bloom_filter
import sds/types/rolling_bloom_filter
import sds/types/reliability_error
import sds/types/callbacks
import sds/types/app_callbacks
import sds/types/reliability_config
import sds/types/channel_context
import sds/types/reliability_manager
import sds/types/protobuf_error
export
sds_message_id,
history_entry,
sds_message,
unacknowledged_message,
incoming_message,
bloom_filter,
rolling_bloom_filter,
reliability_error,
callbacks,
app_callbacks,
reliability_config,
channel_context,
reliability_manager,
protobuf_error

View File

@ -0,0 +1,25 @@
import ./callbacks
export callbacks
type AppCallbacks* = ref object
messageReadyCb*: MessageReadyCallback
messageSentCb*: MessageSentCallback
missingDependenciesCb*: MissingDependenciesCallback
periodicSyncCb*: PeriodicSyncCallback
retrievalHintProvider*: RetrievalHintProvider
proc new*(
T: type AppCallbacks,
messageReadyCb: MessageReadyCallback = nil,
messageSentCb: MessageSentCallback = nil,
missingDependenciesCb: MissingDependenciesCallback = nil,
periodicSyncCb: PeriodicSyncCallback = nil,
retrievalHintProvider: RetrievalHintProvider = nil,
): T =
return T(
messageReadyCb: messageReadyCb,
messageSentCb: messageSentCb,
missingDependenciesCb: missingDependenciesCb,
periodicSyncCb: periodicSyncCb,
retrievalHintProvider: retrievalHintProvider,
)

View File

@ -0,0 +1,22 @@
type BloomFilter* {.requiresInit.} = object
capacity*: int
errorRate*: float
kHashes*: int
mBits*: int
intArray*: seq[int]
proc init*(
T: type BloomFilter,
capacity: int,
errorRate: float,
kHashes: int,
mBits: int,
intArray: seq[int],
): T =
return T(
capacity: capacity,
errorRate: errorRate,
kHashes: kHashes,
mBits: mBits,
intArray: intArray,
)

18
sds/types/callbacks.nim Normal file
View File

@ -0,0 +1,18 @@
import ./sds_message_id
import ./history_entry
export sds_message_id, history_entry
type
MessageReadyCallback* =
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
MessageSentCallback* =
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
MissingDependenciesCallback* = proc(
messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID
) {.gcsafe.}
RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.}
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}

View File

@ -0,0 +1,22 @@
import std/tables
import ./sds_message_id
import ./rolling_bloom_filter
import ./unacknowledged_message
import ./incoming_message
export sds_message_id, rolling_bloom_filter, unacknowledged_message, incoming_message
type ChannelContext* = ref object
lamportTimestamp*: int64
messageHistory*: seq[SdsMessageID]
bloomFilter*: RollingBloomFilter
outgoingBuffer*: seq[UnacknowledgedMessage]
incomingBuffer*: Table[SdsMessageID, IncomingMessage]
proc new*(T: type ChannelContext, bloomFilter: RollingBloomFilter): T =
return T(
lamportTimestamp: 0,
messageHistory: @[],
bloomFilter: bloomFilter,
outgoingBuffer: @[],
incomingBuffer: initTable[SdsMessageID, IncomingMessage](),
)

View File

@ -0,0 +1,8 @@
import ./sds_message_id
type HistoryEntry* = object
messageId*: SdsMessageID
retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash)
proc init*(T: type HistoryEntry, messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): T =
return T(messageId: messageId, retrievalHint: retrievalHint)

View File

@ -0,0 +1,13 @@
import std/sets
import ./sds_message_id
import ./sds_message
export sds_message_id, sds_message
type IncomingMessage* {.requiresInit.} = object
message*: SdsMessage
missingDeps*: HashSet[SdsMessageID]
proc init*(
T: type IncomingMessage, message: SdsMessage, missingDeps: HashSet[SdsMessageID]
): T =
return T(message: message, missingDeps: missingDeps)

View File

@ -0,0 +1,22 @@
import results
import libp2p/protobuf/minprotobuf
type
ProtobufErrorKind* {.pure.} = enum
DecodeFailure
MissingRequiredField
ProtobufError* = object
case kind*: ProtobufErrorKind
of DecodeFailure:
error*: minprotobuf.ProtoError
of MissingRequiredField:
field*: string
ProtobufResult*[T] = Result[T, ProtobufError]
proc init*(T: type ProtobufError, error: minprotobuf.ProtoError): T =
return T(kind: ProtobufErrorKind.DecodeFailure, error: error)
proc init*(T: type ProtobufError, field: string): T =
return T(kind: ProtobufErrorKind.MissingRequiredField, field: field)

View File

@ -0,0 +1,45 @@
import std/times
const
DefaultMaxMessageHistory* = 1000
DefaultMaxCausalHistory* = 10
DefaultResendInterval* = initDuration(seconds = 60)
DefaultMaxResendAttempts* = 5
DefaultSyncMessageInterval* = initDuration(seconds = 30)
DefaultBufferSweepInterval* = initDuration(seconds = 60)
MaxMessageSize* = 1024 * 1024 # 1 MB
import ./rolling_bloom_filter
export rolling_bloom_filter
type ReliabilityConfig* {.requiresInit.} = object
bloomFilterCapacity*: int
bloomFilterErrorRate*: float
maxMessageHistory*: int
maxCausalHistory*: int
resendInterval*: Duration
maxResendAttempts*: int
syncMessageInterval*: Duration
bufferSweepInterval*: Duration
proc init*(
T: type ReliabilityConfig,
bloomFilterCapacity: int = DefaultBloomFilterCapacity,
bloomFilterErrorRate: float = DefaultBloomFilterErrorRate,
maxMessageHistory: int = DefaultMaxMessageHistory,
maxCausalHistory: int = DefaultMaxCausalHistory,
resendInterval: Duration = DefaultResendInterval,
maxResendAttempts: int = DefaultMaxResendAttempts,
syncMessageInterval: Duration = DefaultSyncMessageInterval,
bufferSweepInterval: Duration = DefaultBufferSweepInterval,
): T =
return T(
bloomFilterCapacity: bloomFilterCapacity,
bloomFilterErrorRate: bloomFilterErrorRate,
maxMessageHistory: maxMessageHistory,
maxCausalHistory: maxCausalHistory,
resendInterval: resendInterval,
maxResendAttempts: maxResendAttempts,
syncMessageInterval: syncMessageInterval,
bufferSweepInterval: bufferSweepInterval,
)

View File

@ -0,0 +1,7 @@
type ReliabilityError* {.pure.} = enum
reInvalidArgument
reOutOfMemory
reInternalError
reSerializationError
reDeserializationError
reMessageTooLarge

View File

@ -0,0 +1,27 @@
import std/[tables, locks]
import ./sds_message_id
import ./history_entry
import ./callbacks
import ./reliability_config
import ./channel_context
export sds_message_id, history_entry, callbacks, reliability_config, channel_context
type ReliabilityManager* = ref object
channels*: Table[SdsChannelID, ChannelContext]
config*: ReliabilityConfig
lock*: Lock
onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
onMissingDependencies*: proc(
messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID
) {.gcsafe.}
onPeriodicSync*: PeriodicSyncCallback
onRetrievalHint*: RetrievalHintProvider
proc new*(T: type ReliabilityManager, config: ReliabilityConfig): T =
let rm = T(
channels: initTable[SdsChannelID, ChannelContext](),
config: config,
)
rm.lock.initLock()
return rm

View File

@ -0,0 +1,31 @@
import ./bloom_filter
import ./sds_message_id
export bloom_filter, sds_message_id
const
DefaultBloomFilterCapacity* = 10000
DefaultBloomFilterErrorRate* = 0.001
CapacityFlexPercent* = 20
type RollingBloomFilter* {.requiresInit.} = object
filter*: BloomFilter
capacity*: int
minCapacity*: int
maxCapacity*: int
messages*: seq[SdsMessageID]
proc init*(
T: type RollingBloomFilter,
filter: BloomFilter,
capacity: int,
minCapacity: int,
maxCapacity: int,
messages: seq[SdsMessageID] = @[],
): T =
return T(
filter: filter,
capacity: capacity,
minCapacity: minCapacity,
maxCapacity: maxCapacity,
messages: messages,
)

29
sds/types/sds_message.nim Normal file
View File

@ -0,0 +1,29 @@
import ./sds_message_id
import ./history_entry
export sds_message_id, history_entry
type SdsMessage* {.requiresInit.} = object
messageId*: SdsMessageID
lamportTimestamp*: int64
causalHistory*: seq[HistoryEntry]
channelId*: SdsChannelID
content*: seq[byte]
bloomFilter*: seq[byte]
proc init*(
T: type SdsMessage,
messageId: SdsMessageID,
lamportTimestamp: int64,
causalHistory: seq[HistoryEntry],
channelId: SdsChannelID,
content: seq[byte],
bloomFilter: seq[byte],
): T =
return T(
messageId: messageId,
lamportTimestamp: lamportTimestamp,
causalHistory: causalHistory,
channelId: channelId,
content: content,
bloomFilter: bloomFilter,
)

View File

@ -0,0 +1,3 @@
type
SdsMessageID* = string
SdsChannelID* = string

View File

@ -0,0 +1,13 @@
import std/times
import ./sds_message
export sds_message
type UnacknowledgedMessage* = object
message*: SdsMessage
sendTime*: Time
resendAttempts*: int
proc init*(
T: type UnacknowledgedMessage, message: SdsMessage, sendTime: Time, resendAttempts: int
): T =
return T(message: message, sendTime: sendTime, resendAttempts: resendAttempts)

View File

@ -228,7 +228,7 @@ suite "Reliability Mechanisms":
# Create a message with bloom filter containing our message
var otherPartyBloomFilter =
newRollingBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
RollingBloomFilter.init(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
otherPartyBloomFilter.add(id1)
let bfResult = serializeBloomFilter(otherPartyBloomFilter.filter)