Attempt #4.
This commit is contained in:
parent
f1b43aeb04
commit
49fd70f504
|
@ -633,12 +633,9 @@ elif unixPlatform:
|
||||||
|
|
||||||
proc continuation(udata: pointer) =
|
proc continuation(udata: pointer) =
|
||||||
if SocketHandle(fd) in loop.selector:
|
if SocketHandle(fd) in loop.selector:
|
||||||
echo "closeSocket() continuation unregistering"
|
|
||||||
unregister(fd)
|
unregister(fd)
|
||||||
echo "closeSocket() continuation close()"
|
|
||||||
close(SocketHandle(fd))
|
close(SocketHandle(fd))
|
||||||
if not isNil(aftercb):
|
if not isNil(aftercb):
|
||||||
echo "closeSocket() invoke user-callback"
|
|
||||||
aftercb(nil)
|
aftercb(nil)
|
||||||
|
|
||||||
withData(loop.selector, int(fd), adata) do:
|
withData(loop.selector, int(fd), adata) do:
|
||||||
|
@ -648,12 +645,10 @@ elif unixPlatform:
|
||||||
# from system queue for this reader and writer.
|
# from system queue for this reader and writer.
|
||||||
|
|
||||||
if not(isNil(adata.reader.function)):
|
if not(isNil(adata.reader.function)):
|
||||||
echo "closeSocket() scheduling reader"
|
|
||||||
loop.callbacks.addLast(adata.reader)
|
loop.callbacks.addLast(adata.reader)
|
||||||
adata.reader = default(AsyncCallback)
|
adata.reader = default(AsyncCallback)
|
||||||
|
|
||||||
if not(isNil(adata.writer.function)):
|
if not(isNil(adata.writer.function)):
|
||||||
echo "closeSocket() scheduling writer"
|
|
||||||
loop.callbacks.addLast(adata.writer)
|
loop.callbacks.addLast(adata.writer)
|
||||||
adata.writer = default(AsyncCallback)
|
adata.writer = default(AsyncCallback)
|
||||||
|
|
||||||
|
@ -661,7 +656,6 @@ elif unixPlatform:
|
||||||
# in such case processing queue will stuck on poll() call, because there
|
# in such case processing queue will stuck on poll() call, because there
|
||||||
# can be no file descriptors registered in system queue.
|
# can be no file descriptors registered in system queue.
|
||||||
var acb = AsyncCallback(function: continuation)
|
var acb = AsyncCallback(function: continuation)
|
||||||
echo "closeSocket() scheduling actual close"
|
|
||||||
loop.callbacks.addLast(acb)
|
loop.callbacks.addLast(acb)
|
||||||
|
|
||||||
proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) {.inline.} =
|
proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) {.inline.} =
|
||||||
|
|
|
@ -65,6 +65,7 @@ type
|
||||||
switchToWriter*: AsyncEvent
|
switchToWriter*: AsyncEvent
|
||||||
handshaked*: bool
|
handshaked*: bool
|
||||||
handshakeFut*: Future[void]
|
handshakeFut*: Future[void]
|
||||||
|
closeshakeFut*: Future[void]
|
||||||
|
|
||||||
TLSStreamReader* = ref object of AsyncStreamReader
|
TLSStreamReader* = ref object of AsyncStreamReader
|
||||||
case kind: TLSStreamKind
|
case kind: TLSStreamKind
|
||||||
|
@ -77,6 +78,7 @@ type
|
||||||
switchToWriter*: AsyncEvent
|
switchToWriter*: AsyncEvent
|
||||||
handshaked*: bool
|
handshaked*: bool
|
||||||
handshakeFut*: Future[void]
|
handshakeFut*: Future[void]
|
||||||
|
closeshakeFut*: Future[void]
|
||||||
|
|
||||||
TLSAsyncStream* = ref object of RootRef
|
TLSAsyncStream* = ref object of RootRef
|
||||||
xwc*: X509NoAnchorContext
|
xwc*: X509NoAnchorContext
|
||||||
|
@ -110,6 +112,25 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
||||||
err.errCode = code
|
err.errCode = code
|
||||||
err
|
err
|
||||||
|
|
||||||
|
proc dumpState*(state: cuint): string =
|
||||||
|
var res = ""
|
||||||
|
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||||
|
if len(res) > 0: res.add(", ")
|
||||||
|
res.add("SSL_CLOSED")
|
||||||
|
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||||
|
if len(res) > 0: res.add(", ")
|
||||||
|
res.add("SSL_SENDREC")
|
||||||
|
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||||
|
if len(res) > 0: res.add(", ")
|
||||||
|
res.add("SSL_SENDAPP")
|
||||||
|
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||||
|
if len(res) > 0: res.add(", ")
|
||||||
|
res.add("SSL_RECVREC")
|
||||||
|
if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
||||||
|
if len(res) > 0: res.add(", ")
|
||||||
|
res.add("SSL_RECVAPP")
|
||||||
|
"{" & res & "}"
|
||||||
|
|
||||||
template raiseTLSStreamProtoError*[T](message: T) =
|
template raiseTLSStreamProtoError*[T](message: T) =
|
||||||
raise newTLSStreamProtocolError(message)
|
raise newTLSStreamProtocolError(message)
|
||||||
|
|
||||||
|
@ -135,47 +156,49 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
||||||
if not(wstream.switchToReader.isSet()):
|
if not(wstream.switchToReader.isSet()):
|
||||||
wstream.switchToReader.fire()
|
wstream.switchToReader.fire()
|
||||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
|
|
||||||
|
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||||
|
# TLS record needs to be sent over stream.
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineSendrecBuf(engine, length)
|
||||||
|
doAssert(length != 0 and not isNil(buf))
|
||||||
|
await wstream.wsource.write(buf, int(length))
|
||||||
|
sslEngineSendrecAck(engine, length)
|
||||||
|
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||||
|
# Application data can be sent over stream.
|
||||||
|
if not(wstream.handshaked):
|
||||||
|
wstream.stream.reader.handshaked = true
|
||||||
|
wstream.handshaked = true
|
||||||
|
if not(isNil(wstream.handshakeFut)):
|
||||||
|
wstream.handshakeFut.complete()
|
||||||
|
item = await wstream.queue.get()
|
||||||
|
if item.size > 0:
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineSendappBuf(engine, length)
|
||||||
|
let toWrite = min(int(length), item.size)
|
||||||
|
copyOut(buf, item, toWrite)
|
||||||
|
if int(length) >= item.size:
|
||||||
|
# BearSSL is ready to accept whole item size.
|
||||||
|
sslEngineSendappAck(engine, uint(item.size))
|
||||||
|
sslEngineFlush(engine, 0)
|
||||||
|
item.future.complete()
|
||||||
|
else:
|
||||||
|
# BearSSL is not ready to accept whole item, so we will send
|
||||||
|
# only part of item and adjust offset.
|
||||||
|
item.offset = item.offset + int(length)
|
||||||
|
item.size = item.size - int(length)
|
||||||
|
wstream.queue.addFirstNoWait(item)
|
||||||
|
sslEngineSendappAck(engine, length)
|
||||||
|
else:
|
||||||
|
# Zero length item means finish, so we going to trigger TLS
|
||||||
|
# closure protocol.
|
||||||
|
wstream.state = AsyncStreamState.Finished
|
||||||
|
sslEngineClose(engine)
|
||||||
|
item.future.complete()
|
||||||
|
else:
|
||||||
await wstream.switchToWriter.wait()
|
await wstream.switchToWriter.wait()
|
||||||
wstream.switchToWriter.clear()
|
wstream.switchToWriter.clear()
|
||||||
# We need to refresh `state` because we just returned from readerLoop.
|
|
||||||
else:
|
|
||||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
|
||||||
# TLS record needs to be sent over stream.
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineSendrecBuf(engine, length)
|
|
||||||
doAssert(length != 0 and not isNil(buf))
|
|
||||||
await wstream.wsource.write(buf, int(length))
|
|
||||||
sslEngineSendrecAck(engine, length)
|
|
||||||
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
|
|
||||||
# Application data can be sent over stream.
|
|
||||||
if not(wstream.handshaked):
|
|
||||||
wstream.stream.reader.handshaked = true
|
|
||||||
wstream.handshaked = true
|
|
||||||
if not(isNil(wstream.handshakeFut)):
|
|
||||||
wstream.handshakeFut.complete()
|
|
||||||
item = await wstream.queue.get()
|
|
||||||
if item.size > 0:
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineSendappBuf(engine, length)
|
|
||||||
let toWrite = min(int(length), item.size)
|
|
||||||
copyOut(buf, item, toWrite)
|
|
||||||
if int(length) >= item.size:
|
|
||||||
# BearSSL is ready to accept whole item size.
|
|
||||||
sslEngineSendappAck(engine, uint(item.size))
|
|
||||||
sslEngineFlush(engine, 0)
|
|
||||||
item.future.complete()
|
|
||||||
else:
|
|
||||||
# BearSSL is not ready to accept whole item, so we will send
|
|
||||||
# only part of item and adjust offset.
|
|
||||||
item.offset = item.offset + int(length)
|
|
||||||
item.size = item.size - int(length)
|
|
||||||
wstream.queue.addFirstNoWait(item)
|
|
||||||
sslEngineSendappAck(engine, length)
|
|
||||||
else:
|
|
||||||
# Zero length item means finish, so we going to trigger TLS
|
|
||||||
# closure protocol.
|
|
||||||
sslEngineClose(engine)
|
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
wstream.state = AsyncStreamState.Stopped
|
wstream.state = AsyncStreamState.Stopped
|
||||||
error = newAsyncStreamUseClosedError()
|
error = newAsyncStreamUseClosedError()
|
||||||
|
@ -222,28 +245,28 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
||||||
if not(rstream.switchToWriter.isSet()):
|
if not(rstream.switchToWriter.isSet()):
|
||||||
rstream.switchToWriter.fire()
|
rstream.switchToWriter.fire()
|
||||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
|
|
||||||
|
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||||
|
# TLS records required for further processing
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineRecvrecBuf(engine, length)
|
||||||
|
let res = await rstream.rsource.readOnce(buf, int(length))
|
||||||
|
if res > 0:
|
||||||
|
sslEngineRecvrecAck(engine, uint(res))
|
||||||
|
else:
|
||||||
|
# readOnce() returns `0` if stream is at EOF.
|
||||||
|
rstream.state = AsyncStreamState.Finished
|
||||||
|
sslEngineClose(engine)
|
||||||
|
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
|
||||||
|
# Application data can be recovered.
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineRecvappBuf(engine, length)
|
||||||
|
await upload(addr rstream.buffer, buf, int(length))
|
||||||
|
sslEngineRecvappAck(engine, length)
|
||||||
|
else:
|
||||||
await rstream.switchToReader.wait()
|
await rstream.switchToReader.wait()
|
||||||
rstream.switchToReader.clear()
|
rstream.switchToReader.clear()
|
||||||
# We need to refresh `state` because we just returned from writerLoop.
|
|
||||||
else:
|
|
||||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
|
||||||
# TLS records required for further processing
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineRecvrecBuf(engine, length)
|
|
||||||
let res = await rstream.rsource.readOnce(buf, int(length))
|
|
||||||
if res > 0:
|
|
||||||
sslEngineRecvrecAck(engine, uint(res))
|
|
||||||
else:
|
|
||||||
# readOnce() returns `0` if stream is at EOF, so we initiate TLS
|
|
||||||
# closure procedure.
|
|
||||||
sslEngineClose(engine)
|
|
||||||
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
|
|
||||||
# Application data can be recovered.
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineRecvappBuf(engine, length)
|
|
||||||
await upload(addr rstream.buffer, buf, int(length))
|
|
||||||
sslEngineRecvappAck(engine, length)
|
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
rstream.state = AsyncStreamState.Stopped
|
rstream.state = AsyncStreamState.Stopped
|
||||||
except AsyncStreamError as exc:
|
except AsyncStreamError as exc:
|
||||||
|
|
|
@ -274,15 +274,12 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector],
|
||||||
vector.writer.fail(error)
|
vector.writer.fail(error)
|
||||||
|
|
||||||
proc clean(server: StreamServer) {.inline.} =
|
proc clean(server: StreamServer) {.inline.} =
|
||||||
echo "cleaning server instance"
|
|
||||||
if not(server.loopFuture.finished()):
|
if not(server.loopFuture.finished()):
|
||||||
echo "cleaning server complete()"
|
|
||||||
untrackServer(server)
|
untrackServer(server)
|
||||||
server.loopFuture.complete()
|
server.loopFuture.complete()
|
||||||
if not isNil(server.udata) and GCUserData in server.flags:
|
if not isNil(server.udata) and GCUserData in server.flags:
|
||||||
GC_unref(cast[ref int](server.udata))
|
GC_unref(cast[ref int](server.udata))
|
||||||
GC_unref(server)
|
GC_unref(server)
|
||||||
echo "clean server exit"
|
|
||||||
|
|
||||||
proc clean(transp: StreamTransport) {.inline.} =
|
proc clean(transp: StreamTransport) {.inline.} =
|
||||||
if not(transp.future.finished()):
|
if not(transp.future.finished()):
|
||||||
|
|
|
@ -634,6 +634,8 @@ suite "TLSStream test suite":
|
||||||
echo "server handshaked"
|
echo "server handshaked"
|
||||||
await sstream.writer.write(testMessage & "\r\n")
|
await sstream.writer.write(testMessage & "\r\n")
|
||||||
echo "server wrote string"
|
echo "server wrote string"
|
||||||
|
await sstream.writer.finish()
|
||||||
|
echo "server finished string"
|
||||||
await sstream.writer.closeWait()
|
await sstream.writer.closeWait()
|
||||||
echo "server closed secure writer"
|
echo "server closed secure writer"
|
||||||
await sstream.reader.closeWait()
|
await sstream.reader.closeWait()
|
||||||
|
|
Loading…
Reference in New Issue