Add less strict rules for BoundStream reader/writer.

This commit is contained in:
cheatfate 2021-02-09 13:56:33 +02:00 committed by zah
parent d43a9cb92d
commit 970e5641d7
2 changed files with 320 additions and 239 deletions

View File

@ -19,14 +19,19 @@ import asyncstream, ../transports/stream, ../transports/common
export asyncstream, stream, timer, common
type
BoundCmp* {.pure.} = enum
Equal, LessOrEqual
BoundedStreamReader* = ref object of AsyncStreamReader
boundSize: int
boundary: seq[byte]
offset: int
cmpop: BoundCmp
BoundedStreamWriter* = ref object of AsyncStreamWriter
boundSize: int
offset: int
cmpop: BoundCmp
BoundedStreamError* = object of AsyncStreamError
BoundedStreamIncompleteError* = object of BoundedStreamError
@ -46,18 +51,18 @@ template newBoundedStreamOverflowError*(): ref BoundedStreamError =
proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer,
nbytes: int, sep: seq[byte]): Future[int] {.async.} =
doAssert(not(isNil(pbytes)), "pbytes must not be nil")
doAssert(len(sep) > 0, "separator must not be empty")
doAssert(nbytes >= 0, "nbytes must be non-negative value")
checkStreamClosed(rstream)
if nbytes == 0:
return 0
var k = 0
var state = 0
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes)
var error: ref AsyncStreamIncompleteError
proc predicate(data: openarray[byte]): tuple[consumed: int, done: bool] =
if len(data) == 0:
error = newAsyncStreamIncompleteError()
(0, true)
else:
var index = 0
@ -68,19 +73,17 @@ proc readUntilBoundary*(rstream: AsyncStreamReader, pbytes: pointer,
inc(index)
pbuffer[k] = ch
inc(k)
if sep[state] == ch:
inc(state)
if state == len(sep):
break
else:
state = 0
(index, state == len(sep) or (k == nbytes))
if len(sep) > 0:
if sep[state] == ch:
inc(state)
if state == len(sep):
break
else:
state = 0
(index, (state == len(sep)) or (k == nbytes))
await rstream.readMessage(predicate)
if not isNil(error):
raise error
else:
return k
return k
func endsWith(s, suffix: openarray[byte]): bool =
var i = 0
@ -95,65 +98,55 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.state = AsyncStreamState.Running
var buffer = newSeq[byte](rstream.buffer.bufferLen())
while true:
if len(rstream.boundary) == 0:
# Only size boundary set
if rstream.offset < rstream.boundSize:
let toRead = min(rstream.boundSize - rstream.offset,
rstream.buffer.bufferLen())
try:
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
rstream.offset = rstream.offset + toRead
rstream.buffer.update(toRead)
await rstream.buffer.transfer()
except AsyncStreamIncompleteError:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
except AsyncStreamReadError as exc:
rstream.state = AsyncStreamState.Error
rstream.error = exc
except CancelledError:
rstream.state = AsyncStreamState.Stopped
if rstream.state != AsyncStreamState.Running:
break
else:
rstream.state = AsyncStreamState.Finished
await rstream.buffer.transfer()
break
else:
# Sequence boundary set
if ((rstream.boundSize >= 0) and (rstream.offset < rstream.boundSize)) or
(rstream.boundSize < 0):
let toRead =
if rstream.boundSize < 0:
len(buffer)
else:
min(rstream.boundSize - rstream.offset, len(buffer))
try:
let res = await readUntilBoundary(rstream.rsource, addr buffer[0],
toRead, rstream.boundary)
if endsWith(buffer.toOpenArray(0, res - 1), rstream.boundary):
let length = res - len(rstream.boundary)
rstream.offset = rstream.offset + length
await upload(addr rstream.buffer, addr buffer[0], length)
rstream.state = AsyncStreamState.Finished
# r1 is `true` if `boundSize` was not set.
let r1 = rstream.boundSize < 0
# r2 is `true` if number of bytes read is less then `boundSize`.
let r2 = (rstream.boundSize > 0) and (rstream.offset < rstream.boundSize)
if r1 or r2:
let toRead =
if rstream.boundSize < 0:
len(buffer)
else:
min(rstream.boundSize - rstream.offset, len(buffer))
try:
let res = await readUntilBoundary(rstream.rsource, addr buffer[0],
toRead, rstream.boundary)
if res > 0:
if len(rstream.boundary) > 0:
if endsWith(buffer.toOpenArray(0, res - 1), rstream.boundary):
let length = res - len(rstream.boundary)
rstream.offset = rstream.offset + length
await upload(addr rstream.buffer, addr buffer[0], length)
rstream.state = AsyncStreamState.Finished
else:
rstream.offset = rstream.offset + res
await upload(addr rstream.buffer, addr buffer[0], res)
else:
rstream.offset = rstream.offset + res
await upload(addr rstream.buffer, addr buffer[0], res)
except AsyncStreamIncompleteError:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
except AsyncStreamReadError as exc:
rstream.state = AsyncStreamState.Error
rstream.error = exc
except CancelledError:
rstream.state = AsyncStreamState.Stopped
else:
case rstream.cmpop
of BoundCmp.Equal:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
of BoundCmp.LessOrEqual:
rstream.state = AsyncStreamState.Finished
if rstream.state != AsyncStreamState.Running:
break
else:
rstream.state = AsyncStreamState.Finished
except AsyncStreamReadError as exc:
rstream.state = AsyncStreamState.Error
rstream.error = exc
except CancelledError:
rstream.state = AsyncStreamState.Stopped
if rstream.state != AsyncStreamState.Running:
if rstream.state == AsyncStreamState.Finished:
# This is state when BoundCmp.LessOrEqual and readExactly returned
# `AsyncStreamIncompleteError`.
await rstream.buffer.transfer()
break
else:
rstream.state = AsyncStreamState.Finished
break
# Without this additional wait, procedures such as `read()` could got stuck
# in `await.buffer.wait()` because procedures are unable to detect EOF while
@ -191,8 +184,13 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
error = newBoundedStreamOverflowError()
else:
if wstream.offset != wstream.boundSize:
wstream.state = AsyncStreamState.Error
error = newBoundedStreamIncompleteError()
case wstream.cmpop
of BoundCmp.Equal:
wstream.state = AsyncStreamState.Error
error = newBoundedStreamIncompleteError()
of BoundCmp.LessOrEqual:
wstream.state = AsyncStreamState.Finished
item.future.complete()
else:
wstream.state = AsyncStreamState.Finished
item.future.complete()
@ -235,22 +233,26 @@ proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader,
proc newBoundedStreamReader*[T](rsource: AsyncStreamReader,
boundSize: int,
boundary: openarray[byte] = [],
comparison = BoundCmp.Equal,
bufferSize = BoundedBufferSize,
udata: ref T): BoundedStreamReader =
doAssert(boundSize >= 0 or len(boundary) > 0,
doAssert(not(boundSize <= 0 and (len(boundary) == 0)),
"At least one type of boundary should be set")
var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary)
var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary,
cmpop: comparison)
res.init(rsource, bufferSize, udata)
res
proc newBoundedStreamReader*(rsource: AsyncStreamReader,
boundSize: int,
boundary: openarray[byte] = [],
comparison = BoundCmp.Equal,
bufferSize = BoundedBufferSize,
): BoundedStreamReader =
doAssert(boundSize >= 0 or len(boundary) > 0,
doAssert(not(boundSize <= 0 and (len(boundary) == 0)),
"At least one type of boundary should be set")
var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary)
var res = BoundedStreamReader(boundSize: boundSize, boundary: @boundary,
cmpop: comparison)
res.init(rsource, bufferSize)
res
@ -265,16 +267,20 @@ proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter,
proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter,
boundSize: int,
comparison = BoundCmp.Equal,
queueSize = AsyncStreamDefaultQueueSize,
udata: ref T): BoundedStreamWriter =
var res = BoundedStreamWriter(boundSize: boundSize)
doAssert(boundSize > 0, "Bound size must be bigger then zero")
var res = BoundedStreamWriter(boundSize: boundSize, cmpop: comparison)
res.init(wsource, queueSize, udata)
res
proc newBoundedStreamWriter*(wsource: AsyncStreamWriter,
boundSize: int,
comparison = BoundCmp.Equal,
queueSize = AsyncStreamDefaultQueueSize,
): BoundedStreamWriter =
var res = BoundedStreamWriter(boundSize: boundSize)
doAssert(boundSize > 0, "Bound size must be bigger then zero")
var res = BoundedStreamWriter(boundSize: boundSize, cmpop: comparison)
res.init(wsource, queueSize)
res

