mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-01-02 14:13:07 +00:00
chore: some fixes
This commit is contained in:
parent
cb1f40c9c2
commit
db190d914e
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
|
||||
.DS_Store
|
||||
tests/test_reliability
|
||||
@ -7,7 +7,7 @@ import private/probabilities
|
||||
{.compile: "murmur3.c".}
|
||||
|
||||
type
|
||||
BloomFilterError = object of CatchableError
|
||||
BloomFilterError* = object of CatchableError
|
||||
MurmurHashes = array[0..1, int]
|
||||
BloomFilter* = object
|
||||
capacity*: int
|
||||
|
||||
@ -8,6 +8,7 @@ srcDir = "src"
|
||||
# Dependencies
|
||||
requires "nim >= 2.0.8"
|
||||
requires "chronicles"
|
||||
requires "libp2p"
|
||||
|
||||
# Tasks
|
||||
task test, "Run the test suite":
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
import std/[times, json, locks]
|
||||
import std/[times, locks]
|
||||
import "../nim-bloom/src/bloom"
|
||||
|
||||
type
|
||||
MessageID* = string
|
||||
|
||||
Message* = object
|
||||
senderId*: string
|
||||
messageId*: MessageID
|
||||
lamportTimestamp*: int64
|
||||
causalHistory*: seq[MessageID]
|
||||
channelId*: string
|
||||
content*: string
|
||||
content*: seq[byte]
|
||||
|
||||
UnacknowledgedMessage* = object
|
||||
message*: Message
|
||||
@ -21,18 +20,20 @@ type
|
||||
id*: MessageID
|
||||
timestamp*: Time
|
||||
|
||||
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
|
||||
|
||||
RollingBloomFilter* = object
|
||||
filter*: BloomFilter
|
||||
window*: Duration
|
||||
window*: times.Duration
|
||||
messages*: seq[TimestampedMessageID]
|
||||
|
||||
ReliabilityConfig* = object
|
||||
bloomFilterCapacity*: int
|
||||
bloomFilterErrorRate*: float
|
||||
bloomFilterWindow*: Duration
|
||||
bloomFilterWindow*: times.Duration
|
||||
maxMessageHistory*: int
|
||||
maxCausalHistory*: int
|
||||
resendInterval*: Duration
|
||||
resendInterval*: times.Duration
|
||||
maxResendAttempts*: int
|
||||
|
||||
ReliabilityManager* = ref object
|
||||
@ -44,26 +45,19 @@ type
|
||||
channelId*: string
|
||||
config*: ReliabilityConfig
|
||||
lock*: Lock
|
||||
onMessageReady*: proc(messageId: MessageID)
|
||||
onMessageSent*: proc(messageId: MessageID)
|
||||
onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID])
|
||||
onMessageReady*: proc(messageId: MessageID) {.gcsafe.}
|
||||
onMessageSent*: proc(messageId: MessageID) {.gcsafe.}
|
||||
onMissingDependencies*: proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.}
|
||||
onPeriodicSync*: PeriodicSyncCallback
|
||||
|
||||
ReliabilityError* = enum
|
||||
reSuccess,
|
||||
reInvalidArgument,
|
||||
reOutOfMemory,
|
||||
reInternalError,
|
||||
reSerializationError,
|
||||
reDeserializationError,
|
||||
reInvalidArgument
|
||||
reOutOfMemory
|
||||
reInternalError
|
||||
reSerializationError
|
||||
reDeserializationError
|
||||
reMessageTooLarge
|
||||
|
||||
Result*[T] = object
|
||||
case isOk*: bool
|
||||
of true:
|
||||
value*: T
|
||||
of false:
|
||||
error*: ReliabilityError
|
||||
|
||||
const
|
||||
DefaultBloomFilterCapacity* = 10000
|
||||
DefaultBloomFilterErrorRate* = 0.001
|
||||
@ -72,10 +66,4 @@ const
|
||||
DefaultMaxCausalHistory* = 10
|
||||
DefaultResendInterval* = initDuration(seconds = 30)
|
||||
DefaultMaxResendAttempts* = 5
|
||||
MaxMessageSize* = 1024 * 1024 # 1 MB
|
||||
|
||||
proc ok*[T](value: T): Result[T] =
|
||||
Result[T](isOk: true, value: value)
|
||||
|
||||
proc err*[T](error: ReliabilityError): Result[T] =
|
||||
Result[T](isOk: false, error: error)
|
||||
MaxMessageSize* = 1024 * 1024 # 1 MB
|
||||
58
src/protobuf.nim
Normal file
58
src/protobuf.nim
Normal file
@ -0,0 +1,58 @@
|
||||
import ./protobufutil
|
||||
import ./common
|
||||
import libp2p/protobuf/minprotobuf
|
||||
import std/options
|
||||
|
||||
proc encode*(msg: Message): ProtoBuffer =
|
||||
var pb = initProtoBuffer()
|
||||
|
||||
pb.write(1, msg.messageId)
|
||||
pb.write(2, uint64(msg.lamportTimestamp))
|
||||
for hist in msg.causalHistory:
|
||||
pb.write(3, hist)
|
||||
pb.write(4, msg.channelId)
|
||||
pb.write(5, msg.content)
|
||||
pb.finish()
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type Message, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var msg = Message()
|
||||
|
||||
if not ?pb.getField(1, msg.messageId):
|
||||
return err(ProtobufError.missingRequiredField("messageId"))
|
||||
|
||||
var timestamp: uint64
|
||||
if not ?pb.getField(2, timestamp):
|
||||
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
|
||||
msg.lamportTimestamp = int64(timestamp)
|
||||
|
||||
var hist: string
|
||||
while ?pb.getField(3, hist):
|
||||
msg.causalHistory.add(hist)
|
||||
|
||||
if not ?pb.getField(4, msg.channelId):
|
||||
return err(ProtobufError.missingRequiredField("channelId"))
|
||||
|
||||
if not ?pb.getField(5, msg.content):
|
||||
return err(ProtobufError.missingRequiredField("content"))
|
||||
|
||||
ok(msg)
|
||||
|
||||
proc serializeMessage*(msg: Message): Result[seq[byte], ReliabilityError] =
|
||||
try:
|
||||
let pb = encode(msg)
|
||||
ok(pb.buffer)
|
||||
except:
|
||||
err(reSerializationError)
|
||||
|
||||
proc deserializeMessage*(data: seq[byte]): Result[Message, ReliabilityError] =
|
||||
try:
|
||||
let msgResult = Message.decode(data)
|
||||
if msgResult.isOk:
|
||||
ok(msgResult.get)
|
||||
else:
|
||||
err(reSerializationError)
|
||||
except:
|
||||
err(reDeserializationError)
|
||||
36
src/protobufutil.nim
Normal file
36
src/protobufutil.nim
Normal file
@ -0,0 +1,36 @@
|
||||
# adapted from https://github.com/waku-org/nwaku/blob/master/waku/common/protobuf.nim
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import libp2p/protobuf/minprotobuf
|
||||
import libp2p/varint
|
||||
|
||||
export minprotobuf, varint
|
||||
|
||||
type
|
||||
ProtobufErrorKind* {.pure.} = enum
|
||||
DecodeFailure
|
||||
MissingRequiredField
|
||||
InvalidLengthField
|
||||
|
||||
ProtobufError* = object
|
||||
case kind*: ProtobufErrorKind
|
||||
of DecodeFailure:
|
||||
error*: minprotobuf.ProtoError
|
||||
of MissingRequiredField, InvalidLengthField:
|
||||
field*: string
|
||||
|
||||
ProtobufResult*[T] = Result[T, ProtobufError]
|
||||
|
||||
converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
|
||||
case err
|
||||
of minprotobuf.ProtoError.RequiredFieldMissing:
|
||||
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
|
||||
else:
|
||||
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
|
||||
|
||||
proc missingRequiredField*(T: type ProtobufError, field: string): T =
|
||||
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)
|
||||
|
||||
proc invalidLengthField*(T: type ProtobufError, field: string): T =
|
||||
ProtobufError(kind: ProtobufErrorKind.InvalidLengthField, field: field)
|
||||
@ -1,4 +1,8 @@
|
||||
import ./common, ./utils
|
||||
import std/[times, locks]
|
||||
import chronos, results
|
||||
import ./common
|
||||
import ./utils
|
||||
import ./protobuf
|
||||
|
||||
proc defaultConfig*(): ReliabilityConfig =
|
||||
## Creates a default configuration for the ReliabilityManager.
|
||||
@ -15,7 +19,7 @@ proc defaultConfig*(): ReliabilityConfig =
|
||||
maxResendAttempts: DefaultMaxResendAttempts
|
||||
)
|
||||
|
||||
proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defaultConfig()): Result[ReliabilityManager] =
|
||||
proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defaultConfig()): Result[ReliabilityManager, ReliabilityError] =
|
||||
## Creates a new ReliabilityManager with the specified channel ID and configuration.
|
||||
##
|
||||
## Parameters:
|
||||
@ -25,17 +29,19 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau
|
||||
## Returns:
|
||||
## A Result containing either a new ReliabilityManager instance or an error.
|
||||
if channelId.len == 0:
|
||||
return err[ReliabilityManager](reInvalidArgument)
|
||||
return err(reInvalidArgument)
|
||||
|
||||
try:
|
||||
let bloomFilterResult = newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate, config.bloomFilterWindow)
|
||||
if bloomFilterResult.isErr:
|
||||
return err[ReliabilityManager](bloomFilterResult.error)
|
||||
|
||||
let bloomFilter = newRollingBloomFilter(
|
||||
config.bloomFilterCapacity,
|
||||
config.bloomFilterErrorRate,
|
||||
config.bloomFilterWindow
|
||||
)
|
||||
|
||||
let rm = ReliabilityManager(
|
||||
lamportTimestamp: 0,
|
||||
messageHistory: @[],
|
||||
bloomFilter: bloomFilterResult.value,
|
||||
bloomFilter: bloomFilter,
|
||||
outgoingBuffer: @[],
|
||||
incomingBuffer: @[],
|
||||
channelId: channelId,
|
||||
@ -44,9 +50,9 @@ proc newReliabilityManager*(channelId: string, config: ReliabilityConfig = defau
|
||||
initLock(rm.lock)
|
||||
return ok(rm)
|
||||
except:
|
||||
return err[ReliabilityManager](reOutOfMemory)
|
||||
return err(reOutOfMemory)
|
||||
|
||||
proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte]): Result[seq[byte]] =
|
||||
proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte], messageId: MessageID): Result[seq[byte], ReliabilityError] =
|
||||
## Wraps an outgoing message with reliability metadata.
|
||||
##
|
||||
## Parameters:
|
||||
@ -55,15 +61,14 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte]): Result[se
|
||||
## Returns:
|
||||
## A Result containing either a Message object with reliability metadata or an error.
|
||||
if message.len == 0:
|
||||
return err[Message](reInvalidArgument)
|
||||
return err(reInvalidArgument)
|
||||
if message.len > MaxMessageSize:
|
||||
return err[Message](reMessageTooLarge)
|
||||
return err(reMessageTooLarge)
|
||||
|
||||
withLock rm.lock:
|
||||
try:
|
||||
let msg = Message(
|
||||
senderId: "TODO_SENDER_ID",
|
||||
messageId: generateUniqueID(),
|
||||
messageId: messageId,
|
||||
lamportTimestamp: rm.lamportTimestamp,
|
||||
causalHistory: rm.getRecentMessageIDs(rm.config.maxCausalHistory),
|
||||
channelId: rm.channelId,
|
||||
@ -71,11 +76,11 @@ proc wrapOutgoingMessage*(rm: ReliabilityManager, message: seq[byte]): Result[se
|
||||
)
|
||||
rm.updateLamportTimestamp(getTime().toUnix)
|
||||
rm.outgoingBuffer.add(UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0))
|
||||
return ok(msg)
|
||||
return serializeMessage(msg)
|
||||
except:
|
||||
return err[Message](reInternalError)
|
||||
return err(reInternalError)
|
||||
|
||||
proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]]] =
|
||||
proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[tuple[message: seq[byte], missingDeps: seq[MessageID]], ReliabilityError] =
|
||||
## Unwraps a received message and processes its reliability metadata.
|
||||
##
|
||||
## Parameters:
|
||||
@ -85,33 +90,38 @@ proc unwrapReceivedMessage*(rm: ReliabilityManager, message: seq[byte]): Result[
|
||||
## A Result containing either a tuple with the processed message and missing dependencies, or an error.
|
||||
withLock rm.lock:
|
||||
try:
|
||||
if rm.bloomFilter.contains(message.messageId):
|
||||
return ok((message, @[]))
|
||||
let msgResult = deserializeMessage(message)
|
||||
if not msgResult.isOk:
|
||||
return err(msgResult.error)
|
||||
|
||||
let msg = msgResult.get
|
||||
if rm.bloomFilter.contains(msg.messageId):
|
||||
return ok((msg.content, @[]))
|
||||
|
||||
rm.bloomFilter.add(message.messageId)
|
||||
rm.updateLamportTimestamp(message.lamportTimestamp)
|
||||
rm.bloomFilter.add(msg.messageId)
|
||||
rm.updateLamportTimestamp(msg.lamportTimestamp)
|
||||
|
||||
var missingDeps: seq[MessageID] = @[]
|
||||
for depId in message.causalHistory:
|
||||
for depId in msg.causalHistory:
|
||||
if not rm.bloomFilter.contains(depId):
|
||||
missingDeps.add(depId)
|
||||
|
||||
if missingDeps.len == 0:
|
||||
rm.messageHistory.add(message.messageId)
|
||||
rm.messageHistory.add(msg.messageId)
|
||||
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
||||
rm.messageHistory.delete(0)
|
||||
if rm.onMessageReady != nil:
|
||||
rm.onMessageReady(message.messageId)
|
||||
rm.onMessageReady(msg.messageId)
|
||||
else:
|
||||
rm.incomingBuffer.add(message)
|
||||
rm.incomingBuffer.add(msg)
|
||||
if rm.onMissingDependencies != nil:
|
||||
rm.onMissingDependencies(message.messageId, missingDeps)
|
||||
rm.onMissingDependencies(msg.messageId, missingDeps)
|
||||
|
||||
return ok((message, missingDeps))
|
||||
return ok((msg.content, missingDeps))
|
||||
except:
|
||||
return err[(Message, seq[MessageID])](reInternalError)
|
||||
return err(reInternalError)
|
||||
|
||||
proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): Result[void] =
|
||||
proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): Result[void, ReliabilityError] =
|
||||
## Marks the specified message dependencies as met.
|
||||
##
|
||||
## Parameters:
|
||||
@ -122,9 +132,21 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R
|
||||
withLock rm.lock:
|
||||
try:
|
||||
var processedMessages: seq[Message] = @[]
|
||||
rm.incomingBuffer = rm.incomingBuffer.filterIt(
|
||||
not messageIds.allIt(it in it.causalHistory or rm.bloomFilter.contains(it))
|
||||
)
|
||||
var newIncomingBuffer: seq[Message] = @[]
|
||||
|
||||
for msg in rm.incomingBuffer:
|
||||
var allDependenciesMet = true
|
||||
for depId in msg.causalHistory:
|
||||
if depId notin messageIds and not rm.bloomFilter.contains(depId):
|
||||
allDependenciesMet = false
|
||||
break
|
||||
|
||||
if allDependenciesMet:
|
||||
processedMessages.add(msg)
|
||||
else:
|
||||
newIncomingBuffer.add(msg)
|
||||
|
||||
rm.incomingBuffer = newIncomingBuffer
|
||||
|
||||
for msg in processedMessages:
|
||||
rm.messageHistory.add(msg.messageId)
|
||||
@ -135,72 +157,80 @@ proc markDependenciesMet*(rm: ReliabilityManager, messageIds: seq[MessageID]): R
|
||||
|
||||
return ok()
|
||||
except:
|
||||
return err[void](reInternalError)
|
||||
return err(reInternalError)
|
||||
|
||||
proc setCallbacks*(rm: ReliabilityManager,
|
||||
onMessageReady: proc(messageId: MessageID),
|
||||
onMessageSent: proc(messageId: MessageID),
|
||||
onMissingDependencies: proc(messageId: MessageID, missingDeps: seq[MessageID])) =
|
||||
onMessageReady: proc(messageId: MessageID) {.gcsafe.},
|
||||
onMessageSent: proc(messageId: MessageID) {.gcsafe.},
|
||||
onMissingDependencies: proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.},
|
||||
onPeriodicSync: PeriodicSyncCallback = nil) =
|
||||
## Sets the callback functions for various events in the ReliabilityManager.
|
||||
##
|
||||
## Parameters:
|
||||
## - onMessageReady: Callback function called when a message is ready to be processed.
|
||||
## - onMessageSent: Callback function called when a message is confirmed as sent.
|
||||
## - onMissingDependencies: Callback function called when a message has missing dependencies.
|
||||
## - onPeriodicSync: Callback function called to notify about periodic sync
|
||||
withLock rm.lock:
|
||||
rm.onMessageReady = onMessageReady
|
||||
rm.onMessageSent = onMessageSent
|
||||
rm.onMissingDependencies = onMissingDependencies
|
||||
rm.onPeriodicSync = onPeriodicSync
|
||||
|
||||
proc checkUnacknowledgedMessages*(rm: ReliabilityManager) =
|
||||
proc checkUnacknowledgedMessages*(rm: ReliabilityManager) {.raises: [].} =
|
||||
## Checks and processes unacknowledged messages in the outgoing buffer.
|
||||
withLock rm.lock:
|
||||
let now = getTime()
|
||||
var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[]
|
||||
for msg in rm.outgoingBuffer:
|
||||
if (now - msg.sendTime) < rm.config.resendInterval:
|
||||
newOutgoingBuffer.add(msg)
|
||||
elif msg.resendAttempts < rm.config.maxResendAttempts:
|
||||
# Resend the message
|
||||
msg.resendAttempts += 1
|
||||
msg.sendTime = now
|
||||
newOutgoingBuffer.add(msg)
|
||||
# Here you would actually resend the message
|
||||
elif rm.onMessageSent != nil:
|
||||
rm.onMessageSent(msg.message.messageId)
|
||||
rm.outgoingBuffer = newOutgoingBuffer
|
||||
|
||||
try:
|
||||
for msg in rm.outgoingBuffer:
|
||||
if (now - msg.sendTime) < rm.config.resendInterval:
|
||||
newOutgoingBuffer.add(msg)
|
||||
elif msg.resendAttempts < rm.config.maxResendAttempts:
|
||||
var updatedMsg = msg
|
||||
updatedMsg.resendAttempts += 1
|
||||
updatedMsg.sendTime = now
|
||||
newOutgoingBuffer.add(updatedMsg)
|
||||
elif rm.onMessageSent != nil:
|
||||
rm.onMessageSent(msg.message.messageId)
|
||||
|
||||
rm.outgoingBuffer = newOutgoingBuffer
|
||||
except:
|
||||
discard
|
||||
|
||||
proc periodicBufferSweep(rm: ReliabilityManager) {.async.} =
|
||||
## Periodically sweeps the buffer to clean up and resend messages.
|
||||
proc periodicBufferSweep(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} =
|
||||
## Periodically sweeps the buffer to clean up and check unacknowledged messages.
|
||||
##
|
||||
## This is an internal function and should not be called directly.
|
||||
while true:
|
||||
rm.checkUnacknowledgedMessages()
|
||||
rm.cleanBloomFilter()
|
||||
await sleepAsync(5000) # Sleep for 5 seconds
|
||||
{.gcsafe.}:
|
||||
try:
|
||||
rm.checkUnacknowledgedMessages()
|
||||
rm.cleanBloomFilter()
|
||||
except Exception as e:
|
||||
logError("Error in periodic buffer sweep: " & e.msg)
|
||||
await sleepAsync(chronos.seconds(5))
|
||||
|
||||
proc periodicSyncMessage(rm: ReliabilityManager) {.async.} =
|
||||
## Periodically sends a sync message to maintain connectivity.
|
||||
##
|
||||
## This is an internal function and should not be called directly.
|
||||
proc periodicSyncMessage(rm: ReliabilityManager) {.async: (raises: [CancelledError]).} =
|
||||
## Periodically notifies to send a sync message to maintain connectivity.
|
||||
while true:
|
||||
discard rm.wrapOutgoingMessage("") # Empty content for sync messages
|
||||
await sleepAsync(30000) # Sleep for 30 seconds
|
||||
{.gcsafe.}:
|
||||
try:
|
||||
if rm.onPeriodicSync != nil:
|
||||
rm.onPeriodicSync()
|
||||
except Exception as e:
|
||||
logError("Error in periodic sync: " & e.msg)
|
||||
await sleepAsync(chronos.seconds(30))
|
||||
|
||||
proc startPeriodicTasks*(rm: ReliabilityManager) =
|
||||
## Starts the periodic tasks for buffer sweeping and sync message sending.
|
||||
##
|
||||
## This procedure should be called after creating a ReliabilityManager to enable automatic maintenance.
|
||||
asyncCheck rm.periodicBufferSweep()
|
||||
asyncCheck rm.periodicSyncMessage()
|
||||
asyncSpawn rm.periodicBufferSweep()
|
||||
asyncSpawn rm.periodicSyncMessage()
|
||||
|
||||
# # To demonstrate how to use the ReliabilityManager
|
||||
# proc processMessage*(rm: ReliabilityManager, message: string): seq[MessageID] =
|
||||
# let wrappedMsg = checkAndLogError(rm.wrapOutgoingMessage(message), "Failed to wrap message")
|
||||
# let (_, missingDeps) = checkAndLogError(rm.unwrapReceivedMessage(wrappedMsg), "Failed to unwrap message")
|
||||
# return missingDeps
|
||||
|
||||
proc resetReliabilityManager*(rm: ReliabilityManager): Result[void] =
|
||||
proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityError] =
|
||||
## Resets the ReliabilityManager to its initial state.
|
||||
##
|
||||
## This procedure clears all buffers and resets the Lamport timestamp.
|
||||
@ -208,49 +238,26 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void] =
|
||||
## Returns:
|
||||
## A Result indicating success or an error if the Bloom filter initialization fails.
|
||||
withLock rm.lock:
|
||||
let bloomFilterResult = newRollingBloomFilter(rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate, rm.config.bloomFilterWindow)
|
||||
if bloomFilterResult.isErr:
|
||||
return err[void](bloomFilterResult.error)
|
||||
try:
|
||||
rm.lamportTimestamp = 0
|
||||
rm.messageHistory.setLen(0)
|
||||
rm.outgoingBuffer.setLen(0)
|
||||
rm.incomingBuffer.setLen(0)
|
||||
rm.bloomFilter = newRollingBloomFilter(
|
||||
rm.config.bloomFilterCapacity,
|
||||
rm.config.bloomFilterErrorRate,
|
||||
rm.config.bloomFilterWindow
|
||||
)
|
||||
return ok()
|
||||
except:
|
||||
return err(reInternalError)
|
||||
|
||||
rm.lamportTimestamp = 0
|
||||
rm.messageHistory.setLen(0)
|
||||
rm.outgoingBuffer.setLen(0)
|
||||
rm.incomingBuffer.setLen(0)
|
||||
rm.bloomFilter = bloomFilterResult.value
|
||||
return ok()
|
||||
|
||||
proc `=destroy`(rm: var ReliabilityManager) =
|
||||
## Destructor for ReliabilityManager. Ensures proper cleanup of resources.
|
||||
deinitLock(rm.lock)
|
||||
|
||||
when isMainModule:
|
||||
# Example usage and basic tests
|
||||
let config = defaultConfig()
|
||||
let rmResult = newReliabilityManager("testChannel", config)
|
||||
if rmResult.isOk:
|
||||
let rm = rmResult.value
|
||||
rm.setCallbacks(
|
||||
proc(messageId: MessageID) = echo "Message ready: ", messageId,
|
||||
proc(messageId: MessageID) = echo "Message sent: ", messageId,
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) = echo "Missing dependencies for ", messageId, ": ", missingDeps
|
||||
)
|
||||
|
||||
let msgResult = rm.wrapOutgoingMessage("Hello, World!")
|
||||
if msgResult.isOk:
|
||||
let msg = msgResult.value
|
||||
echo "Wrapped message: ", msg
|
||||
|
||||
let unwrapResult = rm.unwrapReceivedMessage(msg)
|
||||
if unwrapResult.isOk:
|
||||
let (unwrappedMsg, missingDeps) = unwrapResult.value
|
||||
echo "Unwrapped message: ", unwrappedMsg
|
||||
echo "Missing dependencies: ", missingDeps
|
||||
else:
|
||||
echo "Error unwrapping message: ", unwrapResult.error
|
||||
else:
|
||||
echo "Error wrapping message: ", msgResult.error
|
||||
|
||||
rm.startPeriodicTasks()
|
||||
# In a real application, you'd keep the program running to allow periodic tasks to execute
|
||||
else:
|
||||
echo "Error creating ReliabilityManager: ", rmResult.error
|
||||
proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
|
||||
if not rm.isNil:
|
||||
{.gcsafe.}:
|
||||
try:
|
||||
rm.outgoingBuffer.setLen(0)
|
||||
rm.incomingBuffer.setLen(0)
|
||||
rm.messageHistory.setLen(0)
|
||||
except Exception as e:
|
||||
logError("Error during cleanup: " & e.msg)
|
||||
174
src/utils.nim
174
src/utils.nim
@ -1,20 +1,37 @@
|
||||
import std/[times, hashes, random, sequtils, algorithm, json, options, locks, asyncdispatch]
|
||||
import chronicles
|
||||
import std/[times, locks]
|
||||
import chronos, chronicles
|
||||
import "../nim-bloom/src/bloom"
|
||||
import ./common
|
||||
|
||||
proc newRollingBloomFilter*(capacity: int, errorRate: float, window: Duration): Result[RollingBloomFilter] =
|
||||
proc logError*(msg: string) =
|
||||
error "ReliabilityError", message = msg
|
||||
|
||||
proc logInfo*(msg: string) =
|
||||
info "ReliabilityInfo", message = msg
|
||||
|
||||
proc newRollingBloomFilter*(capacity: int, errorRate: float, window: times.Duration): RollingBloomFilter {.gcsafe.} =
|
||||
try:
|
||||
let filter = initializeBloomFilter(capacity, errorRate)
|
||||
return ok(RollingBloomFilter(
|
||||
var filter: BloomFilter
|
||||
{.gcsafe.}:
|
||||
filter = initializeBloomFilter(capacity, errorRate)
|
||||
logInfo("Successfully initialized bloom filter")
|
||||
RollingBloomFilter(
|
||||
filter: filter,
|
||||
window: window,
|
||||
messages: @[]
|
||||
))
|
||||
)
|
||||
except:
|
||||
return err[RollingBloomFilter](reInternalError)
|
||||
logError("Failed to initialize bloom filter")
|
||||
var filter: BloomFilter
|
||||
{.gcsafe.}:
|
||||
filter = initializeBloomFilter(DefaultBloomFilterCapacity, DefaultBloomFilterErrorRate)
|
||||
RollingBloomFilter(
|
||||
filter: filter,
|
||||
window: window,
|
||||
messages: @[]
|
||||
)
|
||||
|
||||
proc add*(rbf: var RollingBloomFilter, messageId: MessageID) =
|
||||
proc add*(rbf: var RollingBloomFilter, messageId: MessageID) {.gcsafe.} =
|
||||
## Adds a message ID to the rolling bloom filter.
|
||||
##
|
||||
## Parameters:
|
||||
@ -22,7 +39,7 @@ proc add*(rbf: var RollingBloomFilter, messageId: MessageID) =
|
||||
rbf.filter.insert(messageId)
|
||||
rbf.messages.add(TimestampedMessageID(id: messageId, timestamp: getTime()))
|
||||
|
||||
proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool =
|
||||
proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool {.gcsafe.} =
|
||||
## Checks if a message ID is in the rolling bloom filter.
|
||||
##
|
||||
## Parameters:
|
||||
@ -32,125 +49,34 @@ proc contains*(rbf: RollingBloomFilter, messageId: MessageID): bool =
|
||||
## True if the message ID is probably in the filter, false otherwise.
|
||||
rbf.filter.lookup(messageId)
|
||||
|
||||
proc clean*(rbf: var RollingBloomFilter) =
|
||||
## Removes outdated entries from the rolling bloom filter.
|
||||
let now = getTime()
|
||||
let cutoff = now - rbf.window
|
||||
var newMessages: seq[TimestampedMessageID] = @[]
|
||||
var newFilter = initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate)
|
||||
proc clean*(rbf: var RollingBloomFilter) {.gcsafe.} =
|
||||
try:
|
||||
let now = getTime()
|
||||
let cutoff = now - rbf.window
|
||||
var newMessages: seq[TimestampedMessageID] = @[]
|
||||
var newFilter: BloomFilter
|
||||
{.gcsafe.}:
|
||||
newFilter = initializeBloomFilter(rbf.filter.capacity, rbf.filter.errorRate)
|
||||
|
||||
for msg in rbf.messages:
|
||||
if msg.timestamp > cutoff:
|
||||
newMessages.add(msg)
|
||||
newFilter.insert(msg.id)
|
||||
for msg in rbf.messages:
|
||||
if msg.timestamp > cutoff:
|
||||
newMessages.add(msg)
|
||||
newFilter.insert(msg.id)
|
||||
|
||||
rbf.messages = newMessages
|
||||
rbf.filter = newFilter
|
||||
rbf.messages = newMessages
|
||||
rbf.filter = newFilter
|
||||
except Exception as e:
|
||||
logError("Failed to clean bloom filter: " & e.msg)
|
||||
|
||||
proc cleanBloomFilter*(rm: ReliabilityManager) =
|
||||
## Cleans the rolling bloom filter, removing outdated entries.
|
||||
proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
||||
withLock rm.lock:
|
||||
rm.bloomFilter.clean()
|
||||
try:
|
||||
rm.bloomFilter.clean()
|
||||
except Exception as e:
|
||||
logError("Failed to clean ReliabilityManager bloom filter: " & e.msg)
|
||||
|
||||
proc updateLamportTimestamp(rm: ReliabilityManager, msgTs: int64) =
|
||||
proc updateLamportTimestamp*(rm: ReliabilityManager, msgTs: int64) =
|
||||
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
|
||||
|
||||
proc getRecentMessageIDs(rm: ReliabilityManager, n: int): seq[MessageID] =
|
||||
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
|
||||
|
||||
proc generateUniqueID*(): MessageID =
|
||||
let timestamp = getTime().toUnix
|
||||
let randomPart = rand(high(int))
|
||||
result = $hash($timestamp & $randomPart)
|
||||
|
||||
proc serializeMessage*(msg: Message): Result[string] =
|
||||
## Serializes a Message object to a JSON string.
|
||||
##
|
||||
## Parameters:
|
||||
## - msg: The Message object to serialize.
|
||||
##
|
||||
## Returns:
|
||||
## A Result containing either the serialized JSON string or an error.
|
||||
try:
|
||||
let jsonNode = %*{
|
||||
"senderId": msg.senderId,
|
||||
"messageId": msg.messageId,
|
||||
"lamportTimestamp": msg.lamportTimestamp,
|
||||
"causalHistory": msg.causalHistory,
|
||||
"channelId": msg.channelId,
|
||||
"content": msg.content
|
||||
}
|
||||
return ok($jsonNode)
|
||||
except:
|
||||
return err[string](reSerializationError)
|
||||
|
||||
proc deserializeMessage*(data: string): Result[Message] =
|
||||
## Deserializes a JSON string to a Message object.
|
||||
##
|
||||
## Parameters:
|
||||
## - data: The JSON string to deserialize.
|
||||
##
|
||||
## Returns:
|
||||
## A Result containing either the deserialized Message object or an error.
|
||||
try:
|
||||
let jsonNode = parseJson(data)
|
||||
return ok(Message(
|
||||
senderId: jsonNode["senderId"].getStr,
|
||||
messageId: jsonNode["messageId"].getStr,
|
||||
lamportTimestamp: jsonNode["lamportTimestamp"].getBiggestInt,
|
||||
causalHistory: jsonNode["causalHistory"].to(seq[string]),
|
||||
channelId: jsonNode["channelId"].getStr,
|
||||
content: jsonNode["content"].getStr
|
||||
))
|
||||
except:
|
||||
return err[Message](reDeserializationError)
|
||||
|
||||
proc getMessageHistory*(rm: ReliabilityManager): seq[MessageID] =
|
||||
## Retrieves the current message history from the ReliabilityManager.
|
||||
##
|
||||
## Returns:
|
||||
## A sequence of MessageIDs representing the current message history.
|
||||
withLock rm.lock:
|
||||
return rm.messageHistory
|
||||
|
||||
proc getOutgoingBufferSize*(rm: ReliabilityManager): int =
|
||||
## Returns the current size of the outgoing message buffer.
|
||||
##
|
||||
## Returns:
|
||||
## The number of messages in the outgoing buffer.
|
||||
withLock rm.lock:
|
||||
return rm.outgoingBuffer.len
|
||||
|
||||
proc getIncomingBufferSize*(rm: ReliabilityManager): int =
|
||||
## Returns the current size of the incoming message buffer.
|
||||
##
|
||||
## Returns:
|
||||
## The number of messages in the incoming buffer.
|
||||
withLock rm.lock:
|
||||
return rm.incomingBuffer.len
|
||||
|
||||
proc logError*(msg: string) =
|
||||
## Logs an error message
|
||||
error "ReliabilityError", message = msg
|
||||
|
||||
proc logInfo*(msg: string) =
|
||||
## Logs an informational message
|
||||
info "ReliabilityInfo", message = msg
|
||||
|
||||
proc checkAndLogError*[T](res: Result[T], errorMsg: string): T =
|
||||
## Checks the result of an operation, logs any errors, and returns the value or raises an exception.
|
||||
##
|
||||
## Parameters:
|
||||
## - res: A Result[T] object to check.
|
||||
## - errorMsg: A message to log if an error occurred.
|
||||
##
|
||||
## Returns:
|
||||
## The value contained in the Result if it was successful.
|
||||
##
|
||||
## Raises:
|
||||
## An exception with the error message if the Result contains an error.
|
||||
if res.isOk:
|
||||
return res.value
|
||||
else:
|
||||
logError(errorMsg & ": " & $res.error)
|
||||
raise newException(ValueError, errorMsg)
|
||||
proc getRecentMessageIDs*(rm: ReliabilityManager, n: int): seq[MessageID] =
|
||||
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
|
||||
@ -1,92 +1,180 @@
|
||||
import unittest
|
||||
import unittest, results, chronos, chronicles
|
||||
import ../src/reliability
|
||||
import ../src/common
|
||||
|
||||
suite "ReliabilityManager":
|
||||
var rm: ReliabilityManager
|
||||
|
||||
setup:
|
||||
let rmResult = newReliabilityManager("testChannel")
|
||||
check rmResult.isOk
|
||||
let rm = rmResult.value
|
||||
check rmResult.isOk()
|
||||
rm = rmResult.get()
|
||||
|
||||
teardown:
|
||||
if not rm.isNil:
|
||||
rm.cleanup()
|
||||
|
||||
test "can create with default config":
|
||||
let config = defaultConfig()
|
||||
check config.bloomFilterCapacity == DefaultBloomFilterCapacity
|
||||
check config.bloomFilterErrorRate == DefaultBloomFilterErrorRate
|
||||
check config.bloomFilterWindow == DefaultBloomFilterWindow
|
||||
|
||||
test "wrapOutgoingMessage":
|
||||
let msgResult = rm.wrapOutgoingMessage("Hello, World!")
|
||||
check msgResult.isOk
|
||||
let msg = msgResult.value
|
||||
check:
|
||||
msg.content == "Hello, World!"
|
||||
msg.channelId == "testChannel"
|
||||
msg.causalHistory.len == 0
|
||||
let msg = @[byte(1), 2, 3]
|
||||
let msgId = "test-msg-1"
|
||||
let wrappedResult = rm.wrapOutgoingMessage(msg, msgId)
|
||||
check wrappedResult.isOk()
|
||||
let wrapped = wrappedResult.get()
|
||||
check wrapped.len > 0
|
||||
|
||||
test "unwrapReceivedMessage":
|
||||
let wrappedMsgResult = rm.wrapOutgoingMessage("Test message")
|
||||
check wrappedMsgResult.isOk
|
||||
let wrappedMsg = wrappedMsgResult.value
|
||||
let unwrapResult = rm.unwrapReceivedMessage(wrappedMsg)
|
||||
check unwrapResult.isOk
|
||||
let (unwrappedMsg, missingDeps) = unwrapResult.value
|
||||
let msg = @[byte(1), 2, 3]
|
||||
let msgId = "test-msg-1"
|
||||
let wrappedResult = rm.wrapOutgoingMessage(msg, msgId)
|
||||
check wrappedResult.isOk()
|
||||
let wrapped = wrappedResult.get()
|
||||
let unwrapResult = rm.unwrapReceivedMessage(wrapped)
|
||||
check unwrapResult.isOk()
|
||||
let (unwrapped, missingDeps) = unwrapResult.get()
|
||||
check:
|
||||
unwrappedMsg.content == "Test message"
|
||||
unwrapped == msg
|
||||
missingDeps.len == 0
|
||||
|
||||
test "markDependenciesMet":
|
||||
var msg1Result = rm.wrapOutgoingMessage("Message 1")
|
||||
var msg2Result = rm.wrapOutgoingMessage("Message 2")
|
||||
var msg3Result = rm.wrapOutgoingMessage("Message 3")
|
||||
check msg1Result.isOk and msg2Result.isOk and msg3Result.isOk
|
||||
let msg1 = msg1Result.value
|
||||
let msg2 = msg2Result.value
|
||||
let msg3 = msg3Result.value
|
||||
info "test_state", state="starting markDependenciesMet test"
|
||||
|
||||
var unwrapResult = rm.unwrapReceivedMessage(msg3)
|
||||
check unwrapResult.isOk
|
||||
var (_, missingDeps) = unwrapResult.value
|
||||
check missingDeps.len == 2
|
||||
block message1:
|
||||
let msg1 = @[byte(1)]
|
||||
let id1 = "msg1"
|
||||
info "message_creation", msg="message 1", id=id1
|
||||
let wrap1 = rm.wrapOutgoingMessage(msg1, id1)
|
||||
check wrap1.isOk()
|
||||
let wrapped1 = wrap1.get()
|
||||
|
||||
let markResult = rm.markDependenciesMet(@[msg1.messageId, msg2.messageId])
|
||||
check markResult.isOk
|
||||
info "message_processing", msg="message 1", id=id1
|
||||
let unwrap1 = rm.unwrapReceivedMessage(wrapped1)
|
||||
check unwrap1.isOk()
|
||||
let (content1, deps1) = unwrap1.get()
|
||||
info "message_processed", msg="message 1", deps_count=deps1.len
|
||||
check content1 == msg1
|
||||
|
||||
unwrapResult = rm.unwrapReceivedMessage(msg3)
|
||||
check unwrapResult.isOk
|
||||
(_, missingDeps) = unwrapResult.value
|
||||
check missingDeps.len == 0
|
||||
block message2:
|
||||
let msg2 = @[byte(2)]
|
||||
let id2 = "msg2"
|
||||
info "message_creation", msg="message 2", id=id2
|
||||
let wrap2 = rm.wrapOutgoingMessage(msg2, id2)
|
||||
check wrap2.isOk()
|
||||
let wrapped2 = wrap2.get()
|
||||
|
||||
test "callbacks":
|
||||
info "message_processing", msg="message 2", id=id2
|
||||
let unwrap2 = rm.unwrapReceivedMessage(wrapped2)
|
||||
check unwrap2.isOk()
|
||||
let (content2, deps2) = unwrap2.get()
|
||||
info "message_processed", msg="message 2", deps_count=deps2.len
|
||||
check content2 == msg2
|
||||
|
||||
block message3:
|
||||
info "message_creation", msg="message 3"
|
||||
let msg3 = @[byte(3)]
|
||||
let id3 = "msg3"
|
||||
let wrap3 = rm.wrapOutgoingMessage(msg3, id3)
|
||||
check wrap3.isOk()
|
||||
info "message_wrapped", msg="message 3", id=id3
|
||||
let wrapped3 = wrap3.get()
|
||||
|
||||
info "checking_dependencies", msg="message 3", id=id3
|
||||
var unwrap3 = rm.unwrapReceivedMessage(wrapped3)
|
||||
check unwrap3.isOk()
|
||||
var (content3, missing3) = unwrap3.get()
|
||||
info "dependencies_checked", msg="message 3", missing_deps=missing3.len
|
||||
|
||||
info "test_state", state="completed"
|
||||
|
||||
test "callbacks work correctly":
|
||||
var messageReadyCount = 0
|
||||
var messageSentCount = 0
|
||||
var missingDepsCount = 0
|
||||
|
||||
rm.setCallbacks(
|
||||
proc(messageId: MessageID) = messageReadyCount += 1,
|
||||
proc(messageId: MessageID) = messageSentCount += 1,
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) = missingDepsCount += 1
|
||||
proc(messageId: MessageID) {.gcsafe.} = messageReadyCount += 1,
|
||||
proc(messageId: MessageID) {.gcsafe.} = messageSentCount += 1,
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = missingDepsCount += 1
|
||||
)
|
||||
|
||||
let msg1Result = rm.wrapOutgoingMessage("Message 1")
|
||||
let msg2Result = rm.wrapOutgoingMessage("Message 2")
|
||||
check msg1Result.isOk and msg2Result.isOk
|
||||
let msg1 = msg1Result.value
|
||||
let msg2 = msg2Result.value
|
||||
let msg1Result = rm.wrapOutgoingMessage(@[byte(1)], "msg1")
|
||||
let msg2Result = rm.wrapOutgoingMessage(@[byte(2)], "msg2")
|
||||
check msg1Result.isOk() and msg2Result.isOk()
|
||||
let msg1 = msg1Result.get()
|
||||
let msg2 = msg2Result.get()
|
||||
discard rm.unwrapReceivedMessage(msg1)
|
||||
discard rm.unwrapReceivedMessage(msg2)
|
||||
|
||||
check:
|
||||
messageReadyCount == 2
|
||||
messageSentCount == 0 # This would be triggered by the checkUnacknowledgedMessages function
|
||||
messageSentCount == 0 # This would be triggered by checkUnacknowledgedMessages
|
||||
missingDepsCount == 0
|
||||
|
||||
test "serialization":
|
||||
let msgResult = rm.wrapOutgoingMessage("Test serialization")
|
||||
check msgResult.isOk
|
||||
let msg = msgResult.value
|
||||
let serializeResult = serializeMessage(msg)
|
||||
check serializeResult.isOk
|
||||
let serialized = serializeResult.value
|
||||
let deserializeResult = deserializeMessage(serialized)
|
||||
check deserializeResult.isOk
|
||||
let deserialized = deserializeResult.value
|
||||
check:
|
||||
deserialized.content == "Test serialization"
|
||||
deserialized.messageId == msg.messageId
|
||||
deserialized.lamportTimestamp == msg.lamportTimestamp
|
||||
test "periodic sync callback works":
|
||||
var syncCallCount = 0
|
||||
rm.setCallbacks(
|
||||
proc(messageId: MessageID) {.gcsafe.} = discard,
|
||||
proc(messageId: MessageID) {.gcsafe.} = discard,
|
||||
proc(messageId: MessageID, missingDeps: seq[MessageID]) {.gcsafe.} = discard,
|
||||
proc() {.gcsafe.} = syncCallCount += 1
|
||||
)
|
||||
|
||||
when isMainModule:
|
||||
unittest.run()
|
||||
rm.startPeriodicTasks()
|
||||
# Sleep briefly to allow periodic tasks to run
|
||||
waitFor sleepAsync(chronos.seconds(1))
|
||||
rm.cleanup()
|
||||
|
||||
check(syncCallCount > 0)
|
||||
|
||||
test "protobuf serialization":
|
||||
let msg = @[byte(1), 2, 3]
|
||||
let msgId = "test-msg-1"
|
||||
let msgResult = rm.wrapOutgoingMessage(msg, msgId)
|
||||
check msgResult.isOk()
|
||||
let wrapped = msgResult.get()
|
||||
|
||||
let unwrapResult = rm.unwrapReceivedMessage(wrapped)
|
||||
check unwrapResult.isOk()
|
||||
let (unwrapped, _) = unwrapResult.get()
|
||||
|
||||
check:
|
||||
unwrapped == msg
|
||||
unwrapped.len == msg.len
|
||||
|
||||
test "handles empty message":
|
||||
let msg: seq[byte] = @[]
|
||||
let msgId = "test-empty-msg"
|
||||
let wrappedResult = rm.wrapOutgoingMessage(msg, msgId)
|
||||
check(not wrappedResult.isOk())
|
||||
check(wrappedResult.error == reInvalidArgument)
|
||||
|
||||
test "handles message too large":
|
||||
let msg = newSeq[byte](MaxMessageSize + 1)
|
||||
let msgId = "test-large-msg"
|
||||
let wrappedResult = rm.wrapOutgoingMessage(msg, msgId)
|
||||
check(not wrappedResult.isOk())
|
||||
check(wrappedResult.error == reMessageTooLarge)
|
||||
|
||||
suite "cleanup":
|
||||
test "cleanup works correctly":
|
||||
let rmResult = newReliabilityManager("testChannel")
|
||||
check rmResult.isOk()
|
||||
let rm = rmResult.get()
|
||||
|
||||
# Add some messages
|
||||
let msg = @[byte(1), 2, 3]
|
||||
let msgId = "test-msg-1"
|
||||
discard rm.wrapOutgoingMessage(msg, msgId)
|
||||
|
||||
# Cleanup
|
||||
rm.cleanup()
|
||||
|
||||
# Check buffers are empty
|
||||
check(rm.outgoingBuffer.len == 0)
|
||||
check(rm.incomingBuffer.len == 0)
|
||||
check(rm.messageHistory.len == 0)
|
||||
Loading…
x
Reference in New Issue
Block a user