Add ability to handle REST server errors with custom responses. (#60)
* Add ability to handle REST server errors with custom responses. Add tests. * Fix chunked encoding issue. Add test for chunked encoding.
This commit is contained in:
parent
2ae448ff5b
commit
e80546edf2
|
@ -17,6 +17,7 @@ type
|
|||
SecureRestServer* = object of RootObj
|
||||
server*: SecureHttpServerRef
|
||||
router*: RestRouter
|
||||
errorHandler*: RestRequestErrorHandler
|
||||
|
||||
SecureRestServerRef* = ref SecureRestServer
|
||||
|
||||
|
@ -35,9 +36,13 @@ proc new*(t: typedesc[SecureRestServerRef],
|
|||
bufferSize: int = 4096,
|
||||
httpHeadersTimeout = 10.seconds,
|
||||
maxHeadersSize: int = 8192,
|
||||
maxRequestBodySize: int = 1_048_576
|
||||
maxRequestBodySize: int = 1_048_576,
|
||||
requestErrorHandler: RestRequestErrorHandler = nil
|
||||
): RestResult[SecureRestServerRef] =
|
||||
var server = SecureRestServerRef(router: router)
|
||||
var server = SecureRestServerRef(
|
||||
router: router,
|
||||
errorHandler: requestErrorHandler
|
||||
)
|
||||
|
||||
proc processCallback(rf: RequestFence): Future[HttpResponseRef] =
|
||||
processRestRequest(server, rf)
|
||||
|
|
|
@ -17,6 +17,7 @@ type
|
|||
RestServer* = object of RootObj
|
||||
server*: HttpServerRef
|
||||
router*: RestRouter
|
||||
errorHandler*: RestRequestErrorHandler
|
||||
|
||||
RestServerRef* = ref RestServer
|
||||
|
||||
|
@ -32,8 +33,10 @@ proc new*(t: typedesc[RestServerRef],
|
|||
bufferSize: int = 4096,
|
||||
httpHeadersTimeout = 10.seconds,
|
||||
maxHeadersSize: int = 8192,
|
||||
maxRequestBodySize: int = 1_048_576): RestResult[RestServerRef] =
|
||||
var server = RestServerRef(router: router)
|
||||
maxRequestBodySize: int = 1_048_576,
|
||||
requestErrorHandler: RestRequestErrorHandler = nil
|
||||
): RestResult[RestServerRef] =
|
||||
var server = RestServerRef(router: router, errorHandler: requestErrorHandler)
|
||||
|
||||
proc processCallback(rf: RequestFence): Future[HttpResponseRef] =
|
||||
processRestRequest[RestServerRef](server, rf)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/options
|
||||
import chronos, chronos/apps/http/httpserver
|
||||
import chronicles
|
||||
import common
|
||||
export chronicles, options
|
||||
|
@ -33,3 +34,10 @@ chronicles.expandIt(RestApiError):
|
|||
type
|
||||
RestServerState* {.pure.} = enum
|
||||
Closed, Stopped, Running
|
||||
|
||||
RestRequestError* {.pure.} = enum
|
||||
Invalid, NotFound, InvalidContentBody, InvalidContentType, Unexpected
|
||||
|
||||
RestRequestErrorHandler* = proc(
|
||||
error: RestRequestError,
|
||||
request: HttpRequestRef): Future[HttpResponseRef] {.async.}
|
||||
|
|
|
@ -30,18 +30,16 @@ proc getContentBody*(r: HttpRequestRef): Future[Option[ContentBody]] {.
|
|||
async.} =
|
||||
if r.meth notin PostMethods:
|
||||
return none[ContentBody]()
|
||||
else:
|
||||
if r.hasBody() and r.contentLength > 0:
|
||||
if r.contentTypeData.isNone():
|
||||
raise newException(RestBadRequestError,
|
||||
"Incorrect/missing Content-Type header")
|
||||
let
|
||||
data = await r.getBody()
|
||||
cbody = ContentBody(contentType: r.contentTypeData.get(),
|
||||
data: data)
|
||||
return some[ContentBody](cbody)
|
||||
else:
|
||||
return none[ContentBody]()
|
||||
if not(r.hasBody()):
|
||||
return none[ContentBody]()
|
||||
if (HttpRequestFlags.BoundBody in r.requestFlags) and (r.contentLength == 0):
|
||||
return none[ContentBody]()
|
||||
if r.contentTypeData.isNone():
|
||||
raise newException(RestBadRequestError,
|
||||
"Incorrect/missing Content-Type header")
|
||||
let data = await r.getBody()
|
||||
return some[ContentBody](
|
||||
ContentBody(contentType: r.contentTypeData.get(), data: data))
|
||||
|
||||
proc originsMatch(requestOrigin, allowedOrigin: string): bool =
|
||||
if allowedOrigin.startsWith("http://") or
|
||||
|
@ -67,7 +65,6 @@ when defined(metrics):
|
|||
if RestServerMetricsType.Status in route.metrics:
|
||||
let
|
||||
endpoint = $route.routePath
|
||||
icode = toInt(code)
|
||||
scode = Base10.toString(uint64(toInt(code)))
|
||||
presto_server_response_status_count.inc(1, @[endpoint, scode])
|
||||
|
||||
|
@ -117,7 +114,12 @@ proc processRestRequest*[T](server: T,
|
|||
when defined(metrics):
|
||||
processStatusMetrics(route, Http400)
|
||||
|
||||
return await request.respond(Http400)
|
||||
return
|
||||
if isNil(server.errorHandler):
|
||||
await request.respond(Http400)
|
||||
else:
|
||||
await server.errorHandler(
|
||||
RestRequestError.InvalidContentBody, request)
|
||||
except RestBadRequestError as exc:
|
||||
debug "Request has incorrect content type", uri = $request.uri,
|
||||
peer = $request.remoteAddress(), meth = $request.meth,
|
||||
|
@ -126,7 +128,12 @@ proc processRestRequest*[T](server: T,
|
|||
when defined(metrics):
|
||||
processStatusMetrics(route, Http400)
|
||||
|
||||
return await request.respond(Http400)
|
||||
return
|
||||
if isNil(server.errorHandler):
|
||||
await request.respond(Http400)
|
||||
else:
|
||||
await server.errorHandler(
|
||||
RestRequestError.InvalidContentType, request)
|
||||
except CatchableError as exc:
|
||||
warn "Unexpected exception while getting request body",
|
||||
uri = $request.uri, peer = $request.remoteAddress(),
|
||||
|
@ -136,7 +143,12 @@ proc processRestRequest*[T](server: T,
|
|||
when defined(metrics):
|
||||
processStatusMetrics(route, Http400)
|
||||
|
||||
return await request.respond(Http400)
|
||||
return
|
||||
if isNil(server.errorHandler):
|
||||
await request.respond(Http400)
|
||||
else:
|
||||
await server.errorHandler(
|
||||
RestRequestError.Unexpected, request)
|
||||
else:
|
||||
none[ContentBody]()
|
||||
|
||||
|
@ -324,7 +336,11 @@ proc processRestRequest*[T](server: T,
|
|||
when defined(metrics):
|
||||
presto_server_missing_requests_count.inc()
|
||||
|
||||
return await request.respond(Http404, "", HttpTable.init())
|
||||
return
|
||||
if isNil(server.errorHandler):
|
||||
await request.respond(Http404, "", HttpTable.init())
|
||||
else:
|
||||
await server.errorHandler(RestRequestError.NotFound, request)
|
||||
else:
|
||||
debug "Received invalid request", peer = $request.remoteAddress(),
|
||||
meth = $request.meth, uri = $request.uri
|
||||
|
@ -332,7 +348,11 @@ proc processRestRequest*[T](server: T,
|
|||
when defined(metrics):
|
||||
presto_server_invalid_requests_count.inc()
|
||||
|
||||
return await request.respond(Http400, "", HttpTable.init())
|
||||
return
|
||||
if isNil(server.errorHandler):
|
||||
await request.respond(Http400, "", HttpTable.init())
|
||||
else:
|
||||
await server.errorHandler(RestRequestError.Invalid, request)
|
||||
else:
|
||||
let httpErr = rf.error()
|
||||
if httpErr.error == HttpServerError.DisconnectError:
|
||||
|
|
|
@ -61,14 +61,21 @@ proc init(t: typedesc[ClientResponse], status: int, data: string,
|
|||
|
||||
proc httpClient(server: TransportAddress, meth: HttpMethod, url: string,
|
||||
body: string, ctype = "",
|
||||
accept = ""): Future[ClientResponse] {.async.} =
|
||||
accept = "", encoding = "",
|
||||
length = -1): Future[ClientResponse] {.async.} =
|
||||
var request = $meth & " " & $parseUri(url) & " HTTP/1.1\r\n"
|
||||
request.add("Host: " & $server & "\r\n")
|
||||
request.add("Content-Length: " & $len(body) & "\r\n")
|
||||
if len(encoding) == 0:
|
||||
if length >= 0:
|
||||
request.add("Content-Length: " & $length & "\r\n")
|
||||
else:
|
||||
request.add("Content-Length: " & $len(body) & "\r\n")
|
||||
if len(ctype) > 0:
|
||||
request.add("Content-Type: " & ctype & "\r\n")
|
||||
if len(accept) > 0:
|
||||
request.add("Accept: " & accept & "\r\n")
|
||||
if len(encoding) > 0:
|
||||
request.add("Transfer-Encoding: " & encoding & "\r\n")
|
||||
request.add("\r\n")
|
||||
|
||||
if len(body) > 0:
|
||||
|
@ -424,7 +431,7 @@ suite "REST API server test suite":
|
|||
("/test/1/2/0xaa?opt1=1&opt2=2&opt3=0xbb&opt4=2&opt4=3&opt4=4&opt5=t&" &
|
||||
"opt5=e&opt5=s&opt5=t&opt6=0xCA&opt6=0xFE", "text/plain", "textbody"),
|
||||
ClientResponse.init(200,
|
||||
"1:2:aa:1:2:bb:2,3,4:t,e,s,t:ca,fe:text/plain,textbody")
|
||||
"1:2:aa:1:2:bb:2,3,4:t,e,s,t:ca,fe:text/plain,textbody")
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -440,6 +447,18 @@ suite "REST API server test suite":
|
|||
if len(item[1].data) > 0:
|
||||
check res.data == item[1].data
|
||||
|
||||
block:
|
||||
let res = await httpClient(serverAddress, MethodPost,
|
||||
url = "/test/1/2/0xaa",
|
||||
body = "4\r\nWiki\r\n5\r\npedia\r\nE\r\n " &
|
||||
"in\r\n\r\nchunks.\r\n0\r\n\r\n",
|
||||
ctype = "application/octet-stream",
|
||||
accept = "*/*",
|
||||
encoding = "chunked")
|
||||
check:
|
||||
res.status == 200
|
||||
res.data == "1:2:aa:::::::application/octet-stream,Wikipedia " &
|
||||
"in\r\n\r\nchunks."
|
||||
finally:
|
||||
await server.closeWait()
|
||||
|
||||
|
@ -933,5 +952,71 @@ suite "REST API server test suite":
|
|||
RestServerMetrics) do () -> RestApiResponse:
|
||||
return RestApiResponse.response("ok-10", Http200, "test/test")
|
||||
|
||||
asyncTest "Custom error handlers test":
|
||||
const
|
||||
InvalidRequest = "////////////////////////////////////////////////////////////////////test"
|
||||
|
||||
var router = RestRouter.init(testValidate)
|
||||
router.api(MethodGet, "/test") do () -> RestApiResponse:
|
||||
return RestApiResponse.response("test", Http200)
|
||||
router.api(MethodPost, "/post") do () -> RestApiResponse:
|
||||
return RestApiResponse.response("post", Http200)
|
||||
|
||||
proc processError(
|
||||
kind: RestRequestError,
|
||||
request: HttpRequestRef
|
||||
): Future[HttpResponseRef] {.async.} =
|
||||
case kind
|
||||
of RestRequestError.Invalid:
|
||||
return await request.respond(Http201, "INVALID", HttpTable.init())
|
||||
of RestRequestError.NotFound:
|
||||
return await request.respond(Http202, "NOT FOUND", HttpTable.init())
|
||||
of RestRequestError.InvalidContentBody:
|
||||
# This type of error is tough to emulate for test, its only possible
|
||||
# with chunked encoding with incorrect encoding headers.
|
||||
return await request.respond(Http203, "CONTENT BODY", HttpTable.init())
|
||||
of RestRequestError.InvalidContentType:
|
||||
return await request.respond(Http204, "CONTENT TYPE", HttpTable.init())
|
||||
of RestRequestError.Unexpected:
|
||||
# This type of error should not be happened at all
|
||||
return defaultResponse()
|
||||
|
||||
var sres = RestServerRef.new(router, serverAddress,
|
||||
requestErrorHandler = processError)
|
||||
let server = sres.get()
|
||||
server.start()
|
||||
let address = server.server.instance.localAddress()
|
||||
|
||||
block:
|
||||
let res = await httpClient(address, MethodGet, InvalidRequest, "")
|
||||
check:
|
||||
res.status == 201
|
||||
res.data == "INVALID"
|
||||
block:
|
||||
let res1 = await httpClient(address, MethodPost, "/test", "")
|
||||
let res2 = await httpClient(address, MethodGet, "/tes", "")
|
||||
check:
|
||||
res1.status == 202
|
||||
res2.status == 202
|
||||
res1.data == "NOT FOUND"
|
||||
res2.data == "NOT FOUND"
|
||||
block:
|
||||
# Invalid content body
|
||||
let res = await httpClient(address, MethodPost, "/post", "z\r\n1",
|
||||
ctype = "application/octet-stream",
|
||||
encoding = "chunked")
|
||||
check:
|
||||
res.status == 203
|
||||
res.data == "CONTENT BODY"
|
||||
block:
|
||||
# Missing `Content-Type` header for requests which has body.
|
||||
let res = await httpClient(address, MethodPost, "/post", "data")
|
||||
check:
|
||||
res.status == 204
|
||||
res.data == "CONTENT TYPE"
|
||||
|
||||
await server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "Leaks test":
|
||||
checkLeaks()
|
||||
|
|
Loading…
Reference in New Issue