diff --git a/chronos.nimble b/chronos.nimble index b397996..82cb803 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,5 +1,5 @@ packageName = "chronos" -version = "2.3.0" +version = "2.3.1" author = "Status Research & Development GmbH" description = "Chronos" license = "Apache License 2.0 or MIT" @@ -7,7 +7,8 @@ skipDirs = @["tests"] ### Dependencies -requires "nim > 0.18.0" +requires "nim > 0.19.4", + "bearssl" task test, "Run all tests": var commands = [ diff --git a/chronos/asyncfutures2.nim b/chronos/asyncfutures2.nim index f9dd176..cd2fd06 100644 --- a/chronos/asyncfutures2.nim +++ b/chronos/asyncfutures2.nim @@ -254,15 +254,19 @@ proc cancel[T](future: Future[T], loc: ptr SrcLoc) = else: var first = FutureBase(future) var last = first - while (not isNil(last.child)) and (not(last.child.finished())): + while not(isNil(last.child)) and not(last.child.cancelled()): last = last.child if last == first: checkFinished(future, loc) + let isPending = (last.state == FutureState.Pending) last.state = FutureState.Cancelled last.error = newException(CancelledError, "") if not(isNil(last.cancelcb)): last.cancelcb(cast[pointer](last)) - last.callbacks.call() + if isPending: + # If Future's state was `Finished` or `Failed` callbacks are already + # scheduled. + last.callbacks.call() template cancel*[T](future: Future[T]) = ## Cancel ``future``. diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index ade48f4..8a77ad6 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -141,12 +141,37 @@ proc copyData*(sb: AsyncBuffer, dest: pointer, offset, length: int) {.inline.} = copyMem(cast[pointer](cast[uint](dest) + cast[uint](offset)), unsafeAddr sb.buffer[0], length) +proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte, + nbytes: int): Future[void] {.async.} = + var length = nbytes + while length > 0: + let size = min(length, sb[].bufferLen()) + if size == 0: + # Internal buffer is full, we need to transfer data to consumer. + await sb[].transfer() + continue + else: + copyMem(addr sb[].buffer[sb.offset], pbytes, size) + sb[].offset = sb[].offset + size + length = length - size + # We notify consumers that new data is available. + sb[].forget() + template toDataOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, 0, sb.offset - 1) template toBufferOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1) +template copyOut*(dest: pointer, item: WriteItem, length: int) = + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(dest, p, length) + elif item.kind == Sequence: + copyMem(dest, unsafeAddr item.data2[item.offset], length) + elif item.kind == String: + copyMem(dest, unsafeAddr item.data3[item.offset], length) + proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} = var w = newException(AsyncStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg @@ -270,6 +295,8 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, if isNil(rstream.rsource): try: await readExactly(rstream.tsource, pbytes, nbytes) + except CancelledError: + raise except TransportIncompleteError: raise newAsyncStreamIncompleteError() except: @@ -308,10 +335,12 @@ proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer, if isNil(rstream.rsource): try: result = await readOnce(rstream.tsource, pbytes, nbytes) + except CancelledError: + raise except: raise newAsyncStreamReadError(getCurrentException()) else: - if isNil(rstream.rsource): + if isNil(rstream.readerLoop): result = await readOnce(rstream.rsource, pbytes, nbytes) else: while true: @@ -351,6 +380,8 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, if isNil(rstream.rsource): try: result = await readUntil(rstream.tsource, pbytes, nbytes, sep) + except CancelledError: + raise except TransportIncompleteError: raise newAsyncStreamIncompleteError() except TransportLimitError: @@ -416,6 +447,8 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0, if isNil(rstream.rsource): try: result = await readLine(rstream.tsource, limit, sep) + except CancelledError: + raise except: raise newAsyncStreamReadError(getCurrentException()) else: @@ -469,6 +502,8 @@ proc read*(rstream: AsyncStreamReader, n = 0): Future[seq[byte]] {.async.} = if isNil(rstream.rsource): try: result = await read(rstream.tsource, n) + except CancelledError: + raise except: raise newAsyncStreamReadError(getCurrentException()) else: @@ -517,6 +552,8 @@ proc consume*(rstream: AsyncStreamReader, n = -1): Future[int] {.async.} = if isNil(rstream.rsource): try: result = await consume(rstream.tsource, n) + except CancelledError: + raise except TransportLimitError: raise newAsyncStreamLimitError() except: @@ -569,6 +606,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, var res: int try: res = await write(wstream.tsource, pbytes, nbytes) + except CancelledError: + raise except: raise newAsyncStreamWriteError(getCurrentException()) if res != nbytes: @@ -584,6 +623,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -608,6 +649,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], var res: int try: res = await write(wstream.tsource, sbytes, msglen) + except CancelledError: + raise except: raise newAsyncStreamWriteError(getCurrentException()) if res != length: @@ -626,6 +669,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -649,6 +694,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, var res: int try: res = await write(wstream.tsource, sbytes, msglen) + except CancelledError: + raise except: raise newAsyncStreamWriteError(getCurrentException()) if res != length: @@ -667,6 +714,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -685,6 +734,8 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} = await wstream.queue.put(item) try: await item.future + except CancelledError: + raise except: raise newAsyncStreamWriteError(item.future.error) @@ -735,8 +786,8 @@ proc close*(rw: AsyncStreamRW) = if rw.future.finished(): callSoon(continuation) else: - rw.future.cancel() rw.future.addCallback(continuation) + rw.future.cancel() elif rw is AsyncStreamWriter: if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future): callSoon(continuation) @@ -744,8 +795,8 @@ proc close*(rw: AsyncStreamRW) = if rw.future.finished(): callSoon(continuation) else: - rw.future.cancel() rw.future.addCallback(continuation) + rw.future.cancel() proc closeWait*(rw: AsyncStreamRW): Future[void] = ## Close and frees resources of stream ``rw``. diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim new file mode 100644 index 0000000..14c1ff3 --- /dev/null +++ b/chronos/streams/tlsstream.nim @@ -0,0 +1,618 @@ +# +# Chronos Asynchronous TLS Stream +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module implements Transport Layer Security (TLS) stream. This module +## uses sources of BearSSL by Thomas Pornin. +import bearssl, bearssl/cacert +import ../asyncloop, ../timer, ../asyncsync +import asyncstream, ../transports/stream, ../transports/common + +type + TLSStreamKind {.pure.} = enum + Client, Server + + TLSVersion* {.pure.} = enum + TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303 + + TLSFlags* {.pure.} = enum + NoVerifyHost, # Client: Skip remote certificate check + NoVerifyServerName, # Client: Skip Server Name Indication (SNI) check + EnforceServerPref, # Server: Enforce server preferences + NoRenegotiation, # Server: Reject renegotiations requests + TolerateNoClientAuth, # Server: Disable strict client authentication + FailOnAlpnMismatch # Server: Fail on application protocol mismatch + + TLSKeyType {.pure.} = enum + RSA, EC + + TLSPrivateKey* = ref object + case kind: TLSKeyType + of RSA: + rsakey: RsaPrivateKey + of EC: + eckey: EcPrivateKey + storage: seq[byte] + + TLSCertificate* = ref object + certs: seq[X509Certificate] + storage: seq[byte] + + TLSSessionCache* = ref object + storage: seq[byte] + context: SslSessionCacheLru + + PEMElement* = object + name*: string + data*: seq[byte] + + PEMContext = ref object + data: seq[byte] + + TLSStreamWriter* = ref object of AsyncStreamWriter + case kind: TLSStreamKind + of TLSStreamKind.Client: + ccontext: ptr SslClientContext + of TLSStreamKind.Server: + scontext: ptr SslServerContext + stream*: TLSAsyncStream + switchToReader*: AsyncEvent + switchToWriter*: AsyncEvent + handshaked*: bool + handshakeFut*: Future[void] + + TLSStreamReader* = ref object of AsyncStreamReader + case kind: TLSStreamKind + of TLSStreamKind.Client: + ccontext: ptr SslClientContext + of TLSStreamKind.Server: + scontext: ptr SslServerContext + stream*: TLSAsyncStream + switchToReader*: AsyncEvent + switchToWriter*: AsyncEvent + handshaked*: bool + handshakeFut*: Future[void] + + TLSAsyncStream* = ref object of RootRef + xwc*: X509NoAnchorContext + ccontext*: SslClientContext + scontext*: SslServerContext + sbuffer*: seq[byte] + x509*: X509MinimalContext + reader*: TLSStreamReader + writer*: TLSStreamWriter + + SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream + + TLSStreamError* = object of CatchableError + TLSStreamProtocolError* = object of TLSStreamError + errCode*: int + +template newTLSStreamProtocolError[T](message: T): ref Exception = + var msg = "" + var code = 0 + when T is string: + msg.add(message) + elif T is cint: + msg.add(sslErrorMsg(message) & " (code: " & $int(message) & ")") + code = int(message) + elif T is int: + msg.add(sslErrorMsg(message) & " (code: " & $message & ")") + code = message + else: + msg.add("Internal Error") + var err = newException(TLSStreamProtocolError, msg) + err.errCode = code + err + +proc raiseTLSStreamProtoError*[T](message: T) = + raise newTLSStreamProtocolError(message) + +proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = + var wstream = cast[TLSStreamWriter](stream) + var engine: ptr SslEngineContext + var error: ref Exception + + if wstream.kind == TLSStreamKind.Server: + engine = addr wstream.scontext.eng + else: + engine = addr wstream.ccontext.eng + + wstream.state = AsyncStreamState.Running + + try: + var length: uint + while true: + var state = engine.sslEngineCurrentState() + + if (state and SSL_CLOSED) == SSL_CLOSED: + wstream.state = AsyncStreamState.Finished + break + + if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0: + if not(wstream.switchToReader.isSet()): + wstream.switchToReader.fire() + + if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0: + await wstream.switchToWriter.wait() + wstream.switchToWriter.clear() + # We need to refresh `state` because we just returned from readerLoop. + continue + + if (state and SSL_SENDREC) == SSL_SENDREC: + # TLS record needs to be sent over stream. + length = 0'u + var buf = sslEngineSendrecBuf(engine, length) + doAssert(length != 0 and not isNil(buf)) + var fut = awaitne wstream.wsource.write(buf, int(length)) + if fut.cancelled(): + raise fut.error + elif fut.failed(): + error = fut.error + break + sslEngineSendrecAck(engine, length) + continue + + if (state and SSL_SENDAPP) == SSL_SENDAPP: + # Application data can be sent over stream. + if not(wstream.handshaked): + wstream.stream.reader.handshaked = true + wstream.handshaked = true + if not(isNil(wstream.handshakeFut)): + wstream.handshakeFut.complete() + + var item = await wstream.queue.get() + if item.size > 0: + length = 0'u + var buf = sslEngineSendappBuf(engine, length) + let toWrite = min(int(length), item.size) + copyOut(buf, item, toWrite) + if int(length) >= item.size: + # BearSSL is ready to accept whole item size. + sslEngineSendappAck(engine, uint(item.size)) + sslEngineFlush(engine, 0) + item.future.complete() + else: + # BearSSL is not ready to accept whole item, so we will send only + # part of item and adjust offset. + item.offset = item.offset + int(length) + item.size = item.size - int(length) + wstream.queue.addFirstNoWait(item) + sslEngineSendappAck(engine, length) + continue + else: + # Zero length item means finish + wstream.state = AsyncStreamState.Finished + break + + except CancelledError: + wstream.state = AsyncStreamState.Stopped + + finally: + if wstream.state == AsyncStreamState.Stopped: + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.complete() + elif wstream.state == AsyncStreamState.Error: + while len(wstream.queue) > 0: + let item = wstream.queue.popFirstNoWait() + if not(item.future.finished()): + item.future.fail(error) + wstream.stream = nil + +proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = + var rstream = cast[TLSStreamReader](stream) + var engine: ptr SslEngineContext + + if rstream.kind == TLSStreamKind.Server: + engine = addr rstream.scontext.eng + else: + engine = addr rstream.ccontext.eng + + rstream.state = AsyncStreamState.Running + + try: + var length: uint + while true: + var state = engine.sslEngineCurrentState() + if (state and SSL_CLOSED) == SSL_CLOSED: + let err = engine.sslEngineLastError() + if err != 0: + raise newTLSStreamProtocolError(err) + rstream.state = AsyncStreamState.Stopped + break + + if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0: + if not(rstream.switchToWriter.isSet()): + rstream.switchToWriter.fire() + + if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0: + await rstream.switchToReader.wait() + rstream.switchToReader.clear() + # We need to refresh `state` because we just returned from writerLoop. + continue + + if (state and SSL_RECVREC) == SSL_RECVREC: + # TLS records required for further processing + length = 0'u + var buf = sslEngineRecvrecBuf(engine, length) + let res = await rstream.rsource.readOnce(buf, int(length)) + if res > 0: + sslEngineRecvrecAck(engine, uint(res)) + continue + else: + rstream.state = AsyncStreamState.Finished + break + + if (state and SSL_RECVAPP) == SSL_RECVAPP: + # Application data can be recovered. + length = 0'u + var buf = sslEngineRecvappBuf(engine, length) + await upload(addr rstream.buffer, buf, int(length)) + sslEngineRecvappAck(engine, length) + continue + + except CancelledError: + rstream.state = AsyncStreamState.Stopped + except AsyncStreamReadError: + rstream.error = getCurrentException() + rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() + finally: + # Perform TLS cleanup procedure + sslEngineClose(engine) + rstream.buffer.forget() + rstream.stream = nil + +proc getSignerAlgo(xc: X509Certificate): int = + ## Get certificate's signing algorithm. + var dc: X509DecoderContext + x509DecoderInit(addr dc, nil, nil) + x509DecoderPush(addr dc, xc.data, xc.dataLen) + let err = x509DecoderLastError(addr dc) + if err != 0: + result = -1 + else: + result = int(x509DecoderGetSignerKeyType(addr dc)) + +proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + serverName: string, + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS11, + maxVersion = TLSVersion.TLS12, + flags: set[TLSFlags] = {}): TLSAsyncStream = + ## Create new TLS asynchronous stream for outbound (client) connections + ## using reading stream ``rsource`` and writing stream ``wsource``. + ## + ## You can specify remote server name using ``serverName``, if while + ## handshake server reports different name you will get an error. If + ## ``serverName`` is empty string, remote server name checking will be + ## disabled. + ## + ## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding + ## incoming data. + ## + ## ``minVersion`` and ``maxVersion`` are TLS versions which will be used + ## for handshake with remote server. If server's version will be lower then + ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. + ## + ## ``flags`` - custom TLS connection flags. + result = new TLSAsyncStream + var reader = new TLSStreamReader + reader.kind = TLSStreamKind.Client + var writer = new TLSStreamWriter + writer.kind = TLSStreamKind.Client + var switchToWriter = newAsyncEvent() + var switchToReader = newAsyncEvent() + reader.stream = result + writer.stream = result + reader.switchToReader = switchToReader + reader.switchToWriter = switchToWriter + writer.switchToReader = switchToReader + writer.switchToWriter = switchToWriter + result.reader = reader + result.writer = writer + reader.ccontext = addr result.ccontext + writer.ccontext = addr result.ccontext + + if TLSFlags.NoVerifyHost in flags: + sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0) + initNoAnchor(addr result.xwc, addr result.x509.vtable) + sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable) + else: + sslClientInitFull(addr result.ccontext, addr result.x509, + unsafeAddr MozillaTrustAnchors[0], + len(MozillaTrustAnchors)) + + let size = max(SSL_BUFSIZE_BIDI, bufferSize) + result.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0], + uint(len(result.sbuffer)), 1) + sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion), + uint16(maxVersion)) + + if TLSFlags.NoVerifyServerName in flags: + let err = sslClientReset(addr result.ccontext, "", 0) + if err == 0: + raise newException(TLSStreamError, "Could not initialize TLS layer") + else: + if len(serverName) == 0: + raise newException(TLSStreamError, "serverName must not be empty string") + + let err = sslClientReset(addr result.ccontext, serverName, 0) + if err == 0: + raise newException(TLSStreamError, "Could not initialize TLS layer") + + init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + bufferSize) + init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + bufferSize) + +proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + privateKey: TLSPrivateKey, + certificate: TLSCertificate, + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS11, + maxVersion = TLSVersion.TLS12, + cache: TLSSessionCache = nil, + flags: set[TLSFlags] = {}): TLSAsyncStream = + ## Create new TLS asynchronous stream for inbound (server) connections + ## using reading stream ``rsource`` and writing stream ``wsource``. + ## + ## You need to specify local private key ``privateKey`` and certificate + ## ``certificate``. + ## + ## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding + ## incoming data. + ## + ## ``minVersion`` and ``maxVersion`` are TLS versions which will be used + ## for handshake with remote server. If server's version will be lower then + ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. + ## + ## ``flags`` - custom TLS connection flags. + if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}: + raiseTLSStreamProtoError("Incorrect private key") + if isNil(certificate) or len(certificate.certs) == 0: + raiseTLSStreamProtoError("Incorrect certificate") + + result = new TLSAsyncStream + var reader = new TLSStreamReader + reader.kind = TLSStreamKind.Server + var writer = new TLSStreamWriter + writer.kind = TLSStreamKind.Server + var switchToWriter = newAsyncEvent() + var switchToReader = newAsyncEvent() + reader.stream = result + writer.stream = result + reader.switchToReader = switchToReader + reader.switchToWriter = switchToWriter + writer.switchToReader = switchToReader + writer.switchToWriter = switchToWriter + result.reader = reader + result.writer = writer + reader.scontext = addr result.scontext + writer.scontext = addr result.scontext + + if privateKey.kind == TLSKeyType.EC: + let algo = getSignerAlgo(certificate.certs[0]) + if algo == -1: + raiseTLSStreamProtoError("Could not decode certificate") + sslServerInitFullEc(addr result.scontext, addr certificate.certs[0], + len(certificate.certs), cuint(algo), + addr privateKey.eckey) + elif privateKey.kind == TLSKeyType.RSA: + sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0], + len(certificate.certs), addr privateKey.rsakey) + + let size = max(SSL_BUFSIZE_BIDI, bufferSize) + result.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0], + uint(len(result.sbuffer)), 1) + sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion), + uint16(maxVersion)) + + if not isNil(cache): + sslServerSetCache(addr result.scontext, addr cache.context.vtable) + + if TLSFlags.EnforceServerPref in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) + if TLSFlags.NoRenegotiation in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION) + if TLSFlags.TolerateNoClientAuth in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) + if TLSFlags.FailOnAlpnMismatch in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) + + let err = sslServerReset(addr result.scontext) + if err == 0: + raise newException(TLSStreamError, "Could not initialize TLS layer") + + init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + bufferSize) + init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + bufferSize) + +proc copyKey(src: RsaPrivateKey): TLSPrivateKey = + ## Creates copy of RsaPrivateKey ``src``. + var offset = 0 + let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen + result = TLSPrivateKey(kind: TLSKeyType.RSA) + result.storage = newSeq[byte](keySize) + copyMem(addr result.storage[offset], src.p, src.plen) + result.rsakey.p = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.plen = src.plen + offset = offset + src.plen + copyMem(addr result.storage[offset], src.q, src.qlen) + result.rsakey.q = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.qlen = src.qlen + offset = offset + src.qlen + copyMem(addr result.storage[offset], src.dp, src.dplen) + result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.dplen = src.dplen + offset = offset + src.dplen + copyMem(addr result.storage[offset], src.dq, src.dqlen) + result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.dqlen = src.dqlen + offset = offset + src.dqlen + copyMem(addr result.storage[offset], src.iq, src.iqlen) + result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.iqlen = src.iqlen + result.rsakey.nBitlen = src.nBitlen + +proc copyKey(src: EcPrivateKey): TLSPrivateKey = + ## Creates copy of EcPrivateKey ``src``. + var offset = 0 + let keySize = src.xlen + result = TLSPrivateKey(kind: TLSKeyType.EC) + result.storage = newSeq[byte](keySize) + copyMem(addr result.storage[offset], src.x, src.xlen) + result.eckey.x = cast[ptr cuchar](addr result.storage[offset]) + result.eckey.xlen = src.xlen + result.eckey.curve = src.curve + +proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = + ## Initialize TLS private key from array of bytes ``data``. + ## + ## This procedure initializes private key using raw, DER-encoded format, + ## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded). + var ctx: SkeyDecoderContext + if len(data) == 0: + raiseTLSStreamProtoError("Incorrect private key") + skeyDecoderInit(addr ctx) + skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data)) + let err = skeyDecoderLastError(addr ctx) + if err != 0: + raiseTLSStreamProtoError(err) + let keyType = skeyDecoderKeyType(addr ctx) + if keyType == KEYTYPE_RSA: + result = copyKey(ctx.key.rsa) + elif keyType == KEYTYPE_EC: + result = copyKey(ctx.key.ec) + else: + raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + +proc pemDecode*(data: openarray[char]): seq[PEMElement] = + ## Decode PEM encoded string and get array of binary blobs. + if len(data) == 0: + raiseTLSStreamProtoError("Empty PEM message") + var ctx: PemDecoderContext + var pctx = new PEMContext + result = newSeq[PEMElement]() + pemDecoderInit(addr ctx) + + proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} = + var p = cast[PEMContext](ctx) + var o = len(p.data) + p.data.setLen(o + nbytes) + copyMem(addr p.data[o], pbytes, nbytes) + + var length = len(data) + var offset = 0 + var inobj = false + var elem: PEMElement + + while length > 0: + var tlen = pemDecoderPush(addr ctx, + cast[pointer](unsafeAddr data[offset]), length) + offset = offset + tlen + length = length - tlen + + let event = pemDecoderEvent(addr ctx) + if event == PEM_BEGIN_OBJ: + inobj = true + elem.name = $pemDecoderName(addr ctx) + pctx.data = newSeq[byte]() + pemDecoderSetdest(addr ctx, itemAppend, cast[pointer](pctx)) + elif event == PEM_END_OBJ: + if inobj: + elem.data = pctx.data + result.add(elem) + inobj = false + else: + break + else: + raiseTLSStreamProtoError("Invalid PEM encoding") + +proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = + ## Initialize TLS private key from string ``data``. + ## + ## This procedure initializes private key using unencrypted PKCS#8 PEM + ## encoded string. + ## + ## Note that PKCS#1 PEM encoded objects are not supported. + var items = pemDecode(data) + for item in items: + if item.name == "PRIVATE KEY": + result = TLSPrivateKey.init(item.data) + break + if isNil(result): + raiseTLSStreamProtoError("Could not find private key") + +proc init*(tt: typedesc[TLSCertificate], + data: openarray[char]): TLSCertificate = + ## Initialize TLS certificates from string ``data``. + ## + ## This procedure initializes array of certificates from PEM encoded string. + var items = pemDecode(data) + result = new TLSCertificate + for item in items: + if item.name == "CERTIFICATE" and len(item.data) > 0: + let offset = len(result.storage) + result.storage.add(item.data) + let cert = X509Certificate( + data: cast[ptr cuchar](addr result.storage[offset]), + dataLen: len(item.data) + ) + let res = getSignerAlgo(cert) + if res == -1: + raiseTLSStreamProtoError("Could not decode certificate") + elif res != KEYTYPE_RSA and res != KEYTYPE_EC: + raiseTLSStreamProtoError("Unsupported signing key type in certificate") + result.certs.add(cert) + if len(result.storage) == 0: + raiseTLSStreamProtoError("Could not find any certificates") + +proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = + ## Create new TLS session cache with size ``size``. + ## + ## One cached item is near 100 bytes size. + result = new TLSSessionCache + var rsize = min(size, 4096) + result.storage = newSeq[byte](rsize) + sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize) + +proc handshake*(rws: SomeTLSStreamType): Future[void] = + ## Wait until initial TLS handshake will be successfully performed. + var retFuture = newFuture[void]("tlsstream.handshake") + when rws is TLSStreamReader: + if rws.handshaked: + retFuture.complete() + else: + rws.handshakeFut = retFuture + rws.stream.writer.handshakeFut = retFuture + elif rws is TLSStreamWriter: + if rws.handshaked: + retFuture.complete() + else: + rws.handshakeFut = retFuture + rws.stream.reader.handshakeFut = retFuture + elif rws is TLSAsyncStream: + if rws.reader.handshaked: + retFuture.complete() + else: + rws.reader.handshakeFut = retFuture + rws.writer.handshakeFut = retFuture + return retFuture diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index c6dee9f..7b52df5 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -6,7 +6,63 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import strutils, unittest, os -import ../chronos +import ../chronos, ../chronos/streams/tlsstream + +const SelfSignedRsaKey = """ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDSXcKMR6zIIHSy ++UUjgyIlJWzEu3JFCEN+qBpjmwbuae3GNed0YgPels6AKe0AlWQNgpqoWfMQrWco +qNf9dcm9OIdUK5FGMWYC8mJu6OjwnexSXt0R/k07lc5ePPDUbGvkbvgDyHzl1Ynj +jYhe/2ujL0E99/KAuA7cvRILEl4rLpnngE+MMYNrFWmsSDzC+w0Hv5Fuyc8wZoSy +nzODvojvFXTva9Nx4LxPF1W79WidwJHwrJghVlGUUhSeRGLHQWRY+954rj6TavoJ +6gLYSHx6ELnRkFMbqxRrCJ0wAbrV8SwcPHjKSc6LQGRMzBRqxT45DoKmzH9Pwdnw +6P0PGJyPAgMBAAECggEAd5Ck39hpIwIXchXtrwZ8ZMKFtLeZdhUBT7656P0XDnEU +nQDMQcDn1B7A1eV+eENwr6EYyDD/zu3P4TM+OCg3dp3nhPaSRmQTR/995O3qX8BS +rmqOmgiA2yoFNljKxOGu3RIZUwUjv/oDulsaNGxWQFS+bzs7EOAMSngIBlT1QvLc +1etvGOW6hc4nzacrXpzhzWem6EzOabPZBmIy0DDz2ARlND1YSd2P5IMpOv0terNF +ZwYl5SZ6Rnbv6GJWKl5IIpJOOtRtwyIhKNU/bd835vOfW07aaHVAT6GNYlyEoGWT +36UjOyl3YxSLNvQOeIz4Y0n+vupBu/YL+mFtQdxJgQKBgQDtjEi0cNlt+Vz5sp4q +wAVBJ/6h7hu3KJkY+xDpdrLyFTcxttKM9q/dR8V2bEaqMYT/mTvADfmW2BBcfi2J +3VdR1lQ5pXeKAuxt1/Vc+Q4UCnX5OX7UXpP9aSetDdo5FXUC9X2H0hO0BqOcc+h2 +khVyXjKt6TdBwV94dP9bmQp9QQKBgQDitPVqRHGepYBYasScyTUXPp7vL9T7PMSu +PGjqEkwvauhICpUbWpE/j8M0UXk64zwSmOYwQ37uiPpws88GL57oy1ZQZnhF/Hi3 +tM00Mn4x0xbyONbWu/AcFIZwSeSL6QhHYfeyVj7Jb/lqUg8sMiGmO25JjlAQBfTb +vvBgEpcVzwKBgQCwgz87JWfLejIGMR2qcoj1A30IYmAh1377uwO0F0mc7PrYbBtE +N8IyUTR/bLGNocJME1b8vOWrmt19fRzlhp1t6C8prrSGzulULdbawQ4fAi7rhDek +Iqsg8FRVGSgAptsN2dDvbcDKUuycQtyHzsE0/J338IXozIHehkGBlNTggQKBgQCF +RDTj5BoaVVWuJA0x0UGJSYFqP2bmzWEcv1w5BMqOMT0cZEQkkUfC4oKwdZhbGosM +r57ZDkRGenUl3T08eK/kTuuNVb8r/O8Fpp3eKjRum5TojKsWDeJmz1X8GiPkbvcz +5w4RYouEJHOsoVJT+6A2NMdvK946nRXEO2jYQPVZlwKBgBro8qGm0+T0T2xNP21q +IzjP/EHT7iIkM5kiUCc2bPIrfzAGxXImakDzd6AgpgxhhficJOpp792Upe/b/Hwy +bwfmbdWlT7/hPCnlVVH2dgO/ysDyEfxPigBMd+MmucRm6fzGIU7XSQw4KJqH4vQN +9IASWlgzyQ1RytAduzRuepzB +-----END PRIVATE KEY----- +""" + +const SelfSignedRsaCert = """ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUEFovLJkPSn4T8BBZMYBrXPDajF0wDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTE5MTAxMjA1NDQ0N1oXDTIwMTAxMTA1NDQ0OFowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA0l3CjEesyCB0svlFI4MiJSVsxLtyRQhDfqgaY5sG7mntxjXndGID +3pbOgCntAJVkDYKaqFnzEK1nKKjX/XXJvTiHVCuRRjFmAvJibujo8J3sUl7dEf5N +O5XOXjzw1Gxr5G74A8h85dWJ442IXv9roy9BPffygLgO3L0SCxJeKy6Z54BPjDGD +axVprEg8wvsNB7+RbsnPMGaEsp8zg76I7xV072vTceC8TxdVu/VoncCR8KyYIVZR +lFIUnkRix0FkWPveeK4+k2r6CeoC2Eh8ehC50ZBTG6sUawidMAG61fEsHDx4yknO +i0BkTMwUasU+OQ6Cpsx/T8HZ8Oj9DxicjwIDAQABo1MwUTAdBgNVHQ4EFgQUMM+1 +FZ6KmN2eCJfDxY+8xa1JKnYwHwYDVR0jBBgwFoAUMM+1FZ6KmN2eCJfDxY+8xa1J +KnYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEASglh98fXvPwA +KMaEezCUqTeE7DehLlhZ8n6ETKaBDcP3JR4+KTEh9y7gRGJ7DXGFAYfU3itgjyZo +kbgpZIhrTKyYCAsF96Q1mHf/cBQ96UXr0U0SbYXSSJFeeMthMvki556dJZajtxcA +9xR/U0PxPjhC9NIfpVSAv/7ocnXh73qOiFHoN9Cr2smzcGPxsifys2iv1qm5LwDr +Dx5h/RfyfuAjS8e1ZCAhS++PYjb8BX54NilW2lTYF3pwpXL8znc4eBmklBkw5L60 +99jrK7LSQT9Nk8Mf9t4P/77N4hXCqsHIxZIqJlbdgdKfvBF3vRomxm3/aWtGlTVD +vvzZPnlYfQ== +-----END CERTIFICATE----- +""" suite "AsyncStream test suite": test "AsyncStream(StreamTransport) readExactly() test": @@ -506,3 +562,96 @@ suite "ChunkedStream test suite": break result = res check waitFor(testVectors2(initTAddress("127.0.0.1:46001"))) == true + + test "ChunkedStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false + +suite "TLSStream test suite": + const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] + test "Simple HTTPS connection": + proc headerClient(address: TransportAddress, + name: string): Future[bool] {.async.} = + var mark = "HTTP/1.1 " + var buffer = newSeq[byte](8192) + var transp = await connect(address) + var reader = newAsyncStreamReader(transp) + var writer = newAsyncStreamWriter(transp) + var tlsstream = newTLSClientAsyncStream(reader, writer, name) + + await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & + "\r\nConnection: close\r\n\r\n") + var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer), + HttpHeadersMark) + let res = await withTimeout(readFut, 5.seconds) + if res: + var length = readFut.read() + buffer.setLen(length) + if len(buffer) > len(mark): + if equalMem(addr buffer[0], addr mark[0], len(mark)): + result = true + + await tlsstream.reader.closeWait() + await tlsstream.writer.closeWait() + await reader.closeWait() + await writer.closeWait() + await transp.closeWait() + + let res = waitFor(headerClient(resolveTAddress("www.google.com:443")[0], + "www.google.com")) + check res == true + + proc checkSSLServer(address: TransportAddress, + pemkey, pemcert: string): Future[bool] {.async.} = + var key: TLSPrivateKey + var cert: TLSCertificate + let testMessage = "TEST MESSAGE" + + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var reader = newAsyncStreamReader(transp) + var writer = newAsyncStreamWriter(transp) + var sstream = newTLSServerAsyncStream(reader, writer, key, cert) + await handshake(sstream) + await sstream.writer.write(testMessage & "\r\n") + await sstream.writer.closeWait() + await sstream.reader.closeWait() + await reader.closeWait() + await writer.closeWait() + await transp.closeWait() + server.stop() + server.close() + + key = TLSPrivateKey.init(pemkey) + cert = TLSCertificate.init(pemcert) + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + var conn = await connect(address) + var creader = newAsyncStreamReader(conn) + var cwriter = newAsyncStreamWriter(conn) + # We are using self-signed certificate + let flags = {NoVerifyHost, NoVerifyServerName} + var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags) + let res = await cstream.reader.readLine() + await cstream.reader.closeWait() + await cstream.writer.closeWait() + await creader.closeWait() + await cwriter.closeWait() + await conn.closeWait() + await server.join() + result = res == testMessage + + test "Simple server with RSA self-signed certificate": + let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"), + SelfSignedRsaKey, SelfSignedRsaCert)) + check res == true + test "TLSStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false