Attempt to fix state machine issue.

This commit is contained in:
cheatfate 2019-10-09 09:12:54 +03:00
parent 417111093e
commit 3f8d529c8e
No known key found for this signature in database
GPG Key ID: 46ADD633A7201F95

View File

@ -29,7 +29,14 @@ type
type
TlsStreamWriter* = ref object of AsyncStreamWriter
case kind: TlsStreamKind
of TlsStreamKind.Client:
ccontext: ptr SslClientContext
of TlsStreamKind.Server:
scontext: ptr SslServerContext
stream*: TlsAsyncStream
switchToReader*: AsyncEvent
switchToWriter*: AsyncEvent
TlsStreamReader* = ref object of AsyncStreamReader
case kind: TlsStreamKind
@ -38,6 +45,8 @@ type
of TlsStreamKind.Server:
scontext: ptr SslServerContext
stream*: TlsAsyncStream
switchToReader*: AsyncEvent
switchToWriter*: AsyncEvent
TlsAsyncStream* = ref object of RootRef
xwc*: X509NoAnchorContext
@ -87,52 +96,48 @@ template newTlsStreamProtocolError[T](message: T): ref Exception =
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[TlsStreamWriter](stream)
try:
# We waiting for empty future which will never be completed, because all
# the logic are inside of tlsReadLoop(). This infinite wait can be stopped
# by closing stream (e.g. cancellation).
var future = newFuture[void]()
await future
except CancelledError:
discard
finally:
wstream.stream = nil
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TlsStreamReader](stream)
var wstream = rstream.stream.writer
var engine: ptr SslEngineContext
if rstream.kind == TlsStreamKind.Server:
engine = addr rstream.scontext.eng
var error: ref Exception
if wstream.kind == TLSStreamKind.Server:
engine = addr wstream.scontext.eng
else:
engine = addr rstream.ccontext.eng
engine = addr wstream.ccontext.eng
wstream.state = AsyncStreamState.Running
try:
var length: uint
while true:
var state = engine.sslEngineCurrentState()
if (state and SSL_CLOSED) == SSL_CLOSED:
let err = engine.sslEngineLastError()
if err != 0:
rstream.error = newTlsStreamProtocolError(err)
rstream.state = AsyncStreamState.Error
break
else:
rstream.state = AsyncStreamState.Stopped
wstream.state = AsyncStreamState.Finished
break
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
wstream.switchToReader.fire()
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
await wstream.switchToWriter.wait()
wstream.switchToWriter.clear()
# We need to refresh `state` because we just returned from readerLoop.
continue
if (state and SSL_SENDREC) == SSL_SENDREC:
# TLS record needs to be sent over stream.
length = 0'u
var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf))
await wstream.wsource.write(buf, int(length))
var fut = awaitne wstream.wsource.write(buf, int(length))
if fut.failed():
error = fut.error
break
sslEngineSendrecAck(engine, length)
continue
if (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream.
if len(wstream.queue) > 0:
var item = await wstream.queue.get()
if item.size > 0:
length = 0'u
@ -165,14 +170,69 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
continue
else:
# Zero length item means finish
rstream.state = AsyncStreamState.Finished
wstream.state = AsyncStreamState.Finished
break
except CancelledError:
wstream.state = AsyncStreamState.Stopped
finally:
if wstream.state == AsyncStreamState.Stopped:
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.complete()
elif wstream.state == AsyncStreamState.Error:
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.fail(error)
wstream.stream = nil
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TlsStreamReader](stream)
var engine: ptr SslEngineContext
if rstream.kind == TlsStreamKind.Server:
engine = addr rstream.scontext.eng
else:
engine = addr rstream.ccontext.eng
rstream.state = AsyncStreamState.Running
try:
var length: uint
while true:
var state = engine.sslEngineCurrentState()
if (state and SSL_CLOSED) == SSL_CLOSED:
let err = engine.sslEngineLastError()
if err != 0:
rstream.error = newTlsStreamProtocolError(err)
rstream.state = AsyncStreamState.Error
break
else:
rstream.state = AsyncStreamState.Stopped
break
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
rstream.switchToWriter.fire()
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
await rstream.switchToReader.wait()
rstream.switchToReader.clear()
# We need to refresh `state` because we just returned from writerLoop.
continue
if (state and SSL_RECVREC) == SSL_RECVREC:
# TLS records required for further processing
length = 0'u
var buf = sslEngineRecvrecBuf(engine, length)
let res = await rstream.rsource.readOnce(buf, int(length))
var resFut = awaitne rstream.rsource.readOnce(buf, int(length))
if resFut.failed():
rstream.error = resFut.error
rstream.state = AsyncStreamState.Error
break
let res = resFut.read()
if res > 0:
sslEngineRecvrecAck(engine, uint(res))
continue
@ -199,21 +259,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
sslEngineClose(engine)
# Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state
# consistent.
wstream.state = rstream.state
if rstream.state == AsyncStreamState.Finished:
rstream.buffer.forget()
elif rstream.state == AsyncStreamState.Stopped:
rstream.buffer.forget()
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.complete()
elif rstream.state == AsyncStreamState.Error:
rstream.buffer.forget()
while len(wstream.queue) > 0:
let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()):
item.future.fail(rstream.error)
rstream.stream = nil
proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
@ -243,11 +289,19 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
var reader = new TlsStreamReader
reader.kind = TlsStreamKind.Client
var writer = new TlsStreamWriter
writer.kind = TlsStreamKind.Client
var switchToWriter = newAsyncEvent()
var switchToReader = newAsyncEvent()
reader.stream = result
writer.stream = result
reader.switchToReader = switchToReader
reader.switchToWriter = switchToWriter
writer.switchToReader = switchToReader
writer.switchToWriter = switchToWriter
result.reader = reader
result.writer = writer
reader.ccontext = addr result.context
writer.ccontext = addr result.context
if TLSFlags.NoVerifyHost in flags:
sslClientInitFull(addr result.context, addr result.x509, nil, 0)