mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-01-07 16:43:07 +00:00
feat - add rolling bloom filter, reliability utils and protobuf
This commit is contained in:
parent
5df71ad3ea
commit
6b0b9c34fa
16
reliability.nimble
Normal file
16
reliability.nimble
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Package
|
||||||
|
version = "0.1.0"
|
||||||
|
author = "Waku Team"
|
||||||
|
description = "E2E Reliability Protocol API"
|
||||||
|
license = "MIT"
|
||||||
|
srcDir = "src"
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
requires "nim >= 2.0.8"
|
||||||
|
requires "chronicles"
|
||||||
|
requires "libp2p"
|
||||||
|
|
||||||
|
# Tasks
|
||||||
|
task test, "Run the test suite":
|
||||||
|
exec "nim c -r tests/test_bloom.nim"
|
||||||
|
exec "nim c -r tests/test_reliability.nim"
|
||||||
30
src/message.nim
Normal file
30
src/message.nim
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import std/times
|
||||||
|
|
||||||
|
type
|
||||||
|
MessageID* = string
|
||||||
|
|
||||||
|
Message* = object
|
||||||
|
messageId*: MessageID
|
||||||
|
lamportTimestamp*: int64
|
||||||
|
causalHistory*: seq[MessageID]
|
||||||
|
channelId*: string
|
||||||
|
content*: seq[byte]
|
||||||
|
bloomFilter*: seq[byte]
|
||||||
|
|
||||||
|
UnacknowledgedMessage* = object
|
||||||
|
message*: Message
|
||||||
|
sendTime*: Time
|
||||||
|
resendAttempts*: int
|
||||||
|
|
||||||
|
TimestampedMessageID* = object
|
||||||
|
id*: MessageID
|
||||||
|
timestamp*: Time
|
||||||
|
|
||||||
|
const
|
||||||
|
DefaultMaxMessageHistory* = 1000
|
||||||
|
DefaultMaxCausalHistory* = 10
|
||||||
|
DefaultResendInterval* = initDuration(seconds = 60)
|
||||||
|
DefaultMaxResendAttempts* = 5
|
||||||
|
DefaultSyncMessageInterval* = initDuration(seconds = 30)
|
||||||
|
DefaultBufferSweepInterval* = initDuration(seconds = 60)
|
||||||
|
MaxMessageSize* = 1024 * 1024 # 1 MB
|
||||||
122
src/protobuf.nim
Normal file
122
src/protobuf.nim
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import libp2p/protobuf/minprotobuf
|
||||||
|
import std/options
|
||||||
|
import ../src/[message, protobufutil, bloom, reliability_utils]
|
||||||
|
|
||||||
|
proc toBytes(s: string): seq[byte] =
|
||||||
|
result = newSeq[byte](s.len)
|
||||||
|
copyMem(result[0].addr, s[0].unsafeAddr, s.len)
|
||||||
|
|
||||||
|
proc encode*(msg: Message): ProtoBuffer =
|
||||||
|
var pb = initProtoBuffer()
|
||||||
|
|
||||||
|
pb.write(1, msg.messageId)
|
||||||
|
pb.write(2, uint64(msg.lamportTimestamp))
|
||||||
|
|
||||||
|
for hist in msg.causalHistory:
|
||||||
|
pb.write(3, hist.toBytes) # Convert string to bytes for proper length handling
|
||||||
|
|
||||||
|
pb.write(4, msg.channelId)
|
||||||
|
pb.write(5, msg.content)
|
||||||
|
pb.write(6, msg.bloomFilter)
|
||||||
|
pb.finish()
|
||||||
|
|
||||||
|
pb
|
||||||
|
|
||||||
|
proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||||
|
let pb = initProtoBuffer(buffer)
|
||||||
|
var msg = Message()
|
||||||
|
|
||||||
|
if not ?pb.getField(1, msg.messageId):
|
||||||
|
return err(ProtobufError.missingRequiredField("messageId"))
|
||||||
|
|
||||||
|
var timestamp: uint64
|
||||||
|
if not ?pb.getField(2, timestamp):
|
||||||
|
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
|
||||||
|
msg.lamportTimestamp = int64(timestamp)
|
||||||
|
|
||||||
|
# Decode causal history
|
||||||
|
var causalHistory: seq[string]
|
||||||
|
let histResult = pb.getRepeatedField(3, causalHistory)
|
||||||
|
if histResult.isOk:
|
||||||
|
msg.causalHistory = causalHistory
|
||||||
|
|
||||||
|
if not ?pb.getField(4, msg.channelId):
|
||||||
|
return err(ProtobufError.missingRequiredField("channelId"))
|
||||||
|
|
||||||
|
if not ?pb.getField(5, msg.content):
|
||||||
|
return err(ProtobufError.missingRequiredField("content"))
|
||||||
|
|
||||||
|
if not ?pb.getField(6, msg.bloomFilter):
|
||||||
|
msg.bloomFilter = @[] # Empty if not present
|
||||||
|
|
||||||
|
ok(msg)
|
||||||
|
|
||||||
|
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
|
||||||
|
try:
|
||||||
|
let pb = encode(msg)
|
||||||
|
ok(pb.buffer)
|
||||||
|
except:
|
||||||
|
err(reSerializationError)
|
||||||
|
|
||||||
|
proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
|
||||||
|
try:
|
||||||
|
let msgResult = Message.decode(data)
|
||||||
|
if msgResult.isOk:
|
||||||
|
ok(msgResult.get)
|
||||||
|
else:
|
||||||
|
err(reSerializationError)
|
||||||
|
except:
|
||||||
|
err(reDeserializationError)
|
||||||
|
|
||||||
|
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
|
||||||
|
try:
|
||||||
|
var pb = initProtoBuffer()
|
||||||
|
|
||||||
|
# Convert intArray to bytes
|
||||||
|
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
|
||||||
|
for i, val in filter.intArray:
|
||||||
|
let start = i * sizeof(int)
|
||||||
|
copyMem(addr bytes[start], unsafeAddr val, sizeof(int))
|
||||||
|
|
||||||
|
pb.write(1, bytes)
|
||||||
|
pb.write(2, uint64(filter.capacity))
|
||||||
|
pb.write(3, uint64(filter.errorRate * 1_000_000))
|
||||||
|
pb.write(4, uint64(filter.kHashes))
|
||||||
|
pb.write(5, uint64(filter.mBits))
|
||||||
|
|
||||||
|
pb.finish()
|
||||||
|
ok(pb.buffer)
|
||||||
|
except:
|
||||||
|
err(reSerializationError)
|
||||||
|
|
||||||
|
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
|
||||||
|
if data.len == 0:
|
||||||
|
return err(reDeserializationError)
|
||||||
|
|
||||||
|
try:
|
||||||
|
let pb = initProtoBuffer(data)
|
||||||
|
var bytes: seq[byte]
|
||||||
|
var cap, errRate, kHashes, mBits: uint64
|
||||||
|
|
||||||
|
if not pb.getField(1, bytes).get() or
|
||||||
|
not pb.getField(2, cap).get() or
|
||||||
|
not pb.getField(3, errRate).get() or
|
||||||
|
not pb.getField(4, kHashes).get() or
|
||||||
|
not pb.getField(5, mBits).get():
|
||||||
|
return err(reDeserializationError)
|
||||||
|
|
||||||
|
# Convert bytes back to intArray
|
||||||
|
var intArray = newSeq[int](bytes.len div sizeof(int))
|
||||||
|
for i in 0 ..< intArray.len:
|
||||||
|
let start = i * sizeof(int)
|
||||||
|
copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int))
|
||||||
|
|
||||||
|
ok(BloomFilter(
|
||||||
|
intArray: intArray,
|
||||||
|
capacity: int(cap),
|
||||||
|
errorRate: float(errRate) / 1_000_000,
|
||||||
|
kHashes: int(kHashes),
|
||||||
|
mBits: int(mBits)
|
||||||
|
))
|
||||||
|
except:
|
||||||
|
err(reDeserializationError)
|
||||||
32
src/protobufutil.nim
Normal file
32
src/protobufutil.nim
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# adapted from https://github.com/waku-org/nwaku/blob/master/waku/common/protobuf.nim
|
||||||
|
|
||||||
|
{.push raises: [].}
|
||||||
|
|
||||||
|
import libp2p/protobuf/minprotobuf
|
||||||
|
import libp2p/varint
|
||||||
|
|
||||||
|
export minprotobuf, varint
|
||||||
|
|
||||||
|
type
|
||||||
|
ProtobufErrorKind* {.pure.} = enum
|
||||||
|
DecodeFailure
|
||||||
|
MissingRequiredField
|
||||||
|
|
||||||
|
ProtobufError* = object
|
||||||
|
case kind*: ProtobufErrorKind
|
||||||
|
of DecodeFailure:
|
||||||
|
error*: minprotobuf.ProtoError
|
||||||
|
of MissingRequiredField:
|
||||||
|
field*: string
|
||||||
|
|
||||||
|
ProtobufResult*[T] = Result[T, ProtobufError]
|
||||||
|
|
||||||
|
converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
|
||||||
|
case err
|
||||||
|
of minprotobuf.ProtoError.RequiredFieldMissing:
|
||||||
|
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
|
||||||
|
else:
|
||||||
|
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
|
||||||
|
|
||||||
|
proc missingRequiredField*(T: type ProtobufError, field: string): T =
|
||||||
|
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)
|
||||||
95
src/reliability_utils.nim
Normal file
95
src/reliability_utils.nim
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import std/[times, locks]
|
||||||
|
import ./[rolling_bloom_filter, message]
|
||||||
|
|
||||||
|
type
|
||||||
|
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
|
||||||
|
|
||||||
|
ReliabilityConfig* = object
|
||||||
|
bloomFilterCapacity*: int
|
||||||
|
bloomFilterErrorRate*: float
|
||||||
|
bloomFilterWindow*: times.Duration
|
||||||
|
maxMessageHistory*: int
|
||||||
|
maxCausalHistory*: int
|
||||||
|
resendInterval*: times.Duration
|
||||||
|
maxResendAttempts*: int
|
||||||
|
syncMessageInterval*: times.Duration
|
||||||
|
bufferSweepInterval*: times.Duration
|
||||||
|
|
||||||
|
ReliabilityManager* = ref object
|
||||||
|
lamportTimestamp*: int64
|
||||||
|
messageHistory*: seq[MessageID]
|
||||||
|
bloomFilter*: RollingBloomFilter
|
||||||
|
outgoingBuffer*: seq[UnacknowledgedMessage]
|
||||||
|
incomingBuffer*: seq[Message]
|
||||||
|
channelId*: string
|
||||||
|
config*: ReliabilityConfig
|
||||||
|
lock*: Lock
|
||||||
|
onMessageReady*: proc(messageId: MessageID) {.gcsafe.}
|
||||||
|
onMessageSent*: proc(messageId: MessageID) {.gcsafe.}
|
||||||
|
onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}
|
||||||
|
onPeriodicSync*: PeriodicSyncCallback
|
||||||
|
|
||||||
|
ReliabilityError* = enum
|
||||||
|
reInvalidArgument
|
||||||
|
reOutOfMemory
|
||||||
|
reInternalError
|
||||||
|
reSerializationError
|
||||||
|
reDeserializationError
|
||||||
|
reMessageTooLarge
|
||||||
|
|
||||||
|
proc defaultConfig*(): ReliabilityConfig =
|
||||||
|
## Creates a default configuration for the ReliabilityManager.
|
||||||
|
##
|
||||||
|
## Returns:
|
||||||
|
## A ReliabilityConfig object with default values.
|
||||||
|
ReliabilityConfig(
|
||||||
|
bloomFilterCapacity: DefaultBloomFilterCapacity,
|
||||||
|
bloomFilterErrorRate: DefaultBloomFilterErrorRate,
|
||||||
|
bloomFilterWindow: DefaultBloomFilterWindow,
|
||||||
|
maxMessageHistory: DefaultMaxMessageHistory,
|
||||||
|
maxCausalHistory: DefaultMaxCausalHistory,
|
||||||
|
resendInterval: DefaultResendInterval,
|
||||||
|
maxResendAttempts: DefaultMaxResendAttempts,
|
||||||
|
syncMessageInterval: DefaultSyncMessageInterval,
|
||||||
|
bufferSweepInterval: DefaultBufferSweepInterval
|
||||||
|
)
|
||||||
|
|
||||||
|
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
|
||||||
|
if not rm.isNil:
|
||||||
|
{.gcsafe.}:
|
||||||
|
try:
|
||||||
|
rm.outgoingBuffer.setLen(0)
|
||||||
|
rm.incomingBuffer.setLen(0)
|
||||||
|
rm.messageHistory.setLen(0)
|
||||||
|
except Exception as e:
|
||||||
|
logError("Error during cleanup: " & e.msg)
|
||||||
|
|
||||||
|
proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
||||||
|
withLock rm.lock:
|
||||||
|
try:
|
||||||
|
rm.bloomFilter.clean()
|
||||||
|
except Exception as e:
|
||||||
|
logError("Failed to clean ReliabilityManager bloom filter: " & e.msg)
|
||||||
|
|
||||||
|
proc addToHistory*(rm: ReliabilityManager, msgId: MessageID) {.gcsafe, raises: [].} =
|
||||||
|
rm.messageHistory.add(msgId)
|
||||||
|
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
||||||
|
rm.messageHistory.delete(0)
|
||||||
|
|
||||||
|
proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) {.gcsafe, raises: [].} =
|
||||||
|
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
|
||||||
|
|
||||||
|
proc getRecentMessageIDs*(rm: ReliabilityManager, n: int): seq[MessageID] =
|
||||||
|
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
|
||||||
|
|
||||||
|
proc getMessageHistory*(rm: ReliabilityManager): seq[MessageID] =
|
||||||
|
withLock rm.lock:
|
||||||
|
result = rm.messageHistory
|
||||||
|
|
||||||
|
proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] =
|
||||||
|
withLock rm.lock:
|
||||||
|
result = rm.outgoingBuffer
|
||||||
|
|
||||||
|
proc getIncomingBuffer*(rm: ReliabilityManager): seq[Message] =
|
||||||
|
withLock rm.lock:
|
||||||
|
result = rm.incomingBuffer
|
||||||
95
src/rolling_bloom_filter.nim
Normal file
95
src/rolling_bloom_filter.nim
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import std/times
|
||||||
|
import chronos
|
||||||
|
import chronicles
|
||||||
|
import ./[bloom, message]
|
||||||
|
|
||||||
|
type
|
||||||
|
RollingBloomFilter* = object
|
||||||
|
filter*: BloomFilter
|
||||||
|
window*: times.Duration
|
||||||
|
messages*: seq[TimestampedMessageID]
|
||||||
|
|
||||||
|
const
|
||||||
|
DefaultBloomFilterCapacity* = 10000
|
||||||
|
DefaultBloomFilterErrorRate* = 0.001
|
||||||
|
DefaultBloomFilterWindow* = initDuration(hours = 1)
|
||||||
|
|
||||||
|
proc logError*(msg: string) =
|
||||||
|
error "ReliabilityError", message = msg
|
||||||
|
|
||||||
|
proc logInfo*(msg: string) =
|
||||||
|
info "ReliabilityInfo", message = msg
|
||||||
|
|
||||||
|
proc newRollingBloomFilter*(capacity: int, errorRate: float, window: times.Duration): RollingBloomFilter {.gcsafe.} =
|
||||||
|
try:
|
||||||
|
var filterResult: Result[BloomFilter, string]
|
||||||
|
{.gcsafe.}:
|
||||||
|
filterResult = initializeBloomFilter(capacity, errorRate)
|
||||||
|
|
||||||
|
if filterResult.isOk:
|
||||||
|
logInfo("Successfully initialized bloom filter")
|
||||||
|
return RollingBloomFilter(
|
||||||
|
filter: filterResult.get(), # Extract the BloomFilter from Result
|
||||||
|
window: window,
|
||||||
|
messages: @[]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logError("Failed to initialize bloom filter: " & filterResult.error)
|
||||||
|
# Fall through to default case below
|
||||||
|
|
||||||
|
except:
|
||||||
|
logError("Failed to initialize bloom filter")
|
||||||
|
|
||||||
|
# Default fallback case
|
||||||
|
let defaultResult = initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
|
||||||
|
if defaultResult.isOk:
|
||||||
|
return RollingBloomFilter(
|
||||||
|
filter: defaultResult.get(),
|
||||||
|
window: window,
|
||||||
|
messages: @[]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If even default initialization fails, raise an exception
|
||||||
|
logError("Failed to initialize bloom filter with default parameters")
|
||||||
|
|
||||||
|
proc add*(rbf: var RollingBloomFilter, messageId: MessageID) {.gcsafe.} =
|
||||||
|
## Adds a message ID to the rolling bloom filter.
|
||||||
|
##
|
||||||
|
## Parameters:
|
||||||
|
## - messageId: The ID of the message to add.
|
||||||
|
rbf.filter.insert(messageId)
|
||||||
|
rbf.messages.add(TimestampedMessageID(id: messageId, timestamp: getTime()))
|
||||||
|
|
||||||
|
proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} =
|
||||||
|
## Checks if a message ID is in the rolling bloom filter.
|
||||||
|
##
|
||||||
|
## Parameters:
|
||||||
|
## - messageId: The ID of the message to check.
|
||||||
|
##
|
||||||
|
## Returns:
|
||||||
|
## True if the message ID is probably in the filter, false otherwise.
|
||||||
|
rbf.filter.lookup(messageId)
|
||||||
|
|
||||||
|
proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
|
||||||
|
try:
|
||||||
|
let now = getTime()
|
||||||
|
let cutoff = now - rbf.window
|
||||||
|
var newMessages: seq[TimestampedMessageID] = @[]
|
||||||
|
|
||||||
|
# Initialize new filter
|
||||||
|
let newFilterResult = initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate)
|
||||||
|
if newFilterResult.isErr:
|
||||||
|
logError("Failed to create new bloom filter: " & newFilterResult.error)
|
||||||
|
return
|
||||||
|
|
||||||
|
var newFilter = newFilterResult.get()
|
||||||
|
|
||||||
|
for msg in rbf.messages:
|
||||||
|
if msg.timestamp > cutoff:
|
||||||
|
newMessages.add(msg)
|
||||||
|
newFilter.insert(msg.id)
|
||||||
|
|
||||||
|
rbf.messages = newMessages
|
||||||
|
rbf.filter = newFilter
|
||||||
|
except Exception as e:
|
||||||
|
logError("Failed to clean bloom filter: " & e.msg)
|
||||||
Loading…
x
Reference in New Issue
Block a user