{.push raises: [].} import std/[options, strutils, re, net], results, chronicles, chronos, chronos/apps/http/httpserver type OriginHandlerMiddlewareRef* = ref object of HttpServerMiddlewareRef allowedOriginMatcher: Option[Regex] everyOriginAllowed: bool proc isEveryOriginAllowed(maybeAllowedOrigin: Option[string]): bool = return maybeAllowedOrigin.isSome() and maybeAllowedOrigin.get() == "*" proc compileOriginMatcher(maybeAllowedOrigin: Option[string]): Option[Regex] = if maybeAllowedOrigin.isNone(): return none(Regex) let allowedOrigin = maybeAllowedOrigin.get() if (len(allowedOrigin) == 0): return none(Regex) try: var matchOrigin: string if allowedOrigin == "*": matchOrigin = r".*" return some(re(matchOrigin, {reIgnoreCase, reExtended})) let allowedOrigins = allowedOrigin.split(",") var matchExpressions: seq[string] = @[] var prefix: string for allowedOrigin in allowedOrigins: if allowedOrigin.startsWith("http://"): prefix = r"http:\/\/" matchOrigin = allowedOrigin.substr(7) elif allowedOrigin.startsWith("https://"): prefix = r"https:\/\/" matchOrigin = allowedOrigin.substr(8) else: prefix = r"https?:\/\/" matchOrigin = allowedOrigin matchOrigin = matchOrigin.replace(".", r"\.") matchOrigin = matchOrigin.replace("*", ".*") matchOrigin = matchOrigin.replace("?", ".?") matchExpressions.add("^" & prefix & matchOrigin & "$") let finalExpression = matchExpressions.join("|") return some(re(finalExpression, {reIgnoreCase, reExtended})) except RegexError: var msg = getCurrentExceptionMsg() error "Failed to compile regex", source = allowedOrigin, err = msg return none(Regex) proc originsMatch( originHandler: OriginHandlerMiddlewareRef, requestOrigin: string ): bool = if originHandler.allowedOriginMatcher.isNone(): return false return requestOrigin.match(originHandler.allowedOriginMatcher.get()) proc originMiddlewareProc( middleware: HttpServerMiddlewareRef, reqfence: RequestFence, nextHandler: HttpProcessCallback2, ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = if reqfence.isErr(): # Ignore request errors that detected before our middleware. # Let final handler deal with it. return await nextHandler(reqfence) let self = OriginHandlerMiddlewareRef(middleware) let request = reqfence.get() var reqHeaders = request.headers var response = request.getResponse() if self.allowedOriginMatcher.isSome(): let origin = reqHeaders.getList("Origin") try: if origin.len == 1: if self.everyOriginAllowed: response.addHeader("Access-Control-Allow-Origin", "*") response.addHeader("Access-Control-Allow-Headers", "Content-Type") elif self.originsMatch(origin[0]): # The Vary: Origin header to must be set to prevent # potential cache poisoning attacks: # https://textslashplain.com/2018/08/02/cors-and-vary/ response.addHeader("Vary", "Origin") response.addHeader("Access-Control-Allow-Origin", origin[0]) response.addHeader("Access-Control-Allow-Headers", "Content-Type") else: return await request.respond(Http403, "Origin not allowed") elif origin.len == 0: discard elif origin.len > 1: return await request.respond( Http400, "Only a single Origin header must be specified" ) except HttpWriteError as exc: # We use default error handler if we unable to send response. return defaultResponse(exc) # Calling next handler. return await nextHandler(reqfence) proc new*( t: typedesc[OriginHandlerMiddlewareRef], allowedOrigin: Option[string] = none(string), ): HttpServerMiddlewareRef = let middleware = OriginHandlerMiddlewareRef( allowedOriginMatcher: compileOriginMatcher(allowedOrigin), everyOriginAllowed: isEveryOriginAllowed(allowedOrigin), handler: originMiddlewareProc, ) return HttpServerMiddlewareRef(middleware)