View File

@ -8,7 +8,7 @@
import unittest
import ../chronos
import ../chronos/streams/[tlsstream, chunkstream, boundstream]
import nimcrypto/utils
when defined(nimHasUsed): {.used.}
# To create self-signed certificate and key you can use openssl
@ -672,184 +672,259 @@ suite "TLSStream test suite":
suite "BoundedStream test suite":
proc createBigMessage(size: int): seq[byte] =
var message = "MESSAGE"
result = newSeq[byte](size)
var message = "ABCDEFGHIJKLMNOP"
var res = newSeq[byte](size)
for i in 0 ..< len(result):
result[i] = byte(message[i mod len(message)])
res[i] = byte(message[i mod len(message)])
res
for itemComp in [BoundCmp.Equal, BoundCmp.LessOrEqual]:
for itemSize in [100, 60000]:
for item in [100, 60000]:
proc boundaryTest(address: TransportAddress, test: int, size: int,
boundary: seq[byte],
cmp: BoundCmp): Future[bool] {.async.} =
var message = createBigMessage(size)
var clientRes = false
proc boundaryTest(address: TransportAddress, test: int, size: int,
boundary: seq[byte]): Future[bool] {.async.} =
var message = createBigMessage(size)
var clientRes = false
proc processClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
if test == 0:
await wstream.write(message)
await wstream.write(boundary)
await wstream.finish()
await wstream.closeWait()
clientRes = true
elif test == 1:
await wstream.write(message)
await wstream.write(boundary)
await wstream.write(message)
await wstream.finish()
await wstream.closeWait()
clientRes = true
elif test == 2:
var ncmessage = message
ncmessage.setLen(len(message) - 2)
await wstream.write(ncmessage)
await wstream.write(@[0x2D'u8, 0x2D'u8])
await wstream.finish()
await wstream.closeWait()
clientRes = true
elif test == 3:
var ncmessage = message
ncmessage.setLen(len(message) - 2)
await wstream.write(ncmessage)
await wstream.finish()
await wstream.closeWait()
clientRes = true
elif test == 4:
await wstream.write(boundary)
await wstream.finish()
await wstream.closeWait()
clientRes = true
proc processClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
await transp.closeWait()
server.stop()
server.close()
var res = false
let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay}
var server = createStreamServer(address, processClient, flags = flags)
server.start()
var conn = await connect(address)
var rstream = newAsyncStreamReader(conn)
if test == 0:
await wstream.write(message)
await wstream.write(boundary)
await wstream.finish()
await wstream.closeWait()
clientRes = true
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response = await rbstream.read()
if response == message:
res = true
await rbstream.closeWait()
elif test == 1:
await wstream.write(message)
await wstream.write(boundary)
await wstream.write(message)
await wstream.finish()
await wstream.closeWait()
clientRes = true
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response1 = await rbstream.read()
await rbstream.closeWait()
let response2 = await rstream.read()
if (response1 == message) and (response2 == message):
res = true
elif test == 2:
var ncmessage = message
ncmessage.setLen(len(message) - 2)
await wstream.write(ncmessage)
await wstream.write(@[0x2D'u8, 0x2D'u8])
await wstream.finish()
await wstream.closeWait()
clientRes = true
var expectMessage = message
expectMessage[^2] = 0x2D'u8
expectMessage[^1] = 0x2D'u8
var rbstream = newBoundedStreamReader(rstream, size, boundary)
let response = await rbstream.read()
await rbstream.closeWait()
if (len(response) == size) and response == expectMessage:
res = true
elif test == 3:
var ncmessage = message
ncmessage.setLen(len(message) - 2)
await wstream.write(ncmessage)
await wstream.finish()
await wstream.closeWait()
clientRes = true
await transp.closeWait()
server.stop()
server.close()
var res = false
var server = createStreamServer(address, processClient,
flags = {ReuseAddr})
server.start()
var conn = await connect(address)
var rstream = newAsyncStreamReader(conn)
if test == 0:
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response = await rbstream.read()
if response == message:
res = true
await rbstream.closeWait()
elif test == 1:
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response1 = await rbstream.read()
await rbstream.closeWait()
let response2 = await rstream.read()
if (response1 == message) and (response2 == message):
res = true
elif test == 2:
var expectMessage = message
expectMessage[^2] = 0x2D'u8
expectMessage[^1] = 0x2D'u8
var rbstream = newBoundedStreamReader(rstream, size, boundary)
let response = await rbstream.read()
await rbstream.closeWait()
if (len(response) == size) and response == expectMessage:
res = true
elif test == 3:
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
try:
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
await rbstream.closeWait()
await rstream.closeWait()
await conn.closeWait()
await server.join()
return (res and clientRes)
proc boundedTest(address: TransportAddress, test: int,
size: int): Future[bool] {.async.} =
var clientRes = false
var res = false
let messagePart = createBigMessage(int(item) div 10)
var message: seq[byte]
for i in 0 ..< 10:
message.add(messagePart)
proc processClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
var wbstream = newBoundedStreamWriter(wstream, size)
if test == 0:
for i in 0 ..< 10:
await wbstream.write(messagePart)
await wbstream.finish()
await wbstream.closeWait()
clientRes = true
elif test == 1:
for i in 0 ..< 10:
await wbstream.write(messagePart)
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
try:
await wbstream.write(messagePart)
except BoundedStreamOverflowError:
clientRes = true
await wbstream.closeWait()
elif test == 2:
for i in 0 ..< 9:
await wbstream.write(messagePart)
try:
await wbstream.finish()
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
await rbstream.closeWait()
elif test == 4:
var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response = await rbstream.read()
await rbstream.closeWait()
if len(response) == 0:
res = true
await rstream.closeWait()
await conn.closeWait()
await server.join()
return (res and clientRes)
proc boundedTest(address: TransportAddress, test: int,
size: int, cmp: BoundCmp): Future[bool] {.async.} =
var clientRes = false
var res = false
let messagePart = createBigMessage(int(itemSize) div 10)
var message: seq[byte]
for i in 0 ..< 10:
message.add(messagePart)
proc processClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
var wbstream = newBoundedStreamWriter(wstream, size, comparison = cmp)
if test == 0:
for i in 0 ..< 10:
await wbstream.write(messagePart)
await wbstream.finish()
await wbstream.closeWait()
clientRes = true
await wbstream.closeWait()
elif test == 1:
for i in 0 ..< 10:
await wbstream.write(messagePart)
try:
await wbstream.write(messagePart)
except BoundedStreamOverflowError:
clientRes = true
await wbstream.closeWait()
elif test == 2:
for i in 0 ..< 9:
await wbstream.write(messagePart)
case cmp
of BoundCmp.Equal:
try:
await wbstream.finish()
except BoundedStreamIncompleteError:
clientRes = true
of BoundCmp.LessOrEqual:
try:
await wbstream.finish()
clientRes = true
except BoundedStreamIncompleteError:
discard
await wbstream.closeWait()
elif test == 3:
case cmp
of BoundCmp.Equal:
try:
await wbstream.finish()
except BoundedStreamIncompleteError:
clientRes = true
of BoundCmp.LessOrEqual:
try:
await wbstream.finish()
clientRes = true
except BoundedStreamIncompleteError:
discard
await wbstream.closeWait()
await wstream.closeWait()
await transp.closeWait()
server.stop()
server.close()
await wstream.closeWait()
await transp.closeWait()
server.stop()
server.close()
var server = createStreamServer(address, processClient, flags = {ReuseAddr})
server.start()
var conn = await connect(address)
var rstream = newAsyncStreamReader(conn)
var rbstream = newBoundedStreamReader(rstream, size)
if test == 0:
let response = await rbstream.read()
await rbstream.closeWait()
if response == message:
res = true
elif test == 1:
let response = await rbstream.read()
await rbstream.closeWait()
if response == message:
res = true
elif test == 2:
try:
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
await rbstream.closeWait()
let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay}
var server = createStreamServer(address, processClient, flags = flags)
server.start()
var conn = await connect(address)
var rstream = newAsyncStreamReader(conn)
var rbstream = newBoundedStreamReader(rstream, size, comparison = cmp)
if test == 0:
let response = await rbstream.read()
await rbstream.closeWait()
if response == message:
res = true
elif test == 1:
let response = await rbstream.read()
await rbstream.closeWait()
if response == message:
res = true
elif test == 2:
case cmp
of BoundCmp.Equal:
try:
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
of BoundCmp.LessOrEqual:
try:
let response = await rbstream.read()
if len(response) == 9 * len(messagePart):
res = true
except BoundedStreamIncompleteError:
res = false
await rbstream.closeWait()
elif test == 3:
case cmp
of BoundCmp.Equal:
try:
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
of BoundCmp.LessOrEqual:
try:
let response = await rbstream.read()
if len(response) == 0:
res = true
except BoundedStreamIncompleteError:
res = false
await rbstream.closeWait()
await rstream.closeWait()
await conn.closeWait()
await server.join()
return (res and clientRes)
await rstream.closeWait()
await conn.closeWait()
await server.join()
return (res and clientRes)
let address = initTAddress("127.0.0.1:48030")
test "BoundedStream(size) reading/writing test [" & $item & "]":
check waitFor(boundedTest(address, 0, item)) == true
test "BoundedStream(size) overflow test [" & $item & "]":
check waitFor(boundedTest(address, 1, item)) == true
test "BoundedStream(size) incomplete test [" & $item & "]":
check waitFor(boundedTest(address, 2, item)) == true
test "BoundedStream(boundary) reading test [" & $item & "]":
check waitFor(boundaryTest(address, 0, item,
@[0x2D'u8, 0x2D'u8, 0x2D'u8]))
test "BoundedStream(boundary) double message test [" & $item & "]":
check waitFor(boundaryTest(address, 1, item,
@[0x2D'u8, 0x2D'u8, 0x2D'u8]))
test "BoundedStream(size+boundary) reading size-bound test [" & $item & "]":
check waitFor(boundaryTest(address, 2, item,
@[0x2D'u8, 0x2D'u8, 0x2D'u8]))
test "BoundedStream(boundary) reading incomplete test [" & $item & "]":
check waitFor(boundaryTest(address, 3, item,
@[0x2D'u8, 0x2D'u8, 0x2D'u8]))
let address = initTAddress("127.0.0.1:48030")
let suffix =
case itemComp
of BoundCmp.Equal:
"== " & $itemSize
of BoundCmp.LessOrEqual:
"<= " & $itemSize
test "BoundedStream(size) reading/writing test [" & suffix & "]":
check waitFor(boundedTest(address, 0, itemSize, itemComp)) == true
test "BoundedStream(size) overflow test [" & suffix & "]":
check waitFor(boundedTest(address, 1, itemSize, itemComp)) == true
test "BoundedStream(size) incomplete test [" & suffix & "]":
check waitFor(boundedTest(address, 2, itemSize, itemComp)) == true
test "BoundedStream(size) empty message test [" & suffix & "]":
check waitFor(boundedTest(address, 3, itemSize, itemComp)) == true
test "BoundedStream(boundary) reading test [" & suffix & "]":
check waitFor(boundaryTest(address, 0, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) double message test [" & suffix & "]":
check waitFor(boundaryTest(address, 1, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(size+boundary) reading size-bound test [" &
suffix & "]":
check waitFor(boundaryTest(address, 2, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) reading incomplete test [" &
suffix & "]":
check waitFor(boundaryTest(address, 3, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) empty message test [" &
suffix & "]":
check waitFor(boundaryTest(address, 4, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream leaks test":
check: