diff --git a/.github/ISSUE_TEMPLATE/prepare_beta_release.md b/.github/ISSUE_TEMPLATE/prepare_beta_release.md index 9afaefbd1..383d9018c 100644 --- a/.github/ISSUE_TEMPLATE/prepare_beta_release.md +++ b/.github/ISSUE_TEMPLATE/prepare_beta_release.md @@ -22,7 +22,7 @@ All items below are to be completed by the owner of the given release. - [ ] Generate and edit release notes in CHANGELOG.md. - [ ] **Waku test and fleets validation** - - [ ] Ensure all the unit tests (specifically js-waku tests) are green against the release candidate. + - [ ] Ensure all the unit tests (specifically logos-messaging-js tests) are green against the release candidate. - [ ] Deploy the release candidate to `waku.test` only through [deploy-waku-test job](https://ci.infra.status.im/job/nim-waku/job/deploy-waku-test/) and wait for it to finish (Jenkins access required; ask the infra team if you don't have it). - After completion, disable [deployment job](https://ci.infra.status.im/job/nim-waku/) so that its version is not updated on every merge to master. - Verify the deployed version at https://fleets.waku.org/. diff --git a/.github/ISSUE_TEMPLATE/prepare_full_release.md b/.github/ISSUE_TEMPLATE/prepare_full_release.md index 314146f60..d7458a8e3 100644 --- a/.github/ISSUE_TEMPLATE/prepare_full_release.md +++ b/.github/ISSUE_TEMPLATE/prepare_full_release.md @@ -24,7 +24,7 @@ All items below are to be completed by the owner of the given release. - [ ] **Validation of release candidate** - [ ] **Automated testing** - - [ ] Ensure all the unit tests (specifically js-waku tests) are green against the release candidate. + - [ ] Ensure all the unit tests (specifically logos-messaging-js tests) are green against the release candidate. - [ ] Ask Vac-QA and Vac-DST to perform the available tests against the release candidate. - [ ] Vac-DST (an additional report is needed; see [this](https://www.notion.so/DST-Reports-1228f96fb65c80729cd1d98a7496fe6f)) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b12a5109..da8383e43 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -151,14 +151,14 @@ jobs: js-waku-node: needs: build-docker-image - uses: logos-messaging/js-waku/.github/workflows/test-node.yml@master + uses: logos-messaging/logos-messaging-js/.github/workflows/test-node.yml@master with: nim_wakunode_image: ${{ needs.build-docker-image.outputs.image }} test_type: node js-waku-node-optional: needs: build-docker-image - uses: logos-messaging/js-waku/.github/workflows/test-node.yml@master + uses: logos-messaging/logos-messaging-js/.github/workflows/test-node.yml@master with: nim_wakunode_image: ${{ needs.build-docker-image.outputs.image }} test_type: node-optional diff --git a/.github/workflows/pre-release.yml b/.github/workflows/pre-release.yml index 380ec755f..faded198b 100644 --- a/.github/workflows/pre-release.yml +++ b/.github/workflows/pre-release.yml @@ -98,7 +98,7 @@ jobs: js-waku-node: needs: build-docker-image - uses: logos-messaging/js-waku/.github/workflows/test-node.yml@master + uses: logos-messaging/logos-messaging-js/.github/workflows/test-node.yml@master with: nim_wakunode_image: ${{ needs.build-docker-image.outputs.image }} test_type: node @@ -106,7 +106,7 @@ jobs: js-waku-node-optional: needs: build-docker-image - uses: logos-messaging/js-waku/.github/workflows/test-node.yml@master + uses: logos-messaging/logos-messaging-js/.github/workflows/test-node.yml@master with: nim_wakunode_image: ${{ needs.build-docker-image.outputs.image }} test_type: node-optional diff --git a/.gitmodules b/.gitmodules index 4d56c4333..6a63491e3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -184,6 +184,12 @@ url = https://github.com/logos-messaging/waku-rlnv2-contract.git ignore = untracked branch = master +[submodule "vendor/nim-lsquic"] + path = vendor/nim-lsquic + url = https://github.com/vacp2p/nim-lsquic +[submodule "vendor/nim-jwt"] + path = vendor/nim-jwt + url = https://github.com/vacp2p/nim-jwt.git [submodule "vendor/nim-ffi"] path = vendor/nim-ffi url = https://github.com/logos-messaging/nim-ffi/ diff --git a/README.md b/README.md index ce352d6f5..c64479738 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,21 @@ -# Nwaku +# Logos Messaging Nim ## Introduction -The nwaku repository implements Waku, and provides tools related to it. +The logos-messaging-nim, a.k.a. lmn or nwaku, repository implements a set of libp2p protocols aimed to bring +private communications. -- A Nim implementation of the [Waku (v2) protocol](https://specs.vac.dev/specs/waku/v2/waku-v2.html). -- CLI application `wakunode2` that allows you to run a Waku node. -- Examples of Waku usage. +- Nim implementation of [these specs](https://github.com/vacp2p/rfc-index/tree/main/waku). +- C library that exposes the implemented protocols. +- CLI application that allows you to run an lmn node. +- Examples. - Various tests of above. For more details see the [source code](waku/README.md) ## How to Build & Run ( Linux, MacOS & WSL ) -These instructions are generic. For more detailed instructions, see the Waku source code above. +These instructions are generic. For more detailed instructions, see the source code above. ### Prerequisites diff --git a/tests/common/test_all.nim b/tests/common/test_all.nim index 7495c7c9e..d597a7424 100644 --- a/tests/common/test_all.nim +++ b/tests/common/test_all.nim @@ -6,7 +6,6 @@ import ./test_protobuf_validation, ./test_sqlite_migrations, ./test_parse_size, - ./test_tokenbucket, ./test_requestratelimiter, ./test_ratelimit_setting, ./test_timed_map, diff --git a/tests/common/test_event_broker.nim b/tests/common/test_event_broker.nim index cead1277f..bcd081f4f 100644 --- a/tests/common/test_event_broker.nim +++ b/tests/common/test_event_broker.nim @@ -4,6 +4,15 @@ import testutils/unittests import waku/common/broker/event_broker +type ExternalDefinedEventType = object + label*: string + +EventBroker: + type IntEvent = int + +EventBroker: + type ExternalAliasEvent = distinct ExternalDefinedEventType + EventBroker: type SampleEvent = object value*: int @@ -123,3 +132,70 @@ suite "EventBroker": check counter == 21 # 1+2+3 + 4+5+6 RefEvent.dropAllListeners() + + test "supports BrokerContext-scoped listeners": + SampleEvent.dropAllListeners() + + let ctxA = NewBrokerContext() + let ctxB = NewBrokerContext() + + var seenA: seq[int] = @[] + var seenB: seq[int] = @[] + + discard SampleEvent.listen( + ctxA, + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + seenA.add(evt.value), + ) + + discard SampleEvent.listen( + ctxB, + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + seenB.add(evt.value), + ) + + SampleEvent.emit(ctxA, SampleEvent(value: 1, label: "a")) + SampleEvent.emit(ctxB, SampleEvent(value: 2, label: "b")) + waitForListeners() + + check seenA == @[1] + check seenB == @[2] + + SampleEvent.dropAllListeners(ctxA) + SampleEvent.emit(ctxA, SampleEvent(value: 3, label: "a2")) + SampleEvent.emit(ctxB, SampleEvent(value: 4, label: "b2")) + waitForListeners() + + check seenA == @[1] + check seenB == @[2, 4] + + SampleEvent.dropAllListeners(ctxB) + + test "supports non-object event types (auto-distinct)": + var seen: seq[int] = @[] + + discard IntEvent.listen( + proc(evt: IntEvent): Future[void] {.async: (raises: []).} = + seen.add(int(evt)) + ) + + IntEvent.emit(IntEvent(42)) + waitForListeners() + + check seen == @[42] + IntEvent.dropAllListeners() + + test "supports externally-defined type aliases (auto-distinct)": + var seen: seq[string] = @[] + + discard ExternalAliasEvent.listen( + proc(evt: ExternalAliasEvent): Future[void] {.async: (raises: []).} = + let base = ExternalDefinedEventType(evt) + seen.add(base.label) + ) + + ExternalAliasEvent.emit(ExternalAliasEvent(ExternalDefinedEventType(label: "x"))) + waitForListeners() + + check seen == @["x"] + ExternalAliasEvent.dropAllListeners() diff --git a/tests/common/test_multi_request_broker.nim b/tests/common/test_multi_request_broker.nim index 3bf10a54d..39ed90eea 100644 --- a/tests/common/test_multi_request_broker.nim +++ b/tests/common/test_multi_request_broker.nim @@ -31,6 +31,23 @@ MultiRequestBroker: suffix: string ): Future[Result[DualResponse, string]] {.async.} +type ExternalBaseType = string + +MultiRequestBroker: + type NativeIntResponse = int + + proc signatureFetch*(): Future[Result[NativeIntResponse, string]] {.async.} + +MultiRequestBroker: + type ExternalAliasResponse = ExternalBaseType + + proc signatureFetch*(): Future[Result[ExternalAliasResponse, string]] {.async.} + +MultiRequestBroker: + type AlreadyDistinctResponse = distinct int + + proc signatureFetch*(): Future[Result[AlreadyDistinctResponse, string]] {.async.} + suite "MultiRequestBroker": test "aggregates zero-argument providers": discard NoArgResponse.setProvider( @@ -194,7 +211,6 @@ suite "MultiRequestBroker": let firstHandler = NoArgResponse.setProvider( proc(): Future[Result[NoArgResponse, string]] {.async.} = raise newException(ValueError, "first handler raised") - ok(NoArgResponse(label: "any")) ) discard NoArgResponse.setProvider( @@ -211,6 +227,99 @@ suite "MultiRequestBroker": test "ref providers returning nil fail request": DualResponse.clearProviders() + test "supports native request types": + NativeIntResponse.clearProviders() + + discard NativeIntResponse.setProvider( + proc(): Future[Result[NativeIntResponse, string]] {.async.} = + ok(NativeIntResponse(1)) + ) + + discard NativeIntResponse.setProvider( + proc(): Future[Result[NativeIntResponse, string]] {.async.} = + ok(NativeIntResponse(2)) + ) + + let res = waitFor NativeIntResponse.request() + check res.isOk() + check res.get().len == 2 + check res.get().anyIt(int(it) == 1) + check res.get().anyIt(int(it) == 2) + + NativeIntResponse.clearProviders() + + test "supports external request types": + ExternalAliasResponse.clearProviders() + + discard ExternalAliasResponse.setProvider( + proc(): Future[Result[ExternalAliasResponse, string]] {.async.} = + ok(ExternalAliasResponse("hello")) + ) + + let res = waitFor ExternalAliasResponse.request() + check res.isOk() + check res.get().len == 1 + check ExternalBaseType(res.get()[0]) == "hello" + + ExternalAliasResponse.clearProviders() + + test "supports already-distinct request types": + AlreadyDistinctResponse.clearProviders() + + discard AlreadyDistinctResponse.setProvider( + proc(): Future[Result[AlreadyDistinctResponse, string]] {.async.} = + ok(AlreadyDistinctResponse(7)) + ) + + let res = waitFor AlreadyDistinctResponse.request() + check res.isOk() + check res.get().len == 1 + check int(res.get()[0]) == 7 + + AlreadyDistinctResponse.clearProviders() + + test "context-aware providers are isolated": + NoArgResponse.clearProviders() + let ctxA = NewBrokerContext() + let ctxB = NewBrokerContext() + + discard NoArgResponse.setProvider( + ctxA, + proc(): Future[Result[NoArgResponse, string]] {.async.} = + ok(NoArgResponse(label: "a")), + ) + discard NoArgResponse.setProvider( + ctxB, + proc(): Future[Result[NoArgResponse, string]] {.async.} = + ok(NoArgResponse(label: "b")), + ) + + let resA = waitFor NoArgResponse.request(ctxA) + check resA.isOk() + check resA.get().len == 1 + check resA.get()[0].label == "a" + + let resB = waitFor NoArgResponse.request(ctxB) + check resB.isOk() + check resB.get().len == 1 + check resB.get()[0].label == "b" + + let resDefault = waitFor NoArgResponse.request() + check resDefault.isOk() + check resDefault.get().len == 0 + + NoArgResponse.clearProviders(ctxA) + let clearedA = waitFor NoArgResponse.request(ctxA) + check clearedA.isOk() + check clearedA.get().len == 0 + + let stillB = waitFor NoArgResponse.request(ctxB) + check stillB.isOk() + check stillB.get().len == 1 + check stillB.get()[0].label == "b" + + NoArgResponse.clearProviders(ctxB) + discard DualResponse.setProvider( proc(): Future[Result[DualResponse, string]] {.async.} = let nilResponse: DualResponse = nil diff --git a/tests/common/test_request_broker.nim b/tests/common/test_request_broker.nim index a534216dc..87065a916 100644 --- a/tests/common/test_request_broker.nim +++ b/tests/common/test_request_broker.nim @@ -203,6 +203,104 @@ suite "RequestBroker macro (async mode)": DualResponse.clearProvider() + test "supports keyed providers (async, zero-arg)": + SimpleResponse.clearProvider() + + check SimpleResponse + .setProvider( + proc(): Future[Result[SimpleResponse, string]] {.async.} = + ok(SimpleResponse(value: "default")) + ) + .isOk() + + check SimpleResponse + .setProvider( + BrokerContext(0x11111111'u32), + proc(): Future[Result[SimpleResponse, string]] {.async.} = + ok(SimpleResponse(value: "one")), + ) + .isOk() + + check SimpleResponse + .setProvider( + BrokerContext(0x22222222'u32), + proc(): Future[Result[SimpleResponse, string]] {.async.} = + ok(SimpleResponse(value: "two")), + ) + .isOk() + + let defaultRes = waitFor SimpleResponse.request() + check defaultRes.isOk() + check defaultRes.value.value == "default" + + let res1 = waitFor SimpleResponse.request(BrokerContext(0x11111111'u32)) + check res1.isOk() + check res1.value.value == "one" + + let res2 = waitFor SimpleResponse.request(BrokerContext(0x22222222'u32)) + check res2.isOk() + check res2.value.value == "two" + + let missing = waitFor SimpleResponse.request(BrokerContext(0x33333333'u32)) + check missing.isErr() + check missing.error.contains("no provider registered for broker context") + + check SimpleResponse + .setProvider( + BrokerContext(0x11111111'u32), + proc(): Future[Result[SimpleResponse, string]] {.async.} = + ok(SimpleResponse(value: "dup")), + ) + .isErr() + + SimpleResponse.clearProvider() + + test "supports keyed providers (async, with args)": + KeyedResponse.clearProvider() + + check KeyedResponse + .setProvider( + proc(key: string, subKey: int): Future[Result[KeyedResponse, string]] {.async.} = + ok(KeyedResponse(key: "default-" & key, payload: $subKey)) + ) + .isOk() + + check KeyedResponse + .setProvider( + BrokerContext(0xABCDEF01'u32), + proc(key: string, subKey: int): Future[Result[KeyedResponse, string]] {.async.} = + ok(KeyedResponse(key: "k1-" & key, payload: "p" & $subKey)), + ) + .isOk() + + check KeyedResponse + .setProvider( + BrokerContext(0xABCDEF02'u32), + proc(key: string, subKey: int): Future[Result[KeyedResponse, string]] {.async.} = + ok(KeyedResponse(key: "k2-" & key, payload: "q" & $subKey)), + ) + .isOk() + + let d = waitFor KeyedResponse.request("topic", 7) + check d.isOk() + check d.value.key == "default-topic" + + let k1 = waitFor KeyedResponse.request(BrokerContext(0xABCDEF01'u32), "topic", 7) + check k1.isOk() + check k1.value.key == "k1-topic" + check k1.value.payload == "p7" + + let k2 = waitFor KeyedResponse.request(BrokerContext(0xABCDEF02'u32), "topic", 7) + check k2.isOk() + check k2.value.key == "k2-topic" + check k2.value.payload == "q7" + + let miss = waitFor KeyedResponse.request(BrokerContext(0xDEADBEEF'u32), "topic", 7) + check miss.isErr() + check miss.error.contains("no provider registered for broker context") + + KeyedResponse.clearProvider() + ## --------------------------------------------------------------------------- ## Sync-mode brokers + tests ## --------------------------------------------------------------------------- @@ -370,6 +468,71 @@ suite "RequestBroker macro (sync mode)": ImplicitResponseSync.clearProvider() + test "supports keyed providers (sync, zero-arg)": + SimpleResponseSync.clearProvider() + + check SimpleResponseSync + .setProvider( + proc(): Result[SimpleResponseSync, string] = + ok(SimpleResponseSync(value: "default")) + ) + .isOk() + + check SimpleResponseSync + .setProvider( + BrokerContext(0x10101010'u32), + proc(): Result[SimpleResponseSync, string] = + ok(SimpleResponseSync(value: "ten")), + ) + .isOk() + + let defaultRes = SimpleResponseSync.request() + check defaultRes.isOk() + check defaultRes.value.value == "default" + + let keyedRes = SimpleResponseSync.request(BrokerContext(0x10101010'u32)) + check keyedRes.isOk() + check keyedRes.value.value == "ten" + + let miss = SimpleResponseSync.request(BrokerContext(0x20202020'u32)) + check miss.isErr() + check miss.error.contains("no provider registered for broker context") + + SimpleResponseSync.clearProvider() + + test "supports keyed providers (sync, with args)": + KeyedResponseSync.clearProvider() + + check KeyedResponseSync + .setProvider( + proc(key: string, subKey: int): Result[KeyedResponseSync, string] = + ok(KeyedResponseSync(key: "default-" & key, payload: $subKey)) + ) + .isOk() + + check KeyedResponseSync + .setProvider( + BrokerContext(0xA0A0A0A0'u32), + proc(key: string, subKey: int): Result[KeyedResponseSync, string] = + ok(KeyedResponseSync(key: "k-" & key, payload: "p" & $subKey)), + ) + .isOk() + + let d = KeyedResponseSync.request("topic", 2) + check d.isOk() + check d.value.key == "default-topic" + + let keyed = KeyedResponseSync.request(BrokerContext(0xA0A0A0A0'u32), "topic", 2) + check keyed.isOk() + check keyed.value.key == "k-topic" + check keyed.value.payload == "p2" + + let miss = KeyedResponseSync.request(BrokerContext(0xB0B0B0B0'u32), "topic", 2) + check miss.isErr() + check miss.error.contains("no provider registered for broker context") + + KeyedResponseSync.clearProvider() + ## --------------------------------------------------------------------------- ## POD / external type brokers + tests (distinct/alias behavior) ## --------------------------------------------------------------------------- diff --git a/tests/common/test_tokenbucket.nim b/tests/common/test_tokenbucket.nim deleted file mode 100644 index 5bc1a0583..000000000 --- a/tests/common/test_tokenbucket.nim +++ /dev/null @@ -1,69 +0,0 @@ -# Chronos Test Suite -# (c) Copyright 2022-Present -# Status Research & Development GmbH -# -# Licensed under either of -# Apache License, version 2.0, (LICENSE-APACHEv2) -# MIT license (LICENSE-MIT) - -{.used.} - -import testutils/unittests -import chronos -import ../../waku/common/rate_limit/token_bucket - -suite "Token Bucket": - test "TokenBucket Sync test - strict": - var bucket = TokenBucket.newStrict(1000, 1.milliseconds) - let - start = Moment.now() - fullTime = start + 1.milliseconds - check: - bucket.tryConsume(800, start) == true - bucket.tryConsume(200, start) == true - # Out of budget - bucket.tryConsume(100, start) == false - bucket.tryConsume(800, fullTime) == true - bucket.tryConsume(200, fullTime) == true - # Out of budget - bucket.tryConsume(100, fullTime) == false - - test "TokenBucket Sync test - compensating": - var bucket = TokenBucket.new(1000, 1.milliseconds) - let - start = Moment.now() - fullTime = start + 1.milliseconds - check: - bucket.tryConsume(800, start) == true - bucket.tryConsume(200, start) == true - # Out of budget - bucket.tryConsume(100, start) == false - bucket.tryConsume(800, fullTime) == true - bucket.tryConsume(200, fullTime) == true - # Due not using the bucket for a full period the compensation will satisfy this request - bucket.tryConsume(100, fullTime) == true - - test "TokenBucket Max compensation": - var bucket = TokenBucket.new(1000, 1.minutes) - var reqTime = Moment.now() - - check bucket.tryConsume(1000, reqTime) - check bucket.tryConsume(1, reqTime) == false - reqTime += 1.minutes - check bucket.tryConsume(500, reqTime) == true - reqTime += 1.minutes - check bucket.tryConsume(1000, reqTime) == true - reqTime += 10.seconds - # max compensation is 25% so try to consume 250 more - check bucket.tryConsume(250, reqTime) == true - reqTime += 49.seconds - # out of budget within the same period - check bucket.tryConsume(1, reqTime) == false - - test "TokenBucket Short replenish": - var bucket = TokenBucket.new(15000, 1.milliseconds) - let start = Moment.now() - check bucket.tryConsume(15000, start) - check bucket.tryConsume(1, start) == false - - check bucket.tryConsume(15000, start + 1.milliseconds) == true diff --git a/tests/test_peer_manager.nim b/tests/test_peer_manager.nim index 1369f3f88..97df39582 100644 --- a/tests/test_peer_manager.nim +++ b/tests/test_peer_manager.nim @@ -997,6 +997,7 @@ procSuite "Peer Manager": .build(), maxFailedAttempts = 1, storage = nil, + maxConnections = 20, ) # Create 30 peers and add them to the peerstore @@ -1063,6 +1064,7 @@ procSuite "Peer Manager": backoffFactor = 2, maxFailedAttempts = 10, storage = nil, + maxConnections = 20, ) var p1: PeerId require p1.init("QmeuZJbXrszW2jdT7GdduSjQskPU3S7vvGWKtKgDfkDvW" & "1") @@ -1116,6 +1118,7 @@ procSuite "Peer Manager": .build(), maxFailedAttempts = 150, storage = nil, + maxConnections = 20, ) # Should result in backoff > 1 week @@ -1131,6 +1134,7 @@ procSuite "Peer Manager": .build(), maxFailedAttempts = 10, storage = nil, + maxConnections = 20, ) let pm = PeerManager.new( @@ -1144,6 +1148,7 @@ procSuite "Peer Manager": .build(), maxFailedAttempts = 5, storage = nil, + maxConnections = 20, ) asyncTest "colocationLimit is enforced by pruneConnsByIp()": diff --git a/tests/waku_filter_v2/test_waku_filter_dos_protection.nim b/tests/waku_filter_v2/test_waku_filter_dos_protection.nim index 7c8c640ba..fd3d8c837 100644 --- a/tests/waku_filter_v2/test_waku_filter_dos_protection.nim +++ b/tests/waku_filter_v2/test_waku_filter_dos_protection.nim @@ -122,24 +122,51 @@ suite "Waku Filter - DOS protection": check client2.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == none(FilterSubscribeErrorKind) - await sleepAsync(20.milliseconds) - check client1.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == - none(FilterSubscribeErrorKind) + # Avoid using tiny sleeps to control refill behavior: CI scheduling can + # oversleep and mint additional tokens. Instead, issue a small burst of + # subscribe requests and require at least one TOO_MANY_REQUESTS. + var c1SubscribeFutures = newSeq[Future[FilterSubscribeResult]]() + for i in 0 ..< 6: + c1SubscribeFutures.add( + client1.wakuFilterClient.subscribe( + serverRemotePeerInfo, pubsubTopic, contentTopicSeq + ) + ) + + let c1Finished = await allFinished(c1SubscribeFutures) + var c1GotTooMany = false + for fut in c1Finished: + check not fut.failed() + let res = fut.read() + if res.isErr() and res.error().kind == FilterSubscribeErrorKind.TOO_MANY_REQUESTS: + c1GotTooMany = true + break + check c1GotTooMany + + # Ensure the other client is not affected by client1's rate limit. check client2.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == none(FilterSubscribeErrorKind) - await sleepAsync(20.milliseconds) - check client1.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == - none(FilterSubscribeErrorKind) - await sleepAsync(20.milliseconds) - check client1.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == - some(FilterSubscribeErrorKind.TOO_MANY_REQUESTS) - check client2.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == - none(FilterSubscribeErrorKind) - check client2.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == - some(FilterSubscribeErrorKind.TOO_MANY_REQUESTS) + + var c2SubscribeFutures = newSeq[Future[FilterSubscribeResult]]() + for i in 0 ..< 6: + c2SubscribeFutures.add( + client2.wakuFilterClient.subscribe( + serverRemotePeerInfo, pubsubTopic, contentTopicSeq + ) + ) + + let c2Finished = await allFinished(c2SubscribeFutures) + var c2GotTooMany = false + for fut in c2Finished: + check not fut.failed() + let res = fut.read() + if res.isErr() and res.error().kind == FilterSubscribeErrorKind.TOO_MANY_REQUESTS: + c2GotTooMany = true + break + check c2GotTooMany # ensure period of time has passed and clients can again use the service - await sleepAsync(1000.milliseconds) + await sleepAsync(1100.milliseconds) check client1.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == none(FilterSubscribeErrorKind) check client2.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == @@ -147,29 +174,54 @@ suite "Waku Filter - DOS protection": asyncTest "Ensure normal usage allowed": # Given + # Rate limit setting is (3 requests / 1000ms) per peer. + # In a token-bucket model this means: + # - capacity = 3 tokens + # - refill rate = 3 tokens / second => ~1 token every ~333ms + # - each request consumes 1 token (including UNSUBSCRIBE) check client1.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == none(FilterSubscribeErrorKind) check wakuFilter.subscriptions.isSubscribed(client1.clientPeerId) - await sleepAsync(500.milliseconds) - check client1.ping(serverRemotePeerInfo) == none(FilterSubscribeErrorKind) - check wakuFilter.subscriptions.isSubscribed(client1.clientPeerId) + # Expected remaining tokens (approx): 2 await sleepAsync(500.milliseconds) check client1.ping(serverRemotePeerInfo) == none(FilterSubscribeErrorKind) check wakuFilter.subscriptions.isSubscribed(client1.clientPeerId) - await sleepAsync(50.milliseconds) + # After ~500ms, ~1 token refilled; PING consumes 1 => expected remaining: 2 + + await sleepAsync(500.milliseconds) + check client1.ping(serverRemotePeerInfo) == none(FilterSubscribeErrorKind) + check wakuFilter.subscriptions.isSubscribed(client1.clientPeerId) + + # After another ~500ms, ~1 token refilled; PING consumes 1 => expected remaining: 2 + check client1.unsubscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == none(FilterSubscribeErrorKind) check wakuFilter.subscriptions.isSubscribed(client1.clientPeerId) == false - await sleepAsync(50.milliseconds) check client1.ping(serverRemotePeerInfo) == some(FilterSubscribeErrorKind.NOT_FOUND) - check client1.ping(serverRemotePeerInfo) == some(FilterSubscribeErrorKind.NOT_FOUND) - await sleepAsync(50.milliseconds) - check client1.ping(serverRemotePeerInfo) == - some(FilterSubscribeErrorKind.TOO_MANY_REQUESTS) + # After unsubscribing, PING is expected to return NOT_FOUND while still + # counting towards the rate limit. + + # CI can oversleep / schedule slowly, which can mint extra tokens between + # requests. To make the test robust, issue a small burst of pings and + # require at least one TOO_MANY_REQUESTS response. + var pingFutures = newSeq[Future[FilterSubscribeResult]]() + for i in 0 ..< 9: + pingFutures.add(client1.wakuFilterClient.ping(serverRemotePeerInfo)) + + let finished = await allFinished(pingFutures) + var gotTooMany = false + for fut in finished: + check not fut.failed() + let pingRes = fut.read() + if pingRes.isErr() and pingRes.error().kind == FilterSubscribeErrorKind.TOO_MANY_REQUESTS: + gotTooMany = true + break + + check gotTooMany check client2.subscribe(serverRemotePeerInfo, pubsubTopic, contentTopicSeq) == none(FilterSubscribeErrorKind) diff --git a/tests/waku_lightpush/test_ratelimit.nim b/tests/waku_lightpush/test_ratelimit.nim index 7420a4e56..bdab3f074 100644 --- a/tests/waku_lightpush/test_ratelimit.nim +++ b/tests/waku_lightpush/test_ratelimit.nim @@ -80,11 +80,12 @@ suite "Rate limited push service": await allFutures(serverSwitch.start(), clientSwitch.start()) ## Given - var handlerFuture = newFuture[(string, WakuMessage)]() + # Don't rely on per-request timing assumptions or a single shared Future. + # CI can be slow enough that sequential requests accidentally refill tokens. + # Instead we issue a small burst and assert we observe at least one rejection. let handler = proc( peer: PeerId, pubsubTopic: PubsubTopic, message: WakuMessage ): Future[WakuLightPushResult] {.async.} = - handlerFuture.complete((pubsubTopic, message)) return lightpushSuccessResult(1) let @@ -93,45 +94,38 @@ suite "Rate limited push service": client = newTestWakuLightpushClient(clientSwitch) let serverPeerId = serverSwitch.peerInfo.toRemotePeerInfo() - let topic = DefaultPubsubTopic + let tokenPeriod = 500.millis - let successProc = proc(): Future[void] {.async.} = + # Fire a burst of requests; require at least one success and one rejection. + var publishFutures = newSeq[Future[WakuLightPushResult]]() + for i in 0 ..< 10: let message = fakeWakuMessage() - handlerFuture = newFuture[(string, WakuMessage)]() - let requestRes = - await client.publish(some(DefaultPubsubTopic), message, serverPeerId) - discard await handlerFuture.withTimeout(10.millis) + publishFutures.add( + client.publish(some(DefaultPubsubTopic), message, serverPeerId) + ) - check: - requestRes.isOk() - handlerFuture.finished() - let (handledMessagePubsubTopic, handledMessage) = handlerFuture.read() - check: - handledMessagePubsubTopic == DefaultPubsubTopic - handledMessage == message + let finished = await allFinished(publishFutures) + var gotOk = false + var gotTooMany = false + for fut in finished: + check not fut.failed() + let res = fut.read() + if res.isOk(): + gotOk = true + else: + check res.error.code == LightPushErrorCode.TOO_MANY_REQUESTS + check res.error.desc == some(TooManyRequestsMessage) + gotTooMany = true - let rejectProc = proc(): Future[void] {.async.} = - let message = fakeWakuMessage() - handlerFuture = newFuture[(string, WakuMessage)]() - let requestRes = - await client.publish(some(DefaultPubsubTopic), message, serverPeerId) - discard await handlerFuture.withTimeout(10.millis) + check gotOk + check gotTooMany - check: - requestRes.isErr() - requestRes.error.code == LightPushErrorCode.TOO_MANY_REQUESTS - requestRes.error.desc == some(TooManyRequestsMessage) - - for testCnt in 0 .. 2: - await successProc() - await sleepAsync(20.millis) - - await rejectProc() - - await sleepAsync(500.millis) - - ## next one shall succeed due to the rate limit time window has passed - await successProc() + # ensure period of time has passed and the client can again use the service + await sleepAsync(tokenPeriod + 100.millis) + let recoveryRes = await client.publish( + some(DefaultPubsubTopic), fakeWakuMessage(), serverPeerId + ) + check recoveryRes.isOk() ## Cleanup await allFutures(clientSwitch.stop(), serverSwitch.stop()) diff --git a/tests/waku_lightpush_legacy/test_ratelimit.nim b/tests/waku_lightpush_legacy/test_ratelimit.nim index 3df8d369d..37c43a066 100644 --- a/tests/waku_lightpush_legacy/test_ratelimit.nim +++ b/tests/waku_lightpush_legacy/test_ratelimit.nim @@ -86,58 +86,52 @@ suite "Rate limited push service": await allFutures(serverSwitch.start(), clientSwitch.start()) ## Given - var handlerFuture = newFuture[(string, WakuMessage)]() let handler = proc( peer: PeerId, pubsubTopic: PubsubTopic, message: WakuMessage ): Future[WakuLightPushResult[void]] {.async.} = - handlerFuture.complete((pubsubTopic, message)) return ok() let + tokenPeriod = 500.millis server = await newTestWakuLegacyLightpushNode( - serverSwitch, handler, some((3, 500.millis)) + serverSwitch, handler, some((3, tokenPeriod)) ) client = newTestWakuLegacyLightpushClient(clientSwitch) let serverPeerId = serverSwitch.peerInfo.toRemotePeerInfo() - let topic = DefaultPubsubTopic - let successProc = proc(): Future[void] {.async.} = - let message = fakeWakuMessage() - handlerFuture = newFuture[(string, WakuMessage)]() - let requestRes = - await client.publish(DefaultPubsubTopic, message, peer = serverPeerId) - discard await handlerFuture.withTimeout(10.millis) + # Avoid assuming the exact Nth request will be rejected. With Chronos TokenBucket + # minting semantics and real network latency, CI timing can allow refills. + # Instead, send a short burst and require that we observe at least one rejection. + let burstSize = 10 + var publishFutures: seq[Future[WakuLightPushResult[string]]] = @[] + for _ in 0 ..< burstSize: + publishFutures.add( + client.publish(DefaultPubsubTopic, fakeWakuMessage(), peer = serverPeerId) + ) - check: - requestRes.isOk() - handlerFuture.finished() - let (handledMessagePubsubTopic, handledMessage) = handlerFuture.read() - check: - handledMessagePubsubTopic == DefaultPubsubTopic - handledMessage == message + let finished = await allFinished(publishFutures) + var gotOk = false + var gotTooMany = false + for fut in finished: + check not fut.failed() + let res = fut.read() + if res.isOk(): + gotOk = true + elif res.error == "TOO_MANY_REQUESTS": + gotTooMany = true - let rejectProc = proc(): Future[void] {.async.} = - let message = fakeWakuMessage() - handlerFuture = newFuture[(string, WakuMessage)]() - let requestRes = - await client.publish(DefaultPubsubTopic, message, peer = serverPeerId) - discard await handlerFuture.withTimeout(10.millis) + check: + gotOk + gotTooMany - check: - requestRes.isErr() - requestRes.error == "TOO_MANY_REQUESTS" - - for testCnt in 0 .. 2: - await successProc() - await sleepAsync(20.millis) - - await rejectProc() - - await sleepAsync(500.millis) + await sleepAsync(tokenPeriod + 100.millis) ## next one shall succeed due to the rate limit time window has passed - await successProc() + let afterCooldownRes = + await client.publish(DefaultPubsubTopic, fakeWakuMessage(), peer = serverPeerId) + check: + afterCooldownRes.isOk() ## Cleanup await allFutures(clientSwitch.stop(), serverSwitch.stop()) diff --git a/tests/waku_rln_relay/utils.nim b/tests/waku_rln_relay/utils.nim index a4247ab44..8aed18f9b 100644 --- a/tests/waku_rln_relay/utils.nim +++ b/tests/waku_rln_relay/utils.nim @@ -24,7 +24,6 @@ proc deployContract*( tr.`from` = Opt.some(web3.defaultAccount) let sData = code & contractInput tr.data = Opt.some(hexToSeqByte(sData)) - tr.gas = Opt.some(Quantity(3000000000000)) if gasPrice != 0: tr.gasPrice = Opt.some(gasPrice.Quantity) diff --git a/tests/waku_rln_relay/utils_onchain.nim b/tests/waku_rln_relay/utils_onchain.nim index d8bb13a62..9f1048097 100644 --- a/tests/waku_rln_relay/utils_onchain.nim +++ b/tests/waku_rln_relay/utils_onchain.nim @@ -529,6 +529,7 @@ proc runAnvil*( # --chain-id Chain ID of the network. # --load-state Initialize the chain from a previously saved state snapshot (read-only) # --dump-state Dump the state on exit to the given file (write-only) + # Values used are representative of Linea Sepolia testnet # See anvil documentation https://book.getfoundry.sh/reference/anvil/ for more details try: let anvilPath = getAnvilPath() @@ -539,11 +540,16 @@ proc runAnvil*( "--port", $port, "--gas-limit", - "300000000000000", + "30000000", + "--gas-price", + "7", + "--base-fee", + "7", "--balance", - "1000000000", + "10000000000", "--chain-id", $chainId, + "--disable-min-priority-fee", ] # Add state file argument if provided diff --git a/tests/waku_store/test_wakunode_store.nim b/tests/waku_store/test_wakunode_store.nim index b20309079..7d1a44ecc 100644 --- a/tests/waku_store/test_wakunode_store.nim +++ b/tests/waku_store/test_wakunode_store.nim @@ -413,7 +413,7 @@ procSuite "WakuNode - Store": for count in 0 ..< 3: waitFor successProc() - waitFor sleepAsync(20.millis) + waitFor sleepAsync(5.millis) waitFor failsProc() diff --git a/vendor/nim-chronos b/vendor/nim-chronos index 0646c444f..85af4db76 160000 --- a/vendor/nim-chronos +++ b/vendor/nim-chronos @@ -1 +1 @@ -Subproject commit 0646c444fce7c7ed08ef6f2c9a7abfd172ffe655 +Subproject commit 85af4db764ecd3573c4704139560df3943216cf1 diff --git a/vendor/nim-jwt b/vendor/nim-jwt new file mode 160000 index 000000000..18f8378de --- /dev/null +++ b/vendor/nim-jwt @@ -0,0 +1 @@ +Subproject commit 18f8378de52b241f321c1f9ea905456e89b95c6f diff --git a/vendor/nim-libp2p b/vendor/nim-libp2p index e82080f7b..eb7e6ff89 160000 --- a/vendor/nim-libp2p +++ b/vendor/nim-libp2p @@ -1 +1 @@ -Subproject commit e82080f7b1aa61c6d35fa5311b873f41eff4bb52 +Subproject commit eb7e6ff89889e41b57515f891ba82986c54809fb diff --git a/vendor/nim-lsquic b/vendor/nim-lsquic new file mode 160000 index 000000000..f3fe33462 --- /dev/null +++ b/vendor/nim-lsquic @@ -0,0 +1 @@ +Subproject commit f3fe33462601ea34eb2e8e9c357c92e61f8d121b diff --git a/waku.nimble b/waku.nimble index 5c5c09763..afc0ad634 100644 --- a/waku.nimble +++ b/waku.nimble @@ -31,6 +31,8 @@ requires "nim >= 2.2.4", "results", "db_connector", "minilru", + "lsquic", + "jwt", "ffi" ### Helper functions @@ -148,7 +150,8 @@ task chat2, "Build example Waku chat usage": let name = "chat2" buildBinary name, "apps/chat2/", - "-d:chronicles_sinks=textlines[file] -d:ssl -d:chronicles_log_level='TRACE' " + "-d:chronicles_sinks=textlines[file] -d:chronicles_log_level='TRACE' " + # -d:ssl - cause unlisted exception error in libp2p/utility... task chat2mix, "Build example Waku chat mix usage": # NOTE For debugging, set debug level. For chat usage we want minimal log @@ -158,7 +161,8 @@ task chat2mix, "Build example Waku chat mix usage": let name = "chat2mix" buildBinary name, "apps/chat2mix/", - "-d:chronicles_sinks=textlines[file] -d:ssl -d:chronicles_log_level='TRACE' " + "-d:chronicles_sinks=textlines[file] -d:chronicles_log_level='TRACE' " + # -d:ssl - cause unlisted exception error in libp2p/utility... task chat2bridge, "Build chat2bridge": let name = "chat2bridge" diff --git a/waku/common/broker/broker_context.nim b/waku/common/broker/broker_context.nim new file mode 100644 index 000000000..483a2e3a7 --- /dev/null +++ b/waku/common/broker/broker_context.nim @@ -0,0 +1,68 @@ +{.push raises: [].} + +import std/[strutils, concurrency/atomics], chronos + +type BrokerContext* = distinct uint32 + +func `==`*(a, b: BrokerContext): bool = + uint32(a) == uint32(b) + +func `!=`*(a, b: BrokerContext): bool = + uint32(a) != uint32(b) + +func `$`*(bc: BrokerContext): string = + toHex(uint32(bc), 8) + +const DefaultBrokerContext* = BrokerContext(0xCAFFE14E'u32) + +# Global broker context accessor. +# +# NOTE: This intentionally creates a *single* active BrokerContext per process +# (per event loop thread). Use only if you accept serialization of all broker +# context usage through the lock. +var globalBrokerContextLock {.threadvar.}: AsyncLock +globalBrokerContextLock = newAsyncLock() +var globalBrokerContextValue {.threadvar.}: BrokerContext +globalBrokerContextValue = DefaultBrokerContext +proc globalBrokerContext*(): BrokerContext = + ## Returns the currently active global broker context. + ## + ## This is intentionally lock-free; callers should use it inside + ## `withNewGlobalBrokerContext` / `withGlobalBrokerContext`. + globalBrokerContextValue + +var gContextCounter: Atomic[uint32] + +proc NewBrokerContext*(): BrokerContext = + var nextId = gContextCounter.fetchAdd(1, moRelaxed) + if nextId == uint32(DefaultBrokerContext): + nextId = gContextCounter.fetchAdd(1, moRelaxed) + return BrokerContext(nextId) + +template lockGlobalBrokerContext*(brokerCtx: BrokerContext, body: untyped): untyped = + ## Runs `body` while holding the global broker context lock with the provided + ## `brokerCtx` installed as the globally accessible context. + ## + ## This template is intended for use from within `chronos` async procs. + block: + await noCancel(globalBrokerContextLock.acquire()) + let previousBrokerCtx = globalBrokerContextValue + globalBrokerContextValue = brokerCtx + try: + body + finally: + globalBrokerContextValue = previousBrokerCtx + try: + globalBrokerContextLock.release() + except AsyncLockError: + doAssert false, "globalBrokerContextLock.release(): lock not held" + +template lockNewGlobalBrokerContext*(body: untyped): untyped = + ## Runs `body` while holding the global broker context lock with a freshly + ## generated broker context installed as the global accessor. + ## + ## The previous global broker context (if any) is restored on exit. + lockGlobalBrokerContext(NewBrokerContext()): + body + +{.pop.} diff --git a/waku/common/broker/event_broker.nim b/waku/common/broker/event_broker.nim index 05d7b50ab..779689f88 100644 --- a/waku/common/broker/event_broker.nim +++ b/waku/common/broker/event_broker.nim @@ -5,10 +5,35 @@ ## need for direct dependencies in between emitters and listeners. ## Worth considering using it in a single or many emitters to many listeners scenario. ## -## Generates a standalone, type-safe event broker for the declared object type. +## Generates a standalone, type-safe event broker for the declared type. ## The macro exports the value type itself plus a broker companion that manages ## listeners via thread-local storage. ## +## Type definitions: +## - Inline `object` / `ref object` definitions are supported. +## - Native types, aliases, and externally-defined types are also supported. +## In that case, EventBroker will automatically wrap the declared RHS type in +## `distinct` unless you already used `distinct`. +## This keeps event types unique even when multiple brokers share the same +## underlying base type. +## +## Default vs. context aware use: +## Every generated broker is a thread-local global instance. This means EventBroker +## enables decoupled event exchange threadwise. +## +## Sometimes we use brokers inside a context (e.g. within a component that has many +## modules or subsystems). If you instantiate multiple such components in a single +## thread, and each component must have its own listener set for the same EventBroker +## type, you can use context-aware EventBroker. +## +## Context awareness is supported through the `BrokerContext` argument for +## `listen`, `emit`, `dropListener`, and `dropAllListeners`. +## Listener stores are kept separate per broker context. +## +## Default broker context is defined as `DefaultBrokerContext`. If you don't need +## context awareness, you can keep using the interfaces without the context +## argument, which operate on `DefaultBrokerContext`. +## ## Usage: ## Declare your desired event type inside an `EventBroker` macro, add any number of fields.: ## ```nim @@ -47,87 +72,46 @@ ## GreetingEvent.dropListener(handle) ## ``` +## Example (non-object event type): +## ```nim +## EventBroker: +## type CounterEvent = int # exported as: `distinct int` +## +## discard CounterEvent.listen( +## proc(evt: CounterEvent): Future[void] {.async.} = +## echo int(evt) +## ) +## CounterEvent.emit(CounterEvent(42)) +## ``` + import std/[macros, tables] import chronos, chronicles, results -import ./helper/broker_utils +import ./helper/broker_utils, broker_context -export chronicles, results, chronos +export chronicles, results, chronos, broker_context macro EventBroker*(body: untyped): untyped = when defined(eventBrokerDebug): echo body.treeRepr - var typeIdent: NimNode = nil - var objectDef: NimNode = nil - var fieldNames: seq[NimNode] = @[] - var fieldTypes: seq[NimNode] = @[] - var isRefObject = false - for stmt in body: - if stmt.kind == nnkTypeSection: - for def in stmt: - if def.kind != nnkTypeDef: - continue - let rhs = def[2] - var objectType: NimNode - case rhs.kind - of nnkObjectTy: - objectType = rhs - of nnkRefTy: - isRefObject = true - if rhs.len != 1 or rhs[0].kind != nnkObjectTy: - error("EventBroker ref object must wrap a concrete object definition", rhs) - objectType = rhs[0] - else: - continue - if not typeIdent.isNil(): - error("Only one object type may be declared inside EventBroker", def) - typeIdent = baseTypeIdent(def[0]) - let recList = objectType[2] - if recList.kind != nnkRecList: - error("EventBroker object must declare a standard field list", objectType) - var exportedRecList = newTree(nnkRecList) - for field in recList: - case field.kind - of nnkIdentDefs: - ensureFieldDef(field) - let fieldTypeNode = field[field.len - 2] - for i in 0 ..< field.len - 2: - let baseFieldIdent = baseTypeIdent(field[i]) - fieldNames.add(copyNimTree(baseFieldIdent)) - fieldTypes.add(copyNimTree(fieldTypeNode)) - var cloned = copyNimTree(field) - for i in 0 ..< cloned.len - 2: - cloned[i] = exportIdentNode(cloned[i]) - exportedRecList.add(cloned) - of nnkEmpty: - discard - else: - error( - "EventBroker object definition only supports simple field declarations", - field, - ) - let exportedObjectType = newTree( - nnkObjectTy, - copyNimTree(objectType[0]), - copyNimTree(objectType[1]), - exportedRecList, - ) - if isRefObject: - objectDef = newTree(nnkRefTy, exportedObjectType) - else: - objectDef = exportedObjectType - if typeIdent.isNil(): - error("EventBroker body must declare exactly one object type", body) + let parsed = parseSingleTypeDef(body, "EventBroker", collectFieldInfo = true) + let typeIdent = parsed.typeIdent + let objectDef = parsed.objectDef + let fieldNames = parsed.fieldNames + let fieldTypes = parsed.fieldTypes + let hasInlineFields = parsed.hasInlineFields let exportedTypeIdent = postfix(copyNimTree(typeIdent), "*") let sanitized = sanitizeIdentName(typeIdent) let typeNameLit = newLit($typeIdent) - let isRefObjectLit = newLit(isRefObject) let handlerProcIdent = ident(sanitized & "ListenerProc") let listenerHandleIdent = ident(sanitized & "Listener") let brokerTypeIdent = ident(sanitized & "Broker") let exportedHandlerProcIdent = postfix(copyNimTree(handlerProcIdent), "*") let exportedListenerHandleIdent = postfix(copyNimTree(listenerHandleIdent), "*") let exportedBrokerTypeIdent = postfix(copyNimTree(brokerTypeIdent), "*") + let bucketTypeIdent = ident(sanitized & "CtxBucket") + let findBucketIdxIdent = ident(sanitized & "FindBucketIdx") + let getOrCreateBucketIdxIdent = ident(sanitized & "GetOrCreateBucketIdx") let accessProcIdent = ident("access" & sanitized & "Broker") let globalVarIdent = ident("g" & sanitized & "Broker") let listenImplIdent = ident("register" & sanitized & "Listener") @@ -147,10 +131,14 @@ macro EventBroker*(body: untyped): untyped = `exportedHandlerProcIdent` = proc(event: `typeIdent`): Future[void] {.async: (raises: []), gcsafe.} - `exportedBrokerTypeIdent` = ref object + `bucketTypeIdent` = object + brokerCtx: BrokerContext listeners: Table[uint64, `handlerProcIdent`] nextId: uint64 + `exportedBrokerTypeIdent` = ref object + buckets: seq[`bucketTypeIdent`] + ) result.add( @@ -163,49 +151,102 @@ macro EventBroker*(body: untyped): untyped = proc `accessProcIdent`(): `brokerTypeIdent` = if `globalVarIdent`.isNil(): new(`globalVarIdent`) - `globalVarIdent`.listeners = initTable[uint64, `handlerProcIdent`]() + `globalVarIdent`.buckets = + @[ + `bucketTypeIdent`( + brokerCtx: DefaultBrokerContext, + listeners: initTable[uint64, `handlerProcIdent`](), + nextId: 1'u64, + ) + ] `globalVarIdent` ) result.add( quote do: + proc `findBucketIdxIdent`( + broker: `brokerTypeIdent`, brokerCtx: BrokerContext + ): int = + if brokerCtx == DefaultBrokerContext: + return 0 + for i in 1 ..< broker.buckets.len: + if broker.buckets[i].brokerCtx == brokerCtx: + return i + return -1 + + proc `getOrCreateBucketIdxIdent`( + broker: `brokerTypeIdent`, brokerCtx: BrokerContext + ): int = + let idx = `findBucketIdxIdent`(broker, brokerCtx) + if idx >= 0: + return idx + broker.buckets.add( + `bucketTypeIdent`( + brokerCtx: brokerCtx, + listeners: initTable[uint64, `handlerProcIdent`](), + nextId: 1'u64, + ) + ) + return broker.buckets.high + proc `listenImplIdent`( - handler: `handlerProcIdent` + brokerCtx: BrokerContext, handler: `handlerProcIdent` ): Result[`listenerHandleIdent`, string] = if handler.isNil(): return err("Must provide a non-nil event handler") var broker = `accessProcIdent`() - if broker.nextId == 0'u64: - broker.nextId = 1'u64 - if broker.nextId == high(uint64): - error "Cannot add more listeners: ID space exhausted", nextId = $broker.nextId + + let bucketIdx = `getOrCreateBucketIdxIdent`(broker, brokerCtx) + if broker.buckets[bucketIdx].nextId == 0'u64: + broker.buckets[bucketIdx].nextId = 1'u64 + + if broker.buckets[bucketIdx].nextId == high(uint64): + error "Cannot add more listeners: ID space exhausted", + nextId = $broker.buckets[bucketIdx].nextId return err("Cannot add more listeners, listener ID space exhausted") - let newId = broker.nextId - inc broker.nextId - broker.listeners[newId] = handler + + let newId = broker.buckets[bucketIdx].nextId + inc broker.buckets[bucketIdx].nextId + broker.buckets[bucketIdx].listeners[newId] = handler return ok(`listenerHandleIdent`(id: newId)) ) result.add( quote do: - proc `dropListenerImplIdent`(handle: `listenerHandleIdent`) = + proc `dropListenerImplIdent`( + brokerCtx: BrokerContext, handle: `listenerHandleIdent` + ) = if handle.id == 0'u64: return var broker = `accessProcIdent`() - if broker.listeners.len == 0: + + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: return - broker.listeners.del(handle.id) + + if broker.buckets[bucketIdx].listeners.len == 0: + return + broker.buckets[bucketIdx].listeners.del(handle.id) + if brokerCtx != DefaultBrokerContext and + broker.buckets[bucketIdx].listeners.len == 0: + broker.buckets.delete(bucketIdx) ) result.add( quote do: - proc `dropAllListenersImplIdent`() = + proc `dropAllListenersImplIdent`(brokerCtx: BrokerContext) = var broker = `accessProcIdent`() - if broker.listeners.len > 0: - broker.listeners.clear() + + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + if broker.buckets[bucketIdx].listeners.len > 0: + broker.buckets[bucketIdx].listeners.clear() + if brokerCtx != DefaultBrokerContext: + broker.buckets.delete(bucketIdx) ) @@ -214,17 +255,34 @@ macro EventBroker*(body: untyped): untyped = proc listen*( _: typedesc[`typeIdent`], handler: `handlerProcIdent` ): Result[`listenerHandleIdent`, string] = - return `listenImplIdent`(handler) + return `listenImplIdent`(DefaultBrokerContext, handler) + + proc listen*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + handler: `handlerProcIdent`, + ): Result[`listenerHandleIdent`, string] = + return `listenImplIdent`(brokerCtx, handler) ) result.add( quote do: proc dropListener*(_: typedesc[`typeIdent`], handle: `listenerHandleIdent`) = - `dropListenerImplIdent`(handle) + `dropListenerImplIdent`(DefaultBrokerContext, handle) + + proc dropListener*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + handle: `listenerHandleIdent`, + ) = + `dropListenerImplIdent`(brokerCtx, handle) proc dropAllListeners*(_: typedesc[`typeIdent`]) = - `dropAllListenersImplIdent`() + `dropAllListenersImplIdent`(DefaultBrokerContext) + + proc dropAllListeners*(_: typedesc[`typeIdent`], brokerCtx: BrokerContext) = + `dropAllListenersImplIdent`(brokerCtx) ) @@ -241,68 +299,114 @@ macro EventBroker*(body: untyped): untyped = error "Failed to execute event listener", error = getCurrentExceptionMsg() proc `emitImplIdent`( - event: `typeIdent` + brokerCtx: BrokerContext, event: `typeIdent` ): Future[void] {.async: (raises: []), gcsafe.} = - when `isRefObjectLit`: + when compiles(event.isNil()): if event.isNil(): error "Cannot emit uninitialized event object", eventType = `typeNameLit` return let broker = `accessProcIdent`() - if broker.listeners.len == 0: + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: # nothing to do as nobody is listening return + if broker.buckets[bucketIdx].listeners.len == 0: + return var callbacks: seq[`handlerProcIdent`] = @[] - for cb in broker.listeners.values: + for cb in broker.buckets[bucketIdx].listeners.values: callbacks.add(cb) for cb in callbacks: asyncSpawn `listenerTaskIdent`(cb, event) proc emit*(event: `typeIdent`) = - asyncSpawn `emitImplIdent`(event) + asyncSpawn `emitImplIdent`(DefaultBrokerContext, event) proc emit*(_: typedesc[`typeIdent`], event: `typeIdent`) = - asyncSpawn `emitImplIdent`(event) + asyncSpawn `emitImplIdent`(DefaultBrokerContext, event) + + proc emit*( + _: typedesc[`typeIdent`], brokerCtx: BrokerContext, event: `typeIdent` + ) = + asyncSpawn `emitImplIdent`(brokerCtx, event) ) - var emitCtorParams = newTree(nnkFormalParams, newEmptyNode()) - let typedescParamType = - newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)) - emitCtorParams.add( - newTree(nnkIdentDefs, ident("_"), typedescParamType, newEmptyNode()) - ) - for i in 0 ..< fieldNames.len: + if hasInlineFields: + # Typedesc emit constructor overloads for inline object/ref object types. + var emitCtorParams = newTree(nnkFormalParams, newEmptyNode()) + let typedescParamType = + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)) emitCtorParams.add( - newTree( - nnkIdentDefs, - copyNimTree(fieldNames[i]), - copyNimTree(fieldTypes[i]), - newEmptyNode(), + newTree(nnkIdentDefs, ident("_"), typedescParamType, newEmptyNode()) + ) + for i in 0 ..< fieldNames.len: + emitCtorParams.add( + newTree( + nnkIdentDefs, + copyNimTree(fieldNames[i]), + copyNimTree(fieldTypes[i]), + newEmptyNode(), + ) ) + + var emitCtorExpr = newTree(nnkObjConstr, copyNimTree(typeIdent)) + for i in 0 ..< fieldNames.len: + emitCtorExpr.add( + newTree( + nnkExprColonExpr, copyNimTree(fieldNames[i]), copyNimTree(fieldNames[i]) + ) + ) + + let emitCtorCallDefault = + newCall(copyNimTree(emitImplIdent), ident("DefaultBrokerContext"), emitCtorExpr) + let emitCtorBodyDefault = quote: + asyncSpawn `emitCtorCallDefault` + + let typedescEmitProcDefault = newTree( + nnkProcDef, + postfix(ident("emit"), "*"), + newEmptyNode(), + newEmptyNode(), + emitCtorParams, + newEmptyNode(), + newEmptyNode(), + emitCtorBodyDefault, ) + result.add(typedescEmitProcDefault) - var emitCtorExpr = newTree(nnkObjConstr, copyNimTree(typeIdent)) - for i in 0 ..< fieldNames.len: - emitCtorExpr.add( - newTree(nnkExprColonExpr, copyNimTree(fieldNames[i]), copyNimTree(fieldNames[i])) + var emitCtorParamsCtx = newTree(nnkFormalParams, newEmptyNode()) + emitCtorParamsCtx.add( + newTree(nnkIdentDefs, ident("_"), typedescParamType, newEmptyNode()) ) + emitCtorParamsCtx.add( + newTree(nnkIdentDefs, ident("brokerCtx"), ident("BrokerContext"), newEmptyNode()) + ) + for i in 0 ..< fieldNames.len: + emitCtorParamsCtx.add( + newTree( + nnkIdentDefs, + copyNimTree(fieldNames[i]), + copyNimTree(fieldTypes[i]), + newEmptyNode(), + ) + ) - let emitCtorCall = newCall(copyNimTree(emitImplIdent), emitCtorExpr) - let emitCtorBody = quote: - asyncSpawn `emitCtorCall` + let emitCtorCallCtx = + newCall(copyNimTree(emitImplIdent), ident("brokerCtx"), copyNimTree(emitCtorExpr)) + let emitCtorBodyCtx = quote: + asyncSpawn `emitCtorCallCtx` - let typedescEmitProc = newTree( - nnkProcDef, - postfix(ident("emit"), "*"), - newEmptyNode(), - newEmptyNode(), - emitCtorParams, - newEmptyNode(), - newEmptyNode(), - emitCtorBody, - ) - - result.add(typedescEmitProc) + let typedescEmitProcCtx = newTree( + nnkProcDef, + postfix(ident("emit"), "*"), + newEmptyNode(), + newEmptyNode(), + emitCtorParamsCtx, + newEmptyNode(), + newEmptyNode(), + emitCtorBodyCtx, + ) + result.add(typedescEmitProcCtx) when defined(eventBrokerDebug): echo result.repr diff --git a/waku/common/broker/helper/broker_utils.nim b/waku/common/broker/helper/broker_utils.nim index ea9f85750..90f2055d3 100644 --- a/waku/common/broker/helper/broker_utils.nim +++ b/waku/common/broker/helper/broker_utils.nim @@ -1,5 +1,21 @@ import std/macros +type ParsedBrokerType* = object + ## Result of parsing the single `type` definition inside a broker macro body. + ## + ## - `typeIdent`: base identifier for the declared type name + ## - `objectDef`: exported type definition RHS (inline object fields exported; + ## non-object types wrapped in `distinct` unless already distinct) + ## - `isRefObject`: true only for inline `ref object` definitions + ## - `hasInlineFields`: true for inline `object` / `ref object` + ## - `fieldNames`/`fieldTypes`: populated only when `collectFieldInfo = true` + typeIdent*: NimNode + objectDef*: NimNode + isRefObject*: bool + hasInlineFields*: bool + fieldNames*: seq[NimNode] + fieldTypes*: seq[NimNode] + proc sanitizeIdentName*(node: NimNode): string = var raw = $node var sanitizedName = newStringOfCap(raw.len) @@ -41,3 +57,150 @@ proc baseTypeIdent*(defName: NimNode): NimNode = baseTypeIdent(defName[0]) else: error("Unsupported type name in broker definition", defName) + +proc ensureDistinctType*(rhs: NimNode): NimNode = + ## For PODs / aliases / externally-defined types, wrap in `distinct` unless + ## it's already distinct. + if rhs.kind == nnkDistinctTy: + return copyNimTree(rhs) + newTree(nnkDistinctTy, copyNimTree(rhs)) + +proc cloneParams*(params: seq[NimNode]): seq[NimNode] = + ## Deep copy parameter definitions so they can be inserted in multiple places. + result = @[] + for param in params: + result.add(copyNimTree(param)) + +proc collectParamNames*(params: seq[NimNode]): seq[NimNode] = + ## Extract all identifier symbols declared across IdentDefs nodes. + result = @[] + for param in params: + assert param.kind == nnkIdentDefs + for i in 0 ..< param.len - 2: + let nameNode = param[i] + if nameNode.kind == nnkEmpty: + continue + result.add(ident($nameNode)) + +proc parseSingleTypeDef*( + body: NimNode, + macroName: string, + allowRefToNonObject = false, + collectFieldInfo = false, +): ParsedBrokerType = + ## Parses exactly one `type` definition from a broker macro body. + ## + ## Supported RHS: + ## - inline `object` / `ref object` (fields are auto-exported) + ## - non-object types / aliases / externally-defined types (wrapped in `distinct`) + ## - optionally: `ref SomeType` when `allowRefToNonObject = true` + var typeIdent: NimNode = nil + var objectDef: NimNode = nil + var isRefObject = false + var hasInlineFields = false + var fieldNames: seq[NimNode] = @[] + var fieldTypes: seq[NimNode] = @[] + + for stmt in body: + if stmt.kind != nnkTypeSection: + continue + for def in stmt: + if def.kind != nnkTypeDef: + continue + if not typeIdent.isNil(): + error("Only one type may be declared inside " & macroName, def) + typeIdent = baseTypeIdent(def[0]) + let rhs = def[2] + + case rhs.kind + of nnkObjectTy: + let recList = rhs[2] + if recList.kind != nnkRecList: + error(macroName & " object must declare a standard field list", rhs) + var exportedRecList = newTree(nnkRecList) + for field in recList: + case field.kind + of nnkIdentDefs: + ensureFieldDef(field) + if collectFieldInfo: + let fieldTypeNode = field[field.len - 2] + for i in 0 ..< field.len - 2: + let baseFieldIdent = baseTypeIdent(field[i]) + fieldNames.add(copyNimTree(baseFieldIdent)) + fieldTypes.add(copyNimTree(fieldTypeNode)) + var cloned = copyNimTree(field) + for i in 0 ..< cloned.len - 2: + cloned[i] = exportIdentNode(cloned[i]) + exportedRecList.add(cloned) + of nnkEmpty: + discard + else: + error( + macroName & " object definition only supports simple field declarations", + field, + ) + objectDef = newTree( + nnkObjectTy, copyNimTree(rhs[0]), copyNimTree(rhs[1]), exportedRecList + ) + isRefObject = false + hasInlineFields = true + of nnkRefTy: + if rhs.len != 1: + error(macroName & " ref type must have a single base", rhs) + if rhs[0].kind == nnkObjectTy: + let obj = rhs[0] + let recList = obj[2] + if recList.kind != nnkRecList: + error(macroName & " object must declare a standard field list", obj) + var exportedRecList = newTree(nnkRecList) + for field in recList: + case field.kind + of nnkIdentDefs: + ensureFieldDef(field) + if collectFieldInfo: + let fieldTypeNode = field[field.len - 2] + for i in 0 ..< field.len - 2: + let baseFieldIdent = baseTypeIdent(field[i]) + fieldNames.add(copyNimTree(baseFieldIdent)) + fieldTypes.add(copyNimTree(fieldTypeNode)) + var cloned = copyNimTree(field) + for i in 0 ..< cloned.len - 2: + cloned[i] = exportIdentNode(cloned[i]) + exportedRecList.add(cloned) + of nnkEmpty: + discard + else: + error( + macroName & " object definition only supports simple field declarations", + field, + ) + let exportedObjectType = newTree( + nnkObjectTy, copyNimTree(obj[0]), copyNimTree(obj[1]), exportedRecList + ) + objectDef = newTree(nnkRefTy, exportedObjectType) + isRefObject = true + hasInlineFields = true + elif allowRefToNonObject: + ## `ref SomeType` (SomeType can be defined elsewhere) + objectDef = ensureDistinctType(rhs) + isRefObject = false + hasInlineFields = false + else: + error(macroName & " ref object must wrap a concrete object definition", rhs) + else: + ## Non-object type / alias. + objectDef = ensureDistinctType(rhs) + isRefObject = false + hasInlineFields = false + + if typeIdent.isNil(): + error(macroName & " body must declare exactly one type", body) + + result = ParsedBrokerType( + typeIdent: typeIdent, + objectDef: objectDef, + isRefObject: isRefObject, + hasInlineFields: hasInlineFields, + fieldNames: fieldNames, + fieldTypes: fieldTypes, + ) diff --git a/waku/common/broker/multi_request_broker.nim b/waku/common/broker/multi_request_broker.nim index 7f4161f5a..2baa19940 100644 --- a/waku/common/broker/multi_request_broker.nim +++ b/waku/common/broker/multi_request_broker.nim @@ -5,12 +5,35 @@ ## need for direct dependencies in between. ## Worth considering using it for use cases where you need to collect data from multiple providers. ## -## Provides a declarative way to define an immutable value type together with a -## thread-local broker that can register multiple asynchronous providers, dispatch -## typed requests, and clear handlers. Unlike `RequestBroker`, -## every call to `request` fan-outs to every registered provider and returns with -## collected responses. -## Request succeeds if all providers succeed, otherwise fails with an error. +## Generates a standalone, type-safe request broker for the declared type. +## The macro exports the value type itself plus a broker companion that manages +## providers via thread-local storage. +## +## Unlike `RequestBroker`, every call to `request` fan-outs to every registered +## provider and returns all collected responses. +## The request succeeds only if all providers succeed, otherwise it fails. +## +## Type definitions: +## - Inline `object` / `ref object` definitions are supported. +## - Native types, aliases, and externally-defined types are also supported. +## In that case, MultiRequestBroker will automatically wrap the declared RHS +## type in `distinct` unless you already used `distinct`. +## This keeps request types unique even when multiple brokers share the same +## underlying base type. +## +## Default vs. context aware use: +## Every generated broker is a thread-local global instance. +## Sometimes you want multiple independent provider sets for the same request +## type within the same thread (e.g. multiple components). For that, you can use +## context-aware MultiRequestBroker. +## +## Context awareness is supported through the `BrokerContext` argument for +## `setProvider`, `request`, `removeProvider`, and `clearProviders`. +## Provider stores are kept separate per broker context. +## +## Default broker context is defined as `DefaultBrokerContext`. If you don't +## need context awareness, you can keep using the interfaces without the context +## argument, which operate on `DefaultBrokerContext`. ## ## Usage: ## @@ -29,14 +52,17 @@ ## ## ``` ## -## You regiser request processor (proveder) at any place of the code without the need to know of who ever may request. -## Respectively to the defined signatures register provider functions with `TypeName.setProvider(...)`. -## Providers are async procs or lambdas that return with a Future[Result[seq[TypeName], string]]. -## Notice MultiRequestBroker's `setProvider` return with a handler that can be used to remove the provider later (or error). +## You can register a request processor (provider) anywhere without the need to +## know who will request. +## Register provider functions with `TypeName.setProvider(...)`. +## Providers are async procs or lambdas that return `Future[Result[TypeName, string]]`. +## `setProvider` returns a handle (or an error) that can later be used to remove +## the provider. -## Requests can be made from anywhere with no direct dependency on the provider(s) by -## calling `TypeName.request()` - with arguments respecting the signature(s). -## This will asynchronously call the registered provider and return the collected data, in form of `Future[Result[seq[TypeName], string]]`. +## Requests can be made from anywhere with no direct dependency on the provider(s) +## by calling `TypeName.request()` (with arguments respecting the declared signature). +## This will asynchronously call all registered providers and return the collected +## responses as `Future[Result[seq[TypeName], string]]`. ## ## Whenever you don't want to process requests anymore (or your object instance that provides the request goes out of scope), ## you can remove it from the broker with `TypeName.removeProvider(handle)`. @@ -77,8 +103,9 @@ import std/[macros, strutils, tables, sugar] import chronos import results import ./helper/broker_utils +import ./broker_context -export results, chronos +export results, chronos, broker_context proc isReturnTypeValid(returnType, typeIdent: NimNode): bool = ## Accept Future[Result[TypeIdent, string]] as the contract. @@ -95,23 +122,6 @@ proc isReturnTypeValid(returnType, typeIdent: NimNode): bool = return false inner[2].kind == nnkIdent and inner[2].eqIdent("string") -proc cloneParams(params: seq[NimNode]): seq[NimNode] = - ## Deep copy parameter definitions so they can be reused in generated nodes. - result = @[] - for param in params: - result.add(copyNimTree(param)) - -proc collectParamNames(params: seq[NimNode]): seq[NimNode] = - ## Extract identifiers declared in parameter definitions. - result = @[] - for param in params: - assert param.kind == nnkIdentDefs - for i in 0 ..< param.len - 2: - let nameNode = param[i] - if nameNode.kind == nnkEmpty: - continue - result.add(ident($nameNode)) - proc makeProcType(returnType: NimNode, params: seq[NimNode]): NimNode = var formal = newTree(nnkFormalParams) formal.add(returnType) @@ -126,65 +136,10 @@ proc makeProcType(returnType: NimNode, params: seq[NimNode]): NimNode = macro MultiRequestBroker*(body: untyped): untyped = when defined(requestBrokerDebug): echo body.treeRepr - var typeIdent: NimNode = nil - var objectDef: NimNode = nil - var isRefObject = false - for stmt in body: - if stmt.kind == nnkTypeSection: - for def in stmt: - if def.kind != nnkTypeDef: - continue - let rhs = def[2] - var objectType: NimNode - case rhs.kind - of nnkObjectTy: - objectType = rhs - of nnkRefTy: - isRefObject = true - if rhs.len != 1 or rhs[0].kind != nnkObjectTy: - error( - "MultiRequestBroker ref object must wrap a concrete object definition", - rhs, - ) - objectType = rhs[0] - else: - continue - if not typeIdent.isNil(): - error("Only one object type may be declared inside MultiRequestBroker", def) - typeIdent = baseTypeIdent(def[0]) - let recList = objectType[2] - if recList.kind != nnkRecList: - error( - "MultiRequestBroker object must declare a standard field list", objectType - ) - var exportedRecList = newTree(nnkRecList) - for field in recList: - case field.kind - of nnkIdentDefs: - ensureFieldDef(field) - var cloned = copyNimTree(field) - for i in 0 ..< cloned.len - 2: - cloned[i] = exportIdentNode(cloned[i]) - exportedRecList.add(cloned) - of nnkEmpty: - discard - else: - error( - "MultiRequestBroker object definition only supports simple field declarations", - field, - ) - let exportedObjectType = newTree( - nnkObjectTy, - copyNimTree(objectType[0]), - copyNimTree(objectType[1]), - exportedRecList, - ) - if isRefObject: - objectDef = newTree(nnkRefTy, exportedObjectType) - else: - objectDef = exportedObjectType - if typeIdent.isNil(): - error("MultiRequestBroker body must declare exactly one object type", body) + let parsed = parseSingleTypeDef(body, "MultiRequestBroker") + let typeIdent = parsed.typeIdent + let objectDef = parsed.objectDef + let isRefObject = parsed.isRefObject when defined(requestBrokerDebug): echo "MultiRequestBroker generating type: ", $typeIdent @@ -193,12 +148,13 @@ macro MultiRequestBroker*(body: untyped): untyped = let sanitized = sanitizeIdentName(typeIdent) let typeNameLit = newLit($typeIdent) let isRefObjectLit = newLit(isRefObject) - let tableSym = bindSym"Table" - let initTableSym = bindSym"initTable" let uint64Ident = ident("uint64") let providerKindIdent = ident(sanitized & "ProviderKind") let providerHandleIdent = ident(sanitized & "ProviderHandle") let exportedProviderHandleIdent = postfix(copyNimTree(providerHandleIdent), "*") + let bucketTypeIdent = ident(sanitized & "CtxBucket") + let findBucketIdxIdent = ident(sanitized & "FindBucketIdx") + let getOrCreateBucketIdxIdent = ident(sanitized & "GetOrCreateBucketIdx") let zeroKindIdent = ident("pk" & sanitized & "NoArgs") let argKindIdent = ident("pk" & sanitized & "WithArgs") var zeroArgSig: NimNode = nil @@ -306,63 +262,90 @@ macro MultiRequestBroker*(body: untyped): untyped = let procType = makeProcType(returnType, cloneParams(argParams)) typeSection.add(newTree(nnkTypeDef, argProviderName, newEmptyNode(), procType)) - var brokerRecList = newTree(nnkRecList) + var bucketRecList = newTree(nnkRecList) + bucketRecList.add( + newTree(nnkIdentDefs, ident("brokerCtx"), ident("BrokerContext"), newEmptyNode()) + ) if not zeroArgSig.isNil(): - brokerRecList.add( + bucketRecList.add( newTree( nnkIdentDefs, zeroArgFieldName, - newTree(nnkBracketExpr, tableSym, uint64Ident, zeroArgProviderName), + newTree(nnkBracketExpr, ident("seq"), zeroArgProviderName), newEmptyNode(), ) ) if not argSig.isNil(): - brokerRecList.add( + bucketRecList.add( newTree( nnkIdentDefs, argFieldName, - newTree(nnkBracketExpr, tableSym, uint64Ident, argProviderName), + newTree(nnkBracketExpr, ident("seq"), argProviderName), newEmptyNode(), ) ) - brokerRecList.add(newTree(nnkIdentDefs, ident("nextId"), uint64Ident, newEmptyNode())) - let brokerTypeIdent = ident(sanitizeIdentName(typeIdent) & "Broker") - let brokerTypeDef = newTree( - nnkTypeDef, - brokerTypeIdent, - newEmptyNode(), + typeSection.add( newTree( - nnkRefTy, newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), brokerRecList) - ), + nnkTypeDef, + bucketTypeIdent, + newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), bucketRecList), + ) + ) + + var brokerRecList = newTree(nnkRecList) + brokerRecList.add( + newTree( + nnkIdentDefs, + ident("buckets"), + newTree(nnkBracketExpr, ident("seq"), bucketTypeIdent), + newEmptyNode(), + ) + ) + let brokerTypeIdent = ident(sanitizeIdentName(typeIdent) & "Broker") + typeSection.add( + newTree( + nnkTypeDef, + brokerTypeIdent, + newEmptyNode(), + newTree( + nnkRefTy, newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), brokerRecList) + ), + ) ) - typeSection.add(brokerTypeDef) result = newStmtList() result.add(typeSection) let globalVarIdent = ident("g" & sanitizeIdentName(typeIdent) & "Broker") let accessProcIdent = ident("access" & sanitizeIdentName(typeIdent) & "Broker") - var initStatements = newStmtList() - if not zeroArgSig.isNil(): - initStatements.add( - quote do: - `globalVarIdent`.`zeroArgFieldName` = - `initTableSym`[`uint64Ident`, `zeroArgProviderName`]() - ) - if not argSig.isNil(): - initStatements.add( - quote do: - `globalVarIdent`.`argFieldName` = - `initTableSym`[`uint64Ident`, `argProviderName`]() - ) result.add( quote do: var `globalVarIdent` {.threadvar.}: `brokerTypeIdent` + proc `findBucketIdxIdent`( + broker: `brokerTypeIdent`, brokerCtx: BrokerContext + ): int = + if brokerCtx == DefaultBrokerContext: + return 0 + for i in 1 ..< broker.buckets.len: + if broker.buckets[i].brokerCtx == brokerCtx: + return i + return -1 + + proc `getOrCreateBucketIdxIdent`( + broker: `brokerTypeIdent`, brokerCtx: BrokerContext + ): int = + let idx = `findBucketIdxIdent`(broker, brokerCtx) + if idx >= 0: + return idx + broker.buckets.add(`bucketTypeIdent`(brokerCtx: brokerCtx)) + return broker.buckets.high + proc `accessProcIdent`(): `brokerTypeIdent` = if `globalVarIdent`.isNil(): new(`globalVarIdent`) - `globalVarIdent`.nextId = 1'u64 - `initStatements` + `globalVarIdent`.buckets = + @[`bucketTypeIdent`(brokerCtx: DefaultBrokerContext)] return `globalVarIdent` ) @@ -372,40 +355,47 @@ macro MultiRequestBroker*(body: untyped): untyped = result.add( quote do: proc setProvider*( - _: typedesc[`typeIdent`], handler: `zeroArgProviderName` + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + handler: `zeroArgProviderName`, ): Result[`providerHandleIdent`, string] = if handler.isNil(): return err("Provider handler must be provided") let broker = `accessProcIdent`() - if broker.nextId == 0'u64: - broker.nextId = 1'u64 - for existingId, existing in broker.`zeroArgFieldName`.pairs: - if existing == handler: - return ok(`providerHandleIdent`(id: existingId, kind: `zeroKindIdent`)) - let newId = broker.nextId - inc broker.nextId - broker.`zeroArgFieldName`[newId] = handler - return ok(`providerHandleIdent`(id: newId, kind: `zeroKindIdent`)) + let bucketIdx = `getOrCreateBucketIdxIdent`(broker, brokerCtx) + for i, existing in broker.buckets[bucketIdx].`zeroArgFieldName`: + if not existing.isNil() and existing == handler: + return ok(`providerHandleIdent`(id: uint64(i + 1), kind: `zeroKindIdent`)) + broker.buckets[bucketIdx].`zeroArgFieldName`.add(handler) + return ok( + `providerHandleIdent`( + id: uint64(broker.buckets[bucketIdx].`zeroArgFieldName`.len), + kind: `zeroKindIdent`, + ) + ) + + proc setProvider*( + _: typedesc[`typeIdent`], handler: `zeroArgProviderName` + ): Result[`providerHandleIdent`, string] = + return setProvider(`typeIdent`, DefaultBrokerContext, handler) - ) - clearBody.add( - quote do: - let broker = `accessProcIdent`() - if not broker.isNil() and broker.`zeroArgFieldName`.len > 0: - broker.`zeroArgFieldName`.clear() ) result.add( quote do: proc request*( - _: typedesc[`typeIdent`] + _: typedesc[`typeIdent`], brokerCtx: BrokerContext ): Future[Result[seq[`typeIdent`], string]] {.async: (raises: []), gcsafe.} = var aggregated: seq[`typeIdent`] = @[] - let providers = `accessProcIdent`().`zeroArgFieldName` + let broker = `accessProcIdent`() + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return ok(aggregated) + let providers = broker.buckets[bucketIdx].`zeroArgFieldName` if providers.len == 0: return ok(aggregated) # var providersFut: seq[Future[Result[`typeIdent`, string]]] = collect: var providersFut = collect(newSeq): - for provider in providers.values: + for provider in providers: if provider.isNil(): continue provider() @@ -435,32 +425,40 @@ macro MultiRequestBroker*(body: untyped): untyped = return ok(aggregated) + proc request*( + _: typedesc[`typeIdent`] + ): Future[Result[seq[`typeIdent`], string]] = + return request(`typeIdent`, DefaultBrokerContext) + ) if not argSig.isNil(): result.add( quote do: proc setProvider*( - _: typedesc[`typeIdent`], handler: `argProviderName` + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + handler: `argProviderName`, ): Result[`providerHandleIdent`, string] = if handler.isNil(): return err("Provider handler must be provided") let broker = `accessProcIdent`() - if broker.nextId == 0'u64: - broker.nextId = 1'u64 - for existingId, existing in broker.`argFieldName`.pairs: - if existing == handler: - return ok(`providerHandleIdent`(id: existingId, kind: `argKindIdent`)) - let newId = broker.nextId - inc broker.nextId - broker.`argFieldName`[newId] = handler - return ok(`providerHandleIdent`(id: newId, kind: `argKindIdent`)) + let bucketIdx = `getOrCreateBucketIdxIdent`(broker, brokerCtx) + for i, existing in broker.buckets[bucketIdx].`argFieldName`: + if not existing.isNil() and existing == handler: + return ok(`providerHandleIdent`(id: uint64(i + 1), kind: `argKindIdent`)) + broker.buckets[bucketIdx].`argFieldName`.add(handler) + return ok( + `providerHandleIdent`( + id: uint64(broker.buckets[bucketIdx].`argFieldName`.len), + kind: `argKindIdent`, + ) + ) + + proc setProvider*( + _: typedesc[`typeIdent`], handler: `argProviderName` + ): Result[`providerHandleIdent`, string] = + return setProvider(`typeIdent`, DefaultBrokerContext, handler) - ) - clearBody.add( - quote do: - let broker = `accessProcIdent`() - if not broker.isNil() and broker.`argFieldName`.len > 0: - broker.`argFieldName`.clear() ) let requestParamDefs = cloneParams(argParams) let argNameIdents = collectParamNames(requestParamDefs) @@ -481,17 +479,24 @@ macro MultiRequestBroker*(body: untyped): untyped = newEmptyNode(), ) ) + formalParams.add( + newTree(nnkIdentDefs, ident("brokerCtx"), ident("BrokerContext"), newEmptyNode()) + ) for paramDef in requestParamDefs: formalParams.add(paramDef) let requestPragmas = quote: {.async: (raises: []), gcsafe.} let requestBody = quote: var aggregated: seq[`typeIdent`] = @[] - let providers = `accessProcIdent`().`argFieldName` + let broker = `accessProcIdent`() + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return ok(aggregated) + let providers = broker.buckets[bucketIdx].`argFieldName` if providers.len == 0: return ok(aggregated) var providersFut = collect(newSeq): - for provider in providers.values: + for provider in providers: if provider.isNil(): continue let `providerSym` = provider @@ -531,53 +536,208 @@ macro MultiRequestBroker*(body: untyped): untyped = ) ) - result.add( - quote do: - proc clearProviders*(_: typedesc[`typeIdent`]) = - `clearBody` - let broker = `accessProcIdent`() - if not broker.isNil(): - broker.nextId = 1'u64 - - ) - - let removeHandleSym = genSym(nskParam, "handle") - let removeBrokerSym = genSym(nskLet, "broker") - var removeBody = newStmtList() - removeBody.add( - quote do: - if `removeHandleSym`.id == 0'u64: - return - let `removeBrokerSym` = `accessProcIdent`() - if `removeBrokerSym`.isNil(): - return - ) - if not zeroArgSig.isNil(): - removeBody.add( + # Backward-compatible default-context overload (no brokerCtx parameter). + var formalParamsDefault = newTree(nnkFormalParams) + formalParamsDefault.add( quote do: - if `removeHandleSym`.kind == `zeroKindIdent`: - `removeBrokerSym`.`zeroArgFieldName`.del(`removeHandleSym`.id) - return + Future[Result[seq[`typeIdent`], string]] ) - if not argSig.isNil(): - removeBody.add( - quote do: - if `removeHandleSym`.kind == `argKindIdent`: - `removeBrokerSym`.`argFieldName`.del(`removeHandleSym`.id) - return + formalParamsDefault.add( + newTree( + nnkIdentDefs, + ident("_"), + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)), + newEmptyNode(), + ) ) - removeBody.add( - quote do: - discard - ) - result.add( - quote do: - proc removeProvider*( - _: typedesc[`typeIdent`], `removeHandleSym`: `providerHandleIdent` - ) = - `removeBody` + for paramDef in requestParamDefs: + formalParamsDefault.add(copyNimTree(paramDef)) - ) + var wrapperCall = newCall(ident("request")) + wrapperCall.add(copyNimTree(typeIdent)) + wrapperCall.add(ident("DefaultBrokerContext")) + for argName in argNameIdents: + wrapperCall.add(copyNimTree(argName)) + + result.add( + newTree( + nnkProcDef, + postfix(ident("request"), "*"), + newEmptyNode(), + newEmptyNode(), + formalParamsDefault, + newEmptyNode(), + newEmptyNode(), + newStmtList(newTree(nnkReturnStmt, wrapperCall)), + ) + ) + let removeHandleCtxSym = genSym(nskParam, "handle") + let removeHandleDefaultSym = genSym(nskParam, "handle") + + when true: + # Generate clearProviders / removeProvider with macro-time knowledge about which + # provider lists exist (zero-arg and/or arg providers). + if not zeroArgSig.isNil() and not argSig.isNil(): + result.add( + quote do: + proc clearProviders*(_: typedesc[`typeIdent`], brokerCtx: BrokerContext) = + let broker = `accessProcIdent`() + if broker.isNil(): + return + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + broker.buckets[bucketIdx].`zeroArgFieldName`.setLen(0) + broker.buckets[bucketIdx].`argFieldName`.setLen(0) + if brokerCtx != DefaultBrokerContext: + broker.buckets.delete(bucketIdx) + + proc clearProviders*(_: typedesc[`typeIdent`]) = + clearProviders(`typeIdent`, DefaultBrokerContext) + + proc removeProvider*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + `removeHandleCtxSym`: `providerHandleIdent`, + ) = + if `removeHandleCtxSym`.id == 0'u64: + return + let broker = `accessProcIdent`() + if broker.isNil(): + return + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + + if `removeHandleCtxSym`.kind == `zeroKindIdent`: + let idx = int(`removeHandleCtxSym`.id) - 1 + if idx >= 0 and idx < broker.buckets[bucketIdx].`zeroArgFieldName`.len: + broker.buckets[bucketIdx].`zeroArgFieldName`[idx] = nil + elif `removeHandleCtxSym`.kind == `argKindIdent`: + let idx = int(`removeHandleCtxSym`.id) - 1 + if idx >= 0 and idx < broker.buckets[bucketIdx].`argFieldName`.len: + broker.buckets[bucketIdx].`argFieldName`[idx] = nil + + if brokerCtx != DefaultBrokerContext: + var hasAny = false + for p in broker.buckets[bucketIdx].`zeroArgFieldName`: + if not p.isNil(): + hasAny = true + break + if not hasAny: + for p in broker.buckets[bucketIdx].`argFieldName`: + if not p.isNil(): + hasAny = true + break + if not hasAny: + broker.buckets.delete(bucketIdx) + + proc removeProvider*( + _: typedesc[`typeIdent`], `removeHandleDefaultSym`: `providerHandleIdent` + ) = + removeProvider(`typeIdent`, DefaultBrokerContext, `removeHandleDefaultSym`) + + ) + elif not zeroArgSig.isNil(): + result.add( + quote do: + proc clearProviders*(_: typedesc[`typeIdent`], brokerCtx: BrokerContext) = + let broker = `accessProcIdent`() + if broker.isNil(): + return + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + broker.buckets[bucketIdx].`zeroArgFieldName`.setLen(0) + if brokerCtx != DefaultBrokerContext: + broker.buckets.delete(bucketIdx) + + proc clearProviders*(_: typedesc[`typeIdent`]) = + clearProviders(`typeIdent`, DefaultBrokerContext) + + proc removeProvider*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + `removeHandleCtxSym`: `providerHandleIdent`, + ) = + if `removeHandleCtxSym`.id == 0'u64: + return + let broker = `accessProcIdent`() + if broker.isNil(): + return + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + if `removeHandleCtxSym`.kind != `zeroKindIdent`: + return + let idx = int(`removeHandleCtxSym`.id) - 1 + if idx >= 0 and idx < broker.buckets[bucketIdx].`zeroArgFieldName`.len: + broker.buckets[bucketIdx].`zeroArgFieldName`[idx] = nil + if brokerCtx != DefaultBrokerContext: + var hasAny = false + for p in broker.buckets[bucketIdx].`zeroArgFieldName`: + if not p.isNil(): + hasAny = true + break + if not hasAny: + broker.buckets.delete(bucketIdx) + + proc removeProvider*( + _: typedesc[`typeIdent`], `removeHandleDefaultSym`: `providerHandleIdent` + ) = + removeProvider(`typeIdent`, DefaultBrokerContext, `removeHandleDefaultSym`) + + ) + else: + result.add( + quote do: + proc clearProviders*(_: typedesc[`typeIdent`], brokerCtx: BrokerContext) = + let broker = `accessProcIdent`() + if broker.isNil(): + return + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + broker.buckets[bucketIdx].`argFieldName`.setLen(0) + if brokerCtx != DefaultBrokerContext: + broker.buckets.delete(bucketIdx) + + proc clearProviders*(_: typedesc[`typeIdent`]) = + clearProviders(`typeIdent`, DefaultBrokerContext) + + proc removeProvider*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + `removeHandleCtxSym`: `providerHandleIdent`, + ) = + if `removeHandleCtxSym`.id == 0'u64: + return + let broker = `accessProcIdent`() + if broker.isNil(): + return + let bucketIdx = `findBucketIdxIdent`(broker, brokerCtx) + if bucketIdx < 0: + return + if `removeHandleCtxSym`.kind != `argKindIdent`: + return + let idx = int(`removeHandleCtxSym`.id) - 1 + if idx >= 0 and idx < broker.buckets[bucketIdx].`argFieldName`.len: + broker.buckets[bucketIdx].`argFieldName`[idx] = nil + if brokerCtx != DefaultBrokerContext: + var hasAny = false + for p in broker.buckets[bucketIdx].`argFieldName`: + if not p.isNil(): + hasAny = true + break + if not hasAny: + broker.buckets.delete(bucketIdx) + + proc removeProvider*( + _: typedesc[`typeIdent`], `removeHandleDefaultSym`: `providerHandleIdent` + ) = + removeProvider(`typeIdent`, DefaultBrokerContext, `removeHandleDefaultSym`) + + ) when defined(requestBrokerDebug): echo result.repr diff --git a/waku/common/broker/request_broker.nim b/waku/common/broker/request_broker.nim index dece77381..46f7d7d16 100644 --- a/waku/common/broker/request_broker.nim +++ b/waku/common/broker/request_broker.nim @@ -16,6 +16,18 @@ ## `async` mode is better to be used when you request date that may involve some long IO operation ## or action. ## +## Default vs. context aware use: +## Every generated broker is a thread-local global instance. This means each RequestBroker enables decoupled +## data exchange threadwise. Sometimes we use brokers inside a context - like inside a component that has many modules or subsystems. +## In case you would instantiate multiple such components in a single thread, and each component must has its own provider for the same RequestBroker type, +## in order to avoid provider collision, you can use context aware RequestBroker. +## Context awareness is supported through the `BrokerContext` argument for `setProvider`, `request`, `clearProvider` interfaces. +## Suce use requires generating a new unique `BrokerContext` value per component instance, and spread it to all modules using the brokers. +## Example, store the `BrokerContext` as a field inside the top level component instance, and spread around at initialization of the subcomponents.. +## +## Default broker context is defined as `DefaultBrokerContext` constant. But if you don't need context awareness, you can use the +## interfaces without context argument. +## ## Usage: ## Declare your desired request type inside a `RequestBroker` macro, add any number of fields. ## Define the provider signature, that is enforced at compile time. @@ -89,7 +101,13 @@ ## After this, you can register a provider anywhere in your code with ## `TypeName.setProvider(...)`, which returns error if already having a provider. ## Providers are async procs/lambdas in default mode and sync procs in sync mode. -## Only one provider can be registered at a time per signature type (zero arg and/or multi arg). +## +## Providers are stored as a broker-context keyed list: +## - the default provider is always stored at index 0 (reserved broker context: 0) +## - additional providers can be registered under arbitrary non-zero broker contexts +## +## The original `setProvider(handler)` / `request(...)` APIs continue to operate +## on the default provider (broker context 0) for backward compatibility. ## ## Requests can be made from anywhere with no direct dependency on the provider by ## calling `TypeName.request()` - with arguments respecting the signature(s). @@ -139,11 +157,12 @@ ## automatically, so the caller only needs to provide the type definition. import std/[macros, strutils] +from std/sequtils import keepItIf import chronos import results -import ./helper/broker_utils +import ./helper/broker_utils, broker_context -export results, chronos +export results, chronos, keepItIf, broker_context proc errorFuture[T](message: string): Future[Result[T, string]] {.inline.} = ## Build a future that is already completed with an error result. @@ -187,23 +206,6 @@ proc isReturnTypeValid(returnType, typeIdent: NimNode, mode: RequestBrokerMode): of rbSync: isSyncReturnTypeValid(returnType, typeIdent) -proc cloneParams(params: seq[NimNode]): seq[NimNode] = - ## Deep copy parameter definitions so they can be inserted in multiple places. - result = @[] - for param in params: - result.add(copyNimTree(param)) - -proc collectParamNames(params: seq[NimNode]): seq[NimNode] = - ## Extract all identifier symbols declared across IdentDefs nodes. - result = @[] - for param in params: - assert param.kind == nnkIdentDefs - for i in 0 ..< param.len - 2: - let nameNode = param[i] - if nameNode.kind == nnkEmpty: - continue - result.add(ident($nameNode)) - proc makeProcType( returnType: NimNode, params: seq[NimNode], mode: RequestBrokerMode ): NimNode = @@ -234,92 +236,13 @@ proc parseMode(modeNode: NimNode): RequestBrokerMode = else: error("RequestBroker mode must be `sync` or `async` (default is async)", modeNode) -proc ensureDistinctType(rhs: NimNode): NimNode = - ## For PODs / aliases / externally-defined types, wrap in `distinct` unless - ## it's already distinct. - if rhs.kind == nnkDistinctTy: - return copyNimTree(rhs) - newTree(nnkDistinctTy, copyNimTree(rhs)) - proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = when defined(requestBrokerDebug): echo body.treeRepr echo "RequestBroker mode: ", $mode - var typeIdent: NimNode = nil - var objectDef: NimNode = nil - for stmt in body: - if stmt.kind == nnkTypeSection: - for def in stmt: - if def.kind != nnkTypeDef: - continue - if not typeIdent.isNil(): - error("Only one type may be declared inside RequestBroker", def) - - typeIdent = baseTypeIdent(def[0]) - let rhs = def[2] - - ## Support inline object types (fields are auto-exported) - ## AND non-object types / aliases (e.g. `string`, `int`, `OtherType`). - case rhs.kind - of nnkObjectTy: - let recList = rhs[2] - if recList.kind != nnkRecList: - error("RequestBroker object must declare a standard field list", rhs) - var exportedRecList = newTree(nnkRecList) - for field in recList: - case field.kind - of nnkIdentDefs: - ensureFieldDef(field) - var cloned = copyNimTree(field) - for i in 0 ..< cloned.len - 2: - cloned[i] = exportIdentNode(cloned[i]) - exportedRecList.add(cloned) - of nnkEmpty: - discard - else: - error( - "RequestBroker object definition only supports simple field declarations", - field, - ) - objectDef = newTree( - nnkObjectTy, copyNimTree(rhs[0]), copyNimTree(rhs[1]), exportedRecList - ) - of nnkRefTy: - if rhs.len != 1: - error("RequestBroker ref type must have a single base", rhs) - if rhs[0].kind == nnkObjectTy: - let obj = rhs[0] - let recList = obj[2] - if recList.kind != nnkRecList: - error("RequestBroker object must declare a standard field list", obj) - var exportedRecList = newTree(nnkRecList) - for field in recList: - case field.kind - of nnkIdentDefs: - ensureFieldDef(field) - var cloned = copyNimTree(field) - for i in 0 ..< cloned.len - 2: - cloned[i] = exportIdentNode(cloned[i]) - exportedRecList.add(cloned) - of nnkEmpty: - discard - else: - error( - "RequestBroker object definition only supports simple field declarations", - field, - ) - let exportedObjectType = newTree( - nnkObjectTy, copyNimTree(obj[0]), copyNimTree(obj[1]), exportedRecList - ) - objectDef = newTree(nnkRefTy, exportedObjectType) - else: - ## `ref SomeType` (SomeType can be defined elsewhere) - objectDef = ensureDistinctType(rhs) - else: - ## Non-object type / alias (e.g. `string`, `int`, `SomeExternalType`). - objectDef = ensureDistinctType(rhs) - if typeIdent.isNil(): - error("RequestBroker body must declare exactly one type", body) + let parsed = parseSingleTypeDef(body, "RequestBroker", allowRefToNonObject = true) + let typeIdent = parsed.typeIdent + let objectDef = parsed.objectDef when defined(requestBrokerDebug): echo "RequestBroker generating type: ", $typeIdent @@ -329,11 +252,9 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = let typeNameLit = newLit(typeDisplayName) var zeroArgSig: NimNode = nil var zeroArgProviderName: NimNode = nil - var zeroArgFieldName: NimNode = nil var argSig: NimNode = nil var argParams: seq[NimNode] = @[] var argProviderName: NimNode = nil - var argFieldName: NimNode = nil for stmt in body: case stmt.kind @@ -368,7 +289,6 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = error("Only one zero-argument signature is allowed", stmt) zeroArgSig = stmt zeroArgProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderNoArgs") - zeroArgFieldName = ident("providerNoArgs") elif paramCount >= 1: if argSig != nil: error("Only one argument-based signature is allowed", stmt) @@ -391,7 +311,6 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = error("Signature parameter must declare a name", paramDef) argParams.add(copyNimTree(paramDef)) argProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderWithArgs") - argFieldName = ident("providerWithArgs") of nnkTypeSection, nnkEmpty: discard else: @@ -400,7 +319,6 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = if zeroArgSig.isNil() and argSig.isNil(): zeroArgSig = newEmptyNode() zeroArgProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderNoArgs") - zeroArgFieldName = ident("providerNoArgs") var typeSection = newTree(nnkTypeSection) typeSection.add(newTree(nnkTypeDef, exportedTypeIdent, newEmptyNode(), objectDef)) @@ -423,12 +341,29 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = var brokerRecList = newTree(nnkRecList) if not zeroArgSig.isNil(): + let zeroArgProvidersFieldName = ident("providersNoArgs") + let zeroArgProvidersTupleTy = newTree( + nnkTupleTy, + newTree(nnkIdentDefs, ident("brokerCtx"), ident("BrokerContext"), newEmptyNode()), + newTree(nnkIdentDefs, ident("handler"), zeroArgProviderName, newEmptyNode()), + ) + let zeroArgProvidersSeqTy = + newTree(nnkBracketExpr, ident("seq"), zeroArgProvidersTupleTy) brokerRecList.add( - newTree(nnkIdentDefs, zeroArgFieldName, zeroArgProviderName, newEmptyNode()) + newTree( + nnkIdentDefs, zeroArgProvidersFieldName, zeroArgProvidersSeqTy, newEmptyNode() + ) ) if not argSig.isNil(): + let argProvidersFieldName = ident("providersWithArgs") + let argProvidersTupleTy = newTree( + nnkTupleTy, + newTree(nnkIdentDefs, ident("brokerCtx"), ident("BrokerContext"), newEmptyNode()), + newTree(nnkIdentDefs, ident("handler"), argProviderName, newEmptyNode()), + ) + let argProvidersSeqTy = newTree(nnkBracketExpr, ident("seq"), argProvidersTupleTy) brokerRecList.add( - newTree(nnkIdentDefs, argFieldName, argProviderName, newEmptyNode()) + newTree(nnkIdentDefs, argProvidersFieldName, argProvidersSeqTy, newEmptyNode()) ) let brokerTypeIdent = ident(sanitizeIdentName(typeIdent) & "Broker") let brokerTypeDef = newTree( @@ -443,31 +378,97 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = let globalVarIdent = ident("g" & sanitizeIdentName(typeIdent) & "Broker") let accessProcIdent = ident("access" & sanitizeIdentName(typeIdent) & "Broker") + + var brokerNewBody = newStmtList() + if not zeroArgSig.isNil(): + brokerNewBody.add( + quote do: + result.providersNoArgs = + @[(brokerCtx: DefaultBrokerContext, handler: default(`zeroArgProviderName`))] + ) + if not argSig.isNil(): + brokerNewBody.add( + quote do: + result.providersWithArgs = + @[(brokerCtx: DefaultBrokerContext, handler: default(`argProviderName`))] + ) + + var brokerInitChecks = newStmtList() + if not zeroArgSig.isNil(): + brokerInitChecks.add( + quote do: + if `globalVarIdent`.providersNoArgs.len == 0: + `globalVarIdent` = `brokerTypeIdent`.new() + ) + if not argSig.isNil(): + brokerInitChecks.add( + quote do: + if `globalVarIdent`.providersWithArgs.len == 0: + `globalVarIdent` = `brokerTypeIdent`.new() + ) + result.add( quote do: var `globalVarIdent` {.threadvar.}: `brokerTypeIdent` + proc new(_: type `brokerTypeIdent`): `brokerTypeIdent` = + result = `brokerTypeIdent`() + `brokerNewBody` + proc `accessProcIdent`(): var `brokerTypeIdent` = + `brokerInitChecks` `globalVarIdent` ) - var clearBody = newStmtList() + var clearBodyKeyed = newStmtList() + let brokerCtxParamIdent = ident("brokerCtx") if not zeroArgSig.isNil(): + let zeroArgProvidersFieldName = ident("providersNoArgs") result.add( quote do: proc setProvider*( _: typedesc[`typeIdent`], handler: `zeroArgProviderName` ): Result[void, string] = - if not `accessProcIdent`().`zeroArgFieldName`.isNil(): + if not `accessProcIdent`().`zeroArgProvidersFieldName`[0].handler.isNil(): return err("Zero-arg provider already set") - `accessProcIdent`().`zeroArgFieldName` = handler + `accessProcIdent`().`zeroArgProvidersFieldName`[0].handler = handler return ok() ) - clearBody.add( + + result.add( quote do: - `accessProcIdent`().`zeroArgFieldName` = nil + proc setProvider*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + handler: `zeroArgProviderName`, + ): Result[void, string] = + if brokerCtx == DefaultBrokerContext: + return setProvider(`typeIdent`, handler) + + for entry in `accessProcIdent`().`zeroArgProvidersFieldName`: + if entry.brokerCtx == brokerCtx: + return err( + "RequestBroker(" & `typeNameLit` & + "): provider already set for broker context " & $brokerCtx + ) + + `accessProcIdent`().`zeroArgProvidersFieldName`.add( + (brokerCtx: brokerCtx, handler: handler) + ) + return ok() + + ) + clearBodyKeyed.add( + quote do: + if `brokerCtxParamIdent` == DefaultBrokerContext: + `accessProcIdent`().`zeroArgProvidersFieldName`[0].handler = + default(`zeroArgProviderName`) + else: + `accessProcIdent`().`zeroArgProvidersFieldName`.keepItIf( + it.brokerCtx != `brokerCtxParamIdent` + ) ) case mode of rbAsync: @@ -476,11 +477,34 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = proc request*( _: typedesc[`typeIdent`] ): Future[Result[`typeIdent`, string]] {.async: (raises: []).} = - let provider = `accessProcIdent`().`zeroArgFieldName` + return await request(`typeIdent`, DefaultBrokerContext) + + ) + + result.add( + quote do: + proc request*( + _: typedesc[`typeIdent`], brokerCtx: BrokerContext + ): Future[Result[`typeIdent`, string]] {.async: (raises: []).} = + var provider: `zeroArgProviderName` + if brokerCtx == DefaultBrokerContext: + provider = `accessProcIdent`().`zeroArgProvidersFieldName`[0].handler + else: + for entry in `accessProcIdent`().`zeroArgProvidersFieldName`: + if entry.brokerCtx == brokerCtx: + provider = entry.handler + break + if provider.isNil(): + if brokerCtx == DefaultBrokerContext: + return err( + "RequestBroker(" & `typeNameLit` & "): no zero-arg provider registered" + ) return err( - "RequestBroker(" & `typeNameLit` & "): no zero-arg provider registered" + "RequestBroker(" & `typeNameLit` & + "): no provider registered for broker context " & $brokerCtx ) + let catchedRes = catch: await provider() @@ -507,10 +531,32 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = proc request*( _: typedesc[`typeIdent`] ): Result[`typeIdent`, string] {.gcsafe, raises: [].} = - let provider = `accessProcIdent`().`zeroArgFieldName` + return request(`typeIdent`, DefaultBrokerContext) + + ) + + result.add( + quote do: + proc request*( + _: typedesc[`typeIdent`], brokerCtx: BrokerContext + ): Result[`typeIdent`, string] {.gcsafe, raises: [].} = + var provider: `zeroArgProviderName` + if brokerCtx == DefaultBrokerContext: + provider = `accessProcIdent`().`zeroArgProvidersFieldName`[0].handler + else: + for entry in `accessProcIdent`().`zeroArgProvidersFieldName`: + if entry.brokerCtx == brokerCtx: + provider = entry.handler + break + if provider.isNil(): + if brokerCtx == DefaultBrokerContext: + return err( + "RequestBroker(" & `typeNameLit` & "): no zero-arg provider registered" + ) return err( - "RequestBroker(" & `typeNameLit` & "): no zero-arg provider registered" + "RequestBroker(" & `typeNameLit` & + "): no provider registered for broker context " & $brokerCtx ) var providerRes: Result[`typeIdent`, string] @@ -533,24 +579,54 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = ) if not argSig.isNil(): + let argProvidersFieldName = ident("providersWithArgs") result.add( quote do: proc setProvider*( _: typedesc[`typeIdent`], handler: `argProviderName` ): Result[void, string] = - if not `accessProcIdent`().`argFieldName`.isNil(): + if not `accessProcIdent`().`argProvidersFieldName`[0].handler.isNil(): return err("Provider already set") - `accessProcIdent`().`argFieldName` = handler + `accessProcIdent`().`argProvidersFieldName`[0].handler = handler return ok() ) - clearBody.add( + + result.add( quote do: - `accessProcIdent`().`argFieldName` = nil + proc setProvider*( + _: typedesc[`typeIdent`], + brokerCtx: BrokerContext, + handler: `argProviderName`, + ): Result[void, string] = + if brokerCtx == DefaultBrokerContext: + return setProvider(`typeIdent`, handler) + + for entry in `accessProcIdent`().`argProvidersFieldName`: + if entry.brokerCtx == brokerCtx: + return err( + "RequestBroker(" & `typeNameLit` & + "): provider already set for broker context " & $brokerCtx + ) + + `accessProcIdent`().`argProvidersFieldName`.add( + (brokerCtx: brokerCtx, handler: handler) + ) + return ok() + + ) + clearBodyKeyed.add( + quote do: + if `brokerCtxParamIdent` == DefaultBrokerContext: + `accessProcIdent`().`argProvidersFieldName`[0].handler = + default(`argProviderName`) + else: + `accessProcIdent`().`argProvidersFieldName`.keepItIf( + it.brokerCtx != `brokerCtxParamIdent` + ) ) let requestParamDefs = cloneParams(argParams) let argNameIdents = collectParamNames(requestParamDefs) - let providerSym = genSym(nskLet, "provider") var formalParams = newTree(nnkFormalParams) formalParams.add(copyNimTree(returnType)) formalParams.add( @@ -572,29 +648,96 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = of rbSync: quote: {.gcsafe, raises: [].} - var providerCall = newCall(providerSym) + + var forwardCall = newCall(ident("request")) + forwardCall.add(copyNimTree(typeIdent)) + forwardCall.add(ident("DefaultBrokerContext")) for argName in argNameIdents: - providerCall.add(argName) + forwardCall.add(argName) + var requestBody = newStmtList() - requestBody.add( - quote do: - let `providerSym` = `accessProcIdent`().`argFieldName` + case mode + of rbAsync: + requestBody.add( + quote do: + return await `forwardCall` + ) + of rbSync: + requestBody.add( + quote do: + return `forwardCall` + ) + + result.add( + newTree( + nnkProcDef, + postfix(ident("request"), "*"), + newEmptyNode(), + newEmptyNode(), + formalParams, + requestPragmas, + newEmptyNode(), + requestBody, + ) ) - requestBody.add( + + # Keyed request variant for the argument-based signature. + let requestParamDefsKeyed = cloneParams(argParams) + let argNameIdentsKeyed = collectParamNames(requestParamDefsKeyed) + let providerSymKeyed = genSym(nskVar, "provider") + var formalParamsKeyed = newTree(nnkFormalParams) + formalParamsKeyed.add(copyNimTree(returnType)) + formalParamsKeyed.add( + newTree( + nnkIdentDefs, + ident("_"), + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)), + newEmptyNode(), + ) + ) + formalParamsKeyed.add( + newTree(nnkIdentDefs, ident("brokerCtx"), ident("BrokerContext"), newEmptyNode()) + ) + for paramDef in requestParamDefsKeyed: + formalParamsKeyed.add(paramDef) + + let requestPragmasKeyed = requestPragmas + var providerCallKeyed = newCall(providerSymKeyed) + for argName in argNameIdentsKeyed: + providerCallKeyed.add(argName) + + var requestBodyKeyed = newStmtList() + requestBodyKeyed.add( quote do: - if `providerSym`.isNil(): + var `providerSymKeyed`: `argProviderName` + if brokerCtx == DefaultBrokerContext: + `providerSymKeyed` = `accessProcIdent`().`argProvidersFieldName`[0].handler + else: + for entry in `accessProcIdent`().`argProvidersFieldName`: + if entry.brokerCtx == brokerCtx: + `providerSymKeyed` = entry.handler + break + ) + requestBodyKeyed.add( + quote do: + if `providerSymKeyed`.isNil(): + if brokerCtx == DefaultBrokerContext: + return err( + "RequestBroker(" & `typeNameLit` & + "): no provider registered for input signature" + ) return err( "RequestBroker(" & `typeNameLit` & - "): no provider registered for input signature" + "): no provider registered for broker context " & $brokerCtx ) ) case mode of rbAsync: - requestBody.add( + requestBodyKeyed.add( quote do: let catchedRes = catch: - await `providerCall` + await `providerCallKeyed` if catchedRes.isErr(): return err( "RequestBroker(" & `typeNameLit` & "): provider threw exception: " & @@ -612,11 +755,11 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = return providerRes ) of rbSync: - requestBody.add( + requestBodyKeyed.add( quote do: var providerRes: Result[`typeIdent`, string] try: - providerRes = `providerCall` + providerRes = `providerCallKeyed` except CatchableError as e: return err( "RequestBroker(" & `typeNameLit` & "): provider threw exception: " & e.msg @@ -631,24 +774,52 @@ proc generateRequestBroker(body: NimNode, mode: RequestBrokerMode): NimNode = ) return providerRes ) - # requestBody.add(providerCall) + result.add( newTree( nnkProcDef, postfix(ident("request"), "*"), newEmptyNode(), newEmptyNode(), - formalParams, - requestPragmas, + formalParamsKeyed, + requestPragmasKeyed, newEmptyNode(), - requestBody, + requestBodyKeyed, + ) + ) + + block: + var formalParamsClearKeyed = newTree(nnkFormalParams) + formalParamsClearKeyed.add(newEmptyNode()) + formalParamsClearKeyed.add( + newTree( + nnkIdentDefs, + ident("_"), + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)), + newEmptyNode(), + ) + ) + formalParamsClearKeyed.add( + newTree(nnkIdentDefs, brokerCtxParamIdent, ident("BrokerContext"), newEmptyNode()) + ) + + result.add( + newTree( + nnkProcDef, + postfix(ident("clearProvider"), "*"), + newEmptyNode(), + newEmptyNode(), + formalParamsClearKeyed, + newEmptyNode(), + newEmptyNode(), + clearBodyKeyed, ) ) result.add( quote do: proc clearProvider*(_: typedesc[`typeIdent`]) = - `clearBody` + clearProvider(`typeIdent`, DefaultBrokerContext) ) diff --git a/waku/common/rate_limit/per_peer_limiter.nim b/waku/common/rate_limit/per_peer_limiter.nim index 5cb96a2d1..16b6bf065 100644 --- a/waku/common/rate_limit/per_peer_limiter.nim +++ b/waku/common/rate_limit/per_peer_limiter.nim @@ -20,7 +20,7 @@ proc mgetOrPut( perPeerRateLimiter: var PerPeerRateLimiter, peerId: PeerId ): var Option[TokenBucket] = return perPeerRateLimiter.peerBucket.mgetOrPut( - peerId, newTokenBucket(perPeerRateLimiter.setting, ReplenishMode.Compensating) + peerId, newTokenBucket(perPeerRateLimiter.setting, ReplenishMode.Continuous) ) template checkUsageLimit*( diff --git a/waku/common/rate_limit/request_limiter.nim b/waku/common/rate_limit/request_limiter.nim index 0ede20be4..bc318e151 100644 --- a/waku/common/rate_limit/request_limiter.nim +++ b/waku/common/rate_limit/request_limiter.nim @@ -39,38 +39,82 @@ const SECONDS_RATIO = 3 const MINUTES_RATIO = 2 type RequestRateLimiter* = ref object of RootObj - tokenBucket: Option[TokenBucket] + tokenBucket: TokenBucket setting*: Option[RateLimitSetting] + mainBucketSetting: RateLimitSetting + ratio: int peerBucketSetting*: RateLimitSetting peerUsage: TimedMap[PeerId, TokenBucket] + checkUsageImpl: proc( + t: var RequestRateLimiter, proto: string, conn: Connection, now: Moment + ): bool {.gcsafe, raises: [].} + +proc newMainTokenBucket( + setting: RateLimitSetting, ratio: int, startTime: Moment +): TokenBucket = + ## RequestRateLimiter's global bucket should keep the *rate* of the configured + ## setting while allowing a larger burst window. We achieve this by scaling + ## both capacity and fillDuration by the same ratio. + ## + ## This matches previous behavior where unused tokens could effectively + ## accumulate across multiple periods. + let burstCapacity = setting.volume * ratio + var bucket = TokenBucket.new( + capacity = burstCapacity, + fillDuration = setting.period * ratio, + startTime = startTime, + mode = Continuous, + ) + + # Start with the configured volume (not the burst capacity) so that the + # initial burst behavior matches the raw setting, while still allowing + # accumulation up to `burstCapacity` over time. + let excess = burstCapacity - setting.volume + if excess > 0: + discard bucket.tryConsume(excess, startTime) + + return bucket proc mgetOrPut( - requestRateLimiter: var RequestRateLimiter, peerId: PeerId + requestRateLimiter: var RequestRateLimiter, peerId: PeerId, now: Moment ): var TokenBucket = - let bucketForNew = newTokenBucket(some(requestRateLimiter.peerBucketSetting)).valueOr: + let bucketForNew = newTokenBucket( + some(requestRateLimiter.peerBucketSetting), Discrete, now + ).valueOr: raiseAssert "This branch is not allowed to be reached as it will not be called if the setting is None." return requestRateLimiter.peerUsage.mgetOrPut(peerId, bucketForNew) -proc checkUsage*( - t: var RequestRateLimiter, proto: string, conn: Connection, now = Moment.now() -): bool {.raises: [].} = - if t.tokenBucket.isNone(): - return true +proc checkUsageUnlimited( + t: var RequestRateLimiter, proto: string, conn: Connection, now: Moment +): bool {.gcsafe, raises: [].} = + true - let peerBucket = t.mgetOrPut(conn.peerId) +proc checkUsageLimited( + t: var RequestRateLimiter, proto: string, conn: Connection, now: Moment +): bool {.gcsafe, raises: [].} = + # Lazy-init the main bucket using the first observed request time. This makes + # refill behavior deterministic under tests where `now` is controlled. + if isNil(t.tokenBucket): + t.tokenBucket = newMainTokenBucket(t.mainBucketSetting, t.ratio, now) + + let peerBucket = t.mgetOrPut(conn.peerId, now) ## check requesting peer's usage is not over the calculated ratio and let that peer go which not requested much/or this time... if not peerBucket.tryConsume(1, now): trace "peer usage limit reached", peer = conn.peerId return false # Ok if the peer can consume, check the overall budget we have left - let tokenBucket = t.tokenBucket.get() - if not tokenBucket.tryConsume(1, now): + if not t.tokenBucket.tryConsume(1, now): return false return true +proc checkUsage*( + t: var RequestRateLimiter, proto: string, conn: Connection, now = Moment.now() +): bool {.raises: [].} = + t.checkUsageImpl(t, proto, conn, now) + template checkUsageLimit*( t: var RequestRateLimiter, proto: string, @@ -135,9 +179,19 @@ func calcPeerTokenSetting( proc newRequestRateLimiter*(setting: Option[RateLimitSetting]): RequestRateLimiter = let ratio = calcPeriodRatio(setting) + let isLimited = setting.isSome() and not setting.get().isUnlimited() + let mainBucketSetting = + if isLimited: + setting.get() + else: + (0, 0.minutes) + return RequestRateLimiter( - tokenBucket: newTokenBucket(setting), + tokenBucket: nil, setting: setting, + mainBucketSetting: mainBucketSetting, + ratio: ratio, peerBucketSetting: calcPeerTokenSetting(setting, ratio), peerUsage: init(TimedMap[PeerId, TokenBucket], calcCacheTimeout(setting, ratio)), + checkUsageImpl: (if isLimited: checkUsageLimited else: checkUsageUnlimited), ) diff --git a/waku/common/rate_limit/single_token_limiter.nim b/waku/common/rate_limit/single_token_limiter.nim index 50fb2d64c..fc4b0acd5 100644 --- a/waku/common/rate_limit/single_token_limiter.nim +++ b/waku/common/rate_limit/single_token_limiter.nim @@ -6,12 +6,15 @@ import std/[options], chronos/timer, libp2p/stream/connection, libp2p/utility import std/times except TimeInterval, Duration -import ./[token_bucket, setting, service_metrics] +import chronos/ratelimit as token_bucket + +import ./[setting, service_metrics] export token_bucket, setting, service_metrics proc newTokenBucket*( setting: Option[RateLimitSetting], - replenishMode: ReplenishMode = ReplenishMode.Compensating, + replenishMode: static[ReplenishMode] = ReplenishMode.Continuous, + startTime: Moment = Moment.now(), ): Option[TokenBucket] = if setting.isNone(): return none[TokenBucket]() @@ -19,7 +22,14 @@ proc newTokenBucket*( if setting.get().isUnlimited(): return none[TokenBucket]() - return some(TokenBucket.new(setting.get().volume, setting.get().period)) + return some( + TokenBucket.new( + capacity = setting.get().volume, + fillDuration = setting.get().period, + startTime = startTime, + mode = replenishMode, + ) + ) proc checkUsage( t: var TokenBucket, proto: string, now = Moment.now() diff --git a/waku/common/rate_limit/token_bucket.nim b/waku/common/rate_limit/token_bucket.nim deleted file mode 100644 index 799817ebd..000000000 --- a/waku/common/rate_limit/token_bucket.nim +++ /dev/null @@ -1,182 +0,0 @@ -{.push raises: [].} - -import chronos, std/math, std/options - -const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 - -## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. -## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. -## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation. -## ref: https://github.com/status-im/nim-chronos/issues/500 -## -## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways: -## - It has a new mode called `Compensating` which is the default mode. -## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average. -## or up until maximum the allowed compansation treshold (Currently it is const 25%). -## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to -## overcompensation. -## - Strict mode is also available which will only replenish when time period is over but also will fill -## the bucket to the max capacity. - -type - ReplenishMode* = enum - Strict - Compensating - - TokenBucket* = ref object - budget: int ## Current number of tokens in the bucket - budgetCap: int ## Bucket capacity - lastTimeFull: Moment - ## This timer measures the proper periodizaiton of the bucket refilling - fillDuration: Duration ## Refill period - case replenishMode*: ReplenishMode - of Strict: - ## In strict mode, the bucket is refilled only till the budgetCap - discard - of Compensating: - ## This is the default mode. - maxCompensation: float - -func periodDistance(bucket: TokenBucket, currentTime: Moment): float = - ## notice fillDuration cannot be zero by design - ## period distance is a float number representing the calculated period time - ## since the last time bucket was refilled. - return - nanoseconds(currentTime - bucket.lastTimeFull).float / - nanoseconds(bucket.fillDuration).float - -func getUsageAverageSince(bucket: TokenBucket, distance: float): float = - if distance == 0.float: - ## in case there is zero time difference than the usage percentage is 100% - return 1.0 - - ## budgetCap can never be zero - ## usage average is calculated as a percentage of total capacity available over - ## the measured period - return bucket.budget.float / bucket.budgetCap.float / distance - -proc calcCompensation(bucket: TokenBucket, averageUsage: float): int = - # if we already fully used or even overused the tokens, there is no place for compensation - if averageUsage >= 1.0: - return 0 - - ## compensation is the not used bucket capacity in the last measured period(s) in average. - ## or maximum the allowed compansation treshold - let compensationPercent = - min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation) - return trunc(compensationPercent).int - -func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool = - return currentTime - bucket.lastTimeFull >= bucket.fillDuration - -## Update will take place if bucket is empty and trying to consume tokens. -## It checks if the bucket can be replenished as refill duration is passed or not. -## - strict mode: -proc updateStrict(bucket: TokenBucket, currentTime: Moment) = - if bucket.fillDuration == default(Duration): - bucket.budget = min(bucket.budgetCap, bucket.budget) - return - - if not periodElapsed(bucket, currentTime): - return - - bucket.budget = bucket.budgetCap - bucket.lastTimeFull = currentTime - -## - compensating - ballancing load: -## - between updates we calculate average load (current bucket capacity / number of periods till last update) -## - gives the percentage load used recently -## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold) -proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) = - if bucket.fillDuration == default(Duration): - bucket.budget = min(bucket.budgetCap, bucket.budget) - return - - # do not replenish within the same period - if not periodElapsed(bucket, currentTime): - return - - let distance = bucket.periodDistance(currentTime) - let recentAvgUsage = bucket.getUsageAverageSince(distance) - let compensation = bucket.calcCompensation(recentAvgUsage) - - bucket.budget = bucket.budgetCap + compensation - bucket.lastTimeFull = currentTime - -proc update(bucket: TokenBucket, currentTime: Moment) = - if bucket.replenishMode == ReplenishMode.Compensating: - updateWithCompensation(bucket, currentTime) - else: - updateStrict(bucket, currentTime) - -proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = - ## If `tokens` are available, consume them, - ## Otherwhise, return false. - - if bucket.budget >= bucket.budgetCap: - bucket.lastTimeFull = now - - if bucket.budget >= tokens: - bucket.budget -= tokens - return true - - bucket.update(now) - - if bucket.budget >= tokens: - bucket.budget -= tokens - return true - else: - return false - -proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = - ## Add `tokens` to the budget (capped to the bucket capacity) - bucket.budget += tokens - bucket.update(now) - -proc new*( - T: type[TokenBucket], - budgetCap: int, - fillDuration: Duration = 1.seconds, - mode: ReplenishMode = ReplenishMode.Compensating, -): T = - assert not isZero(fillDuration) - assert budgetCap != 0 - - ## Create different mode TokenBucket - case mode - of ReplenishMode.Strict: - return T( - budget: budgetCap, - budgetCap: budgetCap, - fillDuration: fillDuration, - lastTimeFull: Moment.now(), - replenishMode: mode, - ) - of ReplenishMode.Compensating: - T( - budget: budgetCap, - budgetCap: budgetCap, - fillDuration: fillDuration, - lastTimeFull: Moment.now(), - replenishMode: mode, - maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, - ) - -proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket = - T.new(capacity, period, ReplenishMode.Strict) - -proc newCompensating*( - T: type[TokenBucket], capacity: int, period: Duration -): TokenBucket = - T.new(capacity, period, ReplenishMode.Compensating) - -func `$`*(b: TokenBucket): string {.inline.} = - if isNil(b): - return "nil" - return $b.budgetCap & "/" & $b.fillDuration - -func `$`*(ob: Option[TokenBucket]): string {.inline.} = - if ob.isNone(): - return "no-limit" - - return $ob.get() diff --git a/waku/factory/builder.nim b/waku/factory/builder.nim index 772cfbffd..f379f92bb 100644 --- a/waku/factory/builder.nim +++ b/waku/factory/builder.nim @@ -209,6 +209,7 @@ proc build*(builder: WakuNodeBuilder): Result[WakuNode, string] = maxServicePeers = some(builder.maxServicePeers), colocationLimit = builder.colocationLimit, shardedPeerManagement = builder.shardAware, + maxConnections = builder.switchMaxConnections.get(builders.MaxConnections), ) var node: WakuNode diff --git a/waku/factory/waku.nim b/waku/factory/waku.nim index c0380ccc9..d55206f97 100644 --- a/waku/factory/waku.nim +++ b/waku/factory/waku.nim @@ -13,7 +13,6 @@ import libp2p/services/autorelayservice, libp2p/services/hpservice, libp2p/peerid, - libp2p/discovery/rendezvousinterface, eth/keys, eth/p2p/discoveryv5/enr, presto, diff --git a/waku/node/peer_manager/peer_manager.nim b/waku/node/peer_manager/peer_manager.nim index c2358763b..bdb68905e 100644 --- a/waku/node/peer_manager/peer_manager.nim +++ b/waku/node/peer_manager/peer_manager.nim @@ -103,6 +103,7 @@ type PeerManager* = ref object of RootObj onConnectionChange*: ConnectionChangeHandler online: bool ## state managed by online_monitor module getShards: GetShards + maxConnections: int #~~~~~~~~~~~~~~~~~~~# # Helper Functions # @@ -748,7 +749,6 @@ proc logAndMetrics(pm: PeerManager) {.async.} = var peerStore = pm.switch.peerStore # log metrics let (inRelayPeers, outRelayPeers) = pm.connectedPeers(WakuRelayCodec) - let maxConnections = pm.switch.connManager.inSema.size let notConnectedPeers = peerStore.getDisconnectedPeers().mapIt(RemotePeerInfo.init(it.peerId, it.addrs)) let outsideBackoffPeers = notConnectedPeers.filterIt(pm.canBeConnected(it.peerId)) @@ -758,7 +758,7 @@ proc logAndMetrics(pm: PeerManager) {.async.} = info "Relay peer connections", inRelayConns = $inRelayPeers.len & "/" & $pm.inRelayPeersTarget, outRelayConns = $outRelayPeers.len & "/" & $pm.outRelayPeersTarget, - totalConnections = $totalConnections & "/" & $maxConnections, + totalConnections = $totalConnections & "/" & $pm.maxConnections, notConnectedPeers = notConnectedPeers.len, outsideBackoffPeers = outsideBackoffPeers.len @@ -1048,9 +1048,9 @@ proc new*( maxFailedAttempts = MaxFailedAttempts, colocationLimit = DefaultColocationLimit, shardedPeerManagement = false, + maxConnections: int = MaxConnections, ): PeerManager {.gcsafe.} = let capacity = switch.peerStore.capacity - let maxConnections = switch.connManager.inSema.size if maxConnections > capacity: error "Max number of connections can't be greater than PeerManager capacity", capacity = capacity, maxConnections = maxConnections @@ -1099,6 +1099,7 @@ proc new*( colocationLimit: colocationLimit, shardedPeerManagement: shardedPeerManagement, online: true, + maxConnections: maxConnections, ) proc peerHook( diff --git a/waku/waku_rendezvous/protocol.nim b/waku/waku_rendezvous/protocol.nim index 7b97375ff..00b5f1a5c 100644 --- a/waku/waku_rendezvous/protocol.nim +++ b/waku/waku_rendezvous/protocol.nim @@ -8,7 +8,6 @@ import stew/byteutils, libp2p/protocols/rendezvous, libp2p/protocols/rendezvous/protobuf, - libp2p/discovery/discoverymngr, libp2p/utils/semaphore, libp2p/utils/offsettedseq, libp2p/crypto/curve25519,