Simplification and fixes for TLSStream state machine.

This commit is contained in:
cheatfate 2021-01-22 10:36:37 +02:00 committed by zah
parent e8d2a3ca0a
commit 13eddf382d
5 changed files with 243 additions and 253 deletions

View File

@ -409,7 +409,6 @@ when defined(windows) or defined(nimdoc):
proc poll*() = proc poll*() =
## Perform single asynchronous step. ## Perform single asynchronous step.
echo "poll()"
let loop = getThreadDispatcher() let loop = getThreadDispatcher()
var curTime = Moment.now() var curTime = Moment.now()
var curTimeout = DWORD(0) var curTimeout = DWORD(0)
@ -423,8 +422,6 @@ when defined(windows) or defined(nimdoc):
var lpCompletionKey: ULONG_PTR var lpCompletionKey: ULONG_PTR
var customOverlapped: PtrCustomOverlapped var customOverlapped: PtrCustomOverlapped
echo "poll() timeout = ", curTimeout, ", len(callbacks) = ", len(loop.callbacks)
let res = getQueuedCompletionStatus( let res = getQueuedCompletionStatus(
loop.ioPort, addr lpNumberOfBytesTransferred, loop.ioPort, addr lpNumberOfBytesTransferred,
addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped), addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped),
@ -460,7 +457,6 @@ when defined(windows) or defined(nimdoc):
# All callbacks which will be added in process will be processed on next # All callbacks which will be added in process will be processed on next
# poll() call. # poll() call.
loop.processCallbacks() loop.processCallbacks()
echo "exit poll()"
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) = proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
## Closes a socket and ensures that it is unregistered. ## Closes a socket and ensures that it is unregistered.

View File

@ -80,6 +80,7 @@ type
writerLoop*: StreamWriterLoop writerLoop*: StreamWriterLoop
state*: AsyncStreamState state*: AsyncStreamState
queue*: AsyncQueue[WriteItem] queue*: AsyncQueue[WriteItem]
error*: ref AsyncStreamError
udata: pointer udata: pointer
bytesCount*: uint64 bytesCount*: uint64
future: Future[void] future: Future[void]

View File

