diff --git a/json_rpc/router.nim b/json_rpc/router.nim index a3dad24..81b1401 100644 --- a/json_rpc/router.nim +++ b/json_rpc/router.nim @@ -196,17 +196,15 @@ proc route*(router: RpcRouter, data: string): return reply -proc tryRoute*(router: RpcRouter, data: JsonString, +proc tryRoute*(router: RpcRouter, req: RequestRx, fut: var Future[JsonString]): Result[void, string] = ## Route to RPC, returns false if the method or params cannot be found. - ## Expects json input and returns json output. + ## Expects RequestRx input and returns json output. when defined(nimHasWarnBareExcept): {.push warning[BareExcept]:off.} {.push warning[UnreachableCode]:off.} try: - let req = JrpcSys.decode(data.string, RequestRx) - if req.jsonrpc.isNone: return err("`jsonrpc` missing or invalid") @@ -229,6 +227,16 @@ proc tryRoute*(router: RpcRouter, data: JsonString, {.pop warning[BareExcept]:on.} {.pop warning[UnreachableCode]:on.} +proc tryRoute*(router: RpcRouter, data: JsonString, + fut: var Future[JsonString]): Result[void, string] = + ## Route to RPC, returns false if the method or params cannot be found. + ## Expects json input and returns json output. + try: + let req = JrpcSys.decode(data.string, RequestRx) + return router.tryRoute(req, fut) + except CatchableError as ex: + return err(ex.msg) + macro rpc*(server: RpcRouter, path: static[string], body: untyped): untyped = ## Define a remote procedure call. ## Input and return parameters are defined using the ``do`` notation. diff --git a/json_rpc/servers/socketserver.nim b/json_rpc/servers/socketserver.nim index dfb9df9..ad12f19 100644 --- a/json_rpc/servers/socketserver.nim +++ b/json_rpc/servers/socketserver.nim @@ -18,6 +18,7 @@ export errors, server type RpcSocketServer* = ref object of RpcServer servers: seq[StreamServer] + processClientHook: StreamCallback2 proc processClient(server: StreamServer, transport: StreamTransport) {.async: (raises: []), gcsafe.} = ## Process transport data to the RPC server @@ -44,7 +45,7 @@ proc processClient(server: StreamServer, transport: StreamTransport) {.async: (r proc addStreamServer*(server: RpcSocketServer, address: TransportAddress) = try: info "Starting JSON-RPC socket server", address = $address - var transportServer = createStreamServer(address, processClient, {ReuseAddr}, udata = server) + var transportServer = createStreamServer(address, server.processClientHook, {ReuseAddr}, udata = server) server.servers.add(transportServer) except CatchableError as exc: error "Failed to create server", address = $address, message = exc.msg @@ -135,7 +136,7 @@ proc addStreamServer*(server: RpcSocketServer, address: string, port: Port) = "Could not setup server on " & address & ":" & $int(port)) proc new(T: type RpcSocketServer): T = - T(router: RpcRouter.init(), servers: @[]) + T(router: RpcRouter.init(), servers: @[], processClientHook: processClient) proc newRpcSocketServer*(): RpcSocketServer = RpcSocketServer.new() @@ -155,6 +156,11 @@ proc newRpcSocketServer*(address: string, port: Port = Port(8545)): RpcSocketSer result = RpcSocketServer.new() result.addStreamServer(address, port) +proc newRpcSocketServer*(processClientHook: StreamCallback2): RpcSocketServer = + ## Create new server with custom processClientHook. + result = RpcSocketServer.new() + result.processClientHook = processClientHook + proc start*(server: RpcSocketServer) = ## Start the RPC server. for item in server.servers: diff --git a/tests/testserverclient.nim b/tests/testserverclient.nim index 23562d1..052887b 100644 --- a/tests/testserverclient.nim +++ b/tests/testserverclient.nim @@ -114,3 +114,20 @@ suite "Websocket Server/Client RPC with Compression": srv.stop() waitFor srv.closeWait() +suite "Custom processClient": + test "Should be able to use custom processClient": + var wasCalled: bool = false + + proc processClientHook(server: StreamServer, transport: StreamTransport) {.async: (raises: []), gcsafe.} = + wasCalled = true + + var srv = newRpcSocketServer(processClientHook) + srv.addStreamServer("localhost", Port(8888)) + var client = newRpcSocketClient() + srv.setupServer() + srv.start() + waitFor client.connect(srv.localAddress()[0]) + asyncCheck client.call("", %[]) + srv.stop() + waitFor srv.closeWait() + check wasCalled