Big refactoring of AsyncStreams.

1. Implement all read() primitives using readLoop() like it was done in streams.
2. Fix readLine() bug.
3. Add readMessage() primitive.
4. Fixing exception hierarchy, handling code and simplification of (break/continue + exception).
5. Fix TLSStream closure procedure.
6. Add BoundedStream stream and tests.
7. Remove `result` usage from the code.
This commit is contained in:
cheatfate 2021-01-20 15:40:15 +02:00 committed by zah
parent 39456e9c18
commit 0cb6840f03
5 changed files with 1086 additions and 692 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,212 @@
#
# Chronos Asynchronous Bound Stream
# (c) Copyright 2021-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
## This module implements bounded stream reading and writing.
##
## For stream reading it means that you should read exactly bounded size of
## bytes.
##
## For stream writing it means that you should write exactly bounded size
## of bytes, and if you wrote not enough bytes error will appear on stream
## close.
import ../asyncloop, ../timer
import asyncstream, ../transports/stream, ../transports/common
export asyncstream, stream, timer, common
type
BoundedStreamReader* = ref object of AsyncStreamReader
boundSize: uint64
offset: uint64
BoundedStreamWriter* = ref object of AsyncStreamWriter
boundSize: uint64
offset: uint64
BoundedStreamError* = object of AsyncStreamError
BoundedStreamIncompleteError* = object of BoundedStreamError
BoundedStreamOverflowError* = object of BoundedStreamError
BoundedStreamRW* = BoundedStreamReader | BoundedStreamWriter
const
BoundedBufferSize* = 4096
template newBoundedStreamIncompleteError*(): ref BoundedStreamError =
newException(BoundedStreamIncompleteError,
"Stream boundary is not reached yet")
template newBoundedStreamOverflowError*(): ref BoundedStreamError =
newException(BoundedStreamOverflowError, "Stream boundary exceeded")
proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[BoundedStreamReader](stream)
rstream.state = AsyncStreamState.Running
while true:
if rstream.offset < rstream.boundSize:
let toRead = int(min(rstream.boundSize - rstream.offset,
uint64(rstream.buffer.bufferLen())))
try:
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
rstream.offset = rstream.offset + uint64(toRead)
rstream.buffer.update(toRead)
await rstream.buffer.transfer()
except AsyncStreamIncompleteError:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
except AsyncStreamReadError as exc:
rstream.state = AsyncStreamState.Error
rstream.error = exc
except CancelledError:
rstream.state = AsyncStreamState.Stopped
if rstream.state != AsyncStreamState.Running:
break
else:
rstream.state = AsyncStreamState.Finished
await rstream.buffer.transfer()
break
if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}:
# We need to notify consumer about error/close, but we do not care about
# incoming data anymore.
rstream.buffer.forget()
proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[BoundedStreamWriter](stream)
wstream.state = AsyncStreamState.Running
while true:
var
item: WriteItem
error: ref AsyncStreamError
try:
item = await wstream.queue.get()
if item.size > 0:
if uint64(item.size) <= (wstream.boundSize - wstream.offset):
# Writing chunk data.
case item.kind
of WriteType.Pointer:
await wstream.wsource.write(item.data1, item.size)
of WriteType.Sequence:
await wstream.wsource.write(addr item.data2[0], item.size)
of WriteType.String:
await wstream.wsource.write(addr item.data3[0], item.size)
wstream.offset = wstream.offset + uint64(item.size)
item.future.complete()
else:
wstream.state = AsyncStreamState.Error
error = newBoundedStreamOverflowError()
else:
if wstream.offset != wstream.boundSize:
wstream.state = AsyncStreamState.Error
error = newBoundedStreamIncompleteError()
else:
wstream.state = AsyncStreamState.Finished
item.future.complete()
except CancelledError:
wstream.state = AsyncStreamState.Stopped
error = newAsyncStreamUseClosedError()
except AsyncStreamWriteError as exc:
wstream.state = AsyncStreamState.Error
error = exc
except AsyncStreamIncompleteError as exc:
wstream.state = AsyncStreamState.Error
error = exc
if wstream.state != AsyncStreamState.Running:
if wstream.state == AsyncStreamState.Finished:
error = newAsyncStreamUseClosedError()
else:
if not(isNil(item.future)):
if not(item.future.finished()):
item.future.fail(error)
while not(wstream.queue.empty()):
let pitem = wstream.queue.popFirstNoWait()
if not(pitem.future.finished()):
pitem.future.fail(error)
break
proc bytesLeft*(stream: BoundedStreamRW): uint64 =
## Returns number of bytes left in stream.
stream.boundSize - stream.bytesCount
proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader,
bufferSize = BoundedBufferSize, udata: ref T) =
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize,
udata)
proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader,
bufferSize = BoundedBufferSize) =
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize)
proc newBoundedStreamReader*[T](rsource: AsyncStreamReader,
boundSize: uint64,
bufferSize = BoundedBufferSize,
udata: ref T): BoundedStreamReader =
var res = BoundedStreamReader(boundSize: boundSize)
res.init(rsource, bufferSize, udata)
res
proc newBoundedStreamReader*(rsource: AsyncStreamReader,
boundSize: uint64,
bufferSize = BoundedBufferSize,
): BoundedStreamReader =
doAssert(boundSize >= 0)
var res = BoundedStreamReader(boundSize: boundSize)
res.init(rsource, bufferSize)
res
proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize,
udata)
proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize) =
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize)
proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter,
boundSize: uint64,
queueSize = AsyncStreamDefaultQueueSize,
udata: ref T): BoundedStreamWriter =
var res = BoundedStreamWriter(boundSize: boundSize)
res.init(wsource, queueSize, udata)
res
proc newBoundedStreamWriter*(wsource: AsyncStreamWriter,
boundSize: uint64,
queueSize = AsyncStreamDefaultQueueSize,
): BoundedStreamWriter =
var res = BoundedStreamWriter(boundSize: boundSize)
res.init(wsource, queueSize)
res
proc close*(rw: BoundedStreamRW) =
## Close and frees resources of stream ``rw``.
##
## Note close() procedure is not completed immediately.
if rw.closed():
raise newAsyncStreamIncorrectError("Stream is already closed!")
# We do not want to raise one more IncompleteError if it was already raised
# by one of the read()/write() primitives.
if rw.state != AsyncStreamState.Error:
if rw.bytesLeft() != 0'u64:
raise newBoundedStreamIncompleteError()
when rw is BoundedStreamReader:
cast[AsyncStreamReader](rw).close()
elif rw is BoundedStreamWriter:
cast[AsyncStreamWriter](rw).close()
proc closeWait*(rw: BoundedStreamRW): Future[void] =
## Close and frees resources of stream ``rw``.
rw.close()
when rw is BoundedStreamReader:
cast[AsyncStreamReader](rw).join()
elif rw is BoundedStreamWriter:
cast[AsyncStreamWriter](rw).join()

