Address review comments.
This commit is contained in:
parent
6f8d06f12d
commit
fc0d1bcb43
|
@ -63,28 +63,15 @@ proc atBound*(bstream: HttpBodyReader): bool {.
|
||||||
let breader = cast[BoundedStreamReader](lreader)
|
let breader = cast[BoundedStreamReader](lreader)
|
||||||
breader.atEof() and (breader.bytesLeft() == 0)
|
breader.atEof() and (breader.bytesLeft() == 0)
|
||||||
|
|
||||||
proc newHttpDefect*(msg: string): ref HttpDefect {.
|
proc raiseHttpCriticalError*(msg: string,
|
||||||
raises: [HttpDefect].} =
|
code = Http400) {.noinline, noreturn.} =
|
||||||
newException(HttpDefect, msg)
|
raise (ref HttpCriticalError)(code: code, msg: msg)
|
||||||
|
|
||||||
proc newHttpCriticalError*(msg: string,
|
proc raiseHttpDisconnectError*() {.noinline, noreturn.} =
|
||||||
code = Http400): ref HttpCriticalError {.
|
raise (ref HttpDisconnectError)(msg: "Remote peer disconnected")
|
||||||
raises: [HttpCriticalError].} =
|
|
||||||
var tre = newException(HttpCriticalError, msg)
|
|
||||||
tre.code = code
|
|
||||||
tre
|
|
||||||
|
|
||||||
proc newHttpRecoverableError*(msg: string,
|
proc raiseHttpDefect*(msg: string) {.noinline, noreturn.} =
|
||||||
code = Http400): ref HttpRecoverableError {.
|
raise (ref HttpDefect)(msg: msg)
|
||||||
raises: [HttpRecoverableError].} =
|
|
||||||
var tre = newException(HttpRecoverableError, msg)
|
|
||||||
tre.code = code
|
|
||||||
tre
|
|
||||||
|
|
||||||
proc newHttpDisconnectError*(): ref HttpDisconnectError {.
|
|
||||||
raises: [HttpDisconnectError].} =
|
|
||||||
var tre = newException(HttpDisconnectError, "Remote peer disconnected")
|
|
||||||
tre
|
|
||||||
|
|
||||||
iterator queryParams*(query: string): tuple[key: string, value: string] {.
|
iterator queryParams*(query: string): tuple[key: string, value: string] {.
|
||||||
raises: [Defect].} =
|
raises: [Defect].} =
|
||||||
|
|
|
@ -349,7 +349,7 @@ proc handleExpect*(request: HttpRequestRef) {.async.} =
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
raise newHttpCriticalError("Unable to send `100-continue` response")
|
raiseHttpCriticalError("Unable to send `100-continue` response")
|
||||||
|
|
||||||
proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} =
|
proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} =
|
||||||
## Obtain request's body as sequence of bytes.
|
## Obtain request's body as sequence of bytes.
|
||||||
|
@ -363,9 +363,9 @@ proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} =
|
||||||
return await reader.read()
|
return await reader.read()
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
if reader.atBound():
|
if reader.atBound():
|
||||||
raise newHttpCriticalError("Maximum size of body reached", Http413)
|
raiseHttpCriticalError("Maximum size of body reached", Http413)
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Unable to read request's body")
|
raiseHttpCriticalError("Unable to read request's body")
|
||||||
finally:
|
finally:
|
||||||
await closeWait(res.get())
|
await closeWait(res.get())
|
||||||
|
|
||||||
|
@ -381,9 +381,9 @@ proc consumeBody*(request: HttpRequestRef): Future[void] {.async.} =
|
||||||
discard await reader.consume()
|
discard await reader.consume()
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
if reader.atBound():
|
if reader.atBound():
|
||||||
raise newHttpCriticalError("Maximum size of body reached", Http413)
|
raiseHttpCriticalError("Maximum size of body reached", Http413)
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Unable to read request's body")
|
raiseHttpCriticalError("Unable to read request's body")
|
||||||
finally:
|
finally:
|
||||||
await closeWait(res.get())
|
await closeWait(res.get())
|
||||||
|
|
||||||
|
@ -422,18 +422,17 @@ proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} =
|
||||||
conn.buffer.setLen(res)
|
conn.buffer.setLen(res)
|
||||||
let header = parseRequest(conn.buffer)
|
let header = parseRequest(conn.buffer)
|
||||||
if header.failed():
|
if header.failed():
|
||||||
raise newHttpCriticalError("Malformed request recieved")
|
raiseHttpCriticalError("Malformed request recieved")
|
||||||
else:
|
else:
|
||||||
let res = prepareRequest(conn, header)
|
let res = prepareRequest(conn, header)
|
||||||
if res.isErr():
|
if res.isErr():
|
||||||
raise newHttpCriticalError("Invalid request received", res.error)
|
raiseHttpCriticalError("Invalid request received", res.error)
|
||||||
else:
|
else:
|
||||||
return res.get()
|
return res.get()
|
||||||
except AsyncStreamIncompleteError, AsyncStreamReadError:
|
except AsyncStreamIncompleteError, AsyncStreamReadError:
|
||||||
raise newHttpDisconnectError()
|
raiseHttpDisconnectError()
|
||||||
except AsyncStreamLimitError:
|
except AsyncStreamLimitError:
|
||||||
raise newHttpCriticalError("Maximum size of request headers reached",
|
raiseHttpCriticalError("Maximum size of request headers reached", Http413)
|
||||||
Http413)
|
|
||||||
|
|
||||||
proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef,
|
proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef,
|
||||||
transp: StreamTransport): HttpConnectionRef =
|
transp: StreamTransport): HttpConnectionRef =
|
||||||
|
@ -503,7 +502,7 @@ proc createConnection(server: HttpServerRef,
|
||||||
raise exc
|
raise exc
|
||||||
except TLSStreamError:
|
except TLSStreamError:
|
||||||
await conn.closeWait()
|
await conn.closeWait()
|
||||||
raise newHttpCriticalError("Unable to establish secure connection")
|
raiseHttpCriticalError("Unable to establish secure connection")
|
||||||
|
|
||||||
proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
|
proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
|
||||||
var
|
var
|
||||||
|
@ -534,7 +533,7 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
|
||||||
return
|
return
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
# There should be no exceptions, so we will raise `Defect`.
|
# There should be no exceptions, so we will raise `Defect`.
|
||||||
raise newHttpDefect("Unexpected exception catched [" & $exc.name & "]")
|
raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]")
|
||||||
|
|
||||||
var breakLoop = false
|
var breakLoop = false
|
||||||
while runLoop:
|
while runLoop:
|
||||||
|
@ -763,7 +762,7 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {.async.} =
|
||||||
var table = HttpTable.init()
|
var table = HttpTable.init()
|
||||||
let res = getMultipartReader(req)
|
let res = getMultipartReader(req)
|
||||||
if res.isErr():
|
if res.isErr():
|
||||||
raise newHttpCriticalError("Unable to retrieve multipart form data")
|
raiseHttpCriticalError("Unable to retrieve multipart form data")
|
||||||
var mpreader = res.get()
|
var mpreader = res.get()
|
||||||
|
|
||||||
# We must handle `Expect` first.
|
# We must handle `Expect` first.
|
||||||
|
@ -808,10 +807,10 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {.async.} =
|
||||||
else:
|
else:
|
||||||
if HttpRequestFlags.BoundBody in req.requestFlags:
|
if HttpRequestFlags.BoundBody in req.requestFlags:
|
||||||
if req.contentLength != 0:
|
if req.contentLength != 0:
|
||||||
raise newHttpCriticalError("Unsupported request body")
|
raiseHttpCriticalError("Unsupported request body")
|
||||||
return HttpTable.init()
|
return HttpTable.init()
|
||||||
elif HttpRequestFlags.UnboundBody in req.requestFlags:
|
elif HttpRequestFlags.UnboundBody in req.requestFlags:
|
||||||
raise newHttpCriticalError("Unsupported request body")
|
raiseHttpCriticalError("Unsupported request body")
|
||||||
|
|
||||||
proc `keepalive=`*(resp: HttpResponseRef, value: bool) =
|
proc `keepalive=`*(resp: HttpResponseRef, value: bool) =
|
||||||
doAssert(resp.state == HttpResponseState.Empty)
|
doAssert(resp.state == HttpResponseState.Empty)
|
||||||
|
@ -854,7 +853,7 @@ template doHeaderVal(buf, name, value) =
|
||||||
|
|
||||||
template checkPending(t: untyped) =
|
template checkPending(t: untyped) =
|
||||||
if t.state != HttpResponseState.Empty:
|
if t.state != HttpResponseState.Empty:
|
||||||
raise newHttpCriticalError("Response body was already sent")
|
raiseHttpCriticalError("Response body was already sent")
|
||||||
|
|
||||||
proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {.
|
proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {.
|
||||||
raises: [Defect].}=
|
raises: [Defect].}=
|
||||||
|
@ -910,7 +909,7 @@ proc sendBody*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} =
|
proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} =
|
||||||
## Send HTTP response at once by using data ``data``.
|
## Send HTTP response at once by using data ``data``.
|
||||||
|
@ -928,7 +927,7 @@ proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} =
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} =
|
proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} =
|
||||||
## Send HTTP error status response.
|
## Send HTTP error status response.
|
||||||
|
@ -947,7 +946,7 @@ proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} =
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc prepare*(resp: HttpResponseRef) {.async.} =
|
proc prepare*(resp: HttpResponseRef) {.async.} =
|
||||||
## Prepare for HTTP stream response.
|
## Prepare for HTTP stream response.
|
||||||
|
@ -966,16 +965,16 @@ proc prepare*(resp: HttpResponseRef) {.async.} =
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
|
proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
|
||||||
## Send single chunk of data pointed by ``pbytes`` and ``nbytes``.
|
## Send single chunk of data pointed by ``pbytes`` and ``nbytes``.
|
||||||
doAssert(not(isNil(pbytes)), "pbytes must not be nil")
|
doAssert(not(isNil(pbytes)), "pbytes must not be nil")
|
||||||
doAssert(nbytes >= 0, "nbytes should be bigger or equal to zero")
|
doAssert(nbytes >= 0, "nbytes should be bigger or equal to zero")
|
||||||
if HttpResponseFlags.Chunked notin resp.flags:
|
if HttpResponseFlags.Chunked notin resp.flags:
|
||||||
raise newHttpCriticalError("Response was not prepared")
|
raiseHttpCriticalError("Response was not prepared")
|
||||||
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
|
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
|
||||||
raise newHttpCriticalError("Response in incorrect state")
|
raiseHttpCriticalError("Response in incorrect state")
|
||||||
try:
|
try:
|
||||||
resp.state = HttpResponseState.Sending
|
resp.state = HttpResponseState.Sending
|
||||||
await resp.chunkedWriter.write(pbytes, nbytes)
|
await resp.chunkedWriter.write(pbytes, nbytes)
|
||||||
|
@ -985,15 +984,15 @@ proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef,
|
proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef,
|
||||||
data: T) {.async.} =
|
data: T) {.async.} =
|
||||||
## Send single chunk of data ``data``.
|
## Send single chunk of data ``data``.
|
||||||
if HttpResponseFlags.Chunked notin resp.flags:
|
if HttpResponseFlags.Chunked notin resp.flags:
|
||||||
raise newHttpCriticalError("Response was not prepared")
|
raiseHttpCriticalError("Response was not prepared")
|
||||||
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
|
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
|
||||||
raise newHttpCriticalError("Response in incorrect state")
|
raiseHttpCriticalError("Response in incorrect state")
|
||||||
try:
|
try:
|
||||||
resp.state = HttpResponseState.Sending
|
resp.state = HttpResponseState.Sending
|
||||||
await resp.chunkedWriter.write(data)
|
await resp.chunkedWriter.write(data)
|
||||||
|
@ -1003,14 +1002,14 @@ proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef,
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc finish*(resp: HttpResponseRef) {.async.} =
|
proc finish*(resp: HttpResponseRef) {.async.} =
|
||||||
## Sending last chunk of data, so it will indicate end of HTTP response.
|
## Sending last chunk of data, so it will indicate end of HTTP response.
|
||||||
if HttpResponseFlags.Chunked notin resp.flags:
|
if HttpResponseFlags.Chunked notin resp.flags:
|
||||||
raise newHttpCriticalError("Response was not prepared")
|
raiseHttpCriticalError("Response was not prepared")
|
||||||
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
|
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
|
||||||
raise newHttpCriticalError("Response in incorrect state")
|
raiseHttpCriticalError("Response in incorrect state")
|
||||||
try:
|
try:
|
||||||
resp.state = HttpResponseState.Sending
|
resp.state = HttpResponseState.Sending
|
||||||
await resp.chunkedWriter.finish()
|
await resp.chunkedWriter.finish()
|
||||||
|
@ -1020,7 +1019,7 @@ proc finish*(resp: HttpResponseRef) {.async.} =
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
except AsyncStreamWriteError, AsyncStreamIncompleteError:
|
||||||
resp.state = HttpResponseState.Failed
|
resp.state = HttpResponseState.Failed
|
||||||
raise newHttpCriticalError("Unable to send response")
|
raiseHttpCriticalError("Unable to send response")
|
||||||
|
|
||||||
proc respond*(req: HttpRequestRef, code: HttpCode, content: string,
|
proc respond*(req: HttpRequestRef, code: HttpCode, content: string,
|
||||||
headers: HttpTable): Future[HttpResponseRef] {.async.} =
|
headers: HttpTable): Future[HttpResponseRef] {.async.} =
|
||||||
|
|
|
@ -149,14 +149,14 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||||
mpr.firstTime = false
|
mpr.firstTime = false
|
||||||
if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3),
|
if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3),
|
||||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))):
|
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))):
|
||||||
raise newHttpCriticalError("Unexpected boundary encountered")
|
raiseHttpCriticalError("Unexpected boundary encountered")
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
if mpr.stream.atBound():
|
if mpr.stream.atBound():
|
||||||
raise newHttpCriticalError("Maximum size of body reached", Http413)
|
raiseHttpCriticalError("Maximum size of body reached", Http413)
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Unable to read multipart body")
|
raiseHttpCriticalError("Unable to read multipart body")
|
||||||
|
|
||||||
# Reading part's headers
|
# Reading part's headers
|
||||||
try:
|
try:
|
||||||
|
@ -170,9 +170,9 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||||
raise newException(MultipartEOMError,
|
raise newException(MultipartEOMError,
|
||||||
"End of multipart message")
|
"End of multipart message")
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Incorrect multipart header found")
|
raiseHttpCriticalError("Incorrect multipart header found")
|
||||||
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
|
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
|
||||||
raise newHttpCriticalError("Incorrect multipart boundary found")
|
raiseHttpCriticalError("Incorrect multipart boundary found")
|
||||||
|
|
||||||
# If two bytes are CRLF we are at the part beginning.
|
# If two bytes are CRLF we are at the part beginning.
|
||||||
# Reading part's headers
|
# Reading part's headers
|
||||||
|
@ -180,7 +180,7 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||||
HeadersMark)
|
HeadersMark)
|
||||||
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
|
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
|
||||||
if headersList.failed():
|
if headersList.failed():
|
||||||
raise newHttpCriticalError("Incorrect multipart's headers found")
|
raiseHttpCriticalError("Incorrect multipart's headers found")
|
||||||
inc(mpr.counter)
|
inc(mpr.counter)
|
||||||
|
|
||||||
var part = MultiPart(
|
var part = MultiPart(
|
||||||
|
@ -196,16 +196,16 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||||
|
|
||||||
let sres = part.setPartNames()
|
let sres = part.setPartNames()
|
||||||
if sres.isErr():
|
if sres.isErr():
|
||||||
raise newHttpCriticalError(sres.error)
|
raiseHttpCriticalError($sres.error)
|
||||||
return part
|
return part
|
||||||
|
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
raise exc
|
raise exc
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
if mpr.stream.atBound():
|
if mpr.stream.atBound():
|
||||||
raise newHttpCriticalError("Maximum size of body reached", Http413)
|
raiseHttpCriticalError("Maximum size of body reached", Http413)
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Unable to read multipart body")
|
raiseHttpCriticalError("Unable to read multipart body")
|
||||||
|
|
||||||
proc atBound*(mp: MultiPart): bool =
|
proc atBound*(mp: MultiPart): bool =
|
||||||
## Returns ``true`` if MultiPart's stream reached request body maximum size.
|
## Returns ``true`` if MultiPart's stream reached request body maximum size.
|
||||||
|
@ -220,9 +220,9 @@ proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
|
||||||
return res
|
return res
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
if mp.breader.atBound():
|
if mp.breader.atBound():
|
||||||
raise newHttpCriticalError("Maximum size of body reached", Http413)
|
raiseHttpCriticalError("Maximum size of body reached", Http413)
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Unable to read multipart body")
|
raiseHttpCriticalError("Unable to read multipart body")
|
||||||
of MultiPartSource.Buffer:
|
of MultiPartSource.Buffer:
|
||||||
return mp.buffer
|
return mp.buffer
|
||||||
|
|
||||||
|
@ -234,9 +234,9 @@ proc consumeBody*(mp: MultiPart) {.async.} =
|
||||||
discard await mp.stream.consume()
|
discard await mp.stream.consume()
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
if mp.breader.atBound():
|
if mp.breader.atBound():
|
||||||
raise newHttpCriticalError("Maximum size of body reached", Http413)
|
raiseHttpCriticalError("Maximum size of body reached", Http413)
|
||||||
else:
|
else:
|
||||||
raise newHttpCriticalError("Unable to consume multipart body")
|
raiseHttpCriticalError("Unable to consume multipart body")
|
||||||
of MultiPartSource.Buffer:
|
of MultiPartSource.Buffer:
|
||||||
discard
|
discard
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ const
|
||||||
|
|
||||||
type
|
type
|
||||||
AsyncStreamError* = object of CatchableError
|
AsyncStreamError* = object of CatchableError
|
||||||
AsyncStreamIncorrectError* = object of Defect
|
AsyncStreamIncorrectDefect* = object of Defect
|
||||||
AsyncStreamIncompleteError* = object of AsyncStreamError
|
AsyncStreamIncompleteError* = object of AsyncStreamError
|
||||||
AsyncStreamLimitError* = object of AsyncStreamError
|
AsyncStreamLimitError* = object of AsyncStreamError
|
||||||
AsyncStreamUseClosedError* = object of AsyncStreamError
|
AsyncStreamUseClosedError* = object of AsyncStreamError
|
||||||
|
@ -179,36 +179,49 @@ template copyOut*(dest: pointer, item: WriteItem, length: int) =
|
||||||
copyMem(dest, unsafeAddr item.data3[item.offset], length)
|
copyMem(dest, unsafeAddr item.data3[item.offset], length)
|
||||||
|
|
||||||
proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {.
|
proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {.
|
||||||
inline.} =
|
noinline.} =
|
||||||
var w = newException(AsyncStreamReadError, "Read stream failed")
|
var w = newException(AsyncStreamReadError, "Read stream failed")
|
||||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||||
w.par = p
|
w.par = p
|
||||||
w
|
w
|
||||||
|
|
||||||
proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {.
|
proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {.
|
||||||
inline.} =
|
noinline.} =
|
||||||
var w = newException(AsyncStreamWriteError, "Write stream failed")
|
var w = newException(AsyncStreamWriteError, "Write stream failed")
|
||||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||||
w.par = p
|
w.par = p
|
||||||
w
|
w
|
||||||
|
|
||||||
proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {.
|
proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {.
|
||||||
inline.} =
|
noinline.} =
|
||||||
newException(AsyncStreamIncompleteError, "Incomplete data sent or received")
|
newException(AsyncStreamIncompleteError, "Incomplete data sent or received")
|
||||||
|
|
||||||
proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.inline.} =
|
proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.noinline.} =
|
||||||
newException(AsyncStreamLimitError, "Buffer limit reached")
|
newException(AsyncStreamLimitError, "Buffer limit reached")
|
||||||
|
|
||||||
proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.inline.} =
|
proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.
|
||||||
|
noinline.} =
|
||||||
newException(AsyncStreamUseClosedError, "Stream is already closed")
|
newException(AsyncStreamUseClosedError, "Stream is already closed")
|
||||||
|
|
||||||
proc newAsyncStreamIncorrectError*(m: string): ref AsyncStreamIncorrectError {.
|
proc raiseAsyncStreamUseClosedError*() {.noinline, noreturn.} =
|
||||||
inline.} =
|
raise newAsyncStreamUseClosedError()
|
||||||
newException(AsyncStreamIncorrectError, m)
|
|
||||||
|
proc raiseAsyncStreamLimitError*() {.noinline, noreturn.} =
|
||||||
|
raise newAsyncStreamLimitError()
|
||||||
|
|
||||||
|
proc raiseAsyncStreamIncompleteError*() {.noinline, noreturn.} =
|
||||||
|
raise newAsyncStreamIncompleteError()
|
||||||
|
|
||||||
|
proc raiseAsyncStreamIncorrectDefect*(m: string) {.noinline, noreturn.} =
|
||||||
|
raise newException(AsyncStreamIncorrectDefect, m)
|
||||||
|
|
||||||
|
proc raiseEmptyMessageDefect*() {.noinline, noreturn.} =
|
||||||
|
raise newException(AsyncStreamIncorrectDefect,
|
||||||
|
"Could not write empty message")
|
||||||
|
|
||||||
template checkStreamClosed*(t: untyped) =
|
template checkStreamClosed*(t: untyped) =
|
||||||
if t.state == AsyncStreamState.Closed:
|
if t.state == AsyncStreamState.Closed:
|
||||||
raise newAsyncStreamUseClosedError()
|
raiseAsyncStreamUseClosedError()
|
||||||
|
|
||||||
proc atEof*(rstream: AsyncStreamReader): bool =
|
proc atEof*(rstream: AsyncStreamReader): bool =
|
||||||
## Returns ``true`` is reading stream is closed or finished and internal
|
## Returns ``true`` is reading stream is closed or finished and internal
|
||||||
|
@ -677,7 +690,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
|
||||||
## ``nbytes` must be more then zero.
|
## ``nbytes` must be more then zero.
|
||||||
checkStreamClosed(wstream)
|
checkStreamClosed(wstream)
|
||||||
if nbytes <= 0:
|
if nbytes <= 0:
|
||||||
raise newAsyncStreamIncorrectError("Zero length message")
|
raiseEmptyMessageDefect()
|
||||||
|
|
||||||
if isNil(wstream.wsource):
|
if isNil(wstream.wsource):
|
||||||
var res: int
|
var res: int
|
||||||
|
@ -725,7 +738,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte],
|
||||||
checkStreamClosed(wstream)
|
checkStreamClosed(wstream)
|
||||||
let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes))
|
let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes))
|
||||||
if length <= 0:
|
if length <= 0:
|
||||||
raise newAsyncStreamIncorrectError("Zero length message")
|
raiseEmptyMessageDefect()
|
||||||
|
|
||||||
if isNil(wstream.wsource):
|
if isNil(wstream.wsource):
|
||||||
var res: int
|
var res: int
|
||||||
|
@ -773,7 +786,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string,
|
||||||
checkStreamClosed(wstream)
|
checkStreamClosed(wstream)
|
||||||
let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes))
|
let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes))
|
||||||
if length <= 0:
|
if length <= 0:
|
||||||
raise newAsyncStreamIncorrectError("Zero length message")
|
raiseEmptyMessageDefect()
|
||||||
|
|
||||||
if isNil(wstream.wsource):
|
if isNil(wstream.wsource):
|
||||||
var res: int
|
var res: int
|
||||||
|
@ -857,7 +870,7 @@ proc close*(rw: AsyncStreamRW) =
|
||||||
##
|
##
|
||||||
## Note close() procedure is not completed immediately!
|
## Note close() procedure is not completed immediately!
|
||||||
if rw.closed():
|
if rw.closed():
|
||||||
raise newAsyncStreamIncorrectError("Stream is already closed!")
|
raiseAsyncStreamIncorrectDefect("Stream is already closed!")
|
||||||
|
|
||||||
rw.state = AsyncStreamState.Closed
|
rw.state = AsyncStreamState.Closed
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,9 @@ type
|
||||||
const
|
const
|
||||||
BoundedBufferSize* = 4096
|
BoundedBufferSize* = 4096
|
||||||
|
|
||||||
template newBoundedStreamIncompleteError*(): ref BoundedStreamError =
|
proc newBoundedStreamIncompleteError*(): ref BoundedStreamError {.noinline.} =
|
||||||
newException(BoundedStreamIncompleteError,
|
newException(BoundedStreamIncompleteError,
|
||||||
"Stream boundary is not reached yet")
|
"Stream boundary is not reached yet")
|
||||||
template newBoundedStreamOverflowError*(): ref BoundedStreamError =
|
|
||||||
newException(BoundedStreamOverflowError, "Stream boundary exceeded")
|
|
||||||
|
|
||||||
proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer,
|
proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer,
|
||||||
nbytes: int, sep: seq[byte]): Future[int] {.async.} =
|
nbytes: int, sep: seq[byte]): Future[int] {.async.} =
|
||||||
|
@ -94,7 +92,7 @@ func endsWith(s, suffix: openarray[byte]): bool =
|
||||||
if i >= len(suffix): return true
|
if i >= len(suffix): return true
|
||||||
|
|
||||||
proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
|
proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
var rstream = cast[BoundedStreamReader](stream)
|
var rstream = BoundedStreamReader(stream)
|
||||||
rstream.state = AsyncStreamState.Running
|
rstream.state = AsyncStreamState.Running
|
||||||
var buffer = newSeq[byte](rstream.buffer.bufferLen())
|
var buffer = newSeq[byte](rstream.buffer.bufferLen())
|
||||||
while true:
|
while true:
|
||||||
|
@ -157,7 +155,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
rstream.buffer.forget()
|
rstream.buffer.forget()
|
||||||
|
|
||||||
proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
var wstream = cast[BoundedStreamWriter](stream)
|
var wstream = BoundedStreamWriter(stream)
|
||||||
|
|
||||||
wstream.state = AsyncStreamState.Running
|
wstream.state = AsyncStreamState.Running
|
||||||
while true:
|
while true:
|
||||||
|
@ -181,7 +179,8 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
item.future.complete()
|
item.future.complete()
|
||||||
else:
|
else:
|
||||||
wstream.state = AsyncStreamState.Error
|
wstream.state = AsyncStreamState.Error
|
||||||
error = newBoundedStreamOverflowError()
|
error = newException(BoundedStreamOverflowError,
|
||||||
|
"Stream boundary exceeded")
|
||||||
else:
|
else:
|
||||||
if wstream.offset != wstream.boundSize:
|
if wstream.offset != wstream.boundSize:
|
||||||
case wstream.cmpop
|
case wstream.cmpop
|
||||||
|
@ -223,12 +222,12 @@ proc bytesLeft*(stream: BoundedStreamRW): uint64 =
|
||||||
|
|
||||||
proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader,
|
proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader,
|
||||||
bufferSize = BoundedBufferSize, udata: ref T) =
|
bufferSize = BoundedBufferSize, udata: ref T) =
|
||||||
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize,
|
init(AsyncStreamReader(child), rsource, boundedReadLoop, bufferSize,
|
||||||
udata)
|
udata)
|
||||||
|
|
||||||
proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader,
|
proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader,
|
||||||
bufferSize = BoundedBufferSize) =
|
bufferSize = BoundedBufferSize) =
|
||||||
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize)
|
init(AsyncStreamReader(child), rsource, boundedReadLoop, bufferSize)
|
||||||
|
|
||||||
proc newBoundedStreamReader*[T](rsource: AsyncStreamReader,
|
proc newBoundedStreamReader*[T](rsource: AsyncStreamReader,
|
||||||
boundSize: int,
|
boundSize: int,
|
||||||
|
@ -258,12 +257,12 @@ proc newBoundedStreamReader*(rsource: AsyncStreamReader,
|
||||||
|
|
||||||
proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter,
|
proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter,
|
||||||
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
|
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
|
||||||
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize,
|
init(AsyncStreamWriter(child), wsource, boundedWriteLoop, queueSize,
|
||||||
udata)
|
udata)
|
||||||
|
|
||||||
proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter,
|
proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter,
|
||||||
queueSize = AsyncStreamDefaultQueueSize) =
|
queueSize = AsyncStreamDefaultQueueSize) =
|
||||||
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize)
|
init(AsyncStreamWriter(child), wsource, boundedWriteLoop, queueSize)
|
||||||
|
|
||||||
proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter,
|
proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter,
|
||||||
boundSize: int,
|
boundSize: int,
|
||||||
|
|
|
@ -10,11 +10,15 @@
|
||||||
## This module implements HTTP/1.1 chunked-encoded stream reading and writing.
|
## This module implements HTTP/1.1 chunked-encoded stream reading and writing.
|
||||||
import ../asyncloop, ../timer
|
import ../asyncloop, ../timer
|
||||||
import asyncstream, ../transports/stream, ../transports/common
|
import asyncstream, ../transports/stream, ../transports/common
|
||||||
|
import stew/results
|
||||||
export asyncstream, stream, timer, common
|
export asyncstream, stream, timer, common
|
||||||
|
|
||||||
const
|
const
|
||||||
ChunkBufferSize = 4096
|
ChunkBufferSize = 4096
|
||||||
ChunkHeaderSize = 8
|
ChunkHeaderSize = 8
|
||||||
|
# This is limit for chunk size to 8 hexadecimal digits, so maximum
|
||||||
|
# chunk size for this implementation become:
|
||||||
|
# 2^32 == FFFF_FFFF'u32 == 4,294,967,295 bytes.
|
||||||
CRLF = @[byte(0x0D), byte(0x0A)]
|
CRLF = @[byte(0x0D), byte(0x0A)]
|
||||||
|
|
||||||
type
|
type
|
||||||
|
@ -25,12 +29,6 @@ type
|
||||||
ChunkedStreamProtocolError* = object of ChunkedStreamError
|
ChunkedStreamProtocolError* = object of ChunkedStreamError
|
||||||
ChunkedStreamIncompleteError* = object of ChunkedStreamError
|
ChunkedStreamIncompleteError* = object of ChunkedStreamError
|
||||||
|
|
||||||
proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} =
|
|
||||||
newException(ChunkedStreamProtocolError, "Protocol error!")
|
|
||||||
|
|
||||||
proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} =
|
|
||||||
newException(ChunkedStreamIncompleteError, "Incomplete data received!")
|
|
||||||
|
|
||||||
proc `-`(x: uint32): uint32 {.inline.} =
|
proc `-`(x: uint32): uint32 {.inline.} =
|
||||||
result = (0xFFFF_FFFF'u32 - x) + 1'u32
|
result = (0xFFFF_FFFF'u32 - x) + 1'u32
|
||||||
|
|
||||||
|
@ -47,18 +45,16 @@ proc hexValue(c: byte): int =
|
||||||
((z + 11'u32) and -LT(z, 6))
|
((z + 11'u32) and -LT(z, 6))
|
||||||
int(r) - 1
|
int(r) - 1
|
||||||
|
|
||||||
proc getChunkSize(buffer: openarray[byte]): uint64 =
|
proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] =
|
||||||
# We using `uint64` representation, but allow only 2^32 chunk size,
|
# We using `uint64` representation, but allow only 2^32 chunk size,
|
||||||
# ChunkHeaderSize.
|
# ChunkHeaderSize.
|
||||||
var res = 0'u64
|
var res = 0'u64
|
||||||
for i in 0..<min(len(buffer), ChunkHeaderSize):
|
for i in 0 ..< min(len(buffer), ChunkHeaderSize):
|
||||||
let value = hexValue(buffer[i])
|
let value = hexValue(buffer[i])
|
||||||
if value >= 0:
|
if value < 0:
|
||||||
res = (res shl 4) or uint64(value)
|
return err("Incorrect chunk size encoding")
|
||||||
else:
|
res = (res shl 4) or uint64(value)
|
||||||
res = 0xFFFF_FFFF_FFFF_FFFF'u64
|
ok(res)
|
||||||
break
|
|
||||||
res
|
|
||||||
|
|
||||||
proc setChunkSize(buffer: var openarray[byte], length: int64): int =
|
proc setChunkSize(buffer: var openarray[byte], length: int64): int =
|
||||||
# Store length as chunk header size (hexadecimal value) with CRLF.
|
# Store length as chunk header size (hexadecimal value) with CRLF.
|
||||||
|
@ -87,48 +83,54 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int =
|
||||||
i = i - 4
|
i = i - 4
|
||||||
buffer[c] = byte(0x0D)
|
buffer[c] = byte(0x0D)
|
||||||
buffer[c + 1] = byte(0x0A)
|
buffer[c + 1] = byte(0x0A)
|
||||||
c + 2
|
(c + 2)
|
||||||
|
|
||||||
proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
|
proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
var rstream = cast[ChunkedStreamReader](stream)
|
var rstream = ChunkedStreamReader(stream)
|
||||||
var buffer = newSeq[byte](1024)
|
var buffer = newSeq[byte](1024)
|
||||||
rstream.state = AsyncStreamState.Running
|
rstream.state = AsyncStreamState.Running
|
||||||
|
|
||||||
while true:
|
while true:
|
||||||
try:
|
try:
|
||||||
# Reading chunk size
|
# Reading chunk size
|
||||||
let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF)
|
let res = await rstream.rsource.readUntil(addr buffer[0], len(buffer),
|
||||||
var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1))
|
CRLF)
|
||||||
|
let cres = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1))
|
||||||
|
|
||||||
if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64:
|
if cres.isErr():
|
||||||
rstream.error = newChunkedProtocolError()
|
rstream.error = newException(ChunkedStreamProtocolError, $cres.error)
|
||||||
rstream.state = AsyncStreamState.Error
|
rstream.state = AsyncStreamState.Error
|
||||||
elif chunksize > 0'u64:
|
|
||||||
while chunksize > 0'u64:
|
|
||||||
let toRead = min(int(chunksize), rstream.buffer.bufferLen())
|
|
||||||
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
|
|
||||||
rstream.buffer.update(toRead)
|
|
||||||
await rstream.buffer.transfer()
|
|
||||||
chunksize = chunksize - uint64(toRead)
|
|
||||||
|
|
||||||
if rstream.state == AsyncStreamState.Running:
|
|
||||||
# Reading chunk trailing CRLF
|
|
||||||
await rstream.rsource.readExactly(addr buffer[0], 2)
|
|
||||||
|
|
||||||
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
|
|
||||||
rstream.error = newChunkedProtocolError()
|
|
||||||
rstream.state = AsyncStreamState.Error
|
|
||||||
else:
|
else:
|
||||||
# Reading trailing line for last chunk
|
var chunksize = cres.get()
|
||||||
discard await rstream.rsource.readUntil(addr buffer[0],
|
if chunksize > 0'u64:
|
||||||
len(buffer), CRLF)
|
while chunksize > 0'u64:
|
||||||
rstream.state = AsyncStreamState.Finished
|
let toRead = min(int(chunksize), rstream.buffer.bufferLen())
|
||||||
await rstream.buffer.transfer()
|
await rstream.rsource.readExactly(rstream.buffer.getBuffer(),
|
||||||
|
toRead)
|
||||||
|
rstream.buffer.update(toRead)
|
||||||
|
await rstream.buffer.transfer()
|
||||||
|
chunksize = chunksize - uint64(toRead)
|
||||||
|
|
||||||
|
if rstream.state == AsyncStreamState.Running:
|
||||||
|
# Reading chunk trailing CRLF
|
||||||
|
await rstream.rsource.readExactly(addr buffer[0], 2)
|
||||||
|
|
||||||
|
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
|
||||||
|
rstream.error = newException(ChunkedStreamProtocolError,
|
||||||
|
"Unexpected trailing bytes")
|
||||||
|
rstream.state = AsyncStreamState.Error
|
||||||
|
else:
|
||||||
|
# Reading trailing line for last chunk
|
||||||
|
discard await rstream.rsource.readUntil(addr buffer[0],
|
||||||
|
len(buffer), CRLF)
|
||||||
|
rstream.state = AsyncStreamState.Finished
|
||||||
|
await rstream.buffer.transfer()
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
rstream.state = AsyncStreamState.Stopped
|
rstream.state = AsyncStreamState.Stopped
|
||||||
except AsyncStreamIncompleteError:
|
except AsyncStreamIncompleteError:
|
||||||
rstream.state = AsyncStreamState.Error
|
rstream.state = AsyncStreamState.Error
|
||||||
rstream.error = newChunkedIncompleteError()
|
rstream.error = newException(ChunkedStreamIncompleteError,
|
||||||
|
"Incomplete chunk received")
|
||||||
except AsyncStreamReadError as exc:
|
except AsyncStreamReadError as exc:
|
||||||
rstream.state = AsyncStreamState.Error
|
rstream.state = AsyncStreamState.Error
|
||||||
rstream.error = exc
|
rstream.error = exc
|
||||||
|
@ -140,7 +142,7 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
break
|
break
|
||||||
|
|
||||||
proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
var wstream = cast[ChunkedStreamWriter](stream)
|
var wstream = ChunkedStreamWriter(stream)
|
||||||
var buffer: array[16, byte]
|
var buffer: array[16, byte]
|
||||||
var error: ref AsyncStreamError
|
var error: ref AsyncStreamError
|
||||||
wstream.state = AsyncStreamState.Running
|
wstream.state = AsyncStreamState.Running
|
||||||
|
@ -200,12 +202,12 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
|
|
||||||
proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader,
|
proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader,
|
||||||
bufferSize = ChunkBufferSize, udata: ref T) =
|
bufferSize = ChunkBufferSize, udata: ref T) =
|
||||||
init(cast[AsyncStreamReader](child), rsource, chunkedReadLoop, bufferSize,
|
init(AsyncStreamReader(child), rsource, chunkedReadLoop, bufferSize,
|
||||||
udata)
|
udata)
|
||||||
|
|
||||||
proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader,
|
proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader,
|
||||||
bufferSize = ChunkBufferSize) =
|
bufferSize = ChunkBufferSize) =
|
||||||
init(cast[AsyncStreamReader](child), rsource, chunkedReadLoop, bufferSize)
|
init(AsyncStreamReader(child), rsource, chunkedReadLoop, bufferSize)
|
||||||
|
|
||||||
proc newChunkedStreamReader*[T](rsource: AsyncStreamReader,
|
proc newChunkedStreamReader*[T](rsource: AsyncStreamReader,
|
||||||
bufferSize = AsyncStreamDefaultBufferSize,
|
bufferSize = AsyncStreamDefaultBufferSize,
|
||||||
|
@ -223,12 +225,12 @@ proc newChunkedStreamReader*(rsource: AsyncStreamReader,
|
||||||
|
|
||||||
proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
|
proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
|
||||||
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
|
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
|
||||||
init(cast[AsyncStreamWriter](child), wsource, chunkedWriteLoop, queueSize,
|
init(AsyncStreamWriter(child), wsource, chunkedWriteLoop, queueSize,
|
||||||
udata)
|
udata)
|
||||||
|
|
||||||
proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
|
proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
|
||||||
queueSize = AsyncStreamDefaultQueueSize) =
|
queueSize = AsyncStreamDefaultQueueSize) =
|
||||||
init(cast[AsyncStreamWriter](child), wsource, chunkedWriteLoop, queueSize)
|
init(AsyncStreamWriter(child), wsource, chunkedWriteLoop, queueSize)
|
||||||
|
|
||||||
proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter,
|
proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter,
|
||||||
queueSize = AsyncStreamDefaultQueueSize,
|
queueSize = AsyncStreamDefaultQueueSize,
|
||||||
|
|
|
@ -91,6 +91,7 @@ type
|
||||||
|
|
||||||
TLSStreamError* = object of AsyncStreamError
|
TLSStreamError* = object of AsyncStreamError
|
||||||
TLSStreamHandshakeError* = object of TLSStreamError
|
TLSStreamHandshakeError* = object of TLSStreamError
|
||||||
|
TLSStreamInitError* = object of TLSStreamError
|
||||||
TLSStreamReadError* = object of TLSStreamError
|
TLSStreamReadError* = object of TLSStreamError
|
||||||
par*: ref AsyncStreamError
|
par*: ref AsyncStreamError
|
||||||
TLSStreamWriteError* = object of TLSStreamError
|
TLSStreamWriteError* = object of TLSStreamError
|
||||||
|
@ -99,20 +100,20 @@ type
|
||||||
errCode*: int
|
errCode*: int
|
||||||
|
|
||||||
proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {.
|
proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {.
|
||||||
inline.} =
|
noinline.} =
|
||||||
var w = newException(TLSStreamReadError, "Read stream failed")
|
var w = newException(TLSStreamReadError, "Read stream failed")
|
||||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||||
w.par = p
|
w.par = p
|
||||||
w
|
w
|
||||||
|
|
||||||
proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {.
|
proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {.
|
||||||
inline.} =
|
noinline.} =
|
||||||
var w = newException(TLSStreamWriteError, "Write stream failed")
|
var w = newException(TLSStreamWriteError, "Write stream failed")
|
||||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||||
w.par = p
|
w.par = p
|
||||||
w
|
w
|
||||||
|
|
||||||
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
template newTLSStreamProtocolImpl[T](message: T): ref TLSStreamProtocolError =
|
||||||
var msg = ""
|
var msg = ""
|
||||||
var code = 0
|
var code = 0
|
||||||
when T is string:
|
when T is string:
|
||||||
|
@ -129,6 +130,12 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
||||||
err.errCode = code
|
err.errCode = code
|
||||||
err
|
err
|
||||||
|
|
||||||
|
proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
||||||
|
newTLSStreamProtocolImpl(message)
|
||||||
|
|
||||||
|
proc raiseTLSStreamProtocolError[T](message: T) {.noreturn, noinline.} =
|
||||||
|
raise newTLSStreamProtocolImpl(message)
|
||||||
|
|
||||||
proc tlsWriteRec(engine: ptr SslEngineContext,
|
proc tlsWriteRec(engine: ptr SslEngineContext,
|
||||||
writer: TLSStreamWriter): Future[TLSResult] {.async.} =
|
writer: TLSStreamWriter): Future[TLSResult] {.async.} =
|
||||||
try:
|
try:
|
||||||
|
@ -208,9 +215,6 @@ proc tlsReadApp(engine: ptr SslEngineContext,
|
||||||
reader.state = AsyncStreamState.Stopped
|
reader.state = AsyncStreamState.Stopped
|
||||||
return TLSResult.Error
|
return TLSResult.Error
|
||||||
|
|
||||||
template raiseTLSStreamProtoError*[T](message: T) =
|
|
||||||
raise newTLSStreamProtocolError(message)
|
|
||||||
|
|
||||||
template readAndReset(fut: untyped) =
|
template readAndReset(fut: untyped) =
|
||||||
if fut.finished():
|
if fut.finished():
|
||||||
let res = fut.read()
|
let res = fut.read()
|
||||||
|
@ -386,7 +390,7 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
|
||||||
stream.reader.buffer.forget()
|
stream.reader.buffer.forget()
|
||||||
|
|
||||||
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
var wstream = cast[TLSStreamWriter](stream)
|
var wstream = TLSStreamWriter(stream)
|
||||||
wstream.state = AsyncStreamState.Running
|
wstream.state = AsyncStreamState.Running
|
||||||
await stepsAsync(1)
|
await stepsAsync(1)
|
||||||
if isNil(wstream.stream.mainLoop):
|
if isNil(wstream.stream.mainLoop):
|
||||||
|
@ -394,7 +398,7 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
await wstream.stream.mainLoop
|
await wstream.stream.mainLoop
|
||||||
|
|
||||||
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
var rstream = cast[TLSStreamReader](stream)
|
var rstream = TLSStreamReader(stream)
|
||||||
rstream.state = AsyncStreamState.Running
|
rstream.state = AsyncStreamState.Running
|
||||||
await stepsAsync(1)
|
await stepsAsync(1)
|
||||||
if isNil(rstream.stream.mainLoop):
|
if isNil(rstream.stream.mainLoop):
|
||||||
|
@ -468,18 +472,19 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
|
||||||
if TLSFlags.NoVerifyServerName in flags:
|
if TLSFlags.NoVerifyServerName in flags:
|
||||||
let err = sslClientReset(addr res.ccontext, "", 0)
|
let err = sslClientReset(addr res.ccontext, "", 0)
|
||||||
if err == 0:
|
if err == 0:
|
||||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
raise newException(TLSStreamInitError, "Could not initialize TLS layer")
|
||||||
else:
|
else:
|
||||||
if len(serverName) == 0:
|
if len(serverName) == 0:
|
||||||
raise newException(TLSStreamError, "serverName must not be empty string")
|
raise newException(TLSStreamInitError,
|
||||||
|
"serverName must not be empty string")
|
||||||
|
|
||||||
let err = sslClientReset(addr res.ccontext, serverName, 0)
|
let err = sslClientReset(addr res.ccontext, serverName, 0)
|
||||||
if err == 0:
|
if err == 0:
|
||||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
raise newException(TLSStreamInitError, "Could not initialize TLS layer")
|
||||||
|
|
||||||
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop,
|
init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop,
|
||||||
bufferSize)
|
bufferSize)
|
||||||
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop,
|
init(AsyncStreamReader(res.reader), rsource, tlsReadLoop,
|
||||||
bufferSize)
|
bufferSize)
|
||||||
res
|
res
|
||||||
|
|
||||||
|
@ -507,9 +512,9 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
||||||
##
|
##
|
||||||
## ``flags`` - custom TLS connection flags.
|
## ``flags`` - custom TLS connection flags.
|
||||||
if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}:
|
if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}:
|
||||||
raiseTLSStreamProtoError("Incorrect private key")
|
raiseTLSStreamProtocolError("Incorrect private key")
|
||||||
if isNil(certificate) or len(certificate.certs) == 0:
|
if isNil(certificate) or len(certificate.certs) == 0:
|
||||||
raiseTLSStreamProtoError("Incorrect certificate")
|
raiseTLSStreamProtocolError("Incorrect certificate")
|
||||||
|
|
||||||
var res = TLSAsyncStream()
|
var res = TLSAsyncStream()
|
||||||
var reader = TLSStreamReader(
|
var reader = TLSStreamReader(
|
||||||
|
@ -528,7 +533,7 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
||||||
if privateKey.kind == TLSKeyType.EC:
|
if privateKey.kind == TLSKeyType.EC:
|
||||||
let algo = getSignerAlgo(certificate.certs[0])
|
let algo = getSignerAlgo(certificate.certs[0])
|
||||||
if algo == -1:
|
if algo == -1:
|
||||||
raiseTLSStreamProtoError("Could not decode certificate")
|
raiseTLSStreamProtocolError("Could not decode certificate")
|
||||||
sslServerInitFullEc(addr res.scontext, addr certificate.certs[0],
|
sslServerInitFullEc(addr res.scontext, addr certificate.certs[0],
|
||||||
len(certificate.certs), cuint(algo),
|
len(certificate.certs), cuint(algo),
|
||||||
addr privateKey.eckey)
|
addr privateKey.eckey)
|
||||||
|
@ -557,11 +562,11 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
||||||
|
|
||||||
let err = sslServerReset(addr res.scontext)
|
let err = sslServerReset(addr res.scontext)
|
||||||
if err == 0:
|
if err == 0:
|
||||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
raise newException(TLSStreamInitError, "Could not initialize TLS layer")
|
||||||
|
|
||||||
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop,
|
init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop,
|
||||||
bufferSize)
|
bufferSize)
|
||||||
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop,
|
init(AsyncStreamReader(res.reader), rsource, tlsReadLoop,
|
||||||
bufferSize)
|
bufferSize)
|
||||||
res
|
res
|
||||||
|
|
||||||
|
@ -610,12 +615,12 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
|
||||||
## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded).
|
## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded).
|
||||||
var ctx: SkeyDecoderContext
|
var ctx: SkeyDecoderContext
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
raiseTLSStreamProtoError("Incorrect private key")
|
raiseTLSStreamProtocolError("Incorrect private key")
|
||||||
skeyDecoderInit(addr ctx)
|
skeyDecoderInit(addr ctx)
|
||||||
skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data))
|
skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data))
|
||||||
let err = skeyDecoderLastError(addr ctx)
|
let err = skeyDecoderLastError(addr ctx)
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raiseTLSStreamProtoError(err)
|
raiseTLSStreamProtocolError(err)
|
||||||
let keyType = skeyDecoderKeyType(addr ctx)
|
let keyType = skeyDecoderKeyType(addr ctx)
|
||||||
let res =
|
let res =
|
||||||
if keyType == KEYTYPE_RSA:
|
if keyType == KEYTYPE_RSA:
|
||||||
|
@ -623,13 +628,13 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
|
||||||
elif keyType == KEYTYPE_EC:
|
elif keyType == KEYTYPE_EC:
|
||||||
copyKey(ctx.key.ec)
|
copyKey(ctx.key.ec)
|
||||||
else:
|
else:
|
||||||
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
|
raiseTLSStreamProtocolError("Unknown key type (" & $keyType & ")")
|
||||||
res
|
res
|
||||||
|
|
||||||
proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
||||||
## Decode PEM encoded string and get array of binary blobs.
|
## Decode PEM encoded string and get array of binary blobs.
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
raiseTLSStreamProtoError("Empty PEM message")
|
raiseTLSStreamProtocolError("Empty PEM message")
|
||||||
var ctx: PemDecoderContext
|
var ctx: PemDecoderContext
|
||||||
var pctx = new PEMContext
|
var pctx = new PEMContext
|
||||||
var res = newSeq[PEMElement]()
|
var res = newSeq[PEMElement]()
|
||||||
|
@ -666,7 +671,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raiseTLSStreamProtoError("Invalid PEM encoding")
|
raiseTLSStreamProtocolError("Invalid PEM encoding")
|
||||||
res
|
res
|
||||||
|
|
||||||
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
|
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
|
||||||
|
@ -683,7 +688,7 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
|
||||||
res = TLSPrivateKey.init(item.data)
|
res = TLSPrivateKey.init(item.data)
|
||||||
break
|
break
|
||||||
if isNil(res):
|
if isNil(res):
|
||||||
raiseTLSStreamProtoError("Could not find private key")
|
raiseTLSStreamProtocolError("Could not find private key")
|
||||||
res
|
res
|
||||||
|
|
||||||
proc init*(tt: typedesc[TLSCertificate],
|
proc init*(tt: typedesc[TLSCertificate],
|
||||||
|
@ -703,12 +708,13 @@ proc init*(tt: typedesc[TLSCertificate],
|
||||||
)
|
)
|
||||||
let ares = getSignerAlgo(cert)
|
let ares = getSignerAlgo(cert)
|
||||||
if ares == -1:
|
if ares == -1:
|
||||||
raiseTLSStreamProtoError("Could not decode certificate")
|
raiseTLSStreamProtocolError("Could not decode certificate")
|
||||||
elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC:
|
elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC:
|
||||||
raiseTLSStreamProtoError("Unsupported signing key type in certificate")
|
raiseTLSStreamProtocolError(
|
||||||
|
"Unsupported signing key type in certificate")
|
||||||
res.certs.add(cert)
|
res.certs.add(cert)
|
||||||
if len(res.storage) == 0:
|
if len(res.storage) == 0:
|
||||||
raiseTLSStreamProtoError("Could not find any certificates")
|
raiseTLSStreamProtocolError("Could not find any certificates")
|
||||||
res
|
res
|
||||||
|
|
||||||
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =
|
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =
|
||||||
|
|
|
@ -813,6 +813,7 @@ suite "HTTP server testing suite":
|
||||||
res2.isOk()
|
res2.isOk()
|
||||||
res2.get() == FlagsVectors[i]
|
res2.get() == FlagsVectors[i]
|
||||||
res3.isErr()
|
res3.isErr()
|
||||||
|
res4.isErr()
|
||||||
res5.isOk()
|
res5.isOk()
|
||||||
res5.get() == FlagsVectors[i]
|
res5.get() == FlagsVectors[i]
|
||||||
|
|
||||||
|
@ -864,6 +865,7 @@ suite "HTTP server testing suite":
|
||||||
res2.isOk()
|
res2.isOk()
|
||||||
res2.get() == FlagsVectors[i]
|
res2.get() == FlagsVectors[i]
|
||||||
res3.isErr()
|
res3.isErr()
|
||||||
|
res4.isErr()
|
||||||
res5.isOk()
|
res5.isOk()
|
||||||
res5.get() == FlagsVectors[i]
|
res5.get() == FlagsVectors[i]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue