Address review comments.

This commit is contained in:
cheatfate 2021-02-17 02:03:12 +02:00 committed by zah
parent 6f8d06f12d
commit fc0d1bcb43
8 changed files with 168 additions and 160 deletions

View File

@ -63,28 +63,15 @@ proc atBound*(bstream: HttpBodyReader): bool {.
let breader = cast[BoundedStreamReader](lreader) let breader = cast[BoundedStreamReader](lreader)
breader.atEof() and (breader.bytesLeft() == 0) breader.atEof() and (breader.bytesLeft() == 0)
proc newHttpDefect*(msg: string): ref HttpDefect {. proc raiseHttpCriticalError*(msg: string,
raises: [HttpDefect].} = code = Http400) {.noinline, noreturn.} =
newException(HttpDefect, msg) raise (ref HttpCriticalError)(code: code, msg: msg)
proc newHttpCriticalError*(msg: string, proc raiseHttpDisconnectError*() {.noinline, noreturn.} =
code = Http400): ref HttpCriticalError {. raise (ref HttpDisconnectError)(msg: "Remote peer disconnected")
raises: [HttpCriticalError].} =
var tre = newException(HttpCriticalError, msg)
tre.code = code
tre
proc newHttpRecoverableError*(msg: string, proc raiseHttpDefect*(msg: string) {.noinline, noreturn.} =
code = Http400): ref HttpRecoverableError {. raise (ref HttpDefect)(msg: msg)
raises: [HttpRecoverableError].} =
var tre = newException(HttpRecoverableError, msg)
tre.code = code
tre
proc newHttpDisconnectError*(): ref HttpDisconnectError {.
raises: [HttpDisconnectError].} =
var tre = newException(HttpDisconnectError, "Remote peer disconnected")
tre
iterator queryParams*(query: string): tuple[key: string, value: string] {. iterator queryParams*(query: string): tuple[key: string, value: string] {.
raises: [Defect].} = raises: [Defect].} =

View File

@ -349,7 +349,7 @@ proc handleExpect*(request: HttpRequestRef) {.async.} =
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
raise newHttpCriticalError("Unable to send `100-continue` response") raiseHttpCriticalError("Unable to send `100-continue` response")
proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} = proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} =
## Obtain request's body as sequence of bytes. ## Obtain request's body as sequence of bytes.
@ -363,9 +363,9 @@ proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} =
return await reader.read() return await reader.read()
except AsyncStreamError: except AsyncStreamError:
if reader.atBound(): if reader.atBound():
raise newHttpCriticalError("Maximum size of body reached", Http413) raiseHttpCriticalError("Maximum size of body reached", Http413)
else: else:
raise newHttpCriticalError("Unable to read request's body") raiseHttpCriticalError("Unable to read request's body")
finally: finally:
await closeWait(res.get()) await closeWait(res.get())
@ -381,9 +381,9 @@ proc consumeBody*(request: HttpRequestRef): Future[void] {.async.} =
discard await reader.consume() discard await reader.consume()
except AsyncStreamError: except AsyncStreamError:
if reader.atBound(): if reader.atBound():
raise newHttpCriticalError("Maximum size of body reached", Http413) raiseHttpCriticalError("Maximum size of body reached", Http413)
else: else:
raise newHttpCriticalError("Unable to read request's body") raiseHttpCriticalError("Unable to read request's body")
finally: finally:
await closeWait(res.get()) await closeWait(res.get())
@ -422,18 +422,17 @@ proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} =
conn.buffer.setLen(res) conn.buffer.setLen(res)
let header = parseRequest(conn.buffer) let header = parseRequest(conn.buffer)
if header.failed(): if header.failed():
raise newHttpCriticalError("Malformed request recieved") raiseHttpCriticalError("Malformed request recieved")
else: else:
let res = prepareRequest(conn, header) let res = prepareRequest(conn, header)
if res.isErr(): if res.isErr():
raise newHttpCriticalError("Invalid request received", res.error) raiseHttpCriticalError("Invalid request received", res.error)
else: else:
return res.get() return res.get()
except AsyncStreamIncompleteError, AsyncStreamReadError: except AsyncStreamIncompleteError, AsyncStreamReadError:
raise newHttpDisconnectError() raiseHttpDisconnectError()
except AsyncStreamLimitError: except AsyncStreamLimitError:
raise newHttpCriticalError("Maximum size of request headers reached", raiseHttpCriticalError("Maximum size of request headers reached", Http413)
Http413)
proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef, proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef,
transp: StreamTransport): HttpConnectionRef = transp: StreamTransport): HttpConnectionRef =
@ -503,7 +502,7 @@ proc createConnection(server: HttpServerRef,
raise exc raise exc
except TLSStreamError: except TLSStreamError:
await conn.closeWait() await conn.closeWait()
raise newHttpCriticalError("Unable to establish secure connection") raiseHttpCriticalError("Unable to establish secure connection")
proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
var var
@ -534,7 +533,7 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
return return
except CatchableError as exc: except CatchableError as exc:
# There should be no exceptions, so we will raise `Defect`. # There should be no exceptions, so we will raise `Defect`.
raise newHttpDefect("Unexpected exception catched [" & $exc.name & "]") raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]")
var breakLoop = false var breakLoop = false
while runLoop: while runLoop:
@ -763,7 +762,7 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {.async.} =
var table = HttpTable.init() var table = HttpTable.init()
let res = getMultipartReader(req) let res = getMultipartReader(req)
if res.isErr(): if res.isErr():
raise newHttpCriticalError("Unable to retrieve multipart form data") raiseHttpCriticalError("Unable to retrieve multipart form data")
var mpreader = res.get() var mpreader = res.get()
# We must handle `Expect` first. # We must handle `Expect` first.
@ -808,10 +807,10 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {.async.} =
else: else:
if HttpRequestFlags.BoundBody in req.requestFlags: if HttpRequestFlags.BoundBody in req.requestFlags:
if req.contentLength != 0: if req.contentLength != 0:
raise newHttpCriticalError("Unsupported request body") raiseHttpCriticalError("Unsupported request body")
return HttpTable.init() return HttpTable.init()
elif HttpRequestFlags.UnboundBody in req.requestFlags: elif HttpRequestFlags.UnboundBody in req.requestFlags:
raise newHttpCriticalError("Unsupported request body") raiseHttpCriticalError("Unsupported request body")
proc `keepalive=`*(resp: HttpResponseRef, value: bool) = proc `keepalive=`*(resp: HttpResponseRef, value: bool) =
doAssert(resp.state == HttpResponseState.Empty) doAssert(resp.state == HttpResponseState.Empty)
@ -854,7 +853,7 @@ template doHeaderVal(buf, name, value) =
template checkPending(t: untyped) = template checkPending(t: untyped) =
if t.state != HttpResponseState.Empty: if t.state != HttpResponseState.Empty:
raise newHttpCriticalError("Response body was already sent") raiseHttpCriticalError("Response body was already sent")
proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {. proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {.
raises: [Defect].}= raises: [Defect].}=
@ -910,7 +909,7 @@ proc sendBody*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} = proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} =
## Send HTTP response at once by using data ``data``. ## Send HTTP response at once by using data ``data``.
@ -928,7 +927,7 @@ proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} =
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} = proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} =
## Send HTTP error status response. ## Send HTTP error status response.
@ -947,7 +946,7 @@ proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} =
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc prepare*(resp: HttpResponseRef) {.async.} = proc prepare*(resp: HttpResponseRef) {.async.} =
## Prepare for HTTP stream response. ## Prepare for HTTP stream response.
@ -966,16 +965,16 @@ proc prepare*(resp: HttpResponseRef) {.async.} =
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} = proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
## Send single chunk of data pointed by ``pbytes`` and ``nbytes``. ## Send single chunk of data pointed by ``pbytes`` and ``nbytes``.
doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(not(isNil(pbytes)), "pbytes must not be nil")
doAssert(nbytes >= 0, "nbytes should be bigger or equal to zero") doAssert(nbytes >= 0, "nbytes should be bigger or equal to zero")
if HttpResponseFlags.Chunked notin resp.flags: if HttpResponseFlags.Chunked notin resp.flags:
raise newHttpCriticalError("Response was not prepared") raiseHttpCriticalError("Response was not prepared")
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}: if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
raise newHttpCriticalError("Response in incorrect state") raiseHttpCriticalError("Response in incorrect state")
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.chunkedWriter.write(pbytes, nbytes) await resp.chunkedWriter.write(pbytes, nbytes)
@ -985,15 +984,15 @@ proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef, proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef,
data: T) {.async.} = data: T) {.async.} =
## Send single chunk of data ``data``. ## Send single chunk of data ``data``.
if HttpResponseFlags.Chunked notin resp.flags: if HttpResponseFlags.Chunked notin resp.flags:
raise newHttpCriticalError("Response was not prepared") raiseHttpCriticalError("Response was not prepared")
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}: if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
raise newHttpCriticalError("Response in incorrect state") raiseHttpCriticalError("Response in incorrect state")
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.chunkedWriter.write(data) await resp.chunkedWriter.write(data)
@ -1003,14 +1002,14 @@ proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef,
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc finish*(resp: HttpResponseRef) {.async.} = proc finish*(resp: HttpResponseRef) {.async.} =
## Sending last chunk of data, so it will indicate end of HTTP response. ## Sending last chunk of data, so it will indicate end of HTTP response.
if HttpResponseFlags.Chunked notin resp.flags: if HttpResponseFlags.Chunked notin resp.flags:
raise newHttpCriticalError("Response was not prepared") raiseHttpCriticalError("Response was not prepared")
if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}: if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}:
raise newHttpCriticalError("Response in incorrect state") raiseHttpCriticalError("Response in incorrect state")
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.chunkedWriter.finish() await resp.chunkedWriter.finish()
@ -1020,7 +1019,7 @@ proc finish*(resp: HttpResponseRef) {.async.} =
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
resp.state = HttpResponseState.Failed resp.state = HttpResponseState.Failed
raise newHttpCriticalError("Unable to send response") raiseHttpCriticalError("Unable to send response")
proc respond*(req: HttpRequestRef, code: HttpCode, content: string, proc respond*(req: HttpRequestRef, code: HttpCode, content: string,
headers: HttpTable): Future[HttpResponseRef] {.async.} = headers: HttpTable): Future[HttpResponseRef] {.async.} =

View File

@ -149,14 +149,14 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
mpr.firstTime = false mpr.firstTime = false
if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3), if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))): mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))):
raise newHttpCriticalError("Unexpected boundary encountered") raiseHttpCriticalError("Unexpected boundary encountered")
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamError: except AsyncStreamError:
if mpr.stream.atBound(): if mpr.stream.atBound():
raise newHttpCriticalError("Maximum size of body reached", Http413) raiseHttpCriticalError("Maximum size of body reached", Http413)
else: else:
raise newHttpCriticalError("Unable to read multipart body") raiseHttpCriticalError("Unable to read multipart body")
# Reading part's headers # Reading part's headers
try: try:
@ -170,9 +170,9 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
raise newException(MultipartEOMError, raise newException(MultipartEOMError,
"End of multipart message") "End of multipart message")
else: else:
raise newHttpCriticalError("Incorrect multipart header found") raiseHttpCriticalError("Incorrect multipart header found")
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8: if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
raise newHttpCriticalError("Incorrect multipart boundary found") raiseHttpCriticalError("Incorrect multipart boundary found")
# If two bytes are CRLF we are at the part beginning. # If two bytes are CRLF we are at the part beginning.
# Reading part's headers # Reading part's headers
@ -180,7 +180,7 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
HeadersMark) HeadersMark)
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false) var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
if headersList.failed(): if headersList.failed():
raise newHttpCriticalError("Incorrect multipart's headers found") raiseHttpCriticalError("Incorrect multipart's headers found")
inc(mpr.counter) inc(mpr.counter)
var part = MultiPart( var part = MultiPart(
@ -196,16 +196,16 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
let sres = part.setPartNames() let sres = part.setPartNames()
if sres.isErr(): if sres.isErr():
raise newHttpCriticalError(sres.error) raiseHttpCriticalError($sres.error)
return part return part
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamError: except AsyncStreamError:
if mpr.stream.atBound(): if mpr.stream.atBound():
raise newHttpCriticalError("Maximum size of body reached", Http413) raiseHttpCriticalError("Maximum size of body reached", Http413)
else: else:
raise newHttpCriticalError("Unable to read multipart body") raiseHttpCriticalError("Unable to read multipart body")
proc atBound*(mp: MultiPart): bool = proc atBound*(mp: MultiPart): bool =
## Returns ``true`` if MultiPart's stream reached request body maximum size. ## Returns ``true`` if MultiPart's stream reached request body maximum size.
@ -220,9 +220,9 @@ proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
return res return res
except AsyncStreamError: except AsyncStreamError:
if mp.breader.atBound(): if mp.breader.atBound():
raise newHttpCriticalError("Maximum size of body reached", Http413) raiseHttpCriticalError("Maximum size of body reached", Http413)
else: else:
raise newHttpCriticalError("Unable to read multipart body") raiseHttpCriticalError("Unable to read multipart body")
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
return mp.buffer return mp.buffer
@ -234,9 +234,9 @@ proc consumeBody*(mp: MultiPart) {.async.} =
discard await mp.stream.consume() discard await mp.stream.consume()
except AsyncStreamError: except AsyncStreamError:
if mp.breader.atBound(): if mp.breader.atBound():
raise newHttpCriticalError("Maximum size of body reached", Http413) raiseHttpCriticalError("Maximum size of body reached", Http413)
else: else:
raise newHttpCriticalError("Unable to consume multipart body") raiseHttpCriticalError("Unable to consume multipart body")
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
discard discard

View File

@ -22,7 +22,7 @@ const
type type
AsyncStreamError* = object of CatchableError AsyncStreamError* = object of CatchableError
AsyncStreamIncorrectError* = object of Defect AsyncStreamIncorrectDefect* = object of Defect
AsyncStreamIncompleteError* = object of AsyncStreamError AsyncStreamIncompleteError* = object of AsyncStreamError
AsyncStreamLimitError* = object of AsyncStreamError AsyncStreamLimitError* = object of AsyncStreamError
AsyncStreamUseClosedError* = object of AsyncStreamError AsyncStreamUseClosedError* = object of AsyncStreamError
@ -179,36 +179,49 @@ template copyOut*(dest: pointer, item: WriteItem, length: int) =
copyMem(dest, unsafeAddr item.data3[item.offset], length) copyMem(dest, unsafeAddr item.data3[item.offset], length)
proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {.
inline.} = noinline.} =
var w = newException(AsyncStreamReadError, "Read stream failed") var w = newException(AsyncStreamReadError, "Read stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.par = p
w w
proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {. proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {.
inline.} = noinline.} =
var w = newException(AsyncStreamWriteError, "Write stream failed") var w = newException(AsyncStreamWriteError, "Write stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.par = p
w w
proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {. proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {.
inline.} = noinline.} =
newException(AsyncStreamIncompleteError, "Incomplete data sent or received") newException(AsyncStreamIncompleteError, "Incomplete data sent or received")
proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.inline.} = proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.noinline.} =
newException(AsyncStreamLimitError, "Buffer limit reached") newException(AsyncStreamLimitError, "Buffer limit reached")
proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.inline.} = proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.
noinline.} =
newException(AsyncStreamUseClosedError, "Stream is already closed") newException(AsyncStreamUseClosedError, "Stream is already closed")
proc newAsyncStreamIncorrectError*(m: string): ref AsyncStreamIncorrectError {. proc raiseAsyncStreamUseClosedError*() {.noinline, noreturn.} =
inline.} = raise newAsyncStreamUseClosedError()
newException(AsyncStreamIncorrectError, m)
proc raiseAsyncStreamLimitError*() {.noinline, noreturn.} =
raise newAsyncStreamLimitError()
proc raiseAsyncStreamIncompleteError*() {.noinline, noreturn.} =
raise newAsyncStreamIncompleteError()
proc raiseAsyncStreamIncorrectDefect*(m: string) {.noinline, noreturn.} =
raise newException(AsyncStreamIncorrectDefect, m)
proc raiseEmptyMessageDefect*() {.noinline, noreturn.} =
raise newException(AsyncStreamIncorrectDefect,
"Could not write empty message")
template checkStreamClosed*(t: untyped) = template checkStreamClosed*(t: untyped) =
if t.state == AsyncStreamState.Closed: if t.state == AsyncStreamState.Closed:
raise newAsyncStreamUseClosedError() raiseAsyncStreamUseClosedError()
proc atEof*(rstream: AsyncStreamReader): bool = proc atEof*(rstream: AsyncStreamReader): bool =
## Returns ``true`` is reading stream is closed or finished and internal ## Returns ``true`` is reading stream is closed or finished and internal
@ -677,7 +690,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
## ``nbytes` must be more then zero. ## ``nbytes` must be more then zero.
checkStreamClosed(wstream) checkStreamClosed(wstream)
if nbytes <= 0: if nbytes <= 0:
raise newAsyncStreamIncorrectError("Zero length message") raiseEmptyMessageDefect()
if isNil(wstream.wsource): if isNil(wstream.wsource):
var res: int var res: int
@ -725,7 +738,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte],
checkStreamClosed(wstream) checkStreamClosed(wstream)
let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes))
if length <= 0: if length <= 0:
raise newAsyncStreamIncorrectError("Zero length message") raiseEmptyMessageDefect()
if isNil(wstream.wsource): if isNil(wstream.wsource):
var res: int var res: int
@ -773,7 +786,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string,
checkStreamClosed(wstream) checkStreamClosed(wstream)
let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes))
if length <= 0: if length <= 0:
raise newAsyncStreamIncorrectError("Zero length message") raiseEmptyMessageDefect()
if isNil(wstream.wsource): if isNil(wstream.wsource):
var res: int var res: int
@ -857,7 +870,7 @@ proc close*(rw: AsyncStreamRW) =
## ##
## Note close() procedure is not completed immediately! ## Note close() procedure is not completed immediately!
if rw.closed(): if rw.closed():
raise newAsyncStreamIncorrectError("Stream is already closed!") raiseAsyncStreamIncorrectDefect("Stream is already closed!")
rw.state = AsyncStreamState.Closed rw.state = AsyncStreamState.Closed

