From d832f92a439eb541f689073b4c40caf02afbf0ad Mon Sep 17 00:00:00 2001 From: NagyZoltanPeter <113987313+NagyZoltanPeter@users.noreply.github.com> Date: Thu, 29 Feb 2024 09:48:14 +0100 Subject: [PATCH] chore: Implemented CORS handling for nwaku REST server (#2470) * Add allowOrigin configuration for wakunode and WakuRestServer Update nim-presto to the latest master that contains middleware support Rework Rest Server in waku to utilize chronos' and presto's new middleware design and added proper CORS handling. Added cors tests and fixes Co-authored-by: Ivan FB <128452529+Ivansete-status@users.noreply.github.com> --- apps/wakunode2/app.nim | 63 ++-- apps/wakunode2/external_config.nim | 8 + tests/all_tests_waku.nim | 3 +- tests/wakunode_rest/test_all.nim | 3 +- tests/wakunode_rest/test_rest_admin.nim | 4 +- tests/wakunode_rest/test_rest_cors.nim | 269 ++++++++++++++++++ tests/wakunode_rest/test_rest_debug.nim | 4 +- tests/wakunode_rest/test_rest_filter.nim | 8 +- tests/wakunode_rest/test_rest_health.nim | 2 +- .../wakunode_rest/test_rest_legacy_filter.nim | 4 +- tests/wakunode_rest/test_rest_lightpush.nim | 4 +- tests/wakunode_rest/test_rest_relay.nim | 58 ++-- tests/wakunode_rest/test_rest_store.nim | 20 +- vendor/nim-presto | 2 +- waku/waku_api/rest/origin_handler.nim | 125 ++++++++ waku/waku_api/rest/server.nim | 143 ++++++++-- 16 files changed, 623 insertions(+), 97 deletions(-) create mode 100644 tests/wakunode_rest/test_rest_cors.nim create mode 100644 waku/waku_api/rest/origin_handler.nim diff --git a/apps/wakunode2/app.nim b/apps/wakunode2/app.nim index b064c2453..10a91abf3 100644 --- a/apps/wakunode2/app.nim +++ b/apps/wakunode2/app.nim @@ -84,7 +84,7 @@ type node: WakuNode rpcServer: Option[RpcHttpServer] - restServer: Option[RestServerRef] + restServer: Option[WakuRestServerRef] metricsServer: Option[MetricsHttpServerRef] AppResult*[T] = Result[T, string] @@ -667,34 +667,55 @@ proc startApp*(app: var App): AppResult[void] = ## Monitoring and external interfaces -proc startRestServer(app: App, address: IpAddress, port: Port, conf: WakuNodeConf): AppResult[RestServerRef] = +proc startRestServer(app: App, + address: IpAddress, + port: Port, + conf: WakuNodeConf): + AppResult[WakuRestServerRef] = # Used to register api endpoints that are not currently installed as keys, # values are holding error messages to be returned to the client var notInstalledTab: Table[string, string] = initTable[string, string]() - proc requestErrorHandler(error: RestRequestError, - request: HttpRequestRef): - Future[HttpResponseRef] {.async.} = - case error - of RestRequestError.Invalid: - return await request.respond(Http400, "Invalid request", HttpTable.init()) - of RestRequestError.NotFound: - let rootPath = request.rawPath.split("/")[1] - if notInstalledTab.hasKey(rootPath): - return await request.respond(Http404, notInstalledTab[rootPath], HttpTable.init()) - else: - return await request.respond(Http400, "Bad request initiated. Invalid path or method used.", HttpTable.init()) - of RestRequestError.InvalidContentBody: - return await request.respond(Http400, "Invalid content body", HttpTable.init()) - of RestRequestError.InvalidContentType: - return await request.respond(Http400, "Invalid content type", HttpTable.init()) - of RestRequestError.Unexpected: - return defaultResponse() + let requestErrorHandler : RestRequestErrorHandler = proc (error: RestRequestError, + request: HttpRequestRef): + Future[HttpResponseRef] + {.async: (raises: [CancelledError]).} = + try: + case error + of RestRequestError.Invalid: + return await request.respond(Http400, "Invalid request", HttpTable.init()) + of RestRequestError.NotFound: + let paths = request.rawPath.split("/") + let rootPath = if len(paths) > 1: + paths[1] + else: + "" + notInstalledTab.withValue(rootPath, errMsg): + return await request.respond(Http404, errMsg[], HttpTable.init()) + do: + return await request.respond(Http400, "Bad request initiated. Invalid path or method used.", HttpTable.init()) + of RestRequestError.InvalidContentBody: + return await request.respond(Http400, "Invalid content body", HttpTable.init()) + of RestRequestError.InvalidContentType: + return await request.respond(Http400, "Invalid content type", HttpTable.init()) + of RestRequestError.Unexpected: + return defaultResponse() + except HttpWriteError: + error "Failed to write response to client", error = getCurrentExceptionMsg() + discard return defaultResponse() - let server = ? newRestHttpServer(address, port, requestErrorHandler = requestErrorHandler) + let allowedOrigin = if len(conf.restAllowOrigin) > 0 : + some(conf.restAllowOrigin.join(",")) + else: + none(string) + + let server = ? newRestHttpServer(address, port, + allowedOrigin = allowedOrigin, + requestErrorHandler = requestErrorHandler) + ## Admin REST API if conf.restAdmin: installAdminApiHandlers(server.router, app.node) diff --git a/apps/wakunode2/external_config.nim b/apps/wakunode2/external_config.nim index 88739de04..e9c1d1e9f 100644 --- a/apps/wakunode2/external_config.nim +++ b/apps/wakunode2/external_config.nim @@ -419,6 +419,14 @@ type defaultValue: false name: "rest-private" }: bool + restAllowOrigin* {. + desc: "Allow cross-origin requests from the specified origin." & + "Argument may be repeated." & + "Wildcards: * or ? allowed." & + "Ex.: \"localhost:*\" or \"127.0.0.1:8080\"", + defaultValue: newSeq[string]() + name: "rest-allow-origin" }: seq[string] + ## Metrics config metricsServer* {. diff --git a/tests/all_tests_waku.nim b/tests/all_tests_waku.nim index 303021404..17d76a50c 100644 --- a/tests/all_tests_waku.nim +++ b/tests/all_tests_waku.nim @@ -85,7 +85,8 @@ import ./wakunode_rest/test_rest_filter, ./wakunode_rest/test_rest_legacy_filter, ./wakunode_rest/test_rest_lightpush, - ./wakunode_rest/test_rest_admin + ./wakunode_rest/test_rest_admin, + ./wakunode_rest/test_rest_cors import ./waku_rln_relay/test_waku_rln_relay, diff --git a/tests/wakunode_rest/test_all.nim b/tests/wakunode_rest/test_all.nim index 620ae8a70..9829a78f2 100644 --- a/tests/wakunode_rest/test_all.nim +++ b/tests/wakunode_rest/test_all.nim @@ -10,4 +10,5 @@ import ./test_rest_relay, ./test_rest_serdes, ./test_rest_store, - ./test_rest_admin + ./test_rest_admin, + ./test_rest_cors diff --git a/tests/wakunode_rest/test_rest_admin.nim b/tests/wakunode_rest/test_rest_admin.nim index 932dcd743..628359ab3 100644 --- a/tests/wakunode_rest/test_rest_admin.nim +++ b/tests/wakunode_rest/test_rest_admin.nim @@ -30,7 +30,7 @@ suite "Waku v2 Rest API - Admin": var peerInfo1 {.threadvar.}: RemotePeerInfo var peerInfo2 {.threadvar.}: RemotePeerInfo var peerInfo3 {.threadvar.}: RemotePeerInfo - var restServer {.threadvar.}: RestServerRef + var restServer {.threadvar.}: WakuRestServerRef var client{.threadvar.}: RestClientRef asyncSetup: @@ -46,7 +46,7 @@ suite "Waku v2 Rest API - Admin": let restPort = Port(58011) let restAddress = parseIpAddress("127.0.0.1") - restServer = RestServerRef.init(restAddress, restPort).tryGet() + restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installAdminApiHandlers(restServer.router, node1) diff --git a/tests/wakunode_rest/test_rest_cors.nim b/tests/wakunode_rest/test_rest_cors.nim new file mode 100644 index 000000000..3ac166f07 --- /dev/null +++ b/tests/wakunode_rest/test_rest_cors.nim @@ -0,0 +1,269 @@ +{.used.} + +import + stew/shims/net, + testutils/unittests, + presto, + presto/client as presto_client, + libp2p/peerinfo, + libp2p/multiaddress, + libp2p/crypto/crypto +import + ../../waku/waku_node, + ../../waku/node/waku_node as waku_node2, + ../../waku/waku_api/rest/server, + ../../waku/waku_api/rest/client, + ../../waku/waku_api/rest/responses, + ../../waku/waku_api/rest/debug/handlers as debug_api, + ../../waku/waku_api/rest/debug/client as debug_api_client, + ../testlib/common, + ../testlib/wakucore, + ../testlib/wakunode + + +type + TestResponseTuple = tuple[status: int, data: string, headers: HttpTable] + +proc testWakuNode(): WakuNode = + let + privkey = crypto.PrivateKey.random(Secp256k1, rng[]).tryGet() + bindIp = parseIpAddress("0.0.0.0") + extIp = parseIpAddress("127.0.0.1") + port = Port(58000) + + newTestWakuNode(privkey, bindIp, port, some(extIp), some(port)) + +proc fetchWithHeader(request: HttpClientRequestRef): Future[TestResponseTuple] + {.async: (raises: [CancelledError, HttpError]).} = + var response: HttpClientResponseRef + try: + response = await request.send() + let buffer = await response.getBodyBytes() + let status = response.status + let headers = response.headers + await response.closeWait() + response = nil + return (status, buffer.bytesToString(), headers) + except HttpError as exc: + if not(isNil(response)): await response.closeWait() + assert false + except CancelledError as exc: + if not(isNil(response)): await response.closeWait() + assert false + +proc issueRequest( + address: HttpAddress, + reqOrigin: Option[string] = none(string) + ): Future[TestResponseTuple] {.async.} = + + var + session = HttpSessionRef.new({HttpClientFlag.Http11Pipeline}) + data: TestResponseTuple + + var originHeader : seq[HttpHeaderTuple] + if reqOrigin.isSome(): + originHeader.insert(("Origin", reqOrigin.get())) + + var + request = HttpClientRequestRef.new(session, + address, + version = HttpVersion11, + headers = originHeader) + try: + data = await request.fetchWithHeader() + finally: + await request.closeWait() + return data + +proc checkResponse(response: TestResponseTuple, + expectedStatus : int, + expectedOrigin : Option[string]): bool = + if response.status != expectedStatus: + echo(" -> check failed: expected status" & $expectedStatus & + " got " & $response.status) + return false + + if not (expectedOrigin.isNone() or + (expectedOrigin.isSome() and + response.headers.contains("Access-Control-Allow-Origin") and + response.headers.getLastString("Access-Control-Allow-Origin") == expectedOrigin.get())): + echo(" -> check failed: expected origin " & $expectedOrigin & " got " & + response.headers.getLastString("Access-Control-Allow-Origin")) + return false + + return true + +suite "Waku v2 REST API CORS Handling": + asyncTest "AllowedOrigin matches": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("test.net:1234,https://localhost:*,http://127.0.0.1:?8,?waku*.net:*80*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, some("http://test.net:1234")) + check checkResponse(response, 200, some("http://test.net:1234")) + + response = await issueRequest(ha, some("https://test.net:1234")) + check checkResponse(response, 200, some("https://test.net:1234")) + + response = await issueRequest(ha, some("https://localhost:8080")) + check checkResponse(response, 200, some("https://localhost:8080")) + + response = await issueRequest(ha, some("https://localhost:80")) + check checkResponse(response, 200, some("https://localhost:80")) + + response = await issueRequest(ha, some("http://127.0.0.1:78")) + check checkResponse(response, 200, some("http://127.0.0.1:78")) + + response = await issueRequest(ha, some("http://wakuTHE.net:8078")) + check checkResponse(response, 200, some("http://wakuTHE.net:8078")) + + response = await issueRequest(ha, some("http://nwaku.main.net:1980")) + check checkResponse(response, 200, some("http://nwaku.main.net:1980")) + + response = await issueRequest(ha, some("http://nwaku.main.net:80")) + check checkResponse(response, 200, some("http://nwaku.main.net:80")) + + await restServer.stop() + await restServer.closeWait() + await node.stop() + + asyncTest "AllowedOrigin reject": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("test.net:1234,https://localhost:*,http://127.0.0.1:?8,?waku*.net:*80*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, some("http://test.net:12334")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://test.net:12345")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("xhttp://test.net:1234")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("https://xtest.net:1234")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://localhost:8080")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("https://127.0.0.1:78")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://127.0.0.1:89")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://the.waku.net:8078")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://nwaku.main.net:1900")) + check checkResponse(response, 403, none(string)) + + await restServer.stop() + await restServer.closeWait() + await node.stop() + + asyncTest "AllowedOrigin allmatches": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, some("http://test.net:1234")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("https://test.net:1234")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("https://localhost:8080")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("https://localhost:80")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://127.0.0.1:78")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://wakuTHE.net:8078")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://nwaku.main.net:1980")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://nwaku.main.net:80")) + check checkResponse(response, 200, some("*")) + + await restServer.stop() + await restServer.closeWait() + await node.stop() + + asyncTest "No origin goes through": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("test.net:1234,https://localhost:*,http://127.0.0.1:?8,?waku*.net:*80*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, none(string)) + check checkResponse(response, 200, none(string)) + + await restServer.stop() + await restServer.closeWait() + await node.stop() diff --git a/tests/wakunode_rest/test_rest_debug.nim b/tests/wakunode_rest/test_rest_debug.nim index ab458ad0c..4f1c4ac6a 100644 --- a/tests/wakunode_rest/test_rest_debug.nim +++ b/tests/wakunode_rest/test_rest_debug.nim @@ -40,7 +40,7 @@ suite "Waku v2 REST API - Debug": let restPort = Port(58001) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installDebugApiHandlers(restServer.router, node) restServer.start() @@ -67,7 +67,7 @@ suite "Waku v2 REST API - Debug": let restPort = Port(58002) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installDebugApiHandlers(restServer.router, node) restServer.start() diff --git a/tests/wakunode_rest/test_rest_filter.nim b/tests/wakunode_rest/test_rest_filter.nim index b86bb8925..10a2d5b8f 100644 --- a/tests/wakunode_rest/test_rest_filter.nim +++ b/tests/wakunode_rest/test_rest_filter.nim @@ -39,8 +39,8 @@ proc testWakuNode(): WakuNode = type RestFilterTest = object serviceNode: WakuNode subscriberNode: WakuNode - restServer: RestServerRef - restServerForService: RestServerRef + restServer: WakuRestServerRef + restServerForService: WakuRestServerRef messageCache: MessageCache client: RestClientRef clientTwdServiceNode: RestClientRef @@ -61,10 +61,10 @@ proc init(T: type RestFilterTest): Future[T] {.async.} = let restPort = Port(58011) let restAddress = parseIpAddress("127.0.0.1") - testSetup.restServer = RestServerRef.init(restAddress, restPort).tryGet() + testSetup.restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() let restPort2 = Port(58012) - testSetup.restServerForService = RestServerRef.init(restAddress, restPort2).tryGet() + testSetup.restServerForService = WakuRestServerRef.init(restAddress, restPort2).tryGet() # through this one we will see if messages are pushed according to our content topic sub testSetup.messageCache = MessageCache.init() diff --git a/tests/wakunode_rest/test_rest_health.nim b/tests/wakunode_rest/test_rest_health.nim index 9560384a0..5937b8dba 100644 --- a/tests/wakunode_rest/test_rest_health.nim +++ b/tests/wakunode_rest/test_rest_health.nim @@ -44,7 +44,7 @@ suite "Waku v2 REST API - health": let restPort = Port(58001) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installHealthApiHandler(restServer.router, node) restServer.start() diff --git a/tests/wakunode_rest/test_rest_legacy_filter.nim b/tests/wakunode_rest/test_rest_legacy_filter.nim index fe82d8117..f7de83372 100644 --- a/tests/wakunode_rest/test_rest_legacy_filter.nim +++ b/tests/wakunode_rest/test_rest_legacy_filter.nim @@ -38,7 +38,7 @@ proc testWakuNode(): WakuNode = type RestFilterTest = object filterNode: WakuNode clientNode: WakuNode - restServer: RestServerRef + restServer: WakuRestServerRef messageCache: MessageCache client: RestClientRef @@ -58,7 +58,7 @@ proc setupRestFilter(): Future[RestFilterTest] {.async.} = let restPort = Port(58011) let restAddress = parseIpAddress("0.0.0.0") - result.restServer = RestServerRef.init(restAddress, restPort).tryGet() + result.restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() result.messageCache = MessageCache.init() installLegacyFilterRestApiHandlers(result.restServer.router diff --git a/tests/wakunode_rest/test_rest_lightpush.nim b/tests/wakunode_rest/test_rest_lightpush.nim index 876be7c64..34dc40dd2 100644 --- a/tests/wakunode_rest/test_rest_lightpush.nim +++ b/tests/wakunode_rest/test_rest_lightpush.nim @@ -39,7 +39,7 @@ type RestLightPushTest = object serviceNode: WakuNode pushNode: WakuNode consumerNode: WakuNode - restServer: RestServerRef + restServer: WakuRestServerRef client: RestClientRef @@ -71,7 +71,7 @@ proc init(T: type RestLightPushTest): Future[T] {.async.} = let restPort = Port(58011) let restAddress = parseIpAddress("127.0.0.1") - testSetup.restServer = RestServerRef.init(restAddress, restPort).tryGet() + testSetup.restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installLightPushRequestHandler(testSetup.restServer.router, testSetup.pushNode) diff --git a/tests/wakunode_rest/test_rest_relay.nim b/tests/wakunode_rest/test_rest_relay.nim index 0e0f40424..9ec2485af 100644 --- a/tests/wakunode_rest/test_rest_relay.nim +++ b/tests/wakunode_rest/test_rest_relay.nim @@ -43,9 +43,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -93,9 +93,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() cache.pubsubSubscribe("pubsub-topic-1") @@ -147,12 +147,12 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let pubSubTopic = "/waku/2/default-waku/proto" - + var messages = @[ fakeWakuMessage(contentTopic = "content-topic-x", payload = toBytes("TEST-1"), meta = toBytes("test-meta") ) @@ -168,7 +168,7 @@ suite "Waku v2 Rest API - Relay": meta = toBytes("test-meta")) messages.add(msg) - + let cache = MessageCache.init() cache.pubsubSubscribe(pubSubTopic) @@ -216,9 +216,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -258,9 +258,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -306,9 +306,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let contentTopics = @[ ContentTopic("/waku/2/default-content1/proto"), @@ -354,9 +354,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let contentTopic = DefaultContentTopic @@ -370,9 +370,9 @@ suite "Waku v2 Rest API - Relay": while msg == messages[i]: msg = fakeWakuMessage(contentTopic = DefaultContentTopic, payload = toBytes("TEST-1")) - + messages.add(msg) - + let cache = MessageCache.init() cache.contentSubscribe(contentTopic) @@ -419,9 +419,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() installRelayApiHandlers(restServer.router, node, cache) @@ -464,9 +464,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() installRelayApiHandlers(restServer.router, node, cache) @@ -504,9 +504,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -525,7 +525,7 @@ suite "Waku v2 Rest API - Relay": contentTopic: some(DefaultContentTopic), timestamp: some(int64(2022)) )) - + # Then check: response.status == 400 @@ -549,9 +549,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -570,7 +570,7 @@ suite "Waku v2 Rest API - Relay": contentTopic: some(DefaultContentTopic), timestamp: some(int64(2022)) )) - + # Then check: response.status == 400 @@ -579,4 +579,4 @@ suite "Waku v2 Rest API - Relay": await restServer.stop() await restServer.closeWait() - await node.stop() \ No newline at end of file + await node.stop() diff --git a/tests/wakunode_rest/test_rest_store.nim b/tests/wakunode_rest/test_rest_store.nim index 6b617970c..908eec571 100644 --- a/tests/wakunode_rest/test_rest_store.nim +++ b/tests/wakunode_rest/test_rest_store.nim @@ -83,7 +83,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58011) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -153,7 +153,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58012) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -251,7 +251,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58013) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -325,7 +325,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58014) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -416,7 +416,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58015) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -473,7 +473,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58016) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -482,7 +482,7 @@ procSuite "Waku v2 Rest API - Store": let driver: ArchiveDriver = QueueDriver.new() let mountArchiveRes = node.mountArchive(driver) assert mountArchiveRes.isOk(), mountArchiveRes.error - + await node.mountStore() node.mountStoreClient() @@ -547,7 +547,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58014) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -609,9 +609,9 @@ procSuite "Waku v2 Rest API - Store": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use installStoreApiHandlers(restServer.router, node) restServer.start() diff --git a/vendor/nim-presto b/vendor/nim-presto index 5ca16485e..223aadeb8 160000 --- a/vendor/nim-presto +++ b/vendor/nim-presto @@ -1 +1 @@ -Subproject commit 5ca16485e4d74e531d50d289ebc0f869d9e6352b +Subproject commit 223aadeb82d35b57e6dae99f0b325ec6345ce7ff diff --git a/waku/waku_api/rest/origin_handler.nim b/waku/waku_api/rest/origin_handler.nim new file mode 100644 index 000000000..fa3453a06 --- /dev/null +++ b/waku/waku_api/rest/origin_handler.nim @@ -0,0 +1,125 @@ +when (NimMajor, NimMinor) < (1, 4): + {.push raises: [Defect].} +else: + {.push raises: [].} + +import + std/[options, strutils, re], + stew/results, + stew/shims/net, + chronicles, + chronos, + chronos/apps/http/httpserver + +type + OriginHandlerMiddlewareRef* = ref object of HttpServerMiddlewareRef + allowedOriginMatcher: Option[Regex] + everyOriginAllowed: bool + + +proc isEveryOriginAllowed(maybeAllowedOrigin: Option[string]): bool = + return maybeAllowedOrigin.isSome() and maybeAllowedOrigin.get() == "*" + +proc compileOriginMatcher(maybeAllowedOrigin: Option[string]): Option[Regex] = + if maybeAllowedOrigin.isNone(): + return none(Regex) + + let allowedOrigin = maybeAllowedOrigin.get() + + if (len(allowedOrigin) == 0): + return none(Regex) + + try: + var matchOrigin : string + + if allowedOrigin == "*": + matchOrigin = r".*" + return some(re(matchOrigin, {reIgnoreCase, reExtended})) + + let allowedOrigins = allowedOrigin.split(",") + + var matchExpressions : seq[string] = @[] + + var prefix : string + for allowedOrigin in allowedOrigins: + if allowedOrigin.startsWith("http://"): + prefix = r"http:\/\/" + matchOrigin = allowedOrigin.substr(7) + elif allowedOrigin.startsWith("https://"): + prefix = r"https:\/\/" + matchOrigin = allowedOrigin.substr(8) + else: + prefix = r"https?:\/\/" + matchOrigin = allowedOrigin + + matchOrigin = matchOrigin.replace(".", r"\.") + matchOrigin = matchOrigin.replace("*", ".*") + matchOrigin = matchOrigin.replace("?", ".?") + + matchExpressions.add("^" & prefix & matchOrigin & "$") + + let finalExpression = matchExpressions.join("|") + + return some(re(finalExpression, {reIgnoreCase, reExtended})) + except RegexError: + var msg = getCurrentExceptionMsg() + error "Failed to compile regex", source=allowedOrigin, err=msg + return none(Regex) + +proc originsMatch(originHandler: OriginHandlerMiddlewareRef, + requestOrigin: string): bool = + + if originHandler.allowedOriginMatcher.isNone(): + return false + + return requestOrigin.match(originHandler.allowedOriginMatcher.get()) + +proc originMiddlewareProc( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors that detected before our middleware. + # Let final handler deal with it. + return await nextHandler(reqfence) + + let self = OriginHandlerMiddlewareRef(middleware) + let request = reqfence.get() + var reqHeaders = request.headers + var response = request.getResponse() + + if self.allowedOriginMatcher.isSome(): + let origin = reqHeaders.getList("Origin") + try: + if origin.len == 1: + if self.everyOriginAllowed: + response.addHeader("Access-Control-Allow-Origin", "*") + elif self.originsMatch(origin[0]): + # The Vary: Origin header to must be set to prevent + # potential cache poisoning attacks: + # https://textslashplain.com/2018/08/02/cors-and-vary/ + response.addHeader("Vary", "Origin") + response.addHeader("Access-Control-Allow-Origin", origin[0]) + else: + return await request.respond(Http403, "Origin not allowed") + elif origin.len == 0: + discard + elif origin.len > 1: + return await request.respond(Http400, "Only a single Origin header must be specified") + except HttpWriteError as exc: + # We use default error handler if we unable to send response. + return defaultResponse(exc) + + # Calling next handler. + return await nextHandler(reqfence) + +proc new*(t: typedesc[OriginHandlerMiddlewareRef], + allowedOrigin: Option[string] = none(string) + ): HttpServerMiddlewareRef = + + let middleware = + OriginHandlerMiddlewareRef(allowedOriginMatcher: compileOriginMatcher(allowedOrigin), + everyOriginAllowed: isEveryOriginAllowed(allowedOrigin), + handler: originMiddlewareProc) + return HttpServerMiddlewareRef(middleware) diff --git a/waku/waku_api/rest/server.nim b/waku/waku_api/rest/server.nim index ac73beb54..95a76d848 100644 --- a/waku/waku_api/rest/server.nim +++ b/waku/waku_api/rest/server.nim @@ -8,11 +8,23 @@ import stew/shims/net, chronicles, chronos, - presto + chronos/apps/http/httpserver, + presto, + presto/middleware, + presto/servercommon + +import + ./origin_handler -type RestServerResult*[T] = Result[T, string] +type + RestServerResult*[T] = Result[T, string] + WakuRestServer* = object of RootObj + router*: RestRouter + httpServer*: HttpServerRef + + WakuRestServerRef* = ref WakuRestServer ### Configuration @@ -46,7 +58,59 @@ proc default*(T: type RestServerConf): T = ### Initialization -proc getRouter(allowedOrigin: Option[string]): RestRouter = +proc new*(t: typedesc[WakuRestServerRef], + router: RestRouter, + address: TransportAddress, + serverIdent: string = PrestoIdent, + serverFlags = {HttpServerFlags.NotifyDisconnect}, + socketFlags: set[ServerFlags] = {ReuseAddr}, + serverUri = Uri(), + maxConnections: int = -1, + backlogSize: int = DefaultBacklogSize, + bufferSize: int = 4096, + httpHeadersTimeout = 10.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576, + requestErrorHandler: RestRequestErrorHandler = nil, + dualstack = DualStackType.Auto, + allowedOrigin: Option[string] = none(string) + ): RestServerResult[WakuRestServerRef] = + var server = WakuRestServerRef(router: router) + + let restMiddleware = RestServerMiddlewareRef.new(router = server.router, errorHandler = requestErrorHandler) + let originHandlerMiddleware = OriginHandlerMiddlewareRef.new(allowedOrigin) + + let middlewares = [originHandlerMiddleware, + restMiddleware] + + ## This must be empty and needed only to confirm original initialization requirements of + ## the RestHttpServer now combining old and new middleware approach. + proc defaultProcessCallback(rf: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + discard + + + let sres = HttpServerRef.new(address + , defaultProcessCallback + , serverFlags + , socketFlags + , serverUri + , serverIdent + , maxConnections + , bufferSize + , backlogSize + , httpHeadersTimeout + , maxHeadersSize + , maxRequestBodySize + , dualstack = dualstack + , middlewares = middlewares) + if sres.isOk(): + server.httpServer = sres.get() + ok(server) + else: + err(sres.error) + +proc getRouter(): RestRouter = # TODO: Review this `validate` method. Check in nim-presto what is this used for. proc validate(pattern: string, value: string): int = ## This is rough validation procedure which should be simple and fast, @@ -54,9 +118,10 @@ proc getRouter(allowedOrigin: Option[string]): RestRouter = if pattern.startsWith("{") and pattern.endsWith("}"): 0 else: 1 - RestRouter.init(validate, allowedOrigin = allowedOrigin) + # disable allowed origin handling by presto, we add our own handling as middleware + RestRouter.init(validate, allowedOrigin = none(string)) -proc init*(T: type RestServerRef, +proc init*(T: type WakuRestServerRef, ip: IpAddress, port: Port, allowedOrigin=none(string), conf=RestServerConf.default(), @@ -73,28 +138,64 @@ proc init*(T: type RestServerRef, maxHeadersSize = conf.maxRequestHeadersSize * 1024 maxRequestBodySize = conf.maxRequestBodySize * 1024 - let router = getRouter(allowedOrigin) + let router = getRouter() - var res: RestResult[RestServerRef] try: - res = RestServerRef.new( - router, - address, - serverFlags = serverFlags, - httpHeadersTimeout = headersTimeout, - maxHeadersSize = maxHeadersSize, - maxRequestBodySize = maxRequestBodySize, - requestErrorHandler = requestErrorHandler - ) + return WakuRestServerRef.new( + router, + address, + serverFlags = serverFlags, + httpHeadersTimeout = headersTimeout, + maxHeadersSize = maxHeadersSize, + maxRequestBodySize = maxRequestBodySize, + requestErrorHandler = requestErrorHandler, + allowedOrigin = allowedOrigin + ) except CatchableError: return err(getCurrentExceptionMsg()) - # RestResult error type is cstring, so we need to map it to string - res.mapErr(proc(err: cstring): string = $err) - proc newRestHttpServer*(ip: IpAddress, port: Port, allowedOrigin=none(string), conf=RestServerConf.default(), requestErrorHandler: RestRequestErrorHandler = nil): - RestServerResult[RestServerRef] = - RestServerRef.init(ip, port, allowedOrigin, conf, requestErrorHandler) + RestServerResult[WakuRestServerRef] = + WakuRestServerRef.init(ip, port, allowedOrigin, conf, requestErrorHandler) + +proc localAddress*(rs: WakuRestServerRef): TransportAddress = + ## Returns `rs` bound local socket address. + rs.httpServer.instance.localAddress() + +proc state*(rs: WakuRestServerRef): RestServerState = + ## Returns current REST server's state. + case rs.httpServer.state + of HttpServerState.ServerClosed: + RestServerState.Closed + of HttpServerState.ServerStopped: + RestServerState.Stopped + of HttpServerState.ServerRunning: + RestServerState.Running + +proc start*(rs: WakuRestServerRef) = + ## Starts REST server. + rs.httpServer.start() + notice "REST service started", address = $rs.localAddress() + +proc stop*(rs: WakuRestServerRef) {.async: (raises: []).} = + ## Stop REST server from accepting new connections. + await rs.httpServer.stop() + notice "REST service stopped", address = $rs.localAddress() + +proc drop*(rs: WakuRestServerRef): Future[void] {. + async: (raw: true, raises: []).} = + ## Drop all pending connections. + rs.httpServer.drop() + +proc closeWait*(rs: WakuRestServerRef) {.async: (raises: []).} = + ## Stop REST server and drop all the pending connections. + await rs.httpServer.closeWait() + notice "REST service closed", address = $rs.localAddress() + +proc join*(rs: WakuRestServerRef): Future[void] {. + async: (raw: true, raises: [CancelledError]).} = + ## Wait until REST server will not be closed. + rs.httpServer.join()