View File

@ -21,25 +21,44 @@ type
ChunkedStreamReader* = ref object of AsyncStreamReader
ChunkedStreamWriter* = ref object of AsyncStreamWriter
ChunkedStreamError* = object of CatchableError
ChunkedStreamError* = object of AsyncStreamError
ChunkedStreamProtocolError* = object of ChunkedStreamError
ChunkedStreamIncompleteError* = object of ChunkedStreamError
proc newProtocolError(): ref Exception {.inline.} =
proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} =
newException(ChunkedStreamProtocolError, "Protocol error!")
proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} =
newException(ChunkedStreamIncompleteError, "Incomplete data received!")
proc `-`(x: uint32): uint32 {.inline.} =
result = (0xFFFF_FFFF'u32 - x) + 1'u32
proc LT(x, y: uint32): uint32 {.inline.} =
let z = x - y
(z xor ((y xor x) and (y xor z))) shr 31
proc hexValue(c: byte): int =
let x = uint32(c) - 0x30'u32
let y = uint32(c) - 0x41'u32
let z = uint32(c) - 0x61'u32
let r = ((x + 1'u32) and -LT(x, 10)) or
((y + 11'u32) and -LT(y, 6)) or
((z + 11'u32) and -LT(z, 6))
int(r) - 1
proc getChunkSize(buffer: openarray[byte]): uint64 =
# We using `uint64` representation, but allow only 2^32 chunk size,
# ChunkHeaderSize.
var res = 0'u64
for i in 0..<min(len(buffer), ChunkHeaderSize):
let ch = buffer[i]
if char(ch) in {'0'..'9', 'a'..'f', 'A'..'F'}:
if ch >= byte('0') and ch <= byte('9'):
result = (result shl 4) or uint64(ch - byte('0'))
else:
result = (result shl 4) or uint64((ch and 0x0F) + 9)
let value = hexValue(buffer[i])
if value >= 0:
res = (res shl 4) or uint64(value)
else:
result = 0xFFFF_FFFF_FFFF_FFFF'u64
res = 0xFFFF_FFFF_FFFF_FFFF'u64
break
res
proc setChunkSize(buffer: var openarray[byte], length: int64): int =
# Store length as chunk header size (hexadecimal value) with CRLF.
@ -53,7 +72,7 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int =
buffer[0] = byte('0')
buffer[1] = byte(0x0D)
buffer[2] = byte(0x0A)
result = 3
3
else:
while n != 0:
var v = length and n
@ -68,161 +87,116 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int =
i = i - 4
buffer[c] = byte(0x0D)
buffer[c + 1] = byte(0x0A)
result = c + 2
c + 2
proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[ChunkedStreamReader](stream)
var buffer = newSeq[byte](1024)
rstream.state = AsyncStreamState.Running
try:
while true:
while true:
try:
# Reading chunk size
var ruFut1 = awaitne rstream.rsource.readUntil(addr buffer[0], 1024, CRLF)
if ruFut1.failed():
rstream.error = ruFut1.error
rstream.state = AsyncStreamState.Error
break
let length = ruFut1.read()
var chunksize = getChunkSize(buffer.toOpenArray(0,
length - len(CRLF) - 1))
let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF)
var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1))
if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64:
rstream.error = newProtocolError()
rstream.error = newChunkedProtocolError()
rstream.state = AsyncStreamState.Error
break
elif chunksize > 0'u64:
while chunksize > 0'u64:
let toRead = min(int(chunksize), rstream.buffer.bufferLen())
var reFut2 = awaitne rstream.rsource.readExactly(
rstream.buffer.getBuffer(), toRead)
if reFut2.failed():
rstream.error = reFut2.error
rstream.state = AsyncStreamState.Error
break
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
rstream.buffer.update(toRead)
await rstream.buffer.transfer()
chunksize = chunksize - uint64(toRead)
if rstream.state != AsyncStreamState.Running:
break
if rstream.state == AsyncStreamState.Running:
# Reading chunk trailing CRLF
await rstream.rsource.readExactly(addr buffer[0], 2)
# Reading chunk trailing CRLF
var reFut3 = awaitne rstream.rsource.readExactly(addr buffer[0], 2)
if reFut3.failed():
rstream.error = reFut3.error
rstream.state = AsyncStreamState.Error
break
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
rstream.error = newProtocolError()
rstream.state = AsyncStreamState.Error
break
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
rstream.error = newChunkedProtocolError()
rstream.state = AsyncStreamState.Error
else:
# Reading trailing line for last chunk
var ruFut4 = awaitne rstream.rsource.readUntil(addr buffer[0],
len(buffer), CRLF)
if ruFut4.failed():
rstream.error = ruFut4.error
rstream.state = AsyncStreamState.Error
break
discard await rstream.rsource.readUntil(addr buffer[0],
len(buffer), CRLF)
rstream.state = AsyncStreamState.Finished
await rstream.buffer.transfer()
break
except CancelledError:
rstream.state = AsyncStreamState.Stopped
except AsyncStreamIncompleteError:
rstream.state = AsyncStreamState.Error
rstream.error = newChunkedIncompleteError()
except AsyncStreamReadError as exc:
rstream.state = AsyncStreamState.Error
rstream.error = exc
except CancelledError:
rstream.state = AsyncStreamState.Stopped
finally:
if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}:
# We need to notify consumer about error/close, but we do not care about
# incoming data anymore.
rstream.buffer.forget()
break
proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[ChunkedStreamWriter](stream)
var buffer: array[16, byte]
var wFut1, wFut2: Future[void]
var error: ref Exception
var error: ref AsyncStreamError
wstream.state = AsyncStreamState.Running
try:
while true:
# Getting new item from stream's queue.
var item = await wstream.queue.get()
while true:
var item: WriteItem
# Getting new item from stream's queue.
try:
item = await wstream.queue.get()
# `item.size == 0` is marker of stream finish, while `item.size != 0` is
# data's marker.
if item.size > 0:
let length = setChunkSize(buffer, int64(item.size))
# Writing chunk header <length>CRLF.
wFut1 = awaitne wstream.wsource.write(addr buffer[0], length)
if wFut1.failed():
error = wFut1.error
item.future.fail(error)
continue
await wstream.wsource.write(addr buffer[0], length)
# Writing chunk data.
if item.kind == Pointer:
wFut2 = awaitne wstream.wsource.write(item.data1, item.size)
elif item.kind == Sequence:
wFut2 = awaitne wstream.wsource.write(addr item.data2[0], item.size)
elif item.kind == String:
wFut2 = awaitne wstream.wsource.write(addr item.data3[0], item.size)
if wFut2.failed():
error = wFut2.error
item.future.fail(error)
continue
case item.kind
of WriteType.Pointer:
await wstream.wsource.write(item.data1, item.size)
of WriteType.Sequence:
await wstream.wsource.write(addr item.data2[0], item.size)
of WriteType.String:
await wstream.wsource.write(addr item.data3[0], item.size)
# Writing chunk footer CRLF.
var wFut3 = awaitne wstream.wsource.write(CRLF)
if wFut3.failed():
error = wFut3.error
item.future.fail(error)
continue
await wstream.wsource.write(CRLF)
# Everything is fine, completing queue item's future.
item.future.complete()
else:
let length = setChunkSize(buffer, 0'i64)
# Write finish chunk `0`.
wFut1 = awaitne wstream.wsource.write(addr buffer[0], length)
if wFut1.failed():
error = wFut1.error
item.future.fail(error)
# We break here, because this is last chunk
break
await wstream.wsource.write(addr buffer[0], length)
# Write trailing CRLF.
wFut2 = awaitne wstream.wsource.write(CRLF)
if wFut2.failed():
error = wFut2.error
item.future.fail(error)
# We break here, because this is last chunk
break
await wstream.wsource.write(CRLF)
# Everything is fine, completing queue item's future.
item.future.complete()
# Set stream state to Finished.
wstream.state = AsyncStreamState.Finished
break
except CancelledError:
wstream.state = AsyncStreamState.Stopped
finally:
if wstream.state == AsyncStreamState.Stopped:
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.complete()
elif wstream.state == AsyncStreamState.Error:
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
if not isNil(error):
except CancelledError:
wstream.state = AsyncStreamState.Stopped
error = newAsyncStreamUseClosedError()
except AsyncStreamError as exc:
wstream.state = AsyncStreamState.Error
error = exc
if wstream.state != AsyncStreamState.Running:
if wstream.state == AsyncStreamState.Finished:
error = newAsyncStreamUseClosedError()
else:
if not(isNil(item.future)):
if not(item.future.finished()):
item.future.fail(error)
while not(wstream.queue.empty()):
let pitem = wstream.queue.popFirstNoWait()
if not(pitem.future.finished()):
pitem.future.fail(error)
break
proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader,
bufferSize = ChunkBufferSize, udata: ref T) =
@ -236,14 +210,16 @@ proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader,
proc newChunkedStreamReader*[T](rsource: AsyncStreamReader,
bufferSize = AsyncStreamDefaultBufferSize,
udata: ref T): ChunkedStreamReader =
result = new ChunkedStreamReader
result.init(rsource, bufferSize, udata)
var res = ChunkedStreamReader()
res.init(rsource, bufferSize, udata)
res
proc newChunkedStreamReader*(rsource: AsyncStreamReader,
bufferSize = AsyncStreamDefaultBufferSize,
): ChunkedStreamReader =
result = new ChunkedStreamReader
result.init(rsource, bufferSize)
var res = ChunkedStreamReader()
res.init(rsource, bufferSize)
res
proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
@ -257,11 +233,13 @@ proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize,
udata: ref T): ChunkedStreamWriter =
result = new ChunkedStreamWriter
result.init(wsource, queueSize, udata)
var res = ChunkedStreamWriter()
res.init(wsource, queueSize, udata)
res
proc newChunkedStreamWriter*(wsource: AsyncStreamWriter,
queueSize = AsyncStreamDefaultQueueSize,
): ChunkedStreamWriter =
result = new ChunkedStreamWriter
result.init(wsource, queueSize)
var res = ChunkedStreamWriter()
res.init(wsource, queueSize)
res

View File

@ -89,11 +89,11 @@ type
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
TLSStreamError* = object of CatchableError
TLSStreamError* = object of AsyncStreamError
TLSStreamProtocolError* = object of TLSStreamError
errCode*: int
template newTLSStreamProtocolError[T](message: T): ref Exception =
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
var msg = ""
var code = 0
when T is string:
@ -110,13 +110,13 @@ template newTLSStreamProtocolError[T](message: T): ref Exception =
err.errCode = code
err
proc raiseTLSStreamProtoError*[T](message: T) =
template raiseTLSStreamProtoError*[T](message: T) =
raise newTLSStreamProtocolError(message)
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[TLSStreamWriter](stream)
var engine: ptr SslEngineContext
var error: ref Exception
var error: ref AsyncStreamError
if wstream.kind == TLSStreamKind.Server:
engine = addr wstream.scontext.eng
@ -125,86 +125,77 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
wstream.state = AsyncStreamState.Running
try:
var length: uint
while true:
while true:
var item: WriteItem
try:
var state = engine.sslEngineCurrentState()
if (state and SSL_CLOSED) == SSL_CLOSED:
wstream.state = AsyncStreamState.Finished
break
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
if not(wstream.switchToReader.isSet()):
wstream.switchToReader.fire()
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
await wstream.switchToWriter.wait()
wstream.switchToWriter.clear()
# We need to refresh `state` because we just returned from readerLoop.
continue
if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream.
length = 0'u
var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf))
var fut = awaitne wstream.wsource.write(buf, int(length))
if fut.cancelled():
raise fut.error
elif fut.failed():
error = fut.error
break
sslEngineSendrecAck(engine, length)
continue
if (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream.
if not(wstream.handshaked):
wstream.stream.reader.handshaked = true
wstream.handshaked = true
if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete()
var item = await wstream.queue.get()
if item.size > 0:
length = 0'u
var buf = sslEngineSendappBuf(engine, length)
let toWrite = min(int(length), item.size)
copyOut(buf, item, toWrite)
if int(length) >= item.size:
# BearSSL is ready to accept whole item size.
sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0)
item.future.complete()
else:
# BearSSL is not ready to accept whole item, so we will send only
# part of item and adjust offset.
item.offset = item.offset + int(length)
item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length)
continue
else:
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
if not(wstream.switchToReader.isSet()):
wstream.switchToReader.fire()
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
await wstream.switchToWriter.wait()
wstream.switchToWriter.clear()
# We need to refresh `state` because we just returned from readerLoop.
else:
# Zero length item means finish
wstream.state = AsyncStreamState.Finished
break
if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream.
var length = 0'u
var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf))
await wstream.wsource.write(buf, int(length))
sslEngineSendrecAck(engine, length)
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream.
if not(wstream.handshaked):
wstream.stream.reader.handshaked = true
wstream.handshaked = true
if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete()
item = await wstream.queue.get()
if item.size > 0:
var length = 0'u
var buf = sslEngineSendappBuf(engine, length)
let toWrite = min(int(length), item.size)
copyOut(buf, item, toWrite)
if int(length) >= item.size:
# BearSSL is ready to accept whole item size.
sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0)
item.future.complete()
else:
# BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset.
item.offset = item.offset + int(length)
item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length)
else:
# Zero length item means finish, so we going to trigger TLS
# closure protocol.
sslEngineClose(engine)
except CancelledError:
wstream.state = AsyncStreamState.Stopped
error = newAsyncStreamUseClosedError()
except AsyncStreamError as exc:
wstream.state = AsyncStreamState.Error
error = exc
except CancelledError:
wstream.state = AsyncStreamState.Stopped
finally:
if wstream.state == AsyncStreamState.Stopped:
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.complete()
elif wstream.state == AsyncStreamState.Error:
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.fail(error)
wstream.stream = nil
if wstream.state != AsyncStreamState.Running:
if wstream.state == AsyncStreamState.Finished:
error = newAsyncStreamUseClosedError()
else:
if not(isNil(item.future)):
if not(item.future.finished()):
item.future.fail(error)
while not(wstream.queue.empty()):
let pitem = wstream.queue.popFirstNoWait()
if not(pitem.future.finished()):
pitem.future.fail(error)
wstream.stream = nil
break
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TLSStreamReader](stream)
@ -217,72 +208,61 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.state = AsyncStreamState.Running
try:
var length: uint
while true:
while true:
try:
var state = engine.sslEngineCurrentState()
if (state and SSL_CLOSED) == SSL_CLOSED:
let err = engine.sslEngineLastError()
if err != 0:
raise newTLSStreamProtocolError(err)
rstream.state = AsyncStreamState.Stopped
break
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
if not(rstream.switchToWriter.isSet()):
rstream.switchToWriter.fire()
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
await rstream.switchToReader.wait()
rstream.switchToReader.clear()
# We need to refresh `state` because we just returned from writerLoop.
continue
if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing
length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
let res = await rstream.rsource.readOnce(buf, int(length))
if res > 0:
sslEngineRecvrecAck(engine, uint(res))
continue
rstream.error = newTLSStreamProtocolError(err)
rstream.state = AsyncStreamState.Error
else:
rstream.state = AsyncStreamState.Finished
break
else:
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
if not(rstream.switchToWriter.isSet()):
rstream.switchToWriter.fire()
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
await rstream.switchToReader.wait()
rstream.switchToReader.clear()
# We need to refresh `state` because we just returned from writerLoop.
else:
if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing
var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
let res = await rstream.rsource.readOnce(buf, int(length))
if res > 0:
sslEngineRecvrecAck(engine, uint(res))
else:
# readOnce() returns `0` if stream is at EOF, so we initiate TLS
# closure procedure.
sslEngineClose(engine)
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
# Application data can be recovered.
var length = 0'u
var buf = sslEngineRecvappBuf(engine, length)
await upload(addr rstream.buffer, buf, int(length))
sslEngineRecvappAck(engine, length)
except CancelledError:
rstream.state = AsyncStreamState.Stopped
except AsyncStreamError as exc:
rstream.error = exc
rstream.state = AsyncStreamState.Error
if not(rstream.handshaked):
rstream.handshaked = true
rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error)
rstream.switchToWriter.fire()
if (state and SSL_RECVAPP) == SSL_RECVAPP:
# Application data can be recovered.
length = 0'u
var buf = sslEngineRecvappBuf(engine, length)
await upload(addr rstream.buffer, buf, int(length))
sslEngineRecvappAck(engine, length)
continue
except CancelledError:
rstream.state = AsyncStreamState.Stopped
except TLSStreamProtocolError as exc:
rstream.error = exc
rstream.state = AsyncStreamState.Error
if not(rstream.handshaked):
rstream.handshaked = true
rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error)
rstream.switchToWriter.fire()
except AsyncStreamReadError as exc:
rstream.error = exc
rstream.state = AsyncStreamState.Error
if not(rstream.handshaked):
rstream.handshaked = true
rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error)
rstream.switchToWriter.fire()
finally:
# Perform TLS cleanup procedure
sslEngineClose(engine)
rstream.buffer.forget()
rstream.stream = nil
if rstream.state != AsyncStreamState.Running:
# Perform TLS cleanup procedure
if rstream.state != AsyncStreamState.Finished:
sslEngineClose(engine)
rstream.buffer.forget()
rstream.stream = nil
break
proc getSignerAlgo(xc: X509Certificate): int =
## Get certificate's signing algorithm.
@ -291,9 +271,9 @@ proc getSignerAlgo(xc: X509Certificate): int =
x509DecoderPush(addr dc, xc.data, xc.dataLen)
let err = x509DecoderLastError(addr dc)
if err != 0:
result = -1
-1
else:
result = int(x509DecoderGetSignerKeyType(addr dc))
int(x509DecoderGetSignerKeyType(addr dc))
proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
wsource: AsyncStreamWriter,
@ -318,54 +298,59 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
##
## ``flags`` - custom TLS connection flags.
result = new TLSAsyncStream
var reader = TLSStreamReader(kind: TLSStreamKind.Client)
var writer = TLSStreamWriter(kind: TLSStreamKind.Client)
var switchToWriter = newAsyncEvent()
var switchToReader = newAsyncEvent()
reader.stream = result
writer.stream = result
reader.switchToReader = switchToReader
reader.switchToWriter = switchToWriter
writer.switchToReader = switchToReader
writer.switchToWriter = switchToWriter
result.reader = reader
result.writer = writer
reader.ccontext = addr result.ccontext
writer.ccontext = addr result.ccontext
let switchToWriter = newAsyncEvent()
let switchToReader = newAsyncEvent()
var res = TLSAsyncStream()
var reader = TLSStreamReader(
kind: TLSStreamKind.Client,
stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
ccontext: addr res.ccontext
)
var writer = TLSStreamWriter(
kind: TLSStreamKind.Client,
stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
ccontext: addr res.ccontext
)
res.reader = reader
res.writer = writer
if TLSFlags.NoVerifyHost in flags:
sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0)
initNoAnchor(addr result.xwc, addr result.x509.vtable)
sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable)
sslClientInitFull(addr res.ccontext, addr res.x509, nil, 0)
initNoAnchor(addr res.xwc, addr res.x509.vtable)
sslEngineSetX509(addr res.ccontext.eng, addr res.xwc.vtable)
else:
sslClientInitFull(addr result.ccontext, addr result.x509,
sslClientInitFull(addr res.ccontext, addr res.x509,
unsafeAddr MozillaTrustAnchors[0],
len(MozillaTrustAnchors))
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
result.sbuffer = newSeq[byte](size)
sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0],
uint(len(result.sbuffer)), 1)
sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion),
res.sbuffer = newSeq[byte](size)
sslEngineSetBuffer(addr res.ccontext.eng, addr res.sbuffer[0],
uint(len(res.sbuffer)), 1)
sslEngineSetVersions(addr res.ccontext.eng, uint16(minVersion),
uint16(maxVersion))
if TLSFlags.NoVerifyServerName in flags:
let err = sslClientReset(addr result.ccontext, "", 0)
let err = sslClientReset(addr res.ccontext, "", 0)
if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer")
else:
if len(serverName) == 0:
raise newException(TLSStreamError, "serverName must not be empty string")
let err = sslClientReset(addr result.ccontext, serverName, 0)
let err = sslClientReset(addr res.ccontext, serverName, 0)
if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer")
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop,
bufferSize)
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop,
bufferSize)
res
proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
wsource: AsyncStreamWriter,
@ -395,98 +380,104 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
if isNil(certificate) or len(certificate.certs) == 0:
raiseTLSStreamProtoError("Incorrect certificate")
result = new TLSAsyncStream
var reader = TLSStreamReader(kind: TLSStreamKind.Server)
var writer = TLSStreamWriter(kind: TLSStreamKind.Server)
var switchToWriter = newAsyncEvent()
var switchToReader = newAsyncEvent()
reader.stream = result
writer.stream = result
reader.switchToReader = switchToReader
reader.switchToWriter = switchToWriter
writer.switchToReader = switchToReader
writer.switchToWriter = switchToWriter
result.reader = reader
result.writer = writer
reader.scontext = addr result.scontext
writer.scontext = addr result.scontext
let switchToWriter = newAsyncEvent()
let switchToReader = newAsyncEvent()
var res = TLSAsyncStream()
var reader = TLSStreamReader(
kind: TLSStreamKind.Server,
stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
scontext: addr res.scontext
)
var writer = TLSStreamWriter(
kind: TLSStreamKind.Server,
stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
scontext: addr res.scontext
)
res.reader = reader
res.writer = writer
if privateKey.kind == TLSKeyType.EC:
let algo = getSignerAlgo(certificate.certs[0])
if algo == -1:
raiseTLSStreamProtoError("Could not decode certificate")
sslServerInitFullEc(addr result.scontext, addr certificate.certs[0],
sslServerInitFullEc(addr res.scontext, addr certificate.certs[0],
len(certificate.certs), cuint(algo),
addr privateKey.eckey)
elif privateKey.kind == TLSKeyType.RSA:
sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0],
sslServerInitFullRsa(addr res.scontext, addr certificate.certs[0],
len(certificate.certs), addr privateKey.rsakey)
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
result.sbuffer = newSeq[byte](size)
sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0],
uint(len(result.sbuffer)), 1)
sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion),
res.sbuffer = newSeq[byte](size)
sslEngineSetBuffer(addr res.scontext.eng, addr res.sbuffer[0],
uint(len(res.sbuffer)), 1)
sslEngineSetVersions(addr res.scontext.eng, uint16(minVersion),
uint16(maxVersion))
if not isNil(cache):
sslServerSetCache(addr result.scontext, addr cache.context.vtable)
sslServerSetCache(addr res.scontext, addr cache.context.vtable)
if TLSFlags.EnforceServerPref in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
sslEngineAddFlags(addr res.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
if TLSFlags.NoRenegotiation in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION)
sslEngineAddFlags(addr res.scontext.eng, OPT_NO_RENEGOTIATION)
if TLSFlags.TolerateNoClientAuth in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH)
sslEngineAddFlags(addr res.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH)
if TLSFlags.FailOnAlpnMismatch in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH)
sslEngineAddFlags(addr res.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH)
let err = sslServerReset(addr result.scontext)
let err = sslServerReset(addr res.scontext)
if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer")
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop,
bufferSize)
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop,
bufferSize)
res
proc copyKey(src: RsaPrivateKey): TLSPrivateKey =
## Creates copy of RsaPrivateKey ``src``.
var offset = 0
let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen
result = TLSPrivateKey(kind: TLSKeyType.RSA)
result.storage = newSeq[byte](keySize)
copyMem(addr result.storage[offset], src.p, src.plen)
result.rsakey.p = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.plen = src.plen
var res = TLSPrivateKey(kind: TLSKeyType.RSA, storage: newSeq[byte](keySize))
copyMem(addr res.storage[offset], src.p, src.plen)
res.rsakey.p = cast[ptr cuchar](addr res.storage[offset])
res.rsakey.plen = src.plen
offset = offset + src.plen
copyMem(addr result.storage[offset], src.q, src.qlen)
result.rsakey.q = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.qlen = src.qlen
copyMem(addr res.storage[offset], src.q, src.qlen)
res.rsakey.q = cast[ptr cuchar](addr res.storage[offset])
res.rsakey.qlen = src.qlen
offset = offset + src.qlen
copyMem(addr result.storage[offset], src.dp, src.dplen)
result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.dplen = src.dplen
copyMem(addr res.storage[offset], src.dp, src.dplen)
res.rsakey.dp = cast[ptr cuchar](addr res.storage[offset])
res.rsakey.dplen = src.dplen
offset = offset + src.dplen
copyMem(addr result.storage[offset], src.dq, src.dqlen)
result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.dqlen = src.dqlen
copyMem(addr res.storage[offset], src.dq, src.dqlen)
res.rsakey.dq = cast[ptr cuchar](addr res.storage[offset])
res.rsakey.dqlen = src.dqlen
offset = offset + src.dqlen
copyMem(addr result.storage[offset], src.iq, src.iqlen)
result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.iqlen = src.iqlen
result.rsakey.nBitlen = src.nBitlen
copyMem(addr res.storage[offset], src.iq, src.iqlen)
res.rsakey.iq = cast[ptr cuchar](addr res.storage[offset])
res.rsakey.iqlen = src.iqlen
res.rsakey.nBitlen = src.nBitlen
res
proc copyKey(src: EcPrivateKey): TLSPrivateKey =
## Creates copy of EcPrivateKey ``src``.
var offset = 0
let keySize = src.xlen
result = TLSPrivateKey(kind: TLSKeyType.EC)
result.storage = newSeq[byte](keySize)
copyMem(addr result.storage[offset], src.x, src.xlen)
result.eckey.x = cast[ptr cuchar](addr result.storage[offset])
result.eckey.xlen = src.xlen
result.eckey.curve = src.curve
var res = TLSPrivateKey(kind: TLSKeyType.EC, storage: newSeq[byte](keySize))
copyMem(addr res.storage[offset], src.x, src.xlen)
res.eckey.x = cast[ptr cuchar](addr res.storage[offset])
res.eckey.xlen = src.xlen
res.eckey.curve = src.curve
res
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
## Initialize TLS private key from array of bytes ``data``.
@ -502,12 +493,14 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
if err != 0:
raiseTLSStreamProtoError(err)
let keyType = skeyDecoderKeyType(addr ctx)
if keyType == KEYTYPE_RSA:
result = copyKey(ctx.key.rsa)
elif keyType == KEYTYPE_EC:
result = copyKey(ctx.key.ec)
else:
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
let res =
if keyType == KEYTYPE_RSA:
copyKey(ctx.key.rsa)
elif keyType == KEYTYPE_EC:
copyKey(ctx.key.ec)
else:
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
res
proc pemDecode*(data: openarray[char]): seq[PEMElement] =
## Decode PEM encoded string and get array of binary blobs.
@ -515,7 +508,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] =
raiseTLSStreamProtoError("Empty PEM message")
var ctx: PemDecoderContext
var pctx = new PEMContext
result = newSeq[PEMElement]()
var res = newSeq[PEMElement]()
pemDecoderInit(addr ctx)
proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} =
@ -544,12 +537,13 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] =
elif event == PEM_END_OBJ:
if inobj:
elem.data = pctx.data
result.add(elem)
res.add(elem)
inobj = false
else:
break
else:
raiseTLSStreamProtoError("Invalid PEM encoding")
res
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
## Initialize TLS private key from string ``data``.
@ -558,13 +552,15 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
## encoded string.
##
## Note that PKCS#1 PEM encoded objects are not supported.
var res: TLSPrivateKey
var items = pemDecode(data)
for item in items:
if item.name == "PRIVATE KEY":
result = TLSPrivateKey.init(item.data)
res = TLSPrivateKey.init(item.data)
break
if isNil(result):
if isNil(res):
raiseTLSStreamProtoError("Could not find private key")
res
proc init*(tt: typedesc[TLSCertificate],
data: openarray[char]): TLSCertificate =
@ -572,32 +568,33 @@ proc init*(tt: typedesc[TLSCertificate],
##
## This procedure initializes array of certificates from PEM encoded string.
var items = pemDecode(data)
result = new TLSCertificate
var res = TLSCertificate()
for item in items:
if item.name == "CERTIFICATE" and len(item.data) > 0:
let offset = len(result.storage)
result.storage.add(item.data)
let offset = len(res.storage)
res.storage.add(item.data)
let cert = X509Certificate(
data: cast[ptr cuchar](addr result.storage[offset]),
data: cast[ptr cuchar](addr res.storage[offset]),
dataLen: len(item.data)
)
let res = getSignerAlgo(cert)
if res == -1:
let ares = getSignerAlgo(cert)
if ares == -1:
raiseTLSStreamProtoError("Could not decode certificate")
elif res != KEYTYPE_RSA and res != KEYTYPE_EC:
elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC:
raiseTLSStreamProtoError("Unsupported signing key type in certificate")
result.certs.add(cert)
if len(result.storage) == 0:
res.certs.add(cert)
if len(res.storage) == 0:
raiseTLSStreamProtoError("Could not find any certificates")
res
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =
## Create new TLS session cache with size ``size``.
##
## One cached item is near 100 bytes size.
result = new TLSSessionCache
var rsize = min(size, 4096)
result.storage = newSeq[byte](rsize)
sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize)
var res = TLSSessionCache(storage: newSeq[byte](rsize))
sslSessionCacheLruInit(addr res.context, addr res.storage[0], rsize)
res
proc handshake*(rws: SomeTLSStreamType): Future[void] =
## Wait until initial TLS handshake will be successfully performed.
@ -620,4 +617,4 @@ proc handshake*(rws: SomeTLSStreamType): Future[void] =
else:
rws.reader.handshakeFut = retFuture
rws.writer.handshakeFut = retFuture
return retFuture
retFuture