@ -61,8 +61,6 @@ type
of TLSStreamKind.Server: of TLSStreamKind.Server:
scontext: ptr SslServerContext scontext: ptr SslServerContext
stream*: TLSAsyncStream stream*: TLSAsyncStream
switchToReader*: AsyncEvent
switchToWriter*: AsyncEvent
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void]
@ -73,8 +71,6 @@ type
of TLSStreamKind.Server: of TLSStreamKind.Server:
scontext: ptr SslServerContext scontext: ptr SslServerContext
stream*: TLSAsyncStream stream*: TLSAsyncStream
switchToReader*: AsyncEvent
switchToWriter*: AsyncEvent
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void]
@ -86,13 +82,33 @@ type
x509*: X509MinimalContext x509*: X509MinimalContext
reader*: TLSStreamReader reader*: TLSStreamReader
writer*: TLSStreamWriter writer*: TLSStreamWriter
mainLoop*: Future[void]
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
TLSStreamError* = object of AsyncStreamError TLSStreamError* = object of AsyncStreamError
TLSStreamHandshakeError* = object of TLSStreamError
TLSStreamReadError* = object of TLSStreamError
par*: ref AsyncStreamError
TLSStreamWriteError* = object of TLSStreamError
par*: ref AsyncStreamError
TLSStreamProtocolError* = object of TLSStreamError TLSStreamProtocolError* = object of TLSStreamError
errCode*: int errCode*: int
proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {.
inline.} =
var w = newException(TLSStreamReadError, "Read stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p
w
proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {.
inline.} =
var w = newException(TLSStreamWriteError, "Write stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p
w
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
var msg = "" var msg = ""
var code = 0 var code = 0
@ -110,72 +126,27 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
err.errCode = code err.errCode = code
err err
proc dumpState*(state: cuint): string = proc tlsWriteRec(engine: ptr SslEngineContext,
var res = "" writer: TLSStreamWriter): Future[bool] {.async.} =
if (state and SSL_CLOSED) == SSL_CLOSED:
if len(res) > 0: res.add(", ")
res.add("SSL_CLOSED")
if (state and SSL_SENDREC) == SSL_SENDREC:
if len(res) > 0: res.add(", ")
res.add("SSL_SENDREC")
if (state and SSL_SENDAPP) == SSL_SENDAPP:
if len(res) > 0: res.add(", ")
res.add("SSL_SENDAPP")
if (state and SSL_RECVREC) == SSL_RECVREC:
if len(res) > 0: res.add(", ")
res.add("SSL_RECVREC")
if (state and SSL_RECVAPP) == SSL_RECVAPP:
if len(res) > 0: res.add(", ")
res.add("SSL_RECVAPP")
"{" & res & "}"
template raiseTLSStreamProtoError*[T](message: T) =
raise newTLSStreamProtocolError(message)
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[TLSStreamWriter](stream)
var engine: ptr SslEngineContext
var error: ref AsyncStreamError
if wstream.kind == TLSStreamKind.Server:
engine = addr wstream.scontext.eng
else:
engine = addr wstream.ccontext.eng
wstream.state = AsyncStreamState.Running
while true:
var item: WriteItem
try: try:
var state = engine.sslEngineCurrentState()
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") state = ", dumpState(state)
if (state and SSL_CLOSED) == SSL_CLOSED:
wstream.state = AsyncStreamState.Finished
else:
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") firing switch to reader, waiters = ", len(wstream.switchToReader.waiters)
wstream.switchToReader.fire()
if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream.
var length = 0'u var length = 0'u
var buf = sslEngineSendrecBuf(engine, length) var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf)) doAssert(length != 0 and not isNil(buf))
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") sending record ", int(length), " bytes" await writer.wsource.write(buf, int(length))
await wstream.wsource.write(buf, int(length))
sslEngineSendrecAck(engine, length) sslEngineSendrecAck(engine, length)
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") record ", int(length), " bytes sent" return true
elif (state and SSL_SENDAPP) == SSL_SENDAPP: except AsyncStreamError as exc:
# Application data can be sent over stream. writer.state = AsyncStreamState.Error
if not(wstream.handshaked): writer.error = exc
wstream.stream.reader.handshaked = true except CancelledError:
wstream.handshaked = true writer.state = AsyncStreamState.Stopped
if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete() return false
if not(wstream.queue.empty()):
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for appdata" proc tlsWriteApp(engine: ptr SslEngineContext,
item = await wstream.queue.get() writer: TLSStreamWriter): Future[bool] {.async.} =
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") sending appdata ", int(item.size), " bytes" try:
var item = await writer.queue.get()
if item.size > 0: if item.size > 0:
var length = 0'u var length = 0'u
var buf = sslEngineSendappBuf(engine, length) var buf = sslEngineSendappBuf(engine, length)
@ -186,147 +157,210 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
sslEngineSendappAck(engine, uint(item.size)) sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0) sslEngineFlush(engine, 0)
item.future.complete() item.future.complete()
return true
else: else:
# BearSSL is not ready to accept whole item, so we will send # BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset. # only part of item and adjust offset.
item.offset = item.offset + int(length) item.offset = item.offset + int(length)
item.size = item.size - int(length) item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item) writer.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length) sslEngineSendappAck(engine, length)
return true
else: else:
# Zero length item means finish, so we going to trigger TLS
# closure protocol.
sslEngineClose(engine) sslEngineClose(engine)
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") ", item.future.complete()
"received zero-length item, state = ", return true
dumpState(engine.sslEngineCurrentState())
else:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") empty queue"
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for switch back"
await wstream.switchToWriter.wait()
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got flow after switch"
wstream.switchToWriter.clear()
else:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for switch back, switchToReader.isSet() == ", wstream.switchToReader.isSet()
await wstream.switchToWriter.wait()
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got flow after switch"
wstream.switchToWriter.clear()
except CancelledError: except CancelledError:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") received cancellation" writer.state = AsyncStreamState.Stopped
wstream.state = AsyncStreamState.Stopped
error = newAsyncStreamUseClosedError()
except AsyncStreamError as exc:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got an exception ",
exc.msg
wstream.state = AsyncStreamState.Error
error = exc
if wstream.state != AsyncStreamState.Running: return false
if wstream.state == AsyncStreamState.Finished:
error = newAsyncStreamUseClosedError() proc tlsReadRec(engine: ptr SslEngineContext,
reader: TLSStreamReader): Future[bool] {.async.} =
try:
var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
let res = await reader.rsource.readOnce(buf, int(length))
sslEngineRecvrecAck(engine, uint(res))
return true
except CancelledError:
reader.state = AsyncStreamState.Stopped
except AsyncStreamError as exc:
reader.state = AsyncStreamState.Error
reader.error = exc
return false
proc tlsReadApp(engine: ptr SslEngineContext,
reader: TLSStreamReader): Future[bool] {.async.} =
try:
var length = 0'u
var buf = sslEngineRecvappBuf(engine, length)
await upload(addr reader.buffer, buf, int(length))
sslEngineRecvappAck(engine, length)
return true
except CancelledError:
reader.state = AsyncStreamState.Stopped
return false
template raiseTLSStreamProtoError*[T](message: T) =
raise newTLSStreamProtocolError(message)
template readAndReset(fut: untyped) =
if fut.finished():
if fut.read():
fut = nil
continue
else: else:
if not(isNil(item.future)): fut = nil
loopState = AsyncStreamState.Error
break
proc cancelAndWait*(a, b, c, d: Future[bool]): Future[void] =
var waiting: seq[Future[bool]]
if not(isNil(a)) and not(a.finished()):
a.cancel()
waiting.add(a)
if not(isNil(b)) and not(b.finished()):
b.cancel()
waiting.add(b)
if not(isNil(c)) and not(c.finished()):
c.cancel()
waiting.add(c)
if not(isNil(d)) and not(d.finished()):
d.cancel()
waiting.add(d)
allFutures(waiting)
proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
var
sendRecFut, sendAppFut: Future[bool]
recvRecFut, recvAppFut: Future[bool]
let engine =
case stream.reader.kind
of TLSStreamKind.Server:
addr stream.scontext.eng
of TLSStreamKind.Client:
addr stream.ccontext.eng
var loopState = AsyncStreamState.Running
while true:
var waiting: seq[Future[bool]]
var state = sslEngineCurrentState(engine)
if (state and SSL_CLOSED) == SSL_CLOSED:
loopState = AsyncStreamState.Finished
break
if isNil(sendRecFut):
if (state and SSL_SENDREC) == SSL_SENDREC:
sendRecFut = tlsWriteRec(engine, stream.writer)
else:
sendRecFut.readAndReset()
if isNil(sendAppFut):
if (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream.
if not(stream.writer.handshaked):
stream.reader.handshaked = true
stream.writer.handshaked = true
if not(isNil(stream.writer.handshakeFut)):
stream.writer.handshakeFut.complete()
sendAppFut = tlsWriteApp(engine, stream.writer)
else:
sendAppFut.readAndReset()
if isNil(recvRecFut):
if (state and SSL_RECVREC) == SSL_RECVREC:
recvRecFut = tlsReadRec(engine, stream.reader)
else:
recvRecFut.readAndReset()
if isNil(recvAppFut):
if (state and SSL_RECVAPP) == SSL_RECVAPP:
recvAppFut = tlsReadApp(engine, stream.reader)
else:
recvAppFut.readAndReset()
if not(isNil(sendRecFut)):
waiting.add(sendRecFut)
if not(isNil(sendAppFut)):
waiting.add(sendAppFut)
if not(isNil(recvRecFut)):
waiting.add(recvRecFut)
if not(isNil(recvAppFut)):
waiting.add(recvAppFut)
if len(waiting) > 0:
try:
discard await one(waiting)
except CancelledError:
loopState = AsyncStreamState.Stopped
if loopState != AsyncStreamState.Running:
break
# Cancelling and waiting all the pending operations
await cancelAndWait(sendRecFut, sendAppFut, recvRecFut, recvAppFut)
# Calculating error
let error =
case loopState
of AsyncStreamState.Stopped:
newAsyncStreamUseClosedError()
of AsyncStreamState.Error:
if not(isNil(stream.writer.error)):
stream.writer.error
else:
newTLSStreamWriteError(stream.reader.error)
of AsyncStreamState.Finished:
let err = engine.sslEngineLastError()
if err != 0:
newTLSStreamProtocolError(err)
else:
nil
of AsyncStreamState.Running:
nil
else:
nil
# Syncing state for reader and writer
stream.writer.state = loopState
if loopState == AsyncStreamState.Error:
if isNil(stream.reader.error):
stream.reader.error = newTLSStreamReadError(error)
stream.reader.state = loopState
if not(isNil(error)):
# Completing all pending writes
while(not(stream.writer.queue.empty())):
let item = stream.writer.queue.popFirstNoWait()
if not(item.future.finished()): if not(item.future.finished()):
item.future.fail(error) item.future.fail(error)
while not(wstream.queue.empty()): # Completing handshake
let pitem = wstream.queue.popFirstNoWait() if not(stream.writer.handshaked):
if not(pitem.future.finished()): if not(isNil(stream.writer.handshakeFut)):
pitem.future.fail(error) if not(stream.writer.handshakeFut.finished()):
stream.writer.handshakeFut.fail(error)
# Completing readers
stream.reader.buffer.forget()
if not(isNil(wstream.stream.reader)): proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
wstream.switchToReader.fire() var wstream = cast[TLSStreamWriter](stream)
wstream.state = AsyncStreamState.Running
wstream.stream.writer = nil await stepsAsync(1)
if isNil(wstream.stream.mainLoop):
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") handle exited" wstream.stream.mainLoop = tlsLoop(wstream.stream)
break await wstream.stream.mainLoop
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") exited"
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TLSStreamReader](stream) var rstream = cast[TLSStreamReader](stream)
var engine: ptr SslEngineContext
if rstream.kind == TLSStreamKind.Server:
engine = addr rstream.scontext.eng
else:
engine = addr rstream.ccontext.eng
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
await stepsAsync(1)
while true: if isNil(rstream.stream.mainLoop):
try: rstream.stream.mainLoop = tlsLoop(rstream.stream)
var state = engine.sslEngineCurrentState() await rstream.stream.mainLoop
echo "tlsReadLoop(", cast[uint](rstream.stream), ") state = ", dumpState(state)
if (state and SSL_CLOSED) == SSL_CLOSED:
let err = engine.sslEngineLastError()
if err != 0:
rstream.error = newTLSStreamProtocolError(err)
rstream.state = AsyncStreamState.Error
else:
rstream.state = AsyncStreamState.Finished
else:
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") ",
"firing switch to writer, len(waiters) = ",
len(rstream.switchToWriter.waiters)
rstream.switchToWriter.fire()
if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing
var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
echo "tlsReadLoop(", cast[uint](rstream.stream), ") waiting for record"
let res = await rstream.rsource.readOnce(buf, int(length))
sslEngineRecvrecAck(engine, uint(res))
echo "tlsReadLoop(", cast[uint](rstream.stream), ") received ", res, " length rec, state = ", dumpState(engine.sslEngineCurrentState())
# if res > 0:
# sslEngineRecvrecAck(engine, uint(res))
# else:
# echo "tlsReadLoop() received 0 length ack"
# # readOnce() returns `0` if stream is at EOF.
# # rstream.state = AsyncStreamState.Finished
# sslEngineClose(engine)
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
# Application data can be recovered.
var length = 0'u
var buf = sslEngineRecvappBuf(engine, length)
await upload(addr rstream.buffer, buf, int(length))
sslEngineRecvappAck(engine, length)
else:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") waiting for `switchToReader` back, ",
"switchToReader.isSet() == ", rstream.switchToReader.isSet(),
", state = ", dumpState(engine.sslEngineCurrentState())
await rstream.switchToReader.wait()
echo "tlsReadLoop(", cast[uint](rstream.stream), ") got flow after switch"
rstream.switchToReader.clear()
except CancelledError:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") cancellation received"
rstream.state = AsyncStreamState.Stopped
except AsyncStreamError as exc:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") got an exception ",
exc.msg
rstream.error = exc
rstream.state = AsyncStreamState.Error
if rstream.state != AsyncStreamState.Running:
if not(rstream.handshaked):
rstream.handshaked = true
rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error)
# Perform TLS cleanup procedure
if not(isNil(rstream.stream.writer)):
rstream.switchToWriter.fire()
if rstream.state != AsyncStreamState.Finished:
sslEngineClose(engine)
rstream.buffer.forget()
rstream.stream.reader = nil
echo "tlsReadLoop(", cast[uint](rstream.stream), ") handle exited"
break
echo "tlsReadLoop(", cast[uint](rstream.stream), ") exited"
proc getSignerAlgo(xc: X509Certificate): int = proc getSignerAlgo(xc: X509Certificate): int =
## Get certificate's signing algorithm. ## Get certificate's signing algorithm.
@ -362,21 +396,15 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
## ``minVersion`` of bigger then ``maxVersion`` you will get an error. ## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
## ##
## ``flags`` - custom TLS connection flags. ## ``flags`` - custom TLS connection flags.
let switchToWriter = newAsyncEvent()
let switchToReader = newAsyncEvent()
var res = TLSAsyncStream() var res = TLSAsyncStream()
var reader = TLSStreamReader( var reader = TLSStreamReader(
kind: TLSStreamKind.Client, kind: TLSStreamKind.Client,
stream: res, stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
ccontext: addr res.ccontext ccontext: addr res.ccontext
) )
var writer = TLSStreamWriter( var writer = TLSStreamWriter(
kind: TLSStreamKind.Client, kind: TLSStreamKind.Client,
stream: res, stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
ccontext: addr res.ccontext ccontext: addr res.ccontext
) )
res.reader = reader res.reader = reader
@ -444,22 +472,15 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
if isNil(certificate) or len(certificate.certs) == 0: if isNil(certificate) or len(certificate.certs) == 0:
raiseTLSStreamProtoError("Incorrect certificate") raiseTLSStreamProtoError("Incorrect certificate")
let switchToWriter = newAsyncEvent()
let switchToReader = newAsyncEvent()
var res = TLSAsyncStream() var res = TLSAsyncStream()
var reader = TLSStreamReader( var reader = TLSStreamReader(
kind: TLSStreamKind.Server, kind: TLSStreamKind.Server,
stream: res, stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
scontext: addr res.scontext scontext: addr res.scontext
) )
var writer = TLSStreamWriter( var writer = TLSStreamWriter(
kind: TLSStreamKind.Server, kind: TLSStreamKind.Server,
stream: res, stream: res,
switchToReader: switchToReader,
switchToWriter: switchToWriter,
scontext: addr res.scontext scontext: addr res.scontext
) )
res.reader = reader res.reader = reader

