From 0cb6840f03ea52545f0be0ce50cdce1a30f67d99 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 20 Jan 2021 15:40:15 +0200 Subject: [PATCH] Big refactoring of AsyncStreams. 1. Implement all read() primitives using readLoop() like it was done in streams. 2. Fix readLine() bug. 3. Add readMessage() primitive. 4. Fixing exception hierarchy, handling code and simplification of (break/continue + exception). 5. Fix TLSStream closure procedure. 6. Add BoundedStream stream and tests. 7. Remove `result` usage from the code. --- chronos/streams/asyncstream.nim | 734 ++++++++++++++++++-------------- chronos/streams/boundstream.nim | 212 +++++++++ chronos/streams/chunkstream.nim | 222 +++++----- chronos/streams/tlsstream.nim | 483 +++++++++++---------- tests/testasyncstream.nim | 127 +++++- 5 files changed, 1086 insertions(+), 692 deletions(-) create mode 100644 chronos/streams/boundstream.nim diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index bc1f6f9..905c57d 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -21,6 +21,16 @@ const ## AsyncStreamWriter leaks tracker name type + AsyncStreamError* = object of CatchableError + AsyncStreamIncorrectError* = object of Defect + AsyncStreamIncompleteError* = object of AsyncStreamError + AsyncStreamLimitError* = object of AsyncStreamError + AsyncStreamUseClosedError* = object of AsyncStreamError + AsyncStreamReadError* = object of AsyncStreamError + par*: ref CatchableError + AsyncStreamWriteError* = object of AsyncStreamError + par*: ref CatchableError + AsyncBuffer* = object offset*: int buffer*: seq[byte] @@ -60,7 +70,8 @@ type state*: AsyncStreamState buffer*: AsyncBuffer udata: pointer - error*: ref Exception + error*: ref AsyncStreamError + bytesCount*: uint64 future: Future[void] AsyncStreamWriter* = ref object of RootRef @@ -70,6 +81,7 @@ type state*: AsyncStreamState queue*: AsyncQueue[WriteItem] udata: pointer + bytesCount*: uint64 future: Future[void] AsyncStream* = object of RootObj @@ -82,36 +94,29 @@ type AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter - AsyncStreamError* = object of CatchableError - AsyncStreamIncompleteError* = object of AsyncStreamError - AsyncStreamIncorrectError* = object of Defect - AsyncStreamLimitError* = object of AsyncStreamError - AsyncStreamReadError* = object of AsyncStreamError - par*: ref Exception - AsyncStreamWriteError* = object of AsyncStreamError - par*: ref Exception - proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = - result.buffer = newSeq[byte](size) - result.events[0] = newAsyncEvent() - result.events[1] = newAsyncEvent() - result.offset = 0 + var res = AsyncBuffer( + buffer: newSeq[byte](size), + events: [newAsyncEvent(), newAsyncEvent()], + offset: 0 + ) + res proc getBuffer*(sb: AsyncBuffer): pointer {.inline.} = - result = unsafeAddr sb.buffer[sb.offset] + unsafeAddr sb.buffer[sb.offset] proc bufferLen*(sb: AsyncBuffer): int {.inline.} = - result = len(sb.buffer) - sb.offset + len(sb.buffer) - sb.offset proc getData*(sb: AsyncBuffer): pointer {.inline.} = - result = unsafeAddr sb.buffer[0] + unsafeAddr sb.buffer[0] -proc dataLen*(sb: AsyncBuffer): int {.inline.} = - result = sb.offset +template dataLen*(sb: AsyncBuffer): int = + sb.offset proc `[]`*(sb: AsyncBuffer, index: int): byte {.inline.} = doAssert(index < sb.offset) - result = sb.buffer[index] + sb.buffer[index] proc update*(sb: var AsyncBuffer, size: int) {.inline.} = sb.offset += size @@ -119,12 +124,12 @@ proc update*(sb: var AsyncBuffer, size: int) {.inline.} = proc wait*(sb: var AsyncBuffer): Future[void] = sb.events[0].clear() sb.events[1].fire() - result = sb.events[0].wait() + sb.events[0].wait() proc transfer*(sb: var AsyncBuffer): Future[void] = sb.events[1].clear() sb.events[0].fire() - result = sb.events[1].wait() + sb.events[1].wait() proc forget*(sb: var AsyncBuffer) {.inline.} = sb.events[1].clear() @@ -172,36 +177,47 @@ template copyOut*(dest: pointer, item: WriteItem, length: int) = elif item.kind == String: copyMem(dest, unsafeAddr item.data3[item.offset], length) -proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} = +proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. + inline.} = var w = newException(AsyncStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p - result = w + w -proc newAsyncStreamWriteError(p: ref Exception): ref Exception {.inline.} = +proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {. + inline.} = var w = newException(AsyncStreamWriteError, "Write stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p - result = w + w -proc newAsyncStreamIncompleteError(): ref Exception {.inline.} = - result = newException(AsyncStreamIncompleteError, "Incomplete data received") +proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {. + inline.} = + newException(AsyncStreamIncompleteError, "Incomplete data sent or received") -proc newAsyncStreamLimitError(): ref Exception {.inline.} = - result = newException(AsyncStreamLimitError, "Buffer limit reached") +proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.inline.} = + newException(AsyncStreamLimitError, "Buffer limit reached") -proc newAsyncStreamIncorrectError(m: string): ref Exception {.inline.} = - result = newException(AsyncStreamIncorrectError, m) +proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.inline.} = + newException(AsyncStreamUseClosedError, "Stream is already closed") + +proc newAsyncStreamIncorrectError*(m: string): ref AsyncStreamIncorrectError {. + inline.} = + newException(AsyncStreamIncorrectError, m) + +template checkRunning(t: untyped) = + if not(t.running()): + raise newAsyncStreamIncorrectError("Incorrect stream state") proc atEof*(rstream: AsyncStreamReader): bool = ## Returns ``true`` is reading stream is closed or finished and internal ## buffer do not have any bytes left. - result = rstream.state in {AsyncStreamState.Stopped, Finished, Closed} and - (rstream.buffer.dataLen() == 0) + rstream.state in {AsyncStreamState.Stopped, Finished, Closed} and + (rstream.buffer.dataLen() == 0) proc atEof*(wstream: AsyncStreamWriter): bool = ## Returns ``true`` is writing stream ``wstream`` closed or finished. - result = wstream.state in {AsyncStreamState.Stopped, Finished, Closed} + wstream.state in {AsyncStreamState.Stopped, Finished, Closed} proc closed*(rw: AsyncStreamRW): bool {.inline.} = ## Returns ``true`` is reading/writing stream is closed. @@ -223,32 +239,36 @@ proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} proc getAsyncStreamReaderTracker(): AsyncStreamTracker {.inline.} = - result = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName)) - if isNil(result): - result = setupAsyncStreamReaderTracker() + var res = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName)) + if isNil(res): + res = setupAsyncStreamReaderTracker() + res proc getAsyncStreamWriterTracker(): AsyncStreamTracker {.inline.} = - result = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName)) - if isNil(result): - result = setupAsyncStreamWriterTracker() + var res = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName)) + if isNil(res): + res = setupAsyncStreamWriterTracker() + res proc dumpAsyncStreamReaderTracking(): string {.gcsafe.} = var tracker = getAsyncStreamReaderTracker() - result = "Opened async stream readers: " & $tracker.opened & "\n" & - "Closed async stream readers: " & $tracker.closed + let res = "Opened async stream readers: " & $tracker.opened & "\n" & + "Closed async stream readers: " & $tracker.closed + res proc dumpAsyncStreamWriterTracking(): string {.gcsafe.} = var tracker = getAsyncStreamWriterTracker() - result = "Opened async stream writers: " & $tracker.opened & "\n" & - "Closed async stream writers: " & $tracker.closed + let res = "Opened async stream writers: " & $tracker.opened & "\n" & + "Closed async stream writers: " & $tracker.closed + res proc leakAsyncStreamReader(): bool {.gcsafe.} = var tracker = getAsyncStreamReaderTracker() - result = tracker.opened != tracker.closed + tracker.opened != tracker.closed proc leakAsyncStreamWriter(): bool {.gcsafe.} = var tracker = getAsyncStreamWriterTracker() - result = tracker.opened != tracker.closed + tracker.opened != tracker.closed proc trackAsyncStreamReader(t: AsyncStreamReader) {.inline.} = var tracker = getAsyncStreamReaderTracker() @@ -267,20 +287,39 @@ proc untrackAsyncStreamWriter*(t: AsyncStreamWriter) {.inline.} = inc(tracker.closed) proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} = - result = new AsyncStreamTracker - result.opened = 0 - result.closed = 0 - result.dump = dumpAsyncStreamReaderTracking - result.isLeaked = leakAsyncStreamReader - addTracker(AsyncStreamReaderTrackerName, result) + var res = AsyncStreamTracker( + opened: 0, + closed: 0, + dump: dumpAsyncStreamReaderTracking, + isLeaked: leakAsyncStreamReader + ) + addTracker(AsyncStreamReaderTrackerName, res) + res proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} = - result = new AsyncStreamTracker - result.opened = 0 - result.closed = 0 - result.dump = dumpAsyncStreamWriterTracking - result.isLeaked = leakAsyncStreamWriter - addTracker(AsyncStreamWriterTrackerName, result) + var res = AsyncStreamTracker( + opened: 0, + closed: 0, + dump: dumpAsyncStreamWriterTracking, + isLeaked: leakAsyncStreamWriter + ) + addTracker(AsyncStreamWriterTrackerName, res) + res + +template readLoop(body: untyped): untyped = + while true: + if rstream.buffer.dataLen() == 0: + if rstream.state == AsyncStreamState.Error: + raise rstream.error + + let (consumed, done) = body + + rstream.buffer.shift(consumed) + rstream.bytesCount = rstream.bytesCount + uint64(consumed) + if done: + break + else: + await rstream.buffer.wait() proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int) {.async.} = @@ -292,17 +331,16 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(nbytes >= 0, "nbytes must be non-negative integer") + checkRunning(rstream) + if nbytes == 0: return - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") - if isNil(rstream.rsource): try: await readExactly(rstream.tsource, pbytes, nbytes) - except CancelledError: - raise + except CancelledError as exc: + raise exc except TransportIncompleteError: raise newAsyncStreamIncompleteError() except CatchableError as exc: @@ -312,61 +350,47 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, await readExactly(rstream.rsource, pbytes, nbytes) else: var index = 0 - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): - raise newAsyncStreamIncompleteError() - - if datalen >= (nbytes - index): - rstream.buffer.copyData(pbytes, index, nbytes - index) - rstream.buffer.shift(nbytes - index) - break - else: - rstream.buffer.copyData(pbytes, index, datalen) - index += datalen - rstream.buffer.shift(datalen) - await rstream.buffer.wait() + var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) + readLoop(): + if rstream.buffer.dataLen() == 0: + if rstream.atEof(): + raise newAsyncStreamIncompleteError() + let count = min(nbytes - index, rstream.buffer.dataLen()) + if count > 0: + rstream.buffer.copyData(addr pbuffer[index], 0, count) + index += count + (consumed: count, done: index == nbytes) proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int): Future[int] {.async.} = ## Perform one read operation on read-only stream ``rstream``. ## ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from - ## internal buffer, otherwise it will wait until some bytes will be received. + ## internal buffer, otherwise it will wait until some bytes will be available. doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(nbytes > 0, "nbytes must be positive value") - - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(rstream) if isNil(rstream.rsource): try: - result = await readOnce(rstream.tsource, pbytes, nbytes) - except CancelledError: - raise + return await readOnce(rstream.tsource, pbytes, nbytes) + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await readOnce(rstream.rsource, pbytes, nbytes) + return await readOnce(rstream.rsource, pbytes, nbytes) else: - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0: - if rstream.atEof(): - result = 0 - break - await rstream.buffer.wait() + var count = 0 + readLoop(): + if rstream.buffer.dataLen() == 0: + (0, rstream.atEof()) else: - let size = min(datalen, nbytes) - rstream.buffer.copyData(pbytes, 0, size) - rstream.buffer.shift(size) - result = size - break + count = min(rstream.buffer.dataLen(), nbytes) + rstream.buffer.copyData(pbytes, 0, count) + (count, true) + return count proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, sep: seq[byte]): Future[int] {.async.} = @@ -386,18 +410,16 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(len(sep) > 0, "separator must not be empty") doAssert(nbytes >= 0, "nbytes must be non-negative value") + checkRunning(rstream) if nbytes == 0: raise newAsyncStreamLimitError() - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") - if isNil(rstream.rsource): try: - result = await readUntil(rstream.tsource, pbytes, nbytes, sep) - except CancelledError: - raise + return await readUntil(rstream.tsource, pbytes, nbytes, sep) + except CancelledError as exc: + raise exc except TransportIncompleteError: raise newAsyncStreamIncompleteError() except TransportLimitError: @@ -406,43 +428,30 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await readUntil(rstream.rsource, pbytes, nbytes, sep) + return await readUntil(rstream.rsource, pbytes, nbytes, sep) else: - var - dest = cast[ptr UncheckedArray[byte]](pbytes) - state = 0 - k = 0 - - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): + var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) + var state = 0 + var k = 0 + readLoop(): + if rstream.atEof(): raise newAsyncStreamIncompleteError() - var index = 0 - while index < datalen: + while index < rstream.buffer.dataLen(): + if k >= nbytes: + raise newAsyncStreamLimitError() let ch = rstream.buffer[index] + inc(index) + pbuffer[k] = ch + inc(k) if sep[state] == ch: inc(state) + if state == len(sep): + break else: state = 0 - if k < nbytes: - dest[k] = ch - inc(k) - else: - raise newAsyncStreamLimitError() - if state == len(sep): - break - inc(index) - - if state == len(sep): - rstream.buffer.shift(index + 1) - result = k - break - else: - rstream.buffer.shift(datalen) - await rstream.buffer.wait() + (index, state == len(sep)) + return k proc readLine*(rstream: AsyncStreamReader, limit = 0, sep = "\r\n"): Future[string] {.async.} = @@ -457,155 +466,207 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0, ## ## If ``limit`` more then 0, then result string will be limited to ``limit`` ## bytes. - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(rstream) if isNil(rstream.rsource): try: - result = await readLine(rstream.tsource, limit, sep) - except CancelledError: - raise + return await readLine(rstream.tsource, limit, sep) + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await readLine(rstream.rsource, limit, sep) + return await readLine(rstream.rsource, limit, sep) else: + let lim = if limit <= 0: -1 else: limit + var state = 0 var res = "" - var - lim = if limit <= 0: -1 else: limit - state = 0 - - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): - result = res - break - - var index = 0 - while index < datalen: - let ch = char(rstream.buffer[index]) - if sep[state] == ch: - inc(state) - if state == len(sep) or len(res) == lim: - rstream.buffer.shift(index + 1) - break - else: - state = 0 - res.add(ch) - if len(res) == lim: - rstream.buffer.shift(index + 1) - break - inc(index) - - if state == len(sep) or (lim == len(res)): - result = res - break + readLoop(): + if rstream.atEof(): + (0, true) else: - rstream.buffer.shift(datalen) - await rstream.buffer.wait() + var index = 0 + while index < rstream.buffer.dataLen(): + let ch = char(rstream.buffer[index]) + inc(index) -proc read*(rstream: AsyncStreamReader, n = 0): Future[seq[byte]] {.async.} = - ## Read all bytes (n <= 0) or exactly `n` bytes from read-only stream - ## ``rstream``. + if sep[state] == ch: + inc(state) + if state == len(sep): + break + else: + if state != 0: + if limit > 0: + let missing = min(state, lim - len(res) - 1) + res.add(sep[0 ..< missing]) + else: + res.add(sep[0 ..< state]) + res.add(ch) + if len(res) == lim: + break + (index, (state == len(sep)) or (lim == len(res))) + return res + +proc read*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = + ## Read all bytes from read-only stream ``rstream``. ## ## This procedure allocates buffer seq[byte] and return it as result. - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(rstream) if isNil(rstream.rsource): try: - result = await read(rstream.tsource, n) - except CancelledError: - raise - except CatchableError as exc: - raise newAsyncStreamReadError(exc) - else: - if isNil(rstream.readerLoop): - result = await read(rstream.rsource, n) - else: - var res = newSeq[byte]() - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): - result = res - break - - if datalen > 0: - let s = len(res) - let o = s + datalen - if n <= 0: - res.setLen(o) - rstream.buffer.copyData(addr res[s], 0, datalen) - rstream.buffer.shift(datalen) - else: - let left = n - s - if datalen >= left: - res.setLen(n) - rstream.buffer.copyData(addr res[s], 0, left) - rstream.buffer.shift(left) - result = res - break - else: - res.setLen(o) - rstream.buffer.copyData(addr res[s], 0, datalen) - rstream.buffer.shift(datalen) - - await rstream.buffer.wait() - -proc consume*(rstream: AsyncStreamReader, n = -1): Future[int] {.async.} = - ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream - ## ``rstream``. - ## - ## Return number of bytes actually consumed (discarded). - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") - - if isNil(rstream.rsource): - try: - result = await consume(rstream.tsource, n) - except CancelledError: - raise + return await read(rstream.tsource) + except CancelledError as exc: + raise exc except TransportLimitError: raise newAsyncStreamLimitError() except CatchableError as exc: raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await consume(rstream.rsource, n) + return await read(rstream.rsource) + else: + var res = newSeq[byte]() + readLoop(): + if rstream.atEof(): + (0, true) + else: + let count = rstream.buffer.dataLen() + res.add(rstream.buffer.buffer.toOpenArray(0, count - 1)) + (count, false) + return res + +proc read*(rstream: AsyncStreamReader, n: int): Future[seq[byte]] {.async.} = + ## Read all bytes (n <= 0) or exactly `n` bytes from read-only stream + ## ``rstream``. + ## + ## This procedure allocates buffer seq[byte] and return it as result. + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + return await read(rstream.tsource, n) + except CancelledError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + return await read(rstream.rsource, n) + else: + if n <= 0: + return await read(rstream.rsource) + else: + var res = newSeq[byte]() + readLoop(): + if rstream.atEof(): + (0, true) + else: + let count = min(rstream.buffer.dataLen(), n - len(res)) + res.add(rstream.buffer.buffer.toOpenArray(0, count - 1)) + (count, len(res) == n) + return res + +proc consume*(rstream: AsyncStreamReader): Future[int] {.async.} = + ## Consume (discard) all bytes from read-only stream ``rstream``. + ## + ## Return number of bytes actually consumed (discarded). + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + return await consume(rstream.tsource) + except CancelledError as exc: + raise exc + except TransportLimitError: + raise newAsyncStreamLimitError() + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + return await consume(rstream.rsource) else: var res = 0 - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0: - if rstream.atEof(): - if n <= 0: - result = res - break - else: - raise newAsyncStreamLimitError() + readLoop(): + if rstream.atEof(): + (0, true) else: - if n <= 0: - res += datalen - rstream.buffer.shift(datalen) - else: - let left = n - res - if datalen >= left: - res += left - rstream.buffer.shift(left) - result = res - break - else: - res += datalen - rstream.buffer.shift(datalen) + res += rstream.buffer.dataLen() + (rstream.buffer.dataLen(), false) + return res - await rstream.buffer.wait() +proc consume*(rstream: AsyncStreamReader, n: int): Future[int] {.async.} = + ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream + ## ``rstream``. + ## + ## Return number of bytes actually consumed (discarded). + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + return await consume(rstream.tsource, n) + except CancelledError as exc: + raise exc + except TransportLimitError: + raise newAsyncStreamLimitError() + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + return await consume(rstream.rsource, n) + else: + if n <= 0: + return await rstream.consume() + else: + var res = 0 + readLoop(): + if rstream.atEof(): + (0, true) + else: + let count = min(rstream.buffer.dataLen(), n - res) + res += count + (count, res == n) + return res + +proc readMessage*(rstream: AsyncStreamReader, pred: ReadMessagePredicate) {. + async.} = + ## Read all bytes from stream ``rstream`` until ``predicate`` callback + ## will not be satisfied. + ## + ## ``predicate`` callback should return tuple ``(consumed, result)``, where + ## ``consumed`` is the number of bytes processed and ``result`` is a + ## completion flag (``true`` if readMessage() should stop reading data, + ## or ``false`` if readMessage() should continue to read data from stream). + ## + ## ``predicate`` callback must copy all the data from ``data`` array and + ## return number of bytes it is going to consume. + ## ``predicate`` callback will receive (zero-length) openarray, if stream + ## is at EOF. + doAssert(not(isNil(pred)), "`predicate` callback should not be `nil`") + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + await readMessage(rstream.tsource, pred) + except CancelledError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + await readMessage(rstream.rsource, pred) + else: + readLoop(): + let count = rstream.buffer.dataLen() + if count == 0: + if rstream.atEof(): + pred([]) + else: + # Case, when transport's buffer is not yet filled with data. + (0, false) + else: + pred(rstream.buffer.buffer.toOpenArray(0, count - 1)) proc write*(wstream: AsyncStreamWriter, pbytes: pointer, nbytes: int) {.async.} = @@ -613,8 +674,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, ## writer stream ``wstream``. ## ## ``nbytes` must be more then zero. - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(wstream) if nbytes <= 0: raise newAsyncStreamIncorrectError("Zero length message") @@ -622,27 +682,34 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, var res: int try: res = await write(wstream.tsource, pbytes, nbytes) - except CancelledError: - raise + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamWriteError(exc) if res != nbytes: raise newAsyncStreamIncompleteError() + wstream.bytesCount = wstream.bytesCount + uint64(nbytes) else: if isNil(wstream.writerLoop): await write(wstream.wsource, pbytes, nbytes) + wstream.bytesCount = wstream.bytesCount + uint64(nbytes) else: var item = WriteItem(kind: Pointer) item.data1 = pbytes item.size = nbytes item.future = newFuture[void]("async.stream.write(pointer)") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + wstream.bytesCount = wstream.bytesCount + uint64(item.size) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], msglen = -1) {.async.} = @@ -654,10 +721,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], ## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream. ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. + checkRunning(wstream) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) - - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") if length <= 0: raise newAsyncStreamIncorrectError("Zero length message") @@ -665,15 +730,17 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], var res: int try: res = await write(wstream.tsource, sbytes, msglen) - except CancelledError: - raise + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamWriteError(exc) if res != length: raise newAsyncStreamIncompleteError() + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: if isNil(wstream.writerLoop): await write(wstream.wsource, sbytes, msglen) + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: var item = WriteItem(kind: Sequence) if not isLiteral(sbytes): @@ -682,13 +749,16 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], item.data2 = sbytes item.size = length item.future = newFuture[void]("async.stream.write(seq)") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + wstream.bytesCount = wstream.bytesCount + uint64(item.size) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc write*(wstream: AsyncStreamWriter, sbytes: string, msglen = -1) {.async.} = @@ -699,10 +769,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, ## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream. ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. + checkRunning(wstream) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) - - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") if length <= 0: raise newAsyncStreamIncorrectError("Zero length message") @@ -710,15 +778,17 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, var res: int try: res = await write(wstream.tsource, sbytes, msglen) - except CancelledError: - raise + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamWriteError(exc) if res != length: raise newAsyncStreamIncompleteError() + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: if isNil(wstream.writerLoop): await write(wstream.wsource, sbytes, msglen) + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: var item = WriteItem(kind: String) if not isLiteral(sbytes): @@ -727,18 +797,20 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, item.data3 = sbytes item.size = length item.future = newFuture[void]("async.stream.write(string)") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + wstream.bytesCount = wstream.bytesCount + uint64(item.size) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc finish*(wstream: AsyncStreamWriter) {.async.} = ## Finish write stream ``wstream``. - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(wstream) if not isNil(wstream.wsource): if isNil(wstream.writerLoop): @@ -747,13 +819,15 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} = var item = WriteItem(kind: Pointer) item.size = 0 item.future = newFuture[void]("async.stream.finish") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc join*(rw: AsyncStreamRW): Future[void] = ## Get Future[void] which will be completed when stream become finished or @@ -766,12 +840,12 @@ proc join*(rw: AsyncStreamRW): Future[void] = proc continuation(udata: pointer) {.gcsafe.} = retFuture.complete() - proc cancel(udata: pointer) {.gcsafe.} = + proc cancellation(udata: pointer) {.gcsafe.} = rw.future.removeCallback(continuation, cast[pointer](retFuture)) if not(rw.future.finished()): rw.future.addCallback(continuation, cast[pointer](retFuture)) - rw.future.cancelCallback = cancel + rw.future.cancelCallback = cancellation else: retFuture.complete() @@ -818,7 +892,7 @@ proc close*(rw: AsyncStreamRW) = proc closeWait*(rw: AsyncStreamRW): Future[void] = ## Close and frees resources of stream ``rw``. rw.close() - result = rw.join() + rw.join() proc startReader(rstream: AsyncStreamReader) = rstream.state = Running @@ -981,8 +1055,9 @@ proc newAsyncStreamReader*[T](rsource: AsyncStreamReader, ## ## ``udata`` - user object which will be associated with new AsyncStreamReader ## object. - result = new AsyncStreamReader - result.init(rsource, loop, bufferSize, udata) + var res = AsyncStreamReader() + res.init(rsource, loop, bufferSize, udata) + res proc newAsyncStreamReader*(rsource: AsyncStreamReader, loop: StreamReaderLoop, @@ -994,8 +1069,9 @@ proc newAsyncStreamReader*(rsource: AsyncStreamReader, ## ``loop`` is main reading loop procedure. ## ## ``bufferSize`` is internal buffer size. - result = new AsyncStreamReader - result.init(rsource, loop, bufferSize) + var res = AsyncStreamReader() + res.init(rsource, loop, bufferSize) + res proc newAsyncStreamReader*[T](tsource: StreamTransport, udata: ref T): AsyncStreamReader = @@ -1004,14 +1080,16 @@ proc newAsyncStreamReader*[T](tsource: StreamTransport, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamReader - result.init(tsource, udata) + var res = AsyncStreamReader() + res.init(tsource, udata) + res proc newAsyncStreamReader*(tsource: StreamTransport): AsyncStreamReader = ## Create new AsyncStreamReader object, which will use stream transport ## ``tsource`` as source data channel. - result = new AsyncStreamReader - result.init(tsource) + var res = AsyncStreamReader() + res.init(tsource) + res proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, loop: StreamWriterLoop, @@ -1026,8 +1104,9 @@ proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamWriter - result.init(wsource, loop, queueSize, udata) + var res = AsyncStreamWriter() + res.init(wsource, loop, queueSize, udata) + res proc newAsyncStreamWriter*(wsource: AsyncStreamWriter, loop: StreamWriterLoop, @@ -1039,8 +1118,9 @@ proc newAsyncStreamWriter*(wsource: AsyncStreamWriter, ## ``loop`` is main writing loop procedure. ## ## ``queueSize`` is writing queue size (default size is unlimited). - result = new AsyncStreamWriter - result.init(wsource, loop, queueSize) + var res = AsyncStreamWriter() + res.init(wsource, loop, queueSize) + res proc newAsyncStreamWriter*[T](tsource: StreamTransport, udata: ref T): AsyncStreamWriter = @@ -1049,14 +1129,16 @@ proc newAsyncStreamWriter*[T](tsource: StreamTransport, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamWriter - result.init(tsource, udata) + var res = AsyncStreamWriter() + res.init(tsource, udata) + res proc newAsyncStreamWriter*(tsource: StreamTransport): AsyncStreamWriter = ## Create new AsyncStreamWriter object which will use stream transport ## ``tsource`` as data channel. - result = new AsyncStreamWriter - result.init(tsource) + var res = AsyncStreamWriter() + res.init(tsource) + res proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, udata: ref T): AsyncStreamWriter = @@ -1064,13 +1146,15 @@ proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamWriter - result.init(wsource, udata) + var res = AsyncStreamWriter() + res.init(wsource, udata) + res proc newAsyncStreamWriter*(wsource: AsyncStreamWriter): AsyncStreamWriter = ## Create copy of AsyncStreamWriter object ``wsource``. - result = new AsyncStreamWriter - result.init(wsource) + var res = AsyncStreamWriter() + res.init(wsource) + res proc newAsyncStreamReader*[T](rsource: AsyncStreamWriter, udata: ref T): AsyncStreamWriter = @@ -1078,15 +1162,17 @@ proc newAsyncStreamReader*[T](rsource: AsyncStreamWriter, ## ## ``udata`` - user object which will be associated with new AsyncStreamReader ## object. - result = new AsyncStreamReader - result.init(rsource, udata) + var res = AsyncStreamReader() + res.init(rsource, udata) + res proc newAsyncStreamReader*(rsource: AsyncStreamReader): AsyncStreamReader = ## Create copy of AsyncStreamReader object ``rsource``. - result = new AsyncStreamReader - result.init(rsource) + var res = AsyncStreamReader() + res.init(rsource) + res proc getUserData*[T](rw: AsyncStreamRW): T {.inline.} = ## Obtain user data associated with AsyncStreamReader or AsyncStreamWriter ## object ``rw``. - result = cast[T](rw.udata) + cast[T](rw.udata) diff --git a/chronos/streams/boundstream.nim b/chronos/streams/boundstream.nim new file mode 100644 index 0000000..0136a86 --- /dev/null +++ b/chronos/streams/boundstream.nim @@ -0,0 +1,212 @@ +# +# Chronos Asynchronous Bound Stream +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module implements bounded stream reading and writing. +## +## For stream reading it means that you should read exactly bounded size of +## bytes. +## +## For stream writing it means that you should write exactly bounded size +## of bytes, and if you wrote not enough bytes error will appear on stream +## close. +import ../asyncloop, ../timer +import asyncstream, ../transports/stream, ../transports/common +export asyncstream, stream, timer, common + +type + BoundedStreamReader* = ref object of AsyncStreamReader + boundSize: uint64 + offset: uint64 + + BoundedStreamWriter* = ref object of AsyncStreamWriter + boundSize: uint64 + offset: uint64 + + BoundedStreamError* = object of AsyncStreamError + BoundedStreamIncompleteError* = object of BoundedStreamError + BoundedStreamOverflowError* = object of BoundedStreamError + + BoundedStreamRW* = BoundedStreamReader | BoundedStreamWriter + +const + BoundedBufferSize* = 4096 + +template newBoundedStreamIncompleteError*(): ref BoundedStreamError = + newException(BoundedStreamIncompleteError, + "Stream boundary is not reached yet") +template newBoundedStreamOverflowError*(): ref BoundedStreamError = + newException(BoundedStreamOverflowError, "Stream boundary exceeded") + +proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = + var rstream = cast[BoundedStreamReader](stream) + rstream.state = AsyncStreamState.Running + while true: + if rstream.offset < rstream.boundSize: + let toRead = int(min(rstream.boundSize - rstream.offset, + uint64(rstream.buffer.bufferLen()))) + try: + await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead) + rstream.offset = rstream.offset + uint64(toRead) + rstream.buffer.update(toRead) + await rstream.buffer.transfer() + except AsyncStreamIncompleteError: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() + except AsyncStreamReadError as exc: + rstream.state = AsyncStreamState.Error + rstream.error = exc + except CancelledError: + rstream.state = AsyncStreamState.Stopped + + if rstream.state != AsyncStreamState.Running: + break + else: + rstream.state = AsyncStreamState.Finished + await rstream.buffer.transfer() + break + + if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}: + # We need to notify consumer about error/close, but we do not care about + # incoming data anymore. + rstream.buffer.forget() + +proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = + var wstream = cast[BoundedStreamWriter](stream) + + wstream.state = AsyncStreamState.Running + while true: + var + item: WriteItem + error: ref AsyncStreamError + + try: + item = await wstream.queue.get() + if item.size > 0: + if uint64(item.size) <= (wstream.boundSize - wstream.offset): + # Writing chunk data. + case item.kind + of WriteType.Pointer: + await wstream.wsource.write(item.data1, item.size) + of WriteType.Sequence: + await wstream.wsource.write(addr item.data2[0], item.size) + of WriteType.String: + await wstream.wsource.write(addr item.data3[0], item.size) + wstream.offset = wstream.offset + uint64(item.size) + item.future.complete() + else: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamOverflowError() + else: + if wstream.offset != wstream.boundSize: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamIncompleteError() + else: + wstream.state = AsyncStreamState.Finished + item.future.complete() + except CancelledError: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() + except AsyncStreamWriteError as exc: + wstream.state = AsyncStreamState.Error + error = exc + except AsyncStreamIncompleteError as exc: + wstream.state = AsyncStreamState.Error + error = exc + + if wstream.state != AsyncStreamState.Running: + if wstream.state == AsyncStreamState.Finished: + error = newAsyncStreamUseClosedError() + else: + if not(isNil(item.future)): + if not(item.future.finished()): + item.future.fail(error) + while not(wstream.queue.empty()): + let pitem = wstream.queue.popFirstNoWait() + if not(pitem.future.finished()): + pitem.future.fail(error) + break + +proc bytesLeft*(stream: BoundedStreamRW): uint64 = + ## Returns number of bytes left in stream. + stream.boundSize - stream.bytesCount + +proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader, + bufferSize = BoundedBufferSize, udata: ref T) = + init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize, + udata) + +proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader, + bufferSize = BoundedBufferSize) = + init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize) + +proc newBoundedStreamReader*[T](rsource: AsyncStreamReader, + boundSize: uint64, + bufferSize = BoundedBufferSize, + udata: ref T): BoundedStreamReader = + var res = BoundedStreamReader(boundSize: boundSize) + res.init(rsource, bufferSize, udata) + res + +proc newBoundedStreamReader*(rsource: AsyncStreamReader, + boundSize: uint64, + bufferSize = BoundedBufferSize, + ): BoundedStreamReader = + doAssert(boundSize >= 0) + var res = BoundedStreamReader(boundSize: boundSize) + res.init(rsource, bufferSize) + res + +proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter, + queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = + init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize, + udata) + +proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter, + queueSize = AsyncStreamDefaultQueueSize) = + init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize) + +proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter, + boundSize: uint64, + queueSize = AsyncStreamDefaultQueueSize, + udata: ref T): BoundedStreamWriter = + var res = BoundedStreamWriter(boundSize: boundSize) + res.init(wsource, queueSize, udata) + res + +proc newBoundedStreamWriter*(wsource: AsyncStreamWriter, + boundSize: uint64, + queueSize = AsyncStreamDefaultQueueSize, + ): BoundedStreamWriter = + var res = BoundedStreamWriter(boundSize: boundSize) + res.init(wsource, queueSize) + res + +proc close*(rw: BoundedStreamRW) = + ## Close and frees resources of stream ``rw``. + ## + ## Note close() procedure is not completed immediately. + if rw.closed(): + raise newAsyncStreamIncorrectError("Stream is already closed!") + # We do not want to raise one more IncompleteError if it was already raised + # by one of the read()/write() primitives. + if rw.state != AsyncStreamState.Error: + if rw.bytesLeft() != 0'u64: + raise newBoundedStreamIncompleteError() + when rw is BoundedStreamReader: + cast[AsyncStreamReader](rw).close() + elif rw is BoundedStreamWriter: + cast[AsyncStreamWriter](rw).close() + +proc closeWait*(rw: BoundedStreamRW): Future[void] = + ## Close and frees resources of stream ``rw``. + rw.close() + when rw is BoundedStreamReader: + cast[AsyncStreamReader](rw).join() + elif rw is BoundedStreamWriter: + cast[AsyncStreamWriter](rw).join() diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index b847cac..9e65d8f 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -21,25 +21,44 @@ type ChunkedStreamReader* = ref object of AsyncStreamReader ChunkedStreamWriter* = ref object of AsyncStreamWriter - ChunkedStreamError* = object of CatchableError + ChunkedStreamError* = object of AsyncStreamError ChunkedStreamProtocolError* = object of ChunkedStreamError + ChunkedStreamIncompleteError* = object of ChunkedStreamError -proc newProtocolError(): ref Exception {.inline.} = +proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} = newException(ChunkedStreamProtocolError, "Protocol error!") +proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} = + newException(ChunkedStreamIncompleteError, "Incomplete data received!") + +proc `-`(x: uint32): uint32 {.inline.} = + result = (0xFFFF_FFFF'u32 - x) + 1'u32 + +proc LT(x, y: uint32): uint32 {.inline.} = + let z = x - y + (z xor ((y xor x) and (y xor z))) shr 31 + +proc hexValue(c: byte): int = + let x = uint32(c) - 0x30'u32 + let y = uint32(c) - 0x41'u32 + let z = uint32(c) - 0x61'u32 + let r = ((x + 1'u32) and -LT(x, 10)) or + ((y + 11'u32) and -LT(y, 6)) or + ((z + 11'u32) and -LT(z, 6)) + int(r) - 1 + proc getChunkSize(buffer: openarray[byte]): uint64 = # We using `uint64` representation, but allow only 2^32 chunk size, # ChunkHeaderSize. + var res = 0'u64 for i in 0..= byte('0') and ch <= byte('9'): - result = (result shl 4) or uint64(ch - byte('0')) - else: - result = (result shl 4) or uint64((ch and 0x0F) + 9) + let value = hexValue(buffer[i]) + if value >= 0: + res = (res shl 4) or uint64(value) else: - result = 0xFFFF_FFFF_FFFF_FFFF'u64 + res = 0xFFFF_FFFF_FFFF_FFFF'u64 break + res proc setChunkSize(buffer: var openarray[byte], length: int64): int = # Store length as chunk header size (hexadecimal value) with CRLF. @@ -53,7 +72,7 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int = buffer[0] = byte('0') buffer[1] = byte(0x0D) buffer[2] = byte(0x0A) - result = 3 + 3 else: while n != 0: var v = length and n @@ -68,161 +87,116 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int = i = i - 4 buffer[c] = byte(0x0D) buffer[c + 1] = byte(0x0A) - result = c + 2 + c + 2 proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[ChunkedStreamReader](stream) var buffer = newSeq[byte](1024) rstream.state = AsyncStreamState.Running - try: - while true: + while true: + try: # Reading chunk size - var ruFut1 = awaitne rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) - if ruFut1.failed(): - rstream.error = ruFut1.error - rstream.state = AsyncStreamState.Error - break - - let length = ruFut1.read() - var chunksize = getChunkSize(buffer.toOpenArray(0, - length - len(CRLF) - 1)) + let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) + var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1)) if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64: - rstream.error = newProtocolError() + rstream.error = newChunkedProtocolError() rstream.state = AsyncStreamState.Error - break elif chunksize > 0'u64: while chunksize > 0'u64: let toRead = min(int(chunksize), rstream.buffer.bufferLen()) - var reFut2 = awaitne rstream.rsource.readExactly( - rstream.buffer.getBuffer(), toRead) - if reFut2.failed(): - rstream.error = reFut2.error - rstream.state = AsyncStreamState.Error - break - + await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead) rstream.buffer.update(toRead) await rstream.buffer.transfer() chunksize = chunksize - uint64(toRead) - if rstream.state != AsyncStreamState.Running: - break + if rstream.state == AsyncStreamState.Running: + # Reading chunk trailing CRLF + await rstream.rsource.readExactly(addr buffer[0], 2) - # Reading chunk trailing CRLF - var reFut3 = awaitne rstream.rsource.readExactly(addr buffer[0], 2) - if reFut3.failed(): - rstream.error = reFut3.error - rstream.state = AsyncStreamState.Error - break - - if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: - rstream.error = newProtocolError() - rstream.state = AsyncStreamState.Error - break + if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: + rstream.error = newChunkedProtocolError() + rstream.state = AsyncStreamState.Error else: # Reading trailing line for last chunk - var ruFut4 = awaitne rstream.rsource.readUntil(addr buffer[0], - len(buffer), CRLF) - if ruFut4.failed(): - rstream.error = ruFut4.error - rstream.state = AsyncStreamState.Error - break - + discard await rstream.rsource.readUntil(addr buffer[0], + len(buffer), CRLF) rstream.state = AsyncStreamState.Finished await rstream.buffer.transfer() - break + except CancelledError: + rstream.state = AsyncStreamState.Stopped + except AsyncStreamIncompleteError: + rstream.state = AsyncStreamState.Error + rstream.error = newChunkedIncompleteError() + except AsyncStreamReadError as exc: + rstream.state = AsyncStreamState.Error + rstream.error = exc - except CancelledError: - rstream.state = AsyncStreamState.Stopped - finally: if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}: # We need to notify consumer about error/close, but we do not care about # incoming data anymore. rstream.buffer.forget() + break proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[ChunkedStreamWriter](stream) var buffer: array[16, byte] - var wFut1, wFut2: Future[void] - var error: ref Exception + var error: ref AsyncStreamError wstream.state = AsyncStreamState.Running - try: - while true: - # Getting new item from stream's queue. - var item = await wstream.queue.get() + while true: + var item: WriteItem + # Getting new item from stream's queue. + try: + item = await wstream.queue.get() # `item.size == 0` is marker of stream finish, while `item.size != 0` is # data's marker. if item.size > 0: let length = setChunkSize(buffer, int64(item.size)) # Writing chunk header CRLF. - wFut1 = awaitne wstream.wsource.write(addr buffer[0], length) - if wFut1.failed(): - error = wFut1.error - item.future.fail(error) - continue - + await wstream.wsource.write(addr buffer[0], length) # Writing chunk data. - if item.kind == Pointer: - wFut2 = awaitne wstream.wsource.write(item.data1, item.size) - elif item.kind == Sequence: - wFut2 = awaitne wstream.wsource.write(addr item.data2[0], item.size) - elif item.kind == String: - wFut2 = awaitne wstream.wsource.write(addr item.data3[0], item.size) - if wFut2.failed(): - error = wFut2.error - item.future.fail(error) - continue - + case item.kind + of WriteType.Pointer: + await wstream.wsource.write(item.data1, item.size) + of WriteType.Sequence: + await wstream.wsource.write(addr item.data2[0], item.size) + of WriteType.String: + await wstream.wsource.write(addr item.data3[0], item.size) # Writing chunk footer CRLF. - var wFut3 = awaitne wstream.wsource.write(CRLF) - if wFut3.failed(): - error = wFut3.error - item.future.fail(error) - continue - + await wstream.wsource.write(CRLF) # Everything is fine, completing queue item's future. item.future.complete() else: let length = setChunkSize(buffer, 0'i64) - # Write finish chunk `0`. - wFut1 = awaitne wstream.wsource.write(addr buffer[0], length) - if wFut1.failed(): - error = wFut1.error - item.future.fail(error) - # We break here, because this is last chunk - break - + await wstream.wsource.write(addr buffer[0], length) # Write trailing CRLF. - wFut2 = awaitne wstream.wsource.write(CRLF) - if wFut2.failed(): - error = wFut2.error - item.future.fail(error) - # We break here, because this is last chunk - break - + await wstream.wsource.write(CRLF) # Everything is fine, completing queue item's future. item.future.complete() - # Set stream state to Finished. wstream.state = AsyncStreamState.Finished - break - except CancelledError: - wstream.state = AsyncStreamState.Stopped - finally: - if wstream.state == AsyncStreamState.Stopped: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.complete() - elif wstream.state == AsyncStreamState.Error: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - if not isNil(error): + except CancelledError: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() + except AsyncStreamError as exc: + wstream.state = AsyncStreamState.Error + error = exc + + if wstream.state != AsyncStreamState.Running: + if wstream.state == AsyncStreamState.Finished: + error = newAsyncStreamUseClosedError() + else: + if not(isNil(item.future)): + if not(item.future.finished()): item.future.fail(error) + while not(wstream.queue.empty()): + let pitem = wstream.queue.popFirstNoWait() + if not(pitem.future.finished()): + pitem.future.fail(error) + break proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader, bufferSize = ChunkBufferSize, udata: ref T) = @@ -236,14 +210,16 @@ proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader, proc newChunkedStreamReader*[T](rsource: AsyncStreamReader, bufferSize = AsyncStreamDefaultBufferSize, udata: ref T): ChunkedStreamReader = - result = new ChunkedStreamReader - result.init(rsource, bufferSize, udata) + var res = ChunkedStreamReader() + res.init(rsource, bufferSize, udata) + res proc newChunkedStreamReader*(rsource: AsyncStreamReader, bufferSize = AsyncStreamDefaultBufferSize, ): ChunkedStreamReader = - result = new ChunkedStreamReader - result.init(rsource, bufferSize) + var res = ChunkedStreamReader() + res.init(rsource, bufferSize) + res proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = @@ -257,11 +233,13 @@ proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter, proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, udata: ref T): ChunkedStreamWriter = - result = new ChunkedStreamWriter - result.init(wsource, queueSize, udata) + var res = ChunkedStreamWriter() + res.init(wsource, queueSize, udata) + res proc newChunkedStreamWriter*(wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, ): ChunkedStreamWriter = - result = new ChunkedStreamWriter - result.init(wsource, queueSize) + var res = ChunkedStreamWriter() + res.init(wsource, queueSize) + res diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index ec8b267..32258a7 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -89,11 +89,11 @@ type SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream - TLSStreamError* = object of CatchableError + TLSStreamError* = object of AsyncStreamError TLSStreamProtocolError* = object of TLSStreamError errCode*: int -template newTLSStreamProtocolError[T](message: T): ref Exception = +template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = var msg = "" var code = 0 when T is string: @@ -110,13 +110,13 @@ template newTLSStreamProtocolError[T](message: T): ref Exception = err.errCode = code err -proc raiseTLSStreamProtoError*[T](message: T) = +template raiseTLSStreamProtoError*[T](message: T) = raise newTLSStreamProtocolError(message) proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[TLSStreamWriter](stream) var engine: ptr SslEngineContext - var error: ref Exception + var error: ref AsyncStreamError if wstream.kind == TLSStreamKind.Server: engine = addr wstream.scontext.eng @@ -125,86 +125,77 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = wstream.state = AsyncStreamState.Running - try: - var length: uint - while true: + while true: + var item: WriteItem + try: var state = engine.sslEngineCurrentState() - if (state and SSL_CLOSED) == SSL_CLOSED: wstream.state = AsyncStreamState.Finished - break - - if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: - if not(wstream.switchToReader.isSet()): - wstream.switchToReader.fire() - - if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: - await wstream.switchToWriter.wait() - wstream.switchToWriter.clear() - # We need to refresh `state` because we just returned from readerLoop. - continue - - if (state and SSL_SENDREC) == SSL_SENDREC: - # TLS record needs to be sent over stream. - length = 0'u - var buf = sslEngineSendrecBuf(engine, length) - doAssert(length != 0 and not isNil(buf)) - var fut = awaitne wstream.wsource.write(buf, int(length)) - if fut.cancelled(): - raise fut.error - elif fut.failed(): - error = fut.error - break - sslEngineSendrecAck(engine, length) - continue - - if (state and SSL_SENDAPP) == SSL_SENDAPP: - # Application data can be sent over stream. - if not(wstream.handshaked): - wstream.stream.reader.handshaked = true - wstream.handshaked = true - if not(isNil(wstream.handshakeFut)): - wstream.handshakeFut.complete() - - var item = await wstream.queue.get() - if item.size > 0: - length = 0'u - var buf = sslEngineSendappBuf(engine, length) - let toWrite = min(int(length), item.size) - copyOut(buf, item, toWrite) - if int(length) >= item.size: - # BearSSL is ready to accept whole item size. - sslEngineSendappAck(engine, uint(item.size)) - sslEngineFlush(engine, 0) - item.future.complete() - else: - # BearSSL is not ready to accept whole item, so we will send only - # part of item and adjust offset. - item.offset = item.offset + int(length) - item.size = item.size - int(length) - wstream.queue.addFirstNoWait(item) - sslEngineSendappAck(engine, length) - continue + else: + if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: + if not(wstream.switchToReader.isSet()): + wstream.switchToReader.fire() + if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + await wstream.switchToWriter.wait() + wstream.switchToWriter.clear() + # We need to refresh `state` because we just returned from readerLoop. else: - # Zero length item means finish - wstream.state = AsyncStreamState.Finished - break + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + var length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + await wstream.wsource.write(buf, int(length)) + sslEngineSendrecAck(engine, length) + elif (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + if not(wstream.handshaked): + wstream.stream.reader.handshaked = true + wstream.handshaked = true + if not(isNil(wstream.handshakeFut)): + wstream.handshakeFut.complete() + item = await wstream.queue.get() + if item.size > 0: + var length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + copyOut(buf, item, toWrite) + if int(length) >= item.size: + # BearSSL is ready to accept whole item size. + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + # BearSSL is not ready to accept whole item, so we will send + # only part of item and adjust offset. + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + else: + # Zero length item means finish, so we going to trigger TLS + # closure protocol. + sslEngineClose(engine) + except CancelledError: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() + except AsyncStreamError as exc: + wstream.state = AsyncStreamState.Error + error = exc - except CancelledError: - wstream.state = AsyncStreamState.Stopped - - finally: - if wstream.state == AsyncStreamState.Stopped: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.complete() - elif wstream.state == AsyncStreamState.Error: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.fail(error) - wstream.stream = nil + if wstream.state != AsyncStreamState.Running: + if wstream.state == AsyncStreamState.Finished: + error = newAsyncStreamUseClosedError() + else: + if not(isNil(item.future)): + if not(item.future.finished()): + item.future.fail(error) + while not(wstream.queue.empty()): + let pitem = wstream.queue.popFirstNoWait() + if not(pitem.future.finished()): + pitem.future.fail(error) + wstream.stream = nil + break proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[TLSStreamReader](stream) @@ -217,72 +208,61 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Running - try: - var length: uint - while true: + while true: + try: var state = engine.sslEngineCurrentState() if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() if err != 0: - raise newTLSStreamProtocolError(err) - rstream.state = AsyncStreamState.Stopped - break - - if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: - if not(rstream.switchToWriter.isSet()): - rstream.switchToWriter.fire() - - if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: - await rstream.switchToReader.wait() - rstream.switchToReader.clear() - # We need to refresh `state` because we just returned from writerLoop. - continue - - if (state and SSL_RECVREC) == SSL_RECVREC: - # TLS records required for further processing - length = 0'u - var buf = sslEngineRecvrecBuf(engine, length) - let res = await rstream.rsource.readOnce(buf, int(length)) - if res > 0: - sslEngineRecvrecAck(engine, uint(res)) - continue + rstream.error = newTLSStreamProtocolError(err) + rstream.state = AsyncStreamState.Error else: rstream.state = AsyncStreamState.Finished - break + else: + if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: + if not(rstream.switchToWriter.isSet()): + rstream.switchToWriter.fire() + if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + await rstream.switchToReader.wait() + rstream.switchToReader.clear() + # We need to refresh `state` because we just returned from writerLoop. + else: + if (state and SSL_RECVREC) == SSL_RECVREC: + # TLS records required for further processing + var length = 0'u + var buf = sslEngineRecvrecBuf(engine, length) + let res = await rstream.rsource.readOnce(buf, int(length)) + if res > 0: + sslEngineRecvrecAck(engine, uint(res)) + else: + # readOnce() returns `0` if stream is at EOF, so we initiate TLS + # closure procedure. + sslEngineClose(engine) + elif (state and SSL_RECVAPP) == SSL_RECVAPP: + # Application data can be recovered. + var length = 0'u + var buf = sslEngineRecvappBuf(engine, length) + await upload(addr rstream.buffer, buf, int(length)) + sslEngineRecvappAck(engine, length) + except CancelledError: + rstream.state = AsyncStreamState.Stopped + except AsyncStreamError as exc: + rstream.error = exc + rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() - if (state and SSL_RECVAPP) == SSL_RECVAPP: - # Application data can be recovered. - length = 0'u - var buf = sslEngineRecvappBuf(engine, length) - await upload(addr rstream.buffer, buf, int(length)) - sslEngineRecvappAck(engine, length) - continue - - except CancelledError: - rstream.state = AsyncStreamState.Stopped - except TLSStreamProtocolError as exc: - rstream.error = exc - rstream.state = AsyncStreamState.Error - if not(rstream.handshaked): - rstream.handshaked = true - rstream.stream.writer.handshaked = true - if not(isNil(rstream.handshakeFut)): - rstream.handshakeFut.fail(rstream.error) - rstream.switchToWriter.fire() - except AsyncStreamReadError as exc: - rstream.error = exc - rstream.state = AsyncStreamState.Error - if not(rstream.handshaked): - rstream.handshaked = true - rstream.stream.writer.handshaked = true - if not(isNil(rstream.handshakeFut)): - rstream.handshakeFut.fail(rstream.error) - rstream.switchToWriter.fire() - finally: - # Perform TLS cleanup procedure - sslEngineClose(engine) - rstream.buffer.forget() - rstream.stream = nil + if rstream.state != AsyncStreamState.Running: + # Perform TLS cleanup procedure + if rstream.state != AsyncStreamState.Finished: + sslEngineClose(engine) + rstream.buffer.forget() + rstream.stream = nil + break proc getSignerAlgo(xc: X509Certificate): int = ## Get certificate's signing algorithm. @@ -291,9 +271,9 @@ proc getSignerAlgo(xc: X509Certificate): int = x509DecoderPush(addr dc, xc.data, xc.dataLen) let err = x509DecoderLastError(addr dc) if err != 0: - result = -1 + -1 else: - result = int(x509DecoderGetSignerKeyType(addr dc)) + int(x509DecoderGetSignerKeyType(addr dc)) proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, @@ -318,54 +298,59 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. ## ## ``flags`` - custom TLS connection flags. - result = new TLSAsyncStream - var reader = TLSStreamReader(kind: TLSStreamKind.Client) - var writer = TLSStreamWriter(kind: TLSStreamKind.Client) - var switchToWriter = newAsyncEvent() - var switchToReader = newAsyncEvent() - reader.stream = result - writer.stream = result - reader.switchToReader = switchToReader - reader.switchToWriter = switchToWriter - writer.switchToReader = switchToReader - writer.switchToWriter = switchToWriter - result.reader = reader - result.writer = writer - reader.ccontext = addr result.ccontext - writer.ccontext = addr result.ccontext + let switchToWriter = newAsyncEvent() + let switchToReader = newAsyncEvent() + var res = TLSAsyncStream() + var reader = TLSStreamReader( + kind: TLSStreamKind.Client, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + ccontext: addr res.ccontext + ) + var writer = TLSStreamWriter( + kind: TLSStreamKind.Client, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + ccontext: addr res.ccontext + ) + res.reader = reader + res.writer = writer if TLSFlags.NoVerifyHost in flags: - sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0) - initNoAnchor(addr result.xwc, addr result.x509.vtable) - sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable) + sslClientInitFull(addr res.ccontext, addr res.x509, nil, 0) + initNoAnchor(addr res.xwc, addr res.x509.vtable) + sslEngineSetX509(addr res.ccontext.eng, addr res.xwc.vtable) else: - sslClientInitFull(addr result.ccontext, addr result.x509, + sslClientInitFull(addr res.ccontext, addr res.x509, unsafeAddr MozillaTrustAnchors[0], len(MozillaTrustAnchors)) let size = max(SSL_BUFSIZE_BIDI, bufferSize) - result.sbuffer = newSeq[byte](size) - sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0], - uint(len(result.sbuffer)), 1) - sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion), + res.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr res.ccontext.eng, addr res.sbuffer[0], + uint(len(res.sbuffer)), 1) + sslEngineSetVersions(addr res.ccontext.eng, uint16(minVersion), uint16(maxVersion)) if TLSFlags.NoVerifyServerName in flags: - let err = sslClientReset(addr result.ccontext, "", 0) + let err = sslClientReset(addr res.ccontext, "", 0) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") else: if len(serverName) == 0: raise newException(TLSStreamError, "serverName must not be empty string") - let err = sslClientReset(addr result.ccontext, serverName, 0) + let err = sslClientReset(addr res.ccontext, serverName, 0) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") - init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, bufferSize) - init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, bufferSize) + res proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, @@ -395,98 +380,104 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, if isNil(certificate) or len(certificate.certs) == 0: raiseTLSStreamProtoError("Incorrect certificate") - result = new TLSAsyncStream - var reader = TLSStreamReader(kind: TLSStreamKind.Server) - var writer = TLSStreamWriter(kind: TLSStreamKind.Server) - var switchToWriter = newAsyncEvent() - var switchToReader = newAsyncEvent() - reader.stream = result - writer.stream = result - reader.switchToReader = switchToReader - reader.switchToWriter = switchToWriter - writer.switchToReader = switchToReader - writer.switchToWriter = switchToWriter - result.reader = reader - result.writer = writer - reader.scontext = addr result.scontext - writer.scontext = addr result.scontext + let switchToWriter = newAsyncEvent() + let switchToReader = newAsyncEvent() + + var res = TLSAsyncStream() + var reader = TLSStreamReader( + kind: TLSStreamKind.Server, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + scontext: addr res.scontext + ) + var writer = TLSStreamWriter( + kind: TLSStreamKind.Server, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + scontext: addr res.scontext + ) + res.reader = reader + res.writer = writer if privateKey.kind == TLSKeyType.EC: let algo = getSignerAlgo(certificate.certs[0]) if algo == -1: raiseTLSStreamProtoError("Could not decode certificate") - sslServerInitFullEc(addr result.scontext, addr certificate.certs[0], + sslServerInitFullEc(addr res.scontext, addr certificate.certs[0], len(certificate.certs), cuint(algo), addr privateKey.eckey) elif privateKey.kind == TLSKeyType.RSA: - sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0], + sslServerInitFullRsa(addr res.scontext, addr certificate.certs[0], len(certificate.certs), addr privateKey.rsakey) let size = max(SSL_BUFSIZE_BIDI, bufferSize) - result.sbuffer = newSeq[byte](size) - sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0], - uint(len(result.sbuffer)), 1) - sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion), + res.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr res.scontext.eng, addr res.sbuffer[0], + uint(len(res.sbuffer)), 1) + sslEngineSetVersions(addr res.scontext.eng, uint16(minVersion), uint16(maxVersion)) if not isNil(cache): - sslServerSetCache(addr result.scontext, addr cache.context.vtable) + sslServerSetCache(addr res.scontext, addr cache.context.vtable) if TLSFlags.EnforceServerPref in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) + sslEngineAddFlags(addr res.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) if TLSFlags.NoRenegotiation in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION) + sslEngineAddFlags(addr res.scontext.eng, OPT_NO_RENEGOTIATION) if TLSFlags.TolerateNoClientAuth in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) + sslEngineAddFlags(addr res.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) if TLSFlags.FailOnAlpnMismatch in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) + sslEngineAddFlags(addr res.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) - let err = sslServerReset(addr result.scontext) + let err = sslServerReset(addr res.scontext) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") - init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, bufferSize) - init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, bufferSize) + res proc copyKey(src: RsaPrivateKey): TLSPrivateKey = ## Creates copy of RsaPrivateKey ``src``. var offset = 0 let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen - result = TLSPrivateKey(kind: TLSKeyType.RSA) - result.storage = newSeq[byte](keySize) - copyMem(addr result.storage[offset], src.p, src.plen) - result.rsakey.p = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.plen = src.plen + var res = TLSPrivateKey(kind: TLSKeyType.RSA, storage: newSeq[byte](keySize)) + copyMem(addr res.storage[offset], src.p, src.plen) + res.rsakey.p = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.plen = src.plen offset = offset + src.plen - copyMem(addr result.storage[offset], src.q, src.qlen) - result.rsakey.q = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.qlen = src.qlen + copyMem(addr res.storage[offset], src.q, src.qlen) + res.rsakey.q = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.qlen = src.qlen offset = offset + src.qlen - copyMem(addr result.storage[offset], src.dp, src.dplen) - result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.dplen = src.dplen + copyMem(addr res.storage[offset], src.dp, src.dplen) + res.rsakey.dp = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.dplen = src.dplen offset = offset + src.dplen - copyMem(addr result.storage[offset], src.dq, src.dqlen) - result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.dqlen = src.dqlen + copyMem(addr res.storage[offset], src.dq, src.dqlen) + res.rsakey.dq = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.dqlen = src.dqlen offset = offset + src.dqlen - copyMem(addr result.storage[offset], src.iq, src.iqlen) - result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.iqlen = src.iqlen - result.rsakey.nBitlen = src.nBitlen + copyMem(addr res.storage[offset], src.iq, src.iqlen) + res.rsakey.iq = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.iqlen = src.iqlen + res.rsakey.nBitlen = src.nBitlen + res proc copyKey(src: EcPrivateKey): TLSPrivateKey = ## Creates copy of EcPrivateKey ``src``. var offset = 0 let keySize = src.xlen - result = TLSPrivateKey(kind: TLSKeyType.EC) - result.storage = newSeq[byte](keySize) - copyMem(addr result.storage[offset], src.x, src.xlen) - result.eckey.x = cast[ptr cuchar](addr result.storage[offset]) - result.eckey.xlen = src.xlen - result.eckey.curve = src.curve + var res = TLSPrivateKey(kind: TLSKeyType.EC, storage: newSeq[byte](keySize)) + copyMem(addr res.storage[offset], src.x, src.xlen) + res.eckey.x = cast[ptr cuchar](addr res.storage[offset]) + res.eckey.xlen = src.xlen + res.eckey.curve = src.curve + res proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = ## Initialize TLS private key from array of bytes ``data``. @@ -502,12 +493,14 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = if err != 0: raiseTLSStreamProtoError(err) let keyType = skeyDecoderKeyType(addr ctx) - if keyType == KEYTYPE_RSA: - result = copyKey(ctx.key.rsa) - elif keyType == KEYTYPE_EC: - result = copyKey(ctx.key.ec) - else: - raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + let res = + if keyType == KEYTYPE_RSA: + copyKey(ctx.key.rsa) + elif keyType == KEYTYPE_EC: + copyKey(ctx.key.ec) + else: + raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + res proc pemDecode*(data: openarray[char]): seq[PEMElement] = ## Decode PEM encoded string and get array of binary blobs. @@ -515,7 +508,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] = raiseTLSStreamProtoError("Empty PEM message") var ctx: PemDecoderContext var pctx = new PEMContext - result = newSeq[PEMElement]() + var res = newSeq[PEMElement]() pemDecoderInit(addr ctx) proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} = @@ -544,12 +537,13 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] = elif event == PEM_END_OBJ: if inobj: elem.data = pctx.data - result.add(elem) + res.add(elem) inobj = false else: break else: raiseTLSStreamProtoError("Invalid PEM encoding") + res proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = ## Initialize TLS private key from string ``data``. @@ -558,13 +552,15 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = ## encoded string. ## ## Note that PKCS#1 PEM encoded objects are not supported. + var res: TLSPrivateKey var items = pemDecode(data) for item in items: if item.name == "PRIVATE KEY": - result = TLSPrivateKey.init(item.data) + res = TLSPrivateKey.init(item.data) break - if isNil(result): + if isNil(res): raiseTLSStreamProtoError("Could not find private key") + res proc init*(tt: typedesc[TLSCertificate], data: openarray[char]): TLSCertificate = @@ -572,32 +568,33 @@ proc init*(tt: typedesc[TLSCertificate], ## ## This procedure initializes array of certificates from PEM encoded string. var items = pemDecode(data) - result = new TLSCertificate + var res = TLSCertificate() for item in items: if item.name == "CERTIFICATE" and len(item.data) > 0: - let offset = len(result.storage) - result.storage.add(item.data) + let offset = len(res.storage) + res.storage.add(item.data) let cert = X509Certificate( - data: cast[ptr cuchar](addr result.storage[offset]), + data: cast[ptr cuchar](addr res.storage[offset]), dataLen: len(item.data) ) - let res = getSignerAlgo(cert) - if res == -1: + let ares = getSignerAlgo(cert) + if ares == -1: raiseTLSStreamProtoError("Could not decode certificate") - elif res != KEYTYPE_RSA and res != KEYTYPE_EC: + elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC: raiseTLSStreamProtoError("Unsupported signing key type in certificate") - result.certs.add(cert) - if len(result.storage) == 0: + res.certs.add(cert) + if len(res.storage) == 0: raiseTLSStreamProtoError("Could not find any certificates") + res proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = ## Create new TLS session cache with size ``size``. ## ## One cached item is near 100 bytes size. - result = new TLSSessionCache var rsize = min(size, 4096) - result.storage = newSeq[byte](rsize) - sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize) + var res = TLSSessionCache(storage: newSeq[byte](rsize)) + sslSessionCacheLruInit(addr res.context, addr res.storage[0], rsize) + res proc handshake*(rws: SomeTLSStreamType): Future[void] = ## Wait until initial TLS handshake will be successfully performed. @@ -620,4 +617,4 @@ proc handshake*(rws: SomeTLSStreamType): Future[void] = else: rws.reader.handshakeFut = retFuture rws.writer.handshakeFut = retFuture - return retFuture + retFuture diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index d2720f5..13b426c 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -6,7 +6,8 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import unittest -import ../chronos, ../chronos/streams/tlsstream +import ../chronos +import ../chronos/streams/[tlsstream, chunkstream, boundstream] when defined(nimHasUsed): {.used.} @@ -553,8 +554,12 @@ suite "ChunkedStream test suite": try: var r = await rstream2.read() doAssert(len(r) > 0) - except AsyncStreamReadError: - res = true + except ChunkedStreamIncompleteError: + if inputstr == "100000000 \r\n1": + res = true + except ChunkedStreamProtocolError: + if inputstr == "z\r\n1": + res = true await rstream2.closeWait() await rstream.closeWait() await transp.closeWait() @@ -663,3 +668,119 @@ suite "TLSStream test suite": getTracker("async.stream.writer").isLeaked() == false getTracker("stream.server").isLeaked() == false getTracker("stream.transport").isLeaked() == false + +suite "BoundedStream test suite": + + proc createBigMessage(size: int): seq[byte] = + var message = "MESSAGE" + result = newSeq[byte](size) + for i in 0 ..< len(result): + result[i] = byte(message[i mod len(message)]) + + for item in [100'u64, 60000'u64]: + + proc boundedTest(address: TransportAddress, test: int, + size: uint64): Future[bool] {.async.} = + var clientRes = false + var res = false + + let messagePart = createBigMessage(int(item) div 10) + var message: seq[byte] + for i in 0 ..< 10: + message.add(messagePart) + + proc processClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + var wbstream = newBoundedStreamWriter(wstream, size) + if test == 0: + for i in 0 ..< 10: + await wbstream.write(messagePart) + await wbstream.finish() + await wbstream.closeWait() + clientRes = true + elif test == 1: + for i in 0 ..< 10: + await wbstream.write(messagePart) + try: + await wbstream.write(messagePart) + except BoundedStreamOverflowError: + clientRes = true + await wbstream.closeWait() + elif test == 2: + for i in 0 ..< 9: + await wbstream.write(messagePart) + try: + await wbstream.finish() + except BoundedStreamIncompleteError: + clientRes = true + await wbstream.closeWait() + elif test == 3: + for i in 0 ..< 10: + await wbstream.write(messagePart) + await wbstream.finish() + await wbstream.closeWait() + clientRes = true + elif test == 4: + for i in 0 ..< 9: + await wbstream.write(messagePart) + try: + await wbstream.closeWait() + except BoundedStreamIncompleteError: + clientRes = true + + await wstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, processClient, flags = {ReuseAddr}) + server.start() + var conn = await connect(address) + var rstream = newAsyncStreamReader(conn) + var rbstream = newBoundedStreamReader(rstream, size) + if test == 0: + let response = await rbstream.read() + await rbstream.closeWait() + if response == message: + res = true + elif test == 1: + let response = await rbstream.read() + await rbstream.closeWait() + if response == message: + res = true + elif test == 2: + try: + let response {.used.} = await rbstream.read() + except BoundedStreamIncompleteError: + res = true + await rbstream.closeWait() + elif test == 3: + let response {.used.} = await rbstream.read(int(size) - 1) + try: + await rbstream.closeWait() + except BoundedStreamIncompleteError: + res = true + elif test == 4: + try: + let response {.used.} = await rbstream.read() + except BoundedStreamIncompleteError: + res = true + await rbstream.closeWait() + + await rstream.closeWait() + await conn.closeWait() + await server.join() + return (res and clientRes) + + let address = initTAddress("127.0.0.1:48030") + test "BoundedStream reading/writing test [" & $item & "]": + check waitFor(boundedTest(address, 0, item)) == true + test "BoundedStream overflow test [" & $item & "]": + check waitFor(boundedTest(address, 1, item)) == true + test "BoundedStream incomplete test [" & $item & "]": + check waitFor(boundedTest(address, 2, item)) == true + test "BoundedStream read() close test [" & $item & "]": + check waitFor(boundedTest(address, 3, item)) == true + test "BoundedStream write() close test [" & $item & "]": + check waitFor(boundedTest(address, 4, item)) == true