From c27624cfc047d1487819ed25207e93f55eb2589d Mon Sep 17 00:00:00 2001 From: cheatfate Date: Tue, 8 Oct 2019 18:46:27 +0300 Subject: [PATCH] Add TlsStream with client-only connections. --- chronos.nimble | 5 +- chronos/streams/tlsstream.nim | 276 ++++++++++++++++++++++++++++++++++ tests/testasyncstream.nim | 50 +++++- 3 files changed, 328 insertions(+), 3 deletions(-) create mode 100644 chronos/streams/tlsstream.nim diff --git a/chronos.nimble b/chronos.nimble index b397996..82cb803 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,5 +1,5 @@ packageName = "chronos" -version = "2.3.0" +version = "2.3.1" author = "Status Research & Development GmbH" description = "Chronos" license = "Apache License 2.0 or MIT" @@ -7,7 +7,8 @@ skipDirs = @["tests"] ### Dependencies -requires "nim > 0.18.0" +requires "nim > 0.19.4", + "bearssl" task test, "Run all tests": var commands = [ diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim new file mode 100644 index 0000000..345b3aa --- /dev/null +++ b/chronos/streams/tlsstream.nim @@ -0,0 +1,276 @@ +# +# Chronos Asynchronous TLS Stream +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module implements TLS stream reading and writing. +import bearssl, bearssl/cacert +import ../asyncloop, ../timer, ../asyncsync +import asyncstream, ../transports/stream, ../transports/common +import strutils +import hexdump + +type + TLSStreamKind {.pure.} = enum + Client, Server + + TLSVersion* {.pure.} = enum + TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303 + + TLSFlags* {.pure.} = enum + NoVerifyHost, # Client: Skip remote certificate check + NoVerifySN, # Client: Skip Server Name Indication (SNI) check + NoRenegotiation, # Server: Reject renegotiations requests + NoClientAuth, # Server: Disable strict client authentication + FailOnAlpnMismatch # Server: Fail on application protocol mismatch + +type + TlsStreamWriter* = ref object of AsyncStreamWriter + + TlsStreamReader* = ref object of AsyncStreamReader + case kind: TlsStreamKind + of TlsStreamKind.Client: + ccontext: ptr SslClientContext + of TlsStreamKind.Server: + scontext: ptr SslServerContext + writer*: TlsStreamWriter + + TlsAsyncStream* = ref object of RootRef + xwc*: X509NoAnchorContext + context*: SslClientContext + sbuffer*: seq[byte] + x509*: X509MinimalContext + reader*: TlsStreamReader + writer*: TlsStreamWriter + + TlsStreamError* = object of CatchableError + TlsStreamProtocolError* = object of TlsStreamError + errCode*: int + +template newTlsStreamProtocolError[T](message: T): ref Exception = + var msg = "" + var code = 0 + when T is string: + msg.add(message) + elif T is cint: + msg.add(sslErrorMsg(message) & " (code: " & $int(message) & ")") + code = int(message) + elif T is int: + msg.add(sslErrorMsg(message) & " (code: " & $message & ")") + code = message + else: + msg.add("Internal Error") + var err = newException(TlsStreamProtocolError, msg) + err.errCode = code + err + +# proc raiseTlsStreamProtoError*[T](message: T) = +# raise newTlsStreamProtocolError(message) + +# proc getStringState*(state: cuint): string = +# var n = newSeq[string]() +# if (state and SSL_CLOSED) == SSL_CLOSED: +# n.add("Closed") +# if (state and SSL_SENDREC) == SSL_SENDREC: +# n.add("SendRec") +# if (state and SSL_RECVREC) == SSL_RECVREC: +# n.add("RecvRec") +# if (state and SSL_SENDAPP) == SSL_SENDAPP: +# n.add("SendApp") +# if (state and SSL_RECVAPP) == SSL_RECVAPP: +# n.add("RecvApp") +# result = "{" & n.join(", ") & "} number (" & $state & ")" + +proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = + var wstream = cast[TlsStreamWriter](stream) + try: + # We waiting for empty future which will never be completed, because all + # the logic are inside of tlsReadLoop(). This infinite wait can be stopped + # by closing stream (e.g. cancellation). + var future = newFuture[void]() + await future + except CancelledError: + discard + +proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = + var rstream = cast[TlsStreamReader](stream) + var wstream = rstream.writer + var engine: ptr SslEngineContext + if rstream.kind == TlsStreamKind.Server: + engine = addr rstream.scontext.eng + else: + engine = addr rstream.ccontext.eng + + try: + var length: uint + while true: + var state = engine.sslEngineCurrentState() + if (state and SSL_CLOSED) == SSL_CLOSED: + let err = engine.sslEngineLastError() + if err != 0: + rstream.error = newTlsStreamProtocolError(err) + rstream.state = AsyncStreamState.Error + break + else: + rstream.state = AsyncStreamState.Stopped + break + + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + await wstream.wsource.write(buf, int(length)) + sslEngineSendrecAck(engine, length) + continue + + if (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + if len(wstream.queue) > 0: + var item = await wstream.queue.get() + if item.size > 0: + length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + + if int(length) >= item.size: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, item.size) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], item.size) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], item.size) + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(buf, p, length) + elif item.kind == Sequence: + copyMem(buf, addr item.data2[item.offset], length) + elif item.kind == String: + copyMem(buf, addr item.data3[item.offset], length) + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + continue + else: + # Zero length item means finish + rstream.state = AsyncStreamState.Finished + break + + if (state and SSL_RECVREC) == SSL_RECVREC: + # TLS records required for further processing + length = 0'u + var buf = sslEngineRecvrecBuf(engine, length) + let res = await rstream.rsource.readOnce(buf, int(length)) + if res > 0: + sslEngineRecvrecAck(engine, uint(res)) + continue + else: + rstream.state = AsyncStreamState.Finished + break + + if (state and SSL_RECVAPP) == SSL_RECVAPP: + # Application data can be recovered. + length = 0'u + var buf = sslEngineRecvappBuf(engine, length) + let toRead = min(int(length), rstream.buffer.bufferLen()) + copyMem(rstream.buffer.getBuffer(), buf, toRead) + rstream.buffer.update(toRead) + sslEngineRecvappAck(engine, uint(toRead)) + await rstream.buffer.transfer() + continue + + except CancelledError: + rstream.state = AsyncStreamState.Stopped + + finally: + # Perform TLS cleanup procedure + sslEngineClose(engine) + # Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state + # consistent. + wstream.state = rstream.state + if rstream.state == AsyncStreamState.Finished: + rstream.buffer.forget() + elif rstream.state == AsyncStreamState.Stopped: + rstream.buffer.forget() + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.complete() + elif rstream.state == AsyncStreamState.Error: + rstream.buffer.forget() + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.fail(rstream.error) + +proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + serverName: string = "", + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS11, + maxVersion = TLSVersion.TLS12, + flags: set[TlsFlags] = {}): TlsAsyncStream = + ## Create new TLS asynchronous stream using reading stream ``rsource``, + ## writing stream ``wsource``. + ## + ## You can specify remote server name using ``serverName``, if while + ## handshake server reports different name you will get an error. If + ## ``serverName`` is empty string, remote server name checking will be + ## disabled. + ## + ## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding + ## incoming data. + ## + ## ``minVersion`` and ``maxVersion`` are TLS versions which will be used + ## for handshake with remote server. If server's version will be lower then + ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. + ## + ## ``flags`` - custom TLS connection flags. + result = new TlsAsyncStream + var reader = new TlsStreamReader + reader.kind = TlsStreamKind.Client + var writer = new TlsStreamWriter + result.reader = reader + result.writer = writer + result.reader.writer = writer + reader.ccontext = addr result.context + + if TLSFlags.NoVerifyHost in flags: + sslClientInitFull(addr result.context, addr result.x509, nil, 0) + initNoAnchor(addr result.xwc, addr result.x509.vtable) + sslEngineSetX509(addr result.context.eng, addr result.xwc.vtable) + else: + sslClientInitFull(addr result.context, addr result.x509, + unsafeAddr MozillaTrustAnchors[0], + len(MozillaTrustAnchors)) + + let size = max(SSL_BUFSIZE_BIDI, bufferSize) + result.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr result.context.eng, addr result.sbuffer[0], + uint(len(result.sbuffer)), 1) + sslEngineSetVersions(addr result.context.eng, uint16(minVersion), + uint16(maxVersion)) + + if TLSFlags.NoVerifySN in flags: + let err = sslClientReset(addr result.context, "", 0) + if err == 0: + raise newException(TlsStreamError, "Could not initialize TLS layer") + else: + let err = sslClientReset(addr result.context, serverName, 0) + if err == 0: + raise newException(TlsStreamError, "Could not initialize TLS layer") + + init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + bufferSize) + init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + bufferSize) diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index c6dee9f..6343ceb 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -6,7 +6,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import strutils, unittest, os -import ../chronos +import ../chronos, ../chronos/streams/tlsstream suite "AsyncStream test suite": test "AsyncStream(StreamTransport) readExactly() test": @@ -506,3 +506,51 @@ suite "ChunkedStream test suite": break result = res check waitFor(testVectors2(initTAddress("127.0.0.1:46001"))) == true + + test "ChunkedStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false + +suite "TLSStream test suite": + const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] + test "Simple HTTPS connection": + proc headerClient(address: TransportAddress, + name: string): Future[bool] {.async.} = + var mark = "HTTP/1.1 " + var buffer = newSeq[byte](8192) + var transp = await connect(address) + var reader = newAsyncStreamReader(transp) + var writer = newAsyncStreamWriter(transp) + var tlsstream = newTlsClientAsyncStream(reader, writer, name) + + await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & + "\r\nConnection: close\r\n\r\n") + var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer), + HttpHeadersMark) + let res = await withTimeout(readFut, 5.seconds) + if res: + var length = readFut.read() + buffer.setLen(length) + if len(buffer) > len(mark): + if equalMem(addr buffer[0], addr mark[0], len(mark)): + result = true + + await tlsstream.reader.closeWait() + await tlsstream.writer.closeWait() + await reader.closeWait() + await writer.closeWait() + await transp.closeWait() + + let res = waitFor(headerClient(resolveTAddress("www.google.com:443")[0], + "www.google.com")) + check res == true + + test "TlsStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false