diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index bc1f6f9..905c57d 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -21,6 +21,16 @@ const ## AsyncStreamWriter leaks tracker name type + AsyncStreamError* = object of CatchableError + AsyncStreamIncorrectError* = object of Defect + AsyncStreamIncompleteError* = object of AsyncStreamError + AsyncStreamLimitError* = object of AsyncStreamError + AsyncStreamUseClosedError* = object of AsyncStreamError + AsyncStreamReadError* = object of AsyncStreamError + par*: ref CatchableError + AsyncStreamWriteError* = object of AsyncStreamError + par*: ref CatchableError + AsyncBuffer* = object offset*: int buffer*: seq[byte] @@ -60,7 +70,8 @@ type state*: AsyncStreamState buffer*: AsyncBuffer udata: pointer - error*: ref Exception + error*: ref AsyncStreamError + bytesCount*: uint64 future: Future[void] AsyncStreamWriter* = ref object of RootRef @@ -70,6 +81,7 @@ type state*: AsyncStreamState queue*: AsyncQueue[WriteItem] udata: pointer + bytesCount*: uint64 future: Future[void] AsyncStream* = object of RootObj @@ -82,36 +94,29 @@ type AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter - AsyncStreamError* = object of CatchableError - AsyncStreamIncompleteError* = object of AsyncStreamError - AsyncStreamIncorrectError* = object of Defect - AsyncStreamLimitError* = object of AsyncStreamError - AsyncStreamReadError* = object of AsyncStreamError - par*: ref Exception - AsyncStreamWriteError* = object of AsyncStreamError - par*: ref Exception - proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = - result.buffer = newSeq[byte](size) - result.events[0] = newAsyncEvent() - result.events[1] = newAsyncEvent() - result.offset = 0 + var res = AsyncBuffer( + buffer: newSeq[byte](size), + events: [newAsyncEvent(), newAsyncEvent()], + offset: 0 + ) + res proc getBuffer*(sb: AsyncBuffer): pointer {.inline.} = - result = unsafeAddr sb.buffer[sb.offset] + unsafeAddr sb.buffer[sb.offset] proc bufferLen*(sb: AsyncBuffer): int {.inline.} = - result = len(sb.buffer) - sb.offset + len(sb.buffer) - sb.offset proc getData*(sb: AsyncBuffer): pointer {.inline.} = - result = unsafeAddr sb.buffer[0] + unsafeAddr sb.buffer[0] -proc dataLen*(sb: AsyncBuffer): int {.inline.} = - result = sb.offset +template dataLen*(sb: AsyncBuffer): int = + sb.offset proc `[]`*(sb: AsyncBuffer, index: int): byte {.inline.} = doAssert(index < sb.offset) - result = sb.buffer[index] + sb.buffer[index] proc update*(sb: var AsyncBuffer, size: int) {.inline.} = sb.offset += size @@ -119,12 +124,12 @@ proc update*(sb: var AsyncBuffer, size: int) {.inline.} = proc wait*(sb: var AsyncBuffer): Future[void] = sb.events[0].clear() sb.events[1].fire() - result = sb.events[0].wait() + sb.events[0].wait() proc transfer*(sb: var AsyncBuffer): Future[void] = sb.events[1].clear() sb.events[0].fire() - result = sb.events[1].wait() + sb.events[1].wait() proc forget*(sb: var AsyncBuffer) {.inline.} = sb.events[1].clear() @@ -172,36 +177,47 @@ template copyOut*(dest: pointer, item: WriteItem, length: int) = elif item.kind == String: copyMem(dest, unsafeAddr item.data3[item.offset], length) -proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} = +proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. + inline.} = var w = newException(AsyncStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p - result = w + w -proc newAsyncStreamWriteError(p: ref Exception): ref Exception {.inline.} = +proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {. + inline.} = var w = newException(AsyncStreamWriteError, "Write stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.par = p - result = w + w -proc newAsyncStreamIncompleteError(): ref Exception {.inline.} = - result = newException(AsyncStreamIncompleteError, "Incomplete data received") +proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {. + inline.} = + newException(AsyncStreamIncompleteError, "Incomplete data sent or received") -proc newAsyncStreamLimitError(): ref Exception {.inline.} = - result = newException(AsyncStreamLimitError, "Buffer limit reached") +proc newAsyncStreamLimitError*(): ref AsyncStreamLimitError {.inline.} = + newException(AsyncStreamLimitError, "Buffer limit reached") -proc newAsyncStreamIncorrectError(m: string): ref Exception {.inline.} = - result = newException(AsyncStreamIncorrectError, m) +proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {.inline.} = + newException(AsyncStreamUseClosedError, "Stream is already closed") + +proc newAsyncStreamIncorrectError*(m: string): ref AsyncStreamIncorrectError {. + inline.} = + newException(AsyncStreamIncorrectError, m) + +template checkRunning(t: untyped) = + if not(t.running()): + raise newAsyncStreamIncorrectError("Incorrect stream state") proc atEof*(rstream: AsyncStreamReader): bool = ## Returns ``true`` is reading stream is closed or finished and internal ## buffer do not have any bytes left. - result = rstream.state in {AsyncStreamState.Stopped, Finished, Closed} and - (rstream.buffer.dataLen() == 0) + rstream.state in {AsyncStreamState.Stopped, Finished, Closed} and + (rstream.buffer.dataLen() == 0) proc atEof*(wstream: AsyncStreamWriter): bool = ## Returns ``true`` is writing stream ``wstream`` closed or finished. - result = wstream.state in {AsyncStreamState.Stopped, Finished, Closed} + wstream.state in {AsyncStreamState.Stopped, Finished, Closed} proc closed*(rw: AsyncStreamRW): bool {.inline.} = ## Returns ``true`` is reading/writing stream is closed. @@ -223,32 +239,36 @@ proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} proc getAsyncStreamReaderTracker(): AsyncStreamTracker {.inline.} = - result = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName)) - if isNil(result): - result = setupAsyncStreamReaderTracker() + var res = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName)) + if isNil(res): + res = setupAsyncStreamReaderTracker() + res proc getAsyncStreamWriterTracker(): AsyncStreamTracker {.inline.} = - result = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName)) - if isNil(result): - result = setupAsyncStreamWriterTracker() + var res = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName)) + if isNil(res): + res = setupAsyncStreamWriterTracker() + res proc dumpAsyncStreamReaderTracking(): string {.gcsafe.} = var tracker = getAsyncStreamReaderTracker() - result = "Opened async stream readers: " & $tracker.opened & "\n" & - "Closed async stream readers: " & $tracker.closed + let res = "Opened async stream readers: " & $tracker.opened & "\n" & + "Closed async stream readers: " & $tracker.closed + res proc dumpAsyncStreamWriterTracking(): string {.gcsafe.} = var tracker = getAsyncStreamWriterTracker() - result = "Opened async stream writers: " & $tracker.opened & "\n" & - "Closed async stream writers: " & $tracker.closed + let res = "Opened async stream writers: " & $tracker.opened & "\n" & + "Closed async stream writers: " & $tracker.closed + res proc leakAsyncStreamReader(): bool {.gcsafe.} = var tracker = getAsyncStreamReaderTracker() - result = tracker.opened != tracker.closed + tracker.opened != tracker.closed proc leakAsyncStreamWriter(): bool {.gcsafe.} = var tracker = getAsyncStreamWriterTracker() - result = tracker.opened != tracker.closed + tracker.opened != tracker.closed proc trackAsyncStreamReader(t: AsyncStreamReader) {.inline.} = var tracker = getAsyncStreamReaderTracker() @@ -267,20 +287,39 @@ proc untrackAsyncStreamWriter*(t: AsyncStreamWriter) {.inline.} = inc(tracker.closed) proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} = - result = new AsyncStreamTracker - result.opened = 0 - result.closed = 0 - result.dump = dumpAsyncStreamReaderTracking - result.isLeaked = leakAsyncStreamReader - addTracker(AsyncStreamReaderTrackerName, result) + var res = AsyncStreamTracker( + opened: 0, + closed: 0, + dump: dumpAsyncStreamReaderTracking, + isLeaked: leakAsyncStreamReader + ) + addTracker(AsyncStreamReaderTrackerName, res) + res proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} = - result = new AsyncStreamTracker - result.opened = 0 - result.closed = 0 - result.dump = dumpAsyncStreamWriterTracking - result.isLeaked = leakAsyncStreamWriter - addTracker(AsyncStreamWriterTrackerName, result) + var res = AsyncStreamTracker( + opened: 0, + closed: 0, + dump: dumpAsyncStreamWriterTracking, + isLeaked: leakAsyncStreamWriter + ) + addTracker(AsyncStreamWriterTrackerName, res) + res + +template readLoop(body: untyped): untyped = + while true: + if rstream.buffer.dataLen() == 0: + if rstream.state == AsyncStreamState.Error: + raise rstream.error + + let (consumed, done) = body + + rstream.buffer.shift(consumed) + rstream.bytesCount = rstream.bytesCount + uint64(consumed) + if done: + break + else: + await rstream.buffer.wait() proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int) {.async.} = @@ -292,17 +331,16 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(nbytes >= 0, "nbytes must be non-negative integer") + checkRunning(rstream) + if nbytes == 0: return - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") - if isNil(rstream.rsource): try: await readExactly(rstream.tsource, pbytes, nbytes) - except CancelledError: - raise + except CancelledError as exc: + raise exc except TransportIncompleteError: raise newAsyncStreamIncompleteError() except CatchableError as exc: @@ -312,61 +350,47 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, await readExactly(rstream.rsource, pbytes, nbytes) else: var index = 0 - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): - raise newAsyncStreamIncompleteError() - - if datalen >= (nbytes - index): - rstream.buffer.copyData(pbytes, index, nbytes - index) - rstream.buffer.shift(nbytes - index) - break - else: - rstream.buffer.copyData(pbytes, index, datalen) - index += datalen - rstream.buffer.shift(datalen) - await rstream.buffer.wait() + var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) + readLoop(): + if rstream.buffer.dataLen() == 0: + if rstream.atEof(): + raise newAsyncStreamIncompleteError() + let count = min(nbytes - index, rstream.buffer.dataLen()) + if count > 0: + rstream.buffer.copyData(addr pbuffer[index], 0, count) + index += count + (consumed: count, done: index == nbytes) proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int): Future[int] {.async.} = ## Perform one read operation on read-only stream ``rstream``. ## ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from - ## internal buffer, otherwise it will wait until some bytes will be received. + ## internal buffer, otherwise it will wait until some bytes will be available. doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(nbytes > 0, "nbytes must be positive value") - - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(rstream) if isNil(rstream.rsource): try: - result = await readOnce(rstream.tsource, pbytes, nbytes) - except CancelledError: - raise + return await readOnce(rstream.tsource, pbytes, nbytes) + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await readOnce(rstream.rsource, pbytes, nbytes) + return await readOnce(rstream.rsource, pbytes, nbytes) else: - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0: - if rstream.atEof(): - result = 0 - break - await rstream.buffer.wait() + var count = 0 + readLoop(): + if rstream.buffer.dataLen() == 0: + (0, rstream.atEof()) else: - let size = min(datalen, nbytes) - rstream.buffer.copyData(pbytes, 0, size) - rstream.buffer.shift(size) - result = size - break + count = min(rstream.buffer.dataLen(), nbytes) + rstream.buffer.copyData(pbytes, 0, count) + (count, true) + return count proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, sep: seq[byte]): Future[int] {.async.} = @@ -386,18 +410,16 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(len(sep) > 0, "separator must not be empty") doAssert(nbytes >= 0, "nbytes must be non-negative value") + checkRunning(rstream) if nbytes == 0: raise newAsyncStreamLimitError() - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") - if isNil(rstream.rsource): try: - result = await readUntil(rstream.tsource, pbytes, nbytes, sep) - except CancelledError: - raise + return await readUntil(rstream.tsource, pbytes, nbytes, sep) + except CancelledError as exc: + raise exc except TransportIncompleteError: raise newAsyncStreamIncompleteError() except TransportLimitError: @@ -406,43 +428,30 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await readUntil(rstream.rsource, pbytes, nbytes, sep) + return await readUntil(rstream.rsource, pbytes, nbytes, sep) else: - var - dest = cast[ptr UncheckedArray[byte]](pbytes) - state = 0 - k = 0 - - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): + var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) + var state = 0 + var k = 0 + readLoop(): + if rstream.atEof(): raise newAsyncStreamIncompleteError() - var index = 0 - while index < datalen: + while index < rstream.buffer.dataLen(): + if k >= nbytes: + raise newAsyncStreamLimitError() let ch = rstream.buffer[index] + inc(index) + pbuffer[k] = ch + inc(k) if sep[state] == ch: inc(state) + if state == len(sep): + break else: state = 0 - if k < nbytes: - dest[k] = ch - inc(k) - else: - raise newAsyncStreamLimitError() - if state == len(sep): - break - inc(index) - - if state == len(sep): - rstream.buffer.shift(index + 1) - result = k - break - else: - rstream.buffer.shift(datalen) - await rstream.buffer.wait() + (index, state == len(sep)) + return k proc readLine*(rstream: AsyncStreamReader, limit = 0, sep = "\r\n"): Future[string] {.async.} = @@ -457,155 +466,207 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0, ## ## If ``limit`` more then 0, then result string will be limited to ``limit`` ## bytes. - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(rstream) if isNil(rstream.rsource): try: - result = await readLine(rstream.tsource, limit, sep) - except CancelledError: - raise + return await readLine(rstream.tsource, limit, sep) + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await readLine(rstream.rsource, limit, sep) + return await readLine(rstream.rsource, limit, sep) else: + let lim = if limit <= 0: -1 else: limit + var state = 0 var res = "" - var - lim = if limit <= 0: -1 else: limit - state = 0 - - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): - result = res - break - - var index = 0 - while index < datalen: - let ch = char(rstream.buffer[index]) - if sep[state] == ch: - inc(state) - if state == len(sep) or len(res) == lim: - rstream.buffer.shift(index + 1) - break - else: - state = 0 - res.add(ch) - if len(res) == lim: - rstream.buffer.shift(index + 1) - break - inc(index) - - if state == len(sep) or (lim == len(res)): - result = res - break + readLoop(): + if rstream.atEof(): + (0, true) else: - rstream.buffer.shift(datalen) - await rstream.buffer.wait() + var index = 0 + while index < rstream.buffer.dataLen(): + let ch = char(rstream.buffer[index]) + inc(index) -proc read*(rstream: AsyncStreamReader, n = 0): Future[seq[byte]] {.async.} = - ## Read all bytes (n <= 0) or exactly `n` bytes from read-only stream - ## ``rstream``. + if sep[state] == ch: + inc(state) + if state == len(sep): + break + else: + if state != 0: + if limit > 0: + let missing = min(state, lim - len(res) - 1) + res.add(sep[0 ..< missing]) + else: + res.add(sep[0 ..< state]) + res.add(ch) + if len(res) == lim: + break + (index, (state == len(sep)) or (lim == len(res))) + return res + +proc read*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = + ## Read all bytes from read-only stream ``rstream``. ## ## This procedure allocates buffer seq[byte] and return it as result. - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(rstream) if isNil(rstream.rsource): try: - result = await read(rstream.tsource, n) - except CancelledError: - raise - except CatchableError as exc: - raise newAsyncStreamReadError(exc) - else: - if isNil(rstream.readerLoop): - result = await read(rstream.rsource, n) - else: - var res = newSeq[byte]() - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0 and rstream.atEof(): - result = res - break - - if datalen > 0: - let s = len(res) - let o = s + datalen - if n <= 0: - res.setLen(o) - rstream.buffer.copyData(addr res[s], 0, datalen) - rstream.buffer.shift(datalen) - else: - let left = n - s - if datalen >= left: - res.setLen(n) - rstream.buffer.copyData(addr res[s], 0, left) - rstream.buffer.shift(left) - result = res - break - else: - res.setLen(o) - rstream.buffer.copyData(addr res[s], 0, datalen) - rstream.buffer.shift(datalen) - - await rstream.buffer.wait() - -proc consume*(rstream: AsyncStreamReader, n = -1): Future[int] {.async.} = - ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream - ## ``rstream``. - ## - ## Return number of bytes actually consumed (discarded). - if not rstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") - - if isNil(rstream.rsource): - try: - result = await consume(rstream.tsource, n) - except CancelledError: - raise + return await read(rstream.tsource) + except CancelledError as exc: + raise exc except TransportLimitError: raise newAsyncStreamLimitError() except CatchableError as exc: raise newAsyncStreamReadError(exc) else: if isNil(rstream.readerLoop): - result = await consume(rstream.rsource, n) + return await read(rstream.rsource) + else: + var res = newSeq[byte]() + readLoop(): + if rstream.atEof(): + (0, true) + else: + let count = rstream.buffer.dataLen() + res.add(rstream.buffer.buffer.toOpenArray(0, count - 1)) + (count, false) + return res + +proc read*(rstream: AsyncStreamReader, n: int): Future[seq[byte]] {.async.} = + ## Read all bytes (n <= 0) or exactly `n` bytes from read-only stream + ## ``rstream``. + ## + ## This procedure allocates buffer seq[byte] and return it as result. + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + return await read(rstream.tsource, n) + except CancelledError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + return await read(rstream.rsource, n) + else: + if n <= 0: + return await read(rstream.rsource) + else: + var res = newSeq[byte]() + readLoop(): + if rstream.atEof(): + (0, true) + else: + let count = min(rstream.buffer.dataLen(), n - len(res)) + res.add(rstream.buffer.buffer.toOpenArray(0, count - 1)) + (count, len(res) == n) + return res + +proc consume*(rstream: AsyncStreamReader): Future[int] {.async.} = + ## Consume (discard) all bytes from read-only stream ``rstream``. + ## + ## Return number of bytes actually consumed (discarded). + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + return await consume(rstream.tsource) + except CancelledError as exc: + raise exc + except TransportLimitError: + raise newAsyncStreamLimitError() + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + return await consume(rstream.rsource) else: var res = 0 - while true: - let datalen = rstream.buffer.dataLen() - if rstream.state == Error: - raise newAsyncStreamReadError(rstream.error) - if datalen == 0: - if rstream.atEof(): - if n <= 0: - result = res - break - else: - raise newAsyncStreamLimitError() + readLoop(): + if rstream.atEof(): + (0, true) else: - if n <= 0: - res += datalen - rstream.buffer.shift(datalen) - else: - let left = n - res - if datalen >= left: - res += left - rstream.buffer.shift(left) - result = res - break - else: - res += datalen - rstream.buffer.shift(datalen) + res += rstream.buffer.dataLen() + (rstream.buffer.dataLen(), false) + return res - await rstream.buffer.wait() +proc consume*(rstream: AsyncStreamReader, n: int): Future[int] {.async.} = + ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream + ## ``rstream``. + ## + ## Return number of bytes actually consumed (discarded). + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + return await consume(rstream.tsource, n) + except CancelledError as exc: + raise exc + except TransportLimitError: + raise newAsyncStreamLimitError() + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + return await consume(rstream.rsource, n) + else: + if n <= 0: + return await rstream.consume() + else: + var res = 0 + readLoop(): + if rstream.atEof(): + (0, true) + else: + let count = min(rstream.buffer.dataLen(), n - res) + res += count + (count, res == n) + return res + +proc readMessage*(rstream: AsyncStreamReader, pred: ReadMessagePredicate) {. + async.} = + ## Read all bytes from stream ``rstream`` until ``predicate`` callback + ## will not be satisfied. + ## + ## ``predicate`` callback should return tuple ``(consumed, result)``, where + ## ``consumed`` is the number of bytes processed and ``result`` is a + ## completion flag (``true`` if readMessage() should stop reading data, + ## or ``false`` if readMessage() should continue to read data from stream). + ## + ## ``predicate`` callback must copy all the data from ``data`` array and + ## return number of bytes it is going to consume. + ## ``predicate`` callback will receive (zero-length) openarray, if stream + ## is at EOF. + doAssert(not(isNil(pred)), "`predicate` callback should not be `nil`") + checkRunning(rstream) + + if isNil(rstream.rsource): + try: + await readMessage(rstream.tsource, pred) + except CancelledError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamReadError(exc) + else: + if isNil(rstream.readerLoop): + await readMessage(rstream.rsource, pred) + else: + readLoop(): + let count = rstream.buffer.dataLen() + if count == 0: + if rstream.atEof(): + pred([]) + else: + # Case, when transport's buffer is not yet filled with data. + (0, false) + else: + pred(rstream.buffer.buffer.toOpenArray(0, count - 1)) proc write*(wstream: AsyncStreamWriter, pbytes: pointer, nbytes: int) {.async.} = @@ -613,8 +674,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, ## writer stream ``wstream``. ## ## ``nbytes` must be more then zero. - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(wstream) if nbytes <= 0: raise newAsyncStreamIncorrectError("Zero length message") @@ -622,27 +682,34 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, var res: int try: res = await write(wstream.tsource, pbytes, nbytes) - except CancelledError: - raise + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamWriteError(exc) if res != nbytes: raise newAsyncStreamIncompleteError() + wstream.bytesCount = wstream.bytesCount + uint64(nbytes) else: if isNil(wstream.writerLoop): await write(wstream.wsource, pbytes, nbytes) + wstream.bytesCount = wstream.bytesCount + uint64(nbytes) else: var item = WriteItem(kind: Pointer) item.data1 = pbytes item.size = nbytes item.future = newFuture[void]("async.stream.write(pointer)") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + wstream.bytesCount = wstream.bytesCount + uint64(item.size) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], msglen = -1) {.async.} = @@ -654,10 +721,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], ## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream. ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. + checkRunning(wstream) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) - - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") if length <= 0: raise newAsyncStreamIncorrectError("Zero length message") @@ -665,15 +730,17 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], var res: int try: res = await write(wstream.tsource, sbytes, msglen) - except CancelledError: - raise + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamWriteError(exc) if res != length: raise newAsyncStreamIncompleteError() + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: if isNil(wstream.writerLoop): await write(wstream.wsource, sbytes, msglen) + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: var item = WriteItem(kind: Sequence) if not isLiteral(sbytes): @@ -682,13 +749,16 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], item.data2 = sbytes item.size = length item.future = newFuture[void]("async.stream.write(seq)") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + wstream.bytesCount = wstream.bytesCount + uint64(item.size) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc write*(wstream: AsyncStreamWriter, sbytes: string, msglen = -1) {.async.} = @@ -699,10 +769,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, ## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream. ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. + checkRunning(wstream) let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) - - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") if length <= 0: raise newAsyncStreamIncorrectError("Zero length message") @@ -710,15 +778,17 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, var res: int try: res = await write(wstream.tsource, sbytes, msglen) - except CancelledError: - raise + except CancelledError as exc: + raise exc except CatchableError as exc: raise newAsyncStreamWriteError(exc) if res != length: raise newAsyncStreamIncompleteError() + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: if isNil(wstream.writerLoop): await write(wstream.wsource, sbytes, msglen) + wstream.bytesCount = wstream.bytesCount + uint64(msglen) else: var item = WriteItem(kind: String) if not isLiteral(sbytes): @@ -727,18 +797,20 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, item.data3 = sbytes item.size = length item.future = newFuture[void]("async.stream.write(string)") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + wstream.bytesCount = wstream.bytesCount + uint64(item.size) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc finish*(wstream: AsyncStreamWriter) {.async.} = ## Finish write stream ``wstream``. - if not wstream.running(): - raise newAsyncStreamIncorrectError("Incorrect stream state") + checkRunning(wstream) if not isNil(wstream.wsource): if isNil(wstream.writerLoop): @@ -747,13 +819,15 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} = var item = WriteItem(kind: Pointer) item.size = 0 item.future = newFuture[void]("async.stream.finish") - await wstream.queue.put(item) try: + await wstream.queue.put(item) await item.future - except CancelledError: - raise - except: - raise newAsyncStreamWriteError(item.future.error) + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc join*(rw: AsyncStreamRW): Future[void] = ## Get Future[void] which will be completed when stream become finished or @@ -766,12 +840,12 @@ proc join*(rw: AsyncStreamRW): Future[void] = proc continuation(udata: pointer) {.gcsafe.} = retFuture.complete() - proc cancel(udata: pointer) {.gcsafe.} = + proc cancellation(udata: pointer) {.gcsafe.} = rw.future.removeCallback(continuation, cast[pointer](retFuture)) if not(rw.future.finished()): rw.future.addCallback(continuation, cast[pointer](retFuture)) - rw.future.cancelCallback = cancel + rw.future.cancelCallback = cancellation else: retFuture.complete() @@ -818,7 +892,7 @@ proc close*(rw: AsyncStreamRW) = proc closeWait*(rw: AsyncStreamRW): Future[void] = ## Close and frees resources of stream ``rw``. rw.close() - result = rw.join() + rw.join() proc startReader(rstream: AsyncStreamReader) = rstream.state = Running @@ -981,8 +1055,9 @@ proc newAsyncStreamReader*[T](rsource: AsyncStreamReader, ## ## ``udata`` - user object which will be associated with new AsyncStreamReader ## object. - result = new AsyncStreamReader - result.init(rsource, loop, bufferSize, udata) + var res = AsyncStreamReader() + res.init(rsource, loop, bufferSize, udata) + res proc newAsyncStreamReader*(rsource: AsyncStreamReader, loop: StreamReaderLoop, @@ -994,8 +1069,9 @@ proc newAsyncStreamReader*(rsource: AsyncStreamReader, ## ``loop`` is main reading loop procedure. ## ## ``bufferSize`` is internal buffer size. - result = new AsyncStreamReader - result.init(rsource, loop, bufferSize) + var res = AsyncStreamReader() + res.init(rsource, loop, bufferSize) + res proc newAsyncStreamReader*[T](tsource: StreamTransport, udata: ref T): AsyncStreamReader = @@ -1004,14 +1080,16 @@ proc newAsyncStreamReader*[T](tsource: StreamTransport, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamReader - result.init(tsource, udata) + var res = AsyncStreamReader() + res.init(tsource, udata) + res proc newAsyncStreamReader*(tsource: StreamTransport): AsyncStreamReader = ## Create new AsyncStreamReader object, which will use stream transport ## ``tsource`` as source data channel. - result = new AsyncStreamReader - result.init(tsource) + var res = AsyncStreamReader() + res.init(tsource) + res proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, loop: StreamWriterLoop, @@ -1026,8 +1104,9 @@ proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamWriter - result.init(wsource, loop, queueSize, udata) + var res = AsyncStreamWriter() + res.init(wsource, loop, queueSize, udata) + res proc newAsyncStreamWriter*(wsource: AsyncStreamWriter, loop: StreamWriterLoop, @@ -1039,8 +1118,9 @@ proc newAsyncStreamWriter*(wsource: AsyncStreamWriter, ## ``loop`` is main writing loop procedure. ## ## ``queueSize`` is writing queue size (default size is unlimited). - result = new AsyncStreamWriter - result.init(wsource, loop, queueSize) + var res = AsyncStreamWriter() + res.init(wsource, loop, queueSize) + res proc newAsyncStreamWriter*[T](tsource: StreamTransport, udata: ref T): AsyncStreamWriter = @@ -1049,14 +1129,16 @@ proc newAsyncStreamWriter*[T](tsource: StreamTransport, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamWriter - result.init(tsource, udata) + var res = AsyncStreamWriter() + res.init(tsource, udata) + res proc newAsyncStreamWriter*(tsource: StreamTransport): AsyncStreamWriter = ## Create new AsyncStreamWriter object which will use stream transport ## ``tsource`` as data channel. - result = new AsyncStreamWriter - result.init(tsource) + var res = AsyncStreamWriter() + res.init(tsource) + res proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, udata: ref T): AsyncStreamWriter = @@ -1064,13 +1146,15 @@ proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, ## ## ``udata`` - user object which will be associated with new AsyncStreamWriter ## object. - result = new AsyncStreamWriter - result.init(wsource, udata) + var res = AsyncStreamWriter() + res.init(wsource, udata) + res proc newAsyncStreamWriter*(wsource: AsyncStreamWriter): AsyncStreamWriter = ## Create copy of AsyncStreamWriter object ``wsource``. - result = new AsyncStreamWriter - result.init(wsource) + var res = AsyncStreamWriter() + res.init(wsource) + res proc newAsyncStreamReader*[T](rsource: AsyncStreamWriter, udata: ref T): AsyncStreamWriter = @@ -1078,15 +1162,17 @@ proc newAsyncStreamReader*[T](rsource: AsyncStreamWriter, ## ## ``udata`` - user object which will be associated with new AsyncStreamReader ## object. - result = new AsyncStreamReader - result.init(rsource, udata) + var res = AsyncStreamReader() + res.init(rsource, udata) + res proc newAsyncStreamReader*(rsource: AsyncStreamReader): AsyncStreamReader = ## Create copy of AsyncStreamReader object ``rsource``. - result = new AsyncStreamReader - result.init(rsource) + var res = AsyncStreamReader() + res.init(rsource) + res proc getUserData*[T](rw: AsyncStreamRW): T {.inline.} = ## Obtain user data associated with AsyncStreamReader or AsyncStreamWriter ## object ``rw``. - result = cast[T](rw.udata) + cast[T](rw.udata) diff --git a/chronos/streams/boundstream.nim b/chronos/streams/boundstream.nim new file mode 100644 index 0000000..0136a86 --- /dev/null +++ b/chronos/streams/boundstream.nim @@ -0,0 +1,212 @@ +# +# Chronos Asynchronous Bound Stream +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module implements bounded stream reading and writing. +## +## For stream reading it means that you should read exactly bounded size of +## bytes. +## +## For stream writing it means that you should write exactly bounded size +## of bytes, and if you wrote not enough bytes error will appear on stream +## close. +import ../asyncloop, ../timer +import asyncstream, ../transports/stream, ../transports/common +export asyncstream, stream, timer, common + +type + BoundedStreamReader* = ref object of AsyncStreamReader + boundSize: uint64 + offset: uint64 + + BoundedStreamWriter* = ref object of AsyncStreamWriter + boundSize: uint64 + offset: uint64 + + BoundedStreamError* = object of AsyncStreamError + BoundedStreamIncompleteError* = object of BoundedStreamError + BoundedStreamOverflowError* = object of BoundedStreamError + + BoundedStreamRW* = BoundedStreamReader | BoundedStreamWriter + +const + BoundedBufferSize* = 4096 + +template newBoundedStreamIncompleteError*(): ref BoundedStreamError = + newException(BoundedStreamIncompleteError, + "Stream boundary is not reached yet") +template newBoundedStreamOverflowError*(): ref BoundedStreamError = + newException(BoundedStreamOverflowError, "Stream boundary exceeded") + +proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = + var rstream = cast[BoundedStreamReader](stream) + rstream.state = AsyncStreamState.Running + while true: + if rstream.offset < rstream.boundSize: + let toRead = int(min(rstream.boundSize - rstream.offset, + uint64(rstream.buffer.bufferLen()))) + try: + await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead) + rstream.offset = rstream.offset + uint64(toRead) + rstream.buffer.update(toRead) + await rstream.buffer.transfer() + except AsyncStreamIncompleteError: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() + except AsyncStreamReadError as exc: + rstream.state = AsyncStreamState.Error + rstream.error = exc + except CancelledError: + rstream.state = AsyncStreamState.Stopped + + if rstream.state != AsyncStreamState.Running: + break + else: + rstream.state = AsyncStreamState.Finished + await rstream.buffer.transfer() + break + + if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}: + # We need to notify consumer about error/close, but we do not care about + # incoming data anymore. + rstream.buffer.forget() + +proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = + var wstream = cast[BoundedStreamWriter](stream) + + wstream.state = AsyncStreamState.Running + while true: + var + item: WriteItem + error: ref AsyncStreamError + + try: + item = await wstream.queue.get() + if item.size > 0: + if uint64(item.size) <= (wstream.boundSize - wstream.offset): + # Writing chunk data. + case item.kind + of WriteType.Pointer: + await wstream.wsource.write(item.data1, item.size) + of WriteType.Sequence: + await wstream.wsource.write(addr item.data2[0], item.size) + of WriteType.String: + await wstream.wsource.write(addr item.data3[0], item.size) + wstream.offset = wstream.offset + uint64(item.size) + item.future.complete() + else: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamOverflowError() + else: + if wstream.offset != wstream.boundSize: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamIncompleteError() + else: + wstream.state = AsyncStreamState.Finished + item.future.complete() + except CancelledError: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() + except AsyncStreamWriteError as exc: + wstream.state = AsyncStreamState.Error + error = exc + except AsyncStreamIncompleteError as exc: + wstream.state = AsyncStreamState.Error + error = exc + + if wstream.state != AsyncStreamState.Running: + if wstream.state == AsyncStreamState.Finished: + error = newAsyncStreamUseClosedError() + else: + if not(isNil(item.future)): + if not(item.future.finished()): + item.future.fail(error) + while not(wstream.queue.empty()): + let pitem = wstream.queue.popFirstNoWait() + if not(pitem.future.finished()): + pitem.future.fail(error) + break + +proc bytesLeft*(stream: BoundedStreamRW): uint64 = + ## Returns number of bytes left in stream. + stream.boundSize - stream.bytesCount + +proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader, + bufferSize = BoundedBufferSize, udata: ref T) = + init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize, + udata) + +proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader, + bufferSize = BoundedBufferSize) = + init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize) + +proc newBoundedStreamReader*[T](rsource: AsyncStreamReader, + boundSize: uint64, + bufferSize = BoundedBufferSize, + udata: ref T): BoundedStreamReader = + var res = BoundedStreamReader(boundSize: boundSize) + res.init(rsource, bufferSize, udata) + res + +proc newBoundedStreamReader*(rsource: AsyncStreamReader, + boundSize: uint64, + bufferSize = BoundedBufferSize, + ): BoundedStreamReader = + doAssert(boundSize >= 0) + var res = BoundedStreamReader(boundSize: boundSize) + res.init(rsource, bufferSize) + res + +proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter, + queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = + init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize, + udata) + +proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter, + queueSize = AsyncStreamDefaultQueueSize) = + init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize) + +proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter, + boundSize: uint64, + queueSize = AsyncStreamDefaultQueueSize, + udata: ref T): BoundedStreamWriter = + var res = BoundedStreamWriter(boundSize: boundSize) + res.init(wsource, queueSize, udata) + res + +proc newBoundedStreamWriter*(wsource: AsyncStreamWriter, + boundSize: uint64, + queueSize = AsyncStreamDefaultQueueSize, + ): BoundedStreamWriter = + var res = BoundedStreamWriter(boundSize: boundSize) + res.init(wsource, queueSize) + res + +proc close*(rw: BoundedStreamRW) = + ## Close and frees resources of stream ``rw``. + ## + ## Note close() procedure is not completed immediately. + if rw.closed(): + raise newAsyncStreamIncorrectError("Stream is already closed!") + # We do not want to raise one more IncompleteError if it was already raised + # by one of the read()/write() primitives. + if rw.state != AsyncStreamState.Error: + if rw.bytesLeft() != 0'u64: + raise newBoundedStreamIncompleteError() + when rw is BoundedStreamReader: + cast[AsyncStreamReader](rw).close() + elif rw is BoundedStreamWriter: + cast[AsyncStreamWriter](rw).close() + +proc closeWait*(rw: BoundedStreamRW): Future[void] = + ## Close and frees resources of stream ``rw``. + rw.close() + when rw is BoundedStreamReader: + cast[AsyncStreamReader](rw).join() + elif rw is BoundedStreamWriter: + cast[AsyncStreamWriter](rw).join() diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index b847cac..9e65d8f 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -21,25 +21,44 @@ type ChunkedStreamReader* = ref object of AsyncStreamReader ChunkedStreamWriter* = ref object of AsyncStreamWriter - ChunkedStreamError* = object of CatchableError + ChunkedStreamError* = object of AsyncStreamError ChunkedStreamProtocolError* = object of ChunkedStreamError + ChunkedStreamIncompleteError* = object of ChunkedStreamError -proc newProtocolError(): ref Exception {.inline.} = +proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} = newException(ChunkedStreamProtocolError, "Protocol error!") +proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} = + newException(ChunkedStreamIncompleteError, "Incomplete data received!") + +proc `-`(x: uint32): uint32 {.inline.} = + result = (0xFFFF_FFFF'u32 - x) + 1'u32 + +proc LT(x, y: uint32): uint32 {.inline.} = + let z = x - y + (z xor ((y xor x) and (y xor z))) shr 31 + +proc hexValue(c: byte): int = + let x = uint32(c) - 0x30'u32 + let y = uint32(c) - 0x41'u32 + let z = uint32(c) - 0x61'u32 + let r = ((x + 1'u32) and -LT(x, 10)) or + ((y + 11'u32) and -LT(y, 6)) or + ((z + 11'u32) and -LT(z, 6)) + int(r) - 1 + proc getChunkSize(buffer: openarray[byte]): uint64 = # We using `uint64` representation, but allow only 2^32 chunk size, # ChunkHeaderSize. + var res = 0'u64 for i in 0..= byte('0') and ch <= byte('9'): - result = (result shl 4) or uint64(ch - byte('0')) - else: - result = (result shl 4) or uint64((ch and 0x0F) + 9) + let value = hexValue(buffer[i]) + if value >= 0: + res = (res shl 4) or uint64(value) else: - result = 0xFFFF_FFFF_FFFF_FFFF'u64 + res = 0xFFFF_FFFF_FFFF_FFFF'u64 break + res proc setChunkSize(buffer: var openarray[byte], length: int64): int = # Store length as chunk header size (hexadecimal value) with CRLF. @@ -53,7 +72,7 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int = buffer[0] = byte('0') buffer[1] = byte(0x0D) buffer[2] = byte(0x0A) - result = 3 + 3 else: while n != 0: var v = length and n @@ -68,161 +87,116 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int = i = i - 4 buffer[c] = byte(0x0D) buffer[c + 1] = byte(0x0A) - result = c + 2 + c + 2 proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[ChunkedStreamReader](stream) var buffer = newSeq[byte](1024) rstream.state = AsyncStreamState.Running - try: - while true: + while true: + try: # Reading chunk size - var ruFut1 = awaitne rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) - if ruFut1.failed(): - rstream.error = ruFut1.error - rstream.state = AsyncStreamState.Error - break - - let length = ruFut1.read() - var chunksize = getChunkSize(buffer.toOpenArray(0, - length - len(CRLF) - 1)) + let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) + var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1)) if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64: - rstream.error = newProtocolError() + rstream.error = newChunkedProtocolError() rstream.state = AsyncStreamState.Error - break elif chunksize > 0'u64: while chunksize > 0'u64: let toRead = min(int(chunksize), rstream.buffer.bufferLen()) - var reFut2 = awaitne rstream.rsource.readExactly( - rstream.buffer.getBuffer(), toRead) - if reFut2.failed(): - rstream.error = reFut2.error - rstream.state = AsyncStreamState.Error - break - + await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead) rstream.buffer.update(toRead) await rstream.buffer.transfer() chunksize = chunksize - uint64(toRead) - if rstream.state != AsyncStreamState.Running: - break + if rstream.state == AsyncStreamState.Running: + # Reading chunk trailing CRLF + await rstream.rsource.readExactly(addr buffer[0], 2) - # Reading chunk trailing CRLF - var reFut3 = awaitne rstream.rsource.readExactly(addr buffer[0], 2) - if reFut3.failed(): - rstream.error = reFut3.error - rstream.state = AsyncStreamState.Error - break - - if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: - rstream.error = newProtocolError() - rstream.state = AsyncStreamState.Error - break + if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: + rstream.error = newChunkedProtocolError() + rstream.state = AsyncStreamState.Error else: # Reading trailing line for last chunk - var ruFut4 = awaitne rstream.rsource.readUntil(addr buffer[0], - len(buffer), CRLF) - if ruFut4.failed(): - rstream.error = ruFut4.error - rstream.state = AsyncStreamState.Error - break - + discard await rstream.rsource.readUntil(addr buffer[0], + len(buffer), CRLF) rstream.state = AsyncStreamState.Finished await rstream.buffer.transfer() - break + except CancelledError: + rstream.state = AsyncStreamState.Stopped + except AsyncStreamIncompleteError: + rstream.state = AsyncStreamState.Error + rstream.error = newChunkedIncompleteError() + except AsyncStreamReadError as exc: + rstream.state = AsyncStreamState.Error + rstream.error = exc - except CancelledError: - rstream.state = AsyncStreamState.Stopped - finally: if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}: # We need to notify consumer about error/close, but we do not care about # incoming data anymore. rstream.buffer.forget() + break proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[ChunkedStreamWriter](stream) var buffer: array[16, byte] - var wFut1, wFut2: Future[void] - var error: ref Exception + var error: ref AsyncStreamError wstream.state = AsyncStreamState.Running - try: - while true: - # Getting new item from stream's queue. - var item = await wstream.queue.get() + while true: + var item: WriteItem + # Getting new item from stream's queue. + try: + item = await wstream.queue.get() # `item.size == 0` is marker of stream finish, while `item.size != 0` is # data's marker. if item.size > 0: let length = setChunkSize(buffer, int64(item.size)) # Writing chunk header CRLF. - wFut1 = awaitne wstream.wsource.write(addr buffer[0], length) - if wFut1.failed(): - error = wFut1.error - item.future.fail(error) - continue - + await wstream.wsource.write(addr buffer[0], length) # Writing chunk data. - if item.kind == Pointer: - wFut2 = awaitne wstream.wsource.write(item.data1, item.size) - elif item.kind == Sequence: - wFut2 = awaitne wstream.wsource.write(addr item.data2[0], item.size) - elif item.kind == String: - wFut2 = awaitne wstream.wsource.write(addr item.data3[0], item.size) - if wFut2.failed(): - error = wFut2.error - item.future.fail(error) - continue - + case item.kind + of WriteType.Pointer: + await wstream.wsource.write(item.data1, item.size) + of WriteType.Sequence: + await wstream.wsource.write(addr item.data2[0], item.size) + of WriteType.String: + await wstream.wsource.write(addr item.data3[0], item.size) # Writing chunk footer CRLF. - var wFut3 = awaitne wstream.wsource.write(CRLF) - if wFut3.failed(): - error = wFut3.error - item.future.fail(error) - continue - + await wstream.wsource.write(CRLF) # Everything is fine, completing queue item's future. item.future.complete() else: let length = setChunkSize(buffer, 0'i64) - # Write finish chunk `0`. - wFut1 = awaitne wstream.wsource.write(addr buffer[0], length) - if wFut1.failed(): - error = wFut1.error - item.future.fail(error) - # We break here, because this is last chunk - break - + await wstream.wsource.write(addr buffer[0], length) # Write trailing CRLF. - wFut2 = awaitne wstream.wsource.write(CRLF) - if wFut2.failed(): - error = wFut2.error - item.future.fail(error) - # We break here, because this is last chunk - break - + await wstream.wsource.write(CRLF) # Everything is fine, completing queue item's future. item.future.complete() - # Set stream state to Finished. wstream.state = AsyncStreamState.Finished - break - except CancelledError: - wstream.state = AsyncStreamState.Stopped - finally: - if wstream.state == AsyncStreamState.Stopped: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.complete() - elif wstream.state == AsyncStreamState.Error: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - if not isNil(error): + except CancelledError: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() + except AsyncStreamError as exc: + wstream.state = AsyncStreamState.Error + error = exc + + if wstream.state != AsyncStreamState.Running: + if wstream.state == AsyncStreamState.Finished: + error = newAsyncStreamUseClosedError() + else: + if not(isNil(item.future)): + if not(item.future.finished()): item.future.fail(error) + while not(wstream.queue.empty()): + let pitem = wstream.queue.popFirstNoWait() + if not(pitem.future.finished()): + pitem.future.fail(error) + break proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader, bufferSize = ChunkBufferSize, udata: ref T) = @@ -236,14 +210,16 @@ proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader, proc newChunkedStreamReader*[T](rsource: AsyncStreamReader, bufferSize = AsyncStreamDefaultBufferSize, udata: ref T): ChunkedStreamReader = - result = new ChunkedStreamReader - result.init(rsource, bufferSize, udata) + var res = ChunkedStreamReader() + res.init(rsource, bufferSize, udata) + res proc newChunkedStreamReader*(rsource: AsyncStreamReader, bufferSize = AsyncStreamDefaultBufferSize, ): ChunkedStreamReader = - result = new ChunkedStreamReader - result.init(rsource, bufferSize) + var res = ChunkedStreamReader() + res.init(rsource, bufferSize) + res proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, udata: ref T) = @@ -257,11 +233,13 @@ proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter, proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, udata: ref T): ChunkedStreamWriter = - result = new ChunkedStreamWriter - result.init(wsource, queueSize, udata) + var res = ChunkedStreamWriter() + res.init(wsource, queueSize, udata) + res proc newChunkedStreamWriter*(wsource: AsyncStreamWriter, queueSize = AsyncStreamDefaultQueueSize, ): ChunkedStreamWriter = - result = new ChunkedStreamWriter - result.init(wsource, queueSize) + var res = ChunkedStreamWriter() + res.init(wsource, queueSize) + res diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index ec8b267..32258a7 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -89,11 +89,11 @@ type SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream - TLSStreamError* = object of CatchableError + TLSStreamError* = object of AsyncStreamError TLSStreamProtocolError* = object of TLSStreamError errCode*: int -template newTLSStreamProtocolError[T](message: T): ref Exception = +template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = var msg = "" var code = 0 when T is string: @@ -110,13 +110,13 @@ template newTLSStreamProtocolError[T](message: T): ref Exception = err.errCode = code err -proc raiseTLSStreamProtoError*[T](message: T) = +template raiseTLSStreamProtoError*[T](message: T) = raise newTLSStreamProtocolError(message) proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[TLSStreamWriter](stream) var engine: ptr SslEngineContext - var error: ref Exception + var error: ref AsyncStreamError if wstream.kind == TLSStreamKind.Server: engine = addr wstream.scontext.eng @@ -125,86 +125,77 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = wstream.state = AsyncStreamState.Running - try: - var length: uint - while true: + while true: + var item: WriteItem + try: var state = engine.sslEngineCurrentState() - if (state and SSL_CLOSED) == SSL_CLOSED: wstream.state = AsyncStreamState.Finished - break - - if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: - if not(wstream.switchToReader.isSet()): - wstream.switchToReader.fire() - - if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: - await wstream.switchToWriter.wait() - wstream.switchToWriter.clear() - # We need to refresh `state` because we just returned from readerLoop. - continue - - if (state and SSL_SENDREC) == SSL_SENDREC: - # TLS record needs to be sent over stream. - length = 0'u - var buf = sslEngineSendrecBuf(engine, length) - doAssert(length != 0 and not isNil(buf)) - var fut = awaitne wstream.wsource.write(buf, int(length)) - if fut.cancelled(): - raise fut.error - elif fut.failed(): - error = fut.error - break - sslEngineSendrecAck(engine, length) - continue - - if (state and SSL_SENDAPP) == SSL_SENDAPP: - # Application data can be sent over stream. - if not(wstream.handshaked): - wstream.stream.reader.handshaked = true - wstream.handshaked = true - if not(isNil(wstream.handshakeFut)): - wstream.handshakeFut.complete() - - var item = await wstream.queue.get() - if item.size > 0: - length = 0'u - var buf = sslEngineSendappBuf(engine, length) - let toWrite = min(int(length), item.size) - copyOut(buf, item, toWrite) - if int(length) >= item.size: - # BearSSL is ready to accept whole item size. - sslEngineSendappAck(engine, uint(item.size)) - sslEngineFlush(engine, 0) - item.future.complete() - else: - # BearSSL is not ready to accept whole item, so we will send only - # part of item and adjust offset. - item.offset = item.offset + int(length) - item.size = item.size - int(length) - wstream.queue.addFirstNoWait(item) - sslEngineSendappAck(engine, length) - continue + else: + if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: + if not(wstream.switchToReader.isSet()): + wstream.switchToReader.fire() + if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + await wstream.switchToWriter.wait() + wstream.switchToWriter.clear() + # We need to refresh `state` because we just returned from readerLoop. else: - # Zero length item means finish - wstream.state = AsyncStreamState.Finished - break + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + var length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + await wstream.wsource.write(buf, int(length)) + sslEngineSendrecAck(engine, length) + elif (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + if not(wstream.handshaked): + wstream.stream.reader.handshaked = true + wstream.handshaked = true + if not(isNil(wstream.handshakeFut)): + wstream.handshakeFut.complete() + item = await wstream.queue.get() + if item.size > 0: + var length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + copyOut(buf, item, toWrite) + if int(length) >= item.size: + # BearSSL is ready to accept whole item size. + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + # BearSSL is not ready to accept whole item, so we will send + # only part of item and adjust offset. + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + else: + # Zero length item means finish, so we going to trigger TLS + # closure protocol. + sslEngineClose(engine) + except CancelledError: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() + except AsyncStreamError as exc: + wstream.state = AsyncStreamState.Error + error = exc - except CancelledError: - wstream.state = AsyncStreamState.Stopped - - finally: - if wstream.state == AsyncStreamState.Stopped: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.complete() - elif wstream.state == AsyncStreamState.Error: - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.fail(error) - wstream.stream = nil + if wstream.state != AsyncStreamState.Running: + if wstream.state == AsyncStreamState.Finished: + error = newAsyncStreamUseClosedError() + else: + if not(isNil(item.future)): + if not(item.future.finished()): + item.future.fail(error) + while not(wstream.queue.empty()): + let pitem = wstream.queue.popFirstNoWait() + if not(pitem.future.finished()): + pitem.future.fail(error) + wstream.stream = nil + break proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[TLSStreamReader](stream) @@ -217,72 +208,61 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Running - try: - var length: uint - while true: + while true: + try: var state = engine.sslEngineCurrentState() if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() if err != 0: - raise newTLSStreamProtocolError(err) - rstream.state = AsyncStreamState.Stopped - break - - if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: - if not(rstream.switchToWriter.isSet()): - rstream.switchToWriter.fire() - - if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: - await rstream.switchToReader.wait() - rstream.switchToReader.clear() - # We need to refresh `state` because we just returned from writerLoop. - continue - - if (state and SSL_RECVREC) == SSL_RECVREC: - # TLS records required for further processing - length = 0'u - var buf = sslEngineRecvrecBuf(engine, length) - let res = await rstream.rsource.readOnce(buf, int(length)) - if res > 0: - sslEngineRecvrecAck(engine, uint(res)) - continue + rstream.error = newTLSStreamProtocolError(err) + rstream.state = AsyncStreamState.Error else: rstream.state = AsyncStreamState.Finished - break + else: + if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: + if not(rstream.switchToWriter.isSet()): + rstream.switchToWriter.fire() + if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + await rstream.switchToReader.wait() + rstream.switchToReader.clear() + # We need to refresh `state` because we just returned from writerLoop. + else: + if (state and SSL_RECVREC) == SSL_RECVREC: + # TLS records required for further processing + var length = 0'u + var buf = sslEngineRecvrecBuf(engine, length) + let res = await rstream.rsource.readOnce(buf, int(length)) + if res > 0: + sslEngineRecvrecAck(engine, uint(res)) + else: + # readOnce() returns `0` if stream is at EOF, so we initiate TLS + # closure procedure. + sslEngineClose(engine) + elif (state and SSL_RECVAPP) == SSL_RECVAPP: + # Application data can be recovered. + var length = 0'u + var buf = sslEngineRecvappBuf(engine, length) + await upload(addr rstream.buffer, buf, int(length)) + sslEngineRecvappAck(engine, length) + except CancelledError: + rstream.state = AsyncStreamState.Stopped + except AsyncStreamError as exc: + rstream.error = exc + rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() - if (state and SSL_RECVAPP) == SSL_RECVAPP: - # Application data can be recovered. - length = 0'u - var buf = sslEngineRecvappBuf(engine, length) - await upload(addr rstream.buffer, buf, int(length)) - sslEngineRecvappAck(engine, length) - continue - - except CancelledError: - rstream.state = AsyncStreamState.Stopped - except TLSStreamProtocolError as exc: - rstream.error = exc - rstream.state = AsyncStreamState.Error - if not(rstream.handshaked): - rstream.handshaked = true - rstream.stream.writer.handshaked = true - if not(isNil(rstream.handshakeFut)): - rstream.handshakeFut.fail(rstream.error) - rstream.switchToWriter.fire() - except AsyncStreamReadError as exc: - rstream.error = exc - rstream.state = AsyncStreamState.Error - if not(rstream.handshaked): - rstream.handshaked = true - rstream.stream.writer.handshaked = true - if not(isNil(rstream.handshakeFut)): - rstream.handshakeFut.fail(rstream.error) - rstream.switchToWriter.fire() - finally: - # Perform TLS cleanup procedure - sslEngineClose(engine) - rstream.buffer.forget() - rstream.stream = nil + if rstream.state != AsyncStreamState.Running: + # Perform TLS cleanup procedure + if rstream.state != AsyncStreamState.Finished: + sslEngineClose(engine) + rstream.buffer.forget() + rstream.stream = nil + break proc getSignerAlgo(xc: X509Certificate): int = ## Get certificate's signing algorithm. @@ -291,9 +271,9 @@ proc getSignerAlgo(xc: X509Certificate): int = x509DecoderPush(addr dc, xc.data, xc.dataLen) let err = x509DecoderLastError(addr dc) if err != 0: - result = -1 + -1 else: - result = int(x509DecoderGetSignerKeyType(addr dc)) + int(x509DecoderGetSignerKeyType(addr dc)) proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, @@ -318,54 +298,59 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. ## ## ``flags`` - custom TLS connection flags. - result = new TLSAsyncStream - var reader = TLSStreamReader(kind: TLSStreamKind.Client) - var writer = TLSStreamWriter(kind: TLSStreamKind.Client) - var switchToWriter = newAsyncEvent() - var switchToReader = newAsyncEvent() - reader.stream = result - writer.stream = result - reader.switchToReader = switchToReader - reader.switchToWriter = switchToWriter - writer.switchToReader = switchToReader - writer.switchToWriter = switchToWriter - result.reader = reader - result.writer = writer - reader.ccontext = addr result.ccontext - writer.ccontext = addr result.ccontext + let switchToWriter = newAsyncEvent() + let switchToReader = newAsyncEvent() + var res = TLSAsyncStream() + var reader = TLSStreamReader( + kind: TLSStreamKind.Client, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + ccontext: addr res.ccontext + ) + var writer = TLSStreamWriter( + kind: TLSStreamKind.Client, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + ccontext: addr res.ccontext + ) + res.reader = reader + res.writer = writer if TLSFlags.NoVerifyHost in flags: - sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0) - initNoAnchor(addr result.xwc, addr result.x509.vtable) - sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable) + sslClientInitFull(addr res.ccontext, addr res.x509, nil, 0) + initNoAnchor(addr res.xwc, addr res.x509.vtable) + sslEngineSetX509(addr res.ccontext.eng, addr res.xwc.vtable) else: - sslClientInitFull(addr result.ccontext, addr result.x509, + sslClientInitFull(addr res.ccontext, addr res.x509, unsafeAddr MozillaTrustAnchors[0], len(MozillaTrustAnchors)) let size = max(SSL_BUFSIZE_BIDI, bufferSize) - result.sbuffer = newSeq[byte](size) - sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0], - uint(len(result.sbuffer)), 1) - sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion), + res.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr res.ccontext.eng, addr res.sbuffer[0], + uint(len(res.sbuffer)), 1) + sslEngineSetVersions(addr res.ccontext.eng, uint16(minVersion), uint16(maxVersion)) if TLSFlags.NoVerifyServerName in flags: - let err = sslClientReset(addr result.ccontext, "", 0) + let err = sslClientReset(addr res.ccontext, "", 0) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") else: if len(serverName) == 0: raise newException(TLSStreamError, "serverName must not be empty string") - let err = sslClientReset(addr result.ccontext, serverName, 0) + let err = sslClientReset(addr res.ccontext, serverName, 0) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") - init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, bufferSize) - init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, bufferSize) + res proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, @@ -395,98 +380,104 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, if isNil(certificate) or len(certificate.certs) == 0: raiseTLSStreamProtoError("Incorrect certificate") - result = new TLSAsyncStream - var reader = TLSStreamReader(kind: TLSStreamKind.Server) - var writer = TLSStreamWriter(kind: TLSStreamKind.Server) - var switchToWriter = newAsyncEvent() - var switchToReader = newAsyncEvent() - reader.stream = result - writer.stream = result - reader.switchToReader = switchToReader - reader.switchToWriter = switchToWriter - writer.switchToReader = switchToReader - writer.switchToWriter = switchToWriter - result.reader = reader - result.writer = writer - reader.scontext = addr result.scontext - writer.scontext = addr result.scontext + let switchToWriter = newAsyncEvent() + let switchToReader = newAsyncEvent() + + var res = TLSAsyncStream() + var reader = TLSStreamReader( + kind: TLSStreamKind.Server, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + scontext: addr res.scontext + ) + var writer = TLSStreamWriter( + kind: TLSStreamKind.Server, + stream: res, + switchToReader: switchToReader, + switchToWriter: switchToWriter, + scontext: addr res.scontext + ) + res.reader = reader + res.writer = writer if privateKey.kind == TLSKeyType.EC: let algo = getSignerAlgo(certificate.certs[0]) if algo == -1: raiseTLSStreamProtoError("Could not decode certificate") - sslServerInitFullEc(addr result.scontext, addr certificate.certs[0], + sslServerInitFullEc(addr res.scontext, addr certificate.certs[0], len(certificate.certs), cuint(algo), addr privateKey.eckey) elif privateKey.kind == TLSKeyType.RSA: - sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0], + sslServerInitFullRsa(addr res.scontext, addr certificate.certs[0], len(certificate.certs), addr privateKey.rsakey) let size = max(SSL_BUFSIZE_BIDI, bufferSize) - result.sbuffer = newSeq[byte](size) - sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0], - uint(len(result.sbuffer)), 1) - sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion), + res.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr res.scontext.eng, addr res.sbuffer[0], + uint(len(res.sbuffer)), 1) + sslEngineSetVersions(addr res.scontext.eng, uint16(minVersion), uint16(maxVersion)) if not isNil(cache): - sslServerSetCache(addr result.scontext, addr cache.context.vtable) + sslServerSetCache(addr res.scontext, addr cache.context.vtable) if TLSFlags.EnforceServerPref in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) + sslEngineAddFlags(addr res.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) if TLSFlags.NoRenegotiation in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION) + sslEngineAddFlags(addr res.scontext.eng, OPT_NO_RENEGOTIATION) if TLSFlags.TolerateNoClientAuth in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) + sslEngineAddFlags(addr res.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) if TLSFlags.FailOnAlpnMismatch in flags: - sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) + sslEngineAddFlags(addr res.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) - let err = sslServerReset(addr result.scontext) + let err = sslServerReset(addr res.scontext) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") - init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop, bufferSize) - init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop, bufferSize) + res proc copyKey(src: RsaPrivateKey): TLSPrivateKey = ## Creates copy of RsaPrivateKey ``src``. var offset = 0 let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen - result = TLSPrivateKey(kind: TLSKeyType.RSA) - result.storage = newSeq[byte](keySize) - copyMem(addr result.storage[offset], src.p, src.plen) - result.rsakey.p = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.plen = src.plen + var res = TLSPrivateKey(kind: TLSKeyType.RSA, storage: newSeq[byte](keySize)) + copyMem(addr res.storage[offset], src.p, src.plen) + res.rsakey.p = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.plen = src.plen offset = offset + src.plen - copyMem(addr result.storage[offset], src.q, src.qlen) - result.rsakey.q = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.qlen = src.qlen + copyMem(addr res.storage[offset], src.q, src.qlen) + res.rsakey.q = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.qlen = src.qlen offset = offset + src.qlen - copyMem(addr result.storage[offset], src.dp, src.dplen) - result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.dplen = src.dplen + copyMem(addr res.storage[offset], src.dp, src.dplen) + res.rsakey.dp = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.dplen = src.dplen offset = offset + src.dplen - copyMem(addr result.storage[offset], src.dq, src.dqlen) - result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.dqlen = src.dqlen + copyMem(addr res.storage[offset], src.dq, src.dqlen) + res.rsakey.dq = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.dqlen = src.dqlen offset = offset + src.dqlen - copyMem(addr result.storage[offset], src.iq, src.iqlen) - result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset]) - result.rsakey.iqlen = src.iqlen - result.rsakey.nBitlen = src.nBitlen + copyMem(addr res.storage[offset], src.iq, src.iqlen) + res.rsakey.iq = cast[ptr cuchar](addr res.storage[offset]) + res.rsakey.iqlen = src.iqlen + res.rsakey.nBitlen = src.nBitlen + res proc copyKey(src: EcPrivateKey): TLSPrivateKey = ## Creates copy of EcPrivateKey ``src``. var offset = 0 let keySize = src.xlen - result = TLSPrivateKey(kind: TLSKeyType.EC) - result.storage = newSeq[byte](keySize) - copyMem(addr result.storage[offset], src.x, src.xlen) - result.eckey.x = cast[ptr cuchar](addr result.storage[offset]) - result.eckey.xlen = src.xlen - result.eckey.curve = src.curve + var res = TLSPrivateKey(kind: TLSKeyType.EC, storage: newSeq[byte](keySize)) + copyMem(addr res.storage[offset], src.x, src.xlen) + res.eckey.x = cast[ptr cuchar](addr res.storage[offset]) + res.eckey.xlen = src.xlen + res.eckey.curve = src.curve + res proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = ## Initialize TLS private key from array of bytes ``data``. @@ -502,12 +493,14 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = if err != 0: raiseTLSStreamProtoError(err) let keyType = skeyDecoderKeyType(addr ctx) - if keyType == KEYTYPE_RSA: - result = copyKey(ctx.key.rsa) - elif keyType == KEYTYPE_EC: - result = copyKey(ctx.key.ec) - else: - raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + let res = + if keyType == KEYTYPE_RSA: + copyKey(ctx.key.rsa) + elif keyType == KEYTYPE_EC: + copyKey(ctx.key.ec) + else: + raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + res proc pemDecode*(data: openarray[char]): seq[PEMElement] = ## Decode PEM encoded string and get array of binary blobs. @@ -515,7 +508,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] = raiseTLSStreamProtoError("Empty PEM message") var ctx: PemDecoderContext var pctx = new PEMContext - result = newSeq[PEMElement]() + var res = newSeq[PEMElement]() pemDecoderInit(addr ctx) proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} = @@ -544,12 +537,13 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] = elif event == PEM_END_OBJ: if inobj: elem.data = pctx.data - result.add(elem) + res.add(elem) inobj = false else: break else: raiseTLSStreamProtoError("Invalid PEM encoding") + res proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = ## Initialize TLS private key from string ``data``. @@ -558,13 +552,15 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = ## encoded string. ## ## Note that PKCS#1 PEM encoded objects are not supported. + var res: TLSPrivateKey var items = pemDecode(data) for item in items: if item.name == "PRIVATE KEY": - result = TLSPrivateKey.init(item.data) + res = TLSPrivateKey.init(item.data) break - if isNil(result): + if isNil(res): raiseTLSStreamProtoError("Could not find private key") + res proc init*(tt: typedesc[TLSCertificate], data: openarray[char]): TLSCertificate = @@ -572,32 +568,33 @@ proc init*(tt: typedesc[TLSCertificate], ## ## This procedure initializes array of certificates from PEM encoded string. var items = pemDecode(data) - result = new TLSCertificate + var res = TLSCertificate() for item in items: if item.name == "CERTIFICATE" and len(item.data) > 0: - let offset = len(result.storage) - result.storage.add(item.data) + let offset = len(res.storage) + res.storage.add(item.data) let cert = X509Certificate( - data: cast[ptr cuchar](addr result.storage[offset]), + data: cast[ptr cuchar](addr res.storage[offset]), dataLen: len(item.data) ) - let res = getSignerAlgo(cert) - if res == -1: + let ares = getSignerAlgo(cert) + if ares == -1: raiseTLSStreamProtoError("Could not decode certificate") - elif res != KEYTYPE_RSA and res != KEYTYPE_EC: + elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC: raiseTLSStreamProtoError("Unsupported signing key type in certificate") - result.certs.add(cert) - if len(result.storage) == 0: + res.certs.add(cert) + if len(res.storage) == 0: raiseTLSStreamProtoError("Could not find any certificates") + res proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = ## Create new TLS session cache with size ``size``. ## ## One cached item is near 100 bytes size. - result = new TLSSessionCache var rsize = min(size, 4096) - result.storage = newSeq[byte](rsize) - sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize) + var res = TLSSessionCache(storage: newSeq[byte](rsize)) + sslSessionCacheLruInit(addr res.context, addr res.storage[0], rsize) + res proc handshake*(rws: SomeTLSStreamType): Future[void] = ## Wait until initial TLS handshake will be successfully performed. @@ -620,4 +617,4 @@ proc handshake*(rws: SomeTLSStreamType): Future[void] = else: rws.reader.handshakeFut = retFuture rws.writer.handshakeFut = retFuture - return retFuture + retFuture diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index d2720f5..13b426c 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -6,7 +6,8 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import unittest -import ../chronos, ../chronos/streams/tlsstream +import ../chronos +import ../chronos/streams/[tlsstream, chunkstream, boundstream] when defined(nimHasUsed): {.used.} @@ -553,8 +554,12 @@ suite "ChunkedStream test suite": try: var r = await rstream2.read() doAssert(len(r) > 0) - except AsyncStreamReadError: - res = true + except ChunkedStreamIncompleteError: + if inputstr == "100000000 \r\n1": + res = true + except ChunkedStreamProtocolError: + if inputstr == "z\r\n1": + res = true await rstream2.closeWait() await rstream.closeWait() await transp.closeWait() @@ -663,3 +668,119 @@ suite "TLSStream test suite": getTracker("async.stream.writer").isLeaked() == false getTracker("stream.server").isLeaked() == false getTracker("stream.transport").isLeaked() == false + +suite "BoundedStream test suite": + + proc createBigMessage(size: int): seq[byte] = + var message = "MESSAGE" + result = newSeq[byte](size) + for i in 0 ..< len(result): + result[i] = byte(message[i mod len(message)]) + + for item in [100'u64, 60000'u64]: + + proc boundedTest(address: TransportAddress, test: int, + size: uint64): Future[bool] {.async.} = + var clientRes = false + var res = false + + let messagePart = createBigMessage(int(item) div 10) + var message: seq[byte] + for i in 0 ..< 10: + message.add(messagePart) + + proc processClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + var wbstream = newBoundedStreamWriter(wstream, size) + if test == 0: + for i in 0 ..< 10: + await wbstream.write(messagePart) + await wbstream.finish() + await wbstream.closeWait() + clientRes = true + elif test == 1: + for i in 0 ..< 10: + await wbstream.write(messagePart) + try: + await wbstream.write(messagePart) + except BoundedStreamOverflowError: + clientRes = true + await wbstream.closeWait() + elif test == 2: + for i in 0 ..< 9: + await wbstream.write(messagePart) + try: + await wbstream.finish() + except BoundedStreamIncompleteError: + clientRes = true + await wbstream.closeWait() + elif test == 3: + for i in 0 ..< 10: + await wbstream.write(messagePart) + await wbstream.finish() + await wbstream.closeWait() + clientRes = true + elif test == 4: + for i in 0 ..< 9: + await wbstream.write(messagePart) + try: + await wbstream.closeWait() + except BoundedStreamIncompleteError: + clientRes = true + + await wstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, processClient, flags = {ReuseAddr}) + server.start() + var conn = await connect(address) + var rstream = newAsyncStreamReader(conn) + var rbstream = newBoundedStreamReader(rstream, size) + if test == 0: + let response = await rbstream.read() + await rbstream.closeWait() + if response == message: + res = true + elif test == 1: + let response = await rbstream.read() + await rbstream.closeWait() + if response == message: + res = true + elif test == 2: + try: + let response {.used.} = await rbstream.read() + except BoundedStreamIncompleteError: + res = true + await rbstream.closeWait() + elif test == 3: + let response {.used.} = await rbstream.read(int(size) - 1) + try: + await rbstream.closeWait() + except BoundedStreamIncompleteError: + res = true + elif test == 4: + try: + let response {.used.} = await rbstream.read() + except BoundedStreamIncompleteError: + res = true + await rbstream.closeWait() + + await rstream.closeWait() + await conn.closeWait() + await server.join() + return (res and clientRes) + + let address = initTAddress("127.0.0.1:48030") + test "BoundedStream reading/writing test [" & $item & "]": + check waitFor(boundedTest(address, 0, item)) == true + test "BoundedStream overflow test [" & $item & "]": + check waitFor(boundedTest(address, 1, item)) == true + test "BoundedStream incomplete test [" & $item & "]": + check waitFor(boundedTest(address, 2, item)) == true + test "BoundedStream read() close test [" & $item & "]": + check waitFor(boundedTest(address, 3, item)) == true + test "BoundedStream write() close test [" & $item & "]": + check waitFor(boundedTest(address, 4, item)) == true