View File

@ -42,11 +42,9 @@ type
const const
BoundedBufferSize* = 4096 BoundedBufferSize* = 4096
template newBoundedStreamIncompleteError*(): ref BoundedStreamError = proc newBoundedStreamIncompleteError*(): ref BoundedStreamError {.noinline.} =
newException(BoundedStreamIncompleteError, newException(BoundedStreamIncompleteError,
"Stream boundary is not reached yet") "Stream boundary is not reached yet")
template newBoundedStreamOverflowError*(): ref BoundedStreamError =
newException(BoundedStreamOverflowError, "Stream boundary exceeded")
proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer, proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer,
nbytes: int, sep: seq[byte]): Future[int] {.async.} = nbytes: int, sep: seq[byte]): Future[int] {.async.} =
@ -94,7 +92,7 @@ func endsWith(s, suffix: openarray[byte]): bool =
if i >= len(suffix): return true if i >= len(suffix): return true
proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[BoundedStreamReader](stream) var rstream = BoundedStreamReader(stream)
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
var buffer = newSeq[byte](rstream.buffer.bufferLen()) var buffer = newSeq[byte](rstream.buffer.bufferLen())
while true: while true:
@ -157,7 +155,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.buffer.forget() rstream.buffer.forget()
proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[BoundedStreamWriter](stream) var wstream = BoundedStreamWriter(stream)
wstream.state = AsyncStreamState.Running wstream.state = AsyncStreamState.Running
while true: while true:
@ -181,7 +179,8 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
item.future.complete() item.future.complete()
else: else:
wstream.state = AsyncStreamState.Error wstream.state = AsyncStreamState.Error
error = newBoundedStreamOverflowError() error = newException(BoundedStreamOverflowError,
"Stream boundary exceeded")
else: else:
if wstream.offset != wstream.boundSize: if wstream.offset != wstream.boundSize:
case wstream.cmpop case wstream.cmpop
@ -223,12 +222,12 @@ proc bytesLeft*(stream: BoundedStreamRW): uint64 =
proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader, proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader,
bufferSize = BoundedBufferSize, udata: ref T) = bufferSize = BoundedBufferSize, udata: ref T) =
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize, init(AsyncStreamReader(child), rsource, boundedReadLoop, bufferSize,
udata) udata)
proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader, proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader,
bufferSize = BoundedBufferSize) = bufferSize = BoundedBufferSize) =
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize) init(AsyncStreamReader(child), rsource, boundedReadLoop, bufferSize)
proc newBoundedStreamReader*[T](rsource: AsyncStreamReader, proc newBoundedStreamReader*[T](rsource: AsyncStreamReader,
boundSize: int, boundSize: int,
@ -258,12 +257,12 @@ proc newBoundedStreamReader*(rsource: AsyncStreamReader,
proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter, proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize, init(AsyncStreamWriter(child), wsource, boundedWriteLoop, queueSize,
udata) udata)
proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter, proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize) = queueSize = AsyncStreamDefaultQueueSize) =
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize) init(AsyncStreamWriter(child), wsource, boundedWriteLoop, queueSize)
proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter, proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter,
boundSize: int, boundSize: int,

