diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index bffdde14..cbd5223f 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -29,7 +29,14 @@ type type TlsStreamWriter* = ref object of AsyncStreamWriter + case kind: TlsStreamKind + of TlsStreamKind.Client: + ccontext: ptr SslClientContext + of TlsStreamKind.Server: + scontext: ptr SslServerContext stream*: TlsAsyncStream + switchToReader*: AsyncEvent + switchToWriter*: AsyncEvent TlsStreamReader* = ref object of AsyncStreamReader case kind: TlsStreamKind @@ -38,6 +45,8 @@ type of TlsStreamKind.Server: scontext: ptr SslServerContext stream*: TlsAsyncStream + switchToReader*: AsyncEvent + switchToWriter*: AsyncEvent TlsAsyncStream* = ref object of RootRef xwc*: X509NoAnchorContext @@ -87,26 +96,110 @@ template newTlsStreamProtocolError[T](message: T): ref Exception = proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[TlsStreamWriter](stream) + var engine: ptr SslEngineContext + var error: ref Exception + + if wstream.kind == TLSStreamKind.Server: + engine = addr wstream.scontext.eng + else: + engine = addr wstream.ccontext.eng + + wstream.state = AsyncStreamState.Running + try: - # We waiting for empty future which will never be completed, because all - # the logic are inside of tlsReadLoop(). This infinite wait can be stopped - # by closing stream (e.g. cancellation). - var future = newFuture[void]() - await future + var length: uint + while true: + var state = engine.sslEngineCurrentState() + + if (state and SSL_CLOSED) == SSL_CLOSED: + wstream.state = AsyncStreamState.Finished + break + + if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: + wstream.switchToReader.fire() + + if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + await wstream.switchToWriter.wait() + wstream.switchToWriter.clear() + # We need to refresh `state` because we just returned from readerLoop. + continue + + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + var fut = awaitne wstream.wsource.write(buf, int(length)) + if fut.failed(): + error = fut.error + break + sslEngineSendrecAck(engine, length) + continue + + if (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + var item = await wstream.queue.get() + if item.size > 0: + length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + + if int(length) >= item.size: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, item.size) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], item.size) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], item.size) + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, length) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], length) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], length) + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + continue + else: + # Zero length item means finish + wstream.state = AsyncStreamState.Finished + break + except CancelledError: - discard + wstream.state = AsyncStreamState.Stopped + finally: + if wstream.state == AsyncStreamState.Stopped: + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.complete() + elif wstream.state == AsyncStreamState.Error: + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.fail(error) wstream.stream = nil proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[TlsStreamReader](stream) - var wstream = rstream.stream.writer var engine: ptr SslEngineContext + if rstream.kind == TlsStreamKind.Server: engine = addr rstream.scontext.eng else: engine = addr rstream.ccontext.eng + rstream.state = AsyncStreamState.Running + try: var length: uint while true: @@ -121,58 +214,25 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Stopped break - if (state and SSL_SENDREC) == SSL_SENDREC: - # TLS record needs to be sent over stream. - length = 0'u - var buf = sslEngineSendrecBuf(engine, length) - doAssert(length != 0 and not isNil(buf)) - await wstream.wsource.write(buf, int(length)) - sslEngineSendrecAck(engine, length) + if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: + rstream.switchToWriter.fire() + + if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + await rstream.switchToReader.wait() + rstream.switchToReader.clear() + # We need to refresh `state` because we just returned from writerLoop. continue - if (state and SSL_SENDAPP) == SSL_SENDAPP: - # Application data can be sent over stream. - if len(wstream.queue) > 0: - var item = await wstream.queue.get() - if item.size > 0: - length = 0'u - var buf = sslEngineSendappBuf(engine, length) - let toWrite = min(int(length), item.size) - - if int(length) >= item.size: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, item.size) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], item.size) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], item.size) - sslEngineSendappAck(engine, uint(item.size)) - sslEngineFlush(engine, 0) - item.future.complete() - else: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, length) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], length) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], length) - item.offset = item.offset + int(length) - item.size = item.size - int(length) - wstream.queue.addFirstNoWait(item) - sslEngineSendappAck(engine, length) - continue - else: - # Zero length item means finish - rstream.state = AsyncStreamState.Finished - break - if (state and SSL_RECVREC) == SSL_RECVREC: # TLS records required for further processing length = 0'u var buf = sslEngineRecvrecBuf(engine, length) - let res = await rstream.rsource.readOnce(buf, int(length)) + var resFut = awaitne rstream.rsource.readOnce(buf, int(length)) + if resFut.failed(): + rstream.error = resFut.error + rstream.state = AsyncStreamState.Error + break + let res = resFut.read() if res > 0: sslEngineRecvrecAck(engine, uint(res)) continue @@ -199,21 +259,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = sslEngineClose(engine) # Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state # consistent. - wstream.state = rstream.state - if rstream.state == AsyncStreamState.Finished: - rstream.buffer.forget() - elif rstream.state == AsyncStreamState.Stopped: - rstream.buffer.forget() - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.complete() - elif rstream.state == AsyncStreamState.Error: - rstream.buffer.forget() - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.fail(rstream.error) + rstream.buffer.forget() rstream.stream = nil proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, @@ -243,11 +289,19 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, var reader = new TlsStreamReader reader.kind = TlsStreamKind.Client var writer = new TlsStreamWriter + writer.kind = TlsStreamKind.Client + var switchToWriter = newAsyncEvent() + var switchToReader = newAsyncEvent() reader.stream = result writer.stream = result + reader.switchToReader = switchToReader + reader.switchToWriter = switchToWriter + writer.switchToReader = switchToReader + writer.switchToWriter = switchToWriter result.reader = reader result.writer = writer reader.ccontext = addr result.context + writer.ccontext = addr result.context if TLSFlags.NoVerifyHost in flags: sslClientInitFull(addr result.context, addr result.x509, nil, 0)