feat: updates for retrieval hint

This commit is contained in:
shash256 2025-07-24 20:06:23 +05:30
parent 4e10d77218
commit b7b24c6747
10 changed files with 180 additions and 36 deletions

View File

@ -5,6 +5,10 @@ type SdsCallBack* = proc(
callerRet: cint, msg: ptr cchar, len: csize_t, userData: pointer
) {.cdecl, gcsafe, raises: [].}
type SdsRetrievalHintProvider* = proc(
messageId: cstring, hint: ptr cstring, hintLen: ptr csize_t, userData: pointer
) {.cdecl, gcsafe, raises: [].}
const RET_OK*: cint = 0
const RET_ERR*: cint = 1
const RET_MISSING_CALLBACK*: cint = 2

View File

@ -20,6 +20,8 @@ extern "C" {
typedef void (*SdsCallBack) (int callerRet, const char* msg, size_t len, void* userData);
typedef void (*SdsRetrievalHintProvider) (const char* messageId, char** hint, size_t* hintLen, void* userData);
// --- Core API Functions ---
@ -28,6 +30,8 @@ void* SdsNewReliabilityManager(SdsCallBack callback, void* userData);
void SdsSetEventCallback(void* ctx, SdsCallBack callback, void* userData);
void SdsSetRetrievalHintProvider(void* ctx, SdsRetrievalHintProvider callback, void* userData);
int SdsCleanupReliabilityManager(void* ctx, SdsCallBack callback, void* userData);
int SdsResetReliabilityManager(void* ctx, SdsCallBack callback, void* userData);

View File

@ -91,6 +91,22 @@ proc onPeriodicSync(ctx: ptr SdsContext): PeriodicSyncCallback =
callEventCallback(ctx, "onPeriodicSync"):
$JsonPeriodicSyncEvent.new()
proc onRetrievalHint(ctx: ptr SdsContext): RetrievalHintProvider =
return proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} =
if isNil(ctx.retrievalHintProvider):
return @[]
var hint: cstring
var hintLen: csize_t
cast[SdsRetrievalHintProvider](ctx.retrievalHintProvider)(
messageId.cstring, addr hint, addr hintLen, ctx.retrievalHintUserData
)
if not isNil(hint) and hintLen > 0:
result = newSeq[byte](hintLen)
copyMem(addr result[0], hint, hintLen)
deallocShared(hint)
### End of not-exported components
################################################################################
@ -153,6 +169,7 @@ proc SdsNewReliabilityManager(
messageSentCb: onMessageSent(ctx),
missingDependenciesCb: onMissingDependencies(ctx),
periodicSyncCb: onPeriodicSync(ctx),
retrievalHintProvider: onRetrievalHint(ctx),
)
let retCode = handleRequest(
@ -177,6 +194,13 @@ proc SdsSetEventCallback(
ctx[].eventCallback = cast[pointer](callback)
ctx[].eventUserData = userData
proc SdsSetRetrievalHintProvider(
ctx: ptr SdsContext, callback: SdsRetrievalHintProvider, userData: pointer
) {.dynlib, exportc.} =
initializeLibrary()
ctx[].retrievalHintProvider = cast[pointer](callback)
ctx[].retrievalHintUserData = userData
proc SdsCleanupReliabilityManager(
ctx: ptr SdsContext, callback: SdsCallBack, userData: pointer
): cint {.dynlib, exportc.} =

View File

@ -65,7 +65,7 @@ proc process*(
error "UNWRAP_MESSAGE failed", error = error
return err("error processing UNWRAP_MESSAGE request: " & $error)
let res = SdsUnwrapResponse(message: unwrappedMessage, missingDeps: missingDeps)
let res = SdsUnwrapResponse(message: unwrappedMessage, missingDeps: missingDeps.getMessageIds())
# return the result as a json string
return ok($(%*(res)))

View File

@ -20,6 +20,8 @@ type SdsContext* = object
userData*: pointer
eventCallback*: pointer
eventUserdata*: pointer
retrievalHintProvider*: pointer
retrievalHintUserData*: pointer
running: Atomic[bool] # To control when the thread is running
proc runSds(ctx: ptr SdsContext) {.async.} =

View File

@ -4,10 +4,14 @@ type
SdsMessageID* = string
SdsChannelID* = string
HistoryEntry* = object
messageId*: SdsMessageID
retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash)
SdsMessage* = object
messageId*: SdsMessageID
lamportTimestamp*: int64
causalHistory*: seq[SdsMessageID]
causalHistory*: seq[HistoryEntry]
channelId*: SdsChannelID
content*: seq[byte]
bloomFilter*: seq[byte]
@ -29,3 +33,24 @@ const
DefaultSyncMessageInterval* = initDuration(seconds = 30)
DefaultBufferSweepInterval* = initDuration(seconds = 60)
MaxMessageSize* = 1024 * 1024 # 1 MB
# Helper functions for HistoryEntry
proc newHistoryEntry*(messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): HistoryEntry =
## Creates a new HistoryEntry with optional retrieval hint
HistoryEntry(messageId: messageId, retrievalHint: retrievalHint)
proc newHistoryEntry*(messageId: SdsMessageID, retrievalHint: string): HistoryEntry =
## Creates a new HistoryEntry with string retrieval hint
HistoryEntry(messageId: messageId, retrievalHint: cast[seq[byte]](retrievalHint))
proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] =
## Converts a sequence of message IDs to HistoryEntry sequence (for backward compatibility)
result = newSeq[HistoryEntry](messageIds.len)
for i, msgId in messageIds:
result[i] = newHistoryEntry(msgId)
proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] =
## Extracts message IDs from HistoryEntry sequence
result = newSeq[SdsMessageID](causalHistory.len)
for i, entry in causalHistory:
result[i] = entry.messageId