View File

@ -10,11 +10,15 @@
## This module implements HTTP/1.1 chunked-encoded stream reading and writing. ## This module implements HTTP/1.1 chunked-encoded stream reading and writing.
import ../asyncloop, ../timer import ../asyncloop, ../timer
import asyncstream, ../transports/stream, ../transports/common import asyncstream, ../transports/stream, ../transports/common
import stew/results
export asyncstream, stream, timer, common export asyncstream, stream, timer, common
const const
ChunkBufferSize = 4096 ChunkBufferSize = 4096
ChunkHeaderSize = 8 ChunkHeaderSize = 8
# This is limit for chunk size to 8 hexadecimal digits, so maximum
# chunk size for this implementation become:
# 2^32 == FFFF_FFFF'u32 == 4,294,967,295 bytes.
CRLF = @[byte(0x0D), byte(0x0A)] CRLF = @[byte(0x0D), byte(0x0A)]
type type
@ -25,12 +29,6 @@ type
ChunkedStreamProtocolError* = object of ChunkedStreamError ChunkedStreamProtocolError* = object of ChunkedStreamError
ChunkedStreamIncompleteError* = object of ChunkedStreamError ChunkedStreamIncompleteError* = object of ChunkedStreamError
proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} =
newException(ChunkedStreamProtocolError, "Protocol error!")
proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} =
newException(ChunkedStreamIncompleteError, "Incomplete data received!")
proc `-`(x: uint32): uint32 {.inline.} = proc `-`(x: uint32): uint32 {.inline.} =
result = (0xFFFF_FFFF'u32 - x) + 1'u32 result = (0xFFFF_FFFF'u32 - x) + 1'u32
@ -47,18 +45,16 @@ proc hexValue(c: byte): int =
((z + 11'u32) and -LT(z, 6)) ((z + 11'u32) and -LT(z, 6))
int(r) - 1 int(r) - 1
proc getChunkSize(buffer: openarray[byte]): uint64 = proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] =
# We using `uint64` representation, but allow only 2^32 chunk size, # We using `uint64` representation, but allow only 2^32 chunk size,
# ChunkHeaderSize. # ChunkHeaderSize.
var res = 0'u64 var res = 0'u64
for i in 0..<min(len(buffer), ChunkHeaderSize): for i in 0 ..< min(len(buffer), ChunkHeaderSize):
let value = hexValue(buffer[i]) let value = hexValue(buffer[i])
if value >= 0: if value < 0:
res = (res shl 4) or uint64(value) return err("Incorrect chunk size encoding")
else: res = (res shl 4) or uint64(value)
res = 0xFFFF_FFFF_FFFF_FFFF'u64 ok(res)
break
res
proc setChunkSize(buffer: var openarray[byte], length: int64): int = proc setChunkSize(buffer: var openarray[byte], length: int64): int =
# Store length as chunk header size (hexadecimal value) with CRLF. # Store length as chunk header size (hexadecimal value) with CRLF.
@ -87,48 +83,54 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int =
i = i - 4 i = i - 4
buffer[c] = byte(0x0D) buffer[c] = byte(0x0D)
buffer[c + 1] = byte(0x0A) buffer[c + 1] = byte(0x0A)
c + 2 (c + 2)
proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[ChunkedStreamReader](stream) var rstream = ChunkedStreamReader(stream)
var buffer = newSeq[byte](1024) var buffer = newSeq[byte](1024)
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
while true: while true:
try: try:
# Reading chunk size # Reading chunk size
let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) let res = await rstream.rsource.readUntil(addr buffer[0], len(buffer),
var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1)) CRLF)
let cres = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1))
if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64: if cres.isErr():
rstream.error = newChunkedProtocolError() rstream.error = newException(ChunkedStreamProtocolError, $cres.error)
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
elif chunksize > 0'u64:
while chunksize > 0'u64:
let toRead = min(int(chunksize), rstream.buffer.bufferLen())
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
rstream.buffer.update(toRead)
await rstream.buffer.transfer()
chunksize = chunksize - uint64(toRead)
if rstream.state == AsyncStreamState.Running:
# Reading chunk trailing CRLF
await rstream.rsource.readExactly(addr buffer[0], 2)
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
rstream.error = newChunkedProtocolError()
rstream.state = AsyncStreamState.Error
else: else:
# Reading trailing line for last chunk var chunksize = cres.get()
discard await rstream.rsource.readUntil(addr buffer[0], if chunksize > 0'u64:
len(buffer), CRLF) while chunksize > 0'u64:
rstream.state = AsyncStreamState.Finished let toRead = min(int(chunksize), rstream.buffer.bufferLen())
await rstream.buffer.transfer() await rstream.rsource.readExactly(rstream.buffer.getBuffer(),
toRead)
rstream.buffer.update(toRead)
await rstream.buffer.transfer()
chunksize = chunksize - uint64(toRead)
if rstream.state == AsyncStreamState.Running:
# Reading chunk trailing CRLF
await rstream.rsource.readExactly(addr buffer[0], 2)
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
rstream.error = newException(ChunkedStreamProtocolError,
"Unexpected trailing bytes")
rstream.state = AsyncStreamState.Error
else:
# Reading trailing line for last chunk
discard await rstream.rsource.readUntil(addr buffer[0],
len(buffer), CRLF)
rstream.state = AsyncStreamState.Finished
await rstream.buffer.transfer()
except CancelledError: except CancelledError:
rstream.state = AsyncStreamState.Stopped rstream.state = AsyncStreamState.Stopped
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
rstream.error = newChunkedIncompleteError() rstream.error = newException(ChunkedStreamIncompleteError,
"Incomplete chunk received")
except AsyncStreamReadError as exc: except AsyncStreamReadError as exc:
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
rstream.error = exc rstream.error = exc
@ -140,7 +142,7 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
break break
proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[ChunkedStreamWriter](stream) var wstream = ChunkedStreamWriter(stream)
var buffer: array[16, byte] var buffer: array[16, byte]
var error: ref AsyncStreamError var error: ref AsyncStreamError
wstream.state = AsyncStreamState.Running wstream.state = AsyncStreamState.Running
@ -200,12 +202,12 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader, proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader,
bufferSize = ChunkBufferSize, udata: ref T) = bufferSize = ChunkBufferSize, udata: ref T) =
init(cast[AsyncStreamReader](child), rsource, chunkedReadLoop, bufferSize, init(AsyncStreamReader(child), rsource, chunkedReadLoop, bufferSize,
udata) udata)
proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader, proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader,
bufferSize = ChunkBufferSize) = bufferSize = ChunkBufferSize) =
init(cast[AsyncStreamReader](child), rsource, chunkedReadLoop, bufferSize) init(AsyncStreamReader(child), rsource, chunkedReadLoop, bufferSize)
proc newChunkedStreamReader*[T](rsource: AsyncStreamReader, proc newChunkedStreamReader*[T](rsource: AsyncStreamReader,
bufferSize = AsyncStreamDefaultBufferSize, bufferSize = AsyncStreamDefaultBufferSize,
@ -223,12 +225,12 @@ proc newChunkedStreamReader*(rsource: AsyncStreamReader,
proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter, proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
init(cast[AsyncStreamWriter](child), wsource, chunkedWriteLoop, queueSize, init(AsyncStreamWriter(child), wsource, chunkedWriteLoop, queueSize,
udata) udata)
proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter, proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize) = queueSize = AsyncStreamDefaultQueueSize) =
init(cast[AsyncStreamWriter](child), wsource, chunkedWriteLoop, queueSize) init(AsyncStreamWriter(child), wsource, chunkedWriteLoop, queueSize)
proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter, proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize, queueSize = AsyncStreamDefaultQueueSize,

View File

@ -91,6 +91,7 @@ type
TLSStreamError* = object of AsyncStreamError TLSStreamError* = object of AsyncStreamError
TLSStreamHandshakeError* = object of TLSStreamError TLSStreamHandshakeError* = object of TLSStreamError
TLSStreamInitError* = object of TLSStreamError
TLSStreamReadError* = object of TLSStreamError TLSStreamReadError* = object of TLSStreamError
par*: ref AsyncStreamError par*: ref AsyncStreamError
TLSStreamWriteError* = object of TLSStreamError TLSStreamWriteError* = object of TLSStreamError
@ -99,20 +100,20 @@ type
errCode*: int errCode*: int
proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {. proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {.
inline.} = noinline.} =
var w = newException(TLSStreamReadError, "Read stream failed") var w = newException(TLSStreamReadError, "Read stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.par = p
w w
proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {. proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {.
inline.} = noinline.} =
var w = newException(TLSStreamWriteError, "Write stream failed") var w = newException(TLSStreamWriteError, "Write stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.par = p
w w
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = template newTLSStreamProtocolImpl[T](message: T): ref TLSStreamProtocolError =
var msg = "" var msg = ""
var code = 0 var code = 0
when T is string: when T is string:
@ -129,6 +130,12 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
err.errCode = code err.errCode = code
err err
proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
newTLSStreamProtocolImpl(message)
proc raiseTLSStreamProtocolError[T](message: T) {.noreturn, noinline.} =
raise newTLSStreamProtocolImpl(message)
proc tlsWriteRec(engine: ptr SslEngineContext, proc tlsWriteRec(engine: ptr SslEngineContext,
writer: TLSStreamWriter): Future[TLSResult] {.async.} = writer: TLSStreamWriter): Future[TLSResult] {.async.} =
try: try:
@ -208,9 +215,6 @@ proc tlsReadApp(engine: ptr SslEngineContext,
reader.state = AsyncStreamState.Stopped reader.state = AsyncStreamState.Stopped
return TLSResult.Error return TLSResult.Error
template raiseTLSStreamProtoError*[T](message: T) =
raise newTLSStreamProtocolError(message)
template readAndReset(fut: untyped) = template readAndReset(fut: untyped) =
if fut.finished(): if fut.finished():
let res = fut.read() let res = fut.read()
@ -386,7 +390,7 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
stream.reader.buffer.forget() stream.reader.buffer.forget()
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[TLSStreamWriter](stream) var wstream = TLSStreamWriter(stream)
wstream.state = AsyncStreamState.Running wstream.state = AsyncStreamState.Running
await stepsAsync(1) await stepsAsync(1)
if isNil(wstream.stream.mainLoop): if isNil(wstream.stream.mainLoop):
@ -394,7 +398,7 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
await wstream.stream.mainLoop await wstream.stream.mainLoop
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TLSStreamReader](stream) var rstream = TLSStreamReader(stream)
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
await stepsAsync(1) await stepsAsync(1)
if isNil(rstream.stream.mainLoop): if isNil(rstream.stream.mainLoop):
@ -468,18 +472,19 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
if TLSFlags.NoVerifyServerName in flags: if TLSFlags.NoVerifyServerName in flags:
let err = sslClientReset(addr res.ccontext, "", 0) let err = sslClientReset(addr res.ccontext, "", 0)
if err == 0: if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer") raise newException(TLSStreamInitError, "Could not initialize TLS layer")
else: else:
if len(serverName) == 0: if len(serverName) == 0:
raise newException(TLSStreamError, "serverName must not be empty string") raise newException(TLSStreamInitError,
"serverName must not be empty string")
let err = sslClientReset(addr res.ccontext, serverName, 0) let err = sslClientReset(addr res.ccontext, serverName, 0)
if err == 0: if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer") raise newException(TLSStreamInitError, "Could not initialize TLS layer")
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop,
bufferSize) bufferSize)
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, init(AsyncStreamReader(res.reader), rsource, tlsReadLoop,
bufferSize) bufferSize)
res res
@ -507,9 +512,9 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
## ##
## ``flags`` - custom TLS connection flags. ## ``flags`` - custom TLS connection flags.
if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}: if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}:
raiseTLSStreamProtoError("Incorrect private key") raiseTLSStreamProtocolError("Incorrect private key")
if isNil(certificate) or len(certificate.certs) == 0: if isNil(certificate) or len(certificate.certs) == 0:
raiseTLSStreamProtoError("Incorrect certificate") raiseTLSStreamProtocolError("Incorrect certificate")
var res = TLSAsyncStream() var res = TLSAsyncStream()
var reader = TLSStreamReader( var reader = TLSStreamReader(
@ -528,7 +533,7 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
if privateKey.kind == TLSKeyType.EC: if privateKey.kind == TLSKeyType.EC:
let algo = getSignerAlgo(certificate.certs[0]) let algo = getSignerAlgo(certificate.certs[0])
if algo == -1: if algo == -1:
raiseTLSStreamProtoError("Could not decode certificate") raiseTLSStreamProtocolError("Could not decode certificate")
sslServerInitFullEc(addr res.scontext, addr certificate.certs[0], sslServerInitFullEc(addr res.scontext, addr certificate.certs[0],
len(certificate.certs), cuint(algo), len(certificate.certs), cuint(algo),
addr privateKey.eckey) addr privateKey.eckey)
@ -557,11 +562,11 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
let err = sslServerReset(addr res.scontext) let err = sslServerReset(addr res.scontext)
if err == 0: if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer") raise newException(TLSStreamInitError, "Could not initialize TLS layer")
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop,
bufferSize) bufferSize)
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, init(AsyncStreamReader(res.reader), rsource, tlsReadLoop,
bufferSize) bufferSize)
res res
@ -610,12 +615,12 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded). ## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded).
var ctx: SkeyDecoderContext var ctx: SkeyDecoderContext
if len(data) == 0: if len(data) == 0:
raiseTLSStreamProtoError("Incorrect private key") raiseTLSStreamProtocolError("Incorrect private key")
skeyDecoderInit(addr ctx) skeyDecoderInit(addr ctx)
skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data)) skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data))
let err = skeyDecoderLastError(addr ctx) let err = skeyDecoderLastError(addr ctx)
if err != 0: if err != 0:
raiseTLSStreamProtoError(err) raiseTLSStreamProtocolError(err)
let keyType = skeyDecoderKeyType(addr ctx) let keyType = skeyDecoderKeyType(addr ctx)
let res = let res =
if keyType == KEYTYPE_RSA: if keyType == KEYTYPE_RSA:
@ -623,13 +628,13 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
elif keyType == KEYTYPE_EC: elif keyType == KEYTYPE_EC:
copyKey(ctx.key.ec) copyKey(ctx.key.ec)
else: else:
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") raiseTLSStreamProtocolError("Unknown key type (" & $keyType & ")")
res res
proc pemDecode*(data: openarray[char]): seq[PEMElement] = proc pemDecode*(data: openarray[char]): seq[PEMElement] =
## Decode PEM encoded string and get array of binary blobs. ## Decode PEM encoded string and get array of binary blobs.
if len(data) == 0: if len(data) == 0:
raiseTLSStreamProtoError("Empty PEM message") raiseTLSStreamProtocolError("Empty PEM message")
var ctx: PemDecoderContext var ctx: PemDecoderContext
var pctx = new PEMContext var pctx = new PEMContext
var res = newSeq[PEMElement]() var res = newSeq[PEMElement]()
@ -666,7 +671,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] =
else: else:
break break
else: else:
raiseTLSStreamProtoError("Invalid PEM encoding") raiseTLSStreamProtocolError("Invalid PEM encoding")
res res
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
@ -683,7 +688,7 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
res = TLSPrivateKey.init(item.data) res = TLSPrivateKey.init(item.data)
break break
if isNil(res): if isNil(res):
raiseTLSStreamProtoError("Could not find private key") raiseTLSStreamProtocolError("Could not find private key")
res res
proc init*(tt: typedesc[TLSCertificate], proc init*(tt: typedesc[TLSCertificate],
@ -703,12 +708,13 @@ proc init*(tt: typedesc[TLSCertificate],
) )
let ares = getSignerAlgo(cert) let ares = getSignerAlgo(cert)
if ares == -1: if ares == -1:
raiseTLSStreamProtoError("Could not decode certificate") raiseTLSStreamProtocolError("Could not decode certificate")
elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC: elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC:
raiseTLSStreamProtoError("Unsupported signing key type in certificate") raiseTLSStreamProtocolError(
"Unsupported signing key type in certificate")
res.certs.add(cert) res.certs.add(cert)
if len(res.storage) == 0: if len(res.storage) == 0:
raiseTLSStreamProtoError("Could not find any certificates") raiseTLSStreamProtocolError("Could not find any certificates")
res res
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =

View File

@ -813,6 +813,7 @@ suite "HTTP server testing suite":
res2.isOk() res2.isOk()
res2.get() == FlagsVectors[i] res2.get() == FlagsVectors[i]
res3.isErr() res3.isErr()
res4.isErr()
res5.isOk() res5.isOk()
res5.get() == FlagsVectors[i] res5.get() == FlagsVectors[i]
@ -864,6 +865,7 @@ suite "HTTP server testing suite":
res2.isOk() res2.isOk()
res2.get() == FlagsVectors[i] res2.get() == FlagsVectors[i]
res3.isErr() res3.isErr()
res4.isErr()
res5.isOk() res5.isOk()
res5.get() == FlagsVectors[i] res5.get() == FlagsVectors[i]