feat - add rolling bloom filter, reliability utils and protobuf

This commit is contained in:
shash256 2025-01-13 17:39:26 +04:00
parent 5df71ad3ea
commit 6b0b9c34fa
6 changed files with 390 additions and 0 deletions

16
reliability.nimble Normal file
View File

@ -0,0 +1,16 @@
# Package
version = "0.1.0"
author = "Waku Team"
description = "E2E Reliability Protocol API"
license = "MIT"
srcDir = "src"
# Dependencies
requires "nim >= 2.0.8"
requires "chronicles"
requires "libp2p"
# Tasks
task test, "Run the test suite":
exec "nim c -r tests/test_bloom.nim"
exec "nim c -r tests/test_reliability.nim"

30
src/message.nim Normal file
View File

@ -0,0 +1,30 @@
import std/times
type
MessageID* = string
Message* = object
messageId*: MessageID
lamportTimestamp*: int64
causalHistory*: seq[MessageID]
channelId*: string
content*: seq[byte]
bloomFilter*: seq[byte]
UnacknowledgedMessage* = object
message*: Message
sendTime*: Time
resendAttempts*: int
TimestampedMessageID* = object
id*: MessageID
timestamp*: Time
const
DefaultMaxMessageHistory* = 1000
DefaultMaxCausalHistory* = 10
DefaultResendInterval* = initDuration(seconds = 60)
DefaultMaxResendAttempts* = 5
DefaultSyncMessageInterval* = initDuration(seconds = 30)
DefaultBufferSweepInterval* = initDuration(seconds = 60)
MaxMessageSize* = 1024 * 1024 # 1 MB

122
src/protobuf.nim Normal file
View File

@ -0,0 +1,122 @@
import libp2p/protobuf/minprotobuf
import std/options
import ../src/[message, protobufutil, bloom, reliability_utils]
proc toBytes(s: string): seq[byte] =
result = newSeq[byte](s.len)
copyMem(result[0].addr, s[0].unsafeAddr, s.len)
proc encode*(msg: Message): ProtoBuffer =
var pb = initProtoBuffer()
pb.write(1, msg.messageId)
pb.write(2, uint64(msg.lamportTimestamp))
for hist in msg.causalHistory:
pb.write(3, hist.toBytes) # Convert string to bytes for proper length handling
pb.write(4, msg.channelId)
pb.write(5, msg.content)
pb.write(6, msg.bloomFilter)
pb.finish()
pb
proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var msg = Message()
if not ?pb.getField(1, msg.messageId):
return err(ProtobufError.missingRequiredField("messageId"))
var timestamp: uint64
if not ?pb.getField(2, timestamp):
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
msg.lamportTimestamp = int64(timestamp)
# Decode causal history
var causalHistory: seq[string]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = causalHistory
if not ?pb.getField(4, msg.channelId):
return err(ProtobufError.missingRequiredField("channelId"))
if not ?pb.getField(5, msg.content):
return err(ProtobufError.missingRequiredField("content"))
if not ?pb.getField(6, msg.bloomFilter):
msg.bloomFilter = @[] # Empty if not present
ok(msg)
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
try:
let pb = encode(msg)
ok(pb.buffer)
except:
err(reSerializationError)
proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
try:
let msgResult = Message.decode(data)
if msgResult.isOk:
ok(msgResult.get)
else:
err(reSerializationError)
except:
err(reDeserializationError)
proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
try:
var pb = initProtoBuffer()
# Convert intArray to bytes
var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
for i, val in filter.intArray:
let start = i * sizeof(int)
copyMem(addr bytes[start], unsafeAddr val, sizeof(int))
pb.write(1, bytes)
pb.write(2, uint64(filter.capacity))
pb.write(3, uint64(filter.errorRate * 1_000_000))
pb.write(4, uint64(filter.kHashes))
pb.write(5, uint64(filter.mBits))
pb.finish()
ok(pb.buffer)
except:
err(reSerializationError)
proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
if data.len == 0:
return err(reDeserializationError)
try:
let pb = initProtoBuffer(data)
var bytes: seq[byte]
var cap, errRate, kHashes, mBits: uint64
if not pb.getField(1, bytes).get() or
not pb.getField(2, cap).get() or
not pb.getField(3, errRate).get() or
not pb.getField(4, kHashes).get() or
not pb.getField(5, mBits).get():
return err(reDeserializationError)
# Convert bytes back to intArray
var intArray = newSeq[int](bytes.len div sizeof(int))
for i in 0 ..< intArray.len:
let start = i * sizeof(int)
copyMem(addr intArray[i], unsafeAddr bytes[start], sizeof(int))
ok(BloomFilter(
intArray: intArray,
capacity: int(cap),
errorRate: float(errRate) / 1_000_000,
kHashes: int(kHashes),
mBits: int(mBits)
))
except:
err(reDeserializationError)

32
src/protobufutil.nim Normal file
View File

@ -0,0 +1,32 @@
# adapted from https://github.com/waku-org/nwaku/blob/master/waku/common/protobuf.nim
{.push raises: [].}
import libp2p/protobuf/minprotobuf
import libp2p/varint
export minprotobuf, varint
type
ProtobufErrorKind* {.pure.} = enum
DecodeFailure
MissingRequiredField
ProtobufError* = object
case kind*: ProtobufErrorKind
of DecodeFailure:
error*: minprotobuf.ProtoError
of MissingRequiredField:
field*: string
ProtobufResult*[T] = Result[T, ProtobufError]
converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
case err
of minprotobuf.ProtoError.RequiredFieldMissing:
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
else:
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
proc missingRequiredField*(T: type ProtobufError, field: string): T =
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)

95
src/reliability_utils.nim Normal file
View File

