From b02b9608c3c4a4815da39583847dad026d89781d Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Fri, 12 Jan 2024 15:27:36 +0200 Subject: [PATCH] HTTP server middleware implementation. (#483) * HTTP server middleware implementation and test. * Address review comments. * Address review comments. --- chronos/apps/http/httpserver.nim | 339 ++++++++++++++++++++--------- docs/examples/middleware.nim | 130 +++++++++++ docs/src/SUMMARY.md | 1 + docs/src/examples.md | 2 + docs/src/http_server_middleware.md | 102 +++++++++ tests/testhttpserver.nim | 299 ++++++++++++++++++++++++- 6 files changed, 771 insertions(+), 102 deletions(-) create mode 100644 docs/examples/middleware.nim create mode 100644 docs/src/http_server_middleware.md diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index 9646956..c716d14 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -11,11 +11,14 @@ import std/[tables, uri, strutils] import stew/[base10], httputils, results -import ../../asyncloop, ../../asyncsync +import ../../[asyncloop, asyncsync] import ../../streams/[asyncstream, boundstream, chunkstream] import "."/[httptable, httpcommon, multipart] +from ../../transports/common import TransportAddress, ServerFlags, `$`, `==` + export asyncloop, asyncsync, httptable, httpcommon, httputils, multipart, asyncstream, boundstream, chunkstream, uri, tables, results +export TransportAddress, ServerFlags, `$`, `==` type HttpServerFlags* {.pure.} = enum @@ -107,6 +110,7 @@ type maxRequestBodySize*: int processCallback*: HttpProcessCallback2 createConnCallback*: HttpConnectionCallback + middlewares: seq[HttpProcessCallback2] HttpServerRef* = ref HttpServer @@ -158,6 +162,16 @@ type HttpConnectionRef* = ref HttpConnection + MiddlewareHandleCallback* = proc( + middleware: HttpServerMiddlewareRef, request: RequestFence, + handler: HttpProcessCallback2): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} + + HttpServerMiddleware* = object of RootObj + handler*: MiddlewareHandleCallback + + HttpServerMiddlewareRef* = ref HttpServerMiddleware + ByteChar* = string | seq[byte] proc init(htype: typedesc[HttpProcessError], error: HttpServerError, @@ -175,6 +189,8 @@ proc init(htype: typedesc[HttpProcessError], proc defaultResponse*(exc: ref CatchableError): HttpResponseRef +proc defaultResponse*(msg: HttpMessage): HttpResponseRef + proc new(htype: typedesc[HttpConnectionHolderRef], server: HttpServerRef, transp: StreamTransport, connectionId: string): HttpConnectionHolderRef = @@ -188,20 +204,54 @@ proc createConnection(server: HttpServerRef, transp: StreamTransport): Future[HttpConnectionRef] {. async: (raises: [CancelledError, HttpConnectionError]).} -proc new*(htype: typedesc[HttpServerRef], - address: TransportAddress, - processCallback: HttpProcessCallback2, - serverFlags: set[HttpServerFlags] = {}, - socketFlags: set[ServerFlags] = {ReuseAddr}, - serverUri = Uri(), - serverIdent = "", - maxConnections: int = -1, - bufferSize: int = 4096, - backlogSize: int = DefaultBacklogSize, - httpHeadersTimeout = 10.seconds, - maxHeadersSize: int = 8192, - maxRequestBodySize: int = 1_048_576, - dualstack = DualStackType.Auto): HttpResult[HttpServerRef] = +proc prepareMiddlewares( + requestProcessCallback: HttpProcessCallback2, + middlewares: openArray[HttpServerMiddlewareRef] + ): seq[HttpProcessCallback2] = + var + handlers: seq[HttpProcessCallback2] + currentHandler = requestProcessCallback + + if len(middlewares) == 0: + return handlers + + let mws = @middlewares + handlers = newSeq[HttpProcessCallback2](len(mws)) + + for index in countdown(len(mws) - 1, 0): + let processor = + block: + var res: HttpProcessCallback2 + closureScope: + let + middleware = mws[index] + realHandler = currentHandler + res = + proc(request: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError], raw: true).} = + middleware.handler(middleware, request, realHandler) + res + handlers[index] = processor + currentHandler = processor + handlers + +proc new*( + htype: typedesc[HttpServerRef], + address: TransportAddress, + processCallback: HttpProcessCallback2, + serverFlags: set[HttpServerFlags] = {}, + socketFlags: set[ServerFlags] = {ReuseAddr}, + serverUri = Uri(), + serverIdent = "", + maxConnections: int = -1, + bufferSize: int = 4096, + backlogSize: int = DefaultBacklogSize, + httpHeadersTimeout = 10.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576, + dualstack = DualStackType.Auto, + middlewares: openArray[HttpServerMiddlewareRef] = [] + ): HttpResult[HttpServerRef] = let serverUri = if len(serverUri.hostname) > 0: @@ -240,24 +290,28 @@ proc new*(htype: typedesc[HttpServerRef], # else: # nil lifetime: newFuture[void]("http.server.lifetime"), - connections: initOrderedTable[string, HttpConnectionHolderRef]() + connections: initOrderedTable[string, HttpConnectionHolderRef](), + middlewares: prepareMiddlewares(processCallback, middlewares) ) ok(res) -proc new*(htype: typedesc[HttpServerRef], - address: TransportAddress, - processCallback: HttpProcessCallback, - serverFlags: set[HttpServerFlags] = {}, - socketFlags: set[ServerFlags] = {ReuseAddr}, - serverUri = Uri(), - serverIdent = "", - maxConnections: int = -1, - bufferSize: int = 4096, - backlogSize: int = DefaultBacklogSize, - httpHeadersTimeout = 10.seconds, - maxHeadersSize: int = 8192, - maxRequestBodySize: int = 1_048_576, - dualstack = DualStackType.Auto): HttpResult[HttpServerRef] {. +proc new*( + htype: typedesc[HttpServerRef], + address: TransportAddress, + processCallback: HttpProcessCallback, + serverFlags: set[HttpServerFlags] = {}, + socketFlags: set[ServerFlags] = {ReuseAddr}, + serverUri = Uri(), + serverIdent = "", + maxConnections: int = -1, + bufferSize: int = 4096, + backlogSize: int = DefaultBacklogSize, + httpHeadersTimeout = 10.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576, + dualstack = DualStackType.Auto, + middlewares: openArray[HttpServerMiddlewareRef] = [] + ): HttpResult[HttpServerRef] {. deprecated: "Callback could raise only CancelledError, annotate with " & "{.async: (raises: [CancelledError]).}".} = @@ -273,7 +327,7 @@ proc new*(htype: typedesc[HttpServerRef], HttpServerRef.new(address, wrap, serverFlags, socketFlags, serverUri, serverIdent, maxConnections, bufferSize, backlogSize, httpHeadersTimeout, maxHeadersSize, maxRequestBodySize, - dualstack) + dualstack, middlewares) proc getServerFlags(req: HttpRequestRef): set[HttpServerFlags] = var defaultFlags: set[HttpServerFlags] = {} @@ -345,6 +399,18 @@ proc defaultResponse*(exc: ref CatchableError): HttpResponseRef = else: HttpResponseRef(state: HttpResponseState.ErrorCode, status: Http503) +proc defaultResponse*(msg: HttpMessage): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.ErrorCode, status: msg.code) + +proc defaultResponse*(err: HttpProcessError): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.ErrorCode, status: err.code) + +proc dropResponse*(): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.Failed) + +proc codeResponse*(status: HttpCode): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.ErrorCode, status: status) + proc dumbResponse*(): HttpResponseRef {. deprecated: "Please use defaultResponse() instead".} = ## Create an empty response to return when request processor got no request. @@ -362,29 +428,21 @@ proc hasBody*(request: HttpRequestRef): bool = request.requestFlags * {HttpRequestFlags.BoundBody, HttpRequestFlags.UnboundBody} != {} -proc prepareRequest(conn: HttpConnectionRef, - req: HttpRequestHeader): HttpResultMessage[HttpRequestRef] = - var request = HttpRequestRef(connection: conn, state: HttpState.Alive) +func new(t: typedesc[HttpRequestRef], conn: HttpConnectionRef): HttpRequestRef = + HttpRequestRef(connection: conn, state: HttpState.Alive) - if req.version notin {HttpVersion10, HttpVersion11}: - return err(HttpMessage.init(Http505, "Unsupported HTTP protocol version")) +proc updateRequest*(request: HttpRequestRef, scheme: string, meth: HttpMethod, + version: HttpVersion, requestUri: string, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. - request.scheme = - if HttpServerFlags.Secure in conn.server.flags: - "https" - else: - "http" - - request.version = req.version - request.meth = req.meth - - request.rawPath = - block: - let res = req.uri() - if len(res) == 0: - return err(HttpMessage.init(Http400, "Invalid request URI")) - res + # Store request version and call method. + request.scheme = scheme + request.version = version + request.meth = meth + # Processing request's URI + request.rawPath = requestUri request.uri = if request.rawPath != "*": let uri = parseUri(request.rawPath) @@ -396,10 +454,11 @@ proc prepareRequest(conn: HttpConnectionRef, uri.path = "*" uri + # Conversion of request query string to HttpTable. request.query = block: let queryFlags = - if QueryCommaSeparatedArray in conn.server.flags: + if QueryCommaSeparatedArray in request.connection.server.flags: {QueryParamsFlag.CommaSeparatedArray} else: {} @@ -408,22 +467,8 @@ proc prepareRequest(conn: HttpConnectionRef, table.add(key, value) table - request.headers = - block: - var table = HttpTable.init() - # Retrieve headers and values - for key, value in req.headers(): - table.add(key, value) - # Validating HTTP request headers - # Some of the headers must be present only once. - if table.count(ContentTypeHeader) > 1: - return err(HttpMessage.init(Http400, "Multiple Content-Type headers")) - if table.count(ContentLengthHeader) > 1: - return err(HttpMessage.init(Http400, "Multiple Content-Length headers")) - if table.count(TransferEncodingHeader) > 1: - return err(HttpMessage.init(Http400, - "Multuple Transfer-Encoding headers")) - table + # Store request headers + request.headers = headers # Preprocessing "Content-Encoding" header. request.contentEncoding = @@ -443,15 +488,17 @@ proc prepareRequest(conn: HttpConnectionRef, # steps to reveal information about body. request.contentLength = if ContentLengthHeader in request.headers: + # Request headers has `Content-Length` header present. let length = request.headers.getInt(ContentLengthHeader) if length != 0: if request.meth == MethodTrace: let msg = "TRACE requests could not have request body" return err(HttpMessage.init(Http400, msg)) - # Because of coversion to `int` we should avoid unexpected OverflowError. + # Because of coversion to `int` we should avoid unexpected + # OverflowError. if length > uint64(high(int)): return err(HttpMessage.init(Http413, "Unsupported content length")) - if length > uint64(conn.server.maxRequestBodySize): + if length > uint64(request.connection.server.maxRequestBodySize): return err(HttpMessage.init(Http413, "Content length exceeds limits")) request.requestFlags.incl(HttpRequestFlags.BoundBody) int(length) @@ -459,6 +506,7 @@ proc prepareRequest(conn: HttpConnectionRef, 0 else: if TransferEncodingFlags.Chunked in request.transferEncoding: + # Request headers has "Transfer-Encoding: chunked" header present. if request.meth == MethodTrace: let msg = "TRACE requests could not have request body" return err(HttpMessage.init(Http400, msg)) @@ -466,8 +514,9 @@ proc prepareRequest(conn: HttpConnectionRef, 0 if request.hasBody(): - # If request has body, we going to understand how its encoded. + # If the request has a body, we will determine how it is encoded. if ContentTypeHeader in request.headers: + # Request headers has "Content-Type" header present. let contentType = getContentType(request.headers.getList(ContentTypeHeader)).valueOr: let msg = "Incorrect or missing Content-Type header" @@ -477,12 +526,67 @@ proc prepareRequest(conn: HttpConnectionRef, elif contentType == MultipartContentType: request.requestFlags.incl(HttpRequestFlags.MultipartForm) request.contentTypeData = Opt.some(contentType) - + # If `Expect` header is present, we will handle expectation procedure. if ExpectHeader in request.headers: let expectHeader = request.headers.getString(ExpectHeader) if strip(expectHeader).toLowerAscii() == "100-continue": request.requestFlags.incl(HttpRequestFlags.ClientExpect) + ok() + +proc updateRequest*(request: HttpRequestRef, meth: HttpMethod, + requestUri: string, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, meth, request.version, requestUri, + headers) + +proc updateRequest*(request: HttpRequestRef, requestUri: string, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, request.meth, request.version, + requestUri, headers) + +proc updateRequest*(request: HttpRequestRef, + requestUri: string): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, request.meth, request.version, + requestUri, request.headers) + +proc updateRequest*(request: HttpRequestRef, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, request.meth, request.version, + request.rawPath, headers) + +proc prepareRequest(conn: HttpConnectionRef, + req: HttpRequestHeader): HttpResultMessage[HttpRequestRef] = + let + request = HttpRequestRef.new(conn) + scheme = + if HttpServerFlags.Secure in conn.server.flags: + "https" + else: + "http" + headers = + block: + var table = HttpTable.init() + # Retrieve headers and values + for key, value in req.headers(): + table.add(key, value) + # Validating HTTP request headers + # Some of the headers must be present only once. + if table.count(ContentTypeHeader) > 1: + return err(HttpMessage.init(Http400, + "Multiple Content-Type headers")) + if table.count(ContentLengthHeader) > 1: + return err(HttpMessage.init(Http400, + "Multiple Content-Length headers")) + if table.count(TransferEncodingHeader) > 1: + return err(HttpMessage.init(Http400, + "Multuple Transfer-Encoding headers")) + table + ? updateRequest(request, scheme, req.meth, req.version, req.uri(), headers) trackCounter(HttpServerRequestTrackerName) ok(request) @@ -736,16 +840,19 @@ proc sendDefaultResponse( # Response was ignored, so we respond with not found. await conn.sendErrorResponse(version, Http404, keepConnection.toBool()) + response.setResponseState(HttpResponseState.Finished) keepConnection of HttpResponseState.Prepared: # Response was prepared but not sent, so we can respond with some # error code await conn.sendErrorResponse(HttpVersion11, Http409, keepConnection.toBool()) + response.setResponseState(HttpResponseState.Finished) keepConnection of HttpResponseState.ErrorCode: # Response with error code await conn.sendErrorResponse(version, response.status, false) + response.setResponseState(HttpResponseState.Finished) HttpProcessExitType.Immediate of HttpResponseState.Sending, HttpResponseState.Failed, HttpResponseState.Cancelled: @@ -755,6 +862,7 @@ proc sendDefaultResponse( # Response was ignored, so we respond with not found. await conn.sendErrorResponse(version, Http404, keepConnection.toBool()) + response.setResponseState(HttpResponseState.Finished) keepConnection of HttpResponseState.Finished: keepConnection @@ -878,6 +986,25 @@ proc getRemoteAddress(connection: HttpConnectionRef): Opt[TransportAddress] = if isNil(connection): return Opt.none(TransportAddress) getRemoteAddress(connection.transp) +proc getLocalAddress(transp: StreamTransport): Opt[TransportAddress] = + if isNil(transp): return Opt.none(TransportAddress) + try: + Opt.some(transp.localAddress()) + except TransportOsError: + Opt.none(TransportAddress) + +proc getLocalAddress(connection: HttpConnectionRef): Opt[TransportAddress] = + if isNil(connection): return Opt.none(TransportAddress) + getLocalAddress(connection.transp) + +proc remote*(request: HttpRequestRef): Opt[TransportAddress] = + ## Returns remote address of HTTP request's connection. + request.connection.getRemoteAddress() + +proc local*(request: HttpRequestRef): Opt[TransportAddress] = + ## Returns local address of HTTP request's connection. + request.connection.getLocalAddress() + proc getRequestFence*(server: HttpServerRef, connection: HttpConnectionRef): Future[RequestFence] {. async: (raises: []).} = @@ -920,6 +1047,14 @@ proc getConnectionFence*(server: HttpServerRef, ConnectionFence.err(HttpProcessError.init( HttpServerError.DisconnectError, exc, address, Http400)) +proc invokeProcessCallback(server: HttpServerRef, + req: RequestFence): Future[HttpResponseRef] {. + async: (raw: true, raises: [CancelledError]).} = + if len(server.middlewares) > 0: + server.middlewares[0](req) + else: + server.processCallback(req) + proc processRequest(server: HttpServerRef, connection: HttpConnectionRef, connId: string): Future[HttpProcessExitType] {. @@ -941,7 +1076,7 @@ proc processRequest(server: HttpServerRef, try: let response = try: - await connection.server.processCallback(requestFence) + await invokeProcessCallback(connection.server, requestFence) except CancelledError: # Cancelled, exiting return HttpProcessExitType.Immediate @@ -962,7 +1097,7 @@ proc processLoop(holder: HttpConnectionHolderRef) {.async: (raises: []).} = if res.isErr(): if res.error.kind != HttpServerError.InterruptError: discard await noCancel( - server.processCallback(RequestFence.err(res.error))) + invokeProcessCallback(server, RequestFence.err(res.error))) server.connections.del(connectionId) return res.get() @@ -1160,31 +1295,6 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {. elif HttpRequestFlags.UnboundBody in req.requestFlags: raiseHttpProtocolError(Http400, "Unsupported request body") -proc setHeader*(resp: HttpResponseRef, key, value: string) = - ## Sets value of header ``key`` to ``value``. - doAssert(resp.getResponseState() == HttpResponseState.Empty) - resp.headersTable.set(key, value) - -proc setHeaderDefault*(resp: HttpResponseRef, key, value: string) = - ## Sets value of header ``key`` to ``value``, only if header ``key`` is not - ## present in the headers table. - discard resp.headersTable.hasKeyOrPut(key, value) - -proc addHeader*(resp: HttpResponseRef, key, value: string) = - ## Adds value ``value`` to header's ``key`` value. - doAssert(resp.getResponseState() == HttpResponseState.Empty) - resp.headersTable.add(key, value) - -proc getHeader*(resp: HttpResponseRef, key: string, - default: string = ""): string = - ## Returns value of header with name ``name`` or ``default``, if header is - ## not present in the table. - resp.headersTable.getString(key, default) - -proc hasHeader*(resp: HttpResponseRef, key: string): bool = - ## Returns ``true`` if header with name ``key`` present in the headers table. - key in resp.headersTable - template checkPending(t: untyped) = let currentState = t.getResponseState() doAssert(currentState == HttpResponseState.Empty, @@ -1199,10 +1309,41 @@ template checkStreamResponseState(t: untyped) = {HttpResponseState.Prepared, HttpResponseState.Sending}, "Response is in the wrong state") +template checkResponseCanBeModified(t: untyped) = + doAssert(t.getResponseState() in + {HttpResponseState.Empty, HttpResponseState.ErrorCode}, + "Response could not be modified at this stage") + template checkPointerLength(t1, t2: untyped) = doAssert(not(isNil(t1)), "pbytes must not be nil") doAssert(t2 >= 0, "nbytes should be bigger or equal to zero") +proc setHeader*(resp: HttpResponseRef, key, value: string) = + ## Sets value of header ``key`` to ``value``. + checkResponseCanBeModified(resp) + resp.headersTable.set(key, value) + +proc setHeaderDefault*(resp: HttpResponseRef, key, value: string) = + ## Sets value of header ``key`` to ``value``, only if header ``key`` is not + ## present in the headers table. + checkResponseCanBeModified(resp) + discard resp.headersTable.hasKeyOrPut(key, value) + +proc addHeader*(resp: HttpResponseRef, key, value: string) = + ## Adds value ``value`` to header's ``key`` value. + checkResponseCanBeModified(resp) + resp.headersTable.add(key, value) + +proc getHeader*(resp: HttpResponseRef, key: string, + default: string = ""): string = + ## Returns value of header with name ``name`` or ``default``, if header is + ## not present in the table. + resp.headersTable.getString(key, default) + +proc hasHeader*(resp: HttpResponseRef, key: string): bool = + ## Returns ``true`` if header with name ``key`` present in the headers table. + key in resp.headersTable + func createHeaders(resp: HttpResponseRef): string = var answer = $(resp.version) & " " & $(resp.status) & "\r\n" for k, v in resp.headersTable.stringItems(): diff --git a/docs/examples/middleware.nim b/docs/examples/middleware.nim new file mode 100644 index 0000000..9d06a89 --- /dev/null +++ b/docs/examples/middleware.nim @@ -0,0 +1,130 @@ +import chronos/apps/http/httpserver + +{.push raises: [].} + +proc firstMiddlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors + return await nextHandler(reqfence) + + let request = reqfence.get() + var headers = request.headers + + if request.uri.path.startsWith("/path/to/hidden/resources"): + headers.add("X-Filter", "drop") + elif request.uri.path.startsWith("/path/to/blocked/resources"): + headers.add("X-Filter", "block") + else: + headers.add("X-Filter", "pass") + + # Updating request by adding new HTTP header `X-Filter`. + let res = request.updateRequest(headers) + if res.isErr(): + # We use default error handler in case of error which will respond with + # proper HTTP status code error. + return defaultResponse(res.error) + + # Calling next handler. + await nextHandler(reqfence) + +proc secondMiddlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors + return await nextHandler(reqfence) + + let + request = reqfence.get() + filtered = request.headers.getString("X-Filter", "pass") + + if filtered == "drop": + # Force HTTP server to drop connection with remote peer. + dropResponse() + elif filtered == "block": + # Force HTTP server to respond with HTTP `404 Not Found` error code. + codeResponse(Http404) + else: + # Calling next handler. + await nextHandler(reqfence) + +proc thirdMiddlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors + return await nextHandler(reqfence) + + let request = reqfence.get() + echo "QUERY = [", request.rawPath, "]" + echo request.headers + try: + if request.uri.path == "/path/to/plugin/resources/page1": + await request.respond(Http200, "PLUGIN PAGE1") + elif request.uri.path == "/path/to/plugin/resources/page2": + await request.respond(Http200, "PLUGIN PAGE2") + else: + # Calling next handler. + await nextHandler(reqfence) + except HttpWriteError as exc: + # We use default error handler if we unable to send response. + defaultResponse(exc) + +proc mainHandler( + reqfence: RequestFence +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + return defaultResponse() + + let request = reqfence.get() + try: + if request.uri.path == "/path/to/original/page1": + await request.respond(Http200, "ORIGINAL PAGE1") + elif request.uri.path == "/path/to/original/page2": + await request.respond(Http200, "ORIGINAL PAGE2") + else: + # Force HTTP server to respond with `404 Not Found` status code. + codeResponse(Http404) + except HttpWriteError as exc: + defaultResponse(exc) + +proc middlewareExample() {.async: (raises: []).} = + let + middlewares = [ + HttpServerMiddlewareRef(handler: firstMiddlewareHandler), + HttpServerMiddlewareRef(handler: secondMiddlewareHandler), + HttpServerMiddlewareRef(handler: thirdMiddlewareHandler) + ] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + boundAddress = + if isAvailable(AddressFamily.IPv6): + AnyAddress6 + else: + AnyAddress + res = HttpServerRef.new(boundAddress, mainHandler, + socketFlags = socketFlags, + middlewares = middlewares) + + doAssert(res.isOk(), "Unable to start HTTP server") + let server = res.get() + server.start() + let address = server.instance.localAddress() + echo "HTTP server running on ", address + try: + await server.join() + except CancelledError: + discard + finally: + await server.stop() + await server.closeWait() + +when isMainModule: + waitFor(middlewareExample()) diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 4f2ee56..f834367 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -8,6 +8,7 @@ - [Errors and exceptions](./error_handling.md) - [Tips, tricks and best practices](./tips.md) - [Porting code to `chronos`](./porting.md) +- [HTTP server middleware](./http_server_middleware.md) # Developer guide diff --git a/docs/src/examples.md b/docs/src/examples.md index c71247c..49c6dc4 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -16,3 +16,5 @@ Examples are available in the [`docs/examples/`](https://github.com/status-im/ni * [httpget](https://github.com/status-im/nim-chronos/tree/master/docs/examples/httpget.nim) - Downloading a web page using the http client * [twogets](https://github.com/status-im/nim-chronos/tree/master/docs/examples/twogets.nim) - Download two pages concurrently +* [middleware](https://github.com/status-im/nim-chronos/tree/master/docs/examples/middleware.nim) +- Deploy multiple HTTP server middlewares diff --git a/docs/src/http_server_middleware.md b/docs/src/http_server_middleware.md new file mode 100644 index 0000000..6edd9b5 --- /dev/null +++ b/docs/src/http_server_middleware.md @@ -0,0 +1,102 @@ +## HTTP server middleware + +Chronos provides a powerful mechanism for customizing HTTP request handlers via +middlewares. + +A middleware is a coroutine that can modify, block or filter HTTP request. + +Single HTTP server could support unlimited number of middlewares, but you need to consider that each request in worst case could go through all the middlewares, and therefore a huge number of middlewares can have a significant impact on HTTP server performance. + +Order of middlewares is also important: right after HTTP server has received request, it will be sent to the first middleware in list, and each middleware will be responsible for passing control to other middlewares. Therefore, when building a list, it would be a good idea to place the request handlers at the end of the list, while keeping the middleware that could block or modify the request at the beginning of the list. + +Middleware could also modify HTTP server request, and these changes will be visible to all handlers (either middlewares or the original request handler). This can be done using the following helpers: + +```nim + proc updateRequest*(request: HttpRequestRef, scheme: string, meth: HttpMethod, + version: HttpVersion, requestUri: string, + headers: HttpTable): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, meth: HttpMethod, + requestUri: string, + headers: HttpTable): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, requestUri: string, + headers: HttpTable): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, + requestUri: string): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, + headers: HttpTable): HttpResultMessage[void] +``` + +As you can see all the HTTP request parameters could be modified: request method, version, request path and request headers. + +Middleware could also use helpers to obtain more information about remote and local addresses of request's connection (this could be helpful when you need to do some IP address filtering). + +```nim + proc remote*(request: HttpRequestRef): Opt[TransportAddress] + ## Returns remote address of HTTP request's connection. + proc local*(request: HttpRequestRef): Opt[TransportAddress] = + ## Returns local address of HTTP request's connection. +``` + +Every middleware is the coroutine which looks like this: + +```nim + proc middlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = +``` + +Where `middleware` argument is the object which could hold some specific values, `reqfence` is HTTP request which is enclosed with HTTP server error information and `nextHandler` is reference to next request handler, it could be either middleware handler or the original request processing callback handler. + +```nim + await nextHandler(reqfence) +``` + +You should perform await for the response from the `nextHandler(reqfence)`. Usually you should call next handler when you dont want to handle request or you dont know how to handle it, for example: + +```nim + proc middlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # We dont know or do not want to handle failed requests, so we call next handler. + return await nextHandler(reqfence) + let request = reqfence.get() + if request.uri.path == "/path/we/able/to/respond": + try: + # Sending some response. + await request.respond(Http200, "TEST") + except HttpWriteError as exc: + # We could also return default response for exception or other types of error. + defaultResponse(exc) + elif request.uri.path == "/path/for/rewrite": + # We going to modify request object for this request, next handler will receive it with different request path. + let res = request.updateRequest("/path/to/new/location") + if res.isErr(): + return defaultResponse(res.error) + await nextHandler(reqfence) + elif request.uri.path == "/restricted/path": + if request.remote().isNone(): + # We can't obtain remote address, so we force HTTP server to respond with `401 Unauthorized` status code. + return codeResponse(Http401) + if $(request.remote().get()).startsWith("127.0.0.1"): + # Remote peer's address starts with "127.0.0.1", sending proper response. + await request.respond(Http200, "AUTHORIZED") + else: + # Force HTTP server to respond with `403 Forbidden` status code. + codeResponse(Http403) + elif request.uri.path == "/blackhole": + # Force HTTP server to drop connection with remote peer. + dropResponse() + else: + # All other requests should be handled by somebody else. + await nextHandler(reqfence) +``` + diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index 0183f1b..91064f5 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -18,9 +18,16 @@ suite "HTTP server testing suite": TooBigTest = enum GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest TestHttpResponse = object + status: int headers: HttpTable data: string + FirstMiddlewareRef = ref object of HttpServerMiddlewareRef + someInteger: int + + SecondMiddlewareRef = ref object of HttpServerMiddlewareRef + someString: string + proc httpClient(address: TransportAddress, data: string): Future[string] {.async.} = var transp: StreamTransport @@ -50,7 +57,7 @@ suite "HTTP server testing suite": zeroMem(addr buffer[0], len(buffer)) await transp.readExactly(addr buffer[0], length) let data = bytesToString(buffer.toOpenArray(0, length - 1)) - let headers = + let (status, headers) = block: let resp = parseResponse(hdata, false) if resp.failed(): @@ -58,8 +65,38 @@ suite "HTTP server testing suite": var res = HttpTable.init() for key, value in resp.headers(hdata): res.add(key, value) - res - return TestHttpResponse(headers: headers, data: data) + (resp.code, res) + TestHttpResponse(status: status, headers: headers, data: data) + + proc httpClient3(address: TransportAddress, + data: string): Future[TestHttpResponse] {.async.} = + var + transp: StreamTransport + buffer = newSeq[byte](4096) + sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8] + try: + transp = await connect(address) + if len(data) > 0: + let wres = await transp.write(data) + if wres != len(data): + raise newException(ValueError, "Unable to write full request") + let hres = await transp.readUntil(addr buffer[0], len(buffer), sep) + var hdata = @buffer + hdata.setLen(hres) + var rres = bytesToString(await transp.read()) + let (status, headers) = + block: + let resp = parseResponse(hdata, false) + if resp.failed(): + raise newException(ValueError, "Unable to decode response headers") + var res = HttpTable.init() + for key, value in resp.headers(hdata): + res.add(key, value) + (resp.code, res) + TestHttpResponse(status: status, headers: headers, data: rres) + finally: + if not(isNil(transp)): + await closeWait(transp) proc testTooBigBodyChunked(operation: TooBigTest): Future[bool] {.async.} = var serverRes = false @@ -1490,5 +1527,261 @@ suite "HTTP server testing suite": await server.stop() await server.closeWait() + asyncTest "HTTP middleware request filtering test": + proc init(t: typedesc[FirstMiddlewareRef], + data: int): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + let mw = FirstMiddlewareRef(middleware) + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let request = reqfence.get() + if request.uri.path == "/first": + # This is request we are waiting for, so we going to process it. + try: + await request.respond(Http200, $mw.someInteger) + except HttpWriteError as exc: + defaultResponse(exc) + else: + # We know nothing about request's URI, so we pass this request to the + # next handler which could process such request. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + FirstMiddlewareRef(someInteger: data, handler: shandler)) + + proc init(t: typedesc[SecondMiddlewareRef], + data: string): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + let mw = SecondMiddlewareRef(middleware) + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let request = reqfence.get() + + if request.uri.path == "/second": + # This is request we are waiting for, so we going to process it. + try: + await request.respond(Http200, mw.someString) + except HttpWriteError as exc: + defaultResponse(exc) + else: + # We know nothing about request's URI, so we pass this request to the + # next handler which could process such request. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + SecondMiddlewareRef(someString: data, handler: shandler)) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + if r.isOk(): + let request = r.get() + if request.uri.path == "/test": + try: + await request.respond(Http200, "ORIGIN") + except HttpWriteError as exc: + defaultResponse(exc) + else: + defaultResponse() + else: + defaultResponse() + + let + middlewares = [FirstMiddlewareRef.init(655370), + SecondMiddlewareRef.init("SECOND")] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + middlewares = middlewares) + check res.isOk() + + let server = res.get() + server.start() + let + address = server.instance.localAddress() + req1 = "GET /test HTTP/1.1\r\n\r\n" + req2 = "GET /first HTTP/1.1\r\n\r\n" + req3 = "GET /second HTTP/1.1\r\n\r\n" + req4 = "GET /noway HTTP/1.1\r\n\r\n" + resp1 = await httpClient3(address, req1) + resp2 = await httpClient3(address, req2) + resp3 = await httpClient3(address, req3) + resp4 = await httpClient3(address, req4) + + check: + resp1.status == 200 + resp1.data == "ORIGIN" + resp2.status == 200 + resp2.data == "655370" + resp3.status == 200 + resp3.data == "SECOND" + resp4.status == 404 + + await server.stop() + await server.closeWait() + + asyncTest "HTTP middleware request modification test": + proc init(t: typedesc[FirstMiddlewareRef], + data: int): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + let mw = FirstMiddlewareRef(middleware) + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let + request = reqfence.get() + modifiedUri = "/modified/" & $mw.someInteger & request.rawPath + var modifiedHeaders = request.headers + modifiedHeaders.add("X-Modified", "test-value") + + let res = request.updateRequest(modifiedUri, modifiedHeaders) + if res.isErr(): + return defaultResponse(res.error) + + # We sending modified request to the next handler. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + FirstMiddlewareRef(someInteger: data, handler: shandler)) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + if r.isOk(): + let request = r.get() + try: + await request.respond(Http200, request.rawPath & ":" & + request.headers.getString("x-modified")) + except HttpWriteError as exc: + defaultResponse(exc) + else: + defaultResponse() + + let + middlewares = [FirstMiddlewareRef.init(655370)] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + middlewares = middlewares) + check res.isOk() + + let server = res.get() + server.start() + let + address = server.instance.localAddress() + req1 = "GET /test HTTP/1.1\r\n\r\n" + req2 = "GET /first HTTP/1.1\r\n\r\n" + req3 = "GET /second HTTP/1.1\r\n\r\n" + req4 = "GET /noway HTTP/1.1\r\n\r\n" + resp1 = await httpClient3(address, req1) + resp2 = await httpClient3(address, req2) + resp3 = await httpClient3(address, req3) + resp4 = await httpClient3(address, req4) + + check: + resp1.status == 200 + resp1.data == "/modified/655370/test:test-value" + resp2.status == 200 + resp2.data == "/modified/655370/first:test-value" + resp3.status == 200 + resp3.data == "/modified/655370/second:test-value" + resp4.status == 200 + resp4.data == "/modified/655370/noway:test-value" + + await server.stop() + await server.closeWait() + + asyncTest "HTTP middleware request blocking test": + proc init(t: typedesc[FirstMiddlewareRef], + data: int): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let request = reqfence.get() + if request.uri.path == "/first": + # Blocking request by disconnecting remote peer. + dropResponse() + elif request.uri.path == "/second": + # Blocking request by sending HTTP error message with 401 code. + codeResponse(Http401) + else: + # Allow all other requests to be processed by next handler. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + FirstMiddlewareRef(someInteger: data, handler: shandler)) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + if r.isOk(): + let request = r.get() + try: + await request.respond(Http200, "ORIGIN") + except HttpWriteError as exc: + defaultResponse(exc) + else: + defaultResponse() + + let + middlewares = [FirstMiddlewareRef.init(655370)] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + middlewares = middlewares) + check res.isOk() + + let server = res.get() + server.start() + let + address = server.instance.localAddress() + req1 = "GET /test HTTP/1.1\r\n\r\n" + req2 = "GET /first HTTP/1.1\r\n\r\n" + req3 = "GET /second HTTP/1.1\r\n\r\n" + resp1 = await httpClient3(address, req1) + resp3 = await httpClient3(address, req3) + + check: + resp1.status == 200 + resp1.data == "ORIGIN" + resp3.status == 401 + + let checked = + try: + let res {.used.} = await httpClient3(address, req2) + false + except TransportIncompleteError: + true + + check: + checked == true + + await server.stop() + await server.closeWait() + test "Leaks test": checkLeaks()