mirror of
https://github.com/status-im/nim-chronos.git
synced 2025-02-07 08:54:08 +00:00
Attempt to fix state machine issue.
This commit is contained in:
parent
417111093e
commit
3f8d529c8e
@ -29,7 +29,14 @@ type
|
|||||||
|
|
||||||
type
|
type
|
||||||
TlsStreamWriter* = ref object of AsyncStreamWriter
|
TlsStreamWriter* = ref object of AsyncStreamWriter
|
||||||
|
case kind: TlsStreamKind
|
||||||
|
of TlsStreamKind.Client:
|
||||||
|
ccontext: ptr SslClientContext
|
||||||
|
of TlsStreamKind.Server:
|
||||||
|
scontext: ptr SslServerContext
|
||||||
stream*: TlsAsyncStream
|
stream*: TlsAsyncStream
|
||||||
|
switchToReader*: AsyncEvent
|
||||||
|
switchToWriter*: AsyncEvent
|
||||||
|
|
||||||
TlsStreamReader* = ref object of AsyncStreamReader
|
TlsStreamReader* = ref object of AsyncStreamReader
|
||||||
case kind: TlsStreamKind
|
case kind: TlsStreamKind
|
||||||
@ -38,6 +45,8 @@ type
|
|||||||
of TlsStreamKind.Server:
|
of TlsStreamKind.Server:
|
||||||
scontext: ptr SslServerContext
|
scontext: ptr SslServerContext
|
||||||
stream*: TlsAsyncStream
|
stream*: TlsAsyncStream
|
||||||
|
switchToReader*: AsyncEvent
|
||||||
|
switchToWriter*: AsyncEvent
|
||||||
|
|
||||||
TlsAsyncStream* = ref object of RootRef
|
TlsAsyncStream* = ref object of RootRef
|
||||||
xwc*: X509NoAnchorContext
|
xwc*: X509NoAnchorContext
|
||||||
@ -87,26 +96,110 @@ template newTlsStreamProtocolError[T](message: T): ref Exception =
|
|||||||
|
|
||||||
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||||
var wstream = cast[TlsStreamWriter](stream)
|
var wstream = cast[TlsStreamWriter](stream)
|
||||||
|
var engine: ptr SslEngineContext
|
||||||
|
var error: ref Exception
|
||||||
|
|
||||||
|
if wstream.kind == TLSStreamKind.Server:
|
||||||
|
engine = addr wstream.scontext.eng
|
||||||
|
else:
|
||||||
|
engine = addr wstream.ccontext.eng
|
||||||
|
|
||||||
|
wstream.state = AsyncStreamState.Running
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# We waiting for empty future which will never be completed, because all
|
var length: uint
|
||||||
# the logic are inside of tlsReadLoop(). This infinite wait can be stopped
|
while true:
|
||||||
# by closing stream (e.g. cancellation).
|
var state = engine.sslEngineCurrentState()
|
||||||
var future = newFuture[void]()
|
|
||||||
await future
|
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||||
|
wstream.state = AsyncStreamState.Finished
|
||||||
|
break
|
||||||
|
|
||||||
|
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
||||||
|
wstream.switchToReader.fire()
|
||||||
|
|
||||||
|
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
|
||||||
|
await wstream.switchToWriter.wait()
|
||||||
|
wstream.switchToWriter.clear()
|
||||||
|
# We need to refresh `state` because we just returned from readerLoop.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||||
|
# TLS record needs to be sent over stream.
|
||||||
|
length = 0'u
|
||||||
|
var buf = sslEngineSendrecBuf(engine, length)
|
||||||
|
doAssert(length != 0 and not isNil(buf))
|
||||||
|
var fut = awaitne wstream.wsource.write(buf, int(length))
|
||||||
|
if fut.failed():
|
||||||
|
error = fut.error
|
||||||
|
break
|
||||||
|
sslEngineSendrecAck(engine, length)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||||
|
# Application data can be sent over stream.
|
||||||
|
var item = await wstream.queue.get()
|
||||||
|
if item.size > 0:
|
||||||
|
length = 0'u
|
||||||
|
var buf = sslEngineSendappBuf(engine, length)
|
||||||
|
let toWrite = min(int(length), item.size)
|
||||||
|
|
||||||
|
if int(length) >= item.size:
|
||||||
|
if item.kind == Pointer:
|
||||||
|
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
||||||
|
copyMem(buf, p, item.size)
|
||||||
|
elif item.kind == Sequence:
|
||||||
|
copyMem(buf, addr item.data2[item.offset], item.size)
|
||||||
|
elif item.kind == String:
|
||||||
|
copyMem(buf, addr item.data3[item.offset], item.size)
|
||||||
|
sslEngineSendappAck(engine, uint(item.size))
|
||||||
|
sslEngineFlush(engine, 0)
|
||||||
|
item.future.complete()
|
||||||
|
else:
|
||||||
|
if item.kind == Pointer:
|
||||||
|
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
||||||
|
copyMem(buf, p, length)
|
||||||
|
elif item.kind == Sequence:
|
||||||
|
copyMem(buf, addr item.data2[item.offset], length)
|
||||||
|
elif item.kind == String:
|
||||||
|
copyMem(buf, addr item.data3[item.offset], length)
|
||||||
|
item.offset = item.offset + int(length)
|
||||||
|
item.size = item.size - int(length)
|
||||||
|
wstream.queue.addFirstNoWait(item)
|
||||||
|
sslEngineSendappAck(engine, length)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Zero length item means finish
|
||||||
|
wstream.state = AsyncStreamState.Finished
|
||||||
|
break
|
||||||
|
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
discard
|
wstream.state = AsyncStreamState.Stopped
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
if wstream.state == AsyncStreamState.Stopped:
|
||||||
|
while len(wstream.queue) > 0:
|
||||||
|
let item = wstream.queue.popFirstNoWait()
|
||||||
|
if not(item.future.finished()):
|
||||||
|
item.future.complete()
|
||||||
|
elif wstream.state == AsyncStreamState.Error:
|
||||||
|
while len(wstream.queue) > 0:
|
||||||
|
let item = wstream.queue.popFirstNoWait()
|
||||||
|
if not(item.future.finished()):
|
||||||
|
item.future.fail(error)
|
||||||
wstream.stream = nil
|
wstream.stream = nil
|
||||||
|
|
||||||
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||||
var rstream = cast[TlsStreamReader](stream)
|
var rstream = cast[TlsStreamReader](stream)
|
||||||
var wstream = rstream.stream.writer
|
|
||||||
var engine: ptr SslEngineContext
|
var engine: ptr SslEngineContext
|
||||||
|
|
||||||
if rstream.kind == TlsStreamKind.Server:
|
if rstream.kind == TlsStreamKind.Server:
|
||||||
engine = addr rstream.scontext.eng
|
engine = addr rstream.scontext.eng
|
||||||
else:
|
else:
|
||||||
engine = addr rstream.ccontext.eng
|
engine = addr rstream.ccontext.eng
|
||||||
|
|
||||||
|
rstream.state = AsyncStreamState.Running
|
||||||
|
|
||||||
try:
|
try:
|
||||||
var length: uint
|
var length: uint
|
||||||
while true:
|
while true:
|
||||||
@ -121,58 +214,25 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
|||||||
rstream.state = AsyncStreamState.Stopped
|
rstream.state = AsyncStreamState.Stopped
|
||||||
break
|
break
|
||||||
|
|
||||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
||||||
# TLS record needs to be sent over stream.
|
rstream.switchToWriter.fire()
|
||||||
length = 0'u
|
|
||||||
var buf = sslEngineSendrecBuf(engine, length)
|
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
|
||||||
doAssert(length != 0 and not isNil(buf))
|
await rstream.switchToReader.wait()
|
||||||
await wstream.wsource.write(buf, int(length))
|
rstream.switchToReader.clear()
|
||||||
sslEngineSendrecAck(engine, length)
|
# We need to refresh `state` because we just returned from writerLoop.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
|
||||||
# Application data can be sent over stream.
|
|
||||||
if len(wstream.queue) > 0:
|
|
||||||
var item = await wstream.queue.get()
|
|
||||||
if item.size > 0:
|
|
||||||
length = 0'u
|
|
||||||
var buf = sslEngineSendappBuf(engine, length)
|
|
||||||
let toWrite = min(int(length), item.size)
|
|
||||||
|
|
||||||
if int(length) >= item.size:
|
|
||||||
if item.kind == Pointer:
|
|
||||||
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
|
||||||
copyMem(buf, p, item.size)
|
|
||||||
elif item.kind == Sequence:
|
|
||||||
copyMem(buf, addr item.data2[item.offset], item.size)
|
|
||||||
elif item.kind == String:
|
|
||||||
copyMem(buf, addr item.data3[item.offset], item.size)
|
|
||||||
sslEngineSendappAck(engine, uint(item.size))
|
|
||||||
sslEngineFlush(engine, 0)
|
|
||||||
item.future.complete()
|
|
||||||
else:
|
|
||||||
if item.kind == Pointer:
|
|
||||||
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
|
||||||
copyMem(buf, p, length)
|
|
||||||
elif item.kind == Sequence:
|
|
||||||
copyMem(buf, addr item.data2[item.offset], length)
|
|
||||||
elif item.kind == String:
|
|
||||||
copyMem(buf, addr item.data3[item.offset], length)
|
|
||||||
item.offset = item.offset + int(length)
|
|
||||||
item.size = item.size - int(length)
|
|
||||||
wstream.queue.addFirstNoWait(item)
|
|
||||||
sslEngineSendappAck(engine, length)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Zero length item means finish
|
|
||||||
rstream.state = AsyncStreamState.Finished
|
|
||||||
break
|
|
||||||
|
|
||||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||||
# TLS records required for further processing
|
# TLS records required for further processing
|
||||||
length = 0'u
|
length = 0'u
|
||||||
var buf = sslEngineRecvrecBuf(engine, length)
|
var buf = sslEngineRecvrecBuf(engine, length)
|
||||||
let res = await rstream.rsource.readOnce(buf, int(length))
|
var resFut = awaitne rstream.rsource.readOnce(buf, int(length))
|
||||||
|
if resFut.failed():
|
||||||
|
rstream.error = resFut.error
|
||||||
|
rstream.state = AsyncStreamState.Error
|
||||||
|
break
|
||||||
|
let res = resFut.read()
|
||||||
if res > 0:
|
if res > 0:
|
||||||
sslEngineRecvrecAck(engine, uint(res))
|
sslEngineRecvrecAck(engine, uint(res))
|
||||||
continue
|
continue
|
||||||
@ -199,21 +259,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
|||||||
sslEngineClose(engine)
|
sslEngineClose(engine)
|
||||||
# Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state
|
# Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state
|
||||||
# consistent.
|
# consistent.
|
||||||
wstream.state = rstream.state
|
rstream.buffer.forget()
|
||||||
if rstream.state == AsyncStreamState.Finished:
|
|
||||||
rstream.buffer.forget()
|
|
||||||
elif rstream.state == AsyncStreamState.Stopped:
|
|
||||||
rstream.buffer.forget()
|
|
||||||
while len(wstream.queue) > 0:
|
|
||||||
let item = wstream.queue.popFirstNoWait()
|
|
||||||
if not(item.future.finished()):
|
|
||||||
item.future.complete()
|
|
||||||
elif rstream.state == AsyncStreamState.Error:
|
|
||||||
rstream.buffer.forget()
|
|
||||||
while len(wstream.queue) > 0:
|
|
||||||
let item = wstream.queue.popFirstNoWait()
|
|
||||||
if not(item.future.finished()):
|
|
||||||
item.future.fail(rstream.error)
|
|
||||||
rstream.stream = nil
|
rstream.stream = nil
|
||||||
|
|
||||||
proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
|
proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
|
||||||
@ -243,11 +289,19 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
|
|||||||
var reader = new TlsStreamReader
|
var reader = new TlsStreamReader
|
||||||
reader.kind = TlsStreamKind.Client
|
reader.kind = TlsStreamKind.Client
|
||||||
var writer = new TlsStreamWriter
|
var writer = new TlsStreamWriter
|
||||||
|
writer.kind = TlsStreamKind.Client
|
||||||
|
var switchToWriter = newAsyncEvent()
|
||||||
|
var switchToReader = newAsyncEvent()
|
||||||
reader.stream = result
|
reader.stream = result
|
||||||
writer.stream = result
|
writer.stream = result
|
||||||
|
reader.switchToReader = switchToReader
|
||||||
|
reader.switchToWriter = switchToWriter
|
||||||
|
writer.switchToReader = switchToReader
|
||||||
|
writer.switchToWriter = switchToWriter
|
||||||
result.reader = reader
|
result.reader = reader
|
||||||
result.writer = writer
|
result.writer = writer
|
||||||
reader.ccontext = addr result.context
|
reader.ccontext = addr result.context
|
||||||
|
writer.ccontext = addr result.context
|
||||||
|
|
||||||
if TLSFlags.NoVerifyHost in flags:
|
if TLSFlags.NoVerifyHost in flags:
|
||||||
sslClientInitFull(addr result.context, addr result.x509, nil, 0)
|
sslClientInitFull(addr result.context, addr result.x509, nil, 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user