nwaku/waku/waku_api/rest/origin_handler.nim
NagyZoltanPeter d832f92a43
chore: Implemented CORS handling for nwaku REST server (#2470)
* Add allowOrigin configuration for wakunode and WakuRestServer
Update nim-presto to the latest master that contains middleware support
Rework Rest Server in waku to utilize chronos' and presto's new middleware design and added proper CORS handling.
Added cors tests and fixes

Co-authored-by: Ivan FB <128452529+Ivansete-status@users.noreply.github.com>
2024-02-29 09:48:14 +01:00

126 lines
4.1 KiB
Nim

when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import
std/[options, strutils, re],
stew/results,
stew/shims/net,
chronicles,
chronos,
chronos/apps/http/httpserver
type
OriginHandlerMiddlewareRef* = ref object of HttpServerMiddlewareRef
allowedOriginMatcher: Option[Regex]
everyOriginAllowed: bool
proc isEveryOriginAllowed(maybeAllowedOrigin: Option[string]): bool =
return maybeAllowedOrigin.isSome() and maybeAllowedOrigin.get() == "*"
proc compileOriginMatcher(maybeAllowedOrigin: Option[string]): Option[Regex] =
if maybeAllowedOrigin.isNone():
return none(Regex)
let allowedOrigin = maybeAllowedOrigin.get()
if (len(allowedOrigin) == 0):
return none(Regex)
try:
var matchOrigin : string
if allowedOrigin == "*":
matchOrigin = r".*"
return some(re(matchOrigin, {reIgnoreCase, reExtended}))
let allowedOrigins = allowedOrigin.split(",")
var matchExpressions : seq[string] = @[]
var prefix : string
for allowedOrigin in allowedOrigins:
if allowedOrigin.startsWith("http://"):
prefix = r"http:\/\/"
matchOrigin = allowedOrigin.substr(7)
elif allowedOrigin.startsWith("https://"):
prefix = r"https:\/\/"
matchOrigin = allowedOrigin.substr(8)
else:
prefix = r"https?:\/\/"
matchOrigin = allowedOrigin
matchOrigin = matchOrigin.replace(".", r"\.")
matchOrigin = matchOrigin.replace("*", ".*")
matchOrigin = matchOrigin.replace("?", ".?")
matchExpressions.add("^" & prefix & matchOrigin & "$")
let finalExpression = matchExpressions.join("|")
return some(re(finalExpression, {reIgnoreCase, reExtended}))
except RegexError:
var msg = getCurrentExceptionMsg()
error "Failed to compile regex", source=allowedOrigin, err=msg
return none(Regex)
proc originsMatch(originHandler: OriginHandlerMiddlewareRef,
requestOrigin: string): bool =
if originHandler.allowedOriginMatcher.isNone():
return false
return requestOrigin.match(originHandler.allowedOriginMatcher.get())
proc originMiddlewareProc(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# Ignore request errors that detected before our middleware.
# Let final handler deal with it.
return await nextHandler(reqfence)
let self = OriginHandlerMiddlewareRef(middleware)
let request = reqfence.get()
var reqHeaders = request.headers
var response = request.getResponse()
if self.allowedOriginMatcher.isSome():
let origin = reqHeaders.getList("Origin")
try:
if origin.len == 1:
if self.everyOriginAllowed:
response.addHeader("Access-Control-Allow-Origin", "*")
elif self.originsMatch(origin[0]):
# The Vary: Origin header to must be set to prevent
# potential cache poisoning attacks:
# https://textslashplain.com/2018/08/02/cors-and-vary/
response.addHeader("Vary", "Origin")
response.addHeader("Access-Control-Allow-Origin", origin[0])
else:
return await request.respond(Http403, "Origin not allowed")
elif origin.len == 0:
discard
elif origin.len > 1:
return await request.respond(Http400, "Only a single Origin header must be specified")
except HttpWriteError as exc:
# We use default error handler if we unable to send response.
return defaultResponse(exc)
# Calling next handler.
return await nextHandler(reqfence)
proc new*(t: typedesc[OriginHandlerMiddlewareRef],
allowedOrigin: Option[string] = none(string)
): HttpServerMiddlewareRef =
let middleware =
OriginHandlerMiddlewareRef(allowedOriginMatcher: compileOriginMatcher(allowedOrigin),
everyOriginAllowed: isEveryOriginAllowed(allowedOrigin),
handler: originMiddlewareProc)
return HttpServerMiddlewareRef(middleware)