From 3db4149c230ee2ba0e3332740385cb3bedc36aab Mon Sep 17 00:00:00 2001 From: Prem Chaitanya Prathi Date: Fri, 29 Aug 2025 18:43:29 +0530 Subject: [PATCH] get shards using callback approach (#3545) --- tests/test_peer_manager.nim | 9 ++++ waku/common/callbacks.nim | 1 + waku/node/peer_manager/peer_manager.nim | 27 +++++++----- waku/node/waku_node.nim | 27 ++++++++---- waku/waku_metadata/protocol.nim | 56 ++++++------------------- 5 files changed, 59 insertions(+), 61 deletions(-) create mode 100644 waku/common/callbacks.nim diff --git a/tests/test_peer_manager.nim b/tests/test_peer_manager.nim index c2639a7c1..1369f3f88 100644 --- a/tests/test_peer_manager.nim +++ b/tests/test_peer_manager.nim @@ -625,6 +625,15 @@ procSuite "Peer Manager": await allFutures(nodes.mapIt(it.mountRelay())) await allFutures(nodes.mapIt(it.start())) + proc simpleHandler( + topic: PubsubTopic, msg: WakuMessage + ): Future[void] {.async, gcsafe.} = + await sleepAsync(0.millis) + + let topic = "/waku/2/rs/0/0" + for node in nodes: + node.wakuRelay.subscribe(topic, simpleHandler) + # Get all peer infos let peerInfos = collect: for i in 0 .. nodes.high: diff --git a/waku/common/callbacks.nim b/waku/common/callbacks.nim new file mode 100644 index 000000000..d1da48067 --- /dev/null +++ b/waku/common/callbacks.nim @@ -0,0 +1 @@ +type GetShards* = proc(): seq[uint16] {.closure, gcsafe, raises: [].} diff --git a/waku/node/peer_manager/peer_manager.nim b/waku/node/peer_manager/peer_manager.nim index 7fcd34a60..e4ac119db 100644 --- a/waku/node/peer_manager/peer_manager.nim +++ b/waku/node/peer_manager/peer_manager.nim @@ -13,6 +13,7 @@ import import ../../common/nimchronos, ../../common/enr, + ../../common/callbacks, ../../common/utils/parse_size_units, ../../waku_core, ../../waku_relay, @@ -99,6 +100,7 @@ type PeerManager* = ref object of RootObj shardedPeerManagement: bool # temp feature flag onConnectionChange*: ConnectionChangeHandler online: bool ## state managed by online_monitor module + getShards: GetShards #~~~~~~~~~~~~~~~~~~~# # Helper Functions # @@ -769,12 +771,12 @@ proc logAndMetrics(pm: PeerManager) {.async.} = protoStreamsOut.float64, labelValues = [$Direction.Out, proto] ) - for shard in pm.wakuMetadata.shards.items: + for shard in pm.getShards().items: waku_connected_peers_per_shard.set(0.0, labelValues = [$shard]) - for shard in pm.wakuMetadata.shards.items: + for shard in pm.getShards().items: let connectedPeers = - peerStore.getPeersByShard(uint16(pm.wakuMetadata.clusterId), uint16(shard)) + peerStore.getPeersByShard(uint16(pm.wakuMetadata.clusterId), shard) waku_connected_peers_per_shard.set( connectedPeers.len.float64, labelValues = [$shard] ) @@ -788,8 +790,9 @@ proc getOnlineStateObserver*(pm: PeerManager): OnOnlineStateChange = #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# proc manageRelayPeers*(pm: PeerManager) {.async.} = + let shardsCount = pm.getShards().len #TODO: this check should not be based on whether shards are present, but rather if relay is mounted - if pm.wakuMetadata.shards.len == 0: + if shardsCount == 0: return if not pm.online: @@ -803,19 +806,18 @@ proc manageRelayPeers*(pm: PeerManager) {.async.} = var (inPeers, outPeers) = pm.connectedPeers(WakuRelayCodec) # Calculate in/out target number of peers for each shards - let inTarget = pm.inRelayPeersTarget div pm.wakuMetadata.shards.len - let outTarget = pm.outRelayPeersTarget div pm.wakuMetadata.shards.len + let inTarget = pm.inRelayPeersTarget div shardsCount + let outTarget = pm.outRelayPeersTarget div shardsCount var peerStore = pm.switch.peerStore - for shard in pm.wakuMetadata.shards.items: + for shard in pm.getShards().items: # Filter out peer not on this shard - let connectedInPeers = inPeers.filterIt( - peerStore.hasShard(it, uint16(pm.wakuMetadata.clusterId), uint16(shard)) - ) + let connectedInPeers = + inPeers.filterIt(peerStore.hasShard(it, uint16(pm.wakuMetadata.clusterId), shard)) let connectedOutPeers = outPeers.filterIt( - peerStore.hasShard(it, uint16(pm.wakuMetadata.clusterId), uint16(shard)) + peerStore.hasShard(it, uint16(pm.wakuMetadata.clusterId), shard) ) # Calculate the difference between current values and targets @@ -1001,6 +1003,9 @@ proc addExtPeerEventHandler*( # Initialization and Constructor # #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# +proc setShardGetter*(pm: PeerManager, c: GetShards) = + pm.getShards = c + proc start*(pm: PeerManager) = pm.started = true asyncSpawn pm.relayConnectivityLoop() diff --git a/waku/node/waku_node.nim b/waku/node/waku_node.nim index a59fff522..50cba98dc 100644 --- a/waku/node/waku_node.nim +++ b/waku/node/waku_node.nim @@ -48,7 +48,8 @@ import ../waku_rln_relay, ./net_config, ./peer_manager, - ../common/rate_limit/setting + ../common/rate_limit/setting, + ../common/callbacks declarePublicCounter waku_node_messages, "number of messages received", ["type"] declarePublicHistogram waku_histogram_message_size, @@ -123,6 +124,20 @@ type topicSubscriptionQueue*: AsyncEventQueue[SubscriptionEvent] rateLimitSettings*: ProtocolRateLimitSettings +proc getShardsGetter(node: WakuNode): GetShards = + return proc(): seq[uint16] {.closure, gcsafe, raises: [].} = + # fetch pubsubTopics subscribed to relay and convert them to shards + if node.wakuRelay.isNil(): + return @[] + let subTopics = node.wakuRelay.subscribedTopics() + let relayShards = topicsToRelayShards(subTopics).valueOr: + error "could not convert relay topics to shards", error = $error + return @[] + if relayShards.isSome(): + let shards = relayShards.get().shardIds + return shards + return @[] + proc new*( T: type WakuNode, netConfig: NetConfig, @@ -138,7 +153,6 @@ proc new*( info "Initializing networking", addrs = $netConfig.announcedAddresses let queue = newAsyncEventQueue[SubscriptionEvent](0) - let node = WakuNode( peerManager: peerManager, switch: switch, @@ -149,6 +163,8 @@ proc new*( rateLimitSettings: rateLimitSettings, ) + peerManager.setShardGetter(node.getShardsGetter()) + return node proc peerInfo*(node: WakuNode): PeerInfo = @@ -182,16 +198,13 @@ proc connectToNodes*( proc disconnectNode*(node: WakuNode, remotePeer: RemotePeerInfo) {.async.} = await peer_manager.disconnectNode(node.peerManager, remotePeer) -## Waku Metadata - proc mountMetadata*( node: WakuNode, clusterId: uint32, shards: seq[uint16] ): Result[void, string] = if not node.wakuMetadata.isNil(): return err("Waku metadata already mounted, skipping") - let shards32 = shards.mapIt(it.uint32) - let metadata = - WakuMetadata.new(clusterId, shards32.toHashSet(), some(node.topicSubscriptionQueue)) + + let metadata = WakuMetadata.new(clusterId, node.getShardsGetter()) node.wakuMetadata = metadata node.peerManager.wakuMetadata = metadata diff --git a/waku/waku_metadata/protocol.nim b/waku/waku_metadata/protocol.nim index 0112fd45e..01aaf027c 100644 --- a/waku/waku_metadata/protocol.nim +++ b/waku/waku_metadata/protocol.nim @@ -10,7 +10,7 @@ import libp2p/stream/connection, libp2p/crypto/crypto, eth/p2p/discoveryv5/enr -import ../common/nimchronos, ../waku_core, ./rpc +import ../common/nimchronos, ../waku_core, ./rpc, ../common/callbacks from ../waku_core/codecs import WakuMetadataCodec export WakuMetadataCodec @@ -22,14 +22,14 @@ const RpcResponseMaxBytes* = 1024 type WakuMetadata* = ref object of LPProtocol clusterId*: uint32 - shards*: HashSet[uint32] - topicSubscriptionQueue: Option[AsyncEventQueue[SubscriptionEvent]] + getShards: GetShards proc respond( m: WakuMetadata, conn: Connection ): Future[Result[void, string]] {.async, gcsafe.} = - let response = - WakuMetadataResponse(clusterId: some(m.clusterId.uint32), shards: toSeq(m.shards)) + let response = WakuMetadataResponse( + clusterId: some(m.clusterId.uint32), shards: m.getShards().mapIt(it.uint32) + ) let res = catch: await conn.writeLP(response.encode().buffer) @@ -41,8 +41,9 @@ proc respond( proc request*( m: WakuMetadata, conn: Connection ): Future[Result[WakuMetadataResponse, string]] {.async, gcsafe.} = - let request = - WakuMetadataRequest(clusterId: some(m.clusterId), shards: toSeq(m.shards)) + let request = WakuMetadataRequest( + clusterId: some(m.clusterId), shards: m.getShards().mapIt(it.uint32) + ) let writeRes = catch: await conn.writeLP(request.encode().buffer) @@ -89,7 +90,7 @@ proc initProtocolHandler(m: WakuMetadata) = remoteClusterId = response.clusterId, remoteShards = response.shards, localClusterId = m.clusterId, - localShards = m.shards, + localShards = m.getShards(), peer = conn.peerId try: @@ -101,49 +102,18 @@ proc initProtocolHandler(m: WakuMetadata) = m.handler = handler m.codec = WakuMetadataCodec -proc new*( - T: type WakuMetadata, - clusterId: uint32, - shards: HashSet[uint32], - queue: Option[AsyncEventQueue[SubscriptionEvent]], -): T = - let wm = - WakuMetadata(clusterId: clusterId, shards: shards, topicSubscriptionQueue: queue) +proc new*(T: type WakuMetadata, clusterId: uint32, getShards: GetShards): T = + let wm = WakuMetadata(clusterId: clusterId, getShards: getShards) wm.initProtocolHandler() - info "Created WakuMetadata protocol", clusterId = wm.clusterId, shards = wm.shards + info "Created WakuMetadata protocol", + clusterId = wm.clusterId, shards = wm.getShards() return wm -proc subscriptionsListener(wm: WakuMetadata) {.async.} = - ## Listen for pubsub topics subscriptions changes - if wm.topicSubscriptionQueue.isSome(): - let key = wm.topicSubscriptionQueue.get().register() - - while wm.started: - let events = await wm.topicSubscriptionQueue.get().waitEvents(key) - - for event in events: - let parsedShard = RelayShard.parse(event.topic).valueOr: - continue - - if parsedShard.clusterId != wm.clusterId: - continue - - case event.kind - of PubsubSub: - wm.shards.incl(parsedShard.shardId) - of PubsubUnsub: - wm.shards.excl(parsedShard.shardId) - else: - continue - - wm.topicSubscriptionQueue.get().unregister(key) - proc start*(wm: WakuMetadata) = wm.started = true - asyncSpawn wm.subscriptionsListener() proc stop*(wm: WakuMetadata) = wm.started = false