diff --git a/channels/reliable_channel.nim b/channels/reliable_channel.nim index c3fbe5d77..e32b57e36 100644 --- a/channels/reliable_channel.nim +++ b/channels/reliable_channel.nim @@ -13,7 +13,7 @@ ## ## See: https://lip.logos.co/messaging/raw/reliable-channel-api.html -import std/[options, sets, tables] +import std/[options, tables] import results import chronos import bearssl/rand @@ -55,32 +55,35 @@ type Persistent Ephemeral - SegmentSendState {.pure.} = enum - ## Lifecycle of a single segment as tracked by the channel. The - ## messaging layer has its own richer `DeliveryState` (retries, - ## propagated-vs-validated); here we only model what's needed to - ## decide when a `channelReqId` is fully accounted for. - AwaitingRateLimit ## Pushed by `send`; not yet released by rate_limit_manager. - InFlight - ## Released by rate_limit_manager and handed to delivery_service; - ## `messagingReqId` is now set. - Confirmed ## `MessageSentEvent` arrived for `messagingReqId`. - Failed - ## `MessageErrorEvent` arrived for `messagingReqId`, or the local - ## delivery-task construction failed before any id was reachable. - - PendingMessagingRequest = object - ## One entry per segment (i.e. per messaging-layer request). The - ## relative order of `AwaitingRateLimit` entries must match the - ## order in which `rate_limit_manager` re-emits messages, which is - ## FIFO with `send()`. - channelReqId*: RequestId - ## The channel-layer parent id returned to the caller of `send()` in channel layer. - ## One channel request maps to N pending messaging requests. - messagingReqId*: Option[RequestId] - ## Per-segment messaging layer id. `none` until `onReadyToSend` assigns it. + ChannelReqState = object + ## Per channel-level request, tracks how many of its segments are + ## still queued, in flight, or have terminated. The channel-level + ## final event fires when `confirmedCount + failedCount` reaches + ## `totalExpectedSegments` AND no segments are still awaiting dispatch + ## or in flight. persistenceReqType: MessagePersistence - segmentSendState*: SegmentSendState + totalExpectedSegments: int + ## Total segments produced by `segmentation.performSegmentation` + ## for this `channelReqId`. Set once in `send`, never mutated. + awaitingDispatch: int + ## Segments enqueued in `rate_limit_manager` but not yet claimed + ## by `onReadyToSend`. Decremented when `onReadyToSend` picks a + ## message and assigns it to this `channelReqId`. + inflightMessagingIds: seq[RequestId] + ## Messaging-layer ids minted by the send handler that have not + ## yet produced a final event. Removed on `MessageSentEvent` / `MessageErrorEvent`. + confirmedCount: int + failedCount: int + + ChannelReqs = OrderedTable[RequestId, ChannelReqState] + ## Key: channelReqId (the parent id returned by channel `send`). Value: + ## per-request state, see `ChannelReqState`. + ## + ## `OrderedTable` preserves insertion order, which matches the FIFO + ## order `rate_limit_manager` re-emits messages in: `onReadyToSend` + ## routes each segment to the first entry with `awaitingDispatch > 0`, + ## and that scan is correct precisely because the outer iteration + ## order matches the order `send` pushed entries. ReliableChannel* = ref object ## Spec-defined public type. Fields are private so callers cannot @@ -95,13 +98,23 @@ type sdsHandler: SdsHandler rateLimit: RateLimitManager - requestIds: Table[RequestId, seq[RequestId]] - pendingMessagingRequests: seq[PendingMessagingRequest] - ## Entries are kept until the matching segment reaches a final - ## state (`Confirmed` or `Failed`); a whole channel request is - ## then pruned in one pass once all its segments are final. + channelReqs: ChannelReqs brokerCtx: BrokerContext +func init( + T: type ChannelReqState, + persistenceReqType: MessagePersistence, + totalExpectedSegments: int, +): T = + return ChannelReqState( + persistenceReqType: persistenceReqType, + totalExpectedSegments: totalExpectedSegments, + awaitingDispatch: totalExpectedSegments, + inflightMessagingIds: @[], + confirmedCount: 0, + failedCount: 0, + ) + func getChannelId*(self: ReliableChannel): ChannelId {.inline.} = self.channelId @@ -111,70 +124,92 @@ func getContentTopic*(self: ReliableChannel): ContentTopic {.inline.} = func getSenderId*(self: ReliableChannel): SdsParticipantID {.inline.} = self.senderId -func isFinal(state: SegmentSendState): bool {.inline.} = - return state in {SegmentSendState.Confirmed, SegmentSendState.Failed} +proc tryFinalizeChannelReq(self: ReliableChannel, channelReqId: RequestId) = + ## Tries to finalize the channel-level request identified by `channelReqId` if + ## certain conditions are met, i.e., no segments are still awaiting dispatch or in flight, + ## and the total number of confirmed + failed segments equals the total expected segments. + ## Therefore, the channel-level request is removed from `self.channelReqs` + ## and the appropriate final event is emitted. + ## + let state = self.channelReqs.getOrDefault(channelReqId) + if state.totalExpectedSegments == 0: + ## Either already finalized (and removed) or never inserted. + return + if state.awaitingDispatch != 0 or state.inflightMessagingIds.len != 0: + return + if state.confirmedCount + state.failedCount < state.totalExpectedSegments: + return -proc pruneCompletedChannelReqs(self: ReliableChannel) = - ## Drop every `pendingMessagingRequests` entry whose `channelReqId` - ## has all of its segments in a final state. A single failing - ## segment doesn't trigger a drop on its own — we wait until siblings - ## are also accounted for, so the channel-level outcome is decided - ## from a complete picture. For each fully-final `channelReqId`, emit - ## the channel-level final event before the entries are dropped: - ## `ChannelMessageSentEvent` if every sibling Confirmed, - ## `ChannelMessageErrorEvent` if any sibling Failed. - var hasPending = initHashSet[RequestId]() - var anyFailed = initHashSet[RequestId]() - for entry in self.pendingMessagingRequests: - if not entry.segmentSendState.isFinal(): - hasPending.incl(entry.channelReqId) - elif entry.segmentSendState == SegmentSendState.Failed: - anyFailed.incl(entry.channelReqId) + self.channelReqs.del(channelReqId) - var emitted = initHashSet[RequestId]() - for entry in self.pendingMessagingRequests: - if entry.channelReqId in hasPending or entry.channelReqId in emitted: + if state.failedCount > 0: + ChannelMessageErrorEvent.emit( + self.brokerCtx, + ChannelMessageErrorEvent( + channelId: self.channelId, + requestId: channelReqId, + error: "one or more segments failed", + ), + ) + else: + ChannelMessageSentEvent.emit( + self.brokerCtx, + ChannelMessageSentEvent(channelId: self.channelId, requestId: channelReqId), + ) + +type ClaimedSegment = object + channelReqId: RequestId + isEphemeral: bool + +proc claimAwaitingChannelReq(self: ReliableChannel): Option[ClaimedSegment] = + for channelReqId, state in self.channelReqs.mpairs: + if state.awaitingDispatch > 0: + state.awaitingDispatch.dec() + return some( + ClaimedSegment( + channelReqId: channelReqId, + isEphemeral: state.persistenceReqType == MessagePersistence.Ephemeral, + ) + ) + return none(ClaimedSegment) + +type MessagingOutcome {.pure.} = enum + Sent + Failed + +proc onMessageFinal( + self: ReliableChannel, messagingReqId: RequestId, outcome: MessagingOutcome +) = + for channelReqId, state in self.channelReqs.mpairs: + let idx = state.inflightMessagingIds.find(messagingReqId) + if idx < 0: continue - emitted.incl(entry.channelReqId) - if entry.channelReqId in anyFailed: - ChannelMessageErrorEvent.emit( - self.brokerCtx, - ChannelMessageErrorEvent( - channelId: self.channelId, - requestId: entry.channelReqId, - error: "one or more segments failed", - ), - ) - else: - ChannelMessageSentEvent.emit( - self.brokerCtx, - ChannelMessageSentEvent( - channelId: self.channelId, requestId: entry.channelReqId - ), - ) + state.inflightMessagingIds.del(idx) + case outcome + of MessagingOutcome.Sent: + state.confirmedCount.inc() + of MessagingOutcome.Failed: + state.failedCount.inc() + self.tryFinalizeChannelReq(channelReqId) + return - self.pendingMessagingRequests.keepItIf(it.channelReqId in hasPending) +proc markSegmentFailed(self: ReliableChannel, channelReqId: RequestId) = + try: + self.channelReqs[channelReqId].failedCount.inc() + except KeyError as e: + error "unreachable: channelReqId not found in markSegmentFailed", + channelReqId = $channelReqId, error = e.msg + return + self.tryFinalizeChannelReq(channelReqId) -proc onMessageSent(self: ReliableChannel, messagingReqId: RequestId) = - ## Invoked from this channel's `MessageSentEvent` listener. Flips - ## the matching `InFlight` segment to `Confirmed` and prunes. The - ## listener routes every event through here; entries that don't - ## belong to this channel simply don't match and are no-ops. - self.pendingMessagingRequests.applyItIf( - it.segmentSendState == SegmentSendState.InFlight and - it.messagingReqId == some(messagingReqId) - ): - it.segmentSendState = SegmentSendState.Confirmed - self.pruneCompletedChannelReqs() - -proc onMessageError(self: ReliableChannel, messagingReqId: RequestId) = - ## Symmetric to `onMessageSent` but for `MessageErrorEvent`. - self.pendingMessagingRequests.applyItIf( - it.segmentSendState == SegmentSendState.InFlight and - it.messagingReqId == some(messagingReqId) - ): - it.segmentSendState = SegmentSendState.Failed - self.pruneCompletedChannelReqs() +proc markSegmentInflight( + self: ReliableChannel, channelReqId: RequestId, messagingReqId: RequestId +) = + try: + self.channelReqs[channelReqId].inflightMessagingIds.add(messagingReqId) + except KeyError as e: + error "unreachable: channelReqId not found in markSegmentInflight", + channelReqId = $channelReqId, error = e.msg proc onReadyToSend( self: ReliableChannel, readyToSendEvent: ReadyToSendEvent @@ -184,30 +219,22 @@ proc onReadyToSend( ## blobs (already-encoded SDS messages): ## ## ... -> rate_limit_manager -> [encryption] -> dispatch - var idx = 0 + ## + ## For each `m`, the next channelReqId still queued in rate-limit + ## claims the slot (FIFO across sibling sends). The channelReqId is + ## captured up front and used as a stable key for every later state + ## update — no positional index is ever held across an `await`, so + ## sibling events mutating other entries (or even this one's + ## `inflightMessagingIds`) cannot corrupt this fiber's view. for m in readyToSendEvent.msgs: - ## The first `AwaitingRateLimit` entry in push order is the one - ## this `m` belongs to: `send()` adds one entry per segment, and - ## `rate_limit_manager` re-emits them in the same FIFO order, so - ## the two sequences advance in lockstep. Earlier entries may - ## already be `InFlight` / `Confirmed` / `Failed` because they - ## live on until every sibling of their `channelReqId` is final, - ## so we walk past those to find the next one that was awaiting for this batch. - while idx < self.pendingMessagingRequests.len and - self.pendingMessagingRequests[idx].segmentSendState != - SegmentSendState.AwaitingRateLimit - : - idx.inc() - if idx >= self.pendingMessagingRequests.len: + let claimed = self.claimAwaitingChannelReq().valueOr: ## rate_limit_manager emitted more messages than we have pending — - ## should not happen given `send` pushes one entry per enqueued - ## SDS payload. Drop silently rather than corrupt state. + ## should not happen given `send` increments `awaitingDispatch` + ## once per enqueued SDS payload. Drop silently rather than + ## corrupt state. break - - let channelReqId = self.pendingMessagingRequests[idx].channelReqId - let isEphemeral = - self.pendingMessagingRequests[idx].persistenceReqType == - MessagePersistence.Ephemeral + let channelReqId = claimed.channelReqId + let isEphemeral = claimed.isEphemeral ## TODO: revisit which fields of the SDS message must be encrypted. ## Encrypting the whole encoded blob forces every receiver to attempt @@ -223,15 +250,7 @@ proc onReadyToSend( ), ) ## Encryption failed *before* we could hand the segment to the - ## delivery layer — no `messagingReqId` was minted and no - ## `DeliveryTask` was queued on `sendService`. The delivery - ## layer will therefore never emit a `MessageSentEvent` / - ## `MessageErrorEvent` for this segment, so `onMessageError` - ## won't fire either. Advance the state machine inline so the - ## parent `channelReqId` can still be pruned once its siblings - ## are also final. - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed - idx.inc() + self.markSegmentFailed(channelReqId) continue let wireBytes = seq[byte](encrypted) @@ -261,16 +280,10 @@ proc onReadyToSend( requestId: channelReqId, messageHash: "", error: "waku send failed: " & error ), ) - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.Failed - idx.inc() + self.markSegmentFailed(channelReqId) continue - self.pendingMessagingRequests[idx].messagingReqId = some(messagingReqId) - self.pendingMessagingRequests[idx].segmentSendState = SegmentSendState.InFlight - self.requestIds.mgetOrPut(channelReqId, @[]).add(messagingReqId) - idx.inc() - - self.pruneCompletedChannelReqs() + self.markSegmentInflight(channelReqId, messagingReqId) proc send*( self: ReliableChannel, payload: seq[byte], ephemeral: bool = false @@ -283,23 +296,20 @@ proc send*( ## ## `rate_limit_manager.enqueueToSend` emits a `ReadyToSendEvent` with ## the SDS messages cleared for transmission; the channel's listener - ## then runs the final stage (encryption -> dispatch). The - ## `persistenceReqType` is carried alongside each segment in - ## `pendingMessagingRequests` and stamped onto the eventual - ## `MessageEnvelope`. + ## then runs the final stage (encryption -> dispatch). ## ## The returned `RequestId` is the channel-level parent of one-or-more - ## messaging-layer `RequestId`s; the mapping is recorded in - ## `self.requestIds`. + ## messaging-layer `RequestId`s; the mapping is held in + ## `self.channelReqs` until every segment is final. if payload.len == 0: return err("empty payload") let channelReqId = RequestId.new(self.rng) - self.requestIds[channelReqId] = @[] - let persistenceReqType = if ephemeral: MessagePersistence.Ephemeral else: MessagePersistence.Persistent + var segmentCount = 0 + var enqueued: seq[seq[byte]] for segmentBytes in self.segmentation.performSegmentation(payload): ## Segments arrive already encoded; the segmentation module owns ## the wire format so SDS only ever sees opaque bytes. @@ -307,14 +317,13 @@ proc send*( self.channelId, self.senderId, segmentBytes ).valueOr: return err("SDS wrap failed: " & error) - self.pendingMessagingRequests.add( - PendingMessagingRequest( - channelReqId: channelReqId, - messagingReqId: none(RequestId), - persistenceReqType: persistenceReqType, - segmentSendState: SegmentSendState.AwaitingRateLimit, - ) - ) + enqueued.add(sdsBytes) + segmentCount.inc() + + self.channelReqs[channelReqId] = + ChannelReqState.init(persistenceReqType, segmentCount) + + for sdsBytes in enqueued: self.rateLimit.enqueueToSend(sdsBytes) return ok(channelReqId) @@ -402,8 +411,7 @@ proc new*( segmentation: SegmentationHandler.new(segConfig), sdsHandler: SdsHandler.new(sdsConfig, senderId), rateLimit: RateLimitManager.new(rateConfig, channelId, brokerCtx), - requestIds: initTable[RequestId, seq[RequestId]](), - pendingMessagingRequests: @[], + channelReqs: initOrderedTable[RequestId, ChannelReqState](), brokerCtx: brokerCtx, ) @@ -411,8 +419,8 @@ proc new*( ## listeners on `chn.brokerCtx`, filtered to traffic addressed to ## this channel. Keeping the listeners (and the handler procs they ## call) inside the channel lets `onReadyToSend` / - ## `onMessageReceived` / `onMessageSent` / `onMessageError` stay - ## private — the manager doesn't need to know about them. + ## `onMessageReceived` / `onMessageFinal` stay private — the + ## manager doesn't need to know about them. discard ReadyToSendEvent.listen( chn.brokerCtx, proc(evt: ReadyToSendEvent): Future[void] {.async: (raises: []).} = @@ -441,13 +449,13 @@ proc new*( discard MessageSentEvent.listen( chn.brokerCtx, proc(evt: MessageSentEvent): Future[void] {.async: (raises: []).} = - chn.onMessageSent(evt.requestId), + chn.onMessageFinal(evt.requestId, MessagingOutcome.Sent), ) discard MessageErrorEvent.listen( chn.brokerCtx, proc(evt: MessageErrorEvent): Future[void] {.async: (raises: []).} = - chn.onMessageError(evt.requestId), + chn.onMessageFinal(evt.requestId, MessagingOutcome.Failed), ) return chn diff --git a/tests/channels/test_reliable_channel_send_receive.nim b/tests/channels/test_reliable_channel_send_receive.nim index 2f49182a2..5ea300eb3 100644 --- a/tests/channels/test_reliable_channel_send_receive.nim +++ b/tests/channels/test_reliable_channel_send_receive.nim @@ -315,3 +315,102 @@ suite "Reliable Channel - send state machine": ## `messagingReqId`s from a fake `SendHandler`, finalise some, and ## assert prune only fires once every sibling is final. skip() + + asyncTest "sibling MessageSentEvent during sendHandler await does not corrupt state": + ## Regression test for the prune-during-await race + ## (PR #3914 review comment r3324891059). Locks in that a sibling + ## `MessageSentEvent` firing while `onReadyToSend` is paused at an + ## `await` does not lose the second `channelReqId`'s terminal + ## event. + const + channelId = ChannelId("sm-race-channel") + contentTopic = ContentTopic("/reliable-channel/test/sm-race") + + var manager: ReliableChannelManager + var brokerCtx: BrokerContext + lockNewGlobalBrokerContext: + brokerCtx = globalBrokerContext() + manager = (await ReliableChannelManager.new(createApiNodeConf())).expect( + "Failed to create manager" + ) + + setNoopEncryption() + + var msgReqIds: seq[RequestId] + var sendsReturned = 0 + let fakeSend: SendHandler = proc( + env: MessageEnvelope + ): Future[Result[RequestId, string]] {.async: (raises: [CatchableError]), gcsafe.} = + ## Call 2 fires the first segment's terminal event and then + ## yields, so the listener task runs while the second segment + ## is still mid-`await` in `onReadyToSend` — the exact race + ## window the regression test targets. + let id = RequestId("race-msg-req-" & $(msgReqIds.len + 1)) + msgReqIds.add(id) + if msgReqIds.len == 2: + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqIds[0], messageHash: ""), + ) + await sleepAsync(50.milliseconds) + sendsReturned.inc() + return ok(id) + + discard manager + .createReliableChannel( + channelId, contentTopic, SdsParticipantID("local"), sendHandler = fakeSend + ) + .expect("createReliableChannel") + + var finalisedReqIds: seq[RequestId] + let bothFinalised = newFuture[void]("both-finalised") + discard ChannelMessageSentEvent + .listen( + brokerCtx, + proc(evt: ChannelMessageSentEvent) {.async: (raises: []).} = + if evt.channelId == channelId: + finalisedReqIds.add(evt.requestId) + if finalisedReqIds.len == 2 and not bothFinalised.finished(): + bothFinalised.complete() + , + ) + .expect("listen ChannelMessageSentEvent") + + let channelReqId1 = manager.send(channelId, "first".toBytes()).expect("send 1") + + ## Drain the first segment fully before queueing the second, so + ## the rate-limit FIFO between sibling sends isn't itself under + ## test here. + let firstDispatched = Moment.now() + 1.seconds + while Moment.now() < firstDispatched and msgReqIds.len < 1: + await sleepAsync(5.milliseconds) + check msgReqIds.len == 1 + + let channelReqId2 = manager.send(channelId, "second".toBytes()).expect("send 2") + + ## Wait until `fakeSend(m2)` has fully returned and yield once + ## more so `onReadyToSend`'s post-await continuation gets a chance + ## to register `id2` in `inflightMessagingIds` before we emit its + ## terminal event. + let dispatchDeadline = Moment.now() + 1.seconds + while Moment.now() < dispatchDeadline and sendsReturned < 2: + await sleepAsync(5.milliseconds) + check sendsReturned == 2 + await sleepAsync(50.milliseconds) + + ## Finalise the second segment from the outside. If the race + ## corrupted state, `channelReqId2`'s entry would never reach + ## `inflightMessagingIds` and this event would silently miss. + waku_message_events.MessageSentEvent.emit( + brokerCtx, + waku_message_events.MessageSentEvent(requestId: msgReqIds[1], messageHash: ""), + ) + + let arrived = await bothFinalised.withTimeout(2.seconds) + check arrived + if arrived: + check finalisedReqIds.len == 2 + check channelReqId1 in finalisedReqIds + check channelReqId2 in finalisedReqIds + + await manager.stop() diff --git a/tools/sync-nimble-lock.sh b/tools/sync-nimble-lock.sh new file mode 100755 index 000000000..b55826327 --- /dev/null +++ b/tools/sync-nimble-lock.sh @@ -0,0 +1,322 @@ +#!/usr/bin/env bash +# +# sync-nimble-lock.sh +# +# Cross-check git-URL pinned `requires` in waku.nimble against nimble.lock and +# sync the lock entry for any pin that CHANGED relative to a git base ref +# (default: HEAD) -- and ONLY those entries. No other package is touched. +# +# It does NOT run `nimble lock` (which rewrites the whole file and churns +# unrelated packages). Instead it computes the package sha1 checksum itself, +# reproducing nimble's algorithm exactly (src/nimblepkg/checksums.nim): +# +# files = `git ls-files` in the package's git checkout at the pinned rev +# files.sort() # lexicographic +# sha1 = SHA1 over, for each existing regular file (in sorted order): +# update(relative_path_string) +# if symlink: update(symlink_target_string) +# else: update(file_bytes) # 8192-byte chunks +# +# For each changed pin it updates exactly three fields of the matching lock +# entry, preserving all formatting and every other entry byte-for-byte: +# version = "#" + (commit or tag) +# vcsRevision = git rev-parse of the ref (resolves tags) +# checksums.sha1 = the self-computed checksum +# +# The `dependencies` array is intentionally left untouched (see NOTE below). +# +# Usage: +# tools/sync-nimble-lock.sh # dry-run; exit 1 if drift +# tools/sync-nimble-lock.sh --apply # update nimble.lock +# tools/sync-nimble-lock.sh --base origin/master # compare against a ref +# +# Exit codes: 0 = in sync / applied, 1 = drift (dry-run), 2 = usage/tooling error +# +# Portable across macOS (bash 3.2, BSD tools) and Linux: all logic is in +# python3; bash only parses args and checks tools. Requires: git, python3. +# +# NOTE on `dependencies`: a version bump can in principle change a package's +# direct dependency set. Reproducing nimble's dependency-name normalization +# without running nimble is fragile, and the user-requested scope is +# version/vcsRevision/sha1. If a bumped dependency added/removed a `requires`, +# update its lock `dependencies` array by hand. The script warns when the +# bumped package's own .nimble `requires` count differs from the lock entry. + +set -euo pipefail + +APPLY=0 +BASE="HEAD" + +usage() { sed -n '2,55p' "$0" | sed 's/^#\{0,1\} \{0,1\}//'; } + +while [ $# -gt 0 ]; do + case "$1" in + --apply) APPLY=1 ;; + --base) shift; [ $# -gt 0 ] || { echo "error: --base needs a ref" >&2; exit 2; }; BASE="$1" ;; + --base=*) BASE="${1#*=}" ;; + -h|--help) usage; exit 0 ;; + *) echo "error: unknown argument: $1" >&2; exit 2 ;; + esac + shift +done + +command -v python3 >/dev/null 2>&1 || { echo "error: python3 is required" >&2; exit 2; } +command -v git >/dev/null 2>&1 || { echo "error: git is required" >&2; exit 2; } + +ROOT="$(git rev-parse --show-toplevel 2>/dev/null)" || { echo "error: not in a git repo" >&2; exit 2; } + +export SYNC_ROOT="$ROOT" SYNC_APPLY="$APPLY" SYNC_BASE="$BASE" SYNC_PKGCACHE="${HOME}/.nimble/pkgcache" + +exec python3 - <<'PYEOF' +import hashlib +import json +import os +import re +import shutil +import subprocess +import sys +import tempfile + +ROOT = os.environ["SYNC_ROOT"] +APPLY = os.environ["SYNC_APPLY"] == "1" +BASE = os.environ["SYNC_BASE"] +PKGCACHE = os.environ["SYNC_PKGCACHE"] + +NIMBLE_FILE = os.path.join(ROOT, "waku.nimble") +LOCK_FILE = os.path.join(ROOT, "nimble.lock") + +REQ_RE = re.compile(r'requires\s+"(https?://[^"#]+)#([^"]+)"') +COMMIT_RE = re.compile(r"^[0-9a-f]{40}$") +NEAR_HASH_RE = re.compile(r"^[0-9a-fx]{38,42}$") # catches the leading-`x` typo + + +def fail(msg): + sys.stderr.write("error: %s\n" % msg) + sys.exit(2) + + +def warn(msg): + sys.stderr.write("warning: %s\n" % msg) + + +def norm_url(url): + u = url.rstrip("/") + return u[:-4] if u.endswith(".git") else u + + +def git(args, cwd=None, check=True): + r = subprocess.run(["git"] + args, cwd=cwd, capture_output=True, text=True) + if check and r.returncode != 0: + fail("git %s failed: %s" % (" ".join(args), (r.stderr or r.stdout).strip())) + return r + + +# --------------------------------------------------------------------------- +# nimble checksum reproduction (verified byte-for-byte against nimble v0.22.3) +# --------------------------------------------------------------------------- +def compute_checksum(checkout_dir): + out = git(["-C", checkout_dir, "ls-files"]).stdout + files = out.strip().splitlines() + files.sort() + h = hashlib.sha1() + for rel in files: + path = os.path.join(checkout_dir, rel) + if not os.path.isfile(path): + # Skips directories / gitlinks / broken symlinks, matching nimble's + # `fileExists` guard (regular file or symlink-to-file only). + continue + h.update(rel.encode("utf-8")) + if os.path.islink(path): + h.update(os.readlink(path).encode("utf-8")) + else: + with open(path, "rb") as fh: + while True: + chunk = fh.read(8192) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + +def get_checkout(url, rev, tmpdir): + """Return (checkout_dir, cleanup_fn). Reuses ~/.nimble/pkgcache when the + exact commit is already cloned; otherwise clones from the URL.""" + # pkgcache dirs are suffixed with the commit sha (commit pins only). + if os.path.isdir(PKGCACHE): + for name in os.listdir(PKGCACHE): + if name.endswith("_" + rev) and os.path.isdir(os.path.join(PKGCACHE, name, ".git")): + cache = os.path.join(PKGCACHE, name) + git(["-C", cache, "checkout", "-q", rev]) + return cache, (lambda: None) + # Fall back to a fresh clone (network). Full clone, then checkout the ref. + dest = os.path.join(tmpdir, "clone") + print(" cloning %s ..." % url) + git(["clone", "--quiet", url, dest]) + r = git(["-C", dest, "checkout", "-q", rev], check=False) + if r.returncode != 0: + # commit may live on a ref not fetched by default; try fetching it + git(["-C", dest, "fetch", "--quiet", "origin", rev], check=False) + git(["-C", dest, "checkout", "-q", rev]) + return dest, (lambda: shutil.rmtree(dest, ignore_errors=True)) + + +def dep_requires_count(checkout_dir): + """Best-effort count of git/registry `requires` in the dep's .nimble file, + for a heads-up if the lock `dependencies` array may be stale.""" + nimbles = [f for f in os.listdir(checkout_dir) if f.endswith(".nimble")] + if not nimbles: + return None + try: + txt = open(os.path.join(checkout_dir, nimbles[0])).read() + except OSError: + return None + n = 0 + for m in re.finditer(r'requires\s+"([^"]+)"', txt): + n += len([p for p in m.group(1).split(",") if p.strip()]) + return n or None + + +# --------------------------------------------------------------------------- +# detect changes +# --------------------------------------------------------------------------- +def parse_changed(base): + r = git(["-C", ROOT, "diff", base, "--", "waku.nimble"], check=False) + if r.returncode != 0: + fail("git diff against %r failed: %s" % (base, r.stderr.strip())) + changed, seen = [], set() + for line in r.stdout.splitlines(): + if not line.startswith("+") or line.startswith("+++"): + continue + m = REQ_RE.search(line[1:]) + if not m: + continue + url, rev = m.group(1), m.group(2) + key = norm_url(url) + if key in seen: + continue + seen.add(key) + if not COMMIT_RE.match(rev) and NEAR_HASH_RE.match(rev): + fail("invalid commit hash for %s: %r is not a valid 40-char hex SHA " + "(stray character / typo?)" % (url, rev)) + changed.append((url, rev)) + return changed + + +# --------------------------------------------------------------------------- +# surgical lock patch (text-level: preserves formatting & all other entries) +# --------------------------------------------------------------------------- +PKG_OPEN_RE = re.compile(r'^\s{4}"[^"]+":\s*\{\s*$') +PKG_CLOSE_RE = re.compile(r'^\s{4}\},?\s*$') + + +def set_value(line, key, val): + return re.sub(r'(^\s*"' + re.escape(key) + r'":\s*")[^"]*(")', + lambda m: m.group(1) + val + m.group(2), line, count=1) + + +def patch_lock_text(text, url, version, vcs_rev, sha1): + lines = text.splitlines(keepends=True) + url_re = re.compile(r'^\s*"url":\s*"' + re.escape(url) + r'"\s*,?\s*$') + ui = next((i for i, l in enumerate(lines) if url_re.match(l)), None) + if ui is None: + return None + # block bounds + start = next(i for i in range(ui, -1, -1) if PKG_OPEN_RE.match(lines[i])) + end = next(i for i in range(ui, len(lines)) if PKG_CLOSE_RE.match(lines[i])) + done = set() + for i in range(start, end + 1): + if "version" not in done and re.match(r'^\s*"version":', lines[i]): + lines[i] = set_value(lines[i], "version", version); done.add("version") + elif "vcsRevision" not in done and re.match(r'^\s*"vcsRevision":', lines[i]): + lines[i] = set_value(lines[i], "vcsRevision", vcs_rev); done.add("vcsRevision") + elif "sha1" not in done and re.match(r'^\s*"sha1":', lines[i]): + lines[i] = set_value(lines[i], "sha1", sha1); done.add("sha1") + missing = {"version", "vcsRevision", "sha1"} - done + if missing: + fail("could not locate field(s) %s in lock block for %s" % (sorted(missing), url)) + return "".join(lines) + + +# --------------------------------------------------------------------------- +def main(): + for p in (NIMBLE_FILE, LOCK_FILE): + if not os.path.isfile(p): + fail("%s not found" % p) + + changed = parse_changed(BASE) + if not changed: + print("No changed git-URL `requires` in waku.nimble vs %s — nothing to sync." % BASE) + return 0 + + lock = json.load(open(LOCK_FILE)) + by_url = {} + for name, e in lock.get("packages", {}).items(): + if e.get("url"): + by_url[norm_url(e["url"])] = (name, e) + + drift = [] # (url, rev, name_or_None, cur_version_or_None) + for url, rev in changed: + hit = by_url.get(norm_url(url)) + want = "#" + rev + if hit is None: + drift.append((url, rev, None, None)) + elif hit[1].get("version") != want: + drift.append((url, rev, hit[0], hit[1].get("version"))) + + if not drift: + print("nimble.lock already in sync with waku.nimble (%d changed pin(s) checked)." % len(changed)) + return 0 + + print("Dependency drift (waku.nimble vs nimble.lock):") + for url, rev, name, cur in drift: + tag = name or "(missing)" + print(" ~ %s [%s]\n waku.nimble: #%s\n nimble.lock: %s" % (url, tag, rev, cur)) + + if not APPLY: + print("\nRun with --apply to update nimble.lock (computes checksum itself; no `nimble lock`).") + return 1 + + print("\nApplying (computing checksums; not running `nimble lock`)...") + text = open(LOCK_FILE).read() + updated = [] + tmproot = tempfile.mkdtemp(prefix="sync-nimble-lock.") + try: + for url, rev, name, _cur in drift: + if name is None: + fail("%s has no entry in nimble.lock; this script updates existing " + "entries only (add new deps with a normal nimble install first)." % url) + sub = os.path.join(tmproot, re.sub(r"\W+", "_", norm_url(url))) + os.makedirs(sub, exist_ok=True) + checkout, cleanup = get_checkout(url, rev, sub) + try: + vcs_rev = git(["-C", checkout, "rev-parse", "HEAD"]).stdout.strip() + sha1 = compute_checksum(checkout) + # dependency-drift heads-up + cnt = dep_requires_count(checkout) + lock_deps = len(by_url[norm_url(url)][1].get("dependencies", [])) + if cnt is not None and lock_deps and cnt != lock_deps: + warn("%s: .nimble has %d `requires` but lock lists %d dependencies; " + "review the `dependencies` array manually." % (name, cnt, lock_deps)) + finally: + cleanup() + new_text = patch_lock_text(text, url, "#" + rev, vcs_rev, sha1) + if new_text is None: + fail("could not find lock block for url %s" % url) + text = new_text + updated.append((name, "#" + rev, vcs_rev, sha1)) + finally: + shutil.rmtree(tmproot, ignore_errors=True) + + with open(LOCK_FILE, "w") as f: + f.write(text) + + print("\nUpdated nimble.lock (only these entries; all others untouched):") + for name, ver, vcs, sha1 in updated: + print(" %-16s version=%s" % (name, ver)) + print(" %-16s vcsRevision=%s" % ("", vcs)) + print(" %-16s sha1=%s" % ("", sha1)) + return 0 + + +sys.exit(main()) +PYEOF diff --git a/waku.nimble b/waku.nimble index 05bac2ba1..adce59b32 100644 --- a/waku.nimble +++ b/waku.nimble @@ -59,7 +59,8 @@ requires "nim >= 2.2.4", "unittest2" # Packages not on nimble (use git URLs) -requires "https://github.com/logos-messaging/nim-ffi#06111de155253b34e47ed2aaed1d61d08d62cc1b" + +requires "https://github.com/logos-messaging/nim-ffi#v0.1.3" requires "https://github.com/logos-messaging/nim-sds.git#2e9a7683f0e180bf112135fae3a3803eed8490d4" @@ -528,3 +529,67 @@ task liblogosdeliveryStaticLinux, "Generate bindings": task liblogosdeliveryStaticMac, "Generate bindings": buildLibStaticMac("liblogosdelivery", "liblogosdelivery") + +### Formatting tasks + +task nphchanges, "Run nph on .nim/.nims/.nimble files changed on this branch/PR": + ## Formats every Nim source file that differs from the base branch. + ## The set covers committed changes on the branch, working-tree edits + ## (staged or not) and untracked files. The base branch is auto-detected + ## (origin's default branch, else local main/master); override it with + ## the NPH_BASE_BRANCH env var. + let nph = + if findExe("nph").len > 0: findExe("nph") + else: getHomeDir() / ".nimble" / "bin" / "nph" + if not fileExists(nph): + quit "nph not found. Run `make build-nph` first.", 1 + + proc detectBaseBranch(): string = + # Explicit override wins. + if existsEnv("NPH_BASE_BRANCH"): + return getEnv("NPH_BASE_BRANCH") + # origin's default branch, e.g. "origin/main" -> "main". + let (head, hCode) = + gorgeEx("git symbolic-ref --short refs/remotes/origin/HEAD") + if hCode == 0 and head.strip().len > 0: + let parts = head.strip().split('/') + return parts[^1] + # Fall back to whichever local branch exists. + for candidate in ["main", "master"]: + let (_, vCode) = + gorgeEx("git rev-parse --verify --quiet " & candidate) + if vCode == 0: + return candidate + return "master" + + let baseBranch = detectBaseBranch() + + # Diff against the merge-base so we only touch what this branch introduced. + var diffRef = baseBranch + let (mergeBase, mbCode) = gorgeEx("git merge-base HEAD " & baseBranch) + if mbCode == 0 and mergeBase.strip().len > 0: + diffRef = mergeBase.strip() + + let (changed, dCode) = gorgeEx("git diff --name-only --diff-filter=ACMR " & diffRef) + if dCode != 0: + quit "git diff failed: " & changed, 1 + let (untracked, _) = gorgeEx("git ls-files --others --exclude-standard") + + var files: seq[string] + for line in (changed & "\n" & untracked).splitLines(): + let f = line.strip() + if f.len == 0: + continue + if not (f.endsWith(".nim") or f.endsWith(".nims") or f.endsWith(".nimble")): + continue + if fileExists(f) and f notin files: + files.add(f) + + if files.len == 0: + echo "nphchanges: no changed .nim/.nims/.nimble files to format" + return + + echo "nphchanges: formatting " & $files.len & " file(s) (base: " & baseBranch & ")" + for f in files: + echo "Formatting " & f + exec nph & " \"" & f & "\""