diff --git a/apps/wakunode2/app.nim b/apps/wakunode2/app.nim index b064c2453..10a91abf3 100644 --- a/apps/wakunode2/app.nim +++ b/apps/wakunode2/app.nim @@ -84,7 +84,7 @@ type node: WakuNode rpcServer: Option[RpcHttpServer] - restServer: Option[RestServerRef] + restServer: Option[WakuRestServerRef] metricsServer: Option[MetricsHttpServerRef] AppResult*[T] = Result[T, string] @@ -667,34 +667,55 @@ proc startApp*(app: var App): AppResult[void] = ## Monitoring and external interfaces -proc startRestServer(app: App, address: IpAddress, port: Port, conf: WakuNodeConf): AppResult[RestServerRef] = +proc startRestServer(app: App, + address: IpAddress, + port: Port, + conf: WakuNodeConf): + AppResult[WakuRestServerRef] = # Used to register api endpoints that are not currently installed as keys, # values are holding error messages to be returned to the client var notInstalledTab: Table[string, string] = initTable[string, string]() - proc requestErrorHandler(error: RestRequestError, - request: HttpRequestRef): - Future[HttpResponseRef] {.async.} = - case error - of RestRequestError.Invalid: - return await request.respond(Http400, "Invalid request", HttpTable.init()) - of RestRequestError.NotFound: - let rootPath = request.rawPath.split("/")[1] - if notInstalledTab.hasKey(rootPath): - return await request.respond(Http404, notInstalledTab[rootPath], HttpTable.init()) - else: - return await request.respond(Http400, "Bad request initiated. Invalid path or method used.", HttpTable.init()) - of RestRequestError.InvalidContentBody: - return await request.respond(Http400, "Invalid content body", HttpTable.init()) - of RestRequestError.InvalidContentType: - return await request.respond(Http400, "Invalid content type", HttpTable.init()) - of RestRequestError.Unexpected: - return defaultResponse() + let requestErrorHandler : RestRequestErrorHandler = proc (error: RestRequestError, + request: HttpRequestRef): + Future[HttpResponseRef] + {.async: (raises: [CancelledError]).} = + try: + case error + of RestRequestError.Invalid: + return await request.respond(Http400, "Invalid request", HttpTable.init()) + of RestRequestError.NotFound: + let paths = request.rawPath.split("/") + let rootPath = if len(paths) > 1: + paths[1] + else: + "" + notInstalledTab.withValue(rootPath, errMsg): + return await request.respond(Http404, errMsg[], HttpTable.init()) + do: + return await request.respond(Http400, "Bad request initiated. Invalid path or method used.", HttpTable.init()) + of RestRequestError.InvalidContentBody: + return await request.respond(Http400, "Invalid content body", HttpTable.init()) + of RestRequestError.InvalidContentType: + return await request.respond(Http400, "Invalid content type", HttpTable.init()) + of RestRequestError.Unexpected: + return defaultResponse() + except HttpWriteError: + error "Failed to write response to client", error = getCurrentExceptionMsg() + discard return defaultResponse() - let server = ? newRestHttpServer(address, port, requestErrorHandler = requestErrorHandler) + let allowedOrigin = if len(conf.restAllowOrigin) > 0 : + some(conf.restAllowOrigin.join(",")) + else: + none(string) + + let server = ? newRestHttpServer(address, port, + allowedOrigin = allowedOrigin, + requestErrorHandler = requestErrorHandler) + ## Admin REST API if conf.restAdmin: installAdminApiHandlers(server.router, app.node) diff --git a/apps/wakunode2/external_config.nim b/apps/wakunode2/external_config.nim index 88739de04..e9c1d1e9f 100644 --- a/apps/wakunode2/external_config.nim +++ b/apps/wakunode2/external_config.nim @@ -419,6 +419,14 @@ type defaultValue: false name: "rest-private" }: bool + restAllowOrigin* {. + desc: "Allow cross-origin requests from the specified origin." & + "Argument may be repeated." & + "Wildcards: * or ? allowed." & + "Ex.: \"localhost:*\" or \"127.0.0.1:8080\"", + defaultValue: newSeq[string]() + name: "rest-allow-origin" }: seq[string] + ## Metrics config metricsServer* {. diff --git a/tests/all_tests_waku.nim b/tests/all_tests_waku.nim index 303021404..17d76a50c 100644 --- a/tests/all_tests_waku.nim +++ b/tests/all_tests_waku.nim @@ -85,7 +85,8 @@ import ./wakunode_rest/test_rest_filter, ./wakunode_rest/test_rest_legacy_filter, ./wakunode_rest/test_rest_lightpush, - ./wakunode_rest/test_rest_admin + ./wakunode_rest/test_rest_admin, + ./wakunode_rest/test_rest_cors import ./waku_rln_relay/test_waku_rln_relay, diff --git a/tests/wakunode_rest/test_all.nim b/tests/wakunode_rest/test_all.nim index 620ae8a70..9829a78f2 100644 --- a/tests/wakunode_rest/test_all.nim +++ b/tests/wakunode_rest/test_all.nim @@ -10,4 +10,5 @@ import ./test_rest_relay, ./test_rest_serdes, ./test_rest_store, - ./test_rest_admin + ./test_rest_admin, + ./test_rest_cors diff --git a/tests/wakunode_rest/test_rest_admin.nim b/tests/wakunode_rest/test_rest_admin.nim index 932dcd743..628359ab3 100644 --- a/tests/wakunode_rest/test_rest_admin.nim +++ b/tests/wakunode_rest/test_rest_admin.nim @@ -30,7 +30,7 @@ suite "Waku v2 Rest API - Admin": var peerInfo1 {.threadvar.}: RemotePeerInfo var peerInfo2 {.threadvar.}: RemotePeerInfo var peerInfo3 {.threadvar.}: RemotePeerInfo - var restServer {.threadvar.}: RestServerRef + var restServer {.threadvar.}: WakuRestServerRef var client{.threadvar.}: RestClientRef asyncSetup: @@ -46,7 +46,7 @@ suite "Waku v2 Rest API - Admin": let restPort = Port(58011) let restAddress = parseIpAddress("127.0.0.1") - restServer = RestServerRef.init(restAddress, restPort).tryGet() + restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installAdminApiHandlers(restServer.router, node1) diff --git a/tests/wakunode_rest/test_rest_cors.nim b/tests/wakunode_rest/test_rest_cors.nim new file mode 100644 index 000000000..3ac166f07 --- /dev/null +++ b/tests/wakunode_rest/test_rest_cors.nim @@ -0,0 +1,269 @@ +{.used.} + +import + stew/shims/net, + testutils/unittests, + presto, + presto/client as presto_client, + libp2p/peerinfo, + libp2p/multiaddress, + libp2p/crypto/crypto +import + ../../waku/waku_node, + ../../waku/node/waku_node as waku_node2, + ../../waku/waku_api/rest/server, + ../../waku/waku_api/rest/client, + ../../waku/waku_api/rest/responses, + ../../waku/waku_api/rest/debug/handlers as debug_api, + ../../waku/waku_api/rest/debug/client as debug_api_client, + ../testlib/common, + ../testlib/wakucore, + ../testlib/wakunode + + +type + TestResponseTuple = tuple[status: int, data: string, headers: HttpTable] + +proc testWakuNode(): WakuNode = + let + privkey = crypto.PrivateKey.random(Secp256k1, rng[]).tryGet() + bindIp = parseIpAddress("0.0.0.0") + extIp = parseIpAddress("127.0.0.1") + port = Port(58000) + + newTestWakuNode(privkey, bindIp, port, some(extIp), some(port)) + +proc fetchWithHeader(request: HttpClientRequestRef): Future[TestResponseTuple] + {.async: (raises: [CancelledError, HttpError]).} = + var response: HttpClientResponseRef + try: + response = await request.send() + let buffer = await response.getBodyBytes() + let status = response.status + let headers = response.headers + await response.closeWait() + response = nil + return (status, buffer.bytesToString(), headers) + except HttpError as exc: + if not(isNil(response)): await response.closeWait() + assert false + except CancelledError as exc: + if not(isNil(response)): await response.closeWait() + assert false + +proc issueRequest( + address: HttpAddress, + reqOrigin: Option[string] = none(string) + ): Future[TestResponseTuple] {.async.} = + + var + session = HttpSessionRef.new({HttpClientFlag.Http11Pipeline}) + data: TestResponseTuple + + var originHeader : seq[HttpHeaderTuple] + if reqOrigin.isSome(): + originHeader.insert(("Origin", reqOrigin.get())) + + var + request = HttpClientRequestRef.new(session, + address, + version = HttpVersion11, + headers = originHeader) + try: + data = await request.fetchWithHeader() + finally: + await request.closeWait() + return data + +proc checkResponse(response: TestResponseTuple, + expectedStatus : int, + expectedOrigin : Option[string]): bool = + if response.status != expectedStatus: + echo(" -> check failed: expected status" & $expectedStatus & + " got " & $response.status) + return false + + if not (expectedOrigin.isNone() or + (expectedOrigin.isSome() and + response.headers.contains("Access-Control-Allow-Origin") and + response.headers.getLastString("Access-Control-Allow-Origin") == expectedOrigin.get())): + echo(" -> check failed: expected origin " & $expectedOrigin & " got " & + response.headers.getLastString("Access-Control-Allow-Origin")) + return false + + return true + +suite "Waku v2 REST API CORS Handling": + asyncTest "AllowedOrigin matches": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("test.net:1234,https://localhost:*,http://127.0.0.1:?8,?waku*.net:*80*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, some("http://test.net:1234")) + check checkResponse(response, 200, some("http://test.net:1234")) + + response = await issueRequest(ha, some("https://test.net:1234")) + check checkResponse(response, 200, some("https://test.net:1234")) + + response = await issueRequest(ha, some("https://localhost:8080")) + check checkResponse(response, 200, some("https://localhost:8080")) + + response = await issueRequest(ha, some("https://localhost:80")) + check checkResponse(response, 200, some("https://localhost:80")) + + response = await issueRequest(ha, some("http://127.0.0.1:78")) + check checkResponse(response, 200, some("http://127.0.0.1:78")) + + response = await issueRequest(ha, some("http://wakuTHE.net:8078")) + check checkResponse(response, 200, some("http://wakuTHE.net:8078")) + + response = await issueRequest(ha, some("http://nwaku.main.net:1980")) + check checkResponse(response, 200, some("http://nwaku.main.net:1980")) + + response = await issueRequest(ha, some("http://nwaku.main.net:80")) + check checkResponse(response, 200, some("http://nwaku.main.net:80")) + + await restServer.stop() + await restServer.closeWait() + await node.stop() + + asyncTest "AllowedOrigin reject": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("test.net:1234,https://localhost:*,http://127.0.0.1:?8,?waku*.net:*80*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, some("http://test.net:12334")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://test.net:12345")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("xhttp://test.net:1234")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("https://xtest.net:1234")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://localhost:8080")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("https://127.0.0.1:78")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://127.0.0.1:89")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://the.waku.net:8078")) + check checkResponse(response, 403, none(string)) + + response = await issueRequest(ha, some("http://nwaku.main.net:1900")) + check checkResponse(response, 403, none(string)) + + await restServer.stop() + await restServer.closeWait() + await node.stop() + + asyncTest "AllowedOrigin allmatches": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, some("http://test.net:1234")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("https://test.net:1234")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("https://localhost:8080")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("https://localhost:80")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://127.0.0.1:78")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://wakuTHE.net:8078")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://nwaku.main.net:1980")) + check checkResponse(response, 200, some("*")) + + response = await issueRequest(ha, some("http://nwaku.main.net:80")) + check checkResponse(response, 200, some("*")) + + await restServer.stop() + await restServer.closeWait() + await node.stop() + + asyncTest "No origin goes through": + # Given + let node = testWakuNode() + await node.start() + await node.mountRelay() + + let restPort = Port(58001) + let restAddress = parseIpAddress("0.0.0.0") + let restServer = WakuRestServerRef.init(restAddress, + restPort, + allowedOrigin=some("test.net:1234,https://localhost:*,http://127.0.0.1:?8,?waku*.net:*80*") + ).tryGet() + + installDebugApiHandlers(restServer.router, node) + restServer.start() + + let srvAddr = restServer.localAddress() + let ha = getAddress(srvAddr, HttpClientScheme.NonSecure, "/debug/v1/info") + + # When + var response = await issueRequest(ha, none(string)) + check checkResponse(response, 200, none(string)) + + await restServer.stop() + await restServer.closeWait() + await node.stop() diff --git a/tests/wakunode_rest/test_rest_debug.nim b/tests/wakunode_rest/test_rest_debug.nim index ab458ad0c..4f1c4ac6a 100644 --- a/tests/wakunode_rest/test_rest_debug.nim +++ b/tests/wakunode_rest/test_rest_debug.nim @@ -40,7 +40,7 @@ suite "Waku v2 REST API - Debug": let restPort = Port(58001) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installDebugApiHandlers(restServer.router, node) restServer.start() @@ -67,7 +67,7 @@ suite "Waku v2 REST API - Debug": let restPort = Port(58002) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installDebugApiHandlers(restServer.router, node) restServer.start() diff --git a/tests/wakunode_rest/test_rest_filter.nim b/tests/wakunode_rest/test_rest_filter.nim index b86bb8925..10a2d5b8f 100644 --- a/tests/wakunode_rest/test_rest_filter.nim +++ b/tests/wakunode_rest/test_rest_filter.nim @@ -39,8 +39,8 @@ proc testWakuNode(): WakuNode = type RestFilterTest = object serviceNode: WakuNode subscriberNode: WakuNode - restServer: RestServerRef - restServerForService: RestServerRef + restServer: WakuRestServerRef + restServerForService: WakuRestServerRef messageCache: MessageCache client: RestClientRef clientTwdServiceNode: RestClientRef @@ -61,10 +61,10 @@ proc init(T: type RestFilterTest): Future[T] {.async.} = let restPort = Port(58011) let restAddress = parseIpAddress("127.0.0.1") - testSetup.restServer = RestServerRef.init(restAddress, restPort).tryGet() + testSetup.restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() let restPort2 = Port(58012) - testSetup.restServerForService = RestServerRef.init(restAddress, restPort2).tryGet() + testSetup.restServerForService = WakuRestServerRef.init(restAddress, restPort2).tryGet() # through this one we will see if messages are pushed according to our content topic sub testSetup.messageCache = MessageCache.init() diff --git a/tests/wakunode_rest/test_rest_health.nim b/tests/wakunode_rest/test_rest_health.nim index 9560384a0..5937b8dba 100644 --- a/tests/wakunode_rest/test_rest_health.nim +++ b/tests/wakunode_rest/test_rest_health.nim @@ -44,7 +44,7 @@ suite "Waku v2 REST API - health": let restPort = Port(58001) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installHealthApiHandler(restServer.router, node) restServer.start() diff --git a/tests/wakunode_rest/test_rest_legacy_filter.nim b/tests/wakunode_rest/test_rest_legacy_filter.nim index fe82d8117..f7de83372 100644 --- a/tests/wakunode_rest/test_rest_legacy_filter.nim +++ b/tests/wakunode_rest/test_rest_legacy_filter.nim @@ -38,7 +38,7 @@ proc testWakuNode(): WakuNode = type RestFilterTest = object filterNode: WakuNode clientNode: WakuNode - restServer: RestServerRef + restServer: WakuRestServerRef messageCache: MessageCache client: RestClientRef @@ -58,7 +58,7 @@ proc setupRestFilter(): Future[RestFilterTest] {.async.} = let restPort = Port(58011) let restAddress = parseIpAddress("0.0.0.0") - result.restServer = RestServerRef.init(restAddress, restPort).tryGet() + result.restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() result.messageCache = MessageCache.init() installLegacyFilterRestApiHandlers(result.restServer.router diff --git a/tests/wakunode_rest/test_rest_lightpush.nim b/tests/wakunode_rest/test_rest_lightpush.nim index 876be7c64..34dc40dd2 100644 --- a/tests/wakunode_rest/test_rest_lightpush.nim +++ b/tests/wakunode_rest/test_rest_lightpush.nim @@ -39,7 +39,7 @@ type RestLightPushTest = object serviceNode: WakuNode pushNode: WakuNode consumerNode: WakuNode - restServer: RestServerRef + restServer: WakuRestServerRef client: RestClientRef @@ -71,7 +71,7 @@ proc init(T: type RestLightPushTest): Future[T] {.async.} = let restPort = Port(58011) let restAddress = parseIpAddress("127.0.0.1") - testSetup.restServer = RestServerRef.init(restAddress, restPort).tryGet() + testSetup.restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installLightPushRequestHandler(testSetup.restServer.router, testSetup.pushNode) diff --git a/tests/wakunode_rest/test_rest_relay.nim b/tests/wakunode_rest/test_rest_relay.nim index 0e0f40424..9ec2485af 100644 --- a/tests/wakunode_rest/test_rest_relay.nim +++ b/tests/wakunode_rest/test_rest_relay.nim @@ -43,9 +43,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -93,9 +93,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() cache.pubsubSubscribe("pubsub-topic-1") @@ -147,12 +147,12 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let pubSubTopic = "/waku/2/default-waku/proto" - + var messages = @[ fakeWakuMessage(contentTopic = "content-topic-x", payload = toBytes("TEST-1"), meta = toBytes("test-meta") ) @@ -168,7 +168,7 @@ suite "Waku v2 Rest API - Relay": meta = toBytes("test-meta")) messages.add(msg) - + let cache = MessageCache.init() cache.pubsubSubscribe(pubSubTopic) @@ -216,9 +216,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -258,9 +258,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -306,9 +306,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let contentTopics = @[ ContentTopic("/waku/2/default-content1/proto"), @@ -354,9 +354,9 @@ suite "Waku v2 Rest API - Relay": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let contentTopic = DefaultContentTopic @@ -370,9 +370,9 @@ suite "Waku v2 Rest API - Relay": while msg == messages[i]: msg = fakeWakuMessage(contentTopic = DefaultContentTopic, payload = toBytes("TEST-1")) - + messages.add(msg) - + let cache = MessageCache.init() cache.contentSubscribe(contentTopic) @@ -419,9 +419,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() installRelayApiHandlers(restServer.router, node, cache) @@ -464,9 +464,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() installRelayApiHandlers(restServer.router, node, cache) @@ -504,9 +504,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -525,7 +525,7 @@ suite "Waku v2 Rest API - Relay": contentTopic: some(DefaultContentTopic), timestamp: some(int64(2022)) )) - + # Then check: response.status == 400 @@ -549,9 +549,9 @@ suite "Waku v2 Rest API - Relay": # RPC server setup var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use let cache = MessageCache.init() @@ -570,7 +570,7 @@ suite "Waku v2 Rest API - Relay": contentTopic: some(DefaultContentTopic), timestamp: some(int64(2022)) )) - + # Then check: response.status == 400 @@ -579,4 +579,4 @@ suite "Waku v2 Rest API - Relay": await restServer.stop() await restServer.closeWait() - await node.stop() \ No newline at end of file + await node.stop() diff --git a/tests/wakunode_rest/test_rest_store.nim b/tests/wakunode_rest/test_rest_store.nim index 6b617970c..908eec571 100644 --- a/tests/wakunode_rest/test_rest_store.nim +++ b/tests/wakunode_rest/test_rest_store.nim @@ -83,7 +83,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58011) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -153,7 +153,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58012) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -251,7 +251,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58013) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -325,7 +325,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58014) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -416,7 +416,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58015) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -473,7 +473,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58016) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -482,7 +482,7 @@ procSuite "Waku v2 Rest API - Store": let driver: ArchiveDriver = QueueDriver.new() let mountArchiveRes = node.mountArchive(driver) assert mountArchiveRes.isOk(), mountArchiveRes.error - + await node.mountStore() node.mountStoreClient() @@ -547,7 +547,7 @@ procSuite "Waku v2 Rest API - Store": let restPort = Port(58014) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() installStoreApiHandlers(restServer.router, node) restServer.start() @@ -609,9 +609,9 @@ procSuite "Waku v2 Rest API - Store": var restPort = Port(0) let restAddress = parseIpAddress("0.0.0.0") - let restServer = RestServerRef.init(restAddress, restPort).tryGet() + let restServer = WakuRestServerRef.init(restAddress, restPort).tryGet() - restPort = restServer.server.address.port # update with bound port for client use + restPort = restServer.httpServer.address.port # update with bound port for client use installStoreApiHandlers(restServer.router, node) restServer.start() diff --git a/vendor/nim-presto b/vendor/nim-presto index 5ca16485e..223aadeb8 160000 --- a/vendor/nim-presto +++ b/vendor/nim-presto @@ -1 +1 @@ -Subproject commit 5ca16485e4d74e531d50d289ebc0f869d9e6352b +Subproject commit 223aadeb82d35b57e6dae99f0b325ec6345ce7ff diff --git a/waku/waku_api/rest/origin_handler.nim b/waku/waku_api/rest/origin_handler.nim new file mode 100644 index 000000000..fa3453a06 --- /dev/null +++ b/waku/waku_api/rest/origin_handler.nim @@ -0,0 +1,125 @@ +when (NimMajor, NimMinor) < (1, 4): + {.push raises: [Defect].} +else: + {.push raises: [].} + +import + std/[options, strutils, re], + stew/results, + stew/shims/net, + chronicles, + chronos, + chronos/apps/http/httpserver + +type + OriginHandlerMiddlewareRef* = ref object of HttpServerMiddlewareRef + allowedOriginMatcher: Option[Regex] + everyOriginAllowed: bool + + +proc isEveryOriginAllowed(maybeAllowedOrigin: Option[string]): bool = + return maybeAllowedOrigin.isSome() and maybeAllowedOrigin.get() == "*" + +proc compileOriginMatcher(maybeAllowedOrigin: Option[string]): Option[Regex] = + if maybeAllowedOrigin.isNone(): + return none(Regex) + + let allowedOrigin = maybeAllowedOrigin.get() + + if (len(allowedOrigin) == 0): + return none(Regex) + + try: + var matchOrigin : string + + if allowedOrigin == "*": + matchOrigin = r".*" + return some(re(matchOrigin, {reIgnoreCase, reExtended})) + + let allowedOrigins = allowedOrigin.split(",") + + var matchExpressions : seq[string] = @[] + + var prefix : string + for allowedOrigin in allowedOrigins: + if allowedOrigin.startsWith("http://"): + prefix = r"http:\/\/" + matchOrigin = allowedOrigin.substr(7) + elif allowedOrigin.startsWith("https://"): + prefix = r"https:\/\/" + matchOrigin = allowedOrigin.substr(8) + else: + prefix = r"https?:\/\/" + matchOrigin = allowedOrigin + + matchOrigin = matchOrigin.replace(".", r"\.") + matchOrigin = matchOrigin.replace("*", ".*") + matchOrigin = matchOrigin.replace("?", ".?") + + matchExpressions.add("^" & prefix & matchOrigin & "$") + + let finalExpression = matchExpressions.join("|") + + return some(re(finalExpression, {reIgnoreCase, reExtended})) + except RegexError: + var msg = getCurrentExceptionMsg() + error "Failed to compile regex", source=allowedOrigin, err=msg + return none(Regex) + +proc originsMatch(originHandler: OriginHandlerMiddlewareRef, + requestOrigin: string): bool = + + if originHandler.allowedOriginMatcher.isNone(): + return false + + return requestOrigin.match(originHandler.allowedOriginMatcher.get()) + +proc originMiddlewareProc( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors that detected before our middleware. + # Let final handler deal with it. + return await nextHandler(reqfence) + + let self = OriginHandlerMiddlewareRef(middleware) + let request = reqfence.get() + var reqHeaders = request.headers + var response = request.getResponse() + + if self.allowedOriginMatcher.isSome(): + let origin = reqHeaders.getList("Origin") + try: + if origin.len == 1: + if self.everyOriginAllowed: + response.addHeader("Access-Control-Allow-Origin", "*") + elif self.originsMatch(origin[0]): + # The Vary: Origin header to must be set to prevent + # potential cache poisoning attacks: + # https://textslashplain.com/2018/08/02/cors-and-vary/ + response.addHeader("Vary", "Origin") + response.addHeader("Access-Control-Allow-Origin", origin[0]) + else: + return await request.respond(Http403, "Origin not allowed") + elif origin.len == 0: + discard + elif origin.len > 1: + return await request.respond(Http400, "Only a single Origin header must be specified") + except HttpWriteError as exc: + # We use default error handler if we unable to send response. + return defaultResponse(exc) + + # Calling next handler. + return await nextHandler(reqfence) + +proc new*(t: typedesc[OriginHandlerMiddlewareRef], + allowedOrigin: Option[string] = none(string) + ): HttpServerMiddlewareRef = + + let middleware = + OriginHandlerMiddlewareRef(allowedOriginMatcher: compileOriginMatcher(allowedOrigin), + everyOriginAllowed: isEveryOriginAllowed(allowedOrigin), + handler: originMiddlewareProc) + return HttpServerMiddlewareRef(middleware) diff --git a/waku/waku_api/rest/server.nim b/waku/waku_api/rest/server.nim index ac73beb54..95a76d848 100644 --- a/waku/waku_api/rest/server.nim +++ b/waku/waku_api/rest/server.nim @@ -8,11 +8,23 @@ import stew/shims/net, chronicles, chronos, - presto + chronos/apps/http/httpserver, + presto, + presto/middleware, + presto/servercommon + +import + ./origin_handler -type RestServerResult*[T] = Result[T, string] +type + RestServerResult*[T] = Result[T, string] + WakuRestServer* = object of RootObj + router*: RestRouter + httpServer*: HttpServerRef + + WakuRestServerRef* = ref WakuRestServer ### Configuration @@ -46,7 +58,59 @@ proc default*(T: type RestServerConf): T = ### Initialization -proc getRouter(allowedOrigin: Option[string]): RestRouter = +proc new*(t: typedesc[WakuRestServerRef], + router: RestRouter, + address: TransportAddress, + serverIdent: string = PrestoIdent, + serverFlags = {HttpServerFlags.NotifyDisconnect}, + socketFlags: set[ServerFlags] = {ReuseAddr}, + serverUri = Uri(), + maxConnections: int = -1, + backlogSize: int = DefaultBacklogSize, + bufferSize: int = 4096, + httpHeadersTimeout = 10.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576, + requestErrorHandler: RestRequestErrorHandler = nil, + dualstack = DualStackType.Auto, + allowedOrigin: Option[string] = none(string) + ): RestServerResult[WakuRestServerRef] = + var server = WakuRestServerRef(router: router) + + let restMiddleware = RestServerMiddlewareRef.new(router = server.router, errorHandler = requestErrorHandler) + let originHandlerMiddleware = OriginHandlerMiddlewareRef.new(allowedOrigin) + + let middlewares = [originHandlerMiddleware, + restMiddleware] + + ## This must be empty and needed only to confirm original initialization requirements of + ## the RestHttpServer now combining old and new middleware approach. + proc defaultProcessCallback(rf: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + discard + + + let sres = HttpServerRef.new(address + , defaultProcessCallback + , serverFlags + , socketFlags + , serverUri + , serverIdent + , maxConnections + , bufferSize + , backlogSize + , httpHeadersTimeout + , maxHeadersSize + , maxRequestBodySize + , dualstack = dualstack + , middlewares = middlewares) + if sres.isOk(): + server.httpServer = sres.get() + ok(server) + else: + err(sres.error) + +proc getRouter(): RestRouter = # TODO: Review this `validate` method. Check in nim-presto what is this used for. proc validate(pattern: string, value: string): int = ## This is rough validation procedure which should be simple and fast, @@ -54,9 +118,10 @@ proc getRouter(allowedOrigin: Option[string]): RestRouter = if pattern.startsWith("{") and pattern.endsWith("}"): 0 else: 1 - RestRouter.init(validate, allowedOrigin = allowedOrigin) + # disable allowed origin handling by presto, we add our own handling as middleware + RestRouter.init(validate, allowedOrigin = none(string)) -proc init*(T: type RestServerRef, +proc init*(T: type WakuRestServerRef, ip: IpAddress, port: Port, allowedOrigin=none(string), conf=RestServerConf.default(), @@ -73,28 +138,64 @@ proc init*(T: type RestServerRef, maxHeadersSize = conf.maxRequestHeadersSize * 1024 maxRequestBodySize = conf.maxRequestBodySize * 1024 - let router = getRouter(allowedOrigin) + let router = getRouter() - var res: RestResult[RestServerRef] try: - res = RestServerRef.new( - router, - address, - serverFlags = serverFlags, - httpHeadersTimeout = headersTimeout, - maxHeadersSize = maxHeadersSize, - maxRequestBodySize = maxRequestBodySize, - requestErrorHandler = requestErrorHandler - ) + return WakuRestServerRef.new( + router, + address, + serverFlags = serverFlags, + httpHeadersTimeout = headersTimeout, + maxHeadersSize = maxHeadersSize, + maxRequestBodySize = maxRequestBodySize, + requestErrorHandler = requestErrorHandler, + allowedOrigin = allowedOrigin + ) except CatchableError: return err(getCurrentExceptionMsg()) - # RestResult error type is cstring, so we need to map it to string - res.mapErr(proc(err: cstring): string = $err) - proc newRestHttpServer*(ip: IpAddress, port: Port, allowedOrigin=none(string), conf=RestServerConf.default(), requestErrorHandler: RestRequestErrorHandler = nil): - RestServerResult[RestServerRef] = - RestServerRef.init(ip, port, allowedOrigin, conf, requestErrorHandler) + RestServerResult[WakuRestServerRef] = + WakuRestServerRef.init(ip, port, allowedOrigin, conf, requestErrorHandler) + +proc localAddress*(rs: WakuRestServerRef): TransportAddress = + ## Returns `rs` bound local socket address. + rs.httpServer.instance.localAddress() + +proc state*(rs: WakuRestServerRef): RestServerState = + ## Returns current REST server's state. + case rs.httpServer.state + of HttpServerState.ServerClosed: + RestServerState.Closed + of HttpServerState.ServerStopped: + RestServerState.Stopped + of HttpServerState.ServerRunning: + RestServerState.Running + +proc start*(rs: WakuRestServerRef) = + ## Starts REST server. + rs.httpServer.start() + notice "REST service started", address = $rs.localAddress() + +proc stop*(rs: WakuRestServerRef) {.async: (raises: []).} = + ## Stop REST server from accepting new connections. + await rs.httpServer.stop() + notice "REST service stopped", address = $rs.localAddress() + +proc drop*(rs: WakuRestServerRef): Future[void] {. + async: (raw: true, raises: []).} = + ## Drop all pending connections. + rs.httpServer.drop() + +proc closeWait*(rs: WakuRestServerRef) {.async: (raises: []).} = + ## Stop REST server and drop all the pending connections. + await rs.httpServer.closeWait() + notice "REST service closed", address = $rs.localAddress() + +proc join*(rs: WakuRestServerRef): Future[void] {. + async: (raw: true, raises: [CancelledError]).} = + ## Wait until REST server will not be closed. + rs.httpServer.join()