From 887a0ac59abc379b374995f8412c324735376915 Mon Sep 17 00:00:00 2001 From: Ivan FB Date: Sat, 30 May 2026 01:25:03 +0200 Subject: [PATCH] enhance reliable channel segment states --- channels/reliable_channel.nim | 302 +++++++++--------- .../test_reliable_channel_send_receive.nim | 134 ++++++++ 2 files changed, 286 insertions(+), 150 deletions(-) diff --git a/channels/reliable_channel.nim b/channels/reliable_channel.nim index c3fbe5d77..b68cfbfab 100644 --- a/channels/reliable_channel.nim +++ b/channels/reliable_channel.nim @@ -13,7 +13,7 @@ ## ## See: https://lip.logos.co/messaging/raw/reliable-channel-api.html -import std/[options, sets, tables] +import std/[options, tables] import results import chronos import bearssl/rand @@ -55,32 +55,35 @@ type Persistent Ephemeral - SegmentSendState {.pure.} = enum - ## Lifecycle of a single segment as tracked by the channel. The - ## messaging layer has its own richer `DeliveryState` (retries, - ## propagated-vs-validated); here we only model what's needed to - ## decide when a `channelReqId` is fully accounted for. - AwaitingRateLimit ## Pushed by `send`; not yet released by rate_limit_manager. - InFlight - ## Released by rate_limit_manager and handed to delivery_service; - ## `messagingReqId` is now set. - Confirmed ## `MessageSentEvent` arrived for `messagingReqId`. - Failed - ## `MessageErrorEvent` arrived for `messagingReqId`, or the local - ## delivery-task construction failed before any id was reachable. - - PendingMessagingRequest = object - ## One entry per segment (i.e. per messaging-layer request). The - ## relative order of `AwaitingRateLimit` entries must match the - ## order in which `rate_limit_manager` re-emits messages, which is - ## FIFO with `send()`. - channelReqId*: RequestId - ## The channel-layer parent id returned to the caller of `send()` in channel layer. - ## One channel request maps to N pending messaging requests. - messagingReqId*: Option[RequestId] - ## Per-segment messaging layer id. `none` until `onReadyToSend` assigns it. + ChannelReqState = object + ## Per channel-level request, tracks how many of its segments are + ## still queued, in flight, or have terminated. The channel-level + ## final event fires when `confirmedCount + failedCount` reaches + ## `totalExpectedSegments` AND no segments are still awaiting dispatch + ## or in flight. persistenceReqType: MessagePersistence - segmentSendState*: SegmentSendState + totalExpectedSegments: int + ## Total segments produced by `segmentation.performSegmentation` + ## for this `channelReqId`. Set once in `send`, never mutated. + awaitingDispatch: int + ## Segments enqueued in `rate_limit_manager` but not yet claimed + ## by `onReadyToSend`. Decremented when `onReadyToSend` picks a + ## message and assigns it to this `channelReqId`. + inflightMessagingIds: seq[RequestId] + ## Messaging-layer ids minted by the send handler that have not + ## yet produced a final event. Removed on `MessageSentEvent` / `MessageErrorEvent`. + confirmedCount: int + failedCount: int + + ChannelReqs = OrderedTable[RequestId, ChannelReqState] + ## Key: channelReqId (the parent id returned by channel `send`). Value: + ## per-request state, see `ChannelReqState`. + ## + ## `OrderedTable` preserves insertion order, which matches the FIFO + ## order `rate_limit_manager` re-emits messages in: `onReadyToSend` + ## routes each segment to the first entry with `awaitingDispatch > 0`, + ## and that scan is correct precisely because the outer iteration + ## order matches the order `send` pushed entries. ReliableChannel* = ref object ## Spec-defined public type. Fields are private so callers cannot @@ -95,13 +98,23 @@ type sdsHandler: SdsHandler rateLimit: RateLimitManager - requestIds: Table[RequestId, seq[RequestId]] - pendingMessagingRequests: seq[PendingMessagingRequest] - ## Entries are kept until the matching segment reaches a final - ## state (`Confirmed` or `Failed`); a whole channel request is - ## then pruned in one pass once all its segments are final. + channelReqs: ChannelReqs brokerCtx: BrokerContext +func init( + T: type ChannelReqState, + persistenceReqType: MessagePersistence, + totalExpectedSegments: int, +): T = + return ChannelReqState( + persistenceReqType: persistenceReqType, + totalExpectedSegments: totalExpectedSegments, + awaitingDispatch: totalExpectedSegments, + inflightMessagingIds: @[], + confirmedCount: 0, + failedCount: 0, + ) + func getChannelId*(self: ReliableChannel): ChannelId {.inline.} = self.channelId @@ -111,70 +124,86 @@ func getContentTopic*(self: ReliableChannel): ContentTopic {.inline.} = func getSenderId*(self: ReliableChannel): SdsParticipantID {.inline.} = self.senderId -func isFinal(state: SegmentSendState): bool {.inline.} = - return state in {SegmentSendState.Confirmed, SegmentSendState.Failed} +proc tryFinalizeChannelReq(self: ReliableChannel, channelReqId: RequestId) = + let state = self.channelReqs.getOrDefault(channelReqId) + if state.totalExpectedSegments == 0: + ## Either already finalized (and removed) or never inserted. + return + if state.awaitingDispatch != 0 or state.inflightMessagingIds.len != 0: + return + if state.confirmedCount + state.failedCount < state.totalExpectedSegments: + return -proc pruneCompletedChannelReqs(self: ReliableChannel) = - ## Drop every `pendingMessagingRequests` entry whose `channelReqId` - ## has all of its segments in a final state. A single failing - ## segment doesn't trigger a drop on its own — we wait until siblings - ## are also accounted for, so the channel-level outcome is decided - ## from a complete picture. For each fully-final `channelReqId`, emit - ## the channel-level final event before the entries are dropped: - ## `ChannelMessageSentEvent` if every sibling Confirmed, - ## `ChannelMessageErrorEvent` if any sibling Failed. - var hasPending = initHashSet[RequestId]() - var anyFailed = initHashSet[RequestId]() - for entry in self.pendingMessagingRequests: - if not entry.segmentSendState.isFinal(): - hasPending.incl(entry.channelReqId) - elif entry.segmentSendState == SegmentSendState.Failed: - anyFailed.incl(entry.channelReqId) + self.channelReqs.del(channelReqId) - var emitted = initHashSet[RequestId]() - for entry in self.pendingMessagingRequests: - if entry.channelReqId in hasPending or entry.channelReqId in emitted: + if state.failedCount > 0: + ChannelMessageErrorEvent.emit( + self.brokerCtx, + ChannelMessageErrorEvent( + channelId: self.channelId, + requestId: channelReqId, + error: "one or more segments failed", + ), + ) + else: + ChannelMessageSentEvent.emit( + self.brokerCtx, + ChannelMessageSentEvent(channelId: self.channelId, requestId: channelReqId), + ) + +type ClaimedSegment = object + channelReqId: RequestId + isEphemeral: bool + +proc claimAwaitingChannelReq(self: ReliableChannel): Option[ClaimedSegment] = + for channelReqId, state in self.channelReqs.mpairs: + if state.awaitingDispatch > 0: + state.awaitingDispatch.dec() + return some( + ClaimedSegment( + channelReqId: channelReqId, + isEphemeral: state.persistenceReqType == MessagePersistence.Ephemeral, + ) + ) + return none(ClaimedSegment) + +type MessagingOutcome {.pure.} = enum + Sent + Failed + +proc onMessageFinal( + self: ReliableChannel, messagingReqId: RequestId, outcome: MessagingOutcome +) = + for channelReqId, state in self.channelReqs.mpairs: + let idx = state.inflightMessagingIds.find(messagingReqId) + if idx < 0: continue - emitted.incl(entry.channelReqId) - if entry.channelReqId in anyFailed: - ChannelMessageErrorEvent.emit( - self.brokerCtx, - ChannelMessageErrorEvent( - channelId: self.channelId, - requestId: entry.channelReqId, - error: "one or more segments failed", - ), - ) - else: - ChannelMessageSentEvent.emit( - self.brokerCtx, - ChannelMessageSentEvent( - channelId: self.channelId, requestId: entry.channelReqId - ), - ) + state.inflightMessagingIds.del(idx) + case outcome + of MessagingOutcome.Sent: + state.confirmedCount.inc() + of MessagingOutcome.Failed: + state.failedCount.inc() + self.tryFinalizeChannelReq(channelReqId) + return - self.pendingMessagingRequests.keepItIf(it.channelReqId in hasPending) +proc markSegmentFailed(self: ReliableChannel, channelReqId: RequestId) = + try: + self.channelReqs[channelReqId].failedCount.inc() + except KeyError as e: + error "unreachable: channelReqId not found in markSegmentFailed", + channelReqId = $channelReqId, error = e.msg + return + self.tryFinalizeChannelReq(channelReqId) -proc onMessageSent(self: ReliableChannel, messagingReqId: RequestId) = - ## Invoked from this channel's `MessageSentEvent` listener. Flips - ## the matching `InFlight` segment to `Confirmed` and prunes. The - ## listener routes every event through here; entries that don't - ## belong to this channel simply don't match and are no-ops. - self.pendingMessagingRequests.applyItIf( - it.segmentSendState == SegmentSendState.InFlight and - it.messagingReqId == some(messagingReqId) - ): - it.segmentSendState = SegmentSendState.Confirmed - self.pruneCompletedChannelReqs() - -proc onMessageError(self: ReliableChannel, messagingReqId: RequestId) = - ## Symmetric to `onMessageSent` but for `MessageErrorEvent`. - self.pendingMessagingRequests.applyItIf( - it.segmentSendState == SegmentSendState.InFlight and - it.messagingReqId == some(messagingReqId) - ): - it.segmentSendState = SegmentSendState.Failed - self.pruneCompletedChannelReqs() +proc markSegmentInflight( + self: ReliableChannel, channelReqId: RequestId, messagingReqId: RequestId +) = + try: + self.channelReqs[channelReqId].inflightMessagingIds.add(messagingReqId) + except KeyError as e: + error "unreachable: channelReqId not found in markSegmentInflight", + channelReqId = $channelReqId, error = e.msg proc onReadyToSend( self: ReliableChannel, readyToSendEvent: ReadyToSendEvent @@ -184,30 +213,22 @@ proc onReadyToSend( ## blobs (already-encoded SDS messages): ## ## ... -> rate_limit_manager -> [encryption] -> dispatch - var idx = 0 + ## + ## For each `m`, the next channelReqId still queued in rate-limit + ## claims the slot (FIFO across sibling sends). The channelReqId is + ## captured up front and used as a stable key for every later state + ## update — no positional index is ever held across an `await`, so + ## sibling events mutating other entries (or even this one's + ## `inflightMessagingIds`) cannot corrupt this fiber's view. for m in readyToSendEvent.msgs: - ## The first `AwaitingRateLimit` entry in push order is the one - ## this `m` belongs to: `send()` adds one entry per segment, and - ## `rate_limit_manager` re-emits them in the same FIFO order, so - ## the two sequences advance in lockstep. Earlier entries may - ## already be `InFlight` / `Confirmed` / `Failed` because they - ## live on until every sibling of their `channelReqId` is final, - ## so we walk past those to find the next one that was awaiting for this batch. - while idx < self.pendingMessagingRequests.len and - self.pendingMessagingRequests[idx].segmentSendState != - SegmentSendState.AwaitingRateLimit - : - idx.inc() - if idx >= self.pendingMessagingRequests.len: + let claimed = self.claimAwaitingChannelReq().valueOr: ## rate_limit_manager emitted more messages than we have pending — - ## should not happen given `send` pushes one entry per enqueued - ## SDS payload. Drop silently rather than corrupt state. + ## should not happen given `send` increments `awaitingDispatch` + ## once per enqueued SDS payload. Drop silently rather than + ## corrupt state. break - - let channelReqId = self.pendingMessagingRequests[idx].channelReqId - let isEphemeral = - self.pendingMessagingRequests[idx].persistenceReqType == - MessagePersistence.Ephemeral + let channelReqId = claimed.channelReqId + let isEphemeral = claimed.isEphemeral ## TODO: revisit which fields of the SDS message must be encrypted. ## Encrypting the whole encoded blob forces every receiver to attempt @@ -223,15 +244,7 @@ proc onReadyToSend( ), ) ## Encryption failed *before* we could hand the segment to the - ## delivery layer — no `messagingReqId` was minted and no - ## `DeliveryTask` was queued on `sendService`. The delivery - ## layer will therefore never emit a `MessageSentEvent` / - ## `MessageErrorEvent` for this segment, so `onMessageError` - ## won't fire either. Advance the state machine inline so the - ## parent `channelReqId` can still be pruned once its siblings - ## are also final. - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed - idx.inc() + self.markSegmentFailed(channelReqId) continue let wireBytes = seq[byte](encrypted) @@ -261,16 +274,10 @@ proc onReadyToSend( requestId: channelReqId, messageHash: "", error: "waku send failed: " & error ), ) - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed - idx.inc() + self.markSegmentFailed(channelReqId) continue - self.pendingMessagingRequests[idx].messagingReqId = some(messagingReqId) - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.InFlight - self.requestIds.mgetOrPut(channelReqId, @[]).add(messagingReqId) - idx.inc() - - self.pruneCompletedChannelReqs() + self.markSegmentInflight(channelReqId, messagingReqId) proc send*( self: ReliableChannel, payload: seq[byte], ephemeral: bool = false @@ -283,23 +290,20 @@ proc send*( ## ## `rate_limit_manager.enqueueToSend` emits a `ReadyToSendEvent` with ## the SDS messages cleared for transmission; the channel's listener - ## then runs the final stage (encryption -> dispatch). The - ## `persistenceReqType` is carried alongside each segment in - ## `pendingMessagingRequests` and stamped onto the eventual - ## `MessageEnvelope`. + ## then runs the final stage (encryption -> dispatch). ## ## The returned `RequestId` is the channel-level parent of one-or-more - ## messaging-layer `RequestId`s; the mapping is recorded in - ## `self.requestIds`. + ## messaging-layer `RequestId`s; the mapping is held in + ## `self.channelReqs` until every segment is final. if payload.len == 0: return err("empty payload") let channelReqId = RequestId.new(self.rng) - self.requestIds[channelReqId] = @[] - let persistenceReqType = if ephemeral: MessagePersistence.Ephemeral else: MessagePersistence.Persistent + var segmentCount = 0 + var enqueued: seq[seq[byte]] for segmentBytes in self.segmentation.performSegmentation(payload): ## Segments arrive already encoded; the segmentation module owns ## the wire format so SDS only ever sees opaque bytes. @@ -307,14 +311,13 @@ proc send*( self.channelId, self.senderId, segmentBytes ).valueOr: return err("SDS wrap failed: " & error) - self.pendingMessagingRequests.add( - PendingMessagingRequest( - channelReqId: channelReqId, - messagingReqId: none(RequestId), - persistenceReqType: persistenceReqType, - segmentSendState: SegmentSendState.AwaitingRateLimit, - ) - ) + enqueued.add(sdsBytes) + segmentCount.inc() + + self.channelReqs[channelReqId] = + ChannelReqState.init(persistenceReqType, segmentCount) + + for sdsBytes in enqueued: self.rateLimit.enqueueToSend(sdsBytes) return ok(channelReqId) @@ -402,8 +405,7 @@ proc new*( segmentation: SegmentationHandler.new(segConfig), sdsHandler: SdsHandler.new(sdsConfig, senderId), rateLimit: RateLimitManager.new(rateConfig, channelId, brokerCtx), - requestIds: initTable[RequestId, seq[RequestId]](), - pendingMessagingRequests: @[], + channelReqs: initOrderedTable[RequestId, ChannelReqState](), brokerCtx: brokerCtx, ) @@ -411,8 +413,8 @@ proc new*( ## listeners on `chn.brokerCtx`, filtered to traffic addressed to ## this channel. Keeping the listeners (and the handler procs they ## call) inside the channel lets `onReadyToSend` / - ## `onMessageReceived` / `onMessageSent` / `onMessageError` stay - ## private — the manager doesn't need to know about them. + ## `onMessageReceived` / `onMessageFinal` stay private — the + ## manager doesn't need to know about them. discard ReadyToSendEvent.listen( chn.brokerCtx, proc(evt: ReadyToSendEvent): Future[void] {.async: (raises: []).} = @@ -441,13 +443,13 @@ proc new*( discard MessageSentEvent.listen( chn.brokerCtx, proc(evt: MessageSentEvent): Future[void] {.async: (raises: []).} = - chn.onMessageSent(evt.requestId), + chn.onMessageFinal(evt.requestId, MessagingOutcome.Sent), ) discard MessageErrorEvent.listen( chn.brokerCtx, proc(evt: MessageErrorEvent): Future[void] {.async: (raises: []).} = - chn.onMessageError(evt.requestId), + chn.onMessageFinal(evt.requestId, MessagingOutcome.Failed), ) return chn diff --git a/tests/channels/test_reliable_channel_send_receive.nim b/tests/channels/test_reliable_channel_send_receive.nim index 2f49182a2..04ad539d0 100644 --- a/tests/channels/test_reliable_channel_send_receive.nim +++ b/tests/channels/test_reliable_channel_send_receive.nim @@ -315,3 +315,137 @@ suite "Reliable Channel - send state machine": ## `messagingReqId`s from a fake `SendHandler`, finalise some, and ## assert prune only fires once every sibling is final. skip() + + asyncTest "sibling MessageSentEvent during sendHandler await does not corrupt state": + ## Regression test for the prune-during-await race + ## (PR #3914 review comment r3324891059). The historical model + ## tracked segments in a flat `seq[PendingMessagingRequest]`, + ## so a sibling `MessageSentEvent` arriving while `onReadyToSend` + ## was paused at `await self.sendHandler(...)` would call + ## `keepItIf` on that seq, shift the entries under the live `idx` + ## walk, and either misassign the in-flight `messagingReqId` to + ## the wrong row or crash on out-of-bounds. The current model + ## keys per-request state by `channelReqId` in an `OrderedTable`, + ## so every lookup is by key (not position) and stays valid + ## across awaits. This test locks in the contract: both + ## `channelReqId`s must still produce exactly one terminal + ## `ChannelMessageSentEvent`. + const + channelId = ChannelId("sm-race-channel") + contentTopic = ContentTopic("/reliable-channel/test/sm-race") + + var manager: ReliableChannelManager + var brokerCtx: BrokerContext + lockNewGlobalBrokerContext: + brokerCtx = globalBrokerContext() + manager = (await ReliableChannelManager.new(createApiNodeConf())).expect( + "Failed to create manager" + ) + + setNoopEncryption() + + var msgReqIds: seq[RequestId] + var sendsReturned = 0 + let fakeSend: SendHandler = proc( + env: MessageEnvelope + ): Future[Result[RequestId, string]] {.async: (raises: [CatchableError]), gcsafe.} = + ## Call 1: return immediately so the first segment's + ## `messagingReqId` lands in `inflightMessagingIds`. Call 2: + ## emit the final `MessageSentEvent` for the first segment, + ## then yield via `sleepAsync` so the listener task runs while + ## the second segment is still mid-`await` in `onReadyToSend`. + ## Under the old positional-index model this is exactly the + ## window that corrupted state; under the table-keyed model + ## the listener mutates a different key and leaves our entry + ## untouched. + let id = RequestId("race-msg-req-" & $(msgReqIds.len + 1)) + msgReqIds.add(id) + if msgReqIds.len == 2: + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqIds[0], messageHash: ""), + ) + await sleepAsync(50.milliseconds) + sendsReturned.inc() + return ok(id) + + discard manager + .createReliableChannel( + channelId, contentTopic, SdsParticipantID("local"), sendHandler = fakeSend + ) + .expect("createReliableChannel") + + var finalisedReqIds: seq[RequestId] + let bothFinalised = newFuture[void]("both-finalised") + discard ChannelMessageSentEvent + .listen( + brokerCtx, + proc(evt: ChannelMessageSentEvent) {.async: (raises: []).} = + if evt.channelId == channelId: + finalisedReqIds.add(evt.requestId) + if finalisedReqIds.len == 2 and not bothFinalised.finished(): + bothFinalised.complete() + , + ) + .expect("listen ChannelMessageSentEvent") + + let channelReqId1 = manager.send(channelId, "first".toBytes()).expect("send 1") + + ## Let the first segment fully traverse the pipeline so entry[0] + ## is firmly `InFlight` with its `messagingReqId` set before send2 + ## queues entry[1]. Without this, listener2 could see entry[0] + ## still `AwaitingRateLimit` and bind m2 to the wrong row — that + ## is a different, pre-existing concurrency assumption (rate-limit + ## FIFO between sibling sends) and not the bug we are testing. + let firstDispatched = Moment.now() + 1.seconds + while Moment.now() < firstDispatched and msgReqIds.len < 1: + await sleepAsync(5.milliseconds) + check msgReqIds.len == 1 + + let channelReqId2 = manager.send(channelId, "second".toBytes()).expect("send 2") + + ## Wait until the second `fakeSend` has fully returned (not just + ## entered). `sendsReturned` ticks after the sleep inside the if + ## branch, so once it reaches 2 we know `fakeSend(m2)` has handed + ## control back to `onReadyToSend`. The extra `sleepAsync` below + ## then yields one more time so the chronos scheduler can run + ## `onReadyToSend`'s post-await continuation — which is where + ## `entry[1].messagingReqId = some(id2)` / state = `InFlight` + ## actually get written. Without that final yield, the emit below + ## would race the continuation and `MessageSentEvent(id2)` would + ## find no `InFlight` entry to match. + let dispatchDeadline = Moment.now() + 1.seconds + while Moment.now() < dispatchDeadline and sendsReturned < 2: + await sleepAsync(5.milliseconds) + check sendsReturned == 2 + await sleepAsync(50.milliseconds) + + ## Now finalise the second segment from the outside; its final + ## event must drive `channelReqId2` to `ChannelMessageSentEvent`. + ## If the race corrupted state during the await, the second + ## `messagingReqId` would never have been written to the right + ## entry and this event would silently never fire. + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqIds[1], messageHash: ""), + ) + + ## Under the table-keyed model both events fire because no fiber + ## ever holds a positional reference: `onMessageSent(id1)` looks + ## up `channelReq1` by key, decrements its counters, and finalizes + ## (deleting that table key); meanwhile `onReadyToSend` is still + ## processing `channelReq2` — a different key — so its post-await + ## write to `inflightMessagingIds` lands on the correct, untouched + ## entry. The external `MessageSentEvent(id2)` below then resolves + ## `channelReq2`. Under the old `seq[PendingMessagingRequest]` + ## model the sibling listener's `keepItIf` would have shifted the + ## seq under the live `idx` walk and either crashed with + ## `IndexDefect` or silently lost `channelReqId2`'s terminal event. + let arrived = await bothFinalised.withTimeout(2.seconds) + check arrived + if arrived: + check finalisedReqIds.len == 2 + check channelReqId1 in finalisedReqIds + check channelReqId2 in finalisedReqIds + + await manager.stop()