From c27624cfc047d1487819ed25207e93f55eb2589d Mon Sep 17 00:00:00 2001 From: cheatfate Date: Tue, 8 Oct 2019 18:46:27 +0300 Subject: [PATCH 01/12] Add TlsStream with client-only connections. --- chronos.nimble | 5 +- chronos/streams/tlsstream.nim | 276 ++++++++++++++++++++++++++++++++++ tests/testasyncstream.nim | 50 +++++- 3 files changed, 328 insertions(+), 3 deletions(-) create mode 100644 chronos/streams/tlsstream.nim diff --git a/chronos.nimble b/chronos.nimble index b397996..82cb803 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,5 +1,5 @@ packageName = "chronos" -version = "2.3.0" +version = "2.3.1" author = "Status Research & Development GmbH" description = "Chronos" license = "Apache License 2.0 or MIT" @@ -7,7 +7,8 @@ skipDirs = @["tests"] ### Dependencies -requires "nim > 0.18.0" +requires "nim > 0.19.4", + "bearssl" task test, "Run all tests": var commands = [ diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim new file mode 100644 index 0000000..345b3aa --- /dev/null +++ b/chronos/streams/tlsstream.nim @@ -0,0 +1,276 @@ +# +# Chronos Asynchronous TLS Stream +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module implements TLS stream reading and writing. +import bearssl, bearssl/cacert +import ../asyncloop, ../timer, ../asyncsync +import asyncstream, ../transports/stream, ../transports/common +import strutils +import hexdump + +type + TLSStreamKind {.pure.} = enum + Client, Server + + TLSVersion* {.pure.} = enum + TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303 + + TLSFlags* {.pure.} = enum + NoVerifyHost, # Client: Skip remote certificate check + NoVerifySN, # Client: Skip Server Name Indication (SNI) check + NoRenegotiation, # Server: Reject renegotiations requests + NoClientAuth, # Server: Disable strict client authentication + FailOnAlpnMismatch # Server: Fail on application protocol mismatch + +type + TlsStreamWriter* = ref object of AsyncStreamWriter + + TlsStreamReader* = ref object of AsyncStreamReader + case kind: TlsStreamKind + of TlsStreamKind.Client: + ccontext: ptr SslClientContext + of TlsStreamKind.Server: + scontext: ptr SslServerContext + writer*: TlsStreamWriter + + TlsAsyncStream* = ref object of RootRef + xwc*: X509NoAnchorContext + context*: SslClientContext + sbuffer*: seq[byte] + x509*: X509MinimalContext + reader*: TlsStreamReader + writer*: TlsStreamWriter + + TlsStreamError* = object of CatchableError + TlsStreamProtocolError* = object of TlsStreamError + errCode*: int + +template newTlsStreamProtocolError[T](message: T): ref Exception = + var msg = "" + var code = 0 + when T is string: + msg.add(message) + elif T is cint: + msg.add(sslErrorMsg(message) & " (code: " & $int(message) & ")") + code = int(message) + elif T is int: + msg.add(sslErrorMsg(message) & " (code: " & $message & ")") + code = message + else: + msg.add("Internal Error") + var err = newException(TlsStreamProtocolError, msg) + err.errCode = code + err + +# proc raiseTlsStreamProtoError*[T](message: T) = +# raise newTlsStreamProtocolError(message) + +# proc getStringState*(state: cuint): string = +# var n = newSeq[string]() +# if (state and SSL_CLOSED) == SSL_CLOSED: +# n.add("Closed") +# if (state and SSL_SENDREC) == SSL_SENDREC: +# n.add("SendRec") +# if (state and SSL_RECVREC) == SSL_RECVREC: +# n.add("RecvRec") +# if (state and SSL_SENDAPP) == SSL_SENDAPP: +# n.add("SendApp") +# if (state and SSL_RECVAPP) == SSL_RECVAPP: +# n.add("RecvApp") +# result = "{" & n.join(", ") & "} number (" & $state & ")" + +proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = + var wstream = cast[TlsStreamWriter](stream) + try: + # We waiting for empty future which will never be completed, because all + # the logic are inside of tlsReadLoop(). This infinite wait can be stopped + # by closing stream (e.g. cancellation). + var future = newFuture[void]() + await future + except CancelledError: + discard + +proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = + var rstream = cast[TlsStreamReader](stream) + var wstream = rstream.writer + var engine: ptr SslEngineContext + if rstream.kind == TlsStreamKind.Server: + engine = addr rstream.scontext.eng + else: + engine = addr rstream.ccontext.eng + + try: + var length: uint + while true: + var state = engine.sslEngineCurrentState() + if (state and SSL_CLOSED) == SSL_CLOSED: + let err = engine.sslEngineLastError() + if err != 0: + rstream.error = newTlsStreamProtocolError(err) + rstream.state = AsyncStreamState.Error + break + else: + rstream.state = AsyncStreamState.Stopped + break + + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + await wstream.wsource.write(buf, int(length)) + sslEngineSendrecAck(engine, length) + continue + + if (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + if len(wstream.queue) > 0: + var item = await wstream.queue.get() + if item.size > 0: + length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + + if int(length) >= item.size: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, item.size) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], item.size) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], item.size) + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, length) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], length) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], length) + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + continue + else: + # Zero length item means finish + rstream.state = AsyncStreamState.Finished + break + + if (state and SSL_RECVREC) == SSL_RECVREC: + # TLS records required for further processing + length = 0'u + var buf = sslEngineRecvrecBuf(engine, length) + let res = await rstream.rsource.readOnce(buf, int(length)) + if res > 0: + sslEngineRecvrecAck(engine, uint(res)) + continue + else: + rstream.state = AsyncStreamState.Finished + break + + if (state and SSL_RECVAPP) == SSL_RECVAPP: + # Application data can be recovered. + length = 0'u + var buf = sslEngineRecvappBuf(engine, length) + let toRead = min(int(length), rstream.buffer.bufferLen()) + copyMem(rstream.buffer.getBuffer(), buf, toRead) + rstream.buffer.update(toRead) + sslEngineRecvappAck(engine, uint(toRead)) + await rstream.buffer.transfer() + continue + + except CancelledError: + rstream.state = AsyncStreamState.Stopped + + finally: + # Perform TLS cleanup procedure + sslEngineClose(engine) + # Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state + # consistent. + wstream.state = rstream.state + if rstream.state == AsyncStreamState.Finished: + rstream.buffer.forget() + elif rstream.state == AsyncStreamState.Stopped: + rstream.buffer.forget() + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.complete() + elif rstream.state == AsyncStreamState.Error: + rstream.buffer.forget() + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.fail(rstream.error) + +proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + serverName: string = "", + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS11, + maxVersion = TLSVersion.TLS12, + flags: set[TlsFlags] = {}): TlsAsyncStream = + ## Create new TLS asynchronous stream using reading stream ``rsource``, + ## writing stream ``wsource``. + ## + ## You can specify remote server name using ``serverName``, if while + ## handshake server reports different name you will get an error. If + ## ``serverName`` is empty string, remote server name checking will be + ## disabled. + ## + ## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding + ## incoming data. + ## + ## ``minVersion`` and ``maxVersion`` are TLS versions which will be used + ## for handshake with remote server. If server's version will be lower then + ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. + ## + ## ``flags`` - custom TLS connection flags. + result = new TlsAsyncStream + var reader = new TlsStreamReader + reader.kind = TlsStreamKind.Client + var writer = new TlsStreamWriter + result.reader = reader + result.writer = writer + result.reader.writer = writer + reader.ccontext = addr result.context + + if TLSFlags.NoVerifyHost in flags: + sslClientInitFull(addr result.context, addr result.x509, nil, 0) + initNoAnchor(addr result.xwc, addr result.x509.vtable) + sslEngineSetX509(addr result.context.eng, addr result.xwc.vtable) + else: + sslClientInitFull(addr result.context, addr result.x509, + unsafeAddr MozillaTrustAnchors[0], + len(MozillaTrustAnchors)) + + let size = max(SSL_BUFSIZE_BIDI, bufferSize) + result.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr result.context.eng, addr result.sbuffer[0], + uint(len(result.sbuffer)), 1) + sslEngineSetVersions(addr result.context.eng, uint16(minVersion), + uint16(maxVersion)) + + if TLSFlags.NoVerifySN in flags: + let err = sslClientReset(addr result.context, "", 0) + if err == 0: + raise newException(TlsStreamError, "Could not initialize TLS layer") + else: + let err = sslClientReset(addr result.context, serverName, 0) + if err == 0: + raise newException(TlsStreamError, "Could not initialize TLS layer") + + init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + bufferSize) + init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + bufferSize) diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index c6dee9f..6343ceb 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -6,7 +6,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import strutils, unittest, os -import ../chronos +import ../chronos, ../chronos/streams/tlsstream suite "AsyncStream test suite": test "AsyncStream(StreamTransport) readExactly() test": @@ -506,3 +506,51 @@ suite "ChunkedStream test suite": break result = res check waitFor(testVectors2(initTAddress("127.0.0.1:46001"))) == true + + test "ChunkedStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false + +suite "TLSStream test suite": + const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] + test "Simple HTTPS connection": + proc headerClient(address: TransportAddress, + name: string): Future[bool] {.async.} = + var mark = "HTTP/1.1 " + var buffer = newSeq[byte](8192) + var transp = await connect(address) + var reader = newAsyncStreamReader(transp) + var writer = newAsyncStreamWriter(transp) + var tlsstream = newTlsClientAsyncStream(reader, writer, name) + + await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & + "\r\nConnection: close\r\n\r\n") + var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer), + HttpHeadersMark) + let res = await withTimeout(readFut, 5.seconds) + if res: + var length = readFut.read() + buffer.setLen(length) + if len(buffer) > len(mark): + if equalMem(addr buffer[0], addr mark[0], len(mark)): + result = true + + await tlsstream.reader.closeWait() + await tlsstream.writer.closeWait() + await reader.closeWait() + await writer.closeWait() + await transp.closeWait() + + let res = waitFor(headerClient(resolveTAddress("www.google.com:443")[0], + "www.google.com")) + check res == true + + test "TlsStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false From cae1d0969011e3b53dbfa061bb57f4779b788ba4 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Tue, 8 Oct 2019 19:02:42 +0300 Subject: [PATCH 02/12] Removed debugging imports. --- chronos/streams/tlsstream.nim | 2 -- 1 file changed, 2 deletions(-) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 345b3aa..369ecac 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -11,8 +11,6 @@ import bearssl, bearssl/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common -import strutils -import hexdump type TLSStreamKind {.pure.} = enum From e19101d287c71cd23f31d6a5d6434a275bca2c56 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Tue, 8 Oct 2019 20:30:43 +0300 Subject: [PATCH 03/12] Add GC reference to reader and writer. --- chronos/streams/tlsstream.nim | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 369ecac..dec9720 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -11,6 +11,7 @@ import bearssl, bearssl/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common +import strutils type TLSStreamKind {.pure.} = enum @@ -28,6 +29,7 @@ type type TlsStreamWriter* = ref object of AsyncStreamWriter + stream*: TlsAsyncStream TlsStreamReader* = ref object of AsyncStreamReader case kind: TlsStreamKind @@ -35,7 +37,7 @@ type ccontext: ptr SslClientContext of TlsStreamKind.Server: scontext: ptr SslServerContext - writer*: TlsStreamWriter + stream*: TlsAsyncStream TlsAsyncStream* = ref object of RootRef xwc*: X509NoAnchorContext @@ -96,7 +98,7 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[TlsStreamReader](stream) - var wstream = rstream.writer + var wstream = rstream.stream.writer var engine: ptr SslEngineContext if rstream.kind == TlsStreamKind.Server: engine = addr rstream.scontext.eng @@ -238,9 +240,10 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, var reader = new TlsStreamReader reader.kind = TlsStreamKind.Client var writer = new TlsStreamWriter + reader.stream = result + writer.stream = result result.reader = reader result.writer = writer - result.reader.writer = writer reader.ccontext = addr result.context if TLSFlags.NoVerifyHost in flags: From 417111093e20baa407faf9ca2ae6d07d3425f982 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Tue, 8 Oct 2019 20:38:39 +0300 Subject: [PATCH 04/12] Cleanup references on exit. --- chronos/streams/tlsstream.nim | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index dec9720..bffdde1 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -95,6 +95,8 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = await future except CancelledError: discard + finally: + wstream.stream = nil proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[TlsStreamReader](stream) @@ -212,6 +214,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = let item = wstream.queue.popFirstNoWait() if not(item.future.finished()): item.future.fail(rstream.error) + rstream.stream = nil proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, From 3f8d529c8ef5f07e898ec144c48e2c942f12513a Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 9 Oct 2019 09:12:54 +0300 Subject: [PATCH 05/12] Attempt to fix state machine issue. --- chronos/streams/tlsstream.nim | 190 ++++++++++++++++++++++------------ 1 file changed, 122 insertions(+), 68 deletions(-) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index bffdde1..cbd5223 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -29,7 +29,14 @@ type type TlsStreamWriter* = ref object of AsyncStreamWriter + case kind: TlsStreamKind + of TlsStreamKind.Client: + ccontext: ptr SslClientContext + of TlsStreamKind.Server: + scontext: ptr SslServerContext stream*: TlsAsyncStream + switchToReader*: AsyncEvent + switchToWriter*: AsyncEvent TlsStreamReader* = ref object of AsyncStreamReader case kind: TlsStreamKind @@ -38,6 +45,8 @@ type of TlsStreamKind.Server: scontext: ptr SslServerContext stream*: TlsAsyncStream + switchToReader*: AsyncEvent + switchToWriter*: AsyncEvent TlsAsyncStream* = ref object of RootRef xwc*: X509NoAnchorContext @@ -87,26 +96,110 @@ template newTlsStreamProtocolError[T](message: T): ref Exception = proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[TlsStreamWriter](stream) + var engine: ptr SslEngineContext + var error: ref Exception + + if wstream.kind == TLSStreamKind.Server: + engine = addr wstream.scontext.eng + else: + engine = addr wstream.ccontext.eng + + wstream.state = AsyncStreamState.Running + try: - # We waiting for empty future which will never be completed, because all - # the logic are inside of tlsReadLoop(). This infinite wait can be stopped - # by closing stream (e.g. cancellation). - var future = newFuture[void]() - await future + var length: uint + while true: + var state = engine.sslEngineCurrentState() + + if (state and SSL_CLOSED) == SSL_CLOSED: + wstream.state = AsyncStreamState.Finished + break + + if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: + wstream.switchToReader.fire() + + if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + await wstream.switchToWriter.wait() + wstream.switchToWriter.clear() + # We need to refresh `state` because we just returned from readerLoop. + continue + + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + var fut = awaitne wstream.wsource.write(buf, int(length)) + if fut.failed(): + error = fut.error + break + sslEngineSendrecAck(engine, length) + continue + + if (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + var item = await wstream.queue.get() + if item.size > 0: + length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + + if int(length) >= item.size: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, item.size) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], item.size) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], item.size) + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, length) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], length) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], length) + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + continue + else: + # Zero length item means finish + wstream.state = AsyncStreamState.Finished + break + except CancelledError: - discard + wstream.state = AsyncStreamState.Stopped + finally: + if wstream.state == AsyncStreamState.Stopped: + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.complete() + elif wstream.state == AsyncStreamState.Error: + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.fail(error) wstream.stream = nil proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var rstream = cast[TlsStreamReader](stream) - var wstream = rstream.stream.writer var engine: ptr SslEngineContext + if rstream.kind == TlsStreamKind.Server: engine = addr rstream.scontext.eng else: engine = addr rstream.ccontext.eng + rstream.state = AsyncStreamState.Running + try: var length: uint while true: @@ -121,58 +214,25 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Stopped break - if (state and SSL_SENDREC) == SSL_SENDREC: - # TLS record needs to be sent over stream. - length = 0'u - var buf = sslEngineSendrecBuf(engine, length) - doAssert(length != 0 and not isNil(buf)) - await wstream.wsource.write(buf, int(length)) - sslEngineSendrecAck(engine, length) + if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: + rstream.switchToWriter.fire() + + if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + await rstream.switchToReader.wait() + rstream.switchToReader.clear() + # We need to refresh `state` because we just returned from writerLoop. continue - if (state and SSL_SENDAPP) == SSL_SENDAPP: - # Application data can be sent over stream. - if len(wstream.queue) > 0: - var item = await wstream.queue.get() - if item.size > 0: - length = 0'u - var buf = sslEngineSendappBuf(engine, length) - let toWrite = min(int(length), item.size) - - if int(length) >= item.size: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, item.size) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], item.size) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], item.size) - sslEngineSendappAck(engine, uint(item.size)) - sslEngineFlush(engine, 0) - item.future.complete() - else: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, length) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], length) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], length) - item.offset = item.offset + int(length) - item.size = item.size - int(length) - wstream.queue.addFirstNoWait(item) - sslEngineSendappAck(engine, length) - continue - else: - # Zero length item means finish - rstream.state = AsyncStreamState.Finished - break - if (state and SSL_RECVREC) == SSL_RECVREC: # TLS records required for further processing length = 0'u var buf = sslEngineRecvrecBuf(engine, length) - let res = await rstream.rsource.readOnce(buf, int(length)) + var resFut = awaitne rstream.rsource.readOnce(buf, int(length)) + if resFut.failed(): + rstream.error = resFut.error + rstream.state = AsyncStreamState.Error + break + let res = resFut.read() if res > 0: sslEngineRecvrecAck(engine, uint(res)) continue @@ -199,21 +259,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = sslEngineClose(engine) # Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state # consistent. - wstream.state = rstream.state - if rstream.state == AsyncStreamState.Finished: - rstream.buffer.forget() - elif rstream.state == AsyncStreamState.Stopped: - rstream.buffer.forget() - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.complete() - elif rstream.state == AsyncStreamState.Error: - rstream.buffer.forget() - while len(wstream.queue) > 0: - let item = wstream.queue.popFirstNoWait() - if not(item.future.finished()): - item.future.fail(rstream.error) + rstream.buffer.forget() rstream.stream = nil proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, @@ -243,11 +289,19 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, var reader = new TlsStreamReader reader.kind = TlsStreamKind.Client var writer = new TlsStreamWriter + writer.kind = TlsStreamKind.Client + var switchToWriter = newAsyncEvent() + var switchToReader = newAsyncEvent() reader.stream = result writer.stream = result + reader.switchToReader = switchToReader + reader.switchToWriter = switchToWriter + writer.switchToReader = switchToReader + writer.switchToWriter = switchToWriter result.reader = reader result.writer = writer reader.ccontext = addr result.context + writer.ccontext = addr result.context if TLSFlags.NoVerifyHost in flags: sslClientInitFull(addr result.context, addr result.x509, nil, 0) From 5c801a5dbc519f25ee2a189a20d455bfee109ecc Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 10 Oct 2019 12:52:12 +0300 Subject: [PATCH 06/12] Add upload() and some debugging. --- chronos/streams/asyncstream.nim | 18 ++++++++++++ chronos/streams/tlsstream.nim | 49 ++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index ade48f4..c1a1554 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -141,6 +141,24 @@ proc copyData*(sb: AsyncBuffer, dest: pointer, offset, length: int) {.inline.} = copyMem(cast[pointer](cast[uint](dest) + cast[uint](offset)), unsafeAddr sb.buffer[0], length) +proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte, + nbytes: int): Future[void] {.async.} = + var length = nbytes + while length > 0: + let size = min(length, sb[].bufferLen()) + if size == 0: + # Internal buffer is full, we need to transfer data to consumer. + await sb[].transfer() + continue + else: + copyMem(addr sb[].buffer[sb.offset], pbytes, size) + sb[].offset = sb[].offset + size + length = length - size + + if length == 0: + # We notify consumers that new data is available. + sb[].forget() + template toDataOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, 0, sb.offset - 1) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index cbd5223..c25cbc2 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -11,7 +11,7 @@ import bearssl, bearssl/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common -import strutils +import strutils, hexdump type TLSStreamKind {.pure.} = enum @@ -80,19 +80,19 @@ template newTlsStreamProtocolError[T](message: T): ref Exception = # proc raiseTlsStreamProtoError*[T](message: T) = # raise newTlsStreamProtocolError(message) -# proc getStringState*(state: cuint): string = -# var n = newSeq[string]() -# if (state and SSL_CLOSED) == SSL_CLOSED: -# n.add("Closed") -# if (state and SSL_SENDREC) == SSL_SENDREC: -# n.add("SendRec") -# if (state and SSL_RECVREC) == SSL_RECVREC: -# n.add("RecvRec") -# if (state and SSL_SENDAPP) == SSL_SENDAPP: -# n.add("SendApp") -# if (state and SSL_RECVAPP) == SSL_RECVAPP: -# n.add("RecvApp") -# result = "{" & n.join(", ") & "} number (" & $state & ")" +proc getStringState*(state: cuint): string = + var n = newSeq[string]() + if (state and SSL_CLOSED) == SSL_CLOSED: + n.add("Closed") + if (state and SSL_SENDREC) == SSL_SENDREC: + n.add("SendRec") + if (state and SSL_RECVREC) == SSL_RECVREC: + n.add("RecvRec") + if (state and SSL_SENDAPP) == SSL_SENDAPP: + n.add("SendApp") + if (state and SSL_RECVAPP) == SSL_RECVAPP: + n.add("RecvApp") + result = "{" & n.join(", ") & "} number (" & $state & ")" proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[TlsStreamWriter](stream) @@ -110,15 +110,18 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var length: uint while true: var state = engine.sslEngineCurrentState() + echo "tlsWriteLoop() state = ", getStringState(state) if (state and SSL_CLOSED) == SSL_CLOSED: wstream.state = AsyncStreamState.Finished break if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: + echo "tlsWriteLoop() signal to tlsReadLoop()" wstream.switchToReader.fire() if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + echo "tlsWriteLoop() waiting" await wstream.switchToWriter.wait() wstream.switchToWriter.clear() # We need to refresh `state` because we just returned from readerLoop. @@ -138,7 +141,9 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = if (state and SSL_SENDAPP) == SSL_SENDAPP: # Application data can be sent over stream. + echo "tlsWriteLoop() waiting for an item" var item = await wstream.queue.get() + echo "tlsWriteLoop() obtained an item" if item.size > 0: length = 0'u var buf = sslEngineSendappBuf(engine, length) @@ -204,6 +209,8 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var length: uint while true: var state = engine.sslEngineCurrentState() + echo "tlsReadLoop() state = ", getStringState(state) + if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() if err != 0: @@ -215,9 +222,11 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = break if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: + echo "tlsReadLoop() signal to tlsWriteLoop()" rstream.switchToWriter.fire() if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + echo "tlsReadLoop() waiting" await rstream.switchToReader.wait() rstream.switchToReader.clear() # We need to refresh `state` because we just returned from writerLoop. @@ -227,7 +236,9 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = # TLS records required for further processing length = 0'u var buf = sslEngineRecvrecBuf(engine, length) + echo "tlsReadLoop() reading" var resFut = awaitne rstream.rsource.readOnce(buf, int(length)) + echo "tlsReadLoop() read completed" if resFut.failed(): rstream.error = resFut.error rstream.state = AsyncStreamState.Error @@ -244,11 +255,11 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = # Application data can be recovered. length = 0'u var buf = sslEngineRecvappBuf(engine, length) - let toRead = min(int(length), rstream.buffer.bufferLen()) - copyMem(rstream.buffer.getBuffer(), buf, toRead) - rstream.buffer.update(toRead) - sslEngineRecvappAck(engine, uint(toRead)) - await rstream.buffer.transfer() + echo "tlsReadLoop(SSL_RECVAPP) received ", length, " bytes" + await upload(addr rstream.buffer, buf, int(length)) + echo dumpHex(buf, int(length)) + echo "tlsReadLoop(SSL_RECVAPP) uploaded ", length, " bytes to buffer" + sslEngineRecvappAck(engine, length) continue except CancelledError: From fe6fca1e677f32353f3991bf0d606bfe566f9241 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 10 Oct 2019 13:01:14 +0300 Subject: [PATCH 07/12] Add hexdump.nim. --- chronos/streams/hexdump.nim | 92 +++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 chronos/streams/hexdump.nim diff --git a/chronos/streams/hexdump.nim b/chronos/streams/hexdump.nim new file mode 100644 index 0000000..38498d9 --- /dev/null +++ b/chronos/streams/hexdump.nim @@ -0,0 +1,92 @@ +# +# Copyright (c) 2016 Eugene Kabanov +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from strutils import toHex, repeat + +proc dumpHex*(pbytes: pointer, nbytes: int, items = 1, ascii = true): string = + ## Return hexadecimal memory dump representation pointed by ``p``. + ## ``nbytes`` - number of bytes to show + ## ``items`` - number of bytes in group (supported ``items`` count is + ## 1, 2, 4, 8) + ## ``ascii`` - if ``true`` show ASCII representation of memory dump. + result = "" + let hexSize = items * 2 + var i = 0 + var slider = pbytes + var asciiText = "" + while i < nbytes: + if i %% 16 == 0: + result = result & toHex(cast[BiggestInt](slider), + sizeof(BiggestInt) * 2) & ": " + var k = 0 + while k < items: + var ch = cast[ptr char](cast[uint](slider) + k.uint)[] + if ord(ch) > 31 and ord(ch) < 127: asciiText &= ch else: asciiText &= "." + inc(k) + case items: + of 1: + result = result & toHex(cast[BiggestInt](cast[ptr uint8](slider)[]), + hexSize) + of 2: + result = result & toHex(cast[BiggestInt](cast[ptr uint16](slider)[]), + hexSize) + of 4: + result = result & toHex(cast[BiggestInt](cast[ptr uint32](slider)[]), + hexSize) + of 8: + result = result & toHex(cast[BiggestInt](cast[ptr uint64](slider)[]), + hexSize) + else: + raise newException(ValueError, "Wrong items size!") + result = result & " " + slider = cast[pointer](cast[uint](slider) + items.uint) + i = i + items + if i %% 16 == 0: + result = result & " " & asciiText + asciiText.setLen(0) + result = result & "\n" + + if i %% 16 != 0: + var spacesCount = ((16 - (i %% 16)) div items) * (hexSize + 1) + 1 + result = result & repeat(' ', spacesCount) + result = result & asciiText + result = result & "\n" + +proc dumpHex*[T](v: openarray[T], items: int = 0, ascii = true): string = + ## Return hexadecimal memory dump representation of openarray[T] ``v``. + ## ``items`` - number of bytes in group (supported ``items`` count is + ## 0, 1, 2, 4, 8). If ``items`` is ``0`` group size will depend on + ## ``sizeof(T)``. + ## ``ascii`` - if ``true`` show ASCII representation of memory dump. + var i = 0 + if items == 0: + when sizeof(T) == 2: + i = 2 + elif sizeof(T) == 4: + i = 4 + elif sizeof(T) == 8: + i = 8 + else: + i = 1 + else: + i = items + result = dumpHex(unsafeAddr v[0], sizeof(T) * len(v), i, ascii) From 161c50209eb854dd416ff7fed33a7929cb9dbe32 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 10 Oct 2019 14:53:33 +0300 Subject: [PATCH 08/12] Remove debugging echos. --- chronos/streams/hexdump.nim | 92 ----------------------------------- chronos/streams/tlsstream.nim | 21 ++------ 2 files changed, 5 insertions(+), 108 deletions(-) delete mode 100644 chronos/streams/hexdump.nim diff --git a/chronos/streams/hexdump.nim b/chronos/streams/hexdump.nim deleted file mode 100644 index 38498d9..0000000 --- a/chronos/streams/hexdump.nim +++ /dev/null @@ -1,92 +0,0 @@ -# -# Copyright (c) 2016 Eugene Kabanov -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# - -from strutils import toHex, repeat - -proc dumpHex*(pbytes: pointer, nbytes: int, items = 1, ascii = true): string = - ## Return hexadecimal memory dump representation pointed by ``p``. - ## ``nbytes`` - number of bytes to show - ## ``items`` - number of bytes in group (supported ``items`` count is - ## 1, 2, 4, 8) - ## ``ascii`` - if ``true`` show ASCII representation of memory dump. - result = "" - let hexSize = items * 2 - var i = 0 - var slider = pbytes - var asciiText = "" - while i < nbytes: - if i %% 16 == 0: - result = result & toHex(cast[BiggestInt](slider), - sizeof(BiggestInt) * 2) & ": " - var k = 0 - while k < items: - var ch = cast[ptr char](cast[uint](slider) + k.uint)[] - if ord(ch) > 31 and ord(ch) < 127: asciiText &= ch else: asciiText &= "." - inc(k) - case items: - of 1: - result = result & toHex(cast[BiggestInt](cast[ptr uint8](slider)[]), - hexSize) - of 2: - result = result & toHex(cast[BiggestInt](cast[ptr uint16](slider)[]), - hexSize) - of 4: - result = result & toHex(cast[BiggestInt](cast[ptr uint32](slider)[]), - hexSize) - of 8: - result = result & toHex(cast[BiggestInt](cast[ptr uint64](slider)[]), - hexSize) - else: - raise newException(ValueError, "Wrong items size!") - result = result & " " - slider = cast[pointer](cast[uint](slider) + items.uint) - i = i + items - if i %% 16 == 0: - result = result & " " & asciiText - asciiText.setLen(0) - result = result & "\n" - - if i %% 16 != 0: - var spacesCount = ((16 - (i %% 16)) div items) * (hexSize + 1) + 1 - result = result & repeat(' ', spacesCount) - result = result & asciiText - result = result & "\n" - -proc dumpHex*[T](v: openarray[T], items: int = 0, ascii = true): string = - ## Return hexadecimal memory dump representation of openarray[T] ``v``. - ## ``items`` - number of bytes in group (supported ``items`` count is - ## 0, 1, 2, 4, 8). If ``items`` is ``0`` group size will depend on - ## ``sizeof(T)``. - ## ``ascii`` - if ``true`` show ASCII representation of memory dump. - var i = 0 - if items == 0: - when sizeof(T) == 2: - i = 2 - elif sizeof(T) == 4: - i = 4 - elif sizeof(T) == 8: - i = 8 - else: - i = 1 - else: - i = items - result = dumpHex(unsafeAddr v[0], sizeof(T) * len(v), i, ascii) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index c25cbc2..18f7c5d 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -11,7 +11,7 @@ import bearssl, bearssl/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common -import strutils, hexdump +import strutils type TLSStreamKind {.pure.} = enum @@ -110,18 +110,16 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var length: uint while true: var state = engine.sslEngineCurrentState() - echo "tlsWriteLoop() state = ", getStringState(state) if (state and SSL_CLOSED) == SSL_CLOSED: wstream.state = AsyncStreamState.Finished break if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: - echo "tlsWriteLoop() signal to tlsReadLoop()" - wstream.switchToReader.fire() + if not(wstream.switchToReader.isSet()): + wstream.switchToReader.fire() if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: - echo "tlsWriteLoop() waiting" await wstream.switchToWriter.wait() wstream.switchToWriter.clear() # We need to refresh `state` because we just returned from readerLoop. @@ -141,9 +139,7 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = if (state and SSL_SENDAPP) == SSL_SENDAPP: # Application data can be sent over stream. - echo "tlsWriteLoop() waiting for an item" var item = await wstream.queue.get() - echo "tlsWriteLoop() obtained an item" if item.size > 0: length = 0'u var buf = sslEngineSendappBuf(engine, length) @@ -209,7 +205,6 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var length: uint while true: var state = engine.sslEngineCurrentState() - echo "tlsReadLoop() state = ", getStringState(state) if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() @@ -222,11 +217,10 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = break if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: - echo "tlsReadLoop() signal to tlsWriteLoop()" - rstream.switchToWriter.fire() + if not(rstream.switchToWriter.isSet()): + rstream.switchToWriter.fire() if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: - echo "tlsReadLoop() waiting" await rstream.switchToReader.wait() rstream.switchToReader.clear() # We need to refresh `state` because we just returned from writerLoop. @@ -236,9 +230,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = # TLS records required for further processing length = 0'u var buf = sslEngineRecvrecBuf(engine, length) - echo "tlsReadLoop() reading" var resFut = awaitne rstream.rsource.readOnce(buf, int(length)) - echo "tlsReadLoop() read completed" if resFut.failed(): rstream.error = resFut.error rstream.state = AsyncStreamState.Error @@ -255,10 +247,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = # Application data can be recovered. length = 0'u var buf = sslEngineRecvappBuf(engine, length) - echo "tlsReadLoop(SSL_RECVAPP) received ", length, " bytes" await upload(addr rstream.buffer, buf, int(length)) - echo dumpHex(buf, int(length)) - echo "tlsReadLoop(SSL_RECVAPP) uploaded ", length, " bytes to buffer" sslEngineRecvappAck(engine, length) continue From a92ad6d2d2cd27240da8784089c0d160ef9c6e03 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 16 Oct 2019 09:01:52 +0300 Subject: [PATCH 09/12] Add TLS inbound stream. Fix some review comments. --- chronos/streams/asyncstream.nim | 15 +- chronos/streams/tlsstream.nim | 475 ++++++++++++++++++++++++++------ tests/testasyncstream.nim | 105 ++++++- 3 files changed, 505 insertions(+), 90 deletions(-) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index c1a1554..3eeb495 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -154,10 +154,8 @@ proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte, copyMem(addr sb[].buffer[sb.offset], pbytes, size) sb[].offset = sb[].offset + size length = length - size - - if length == 0: - # We notify consumers that new data is available. - sb[].forget() + # We notify consumers that new data is available. + sb[].forget() template toDataOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, 0, sb.offset - 1) @@ -165,6 +163,15 @@ template toDataOpenArray*(sb: AsyncBuffer): auto = template toBufferOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1) +template harvestItem*(dest: pointer, item: WriteItem, length: int) = + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(dest, p, length) + elif item.kind == Sequence: + copyMem(dest, unsafeAddr item.data2[item.offset], length) + elif item.kind == String: + copyMem(dest, unsafeAddr item.data3[item.offset], length) + proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} = var w = newException(AsyncStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 18f7c5d..a31e603 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -7,11 +7,11 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -## This module implements TLS stream reading and writing. +## This module implements Transport Layer Security (TLS) stream. This module +## uses sources of BearSSL by Thomas Pornin. import bearssl, bearssl/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common -import strutils type TLSStreamKind {.pure.} = enum @@ -21,46 +21,79 @@ type TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303 TLSFlags* {.pure.} = enum - NoVerifyHost, # Client: Skip remote certificate check - NoVerifySN, # Client: Skip Server Name Indication (SNI) check - NoRenegotiation, # Server: Reject renegotiations requests - NoClientAuth, # Server: Disable strict client authentication - FailOnAlpnMismatch # Server: Fail on application protocol mismatch + NoVerifyHost, # Client: Skip remote certificate check + NoVerifyServerName, # Client: Skip Server Name Indication (SNI) check + EnforceServerPref, # Server: Enforce server preferences + NoRenegotiation, # Server: Reject renegotiations requests + TolerateNoClientAuth, # Server: Disable strict client authentication + FailOnAlpnMismatch # Server: Fail on application protocol mismatch -type - TlsStreamWriter* = ref object of AsyncStreamWriter - case kind: TlsStreamKind - of TlsStreamKind.Client: + TLSKeyType {.pure.} = enum + RSA, EC + + TLSPrivateKey* = ref object + case kind: TLSKeyType + of RSA: + rsakey: RsaPrivateKey + of EC: + eckey: EcPrivateKey + storage: seq[byte] + + TLSCertificate* = ref object + certs: seq[X509Certificate] + storage: seq[byte] + + TLSSessionCache* = ref object + storage: seq[byte] + context: SslSessionCacheLru + + PEMElement* = object + name*: string + data*: seq[byte] + + PEMContext = ref object + data: seq[byte] + + TLSStreamWriter* = ref object of AsyncStreamWriter + case kind: TLSStreamKind + of TLSStreamKind.Client: ccontext: ptr SslClientContext - of TlsStreamKind.Server: + of TLSStreamKind.Server: scontext: ptr SslServerContext - stream*: TlsAsyncStream + stream*: TLSAsyncStream switchToReader*: AsyncEvent switchToWriter*: AsyncEvent + handshaked*: bool + handshakeFut*: Future[void] - TlsStreamReader* = ref object of AsyncStreamReader - case kind: TlsStreamKind - of TlsStreamKind.Client: + TLSStreamReader* = ref object of AsyncStreamReader + case kind: TLSStreamKind + of TLSStreamKind.Client: ccontext: ptr SslClientContext - of TlsStreamKind.Server: + of TLSStreamKind.Server: scontext: ptr SslServerContext - stream*: TlsAsyncStream + stream*: TLSAsyncStream switchToReader*: AsyncEvent switchToWriter*: AsyncEvent + handshaked*: bool + handshakeFut*: Future[void] - TlsAsyncStream* = ref object of RootRef + TLSAsyncStream* = ref object of RootRef xwc*: X509NoAnchorContext - context*: SslClientContext + ccontext*: SslClientContext + scontext*: SslServerContext sbuffer*: seq[byte] x509*: X509MinimalContext - reader*: TlsStreamReader - writer*: TlsStreamWriter + reader*: TLSStreamReader + writer*: TLSStreamWriter - TlsStreamError* = object of CatchableError - TlsStreamProtocolError* = object of TlsStreamError + SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream + + TLSStreamError* = object of CatchableError + TLSStreamProtocolError* = object of TLSStreamError errCode*: int -template newTlsStreamProtocolError[T](message: T): ref Exception = +template newTLSStreamProtocolError[T](message: T): ref Exception = var msg = "" var code = 0 when T is string: @@ -73,29 +106,29 @@ template newTlsStreamProtocolError[T](message: T): ref Exception = code = message else: msg.add("Internal Error") - var err = newException(TlsStreamProtocolError, msg) + var err = newException(TLSStreamProtocolError, msg) err.errCode = code err -# proc raiseTlsStreamProtoError*[T](message: T) = -# raise newTlsStreamProtocolError(message) +proc raiseTLSStreamProtoError*[T](message: T) = + raise newTLSStreamProtocolError(message) -proc getStringState*(state: cuint): string = - var n = newSeq[string]() - if (state and SSL_CLOSED) == SSL_CLOSED: - n.add("Closed") - if (state and SSL_SENDREC) == SSL_SENDREC: - n.add("SendRec") - if (state and SSL_RECVREC) == SSL_RECVREC: - n.add("RecvRec") - if (state and SSL_SENDAPP) == SSL_SENDAPP: - n.add("SendApp") - if (state and SSL_RECVAPP) == SSL_RECVAPP: - n.add("RecvApp") - result = "{" & n.join(", ") & "} number (" & $state & ")" +# proc getStringState*(state: cuint): string = +# var n = newSeq[string]() +# if (state and SSL_CLOSED) == SSL_CLOSED: +# n.add("Closed") +# if (state and SSL_SENDREC) == SSL_SENDREC: +# n.add("SendRec") +# if (state and SSL_RECVREC) == SSL_RECVREC: +# n.add("RecvRec") +# if (state and SSL_SENDAPP) == SSL_SENDAPP: +# n.add("SendApp") +# if (state and SSL_RECVAPP) == SSL_RECVAPP: +# n.add("RecvApp") +# result = "{" & n.join(", ") & "} number (" & $state & ")" proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = - var wstream = cast[TlsStreamWriter](stream) + var wstream = cast[TLSStreamWriter](stream) var engine: ptr SslEngineContext var error: ref Exception @@ -139,31 +172,26 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = if (state and SSL_SENDAPP) == SSL_SENDAPP: # Application data can be sent over stream. + if not(wstream.handshaked): + wstream.stream.reader.handshaked = true + wstream.handshaked = true + if not(isNil(wstream.handshakeFut)): + wstream.handshakeFut.complete() + var item = await wstream.queue.get() if item.size > 0: length = 0'u var buf = sslEngineSendappBuf(engine, length) let toWrite = min(int(length), item.size) - + harvestItem(buf, item, toWrite) if int(length) >= item.size: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, item.size) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], item.size) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], item.size) + # BearSSL is ready to accept whole item size. sslEngineSendappAck(engine, uint(item.size)) sslEngineFlush(engine, 0) item.future.complete() else: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, length) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], length) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], length) + # BearSSL is not ready to accept whole item, so we will send only + # part of item and adjust offset. item.offset = item.offset + int(length) item.size = item.size - int(length) wstream.queue.addFirstNoWait(item) @@ -188,13 +216,14 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = let item = wstream.queue.popFirstNoWait() if not(item.future.finished()): item.future.fail(error) + wstream.switchToReader.fire() wstream.stream = nil proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = - var rstream = cast[TlsStreamReader](stream) + var rstream = cast[TLSStreamReader](stream) var engine: ptr SslEngineContext - if rstream.kind == TlsStreamKind.Server: + if rstream.kind == TLSStreamKind.Server: engine = addr rstream.scontext.eng else: engine = addr rstream.ccontext.eng @@ -205,12 +234,17 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var length: uint while true: var state = engine.sslEngineCurrentState() - if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() if err != 0: - rstream.error = newTlsStreamProtocolError(err) + rstream.error = newTLSStreamProtocolError(err) rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() break else: rstream.state = AsyncStreamState.Stopped @@ -234,6 +268,12 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = if resFut.failed(): rstream.error = resFut.error rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() break let res = resFut.read() if res > 0: @@ -257,20 +297,30 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = finally: # Perform TLS cleanup procedure sslEngineClose(engine) - # Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state - # consistent. rstream.buffer.forget() + rstream.switchToWriter.fire() rstream.stream = nil -proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, +proc getSignerAlgo(xc: X509Certificate): int = + ## Get certificate's signing algorithm. + var dc: X509DecoderContext + x509DecoderInit(addr dc, nil, nil) + x509DecoderPush(addr dc, xc.data, xc.dataLen) + let err = x509DecoderLastError(addr dc) + if err != 0: + result = -1 + else: + result = int(x509DecoderGetSignerKeyType(addr dc)) + +proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, serverName: string = "", bufferSize = SSL_BUFSIZE_BIDI, minVersion = TLSVersion.TLS11, maxVersion = TLSVersion.TLS12, - flags: set[TlsFlags] = {}): TlsAsyncStream = - ## Create new TLS asynchronous stream using reading stream ``rsource``, - ## writing stream ``wsource``. + flags: set[TLSFlags] = {}): TLSAsyncStream = + ## Create new TLS asynchronous stream for outbound (client) connections + ## using reading stream ``rsource`` and writing stream ``wsource``. ## ## You can specify remote server name using ``serverName``, if while ## handshake server reports different name you will get an error. If @@ -285,11 +335,11 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. ## ## ``flags`` - custom TLS connection flags. - result = new TlsAsyncStream - var reader = new TlsStreamReader - reader.kind = TlsStreamKind.Client - var writer = new TlsStreamWriter - writer.kind = TlsStreamKind.Client + result = new TLSAsyncStream + var reader = new TLSStreamReader + reader.kind = TLSStreamKind.Client + var writer = new TLSStreamWriter + writer.kind = TLSStreamKind.Client var switchToWriter = newAsyncEvent() var switchToReader = newAsyncEvent() reader.stream = result @@ -300,35 +350,292 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, writer.switchToWriter = switchToWriter result.reader = reader result.writer = writer - reader.ccontext = addr result.context - writer.ccontext = addr result.context + reader.ccontext = addr result.ccontext + writer.ccontext = addr result.ccontext if TLSFlags.NoVerifyHost in flags: - sslClientInitFull(addr result.context, addr result.x509, nil, 0) + sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0) initNoAnchor(addr result.xwc, addr result.x509.vtable) - sslEngineSetX509(addr result.context.eng, addr result.xwc.vtable) + sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable) else: - sslClientInitFull(addr result.context, addr result.x509, + sslClientInitFull(addr result.ccontext, addr result.x509, unsafeAddr MozillaTrustAnchors[0], len(MozillaTrustAnchors)) let size = max(SSL_BUFSIZE_BIDI, bufferSize) result.sbuffer = newSeq[byte](size) - sslEngineSetBuffer(addr result.context.eng, addr result.sbuffer[0], + sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0], uint(len(result.sbuffer)), 1) - sslEngineSetVersions(addr result.context.eng, uint16(minVersion), + sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion), uint16(maxVersion)) - if TLSFlags.NoVerifySN in flags: - let err = sslClientReset(addr result.context, "", 0) + if TLSFlags.NoVerifyServerName in flags: + let err = sslClientReset(addr result.ccontext, "", 0) if err == 0: - raise newException(TlsStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamError, "Could not initialize TLS layer") else: - let err = sslClientReset(addr result.context, serverName, 0) + let err = sslClientReset(addr result.ccontext, serverName, 0) if err == 0: - raise newException(TlsStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamError, "Could not initialize TLS layer") init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, bufferSize) init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, bufferSize) + +proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + privateKey: TLSPrivateKey, + certificate: TLSCertificate, + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS11, + maxVersion = TLSVersion.TLS12, + cache: TLSSessionCache = nil, + flags: set[TLSFlags] = {}): TLSAsyncStream = + ## Create new TLS asynchronous stream for inbound (server) connections + ## using reading stream ``rsource`` and writing stream ``wsource``. + ## + ## You need to specify local private key ``privateKey`` and certificate + ## ``certificate``. + ## + ## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding + ## incoming data. + ## + ## ``minVersion`` and ``maxVersion`` are TLS versions which will be used + ## for handshake with remote server. If server's version will be lower then + ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. + ## + ## ``flags`` - custom TLS connection flags. + if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}: + raiseTLSStreamProtoError("Incorrect private key") + if isNil(certificate) or len(certificate.certs) == 0: + raiseTLSStreamProtoError("Incorrect certificate") + + result = new TLSAsyncStream + var reader = new TLSStreamReader + reader.kind = TLSStreamKind.Server + var writer = new TLSStreamWriter + writer.kind = TLSStreamKind.Server + var switchToWriter = newAsyncEvent() + var switchToReader = newAsyncEvent() + reader.stream = result + writer.stream = result + reader.switchToReader = switchToReader + reader.switchToWriter = switchToWriter + writer.switchToReader = switchToReader + writer.switchToWriter = switchToWriter + result.reader = reader + result.writer = writer + reader.scontext = addr result.scontext + writer.scontext = addr result.scontext + + if privateKey.kind == TLSKeyType.EC: + let algo = getSignerAlgo(certificate.certs[0]) + if algo == -1: + raiseTLSStreamProtoError("Could not decode certificate") + sslServerInitFullEc(addr result.scontext, addr certificate.certs[0], + len(certificate.certs), cuint(algo), + addr privateKey.eckey) + elif privateKey.kind == TLSKeyType.RSA: + sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0], + len(certificate.certs), addr privateKey.rsakey) + + let size = max(SSL_BUFSIZE_BIDI, bufferSize) + result.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0], + uint(len(result.sbuffer)), 1) + sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion), + uint16(maxVersion)) + + if not isNil(cache): + sslServerSetCache(addr result.scontext, addr cache.context.vtable) + + if TLSFlags.EnforceServerPref in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) + if TLSFlags.NoRenegotiation in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION) + if TLSFlags.TolerateNoClientAuth in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) + if TLSFlags.FailOnAlpnMismatch in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) + + let err = sslServerReset(addr result.scontext) + if err == 0: + raise newException(TLSStreamError, "Could not initialize TLS layer") + + init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + bufferSize) + init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + bufferSize) + +proc copyKey(src: RsaPrivateKey): TLSPrivateKey = + ## Creates copy of RsaPrivateKey ``src``. + var offset = 0 + let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen + result = TLSPrivateKey(kind: TLSKeyType.RSA) + result.storage = newSeq[byte](keySize) + copyMem(addr result.storage[offset], src.p, src.plen) + result.rsakey.p = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.plen = src.plen + offset = offset + src.plen + copyMem(addr result.storage[offset], src.q, src.qlen) + result.rsakey.q = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.qlen = src.qlen + offset = offset + src.qlen + copyMem(addr result.storage[offset], src.dp, src.dplen) + result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.dplen = src.dplen + offset = offset + src.dplen + copyMem(addr result.storage[offset], src.dq, src.dqlen) + result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.dqlen = src.dqlen + offset = offset + src.dqlen + copyMem(addr result.storage[offset], src.iq, src.iqlen) + result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.iqlen = src.iqlen + result.rsakey.nBitlen = src.nBitlen + +proc copyKey(src: EcPrivateKey): TLSPrivateKey = + ## Creates copy of EcPrivateKey ``src``. + var offset = 0 + let keySize = src.xlen + result = TLSPrivateKey(kind: TLSKeyType.EC) + result.storage = newSeq[byte](keySize) + copyMem(addr result.storage[offset], src.x, src.xlen) + result.eckey.x = cast[ptr cuchar](addr result.storage[offset]) + result.eckey.xlen = src.xlen + result.eckey.curve = src.curve + +proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = + ## Initialize TLS private key from array of bytes ``data``. + ## + ## This procedure initializes private key using raw, DER-encoded format, + ## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded). + var ctx: SkeyDecoderContext + if len(data) == 0: + raiseTLSStreamProtoError("Incorrect private key") + skeyDecoderInit(addr ctx) + skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data)) + let err = skeyDecoderLastError(addr ctx) + if err != 0: + raiseTLSStreamProtoError(err) + let keyType = skeyDecoderKeyType(addr ctx) + if keyType == KEYTYPE_RSA: + result = copyKey(ctx.key.rsa) + elif keyType == KEYTYPE_EC: + result = copyKey(ctx.key.ec) + else: + raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + +proc pemDecode*(data: openarray[char]): seq[PEMElement] = + ## Decode PEM encoded string and get array of binary blobs. + if len(data) == 0: + raiseTLSStreamProtoError("Empty PEM message") + var ctx: PemDecoderContext + var pctx = new PEMContext + result = newSeq[PEMElement]() + pemDecoderInit(addr ctx) + + proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} = + var p = cast[PEMContext](ctx) + var o = len(p.data) + p.data.setLen(o + nbytes) + copyMem(addr p.data[o], pbytes, nbytes) + + var length = len(data) + var offset = 0 + var inobj = false + var elem: PEMElement + + while length > 0: + var tlen = pemDecoderPush(addr ctx, + cast[pointer](unsafeAddr data[offset]), length) + offset = offset + tlen + length = length - tlen + + let event = pemDecoderEvent(addr ctx) + if event == PEM_BEGIN_OBJ: + inobj = true + elem.name = $pemDecoderName(addr ctx) + pctx.data = newSeq[byte]() + pemDecoderSetdest(addr ctx, itemAppend, cast[pointer](pctx)) + elif event == PEM_END_OBJ: + if inobj: + elem.data = pctx.data + result.add(elem) + inobj = false + else: + break + else: + raiseTLSStreamProtoError("Invalid PEM encoding") + +proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = + ## Initialize TLS private key from string ``data``. + ## + ## This procedure initializes private key using unencrypted PKCS#8 PEM + ## encoded string. + ## + ## Note that PKCS#1 PEM encoded objects are not supported. + var items = pemDecode(data) + for item in items: + if item.name == "PRIVATE KEY": + result = TLSPrivateKey.init(item.data) + break + if isNil(result): + raiseTLSStreamProtoError("Could not find private key") + +proc init*(tt: typedesc[TLSCertificate], + data: openarray[char]): TLSCertificate = + ## Initialize TLS certificates from string ``data``. + ## + ## This procedure initializes array of certificates from PEM encoded string. + var items = pemDecode(data) + result = new TLSCertificate + for item in items: + if item.name == "CERTIFICATE" and len(item.data) > 0: + let offset = len(result.storage) + result.storage.add(item.data) + let cert = X509Certificate( + data: cast[ptr cuchar](addr result.storage[offset]), + dataLen: len(item.data) + ) + let res = getSignerAlgo(cert) + if res == -1: + raiseTLSStreamProtoError("Could not decode certificate") + elif res != KEYTYPE_RSA and res != KEYTYPE_EC: + raiseTLSStreamProtoError("Unsupported signing key type in certificate") + result.certs.add(cert) + if len(result.storage) == 0: + raiseTLSStreamProtoError("Could not find any certificates") + +proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = + ## Create new TLS session cache with size ``size``. + ## + ## One cached item is near 100 bytes size. + result = new TLSSessionCache + var rsize = min(size, 4096) + result.storage = newSeq[byte](rsize) + sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize) + +proc handshake*(rws: SomeTLSStreamType): Future[void] = + ## Wait until initial TLS handshake will be successfully performed. + var retFuture = newFuture[void]("tlsstream.handshake") + when rws is TLSStreamReader: + if rws.handshaked: + retFuture.complete() + else: + rws.handshakeFut = retFuture + rws.stream.writer.handshakeFut = retFuture + elif rws is TLSStreamWriter: + if rws.handshaked: + retFuture.complete() + else: + rws.handshakeFut = retFuture + rws.stream.reader.handshakeFut = retFuture + elif rws is TLSAsyncStream: + if rws.reader.handshaked: + retFuture.complete() + else: + rws.reader.handshakeFut = retFuture + rws.writer.handshakeFut = retFuture + return retFuture diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 6343ceb..038aec1 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -8,6 +8,62 @@ import strutils, unittest, os import ../chronos, ../chronos/streams/tlsstream +const SelfSignedRsaKey = """ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDSXcKMR6zIIHSy ++UUjgyIlJWzEu3JFCEN+qBpjmwbuae3GNed0YgPels6AKe0AlWQNgpqoWfMQrWco +qNf9dcm9OIdUK5FGMWYC8mJu6OjwnexSXt0R/k07lc5ePPDUbGvkbvgDyHzl1Ynj +jYhe/2ujL0E99/KAuA7cvRILEl4rLpnngE+MMYNrFWmsSDzC+w0Hv5Fuyc8wZoSy +nzODvojvFXTva9Nx4LxPF1W79WidwJHwrJghVlGUUhSeRGLHQWRY+954rj6TavoJ +6gLYSHx6ELnRkFMbqxRrCJ0wAbrV8SwcPHjKSc6LQGRMzBRqxT45DoKmzH9Pwdnw +6P0PGJyPAgMBAAECggEAd5Ck39hpIwIXchXtrwZ8ZMKFtLeZdhUBT7656P0XDnEU +nQDMQcDn1B7A1eV+eENwr6EYyDD/zu3P4TM+OCg3dp3nhPaSRmQTR/995O3qX8BS +rmqOmgiA2yoFNljKxOGu3RIZUwUjv/oDulsaNGxWQFS+bzs7EOAMSngIBlT1QvLc +1etvGOW6hc4nzacrXpzhzWem6EzOabPZBmIy0DDz2ARlND1YSd2P5IMpOv0terNF +ZwYl5SZ6Rnbv6GJWKl5IIpJOOtRtwyIhKNU/bd835vOfW07aaHVAT6GNYlyEoGWT +36UjOyl3YxSLNvQOeIz4Y0n+vupBu/YL+mFtQdxJgQKBgQDtjEi0cNlt+Vz5sp4q +wAVBJ/6h7hu3KJkY+xDpdrLyFTcxttKM9q/dR8V2bEaqMYT/mTvADfmW2BBcfi2J +3VdR1lQ5pXeKAuxt1/Vc+Q4UCnX5OX7UXpP9aSetDdo5FXUC9X2H0hO0BqOcc+h2 +khVyXjKt6TdBwV94dP9bmQp9QQKBgQDitPVqRHGepYBYasScyTUXPp7vL9T7PMSu +PGjqEkwvauhICpUbWpE/j8M0UXk64zwSmOYwQ37uiPpws88GL57oy1ZQZnhF/Hi3 +tM00Mn4x0xbyONbWu/AcFIZwSeSL6QhHYfeyVj7Jb/lqUg8sMiGmO25JjlAQBfTb +vvBgEpcVzwKBgQCwgz87JWfLejIGMR2qcoj1A30IYmAh1377uwO0F0mc7PrYbBtE +N8IyUTR/bLGNocJME1b8vOWrmt19fRzlhp1t6C8prrSGzulULdbawQ4fAi7rhDek +Iqsg8FRVGSgAptsN2dDvbcDKUuycQtyHzsE0/J338IXozIHehkGBlNTggQKBgQCF +RDTj5BoaVVWuJA0x0UGJSYFqP2bmzWEcv1w5BMqOMT0cZEQkkUfC4oKwdZhbGosM +r57ZDkRGenUl3T08eK/kTuuNVb8r/O8Fpp3eKjRum5TojKsWDeJmz1X8GiPkbvcz +5w4RYouEJHOsoVJT+6A2NMdvK946nRXEO2jYQPVZlwKBgBro8qGm0+T0T2xNP21q +IzjP/EHT7iIkM5kiUCc2bPIrfzAGxXImakDzd6AgpgxhhficJOpp792Upe/b/Hwy +bwfmbdWlT7/hPCnlVVH2dgO/ysDyEfxPigBMd+MmucRm6fzGIU7XSQw4KJqH4vQN +9IASWlgzyQ1RytAduzRuepzB +-----END PRIVATE KEY----- +""" + +const SelfSignedRsaCert = """ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUEFovLJkPSn4T8BBZMYBrXPDajF0wDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTE5MTAxMjA1NDQ0N1oXDTIwMTAxMTA1NDQ0OFowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA0l3CjEesyCB0svlFI4MiJSVsxLtyRQhDfqgaY5sG7mntxjXndGID +3pbOgCntAJVkDYKaqFnzEK1nKKjX/XXJvTiHVCuRRjFmAvJibujo8J3sUl7dEf5N +O5XOXjzw1Gxr5G74A8h85dWJ442IXv9roy9BPffygLgO3L0SCxJeKy6Z54BPjDGD +axVprEg8wvsNB7+RbsnPMGaEsp8zg76I7xV072vTceC8TxdVu/VoncCR8KyYIVZR +lFIUnkRix0FkWPveeK4+k2r6CeoC2Eh8ehC50ZBTG6sUawidMAG61fEsHDx4yknO +i0BkTMwUasU+OQ6Cpsx/T8HZ8Oj9DxicjwIDAQABo1MwUTAdBgNVHQ4EFgQUMM+1 +FZ6KmN2eCJfDxY+8xa1JKnYwHwYDVR0jBBgwFoAUMM+1FZ6KmN2eCJfDxY+8xa1J +KnYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEASglh98fXvPwA +KMaEezCUqTeE7DehLlhZ8n6ETKaBDcP3JR4+KTEh9y7gRGJ7DXGFAYfU3itgjyZo +kbgpZIhrTKyYCAsF96Q1mHf/cBQ96UXr0U0SbYXSSJFeeMthMvki556dJZajtxcA +9xR/U0PxPjhC9NIfpVSAv/7ocnXh73qOiFHoN9Cr2smzcGPxsifys2iv1qm5LwDr +Dx5h/RfyfuAjS8e1ZCAhS++PYjb8BX54NilW2lTYF3pwpXL8znc4eBmklBkw5L60 +99jrK7LSQT9Nk8Mf9t4P/77N4hXCqsHIxZIqJlbdgdKfvBF3vRomxm3/aWtGlTVD +vvzZPnlYfQ== +-----END CERTIFICATE----- +""" + suite "AsyncStream test suite": test "AsyncStream(StreamTransport) readExactly() test": proc testReadExactly(address: TransportAddress): Future[bool] {.async.} = @@ -524,7 +580,7 @@ suite "TLSStream test suite": var transp = await connect(address) var reader = newAsyncStreamReader(transp) var writer = newAsyncStreamWriter(transp) - var tlsstream = newTlsClientAsyncStream(reader, writer, name) + var tlsstream = newTLSClientAsyncStream(reader, writer, name) await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & "\r\nConnection: close\r\n\r\n") @@ -548,7 +604,52 @@ suite "TLSStream test suite": "www.google.com")) check res == true - test "TlsStream leaks test": + proc checkSSLServer(address: TransportAddress, + pemkey, pemcert: string): Future[bool] {.async.} = + var key: TLSPrivateKey + var cert: TLSCertificate + let testMessage = "TEST MESSAGE" + + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var reader = newAsyncStreamReader(transp) + var writer = newAsyncStreamWriter(transp) + var sstream = newTLSServerAsyncStream(reader, writer, key, cert) + await handshake(sstream) + await sstream.writer.write(testMessage & "\r\n") + await sstream.writer.closeWait() + await sstream.reader.closeWait() + await reader.closeWait() + await writer.closeWait() + await transp.closeWait() + server.stop() + server.close() + + key = TLSPrivateKey.init(pemkey) + cert = TLSCertificate.init(pemcert) + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + var conn = await connect(address) + var creader = newAsyncStreamReader(conn) + var cwriter = newAsyncStreamWriter(conn) + # We are using self-signed certificate + var cstream = newTLSClientAsyncStream(creader, cwriter, + flags = {NoVerifyHost}) + let res = await cstream.reader.readLine() + await cstream.reader.closeWait() + await cstream.writer.closeWait() + await creader.closeWait() + await cwriter.closeWait() + await conn.closeWait() + await server.join() + result = res == testMessage + + test "Simple server with RSA self-signed certificate": + let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"), + SelfSignedRsaKey, SelfSignedRsaCert)) + check res == true + test "TLSStream leaks test": check: getTracker("async.stream.reader").isLeaked() == false getTracker("async.stream.writer").isLeaked() == false From d008fa2087ad2739e0e49558712c618e1ccefc4e Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 16 Oct 2019 09:07:46 +0300 Subject: [PATCH 10/12] Fix make serverName mandatory and check for empty serverName. --- chronos/streams/tlsstream.nim | 19 ++++--------------- tests/testasyncstream.nim | 4 ++-- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index a31e603..8605c38 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -113,20 +113,6 @@ template newTLSStreamProtocolError[T](message: T): ref Exception = proc raiseTLSStreamProtoError*[T](message: T) = raise newTLSStreamProtocolError(message) -# proc getStringState*(state: cuint): string = -# var n = newSeq[string]() -# if (state and SSL_CLOSED) == SSL_CLOSED: -# n.add("Closed") -# if (state and SSL_SENDREC) == SSL_SENDREC: -# n.add("SendRec") -# if (state and SSL_RECVREC) == SSL_RECVREC: -# n.add("RecvRec") -# if (state and SSL_SENDAPP) == SSL_SENDAPP: -# n.add("SendApp") -# if (state and SSL_RECVAPP) == SSL_RECVAPP: -# n.add("RecvApp") -# result = "{" & n.join(", ") & "} number (" & $state & ")" - proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var wstream = cast[TLSStreamWriter](stream) var engine: ptr SslEngineContext @@ -314,7 +300,7 @@ proc getSignerAlgo(xc: X509Certificate): int = proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, - serverName: string = "", + serverName: string, bufferSize = SSL_BUFSIZE_BIDI, minVersion = TLSVersion.TLS11, maxVersion = TLSVersion.TLS12, @@ -374,6 +360,9 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") else: + if len(serverName) == 0: + raise newException(TLSStreamError, "serverName must not be empty string") + let err = sslClientReset(addr result.ccontext, serverName, 0) if err == 0: raise newException(TLSStreamError, "Could not initialize TLS layer") diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 038aec1..7b52df5 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -634,8 +634,8 @@ suite "TLSStream test suite": var creader = newAsyncStreamReader(conn) var cwriter = newAsyncStreamWriter(conn) # We are using self-signed certificate - var cstream = newTLSClientAsyncStream(creader, cwriter, - flags = {NoVerifyHost}) + let flags = {NoVerifyHost, NoVerifyServerName} + var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags) let res = await cstream.reader.readLine() await cstream.reader.closeWait() await cstream.writer.closeWait() From 9ce714108790bde35307f789090c5347165557e0 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 17 Oct 2019 14:44:14 +0300 Subject: [PATCH 11/12] Fix cancel() issue. Fix asyncstream.nim not propagating cancellation. Fix tlsstream.nim to proper propagate cancellation. Fix tlsstream.nim stuck on close. --- chronos/asyncfutures2.nim | 8 ++++-- chronos/streams/asyncstream.nim | 32 +++++++++++++++++++++--- chronos/streams/tlsstream.nim | 44 ++++++++++++--------------------- 3 files changed, 51 insertions(+), 33 deletions(-) diff --git a/chronos/asyncfutures2.nim b/chronos/asyncfutures2.nim index f9dd176..cd2fd06 100644 --- a/chronos/asyncfutures2.nim +++ b/chronos/asyncfutures2.nim @@ -254,15 +254,19 @@ proc cancel[T](future: Future[T], loc: ptr SrcLoc) = else: var first = FutureBase(future) var last = first - while (not isNil(last.child)) and (not(last.child.finished())): + while not(isNil(last.child)) and not(last.child.cancelled()): last = last.child if last == first: checkFinished(future, loc) + let isPending = (last.state == FutureState.Pending) last.state = FutureState.Cancelled last.error = newException(CancelledError, "") if not(isNil(last.cancelcb)): last.cancelcb(cast[pointer](last)) - last.callbacks.call() + if isPending: + # If Future's state was `Finished` or `Failed` callbacks are already + # scheduled. + last.callbacks.call() template cancel*[T](future: Future[T]) = ## Cancel ``future``. diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index 3eeb495..b8c7253 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -295,6 +295,8 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, if isNil(rstream.rsource): try: await readExactly(rstream.tsource, pbytes, nbytes) + except CancelledError: + raise except TransportIncompleteError: raise newAsyncStreamIncompleteError() except: @@ -333,10 +335,12 @@ proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer, if isNil(rstream.rsource): try: result = await readOnce(rstream.tsource, pbytes, nbytes) + except CancelledError: + raise except: raise newAsyncStreamReadError(getCurrentException()) else: - if isNil(rstream.rsource): + if isNil(rstream.readerLoop): result = await readOnce(rstream.rsource, pbytes, nbytes) else: while true: @@ -376,6 +380,8 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, if isNil(rstream.rsource): try: result = await readUntil(rstream.tsource, pbytes, nbytes, sep) + except CancelledError: + raise except TransportIncompleteError: raise newAsyncStreamIncompleteError() except TransportLimitError: @@ -441,6 +447,8 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0, if isNil(rstream.rsource): try: result = await readLine(rstream.tsource, limit, sep) + except CancelledError: + raise except: raise newAsyncStreamReadError(getCurrentException()) else: @@ -494,6 +502,8 @@ proc read*(rstream: AsyncStreamReader, n = 0): Future[seq[byte]] {.async.} = if isNil(rstream.rsource): try: result = await read(rstream.tsource, n) + except CancelledError: + raise except: raise newAsyncStreamReadError(getCurrentException()) else: @@ -542,6 +552,8 @@ proc consume*(rstream: AsyncStreamReader, n = -1): Future[int] {.async.} = if isNil(rstream.rsource): try: result = await consume(rstream.tsource, n) + except CancelledError: + raise except TransportLimitError: raise newAsyncStreamLimitError() except: @@ -594,6 +606,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, var res: int try: res = await write(wstream.tsource, pbytes, nbytes) + except CancelledError: + raise except: raise newAsyncStreamWriteError(getCurrentException()) if res != nbytes: @@ -609,6 +623,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -633,6 +649,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], var res: int try: res = await write(wstream.tsource, sbytes, msglen) + except CancelledError: + raise except: raise newAsyncStreamWriteError(getCurrentException()) if res != length: @@ -651,6 +669,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -674,6 +694,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, var res: int try: res = await write(wstream.tsource, sbytes, msglen) + except CancelledError: + raise except: raise newAsyncStreamWriteError(getCurrentException()) if res != length: @@ -692,6 +714,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -710,6 +734,8 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} = await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -760,8 +786,8 @@ proc close*(rw: AsyncStreamRW) = if rw.future.finished(): callSoon(continuation) else: - rw.future.cancel() rw.future.addCallback(continuation) + rw.future.cancel() elif rw is AsyncStreamWriter: if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future): callSoon(continuation) @@ -769,8 +795,8 @@ proc close*(rw: AsyncStreamRW) = if rw.future.finished(): callSoon(continuation) else: - rw.future.cancel() rw.future.addCallback(continuation) + rw.future.cancel() proc closeWait*(rw: AsyncStreamRW): Future[void] = ## Close and frees resources of stream ``rw``. diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 8605c38..16dc5a3 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -150,7 +150,9 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = var buf = sslEngineSendrecBuf(engine, length) doAssert(length != 0 and not isNil(buf)) var fut = awaitne wstream.wsource.write(buf, int(length)) - if fut.failed(): + if fut.cancelled(): + raise fut.error + elif fut.failed(): error = fut.error break sslEngineSendrecAck(engine, length) @@ -202,7 +204,6 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = let item = wstream.queue.popFirstNoWait() if not(item.future.finished()): item.future.fail(error) - wstream.switchToReader.fire() wstream.stream = nil proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = @@ -223,18 +224,9 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() if err != 0: - rstream.error = newTLSStreamProtocolError(err) - rstream.state = AsyncStreamState.Error - if not(rstream.handshaked): - rstream.handshaked = true - rstream.stream.writer.handshaked = true - if not(isNil(rstream.handshakeFut)): - rstream.handshakeFut.fail(rstream.error) - rstream.switchToWriter.fire() - break - else: - rstream.state = AsyncStreamState.Stopped - break + raise newTLSStreamProtocolError(err) + rstream.state = AsyncStreamState.Stopped + break if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: if not(rstream.switchToWriter.isSet()): @@ -250,18 +242,7 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = # TLS records required for further processing length = 0'u var buf = sslEngineRecvrecBuf(engine, length) - var resFut = awaitne rstream.rsource.readOnce(buf, int(length)) - if resFut.failed(): - rstream.error = resFut.error - rstream.state = AsyncStreamState.Error - if not(rstream.handshaked): - rstream.handshaked = true - rstream.stream.writer.handshaked = true - if not(isNil(rstream.handshakeFut)): - rstream.handshakeFut.fail(rstream.error) - rstream.switchToWriter.fire() - break - let res = resFut.read() + let res = await rstream.rsource.readOnce(buf, int(length)) if res > 0: sslEngineRecvrecAck(engine, uint(res)) continue @@ -279,12 +260,19 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = except CancelledError: rstream.state = AsyncStreamState.Stopped - + except AsyncStreamReadError: + rstream.error = getCurrentException() + rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() finally: # Perform TLS cleanup procedure sslEngineClose(engine) rstream.buffer.forget() - rstream.switchToWriter.fire() rstream.stream = nil proc getSignerAlgo(xc: X509Certificate): int = From 368502c10b945fa2b483cf33886402cb8e18b55a Mon Sep 17 00:00:00 2001 From: cheatfate Date: Fri, 18 Oct 2019 19:24:58 +0300 Subject: [PATCH 12/12] Rename harvestItem to copyOut. --- chronos/streams/asyncstream.nim | 2 +- chronos/streams/tlsstream.nim | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index b8c7253..8a77ad6 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -163,7 +163,7 @@ template toDataOpenArray*(sb: AsyncBuffer): auto = template toBufferOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1) -template harvestItem*(dest: pointer, item: WriteItem, length: int) = +template copyOut*(dest: pointer, item: WriteItem, length: int) = if item.kind == Pointer: let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) copyMem(dest, p, length) diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 16dc5a3..14c1ff3 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -171,7 +171,7 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = length = 0'u var buf = sslEngineSendappBuf(engine, length) let toWrite = min(int(length), item.size) - harvestItem(buf, item, toWrite) + copyOut(buf, item, toWrite) if int(length) >= item.size: # BearSSL is ready to accept whole item size. sslEngineSendappAck(engine, uint(item.size))