diff --git a/apps/wakunode2/app.nim b/apps/wakunode2/app.nim index 91863cc35..1abc200a2 100644 --- a/apps/wakunode2/app.nim +++ b/apps/wakunode2/app.nim @@ -541,12 +541,14 @@ proc startNode(node: WakuNode, conf: WakuNodeConf, proc startApp*(app: App): Future[AppResult[void]] {.async.} = if app.wakuDiscv5.isSome(): - let res = app.wakuDiscv5.get().start() + let wakuDiscv5 = app.wakuDiscv5.get() + let res = wakuDiscv5.start() if res.isErr(): return err("failed to start waku discovery v5: " & $res.error) - asyncSpawn app.wakuDiscv5.get().searchLoop(app.node.peerManager, some(app.record)) + asyncSpawn wakuDiscv5.searchLoop(app.node.peerManager, some(app.record)) + asyncSpawn wakuDiscv5.subscriptionsListener(app.node.topicSubscriptionQueue) return await startNode( app.node, diff --git a/tests/test_waku_discv5.nim b/tests/test_waku_discv5.nim index f3c58ad4d..a99380d99 100644 --- a/tests/test_waku_discv5.nim +++ b/tests/test_waku_discv5.nim @@ -5,10 +5,12 @@ import stew/results, stew/shims/net, chronos, + chronicles, testutils/unittests, libp2p/crypto/crypto as libp2p_keys, eth/keys as eth_keys import + ../../waku/waku_core/topics, ../../waku/waku_enr, ../../waku/waku_discv5, ./testlib/common, @@ -282,7 +284,7 @@ procSuite "Waku Discovery v5": let gibberish = @["aedyttydcb/uioasduyio", "jhdfsjhlsdfjhk/sadjhk", "khfsd/hjfdsgjh/dfs"] let empty: seq[string] = @[] - let relayShards = RelayShards.init(0, @[uint16(2), uint16(4), uint16(8)]) + let relayShards = RelayShards.init(0, @[uint16(2), uint16(4), uint16(8)]).expect("Valid Shards") ## When @@ -314,7 +316,7 @@ procSuite "Waku Discovery v5": shardCluster: uint16 = 21 shardIndices: seq[uint16] = @[1u16, 2u16, 5u16, 7u16, 9u16, 11u16] - let shards = RelayShards.init(shardCluster, shardIndices) + let shards = RelayShards.init(shardCluster, shardIndices).expect("Valid Shards") var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) require builder.withWakuRelaySharding(shards).isOk() @@ -332,7 +334,7 @@ procSuite "Waku Discovery v5": shardCluster: uint16 = 22 shardIndices: seq[uint16] = @[2u16, 4u16, 5u16, 8u16, 10u16, 12u16] - let shards = RelayShards.init(shardCluster, shardIndices) + let shards = RelayShards.init(shardCluster, shardIndices).expect("Valid Shards") var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) require builder.withWakuRelaySharding(shards).isOk() @@ -350,7 +352,7 @@ procSuite "Waku Discovery v5": shardCluster: uint16 = 22 shardIndices: seq[uint16] = @[1u16, 3u16, 6u16, 7u16, 9u16, 11u16] - let shards = RelayShards.init(shardCluster, shardIndices) + let shards = RelayShards.init(shardCluster, shardIndices).expect("Valid Shards") var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) require builder.withWakuRelaySharding(shards).isOk() @@ -377,4 +379,76 @@ procSuite "Waku Discovery v5": predicateCluster22(recordCluster22Indices1) == true predicateCluster22(recordCluster22Indices2) == false + asyncTest "update ENR from subscriptions": + ## Given + let + shard1 = "/waku/2/rs/0/1" + shard2 = "/waku/2/rs/0/2" + shard3 = "/waku/2/rs/0/3" + privKey = generateSecp256k1Key() + bindIp = "0.0.0.0" + extIp = "127.0.0.1" + tcpPort = 61500u16 + udpPort = 9000u16 + + let record = newTestEnrRecord( + privKey = privKey, + extIp = extIp, + tcpPort = tcpPort, + udpPort = udpPort, + ) + + let node = newTestDiscv5( + privKey = privKey, + bindIp = bindIp, + tcpPort = tcpPort, + udpPort = udpPort, + record = record + ) + + let res = node.start() + assert res.isOk(), res.error + + let queue = newAsyncEventQueue[SubscriptionEvent](0) + + ## When + asyncSpawn node.subscriptionsListener(queue) + + ## Then + queue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: shard1)) + queue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: shard2)) + queue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: shard3)) + + await sleepAsync(1.seconds) + + check: + node.protocol.localNode.record.containsShard(shard1) == true + node.protocol.localNode.record.containsShard(shard2) == true + node.protocol.localNode.record.containsShard(shard3) == true + + queue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: shard1)) + queue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: shard2)) + queue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: shard3)) + + await sleepAsync(1.seconds) + + check: + node.protocol.localNode.record.containsShard(shard1) == true + node.protocol.localNode.record.containsShard(shard2) == true + node.protocol.localNode.record.containsShard(shard3) == true + + queue.emit(SubscriptionEvent(kind: PubsubUnsub, pubsubUnsub: shard1)) + queue.emit(SubscriptionEvent(kind: PubsubUnsub, pubsubUnsub: shard2)) + queue.emit(SubscriptionEvent(kind: PubsubUnsub, pubsubUnsub: shard3)) + + await sleepAsync(1.seconds) + + check: + node.protocol.localNode.record.containsShard(shard1) == false + node.protocol.localNode.record.containsShard(shard2) == false + node.protocol.localNode.record.containsShard(shard3) == false + + ## Cleanup + await node.stop() + diff --git a/tests/test_waku_enr.nim b/tests/test_waku_enr.nim index a0b2bc1b8..aebd4e5ef 100644 --- a/tests/test_waku_enr.nim +++ b/tests/test_waku_enr.nim @@ -262,8 +262,10 @@ suite "Waku ENR - Relay static sharding": shardIndex: uint16 = 1024 ## When - expect Defect: - discard RelayShards.init(shardCluster, shardIndex) + let res = RelayShards.init(shardCluster, shardIndex) + + ## Then + assert res.isErr(), $res.get() test "new relay shards field with single invalid index in list": ## Given @@ -272,8 +274,10 @@ suite "Waku ENR - Relay static sharding": shardIndices: seq[uint16] = @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16, 1024u16] ## When - expect Defect: - discard RelayShards.init(shardCluster, shardIndices) + let res = RelayShards.init(shardCluster, shardIndices) + + ## Then + assert res.isErr(), $res.get() test "new relay shards field with single valid index": ## Given @@ -284,7 +288,7 @@ suite "Waku ENR - Relay static sharding": let topic = NsPubsubTopic.staticSharding(shardCluster, shardIndex) ## When - let shards = RelayShards.init(shardCluster, shardIndex) + let shards = RelayShards.init(shardCluster, shardIndex).expect("Valid Shards") ## Then check: @@ -310,7 +314,7 @@ suite "Waku ENR - Relay static sharding": shardIndices: seq[uint16] = @[1u16, 2u16, 2u16, 3u16, 3u16, 3u16] ## When - let shards = RelayShards.init(shardCluster, shardIndices) + let shards = RelayShards.init(shardCluster, shardIndices).expect("Valid Shards") ## Then check: @@ -344,7 +348,7 @@ suite "Waku ENR - Relay static sharding": shardCluster: uint16 = 22 shardIndices: seq[uint16] = @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16] - let shards = RelayShards.init(shardCluster, shardIndices) + let shards = RelayShards.init(shardCluster, shardIndices).expect("Valid Shards") ## When var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) @@ -370,7 +374,7 @@ suite "Waku ENR - Relay static sharding": enrSeqNum = 1u64 enrPrivKey = generatesecp256k1key() - let shards = RelayShards.init(33, toSeq(0u16 ..< 64u16)) + let shards = RelayShards.init(33, toSeq(0u16 ..< 64u16)).expect("Valid Shards") var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) require builder.withWakuRelaySharding(shards).isOk() @@ -398,8 +402,8 @@ suite "Waku ENR - Relay static sharding": enrPrivKey = generatesecp256k1key() let - shardsIndicesList = RelayShards.init(22, @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16]) - shardsBitVector = RelayShards.init(33, @[13u16, 24u16, 37u16, 61u16, 98u16, 159u16]) + shardsIndicesList = RelayShards.init(22, @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16]).expect("Valid Shards") + shardsBitVector = RelayShards.init(33, @[13u16, 24u16, 37u16, 61u16, 98u16, 159u16]).expect("Valid Shards") var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) diff --git a/waku/node/jsonrpc/relay/handlers.nim b/waku/node/jsonrpc/relay/handlers.nim index 8da726995..7a8d952e0 100644 --- a/waku/node/jsonrpc/relay/handlers.nim +++ b/waku/node/jsonrpc/relay/handlers.nim @@ -56,10 +56,9 @@ proc installRelayApiHandlers*(node: WakuNode, server: RpcServer, cache: MessageC debug "post_waku_v2_relay_v1_subscriptions" # Subscribe to all requested topics - for topic in topics: - if cache.isSubscribed(topic): - continue + let newTopics = topics.filterIt(not cache.isSubscribed(it)) + for topic in newTopics: cache.subscribe(topic) node.subscribe(topic, topicHandler) @@ -70,7 +69,9 @@ proc installRelayApiHandlers*(node: WakuNode, server: RpcServer, cache: MessageC debug "delete_waku_v2_relay_v1_subscriptions" # Unsubscribe all handlers from requested topics - for topic in topics: + let subscribedTopics = topics.filterIt(cache.isSubscribed(it)) + + for topic in subscribedTopics: node.unsubscribe(topic) cache.unsubscribe(topic) diff --git a/waku/node/rest/relay/handlers.nim b/waku/node/rest/relay/handlers.nim index 6274f685a..757972eec 100644 --- a/waku/node/rest/relay/handlers.nim +++ b/waku/node/rest/relay/handlers.nim @@ -55,13 +55,12 @@ proc installRelayPostSubscriptionsV1Handler*(router: var RestRouter, node: WakuN let req: RelayPostSubscriptionsRequest = reqResult.get() - for topic in req: - if cache.isSubscribed(string(topic)): - # Only subscribe to topics for which we have no subscribed topic handlers yet - continue + # Only subscribe to topics for which we have no subscribed topic handlers yet + let newTopics = req.filterIt(not cache.isSubscribed(it)) - cache.subscribe(string(topic)) - node.subscribe(string(topic), cache.messageHandler()) + for topic in newTopics: + cache.subscribe(topic) + node.subscribe(topic, cache.messageHandler()) return RestApiResponse.ok() @@ -88,8 +87,8 @@ proc installRelayDeleteSubscriptionsV1Handler*(router: var RestRouter, node: Wak # Unsubscribe all handlers from requested topics for topic in req: - node.unsubscribe(string(topic)) - cache.unsubscribe(string(topic)) + node.unsubscribe(topic) + cache.unsubscribe(topic) # Successfully unsubscribed from all requested topics return RestApiResponse.ok() diff --git a/waku/node/waku_node.nim b/waku/node/waku_node.nim index ade92de41..5df6200a1 100644 --- a/waku/node/waku_node.nim +++ b/waku/node/waku_node.nim @@ -103,6 +103,7 @@ type rendezvous*: RendezVous announcedAddresses* : seq[MultiAddress] started*: bool # Indicates that node has started listening + topicSubscriptionQueue*: AsyncEventQueue[SubscriptionEvent] proc getAutonatService*(rng: ref HmacDrbgContext): AutonatService = ## AutonatService request other peers to dial us back @@ -141,12 +142,15 @@ proc new*(T: type WakuNode, info "Initializing networking", addrs= $netConfig.announcedAddresses + let queue = newAsyncEventQueue[SubscriptionEvent](30) + return WakuNode( peerManager: peerManager, switch: switch, rng: rng, enr: enr, announcedAddresses: netConfig.announcedAddresses, + topicSubscriptionQueue: queue ) proc peerInfo*(node: WakuNode): PeerInfo = @@ -229,6 +233,7 @@ proc subscribe*(node: WakuNode, topic: PubsubTopic) = debug "subscribe", pubsubTopic= topic + node.topicSubscriptionQueue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: topic)) node.registerRelayDefaultHandler(topic) proc subscribe*(node: WakuNode, topic: PubsubTopic, handler: WakuRelayHandler) = @@ -240,6 +245,7 @@ proc subscribe*(node: WakuNode, topic: PubsubTopic, handler: WakuRelayHandler) = debug "subscribe", pubsubTopic= topic + node.topicSubscriptionQueue.emit(SubscriptionEvent(kind: PubsubSub, pubsubSub: topic)) node.registerRelayDefaultHandler(topic) node.wakuRelay.subscribe(topic, handler) @@ -252,6 +258,7 @@ proc unsubscribe*(node: WakuNode, topic: PubsubTopic) = info "unsubscribe", pubsubTopic=topic + node.topicSubscriptionQueue.emit(SubscriptionEvent(kind: PubsubUnsub, pubsubUnsub: topic)) node.wakuRelay.unsubscribe(topic) diff --git a/waku/waku_core/topics.nim b/waku/waku_core/topics.nim index 08519ce6f..5c44aaf78 100644 --- a/waku/waku_core/topics.nim +++ b/waku/waku_core/topics.nim @@ -7,3 +7,12 @@ export content_topic, pubsub_topic, sharding + +type + SubscriptionKind* = enum ContentSub, ContentUnsub, PubsubSub, PubsubUnsub + SubscriptionEvent* = object + case kind*: SubscriptionKind + of PubsubSub: pubsubSub*: string + of ContentSub: contentSub*: string + of PubsubUnsub: pubsubUnsub*: string + of ContentUnsub: contentUnsub*: string \ No newline at end of file diff --git a/waku/waku_discv5.nim b/waku/waku_discv5.nim index 32406fb64..6f6a880d9 100644 --- a/waku/waku_discv5.nim +++ b/waku/waku_discv5.nim @@ -4,7 +4,7 @@ else: {.push raises: [].} import - std/[sequtils, strutils, options], + std/[sequtils, strutils, options, sugar, sets], stew/results, stew/shims/net, chronos, @@ -122,6 +122,79 @@ proc new*(T: type WakuDiscoveryV5, WakuDiscoveryV5.new(rng, conf, some(record)) +proc updateENRShards(wd: WakuDiscoveryV5, + newTopics: seq[PubsubTopic], add: bool): Result[void, string] = + ## Add or remove shards from the Discv5 ENR + + let newShardOp = ?topicsToRelayShards(newTopics) + + let newShard = + if newShardOp.isSome(): + newShardOp.get() + else: + return ok() + + let typedRecordRes = wd.protocol.localNode.record.toTyped() + let typedRecord = + if typedRecordRes.isErr(): + return err($typedRecordRes.error) + else: + typedRecordRes.get() + + let currentShardsOp = typedRecord.relaySharding() + + let resultShard = + if add and currentShardsOp.isSome(): + let currentShard = currentShardsOp.get() + + if currentShard.cluster != newShard.cluster: + return err("ENR are limited to one shard cluster") + + ?RelayShards.init(currentShard.cluster, currentShard.indices & newShard.indices) + elif not add and currentShardsOp.isSome(): + let currentShard = currentShardsOp.get() + + if currentShard.cluster != newShard.cluster: + return err("ENR are limited to one shard cluster") + + let currentSet = toHashSet(currentShard.indices) + let newSet = toHashSet(newShard.indices) + + let indices = toSeq(currentSet - newSet) + + if indices.len == 0: + # Can't create RelayShard with no indices so update then return + let (field, value) = (ShardingIndicesListEnrField, newSeq[byte](3)) + + let res = wd.protocol.updateRecord([(field, value)]) + if res.isErr(): + return err($res.error) + + return ok() + + ?RelayShards.init(currentShard.cluster, indices) + elif add and currentShardsOp.isNone(): newShard + else: return ok() + + let (field, value) = + if resultShard.indices.len >= ShardingIndicesListMaxLength: + (ShardingBitVectorEnrField, resultShard.toBitVector()) + else: + let listRes = resultShard.toIndicesList() + let list = + if listRes.isErr(): + return err($listRes.error) + else: + listRes.get() + + (ShardingIndicesListEnrField, list) + + let res = wd.protocol.updateRecord([(field, value)]) + if res.isErr(): + return err($res.error) + + return ok() + proc shardingPredicate*(record: Record): Option[WakuDiscv5Predicate] = ## Filter peers based on relay sharding information @@ -219,6 +292,33 @@ proc stop*(wd: WakuDiscoveryV5): Future[void] {.async.} = debug "Successfully stopped discovery v5 service" +proc subscriptionsListener*(wd: WakuDiscoveryV5, topicSubscriptionQueue: AsyncEventQueue[SubscriptionEvent]) {.async.} = + ## Listen for pubsub topics subscriptions changes + + let key = topicSubscriptionQueue.register() + + while wd.listening: + let events = await topicSubscriptionQueue.waitEvents(key) + + # Since we don't know the events we will receive we have to anticipate. + + let subs = events.filterIt(it.kind == SubscriptionKind.PubsubSub).mapIt(it.pubsubSub) + let unsubs = events.filterIt(it.kind == SubscriptionKind.PubsubUnsub).mapIt(it.pubsubUnsub) + + let unsubRes = wd.updateENRShards(unsubs, false) + let subRes = wd.updateENRShards(subs, true) + + if subRes.isErr(): + debug "ENR shard addition failed", reason= $subRes.error + + if unsubRes.isErr(): + debug "ENR shard removal failed", reason= $unsubRes.error + + if subRes.isOk() and unsubRes.isOk(): + debug "ENR updated successfully" + + topicSubscriptionQueue.unregister(key) + ## Helper functions proc parseBootstrapAddress(address: string): Result[enr.Record, cstring] = diff --git a/waku/waku_enr/sharding.nim b/waku/waku_enr/sharding.nim index e431b0c59..17954803c 100644 --- a/waku/waku_enr/sharding.nim +++ b/waku/waku_enr/sharding.nim @@ -23,6 +23,7 @@ const MaxShardIndex: uint16 = 1023 const ShardingIndicesListEnrField* = "rs" + ShardingIndicesListMaxLength* = 64 ShardingBitVectorEnrField* = "rsv" @@ -42,31 +43,31 @@ func topics*(rs: RelayShards): seq[NsPubsubTopic] = rs.indices.mapIt(NsPubsubTopic.staticSharding(rs.cluster, it)) -func init*(T: type RelayShards, cluster, index: uint16): T = +func init*(T: type RelayShards, cluster, index: uint16): Result[T, string] = if index > MaxShardIndex: - raise newException(Defect, "invalid index") + return err("invalid index") - RelayShards(cluster: cluster, indices: @[index]) + ok(RelayShards(cluster: cluster, indices: @[index])) -func init*(T: type RelayShards, cluster: uint16, indices: varargs[uint16]): T = +func init*(T: type RelayShards, cluster: uint16, indices: varargs[uint16]): Result[T, string] = if toSeq(indices).anyIt(it > MaxShardIndex): - raise newException(Defect, "invalid index") + return err("invalid index") let indicesSeq = deduplicate(@indices) if indices.len < 1: - raise newException(Defect, "invalid index count") + return err("invalid index count") - RelayShards(cluster: cluster, indices: indicesSeq) + ok(RelayShards(cluster: cluster, indices: indicesSeq)) -func init*(T: type RelayShards, cluster: uint16, indices: seq[uint16]): T = +func init*(T: type RelayShards, cluster: uint16, indices: seq[uint16]): Result[T, string] = if indices.anyIt(it > MaxShardIndex): - raise newException(Defect, "invalid index") + return err("invalid index") let indicesSeq = deduplicate(indices) if indices.len < 1: - raise newException(Defect, "invalid index count") + return err("invalid index count") - RelayShards(cluster: cluster, indices: indicesSeq) + ok(RelayShards(cluster: cluster, indices: indicesSeq)) func topicsToRelayShards*(topics: seq[string]): Result[Option[RelayShards], string] = if topics.len < 1: @@ -87,7 +88,9 @@ func topicsToRelayShards*(topics: seq[string]): Result[Option[RelayShards], stri if parsedTopicsRes.anyIt(it.get().cluster != parsedTopicsRes[0].get().cluster): return err("use sharded topics within the same cluster.") - return ok(some(RelayShards.init(parsedTopicsRes[0].get().cluster, parsedTopicsRes.mapIt(it.get().shard)))) + let relayShard = ?RelayShards.init(parsedTopicsRes[0].get().cluster, parsedTopicsRes.mapIt(it.get().shard)) + + return ok(some(relayShard)) func contains*(rs: RelayShards, cluster, index: uint16): bool = rs.cluster == cluster and rs.indices.contains(index) @@ -108,7 +111,7 @@ func contains*(rs: RelayShards, topic: PubsubTopic|string): bool = # ENR builder extension -func toIndicesList(rs: RelayShards): EnrResult[seq[byte]] = +func toIndicesList*(rs: RelayShards): EnrResult[seq[byte]] = if rs.indices.len > high(uint8).int: return err("indices list too long") @@ -137,7 +140,7 @@ func fromIndicesList(buf: seq[byte]): Result[RelayShards, string] = ok(RelayShards(cluster: cluster, indices: indices)) -func toBitVector(rs: RelayShards): seq[byte] = +func toBitVector*(rs: RelayShards): seq[byte] = ## The value is comprised of a two-byte shard cluster index in network byte ## order concatenated with a 128-byte wide bit vector. The bit vector ## indicates which shards of the respective shard cluster the node is part @@ -182,7 +185,7 @@ func withWakuRelayShardingBitVector*(builder: var EnrBuilder, rs: RelayShards): ok() func withWakuRelaySharding*(builder: var EnrBuilder, rs: RelayShards): EnrResult[void] = - if rs.indices.len >= 64: + if rs.indices.len >= ShardingIndicesListMaxLength: builder.withWakuRelayShardingBitVector(rs) else: builder.withWakuRelayShardingIndicesList(rs) @@ -264,7 +267,7 @@ proc containsShard*(r: Record, topic: NsPubsubTopic): bool = containsShard(r, topic.cluster, topic.shard) -func containsShard*(r: Record, topic: PubsubTopic|string): bool = +proc containsShard*(r: Record, topic: PubsubTopic|string): bool = let parseRes = NsPubsubTopic.parse(topic) if parseRes.isErr(): debug "invalid static sharding topic", topic = topic, error = parseRes.error