Big refactoring of AsyncStreams.
1. Implement all read() primitives using readLoop() like it was done in streams. 2. Fix readLine() bug. 3. Add readMessage() primitive. 4. Fixing exception hierarchy, handling code and simplification of (break/continue + exception). 5. Fix TLSStream closure procedure. 6. Add BoundedStream stream and tests. 7. Remove `result` usage from the code.
This commit is contained in:
parent
39456e9c18
commit
0cb6840f03
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,212 @@
|
|||
#
|
||||
# Chronos Asynchronous Bound Stream
|
||||
# (c) Copyright 2021-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
|
||||
## This module implements bounded stream reading and writing.
|
||||
##
|
||||
## For stream reading it means that you should read exactly bounded size of
|
||||
## bytes.
|
||||
##
|
||||
## For stream writing it means that you should write exactly bounded size
|
||||
## of bytes, and if you wrote not enough bytes error will appear on stream
|
||||
## close.
|
||||
import ../asyncloop, ../timer
|
||||
import asyncstream, ../transports/stream, ../transports/common
|
||||
export asyncstream, stream, timer, common
|
||||
|
||||
type
|
||||
BoundedStreamReader* = ref object of AsyncStreamReader
|
||||
boundSize: uint64
|
||||
offset: uint64
|
||||
|
||||
BoundedStreamWriter* = ref object of AsyncStreamWriter
|
||||
boundSize: uint64
|
||||
offset: uint64
|
||||
|
||||
BoundedStreamError* = object of AsyncStreamError
|
||||
BoundedStreamIncompleteError* = object of BoundedStreamError
|
||||
BoundedStreamOverflowError* = object of BoundedStreamError
|
||||
|
||||
BoundedStreamRW* = BoundedStreamReader | BoundedStreamWriter
|
||||
|
||||
const
|
||||
BoundedBufferSize* = 4096
|
||||
|
||||
template newBoundedStreamIncompleteError*(): ref BoundedStreamError =
|
||||
newException(BoundedStreamIncompleteError,
|
||||
"Stream boundary is not reached yet")
|
||||
template newBoundedStreamOverflowError*(): ref BoundedStreamError =
|
||||
newException(BoundedStreamOverflowError, "Stream boundary exceeded")
|
||||
|
||||
proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||
var rstream = cast[BoundedStreamReader](stream)
|
||||
rstream.state = AsyncStreamState.Running
|
||||
while true:
|
||||
if rstream.offset < rstream.boundSize:
|
||||
let toRead = int(min(rstream.boundSize - rstream.offset,
|
||||
uint64(rstream.buffer.bufferLen())))
|
||||
try:
|
||||
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
|
||||
rstream.offset = rstream.offset + uint64(toRead)
|
||||
rstream.buffer.update(toRead)
|
||||
await rstream.buffer.transfer()
|
||||
except AsyncStreamIncompleteError:
|
||||
rstream.state = AsyncStreamState.Error
|
||||
rstream.error = newBoundedStreamIncompleteError()
|
||||
except AsyncStreamReadError as exc:
|
||||
rstream.state = AsyncStreamState.Error
|
||||
rstream.error = exc
|
||||
except CancelledError:
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
|
||||
if rstream.state != AsyncStreamState.Running:
|
||||
break
|
||||
else:
|
||||
rstream.state = AsyncStreamState.Finished
|
||||
await rstream.buffer.transfer()
|
||||
break
|
||||
|
||||
if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}:
|
||||
# We need to notify consumer about error/close, but we do not care about
|
||||
# incoming data anymore.
|
||||
rstream.buffer.forget()
|
||||
|
||||
proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||
var wstream = cast[BoundedStreamWriter](stream)
|
||||
|
||||
wstream.state = AsyncStreamState.Running
|
||||
while true:
|
||||
var
|
||||
item: WriteItem
|
||||
error: ref AsyncStreamError
|
||||
|
||||
try:
|
||||
item = await wstream.queue.get()
|
||||
if item.size > 0:
|
||||
if uint64(item.size) <= (wstream.boundSize - wstream.offset):
|
||||
# Writing chunk data.
|
||||
case item.kind
|
||||
of WriteType.Pointer:
|
||||
await wstream.wsource.write(item.data1, item.size)
|
||||
of WriteType.Sequence:
|
||||
await wstream.wsource.write(addr item.data2[0], item.size)
|
||||
of WriteType.String:
|
||||
await wstream.wsource.write(addr item.data3[0], item.size)
|
||||
wstream.offset = wstream.offset + uint64(item.size)
|
||||
item.future.complete()
|
||||
else:
|
||||
wstream.state = AsyncStreamState.Error
|
||||
error = newBoundedStreamOverflowError()
|
||||
else:
|
||||
if wstream.offset != wstream.boundSize:
|
||||
wstream.state = AsyncStreamState.Error
|
||||
error = newBoundedStreamIncompleteError()
|
||||
else:
|
||||
wstream.state = AsyncStreamState.Finished
|
||||
item.future.complete()
|
||||
except CancelledError:
|
||||
wstream.state = AsyncStreamState.Stopped
|
||||
error = newAsyncStreamUseClosedError()
|
||||
except AsyncStreamWriteError as exc:
|
||||
wstream.state = AsyncStreamState.Error
|
||||
error = exc
|
||||
except AsyncStreamIncompleteError as exc:
|
||||
wstream.state = AsyncStreamState.Error
|
||||
error = exc
|
||||
|
||||
if wstream.state != AsyncStreamState.Running:
|
||||
if wstream.state == AsyncStreamState.Finished:
|
||||
error = newAsyncStreamUseClosedError()
|
||||
else:
|
||||
if not(isNil(item.future)):
|
||||
if not(item.future.finished()):
|
||||
item.future.fail(error)
|
||||
while not(wstream.queue.empty()):
|
||||
let pitem = wstream.queue.popFirstNoWait()
|
||||
if not(pitem.future.finished()):
|
||||
pitem.future.fail(error)
|
||||
break
|
||||
|
||||
proc bytesLeft*(stream: BoundedStreamRW): uint64 =
|
||||
## Returns number of bytes left in stream.
|
||||
stream.boundSize - stream.bytesCount
|
||||
|
||||
proc init*[T](child: BoundedStreamReader, rsource: AsyncStreamReader,
|
||||
bufferSize = BoundedBufferSize, udata: ref T) =
|
||||
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize,
|
||||
udata)
|
||||
|
||||
proc init*(child: BoundedStreamReader, rsource: AsyncStreamReader,
|
||||
bufferSize = BoundedBufferSize) =
|
||||
init(cast[AsyncStreamReader](child), rsource, boundedReadLoop, bufferSize)
|
||||
|
||||
proc newBoundedStreamReader*[T](rsource: AsyncStreamReader,
|
||||
boundSize: uint64,
|
||||
bufferSize = BoundedBufferSize,
|
||||
udata: ref T): BoundedStreamReader =
|
||||
var res = BoundedStreamReader(boundSize: boundSize)
|
||||
res.init(rsource, bufferSize, udata)
|
||||
res
|
||||
|
||||
proc newBoundedStreamReader*(rsource: AsyncStreamReader,
|
||||
boundSize: uint64,
|
||||
bufferSize = BoundedBufferSize,
|
||||
): BoundedStreamReader =
|
||||
doAssert(boundSize >= 0)
|
||||
var res = BoundedStreamReader(boundSize: boundSize)
|
||||
res.init(rsource, bufferSize)
|
||||
res
|
||||
|
||||
proc init*[T](child: BoundedStreamWriter, wsource: AsyncStreamWriter,
|
||||
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
|
||||
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize,
|
||||
udata)
|
||||
|
||||
proc init*(child: BoundedStreamWriter, wsource: AsyncStreamWriter,
|
||||
queueSize = AsyncStreamDefaultQueueSize) =
|
||||
init(cast[AsyncStreamWriter](child), wsource, boundedWriteLoop, queueSize)
|
||||
|
||||
proc newBoundedStreamWriter*[T](wsource: AsyncStreamWriter,
|
||||
boundSize: uint64,
|
||||
queueSize = AsyncStreamDefaultQueueSize,
|
||||
udata: ref T): BoundedStreamWriter =
|
||||
var res = BoundedStreamWriter(boundSize: boundSize)
|
||||
res.init(wsource, queueSize, udata)
|
||||
res
|
||||
|
||||
proc newBoundedStreamWriter*(wsource: AsyncStreamWriter,
|
||||
boundSize: uint64,
|
||||
queueSize = AsyncStreamDefaultQueueSize,
|
||||
): BoundedStreamWriter =
|
||||
var res = BoundedStreamWriter(boundSize: boundSize)
|
||||
res.init(wsource, queueSize)
|
||||
res
|
||||
|
||||
proc close*(rw: BoundedStreamRW) =
|
||||
## Close and frees resources of stream ``rw``.
|
||||
##
|
||||
## Note close() procedure is not completed immediately.
|
||||
if rw.closed():
|
||||
raise newAsyncStreamIncorrectError("Stream is already closed!")
|
||||
# We do not want to raise one more IncompleteError if it was already raised
|
||||
# by one of the read()/write() primitives.
|
||||
if rw.state != AsyncStreamState.Error:
|
||||
if rw.bytesLeft() != 0'u64:
|
||||
raise newBoundedStreamIncompleteError()
|
||||
when rw is BoundedStreamReader:
|
||||
cast[AsyncStreamReader](rw).close()
|
||||
elif rw is BoundedStreamWriter:
|
||||
cast[AsyncStreamWriter](rw).close()
|
||||
|
||||
proc closeWait*(rw: BoundedStreamRW): Future[void] =
|
||||
## Close and frees resources of stream ``rw``.
|
||||
rw.close()
|
||||
when rw is BoundedStreamReader:
|
||||
cast[AsyncStreamReader](rw).join()
|
||||
elif rw is BoundedStreamWriter:
|
||||
cast[AsyncStreamWriter](rw).join()
|
|
@ -21,25 +21,44 @@ type
|
|||
ChunkedStreamReader* = ref object of AsyncStreamReader
|
||||
ChunkedStreamWriter* = ref object of AsyncStreamWriter
|
||||
|
||||
ChunkedStreamError* = object of CatchableError
|
||||
ChunkedStreamError* = object of AsyncStreamError
|
||||
ChunkedStreamProtocolError* = object of ChunkedStreamError
|
||||
ChunkedStreamIncompleteError* = object of ChunkedStreamError
|
||||
|
||||
proc newProtocolError(): ref Exception {.inline.} =
|
||||
proc newChunkedProtocolError(): ref ChunkedStreamProtocolError {.inline.} =
|
||||
newException(ChunkedStreamProtocolError, "Protocol error!")
|
||||
|
||||
proc newChunkedIncompleteError(): ref ChunkedStreamIncompleteError {.inline.} =
|
||||
newException(ChunkedStreamIncompleteError, "Incomplete data received!")
|
||||
|
||||
proc `-`(x: uint32): uint32 {.inline.} =
|
||||
result = (0xFFFF_FFFF'u32 - x) + 1'u32
|
||||
|
||||
proc LT(x, y: uint32): uint32 {.inline.} =
|
||||
let z = x - y
|
||||
(z xor ((y xor x) and (y xor z))) shr 31
|
||||
|
||||
proc hexValue(c: byte): int =
|
||||
let x = uint32(c) - 0x30'u32
|
||||
let y = uint32(c) - 0x41'u32
|
||||
let z = uint32(c) - 0x61'u32
|
||||
let r = ((x + 1'u32) and -LT(x, 10)) or
|
||||
((y + 11'u32) and -LT(y, 6)) or
|
||||
((z + 11'u32) and -LT(z, 6))
|
||||
int(r) - 1
|
||||
|
||||
proc getChunkSize(buffer: openarray[byte]): uint64 =
|
||||
# We using `uint64` representation, but allow only 2^32 chunk size,
|
||||
# ChunkHeaderSize.
|
||||
var res = 0'u64
|
||||
for i in 0..<min(len(buffer), ChunkHeaderSize):
|
||||
let ch = buffer[i]
|
||||
if char(ch) in {'0'..'9', 'a'..'f', 'A'..'F'}:
|
||||
if ch >= byte('0') and ch <= byte('9'):
|
||||
result = (result shl 4) or uint64(ch - byte('0'))
|
||||
else:
|
||||
result = (result shl 4) or uint64((ch and 0x0F) + 9)
|
||||
let value = hexValue(buffer[i])
|
||||
if value >= 0:
|
||||
res = (res shl 4) or uint64(value)
|
||||
else:
|
||||
result = 0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
res = 0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
break
|
||||
res
|
||||
|
||||
proc setChunkSize(buffer: var openarray[byte], length: int64): int =
|
||||
# Store length as chunk header size (hexadecimal value) with CRLF.
|
||||
|
@ -53,7 +72,7 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int =
|
|||
buffer[0] = byte('0')
|
||||
buffer[1] = byte(0x0D)
|
||||
buffer[2] = byte(0x0A)
|
||||
result = 3
|
||||
3
|
||||
else:
|
||||
while n != 0:
|
||||
var v = length and n
|
||||
|
@ -68,161 +87,116 @@ proc setChunkSize(buffer: var openarray[byte], length: int64): int =
|
|||
i = i - 4
|
||||
buffer[c] = byte(0x0D)
|
||||
buffer[c + 1] = byte(0x0A)
|
||||
result = c + 2
|
||||
c + 2
|
||||
|
||||
proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||
var rstream = cast[ChunkedStreamReader](stream)
|
||||
var buffer = newSeq[byte](1024)
|
||||
rstream.state = AsyncStreamState.Running
|
||||
|
||||
try:
|
||||
while true:
|
||||
while true:
|
||||
try:
|
||||
# Reading chunk size
|
||||
var ruFut1 = awaitne rstream.rsource.readUntil(addr buffer[0], 1024, CRLF)
|
||||
if ruFut1.failed():
|
||||
rstream.error = ruFut1.error
|
||||
rstream.state = AsyncStreamState.Error
|
||||
break
|
||||
|
||||
let length = ruFut1.read()
|
||||
var chunksize = getChunkSize(buffer.toOpenArray(0,
|
||||
length - len(CRLF) - 1))
|
||||
let res = await rstream.rsource.readUntil(addr buffer[0], 1024, CRLF)
|
||||
var chunksize = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1))
|
||||
|
||||
if chunksize == 0xFFFF_FFFF_FFFF_FFFF'u64:
|
||||
rstream.error = newProtocolError()
|
||||
rstream.error = newChunkedProtocolError()
|
||||
rstream.state = AsyncStreamState.Error
|
||||
break
|
||||
elif chunksize > 0'u64:
|
||||
while chunksize > 0'u64:
|
||||
let toRead = min(int(chunksize), rstream.buffer.bufferLen())
|
||||
var reFut2 = awaitne rstream.rsource.readExactly(
|
||||
rstream.buffer.getBuffer(), toRead)
|
||||
if reFut2.failed():
|
||||
rstream.error = reFut2.error
|
||||
rstream.state = AsyncStreamState.Error
|
||||
break
|
||||
|
||||
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead)
|
||||
rstream.buffer.update(toRead)
|
||||
await rstream.buffer.transfer()
|
||||
chunksize = chunksize - uint64(toRead)
|
||||
|
||||
if rstream.state != AsyncStreamState.Running:
|
||||
break
|
||||
if rstream.state == AsyncStreamState.Running:
|
||||
# Reading chunk trailing CRLF
|
||||
await rstream.rsource.readExactly(addr buffer[0], 2)
|
||||
|
||||
# Reading chunk trailing CRLF
|
||||
var reFut3 = awaitne rstream.rsource.readExactly(addr buffer[0], 2)
|
||||
if reFut3.failed():
|
||||
rstream.error = reFut3.error
|
||||
rstream.state = AsyncStreamState.Error
|
||||
break
|
||||
|
||||
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
|
||||
rstream.error = newProtocolError()
|
||||
rstream.state = AsyncStreamState.Error
|
||||
break
|
||||
if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]:
|
||||
rstream.error = newChunkedProtocolError()
|
||||
rstream.state = AsyncStreamState.Error
|
||||
else:
|
||||
# Reading trailing line for last chunk
|
||||
var ruFut4 = awaitne rstream.rsource.readUntil(addr buffer[0],
|
||||
len(buffer), CRLF)
|
||||
if ruFut4.failed():
|
||||
rstream.error = ruFut4.error
|
||||
rstream.state = AsyncStreamState.Error
|
||||
break
|
||||
|
||||
discard await rstream.rsource.readUntil(addr buffer[0],
|
||||
len(buffer), CRLF)
|
||||
rstream.state = AsyncStreamState.Finished
|
||||
await rstream.buffer.transfer()
|
||||
break
|
||||
except CancelledError:
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
except AsyncStreamIncompleteError:
|
||||
rstream.state = AsyncStreamState.Error
|
||||
rstream.error = newChunkedIncompleteError()
|
||||
except AsyncStreamReadError as exc:
|
||||
rstream.state = AsyncStreamState.Error
|
||||
rstream.error = exc
|
||||
|
||||
except CancelledError:
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
finally:
|
||||
if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}:
|
||||
# We need to notify consumer about error/close, but we do not care about
|
||||
# incoming data anymore.
|
||||
rstream.buffer.forget()
|
||||
break
|
||||
|
||||
proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||
var wstream = cast[ChunkedStreamWriter](stream)
|
||||
var buffer: array[16, byte]
|
||||
var wFut1, wFut2: Future[void]
|
||||
var error: ref Exception
|
||||
var error: ref AsyncStreamError
|
||||
wstream.state = AsyncStreamState.Running
|
||||
|
||||
try:
|
||||
while true:
|
||||
# Getting new item from stream's queue.
|
||||
var item = await wstream.queue.get()
|
||||
while true:
|
||||
var item: WriteItem
|
||||
# Getting new item from stream's queue.
|
||||
try:
|
||||
item = await wstream.queue.get()
|
||||
# `item.size == 0` is marker of stream finish, while `item.size != 0` is
|
||||
# data's marker.
|
||||
if item.size > 0:
|
||||
let length = setChunkSize(buffer, int64(item.size))
|
||||
# Writing chunk header <length>CRLF.
|
||||
wFut1 = awaitne wstream.wsource.write(addr buffer[0], length)
|
||||
if wFut1.failed():
|
||||
error = wFut1.error
|
||||
item.future.fail(error)
|
||||
continue
|
||||
|
||||
await wstream.wsource.write(addr buffer[0], length)
|
||||
# Writing chunk data.
|
||||
if item.kind == Pointer:
|
||||
wFut2 = awaitne wstream.wsource.write(item.data1, item.size)
|
||||
elif item.kind == Sequence:
|
||||
wFut2 = awaitne wstream.wsource.write(addr item.data2[0], item.size)
|
||||
elif item.kind == String:
|
||||
wFut2 = awaitne wstream.wsource.write(addr item.data3[0], item.size)
|
||||
if wFut2.failed():
|
||||
error = wFut2.error
|
||||
item.future.fail(error)
|
||||
continue
|
||||
|
||||
case item.kind
|
||||
of WriteType.Pointer:
|
||||
await wstream.wsource.write(item.data1, item.size)
|
||||
of WriteType.Sequence:
|
||||
await wstream.wsource.write(addr item.data2[0], item.size)
|
||||
of WriteType.String:
|
||||
await wstream.wsource.write(addr item.data3[0], item.size)
|
||||
# Writing chunk footer CRLF.
|
||||
var wFut3 = awaitne wstream.wsource.write(CRLF)
|
||||
if wFut3.failed():
|
||||
error = wFut3.error
|
||||
item.future.fail(error)
|
||||
continue
|
||||
|
||||
await wstream.wsource.write(CRLF)
|
||||
# Everything is fine, completing queue item's future.
|
||||
item.future.complete()
|
||||
else:
|
||||
let length = setChunkSize(buffer, 0'i64)
|
||||
|
||||
# Write finish chunk `0`.
|
||||
wFut1 = awaitne wstream.wsource.write(addr buffer[0], length)
|
||||
if wFut1.failed():
|
||||
error = wFut1.error
|
||||
item.future.fail(error)
|
||||
# We break here, because this is last chunk
|
||||
break
|
||||
|
||||
await wstream.wsource.write(addr buffer[0], length)
|
||||
# Write trailing CRLF.
|
||||
wFut2 = awaitne wstream.wsource.write(CRLF)
|
||||
if wFut2.failed():
|
||||
error = wFut2.error
|
||||
item.future.fail(error)
|
||||
# We break here, because this is last chunk
|
||||
break
|
||||
|
||||
await wstream.wsource.write(CRLF)
|
||||
# Everything is fine, completing queue item's future.
|
||||
item.future.complete()
|
||||
|
||||
# Set stream state to Finished.
|
||||
wstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
except CancelledError:
|
||||
wstream.state = AsyncStreamState.Stopped
|
||||
finally:
|
||||
if wstream.state == AsyncStreamState.Stopped:
|
||||
while len(wstream.queue) > 0:
|
||||
let item = wstream.queue.popFirstNoWait()
|
||||
if not(item.future.finished()):
|
||||
item.future.complete()
|
||||
elif wstream.state == AsyncStreamState.Error:
|
||||
while len(wstream.queue) > 0:
|
||||
let item = wstream.queue.popFirstNoWait()
|
||||
if not(item.future.finished()):
|
||||
if not isNil(error):
|
||||
except CancelledError:
|
||||
wstream.state = AsyncStreamState.Stopped
|
||||
error = newAsyncStreamUseClosedError()
|
||||
except AsyncStreamError as exc:
|
||||
wstream.state = AsyncStreamState.Error
|
||||
error = exc
|
||||
|
||||
if wstream.state != AsyncStreamState.Running:
|
||||
if wstream.state == AsyncStreamState.Finished:
|
||||
error = newAsyncStreamUseClosedError()
|
||||
else:
|
||||
if not(isNil(item.future)):
|
||||
if not(item.future.finished()):
|
||||
item.future.fail(error)
|
||||
while not(wstream.queue.empty()):
|
||||
let pitem = wstream.queue.popFirstNoWait()
|
||||
if not(pitem.future.finished()):
|
||||
pitem.future.fail(error)
|
||||
break
|
||||
|
||||
proc init*[T](child: ChunkedStreamReader, rsource: AsyncStreamReader,
|
||||
bufferSize = ChunkBufferSize, udata: ref T) =
|
||||
|
@ -236,14 +210,16 @@ proc init*(child: ChunkedStreamReader, rsource: AsyncStreamReader,
|
|||
proc newChunkedStreamReader*[T](rsource: AsyncStreamReader,
|
||||
bufferSize = AsyncStreamDefaultBufferSize,
|
||||
udata: ref T): ChunkedStreamReader =
|
||||
result = new ChunkedStreamReader
|
||||
result.init(rsource, bufferSize, udata)
|
||||
var res = ChunkedStreamReader()
|
||||
res.init(rsource, bufferSize, udata)
|
||||
res
|
||||
|
||||
proc newChunkedStreamReader*(rsource: AsyncStreamReader,
|
||||
bufferSize = AsyncStreamDefaultBufferSize,
|
||||
): ChunkedStreamReader =
|
||||
result = new ChunkedStreamReader
|
||||
result.init(rsource, bufferSize)
|
||||
var res = ChunkedStreamReader()
|
||||
res.init(rsource, bufferSize)
|
||||
res
|
||||
|
||||
proc init*[T](child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
|
||||
queueSize = AsyncStreamDefaultQueueSize, udata: ref T) =
|
||||
|
@ -257,11 +233,13 @@ proc init*(child: ChunkedStreamWriter, wsource: AsyncStreamWriter,
|
|||
proc newChunkedStreamWriter*[T](wsource: AsyncStreamWriter,
|
||||
queueSize = AsyncStreamDefaultQueueSize,
|
||||
udata: ref T): ChunkedStreamWriter =
|
||||
result = new ChunkedStreamWriter
|
||||
result.init(wsource, queueSize, udata)
|
||||
var res = ChunkedStreamWriter()
|
||||
res.init(wsource, queueSize, udata)
|
||||
res
|
||||
|
||||
proc newChunkedStreamWriter*(wsource: AsyncStreamWriter,
|
||||
queueSize = AsyncStreamDefaultQueueSize,
|
||||
): ChunkedStreamWriter =
|
||||
result = new ChunkedStreamWriter
|
||||
result.init(wsource, queueSize)
|
||||
var res = ChunkedStreamWriter()
|
||||
res.init(wsource, queueSize)
|
||||
res
|
||||
|
|
|
@ -89,11 +89,11 @@ type
|
|||
|
||||
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
|
||||
|
||||
TLSStreamError* = object of CatchableError
|
||||
TLSStreamError* = object of AsyncStreamError
|
||||
TLSStreamProtocolError* = object of TLSStreamError
|
||||
errCode*: int
|
||||
|
||||
template newTLSStreamProtocolError[T](message: T): ref Exception =
|
||||
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
||||
var msg = ""
|
||||
var code = 0
|
||||
when T is string:
|
||||
|
@ -110,13 +110,13 @@ template newTLSStreamProtocolError[T](message: T): ref Exception =
|
|||
err.errCode = code
|
||||
err
|
||||
|
||||
proc raiseTLSStreamProtoError*[T](message: T) =
|
||||
template raiseTLSStreamProtoError*[T](message: T) =
|
||||
raise newTLSStreamProtocolError(message)
|
||||
|
||||
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||
var wstream = cast[TLSStreamWriter](stream)
|
||||
var engine: ptr SslEngineContext
|
||||
var error: ref Exception
|
||||
var error: ref AsyncStreamError
|
||||
|
||||
if wstream.kind == TLSStreamKind.Server:
|
||||
engine = addr wstream.scontext.eng
|
||||
|
@ -125,86 +125,77 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
|||
|
||||
wstream.state = AsyncStreamState.Running
|
||||
|
||||
try:
|
||||
var length: uint
|
||||
while true:
|
||||
while true:
|
||||
var item: WriteItem
|
||||
try:
|
||||
var state = engine.sslEngineCurrentState()
|
||||
|
||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||
wstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
|
||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
||||
if not(wstream.switchToReader.isSet()):
|
||||
wstream.switchToReader.fire()
|
||||
|
||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
|
||||
await wstream.switchToWriter.wait()
|
||||
wstream.switchToWriter.clear()
|
||||
# We need to refresh `state` because we just returned from readerLoop.
|
||||
continue
|
||||
|
||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||
# TLS record needs to be sent over stream.
|
||||
length = 0'u
|
||||
var buf = sslEngineSendrecBuf(engine, length)
|
||||
doAssert(length != 0 and not isNil(buf))
|
||||
var fut = awaitne wstream.wsource.write(buf, int(length))
|
||||
if fut.cancelled():
|
||||
raise fut.error
|
||||
elif fut.failed():
|
||||
error = fut.error
|
||||
break
|
||||
sslEngineSendrecAck(engine, length)
|
||||
continue
|
||||
|
||||
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||
# Application data can be sent over stream.
|
||||
if not(wstream.handshaked):
|
||||
wstream.stream.reader.handshaked = true
|
||||
wstream.handshaked = true
|
||||
if not(isNil(wstream.handshakeFut)):
|
||||
wstream.handshakeFut.complete()
|
||||
|
||||
var item = await wstream.queue.get()
|
||||
if item.size > 0:
|
||||
length = 0'u
|
||||
var buf = sslEngineSendappBuf(engine, length)
|
||||
let toWrite = min(int(length), item.size)
|
||||
copyOut(buf, item, toWrite)
|
||||
if int(length) >= item.size:
|
||||
# BearSSL is ready to accept whole item size.
|
||||
sslEngineSendappAck(engine, uint(item.size))
|
||||
sslEngineFlush(engine, 0)
|
||||
item.future.complete()
|
||||
else:
|
||||
# BearSSL is not ready to accept whole item, so we will send only
|
||||
# part of item and adjust offset.
|
||||
item.offset = item.offset + int(length)
|
||||
item.size = item.size - int(length)
|
||||
wstream.queue.addFirstNoWait(item)
|
||||
sslEngineSendappAck(engine, length)
|
||||
continue
|
||||
else:
|
||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
||||
if not(wstream.switchToReader.isSet()):
|
||||
wstream.switchToReader.fire()
|
||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
|
||||
await wstream.switchToWriter.wait()
|
||||
wstream.switchToWriter.clear()
|
||||
# We need to refresh `state` because we just returned from readerLoop.
|
||||
else:
|
||||
# Zero length item means finish
|
||||
wstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||
# TLS record needs to be sent over stream.
|
||||
var length = 0'u
|
||||
var buf = sslEngineSendrecBuf(engine, length)
|
||||
doAssert(length != 0 and not isNil(buf))
|
||||
await wstream.wsource.write(buf, int(length))
|
||||
sslEngineSendrecAck(engine, length)
|
||||
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||
# Application data can be sent over stream.
|
||||
if not(wstream.handshaked):
|
||||
wstream.stream.reader.handshaked = true
|
||||
wstream.handshaked = true
|
||||
if not(isNil(wstream.handshakeFut)):
|
||||
wstream.handshakeFut.complete()
|
||||
item = await wstream.queue.get()
|
||||
if item.size > 0:
|
||||
var length = 0'u
|
||||
var buf = sslEngineSendappBuf(engine, length)
|
||||
let toWrite = min(int(length), item.size)
|
||||
copyOut(buf, item, toWrite)
|
||||
if int(length) >= item.size:
|
||||
# BearSSL is ready to accept whole item size.
|
||||
sslEngineSendappAck(engine, uint(item.size))
|
||||
sslEngineFlush(engine, 0)
|
||||
item.future.complete()
|
||||
else:
|
||||
# BearSSL is not ready to accept whole item, so we will send
|
||||
# only part of item and adjust offset.
|
||||
item.offset = item.offset + int(length)
|
||||
item.size = item.size - int(length)
|
||||
wstream.queue.addFirstNoWait(item)
|
||||
sslEngineSendappAck(engine, length)
|
||||
else:
|
||||
# Zero length item means finish, so we going to trigger TLS
|
||||
# closure protocol.
|
||||
sslEngineClose(engine)
|
||||
except CancelledError:
|
||||
wstream.state = AsyncStreamState.Stopped
|
||||
error = newAsyncStreamUseClosedError()
|
||||
except AsyncStreamError as exc:
|
||||
wstream.state = AsyncStreamState.Error
|
||||
error = exc
|
||||
|
||||
except CancelledError:
|
||||
wstream.state = AsyncStreamState.Stopped
|
||||
|
||||
finally:
|
||||
if wstream.state == AsyncStreamState.Stopped:
|
||||
while len(wstream.queue) > 0:
|
||||
let item = wstream.queue.popFirstNoWait()
|
||||
if not(item.future.finished()):
|
||||
item.future.complete()
|
||||
elif wstream.state == AsyncStreamState.Error:
|
||||
while len(wstream.queue) > 0:
|
||||
let item = wstream.queue.popFirstNoWait()
|
||||
if not(item.future.finished()):
|
||||
item.future.fail(error)
|
||||
wstream.stream = nil
|
||||
if wstream.state != AsyncStreamState.Running:
|
||||
if wstream.state == AsyncStreamState.Finished:
|
||||
error = newAsyncStreamUseClosedError()
|
||||
else:
|
||||
if not(isNil(item.future)):
|
||||
if not(item.future.finished()):
|
||||
item.future.fail(error)
|
||||
while not(wstream.queue.empty()):
|
||||
let pitem = wstream.queue.popFirstNoWait()
|
||||
if not(pitem.future.finished()):
|
||||
pitem.future.fail(error)
|
||||
wstream.stream = nil
|
||||
break
|
||||
|
||||
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||
var rstream = cast[TLSStreamReader](stream)
|
||||
|
@ -217,72 +208,61 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
|||
|
||||
rstream.state = AsyncStreamState.Running
|
||||
|
||||
try:
|
||||
var length: uint
|
||||
while true:
|
||||
while true:
|
||||
try:
|
||||
var state = engine.sslEngineCurrentState()
|
||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||
let err = engine.sslEngineLastError()
|
||||
if err != 0:
|
||||
raise newTLSStreamProtocolError(err)
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
break
|
||||
|
||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
||||
if not(rstream.switchToWriter.isSet()):
|
||||
rstream.switchToWriter.fire()
|
||||
|
||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
|
||||
await rstream.switchToReader.wait()
|
||||
rstream.switchToReader.clear()
|
||||
# We need to refresh `state` because we just returned from writerLoop.
|
||||
continue
|
||||
|
||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||
# TLS records required for further processing
|
||||
length = 0'u
|
||||
var buf = sslEngineRecvrecBuf(engine, length)
|
||||
let res = await rstream.rsource.readOnce(buf, int(length))
|
||||
if res > 0:
|
||||
sslEngineRecvrecAck(engine, uint(res))
|
||||
continue
|
||||
rstream.error = newTLSStreamProtocolError(err)
|
||||
rstream.state = AsyncStreamState.Error
|
||||
else:
|
||||
rstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
else:
|
||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
||||
if not(rstream.switchToWriter.isSet()):
|
||||
rstream.switchToWriter.fire()
|
||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
|
||||
await rstream.switchToReader.wait()
|
||||
rstream.switchToReader.clear()
|
||||
# We need to refresh `state` because we just returned from writerLoop.
|
||||
else:
|
||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||
# TLS records required for further processing
|
||||
var length = 0'u
|
||||
var buf = sslEngineRecvrecBuf(engine, length)
|
||||
let res = await rstream.rsource.readOnce(buf, int(length))
|
||||
if res > 0:
|
||||
sslEngineRecvrecAck(engine, uint(res))
|
||||
else:
|
||||
# readOnce() returns `0` if stream is at EOF, so we initiate TLS
|
||||
# closure procedure.
|
||||
sslEngineClose(engine)
|
||||
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
|
||||
# Application data can be recovered.
|
||||
var length = 0'u
|
||||
var buf = sslEngineRecvappBuf(engine, length)
|
||||
await upload(addr rstream.buffer, buf, int(length))
|
||||
sslEngineRecvappAck(engine, length)
|
||||
except CancelledError:
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
except AsyncStreamError as exc:
|
||||
rstream.error = exc
|
||||
rstream.state = AsyncStreamState.Error
|
||||
if not(rstream.handshaked):
|
||||
rstream.handshaked = true
|
||||
rstream.stream.writer.handshaked = true
|
||||
if not(isNil(rstream.handshakeFut)):
|
||||
rstream.handshakeFut.fail(rstream.error)
|
||||
rstream.switchToWriter.fire()
|
||||
|
||||
if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
||||
# Application data can be recovered.
|
||||
length = 0'u
|
||||
var buf = sslEngineRecvappBuf(engine, length)
|
||||
await upload(addr rstream.buffer, buf, int(length))
|
||||
sslEngineRecvappAck(engine, length)
|
||||
continue
|
||||
|
||||
except CancelledError:
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
except TLSStreamProtocolError as exc:
|
||||
rstream.error = exc
|
||||
rstream.state = AsyncStreamState.Error
|
||||
if not(rstream.handshaked):
|
||||
rstream.handshaked = true
|
||||
rstream.stream.writer.handshaked = true
|
||||
if not(isNil(rstream.handshakeFut)):
|
||||
rstream.handshakeFut.fail(rstream.error)
|
||||
rstream.switchToWriter.fire()
|
||||
except AsyncStreamReadError as exc:
|
||||
rstream.error = exc
|
||||
rstream.state = AsyncStreamState.Error
|
||||
if not(rstream.handshaked):
|
||||
rstream.handshaked = true
|
||||
rstream.stream.writer.handshaked = true
|
||||
if not(isNil(rstream.handshakeFut)):
|
||||
rstream.handshakeFut.fail(rstream.error)
|
||||
rstream.switchToWriter.fire()
|
||||
finally:
|
||||
# Perform TLS cleanup procedure
|
||||
sslEngineClose(engine)
|
||||
rstream.buffer.forget()
|
||||
rstream.stream = nil
|
||||
if rstream.state != AsyncStreamState.Running:
|
||||
# Perform TLS cleanup procedure
|
||||
if rstream.state != AsyncStreamState.Finished:
|
||||
sslEngineClose(engine)
|
||||
rstream.buffer.forget()
|
||||
rstream.stream = nil
|
||||
break
|
||||
|
||||
proc getSignerAlgo(xc: X509Certificate): int =
|
||||
## Get certificate's signing algorithm.
|
||||
|
@ -291,9 +271,9 @@ proc getSignerAlgo(xc: X509Certificate): int =
|
|||
x509DecoderPush(addr dc, xc.data, xc.dataLen)
|
||||
let err = x509DecoderLastError(addr dc)
|
||||
if err != 0:
|
||||
result = -1
|
||||
-1
|
||||
else:
|
||||
result = int(x509DecoderGetSignerKeyType(addr dc))
|
||||
int(x509DecoderGetSignerKeyType(addr dc))
|
||||
|
||||
proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
|
||||
wsource: AsyncStreamWriter,
|
||||
|
@ -318,54 +298,59 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
|
|||
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
|
||||
##
|
||||
## ``flags`` - custom TLS connection flags.
|
||||
result = new TLSAsyncStream
|
||||
var reader = TLSStreamReader(kind: TLSStreamKind.Client)
|
||||
var writer = TLSStreamWriter(kind: TLSStreamKind.Client)
|
||||
var switchToWriter = newAsyncEvent()
|
||||
var switchToReader = newAsyncEvent()
|
||||
reader.stream = result
|
||||
writer.stream = result
|
||||
reader.switchToReader = switchToReader
|
||||
reader.switchToWriter = switchToWriter
|
||||
writer.switchToReader = switchToReader
|
||||
writer.switchToWriter = switchToWriter
|
||||
result.reader = reader
|
||||
result.writer = writer
|
||||
reader.ccontext = addr result.ccontext
|
||||
writer.ccontext = addr result.ccontext
|
||||
let switchToWriter = newAsyncEvent()
|
||||
let switchToReader = newAsyncEvent()
|
||||
var res = TLSAsyncStream()
|
||||
var reader = TLSStreamReader(
|
||||
kind: TLSStreamKind.Client,
|
||||
stream: res,
|
||||
switchToReader: switchToReader,
|
||||
switchToWriter: switchToWriter,
|
||||
ccontext: addr res.ccontext
|
||||
)
|
||||
var writer = TLSStreamWriter(
|
||||
kind: TLSStreamKind.Client,
|
||||
stream: res,
|
||||
switchToReader: switchToReader,
|
||||
switchToWriter: switchToWriter,
|
||||
ccontext: addr res.ccontext
|
||||
)
|
||||
res.reader = reader
|
||||
res.writer = writer
|
||||
|
||||
if TLSFlags.NoVerifyHost in flags:
|
||||
sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0)
|
||||
initNoAnchor(addr result.xwc, addr result.x509.vtable)
|
||||
sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable)
|
||||
sslClientInitFull(addr res.ccontext, addr res.x509, nil, 0)
|
||||
initNoAnchor(addr res.xwc, addr res.x509.vtable)
|
||||
sslEngineSetX509(addr res.ccontext.eng, addr res.xwc.vtable)
|
||||
else:
|
||||
sslClientInitFull(addr result.ccontext, addr result.x509,
|
||||
sslClientInitFull(addr res.ccontext, addr res.x509,
|
||||
unsafeAddr MozillaTrustAnchors[0],
|
||||
len(MozillaTrustAnchors))
|
||||
|
||||
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
|
||||
result.sbuffer = newSeq[byte](size)
|
||||
sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0],
|
||||
uint(len(result.sbuffer)), 1)
|
||||
sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion),
|
||||
res.sbuffer = newSeq[byte](size)
|
||||
sslEngineSetBuffer(addr res.ccontext.eng, addr res.sbuffer[0],
|
||||
uint(len(res.sbuffer)), 1)
|
||||
sslEngineSetVersions(addr res.ccontext.eng, uint16(minVersion),
|
||||
uint16(maxVersion))
|
||||
|
||||
if TLSFlags.NoVerifyServerName in flags:
|
||||
let err = sslClientReset(addr result.ccontext, "", 0)
|
||||
let err = sslClientReset(addr res.ccontext, "", 0)
|
||||
if err == 0:
|
||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
||||
else:
|
||||
if len(serverName) == 0:
|
||||
raise newException(TLSStreamError, "serverName must not be empty string")
|
||||
|
||||
let err = sslClientReset(addr result.ccontext, serverName, 0)
|
||||
let err = sslClientReset(addr res.ccontext, serverName, 0)
|
||||
if err == 0:
|
||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
||||
|
||||
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
|
||||
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop,
|
||||
bufferSize)
|
||||
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
|
||||
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop,
|
||||
bufferSize)
|
||||
res
|
||||
|
||||
proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
||||
wsource: AsyncStreamWriter,
|
||||
|
@ -395,98 +380,104 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
|||
if isNil(certificate) or len(certificate.certs) == 0:
|
||||
raiseTLSStreamProtoError("Incorrect certificate")
|
||||
|
||||
result = new TLSAsyncStream
|
||||
var reader = TLSStreamReader(kind: TLSStreamKind.Server)
|
||||
var writer = TLSStreamWriter(kind: TLSStreamKind.Server)
|
||||
var switchToWriter = newAsyncEvent()
|
||||
var switchToReader = newAsyncEvent()
|
||||
reader.stream = result
|
||||
writer.stream = result
|
||||
reader.switchToReader = switchToReader
|
||||
reader.switchToWriter = switchToWriter
|
||||
writer.switchToReader = switchToReader
|
||||
writer.switchToWriter = switchToWriter
|
||||
result.reader = reader
|
||||
result.writer = writer
|
||||
reader.scontext = addr result.scontext
|
||||
writer.scontext = addr result.scontext
|
||||
let switchToWriter = newAsyncEvent()
|
||||
let switchToReader = newAsyncEvent()
|
||||
|
||||
var res = TLSAsyncStream()
|
||||
var reader = TLSStreamReader(
|
||||
kind: TLSStreamKind.Server,
|
||||
stream: res,
|
||||
switchToReader: switchToReader,
|
||||
switchToWriter: switchToWriter,
|
||||
scontext: addr res.scontext
|
||||
)
|
||||
var writer = TLSStreamWriter(
|
||||
kind: TLSStreamKind.Server,
|
||||
stream: res,
|
||||
switchToReader: switchToReader,
|
||||
switchToWriter: switchToWriter,
|
||||
scontext: addr res.scontext
|
||||
)
|
||||
res.reader = reader
|
||||
res.writer = writer
|
||||
|
||||
if privateKey.kind == TLSKeyType.EC:
|
||||
let algo = getSignerAlgo(certificate.certs[0])
|
||||
if algo == -1:
|
||||
raiseTLSStreamProtoError("Could not decode certificate")
|
||||
sslServerInitFullEc(addr result.scontext, addr certificate.certs[0],
|
||||
sslServerInitFullEc(addr res.scontext, addr certificate.certs[0],
|
||||
len(certificate.certs), cuint(algo),
|
||||
addr privateKey.eckey)
|
||||
elif privateKey.kind == TLSKeyType.RSA:
|
||||
sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0],
|
||||
sslServerInitFullRsa(addr res.scontext, addr certificate.certs[0],
|
||||
len(certificate.certs), addr privateKey.rsakey)
|
||||
|
||||
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
|
||||
result.sbuffer = newSeq[byte](size)
|
||||
sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0],
|
||||
uint(len(result.sbuffer)), 1)
|
||||
sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion),
|
||||
res.sbuffer = newSeq[byte](size)
|
||||
sslEngineSetBuffer(addr res.scontext.eng, addr res.sbuffer[0],
|
||||
uint(len(res.sbuffer)), 1)
|
||||
sslEngineSetVersions(addr res.scontext.eng, uint16(minVersion),
|
||||
uint16(maxVersion))
|
||||
|
||||
if not isNil(cache):
|
||||
sslServerSetCache(addr result.scontext, addr cache.context.vtable)
|
||||
sslServerSetCache(addr res.scontext, addr cache.context.vtable)
|
||||
|
||||
if TLSFlags.EnforceServerPref in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
|
||||
sslEngineAddFlags(addr res.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
|
||||
if TLSFlags.NoRenegotiation in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION)
|
||||
sslEngineAddFlags(addr res.scontext.eng, OPT_NO_RENEGOTIATION)
|
||||
if TLSFlags.TolerateNoClientAuth in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH)
|
||||
sslEngineAddFlags(addr res.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH)
|
||||
if TLSFlags.FailOnAlpnMismatch in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH)
|
||||
sslEngineAddFlags(addr res.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH)
|
||||
|
||||
let err = sslServerReset(addr result.scontext)
|
||||
let err = sslServerReset(addr res.scontext)
|
||||
if err == 0:
|
||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
||||
|
||||
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
|
||||
init(cast[AsyncStreamWriter](res.writer), wsource, tlsWriteLoop,
|
||||
bufferSize)
|
||||
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
|
||||
init(cast[AsyncStreamReader](res.reader), rsource, tlsReadLoop,
|
||||
bufferSize)
|
||||
res
|
||||
|
||||
proc copyKey(src: RsaPrivateKey): TLSPrivateKey =
|
||||
## Creates copy of RsaPrivateKey ``src``.
|
||||
var offset = 0
|
||||
let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen
|
||||
result = TLSPrivateKey(kind: TLSKeyType.RSA)
|
||||
result.storage = newSeq[byte](keySize)
|
||||
copyMem(addr result.storage[offset], src.p, src.plen)
|
||||
result.rsakey.p = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.plen = src.plen
|
||||
var res = TLSPrivateKey(kind: TLSKeyType.RSA, storage: newSeq[byte](keySize))
|
||||
copyMem(addr res.storage[offset], src.p, src.plen)
|
||||
res.rsakey.p = cast[ptr cuchar](addr res.storage[offset])
|
||||
res.rsakey.plen = src.plen
|
||||
offset = offset + src.plen
|
||||
copyMem(addr result.storage[offset], src.q, src.qlen)
|
||||
result.rsakey.q = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.qlen = src.qlen
|
||||
copyMem(addr res.storage[offset], src.q, src.qlen)
|
||||
res.rsakey.q = cast[ptr cuchar](addr res.storage[offset])
|
||||
res.rsakey.qlen = src.qlen
|
||||
offset = offset + src.qlen
|
||||
copyMem(addr result.storage[offset], src.dp, src.dplen)
|
||||
result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.dplen = src.dplen
|
||||
copyMem(addr res.storage[offset], src.dp, src.dplen)
|
||||
res.rsakey.dp = cast[ptr cuchar](addr res.storage[offset])
|
||||
res.rsakey.dplen = src.dplen
|
||||
offset = offset + src.dplen
|
||||
copyMem(addr result.storage[offset], src.dq, src.dqlen)
|
||||
result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.dqlen = src.dqlen
|
||||
copyMem(addr res.storage[offset], src.dq, src.dqlen)
|
||||
res.rsakey.dq = cast[ptr cuchar](addr res.storage[offset])
|
||||
res.rsakey.dqlen = src.dqlen
|
||||
offset = offset + src.dqlen
|
||||
copyMem(addr result.storage[offset], src.iq, src.iqlen)
|
||||
result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.iqlen = src.iqlen
|
||||
result.rsakey.nBitlen = src.nBitlen
|
||||
copyMem(addr res.storage[offset], src.iq, src.iqlen)
|
||||
res.rsakey.iq = cast[ptr cuchar](addr res.storage[offset])
|
||||
res.rsakey.iqlen = src.iqlen
|
||||
res.rsakey.nBitlen = src.nBitlen
|
||||
res
|
||||
|
||||
proc copyKey(src: EcPrivateKey): TLSPrivateKey =
|
||||
## Creates copy of EcPrivateKey ``src``.
|
||||
var offset = 0
|
||||
let keySize = src.xlen
|
||||
result = TLSPrivateKey(kind: TLSKeyType.EC)
|
||||
result.storage = newSeq[byte](keySize)
|
||||
copyMem(addr result.storage[offset], src.x, src.xlen)
|
||||
result.eckey.x = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.eckey.xlen = src.xlen
|
||||
result.eckey.curve = src.curve
|
||||
var res = TLSPrivateKey(kind: TLSKeyType.EC, storage: newSeq[byte](keySize))
|
||||
copyMem(addr res.storage[offset], src.x, src.xlen)
|
||||
res.eckey.x = cast[ptr cuchar](addr res.storage[offset])
|
||||
res.eckey.xlen = src.xlen
|
||||
res.eckey.curve = src.curve
|
||||
res
|
||||
|
||||
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
|
||||
## Initialize TLS private key from array of bytes ``data``.
|
||||
|
@ -502,12 +493,14 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
|
|||
if err != 0:
|
||||
raiseTLSStreamProtoError(err)
|
||||
let keyType = skeyDecoderKeyType(addr ctx)
|
||||
if keyType == KEYTYPE_RSA:
|
||||
result = copyKey(ctx.key.rsa)
|
||||
elif keyType == KEYTYPE_EC:
|
||||
result = copyKey(ctx.key.ec)
|
||||
else:
|
||||
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
|
||||
let res =
|
||||
if keyType == KEYTYPE_RSA:
|
||||
copyKey(ctx.key.rsa)
|
||||
elif keyType == KEYTYPE_EC:
|
||||
copyKey(ctx.key.ec)
|
||||
else:
|
||||
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
|
||||
res
|
||||
|
||||
proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
||||
## Decode PEM encoded string and get array of binary blobs.
|
||||
|
@ -515,7 +508,7 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
|||
raiseTLSStreamProtoError("Empty PEM message")
|
||||
var ctx: PemDecoderContext
|
||||
var pctx = new PEMContext
|
||||
result = newSeq[PEMElement]()
|
||||
var res = newSeq[PEMElement]()
|
||||
pemDecoderInit(addr ctx)
|
||||
|
||||
proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} =
|
||||
|
@ -544,12 +537,13 @@ proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
|||
elif event == PEM_END_OBJ:
|
||||
if inobj:
|
||||
elem.data = pctx.data
|
||||
result.add(elem)
|
||||
res.add(elem)
|
||||
inobj = false
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raiseTLSStreamProtoError("Invalid PEM encoding")
|
||||
res
|
||||
|
||||
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
|
||||
## Initialize TLS private key from string ``data``.
|
||||
|
@ -558,13 +552,15 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
|
|||
## encoded string.
|
||||
##
|
||||
## Note that PKCS#1 PEM encoded objects are not supported.
|
||||
var res: TLSPrivateKey
|
||||
var items = pemDecode(data)
|
||||
for item in items:
|
||||
if item.name == "PRIVATE KEY":
|
||||
result = TLSPrivateKey.init(item.data)
|
||||
res = TLSPrivateKey.init(item.data)
|
||||
break
|
||||
if isNil(result):
|
||||
if isNil(res):
|
||||
raiseTLSStreamProtoError("Could not find private key")
|
||||
res
|
||||
|
||||
proc init*(tt: typedesc[TLSCertificate],
|
||||
data: openarray[char]): TLSCertificate =
|
||||
|
@ -572,32 +568,33 @@ proc init*(tt: typedesc[TLSCertificate],
|
|||
##
|
||||
## This procedure initializes array of certificates from PEM encoded string.
|
||||
var items = pemDecode(data)
|
||||
result = new TLSCertificate
|
||||
var res = TLSCertificate()
|
||||
for item in items:
|
||||
if item.name == "CERTIFICATE" and len(item.data) > 0:
|
||||
let offset = len(result.storage)
|
||||
result.storage.add(item.data)
|
||||
let offset = len(res.storage)
|
||||
res.storage.add(item.data)
|
||||
let cert = X509Certificate(
|
||||
data: cast[ptr cuchar](addr result.storage[offset]),
|
||||
data: cast[ptr cuchar](addr res.storage[offset]),
|
||||
dataLen: len(item.data)
|
||||
)
|
||||
let res = getSignerAlgo(cert)
|
||||
if res == -1:
|
||||
let ares = getSignerAlgo(cert)
|
||||
if ares == -1:
|
||||
raiseTLSStreamProtoError("Could not decode certificate")
|
||||
elif res != KEYTYPE_RSA and res != KEYTYPE_EC:
|
||||
elif ares != KEYTYPE_RSA and ares != KEYTYPE_EC:
|
||||
raiseTLSStreamProtoError("Unsupported signing key type in certificate")
|
||||
result.certs.add(cert)
|
||||
if len(result.storage) == 0:
|
||||
res.certs.add(cert)
|
||||
if len(res.storage) == 0:
|
||||
raiseTLSStreamProtoError("Could not find any certificates")
|
||||
res
|
||||
|
||||
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =
|
||||
## Create new TLS session cache with size ``size``.
|
||||
##
|
||||
## One cached item is near 100 bytes size.
|
||||
result = new TLSSessionCache
|
||||
var rsize = min(size, 4096)
|
||||
result.storage = newSeq[byte](rsize)
|
||||
sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize)
|
||||
var res = TLSSessionCache(storage: newSeq[byte](rsize))
|
||||
sslSessionCacheLruInit(addr res.context, addr res.storage[0], rsize)
|
||||
res
|
||||
|
||||
proc handshake*(rws: SomeTLSStreamType): Future[void] =
|
||||
## Wait until initial TLS handshake will be successfully performed.
|
||||
|
@ -620,4 +617,4 @@ proc handshake*(rws: SomeTLSStreamType): Future[void] =
|
|||
else:
|
||||
rws.reader.handshakeFut = retFuture
|
||||
rws.writer.handshakeFut = retFuture
|
||||
return retFuture
|
||||
retFuture
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import unittest
|
||||
import ../chronos, ../chronos/streams/tlsstream
|
||||
import ../chronos
|
||||
import ../chronos/streams/[tlsstream, chunkstream, boundstream]
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
|
@ -553,8 +554,12 @@ suite "ChunkedStream test suite":
|
|||
try:
|
||||
var r = await rstream2.read()
|
||||
doAssert(len(r) > 0)
|
||||
except AsyncStreamReadError:
|
||||
res = true
|
||||
except ChunkedStreamIncompleteError:
|
||||
if inputstr == "100000000 \r\n1":
|
||||
res = true
|
||||
except ChunkedStreamProtocolError:
|
||||
if inputstr == "z\r\n1":
|
||||
res = true
|
||||
await rstream2.closeWait()
|
||||
await rstream.closeWait()
|
||||
await transp.closeWait()
|
||||
|
@ -663,3 +668,119 @@ suite "TLSStream test suite":
|
|||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
|
||||
suite "BoundedStream test suite":
|
||||
|
||||
proc createBigMessage(size: int): seq[byte] =
|
||||
var message = "MESSAGE"
|
||||
result = newSeq[byte](size)
|
||||
for i in 0 ..< len(result):
|
||||
result[i] = byte(message[i mod len(message)])
|
||||
|
||||
for item in [100'u64, 60000'u64]:
|
||||
|
||||
proc boundedTest(address: TransportAddress, test: int,
|
||||
size: uint64): Future[bool] {.async.} =
|
||||
var clientRes = false
|
||||
var res = false
|
||||
|
||||
let messagePart = createBigMessage(int(item) div 10)
|
||||
var message: seq[byte]
|
||||
for i in 0 ..< 10:
|
||||
message.add(messagePart)
|
||||
|
||||
proc processClient(server: StreamServer,
|
||||
transp: StreamTransport) {.async.} =
|
||||
var wstream = newAsyncStreamWriter(transp)
|
||||
var wbstream = newBoundedStreamWriter(wstream, size)
|
||||
if test == 0:
|
||||
for i in 0 ..< 10:
|
||||
await wbstream.write(messagePart)
|
||||
await wbstream.finish()
|
||||
await wbstream.closeWait()
|
||||
clientRes = true
|
||||
elif test == 1:
|
||||
for i in 0 ..< 10:
|
||||
await wbstream.write(messagePart)
|
||||
try:
|
||||
await wbstream.write(messagePart)
|
||||
except BoundedStreamOverflowError:
|
||||
clientRes = true
|
||||
await wbstream.closeWait()
|
||||
elif test == 2:
|
||||
for i in 0 ..< 9:
|
||||
await wbstream.write(messagePart)
|
||||
try:
|
||||
await wbstream.finish()
|
||||
except BoundedStreamIncompleteError:
|
||||
clientRes = true
|
||||
await wbstream.closeWait()
|
||||
elif test == 3:
|
||||
for i in 0 ..< 10:
|
||||
await wbstream.write(messagePart)
|
||||
await wbstream.finish()
|
||||
await wbstream.closeWait()
|
||||
clientRes = true
|
||||
elif test == 4:
|
||||
for i in 0 ..< 9:
|
||||
await wbstream.write(messagePart)
|
||||
try:
|
||||
await wbstream.closeWait()
|
||||
except BoundedStreamIncompleteError:
|
||||
clientRes = true
|
||||
|
||||
await wstream.closeWait()
|
||||
await transp.closeWait()
|
||||
server.stop()
|
||||
server.close()
|
||||
|
||||
var server = createStreamServer(address, processClient, flags = {ReuseAddr})
|
||||
server.start()
|
||||
var conn = await connect(address)
|
||||
var rstream = newAsyncStreamReader(conn)
|
||||
var rbstream = newBoundedStreamReader(rstream, size)
|
||||
if test == 0:
|
||||
let response = await rbstream.read()
|
||||
await rbstream.closeWait()
|
||||
if response == message:
|
||||
res = true
|
||||
elif test == 1:
|
||||
let response = await rbstream.read()
|
||||
await rbstream.closeWait()
|
||||
if response == message:
|
||||
res = true
|
||||
elif test == 2:
|
||||
try:
|
||||
let response {.used.} = await rbstream.read()
|
||||
except BoundedStreamIncompleteError:
|
||||
res = true
|
||||
await rbstream.closeWait()
|
||||
elif test == 3:
|
||||
let response {.used.} = await rbstream.read(int(size) - 1)
|
||||
try:
|
||||
await rbstream.closeWait()
|
||||
except BoundedStreamIncompleteError:
|
||||
res = true
|
||||
elif test == 4:
|
||||
try:
|
||||
let response {.used.} = await rbstream.read()
|
||||
except BoundedStreamIncompleteError:
|
||||
res = true
|
||||
await rbstream.closeWait()
|
||||
|
||||
await rstream.closeWait()
|
||||
await conn.closeWait()
|
||||
await server.join()
|
||||
return (res and clientRes)
|
||||
|
||||
let address = initTAddress("127.0.0.1:48030")
|
||||
test "BoundedStream reading/writing test [" & $item & "]":
|
||||
check waitFor(boundedTest(address, 0, item)) == true
|
||||
test "BoundedStream overflow test [" & $item & "]":
|
||||
check waitFor(boundedTest(address, 1, item)) == true
|
||||
test "BoundedStream incomplete test [" & $item & "]":
|
||||
check waitFor(boundedTest(address, 2, item)) == true
|
||||
test "BoundedStream read() close test [" & $item & "]":
|
||||
check waitFor(boundedTest(address, 3, item)) == true
|
||||
test "BoundedStream write() close test [" & $item & "]":
|
||||
check waitFor(boundedTest(address, 4, item)) == true
|
||||
|
|
Loading…
Reference in New Issue