View File

@ -6,7 +6,8 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import unittest
import ../chronos, ../chronos/streams/tlsstream
import ../chronos
import ../chronos/streams/[tlsstream, chunkstream, boundstream]
when defined(nimHasUsed): {.used.}
@ -553,8 +554,12 @@ suite "ChunkedStream test suite":
try:
var r = await rstream2.read()
doAssert(len(r) > 0)
except AsyncStreamReadError:
res = true
except ChunkedStreamIncompleteError:
if inputstr == "100000000 \r\n1":
res = true
except ChunkedStreamProtocolError:
if inputstr == "z\r\n1":
res = true
await rstream2.closeWait()
await rstream.closeWait()
await transp.closeWait()
@ -663,3 +668,119 @@ suite "TLSStream test suite":
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "BoundedStream test suite":
proc createBigMessage(size: int): seq[byte] =
var message = "MESSAGE"
result = newSeq[byte](size)
for i in 0 ..< len(result):
result[i] = byte(message[i mod len(message)])
for item in [100'u64, 60000'u64]:
proc boundedTest(address: TransportAddress, test: int,
size: uint64): Future[bool] {.async.} =
var clientRes = false
var res = false
let messagePart = createBigMessage(int(item) div 10)
var message: seq[byte]
for i in 0 ..< 10:
message.add(messagePart)
proc processClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
var wbstream = newBoundedStreamWriter(wstream, size)
if test == 0:
for i in 0 ..< 10:
await wbstream.write(messagePart)
await wbstream.finish()
await wbstream.closeWait()
clientRes = true
elif test == 1:
for i in 0 ..< 10:
await wbstream.write(messagePart)
try:
await wbstream.write(messagePart)
except BoundedStreamOverflowError:
clientRes = true
await wbstream.closeWait()
elif test == 2:
for i in 0 ..< 9:
await wbstream.write(messagePart)
try:
await wbstream.finish()
except BoundedStreamIncompleteError:
clientRes = true
await wbstream.closeWait()
elif test == 3:
for i in 0 ..< 10:
await wbstream.write(messagePart)
await wbstream.finish()
await wbstream.closeWait()
clientRes = true
elif test == 4:
for i in 0 ..< 9:
await wbstream.write(messagePart)
try:
await wbstream.closeWait()
except BoundedStreamIncompleteError:
clientRes = true
await wstream.closeWait()
await transp.closeWait()
server.stop()
server.close()
var server = createStreamServer(address, processClient, flags = {ReuseAddr})
server.start()
var conn = await connect(address)
var rstream = newAsyncStreamReader(conn)
var rbstream = newBoundedStreamReader(rstream, size)
if test == 0:
let response = await rbstream.read()
await rbstream.closeWait()
if response == message:
res = true
elif test == 1:
let response = await rbstream.read()
await rbstream.closeWait()
if response == message:
res = true
elif test == 2:
try:
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
await rbstream.closeWait()
elif test == 3:
let response {.used.} = await rbstream.read(int(size) - 1)
try:
await rbstream.closeWait()
except BoundedStreamIncompleteError:
res = true
elif test == 4:
try:
let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError:
res = true
await rbstream.closeWait()
await rstream.closeWait()
await conn.closeWait()
await server.join()
return (res and clientRes)
let address = initTAddress("127.0.0.1:48030")
test "BoundedStream reading/writing test [" & $item & "]":
check waitFor(boundedTest(address, 0, item)) == true
test "BoundedStream overflow test [" & $item & "]":
check waitFor(boundedTest(address, 1, item)) == true
test "BoundedStream incomplete test [" & $item & "]":
check waitFor(boundedTest(address, 2, item)) == true
test "BoundedStream read() close test [" & $item & "]":
check waitFor(boundedTest(address, 3, item)) == true
test "BoundedStream write() close test [" & $item & "]":
check waitFor(boundedTest(address, 4, item)) == true