View File

@ -9,8 +9,13 @@ proc encode*(msg: SdsMessage): ProtoBuffer =
pb.write(1, msg.messageId)
pb.write(2, uint64(msg.lamportTimestamp))
for hist in msg.causalHistory:
pb.write(3, hist)
for entry in msg.causalHistory:
var entryPb = initProtoBuffer()
entryPb.write(1, entry.messageId)
if entry.retrievalHint.len > 0:
entryPb.write(2, entry.retrievalHint)
entryPb.finish()
pb.write(3, entryPb.buffer)
pb.write(4, msg.channelId)
pb.write(5, msg.content)
@ -31,10 +36,24 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
msg.lamportTimestamp = int64(timestamp)
var causalHistory: seq[SdsMessageID]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = causalHistory
# Handle both old and new causal history formats
var historyBuffers: seq[seq[byte]]
if pb.getRepeatedField(3, historyBuffers).isOk:
# New format: repeated HistoryEntry
for histBuffer in historyBuffers:
let entryPb = initProtoBuffer(histBuffer)
var entry: HistoryEntry
if not ?entryPb.getField(1, entry.messageId):
return err(ProtobufError.missingRequiredField("HistoryEntry.messageId"))
# retrievalHint is optional
discard entryPb.getField(2, entry.retrievalHint)
msg.causalHistory.add(entry)
else:
# Try old format: repeated string
var causalHistory: seq[SdsMessageID]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = toCausalHistory(causalHistory)
if not ?pb.getField(4, msg.channelId):
return err(ProtobufError.missingRequiredField("channelId"))

View File