@ -0,0 +1,95 @@
import std/[times, locks]
import ./[rolling_bloom_filter, message]
type
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
ReliabilityConfig* = object
bloomFilterCapacity*: int
bloomFilterErrorRate*: float
bloomFilterWindow*: times.Duration
maxMessageHistory*: int
maxCausalHistory*: int
resendInterval*: times.Duration
maxResendAttempts*: int
syncMessageInterval*: times.Duration
bufferSweepInterval*: times.Duration
ReliabilityManager* = ref object
lamportTimestamp*: int64
messageHistory*: seq[MessageID]
bloomFilter*: RollingBloomFilter
outgoingBuffer*: seq[UnacknowledgedMessage]
incomingBuffer*: seq[Message]
channelId*: string
config*: ReliabilityConfig
lock*: Lock
onMessageReady*: proc(messageId: MessageID) {.gcsafe.}
onMessageSent*: proc(messageId: MessageID) {.gcsafe.}
onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}
onPeriodicSync*: PeriodicSyncCallback
ReliabilityError* = enum
reInvalidArgument
reOutOfMemory
reInternalError
reSerializationError
reDeserializationError
reMessageTooLarge
proc defaultConfig*(): ReliabilityConfig =
## Creates a default configuration for the ReliabilityManager.
##
## Returns:
## A ReliabilityConfig object with default values.
ReliabilityConfig(
bloomFilterCapacity: DefaultBloomFilterCapacity,
bloomFilterErrorRate: DefaultBloomFilterErrorRate,
bloomFilterWindow: DefaultBloomFilterWindow,
maxMessageHistory: DefaultMaxMessageHistory,
maxCausalHistory: DefaultMaxCausalHistory,
resendInterval: DefaultResendInterval,
maxResendAttempts: DefaultMaxResendAttempts,
syncMessageInterval: DefaultSyncMessageInterval,
bufferSweepInterval: DefaultBufferSweepInterval
)
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
if not rm.isNil:
{.gcsafe.}:
try:
rm.outgoingBuffer.setLen(0)
rm.incomingBuffer.setLen(0)
rm.messageHistory.setLen(0)
except Exception as e:
logError("Error during cleanup: " & e.msg)
proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
withLock rm.lock:
try:
rm.bloomFilter.clean()
except Exception as e:
logError("Failed to clean ReliabilityManager bloom filter: " & e.msg)
proc addToHistory*(rm: ReliabilityManager, msgId: MessageID) {.gcsafe, raises: [].} =
rm.messageHistory.add(msgId)
if rm.messageHistory.len > rm.config.maxMessageHistory:
rm.messageHistory.delete(0)
proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) {.gcsafe, raises: [].} =
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
proc getRecentMessageIDs*(rm: ReliabilityManager, n: int): seq[MessageID] =
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
proc getMessageHistory*(rm: ReliabilityManager): seq[MessageID] =
withLock rm.lock:
result = rm.messageHistory
proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] =
withLock rm.lock:
result = rm.outgoingBuffer
proc getIncomingBuffer*(rm: ReliabilityManager): seq[Message] =
withLock rm.lock:
result = rm.incomingBuffer

View File

@ -0,0 +1,95 @@
import std/times
import chronos
import chronicles
import ./[bloom, message]
type
RollingBloomFilter* = object
filter*: BloomFilter
window*: times.Duration
messages*: seq[TimestampedMessageID]
const
DefaultBloomFilterCapacity* = 10000
DefaultBloomFilterErrorRate* = 0.001
DefaultBloomFilterWindow* = initDuration(hours = 1)
proc logError*(msg: string) =
error "ReliabilityError", message = msg
proc logInfo*(msg: string) =
info "ReliabilityInfo", message = msg
proc newRollingBloomFilter*(capacity: int, errorRate: float, window: times.Duration): RollingBloomFilter {.gcsafe.} =
try:
var filterResult: Result[BloomFilter, string]
{.gcsafe.}:
filterResult = initializeBloomFilter(capacity, errorRate)
if filterResult.isOk:
logInfo("Successfully initialized bloom filter")
return RollingBloomFilter(
filter: filterResult.get(), # Extract the BloomFilter from Result
window: window,
messages: @[]
)
else:
logError("Failed to initialize bloom filter: " & filterResult.error)
# Fall through to default case below
except:
logError("Failed to initialize bloom filter")
# Default fallback case
let defaultResult = initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
if defaultResult.isOk:
return RollingBloomFilter(
filter: defaultResult.get(),
window: window,
messages: @[]
)
else:
# If even default initialization fails, raise an exception
logError("Failed to initialize bloom filter with default parameters")
proc add*(rbf: var RollingBloomFilter, messageId: MessageID) {.gcsafe.} =
## Adds a message ID to the rolling bloom filter.
##
## Parameters:
## - messageId: The ID of the message to add.
rbf.filter.insert(messageId)
rbf.messages.add(TimestampedMessageID(id: messageId, timestamp: getTime()))
proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} =
## Checks if a message ID is in the rolling bloom filter.
##
## Parameters:
## - messageId: The ID of the message to check.
##
## Returns:
## True if the message ID is probably in the filter, false otherwise.
rbf.filter.lookup(messageId)
proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
try:
let now = getTime()
let cutoff = now - rbf.window
var newMessages: seq[TimestampedMessageID] = @[]
# Initialize new filter
let newFilterResult = initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate)
if newFilterResult.isErr:
logError("Failed to create new bloom filter: " & newFilterResult.error)
return
var newFilter = newFilterResult.get()
for msg in rbf.messages:
if msg.timestamp > cutoff:
newMessages.add(msg)
newFilter.insert(msg.id)
rbf.messages = newMessages
rbf.filter = newFilter
except Exception as e:
logError("Failed to clean bloom filter: " & e.msg)