This commit is contained in:
cheatfate 2021-01-21 20:11:43 +02:00 committed by zah
parent 49fd70f504
commit e8d2a3ca0a
4 changed files with 111 additions and 62 deletions

View File

@ -409,6 +409,7 @@ when defined(windows) or defined(nimdoc):
proc poll*() = proc poll*() =
## Perform single asynchronous step. ## Perform single asynchronous step.
echo "poll()"
let loop = getThreadDispatcher() let loop = getThreadDispatcher()
var curTime = Moment.now() var curTime = Moment.now()
var curTimeout = DWORD(0) var curTimeout = DWORD(0)
@ -422,6 +423,8 @@ when defined(windows) or defined(nimdoc):
var lpCompletionKey: ULONG_PTR var lpCompletionKey: ULONG_PTR
var customOverlapped: PtrCustomOverlapped var customOverlapped: PtrCustomOverlapped
echo "poll() timeout = ", curTimeout, ", len(callbacks) = ", len(loop.callbacks)
let res = getQueuedCompletionStatus( let res = getQueuedCompletionStatus(
loop.ioPort, addr lpNumberOfBytesTransferred, loop.ioPort, addr lpNumberOfBytesTransferred,
addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped), addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped),
@ -457,6 +460,7 @@ when defined(windows) or defined(nimdoc):
# All callbacks which will be added in process will be processed on next # All callbacks which will be added in process will be processed on next
# poll() call. # poll() call.
loop.processCallbacks() loop.processCallbacks()
echo "exit poll()"
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) = proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
## Closes a socket and ensures that it is unregistered. ## Closes a socket and ensures that it is unregistered.

View File

@ -37,7 +37,7 @@ type
## state to be signaled, when event get fired, then all coroutines ## state to be signaled, when event get fired, then all coroutines
## continue proceeds in order, they have entered waiting state. ## continue proceeds in order, they have entered waiting state.
flag: bool flag: bool
waiters: seq[Future[void]] waiters*: seq[Future[void]]
AsyncQueue*[T] = ref object of RootRef AsyncQueue*[T] = ref object of RootRef
## A queue, useful for coordinating producer and consumer coroutines. ## A queue, useful for coordinating producer and consumer coroutines.

View File

