From fc0d1bcb43a4d5176cfc82840cf728b83cbe7206 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 17 Feb 2021 02:03:12 +0200 Subject: [PATCH] Address review comments. --- chronos/apps/http/httpcommon.nim | 27 +++------ chronos/apps/http/httpserver.nim | 57 ++++++++++--------- chronos/apps/http/multipart.nim | 26 ++++----- chronos/streams/asyncstream.nim | 41 +++++++++----- chronos/streams/boundstream.nim | 19 +++---- chronos/streams/chunkstream.nim | 94 ++++++++++++++++---------------- chronos/streams/tlsstream.nim | 62 +++++++++++---------- tests/testhttpserver.nim | 2 + 8 files changed, 168 insertions(+), 160 deletions(-) diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index c186573..52e440e 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -63,28 +63,15 @@ proc atBound*(bstream: HttpBodyReader): bool {. let breader = cast[BoundedStreamReader](lreader) breader.atEof() and (breader.bytesLeft() == 0) -proc newHttpDefect*(msg: string): ref HttpDefect {. - raises: [HttpDefect].} = - newException(HttpDefect, msg) +proc raiseHttpCriticalError*(msg: string, + code = Http400) {.noinline, noreturn.} = + raise (ref HttpCriticalError)(code: code, msg: msg) -proc newHttpCriticalError*(msg: string, - code = Http400): ref HttpCriticalError {. - raises: [HttpCriticalError].} = - var tre = newException(HttpCriticalError, msg) - tre.code = code - tre +proc raiseHttpDisconnectError*() {.noinline, noreturn.} = + raise (ref HttpDisconnectError)(msg: "Remote peer disconnected") -proc newHttpRecoverableError*(msg: string, - code = Http400): ref HttpRecoverableError {. - raises: [HttpRecoverableError].} = - var tre = newException(HttpRecoverableError, msg) - tre.code = code - tre - -proc newHttpDisconnectError*(): ref HttpDisconnectError {. - raises: [HttpDisconnectError].} = - var tre = newException(HttpDisconnectError, "Remote peer disconnected") - tre +proc raiseHttpDefect*(msg: string) {.noinline, noreturn.} = + raise (ref HttpDefect)(msg: msg) iterator queryParams*(query: string): tuple[key: string, value: string] {. raises: [Defect].} = diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index 08fbd15..f538d2f 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -349,7 +349,7 @@ proc handleExpect*(request: HttpRequestRef) {.async.} = except CancelledError as exc: raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: - raise newHttpCriticalError("Unable to send `100-continue` response") + raiseHttpCriticalError("Unable to send `100-continue` response") proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} = ## Obtain request's body as sequence of bytes. @@ -363,9 +363,9 @@ proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} = return await reader.read() except AsyncStreamError: if reader.atBound(): - raise newHttpCriticalError("Maximum size of body reached", Http413) + raiseHttpCriticalError("Maximum size of body reached", Http413) else: - raise newHttpCriticalError("Unable to read request's body") + raiseHttpCriticalError("Unable to read request's body") finally: await closeWait(res.get()) @@ -381,9 +381,9 @@ proc consumeBody*(request: HttpRequestRef): Future[void] {.async.} = discard await reader.consume() except AsyncStreamError: if reader.atBound(): - raise newHttpCriticalError("Maximum size of body reached", Http413) + raiseHttpCriticalError("Maximum size of body reached", Http413) else: - raise newHttpCriticalError("Unable to read request's body") + raiseHttpCriticalError("Unable to read request's body") finally: await closeWait(res.get()) @@ -422,18 +422,17 @@ proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} = conn.buffer.setLen(res) let header = parseRequest(conn.buffer) if header.failed(): - raise newHttpCriticalError("Malformed request recieved") + raiseHttpCriticalError("Malformed request recieved") else: let res = prepareRequest(conn, header) if res.isErr(): - raise newHttpCriticalError("Invalid request received", res.error) + raiseHttpCriticalError("Invalid request received", res.error) else: return res.get() except AsyncStreamIncompleteError, AsyncStreamReadError: - raise newHttpDisconnectError() + raiseHttpDisconnectError() except AsyncStreamLimitError: - raise newHttpCriticalError("Maximum size of request headers reached", - Http413) + raiseHttpCriticalError("Maximum size of request headers reached", Http413) proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef, transp: StreamTransport): HttpConnectionRef = @@ -503,7 +502,7 @@ proc createConnection(server: HttpServerRef, raise exc except TLSStreamError: await conn.closeWait() - raise newHttpCriticalError("Unable to establish secure connection") + raiseHttpCriticalError("Unable to establish secure connection") proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = var @@ -534,7 +533,7 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = return except CatchableError as exc: # There should be no exceptions, so we will raise `Defect`. - raise newHttpDefect("Unexpected exception catched [" & $exc.name & "]") + raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]") var breakLoop = false while runLoop: @@ -763,7 +762,7 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {.async.} = var table = HttpTable.init() let res = getMultipartReader(req) if res.isErr(): - raise newHttpCriticalError("Unable to retrieve multipart form data") + raiseHttpCriticalError("Unable to retrieve multipart form data") var mpreader = res.get() # We must handle `Expect` first. @@ -808,10 +807,10 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {.async.} = else: if HttpRequestFlags.BoundBody in req.requestFlags: if req.contentLength != 0: - raise newHttpCriticalError("Unsupported request body") + raiseHttpCriticalError("Unsupported request body") return HttpTable.init() elif HttpRequestFlags.UnboundBody in req.requestFlags: - raise newHttpCriticalError("Unsupported request body") + raiseHttpCriticalError("Unsupported request body") proc `keepalive=`*(resp: HttpResponseRef, value: bool) = doAssert(resp.state == HttpResponseState.Empty) @@ -854,7 +853,7 @@ template doHeaderVal(buf, name, value) = template checkPending(t: untyped) = if t.state != HttpResponseState.Empty: - raise newHttpCriticalError("Response body was already sent") + raiseHttpCriticalError("Response body was already sent") proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {. raises: [Defect].}= @@ -910,7 +909,7 @@ proc sendBody*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} = raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} = ## Send HTTP response at once by using data ``data``. @@ -928,7 +927,7 @@ proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} = raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} = ## Send HTTP error status response. @@ -947,7 +946,7 @@ proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} = raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc prepare*(resp: HttpResponseRef) {.async.} = ## Prepare for HTTP stream response. @@ -966,16 +965,16 @@ proc prepare*(resp: HttpResponseRef) {.async.} = raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} = ## Send single chunk of data pointed by ``pbytes`` and ``nbytes``. doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(nbytes >= 0, "nbytes should be bigger or equal to zero") if HttpResponseFlags.Chunked notin resp.flags: - raise newHttpCriticalError("Response was not prepared") + raiseHttpCriticalError("Response was not prepared") if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}: - raise newHttpCriticalError("Response in incorrect state") + raiseHttpCriticalError("Response in incorrect state") try: resp.state = HttpResponseState.Sending await resp.chunkedWriter.write(pbytes, nbytes) @@ -985,15 +984,15 @@ proc sendChunk*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} = raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} = ## Send single chunk of data ``data``. if HttpResponseFlags.Chunked notin resp.flags: - raise newHttpCriticalError("Response was not prepared") + raiseHttpCriticalError("Response was not prepared") if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}: - raise newHttpCriticalError("Response in incorrect state") + raiseHttpCriticalError("Response in incorrect state") try: resp.state = HttpResponseState.Sending await resp.chunkedWriter.write(data) @@ -1003,14 +1002,14 @@ proc sendChunk*[T: string|seq[byte]](resp: HttpResponseRef, raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc finish*(resp: HttpResponseRef) {.async.} = ## Sending last chunk of data, so it will indicate end of HTTP response. if HttpResponseFlags.Chunked notin resp.flags: - raise newHttpCriticalError("Response was not prepared") + raiseHttpCriticalError("Response was not prepared") if resp.state notin {HttpResponseState.Prepared, HttpResponseState.Sending}: - raise newHttpCriticalError("Response in incorrect state") + raiseHttpCriticalError("Response in incorrect state") try: resp.state = HttpResponseState.Sending await resp.chunkedWriter.finish() @@ -1020,7 +1019,7 @@ proc finish*(resp: HttpResponseRef) {.async.} = raise exc except AsyncStreamWriteError, AsyncStreamIncompleteError: resp.state = HttpResponseState.Failed - raise newHttpCriticalError("Unable to send response") + raiseHttpCriticalError("Unable to send response") proc respond*(req: HttpRequestRef, code: HttpCode, content: string, headers: HttpTable): Future[HttpResponseRef] {.async.} = diff --git a/chronos/apps/http/multipart.nim b/chronos/apps/http/multipart.nim index 2e99c27..88f4dd3 100644 --- a/chronos/apps/http/multipart.nim +++ b/chronos/apps/http/multipart.nim @@ -149,14 +149,14 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = mpr.firstTime = false if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3), mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))): - raise newHttpCriticalError("Unexpected boundary encountered") + raiseHttpCriticalError("Unexpected boundary encountered") except CancelledError as exc: raise exc except AsyncStreamError: if mpr.stream.atBound(): - raise newHttpCriticalError("Maximum size of body reached", Http413) + raiseHttpCriticalError("Maximum size of body reached", Http413) else: - raise newHttpCriticalError("Unable to read multipart body") + raiseHttpCriticalError("Unable to read multipart body") # Reading part's headers try: @@ -170,9 +170,9 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = raise newException(MultipartEOMError, "End of multipart message") else: - raise newHttpCriticalError("Incorrect multipart header found") + raiseHttpCriticalError("Incorrect multipart header found") if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8: - raise newHttpCriticalError("Incorrect multipart boundary found") + raiseHttpCriticalError("Incorrect multipart boundary found") # If two bytes are CRLF we are at the part beginning. # Reading part's headers @@ -180,7 +180,7 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = HeadersMark) var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false) if headersList.failed(): - raise newHttpCriticalError("Incorrect multipart's headers found") + raiseHttpCriticalError("Incorrect multipart's headers found") inc(mpr.counter) var part = MultiPart( @@ -196,16 +196,16 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = let sres = part.setPartNames() if sres.isErr(): - raise newHttpCriticalError(sres.error) + raiseHttpCriticalError($sres.error) return part except CancelledError as exc: raise exc except AsyncStreamError: if mpr.stream.atBound(): - raise newHttpCriticalError("Maximum size of body reached", Http413) + raiseHttpCriticalError("Maximum size of body reached", Http413) else: - raise newHttpCriticalError("Unable to read multipart body") + raiseHttpCriticalError("Unable to read multipart body") proc atBound*(mp: MultiPart): bool = ## Returns ``true`` if MultiPart's stream reached request body maximum size. @@ -220,9 +220,9 @@ proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} = return res except AsyncStreamError: if mp.breader.atBound(): - raise newHttpCriticalError("Maximum size of body reached", Http413) + raiseHttpCriticalError("Maximum size of body reached", Http413) else: - raise newHttpCriticalError("Unable to read multipart body") + raiseHttpCriticalError("Unable to read multipart body") of MultiPartSource.Buffer: return mp.buffer @@ -234,9 +234,9 @@ proc consumeBody*(mp: MultiPart) {.async.} = discard await mp.stream.consume() except AsyncStreamError: if mp.breader.atBound(): - raise newHttpCriticalError("Maximum size of body reached", Http413) + raiseHttpCriticalError("Maximum size of body reached", Http413) else: - raise newHttpCriticalError("Unable to consume multipart body") + raiseHttpCriticalError("Unable to consume multipart body") of MultiPartSource.Buffer: discard diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index 05dc16f..e25d108 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -22,7 +22,7 @@ const type AsyncStreamError* = object of CatchableError - AsyncStreamIncorrectError* = object of Defect + AsyncStreamIncorrectDefect* = object of Defect AsyncStreamIncompleteError* = object of AsyncStreamError AsyncStreamLimitError* = object of AsyncStreamError AsyncStreamUseClosedError* = object of AsyncStreamError @@ -179,36 +179,49 @@ template copyOut*(dest: pointer, item: WriteItem, length: int) = copyMem(dest, unsafeAddr item.data3[item.offset], length) proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. - inline.} = + noinline.} = var w = newException(AsyncStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p w proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {. - inline.} = + noinline.} = var w = newException(AsyncStreamWriteError, "Write stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p w proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {. - inline.} = + noinline.} = newException(AsyncStreamIncompleteError, "Incomplete data sent or received") -proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.inline.} = +proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.noinline.} = newException(AsyncStreamLimitError, "Buffer limit reached") -proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.inline.} = +proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {. + noinline.} = newException(AsyncStreamUseClosedError, "Stream is already closed") -proc newAsyncStreamIncorrectError*(m: string): ref AsyncStreamIncorrectError {. - inline.} = - newException(AsyncStreamIncorrectError, m) +proc raiseAsyncStreamUseClosedError*() {.noinline, noreturn.} = + raise newAsyncStreamUseClosedError() + +proc raiseAsyncStreamLimitError*() {.noinline, noreturn.} = + raise newAsyncStreamLimitError() + +proc raiseAsyncStreamIncompleteError*() {.noinline, noreturn.} = + raise newAsyncStreamIncompleteError() + +proc raiseAsyncStreamIncorrectDefect*(m: string) {.noinline, noreturn.} = + raise newException(AsyncStreamIncorrectDefect, m) + +proc raiseEmptyMessageDefect*() {.noinline, noreturn.} = + raise newException(AsyncStreamIncorrectDefect, + "Could not write empty message") template checkStreamClosed*(t: untyped) = if t.state == AsyncStreamState.Closed: - raise newAsyncStreamUseClosedError() + raiseAsyncStreamUseClosedError() proc atEof*(rstream: AsyncStreamReader): bool = ## Returns ``true`` is reading stream is closed or finished and internal @@ -677,7 +690,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, ## ``nbytes` must be more then zero. checkStreamClosed(wstream) if nbytes <= 0: - raise newAsyncStreamIncorrectError("Zero length message") + raiseEmptyMessageDefect() if isNil(wstream.wsource): var res: int @@ -725,7 +738,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], checkStreamClosed(wstream) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) if length <= 0: - raise newAsyncStreamIncorrectError("Zero length message") + raiseEmptyMessageDefect() if isNil(wstream.wsource): var res: int @@ -773,7 +786,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, checkStreamClosed(wstream) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) if length <= 0: - raise newAsyncStreamIncorrectError("Zero length message") + raiseEmptyMessageDefect() if isNil(wstream.wsource): var res: int @@ -857,7 +870,7 @@ proc close*(rw: AsyncStreamRW) = ## ## Note close() procedure is not completed immediately! if rw.closed(): - raise newAsyncStreamIncorrectError("Stream is already closed!") + raiseAsyncStreamIncorrectDefect("Stream is already closed!") rw.state = AsyncStreamState.Closed diff --git a/chronos/streams/boundstream.nim b/chronos/streams/boundstream.nim index aa11d06..e927768 100644 --- a/chronos/streams/boundstream.nim +++ b/chronos/streams/boundstream.nim @@ -42,11 +42,9 @@ type const BoundedBufferSize* = 4096 -template newBoundedStreamIncompleteError*(): ref BoundedStreamError = +proc newBoundedStreamIncompleteError*(): ref BoundedStreamError {.noinline.} = newException(BoundedStreamIncompleteError, "Stream boundary is not reached yet") -template newBoundedStreamOverflowError*(): ref BoundedStreamError = - newException(BoundedStreamOverflowError, "Stream boundary exceeded") proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, sep: seq[byte]): Future[int] {.async.} = @@ -94,7 +92,7 @@ func endsWith(s, suffix: openarray[byte]): bool = if i >= len(suffix): return true proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = - var rstream = cast[BoundedStreamReader](stream) + var rstream = BoundedStreamReader(stream) rstream.state = AsyncStreamState.Running var buffer = newSeq[byte](rstream.buffer.bufferLen()) while true: @@ -157,7 +155,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = rstream.buffer.forget() proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = - var wstream = cast[BoundedStreamWriter](stream) + var wstream = BoundedStreamWriter(stream) wstream.state = AsyncStreamState.Running while true: @@ -181,7 +179,8 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = item.future.complete() else: wstream.state = AsyncStreamState.Error - error = newBoundedStreamOverflowError() + error = newException(BoundedStreamOverflowError, + "Stream boundary exceeded") else: if wstream.offset != wstream.boundSize: case wstream.cmpop @@ -223,12 +222,12 @@ proc bytesLeft*(stream: BoundedStreamRW): uint64 = proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader, bufferSize = BoundedBufferSize, udata: ref T) = - init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize, + init(AsyncStreamReader(child), rsource, boundedReadLoop, bufferSize, udata) proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader, bufferSize = BoundedBufferSize) = - init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize) + init(AsyncStreamReader(child), rsource, boundedReadLoop, bufferSize) proc newBoundedStreamReader*[T](rsource: AsyncStreamReader, boundSize: int, @@ -258,12 +257,12 @@ proc newBoundedStreamReader*(rsource: AsyncStreamReader, proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = - init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize, + init(AsyncStreamWriter(child), wsource, boundedWriteLoop, queueSize, udata) proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize) = - init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize) + init(AsyncStreamWriter(child), wsource, boundedWriteLoop, queueSize) proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter, boundSize: int, diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index 9e65d8f..0cf61fd 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -10,11 +10,15 @@ ## This module implements HTTP/1.1 chunked-encoded stream reading and writing. import ../asyncloop, ../timer import asyncstream, ../transports/stream, ../transports/common +import stew/results export asyncstream, stream, timer, common const ChunkBufferSize = 4096 ChunkHeaderSize = 8 + # This is limit for chunk size to 8 hexadecimal digits, so maximum + # chunk size for this implementation become: + # 2^32 == FFFF_FFFF'u32 == 4,294,967,295 bytes. CRLF = @[byte(0x0D), byte(0x0A)] type @@ -25,12 +29,6 @@ type ChunkedStreamProtocolError* = object of ChunkedStreamError ChunkedStreamIncompleteError* = object of ChunkedStreamError -proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} = - newException(ChunkedStreamProtocolError, "Protocol error!") - -proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} = - newException(ChunkedStreamIncompleteError, "Incomplete data received!") - proc `-`(x: uint32): uint32 {.inline.} = result = (0xFFFF_FFFF'u32 - x) + 1'u32 @@ -47,18 +45,16 @@ proc hexValue(c: byte): int = ((z + 11'u32) and -LT(z, 6)) int(r) - 1 -proc getChunkSize(buffer: openarray[byte]): uint64 = +proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] = # We using `uint64` representation, but allow only 2^32 chunk size, # ChunkHeaderSize. var res = 0'u64 - for i in 0..= 0: - res = (res shl 4) or uint64(value) - else: - res = 0xFFFF_FFFF_FFFF_FFFF'u64 - break - res + if value < 0: + return err("Incorrect chunk size encoding") + res = (res shl 4) or uint64(value) + ok(res) proc setChunkSize(buffer: var openarray[byte], length: int64): int = # Store length as chunk header size (hexadecimal value) with CRLF. @@ -87,48 +83,54 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int = i = i - 4 buffer[c] = byte(0x0D) buffer[c + 1] = byte(0x0A) - c + 2 + (c + 2) proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = - var rstream = cast[ChunkedStreamReader](stream) + var rstream = ChunkedStreamReader(stream) var buffer = newSeq[byte](1024) rstream.state = AsyncStreamState.Running while true: try: # Reading chunk size - let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) - var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1)) + let res = await rstream.rsource.readUntil(addr buffer[0], len(buffer), + CRLF) + let cres = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1)) - if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64: - rstream.error = newChunkedProtocolError() + if cres.isErr(): + rstream.error = newException(ChunkedStreamProtocolError, $cres.error) rstream.state = AsyncStreamState.Error - elif chunksize > 0'u64: - while chunksize > 0'u64: - let toRead = min(int(chunksize), rstream.buffer.bufferLen()) - await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead) - rstream.buffer.update(toRead) - await rstream.buffer.transfer() - chunksize = chunksize - uint64(toRead) - - if rstream.state == AsyncStreamState.Running: - # Reading chunk trailing CRLF - await rstream.rsource.readExactly(addr buffer[0], 2) - - if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: - rstream.error = newChunkedProtocolError() - rstream.state = AsyncStreamState.Error else: - # Reading trailing line for last chunk - discard await rstream.rsource.readUntil(addr buffer[0], - len(buffer), CRLF) - rstream.state = AsyncStreamState.Finished - await rstream.buffer.transfer() + var chunksize = cres.get() + if chunksize > 0'u64: + while chunksize > 0'u64: + let toRead = min(int(chunksize), rstream.buffer.bufferLen()) + await rstream.rsource.readExactly(rstream.buffer.getBuffer(), + toRead) + rstream.buffer.update(toRead) + await rstream.buffer.transfer() + chunksize = chunksize - uint64(toRead) + + if rstream.state == AsyncStreamState.Running: + # Reading chunk trailing CRLF + await rstream.rsource.readExactly(addr buffer[0], 2) + + if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: + rstream.error = newException(ChunkedStreamProtocolError, + "Unexpected trailing bytes") + rstream.state = AsyncStreamState.Error + else: + # Reading trailing line for last chunk + discard await rstream.rsource.readUntil(addr buffer[0], + len(buffer), CRLF) + rstream.state = AsyncStreamState.Finished + await rstream.buffer.transfer() except CancelledError: rstream.state = AsyncStreamState.Stopped except AsyncStreamIncompleteError: rstream.state = AsyncStreamState.Error - rstream.error = newChunkedIncompleteError() + rstream.error = newException(ChunkedStreamIncompleteError, + "Incomplete chunk received") except AsyncStreamReadError as exc: rstream.state = AsyncStreamState.Error rstream.error = exc @@ -140,7 +142,7 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = break proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = - var wstream = cast[ChunkedStreamWriter](stream) + var wstream = ChunkedStreamWriter(stream) var buffer: array[16, byte] var error: ref AsyncStreamError wstream.state = AsyncStreamState.Running @@ -200,12 +202,12 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader, bufferSize = ChunkBufferSize, udata: ref T) = - init(cast[AsyncStreamReader](child), rsource, chunkedReadLoop, bufferSize, + init(AsyncStreamReader(child), rsource, chunkedReadLoop, bufferSize, udata) proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader, bufferSize = ChunkBufferSize) = - init(cast[AsyncStreamReader](child), rsource, chunkedReadLoop, bufferSize) + init(AsyncStreamReader(child), rsource, chunkedReadLoop, bufferSize) proc newChunkedStreamReader*[T](rsource: AsyncStreamReader, bufferSize = AsyncStreamDefaultBufferSize, @@ -223,12 +225,12 @@ proc newChunkedStreamReader*(rsource: AsyncStreamReader, proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = - init(cast[AsyncStreamWriter](child), wsource, chunkedWriteLoop, queueSize, + init(AsyncStreamWriter(child), wsource, chunkedWriteLoop, queueSize, udata) proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize) = - init(cast[AsyncStreamWriter](child), wsource, chunkedWriteLoop, queueSize) + init(AsyncStreamWriter(child), wsource, chunkedWriteLoop, queueSize) proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index eda191f..960f56f 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -91,6 +91,7 @@ type TLSStreamError* = object of AsyncStreamError TLSStreamHandshakeError* = object of TLSStreamError + TLSStreamInitError* = object of TLSStreamError TLSStreamReadError* = object of TLSStreamError par*: ref AsyncStreamError TLSStreamWriteError* = object of TLSStreamError @@ -99,20 +100,20 @@ type errCode*: int proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {. - inline.} = + noinline.} = var w = newException(TLSStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p w proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {. - inline.} = + noinline.} = var w = newException(TLSStreamWriteError, "Write stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p w -template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = +template newTLSStreamProtocolImpl[T](message: T): ref TLSStreamProtocolError = var msg = "" var code = 0 when T is string: @@ -129,6 +130,12 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = err.errCode = code err +proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = + newTLSStreamProtocolImpl(message) + +proc raiseTLSStreamProtocolError[T](message: T) {.noreturn, noinline.} = + raise newTLSStreamProtocolImpl(message) + proc tlsWriteRec(engine: ptr SslEngineContext, writer: TLSStreamWriter): Future[TLSResult] {.async.} = try: @@ -208,9 +215,6 @@ proc tlsReadApp(engine: ptr SslEngineContext, reader.state = AsyncStreamState.Stopped return TLSResult.Error -template raiseTLSStreamProtoError*[T](message: T) = - raise newTLSStreamProtocolError(message) - template readAndReset(fut: untyped) = if fut.finished(): let res = fut.read() @@ -386,7 +390,7 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = stream.reader.buffer.forget() proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = - var wstream = cast[TLSStreamWriter](stream) + var wstream = TLSStreamWriter(stream) wstream.state = AsyncStreamState.Running await stepsAsync(1) if isNil(wstream.stream.mainLoop): @@ -394,7 +398,7 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = await wstream.stream.mainLoop proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = - var rstream = cast[TLSStreamReader](stream) + var rstream = TLSStreamReader(stream) rstream.state = AsyncStreamState.Running await stepsAsync(1) if isNil(rstream.stream.mainLoop): @@ -468,18 +472,19 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, if TLSFlags.NoVerifyServerName in flags: let err = sslClientReset(addr res.ccontext, "", 0) if err == 0: - raise newException(TLSStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamInitError, "Could not initialize TLS layer") else: if len(serverName) == 0: - raise newException(TLSStreamError, "serverName must not be empty string") + raise newException(TLSStreamInitError, + "serverName must not be empty string") let err = sslClientReset(addr res.ccontext, serverName, 0) if err == 0: - raise newException(TLSStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamInitError, "Could not initialize TLS layer") - init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, + init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop, bufferSize) - init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, + init(AsyncStreamReader(res.reader), rsource, tlsReadLoop, bufferSize) res @@ -507,9 +512,9 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, ## ## ``flags`` - custom TLS connection flags. if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}: - raiseTLSStreamProtoError("Incorrect private key") + raiseTLSStreamProtocolError("Incorrect private key") if isNil(certificate) or len(certificate.certs) == 0: - raiseTLSStreamProtoError("Incorrect certificate") + raiseTLSStreamProtocolError("Incorrect certificate") var res = TLSAsyncStream() var reader = TLSStreamReader( @@ -528,7 +533,7 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, if privateKey.kind == TLSKeyType.EC: let algo = getSignerAlgo(certificate.certs[0]) if algo == -1: - raiseTLSStreamProtoError("Could not decode certificate") + raiseTLSStreamProtocolError("Could not decode certificate") sslServerInitFullEc(addr res.scontext, addr certificate.certs[0], len(certificate.certs), cuint(algo), addr privateKey.eckey) @@ -557,11 +562,11 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, let err = sslServerReset(addr res.scontext) if err == 0: - raise newException(TLSStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamInitError, "Could not initialize TLS layer") - init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, + init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop, bufferSize) - init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, + init(AsyncStreamReader(res.reader), rsource, tlsReadLoop, bufferSize) res @@ -610,12 +615,12 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = ## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded). var ctx: SkeyDecoderContext if len(data) == 0: - raiseTLSStreamProtoError("Incorrect private key") + raiseTLSStreamProtocolError("Incorrect private key") skeyDecoderInit(addr ctx) skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data)) let err = skeyDecoderLastError(addr ctx) if err != 0: - raiseTLSStreamProtoError(err) + raiseTLSStreamProtocolError(err) let keyType = skeyDecoderKeyType(addr ctx) let res = if keyType == KEYTYPE_RSA: @@ -623,13 +628,13 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = elif keyType == KEYTYPE_EC: copyKey(ctx.key.ec) else: - raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + raiseTLSStreamProtocolError("Unknown key type (" & $keyType & ")") res proc pemDecode*(data: openarray[char]): seq[PEMElement] = ## Decode PEM encoded string and get array of binary blobs. if len(data) == 0: - raiseTLSStreamProtoError("Empty PEM message") + raiseTLSStreamProtocolError("Empty PEM message") var ctx: PemDecoderContext var pctx = new PEMContext var res = newSeq[PEMElement]() @@ -666,7 +671,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] = else: break else: - raiseTLSStreamProtoError("Invalid PEM encoding") + raiseTLSStreamProtocolError("Invalid PEM encoding") res proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = @@ -683,7 +688,7 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = res = TLSPrivateKey.init(item.data) break if isNil(res): - raiseTLSStreamProtoError("Could not find private key") + raiseTLSStreamProtocolError("Could not find private key") res proc init*(tt: typedesc[TLSCertificate], @@ -703,12 +708,13 @@ proc init*(tt: typedesc[TLSCertificate], ) let ares = getSignerAlgo(cert) if ares == -1: - raiseTLSStreamProtoError("Could not decode certificate") + raiseTLSStreamProtocolError("Could not decode certificate") elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC: - raiseTLSStreamProtoError("Unsupported signing key type in certificate") + raiseTLSStreamProtocolError( + "Unsupported signing key type in certificate") res.certs.add(cert) if len(res.storage) == 0: - raiseTLSStreamProtoError("Could not find any certificates") + raiseTLSStreamProtocolError("Could not find any certificates") res proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index 76793af..3e6e809 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -813,6 +813,7 @@ suite "HTTP server testing suite": res2.isOk() res2.get() == FlagsVectors[i] res3.isErr() + res4.isErr() res5.isOk() res5.get() == FlagsVectors[i] @@ -864,6 +865,7 @@ suite "HTTP server testing suite": res2.isOk() res2.get() == FlagsVectors[i] res3.isErr() + res4.isErr() res5.isOk() res5.get() == FlagsVectors[i]