mirror of
https://github.com/waku-org/nwaku.git
synced 2025-01-15 01:14:56 +00:00
126 lines
4.1 KiB
Nim
126 lines
4.1 KiB
Nim
|
when (NimMajor, NimMinor) < (1, 4):
|
||
|
{.push raises: [Defect].}
|
||
|
else:
|
||
|
{.push raises: [].}
|
||
|
|
||
|
import
|
||
|
std/[options, strutils, re],
|
||
|
stew/results,
|
||
|
stew/shims/net,
|
||
|
chronicles,
|
||
|
chronos,
|
||
|
chronos/apps/http/httpserver
|
||
|
|
||
|
type
|
||
|
OriginHandlerMiddlewareRef* = ref object of HttpServerMiddlewareRef
|
||
|
allowedOriginMatcher: Option[Regex]
|
||
|
everyOriginAllowed: bool
|
||
|
|
||
|
|
||
|
proc isEveryOriginAllowed(maybeAllowedOrigin: Option[string]): bool =
|
||
|
return maybeAllowedOrigin.isSome() and maybeAllowedOrigin.get() == "*"
|
||
|
|
||
|
proc compileOriginMatcher(maybeAllowedOrigin: Option[string]): Option[Regex] =
|
||
|
if maybeAllowedOrigin.isNone():
|
||
|
return none(Regex)
|
||
|
|
||
|
let allowedOrigin = maybeAllowedOrigin.get()
|
||
|
|
||
|
if (len(allowedOrigin) == 0):
|
||
|
return none(Regex)
|
||
|
|
||
|
try:
|
||
|
var matchOrigin : string
|
||
|
|
||
|
if allowedOrigin == "*":
|
||
|
matchOrigin = r".*"
|
||
|
return some(re(matchOrigin, {reIgnoreCase, reExtended}))
|
||
|
|
||
|
let allowedOrigins = allowedOrigin.split(",")
|
||
|
|
||
|
var matchExpressions : seq[string] = @[]
|
||
|
|
||
|
var prefix : string
|
||
|
for allowedOrigin in allowedOrigins:
|
||
|
if allowedOrigin.startsWith("http://"):
|
||
|
prefix = r"http:\/\/"
|
||
|
matchOrigin = allowedOrigin.substr(7)
|
||
|
elif allowedOrigin.startsWith("https://"):
|
||
|
prefix = r"https:\/\/"
|
||
|
matchOrigin = allowedOrigin.substr(8)
|
||
|
else:
|
||
|
prefix = r"https?:\/\/"
|
||
|
matchOrigin = allowedOrigin
|
||
|
|
||
|
matchOrigin = matchOrigin.replace(".", r"\.")
|
||
|
matchOrigin = matchOrigin.replace("*", ".*")
|
||
|
matchOrigin = matchOrigin.replace("?", ".?")
|
||
|
|
||
|
matchExpressions.add("^" & prefix & matchOrigin & "$")
|
||
|
|
||
|
let finalExpression = matchExpressions.join("|")
|
||
|
|
||
|
return some(re(finalExpression, {reIgnoreCase, reExtended}))
|
||
|
except RegexError:
|
||
|
var msg = getCurrentExceptionMsg()
|
||
|
error "Failed to compile regex", source=allowedOrigin, err=msg
|
||
|
return none(Regex)
|
||
|
|
||
|
proc originsMatch(originHandler: OriginHandlerMiddlewareRef,
|
||
|
requestOrigin: string): bool =
|
||
|
|
||
|
if originHandler.allowedOriginMatcher.isNone():
|
||
|
return false
|
||
|
|
||
|
return requestOrigin.match(originHandler.allowedOriginMatcher.get())
|
||
|
|
||
|
proc originMiddlewareProc(
|
||
|
middleware: HttpServerMiddlewareRef,
|
||
|
reqfence: RequestFence,
|
||
|
nextHandler: HttpProcessCallback2
|
||
|
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
|
||
|
if reqfence.isErr():
|
||
|
# Ignore request errors that detected before our middleware.
|
||
|
# Let final handler deal with it.
|
||
|
return await nextHandler(reqfence)
|
||
|
|
||
|
let self = OriginHandlerMiddlewareRef(middleware)
|
||
|
let request = reqfence.get()
|
||
|
var reqHeaders = request.headers
|
||
|
var response = request.getResponse()
|
||
|
|
||
|
if self.allowedOriginMatcher.isSome():
|
||
|
let origin = reqHeaders.getList("Origin")
|
||
|
try:
|
||
|
if origin.len == 1:
|
||
|
if self.everyOriginAllowed:
|
||
|
response.addHeader("Access-Control-Allow-Origin", "*")
|
||
|
elif self.originsMatch(origin[0]):
|
||
|
# The Vary: Origin header to must be set to prevent
|
||
|
# potential cache poisoning attacks:
|
||
|
# https://textslashplain.com/2018/08/02/cors-and-vary/
|
||
|
response.addHeader("Vary", "Origin")
|
||
|
response.addHeader("Access-Control-Allow-Origin", origin[0])
|
||
|
else:
|
||
|
return await request.respond(Http403, "Origin not allowed")
|
||
|
elif origin.len == 0:
|
||
|
discard
|
||
|
elif origin.len > 1:
|
||
|
return await request.respond(Http400, "Only a single Origin header must be specified")
|
||
|
except HttpWriteError as exc:
|
||
|
# We use default error handler if we unable to send response.
|
||
|
return defaultResponse(exc)
|
||
|
|
||
|
# Calling next handler.
|
||
|
return await nextHandler(reqfence)
|
||
|
|
||
|
proc new*(t: typedesc[OriginHandlerMiddlewareRef],
|
||
|
allowedOrigin: Option[string] = none(string)
|
||
|
): HttpServerMiddlewareRef =
|
||
|
|
||
|
let middleware =
|
||
|
OriginHandlerMiddlewareRef(allowedOriginMatcher: compileOriginMatcher(allowedOrigin),
|
||
|
everyOriginAllowed: isEveryOriginAllowed(allowedOrigin),
|
||
|
handler: originMiddlewareProc)
|
||
|
return HttpServerMiddlewareRef(middleware)
|