diff --git a/tests/common/test_all.nim b/tests/common/test_all.nim index 5b4515093..7495c7c9e 100644 --- a/tests/common/test_all.nim +++ b/tests/common/test_all.nim @@ -9,4 +9,7 @@ import ./test_tokenbucket, ./test_requestratelimiter, ./test_ratelimit_setting, - ./test_timed_map + ./test_timed_map, + ./test_event_broker, + ./test_request_broker, + ./test_multi_request_broker diff --git a/tests/common/test_event_broker.nim b/tests/common/test_event_broker.nim new file mode 100644 index 000000000..cead1277f --- /dev/null +++ b/tests/common/test_event_broker.nim @@ -0,0 +1,125 @@ +import chronos +import std/sequtils +import testutils/unittests + +import waku/common/broker/event_broker + +EventBroker: + type SampleEvent = object + value*: int + label*: string + +EventBroker: + type BinaryEvent = object + flag*: bool + +EventBroker: + type RefEvent = ref object + payload*: seq[int] + +template waitForListeners() = + waitFor sleepAsync(1.milliseconds) + +suite "EventBroker": + test "delivers events to all listeners": + var seen: seq[(int, string)] = @[] + + discard SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + seen.add((evt.value, evt.label)) + ) + + discard SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + seen.add((evt.value * 2, evt.label & "!")) + ) + + let evt = SampleEvent(value: 5, label: "hi") + SampleEvent.emit(evt) + waitForListeners() + + check seen.len == 2 + check seen.anyIt(it == (5, "hi")) + check seen.anyIt(it == (10, "hi!")) + + SampleEvent.dropAllListeners() + + test "forget removes a single listener": + var counter = 0 + + let handleA = SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + inc counter + ) + + let handleB = SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + inc(counter, 2) + ) + + SampleEvent.dropListener(handleA.get()) + let eventVal = SampleEvent(value: 1, label: "one") + SampleEvent.emit(eventVal) + waitForListeners() + check counter == 2 + + SampleEvent.dropAllListeners() + + test "forgetAll clears every listener": + var triggered = false + + let handle1 = SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + triggered = true + ) + let handle2 = SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + discard + ) + + SampleEvent.dropAllListeners() + SampleEvent.emit(42, "noop") + SampleEvent.emit(label = "noop", value = 42) + waitForListeners() + check not triggered + + let freshHandle = SampleEvent.listen( + proc(evt: SampleEvent): Future[void] {.async: (raises: []).} = + discard + ) + check freshHandle.get().id > 0'u64 + SampleEvent.dropListener(freshHandle.get()) + + test "broker helpers operate via typedesc": + var toggles: seq[bool] = @[] + + let handle = BinaryEvent.listen( + proc(evt: BinaryEvent): Future[void] {.async: (raises: []).} = + toggles.add(evt.flag) + ) + + BinaryEvent(flag: true).emit() + waitForListeners() + let binaryEvent = BinaryEvent(flag: false) + BinaryEvent.emit(binaryEvent) + waitForListeners() + + check toggles == @[true, false] + BinaryEvent.dropAllListeners() + + test "ref typed event": + var counter: int = 0 + + let handle = RefEvent.listen( + proc(evt: RefEvent): Future[void] {.async: (raises: []).} = + for n in evt.payload: + counter += n + ) + + RefEvent(payload: @[1, 2, 3]).emit() + waitForListeners() + RefEvent.emit(payload = @[4, 5, 6]) + waitForListeners() + + check counter == 21 # 1+2+3 + 4+5+6 + RefEvent.dropAllListeners() diff --git a/tests/common/test_multi_request_broker.nim b/tests/common/test_multi_request_broker.nim new file mode 100644 index 000000000..3bf10a54d --- /dev/null +++ b/tests/common/test_multi_request_broker.nim @@ -0,0 +1,234 @@ +{.used.} + +import testutils/unittests +import chronos +import std/sequtils +import std/strutils + +import waku/common/broker/multi_request_broker + +MultiRequestBroker: + type NoArgResponse = object + label*: string + + proc signatureFetch*(): Future[Result[NoArgResponse, string]] {.async.} + +MultiRequestBroker: + type ArgResponse = object + id*: string + + proc signatureFetch*( + suffix: string, numsuffix: int + ): Future[Result[ArgResponse, string]] {.async.} + +MultiRequestBroker: + type DualResponse = ref object + note*: string + suffix*: string + + proc signatureBase*(): Future[Result[DualResponse, string]] {.async.} + proc signatureWithInput*( + suffix: string + ): Future[Result[DualResponse, string]] {.async.} + +suite "MultiRequestBroker": + test "aggregates zero-argument providers": + discard NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + ok(NoArgResponse(label: "one")) + ) + + discard NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + discard catch: + await sleepAsync(1.milliseconds) + ok(NoArgResponse(label: "two")) + ) + + let responses = waitFor NoArgResponse.request() + check responses.get().len == 2 + check responses.get().anyIt(it.label == "one") + check responses.get().anyIt(it.label == "two") + + NoArgResponse.clearProviders() + + test "aggregates argument providers": + discard ArgResponse.setProvider( + proc(suffix: string, num: int): Future[Result[ArgResponse, string]] {.async.} = + ok(ArgResponse(id: suffix & "-a-" & $num)) + ) + + discard ArgResponse.setProvider( + proc(suffix: string, num: int): Future[Result[ArgResponse, string]] {.async.} = + ok(ArgResponse(id: suffix & "-b-" & $num)) + ) + + let keyed = waitFor ArgResponse.request("topic", 1) + check keyed.get().len == 2 + check keyed.get().anyIt(it.id == "topic-a-1") + check keyed.get().anyIt(it.id == "topic-b-1") + + ArgResponse.clearProviders() + + test "clearProviders resets both provider lists": + discard DualResponse.setProvider( + proc(): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "base", suffix: "")) + ) + + discard DualResponse.setProvider( + proc(suffix: string): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "base" & suffix, suffix: suffix)) + ) + + let noArgs = waitFor DualResponse.request() + check noArgs.get().len == 1 + + let param = waitFor DualResponse.request("-extra") + check param.get().len == 1 + check param.get()[0].suffix == "-extra" + + DualResponse.clearProviders() + + let emptyNoArgs = waitFor DualResponse.request() + check emptyNoArgs.get().len == 0 + + let emptyWithArgs = waitFor DualResponse.request("-extra") + check emptyWithArgs.get().len == 0 + + test "request returns empty seq when no providers registered": + let empty = waitFor NoArgResponse.request() + check empty.get().len == 0 + + test "failed providers will fail the request": + NoArgResponse.clearProviders() + discard NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + err("boom") + ) + + discard NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + ok(NoArgResponse(label: "survivor")) + ) + + let filtered = waitFor NoArgResponse.request() + check filtered.isErr() + + NoArgResponse.clearProviders() + + test "deduplicates identical zero-argument providers": + NoArgResponse.clearProviders() + var invocations = 0 + let sharedHandler = proc(): Future[Result[NoArgResponse, string]] {.async.} = + inc invocations + ok(NoArgResponse(label: "dup")) + + let first = NoArgResponse.setProvider(sharedHandler) + let second = NoArgResponse.setProvider(sharedHandler) + + check first.get().id == second.get().id + check first.get().kind == second.get().kind + + let dupResponses = waitFor NoArgResponse.request() + check dupResponses.get().len == 1 + check invocations == 1 + + NoArgResponse.clearProviders() + + test "removeProvider deletes registered handlers": + var removedCalled = false + var keptCalled = false + + let removable = NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + removedCalled = true + ok(NoArgResponse(label: "removed")) + ) + + discard NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + keptCalled = true + ok(NoArgResponse(label: "kept")) + ) + + NoArgResponse.removeProvider(removable.get()) + + let afterRemoval = (waitFor NoArgResponse.request()).valueOr: + assert false, "request failed" + @[] + check afterRemoval.len == 1 + check afterRemoval[0].label == "kept" + check not removedCalled + check keptCalled + + NoArgResponse.clearProviders() + + test "removeProvider works for argument signatures": + var invoked: seq[string] = @[] + + discard ArgResponse.setProvider( + proc(suffix: string, num: int): Future[Result[ArgResponse, string]] {.async.} = + invoked.add("first" & suffix) + ok(ArgResponse(id: suffix & "-one-" & $num)) + ) + + let handle = ArgResponse.setProvider( + proc(suffix: string, num: int): Future[Result[ArgResponse, string]] {.async.} = + invoked.add("second" & suffix) + ok(ArgResponse(id: suffix & "-two-" & $num)) + ) + + ArgResponse.removeProvider(handle.get()) + + let single = (waitFor ArgResponse.request("topic", 1)).valueOr: + assert false, "request failed" + @[] + check single.len == 1 + check single[0].id == "topic-one-1" + check invoked == @["firsttopic"] + + ArgResponse.clearProviders() + + test "catches exception from providers and report error": + let firstHandler = NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + raise newException(ValueError, "first handler raised") + ok(NoArgResponse(label: "any")) + ) + + discard NoArgResponse.setProvider( + proc(): Future[Result[NoArgResponse, string]] {.async.} = + ok(NoArgResponse(label: "just ok")) + ) + + let afterException = waitFor NoArgResponse.request() + check afterException.isErr() + check afterException.error().contains("first handler raised") + + NoArgResponse.clearProviders() + + test "ref providers returning nil fail request": + DualResponse.clearProviders() + + discard DualResponse.setProvider( + proc(): Future[Result[DualResponse, string]] {.async.} = + let nilResponse: DualResponse = nil + ok(nilResponse) + ) + + let zeroArg = waitFor DualResponse.request() + check zeroArg.isErr() + + DualResponse.clearProviders() + + discard DualResponse.setProvider( + proc(suffix: string): Future[Result[DualResponse, string]] {.async.} = + let nilResponse: DualResponse = nil + ok(nilResponse) + ) + + let withInput = waitFor DualResponse.request("-extra") + check withInput.isErr() + + DualResponse.clearProviders() diff --git a/tests/common/test_request_broker.nim b/tests/common/test_request_broker.nim new file mode 100644 index 000000000..2ffd9cbf8 --- /dev/null +++ b/tests/common/test_request_broker.nim @@ -0,0 +1,198 @@ +{.used.} + +import testutils/unittests +import chronos +import std/strutils + +import waku/common/broker/request_broker + +RequestBroker: + type SimpleResponse = object + value*: string + + proc signatureFetch*(): Future[Result[SimpleResponse, string]] {.async.} + +RequestBroker: + type KeyedResponse = object + key*: string + payload*: string + + proc signatureFetchWithKey*( + key: string, subKey: int + ): Future[Result[KeyedResponse, string]] {.async.} + +RequestBroker: + type DualResponse = object + note*: string + count*: int + + proc signatureNoInput*(): Future[Result[DualResponse, string]] {.async.} + proc signatureWithInput*( + suffix: string + ): Future[Result[DualResponse, string]] {.async.} + +RequestBroker: + type ImplicitResponse = ref object + note*: string + +suite "RequestBroker macro": + test "serves zero-argument providers": + check SimpleResponse + .setProvider( + proc(): Future[Result[SimpleResponse, string]] {.async.} = + ok(SimpleResponse(value: "hi")) + ) + .isOk() + + let res = waitFor SimpleResponse.request() + check res.isOk() + check res.value.value == "hi" + + SimpleResponse.clearProvider() + + test "zero-argument request errors when unset": + let res = waitFor SimpleResponse.request() + check res.isErr + check res.error.contains("no zero-arg provider") + + test "serves input-based providers": + var seen: seq[string] = @[] + check KeyedResponse + .setProvider( + proc(key: string, subKey: int): Future[Result[KeyedResponse, string]] {.async.} = + seen.add(key) + ok(KeyedResponse(key: key, payload: key & "-payload+" & $subKey)) + ) + .isOk() + + let res = waitFor KeyedResponse.request("topic", 1) + check res.isOk() + check res.value.key == "topic" + check res.value.payload == "topic-payload+1" + check seen == @["topic"] + + KeyedResponse.clearProvider() + + test "catches provider exception": + check KeyedResponse + .setProvider( + proc(key: string, subKey: int): Future[Result[KeyedResponse, string]] {.async.} = + raise newException(ValueError, "simulated failure") + ok(KeyedResponse(key: key, payload: "")) + ) + .isOk() + + let res = waitFor KeyedResponse.request("neglected", 11) + check res.isErr() + check res.error.contains("simulated failure") + + KeyedResponse.clearProvider() + + test "input request errors when unset": + let res = waitFor KeyedResponse.request("foo", 2) + check res.isErr + check res.error.contains("input signature") + + test "supports both provider types simultaneously": + check DualResponse + .setProvider( + proc(): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "base", count: 1)) + ) + .isOk() + + check DualResponse + .setProvider( + proc(suffix: string): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "base" & suffix, count: suffix.len)) + ) + .isOk() + + let noInput = waitFor DualResponse.request() + check noInput.isOk + check noInput.value.note == "base" + + let withInput = waitFor DualResponse.request("-extra") + check withInput.isOk + check withInput.value.note == "base-extra" + check withInput.value.count == 6 + + DualResponse.clearProvider() + + test "clearProvider resets both entries": + check DualResponse + .setProvider( + proc(): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "temp", count: 0)) + ) + .isOk() + DualResponse.clearProvider() + + let res = waitFor DualResponse.request() + check res.isErr + + test "implicit zero-argument provider works by default": + check ImplicitResponse + .setProvider( + proc(): Future[Result[ImplicitResponse, string]] {.async.} = + ok(ImplicitResponse(note: "auto")) + ) + .isOk() + + let res = waitFor ImplicitResponse.request() + check res.isOk + + ImplicitResponse.clearProvider() + check res.value.note == "auto" + + test "implicit zero-argument request errors when unset": + let res = waitFor ImplicitResponse.request() + check res.isErr + check res.error.contains("no zero-arg provider") + + test "no provider override": + check DualResponse + .setProvider( + proc(): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "base", count: 1)) + ) + .isOk() + + check DualResponse + .setProvider( + proc(suffix: string): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "base" & suffix, count: suffix.len)) + ) + .isOk() + + let overrideProc = proc(): Future[Result[DualResponse, string]] {.async.} = + ok(DualResponse(note: "something else", count: 1)) + + check DualResponse.setProvider(overrideProc).isErr() + + let noInput = waitFor DualResponse.request() + check noInput.isOk + check noInput.value.note == "base" + + let stillResponse = waitFor DualResponse.request(" still works") + check stillResponse.isOk() + check stillResponse.value.note.contains("base still works") + + DualResponse.clearProvider() + + let noResponse = waitFor DualResponse.request() + check noResponse.isErr() + check noResponse.error.contains("no zero-arg provider") + + let noResponseArg = waitFor DualResponse.request("Should not work") + check noResponseArg.isErr() + check noResponseArg.error.contains("no provider") + + check DualResponse.setProvider(overrideProc).isOk() + + let nowSuccWithOverride = waitFor DualResponse.request() + check nowSuccWithOverride.isOk + check nowSuccWithOverride.value.note == "something else" + check nowSuccWithOverride.value.count == 1 + + DualResponse.clearProvider() diff --git a/waku/common/broker/event_broker.nim b/waku/common/broker/event_broker.nim new file mode 100644 index 000000000..05d7b50ab --- /dev/null +++ b/waku/common/broker/event_broker.nim @@ -0,0 +1,308 @@ +## EventBroker +## ------------------- +## EventBroker represents a reactive decoupling pattern, that +## allows event-driven development without +## need for direct dependencies in between emitters and listeners. +## Worth considering using it in a single or many emitters to many listeners scenario. +## +## Generates a standalone, type-safe event broker for the declared object type. +## The macro exports the value type itself plus a broker companion that manages +## listeners via thread-local storage. +## +## Usage: +## Declare your desired event type inside an `EventBroker` macro, add any number of fields.: +## ```nim +## EventBroker: +## type TypeName = object +## field1*: FieldType +## field2*: AnotherFieldType +## ``` +## +## After this, you can register async listeners anywhere in your code with +## `TypeName.listen(...)`, which returns a handle to the registered listener. +## Listeners are async procs or lambdas that take a single argument of the event type. +## Any number of listeners can be registered in different modules. +## +## Events can be emitted from anywhere with no direct dependency on the listeners by +## calling `TypeName.emit(...)` with an instance of the event type. +## This will asynchronously notify all registered listeners with the emitted event. +## +## Whenever you no longer need a listener (or your object instance that listen to the event goes out of scope), +## you can remove it from the broker with the handle returned by `listen`. +## This is done by calling `TypeName.dropListener(handle)`. +## Alternatively, you can remove all registered listeners through `TypeName.dropAllListeners()`. +## +## +## Example: +## ```nim +## EventBroker: +## type GreetingEvent = object +## text*: string +## +## let handle = GreetingEvent.listen( +## proc(evt: GreetingEvent): Future[void] {.async.} = +## echo evt.text +## ) +## GreetingEvent.emit(text= "hi") +## GreetingEvent.dropListener(handle) +## ``` + +import std/[macros, tables] +import chronos, chronicles, results +import ./helper/broker_utils + +export chronicles, results, chronos + +macro EventBroker*(body: untyped): untyped = + when defined(eventBrokerDebug): + echo body.treeRepr + var typeIdent: NimNode = nil + var objectDef: NimNode = nil + var fieldNames: seq[NimNode] = @[] + var fieldTypes: seq[NimNode] = @[] + var isRefObject = false + for stmt in body: + if stmt.kind == nnkTypeSection: + for def in stmt: + if def.kind != nnkTypeDef: + continue + let rhs = def[2] + var objectType: NimNode + case rhs.kind + of nnkObjectTy: + objectType = rhs + of nnkRefTy: + isRefObject = true + if rhs.len != 1 or rhs[0].kind != nnkObjectTy: + error("EventBroker ref object must wrap a concrete object definition", rhs) + objectType = rhs[0] + else: + continue + if not typeIdent.isNil(): + error("Only one object type may be declared inside EventBroker", def) + typeIdent = baseTypeIdent(def[0]) + let recList = objectType[2] + if recList.kind != nnkRecList: + error("EventBroker object must declare a standard field list", objectType) + var exportedRecList = newTree(nnkRecList) + for field in recList: + case field.kind + of nnkIdentDefs: + ensureFieldDef(field) + let fieldTypeNode = field[field.len - 2] + for i in 0 ..< field.len - 2: + let baseFieldIdent = baseTypeIdent(field[i]) + fieldNames.add(copyNimTree(baseFieldIdent)) + fieldTypes.add(copyNimTree(fieldTypeNode)) + var cloned = copyNimTree(field) + for i in 0 ..< cloned.len - 2: + cloned[i] = exportIdentNode(cloned[i]) + exportedRecList.add(cloned) + of nnkEmpty: + discard + else: + error( + "EventBroker object definition only supports simple field declarations", + field, + ) + let exportedObjectType = newTree( + nnkObjectTy, + copyNimTree(objectType[0]), + copyNimTree(objectType[1]), + exportedRecList, + ) + if isRefObject: + objectDef = newTree(nnkRefTy, exportedObjectType) + else: + objectDef = exportedObjectType + if typeIdent.isNil(): + error("EventBroker body must declare exactly one object type", body) + + let exportedTypeIdent = postfix(copyNimTree(typeIdent), "*") + let sanitized = sanitizeIdentName(typeIdent) + let typeNameLit = newLit($typeIdent) + let isRefObjectLit = newLit(isRefObject) + let handlerProcIdent = ident(sanitized & "ListenerProc") + let listenerHandleIdent = ident(sanitized & "Listener") + let brokerTypeIdent = ident(sanitized & "Broker") + let exportedHandlerProcIdent = postfix(copyNimTree(handlerProcIdent), "*") + let exportedListenerHandleIdent = postfix(copyNimTree(listenerHandleIdent), "*") + let exportedBrokerTypeIdent = postfix(copyNimTree(brokerTypeIdent), "*") + let accessProcIdent = ident("access" & sanitized & "Broker") + let globalVarIdent = ident("g" & sanitized & "Broker") + let listenImplIdent = ident("register" & sanitized & "Listener") + let dropListenerImplIdent = ident("drop" & sanitized & "Listener") + let dropAllListenersImplIdent = ident("dropAll" & sanitized & "Listeners") + let emitImplIdent = ident("emit" & sanitized & "Value") + let listenerTaskIdent = ident("notify" & sanitized & "Listener") + + result = newStmtList() + + result.add( + quote do: + type + `exportedTypeIdent` = `objectDef` + `exportedListenerHandleIdent` = object + id*: uint64 + + `exportedHandlerProcIdent` = + proc(event: `typeIdent`): Future[void] {.async: (raises: []), gcsafe.} + `exportedBrokerTypeIdent` = ref object + listeners: Table[uint64, `handlerProcIdent`] + nextId: uint64 + + ) + + result.add( + quote do: + var `globalVarIdent` {.threadvar.}: `brokerTypeIdent` + ) + + result.add( + quote do: + proc `accessProcIdent`(): `brokerTypeIdent` = + if `globalVarIdent`.isNil(): + new(`globalVarIdent`) + `globalVarIdent`.listeners = initTable[uint64, `handlerProcIdent`]() + `globalVarIdent` + + ) + + result.add( + quote do: + proc `listenImplIdent`( + handler: `handlerProcIdent` + ): Result[`listenerHandleIdent`, string] = + if handler.isNil(): + return err("Must provide a non-nil event handler") + var broker = `accessProcIdent`() + if broker.nextId == 0'u64: + broker.nextId = 1'u64 + if broker.nextId == high(uint64): + error "Cannot add more listeners: ID space exhausted", nextId = $broker.nextId + return err("Cannot add more listeners, listener ID space exhausted") + let newId = broker.nextId + inc broker.nextId + broker.listeners[newId] = handler + return ok(`listenerHandleIdent`(id: newId)) + + ) + + result.add( + quote do: + proc `dropListenerImplIdent`(handle: `listenerHandleIdent`) = + if handle.id == 0'u64: + return + var broker = `accessProcIdent`() + if broker.listeners.len == 0: + return + broker.listeners.del(handle.id) + + ) + + result.add( + quote do: + proc `dropAllListenersImplIdent`() = + var broker = `accessProcIdent`() + if broker.listeners.len > 0: + broker.listeners.clear() + + ) + + result.add( + quote do: + proc listen*( + _: typedesc[`typeIdent`], handler: `handlerProcIdent` + ): Result[`listenerHandleIdent`, string] = + return `listenImplIdent`(handler) + + ) + + result.add( + quote do: + proc dropListener*(_: typedesc[`typeIdent`], handle: `listenerHandleIdent`) = + `dropListenerImplIdent`(handle) + + proc dropAllListeners*(_: typedesc[`typeIdent`]) = + `dropAllListenersImplIdent`() + + ) + + result.add( + quote do: + proc `listenerTaskIdent`( + callback: `handlerProcIdent`, event: `typeIdent` + ) {.async: (raises: []), gcsafe.} = + if callback.isNil(): + return + try: + await callback(event) + except Exception: + error "Failed to execute event listener", error = getCurrentExceptionMsg() + + proc `emitImplIdent`( + event: `typeIdent` + ): Future[void] {.async: (raises: []), gcsafe.} = + when `isRefObjectLit`: + if event.isNil(): + error "Cannot emit uninitialized event object", eventType = `typeNameLit` + return + let broker = `accessProcIdent`() + if broker.listeners.len == 0: + # nothing to do as nobody is listening + return + var callbacks: seq[`handlerProcIdent`] = @[] + for cb in broker.listeners.values: + callbacks.add(cb) + for cb in callbacks: + asyncSpawn `listenerTaskIdent`(cb, event) + + proc emit*(event: `typeIdent`) = + asyncSpawn `emitImplIdent`(event) + + proc emit*(_: typedesc[`typeIdent`], event: `typeIdent`) = + asyncSpawn `emitImplIdent`(event) + + ) + + var emitCtorParams = newTree(nnkFormalParams, newEmptyNode()) + let typedescParamType = + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)) + emitCtorParams.add( + newTree(nnkIdentDefs, ident("_"), typedescParamType, newEmptyNode()) + ) + for i in 0 ..< fieldNames.len: + emitCtorParams.add( + newTree( + nnkIdentDefs, + copyNimTree(fieldNames[i]), + copyNimTree(fieldTypes[i]), + newEmptyNode(), + ) + ) + + var emitCtorExpr = newTree(nnkObjConstr, copyNimTree(typeIdent)) + for i in 0 ..< fieldNames.len: + emitCtorExpr.add( + newTree(nnkExprColonExpr, copyNimTree(fieldNames[i]), copyNimTree(fieldNames[i])) + ) + + let emitCtorCall = newCall(copyNimTree(emitImplIdent), emitCtorExpr) + let emitCtorBody = quote: + asyncSpawn `emitCtorCall` + + let typedescEmitProc = newTree( + nnkProcDef, + postfix(ident("emit"), "*"), + newEmptyNode(), + newEmptyNode(), + emitCtorParams, + newEmptyNode(), + newEmptyNode(), + emitCtorBody, + ) + + result.add(typedescEmitProc) + + when defined(eventBrokerDebug): + echo result.repr diff --git a/waku/common/broker/helper/broker_utils.nim b/waku/common/broker/helper/broker_utils.nim new file mode 100644 index 000000000..ea9f85750 --- /dev/null +++ b/waku/common/broker/helper/broker_utils.nim @@ -0,0 +1,43 @@ +import std/macros + +proc sanitizeIdentName*(node: NimNode): string = + var raw = $node + var sanitizedName = newStringOfCap(raw.len) + for ch in raw: + case ch + of 'A' .. 'Z', 'a' .. 'z', '0' .. '9', '_': + sanitizedName.add(ch) + else: + sanitizedName.add('_') + sanitizedName + +proc ensureFieldDef*(node: NimNode) = + if node.kind != nnkIdentDefs or node.len < 3: + error("Expected field definition of the form `name: Type`", node) + let typeSlot = node.len - 2 + if node[typeSlot].kind == nnkEmpty: + error("Field `" & $node[0] & "` must declare a type", node) + +proc exportIdentNode*(node: NimNode): NimNode = + case node.kind + of nnkIdent: + postfix(copyNimTree(node), "*") + of nnkPostfix: + node + else: + error("Unsupported identifier form in field definition", node) + +proc baseTypeIdent*(defName: NimNode): NimNode = + case defName.kind + of nnkIdent: + defName + of nnkAccQuoted: + if defName.len != 1: + error("Unsupported quoted identifier", defName) + defName[0] + of nnkPostfix: + baseTypeIdent(defName[1]) + of nnkPragmaExpr: + baseTypeIdent(defName[0]) + else: + error("Unsupported type name in broker definition", defName) diff --git a/waku/common/broker/multi_request_broker.nim b/waku/common/broker/multi_request_broker.nim new file mode 100644 index 000000000..7f4161f5a --- /dev/null +++ b/waku/common/broker/multi_request_broker.nim @@ -0,0 +1,583 @@ +## MultiRequestBroker +## -------------------- +## MultiRequestBroker represents a proactive decoupling pattern, that +## allows defining request-response style interactions between modules without +## need for direct dependencies in between. +## Worth considering using it for use cases where you need to collect data from multiple providers. +## +## Provides a declarative way to define an immutable value type together with a +## thread-local broker that can register multiple asynchronous providers, dispatch +## typed requests, and clear handlers. Unlike `RequestBroker`, +## every call to `request` fan-outs to every registered provider and returns with +## collected responses. +## Request succeeds if all providers succeed, otherwise fails with an error. +## +## Usage: +## +## Declare collectable request data type inside a `MultiRequestBroker` macro, add any number of fields: +## ```nim +## MultiRequestBroker: +## type TypeName = object +## field1*: Type1 +## field2*: Type2 +## +## ## Define the request and provider signature, that is enforced at compile time. +## proc signature*(): Future[Result[TypeName, string]] {.async: (raises: []).} +## +## ## Also possible to define signature with arbitrary input arguments. +## proc signature*(arg1: ArgType, arg2: AnotherArgType): Future[Result[TypeName, string]] {.async: (raises: []).} +## +## ``` +## +## You regiser request processor (proveder) at any place of the code without the need to know of who ever may request. +## Respectively to the defined signatures register provider functions with `TypeName.setProvider(...)`. +## Providers are async procs or lambdas that return with a Future[Result[seq[TypeName], string]]. +## Notice MultiRequestBroker's `setProvider` return with a handler that can be used to remove the provider later (or error). + +## Requests can be made from anywhere with no direct dependency on the provider(s) by +## calling `TypeName.request()` - with arguments respecting the signature(s). +## This will asynchronously call the registered provider and return the collected data, in form of `Future[Result[seq[TypeName], string]]`. +## +## Whenever you don't want to process requests anymore (or your object instance that provides the request goes out of scope), +## you can remove it from the broker with `TypeName.removeProvider(handle)`. +## Alternatively, you can remove all registered providers through `TypeName.clearProviders()`. +## +## Example: +## ```nim +## MultiRequestBroker: +## type Greeting = object +## text*: string +## +## ## Define the request and provider signature, that is enforced at compile time. +## proc signature*(): Future[Result[Greeting, string]] {.async: (raises: []).} +## +## ## Also possible to define signature with arbitrary input arguments. +## proc signature*(lang: string): Future[Result[Greeting, string]] {.async: (raises: []).} +## +## ... +## let handle = Greeting.setProvider( +## proc(): Future[Result[Greeting, string]] {.async: (raises: []).} = +## ok(Greeting(text: "hello")) +## ) +## +## let anotherHandle = Greeting.setProvider( +## proc(): Future[Result[Greeting, string]] {.async: (raises: []).} = +## ok(Greeting(text: "szia")) +## ) +## +## let responses = (await Greeting.request()).valueOr(@[Greeting(text: "default")]) +## +## echo responses.len +## Greeting.clearProviders() +## ``` +## If no `signature` proc is declared, a zero-argument form is generated +## automatically, so the caller only needs to provide the type definition. + +import std/[macros, strutils, tables, sugar] +import chronos +import results +import ./helper/broker_utils + +export results, chronos + +proc isReturnTypeValid(returnType, typeIdent: NimNode): bool = + ## Accept Future[Result[TypeIdent, string]] as the contract. + if returnType.kind != nnkBracketExpr or returnType.len != 2: + return false + if returnType[0].kind != nnkIdent or not returnType[0].eqIdent("Future"): + return false + let inner = returnType[1] + if inner.kind != nnkBracketExpr or inner.len != 3: + return false + if inner[0].kind != nnkIdent or not inner[0].eqIdent("Result"): + return false + if inner[1].kind != nnkIdent or not inner[1].eqIdent($typeIdent): + return false + inner[2].kind == nnkIdent and inner[2].eqIdent("string") + +proc cloneParams(params: seq[NimNode]): seq[NimNode] = + ## Deep copy parameter definitions so they can be reused in generated nodes. + result = @[] + for param in params: + result.add(copyNimTree(param)) + +proc collectParamNames(params: seq[NimNode]): seq[NimNode] = + ## Extract identifiers declared in parameter definitions. + result = @[] + for param in params: + assert param.kind == nnkIdentDefs + for i in 0 ..< param.len - 2: + let nameNode = param[i] + if nameNode.kind == nnkEmpty: + continue + result.add(ident($nameNode)) + +proc makeProcType(returnType: NimNode, params: seq[NimNode]): NimNode = + var formal = newTree(nnkFormalParams) + formal.add(returnType) + for param in params: + formal.add(param) + + let pragmas = quote: + {.async.} + + newTree(nnkProcTy, formal, pragmas) + +macro MultiRequestBroker*(body: untyped): untyped = + when defined(requestBrokerDebug): + echo body.treeRepr + var typeIdent: NimNode = nil + var objectDef: NimNode = nil + var isRefObject = false + for stmt in body: + if stmt.kind == nnkTypeSection: + for def in stmt: + if def.kind != nnkTypeDef: + continue + let rhs = def[2] + var objectType: NimNode + case rhs.kind + of nnkObjectTy: + objectType = rhs + of nnkRefTy: + isRefObject = true + if rhs.len != 1 or rhs[0].kind != nnkObjectTy: + error( + "MultiRequestBroker ref object must wrap a concrete object definition", + rhs, + ) + objectType = rhs[0] + else: + continue + if not typeIdent.isNil(): + error("Only one object type may be declared inside MultiRequestBroker", def) + typeIdent = baseTypeIdent(def[0]) + let recList = objectType[2] + if recList.kind != nnkRecList: + error( + "MultiRequestBroker object must declare a standard field list", objectType + ) + var exportedRecList = newTree(nnkRecList) + for field in recList: + case field.kind + of nnkIdentDefs: + ensureFieldDef(field) + var cloned = copyNimTree(field) + for i in 0 ..< cloned.len - 2: + cloned[i] = exportIdentNode(cloned[i]) + exportedRecList.add(cloned) + of nnkEmpty: + discard + else: + error( + "MultiRequestBroker object definition only supports simple field declarations", + field, + ) + let exportedObjectType = newTree( + nnkObjectTy, + copyNimTree(objectType[0]), + copyNimTree(objectType[1]), + exportedRecList, + ) + if isRefObject: + objectDef = newTree(nnkRefTy, exportedObjectType) + else: + objectDef = exportedObjectType + if typeIdent.isNil(): + error("MultiRequestBroker body must declare exactly one object type", body) + + when defined(requestBrokerDebug): + echo "MultiRequestBroker generating type: ", $typeIdent + + let exportedTypeIdent = postfix(copyNimTree(typeIdent), "*") + let sanitized = sanitizeIdentName(typeIdent) + let typeNameLit = newLit($typeIdent) + let isRefObjectLit = newLit(isRefObject) + let tableSym = bindSym"Table" + let initTableSym = bindSym"initTable" + let uint64Ident = ident("uint64") + let providerKindIdent = ident(sanitized & "ProviderKind") + let providerHandleIdent = ident(sanitized & "ProviderHandle") + let exportedProviderHandleIdent = postfix(copyNimTree(providerHandleIdent), "*") + let zeroKindIdent = ident("pk" & sanitized & "NoArgs") + let argKindIdent = ident("pk" & sanitized & "WithArgs") + var zeroArgSig: NimNode = nil + var zeroArgProviderName: NimNode = nil + var zeroArgFieldName: NimNode = nil + var argSig: NimNode = nil + var argParams: seq[NimNode] = @[] + var argProviderName: NimNode = nil + var argFieldName: NimNode = nil + + for stmt in body: + case stmt.kind + of nnkProcDef: + let procName = stmt[0] + let procNameIdent = + case procName.kind + of nnkIdent: + procName + of nnkPostfix: + procName[1] + else: + procName + let procNameStr = $procNameIdent + if not procNameStr.startsWith("signature"): + error("Signature proc names must start with `signature`", procName) + let params = stmt.params + if params.len == 0: + error("Signature must declare a return type", stmt) + let returnType = params[0] + if not isReturnTypeValid(returnType, typeIdent): + error( + "Signature must return Future[Result[`" & $typeIdent & "`, string]]", stmt + ) + let paramCount = params.len - 1 + if paramCount == 0: + if zeroArgSig != nil: + error("Only one zero-argument signature is allowed", stmt) + zeroArgSig = stmt + zeroArgProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderNoArgs") + zeroArgFieldName = ident("providerNoArgs") + elif paramCount >= 1: + if argSig != nil: + error("Only one argument-based signature is allowed", stmt) + argSig = stmt + argParams = @[] + for idx in 1 ..< params.len: + let paramDef = params[idx] + if paramDef.kind != nnkIdentDefs: + error( + "Signature parameter must be a standard identifier declaration", paramDef + ) + let paramTypeNode = paramDef[paramDef.len - 2] + if paramTypeNode.kind == nnkEmpty: + error("Signature parameter must declare a type", paramDef) + var hasName = false + for i in 0 ..< paramDef.len - 2: + if paramDef[i].kind != nnkEmpty: + hasName = true + if not hasName: + error("Signature parameter must declare a name", paramDef) + argParams.add(copyNimTree(paramDef)) + argProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderWithArgs") + argFieldName = ident("providerWithArgs") + of nnkTypeSection, nnkEmpty: + discard + else: + error("Unsupported statement inside MultiRequestBroker definition", stmt) + + if zeroArgSig.isNil() and argSig.isNil(): + zeroArgSig = newEmptyNode() + zeroArgProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderNoArgs") + zeroArgFieldName = ident("providerNoArgs") + + var typeSection = newTree(nnkTypeSection) + typeSection.add(newTree(nnkTypeDef, exportedTypeIdent, newEmptyNode(), objectDef)) + + var kindEnum = newTree(nnkEnumTy, newEmptyNode()) + if not zeroArgSig.isNil(): + kindEnum.add(zeroKindIdent) + if not argSig.isNil(): + kindEnum.add(argKindIdent) + typeSection.add(newTree(nnkTypeDef, providerKindIdent, newEmptyNode(), kindEnum)) + + var handleRecList = newTree(nnkRecList) + handleRecList.add(newTree(nnkIdentDefs, ident("id"), uint64Ident, newEmptyNode())) + handleRecList.add( + newTree(nnkIdentDefs, ident("kind"), providerKindIdent, newEmptyNode()) + ) + typeSection.add( + newTree( + nnkTypeDef, + exportedProviderHandleIdent, + newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), handleRecList), + ) + ) + + let returnType = quote: + Future[Result[`typeIdent`, string]] + + if not zeroArgSig.isNil(): + let procType = makeProcType(returnType, @[]) + typeSection.add(newTree(nnkTypeDef, zeroArgProviderName, newEmptyNode(), procType)) + if not argSig.isNil(): + let procType = makeProcType(returnType, cloneParams(argParams)) + typeSection.add(newTree(nnkTypeDef, argProviderName, newEmptyNode(), procType)) + + var brokerRecList = newTree(nnkRecList) + if not zeroArgSig.isNil(): + brokerRecList.add( + newTree( + nnkIdentDefs, + zeroArgFieldName, + newTree(nnkBracketExpr, tableSym, uint64Ident, zeroArgProviderName), + newEmptyNode(), + ) + ) + if not argSig.isNil(): + brokerRecList.add( + newTree( + nnkIdentDefs, + argFieldName, + newTree(nnkBracketExpr, tableSym, uint64Ident, argProviderName), + newEmptyNode(), + ) + ) + brokerRecList.add(newTree(nnkIdentDefs, ident("nextId"), uint64Ident, newEmptyNode())) + let brokerTypeIdent = ident(sanitizeIdentName(typeIdent) & "Broker") + let brokerTypeDef = newTree( + nnkTypeDef, + brokerTypeIdent, + newEmptyNode(), + newTree( + nnkRefTy, newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), brokerRecList) + ), + ) + typeSection.add(brokerTypeDef) + result = newStmtList() + result.add(typeSection) + + let globalVarIdent = ident("g" & sanitizeIdentName(typeIdent) & "Broker") + let accessProcIdent = ident("access" & sanitizeIdentName(typeIdent) & "Broker") + var initStatements = newStmtList() + if not zeroArgSig.isNil(): + initStatements.add( + quote do: + `globalVarIdent`.`zeroArgFieldName` = + `initTableSym`[`uint64Ident`, `zeroArgProviderName`]() + ) + if not argSig.isNil(): + initStatements.add( + quote do: + `globalVarIdent`.`argFieldName` = + `initTableSym`[`uint64Ident`, `argProviderName`]() + ) + result.add( + quote do: + var `globalVarIdent` {.threadvar.}: `brokerTypeIdent` + + proc `accessProcIdent`(): `brokerTypeIdent` = + if `globalVarIdent`.isNil(): + new(`globalVarIdent`) + `globalVarIdent`.nextId = 1'u64 + `initStatements` + return `globalVarIdent` + + ) + + var clearBody = newStmtList() + if not zeroArgSig.isNil(): + result.add( + quote do: + proc setProvider*( + _: typedesc[`typeIdent`], handler: `zeroArgProviderName` + ): Result[`providerHandleIdent`, string] = + if handler.isNil(): + return err("Provider handler must be provided") + let broker = `accessProcIdent`() + if broker.nextId == 0'u64: + broker.nextId = 1'u64 + for existingId, existing in broker.`zeroArgFieldName`.pairs: + if existing == handler: + return ok(`providerHandleIdent`(id: existingId, kind: `zeroKindIdent`)) + let newId = broker.nextId + inc broker.nextId + broker.`zeroArgFieldName`[newId] = handler + return ok(`providerHandleIdent`(id: newId, kind: `zeroKindIdent`)) + + ) + clearBody.add( + quote do: + let broker = `accessProcIdent`() + if not broker.isNil() and broker.`zeroArgFieldName`.len > 0: + broker.`zeroArgFieldName`.clear() + ) + result.add( + quote do: + proc request*( + _: typedesc[`typeIdent`] + ): Future[Result[seq[`typeIdent`], string]] {.async: (raises: []), gcsafe.} = + var aggregated: seq[`typeIdent`] = @[] + let providers = `accessProcIdent`().`zeroArgFieldName` + if providers.len == 0: + return ok(aggregated) + # var providersFut: seq[Future[Result[`typeIdent`, string]]] = collect: + var providersFut = collect(newSeq): + for provider in providers.values: + if provider.isNil(): + continue + provider() + + let catchable = catch: + await allFinished(providersFut) + + catchable.isOkOr: + return err("Some provider(s) failed:" & error.msg) + + for fut in catchable.get(): + if fut.failed(): + return err("Some provider(s) failed:" & fut.error.msg) + elif fut.finished(): + let providerResult = fut.value() + if providerResult.isOk: + let providerValue = providerResult.get() + when `isRefObjectLit`: + if providerValue.isNil(): + return err( + "MultiRequestBroker(" & `typeNameLit` & + "): provider returned nil result" + ) + aggregated.add(providerValue) + else: + return err("Some provider(s) failed:" & providerResult.error) + + return ok(aggregated) + + ) + if not argSig.isNil(): + result.add( + quote do: + proc setProvider*( + _: typedesc[`typeIdent`], handler: `argProviderName` + ): Result[`providerHandleIdent`, string] = + if handler.isNil(): + return err("Provider handler must be provided") + let broker = `accessProcIdent`() + if broker.nextId == 0'u64: + broker.nextId = 1'u64 + for existingId, existing in broker.`argFieldName`.pairs: + if existing == handler: + return ok(`providerHandleIdent`(id: existingId, kind: `argKindIdent`)) + let newId = broker.nextId + inc broker.nextId + broker.`argFieldName`[newId] = handler + return ok(`providerHandleIdent`(id: newId, kind: `argKindIdent`)) + + ) + clearBody.add( + quote do: + let broker = `accessProcIdent`() + if not broker.isNil() and broker.`argFieldName`.len > 0: + broker.`argFieldName`.clear() + ) + let requestParamDefs = cloneParams(argParams) + let argNameIdents = collectParamNames(requestParamDefs) + let providerSym = genSym(nskLet, "providerVal") + var providerCall = newCall(providerSym) + for argName in argNameIdents: + providerCall.add(argName) + var formalParams = newTree(nnkFormalParams) + formalParams.add( + quote do: + Future[Result[seq[`typeIdent`], string]] + ) + formalParams.add( + newTree( + nnkIdentDefs, + ident("_"), + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)), + newEmptyNode(), + ) + ) + for paramDef in requestParamDefs: + formalParams.add(paramDef) + let requestPragmas = quote: + {.async: (raises: []), gcsafe.} + let requestBody = quote: + var aggregated: seq[`typeIdent`] = @[] + let providers = `accessProcIdent`().`argFieldName` + if providers.len == 0: + return ok(aggregated) + var providersFut = collect(newSeq): + for provider in providers.values: + if provider.isNil(): + continue + let `providerSym` = provider + `providerCall` + let catchable = catch: + await allFinished(providersFut) + catchable.isOkOr: + return err("Some provider(s) failed:" & error.msg) + for fut in catchable.get(): + if fut.failed(): + return err("Some provider(s) failed:" & fut.error.msg) + elif fut.finished(): + let providerResult = fut.value() + if providerResult.isOk: + let providerValue = providerResult.get() + when `isRefObjectLit`: + if providerValue.isNil(): + return err( + "MultiRequestBroker(" & `typeNameLit` & + "): provider returned nil result" + ) + aggregated.add(providerValue) + else: + return err("Some provider(s) failed:" & providerResult.error) + return ok(aggregated) + + result.add( + newTree( + nnkProcDef, + postfix(ident("request"), "*"), + newEmptyNode(), + newEmptyNode(), + formalParams, + requestPragmas, + newEmptyNode(), + requestBody, + ) + ) + + result.add( + quote do: + proc clearProviders*(_: typedesc[`typeIdent`]) = + `clearBody` + let broker = `accessProcIdent`() + if not broker.isNil(): + broker.nextId = 1'u64 + + ) + + let removeHandleSym = genSym(nskParam, "handle") + let removeBrokerSym = genSym(nskLet, "broker") + var removeBody = newStmtList() + removeBody.add( + quote do: + if `removeHandleSym`.id == 0'u64: + return + let `removeBrokerSym` = `accessProcIdent`() + if `removeBrokerSym`.isNil(): + return + ) + if not zeroArgSig.isNil(): + removeBody.add( + quote do: + if `removeHandleSym`.kind == `zeroKindIdent`: + `removeBrokerSym`.`zeroArgFieldName`.del(`removeHandleSym`.id) + return + ) + if not argSig.isNil(): + removeBody.add( + quote do: + if `removeHandleSym`.kind == `argKindIdent`: + `removeBrokerSym`.`argFieldName`.del(`removeHandleSym`.id) + return + ) + removeBody.add( + quote do: + discard + ) + result.add( + quote do: + proc removeProvider*( + _: typedesc[`typeIdent`], `removeHandleSym`: `providerHandleIdent` + ) = + `removeBody` + + ) + + when defined(requestBrokerDebug): + echo result.repr diff --git a/waku/common/broker/request_broker.nim b/waku/common/broker/request_broker.nim new file mode 100644 index 000000000..a8a6651d7 --- /dev/null +++ b/waku/common/broker/request_broker.nim @@ -0,0 +1,438 @@ +## RequestBroker +## -------------------- +## RequestBroker represents a proactive decoupling pattern, that +## allows defining request-response style interactions between modules without +## need for direct dependencies in between. +## Worth considering using it in a single provider, many requester scenario. +## +## Provides a declarative way to define an immutable value type together with a +## thread-local broker that can register an asynchronous provider, dispatch typed +## requests and clear provider. +## +## Usage: +## Declare your desired request type inside a `RequestBroker` macro, add any number of fields. +## Define the provider signature, that is enforced at compile time. +## +## ```nim +## RequestBroker: +## type TypeName = object +## field1*: FieldType +## field2*: AnotherFieldType +## +## proc signature*(): Future[Result[TypeName, string]] +## ## Also possible to define signature with arbitrary input arguments. +## proc signature*(arg1: ArgType, arg2: AnotherArgType): Future[Result[TypeName, string]] +## +## ``` +## The 'TypeName' object defines the requestable data (but also can be seen as request for action with return value). +## The 'signature' proc defines the provider(s) signature, that is enforced at compile time. +## One signature can be with no arguments, another with any number of arguments - where the input arguments are +## not related to the request type - but alternative inputs for the request to be processed. +## +## After this, you can register a provider anywhere in your code with +## `TypeName.setProvider(...)`, which returns error if already having a provider. +## Providers are async procs or lambdas that take no arguments and return a Future[Result[TypeName, string]]. +## Only one provider can be registered at a time per signature type (zero arg and/or multi arg). +## +## Requests can be made from anywhere with no direct dependency on the provider by +## calling `TypeName.request()` - with arguments respecting the signature(s). +## This will asynchronously call the registered provider and return a Future[Result[TypeName, string]]. +## +## Whenever you no want to process requests (or your object instance that provides the request goes out of scope), +## you can remove it from the broker with `TypeName.clearProvider()`. +## +## +## Example: +## ```nim +## RequestBroker: +## type Greeting = object +## text*: string +## +## ## Define the request and provider signature, that is enforced at compile time. +## proc signature*(): Future[Result[Greeting, string]] +## +## ## Also possible to define signature with arbitrary input arguments. +## proc signature*(lang: string): Future[Result[Greeting, string]] +## +## ... +## Greeting.setProvider( +## proc(): Future[Result[Greeting, string]] {.async.} = +## ok(Greeting(text: "hello")) +## ) +## let res = await Greeting.request() +## ``` +## If no `signature` proc is declared, a zero-argument form is generated +## automatically, so the caller only needs to provide the type definition. + +import std/[macros, strutils] +import chronos +import results +import ./helper/broker_utils + +export results, chronos + +proc errorFuture[T](message: string): Future[Result[T, string]] {.inline.} = + ## Build a future that is already completed with an error result. + let fut = newFuture[Result[T, string]]("request_broker.errorFuture") + fut.complete(err(Result[T, string], message)) + fut + +proc isReturnTypeValid(returnType, typeIdent: NimNode): bool = + ## Accept Future[Result[TypeIdent, string]] as the contract. + if returnType.kind != nnkBracketExpr or returnType.len != 2: + return false + if returnType[0].kind != nnkIdent or not returnType[0].eqIdent("Future"): + return false + let inner = returnType[1] + if inner.kind != nnkBracketExpr or inner.len != 3: + return false + if inner[0].kind != nnkIdent or not inner[0].eqIdent("Result"): + return false + if inner[1].kind != nnkIdent or not inner[1].eqIdent($typeIdent): + return false + inner[2].kind == nnkIdent and inner[2].eqIdent("string") + +proc cloneParams(params: seq[NimNode]): seq[NimNode] = + ## Deep copy parameter definitions so they can be inserted in multiple places. + result = @[] + for param in params: + result.add(copyNimTree(param)) + +proc collectParamNames(params: seq[NimNode]): seq[NimNode] = + ## Extract all identifier symbols declared across IdentDefs nodes. + result = @[] + for param in params: + assert param.kind == nnkIdentDefs + for i in 0 ..< param.len - 2: + let nameNode = param[i] + if nameNode.kind == nnkEmpty: + continue + result.add(ident($nameNode)) + +proc makeProcType(returnType: NimNode, params: seq[NimNode]): NimNode = + var formal = newTree(nnkFormalParams) + formal.add(returnType) + for param in params: + formal.add(param) + let pragmas = newTree(nnkPragma, ident("async")) + newTree(nnkProcTy, formal, pragmas) + +macro RequestBroker*(body: untyped): untyped = + when defined(requestBrokerDebug): + echo body.treeRepr + var typeIdent: NimNode = nil + var objectDef: NimNode = nil + var isRefObject = false + for stmt in body: + if stmt.kind == nnkTypeSection: + for def in stmt: + if def.kind != nnkTypeDef: + continue + let rhs = def[2] + var objectType: NimNode + case rhs.kind + of nnkObjectTy: + objectType = rhs + of nnkRefTy: + isRefObject = true + if rhs.len != 1 or rhs[0].kind != nnkObjectTy: + error( + "RequestBroker ref object must wrap a concrete object definition", rhs + ) + objectType = rhs[0] + else: + continue + if not typeIdent.isNil(): + error("Only one object type may be declared inside RequestBroker", def) + typeIdent = baseTypeIdent(def[0]) + let recList = objectType[2] + if recList.kind != nnkRecList: + error("RequestBroker object must declare a standard field list", objectType) + var exportedRecList = newTree(nnkRecList) + for field in recList: + case field.kind + of nnkIdentDefs: + ensureFieldDef(field) + var cloned = copyNimTree(field) + for i in 0 ..< cloned.len - 2: + cloned[i] = exportIdentNode(cloned[i]) + exportedRecList.add(cloned) + of nnkEmpty: + discard + else: + error( + "RequestBroker object definition only supports simple field declarations", + field, + ) + let exportedObjectType = newTree( + nnkObjectTy, + copyNimTree(objectType[0]), + copyNimTree(objectType[1]), + exportedRecList, + ) + if isRefObject: + objectDef = newTree(nnkRefTy, exportedObjectType) + else: + objectDef = exportedObjectType + if typeIdent.isNil(): + error("RequestBroker body must declare exactly one object type", body) + + when defined(requestBrokerDebug): + echo "RequestBroker generating type: ", $typeIdent + + let exportedTypeIdent = postfix(copyNimTree(typeIdent), "*") + let typeDisplayName = sanitizeIdentName(typeIdent) + let typeNameLit = newLit(typeDisplayName) + let isRefObjectLit = newLit(isRefObject) + var zeroArgSig: NimNode = nil + var zeroArgProviderName: NimNode = nil + var zeroArgFieldName: NimNode = nil + var argSig: NimNode = nil + var argParams: seq[NimNode] = @[] + var argProviderName: NimNode = nil + var argFieldName: NimNode = nil + + for stmt in body: + case stmt.kind + of nnkProcDef: + let procName = stmt[0] + let procNameIdent = + case procName.kind + of nnkIdent: + procName + of nnkPostfix: + procName[1] + else: + procName + let procNameStr = $procNameIdent + if not procNameStr.startsWith("signature"): + error("Signature proc names must start with `signature`", procName) + let params = stmt.params + if params.len == 0: + error("Signature must declare a return type", stmt) + let returnType = params[0] + if not isReturnTypeValid(returnType, typeIdent): + error( + "Signature must return Future[Result[`" & $typeIdent & "`, string]]", stmt + ) + let paramCount = params.len - 1 + if paramCount == 0: + if zeroArgSig != nil: + error("Only one zero-argument signature is allowed", stmt) + zeroArgSig = stmt + zeroArgProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderNoArgs") + zeroArgFieldName = ident("providerNoArgs") + elif paramCount >= 1: + if argSig != nil: + error("Only one argument-based signature is allowed", stmt) + argSig = stmt + argParams = @[] + for idx in 1 ..< params.len: + let paramDef = params[idx] + if paramDef.kind != nnkIdentDefs: + error( + "Signature parameter must be a standard identifier declaration", paramDef + ) + let paramTypeNode = paramDef[paramDef.len - 2] + if paramTypeNode.kind == nnkEmpty: + error("Signature parameter must declare a type", paramDef) + var hasName = false + for i in 0 ..< paramDef.len - 2: + if paramDef[i].kind != nnkEmpty: + hasName = true + if not hasName: + error("Signature parameter must declare a name", paramDef) + argParams.add(copyNimTree(paramDef)) + argProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderWithArgs") + argFieldName = ident("providerWithArgs") + of nnkTypeSection, nnkEmpty: + discard + else: + error("Unsupported statement inside RequestBroker definition", stmt) + + if zeroArgSig.isNil() and argSig.isNil(): + zeroArgSig = newEmptyNode() + zeroArgProviderName = ident(sanitizeIdentName(typeIdent) & "ProviderNoArgs") + zeroArgFieldName = ident("providerNoArgs") + + var typeSection = newTree(nnkTypeSection) + typeSection.add(newTree(nnkTypeDef, exportedTypeIdent, newEmptyNode(), objectDef)) + + let returnType = quote: + Future[Result[`typeIdent`, string]] + + if not zeroArgSig.isNil(): + let procType = makeProcType(returnType, @[]) + typeSection.add(newTree(nnkTypeDef, zeroArgProviderName, newEmptyNode(), procType)) + if not argSig.isNil(): + let procType = makeProcType(returnType, cloneParams(argParams)) + typeSection.add(newTree(nnkTypeDef, argProviderName, newEmptyNode(), procType)) + + var brokerRecList = newTree(nnkRecList) + if not zeroArgSig.isNil(): + brokerRecList.add( + newTree(nnkIdentDefs, zeroArgFieldName, zeroArgProviderName, newEmptyNode()) + ) + if not argSig.isNil(): + brokerRecList.add( + newTree(nnkIdentDefs, argFieldName, argProviderName, newEmptyNode()) + ) + let brokerTypeIdent = ident(sanitizeIdentName(typeIdent) & "Broker") + let brokerTypeDef = newTree( + nnkTypeDef, + brokerTypeIdent, + newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), brokerRecList), + ) + typeSection.add(brokerTypeDef) + result = newStmtList() + result.add(typeSection) + + let globalVarIdent = ident("g" & sanitizeIdentName(typeIdent) & "Broker") + let accessProcIdent = ident("access" & sanitizeIdentName(typeIdent) & "Broker") + result.add( + quote do: + var `globalVarIdent` {.threadvar.}: `brokerTypeIdent` + + proc `accessProcIdent`(): var `brokerTypeIdent` = + `globalVarIdent` + + ) + + var clearBody = newStmtList() + if not zeroArgSig.isNil(): + result.add( + quote do: + proc setProvider*( + _: typedesc[`typeIdent`], handler: `zeroArgProviderName` + ): Result[void, string] = + if not `accessProcIdent`().`zeroArgFieldName`.isNil(): + return err("Zero-arg provider already set") + `accessProcIdent`().`zeroArgFieldName` = handler + return ok() + + ) + clearBody.add( + quote do: + `accessProcIdent`().`zeroArgFieldName` = nil + ) + result.add( + quote do: + proc request*( + _: typedesc[`typeIdent`] + ): Future[Result[`typeIdent`, string]] {.async: (raises: []).} = + let provider = `accessProcIdent`().`zeroArgFieldName` + if provider.isNil(): + return err( + "RequestBroker(" & `typeNameLit` & "): no zero-arg provider registered" + ) + let catchedRes = catch: + await provider() + + if catchedRes.isErr(): + return err("Request failed:" & catchedRes.error.msg) + + let providerRes = catchedRes.get() + when `isRefObjectLit`: + if providerRes.isOk(): + let resultValue = providerRes.get() + if resultValue.isNil(): + return err( + "RequestBroker(" & `typeNameLit` & "): provider returned nil result" + ) + return providerRes + + ) + if not argSig.isNil(): + result.add( + quote do: + proc setProvider*( + _: typedesc[`typeIdent`], handler: `argProviderName` + ): Result[void, string] = + if not `accessProcIdent`().`argFieldName`.isNil(): + return err("Provider already set") + `accessProcIdent`().`argFieldName` = handler + return ok() + + ) + clearBody.add( + quote do: + `accessProcIdent`().`argFieldName` = nil + ) + let requestParamDefs = cloneParams(argParams) + let argNameIdents = collectParamNames(requestParamDefs) + let providerSym = genSym(nskLet, "provider") + var formalParams = newTree(nnkFormalParams) + formalParams.add( + quote do: + Future[Result[`typeIdent`, string]] + ) + formalParams.add( + newTree( + nnkIdentDefs, + ident("_"), + newTree(nnkBracketExpr, ident("typedesc"), copyNimTree(typeIdent)), + newEmptyNode(), + ) + ) + for paramDef in requestParamDefs: + formalParams.add(paramDef) + + let requestPragmas = quote: + {.async: (raises: []), gcsafe.} + var providerCall = newCall(providerSym) + for argName in argNameIdents: + providerCall.add(argName) + var requestBody = newStmtList() + requestBody.add( + quote do: + let `providerSym` = `accessProcIdent`().`argFieldName` + ) + requestBody.add( + quote do: + if `providerSym`.isNil(): + return err( + "RequestBroker(" & `typeNameLit` & + "): no provider registered for input signature" + ) + ) + requestBody.add( + quote do: + let catchedRes = catch: + await `providerCall` + if catchedRes.isErr(): + return err("Request failed:" & catchedRes.error.msg) + + let providerRes = catchedRes.get() + when `isRefObjectLit`: + if providerRes.isOk(): + let resultValue = providerRes.get() + if resultValue.isNil(): + return err( + "RequestBroker(" & `typeNameLit` & "): provider returned nil result" + ) + return providerRes + ) + # requestBody.add(providerCall) + result.add( + newTree( + nnkProcDef, + postfix(ident("request"), "*"), + newEmptyNode(), + newEmptyNode(), + formalParams, + requestPragmas, + newEmptyNode(), + requestBody, + ) + ) + + result.add( + quote do: + proc clearProvider*(_: typedesc[`typeIdent`]) = + `clearBody` + + ) + + when defined(requestBrokerDebug): + echo result.repr