@ -65,7 +65,6 @@ type
switchToWriter*: AsyncEvent switchToWriter*: AsyncEvent
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void]
closeshakeFut*: Future[void]
TLSStreamReader* = ref object of AsyncStreamReader TLSStreamReader* = ref object of AsyncStreamReader
case kind: TLSStreamKind case kind: TLSStreamKind
@ -78,7 +77,6 @@ type
switchToWriter*: AsyncEvent switchToWriter*: AsyncEvent
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void]
closeshakeFut*: Future[void]
TLSAsyncStream* = ref object of RootRef TLSAsyncStream* = ref object of RootRef
xwc*: X509NoAnchorContext xwc*: X509NoAnchorContext
@ -150,20 +148,23 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
var item: WriteItem var item: WriteItem
try: try:
var state = engine.sslEngineCurrentState() var state = engine.sslEngineCurrentState()
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") state = ", dumpState(state)
if (state and SSL_CLOSED) == SSL_CLOSED: if (state and SSL_CLOSED) == SSL_CLOSED:
wstream.state = AsyncStreamState.Finished wstream.state = AsyncStreamState.Finished
else: else:
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
if not(wstream.switchToReader.isSet()): echo "tlsWriteLoop(", cast[uint](wstream.stream), ") firing switch to reader, waiters = ", len(wstream.switchToReader.waiters)
wstream.switchToReader.fire() wstream.switchToReader.fire()
if (state and SSL_SENDREC) == SSL_SENDREC: if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream. # TLS record needs to be sent over stream.
var length = 0'u var length = 0'u
var buf = sslEngineSendrecBuf(engine, length) var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf)) doAssert(length != 0 and not isNil(buf))
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") sending record ", int(length), " bytes"
await wstream.wsource.write(buf, int(length)) await wstream.wsource.write(buf, int(length))
sslEngineSendrecAck(engine, length) sslEngineSendrecAck(engine, length)
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") record ", int(length), " bytes sent"
elif (state and SSL_SENDAPP) == SSL_SENDAPP: elif (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream. # Application data can be sent over stream.
if not(wstream.handshaked): if not(wstream.handshaked):
@ -171,38 +172,53 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
wstream.handshaked = true wstream.handshaked = true
if not(isNil(wstream.handshakeFut)): if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete() wstream.handshakeFut.complete()
item = await wstream.queue.get() if not(wstream.queue.empty()):
if item.size > 0: echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for appdata"
var length = 0'u item = await wstream.queue.get()
var buf = sslEngineSendappBuf(engine, length) echo "tlsWriteLoop(", cast[uint](wstream.stream), ") sending appdata ", int(item.size), " bytes"
let toWrite = min(int(length), item.size) if item.size > 0:
copyOut(buf, item, toWrite) var length = 0'u
if int(length) >= item.size: var buf = sslEngineSendappBuf(engine, length)
# BearSSL is ready to accept whole item size. let toWrite = min(int(length), item.size)
sslEngineSendappAck(engine, uint(item.size)) copyOut(buf, item, toWrite)
sslEngineFlush(engine, 0) if int(length) >= item.size:
item.future.complete() # BearSSL is ready to accept whole item size.
sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0)
item.future.complete()
else:
# BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset.
item.offset = item.offset + int(length)
item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length)
else: else:
# BearSSL is not ready to accept whole item, so we will send # Zero length item means finish, so we going to trigger TLS
# only part of item and adjust offset. # closure protocol.
item.offset = item.offset + int(length) sslEngineClose(engine)
item.size = item.size - int(length) echo "tlsWriteLoop(", cast[uint](wstream.stream), ") ",
wstream.queue.addFirstNoWait(item) "received zero-length item, state = ",
sslEngineSendappAck(engine, length) dumpState(engine.sslEngineCurrentState())
else: else:
# Zero length item means finish, so we going to trigger TLS echo "tlsWriteLoop(", cast[uint](wstream.stream), ") empty queue"
# closure protocol. echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for switch back"
wstream.state = AsyncStreamState.Finished await wstream.switchToWriter.wait()
sslEngineClose(engine) echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got flow after switch"
item.future.complete() wstream.switchToWriter.clear()
else: else:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for switch back, switchToReader.isSet() == ", wstream.switchToReader.isSet()
await wstream.switchToWriter.wait() await wstream.switchToWriter.wait()
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got flow after switch"
wstream.switchToWriter.clear() wstream.switchToWriter.clear()
except CancelledError: except CancelledError:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") received cancellation"
wstream.state = AsyncStreamState.Stopped wstream.state = AsyncStreamState.Stopped
error = newAsyncStreamUseClosedError() error = newAsyncStreamUseClosedError()
except AsyncStreamError as exc: except AsyncStreamError as exc:
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got an exception ",
exc.msg
wstream.state = AsyncStreamState.Error wstream.state = AsyncStreamState.Error
error = exc error = exc
@ -217,8 +233,15 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
let pitem = wstream.queue.popFirstNoWait() let pitem = wstream.queue.popFirstNoWait()
if not(pitem.future.finished()): if not(pitem.future.finished()):
pitem.future.fail(error) pitem.future.fail(error)
wstream.stream = nil
if not(isNil(wstream.stream.reader)):
wstream.switchToReader.fire()
wstream.stream.writer = nil
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") handle exited"
break break
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") exited"
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TLSStreamReader](stream) var rstream = cast[TLSStreamReader](stream)
@ -234,6 +257,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
while true: while true:
try: try:
var state = engine.sslEngineCurrentState() var state = engine.sslEngineCurrentState()
echo "tlsReadLoop(", cast[uint](rstream.stream), ") state = ", dumpState(state)
if (state and SSL_CLOSED) == SSL_CLOSED: if (state and SSL_CLOSED) == SSL_CLOSED:
let err = engine.sslEngineLastError() let err = engine.sslEngineLastError()
if err != 0: if err != 0:
@ -243,20 +267,26 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.state = AsyncStreamState.Finished rstream.state = AsyncStreamState.Finished
else: else:
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
if not(rstream.switchToWriter.isSet()): echo "tlsReadLoop(", cast[uint](rstream.stream), ") ",
rstream.switchToWriter.fire() "firing switch to writer, len(waiters) = ",
len(rstream.switchToWriter.waiters)
rstream.switchToWriter.fire()
if (state and SSL_RECVREC) == SSL_RECVREC: if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing # TLS records required for further processing
var length = 0'u var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length) var buf = sslEngineRecvrecBuf(engine, length)
echo "tlsReadLoop(", cast[uint](rstream.stream), ") waiting for record"
let res = await rstream.rsource.readOnce(buf, int(length)) let res = await rstream.rsource.readOnce(buf, int(length))
if res > 0: sslEngineRecvrecAck(engine, uint(res))
sslEngineRecvrecAck(engine, uint(res)) echo "tlsReadLoop(", cast[uint](rstream.stream), ") received ", res, " length rec, state = ", dumpState(engine.sslEngineCurrentState())
else: # if res > 0:
# readOnce() returns `0` if stream is at EOF. # sslEngineRecvrecAck(engine, uint(res))
rstream.state = AsyncStreamState.Finished # else:
sslEngineClose(engine) # echo "tlsReadLoop() received 0 length ack"
# # readOnce() returns `0` if stream is at EOF.
# # rstream.state = AsyncStreamState.Finished
# sslEngineClose(engine)
elif (state and SSL_RECVAPP) == SSL_RECVAPP: elif (state and SSL_RECVAPP) == SSL_RECVAPP:
# Application data can be recovered. # Application data can be recovered.
var length = 0'u var length = 0'u
@ -264,28 +294,39 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
await upload(addr rstream.buffer, buf, int(length)) await upload(addr rstream.buffer, buf, int(length))
sslEngineRecvappAck(engine, length) sslEngineRecvappAck(engine, length)
else: else:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") waiting for `switchToReader` back, ",
"switchToReader.isSet() == ", rstream.switchToReader.isSet(),
", state = ", dumpState(engine.sslEngineCurrentState())
await rstream.switchToReader.wait() await rstream.switchToReader.wait()
echo "tlsReadLoop(", cast[uint](rstream.stream), ") got flow after switch"
rstream.switchToReader.clear() rstream.switchToReader.clear()
except CancelledError: except CancelledError:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") cancellation received"
rstream.state = AsyncStreamState.Stopped rstream.state = AsyncStreamState.Stopped
except AsyncStreamError as exc: except AsyncStreamError as exc:
echo "tlsReadLoop(", cast[uint](rstream.stream), ") got an exception ",
exc.msg
rstream.error = exc rstream.error = exc
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
if rstream.state != AsyncStreamState.Running:
if not(rstream.handshaked): if not(rstream.handshaked):
rstream.handshaked = true rstream.handshaked = true
rstream.stream.writer.handshaked = true rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)): if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error) rstream.handshakeFut.fail(rstream.error)
rstream.switchToWriter.fire()
if rstream.state != AsyncStreamState.Running:
# Perform TLS cleanup procedure # Perform TLS cleanup procedure
if not(isNil(rstream.stream.writer)):
rstream.switchToWriter.fire()
if rstream.state != AsyncStreamState.Finished: if rstream.state != AsyncStreamState.Finished:
sslEngineClose(engine) sslEngineClose(engine)
rstream.buffer.forget() rstream.buffer.forget()
rstream.stream = nil rstream.stream.reader = nil
echo "tlsReadLoop(", cast[uint](rstream.stream), ") handle exited"
break break
echo "tlsReadLoop(", cast[uint](rstream.stream), ") exited"
proc getSignerAlgo(xc: X509Certificate): int = proc getSignerAlgo(xc: X509Certificate): int =
## Get certificate's signing algorithm. ## Get certificate's signing algorithm.

