Add ability to handle REST server errors with custom responses. (#60)

* Add ability to handle REST server errors with custom responses.
Add tests.

* Fix chunked encoding issue.
Add test for chunked encoding.
This commit is contained in:
Eugene Kabanov 2023-10-19 01:16:16 +03:00 committed by GitHub
parent 2ae448ff5b
commit e80546edf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 146 additions and 25 deletions

View File

@ -17,6 +17,7 @@ type
SecureRestServer* = object of RootObj
server*: SecureHttpServerRef
router*: RestRouter
errorHandler*: RestRequestErrorHandler
SecureRestServerRef* = ref SecureRestServer
@ -35,9 +36,13 @@ proc new*(t: typedesc[SecureRestServerRef],
bufferSize: int = 4096,
httpHeadersTimeout = 10.seconds,
maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576
maxRequestBodySize: int = 1_048_576,
requestErrorHandler: RestRequestErrorHandler = nil
): RestResult[SecureRestServerRef] =
var server = SecureRestServerRef(router: router)
var server = SecureRestServerRef(
router: router,
errorHandler: requestErrorHandler
)
proc processCallback(rf: RequestFence): Future[HttpResponseRef] =
processRestRequest(server, rf)

View File

@ -17,6 +17,7 @@ type
RestServer* = object of RootObj
server*: HttpServerRef
router*: RestRouter
errorHandler*: RestRequestErrorHandler
RestServerRef* = ref RestServer
@ -32,8 +33,10 @@ proc new*(t: typedesc[RestServerRef],
bufferSize: int = 4096,
httpHeadersTimeout = 10.seconds,
maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576): RestResult[RestServerRef] =
var server = RestServerRef(router: router)
maxRequestBodySize: int = 1_048_576,
requestErrorHandler: RestRequestErrorHandler = nil
): RestResult[RestServerRef] =
var server = RestServerRef(router: router, errorHandler: requestErrorHandler)
proc processCallback(rf: RequestFence): Future[HttpResponseRef] =
processRestRequest[RestServerRef](server, rf)

View File

@ -7,6 +7,7 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/options
import chronos, chronos/apps/http/httpserver
import chronicles
import common
export chronicles, options
@ -33,3 +34,10 @@ chronicles.expandIt(RestApiError):
type
RestServerState* {.pure.} = enum
Closed, Stopped, Running
RestRequestError* {.pure.} = enum
Invalid, NotFound, InvalidContentBody, InvalidContentType, Unexpected
RestRequestErrorHandler* = proc(
error: RestRequestError,
request: HttpRequestRef): Future[HttpResponseRef] {.async.}

View File

