From 49fd70f5041fedf29882576b04858e2cbe7ea0be Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 21 Jan 2021 05:42:44 +0200 Subject: [PATCH] Attempt #4. --- chronos/asyncloop.nim | 6 -- chronos/streams/tlsstream.nim | 141 ++++++++++++++++++++-------------- chronos/transports/stream.nim | 3 - tests/testasyncstream.nim | 2 + 4 files changed, 84 insertions(+), 68 deletions(-) diff --git a/chronos/asyncloop.nim b/chronos/asyncloop.nim index dd0103b..d7a98ea 100644 --- a/chronos/asyncloop.nim +++ b/chronos/asyncloop.nim @@ -633,12 +633,9 @@ elif unixPlatform: proc continuation(udata: pointer) = if SocketHandle(fd) in loop.selector: - echo "closeSocket() continuation unregistering" unregister(fd) - echo "closeSocket() continuation close()" close(SocketHandle(fd)) if not isNil(aftercb): - echo "closeSocket() invoke user-callback" aftercb(nil) withData(loop.selector, int(fd), adata) do: @@ -648,12 +645,10 @@ elif unixPlatform: # from system queue for this reader and writer. if not(isNil(adata.reader.function)): - echo "closeSocket() scheduling reader" loop.callbacks.addLast(adata.reader) adata.reader = default(AsyncCallback) if not(isNil(adata.writer.function)): - echo "closeSocket() scheduling writer" loop.callbacks.addLast(adata.writer) adata.writer = default(AsyncCallback) @@ -661,7 +656,6 @@ elif unixPlatform: # in such case processing queue will stuck on poll() call, because there # can be no file descriptors registered in system queue. var acb = AsyncCallback(function: continuation) - echo "closeSocket() scheduling actual close" loop.callbacks.addLast(acb) proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) {.inline.} = diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 32258a7..6f20c55 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -65,6 +65,7 @@ type switchToWriter*: AsyncEvent handshaked*: bool handshakeFut*: Future[void] + closeshakeFut*: Future[void] TLSStreamReader* = ref object of AsyncStreamReader case kind: TLSStreamKind @@ -77,6 +78,7 @@ type switchToWriter*: AsyncEvent handshaked*: bool handshakeFut*: Future[void] + closeshakeFut*: Future[void] TLSAsyncStream* = ref object of RootRef xwc*: X509NoAnchorContext @@ -110,6 +112,25 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = err.errCode = code err +proc dumpState*(state: cuint): string = + var res = "" + if (state and SSL_CLOSED) == SSL_CLOSED: + if len(res) > 0: res.add(", ") + res.add("SSL_CLOSED") + if (state and SSL_SENDREC) == SSL_SENDREC: + if len(res) > 0: res.add(", ") + res.add("SSL_SENDREC") + if (state and SSL_SENDAPP) == SSL_SENDAPP: + if len(res) > 0: res.add(", ") + res.add("SSL_SENDAPP") + if (state and SSL_RECVREC) == SSL_RECVREC: + if len(res) > 0: res.add(", ") + res.add("SSL_RECVREC") + if (state and SSL_RECVAPP) == SSL_RECVAPP: + if len(res) > 0: res.add(", ") + res.add("SSL_RECVAPP") + "{" & res & "}" + template raiseTLSStreamProtoError*[T](message: T) = raise newTLSStreamProtocolError(message) @@ -135,47 +156,49 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: if not(wstream.switchToReader.isSet()): wstream.switchToReader.fire() - if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + var length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + await wstream.wsource.write(buf, int(length)) + sslEngineSendrecAck(engine, length) + elif (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + if not(wstream.handshaked): + wstream.stream.reader.handshaked = true + wstream.handshaked = true + if not(isNil(wstream.handshakeFut)): + wstream.handshakeFut.complete() + item = await wstream.queue.get() + if item.size > 0: + var length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + copyOut(buf, item, toWrite) + if int(length) >= item.size: + # BearSSL is ready to accept whole item size. + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + # BearSSL is not ready to accept whole item, so we will send + # only part of item and adjust offset. + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + else: + # Zero length item means finish, so we going to trigger TLS + # closure protocol. + wstream.state = AsyncStreamState.Finished + sslEngineClose(engine) + item.future.complete() + else: await wstream.switchToWriter.wait() wstream.switchToWriter.clear() - # We need to refresh `state` because we just returned from readerLoop. - else: - if (state and SSL_SENDREC) == SSL_SENDREC: - # TLS record needs to be sent over stream. - var length = 0'u - var buf = sslEngineSendrecBuf(engine, length) - doAssert(length != 0 and not isNil(buf)) - await wstream.wsource.write(buf, int(length)) - sslEngineSendrecAck(engine, length) - elif (state and SSL_SENDAPP) == SSL_SENDAPP: - # Application data can be sent over stream. - if not(wstream.handshaked): - wstream.stream.reader.handshaked = true - wstream.handshaked = true - if not(isNil(wstream.handshakeFut)): - wstream.handshakeFut.complete() - item = await wstream.queue.get() - if item.size > 0: - var length = 0'u - var buf = sslEngineSendappBuf(engine, length) - let toWrite = min(int(length), item.size) - copyOut(buf, item, toWrite) - if int(length) >= item.size: - # BearSSL is ready to accept whole item size. - sslEngineSendappAck(engine, uint(item.size)) - sslEngineFlush(engine, 0) - item.future.complete() - else: - # BearSSL is not ready to accept whole item, so we will send - # only part of item and adjust offset. - item.offset = item.offset + int(length) - item.size = item.size - int(length) - wstream.queue.addFirstNoWait(item) - sslEngineSendappAck(engine, length) - else: - # Zero length item means finish, so we going to trigger TLS - # closure protocol. - sslEngineClose(engine) + except CancelledError: wstream.state = AsyncStreamState.Stopped error = newAsyncStreamUseClosedError() @@ -222,28 +245,28 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: if not(rstream.switchToWriter.isSet()): rstream.switchToWriter.fire() - if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + + if (state and SSL_RECVREC) == SSL_RECVREC: + # TLS records required for further processing + var length = 0'u + var buf = sslEngineRecvrecBuf(engine, length) + let res = await rstream.rsource.readOnce(buf, int(length)) + if res > 0: + sslEngineRecvrecAck(engine, uint(res)) + else: + # readOnce() returns `0` if stream is at EOF. + rstream.state = AsyncStreamState.Finished + sslEngineClose(engine) + elif (state and SSL_RECVAPP) == SSL_RECVAPP: + # Application data can be recovered. + var length = 0'u + var buf = sslEngineRecvappBuf(engine, length) + await upload(addr rstream.buffer, buf, int(length)) + sslEngineRecvappAck(engine, length) + else: await rstream.switchToReader.wait() rstream.switchToReader.clear() - # We need to refresh `state` because we just returned from writerLoop. - else: - if (state and SSL_RECVREC) == SSL_RECVREC: - # TLS records required for further processing - var length = 0'u - var buf = sslEngineRecvrecBuf(engine, length) - let res = await rstream.rsource.readOnce(buf, int(length)) - if res > 0: - sslEngineRecvrecAck(engine, uint(res)) - else: - # readOnce() returns `0` if stream is at EOF, so we initiate TLS - # closure procedure. - sslEngineClose(engine) - elif (state and SSL_RECVAPP) == SSL_RECVAPP: - # Application data can be recovered. - var length = 0'u - var buf = sslEngineRecvappBuf(engine, length) - await upload(addr rstream.buffer, buf, int(length)) - sslEngineRecvappAck(engine, length) + except CancelledError: rstream.state = AsyncStreamState.Stopped except AsyncStreamError as exc: diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index 9c09eaf..752e7f0 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -274,15 +274,12 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector], vector.writer.fail(error) proc clean(server: StreamServer) {.inline.} = - echo "cleaning server instance" if not(server.loopFuture.finished()): - echo "cleaning server complete()" untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) GC_unref(server) - echo "clean server exit" proc clean(transp: StreamTransport) {.inline.} = if not(transp.future.finished()): diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 34a5194..039af73 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -634,6 +634,8 @@ suite "TLSStream test suite": echo "server handshaked" await sstream.writer.write(testMessage & "\r\n") echo "server wrote string" + await sstream.writer.finish() + echo "server finished string" await sstream.writer.closeWait() echo "server closed secure writer" await sstream.reader.closeWait()