mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-01-02 14:13:07 +00:00
feat: Implementing Support for Multiple Channels in Single Reliability Manager (#13)
This commit is contained in:
parent
565ec66b30
commit
99121098cc
@ -1,4 +1,4 @@
|
|||||||
import std/[times, options, sets]
|
import std/[times, sets]
|
||||||
|
|
||||||
type
|
type
|
||||||
SdsMessageID* = string
|
SdsMessageID* = string
|
||||||
@ -8,7 +8,7 @@ type
|
|||||||
messageId*: SdsMessageID
|
messageId*: SdsMessageID
|
||||||
lamportTimestamp*: int64
|
lamportTimestamp*: int64
|
||||||
causalHistory*: seq[SdsMessageID]
|
causalHistory*: seq[SdsMessageID]
|
||||||
channelId*: Option[SdsChannelID]
|
channelId*: SdsChannelID
|
||||||
content*: seq[byte]
|
content*: seq[byte]
|
||||||
bloomFilter*: seq[byte]
|
bloomFilter*: seq[byte]
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,7 @@ proc encode*(msg: SdsMessage): ProtoBuffer =
|
|||||||
for hist in msg.causalHistory:
|
for hist in msg.causalHistory:
|
||||||
pb.write(3, hist)
|
pb.write(3, hist)
|
||||||
|
|
||||||
if msg.channelId.isSome():
|
pb.write(4, msg.channelId)
|
||||||
pb.write(4, msg.channelId.get())
|
|
||||||
pb.write(5, msg.content)
|
pb.write(5, msg.content)
|
||||||
pb.write(6, msg.bloomFilter)
|
pb.write(6, msg.bloomFilter)
|
||||||
pb.finish()
|
pb.finish()
|
||||||
@ -37,11 +36,8 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
|
|||||||
if histResult.isOk:
|
if histResult.isOk:
|
||||||
msg.causalHistory = causalHistory
|
msg.causalHistory = causalHistory
|
||||||
|
|
||||||
var channelId: SdsChannelID
|
if not ?pb.getField(4, msg.channelId):
|
||||||
if ?pb.getField(4, channelId):
|
return err(ProtobufError.missingRequiredField("channelId"))
|
||||||
msg.channelId = some(channelId)
|
|
||||||
else:
|
|
||||||
msg.channelId = none[SdsChannelID]()
|
|
||||||
|
|
||||||
if not ?pb.getField(5, msg.content):
|
if not ?pb.getField(5, msg.content):
|
||||||
return err(ProtobufError.missingRequiredField("content"))
|
return err(ProtobufError.missingRequiredField("content"))
|
||||||
@ -51,6 +47,17 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
|
|||||||
|
|
||||||
ok(msg)
|
ok(msg)
|
||||||
|
|
||||||
|
proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] =
|
||||||
|
## For extraction of channel ID without full message deserialization
|
||||||
|
try:
|
||||||
|
let pb = initProtoBuffer(data)
|
||||||
|
var channelId: SdsChannelID
|
||||||
|
if not pb.getField(4, channelId).get():
|
||||||
|
return err(ReliabilityError.reDeserializationError)
|
||||||
|
ok(channelId)
|
||||||
|
except:
|
||||||
|
err(ReliabilityError.reDeserializationError)
|
||||||
|
|
||||||
proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] =
|
proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] =
|
||||||
let pb = encode(msg)
|
let pb = encode(msg)
|
||||||
ok(pb.buffer)
|
ok(pb.buffer)
|
||||||
|
|||||||
@ -3,31 +3,18 @@ import chronos, results, chronicles
|
|||||||
import ./[message, protobuf, reliability_utils, rolling_bloom_filter]
|
import ./[message, protobuf, reliability_utils, rolling_bloom_filter]
|
||||||
|
|
||||||
proc newReliabilityManager*(
|
proc newReliabilityManager*(
|
||||||
channelId: Option[SdsChannelID], config: ReliabilityConfig = defaultConfig()
|
config: ReliabilityConfig = defaultConfig()
|
||||||
): Result[ReliabilityManager, ReliabilityError] =
|
): Result[ReliabilityManager, ReliabilityError] =
|
||||||
## Creates a new ReliabilityManager with the specified channel ID and configuration.
|
## Creates a new multi-channel ReliabilityManager.
|
||||||
##
|
##
|
||||||
## Parameters:
|
## Parameters:
|
||||||
## - channelId: A unique identifier for the communication channel.
|
|
||||||
## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used.
|
## - config: Configuration options for the ReliabilityManager. If not provided, default configuration is used.
|
||||||
##
|
##
|
||||||
## Returns:
|
## Returns:
|
||||||
## A Result containing either a new ReliabilityManager instance or an error.
|
## A Result containing either a new ReliabilityManager instance or an error.
|
||||||
if not channelId.isSome():
|
|
||||||
return err(ReliabilityError.reInvalidArgument)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
let bloomFilter =
|
|
||||||
newRollingBloomFilter(config.bloomFilterCapacity, config.bloomFilterErrorRate)
|
|
||||||
|
|
||||||
let rm = ReliabilityManager(
|
let rm = ReliabilityManager(
|
||||||
lamportTimestamp: 0,
|
channels: initTable[SdsChannelID, ChannelContext](), config: config
|
||||||
messageHistory: @[],
|
|
||||||
bloomFilter: bloomFilter,
|
|
||||||
outgoingBuffer: @[],
|
|
||||||
incomingBuffer: initTable[SdsMessageID, IncomingMessage](),
|
|
||||||
channelId: channelId,
|
|
||||||
config: config,
|
|
||||||
)
|
)
|
||||||
initLock(rm.lock)
|
initLock(rm.lock)
|
||||||
return ok(rm)
|
return ok(rm)
|
||||||
@ -73,29 +60,37 @@ proc reviewAckStatus(rm: ReliabilityManager, msg: SdsMessage) {.gcsafe.} =
|
|||||||
else:
|
else:
|
||||||
rbf = none[RollingBloomFilter]()
|
rbf = none[RollingBloomFilter]()
|
||||||
|
|
||||||
|
if msg.channelId notin rm.channels:
|
||||||
|
return
|
||||||
|
|
||||||
|
let channel = rm.channels[msg.channelId]
|
||||||
# Keep track of indices to delete
|
# Keep track of indices to delete
|
||||||
var toDelete: seq[int] = @[]
|
var toDelete: seq[int] = @[]
|
||||||
var i = 0
|
var i = 0
|
||||||
|
|
||||||
while i < rm.outgoingBuffer.len:
|
while i < channel.outgoingBuffer.len:
|
||||||
let outMsg = rm.outgoingBuffer[i]
|
let outMsg = channel.outgoingBuffer[i]
|
||||||
if outMsg.isAcknowledged(msg.causalHistory, rbf):
|
if outMsg.isAcknowledged(msg.causalHistory, rbf):
|
||||||
if not rm.onMessageSent.isNil():
|
if not rm.onMessageSent.isNil():
|
||||||
rm.onMessageSent(outMsg.message.messageId)
|
rm.onMessageSent(outMsg.message.messageId, outMsg.message.channelId)
|
||||||
toDelete.add(i)
|
toDelete.add(i)
|
||||||
inc i
|
inc i
|
||||||
|
|
||||||
for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices
|
for i in countdown(toDelete.high, 0): # Delete in reverse order to maintain indices
|
||||||
rm.outgoingBuffer.delete(toDelete[i])
|
channel.outgoingBuffer.delete(toDelete[i])
|
||||||
|
|
||||||
proc wrapOutgoingMessage*(
|
proc wrapOutgoingMessage*(
|
||||||
rm: ReliabilityManager, message: seq[byte], messageId: SdsMessageID
|
rm: ReliabilityManager,
|
||||||
|
message: seq[byte],
|
||||||
|
messageId: SdsMessageID,
|
||||||
|
channelId: SdsChannelID,
|
||||||
): Result[seq[byte], ReliabilityError] =
|
): Result[seq[byte], ReliabilityError] =
|
||||||
## Wraps an outgoing message with reliability metadata.
|
## Wraps an outgoing message with reliability metadata.
|
||||||
##
|
##
|
||||||
## Parameters:
|
## Parameters:
|
||||||
## - message: The content of the message to be sent.
|
## - message: The content of the message to be sent.
|
||||||
## - messageId: Unique identifier for the message
|
## - messageId: Unique identifier for the message
|
||||||
|
## - channelId: Identifier for the channel this message belongs to.
|
||||||
##
|
##
|
||||||
## Returns:
|
## Returns:
|
||||||
## A Result containing either wrapped message bytes or an error.
|
## A Result containing either wrapped message bytes or an error.
|
||||||
@ -106,46 +101,52 @@ proc wrapOutgoingMessage*(
|
|||||||
|
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
try:
|
try:
|
||||||
rm.updateLamportTimestamp(getTime().toUnix)
|
let channel = rm.getOrCreateChannel(channelId)
|
||||||
|
rm.updateLamportTimestamp(getTime().toUnix, channelId)
|
||||||
|
|
||||||
let bfResult = serializeBloomFilter(rm.bloomFilter.filter)
|
let bfResult = serializeBloomFilter(channel.bloomFilter.filter)
|
||||||
if bfResult.isErr:
|
if bfResult.isErr:
|
||||||
error "Failed to serialize bloom filter"
|
error "Failed to serialize bloom filter", channelId = channelId
|
||||||
return err(ReliabilityError.reSerializationError)
|
return err(ReliabilityError.reSerializationError)
|
||||||
|
|
||||||
let msg = SdsMessage(
|
let msg = SdsMessage(
|
||||||
messageId: messageId,
|
messageId: messageId,
|
||||||
lamportTimestamp: rm.lamportTimestamp,
|
lamportTimestamp: channel.lamportTimestamp,
|
||||||
causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory),
|
causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory, channelId),
|
||||||
channelId: rm.channelId,
|
channelId: channelId,
|
||||||
content: message,
|
content: message,
|
||||||
bloomFilter: bfResult.get(),
|
bloomFilter: bfResult.get(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to outgoing buffer
|
channel.outgoingBuffer.add(
|
||||||
rm.outgoingBuffer.add(
|
|
||||||
UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)
|
UnacknowledgedMessage(message: msg, sendTime: getTime(), resendAttempts: 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to causal history and bloom filter
|
# Add to causal history and bloom filter
|
||||||
rm.bloomFilter.add(msg.messageId)
|
channel.bloomFilter.add(msg.messageId)
|
||||||
rm.addToHistory(msg.messageId)
|
rm.addToHistory(msg.messageId, channelId)
|
||||||
|
|
||||||
return serializeMessage(msg)
|
return serializeMessage(msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Failed to wrap message", msg = getCurrentExceptionMsg()
|
error "Failed to wrap message",
|
||||||
|
channelId = channelId, msg = getCurrentExceptionMsg()
|
||||||
return err(ReliabilityError.reSerializationError)
|
return err(ReliabilityError.reSerializationError)
|
||||||
|
|
||||||
proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} =
|
proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gcsafe.} =
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
if rm.incomingBuffer.len == 0:
|
if channelId notin rm.channels:
|
||||||
|
error "Channel does not exist", channelId = channelId
|
||||||
|
return
|
||||||
|
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
|
if channel.incomingBuffer.len == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
var processed = initHashSet[SdsMessageID]()
|
var processed = initHashSet[SdsMessageID]()
|
||||||
var readyToProcess = newSeq[SdsMessageID]()
|
var readyToProcess = newSeq[SdsMessageID]()
|
||||||
|
|
||||||
# Find initially ready messages
|
# Find initially ready messages
|
||||||
for msgId, entry in rm.incomingBuffer:
|
for msgId, entry in channel.incomingBuffer:
|
||||||
if entry.missingDeps.len == 0:
|
if entry.missingDeps.len == 0:
|
||||||
readyToProcess.add(msgId)
|
readyToProcess.add(msgId)
|
||||||
|
|
||||||
@ -154,105 +155,114 @@ proc processIncomingBuffer(rm: ReliabilityManager) {.gcsafe.} =
|
|||||||
if msgId in processed:
|
if msgId in processed:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if msgId in rm.incomingBuffer:
|
if msgId in channel.incomingBuffer:
|
||||||
rm.addToHistory(msgId)
|
rm.addToHistory(msgId, channelId)
|
||||||
if not rm.onMessageReady.isNil():
|
if not rm.onMessageReady.isNil():
|
||||||
rm.onMessageReady(msgId)
|
rm.onMessageReady(msgId, channelId)
|
||||||
processed.incl(msgId)
|
processed.incl(msgId)
|
||||||
|
|
||||||
# Update dependencies for remaining messages
|
# Update dependencies for remaining messages
|
||||||
for remainingId, entry in rm.incomingBuffer:
|
for remainingId, entry in channel.incomingBuffer:
|
||||||
if remainingId notin processed:
|
if remainingId notin processed:
|
||||||
if msgId in entry.missingDeps:
|
if msgId in entry.missingDeps:
|
||||||
rm.incomingBuffer[remainingId].missingDeps.excl(msgId)
|
channel.incomingBuffer[remainingId].missingDeps.excl(msgId)
|
||||||
if rm.incomingBuffer[remainingId].missingDeps.len == 0:
|
if channel.incomingBuffer[remainingId].missingDeps.len == 0:
|
||||||
readyToProcess.add(remainingId)
|
readyToProcess.add(remainingId)
|
||||||
|
|
||||||
# Remove processed messages
|
# Remove processed messages
|
||||||
for msgId in processed:
|
for msgId in processed:
|
||||||
rm.incomingBuffer.del(msgId)
|
channel.incomingBuffer.del(msgId)
|
||||||
|
|
||||||
proc unwrapReceivedMessage*(
|
proc unwrapReceivedMessage*(
|
||||||
rm: ReliabilityManager, message: seq[byte]
|
rm: ReliabilityManager, message: seq[byte]
|
||||||
): Result[tuple[message: seq[byte], missingDeps: seq[SdsMessageID]], ReliabilityError] =
|
): Result[
|
||||||
|
tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID],
|
||||||
|
ReliabilityError,
|
||||||
|
] =
|
||||||
## Unwraps a received message and processes its reliability metadata.
|
## Unwraps a received message and processes its reliability metadata.
|
||||||
##
|
##
|
||||||
## Parameters:
|
## Parameters:
|
||||||
## - message: The received message bytes
|
## - message: The received message bytes
|
||||||
##
|
##
|
||||||
## Returns:
|
## Returns:
|
||||||
## A Result containing either tuple of (processed message, missing dependencies) or an error.
|
## A Result containing either tuple of (processed message, missing dependencies, channel ID) or an error.
|
||||||
try:
|
try:
|
||||||
|
let channelId = extractChannelId(message).valueOr:
|
||||||
|
return err(ReliabilityError.reDeserializationError)
|
||||||
|
|
||||||
let msg = deserializeMessage(message).valueOr:
|
let msg = deserializeMessage(message).valueOr:
|
||||||
return err(ReliabilityError.reDeserializationError)
|
return err(ReliabilityError.reDeserializationError)
|
||||||
|
|
||||||
if msg.messageId in rm.messageHistory:
|
let channel = rm.getOrCreateChannel(channelId)
|
||||||
return ok((msg.content, @[]))
|
|
||||||
|
|
||||||
rm.bloomFilter.add(msg.messageId)
|
if msg.messageId in channel.messageHistory:
|
||||||
|
return ok((msg.content, @[], channelId))
|
||||||
|
|
||||||
# Update Lamport timestamp
|
channel.bloomFilter.add(msg.messageId)
|
||||||
rm.updateLamportTimestamp(msg.lamportTimestamp)
|
|
||||||
|
|
||||||
|
rm.updateLamportTimestamp(msg.lamportTimestamp, channelId)
|
||||||
# Review ACK status for outgoing messages
|
# Review ACK status for outgoing messages
|
||||||
rm.reviewAckStatus(msg)
|
rm.reviewAckStatus(msg)
|
||||||
|
|
||||||
var missingDeps = rm.checkDependencies(msg.causalHistory)
|
var missingDeps = rm.checkDependencies(msg.causalHistory, channelId)
|
||||||
|
|
||||||
if missingDeps.len == 0:
|
if missingDeps.len == 0:
|
||||||
# Check if any dependencies are still in incoming buffer
|
|
||||||
var depsInBuffer = false
|
var depsInBuffer = false
|
||||||
for msgId, entry in rm.incomingBuffer.pairs():
|
for msgId, entry in channel.incomingBuffer.pairs():
|
||||||
if msgId in msg.causalHistory:
|
if msgId in msg.causalHistory:
|
||||||
depsInBuffer = true
|
depsInBuffer = true
|
||||||
break
|
break
|
||||||
|
# Check if any dependencies are still in incoming buffer
|
||||||
if depsInBuffer:
|
if depsInBuffer:
|
||||||
rm.incomingBuffer[msg.messageId] =
|
channel.incomingBuffer[msg.messageId] =
|
||||||
IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]())
|
IncomingMessage(message: msg, missingDeps: initHashSet[SdsMessageID]())
|
||||||
else:
|
else:
|
||||||
# All dependencies met, add to history
|
# All dependencies met, add to history
|
||||||
rm.addToHistory(msg.messageId)
|
rm.addToHistory(msg.messageId, channelId)
|
||||||
rm.processIncomingBuffer()
|
rm.processIncomingBuffer(channelId)
|
||||||
if not rm.onMessageReady.isNil():
|
if not rm.onMessageReady.isNil():
|
||||||
rm.onMessageReady(msg.messageId)
|
rm.onMessageReady(msg.messageId, channelId)
|
||||||
else:
|
else:
|
||||||
rm.incomingBuffer[msg.messageId] =
|
channel.incomingBuffer[msg.messageId] =
|
||||||
IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet())
|
IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet())
|
||||||
if not rm.onMissingDependencies.isNil():
|
if not rm.onMissingDependencies.isNil():
|
||||||
rm.onMissingDependencies(msg.messageId, missingDeps)
|
rm.onMissingDependencies(msg.messageId, missingDeps, channelId)
|
||||||
|
|
||||||
return ok((msg.content, missingDeps))
|
return ok((msg.content, missingDeps, channelId))
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Failed to unwrap message", msg = getCurrentExceptionMsg()
|
error "Failed to unwrap message", msg = getCurrentExceptionMsg()
|
||||||
return err(ReliabilityError.reDeserializationError)
|
return err(ReliabilityError.reDeserializationError)
|
||||||
|
|
||||||
proc markDependenciesMet*(
|
proc markDependenciesMet*(
|
||||||
rm: ReliabilityManager, messageIds: seq[SdsMessageID]
|
rm: ReliabilityManager, messageIds: seq[SdsMessageID], channelId: SdsChannelID
|
||||||
): Result[void, ReliabilityError] =
|
): Result[void, ReliabilityError] =
|
||||||
## Marks the specified message dependencies as met.
|
## Marks the specified message dependencies as met.
|
||||||
##
|
##
|
||||||
## Parameters:
|
## Parameters:
|
||||||
## - messageIds: A sequence of message IDs to mark as met.
|
## - messageIds: A sequence of message IDs to mark as met.
|
||||||
|
## - channelId: Identifier for the channel.
|
||||||
##
|
##
|
||||||
## Returns:
|
## Returns:
|
||||||
## A Result indicating success or an error.
|
## A Result indicating success or an error.
|
||||||
try:
|
try:
|
||||||
# Add all messageIds to bloom filter
|
if channelId notin rm.channels:
|
||||||
|
return err(ReliabilityError.reInvalidArgument)
|
||||||
|
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
|
|
||||||
for msgId in messageIds:
|
for msgId in messageIds:
|
||||||
if not rm.bloomFilter.contains(msgId):
|
if not channel.bloomFilter.contains(msgId):
|
||||||
rm.bloomFilter.add(msgId)
|
channel.bloomFilter.add(msgId)
|
||||||
# rm.addToHistory(msgId) -- not needed as this proc usually called when msg in long-term storage of application?
|
|
||||||
|
|
||||||
# Update any pending messages that depend on this one
|
for pendingId, entry in channel.incomingBuffer:
|
||||||
for pendingId, entry in rm.incomingBuffer:
|
|
||||||
if msgId in entry.missingDeps:
|
if msgId in entry.missingDeps:
|
||||||
rm.incomingBuffer[pendingId].missingDeps.excl(msgId)
|
channel.incomingBuffer[pendingId].missingDeps.excl(msgId)
|
||||||
|
|
||||||
rm.processIncomingBuffer()
|
rm.processIncomingBuffer(channelId)
|
||||||
return ok()
|
return ok()
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Failed to mark dependencies as met", msg = getCurrentExceptionMsg()
|
error "Failed to mark dependencies as met",
|
||||||
|
channelId = channelId, msg = getCurrentExceptionMsg()
|
||||||
return err(ReliabilityError.reInternalError)
|
return err(ReliabilityError.reInternalError)
|
||||||
|
|
||||||
proc setCallbacks*(
|
proc setCallbacks*(
|
||||||
@ -275,16 +285,22 @@ proc setCallbacks*(
|
|||||||
rm.onMissingDependencies = onMissingDependencies
|
rm.onMissingDependencies = onMissingDependencies
|
||||||
rm.onPeriodicSync = onPeriodicSync
|
rm.onPeriodicSync = onPeriodicSync
|
||||||
|
|
||||||
proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} =
|
proc checkUnacknowledgedMessages(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
) {.gcsafe.} =
|
||||||
## Checks and processes unacknowledged messages in the outgoing buffer.
|
## Checks and processes unacknowledged messages in the outgoing buffer.
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
|
if channelId notin rm.channels:
|
||||||
|
error "Channel does not exist", channelId = channelId
|
||||||
|
return
|
||||||
|
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
let now = getTime()
|
let now = getTime()
|
||||||
var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[]
|
var newOutgoingBuffer: seq[UnacknowledgedMessage] = @[]
|
||||||
|
|
||||||
for unackMsg in rm.outgoingBuffer:
|
for unackMsg in channel.outgoingBuffer:
|
||||||
let elapsed = now - unackMsg.sendTime
|
let elapsed = now - unackMsg.sendTime
|
||||||
if elapsed > rm.config.resendInterval:
|
if elapsed > rm.config.resendInterval:
|
||||||
# Time to attempt resend
|
|
||||||
if unackMsg.resendAttempts < rm.config.maxResendAttempts:
|
if unackMsg.resendAttempts < rm.config.maxResendAttempts:
|
||||||
var updatedMsg = unackMsg
|
var updatedMsg = unackMsg
|
||||||
updatedMsg.resendAttempts += 1
|
updatedMsg.resendAttempts += 1
|
||||||
@ -292,11 +308,11 @@ proc checkUnacknowledgedMessages(rm: ReliabilityManager) {.gcsafe.} =
|
|||||||
newOutgoingBuffer.add(updatedMsg)
|
newOutgoingBuffer.add(updatedMsg)
|
||||||
else:
|
else:
|
||||||
if not rm.onMessageSent.isNil():
|
if not rm.onMessageSent.isNil():
|
||||||
rm.onMessageSent(unackMsg.message.messageId)
|
rm.onMessageSent(unackMsg.message.messageId, channelId)
|
||||||
else:
|
else:
|
||||||
newOutgoingBuffer.add(unackMsg)
|
newOutgoingBuffer.add(unackMsg)
|
||||||
|
|
||||||
rm.outgoingBuffer = newOutgoingBuffer
|
channel.outgoingBuffer = newOutgoingBuffer
|
||||||
|
|
||||||
proc periodicBufferSweep(
|
proc periodicBufferSweep(
|
||||||
rm: ReliabilityManager
|
rm: ReliabilityManager
|
||||||
@ -304,8 +320,13 @@ proc periodicBufferSweep(
|
|||||||
## Periodically sweeps the buffer to clean up and check unacknowledged messages.
|
## Periodically sweeps the buffer to clean up and check unacknowledged messages.
|
||||||
while true:
|
while true:
|
||||||
try:
|
try:
|
||||||
rm.checkUnacknowledgedMessages()
|
for channelId, channel in rm.channels:
|
||||||
rm.cleanBloomFilter()
|
try:
|
||||||
|
rm.checkUnacknowledgedMessages(channelId)
|
||||||
|
rm.cleanBloomFilter(channelId)
|
||||||
|
except Exception:
|
||||||
|
error "Error in buffer sweep for channel",
|
||||||
|
channelId = channelId, msg = getCurrentExceptionMsg()
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg()
|
error "Error in periodic buffer sweep", msg = getCurrentExceptionMsg()
|
||||||
|
|
||||||
@ -336,13 +357,15 @@ proc resetReliabilityManager*(rm: ReliabilityManager): Result[void, ReliabilityE
|
|||||||
## This procedure clears all buffers and resets the Lamport timestamp.
|
## This procedure clears all buffers and resets the Lamport timestamp.
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
try:
|
try:
|
||||||
rm.lamportTimestamp = 0
|
for channelId, channel in rm.channels:
|
||||||
rm.messageHistory.setLen(0)
|
channel.lamportTimestamp = 0
|
||||||
rm.outgoingBuffer.setLen(0)
|
channel.messageHistory.setLen(0)
|
||||||
rm.incomingBuffer.clear()
|
channel.outgoingBuffer.setLen(0)
|
||||||
rm.bloomFilter = newRollingBloomFilter(
|
channel.incomingBuffer.clear()
|
||||||
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate
|
channel.bloomFilter = newRollingBloomFilter(
|
||||||
)
|
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate
|
||||||
|
)
|
||||||
|
rm.channels.clear()
|
||||||
return ok()
|
return ok()
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg()
|
error "Failed to reset ReliabilityManager", msg = getCurrentExceptionMsg()
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
import std/[times, locks, options]
|
import std/[times, locks, tables]
|
||||||
import chronicles
|
import chronicles, results
|
||||||
import ./[rolling_bloom_filter, message]
|
import ./[rolling_bloom_filter, message]
|
||||||
|
|
||||||
type
|
type
|
||||||
MessageReadyCallback* = proc(messageId: SdsMessageID) {.gcsafe.}
|
MessageReadyCallback* =
|
||||||
|
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
|
||||||
|
|
||||||
MessageSentCallback* = proc(messageId: SdsMessageID) {.gcsafe.}
|
MessageSentCallback* =
|
||||||
|
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
|
||||||
|
|
||||||
MissingDependenciesCallback* =
|
MissingDependenciesCallback* = proc(
|
||||||
proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.}
|
messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID
|
||||||
|
) {.gcsafe.}
|
||||||
|
|
||||||
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
|
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
|
||||||
|
|
||||||
@ -28,19 +31,22 @@ type
|
|||||||
syncMessageInterval*: Duration
|
syncMessageInterval*: Duration
|
||||||
bufferSweepInterval*: Duration
|
bufferSweepInterval*: Duration
|
||||||
|
|
||||||
ReliabilityManager* = ref object
|
ChannelContext* = ref object
|
||||||
lamportTimestamp*: int64
|
lamportTimestamp*: int64
|
||||||
messageHistory*: seq[SdsMessageID]
|
messageHistory*: seq[SdsMessageID]
|
||||||
bloomFilter*: RollingBloomFilter
|
bloomFilter*: RollingBloomFilter
|
||||||
outgoingBuffer*: seq[UnacknowledgedMessage]
|
outgoingBuffer*: seq[UnacknowledgedMessage]
|
||||||
incomingBuffer*: Table[SdsMessageID, IncomingMessage]
|
incomingBuffer*: Table[SdsMessageID, IncomingMessage]
|
||||||
channelId*: Option[SdsChannelID]
|
|
||||||
|
ReliabilityManager* = ref object
|
||||||
|
channels*: Table[SdsChannelID, ChannelContext]
|
||||||
config*: ReliabilityConfig
|
config*: ReliabilityConfig
|
||||||
lock*: Lock
|
lock*: Lock
|
||||||
onMessageReady*: proc(messageId: SdsMessageID) {.gcsafe.}
|
onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
|
||||||
onMessageSent*: proc(messageId: SdsMessageID) {.gcsafe.}
|
onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.}
|
||||||
onMissingDependencies*:
|
onMissingDependencies*: proc(
|
||||||
proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID]) {.gcsafe.}
|
messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID
|
||||||
|
) {.gcsafe.}
|
||||||
onPeriodicSync*: PeriodicSyncCallback
|
onPeriodicSync*: PeriodicSyncCallback
|
||||||
|
|
||||||
ReliabilityError* {.pure.} = enum
|
ReliabilityError* {.pure.} = enum
|
||||||
@ -71,51 +77,168 @@ proc cleanup*(rm: ReliabilityManager) {.raises: [].} =
|
|||||||
if not rm.isNil():
|
if not rm.isNil():
|
||||||
try:
|
try:
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
rm.outgoingBuffer.setLen(0)
|
for channelId, channel in rm.channels:
|
||||||
rm.incomingBuffer.clear()
|
channel.outgoingBuffer.setLen(0)
|
||||||
rm.messageHistory.setLen(0)
|
channel.incomingBuffer.clear()
|
||||||
|
channel.messageHistory.setLen(0)
|
||||||
|
rm.channels.clear()
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Error during cleanup", error = getCurrentExceptionMsg()
|
error "Error during cleanup", error = getCurrentExceptionMsg()
|
||||||
|
|
||||||
proc cleanBloomFilter*(rm: ReliabilityManager) {.gcsafe, raises: [].} =
|
proc cleanBloomFilter*(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
) {.gcsafe, raises: [].} =
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
try:
|
try:
|
||||||
rm.bloomFilter.clean()
|
if channelId in rm.channels:
|
||||||
|
rm.channels[channelId].bloomFilter.clean()
|
||||||
except Exception:
|
except Exception:
|
||||||
error "Failed to clean bloom filter", error = getCurrentExceptionMsg()
|
error "Failed to clean bloom filter",
|
||||||
|
error = getCurrentExceptionMsg(), channelId = channelId
|
||||||
|
|
||||||
proc addToHistory*(rm: ReliabilityManager, msgId: SdsMessageID) {.gcsafe, raises: [].} =
|
proc addToHistory*(
|
||||||
rm.messageHistory.add(msgId)
|
rm: ReliabilityManager, msgId: SdsMessageID, channelId: SdsChannelID
|
||||||
if rm.messageHistory.len > rm.config.maxMessageHistory:
|
) {.gcsafe, raises: [].} =
|
||||||
rm.messageHistory.delete(0)
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
|
channel.messageHistory.add(msgId)
|
||||||
|
if channel.messageHistory.len > rm.config.maxMessageHistory:
|
||||||
|
channel.messageHistory.delete(0)
|
||||||
|
except Exception:
|
||||||
|
error "Failed to add to history",
|
||||||
|
channelId = channelId, msgId = msgId, error = getCurrentExceptionMsg()
|
||||||
|
|
||||||
proc updateLamportTimestamp*(
|
proc updateLamportTimestamp*(
|
||||||
rm: ReliabilityManager, msgTs: int64
|
rm: ReliabilityManager, msgTs: int64, channelId: SdsChannelID
|
||||||
) {.gcsafe, raises: [].} =
|
) {.gcsafe, raises: [].} =
|
||||||
rm.lamportTimestamp = max(msgTs, rm.lamportTimestamp) + 1
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
|
channel.lamportTimestamp = max(msgTs, channel.lamportTimestamp) + 1
|
||||||
|
except Exception:
|
||||||
|
error "Failed to update lamport timestamp",
|
||||||
|
channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg()
|
||||||
|
|
||||||
proc getRecentSdsMessageIDs*(rm: ReliabilityManager, n: int): seq[SdsMessageID] =
|
proc getRecentSdsMessageIDs*(
|
||||||
result = rm.messageHistory[max(0, rm.messageHistory.len - n) .. ^1]
|
rm: ReliabilityManager, n: int, channelId: SdsChannelID
|
||||||
|
): seq[SdsMessageID] =
|
||||||
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
|
result = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1]
|
||||||
|
else:
|
||||||
|
result = @[]
|
||||||
|
except Exception:
|
||||||
|
error "Failed to get recent message IDs",
|
||||||
|
channelId = channelId, n = n, error = getCurrentExceptionMsg()
|
||||||
|
result = @[]
|
||||||
|
|
||||||
proc checkDependencies*(
|
proc checkDependencies*(
|
||||||
rm: ReliabilityManager, deps: seq[SdsMessageID]
|
rm: ReliabilityManager, deps: seq[SdsMessageID], channelId: SdsChannelID
|
||||||
): seq[SdsMessageID] =
|
): seq[SdsMessageID] =
|
||||||
var missingDeps: seq[SdsMessageID] = @[]
|
var missingDeps: seq[SdsMessageID] = @[]
|
||||||
for depId in deps:
|
try:
|
||||||
if depId notin rm.messageHistory:
|
if channelId in rm.channels:
|
||||||
missingDeps.add(depId)
|
let channel = rm.channels[channelId]
|
||||||
|
for depId in deps:
|
||||||
|
if depId notin channel.messageHistory:
|
||||||
|
missingDeps.add(depId)
|
||||||
|
else:
|
||||||
|
missingDeps = deps
|
||||||
|
except Exception:
|
||||||
|
error "Failed to check dependencies",
|
||||||
|
channelId = channelId, error = getCurrentExceptionMsg()
|
||||||
|
missingDeps = deps
|
||||||
return missingDeps
|
return missingDeps
|
||||||
|
|
||||||
proc getMessageHistory*(rm: ReliabilityManager): seq[SdsMessageID] =
|
proc getMessageHistory*(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
): seq[SdsMessageID] =
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
result = rm.messageHistory
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
result = rm.channels[channelId].messageHistory
|
||||||
|
else:
|
||||||
|
result = @[]
|
||||||
|
except Exception:
|
||||||
|
error "Failed to get message history",
|
||||||
|
channelId = channelId, error = getCurrentExceptionMsg()
|
||||||
|
result = @[]
|
||||||
|
|
||||||
proc getOutgoingBuffer*(rm: ReliabilityManager): seq[UnacknowledgedMessage] =
|
proc getOutgoingBuffer*(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
): seq[UnacknowledgedMessage] =
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
result = rm.outgoingBuffer
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
result = rm.channels[channelId].outgoingBuffer
|
||||||
|
else:
|
||||||
|
result = @[]
|
||||||
|
except Exception:
|
||||||
|
error "Failed to get outgoing buffer",
|
||||||
|
channelId = channelId, error = getCurrentExceptionMsg()
|
||||||
|
result = @[]
|
||||||
|
|
||||||
proc getIncomingBuffer*(
|
proc getIncomingBuffer*(
|
||||||
rm: ReliabilityManager
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
): Table[SdsMessageID, message.IncomingMessage] =
|
): Table[SdsMessageID, message.IncomingMessage] =
|
||||||
withLock rm.lock:
|
withLock rm.lock:
|
||||||
result = rm.incomingBuffer
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
result = rm.channels[channelId].incomingBuffer
|
||||||
|
else:
|
||||||
|
result = initTable[SdsMessageID, message.IncomingMessage]()
|
||||||
|
except Exception:
|
||||||
|
error "Failed to get incoming buffer",
|
||||||
|
channelId = channelId, error = getCurrentExceptionMsg()
|
||||||
|
result = initTable[SdsMessageID, message.IncomingMessage]()
|
||||||
|
|
||||||
|
proc getOrCreateChannel*(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
): ChannelContext =
|
||||||
|
try:
|
||||||
|
if channelId notin rm.channels:
|
||||||
|
rm.channels[channelId] = ChannelContext(
|
||||||
|
lamportTimestamp: 0,
|
||||||
|
messageHistory: @[],
|
||||||
|
bloomFilter: newRollingBloomFilter(
|
||||||
|
rm.config.bloomFilterCapacity, rm.config.bloomFilterErrorRate
|
||||||
|
),
|
||||||
|
outgoingBuffer: @[],
|
||||||
|
incomingBuffer: initTable[SdsMessageID, IncomingMessage](),
|
||||||
|
)
|
||||||
|
result = rm.channels[channelId]
|
||||||
|
except Exception:
|
||||||
|
error "Failed to get or create channel",
|
||||||
|
channelId = channelId, error = getCurrentExceptionMsg()
|
||||||
|
raise
|
||||||
|
|
||||||
|
proc ensureChannel*(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
): Result[void, ReliabilityError] =
|
||||||
|
withLock rm.lock:
|
||||||
|
try:
|
||||||
|
discard rm.getOrCreateChannel(channelId)
|
||||||
|
return ok()
|
||||||
|
except Exception:
|
||||||
|
error "Failed to ensure channel",
|
||||||
|
channelId = channelId, msg = getCurrentExceptionMsg()
|
||||||
|
return err(ReliabilityError.reInternalError)
|
||||||
|
|
||||||
|
proc removeChannel*(
|
||||||
|
rm: ReliabilityManager, channelId: SdsChannelID
|
||||||
|
): Result[void, ReliabilityError] =
|
||||||
|
withLock rm.lock:
|
||||||
|
try:
|
||||||
|
if channelId in rm.channels:
|
||||||
|
let channel = rm.channels[channelId]
|
||||||
|
channel.outgoingBuffer.setLen(0)
|
||||||
|
channel.incomingBuffer.clear()
|
||||||
|
channel.messageHistory.setLen(0)
|
||||||
|
rm.channels.del(channelId)
|
||||||
|
return ok()
|
||||||
|
except Exception:
|
||||||
|
error "Failed to remove channel",
|
||||||
|
channelId = channelId, msg = getCurrentExceptionMsg()
|
||||||
|
return err(ReliabilityError.reInternalError)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user