@ -30,18 +30,16 @@ proc getContentBody*(r: HttpRequestRef): Future[Option[ContentBody]] {.
async.} =
if r.meth notin PostMethods:
return none[ContentBody]()
else:
if r.hasBody() and r.contentLength > 0:
if r.contentTypeData.isNone():
raise newException(RestBadRequestError,
"Incorrect/missing Content-Type header")
let
data = await r.getBody()
cbody = ContentBody(contentType: r.contentTypeData.get(),
data: data)
return some[ContentBody](cbody)
else:
return none[ContentBody]()
if not(r.hasBody()):
return none[ContentBody]()
if (HttpRequestFlags.BoundBody in r.requestFlags) and (r.contentLength == 0):
return none[ContentBody]()
if r.contentTypeData.isNone():
raise newException(RestBadRequestError,
"Incorrect/missing Content-Type header")
let data = await r.getBody()
return some[ContentBody](
ContentBody(contentType: r.contentTypeData.get(), data: data))
proc originsMatch(requestOrigin, allowedOrigin: string): bool =
if allowedOrigin.startsWith("http://") or
@ -67,7 +65,6 @@ when defined(metrics):
if RestServerMetricsType.Status in route.metrics:
let
endpoint = $route.routePath
icode = toInt(code)
scode = Base10.toString(uint64(toInt(code)))
presto_server_response_status_count.inc(1, @[endpoint, scode])
@ -117,7 +114,12 @@ proc processRestRequest*[T](server: T,
when defined(metrics):
processStatusMetrics(route, Http400)
return await request.respond(Http400)
return
if isNil(server.errorHandler):
await request.respond(Http400)
else:
await server.errorHandler(
RestRequestError.InvalidContentBody, request)
except RestBadRequestError as exc:
debug "Request has incorrect content type", uri = $request.uri,
peer = $request.remoteAddress(), meth = $request.meth,
@ -126,7 +128,12 @@ proc processRestRequest*[T](server: T,
when defined(metrics):
processStatusMetrics(route, Http400)
return await request.respond(Http400)
return
if isNil(server.errorHandler):
await request.respond(Http400)
else:
await server.errorHandler(
RestRequestError.InvalidContentType, request)
except CatchableError as exc:
warn "Unexpected exception while getting request body",
uri = $request.uri, peer = $request.remoteAddress(),
@ -136,7 +143,12 @@ proc processRestRequest*[T](server: T,
when defined(metrics):
processStatusMetrics(route, Http400)
return await request.respond(Http400)
return
if isNil(server.errorHandler):
await request.respond(Http400)
else:
await server.errorHandler(
RestRequestError.Unexpected, request)
else:
none[ContentBody]()
@ -324,7 +336,11 @@ proc processRestRequest*[T](server: T,
when defined(metrics):
presto_server_missing_requests_count.inc()
return await request.respond(Http404, "", HttpTable.init())
return
if isNil(server.errorHandler):
await request.respond(Http404, "", HttpTable.init())
else:
await server.errorHandler(RestRequestError.NotFound, request)
else:
debug "Received invalid request", peer = $request.remoteAddress(),
meth = $request.meth, uri = $request.uri
@ -332,7 +348,11 @@ proc processRestRequest*[T](server: T,
when defined(metrics):
presto_server_invalid_requests_count.inc()
return await request.respond(Http400, "", HttpTable.init())
return
if isNil(server.errorHandler):
await request.respond(Http400, "", HttpTable.init())
else:
await server.errorHandler(RestRequestError.Invalid, request)
else:
let httpErr = rf.error()
if httpErr.error == HttpServerError.DisconnectError:

View File

@ -61,14 +61,21 @@ proc init(t: typedesc[ClientResponse], status: int, data: string,
proc httpClient(server: TransportAddress, meth: HttpMethod, url: string,
body: string, ctype = "",
accept = ""): Future[ClientResponse] {.async.} =
accept = "", encoding = "",
length = -1): Future[ClientResponse] {.async.} =
var request = $meth & " " & $parseUri(url) & " HTTP/1.1\r\n"
request.add("Host: " & $server & "\r\n")
request.add("Content-Length: " & $len(body) & "\r\n")
if len(encoding) == 0:
if length >= 0:
request.add("Content-Length: " & $length & "\r\n")
else:
request.add("Content-Length: " & $len(body) & "\r\n")
if len(ctype) > 0:
request.add("Content-Type: " & ctype & "\r\n")
if len(accept) > 0:
request.add("Accept: " & accept & "\r\n")
if len(encoding) > 0:
request.add("Transfer-Encoding: " & encoding & "\r\n")
request.add("\r\n")
if len(body) > 0:
@ -424,7 +431,7 @@ suite "REST API server test suite":
("/test/1/2/0xaa?opt1=1&opt2=2&opt3=0xbb&opt4=2&opt4=3&opt4=4&opt5=t&" &
"opt5=e&opt5=s&opt5=t&opt6=0xCA&opt6=0xFE", "text/plain", "textbody"),
ClientResponse.init(200,
"1:2:aa:1:2:bb:2,3,4:t,e,s,t:ca,fe:text/plain,textbody")
"1:2:aa:1:2:bb:2,3,4:t,e,s,t:ca,fe:text/plain,textbody")
)
]
@ -440,6 +447,18 @@ suite "REST API server test suite":
if len(item[1].data) > 0:
check res.data == item[1].data
block:
let res = await httpClient(serverAddress, MethodPost,
url = "/test/1/2/0xaa",
body = "4\r\nWiki\r\n5\r\npedia\r\nE\r\n " &
"in\r\n\r\nchunks.\r\n0\r\n\r\n",
ctype = "application/octet-stream",
accept = "*/*",
encoding = "chunked")
check:
res.status == 200
res.data == "1:2:aa:::::::application/octet-stream,Wikipedia " &
"in\r\n\r\nchunks."
finally:
await server.closeWait()
@ -933,5 +952,71 @@ suite "REST API server test suite":
RestServerMetrics) do () -> RestApiResponse:
return RestApiResponse.response("ok-10", Http200, "test/test")
asyncTest "Custom error handlers test":
const
InvalidRequest = "////////////////////////////////////////////////////////////////////test"
var router = RestRouter.init(testValidate)
router.api(MethodGet, "/test") do () -> RestApiResponse:
return RestApiResponse.response("test", Http200)
router.api(MethodPost, "/post") do () -> RestApiResponse:
return RestApiResponse.response("post", Http200)
proc processError(
kind: RestRequestError,
request: HttpRequestRef
): Future[HttpResponseRef] {.async.} =
case kind
of RestRequestError.Invalid:
return await request.respond(Http201, "INVALID", HttpTable.init())
of RestRequestError.NotFound:
return await request.respond(Http202, "NOT FOUND", HttpTable.init())
of RestRequestError.InvalidContentBody:
# This type of error is tough to emulate for test, its only possible
# with chunked encoding with incorrect encoding headers.
return await request.respond(Http203, "CONTENT BODY", HttpTable.init())
of RestRequestError.InvalidContentType:
return await request.respond(Http204, "CONTENT TYPE", HttpTable.init())
of RestRequestError.Unexpected:
# This type of error should not be happened at all
return defaultResponse()
var sres = RestServerRef.new(router, serverAddress,
requestErrorHandler = processError)
let server = sres.get()
server.start()
let address = server.server.instance.localAddress()
block:
let res = await httpClient(address, MethodGet, InvalidRequest, "")
check:
res.status == 201
res.data == "INVALID"
block:
let res1 = await httpClient(address, MethodPost, "/test", "")
let res2 = await httpClient(address, MethodGet, "/tes", "")
check:
res1.status == 202
res2.status == 202
res1.data == "NOT FOUND"
res2.data == "NOT FOUND"
block:
# Invalid content body
let res = await httpClient(address, MethodPost, "/post", "z\r\n1",
ctype = "application/octet-stream",
encoding = "chunked")
check:
res.status == 203
res.data == "CONTENT BODY"
block:
# Missing `Content-Type` header for requests which has body.
let res = await httpClient(address, MethodPost, "/post", "data")
check:
res.status == 204
res.data == "CONTENT TYPE"
await server.stop()
await server.closeWait()
test "Leaks test":
checkLeaks()