View File

@ -1639,7 +1639,6 @@ proc close*(server: StreamServer) =
## Please note that release of resources is not completed immediately, to be ## Please note that release of resources is not completed immediately, to be
## sure all resources got released please use ``await server.join()``. ## sure all resources got released please use ``await server.join()``.
proc continuation(udata: pointer) {.gcsafe.} = proc continuation(udata: pointer) {.gcsafe.} =
echo "server close() continuation"
server.clean() server.clean()
let r1 = (server.status == ServerStatus.Stopped) and let r1 = (server.status == ServerStatus.Stopped) and

View File

@ -594,7 +594,6 @@ suite "TLSStream test suite":
var reader = newAsyncStreamReader(transp) var reader = newAsyncStreamReader(transp)
var writer = newAsyncStreamWriter(transp) var writer = newAsyncStreamWriter(transp)
var tlsstream = newTLSClientAsyncStream(reader, writer, name) var tlsstream = newTLSClientAsyncStream(reader, writer, name)
await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name &
"\r\nConnection: close\r\n\r\n") "\r\nConnection: close\r\n\r\n")
var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer), var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer),
@ -625,65 +624,39 @@ suite "TLSStream test suite":
proc serveClient(server: StreamServer, proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async.} =
echo "- server accepted client"
var reader = newAsyncStreamReader(transp) var reader = newAsyncStreamReader(transp)
var writer = newAsyncStreamWriter(transp) var writer = newAsyncStreamWriter(transp)
var sstream = newTLSServerAsyncStream(reader, writer, key, cert) var sstream = newTLSServerAsyncStream(reader, writer, key, cert)
echo "- server stream is [", cast[uint](sstream), "]"
echo "- server handshaking"
await handshake(sstream) await handshake(sstream)
echo "- server handshaked"
await sstream.writer.write(testMessage & "\r\n") await sstream.writer.write(testMessage & "\r\n")
echo "- server wrote string"
await sstream.writer.finish() await sstream.writer.finish()
echo "- server finished"
await sleepAsync(5.seconds)
echo "- server sleeped"
await sstream.writer.closeWait() await sstream.writer.closeWait()
echo "- server closed secure writer"
await sstream.reader.closeWait() await sstream.reader.closeWait()
echo "- server closed secure reader"
await reader.closeWait() await reader.closeWait()
echo "- server closed reader"
await writer.closeWait() await writer.closeWait()
echo "- server closed writer"
await transp.closeWait() await transp.closeWait()
echo "- server closed transport"
server.stop() server.stop()
echo "- server stopped server"
server.close() server.close()
echo "- server closed server"
key = TLSPrivateKey.init(pemkey) key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert) cert = TLSCertificate.init(pemcert)
var server = createStreamServer(address, serveClient, {ReuseAddr}) var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start() server.start()
echo "server started"
var conn = await connect(address) var conn = await connect(address)
echo "= client connected"
var creader = newAsyncStreamReader(conn) var creader = newAsyncStreamReader(conn)
var cwriter = newAsyncStreamWriter(conn) var cwriter = newAsyncStreamWriter(conn)
# We are using self-signed certificate # We are using self-signed certificate
let flags = {NoVerifyHost, NoVerifyServerName} let flags = {NoVerifyHost, NoVerifyServerName}
var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags) var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags)
echo "= client stream is [", cast[uint](cstream), "]"
echo "= client reading line"
let res = await cstream.reader.read() let res = await cstream.reader.read()
echo "= client readed line"
await cstream.reader.closeWait() await cstream.reader.closeWait()
echo "= client closed reader"
await cstream.writer.closeWait() await cstream.writer.closeWait()
echo "= client closed writer"
await creader.closeWait() await creader.closeWait()
echo "= client closed creader"
await cwriter.closeWait() await cwriter.closeWait()
echo "= client closed cwriter"
await conn.closeWait() await conn.closeWait()
echo "= client closed connection"
await server.join() await server.join()
echo "= client waited server" return cast[string](res) == (testMessage & "\r\n")
result = true # res == testMessage
test "Simple server with RSA self-signed certificate": test "Simple server with RSA self-signed certificate":
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"), let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),