From 8b0e21fadaf3a4115cf050a95f6903eb6230b33b Mon Sep 17 00:00:00 2001 From: Ivan FB <128452529+Ivansete-status@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:16:13 +0200 Subject: [PATCH] enhance reliable channel segment states (#3919) --- channels/reliable_channel.nim | 308 +++++++++--------- .../test_reliable_channel_send_receive.nim | 99 ++++++ 2 files changed, 257 insertions(+), 150 deletions(-) diff --git a/channels/reliable_channel.nim b/channels/reliable_channel.nim index c3fbe5d77..e32b57e36 100644 --- a/channels/reliable_channel.nim +++ b/channels/reliable_channel.nim @@ -13,7 +13,7 @@ ## ## See: https://lip.logos.co/messaging/raw/reliable-channel-api.html -import std/[options, sets, tables] +import std/[options, tables] import results import chronos import bearssl/rand @@ -55,32 +55,35 @@ type Persistent Ephemeral - SegmentSendState {.pure.} = enum - ## Lifecycle of a single segment as tracked by the channel. The - ## messaging layer has its own richer `DeliveryState` (retries, - ## propagated-vs-validated); here we only model what's needed to - ## decide when a `channelReqId` is fully accounted for. - AwaitingRateLimit ## Pushed by `send`; not yet released by rate_limit_manager. - InFlight - ## Released by rate_limit_manager and handed to delivery_service; - ## `messagingReqId` is now set. - Confirmed ## `MessageSentEvent` arrived for `messagingReqId`. - Failed - ## `MessageErrorEvent` arrived for `messagingReqId`, or the local - ## delivery-task construction failed before any id was reachable. - - PendingMessagingRequest = object - ## One entry per segment (i.e. per messaging-layer request). The - ## relative order of `AwaitingRateLimit` entries must match the - ## order in which `rate_limit_manager` re-emits messages, which is - ## FIFO with `send()`. - channelReqId*: RequestId - ## The channel-layer parent id returned to the caller of `send()` in channel layer. - ## One channel request maps to N pending messaging requests. - messagingReqId*: Option[RequestId] - ## Per-segment messaging layer id. `none` until `onReadyToSend` assigns it. + ChannelReqState = object + ## Per channel-level request, tracks how many of its segments are + ## still queued, in flight, or have terminated. The channel-level + ## final event fires when `confirmedCount + failedCount` reaches + ## `totalExpectedSegments` AND no segments are still awaiting dispatch + ## or in flight. persistenceReqType: MessagePersistence - segmentSendState*: SegmentSendState + totalExpectedSegments: int + ## Total segments produced by `segmentation.performSegmentation` + ## for this `channelReqId`. Set once in `send`, never mutated. + awaitingDispatch: int + ## Segments enqueued in `rate_limit_manager` but not yet claimed + ## by `onReadyToSend`. Decremented when `onReadyToSend` picks a + ## message and assigns it to this `channelReqId`. + inflightMessagingIds: seq[RequestId] + ## Messaging-layer ids minted by the send handler that have not + ## yet produced a final event. Removed on `MessageSentEvent` / `MessageErrorEvent`. + confirmedCount: int + failedCount: int + + ChannelReqs = OrderedTable[RequestId, ChannelReqState] + ## Key: channelReqId (the parent id returned by channel `send`). Value: + ## per-request state, see `ChannelReqState`. + ## + ## `OrderedTable` preserves insertion order, which matches the FIFO + ## order `rate_limit_manager` re-emits messages in: `onReadyToSend` + ## routes each segment to the first entry with `awaitingDispatch > 0`, + ## and that scan is correct precisely because the outer iteration + ## order matches the order `send` pushed entries. ReliableChannel* = ref object ## Spec-defined public type. Fields are private so callers cannot @@ -95,13 +98,23 @@ type sdsHandler: SdsHandler rateLimit: RateLimitManager - requestIds: Table[RequestId, seq[RequestId]] - pendingMessagingRequests: seq[PendingMessagingRequest] - ## Entries are kept until the matching segment reaches a final - ## state (`Confirmed` or `Failed`); a whole channel request is - ## then pruned in one pass once all its segments are final. + channelReqs: ChannelReqs brokerCtx: BrokerContext +func init( + T: type ChannelReqState, + persistenceReqType: MessagePersistence, + totalExpectedSegments: int, +): T = + return ChannelReqState( + persistenceReqType: persistenceReqType, + totalExpectedSegments: totalExpectedSegments, + awaitingDispatch: totalExpectedSegments, + inflightMessagingIds: @[], + confirmedCount: 0, + failedCount: 0, + ) + func getChannelId*(self: ReliableChannel): ChannelId {.inline.} = self.channelId @@ -111,70 +124,92 @@ func getContentTopic*(self: ReliableChannel): ContentTopic {.inline.} = func getSenderId*(self: ReliableChannel): SdsParticipantID {.inline.} = self.senderId -func isFinal(state: SegmentSendState): bool {.inline.} = - return state in {SegmentSendState.Confirmed, SegmentSendState.Failed} +proc tryFinalizeChannelReq(self: ReliableChannel, channelReqId: RequestId) = + ## Tries to finalize the channel-level request identified by `channelReqId` if + ## certain conditions are met, i.e., no segments are still awaiting dispatch or in flight, + ## and the total number of confirmed + failed segments equals the total expected segments. + ## Therefore, the channel-level request is removed from `self.channelReqs` + ## and the appropriate final event is emitted. + ## + let state = self.channelReqs.getOrDefault(channelReqId) + if state.totalExpectedSegments == 0: + ## Either already finalized (and removed) or never inserted. + return + if state.awaitingDispatch != 0 or state.inflightMessagingIds.len != 0: + return + if state.confirmedCount + state.failedCount < state.totalExpectedSegments: + return -proc pruneCompletedChannelReqs(self: ReliableChannel) = - ## Drop every `pendingMessagingRequests` entry whose `channelReqId` - ## has all of its segments in a final state. A single failing - ## segment doesn't trigger a drop on its own — we wait until siblings - ## are also accounted for, so the channel-level outcome is decided - ## from a complete picture. For each fully-final `channelReqId`, emit - ## the channel-level final event before the entries are dropped: - ## `ChannelMessageSentEvent` if every sibling Confirmed, - ## `ChannelMessageErrorEvent` if any sibling Failed. - var hasPending = initHashSet[RequestId]() - var anyFailed = initHashSet[RequestId]() - for entry in self.pendingMessagingRequests: - if not entry.segmentSendState.isFinal(): - hasPending.incl(entry.channelReqId) - elif entry.segmentSendState == SegmentSendState.Failed: - anyFailed.incl(entry.channelReqId) + self.channelReqs.del(channelReqId) - var emitted = initHashSet[RequestId]() - for entry in self.pendingMessagingRequests: - if entry.channelReqId in hasPending or entry.channelReqId in emitted: + if state.failedCount > 0: + ChannelMessageErrorEvent.emit( + self.brokerCtx, + ChannelMessageErrorEvent( + channelId: self.channelId, + requestId: channelReqId, + error: "one or more segments failed", + ), + ) + else: + ChannelMessageSentEvent.emit( + self.brokerCtx, + ChannelMessageSentEvent(channelId: self.channelId, requestId: channelReqId), + ) + +type ClaimedSegment = object + channelReqId: RequestId + isEphemeral: bool + +proc claimAwaitingChannelReq(self: ReliableChannel): Option[ClaimedSegment] = + for channelReqId, state in self.channelReqs.mpairs: + if state.awaitingDispatch > 0: + state.awaitingDispatch.dec() + return some( + ClaimedSegment( + channelReqId: channelReqId, + isEphemeral: state.persistenceReqType == MessagePersistence.Ephemeral, + ) + ) + return none(ClaimedSegment) + +type MessagingOutcome {.pure.} = enum + Sent + Failed + +proc onMessageFinal( + self: ReliableChannel, messagingReqId: RequestId, outcome: MessagingOutcome +) = + for channelReqId, state in self.channelReqs.mpairs: + let idx = state.inflightMessagingIds.find(messagingReqId) + if idx < 0: continue - emitted.incl(entry.channelReqId) - if entry.channelReqId in anyFailed: - ChannelMessageErrorEvent.emit( - self.brokerCtx, - ChannelMessageErrorEvent( - channelId: self.channelId, - requestId: entry.channelReqId, - error: "one or more segments failed", - ), - ) - else: - ChannelMessageSentEvent.emit( - self.brokerCtx, - ChannelMessageSentEvent( - channelId: self.channelId, requestId: entry.channelReqId - ), - ) + state.inflightMessagingIds.del(idx) + case outcome + of MessagingOutcome.Sent: + state.confirmedCount.inc() + of MessagingOutcome.Failed: + state.failedCount.inc() + self.tryFinalizeChannelReq(channelReqId) + return - self.pendingMessagingRequests.keepItIf(it.channelReqId in hasPending) +proc markSegmentFailed(self: ReliableChannel, channelReqId: RequestId) = + try: + self.channelReqs[channelReqId].failedCount.inc() + except KeyError as e: + error "unreachable: channelReqId not found in markSegmentFailed", + channelReqId = $channelReqId, error = e.msg + return + self.tryFinalizeChannelReq(channelReqId) -proc onMessageSent(self: ReliableChannel, messagingReqId: RequestId) = - ## Invoked from this channel's `MessageSentEvent` listener. Flips - ## the matching `InFlight` segment to `Confirmed` and prunes. The - ## listener routes every event through here; entries that don't - ## belong to this channel simply don't match and are no-ops. - self.pendingMessagingRequests.applyItIf( - it.segmentSendState == SegmentSendState.InFlight and - it.messagingReqId == some(messagingReqId) - ): - it.segmentSendState = SegmentSendState.Confirmed - self.pruneCompletedChannelReqs() - -proc onMessageError(self: ReliableChannel, messagingReqId: RequestId) = - ## Symmetric to `onMessageSent` but for `MessageErrorEvent`. - self.pendingMessagingRequests.applyItIf( - it.segmentSendState == SegmentSendState.InFlight and - it.messagingReqId == some(messagingReqId) - ): - it.segmentSendState = SegmentSendState.Failed - self.pruneCompletedChannelReqs() +proc markSegmentInflight( + self: ReliableChannel, channelReqId: RequestId, messagingReqId: RequestId +) = + try: + self.channelReqs[channelReqId].inflightMessagingIds.add(messagingReqId) + except KeyError as e: + error "unreachable: channelReqId not found in markSegmentInflight", + channelReqId = $channelReqId, error = e.msg proc onReadyToSend( self: ReliableChannel, readyToSendEvent: ReadyToSendEvent @@ -184,30 +219,22 @@ proc onReadyToSend( ## blobs (already-encoded SDS messages): ## ## ... -> rate_limit_manager -> [encryption] -> dispatch - var idx = 0 + ## + ## For each `m`, the next channelReqId still queued in rate-limit + ## claims the slot (FIFO across sibling sends). The channelReqId is + ## captured up front and used as a stable key for every later state + ## update — no positional index is ever held across an `await`, so + ## sibling events mutating other entries (or even this one's + ## `inflightMessagingIds`) cannot corrupt this fiber's view. for m in readyToSendEvent.msgs: - ## The first `AwaitingRateLimit` entry in push order is the one - ## this `m` belongs to: `send()` adds one entry per segment, and - ## `rate_limit_manager` re-emits them in the same FIFO order, so - ## the two sequences advance in lockstep. Earlier entries may - ## already be `InFlight` / `Confirmed` / `Failed` because they - ## live on until every sibling of their `channelReqId` is final, - ## so we walk past those to find the next one that was awaiting for this batch. - while idx < self.pendingMessagingRequests.len and - self.pendingMessagingRequests[idx].segmentSendState != - SegmentSendState.AwaitingRateLimit - : - idx.inc() - if idx >= self.pendingMessagingRequests.len: + let claimed = self.claimAwaitingChannelReq().valueOr: ## rate_limit_manager emitted more messages than we have pending — - ## should not happen given `send` pushes one entry per enqueued - ## SDS payload. Drop silently rather than corrupt state. + ## should not happen given `send` increments `awaitingDispatch` + ## once per enqueued SDS payload. Drop silently rather than + ## corrupt state. break - - let channelReqId = self.pendingMessagingRequests[idx].channelReqId - let isEphemeral = - self.pendingMessagingRequests[idx].persistenceReqType == - MessagePersistence.Ephemeral + let channelReqId = claimed.channelReqId + let isEphemeral = claimed.isEphemeral ## TODO: revisit which fields of the SDS message must be encrypted. ## Encrypting the whole encoded blob forces every receiver to attempt @@ -223,15 +250,7 @@ proc onReadyToSend( ), ) ## Encryption failed *before* we could hand the segment to the - ## delivery layer — no `messagingReqId` was minted and no - ## `DeliveryTask` was queued on `sendService`. The delivery - ## layer will therefore never emit a `MessageSentEvent` / - ## `MessageErrorEvent` for this segment, so `onMessageError` - ## won't fire either. Advance the state machine inline so the - ## parent `channelReqId` can still be pruned once its siblings - ## are also final. - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed - idx.inc() + self.markSegmentFailed(channelReqId) continue let wireBytes = seq[byte](encrypted) @@ -261,16 +280,10 @@ proc onReadyToSend( requestId: channelReqId, messageHash: "", error: "waku send failed: " & error ), ) - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed - idx.inc() + self.markSegmentFailed(channelReqId) continue - self.pendingMessagingRequests[idx].messagingReqId = some(messagingReqId) - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.InFlight - self.requestIds.mgetOrPut(channelReqId, @[]).add(messagingReqId) - idx.inc() - - self.pruneCompletedChannelReqs() + self.markSegmentInflight(channelReqId, messagingReqId) proc send*( self: ReliableChannel, payload: seq[byte], ephemeral: bool = false @@ -283,23 +296,20 @@ proc send*( ## ## `rate_limit_manager.enqueueToSend` emits a `ReadyToSendEvent` with ## the SDS messages cleared for transmission; the channel's listener - ## then runs the final stage (encryption -> dispatch). The - ## `persistenceReqType` is carried alongside each segment in - ## `pendingMessagingRequests` and stamped onto the eventual - ## `MessageEnvelope`. + ## then runs the final stage (encryption -> dispatch). ## ## The returned `RequestId` is the channel-level parent of one-or-more - ## messaging-layer `RequestId`s; the mapping is recorded in - ## `self.requestIds`. + ## messaging-layer `RequestId`s; the mapping is held in + ## `self.channelReqs` until every segment is final. if payload.len == 0: return err("empty payload") let channelReqId = RequestId.new(self.rng) - self.requestIds[channelReqId] = @[] - let persistenceReqType = if ephemeral: MessagePersistence.Ephemeral else: MessagePersistence.Persistent + var segmentCount = 0 + var enqueued: seq[seq[byte]] for segmentBytes in self.segmentation.performSegmentation(payload): ## Segments arrive already encoded; the segmentation module owns ## the wire format so SDS only ever sees opaque bytes. @@ -307,14 +317,13 @@ proc send*( self.channelId, self.senderId, segmentBytes ).valueOr: return err("SDS wrap failed: " & error) - self.pendingMessagingRequests.add( - PendingMessagingRequest( - channelReqId: channelReqId, - messagingReqId: none(RequestId), - persistenceReqType: persistenceReqType, - segmentSendState: SegmentSendState.AwaitingRateLimit, - ) - ) + enqueued.add(sdsBytes) + segmentCount.inc() + + self.channelReqs[channelReqId] = + ChannelReqState.init(persistenceReqType, segmentCount) + + for sdsBytes in enqueued: self.rateLimit.enqueueToSend(sdsBytes) return ok(channelReqId) @@ -402,8 +411,7 @@ proc new*( segmentation: SegmentationHandler.new(segConfig), sdsHandler: SdsHandler.new(sdsConfig, senderId), rateLimit: RateLimitManager.new(rateConfig, channelId, brokerCtx), - requestIds: initTable[RequestId, seq[RequestId]](), - pendingMessagingRequests: @[], + channelReqs: initOrderedTable[RequestId, ChannelReqState](), brokerCtx: brokerCtx, ) @@ -411,8 +419,8 @@ proc new*( ## listeners on `chn.brokerCtx`, filtered to traffic addressed to ## this channel. Keeping the listeners (and the handler procs they ## call) inside the channel lets `onReadyToSend` / - ## `onMessageReceived` / `onMessageSent` / `onMessageError` stay - ## private — the manager doesn't need to know about them. + ## `onMessageReceived` / `onMessageFinal` stay private — the + ## manager doesn't need to know about them. discard ReadyToSendEvent.listen( chn.brokerCtx, proc(evt: ReadyToSendEvent): Future[void] {.async: (raises: []).} = @@ -441,13 +449,13 @@ proc new*( discard MessageSentEvent.listen( chn.brokerCtx, proc(evt: MessageSentEvent): Future[void] {.async: (raises: []).} = - chn.onMessageSent(evt.requestId), + chn.onMessageFinal(evt.requestId, MessagingOutcome.Sent), ) discard MessageErrorEvent.listen( chn.brokerCtx, proc(evt: MessageErrorEvent): Future[void] {.async: (raises: []).} = - chn.onMessageError(evt.requestId), + chn.onMessageFinal(evt.requestId, MessagingOutcome.Failed), ) return chn diff --git a/tests/channels/test_reliable_channel_send_receive.nim b/tests/channels/test_reliable_channel_send_receive.nim index 2f49182a2..5ea300eb3 100644 --- a/tests/channels/test_reliable_channel_send_receive.nim +++ b/tests/channels/test_reliable_channel_send_receive.nim @@ -315,3 +315,102 @@ suite "Reliable Channel - send state machine": ## `messagingReqId`s from a fake `SendHandler`, finalise some, and ## assert prune only fires once every sibling is final. skip() + + asyncTest "sibling MessageSentEvent during sendHandler await does not corrupt state": + ## Regression test for the prune-during-await race + ## (PR #3914 review comment r3324891059). Locks in that a sibling + ## `MessageSentEvent` firing while `onReadyToSend` is paused at an + ## `await` does not lose the second `channelReqId`'s terminal + ## event. + const + channelId = ChannelId("sm-race-channel") + contentTopic = ContentTopic("/reliable-channel/test/sm-race") + + var manager: ReliableChannelManager + var brokerCtx: BrokerContext + lockNewGlobalBrokerContext: + brokerCtx = globalBrokerContext() + manager = (await ReliableChannelManager.new(createApiNodeConf())).expect( + "Failed to create manager" + ) + + setNoopEncryption() + + var msgReqIds: seq[RequestId] + var sendsReturned = 0 + let fakeSend: SendHandler = proc( + env: MessageEnvelope + ): Future[Result[RequestId, string]] {.async: (raises: [CatchableError]), gcsafe.} = + ## Call 2 fires the first segment's terminal event and then + ## yields, so the listener task runs while the second segment + ## is still mid-`await` in `onReadyToSend` — the exact race + ## window the regression test targets. + let id = RequestId("race-msg-req-" & $(msgReqIds.len + 1)) + msgReqIds.add(id) + if msgReqIds.len == 2: + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqIds[0], messageHash: ""), + ) + await sleepAsync(50.milliseconds) + sendsReturned.inc() + return ok(id) + + discard manager + .createReliableChannel( + channelId, contentTopic, SdsParticipantID("local"), sendHandler = fakeSend + ) + .expect("createReliableChannel") + + var finalisedReqIds: seq[RequestId] + let bothFinalised = newFuture[void]("both-finalised") + discard ChannelMessageSentEvent + .listen( + brokerCtx, + proc(evt: ChannelMessageSentEvent) {.async: (raises: []).} = + if evt.channelId == channelId: + finalisedReqIds.add(evt.requestId) + if finalisedReqIds.len == 2 and not bothFinalised.finished(): + bothFinalised.complete() + , + ) + .expect("listen ChannelMessageSentEvent") + + let channelReqId1 = manager.send(channelId, "first".toBytes()).expect("send 1") + + ## Drain the first segment fully before queueing the second, so + ## the rate-limit FIFO between sibling sends isn't itself under + ## test here. + let firstDispatched = Moment.now() + 1.seconds + while Moment.now() < firstDispatched and msgReqIds.len < 1: + await sleepAsync(5.milliseconds) + check msgReqIds.len == 1 + + let channelReqId2 = manager.send(channelId, "second".toBytes()).expect("send 2") + + ## Wait until `fakeSend(m2)` has fully returned and yield once + ## more so `onReadyToSend`'s post-await continuation gets a chance + ## to register `id2` in `inflightMessagingIds` before we emit its + ## terminal event. + let dispatchDeadline = Moment.now() + 1.seconds + while Moment.now() < dispatchDeadline and sendsReturned < 2: + await sleepAsync(5.milliseconds) + check sendsReturned == 2 + await sleepAsync(50.milliseconds) + + ## Finalise the second segment from the outside. If the race + ## corrupted state, `channelReqId2`'s entry would never reach + ## `inflightMessagingIds` and this event would silently miss. + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqIds[1], messageHash: ""), + ) + + let arrived = await bothFinalised.withTimeout(2.seconds) + check arrived + if arrived: + check finalisedReqIds.len == 2 + check channelReqId1 in finalisedReqIds + check channelReqId2 in finalisedReqIds + + await manager.stop()