View File

@ -625,31 +625,34 @@ suite "TLSStream test suite":
proc serveClient(server: StreamServer, proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async.} =
echo "server accepted client" echo "- server accepted client"
var reader = newAsyncStreamReader(transp) var reader = newAsyncStreamReader(transp)
var writer = newAsyncStreamWriter(transp) var writer = newAsyncStreamWriter(transp)
var sstream = newTLSServerAsyncStream(reader, writer, key, cert) var sstream = newTLSServerAsyncStream(reader, writer, key, cert)
echo "server handshaking" echo "- server stream is [", cast[uint](sstream), "]"
echo "- server handshaking"
await handshake(sstream) await handshake(sstream)
echo "server handshaked" echo "- server handshaked"
await sstream.writer.write(testMessage & "\r\n") await sstream.writer.write(testMessage & "\r\n")
echo "server wrote string" echo "- server wrote string"
await sstream.writer.finish() await sstream.writer.finish()
echo "server finished string" echo "- server finished"
await sleepAsync(5.seconds)
echo "- server sleeped"
await sstream.writer.closeWait() await sstream.writer.closeWait()
echo "server closed secure writer" echo "- server closed secure writer"
await sstream.reader.closeWait() await sstream.reader.closeWait()
echo "server closed secure reader" echo "- server closed secure reader"
await reader.closeWait() await reader.closeWait()
echo "server closed reader" echo "- server closed reader"
await writer.closeWait() await writer.closeWait()
echo "server closed writer" echo "- server closed writer"
await transp.closeWait() await transp.closeWait()
echo "server closed transport" echo "- server closed transport"
server.stop() server.stop()
echo "server stopped server" echo "- server stopped server"
server.close() server.close()
echo "server closed server" echo "- server closed server"
key = TLSPrivateKey.init(pemkey) key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert) cert = TLSCertificate.init(pemcert)
@ -658,28 +661,29 @@ suite "TLSStream test suite":
server.start() server.start()
echo "server started" echo "server started"
var conn = await connect(address) var conn = await connect(address)
echo "client connected" echo "= client connected"
var creader = newAsyncStreamReader(conn) var creader = newAsyncStreamReader(conn)
var cwriter = newAsyncStreamWriter(conn) var cwriter = newAsyncStreamWriter(conn)
# We are using self-signed certificate # We are using self-signed certificate
let flags = {NoVerifyHost, NoVerifyServerName} let flags = {NoVerifyHost, NoVerifyServerName}
var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags) var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags)
echo "client reading line" echo "= client stream is [", cast[uint](cstream), "]"
let res = await cstream.reader.readLine() echo "= client reading line"
echo "client readed line" let res = await cstream.reader.read()
echo "= client readed line"
await cstream.reader.closeWait() await cstream.reader.closeWait()
echo "client closed reader" echo "= client closed reader"
await cstream.writer.closeWait() await cstream.writer.closeWait()
echo "client closed writer" echo "= client closed writer"
await creader.closeWait() await creader.closeWait()
echo "client closed creader" echo "= client closed creader"
await cwriter.closeWait() await cwriter.closeWait()
echo "client closed cwriter" echo "= client closed cwriter"
await conn.closeWait() await conn.closeWait()
echo "client closed connection" echo "= client closed connection"
await server.join() await server.join()
echo "client waited server" echo "= client waited server"
result = res == testMessage result = true # res == testMessage
test "Simple server with RSA self-signed certificate": test "Simple server with RSA self-signed certificate":
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"), let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),