diff --git a/presto/common.nim b/presto/common.nim index 00e2b06..db9b743 100644 --- a/presto/common.nim +++ b/presto/common.nim @@ -6,7 +6,7 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import chronos/apps +import chronos/apps, chronos/apps/http/httpclient import stew/[results, byteutils] export results, apps @@ -72,11 +72,11 @@ proc response*(t: typedesc[RestApiResponse], data: ByteChar, else: block: var default: seq[byte] - if len(data) > 0: - ContentBody(contentType: contentType, data: toBytes(data)) - else: - ContentBody(contentType: contentType, data: default) - RestApiResponse(kind: RestApiResponseKind.Content, status: status, + ContentBody(contentType: contentType, + data: if len(data) > 0: toBytes(data) else: default) + + RestApiResponse(kind: RestApiResponseKind.Content, + status: status, content: content) proc redirect*(t: typedesc[RestApiResponse], status: HttpCode = Http307, diff --git a/presto/route.nim b/presto/route.nim index 704ad51..3d1d950 100644 --- a/presto/route.nim +++ b/presto/route.nim @@ -6,8 +6,10 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import chronos, chronos/apps/http/[httpcommon, httptable] + import std/[macros, options] +import chronos, chronos/apps/http/[httpcommon, httptable, httpclient] +import httputils import stew/bitops2 import btrees import common, segpath, macrocommon @@ -46,13 +48,22 @@ type RestRouter* = object patternCallback*: PatternCallback routes*: BTree[SegmentedPath, RestRouteItem] + allowedOrigin*: Option[string] proc init*(t: typedesc[RestRouter], - patternCallback: PatternCallback): RestRouter {.raises: [Defect].} = + patternCallback: PatternCallback, + allowedOrigin = none(string)): RestRouter {.raises: [Defect].} = doAssert(not(isNil(patternCallback)), "Pattern validation callback must not be nil") RestRouter(patternCallback: patternCallback, - routes: initBTree[SegmentedPath, RestRouteItem]()) + routes: initBTree[SegmentedPath, RestRouteItem](), + allowedOrigin: allowedOrigin) + +proc optionsRequestHandler(request: HttpRequestRef, + pathParams: HttpTable, + queryParams: HttpTable, + body: Option[ContentBody]): Future[RestApiResponse] {.async.} = + return RestApiResponse.response("", Http200) proc addRoute*(rr: var RestRouter, request: HttpMethod, path: string, flags: set[RestRouterFlag], handler: RestApiCallback) {. @@ -65,6 +76,25 @@ proc addRoute*(rr: var RestRouter, request: HttpMethod, path: string, let item = RestRouteItem(kind: RestRouteKind.Handler, path: spath, flags: flags, callback: handler) rr.routes.add(spath, item) + + if rr.allowedOrigin.isSome: + let + optionsPath = SegmentedPath.init( + MethodOptions, path, rr.patternCallback) + optionsRoute = rr.routes.getOrDefault( + optionsPath, RestRouteItem(kind: RestRouteKind.None)) + case route.kind + of RestRouteKind.None: + let optionsHandler = RestRouteItem(kind: RestRouteKind.Handler, + path: optionsPath, + flags: {RestRouterFlag.Raw}, + callback: optionsRequestHandler) + rr.routes.add(optionsPath, optionsHandler) + else: + # This may happen if we use the same URL path in separate GET and + # POST handlers. Reusing the previously installed OPTIONS handler + # is perfectly fine. + discard else: raiseAssert("The route is already in the routing table") diff --git a/presto/serverprivate.nim b/presto/serverprivate.nim index 85fae6f..0358ef3 100644 --- a/presto/serverprivate.nim +++ b/presto/serverprivate.nim @@ -28,6 +28,16 @@ proc getContentBody*(r: HttpRequestRef): Future[Option[ContentBody]] {.async.} = let cbody = ContentBody(contentType: cres.get(), data: data) return some[ContentBody](cbody) +proc originsMatch(requestOrigin, allowedOrigin: string): bool = + if allowedOrigin.startsWith("http://") or allowedOrigin.startsWith("https://"): + requestOrigin == allowedOrigin + elif requestOrigin.startsWith("http://"): + requestOrigin.toOpenArray(7, requestOrigin.len - 1) == allowedOrigin + elif requestOrigin.startsWith("https://"): + requestOrigin.toOpenArray(8, requestOrigin.len - 1) == allowedOrigin + else: + false + proc processRestRequest*[T](server: T, rf: RequestFence): Future[HttpResponseRef] {. gcsafe, async.} = @@ -98,14 +108,36 @@ proc processRestRequest*[T](server: T, uri = $request.uri return await request.respond(Http410) of RestApiResponseKind.Content: - let headers = HttpTable.init([("Content-Type", + var headers = HttpTable.init([("Content-Type", restRes.content.contentType)]) + if server.router.allowedOrigin.isSome: + let origin = request.headers.getList("Origin") + let everyOriginAllowed = server.router.allowedOrigin.get == "*" + if origin.len == 1: + if everyOriginAllowed: + headers.add("Access-Control-Allow-Origin", "*") + elif originsMatch(origin[0], server.router.allowedOrigin.get): + # The Vary: Origin header to must be set to prevent + # potential cache poisoning attacks: + # https://textslashplain.com/2018/08/02/cors-and-vary/ + headers.add("Vary", "Origin") + headers.add("Access-Control-Allow-Origin", origin[0]) + else: + return await request.respond(Http403, "Origin not allowed") + elif origin.len > 1: + return await request.respond(Http400, + "Only a single Origin header must be specified") + elif not everyOriginAllowed: + return await request.respond(Http403, + "Service can be used only from CORS-enabled clients") + debug "Received response from handler", status = restRes.status.toInt(), meth = $request.meth, peer = $request.remoteAddress(), uri = $request.uri, content_type = restRes.content.contentType, content_size = len(restRes.content.data) + return await request.respond(restRes.status, restRes.content.data, headers) of RestApiResponseKind.Error: