nwaku/waku/waku_api/rest/origin_handler.nim
2024-03-16 00:08:47 +01:00

127 lines
4.0 KiB
Nim

when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import
std/[options, strutils, re],
stew/results,
stew/shims/net,
chronicles,
chronos,
chronos/apps/http/httpserver
type OriginHandlerMiddlewareRef* = ref object of HttpServerMiddlewareRef
allowedOriginMatcher: Option[Regex]
everyOriginAllowed: bool
proc isEveryOriginAllowed(maybeAllowedOrigin: Option[string]): bool =
return maybeAllowedOrigin.isSome() and maybeAllowedOrigin.get() == "*"
proc compileOriginMatcher(maybeAllowedOrigin: Option[string]): Option[Regex] =
if maybeAllowedOrigin.isNone():
return none(Regex)
let allowedOrigin = maybeAllowedOrigin.get()
if (len(allowedOrigin) == 0):
return none(Regex)
try:
var matchOrigin: string
if allowedOrigin == "*":
matchOrigin = r".*"
return some(re(matchOrigin, {reIgnoreCase, reExtended}))
let allowedOrigins = allowedOrigin.split(",")
var matchExpressions: seq[string] = @[]
var prefix: string
for allowedOrigin in allowedOrigins:
if allowedOrigin.startsWith("http://"):
prefix = r"http:\/\/"
matchOrigin = allowedOrigin.substr(7)
elif allowedOrigin.startsWith("https://"):
prefix = r"https:\/\/"
matchOrigin = allowedOrigin.substr(8)
else:
prefix = r"https?:\/\/"
matchOrigin = allowedOrigin
matchOrigin = matchOrigin.replace(".", r"\.")
matchOrigin = matchOrigin.replace("*", ".*")
matchOrigin = matchOrigin.replace("?", ".?")
matchExpressions.add("^" & prefix & matchOrigin & "$")
let finalExpression = matchExpressions.join("|")
return some(re(finalExpression, {reIgnoreCase, reExtended}))
except RegexError:
var msg = getCurrentExceptionMsg()
error "Failed to compile regex", source = allowedOrigin, err = msg
return none(Regex)
proc originsMatch(
originHandler: OriginHandlerMiddlewareRef, requestOrigin: string
): bool =
if originHandler.allowedOriginMatcher.isNone():
return false
return requestOrigin.match(originHandler.allowedOriginMatcher.get())
proc originMiddlewareProc(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2,
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# Ignore request errors that detected before our middleware.
# Let final handler deal with it.
return await nextHandler(reqfence)
let self = OriginHandlerMiddlewareRef(middleware)
let request = reqfence.get()
var reqHeaders = request.headers
var response = request.getResponse()
if self.allowedOriginMatcher.isSome():
let origin = reqHeaders.getList("Origin")
try:
if origin.len == 1:
if self.everyOriginAllowed:
response.addHeader("Access-Control-Allow-Origin", "*")
elif self.originsMatch(origin[0]):
# The Vary: Origin header to must be set to prevent
# potential cache poisoning attacks:
# https://textslashplain.com/2018/08/02/cors-and-vary/
response.addHeader("Vary", "Origin")
response.addHeader("Access-Control-Allow-Origin", origin[0])
else:
return await request.respond(Http403, "Origin not allowed")
elif origin.len == 0:
discard
elif origin.len > 1:
return await request.respond(
Http400, "Only a single Origin header must be specified"
)
except HttpWriteError as exc:
# We use default error handler if we unable to send response.
return defaultResponse(exc)
# Calling next handler.
return await nextHandler(reqfence)
proc new*(
t: typedesc[OriginHandlerMiddlewareRef],
allowedOrigin: Option[string] = none(string),
): HttpServerMiddlewareRef =
let middleware = OriginHandlerMiddlewareRef(
allowedOriginMatcher: compileOriginMatcher(allowedOrigin),
everyOriginAllowed: isEveryOriginAllowed(allowedOrigin),
handler: originMiddlewareProc,
)
return HttpServerMiddlewareRef(middleware)