diff --git a/channels/reliable_channel.nim b/channels/reliable_channel.nim index 2a7d01d35..c3a49a370 100644 --- a/channels/reliable_channel.nim +++ b/channels/reliable_channel.nim @@ -13,7 +13,7 @@ ## ## See: https://lip.logos.co/messaging/raw/reliable-channel-api.html -import std/[options, tables] +import std/[options, sets, tables] import results import chronos import bearssl/rand @@ -42,27 +42,59 @@ const LipWireReliableChannelVersion* = "RELIABLE-CHANNEL-API/1" ## The trailing `/N` is the wire-format version and is bumped only ## on breaking on-the-wire changes; implementations pin one version. -type ReliableChannel* = ref object - ## Spec-defined public type. Fields are private so callers cannot - ## mutate internals and break invariants. Getters are added below - ## for the few values consumers may need. - deliveryService: DeliveryService - channelId: ChannelId - contentTopic: ContentTopic - senderId: SdsParticipantID - rng: ref HmacDrbgContext - segmentation: SegmentationHandler - sdsHandler: SdsHandler - rateLimit: RateLimitManager +type + MessagePersistence {.pure.} = enum + Persistent + Ephemeral - requestIds: Table[RequestId, seq[RequestId]] - pendingRequests: seq[tuple[parent: RequestId, ephemeral: bool]] - brokerCtx: BrokerContext - ## Captured here so the channel emits `ChannelMessageReceivedEvent` - ## on the same broker context the owning manager registered its - ## listeners on. Without this, an emit via `globalBrokerContext()` - ## would land on whatever context happens to be thread-local at - ## emit time, which is not necessarily the manager's. + SegmentSendState {.pure.} = enum + ## Lifecycle of a single segment as tracked by the channel. The + ## messaging layer has its own richer `DeliveryState` (retries, + ## propagated-vs-validated); here we only model what's needed to + ## decide when a `channelReqId` is fully accounted for. + AwaitingRateLimit + ## Pushed by `send`; not yet released by rate_limit_manager. + InFlight + ## Released by rate_limit_manager and handed to delivery_service; + ## `messagingReqId` is now set. + Confirmed + ## `MessageSentEvent` arrived for `messagingReqId`. + Failed + ## `MessageErrorEvent` arrived for `messagingReqId`, or the local + ## delivery-task construction failed before any id was reachable. + + PendingMessagingRequest = object + ## One entry per segment (i.e. per messaging-layer request). The + ## relative order of `AwaitingRateLimit` entries must match the + ## order in which `rate_limit_manager` re-emits messages, which is + ## FIFO with `send()`. + channelReqId*: RequestId + ## The channel-layer parent id returned to the caller of `send()` in channel layer. + ## One channel request maps to N pending messaging requests. + messagingReqId*: Option[RequestId] + ## Per-segment messaging layer id. `none` until `onReadyToSend` assigns it. + persistenceReqType: MessagePersistence + segmentSendState*: SegmentSendState + + ReliableChannel* = ref object + ## Spec-defined public type. Fields are private so callers cannot + ## mutate internals and break invariants. Getters are added below + ## for the few values consumers may need. + deliveryService: DeliveryService + channelId: ChannelId + contentTopic: ContentTopic + senderId: SdsParticipantID + rng: ref HmacDrbgContext + segmentation: SegmentationHandler + sdsHandler: SdsHandler + rateLimit: RateLimitManager + + requestIds: Table[RequestId, seq[RequestId]] + pendingMessagingRequests: seq[PendingMessagingRequest] + ## Entries are kept until the matching segment reaches a final + ## state (`Confirmed` or `Failed`); a whole channel request is + ## then pruned in one pass once all its segments are final. + brokerCtx: BrokerContext func getChannelId*(self: ReliableChannel): ChannelId {.inline.} = self.channelId @@ -73,6 +105,65 @@ func getContentTopic*(self: ReliableChannel): ContentTopic {.inline.} = func getSenderId*(self: ReliableChannel): SdsParticipantID {.inline.} = self.senderId +func pendingMessagingRequestsLenForTest*(self: ReliableChannel): int {.inline.} = + ## Test-only: returns how many segments are still tracked in the + ## state machine. The internal segment lifecycle is not part of the + ## spec'd API; production callers must not observe it. + return self.pendingMessagingRequests.len + +proc forceInjectInFlightForTest*( + self: ReliableChannel, channelReqId: RequestId, messagingReqId: RequestId +) = + ## Test-only: inject a pending entry already in `InFlight`. Bypasses + ## `send` / `onReadyToSend` so unit tests can exercise final-state + ## handling and the `pruneCompletedChannelReqs` rule (drop only when + ## *all* siblings of a `channelReqId` are final) without having + ## to drive — and race with — the real send pipeline. + self.pendingMessagingRequests.add( + PendingMessagingRequest( + channelReqId: channelReqId, + messagingReqId: some(messagingReqId), + persistenceReqType: MessagePersistence.Persistent, + segmentSendState: SegmentSendState.InFlight, + ) + ) + +func isFinal(state: SegmentSendState): bool {.inline.} = + return state in {SegmentSendState.Confirmed, SegmentSendState.Failed} + +proc pruneCompletedChannelReqs(self: ReliableChannel) = + ## Drop every `pendingMessagingRequests` entry whose `channelReqId` + ## has all of its segments in a final state. A single failing + ## segment doesn't trigger a drop on its own — we wait until siblings + ## are also accounted for, so the channel-level outcome is decided + ## from a complete picture. + var ongoing: HashSet[RequestId] + for entry in self.pendingMessagingRequests: + if not entry.segmentSendState.isFinal: + ongoing.incl(entry.channelReqId) + self.pendingMessagingRequests.keepItIf(it.channelReqId in ongoing) + +proc onMessageSent(self: ReliableChannel, messagingReqId: RequestId) = + ## Invoked from this channel's `MessageSentEvent` listener. Flips + ## the matching `InFlight` segment to `Confirmed` and prunes. The + ## listener routes every event through here; entries that don't + ## belong to this channel simply don't match and are no-ops. + for i in 0 ..< self.pendingMessagingRequests.len: + if self.pendingMessagingRequests[i].segmentSendState == SegmentSendState.InFlight and + self.pendingMessagingRequests[i].messagingReqId == some(messagingReqId): + self.pendingMessagingRequests[i].segmentSendState = SegmentSendState.Confirmed + self.pruneCompletedChannelReqs() + return + +proc onMessageError(self: ReliableChannel, messagingReqId: RequestId) = + ## Symmetric to `onMessageSent` but for `MessageErrorEvent`. + for i in 0 ..< self.pendingMessagingRequests.len: + if self.pendingMessagingRequests[i].segmentSendState == SegmentSendState.InFlight and + self.pendingMessagingRequests[i].messagingReqId == some(messagingReqId): + self.pendingMessagingRequests[i].segmentSendState = SegmentSendState.Failed + self.pruneCompletedChannelReqs() + return + proc onReadyToSend( self: ReliableChannel, msgs: seq[seq[byte]] ) {.async: (raises: []).} = @@ -81,11 +172,28 @@ proc onReadyToSend( ## blobs (already-encoded SDS messages): ## ## ... -> rate_limit_manager -> [encryption] -> dispatch + var idx = 0 for m in msgs: - ## Each `m` was preceded by exactly one push onto `pendingRequests` - ## in `send`, so this pop is always safe in the current skeleton. - let pending = self.pendingRequests[0] - self.pendingRequests.delete(0) + ## The first `AwaitingRateLimit` entry in push order is the one + ## this `m` belongs to: `send()` adds one entry per segment, and + ## `rate_limit_manager` re-emits them in the same FIFO order, so + ## the two sequences advance in lockstep. Earlier entries may + ## already be `InFlight` / `Confirmed` / `Failed` because they + ## live on until every sibling of their `channelReqId` is final, + ## so we walk past those to find the next one that was awaiting for this batch. + while idx < self.pendingMessagingRequests.len and + self.pendingMessagingRequests[idx].segmentSendState != SegmentSendState.AwaitingRateLimit: + idx.inc() + if idx >= self.pendingMessagingRequests.len: + ## rate_limit_manager emitted more messages than we have pending — + ## should not happen given `send` pushes one entry per enqueued + ## SDS payload. Drop silently rather than corrupt state. + break + + let channelReqId = self.pendingMessagingRequests[idx].channelReqId + let isEphemeral = + self.pendingMessagingRequests[idx].persistenceReqType == + MessagePersistence.Ephemeral ## TODO: revisit which fields of the SDS message must be encrypted. ## Encrypting the whole encoded blob forces every receiver to attempt @@ -97,21 +205,45 @@ proc onReadyToSend( MessageErrorEvent.emit( self.brokerCtx, MessageErrorEvent( - requestId: pending.parent, + requestId: channelReqId, messageHash: "", error: "encryption failed: " & error, ), ) + ## Encryption failed *before* we could hand the segment to the + ## delivery layer — no `messagingReqId` was minted and no + ## `DeliveryTask` was queued on `sendService`. The delivery + ## layer will therefore never emit a `MessageSentEvent` / + ## `MessageErrorEvent` for this segment, so `onMessageError` + ## won't fire either. Advance the state machine inline so the + ## parent `channelReqId` can still be pruned once its siblings + ## are also final. + self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed + self.pruneCompletedChannelReqs() + idx.inc() continue let wireBytes = seq[byte](encrypted) let envelope = MessageEnvelope( - contentTopic: self.contentTopic, payload: wireBytes, ephemeral: pending.ephemeral + contentTopic: self.contentTopic, payload: wireBytes, ephemeral: isEphemeral ) - let deliveryReqId = RequestId.new(self.rng) - let deliveryTask = DeliveryTask.new(deliveryReqId, envelope, globalBrokerContext()).valueOr: - ## TODO: emit waku `MessageErrorEvent` for the parent request id. + let messagingReqId = RequestId.new(self.rng) + let deliveryTask = DeliveryTask.new( + messagingReqId, envelope, self.brokerCtx + ).valueOr: + MessageErrorEvent.emit( + self.brokerCtx, + MessageErrorEvent( + requestId: channelReqId, + messageHash: "", + error: "delivery task setup failed: " & error, + ), + ) + self.pendingMessagingRequests[idx].messagingReqId = some(messagingReqId) + self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed + self.pruneCompletedChannelReqs() + idx.inc() continue ## Stamp the Reliable Channel wire-format spec marker so the ingress @@ -121,8 +253,11 @@ proc onReadyToSend( ## `meta` field. deliveryTask.msg.meta = LipWireReliableChannelVersion.toBytes() + self.pendingMessagingRequests[idx].messagingReqId = some(messagingReqId) + self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.InFlight asyncSpawn self.deliveryService.sendService.send(deliveryTask) - self.requestIds.mgetOrPut(pending.parent, @[]).add(deliveryReqId) + self.requestIds.mgetOrPut(channelReqId, @[]).add(messagingReqId) + idx.inc() proc send*( self: ReliableChannel, payload: seq[byte], ephemeral: bool = false @@ -135,18 +270,22 @@ proc send*( ## ## `rate_limit_manager.enqueueToSend` emits a `ReadyToSendEvent` with ## the SDS messages cleared for transmission; the channel's listener - ## then runs the final stage (encryption -> dispatch). The `ephemeral` - ## flag is carried alongside each segment in `pendingRequests` and - ## stamped onto the eventual `MessageEnvelope`. + ## then runs the final stage (encryption -> dispatch). The + ## `persistenceReqType` is carried alongside each segment in + ## `pendingMessagingRequests` and stamped onto the eventual + ## `MessageEnvelope`. ## - ## The returned `RequestId` is the parent of one-or-more - ## delivery-service `RequestId`s; the mapping is recorded in + ## The returned `RequestId` is the channel-level parent of one-or-more + ## messaging-layer `RequestId`s; the mapping is recorded in ## `self.requestIds`. if payload.len == 0: return err("empty payload") - let parentReqId = RequestId.new(self.rng) - self.requestIds[parentReqId] = @[] + let channelReqId = RequestId.new(self.rng) + self.requestIds[channelReqId] = @[] + + let persistenceReqType = + if ephemeral: MessagePersistence.Ephemeral else: MessagePersistence.Persistent for segmentBytes in self.segmentation.performSegmentation(payload): ## Segments arrive already encoded; the segmentation module owns @@ -155,10 +294,17 @@ proc send*( self.channelId, self.senderId, segmentBytes ).valueOr: return err("SDS wrap failed: " & error) - self.pendingRequests.add((parent: parentReqId, ephemeral: ephemeral)) + self.pendingMessagingRequests.add( + PendingMessagingRequest( + channelReqId: channelReqId, + messagingReqId: none(RequestId), + persistenceReqType: persistenceReqType, + segmentSendState: SegmentSendState.AwaitingRateLimit, + ) + ) self.rateLimit.enqueueToSend(sdsBytes) - return ok(parentReqId) + return ok(channelReqId) proc onMessageReceived( self: ReliableChannel, messageHash: string, payload: seq[byte] @@ -231,15 +377,16 @@ proc new*( sdsHandler: SdsHandler.new(sdsConfig, senderId), rateLimit: RateLimitManager.new(rateConfig, channelId, brokerCtx), requestIds: initTable[RequestId, seq[RequestId]](), - pendingRequests: @[], + pendingMessagingRequests: @[], brokerCtx: brokerCtx, ) - ## Each channel owns its own egress + ingress listeners on - ## `chn.brokerCtx`, filtered to traffic addressed to this channel. - ## Keeping the listeners (and the procs they call) inside the - ## channel lets `onReadyToSend` and `onMessageReceived` stay private - ## — the manager doesn't need to know about them. + ## Each channel owns its own egress + ingress + send-completion + ## listeners on `chn.brokerCtx`, filtered to traffic addressed to + ## this channel. Keeping the listeners (and the handler procs they + ## call) inside the channel lets `onReadyToSend` / + ## `onMessageReceived` / `onMessageSent` / `onMessageError` stay + ## private — the manager doesn't need to know about them. discard ReadyToSendEvent.listen( chn.brokerCtx, proc(evt: ReadyToSendEvent): Future[void] {.async: (raises: []).} = @@ -261,4 +408,22 @@ proc new*( , ) + ## Send-completion events are tagged with the per-segment messaging + ## `requestId` — globally unique, so we don't need any channel filter + ## up front. The handler scans this channel's pending entries for a + ## match and is a no-op when the id belongs to a different channel. + discard MessageSentEvent.listen( + chn.brokerCtx, + proc(evt: MessageSentEvent): Future[void] {.async: (raises: []).} = + chn.onMessageSent(evt.requestId) + , + ) + + discard MessageErrorEvent.listen( + chn.brokerCtx, + proc(evt: MessageErrorEvent): Future[void] {.async: (raises: []).} = + chn.onMessageError(evt.requestId) + , + ) + return chn diff --git a/channels/reliable_channel_manager.nim b/channels/reliable_channel_manager.nim index ddbdb37a6..b95a67263 100644 --- a/channels/reliable_channel_manager.nim +++ b/channels/reliable_channel_manager.nim @@ -61,6 +61,15 @@ proc stop*(self: ReliableChannelManager) {.async.} = if not self.deliveryService.isNil(): await self.deliveryService.stopDeliveryService() +proc getChannelForTest*( + self: ReliableChannelManager, channelId: ChannelId +): ReliableChannel = + ## Test-only: returns the channel for `channelId`, or `nil` if none + ## exists. Production callers must address channels by `channelId` + ## through the manager's `send` / `closeChannel` API — direct + ## references bypass the manager's lifecycle and pipeline. + self.channels.getOrDefault(channelId) + proc createReliableChannel*( self: ReliableChannelManager, channelId: ChannelId, diff --git a/tests/channels/test_reliable_channel_send_receive.nim b/tests/channels/test_reliable_channel_send_receive.nim index 052cd35c9..997a9d5cb 100644 --- a/tests/channels/test_reliable_channel_send_receive.nim +++ b/tests/channels/test_reliable_channel_send_receive.nim @@ -147,3 +147,104 @@ suite "Reliable Channel - ingress": check not fired await manager.stop() + +suite "Reliable Channel - send state machine": + asyncTest "MessageSentEvent flips InFlight -> Confirmed and prunes": + ## Exercises the channel-side state machine in isolation. We + ## inject a pending entry already in `InFlight` (so we don't have + ## to drive — and race with — the real send pipeline), then emit + ## the delivery-layer `MessageSentEvent` for its `messagingReqId`. + ## The channel's own listener flips the entry to `Confirmed` and, + ## since it's the only segment for that `channelReqId`, prunes it. + const + channelId = ChannelId("sm-success-channel") + contentTopic = ContentTopic("/reliable-channel/test/sm-success") + + var manager: ReliableChannelManager + var brokerCtx: BrokerContext + lockNewGlobalBrokerContext: + brokerCtx = globalBrokerContext() + manager = (await ReliableChannelManager.new(createApiNodeConf())).expect( + "Failed to create manager" + ) + + setNoopEncryption() + discard manager + .createReliableChannel(channelId, contentTopic, SdsParticipantID("local")) + .expect("createReliableChannel") + + let chn = manager.getChannelForTest(channelId) + doAssert not chn.isNil() + check chn.pendingMessagingRequestsLenForTest == 0 + + let channelReqId = RequestId("test-channel-req") + let messagingReqId = RequestId("test-msg-req") + chn.forceInjectInFlightForTest(channelReqId, messagingReqId) + check chn.pendingMessagingRequestsLenForTest == 1 + + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: messagingReqId, messageHash: ""), + ) + + let deadline = Moment.now() + 1.seconds + while Moment.now() < deadline and chn.pendingMessagingRequestsLenForTest > 0: + await sleepAsync(5.milliseconds) + check chn.pendingMessagingRequestsLenForTest == 0 + + await manager.stop() + + asyncTest "channelReqId not pruned until ALL its segments are final": + ## Validates `pruneCompletedChannelReqs`'s "wait for siblings" rule: + ## a channel request with multiple segments is only dropped once + ## every segment is `Confirmed` or `Failed`. Confirm the first + ## segment and assert both entries are still tracked; fail the + ## second and assert both are pruned. + const + channelId = ChannelId("sm-multi-channel") + contentTopic = ContentTopic("/reliable-channel/test/sm-multi") + + var manager: ReliableChannelManager + var brokerCtx: BrokerContext + lockNewGlobalBrokerContext: + brokerCtx = globalBrokerContext() + manager = (await ReliableChannelManager.new(createApiNodeConf())).expect( + "Failed to create manager" + ) + + setNoopEncryption() + discard manager + .createReliableChannel(channelId, contentTopic, SdsParticipantID("local")) + .expect("createReliableChannel") + + let chn = manager.getChannelForTest(channelId) + doAssert not chn.isNil() + + let channelReqId = RequestId("multi-channel-req") + let msgReqId1 = RequestId("multi-msg-req-1") + let msgReqId2 = RequestId("multi-msg-req-2") + chn.forceInjectInFlightForTest(channelReqId, msgReqId1) + chn.forceInjectInFlightForTest(channelReqId, msgReqId2) + check chn.pendingMessagingRequestsLenForTest == 2 + + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqId1, messageHash: ""), + ) + await sleepAsync(50.milliseconds) + ## Sibling msgReqId2 is still `InFlight`, so prune must NOT fire + ## yet — both entries remain tracked. + check chn.pendingMessagingRequestsLenForTest == 2 + + waku_message_events.MessageErrorEvent.emit( + brokerCtx, + waku_message_events.MessageErrorEvent( + requestId: msgReqId2, messageHash: "", error: "synthetic" + ), + ) + let deadline = Moment.now() + 1.seconds + while Moment.now() < deadline and chn.pendingMessagingRequestsLenForTest > 0: + await sleepAsync(5.milliseconds) + check chn.pendingMessagingRequestsLenForTest == 0 + + await manager.stop()