diff --git a/chronos/streams/boundstream.nim b/chronos/streams/boundstream.nim index 1792347b..d5a3f7d4 100644 --- a/chronos/streams/boundstream.nim +++ b/chronos/streams/boundstream.nim @@ -19,14 +19,19 @@ import asyncstream, ../transports/stream, ../transports/common export asyncstream, stream, timer, common type + BoundCmp* {.pure.} = enum + Equal, LessOrEqual + BoundedStreamReader* = ref object of AsyncStreamReader boundSize: int boundary: seq[byte] offset: int + cmpop: BoundCmp BoundedStreamWriter* = ref object of AsyncStreamWriter boundSize: int offset: int + cmpop: BoundCmp BoundedStreamError* = object of AsyncStreamError BoundedStreamIncompleteError* = object of BoundedStreamError @@ -46,18 +51,18 @@ template newBoundedStreamOverflowError*(): ref BoundedStreamError = proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, sep: seq[byte]): Future[int] {.async.} = doAssert(not(isNil(pbytes)), "pbytes must not be nil") - doAssert(len(sep) > 0, "separator must not be empty") doAssert(nbytes >= 0, "nbytes must be non-negative value") checkStreamClosed(rstream) + if nbytes == 0: + return 0 + var k = 0 var state = 0 var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) - var error: ref AsyncStreamIncompleteError proc predicate(data: openarray[byte]): tuple[consumed: int, done: bool] = if len(data) == 0: - error = newAsyncStreamIncompleteError() (0, true) else: var index = 0 @@ -68,19 +73,17 @@ proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer, inc(index) pbuffer[k] = ch inc(k) - if sep[state] == ch: - inc(state) - if state == len(sep): - break - else: - state = 0 - (index, state == len(sep) or (k == nbytes)) + if len(sep) > 0: + if sep[state] == ch: + inc(state) + if state == len(sep): + break + else: + state = 0 + (index, (state == len(sep)) or (k == nbytes)) await rstream.readMessage(predicate) - if not isNil(error): - raise error - else: - return k + return k func endsWith(s, suffix: openarray[byte]): bool = var i = 0 @@ -95,65 +98,55 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Running var buffer = newSeq[byte](rstream.buffer.bufferLen()) while true: - if len(rstream.boundary) == 0: - # Only size boundary set - if rstream.offset < rstream.boundSize: - let toRead = min(rstream.boundSize - rstream.offset, - rstream.buffer.bufferLen()) - try: - await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead) - rstream.offset = rstream.offset + toRead - rstream.buffer.update(toRead) - await rstream.buffer.transfer() - except AsyncStreamIncompleteError: - rstream.state = AsyncStreamState.Error - rstream.error = newBoundedStreamIncompleteError() - except AsyncStreamReadError as exc: - rstream.state = AsyncStreamState.Error - rstream.error = exc - except CancelledError: - rstream.state = AsyncStreamState.Stopped - - if rstream.state != AsyncStreamState.Running: - break - else: - rstream.state = AsyncStreamState.Finished - await rstream.buffer.transfer() - break - else: - # Sequence boundary set - if ((rstream.boundSize >= 0) and (rstream.offset < rstream.boundSize)) or - (rstream.boundSize < 0): - let toRead = - if rstream.boundSize < 0: - len(buffer) - else: - min(rstream.boundSize - rstream.offset, len(buffer)) - try: - let res = await readUntilBoundary(rstream.rsource, addr buffer[0], - toRead, rstream.boundary) - if endsWith(buffer.toOpenArray(0, res - 1), rstream.boundary): - let length = res - len(rstream.boundary) - rstream.offset = rstream.offset + length - await upload(addr rstream.buffer, addr buffer[0], length) - rstream.state = AsyncStreamState.Finished + # r1 is `true` if `boundSize` was not set. + let r1 = rstream.boundSize < 0 + # r2 is `true` if number of bytes read is less then `boundSize`. + let r2 = (rstream.boundSize > 0) and (rstream.offset < rstream.boundSize) + if r1 or r2: + let toRead = + if rstream.boundSize < 0: + len(buffer) + else: + min(rstream.boundSize - rstream.offset, len(buffer)) + try: + let res = await readUntilBoundary(rstream.rsource, addr buffer[0], + toRead, rstream.boundary) + if res > 0: + if len(rstream.boundary) > 0: + if endsWith(buffer.toOpenArray(0, res - 1), rstream.boundary): + let length = res - len(rstream.boundary) + rstream.offset = rstream.offset + length + await upload(addr rstream.buffer, addr buffer[0], length) + rstream.state = AsyncStreamState.Finished + else: + rstream.offset = rstream.offset + res + await upload(addr rstream.buffer, addr buffer[0], res) else: rstream.offset = rstream.offset + res await upload(addr rstream.buffer, addr buffer[0], res) - except AsyncStreamIncompleteError: - rstream.state = AsyncStreamState.Error - rstream.error = newBoundedStreamIncompleteError() - except AsyncStreamReadError as exc: - rstream.state = AsyncStreamState.Error - rstream.error = exc - except CancelledError: - rstream.state = AsyncStreamState.Stopped + else: + case rstream.cmpop + of BoundCmp.Equal: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() + of BoundCmp.LessOrEqual: + rstream.state = AsyncStreamState.Finished - if rstream.state != AsyncStreamState.Running: - break - else: - rstream.state = AsyncStreamState.Finished + except AsyncStreamReadError as exc: + rstream.state = AsyncStreamState.Error + rstream.error = exc + except CancelledError: + rstream.state = AsyncStreamState.Stopped + + if rstream.state != AsyncStreamState.Running: + if rstream.state == AsyncStreamState.Finished: + # This is state when BoundCmp.LessOrEqual and readExactly returned + # `AsyncStreamIncompleteError`. + await rstream.buffer.transfer() break + else: + rstream.state = AsyncStreamState.Finished + break # Without this additional wait, procedures such as `read()` could got stuck # in `await.buffer.wait()` because procedures are unable to detect EOF while @@ -191,8 +184,13 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = error = newBoundedStreamOverflowError() else: if wstream.offset != wstream.boundSize: - wstream.state = AsyncStreamState.Error - error = newBoundedStreamIncompleteError() + case wstream.cmpop + of BoundCmp.Equal: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamIncompleteError() + of BoundCmp.LessOrEqual: + wstream.state = AsyncStreamState.Finished + item.future.complete() else: wstream.state = AsyncStreamState.Finished item.future.complete() @@ -235,22 +233,26 @@ proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader, proc newBoundedStreamReader*[T](rsource: AsyncStreamReader, boundSize: int, boundary: openarray[byte] = [], + comparison = BoundCmp.Equal, bufferSize = BoundedBufferSize, udata: ref T): BoundedStreamReader = - doAssert(boundSize >= 0 or len(boundary) > 0, + doAssert(not(boundSize <= 0 and (len(boundary) == 0)), "At least one type of boundary should be set") - var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary) + var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary, + cmpop: comparison) res.init(rsource, bufferSize, udata) res proc newBoundedStreamReader*(rsource: AsyncStreamReader, boundSize: int, boundary: openarray[byte] = [], + comparison = BoundCmp.Equal, bufferSize = BoundedBufferSize, ): BoundedStreamReader = - doAssert(boundSize >= 0 or len(boundary) > 0, + doAssert(not(boundSize <= 0 and (len(boundary) == 0)), "At least one type of boundary should be set") - var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary) + var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary, + cmpop: comparison) res.init(rsource, bufferSize) res @@ -265,16 +267,20 @@ proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter, proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter, boundSize: int, + comparison = BoundCmp.Equal, queueSize = AsyncStreamDefaultQueueSize, udata: ref T): BoundedStreamWriter = - var res = BoundedStreamWriter(boundSize: boundSize) + doAssert(boundSize > 0, "Bound size must be bigger then zero") + var res = BoundedStreamWriter(boundSize: boundSize, cmpop: comparison) res.init(wsource, queueSize, udata) res proc newBoundedStreamWriter*(wsource: AsyncStreamWriter, boundSize: int, + comparison = BoundCmp.Equal, queueSize = AsyncStreamDefaultQueueSize, ): BoundedStreamWriter = - var res = BoundedStreamWriter(boundSize: boundSize) + doAssert(boundSize > 0, "Bound size must be bigger then zero") + var res = BoundedStreamWriter(boundSize: boundSize, cmpop: comparison) res.init(wsource, queueSize) res diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index e4c53304..3cedfe78 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -8,7 +8,7 @@ import unittest import ../chronos import ../chronos/streams/[tlsstream, chunkstream, boundstream] - +import nimcrypto/utils when defined(nimHasUsed): {.used.} # To create self-signed certificate and key you can use openssl @@ -672,184 +672,259 @@ suite "TLSStream test suite": suite "BoundedStream test suite": proc createBigMessage(size: int): seq[byte] = - var message = "MESSAGE" - result = newSeq[byte](size) + var message = "ABCDEFGHIJKLMNOP" + var res = newSeq[byte](size) for i in 0 ..< len(result): - result[i] = byte(message[i mod len(message)]) + res[i] = byte(message[i mod len(message)]) + res + for itemComp in [BoundCmp.Equal, BoundCmp.LessOrEqual]: + for itemSize in [100, 60000]: - for item in [100, 60000]: + proc boundaryTest(address: TransportAddress, test: int, size: int, + boundary: seq[byte], + cmp: BoundCmp): Future[bool] {.async.} = + var message = createBigMessage(size) + var clientRes = false - proc boundaryTest(address: TransportAddress, test: int, size: int, - boundary: seq[byte]): Future[bool] {.async.} = - var message = createBigMessage(size) - var clientRes = false + proc processClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + if test == 0: + await wstream.write(message) + await wstream.write(boundary) + await wstream.finish() + await wstream.closeWait() + clientRes = true + elif test == 1: + await wstream.write(message) + await wstream.write(boundary) + await wstream.write(message) + await wstream.finish() + await wstream.closeWait() + clientRes = true + elif test == 2: + var ncmessage = message + ncmessage.setLen(len(message) - 2) + await wstream.write(ncmessage) + await wstream.write(@[0x2D'u8, 0x2D'u8]) + await wstream.finish() + await wstream.closeWait() + clientRes = true + elif test == 3: + var ncmessage = message + ncmessage.setLen(len(message) - 2) + await wstream.write(ncmessage) + await wstream.finish() + await wstream.closeWait() + clientRes = true + elif test == 4: + await wstream.write(boundary) + await wstream.finish() + await wstream.closeWait() + clientRes = true - proc processClient(server: StreamServer, - transp: StreamTransport) {.async.} = - var wstream = newAsyncStreamWriter(transp) + await transp.closeWait() + server.stop() + server.close() + + var res = false + let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay} + var server = createStreamServer(address, processClient, flags = flags) + server.start() + var conn = await connect(address) + var rstream = newAsyncStreamReader(conn) if test == 0: - await wstream.write(message) - await wstream.write(boundary) - await wstream.finish() - await wstream.closeWait() - clientRes = true + var rbstream = newBoundedStreamReader(rstream, -1, boundary) + let response = await rbstream.read() + if response == message: + res = true + await rbstream.closeWait() elif test == 1: - await wstream.write(message) - await wstream.write(boundary) - await wstream.write(message) - await wstream.finish() - await wstream.closeWait() - clientRes = true + var rbstream = newBoundedStreamReader(rstream, -1, boundary) + let response1 = await rbstream.read() + await rbstream.closeWait() + let response2 = await rstream.read() + if (response1 == message) and (response2 == message): + res = true elif test == 2: - var ncmessage = message - ncmessage.setLen(len(message) - 2) - await wstream.write(ncmessage) - await wstream.write(@[0x2D'u8, 0x2D'u8]) - await wstream.finish() - await wstream.closeWait() - clientRes = true + var expectMessage = message + expectMessage[^2] = 0x2D'u8 + expectMessage[^1] = 0x2D'u8 + var rbstream = newBoundedStreamReader(rstream, size, boundary) + let response = await rbstream.read() + await rbstream.closeWait() + if (len(response) == size) and response == expectMessage: + res = true elif test == 3: - var ncmessage = message - ncmessage.setLen(len(message) - 2) - await wstream.write(ncmessage) - await wstream.finish() - await wstream.closeWait() - clientRes = true - - await transp.closeWait() - server.stop() - server.close() - - var res = false - var server = createStreamServer(address, processClient, - flags = {ReuseAddr}) - server.start() - var conn = await connect(address) - var rstream = newAsyncStreamReader(conn) - if test == 0: - var rbstream = newBoundedStreamReader(rstream, -1, boundary) - let response = await rbstream.read() - if response == message: - res = true - await rbstream.closeWait() - elif test == 1: - var rbstream = newBoundedStreamReader(rstream, -1, boundary) - let response1 = await rbstream.read() - await rbstream.closeWait() - let response2 = await rstream.read() - if (response1 == message) and (response2 == message): - res = true - elif test == 2: - var expectMessage = message - expectMessage[^2] = 0x2D'u8 - expectMessage[^1] = 0x2D'u8 - var rbstream = newBoundedStreamReader(rstream, size, boundary) - let response = await rbstream.read() - await rbstream.closeWait() - if (len(response) == size) and response == expectMessage: - res = true - elif test == 3: - var rbstream = newBoundedStreamReader(rstream, -1, boundary) - try: - let response {.used.} = await rbstream.read() - except BoundedStreamIncompleteError: - res = true - await rbstream.closeWait() - - await rstream.closeWait() - await conn.closeWait() - await server.join() - return (res and clientRes) - - proc boundedTest(address: TransportAddress, test: int, - size: int): Future[bool] {.async.} = - var clientRes = false - var res = false - - let messagePart = createBigMessage(int(item) div 10) - var message: seq[byte] - for i in 0 ..< 10: - message.add(messagePart) - - proc processClient(server: StreamServer, - transp: StreamTransport) {.async.} = - var wstream = newAsyncStreamWriter(transp) - var wbstream = newBoundedStreamWriter(wstream, size) - if test == 0: - for i in 0 ..< 10: - await wbstream.write(messagePart) - await wbstream.finish() - await wbstream.closeWait() - clientRes = true - elif test == 1: - for i in 0 ..< 10: - await wbstream.write(messagePart) + var rbstream = newBoundedStreamReader(rstream, -1, boundary) try: - await wbstream.write(messagePart) - except BoundedStreamOverflowError: - clientRes = true - await wbstream.closeWait() - elif test == 2: - for i in 0 ..< 9: - await wbstream.write(messagePart) - try: - await wbstream.finish() + let response {.used.} = await rbstream.read() except BoundedStreamIncompleteError: + res = true + await rbstream.closeWait() + elif test == 4: + var rbstream = newBoundedStreamReader(rstream, -1, boundary) + let response = await rbstream.read() + await rbstream.closeWait() + if len(response) == 0: + res = true + + await rstream.closeWait() + await conn.closeWait() + await server.join() + return (res and clientRes) + + proc boundedTest(address: TransportAddress, test: int, + size: int, cmp: BoundCmp): Future[bool] {.async.} = + var clientRes = false + var res = false + + let messagePart = createBigMessage(int(itemSize) div 10) + var message: seq[byte] + for i in 0 ..< 10: + message.add(messagePart) + + proc processClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + var wbstream = newBoundedStreamWriter(wstream, size, comparison = cmp) + if test == 0: + for i in 0 ..< 10: + await wbstream.write(messagePart) + await wbstream.finish() + await wbstream.closeWait() clientRes = true - await wbstream.closeWait() + elif test == 1: + for i in 0 ..< 10: + await wbstream.write(messagePart) + try: + await wbstream.write(messagePart) + except BoundedStreamOverflowError: + clientRes = true + await wbstream.closeWait() + elif test == 2: + for i in 0 ..< 9: + await wbstream.write(messagePart) + case cmp + of BoundCmp.Equal: + try: + await wbstream.finish() + except BoundedStreamIncompleteError: + clientRes = true + of BoundCmp.LessOrEqual: + try: + await wbstream.finish() + clientRes = true + except BoundedStreamIncompleteError: + discard + await wbstream.closeWait() + elif test == 3: + case cmp + of BoundCmp.Equal: + try: + await wbstream.finish() + except BoundedStreamIncompleteError: + clientRes = true + of BoundCmp.LessOrEqual: + try: + await wbstream.finish() + clientRes = true + except BoundedStreamIncompleteError: + discard + await wbstream.closeWait() - await wstream.closeWait() - await transp.closeWait() - server.stop() - server.close() + await wstream.closeWait() + await transp.closeWait() + server.stop() + server.close() - var server = createStreamServer(address, processClient, flags = {ReuseAddr}) - server.start() - var conn = await connect(address) - var rstream = newAsyncStreamReader(conn) - var rbstream = newBoundedStreamReader(rstream, size) - if test == 0: - let response = await rbstream.read() - await rbstream.closeWait() - if response == message: - res = true - elif test == 1: - let response = await rbstream.read() - await rbstream.closeWait() - if response == message: - res = true - elif test == 2: - try: - let response {.used.} = await rbstream.read() - except BoundedStreamIncompleteError: - res = true - await rbstream.closeWait() + let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay} + var server = createStreamServer(address, processClient, flags = flags) + server.start() + var conn = await connect(address) + var rstream = newAsyncStreamReader(conn) + var rbstream = newBoundedStreamReader(rstream, size, comparison = cmp) + if test == 0: + let response = await rbstream.read() + await rbstream.closeWait() + if response == message: + res = true + elif test == 1: + let response = await rbstream.read() + await rbstream.closeWait() + if response == message: + res = true + elif test == 2: + case cmp + of BoundCmp.Equal: + try: + let response {.used.} = await rbstream.read() + except BoundedStreamIncompleteError: + res = true + of BoundCmp.LessOrEqual: + try: + let response = await rbstream.read() + if len(response) == 9 * len(messagePart): + res = true + except BoundedStreamIncompleteError: + res = false + await rbstream.closeWait() + elif test == 3: + case cmp + of BoundCmp.Equal: + try: + let response {.used.} = await rbstream.read() + except BoundedStreamIncompleteError: + res = true + of BoundCmp.LessOrEqual: + try: + let response = await rbstream.read() + if len(response) == 0: + res = true + except BoundedStreamIncompleteError: + res = false + await rbstream.closeWait() - await rstream.closeWait() - await conn.closeWait() - await server.join() - return (res and clientRes) + await rstream.closeWait() + await conn.closeWait() + await server.join() + return (res and clientRes) - let address = initTAddress("127.0.0.1:48030") - test "BoundedStream(size) reading/writing test [" & $item & "]": - check waitFor(boundedTest(address, 0, item)) == true - test "BoundedStream(size) overflow test [" & $item & "]": - check waitFor(boundedTest(address, 1, item)) == true - test "BoundedStream(size) incomplete test [" & $item & "]": - check waitFor(boundedTest(address, 2, item)) == true - test "BoundedStream(boundary) reading test [" & $item & "]": - check waitFor(boundaryTest(address, 0, item, - @[0x2D'u8, 0x2D'u8, 0x2D'u8])) - test "BoundedStream(boundary) double message test [" & $item & "]": - check waitFor(boundaryTest(address, 1, item, - @[0x2D'u8, 0x2D'u8, 0x2D'u8])) - test "BoundedStream(size+boundary) reading size-bound test [" & $item & "]": - check waitFor(boundaryTest(address, 2, item, - @[0x2D'u8, 0x2D'u8, 0x2D'u8])) - test "BoundedStream(boundary) reading incomplete test [" & $item & "]": - check waitFor(boundaryTest(address, 3, item, - @[0x2D'u8, 0x2D'u8, 0x2D'u8])) + let address = initTAddress("127.0.0.1:48030") + let suffix = + case itemComp + of BoundCmp.Equal: + "== " & $itemSize + of BoundCmp.LessOrEqual: + "<= " & $itemSize + + test "BoundedStream(size) reading/writing test [" & suffix & "]": + check waitFor(boundedTest(address, 0, itemSize, itemComp)) == true + test "BoundedStream(size) overflow test [" & suffix & "]": + check waitFor(boundedTest(address, 1, itemSize, itemComp)) == true + test "BoundedStream(size) incomplete test [" & suffix & "]": + check waitFor(boundedTest(address, 2, itemSize, itemComp)) == true + test "BoundedStream(size) empty message test [" & suffix & "]": + check waitFor(boundedTest(address, 3, itemSize, itemComp)) == true + test "BoundedStream(boundary) reading test [" & suffix & "]": + check waitFor(boundaryTest(address, 0, itemSize, + @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) + test "BoundedStream(boundary) double message test [" & suffix & "]": + check waitFor(boundaryTest(address, 1, itemSize, + @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) + test "BoundedStream(size+boundary) reading size-bound test [" & + suffix & "]": + check waitFor(boundaryTest(address, 2, itemSize, + @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) + test "BoundedStream(boundary) reading incomplete test [" & + suffix & "]": + check waitFor(boundaryTest(address, 3, itemSize, + @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) + test "BoundedStream(boundary) empty message test [" & + suffix & "]": + check waitFor(boundaryTest(address, 4, itemSize, + @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) test "BoundedStream leaks test": check: