diff --git a/presto/secureserver.nim b/presto/secureserver.nim index 144c010..e1d2312 100644 --- a/presto/secureserver.nim +++ b/presto/secureserver.nim @@ -17,6 +17,7 @@ type SecureRestServer* = object of RootObj server*: SecureHttpServerRef router*: RestRouter + errorHandler*: RestRequestErrorHandler SecureRestServerRef* = ref SecureRestServer @@ -35,9 +36,13 @@ proc new*(t: typedesc[SecureRestServerRef], bufferSize: int = 4096, httpHeadersTimeout = 10.seconds, maxHeadersSize: int = 8192, - maxRequestBodySize: int = 1_048_576 + maxRequestBodySize: int = 1_048_576, + requestErrorHandler: RestRequestErrorHandler = nil ): RestResult[SecureRestServerRef] = - var server = SecureRestServerRef(router: router) + var server = SecureRestServerRef( + router: router, + errorHandler: requestErrorHandler + ) proc processCallback(rf: RequestFence): Future[HttpResponseRef] = processRestRequest(server, rf) diff --git a/presto/server.nim b/presto/server.nim index 4544b05..b435ce6 100644 --- a/presto/server.nim +++ b/presto/server.nim @@ -17,6 +17,7 @@ type RestServer* = object of RootObj server*: HttpServerRef router*: RestRouter + errorHandler*: RestRequestErrorHandler RestServerRef* = ref RestServer @@ -32,8 +33,10 @@ proc new*(t: typedesc[RestServerRef], bufferSize: int = 4096, httpHeadersTimeout = 10.seconds, maxHeadersSize: int = 8192, - maxRequestBodySize: int = 1_048_576): RestResult[RestServerRef] = - var server = RestServerRef(router: router) + maxRequestBodySize: int = 1_048_576, + requestErrorHandler: RestRequestErrorHandler = nil + ): RestResult[RestServerRef] = + var server = RestServerRef(router: router, errorHandler: requestErrorHandler) proc processCallback(rf: RequestFence): Future[HttpResponseRef] = processRestRequest[RestServerRef](server, rf) diff --git a/presto/servercommon.nim b/presto/servercommon.nim index c89884c..8e391f8 100644 --- a/presto/servercommon.nim +++ b/presto/servercommon.nim @@ -7,6 +7,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/options +import chronos, chronos/apps/http/httpserver import chronicles import common export chronicles, options @@ -33,3 +34,10 @@ chronicles.expandIt(RestApiError): type RestServerState* {.pure.} = enum Closed, Stopped, Running + + RestRequestError* {.pure.} = enum + Invalid, NotFound, InvalidContentBody, InvalidContentType, Unexpected + + RestRequestErrorHandler* = proc( + error: RestRequestError, + request: HttpRequestRef): Future[HttpResponseRef] {.async.} diff --git a/presto/serverprivate.nim b/presto/serverprivate.nim index b6046b6..2aaa13c 100644 --- a/presto/serverprivate.nim +++ b/presto/serverprivate.nim @@ -30,18 +30,16 @@ proc getContentBody*(r: HttpRequestRef): Future[Option[ContentBody]] {. async.} = if r.meth notin PostMethods: return none[ContentBody]() - else: - if r.hasBody() and r.contentLength > 0: - if r.contentTypeData.isNone(): - raise newException(RestBadRequestError, - "Incorrect/missing Content-Type header") - let - data = await r.getBody() - cbody = ContentBody(contentType: r.contentTypeData.get(), - data: data) - return some[ContentBody](cbody) - else: - return none[ContentBody]() + if not(r.hasBody()): + return none[ContentBody]() + if (HttpRequestFlags.BoundBody in r.requestFlags) and (r.contentLength == 0): + return none[ContentBody]() + if r.contentTypeData.isNone(): + raise newException(RestBadRequestError, + "Incorrect/missing Content-Type header") + let data = await r.getBody() + return some[ContentBody]( + ContentBody(contentType: r.contentTypeData.get(), data: data)) proc originsMatch(requestOrigin, allowedOrigin: string): bool = if allowedOrigin.startsWith("http://") or @@ -67,7 +65,6 @@ when defined(metrics): if RestServerMetricsType.Status in route.metrics: let endpoint = $route.routePath - icode = toInt(code) scode = Base10.toString(uint64(toInt(code))) presto_server_response_status_count.inc(1, @[endpoint, scode]) @@ -117,7 +114,12 @@ proc processRestRequest*[T](server: T, when defined(metrics): processStatusMetrics(route, Http400) - return await request.respond(Http400) + return + if isNil(server.errorHandler): + await request.respond(Http400) + else: + await server.errorHandler( + RestRequestError.InvalidContentBody, request) except RestBadRequestError as exc: debug "Request has incorrect content type", uri = $request.uri, peer = $request.remoteAddress(), meth = $request.meth, @@ -126,7 +128,12 @@ proc processRestRequest*[T](server: T, when defined(metrics): processStatusMetrics(route, Http400) - return await request.respond(Http400) + return + if isNil(server.errorHandler): + await request.respond(Http400) + else: + await server.errorHandler( + RestRequestError.InvalidContentType, request) except CatchableError as exc: warn "Unexpected exception while getting request body", uri = $request.uri, peer = $request.remoteAddress(), @@ -136,7 +143,12 @@ proc processRestRequest*[T](server: T, when defined(metrics): processStatusMetrics(route, Http400) - return await request.respond(Http400) + return + if isNil(server.errorHandler): + await request.respond(Http400) + else: + await server.errorHandler( + RestRequestError.Unexpected, request) else: none[ContentBody]() @@ -324,7 +336,11 @@ proc processRestRequest*[T](server: T, when defined(metrics): presto_server_missing_requests_count.inc() - return await request.respond(Http404, "", HttpTable.init()) + return + if isNil(server.errorHandler): + await request.respond(Http404, "", HttpTable.init()) + else: + await server.errorHandler(RestRequestError.NotFound, request) else: debug "Received invalid request", peer = $request.remoteAddress(), meth = $request.meth, uri = $request.uri @@ -332,7 +348,11 @@ proc processRestRequest*[T](server: T, when defined(metrics): presto_server_invalid_requests_count.inc() - return await request.respond(Http400, "", HttpTable.init()) + return + if isNil(server.errorHandler): + await request.respond(Http400, "", HttpTable.init()) + else: + await server.errorHandler(RestRequestError.Invalid, request) else: let httpErr = rf.error() if httpErr.error == HttpServerError.DisconnectError: diff --git a/tests/testserver.nim b/tests/testserver.nim index e5dcce7..84cbe95 100644 --- a/tests/testserver.nim +++ b/tests/testserver.nim @@ -61,14 +61,21 @@ proc init(t: typedesc[ClientResponse], status: int, data: string, proc httpClient(server: TransportAddress, meth: HttpMethod, url: string, body: string, ctype = "", - accept = ""): Future[ClientResponse] {.async.} = + accept = "", encoding = "", + length = -1): Future[ClientResponse] {.async.} = var request = $meth & " " & $parseUri(url) & " HTTP/1.1\r\n" request.add("Host: " & $server & "\r\n") - request.add("Content-Length: " & $len(body) & "\r\n") + if len(encoding) == 0: + if length >= 0: + request.add("Content-Length: " & $length & "\r\n") + else: + request.add("Content-Length: " & $len(body) & "\r\n") if len(ctype) > 0: request.add("Content-Type: " & ctype & "\r\n") if len(accept) > 0: request.add("Accept: " & accept & "\r\n") + if len(encoding) > 0: + request.add("Transfer-Encoding: " & encoding & "\r\n") request.add("\r\n") if len(body) > 0: @@ -424,7 +431,7 @@ suite "REST API server test suite": ("/test/1/2/0xaa?opt1=1&opt2=2&opt3=0xbb&opt4=2&opt4=3&opt4=4&opt5=t&" & "opt5=e&opt5=s&opt5=t&opt6=0xCA&opt6=0xFE", "text/plain", "textbody"), ClientResponse.init(200, - "1:2:aa:1:2:bb:2,3,4:t,e,s,t:ca,fe:text/plain,textbody") + "1:2:aa:1:2:bb:2,3,4:t,e,s,t:ca,fe:text/plain,textbody") ) ] @@ -440,6 +447,18 @@ suite "REST API server test suite": if len(item[1].data) > 0: check res.data == item[1].data + block: + let res = await httpClient(serverAddress, MethodPost, + url = "/test/1/2/0xaa", + body = "4\r\nWiki\r\n5\r\npedia\r\nE\r\n " & + "in\r\n\r\nchunks.\r\n0\r\n\r\n", + ctype = "application/octet-stream", + accept = "*/*", + encoding = "chunked") + check: + res.status == 200 + res.data == "1:2:aa:::::::application/octet-stream,Wikipedia " & + "in\r\n\r\nchunks." finally: await server.closeWait() @@ -933,5 +952,71 @@ suite "REST API server test suite": RestServerMetrics) do () -> RestApiResponse: return RestApiResponse.response("ok-10", Http200, "test/test") + asyncTest "Custom error handlers test": + const + InvalidRequest = "////////////////////////////////////////////////////////////////////test" + + var router = RestRouter.init(testValidate) + router.api(MethodGet, "/test") do () -> RestApiResponse: + return RestApiResponse.response("test", Http200) + router.api(MethodPost, "/post") do () -> RestApiResponse: + return RestApiResponse.response("post", Http200) + + proc processError( + kind: RestRequestError, + request: HttpRequestRef + ): Future[HttpResponseRef] {.async.} = + case kind + of RestRequestError.Invalid: + return await request.respond(Http201, "INVALID", HttpTable.init()) + of RestRequestError.NotFound: + return await request.respond(Http202, "NOT FOUND", HttpTable.init()) + of RestRequestError.InvalidContentBody: + # This type of error is tough to emulate for test, its only possible + # with chunked encoding with incorrect encoding headers. + return await request.respond(Http203, "CONTENT BODY", HttpTable.init()) + of RestRequestError.InvalidContentType: + return await request.respond(Http204, "CONTENT TYPE", HttpTable.init()) + of RestRequestError.Unexpected: + # This type of error should not be happened at all + return defaultResponse() + + var sres = RestServerRef.new(router, serverAddress, + requestErrorHandler = processError) + let server = sres.get() + server.start() + let address = server.server.instance.localAddress() + + block: + let res = await httpClient(address, MethodGet, InvalidRequest, "") + check: + res.status == 201 + res.data == "INVALID" + block: + let res1 = await httpClient(address, MethodPost, "/test", "") + let res2 = await httpClient(address, MethodGet, "/tes", "") + check: + res1.status == 202 + res2.status == 202 + res1.data == "NOT FOUND" + res2.data == "NOT FOUND" + block: + # Invalid content body + let res = await httpClient(address, MethodPost, "/post", "z\r\n1", + ctype = "application/octet-stream", + encoding = "chunked") + check: + res.status == 203 + res.data == "CONTENT BODY" + block: + # Missing `Content-Type` header for requests which has body. + let res = await httpClient(address, MethodPost, "/post", "data") + check: + res.status == 204 + res.data == "CONTENT TYPE" + + await server.stop() + await server.closeWait() + test "Leaks test": checkLeaks()