This commit is contained in:
cheatfate 2021-01-21 05:42:44 +02:00 committed by zah
parent f1b43aeb04
commit 49fd70f504
4 changed files with 84 additions and 68 deletions

View File

@ -633,12 +633,9 @@ elif unixPlatform:
proc continuation(udata: pointer) = proc continuation(udata: pointer) =
if SocketHandle(fd) in loop.selector: if SocketHandle(fd) in loop.selector:
echo "closeSocket() continuation unregistering"
unregister(fd) unregister(fd)
echo "closeSocket() continuation close()"
close(SocketHandle(fd)) close(SocketHandle(fd))
if not isNil(aftercb): if not isNil(aftercb):
echo "closeSocket() invoke user-callback"
aftercb(nil) aftercb(nil)
withData(loop.selector, int(fd), adata) do: withData(loop.selector, int(fd), adata) do:
@ -648,12 +645,10 @@ elif unixPlatform:
# from system queue for this reader and writer. # from system queue for this reader and writer.
if not(isNil(adata.reader.function)): if not(isNil(adata.reader.function)):
echo "closeSocket() scheduling reader"
loop.callbacks.addLast(adata.reader) loop.callbacks.addLast(adata.reader)
adata.reader = default(AsyncCallback) adata.reader = default(AsyncCallback)
if not(isNil(adata.writer.function)): if not(isNil(adata.writer.function)):
echo "closeSocket() scheduling writer"
loop.callbacks.addLast(adata.writer) loop.callbacks.addLast(adata.writer)
adata.writer = default(AsyncCallback) adata.writer = default(AsyncCallback)
@ -661,7 +656,6 @@ elif unixPlatform:
# in such case processing queue will stuck on poll() call, because there # in such case processing queue will stuck on poll() call, because there
# can be no file descriptors registered in system queue. # can be no file descriptors registered in system queue.
var acb = AsyncCallback(function: continuation) var acb = AsyncCallback(function: continuation)
echo "closeSocket() scheduling actual close"
loop.callbacks.addLast(acb) loop.callbacks.addLast(acb)
proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) {.inline.} = proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) {.inline.} =

View File

