Simplification and fixes for TLSStream state machine.
This commit is contained in:
parent
e8d2a3ca0a
commit
13eddf382d
|
@ -409,7 +409,6 @@ when defined(windows) or defined(nimdoc):
|
||||||
|
|
||||||
proc poll*() =
|
proc poll*() =
|
||||||
## Perform single asynchronous step.
|
## Perform single asynchronous step.
|
||||||
echo "poll()"
|
|
||||||
let loop = getThreadDispatcher()
|
let loop = getThreadDispatcher()
|
||||||
var curTime = Moment.now()
|
var curTime = Moment.now()
|
||||||
var curTimeout = DWORD(0)
|
var curTimeout = DWORD(0)
|
||||||
|
@ -423,8 +422,6 @@ when defined(windows) or defined(nimdoc):
|
||||||
var lpCompletionKey: ULONG_PTR
|
var lpCompletionKey: ULONG_PTR
|
||||||
var customOverlapped: PtrCustomOverlapped
|
var customOverlapped: PtrCustomOverlapped
|
||||||
|
|
||||||
echo "poll() timeout = ", curTimeout, ", len(callbacks) = ", len(loop.callbacks)
|
|
||||||
|
|
||||||
let res = getQueuedCompletionStatus(
|
let res = getQueuedCompletionStatus(
|
||||||
loop.ioPort, addr lpNumberOfBytesTransferred,
|
loop.ioPort, addr lpNumberOfBytesTransferred,
|
||||||
addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped),
|
addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped),
|
||||||
|
@ -460,7 +457,6 @@ when defined(windows) or defined(nimdoc):
|
||||||
# All callbacks which will be added in process will be processed on next
|
# All callbacks which will be added in process will be processed on next
|
||||||
# poll() call.
|
# poll() call.
|
||||||
loop.processCallbacks()
|
loop.processCallbacks()
|
||||||
echo "exit poll()"
|
|
||||||
|
|
||||||
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
|
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
|
||||||
## Closes a socket and ensures that it is unregistered.
|
## Closes a socket and ensures that it is unregistered.
|
||||||
|
|
|
@ -80,6 +80,7 @@ type
|
||||||
writerLoop*: StreamWriterLoop
|
writerLoop*: StreamWriterLoop
|
||||||
state*: AsyncStreamState
|
state*: AsyncStreamState
|
||||||
queue*: AsyncQueue[WriteItem]
|
queue*: AsyncQueue[WriteItem]
|
||||||
|
error*: ref AsyncStreamError
|
||||||
udata: pointer
|
udata: pointer
|
||||||
bytesCount*: uint64
|
bytesCount*: uint64
|
||||||
future: Future[void]
|
future: Future[void]
|
||||||
|
|
|
@ -61,8 +61,6 @@ type
|
||||||
of TLSStreamKind.Server:
|
of TLSStreamKind.Server:
|
||||||
scontext: ptr SslServerContext
|
scontext: ptr SslServerContext
|
||||||
stream*: TLSAsyncStream
|
stream*: TLSAsyncStream
|
||||||
switchToReader*: AsyncEvent
|
|
||||||
switchToWriter*: AsyncEvent
|
|
||||||
handshaked*: bool
|
handshaked*: bool
|
||||||
handshakeFut*: Future[void]
|
handshakeFut*: Future[void]
|
||||||
|
|
||||||
|
@ -73,8 +71,6 @@ type
|
||||||
of TLSStreamKind.Server:
|
of TLSStreamKind.Server:
|
||||||
scontext: ptr SslServerContext
|
scontext: ptr SslServerContext
|
||||||
stream*: TLSAsyncStream
|
stream*: TLSAsyncStream
|
||||||
switchToReader*: AsyncEvent
|
|
||||||
switchToWriter*: AsyncEvent
|
|
||||||
handshaked*: bool
|
handshaked*: bool
|
||||||
handshakeFut*: Future[void]
|
handshakeFut*: Future[void]
|
||||||
|
|
||||||
|
@ -86,13 +82,33 @@ type
|
||||||
x509*: X509MinimalContext
|
x509*: X509MinimalContext
|
||||||
reader*: TLSStreamReader
|
reader*: TLSStreamReader
|
||||||
writer*: TLSStreamWriter
|
writer*: TLSStreamWriter
|
||||||
|
mainLoop*: Future[void]
|
||||||
|
|
||||||
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
|
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
|
||||||
|
|
||||||
TLSStreamError* = object of AsyncStreamError
|
TLSStreamError* = object of AsyncStreamError
|
||||||
|
TLSStreamHandshakeError* = object of TLSStreamError
|
||||||
|
TLSStreamReadError* = object of TLSStreamError
|
||||||
|
par*: ref AsyncStreamError
|
||||||
|
TLSStreamWriteError* = object of TLSStreamError
|
||||||
|
par*: ref AsyncStreamError
|
||||||
TLSStreamProtocolError* = object of TLSStreamError
|
TLSStreamProtocolError* = object of TLSStreamError
|
||||||
errCode*: int
|
errCode*: int
|
||||||
|
|
||||||
|
proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {.
|
||||||
|
inline.} =
|
||||||
|
var w = newException(TLSStreamReadError, "Read stream failed")
|
||||||
|
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||||
|
w.par = p
|
||||||
|
w
|
||||||
|
|
||||||
|
proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {.
|
||||||
|
inline.} =
|
||||||
|
var w = newException(TLSStreamWriteError, "Write stream failed")
|
||||||
|
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||||
|
w.par = p
|
||||||
|
w
|
||||||
|
|
||||||
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
||||||
var msg = ""
|
var msg = ""
|
||||||
var code = 0
|
var code = 0
|
||||||
|
@ -110,223 +126,241 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
|
||||||
err.errCode = code
|
err.errCode = code
|
||||||
err
|
err
|
||||||
|
|
||||||
proc dumpState*(state: cuint): string =
|
proc tlsWriteRec(engine: ptr SslEngineContext,
|
||||||
var res = ""
|
writer: TLSStreamWriter): Future[bool] {.async.} =
|
||||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
try:
|
||||||
if len(res) > 0: res.add(", ")
|
var length = 0'u
|
||||||
res.add("SSL_CLOSED")
|
var buf = sslEngineSendrecBuf(engine, length)
|
||||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
doAssert(length != 0 and not isNil(buf))
|
||||||
if len(res) > 0: res.add(", ")
|
await writer.wsource.write(buf, int(length))
|
||||||
res.add("SSL_SENDREC")
|
sslEngineSendrecAck(engine, length)
|
||||||
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
return true
|
||||||
if len(res) > 0: res.add(", ")
|
except AsyncStreamError as exc:
|
||||||
res.add("SSL_SENDAPP")
|
writer.state = AsyncStreamState.Error
|
||||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
writer.error = exc
|
||||||
if len(res) > 0: res.add(", ")
|
except CancelledError:
|
||||||
res.add("SSL_RECVREC")
|
writer.state = AsyncStreamState.Stopped
|
||||||
if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
|
||||||
if len(res) > 0: res.add(", ")
|
return false
|
||||||
res.add("SSL_RECVAPP")
|
|
||||||
"{" & res & "}"
|
proc tlsWriteApp(engine: ptr SslEngineContext,
|
||||||
|
writer: TLSStreamWriter): Future[bool] {.async.} =
|
||||||
|
try:
|
||||||
|
var item = await writer.queue.get()
|
||||||
|
if item.size > 0:
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineSendappBuf(engine, length)
|
||||||
|
let toWrite = min(int(length), item.size)
|
||||||
|
copyOut(buf, item, toWrite)
|
||||||
|
if int(length) >= item.size:
|
||||||
|
# BearSSL is ready to accept whole item size.
|
||||||
|
sslEngineSendappAck(engine, uint(item.size))
|
||||||
|
sslEngineFlush(engine, 0)
|
||||||
|
item.future.complete()
|
||||||
|
return true
|
||||||
|
else:
|
||||||
|
# BearSSL is not ready to accept whole item, so we will send
|
||||||
|
# only part of item and adjust offset.
|
||||||
|
item.offset = item.offset + int(length)
|
||||||
|
item.size = item.size - int(length)
|
||||||
|
writer.queue.addFirstNoWait(item)
|
||||||
|
sslEngineSendappAck(engine, length)
|
||||||
|
return true
|
||||||
|
else:
|
||||||
|
sslEngineClose(engine)
|
||||||
|
item.future.complete()
|
||||||
|
return true
|
||||||
|
except CancelledError:
|
||||||
|
writer.state = AsyncStreamState.Stopped
|
||||||
|
|
||||||
|
return false
|
||||||
|
|
||||||
|
proc tlsReadRec(engine: ptr SslEngineContext,
|
||||||
|
reader: TLSStreamReader): Future[bool] {.async.} =
|
||||||
|
try:
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineRecvrecBuf(engine, length)
|
||||||
|
let res = await reader.rsource.readOnce(buf, int(length))
|
||||||
|
sslEngineRecvrecAck(engine, uint(res))
|
||||||
|
return true
|
||||||
|
except CancelledError:
|
||||||
|
reader.state = AsyncStreamState.Stopped
|
||||||
|
except AsyncStreamError as exc:
|
||||||
|
reader.state = AsyncStreamState.Error
|
||||||
|
reader.error = exc
|
||||||
|
return false
|
||||||
|
|
||||||
|
proc tlsReadApp(engine: ptr SslEngineContext,
|
||||||
|
reader: TLSStreamReader): Future[bool] {.async.} =
|
||||||
|
try:
|
||||||
|
var length = 0'u
|
||||||
|
var buf = sslEngineRecvappBuf(engine, length)
|
||||||
|
await upload(addr reader.buffer, buf, int(length))
|
||||||
|
sslEngineRecvappAck(engine, length)
|
||||||
|
return true
|
||||||
|
except CancelledError:
|
||||||
|
reader.state = AsyncStreamState.Stopped
|
||||||
|
return false
|
||||||
|
|
||||||
template raiseTLSStreamProtoError*[T](message: T) =
|
template raiseTLSStreamProtoError*[T](message: T) =
|
||||||
raise newTLSStreamProtocolError(message)
|
raise newTLSStreamProtocolError(message)
|
||||||
|
|
||||||
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
template readAndReset(fut: untyped) =
|
||||||
var wstream = cast[TLSStreamWriter](stream)
|
if fut.finished():
|
||||||
var engine: ptr SslEngineContext
|
if fut.read():
|
||||||
var error: ref AsyncStreamError
|
fut = nil
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
fut = nil
|
||||||
|
loopState = AsyncStreamState.Error
|
||||||
|
break
|
||||||
|
|
||||||
if wstream.kind == TLSStreamKind.Server:
|
proc cancelAndWait*(a, b, c, d: Future[bool]): Future[void] =
|
||||||
engine = addr wstream.scontext.eng
|
var waiting: seq[Future[bool]]
|
||||||
else:
|
if not(isNil(a)) and not(a.finished()):
|
||||||
engine = addr wstream.ccontext.eng
|
a.cancel()
|
||||||
|
waiting.add(a)
|
||||||
|
if not(isNil(b)) and not(b.finished()):
|
||||||
|
b.cancel()
|
||||||
|
waiting.add(b)
|
||||||
|
if not(isNil(c)) and not(c.finished()):
|
||||||
|
c.cancel()
|
||||||
|
waiting.add(c)
|
||||||
|
if not(isNil(d)) and not(d.finished()):
|
||||||
|
d.cancel()
|
||||||
|
waiting.add(d)
|
||||||
|
allFutures(waiting)
|
||||||
|
|
||||||
wstream.state = AsyncStreamState.Running
|
proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
|
||||||
|
var
|
||||||
|
sendRecFut, sendAppFut: Future[bool]
|
||||||
|
recvRecFut, recvAppFut: Future[bool]
|
||||||
|
|
||||||
|
let engine =
|
||||||
|
case stream.reader.kind
|
||||||
|
of TLSStreamKind.Server:
|
||||||
|
addr stream.scontext.eng
|
||||||
|
of TLSStreamKind.Client:
|
||||||
|
addr stream.ccontext.eng
|
||||||
|
|
||||||
|
var loopState = AsyncStreamState.Running
|
||||||
|
|
||||||
while true:
|
while true:
|
||||||
var item: WriteItem
|
var waiting: seq[Future[bool]]
|
||||||
try:
|
var state = sslEngineCurrentState(engine)
|
||||||
var state = engine.sslEngineCurrentState()
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") state = ", dumpState(state)
|
|
||||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
|
||||||
wstream.state = AsyncStreamState.Finished
|
|
||||||
else:
|
|
||||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") firing switch to reader, waiters = ", len(wstream.switchToReader.waiters)
|
|
||||||
wstream.switchToReader.fire()
|
|
||||||
|
|
||||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||||
# TLS record needs to be sent over stream.
|
loopState = AsyncStreamState.Finished
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineSendrecBuf(engine, length)
|
|
||||||
doAssert(length != 0 and not isNil(buf))
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") sending record ", int(length), " bytes"
|
|
||||||
await wstream.wsource.write(buf, int(length))
|
|
||||||
sslEngineSendrecAck(engine, length)
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") record ", int(length), " bytes sent"
|
|
||||||
elif (state and SSL_SENDAPP) == SSL_SENDAPP:
|
|
||||||
# Application data can be sent over stream.
|
|
||||||
if not(wstream.handshaked):
|
|
||||||
wstream.stream.reader.handshaked = true
|
|
||||||
wstream.handshaked = true
|
|
||||||
if not(isNil(wstream.handshakeFut)):
|
|
||||||
wstream.handshakeFut.complete()
|
|
||||||
if not(wstream.queue.empty()):
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for appdata"
|
|
||||||
item = await wstream.queue.get()
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") sending appdata ", int(item.size), " bytes"
|
|
||||||
if item.size > 0:
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineSendappBuf(engine, length)
|
|
||||||
let toWrite = min(int(length), item.size)
|
|
||||||
copyOut(buf, item, toWrite)
|
|
||||||
if int(length) >= item.size:
|
|
||||||
# BearSSL is ready to accept whole item size.
|
|
||||||
sslEngineSendappAck(engine, uint(item.size))
|
|
||||||
sslEngineFlush(engine, 0)
|
|
||||||
item.future.complete()
|
|
||||||
else:
|
|
||||||
# BearSSL is not ready to accept whole item, so we will send
|
|
||||||
# only part of item and adjust offset.
|
|
||||||
item.offset = item.offset + int(length)
|
|
||||||
item.size = item.size - int(length)
|
|
||||||
wstream.queue.addFirstNoWait(item)
|
|
||||||
sslEngineSendappAck(engine, length)
|
|
||||||
else:
|
|
||||||
# Zero length item means finish, so we going to trigger TLS
|
|
||||||
# closure protocol.
|
|
||||||
sslEngineClose(engine)
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") ",
|
|
||||||
"received zero-length item, state = ",
|
|
||||||
dumpState(engine.sslEngineCurrentState())
|
|
||||||
else:
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") empty queue"
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for switch back"
|
|
||||||
await wstream.switchToWriter.wait()
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got flow after switch"
|
|
||||||
wstream.switchToWriter.clear()
|
|
||||||
else:
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") waiting for switch back, switchToReader.isSet() == ", wstream.switchToReader.isSet()
|
|
||||||
await wstream.switchToWriter.wait()
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got flow after switch"
|
|
||||||
wstream.switchToWriter.clear()
|
|
||||||
|
|
||||||
except CancelledError:
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") received cancellation"
|
|
||||||
wstream.state = AsyncStreamState.Stopped
|
|
||||||
error = newAsyncStreamUseClosedError()
|
|
||||||
except AsyncStreamError as exc:
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") got an exception ",
|
|
||||||
exc.msg
|
|
||||||
wstream.state = AsyncStreamState.Error
|
|
||||||
error = exc
|
|
||||||
|
|
||||||
if wstream.state != AsyncStreamState.Running:
|
|
||||||
if wstream.state == AsyncStreamState.Finished:
|
|
||||||
error = newAsyncStreamUseClosedError()
|
|
||||||
else:
|
|
||||||
if not(isNil(item.future)):
|
|
||||||
if not(item.future.finished()):
|
|
||||||
item.future.fail(error)
|
|
||||||
while not(wstream.queue.empty()):
|
|
||||||
let pitem = wstream.queue.popFirstNoWait()
|
|
||||||
if not(pitem.future.finished()):
|
|
||||||
pitem.future.fail(error)
|
|
||||||
|
|
||||||
if not(isNil(wstream.stream.reader)):
|
|
||||||
wstream.switchToReader.fire()
|
|
||||||
|
|
||||||
wstream.stream.writer = nil
|
|
||||||
|
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") handle exited"
|
|
||||||
break
|
break
|
||||||
echo "tlsWriteLoop(", cast[uint](wstream.stream), ") exited"
|
|
||||||
|
if isNil(sendRecFut):
|
||||||
|
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||||
|
sendRecFut = tlsWriteRec(engine, stream.writer)
|
||||||
|
else:
|
||||||
|
sendRecFut.readAndReset()
|
||||||
|
|
||||||
|
if isNil(sendAppFut):
|
||||||
|
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||||
|
# Application data can be sent over stream.
|
||||||
|
if not(stream.writer.handshaked):
|
||||||
|
stream.reader.handshaked = true
|
||||||
|
stream.writer.handshaked = true
|
||||||
|
if not(isNil(stream.writer.handshakeFut)):
|
||||||
|
stream.writer.handshakeFut.complete()
|
||||||
|
|
||||||
|
sendAppFut = tlsWriteApp(engine, stream.writer)
|
||||||
|
else:
|
||||||
|
sendAppFut.readAndReset()
|
||||||
|
|
||||||
|
if isNil(recvRecFut):
|
||||||
|
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||||
|
recvRecFut = tlsReadRec(engine, stream.reader)
|
||||||
|
else:
|
||||||
|
recvRecFut.readAndReset()
|
||||||
|
|
||||||
|
if isNil(recvAppFut):
|
||||||
|
if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
||||||
|
recvAppFut = tlsReadApp(engine, stream.reader)
|
||||||
|
else:
|
||||||
|
recvAppFut.readAndReset()
|
||||||
|
|
||||||
|
if not(isNil(sendRecFut)):
|
||||||
|
waiting.add(sendRecFut)
|
||||||
|
if not(isNil(sendAppFut)):
|
||||||
|
waiting.add(sendAppFut)
|
||||||
|
if not(isNil(recvRecFut)):
|
||||||
|
waiting.add(recvRecFut)
|
||||||
|
if not(isNil(recvAppFut)):
|
||||||
|
waiting.add(recvAppFut)
|
||||||
|
|
||||||
|
if len(waiting) > 0:
|
||||||
|
try:
|
||||||
|
discard await one(waiting)
|
||||||
|
except CancelledError:
|
||||||
|
loopState = AsyncStreamState.Stopped
|
||||||
|
|
||||||
|
if loopState != AsyncStreamState.Running:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Cancelling and waiting all the pending operations
|
||||||
|
await cancelAndWait(sendRecFut, sendAppFut, recvRecFut, recvAppFut)
|
||||||
|
# Calculating error
|
||||||
|
let error =
|
||||||
|
case loopState
|
||||||
|
of AsyncStreamState.Stopped:
|
||||||
|
newAsyncStreamUseClosedError()
|
||||||
|
of AsyncStreamState.Error:
|
||||||
|
if not(isNil(stream.writer.error)):
|
||||||
|
stream.writer.error
|
||||||
|
else:
|
||||||
|
newTLSStreamWriteError(stream.reader.error)
|
||||||
|
of AsyncStreamState.Finished:
|
||||||
|
let err = engine.sslEngineLastError()
|
||||||
|
if err != 0:
|
||||||
|
newTLSStreamProtocolError(err)
|
||||||
|
else:
|
||||||
|
nil
|
||||||
|
of AsyncStreamState.Running:
|
||||||
|
nil
|
||||||
|
else:
|
||||||
|
nil
|
||||||
|
|
||||||
|
# Syncing state for reader and writer
|
||||||
|
stream.writer.state = loopState
|
||||||
|
if loopState == AsyncStreamState.Error:
|
||||||
|
if isNil(stream.reader.error):
|
||||||
|
stream.reader.error = newTLSStreamReadError(error)
|
||||||
|
stream.reader.state = loopState
|
||||||
|
|
||||||
|
if not(isNil(error)):
|
||||||
|
# Completing all pending writes
|
||||||
|
while(not(stream.writer.queue.empty())):
|
||||||
|
let item = stream.writer.queue.popFirstNoWait()
|
||||||
|
if not(item.future.finished()):
|
||||||
|
item.future.fail(error)
|
||||||
|
# Completing handshake
|
||||||
|
if not(stream.writer.handshaked):
|
||||||
|
if not(isNil(stream.writer.handshakeFut)):
|
||||||
|
if not(stream.writer.handshakeFut.finished()):
|
||||||
|
stream.writer.handshakeFut.fail(error)
|
||||||
|
# Completing readers
|
||||||
|
stream.reader.buffer.forget()
|
||||||
|
|
||||||
|
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
|
var wstream = cast[TLSStreamWriter](stream)
|
||||||
|
wstream.state = AsyncStreamState.Running
|
||||||
|
await stepsAsync(1)
|
||||||
|
if isNil(wstream.stream.mainLoop):
|
||||||
|
wstream.stream.mainLoop = tlsLoop(wstream.stream)
|
||||||
|
await wstream.stream.mainLoop
|
||||||
|
|
||||||
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
var rstream = cast[TLSStreamReader](stream)
|
var rstream = cast[TLSStreamReader](stream)
|
||||||
var engine: ptr SslEngineContext
|
|
||||||
|
|
||||||
if rstream.kind == TLSStreamKind.Server:
|
|
||||||
engine = addr rstream.scontext.eng
|
|
||||||
else:
|
|
||||||
engine = addr rstream.ccontext.eng
|
|
||||||
|
|
||||||
rstream.state = AsyncStreamState.Running
|
rstream.state = AsyncStreamState.Running
|
||||||
|
await stepsAsync(1)
|
||||||
while true:
|
if isNil(rstream.stream.mainLoop):
|
||||||
try:
|
rstream.stream.mainLoop = tlsLoop(rstream.stream)
|
||||||
var state = engine.sslEngineCurrentState()
|
await rstream.stream.mainLoop
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") state = ", dumpState(state)
|
|
||||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
|
||||||
let err = engine.sslEngineLastError()
|
|
||||||
if err != 0:
|
|
||||||
rstream.error = newTLSStreamProtocolError(err)
|
|
||||||
rstream.state = AsyncStreamState.Error
|
|
||||||
else:
|
|
||||||
rstream.state = AsyncStreamState.Finished
|
|
||||||
else:
|
|
||||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") ",
|
|
||||||
"firing switch to writer, len(waiters) = ",
|
|
||||||
len(rstream.switchToWriter.waiters)
|
|
||||||
rstream.switchToWriter.fire()
|
|
||||||
|
|
||||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
|
||||||
# TLS records required for further processing
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineRecvrecBuf(engine, length)
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") waiting for record"
|
|
||||||
let res = await rstream.rsource.readOnce(buf, int(length))
|
|
||||||
sslEngineRecvrecAck(engine, uint(res))
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") received ", res, " length rec, state = ", dumpState(engine.sslEngineCurrentState())
|
|
||||||
# if res > 0:
|
|
||||||
# sslEngineRecvrecAck(engine, uint(res))
|
|
||||||
# else:
|
|
||||||
# echo "tlsReadLoop() received 0 length ack"
|
|
||||||
# # readOnce() returns `0` if stream is at EOF.
|
|
||||||
# # rstream.state = AsyncStreamState.Finished
|
|
||||||
# sslEngineClose(engine)
|
|
||||||
elif (state and SSL_RECVAPP) == SSL_RECVAPP:
|
|
||||||
# Application data can be recovered.
|
|
||||||
var length = 0'u
|
|
||||||
var buf = sslEngineRecvappBuf(engine, length)
|
|
||||||
await upload(addr rstream.buffer, buf, int(length))
|
|
||||||
sslEngineRecvappAck(engine, length)
|
|
||||||
else:
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") waiting for `switchToReader` back, ",
|
|
||||||
"switchToReader.isSet() == ", rstream.switchToReader.isSet(),
|
|
||||||
", state = ", dumpState(engine.sslEngineCurrentState())
|
|
||||||
await rstream.switchToReader.wait()
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") got flow after switch"
|
|
||||||
rstream.switchToReader.clear()
|
|
||||||
|
|
||||||
except CancelledError:
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") cancellation received"
|
|
||||||
rstream.state = AsyncStreamState.Stopped
|
|
||||||
except AsyncStreamError as exc:
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") got an exception ",
|
|
||||||
exc.msg
|
|
||||||
rstream.error = exc
|
|
||||||
rstream.state = AsyncStreamState.Error
|
|
||||||
|
|
||||||
if rstream.state != AsyncStreamState.Running:
|
|
||||||
if not(rstream.handshaked):
|
|
||||||
rstream.handshaked = true
|
|
||||||
rstream.stream.writer.handshaked = true
|
|
||||||
if not(isNil(rstream.handshakeFut)):
|
|
||||||
rstream.handshakeFut.fail(rstream.error)
|
|
||||||
|
|
||||||
# Perform TLS cleanup procedure
|
|
||||||
if not(isNil(rstream.stream.writer)):
|
|
||||||
rstream.switchToWriter.fire()
|
|
||||||
if rstream.state != AsyncStreamState.Finished:
|
|
||||||
sslEngineClose(engine)
|
|
||||||
rstream.buffer.forget()
|
|
||||||
rstream.stream.reader = nil
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") handle exited"
|
|
||||||
break
|
|
||||||
echo "tlsReadLoop(", cast[uint](rstream.stream), ") exited"
|
|
||||||
|
|
||||||
proc getSignerAlgo(xc: X509Certificate): int =
|
proc getSignerAlgo(xc: X509Certificate): int =
|
||||||
## Get certificate's signing algorithm.
|
## Get certificate's signing algorithm.
|
||||||
|
@ -362,21 +396,15 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
|
||||||
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
|
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
|
||||||
##
|
##
|
||||||
## ``flags`` - custom TLS connection flags.
|
## ``flags`` - custom TLS connection flags.
|
||||||
let switchToWriter = newAsyncEvent()
|
|
||||||
let switchToReader = newAsyncEvent()
|
|
||||||
var res = TLSAsyncStream()
|
var res = TLSAsyncStream()
|
||||||
var reader = TLSStreamReader(
|
var reader = TLSStreamReader(
|
||||||
kind: TLSStreamKind.Client,
|
kind: TLSStreamKind.Client,
|
||||||
stream: res,
|
stream: res,
|
||||||
switchToReader: switchToReader,
|
|
||||||
switchToWriter: switchToWriter,
|
|
||||||
ccontext: addr res.ccontext
|
ccontext: addr res.ccontext
|
||||||
)
|
)
|
||||||
var writer = TLSStreamWriter(
|
var writer = TLSStreamWriter(
|
||||||
kind: TLSStreamKind.Client,
|
kind: TLSStreamKind.Client,
|
||||||
stream: res,
|
stream: res,
|
||||||
switchToReader: switchToReader,
|
|
||||||
switchToWriter: switchToWriter,
|
|
||||||
ccontext: addr res.ccontext
|
ccontext: addr res.ccontext
|
||||||
)
|
)
|
||||||
res.reader = reader
|
res.reader = reader
|
||||||
|
@ -444,22 +472,15 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
||||||
if isNil(certificate) or len(certificate.certs) == 0:
|
if isNil(certificate) or len(certificate.certs) == 0:
|
||||||
raiseTLSStreamProtoError("Incorrect certificate")
|
raiseTLSStreamProtoError("Incorrect certificate")
|
||||||
|
|
||||||
let switchToWriter = newAsyncEvent()
|
|
||||||
let switchToReader = newAsyncEvent()
|
|
||||||
|
|
||||||
var res = TLSAsyncStream()
|
var res = TLSAsyncStream()
|
||||||
var reader = TLSStreamReader(
|
var reader = TLSStreamReader(
|
||||||
kind: TLSStreamKind.Server,
|
kind: TLSStreamKind.Server,
|
||||||
stream: res,
|
stream: res,
|
||||||
switchToReader: switchToReader,
|
|
||||||
switchToWriter: switchToWriter,
|
|
||||||
scontext: addr res.scontext
|
scontext: addr res.scontext
|
||||||
)
|
)
|
||||||
var writer = TLSStreamWriter(
|
var writer = TLSStreamWriter(
|
||||||
kind: TLSStreamKind.Server,
|
kind: TLSStreamKind.Server,
|
||||||
stream: res,
|
stream: res,
|
||||||
switchToReader: switchToReader,
|
|
||||||
switchToWriter: switchToWriter,
|
|
||||||
scontext: addr res.scontext
|
scontext: addr res.scontext
|
||||||
)
|
)
|
||||||
res.reader = reader
|
res.reader = reader
|
||||||
|
|
|
@ -1639,7 +1639,6 @@ proc close*(server: StreamServer) =
|
||||||
## Please note that release of resources is not completed immediately, to be
|
## Please note that release of resources is not completed immediately, to be
|
||||||
## sure all resources got released please use ``await server.join()``.
|
## sure all resources got released please use ``await server.join()``.
|
||||||
proc continuation(udata: pointer) {.gcsafe.} =
|
proc continuation(udata: pointer) {.gcsafe.} =
|
||||||
echo "server close() continuation"
|
|
||||||
server.clean()
|
server.clean()
|
||||||
|
|
||||||
let r1 = (server.status == ServerStatus.Stopped) and
|
let r1 = (server.status == ServerStatus.Stopped) and
|
||||||
|
|
|
@ -594,7 +594,6 @@ suite "TLSStream test suite":
|
||||||
var reader = newAsyncStreamReader(transp)
|
var reader = newAsyncStreamReader(transp)
|
||||||
var writer = newAsyncStreamWriter(transp)
|
var writer = newAsyncStreamWriter(transp)
|
||||||
var tlsstream = newTLSClientAsyncStream(reader, writer, name)
|
var tlsstream = newTLSClientAsyncStream(reader, writer, name)
|
||||||
|
|
||||||
await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name &
|
await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name &
|
||||||
"\r\nConnection: close\r\n\r\n")
|
"\r\nConnection: close\r\n\r\n")
|
||||||
var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer),
|
var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer),
|
||||||
|
@ -625,65 +624,39 @@ suite "TLSStream test suite":
|
||||||
|
|
||||||
proc serveClient(server: StreamServer,
|
proc serveClient(server: StreamServer,
|
||||||
transp: StreamTransport) {.async.} =
|
transp: StreamTransport) {.async.} =
|
||||||
echo "- server accepted client"
|
|
||||||
var reader = newAsyncStreamReader(transp)
|
var reader = newAsyncStreamReader(transp)
|
||||||
var writer = newAsyncStreamWriter(transp)
|
var writer = newAsyncStreamWriter(transp)
|
||||||
var sstream = newTLSServerAsyncStream(reader, writer, key, cert)
|
var sstream = newTLSServerAsyncStream(reader, writer, key, cert)
|
||||||
echo "- server stream is [", cast[uint](sstream), "]"
|
|
||||||
echo "- server handshaking"
|
|
||||||
await handshake(sstream)
|
await handshake(sstream)
|
||||||
echo "- server handshaked"
|
|
||||||
await sstream.writer.write(testMessage & "\r\n")
|
await sstream.writer.write(testMessage & "\r\n")
|
||||||
echo "- server wrote string"
|
|
||||||
await sstream.writer.finish()
|
await sstream.writer.finish()
|
||||||
echo "- server finished"
|
|
||||||
await sleepAsync(5.seconds)
|
|
||||||
echo "- server sleeped"
|
|
||||||
await sstream.writer.closeWait()
|
await sstream.writer.closeWait()
|
||||||
echo "- server closed secure writer"
|
|
||||||
await sstream.reader.closeWait()
|
await sstream.reader.closeWait()
|
||||||
echo "- server closed secure reader"
|
|
||||||
await reader.closeWait()
|
await reader.closeWait()
|
||||||
echo "- server closed reader"
|
|
||||||
await writer.closeWait()
|
await writer.closeWait()
|
||||||
echo "- server closed writer"
|
|
||||||
await transp.closeWait()
|
await transp.closeWait()
|
||||||
echo "- server closed transport"
|
|
||||||
server.stop()
|
server.stop()
|
||||||
echo "- server stopped server"
|
|
||||||
server.close()
|
server.close()
|
||||||
echo "- server closed server"
|
|
||||||
|
|
||||||
key = TLSPrivateKey.init(pemkey)
|
key = TLSPrivateKey.init(pemkey)
|
||||||
cert = TLSCertificate.init(pemcert)
|
cert = TLSCertificate.init(pemcert)
|
||||||
|
|
||||||
var server = createStreamServer(address, serveClient, {ReuseAddr})
|
var server = createStreamServer(address, serveClient, {ReuseAddr})
|
||||||
server.start()
|
server.start()
|
||||||
echo "server started"
|
|
||||||
var conn = await connect(address)
|
var conn = await connect(address)
|
||||||
echo "= client connected"
|
|
||||||
var creader = newAsyncStreamReader(conn)
|
var creader = newAsyncStreamReader(conn)
|
||||||
var cwriter = newAsyncStreamWriter(conn)
|
var cwriter = newAsyncStreamWriter(conn)
|
||||||
# We are using self-signed certificate
|
# We are using self-signed certificate
|
||||||
let flags = {NoVerifyHost, NoVerifyServerName}
|
let flags = {NoVerifyHost, NoVerifyServerName}
|
||||||
var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags)
|
var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags)
|
||||||
echo "= client stream is [", cast[uint](cstream), "]"
|
|
||||||
echo "= client reading line"
|
|
||||||
let res = await cstream.reader.read()
|
let res = await cstream.reader.read()
|
||||||
echo "= client readed line"
|
|
||||||
await cstream.reader.closeWait()
|
await cstream.reader.closeWait()
|
||||||
echo "= client closed reader"
|
|
||||||
await cstream.writer.closeWait()
|
await cstream.writer.closeWait()
|
||||||
echo "= client closed writer"
|
|
||||||
await creader.closeWait()
|
await creader.closeWait()
|
||||||
echo "= client closed creader"
|
|
||||||
await cwriter.closeWait()
|
await cwriter.closeWait()
|
||||||
echo "= client closed cwriter"
|
|
||||||
await conn.closeWait()
|
await conn.closeWait()
|
||||||
echo "= client closed connection"
|
|
||||||
await server.join()
|
await server.join()
|
||||||
echo "= client waited server"
|
return cast[string](res) == (testMessage & "\r\n")
|
||||||
result = true # res == testMessage
|
|
||||||
|
|
||||||
test "Simple server with RSA self-signed certificate":
|
test "Simple server with RSA self-signed certificate":
|
||||||
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),
|
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),
|
||||||
|
|
Loading…
Reference in New Issue