diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index bca01ac9..598d641e 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -15,7 +15,8 @@ export asyncstream, stream, timer, common const ChunkBufferSize = 4096 - ChunkHeaderSize = 8 + MaxChunkHeaderSize = 1024 + ChunkHeaderValueSize = 8 # This is limit for chunk size to 8 hexadecimal digits, so maximum # chunk size for this implementation become: # 2^32 == FFFF_FFFF'u32 == 4,294,967,295 bytes. @@ -49,9 +50,9 @@ proc hexValue*(c: byte): int = proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] = # We using `uint64` representation, but allow only 2^32 chunk size, - # ChunkHeaderSize. + # ChunkHeaderValueSize. var res = 0'u64 - for i in 0 ..< min(len(buffer), ChunkHeaderSize + 1): + for i in 0 ..< min(len(buffer), ChunkHeaderValueSize + 1): let value = hexValue(buffer[i]) if value < 0: if buffer[i] == byte(';'): @@ -60,7 +61,7 @@ proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] = else: return err("Incorrect chunk size encoding") else: - if i >= ChunkHeaderSize: + if i >= ChunkHeaderValueSize: return err("The chunk size exceeds the limit") res = (res shl 4) or uint64(value) ok(res) @@ -96,7 +97,7 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int = proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = ChunkedStreamReader(stream) - var buffer = newSeq[byte](1024) + var buffer = newSeq[byte](MaxChunkHeaderSize) rstream.state = AsyncStreamState.Running while true: @@ -137,6 +138,10 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = await rstream.buffer.transfer() except CancelledError: rstream.state = AsyncStreamState.Stopped + except AsyncStreamLimitError: + rstream.state = AsyncStreamState.Error + rstream.error = newException(ChunkedStreamProtocolError, + "Chunk header exceeds maximum size") except AsyncStreamIncompleteError: rstream.state = AsyncStreamState.Error rstream.error = newException(ChunkedStreamIncompleteError, diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 7a4f2112..39abbe45 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -485,6 +485,12 @@ suite "AsyncStream test suite": getTracker("stream.transport").isLeaked() == false suite "ChunkedStream test suite": + proc createBigMessage(message: string, size: int): string = + var res = newString(size) + for i in 0 ..< len(res): + res[i] = chr(ord(message[i mod len(message)])) + res + test "ChunkedStream test vectors": const ChunkedVectors = [ ["4\r\nWiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\n\r\n", @@ -642,6 +648,91 @@ suite "ChunkedStream test suite": else: check hexValue(byte(ch)) == -1 + test "ChunkedStream too big chunk header test": + proc checkTooBigChunkHeader(address: TransportAddress, + inputstr: string): Future[bool] {.async.} = + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + var data = inputstr + await wstream.write(data) + await wstream.finish() + await wstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + var transp = await connect(address) + var rstream = newAsyncStreamReader(transp) + var rstream2 = newChunkedStreamReader(rstream) + let res = + try: + var data {.used.} = await rstream2.read() + false + except ChunkedStreamProtocolError: + true + except CatchableError: + false + await rstream2.closeWait() + await rstream.closeWait() + await transp.closeWait() + await server.join() + return res + + let address = initTAddress("127.0.0.1:46001") + var data1 = createBigMessage("REQUESTSTREAMMESSAGE", 65600) + var data2 = createBigMessage("REQUESTSTREAMMESSAGE", 262400) + check waitFor(checkTooBigChunkHeader(address, data1)) == true + check waitFor(checkTooBigChunkHeader(address, data2)) == true + + test "ChunkedStream read/write test": + proc checkVector(address: TransportAddress, + inputstr: string, chsize: int): Future[string] {.async.} = + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + var wstream2 = newChunkedStreamWriter(wstream) + var data = inputstr + var offset = 0 + while true: + if len(data) == offset: + break + let toWrite = min(chsize, len(data) - offset) + await wstream2.write(addr data[offset], toWrite) + offset = offset + toWrite + await wstream2.finish() + await wstream2.closeWait() + await wstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + var transp = await connect(address) + var rstream = newAsyncStreamReader(transp) + var rstream2 = newChunkedStreamReader(rstream) + var res = await rstream2.read() + var ress = cast[string](res) + await rstream2.closeWait() + await rstream.closeWait() + await transp.closeWait() + await server.join() + result = ress + + proc testBigData(address: TransportAddress, + datasize: int, chunksize: int): Future[bool] {.async.} = + var data = createBigMessage("REQUESTSTREAMMESSAGE", datasize) + var check = await checkVector(address, data, chunksize) + return (data == check) + + let address = initTAddress("127.0.0.1:46001") + check waitFor(testBigData(address, 65600, 1024)) == true + check waitFor(testBigData(address, 262400, 4096)) == true + check waitFor(testBigData(address, 767309, 4457)) == true + test "ChunkedStream leaks test": check: getTracker("async.stream.reader").isLeaked() == false