@ -65,6 +65,7 @@ type
switchToWriter*: AsyncEvent switchToWriter*: AsyncEvent
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void]
closeshakeFut*: Future[void]
TLSStreamReader* = ref object of AsyncStreamReader TLSStreamReader* = ref object of AsyncStreamReader
case kind: TLSStreamKind case kind: TLSStreamKind
@ -77,6 +78,7 @@ type
switchToWriter*: AsyncEvent switchToWriter*: AsyncEvent
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void]
closeshakeFut*: Future[void]
TLSAsyncStream* = ref object of RootRef TLSAsyncStream* = ref object of RootRef
xwc*: X509NoAnchorContext xwc*: X509NoAnchorContext
@ -110,6 +112,25 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
err.errCode = code err.errCode = code
err err
proc dumpState*(state: cuint): string =
var res = ""
if (state and SSL_CLOSED) == SSL_CLOSED:
if len(res) > 0: res.add(", ")
res.add("SSL_CLOSED")
if (state and SSL_SENDREC) == SSL_SENDREC:
if len(res) > 0: res.add(", ")
res.add("SSL_SENDREC")
if (state and SSL_SENDAPP) == SSL_SENDAPP:
if len(res) > 0: res.add(", ")
res.add("SSL_SENDAPP")
if (state and SSL_RECVREC) == SSL_RECVREC:
if len(res) > 0: res.add(", ")
res.add("SSL_RECVREC")
if (state and SSL_RECVAPP) == SSL_RECVAPP:
if len(res) > 0: res.add(", ")
res.add("SSL_RECVAPP")
"{" & res & "}"
template raiseTLSStreamProtoError*[T](message: T) = template raiseTLSStreamProtoError*[T](message: T) =
raise newTLSStreamProtocolError(message) raise newTLSStreamProtocolError(message)
@ -135,47 +156,49 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
if not(wstream.switchToReader.isSet()): if not(wstream.switchToReader.isSet()):
wstream.switchToReader.fire() wstream.switchToReader.fire()
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream.
var length = 0'u
var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf))
await wstream.wsource.write(buf, int(length))
sslEngineSendrecAck(engine, length)
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream.
if not(wstream.handshaked):
wstream.stream.reader.handshaked = true
wstream.handshaked = true
if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete()
item = await wstream.queue.get()
if item.size > 0:
var length = 0'u
var buf = sslEngineSendappBuf(engine, length)
let toWrite = min(int(length), item.size)
copyOut(buf, item, toWrite)
if int(length) >= item.size:
# BearSSL is ready to accept whole item size.
sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0)
item.future.complete()
else:
# BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset.
item.offset = item.offset + int(length)
item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length)
else:
# Zero length item means finish, so we going to trigger TLS
# closure protocol.
wstream.state = AsyncStreamState.Finished
sslEngineClose(engine)
item.future.complete()
else:
await wstream.switchToWriter.wait() await wstream.switchToWriter.wait()
wstream.switchToWriter.clear() wstream.switchToWriter.clear()
# We need to refresh `state` because we just returned from readerLoop.
else:
if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream.
var length = 0'u
var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf))
await wstream.wsource.write(buf, int(length))
sslEngineSendrecAck(engine, length)
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream.
if not(wstream.handshaked):
wstream.stream.reader.handshaked = true
wstream.handshaked = true
if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete()
item = await wstream.queue.get()
if item.size > 0:
var length = 0'u
var buf = sslEngineSendappBuf(engine, length)
let toWrite = min(int(length), item.size)
copyOut(buf, item, toWrite)
if int(length) >= item.size:
# BearSSL is ready to accept whole item size.
sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0)
item.future.complete()
else:
# BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset.
item.offset = item.offset + int(length)
item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length)
else:
# Zero length item means finish, so we going to trigger TLS
# closure protocol.
sslEngineClose(engine)
except CancelledError: except CancelledError:
wstream.state = AsyncStreamState.Stopped wstream.state = AsyncStreamState.Stopped
error = newAsyncStreamUseClosedError() error = newAsyncStreamUseClosedError()
@ -222,28 +245,28 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
if not(rstream.switchToWriter.isSet()): if not(rstream.switchToWriter.isSet()):
rstream.switchToWriter.fire() rstream.switchToWriter.fire()
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing
var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
let res = await rstream.rsource.readOnce(buf, int(length))
if res > 0:
sslEngineRecvrecAck(engine, uint(res))
else:
# readOnce() returns `0` if stream is at EOF.
rstream.state = AsyncStreamState.Finished
sslEngineClose(engine)
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
# Application data can be recovered.
var length = 0'u
var buf = sslEngineRecvappBuf(engine, length)
await upload(addr rstream.buffer, buf, int(length))
sslEngineRecvappAck(engine, length)
else:
await rstream.switchToReader.wait() await rstream.switchToReader.wait()
rstream.switchToReader.clear() rstream.switchToReader.clear()
# We need to refresh `state` because we just returned from writerLoop.
else:
if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing
var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
let res = await rstream.rsource.readOnce(buf, int(length))
if res > 0:
sslEngineRecvrecAck(engine, uint(res))
else:
# readOnce() returns `0` if stream is at EOF, so we initiate TLS
# closure procedure.
sslEngineClose(engine)
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
# Application data can be recovered.
var length = 0'u
var buf = sslEngineRecvappBuf(engine, length)
await upload(addr rstream.buffer, buf, int(length))
sslEngineRecvappAck(engine, length)
except CancelledError: except CancelledError:
rstream.state = AsyncStreamState.Stopped rstream.state = AsyncStreamState.Stopped
except AsyncStreamError as exc: except AsyncStreamError as exc:

View File

@ -274,15 +274,12 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector],
vector.writer.fail(error) vector.writer.fail(error)
proc clean(server: StreamServer) {.inline.} = proc clean(server: StreamServer) {.inline.} =
echo "cleaning server instance"
if not(server.loopFuture.finished()): if not(server.loopFuture.finished()):
echo "cleaning server complete()"
untrackServer(server) untrackServer(server)
server.loopFuture.complete() server.loopFuture.complete()
if not isNil(server.udata) and GCUserData in server.flags: if not isNil(server.udata) and GCUserData in server.flags:
GC_unref(cast[ref int](server.udata)) GC_unref(cast[ref int](server.udata))
GC_unref(server) GC_unref(server)
echo "clean server exit"
proc clean(transp: StreamTransport) {.inline.} = proc clean(transp: StreamTransport) {.inline.} =
if not(transp.future.finished()): if not(transp.future.finished()):

View File

@ -634,6 +634,8 @@ suite "TLSStream test suite":
echo "server handshaked" echo "server handshaked"
await sstream.writer.write(testMessage & "\r\n") await sstream.writer.write(testMessage & "\r\n")
echo "server wrote string" echo "server wrote string"
await sstream.writer.finish()
echo "server finished string"
await sstream.writer.closeWait() await sstream.writer.closeWait()
echo "server closed secure writer" echo "server closed secure writer"
await sstream.reader.closeWait() await sstream.reader.closeWait()