@ -24,10 +24,10 @@ proc newReliabilityManager*(
proc isAcknowledged*(
msg: UnacknowledgedMessage,
causalHistory: seq[SdsMessageID],
causalHistory: seq[HistoryEntry],
rbf: Option[RollingBloomFilter],
): bool =
if msg.message.messageId in causalHistory:
if msg.message.messageId in causalHistory.getMessageIds():
return true
if rbf.isSome():
@ -112,7 +112,7 @@ proc wrapOutgoingMessage*(
let msg = SdsMessage(
messageId: messageId,
lamportTimestamp: channel.lamportTimestamp,
causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory, channelId),
causalHistory: rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId),
channelId: channelId,
content: message,
bloomFilter: bfResult.get(),
@ -176,7 +176,7 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
proc unwrapReceivedMessage*(
rm: ReliabilityManager, message: seq[byte]
): Result[
tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID],
tuple[message: seq[byte], missingDeps: seq[HistoryEntry], channelId: SdsChannelID],
ReliabilityError,
] =
## Unwraps a received message and processes its reliability metadata.
@ -209,7 +209,7 @@ proc unwrapReceivedMessage*(
if missingDeps.len == 0:
var depsInBuffer = false
for msgId, entry in channel.incomingBuffer.pairs():
if msgId in msg.causalHistory:
if msgId in msg.causalHistory.getMessageIds():
depsInBuffer = true
break
# Check if any dependencies are still in incoming buffer
@ -224,9 +224,9 @@ proc unwrapReceivedMessage*(
rm.onMessageReady(msg.messageId, channelId)
else:
channel.incomingBuffer[msg.messageId] =
IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet())
IncomingMessage(message: msg, missingDeps: missingDeps.getMessageIds().toHashSet())
if not rm.onMissingDependencies.isNil():
rm.onMissingDependencies(msg.messageId, missingDeps, channelId)
rm.onMissingDependencies(msg.messageId, missingDeps.getMessageIds(), channelId)
return ok((msg.content, missingDeps, channelId))
except Exception:
@ -271,6 +271,7 @@ proc setCallbacks*(
onMessageSent: MessageSentCallback,
onMissingDependencies: MissingDependenciesCallback,
onPeriodicSync: PeriodicSyncCallback = nil,
onRetrievalHint: RetrievalHintProvider = nil
) =
## Sets the callback functions for various events in the ReliabilityManager.
##
@ -279,11 +280,13 @@ proc setCallbacks*(
## - onMessageSent: Callback function called when a message is confirmed as sent.
## - onMissingDependencies: Callback function called when a message has missing dependencies.
## - onPeriodicSync: Callback function called to notify about periodic sync
## - onRetrievalHint: Callback function called to get a retrieval hint for a message ID.
withLock rm.lock:
rm.onMessageReady = onMessageReady
rm.onMessageSent = onMessageSent
rm.onMissingDependencies = onMissingDependencies
rm.onPeriodicSync = onPeriodicSync
rm.onRetrievalHint = onRetrievalHint
proc checkUnacknowledgedMessages(
rm: ReliabilityManager, channelId: SdsChannelID

View File

@ -13,6 +13,8 @@ type
messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID
) {.gcsafe.}
RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.}
PeriodicSyncCallback* = proc() {.gcsafe, raises: [].}
AppCallbacks* = ref object
@ -20,6 +22,7 @@ type
messageSentCb*: MessageSentCallback
missingDependenciesCb*: MissingDependenciesCallback
periodicSyncCb*: PeriodicSyncCallback
retrievalHintProvider*: RetrievalHintProvider
ReliabilityConfig* = object
bloomFilterCapacity*: int
@ -48,6 +51,7 @@ type
messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID
) {.gcsafe.}
onPeriodicSync*: PeriodicSyncCallback
onRetrievalHint*: RetrievalHintProvider
ReliabilityError* {.pure.} = enum
reInvalidArgument
@ -120,30 +124,36 @@ proc updateLamportTimestamp*(
error "Failed to update lamport timestamp",
channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg()
proc getRecentSdsMessageIDs*(
proc getRecentHistoryEntries*(
rm: ReliabilityManager, n: int, channelId: SdsChannelID
): seq[SdsMessageID] =
): seq[HistoryEntry] =
try:
if channelId in rm.channels:
let channel = rm.channels[channelId]
result = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1]
let recentMessageIds = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1]
if rm.onRetrievalHint.isNil():
return toCausalHistory(recentMessageIds)
else:
for msgId in recentMessageIds:
let hint = rm.onRetrievalHint(msgId)
result.add(newHistoryEntry(msgId, hint))
else:
result = @[]
except Exception:
error "Failed to get recent message IDs",
error "Failed to get recent history entries",
channelId = channelId, n = n, error = getCurrentExceptionMsg()
result = @[]
proc checkDependencies*(
rm: ReliabilityManager, deps: seq[SdsMessageID], channelId: SdsChannelID
): seq[SdsMessageID] =
var missingDeps: seq[SdsMessageID] = @[]
rm: ReliabilityManager, deps: seq[HistoryEntry], channelId: SdsChannelID
): seq[HistoryEntry] =
var missingDeps: seq[HistoryEntry] = @[]
try:
if channelId in rm.channels:
let channel = rm.channels[channelId]
for depId in deps:
if depId notin channel.messageHistory:
missingDeps.add(depId)
for dep in deps:
if dep.messageId notin channel.messageHistory:
missingDeps.add(dep)
else:
missingDeps = deps
except Exception:

View File

@ -112,7 +112,7 @@ suite "Reliability Mechanisms":
let msg2 = SdsMessage(
messageId: id2,
lamportTimestamp: 2,
causalHistory: @[id1], # msg2 depends on msg1
causalHistory: toCausalHistory(@[id1]), # msg2 depends on msg1
channelId: testChannel,
content: @[byte(2)],
bloomFilter: @[],
@ -121,7 +121,7 @@ suite "Reliability Mechanisms":
let msg3 = SdsMessage(
messageId: id3,
lamportTimestamp: 3,
causalHistory: @[id1, id2], # msg3 depends on both msg1 and msg2
causalHistory: toCausalHistory(@[id1, id2]), # msg3 depends on both msg1 and msg2
channelId: testChannel,
content: @[byte(3)],
bloomFilter: @[],
@ -141,8 +141,8 @@ suite "Reliability Mechanisms":
check:
missingDepsCount == 1 # Should trigger missing deps callback
missingDeps3.len == 2 # Should be missing both msg1 and msg2
id1 in missingDeps3
id2 in missingDeps3
id1 in missingDeps3.getMessageIds()
id2 in missingDeps3.getMessageIds()
# Then try processing msg2 (which only depends on msg1)
let unwrapResult2 = rm.unwrapReceivedMessage(serialized2.get())
@ -152,7 +152,7 @@ suite "Reliability Mechanisms":
check:
missingDepsCount == 2 # Should have triggered another missing deps callback
missingDeps2.len == 1 # Should only be missing msg1
id1 in missingDeps2
id1 in missingDeps2.getMessageIds()
messageReadyCount == 0 # No messages should be ready yet
# Mark first dependency (msg1) as met
@ -190,7 +190,7 @@ suite "Reliability Mechanisms":
let msg2 = SdsMessage(
messageId: "msg2",
lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1,
causalHistory: @[id1], # Include our message in causal history
causalHistory: toCausalHistory(@[id1]), # Include our message in causal history
channelId: testChannel,
content: @[byte(2)],
bloomFilter: @[] # Test with an empty bloom filter
@ -251,6 +251,59 @@ suite "Reliability Mechanisms":
check messageSentCount == 1 # Our message should be acknowledged via bloom filter
test "retrieval hints":
var messageReadyCount = 0
var messageSentCount = 0
var missingDepsCount = 0
rm.setCallbacks(
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} =
messageReadyCount += 1,
proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} =
messageSentCount += 1,
proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} =
missingDepsCount += 1,
nil,
proc(messageId: SdsMessageID): seq[byte] =
return cast[seq[byte]]("hint:" & messageId)
)
# Send a first message to populate history
let msg1 = @[byte(1)]
let id1 = "msg1"
let wrap1 = rm.wrapOutgoingMessage(msg1, id1, testChannel)
check wrap1.isOk()
# Send a second message, which should have the first in its causal history
let msg2 = @[byte(2)]
let id2 = "msg2"
let wrap2 = rm.wrapOutgoingMessage(msg2, id2, testChannel)
check wrap2.isOk()
# Check that the wrapped message contains the hint
let unwrappedMsg2 = deserializeMessage(wrap2.get()).get()
check unwrappedMsg2.causalHistory.len > 0
check unwrappedMsg2.causalHistory[0].messageId == id1
check unwrappedMsg2.causalHistory[0].retrievalHint == cast[seq[byte]]("hint:" & id1)
# Create a message with a missing dependency
let msg3 = SdsMessage(
messageId: "msg3",
lamportTimestamp: 3,
causalHistory: toCausalHistory(@["missing-dep"]),
channelId: testChannel,
content: @[byte(3)],
bloomFilter: @[],
)
let serialized3 = serializeMessage(msg3).get()
let unwrapResult3 = rm.unwrapReceivedMessage(serialized3)
check unwrapResult3.isOk()
let (_, missingDeps3, _) = unwrapResult3.get()
check missingDeps3.len == 1
check missingDeps3[0].messageId == "missing-dep"
# The hint is empty because it was not in our history, so the provider was not called
check missingDeps3[0].retrievalHint.len == 0
# Periodic task & Buffer management tests
suite "Periodic Tasks & Buffer Management":
var rm: ReliabilityManager
@ -291,7 +344,7 @@ suite "Periodic Tasks & Buffer Management":
let ackMsg = SdsMessage(
messageId: "ack1",
lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1,
causalHistory: @["msg0", "msg2", "msg4"],
causalHistory: toCausalHistory(@["msg0", "msg2", "msg4"]),
channelId: testChannel,
content: @[byte(100)],
bloomFilter: @[],
@ -420,7 +473,7 @@ suite "Special Cases Handling":
let msgInvalid = SdsMessage(
messageId: "invalid-bf",
lamportTimestamp: 1,
causalHistory: @[],
causalHistory: toCausalHistory(@[]),
channelId: testChannel,
content: @[byte(1)],
bloomFilter: @[1.byte, 2.byte, 3.byte] # Invalid filter data
@ -451,7 +504,7 @@ suite "Special Cases Handling":
let msg = SdsMessage(
messageId: "dup-msg",
lamportTimestamp: 1,
causalHistory: @[],
causalHistory: toCausalHistory(@[]),
channelId: testChannel,
content: @[byte(1)],
bloomFilter: @[],
@ -624,7 +677,7 @@ suite "Multi-Channel ReliabilityManager Tests":
let ackMsg1 = SdsMessage(
messageId: "ack1",
lamportTimestamp: rm.channels[channel1].lamportTimestamp + 1,
causalHistory: @[msgId1], # Acknowledge msg1
causalHistory: toCausalHistory(@[msgId1]), # Acknowledge msg1
channelId: channel1,
content: @[byte(100)],
bloomFilter: @[],
@ -633,7 +686,7 @@ suite "Multi-Channel ReliabilityManager Tests":
let ackMsg2 = SdsMessage(
messageId: "ack2",
lamportTimestamp: rm.channels[channel2].lamportTimestamp + 1,
causalHistory: @[msgId2], # Acknowledge msg2
causalHistory: toCausalHistory(@[msgId2]), # Acknowledge msg2
channelId: channel2,
content: @[byte(101)],
bloomFilter: @[],