From a92ad6d2d2cd27240da8784089c0d160ef9c6e03 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 16 Oct 2019 09:01:52 +0300 Subject: [PATCH] Add TLS inbound stream. Fix some review comments. --- chronos/streams/asyncstream.nim | 15 +- chronos/streams/tlsstream.nim | 475 ++++++++++++++++++++++++++------ tests/testasyncstream.nim | 105 ++++++- 3 files changed, 505 insertions(+), 90 deletions(-) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index c1a1554..3eeb495 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -154,10 +154,8 @@ proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte, copyMem(addr sb[].buffer[sb.offset], pbytes, size) sb[].offset = sb[].offset + size length = length - size - - if length == 0: - # We notify consumers that new data is available. - sb[].forget() + # We notify consumers that new data is available. + sb[].forget() template toDataOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, 0, sb.offset - 1) @@ -165,6 +163,15 @@ template toDataOpenArray*(sb: AsyncBuffer): auto = template toBufferOpenArray*(sb: AsyncBuffer): auto = toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1) +template harvestItem*(dest: pointer, item: WriteItem, length: int) = + if item.kind == Pointer: + let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + copyMem(dest, p, length) + elif item.kind == Sequence: + copyMem(dest, unsafeAddr item.data2[item.offset], length) + elif item.kind == String: + copyMem(dest, unsafeAddr item.data3[item.offset], length) + proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} = var w = newException(AsyncStreamReadError, "Read stream failed") w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 18f7c5d..a31e603 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -7,11 +7,11 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -## This module implements TLS stream reading and writing. +## This module implements Transport Layer Security (TLS) stream. This module +## uses sources of BearSSL by Thomas Pornin. import bearssl, bearssl/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common -import strutils type TLSStreamKind {.pure.} = enum @@ -21,46 +21,79 @@ type TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303 TLSFlags* {.pure.} = enum - NoVerifyHost, # Client: Skip remote certificate check - NoVerifySN, # Client: Skip Server Name Indication (SNI) check - NoRenegotiation, # Server: Reject renegotiations requests - NoClientAuth, # Server: Disable strict client authentication - FailOnAlpnMismatch # Server: Fail on application protocol mismatch + NoVerifyHost, # Client: Skip remote certificate check + NoVerifyServerName, # Client: Skip Server Name Indication (SNI) check + EnforceServerPref, # Server: Enforce server preferences + NoRenegotiation, # Server: Reject renegotiations requests + TolerateNoClientAuth, # Server: Disable strict client authentication + FailOnAlpnMismatch # Server: Fail on application protocol mismatch -type - TlsStreamWriter* = ref object of AsyncStreamWriter - case kind: TlsStreamKind - of TlsStreamKind.Client: + TLSKeyType {.pure.} = enum + RSA, EC + + TLSPrivateKey* = ref object + case kind: TLSKeyType + of RSA: + rsakey: RsaPrivateKey + of EC: + eckey: EcPrivateKey + storage: seq[byte] + + TLSCertificate* = ref object + certs: seq[X509Certificate] + storage: seq[byte] + + TLSSessionCache* = ref object + storage: seq[byte] + context: SslSessionCacheLru + + PEMElement* = object + name*: string + data*: seq[byte] + + PEMContext = ref object + data: seq[byte] + + TLSStreamWriter* = ref object of AsyncStreamWriter + case kind: TLSStreamKind + of TLSStreamKind.Client: ccontext: ptr SslClientContext - of TlsStreamKind.Server: + of TLSStreamKind.Server: scontext: ptr SslServerContext - stream*: TlsAsyncStream + stream*: TLSAsyncStream switchToReader*: AsyncEvent switchToWriter*: AsyncEvent + handshaked*: bool + handshakeFut*: Future[void] - TlsStreamReader* = ref object of AsyncStreamReader - case kind: TlsStreamKind - of TlsStreamKind.Client: + TLSStreamReader* = ref object of AsyncStreamReader + case kind: TLSStreamKind + of TLSStreamKind.Client: ccontext: ptr SslClientContext - of TlsStreamKind.Server: + of TLSStreamKind.Server: scontext: ptr SslServerContext - stream*: TlsAsyncStream + stream*: TLSAsyncStream switchToReader*: AsyncEvent switchToWriter*: AsyncEvent + handshaked*: bool + handshakeFut*: Future[void] - TlsAsyncStream* = ref object of RootRef + TLSAsyncStream* = ref object of RootRef xwc*: X509NoAnchorContext - context*: SslClientContext + ccontext*: SslClientContext + scontext*: SslServerContext sbuffer*: seq[byte] x509*: X509MinimalContext - reader*: TlsStreamReader - writer*: TlsStreamWriter + reader*: TLSStreamReader + writer*: TLSStreamWriter - TlsStreamError* = object of CatchableError - TlsStreamProtocolError* = object of TlsStreamError + SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream + + TLSStreamError* = object of CatchableError + TLSStreamProtocolError* = object of TLSStreamError errCode*: int -template newTlsStreamProtocolError[T](message: T): ref Exception = +template newTLSStreamProtocolError[T](message: T): ref Exception = var msg = "" var code = 0 when T is string: @@ -73,29 +106,29 @@ template newTlsStreamProtocolError[T](message: T): ref Exception = code = message else: msg.add("Internal Error") - var err = newException(TlsStreamProtocolError, msg) + var err = newException(TLSStreamProtocolError, msg) err.errCode = code err -# proc raiseTlsStreamProtoError*[T](message: T) = -# raise newTlsStreamProtocolError(message) +proc raiseTLSStreamProtoError*[T](message: T) = + raise newTLSStreamProtocolError(message) -proc getStringState*(state: cuint): string = - var n = newSeq[string]() - if (state and SSL_CLOSED) == SSL_CLOSED: - n.add("Closed") - if (state and SSL_SENDREC) == SSL_SENDREC: - n.add("SendRec") - if (state and SSL_RECVREC) == SSL_RECVREC: - n.add("RecvRec") - if (state and SSL_SENDAPP) == SSL_SENDAPP: - n.add("SendApp") - if (state and SSL_RECVAPP) == SSL_RECVAPP: - n.add("RecvApp") - result = "{" & n.join(", ") & "} number (" & $state & ")" +# proc getStringState*(state: cuint): string = +# var n = newSeq[string]() +# if (state and SSL_CLOSED) == SSL_CLOSED: +# n.add("Closed") +# if (state and SSL_SENDREC) == SSL_SENDREC: +# n.add("SendRec") +# if (state and SSL_RECVREC) == SSL_RECVREC: +# n.add("RecvRec") +# if (state and SSL_SENDAPP) == SSL_SENDAPP: +# n.add("SendApp") +# if (state and SSL_RECVAPP) == SSL_RECVAPP: +# n.add("RecvApp") +# result = "{" & n.join(", ") & "} number (" & $state & ")" proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = - var wstream = cast[TlsStreamWriter](stream) + var wstream = cast[TLSStreamWriter](stream) var engine: ptr SslEngineContext var error: ref Exception @@ -139,31 +172,26 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = if (state and SSL_SENDAPP) == SSL_SENDAPP: # Application data can be sent over stream. + if not(wstream.handshaked): + wstream.stream.reader.handshaked = true + wstream.handshaked = true + if not(isNil(wstream.handshakeFut)): + wstream.handshakeFut.complete() + var item = await wstream.queue.get() if item.size > 0: length = 0'u var buf = sslEngineSendappBuf(engine, length) let toWrite = min(int(length), item.size) - + harvestItem(buf, item, toWrite) if int(length) >= item.size: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, item.size) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], item.size) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], item.size) + # BearSSL is ready to accept whole item size. sslEngineSendappAck(engine, uint(item.size)) sslEngineFlush(engine, 0) item.future.complete() else: - if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) - copyMem(buf, p, length) - elif item.kind == Sequence: - copyMem(buf, addr item.data2[item.offset], length) - elif item.kind == String: - copyMem(buf, addr item.data3[item.offset], length) + # BearSSL is not ready to accept whole item, so we will send only + # part of item and adjust offset. item.offset = item.offset + int(length) item.size = item.size - int(length) wstream.queue.addFirstNoWait(item) @@ -188,13 +216,14 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = let item = wstream.queue.popFirstNoWait() if not(item.future.finished()): item.future.fail(error) + wstream.switchToReader.fire() wstream.stream = nil proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = - var rstream = cast[TlsStreamReader](stream) + var rstream = cast[TLSStreamReader](stream) var engine: ptr SslEngineContext - if rstream.kind == TlsStreamKind.Server: + if rstream.kind == TLSStreamKind.Server: engine = addr rstream.scontext.eng else: engine = addr rstream.ccontext.eng @@ -205,12 +234,17 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = var length: uint while true: var state = engine.sslEngineCurrentState() - if (state and SSL_CLOSED) == SSL_CLOSED: let err = engine.sslEngineLastError() if err != 0: - rstream.error = newTlsStreamProtocolError(err) + rstream.error = newTLSStreamProtocolError(err) rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() break else: rstream.state = AsyncStreamState.Stopped @@ -234,6 +268,12 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = if resFut.failed(): rstream.error = resFut.error rstream.state = AsyncStreamState.Error + if not(rstream.handshaked): + rstream.handshaked = true + rstream.stream.writer.handshaked = true + if not(isNil(rstream.handshakeFut)): + rstream.handshakeFut.fail(rstream.error) + rstream.switchToWriter.fire() break let res = resFut.read() if res > 0: @@ -257,20 +297,30 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = finally: # Perform TLS cleanup procedure sslEngineClose(engine) - # Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state - # consistent. rstream.buffer.forget() + rstream.switchToWriter.fire() rstream.stream = nil -proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, +proc getSignerAlgo(xc: X509Certificate): int = + ## Get certificate's signing algorithm. + var dc: X509DecoderContext + x509DecoderInit(addr dc, nil, nil) + x509DecoderPush(addr dc, xc.data, xc.dataLen) + let err = x509DecoderLastError(addr dc) + if err != 0: + result = -1 + else: + result = int(x509DecoderGetSignerKeyType(addr dc)) + +proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, wsource: AsyncStreamWriter, serverName: string = "", bufferSize = SSL_BUFSIZE_BIDI, minVersion = TLSVersion.TLS11, maxVersion = TLSVersion.TLS12, - flags: set[TlsFlags] = {}): TlsAsyncStream = - ## Create new TLS asynchronous stream using reading stream ``rsource``, - ## writing stream ``wsource``. + flags: set[TLSFlags] = {}): TLSAsyncStream = + ## Create new TLS asynchronous stream for outbound (client) connections + ## using reading stream ``rsource`` and writing stream ``wsource``. ## ## You can specify remote server name using ``serverName``, if while ## handshake server reports different name you will get an error. If @@ -285,11 +335,11 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. ## ## ``flags`` - custom TLS connection flags. - result = new TlsAsyncStream - var reader = new TlsStreamReader - reader.kind = TlsStreamKind.Client - var writer = new TlsStreamWriter - writer.kind = TlsStreamKind.Client + result = new TLSAsyncStream + var reader = new TLSStreamReader + reader.kind = TLSStreamKind.Client + var writer = new TLSStreamWriter + writer.kind = TLSStreamKind.Client var switchToWriter = newAsyncEvent() var switchToReader = newAsyncEvent() reader.stream = result @@ -300,35 +350,292 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, writer.switchToWriter = switchToWriter result.reader = reader result.writer = writer - reader.ccontext = addr result.context - writer.ccontext = addr result.context + reader.ccontext = addr result.ccontext + writer.ccontext = addr result.ccontext if TLSFlags.NoVerifyHost in flags: - sslClientInitFull(addr result.context, addr result.x509, nil, 0) + sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0) initNoAnchor(addr result.xwc, addr result.x509.vtable) - sslEngineSetX509(addr result.context.eng, addr result.xwc.vtable) + sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable) else: - sslClientInitFull(addr result.context, addr result.x509, + sslClientInitFull(addr result.ccontext, addr result.x509, unsafeAddr MozillaTrustAnchors[0], len(MozillaTrustAnchors)) let size = max(SSL_BUFSIZE_BIDI, bufferSize) result.sbuffer = newSeq[byte](size) - sslEngineSetBuffer(addr result.context.eng, addr result.sbuffer[0], + sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0], uint(len(result.sbuffer)), 1) - sslEngineSetVersions(addr result.context.eng, uint16(minVersion), + sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion), uint16(maxVersion)) - if TLSFlags.NoVerifySN in flags: - let err = sslClientReset(addr result.context, "", 0) + if TLSFlags.NoVerifyServerName in flags: + let err = sslClientReset(addr result.ccontext, "", 0) if err == 0: - raise newException(TlsStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamError, "Could not initialize TLS layer") else: - let err = sslClientReset(addr result.context, serverName, 0) + let err = sslClientReset(addr result.ccontext, serverName, 0) if err == 0: - raise newException(TlsStreamError, "Could not initialize TLS layer") + raise newException(TLSStreamError, "Could not initialize TLS layer") init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, bufferSize) init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, bufferSize) + +proc newTLSServerAsyncStream*(rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + privateKey: TLSPrivateKey, + certificate: TLSCertificate, + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS11, + maxVersion = TLSVersion.TLS12, + cache: TLSSessionCache = nil, + flags: set[TLSFlags] = {}): TLSAsyncStream = + ## Create new TLS asynchronous stream for inbound (server) connections + ## using reading stream ``rsource`` and writing stream ``wsource``. + ## + ## You need to specify local private key ``privateKey`` and certificate + ## ``certificate``. + ## + ## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding + ## incoming data. + ## + ## ``minVersion`` and ``maxVersion`` are TLS versions which will be used + ## for handshake with remote server. If server's version will be lower then + ## ``minVersion`` of bigger then ``maxVersion`` you will get an error. + ## + ## ``flags`` - custom TLS connection flags. + if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}: + raiseTLSStreamProtoError("Incorrect private key") + if isNil(certificate) or len(certificate.certs) == 0: + raiseTLSStreamProtoError("Incorrect certificate") + + result = new TLSAsyncStream + var reader = new TLSStreamReader + reader.kind = TLSStreamKind.Server + var writer = new TLSStreamWriter + writer.kind = TLSStreamKind.Server + var switchToWriter = newAsyncEvent() + var switchToReader = newAsyncEvent() + reader.stream = result + writer.stream = result + reader.switchToReader = switchToReader + reader.switchToWriter = switchToWriter + writer.switchToReader = switchToReader + writer.switchToWriter = switchToWriter + result.reader = reader + result.writer = writer + reader.scontext = addr result.scontext + writer.scontext = addr result.scontext + + if privateKey.kind == TLSKeyType.EC: + let algo = getSignerAlgo(certificate.certs[0]) + if algo == -1: + raiseTLSStreamProtoError("Could not decode certificate") + sslServerInitFullEc(addr result.scontext, addr certificate.certs[0], + len(certificate.certs), cuint(algo), + addr privateKey.eckey) + elif privateKey.kind == TLSKeyType.RSA: + sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0], + len(certificate.certs), addr privateKey.rsakey) + + let size = max(SSL_BUFSIZE_BIDI, bufferSize) + result.sbuffer = newSeq[byte](size) + sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0], + uint(len(result.sbuffer)), 1) + sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion), + uint16(maxVersion)) + + if not isNil(cache): + sslServerSetCache(addr result.scontext, addr cache.context.vtable) + + if TLSFlags.EnforceServerPref in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) + if TLSFlags.NoRenegotiation in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION) + if TLSFlags.TolerateNoClientAuth in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH) + if TLSFlags.FailOnAlpnMismatch in flags: + sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH) + + let err = sslServerReset(addr result.scontext) + if err == 0: + raise newException(TLSStreamError, "Could not initialize TLS layer") + + init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, + bufferSize) + init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, + bufferSize) + +proc copyKey(src: RsaPrivateKey): TLSPrivateKey = + ## Creates copy of RsaPrivateKey ``src``. + var offset = 0 + let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen + result = TLSPrivateKey(kind: TLSKeyType.RSA) + result.storage = newSeq[byte](keySize) + copyMem(addr result.storage[offset], src.p, src.plen) + result.rsakey.p = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.plen = src.plen + offset = offset + src.plen + copyMem(addr result.storage[offset], src.q, src.qlen) + result.rsakey.q = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.qlen = src.qlen + offset = offset + src.qlen + copyMem(addr result.storage[offset], src.dp, src.dplen) + result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.dplen = src.dplen + offset = offset + src.dplen + copyMem(addr result.storage[offset], src.dq, src.dqlen) + result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.dqlen = src.dqlen + offset = offset + src.dqlen + copyMem(addr result.storage[offset], src.iq, src.iqlen) + result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset]) + result.rsakey.iqlen = src.iqlen + result.rsakey.nBitlen = src.nBitlen + +proc copyKey(src: EcPrivateKey): TLSPrivateKey = + ## Creates copy of EcPrivateKey ``src``. + var offset = 0 + let keySize = src.xlen + result = TLSPrivateKey(kind: TLSKeyType.EC) + result.storage = newSeq[byte](keySize) + copyMem(addr result.storage[offset], src.x, src.xlen) + result.eckey.x = cast[ptr cuchar](addr result.storage[offset]) + result.eckey.xlen = src.xlen + result.eckey.curve = src.curve + +proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey = + ## Initialize TLS private key from array of bytes ``data``. + ## + ## This procedure initializes private key using raw, DER-encoded format, + ## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded). + var ctx: SkeyDecoderContext + if len(data) == 0: + raiseTLSStreamProtoError("Incorrect private key") + skeyDecoderInit(addr ctx) + skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data)) + let err = skeyDecoderLastError(addr ctx) + if err != 0: + raiseTLSStreamProtoError(err) + let keyType = skeyDecoderKeyType(addr ctx) + if keyType == KEYTYPE_RSA: + result = copyKey(ctx.key.rsa) + elif keyType == KEYTYPE_EC: + result = copyKey(ctx.key.ec) + else: + raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")") + +proc pemDecode*(data: openarray[char]): seq[PEMElement] = + ## Decode PEM encoded string and get array of binary blobs. + if len(data) == 0: + raiseTLSStreamProtoError("Empty PEM message") + var ctx: PemDecoderContext + var pctx = new PEMContext + result = newSeq[PEMElement]() + pemDecoderInit(addr ctx) + + proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} = + var p = cast[PEMContext](ctx) + var o = len(p.data) + p.data.setLen(o + nbytes) + copyMem(addr p.data[o], pbytes, nbytes) + + var length = len(data) + var offset = 0 + var inobj = false + var elem: PEMElement + + while length > 0: + var tlen = pemDecoderPush(addr ctx, + cast[pointer](unsafeAddr data[offset]), length) + offset = offset + tlen + length = length - tlen + + let event = pemDecoderEvent(addr ctx) + if event == PEM_BEGIN_OBJ: + inobj = true + elem.name = $pemDecoderName(addr ctx) + pctx.data = newSeq[byte]() + pemDecoderSetdest(addr ctx, itemAppend, cast[pointer](pctx)) + elif event == PEM_END_OBJ: + if inobj: + elem.data = pctx.data + result.add(elem) + inobj = false + else: + break + else: + raiseTLSStreamProtoError("Invalid PEM encoding") + +proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey = + ## Initialize TLS private key from string ``data``. + ## + ## This procedure initializes private key using unencrypted PKCS#8 PEM + ## encoded string. + ## + ## Note that PKCS#1 PEM encoded objects are not supported. + var items = pemDecode(data) + for item in items: + if item.name == "PRIVATE KEY": + result = TLSPrivateKey.init(item.data) + break + if isNil(result): + raiseTLSStreamProtoError("Could not find private key") + +proc init*(tt: typedesc[TLSCertificate], + data: openarray[char]): TLSCertificate = + ## Initialize TLS certificates from string ``data``. + ## + ## This procedure initializes array of certificates from PEM encoded string. + var items = pemDecode(data) + result = new TLSCertificate + for item in items: + if item.name == "CERTIFICATE" and len(item.data) > 0: + let offset = len(result.storage) + result.storage.add(item.data) + let cert = X509Certificate( + data: cast[ptr cuchar](addr result.storage[offset]), + dataLen: len(item.data) + ) + let res = getSignerAlgo(cert) + if res == -1: + raiseTLSStreamProtoError("Could not decode certificate") + elif res != KEYTYPE_RSA and res != KEYTYPE_EC: + raiseTLSStreamProtoError("Unsupported signing key type in certificate") + result.certs.add(cert) + if len(result.storage) == 0: + raiseTLSStreamProtoError("Could not find any certificates") + +proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = + ## Create new TLS session cache with size ``size``. + ## + ## One cached item is near 100 bytes size. + result = new TLSSessionCache + var rsize = min(size, 4096) + result.storage = newSeq[byte](rsize) + sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize) + +proc handshake*(rws: SomeTLSStreamType): Future[void] = + ## Wait until initial TLS handshake will be successfully performed. + var retFuture = newFuture[void]("tlsstream.handshake") + when rws is TLSStreamReader: + if rws.handshaked: + retFuture.complete() + else: + rws.handshakeFut = retFuture + rws.stream.writer.handshakeFut = retFuture + elif rws is TLSStreamWriter: + if rws.handshaked: + retFuture.complete() + else: + rws.handshakeFut = retFuture + rws.stream.reader.handshakeFut = retFuture + elif rws is TLSAsyncStream: + if rws.reader.handshaked: + retFuture.complete() + else: + rws.reader.handshakeFut = retFuture + rws.writer.handshakeFut = retFuture + return retFuture diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 6343ceb..038aec1 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -8,6 +8,62 @@ import strutils, unittest, os import ../chronos, ../chronos/streams/tlsstream +const SelfSignedRsaKey = """ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDSXcKMR6zIIHSy ++UUjgyIlJWzEu3JFCEN+qBpjmwbuae3GNed0YgPels6AKe0AlWQNgpqoWfMQrWco +qNf9dcm9OIdUK5FGMWYC8mJu6OjwnexSXt0R/k07lc5ePPDUbGvkbvgDyHzl1Ynj +jYhe/2ujL0E99/KAuA7cvRILEl4rLpnngE+MMYNrFWmsSDzC+w0Hv5Fuyc8wZoSy +nzODvojvFXTva9Nx4LxPF1W79WidwJHwrJghVlGUUhSeRGLHQWRY+954rj6TavoJ +6gLYSHx6ELnRkFMbqxRrCJ0wAbrV8SwcPHjKSc6LQGRMzBRqxT45DoKmzH9Pwdnw +6P0PGJyPAgMBAAECggEAd5Ck39hpIwIXchXtrwZ8ZMKFtLeZdhUBT7656P0XDnEU +nQDMQcDn1B7A1eV+eENwr6EYyDD/zu3P4TM+OCg3dp3nhPaSRmQTR/995O3qX8BS +rmqOmgiA2yoFNljKxOGu3RIZUwUjv/oDulsaNGxWQFS+bzs7EOAMSngIBlT1QvLc +1etvGOW6hc4nzacrXpzhzWem6EzOabPZBmIy0DDz2ARlND1YSd2P5IMpOv0terNF +ZwYl5SZ6Rnbv6GJWKl5IIpJOOtRtwyIhKNU/bd835vOfW07aaHVAT6GNYlyEoGWT +36UjOyl3YxSLNvQOeIz4Y0n+vupBu/YL+mFtQdxJgQKBgQDtjEi0cNlt+Vz5sp4q +wAVBJ/6h7hu3KJkY+xDpdrLyFTcxttKM9q/dR8V2bEaqMYT/mTvADfmW2BBcfi2J +3VdR1lQ5pXeKAuxt1/Vc+Q4UCnX5OX7UXpP9aSetDdo5FXUC9X2H0hO0BqOcc+h2 +khVyXjKt6TdBwV94dP9bmQp9QQKBgQDitPVqRHGepYBYasScyTUXPp7vL9T7PMSu +PGjqEkwvauhICpUbWpE/j8M0UXk64zwSmOYwQ37uiPpws88GL57oy1ZQZnhF/Hi3 +tM00Mn4x0xbyONbWu/AcFIZwSeSL6QhHYfeyVj7Jb/lqUg8sMiGmO25JjlAQBfTb +vvBgEpcVzwKBgQCwgz87JWfLejIGMR2qcoj1A30IYmAh1377uwO0F0mc7PrYbBtE +N8IyUTR/bLGNocJME1b8vOWrmt19fRzlhp1t6C8prrSGzulULdbawQ4fAi7rhDek +Iqsg8FRVGSgAptsN2dDvbcDKUuycQtyHzsE0/J338IXozIHehkGBlNTggQKBgQCF +RDTj5BoaVVWuJA0x0UGJSYFqP2bmzWEcv1w5BMqOMT0cZEQkkUfC4oKwdZhbGosM +r57ZDkRGenUl3T08eK/kTuuNVb8r/O8Fpp3eKjRum5TojKsWDeJmz1X8GiPkbvcz +5w4RYouEJHOsoVJT+6A2NMdvK946nRXEO2jYQPVZlwKBgBro8qGm0+T0T2xNP21q +IzjP/EHT7iIkM5kiUCc2bPIrfzAGxXImakDzd6AgpgxhhficJOpp792Upe/b/Hwy +bwfmbdWlT7/hPCnlVVH2dgO/ysDyEfxPigBMd+MmucRm6fzGIU7XSQw4KJqH4vQN +9IASWlgzyQ1RytAduzRuepzB +-----END PRIVATE KEY----- +""" + +const SelfSignedRsaCert = """ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUEFovLJkPSn4T8BBZMYBrXPDajF0wDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTE5MTAxMjA1NDQ0N1oXDTIwMTAxMTA1NDQ0OFowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA0l3CjEesyCB0svlFI4MiJSVsxLtyRQhDfqgaY5sG7mntxjXndGID +3pbOgCntAJVkDYKaqFnzEK1nKKjX/XXJvTiHVCuRRjFmAvJibujo8J3sUl7dEf5N +O5XOXjzw1Gxr5G74A8h85dWJ442IXv9roy9BPffygLgO3L0SCxJeKy6Z54BPjDGD +axVprEg8wvsNB7+RbsnPMGaEsp8zg76I7xV072vTceC8TxdVu/VoncCR8KyYIVZR +lFIUnkRix0FkWPveeK4+k2r6CeoC2Eh8ehC50ZBTG6sUawidMAG61fEsHDx4yknO +i0BkTMwUasU+OQ6Cpsx/T8HZ8Oj9DxicjwIDAQABo1MwUTAdBgNVHQ4EFgQUMM+1 +FZ6KmN2eCJfDxY+8xa1JKnYwHwYDVR0jBBgwFoAUMM+1FZ6KmN2eCJfDxY+8xa1J +KnYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEASglh98fXvPwA +KMaEezCUqTeE7DehLlhZ8n6ETKaBDcP3JR4+KTEh9y7gRGJ7DXGFAYfU3itgjyZo +kbgpZIhrTKyYCAsF96Q1mHf/cBQ96UXr0U0SbYXSSJFeeMthMvki556dJZajtxcA +9xR/U0PxPjhC9NIfpVSAv/7ocnXh73qOiFHoN9Cr2smzcGPxsifys2iv1qm5LwDr +Dx5h/RfyfuAjS8e1ZCAhS++PYjb8BX54NilW2lTYF3pwpXL8znc4eBmklBkw5L60 +99jrK7LSQT9Nk8Mf9t4P/77N4hXCqsHIxZIqJlbdgdKfvBF3vRomxm3/aWtGlTVD +vvzZPnlYfQ== +-----END CERTIFICATE----- +""" + suite "AsyncStream test suite": test "AsyncStream(StreamTransport) readExactly() test": proc testReadExactly(address: TransportAddress): Future[bool] {.async.} = @@ -524,7 +580,7 @@ suite "TLSStream test suite": var transp = await connect(address) var reader = newAsyncStreamReader(transp) var writer = newAsyncStreamWriter(transp) - var tlsstream = newTlsClientAsyncStream(reader, writer, name) + var tlsstream = newTLSClientAsyncStream(reader, writer, name) await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & "\r\nConnection: close\r\n\r\n") @@ -548,7 +604,52 @@ suite "TLSStream test suite": "www.google.com")) check res == true - test "TlsStream leaks test": + proc checkSSLServer(address: TransportAddress, + pemkey, pemcert: string): Future[bool] {.async.} = + var key: TLSPrivateKey + var cert: TLSCertificate + let testMessage = "TEST MESSAGE" + + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var reader = newAsyncStreamReader(transp) + var writer = newAsyncStreamWriter(transp) + var sstream = newTLSServerAsyncStream(reader, writer, key, cert) + await handshake(sstream) + await sstream.writer.write(testMessage & "\r\n") + await sstream.writer.closeWait() + await sstream.reader.closeWait() + await reader.closeWait() + await writer.closeWait() + await transp.closeWait() + server.stop() + server.close() + + key = TLSPrivateKey.init(pemkey) + cert = TLSCertificate.init(pemcert) + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + var conn = await connect(address) + var creader = newAsyncStreamReader(conn) + var cwriter = newAsyncStreamWriter(conn) + # We are using self-signed certificate + var cstream = newTLSClientAsyncStream(creader, cwriter, + flags = {NoVerifyHost}) + let res = await cstream.reader.readLine() + await cstream.reader.closeWait() + await cstream.writer.closeWait() + await creader.closeWait() + await cwriter.closeWait() + await conn.closeWait() + await server.join() + result = res == testMessage + + test "Simple server with RSA self-signed certificate": + let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"), + SelfSignedRsaKey, SelfSignedRsaCert)) + check res == true + test "TLSStream leaks test": check: getTracker("async.stream.reader").isLeaked() == false getTracker("async.stream.writer").isLeaked() == false