diff --git a/waku/node/delivery_service/subscription_manager.nim b/waku/node/delivery_service/subscription_manager.nim index 70d8df7f0..f00d9024c 100644 --- a/waku/node/delivery_service/subscription_manager.nim +++ b/waku/node/delivery_service/subscription_manager.nim @@ -115,7 +115,7 @@ proc subscribePubsubTopics( if isNil(self.node.wakuRelay): return err("subscribePubsubTopics requires a Relay") - var errors: seq[string] = @[] + var errors: seq[string] for shard in shards: if not self.contentTopicSubs.hasKey(shard): @@ -196,6 +196,11 @@ const EdgeFilterLoopInterval = chronos.seconds(30) const EdgeFilterSubLoopDebounce = chronos.seconds(1) ## Debounce delay to coalesce rapid-fire wakeups into a single reconciliation pass. +type EdgeDialTask = object + peer: RemotePeerInfo + shard: PubsubTopic + topics: seq[ContentTopic] + proc updateShardHealth( self: SubscriptionManager, shard: PubsubTopic, state: var EdgeFilterSubState ) = @@ -335,7 +340,7 @@ proc edgeFilterHealthLoop*(self: SubscriptionManager) {.async.} = var alive = initHashSet[PeerId]() if connected.len > 0: - var pingTasks: seq[(PeerId, Future[FilterSubscribeResult])] = @[] + var pingTasks: seq[(PeerId, Future[FilterSubscribeResult])] for peer in connected.values: pingTasks.add( (peer.peerId, self.node.wakuFilterClient.ping(peer, EdgeFilterPingTimeout)) @@ -362,6 +367,36 @@ proc edgeFilterHealthLoop*(self: SubscriptionManager) {.async.} = if changed: self.edgeFilterWakeup.fire() +proc selectFilterCandidates( + self: SubscriptionManager, shard: PubsubTopic, exclude: HashSet[PeerId], needed: int +): seq[RemotePeerInfo] = + ## Select filter service peer candidates for a shard. + + # Start with every filter server peer that can serve the shard + var allCandidates = self.node.peerManager.selectPeers( + filter_common.WakuFilterSubscribeCodec, some(shard) + ) + + # Remove all already used in this shard or being dialed for it + allCandidates.keepItIf(it.peerId notin exclude) + + # Collect peer IDs already tracked on other shards + var trackedOnOther = initHashSet[PeerId]() + for otherShard, otherState in self.edgeFilterSubStates.pairs: + if otherShard != shard: + for peer in otherState.peers: + trackedOnOther.incl(peer.peerId) + + # Prefer peers we already have a connection to first, preserving shuffle + var candidates = + allCandidates.filterIt(it.peerId in trackedOnOther) & + allCandidates.filterIt(it.peerId notin trackedOnOther) + + # We need to return 'needed' peers only + if candidates.len > needed: + candidates.setLen(needed) + return candidates + proc edgeFilterSubLoop*(self: SubscriptionManager) {.async.} = ## Reconciles filter subscriptions with the desired state from SubscriptionManager. var lastSynced = initTable[PubsubTopic, HashSet[ContentTopic]]() @@ -382,6 +417,12 @@ proc edgeFilterSubLoop*(self: SubscriptionManager) {.async.} = let allShards = toHashSet(toSeq(desired.keys)) + toHashSet(toSeq(lastSynced.keys)) + # Step 1: read state across all shards at once and + # create a list of peer dial tasks and shard tracking to delete. + + var dialTasks: seq[EdgeDialTask] + var shardsToDelete: seq[PubsubTopic] + for shard in allShards: let currTopics = desired.getOrDefault(shard) let prevTopics = lastSynced.getOrDefault(shard) @@ -404,11 +445,7 @@ proc edgeFilterSubLoop*(self: SubscriptionManager) {.async.} = asyncSpawn self.syncFilterDeltas(peer, shard, addedTopics, removedTopics) if currTopics.len == 0: - for fut in state.pending: - if not fut.finished: - await fut.cancelAndWait() - self.edgeFilterSubStates.del(shard) - # invalidates `state` — do not use after this + shardsToDelete.add(shard) else: self.updateShardHealth(shard, state[]) @@ -416,11 +453,7 @@ proc edgeFilterSubLoop*(self: SubscriptionManager) {.async.} = if needed > 0: let tracked = state.peers.mapIt(it.peerId).toHashSet() + state.pendingPeers - var candidates = self.node.peerManager.selectPeers( - filter_common.WakuFilterSubscribeCodec, some(shard) - ) - candidates.keepItIf(it.peerId notin tracked) - + let candidates = self.selectFilterCandidates(shard, tracked, needed) let toDial = min(needed, candidates.len) trace "edgeFilterSubLoop: shard reconciliation", @@ -432,8 +465,25 @@ proc edgeFilterSubLoop*(self: SubscriptionManager) {.async.} = toDial = toDial for i in 0 ..< toDial: - let fut = self.dialFilterPeer(candidates[i], shard, toSeq(currTopics)) - state.pending.add(fut) + dialTasks.add( + EdgeDialTask( + peer: candidates[i], shard: shard, topics: toSeq(currTopics) + ) + ) + + # Step 2: execute deferred shard tracking deletion and dial tasks. + + for shard in shardsToDelete: + self.edgeFilterSubStates.withValue(shard, state): + for fut in state.pending: + if not fut.finished: + await fut.cancelAndWait() + self.edgeFilterSubStates.del(shard) + + for task in dialTasks: + let fut = self.dialFilterPeer(task.peer, task.shard, task.topics) + self.edgeFilterSubStates.withValue(task.shard, state): + state.pending.add(fut) lastSynced = desired