Add TLS inbound stream.

Fix some review comments.
This commit is contained in:
cheatfate 2019-10-16 09:01:52 +03:00
parent 161c50209e
commit a92ad6d2d2
No known key found for this signature in database
GPG Key ID: 46ADD633A7201F95
3 changed files with 505 additions and 90 deletions

View File

@ -154,10 +154,8 @@ proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte,
copyMem(addr sb[].buffer[sb.offset], pbytes, size) copyMem(addr sb[].buffer[sb.offset], pbytes, size)
sb[].offset = sb[].offset + size sb[].offset = sb[].offset + size
length = length - size length = length - size
# We notify consumers that new data is available.
if length == 0: sb[].forget()
# We notify consumers that new data is available.
sb[].forget()
template toDataOpenArray*(sb: AsyncBuffer): auto = template toDataOpenArray*(sb: AsyncBuffer): auto =
toOpenArray(sb.buffer, 0, sb.offset - 1) toOpenArray(sb.buffer, 0, sb.offset - 1)
@ -165,6 +163,15 @@ template toDataOpenArray*(sb: AsyncBuffer): auto =
template toBufferOpenArray*(sb: AsyncBuffer): auto = template toBufferOpenArray*(sb: AsyncBuffer): auto =
toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1) toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1)
template harvestItem*(dest: pointer, item: WriteItem, length: int) =
if item.kind == Pointer:
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
copyMem(dest, p, length)
elif item.kind == Sequence:
copyMem(dest, unsafeAddr item.data2[item.offset], length)
elif item.kind == String:
copyMem(dest, unsafeAddr item.data3[item.offset], length)
proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} = proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} =
var w = newException(AsyncStreamReadError, "Read stream failed") var w = newException(AsyncStreamReadError, "Read stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg

View File

@ -7,11 +7,11 @@
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
## This module implements TLS stream reading and writing. ## This module implements Transport Layer Security (TLS) stream. This module
## uses sources of BearSSL <https://www.bearssl.org> by Thomas Pornin.
import bearssl, bearssl/cacert import bearssl, bearssl/cacert
import ../asyncloop, ../timer, ../asyncsync import ../asyncloop, ../timer, ../asyncsync
import asyncstream, ../transports/stream, ../transports/common import asyncstream, ../transports/stream, ../transports/common
import strutils
type type
TLSStreamKind {.pure.} = enum TLSStreamKind {.pure.} = enum
@ -21,46 +21,79 @@ type
TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303 TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303
TLSFlags* {.pure.} = enum TLSFlags* {.pure.} = enum
NoVerifyHost, # Client: Skip remote certificate check NoVerifyHost, # Client: Skip remote certificate check
NoVerifySN, # Client: Skip Server Name Indication (SNI) check NoVerifyServerName, # Client: Skip Server Name Indication (SNI) check
NoRenegotiation, # Server: Reject renegotiations requests EnforceServerPref, # Server: Enforce server preferences
NoClientAuth, # Server: Disable strict client authentication NoRenegotiation, # Server: Reject renegotiations requests
FailOnAlpnMismatch # Server: Fail on application protocol mismatch TolerateNoClientAuth, # Server: Disable strict client authentication
FailOnAlpnMismatch # Server: Fail on application protocol mismatch
type TLSKeyType {.pure.} = enum
TlsStreamWriter* = ref object of AsyncStreamWriter RSA, EC
case kind: TlsStreamKind
of TlsStreamKind.Client: TLSPrivateKey* = ref object
case kind: TLSKeyType
of RSA:
rsakey: RsaPrivateKey
of EC:
eckey: EcPrivateKey
storage: seq[byte]
TLSCertificate* = ref object
certs: seq[X509Certificate]
storage: seq[byte]
TLSSessionCache* = ref object
storage: seq[byte]
context: SslSessionCacheLru
PEMElement* = object
name*: string
data*: seq[byte]
PEMContext = ref object
data: seq[byte]
TLSStreamWriter* = ref object of AsyncStreamWriter
case kind: TLSStreamKind
of TLSStreamKind.Client:
ccontext: ptr SslClientContext ccontext: ptr SslClientContext
of TlsStreamKind.Server: of TLSStreamKind.Server:
scontext: ptr SslServerContext scontext: ptr SslServerContext
stream*: TlsAsyncStream stream*: TLSAsyncStream
switchToReader*: AsyncEvent switchToReader*: AsyncEvent
switchToWriter*: AsyncEvent switchToWriter*: AsyncEvent
handshaked*: bool
handshakeFut*: Future[void]
TlsStreamReader* = ref object of AsyncStreamReader TLSStreamReader* = ref object of AsyncStreamReader
case kind: TlsStreamKind case kind: TLSStreamKind
of TlsStreamKind.Client: of TLSStreamKind.Client:
ccontext: ptr SslClientContext ccontext: ptr SslClientContext
of TlsStreamKind.Server: of TLSStreamKind.Server:
scontext: ptr SslServerContext scontext: ptr SslServerContext
stream*: TlsAsyncStream stream*: TLSAsyncStream
switchToReader*: AsyncEvent switchToReader*: AsyncEvent
switchToWriter*: AsyncEvent switchToWriter*: AsyncEvent
handshaked*: bool
handshakeFut*: Future[void]
TlsAsyncStream* = ref object of RootRef TLSAsyncStream* = ref object of RootRef
xwc*: X509NoAnchorContext xwc*: X509NoAnchorContext
context*: SslClientContext ccontext*: SslClientContext
scontext*: SslServerContext
sbuffer*: seq[byte] sbuffer*: seq[byte]
x509*: X509MinimalContext x509*: X509MinimalContext
reader*: TlsStreamReader reader*: TLSStreamReader
writer*: TlsStreamWriter writer*: TLSStreamWriter
TlsStreamError* = object of CatchableError SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
TlsStreamProtocolError* = object of TlsStreamError
TLSStreamError* = object of CatchableError
TLSStreamProtocolError* = object of TLSStreamError
errCode*: int errCode*: int
template newTlsStreamProtocolError[T](message: T): ref Exception = template newTLSStreamProtocolError[T](message: T): ref Exception =
var msg = "" var msg = ""
var code = 0 var code = 0
when T is string: when T is string:
@ -73,29 +106,29 @@ template newTlsStreamProtocolError[T](message: T): ref Exception =
code = message code = message
else: else:
msg.add("Internal Error") msg.add("Internal Error")
var err = newException(TlsStreamProtocolError, msg) var err = newException(TLSStreamProtocolError, msg)
err.errCode = code err.errCode = code
err err
# proc raiseTlsStreamProtoError*[T](message: T) = proc raiseTLSStreamProtoError*[T](message: T) =
# raise newTlsStreamProtocolError(message) raise newTLSStreamProtocolError(message)
proc getStringState*(state: cuint): string = # proc getStringState*(state: cuint): string =
var n = newSeq[string]() # var n = newSeq[string]()
if (state and SSL_CLOSED) == SSL_CLOSED: # if (state and SSL_CLOSED) == SSL_CLOSED:
n.add("Closed") # n.add("Closed")
if (state and SSL_SENDREC) == SSL_SENDREC: # if (state and SSL_SENDREC) == SSL_SENDREC:
n.add("SendRec") # n.add("SendRec")
if (state and SSL_RECVREC) == SSL_RECVREC: # if (state and SSL_RECVREC) == SSL_RECVREC:
n.add("RecvRec") # n.add("RecvRec")
if (state and SSL_SENDAPP) == SSL_SENDAPP: # if (state and SSL_SENDAPP) == SSL_SENDAPP:
n.add("SendApp") # n.add("SendApp")
if (state and SSL_RECVAPP) == SSL_RECVAPP: # if (state and SSL_RECVAPP) == SSL_RECVAPP:
n.add("RecvApp") # n.add("RecvApp")
result = "{" & n.join(", ") & "} number (" & $state & ")" # result = "{" & n.join(", ") & "} number (" & $state & ")"
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
var wstream = cast[TlsStreamWriter](stream) var wstream = cast[TLSStreamWriter](stream)
var engine: ptr SslEngineContext var engine: ptr SslEngineContext
var error: ref Exception var error: ref Exception
@ -139,31 +172,26 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
if (state and SSL_SENDAPP) == SSL_SENDAPP: if (state and SSL_SENDAPP) == SSL_SENDAPP:
# Application data can be sent over stream. # Application data can be sent over stream.
if not(wstream.handshaked):
wstream.stream.reader.handshaked = true
wstream.handshaked = true
if not(isNil(wstream.handshakeFut)):
wstream.handshakeFut.complete()
var item = await wstream.queue.get() var item = await wstream.queue.get()
if item.size > 0: if item.size > 0:
length = 0'u length = 0'u
var buf = sslEngineSendappBuf(engine, length) var buf = sslEngineSendappBuf(engine, length)
let toWrite = min(int(length), item.size) let toWrite = min(int(length), item.size)
harvestItem(buf, item, toWrite)
if int(length) >= item.size: if int(length) >= item.size:
if item.kind == Pointer: # BearSSL is ready to accept whole item size.
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
copyMem(buf, p, item.size)
elif item.kind == Sequence:
copyMem(buf, addr item.data2[item.offset], item.size)
elif item.kind == String:
copyMem(buf, addr item.data3[item.offset], item.size)
sslEngineSendappAck(engine, uint(item.size)) sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0) sslEngineFlush(engine, 0)
item.future.complete() item.future.complete()
else: else:
if item.kind == Pointer: # BearSSL is not ready to accept whole item, so we will send only
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) # part of item and adjust offset.
copyMem(buf, p, length)
elif item.kind == Sequence:
copyMem(buf, addr item.data2[item.offset], length)
elif item.kind == String:
copyMem(buf, addr item.data3[item.offset], length)
item.offset = item.offset + int(length) item.offset = item.offset + int(length)
item.size = item.size - int(length) item.size = item.size - int(length)
wstream.queue.addFirstNoWait(item) wstream.queue.addFirstNoWait(item)
@ -188,13 +216,14 @@ proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
let item = wstream.queue.popFirstNoWait() let item = wstream.queue.popFirstNoWait()
if not(item.future.finished()): if not(item.future.finished()):
item.future.fail(error) item.future.fail(error)
wstream.switchToReader.fire()
wstream.stream = nil wstream.stream = nil
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var rstream = cast[TlsStreamReader](stream) var rstream = cast[TLSStreamReader](stream)
var engine: ptr SslEngineContext var engine: ptr SslEngineContext
if rstream.kind == TlsStreamKind.Server: if rstream.kind == TLSStreamKind.Server:
engine = addr rstream.scontext.eng engine = addr rstream.scontext.eng
else: else:
engine = addr rstream.ccontext.eng engine = addr rstream.ccontext.eng
@ -205,12 +234,17 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
var length: uint var length: uint
while true: while true:
var state = engine.sslEngineCurrentState() var state = engine.sslEngineCurrentState()
if (state and SSL_CLOSED) == SSL_CLOSED: if (state and SSL_CLOSED) == SSL_CLOSED:
let err = engine.sslEngineLastError() let err = engine.sslEngineLastError()
if err != 0: if err != 0:
rstream.error = newTlsStreamProtocolError(err) rstream.error = newTLSStreamProtocolError(err)
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
if not(rstream.handshaked):
rstream.handshaked = true
rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error)
rstream.switchToWriter.fire()
break break
else: else:
rstream.state = AsyncStreamState.Stopped rstream.state = AsyncStreamState.Stopped
@ -234,6 +268,12 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
if resFut.failed(): if resFut.failed():
rstream.error = resFut.error rstream.error = resFut.error
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
if not(rstream.handshaked):
rstream.handshaked = true
rstream.stream.writer.handshaked = true
if not(isNil(rstream.handshakeFut)):
rstream.handshakeFut.fail(rstream.error)
rstream.switchToWriter.fire()
break break
let res = resFut.read() let res = resFut.read()
if res > 0: if res > 0:
@ -257,20 +297,30 @@ proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
finally: finally:
# Perform TLS cleanup procedure # Perform TLS cleanup procedure
sslEngineClose(engine) sslEngineClose(engine)
# Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state
# consistent.
rstream.buffer.forget() rstream.buffer.forget()
rstream.switchToWriter.fire()
rstream.stream = nil rstream.stream = nil
proc newTlsClientAsyncStream*(rsource: AsyncStreamReader, proc getSignerAlgo(xc: X509Certificate): int =
## Get certificate's signing algorithm.
var dc: X509DecoderContext
x509DecoderInit(addr dc, nil, nil)
x509DecoderPush(addr dc, xc.data, xc.dataLen)
let err = x509DecoderLastError(addr dc)
if err != 0:
result = -1
else:
result = int(x509DecoderGetSignerKeyType(addr dc))
proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
wsource: AsyncStreamWriter, wsource: AsyncStreamWriter,
serverName: string = "", serverName: string = "",
bufferSize = SSL_BUFSIZE_BIDI, bufferSize = SSL_BUFSIZE_BIDI,
minVersion = TLSVersion.TLS11, minVersion = TLSVersion.TLS11,
maxVersion = TLSVersion.TLS12, maxVersion = TLSVersion.TLS12,
flags: set[TlsFlags] = {}): TlsAsyncStream = flags: set[TLSFlags] = {}): TLSAsyncStream =
## Create new TLS asynchronous stream using reading stream ``rsource``, ## Create new TLS asynchronous stream for outbound (client) connections
## writing stream ``wsource``. ## using reading stream ``rsource`` and writing stream ``wsource``.
## ##
## You can specify remote server name using ``serverName``, if while ## You can specify remote server name using ``serverName``, if while
## handshake server reports different name you will get an error. If ## handshake server reports different name you will get an error. If
@ -285,11 +335,11 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
## ``minVersion`` of bigger then ``maxVersion`` you will get an error. ## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
## ##
## ``flags`` - custom TLS connection flags. ## ``flags`` - custom TLS connection flags.
result = new TlsAsyncStream result = new TLSAsyncStream
var reader = new TlsStreamReader var reader = new TLSStreamReader
reader.kind = TlsStreamKind.Client reader.kind = TLSStreamKind.Client
var writer = new TlsStreamWriter var writer = new TLSStreamWriter
writer.kind = TlsStreamKind.Client writer.kind = TLSStreamKind.Client
var switchToWriter = newAsyncEvent() var switchToWriter = newAsyncEvent()
var switchToReader = newAsyncEvent() var switchToReader = newAsyncEvent()
reader.stream = result reader.stream = result
@ -300,35 +350,292 @@ proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
writer.switchToWriter = switchToWriter writer.switchToWriter = switchToWriter
result.reader = reader result.reader = reader
result.writer = writer result.writer = writer
reader.ccontext = addr result.context reader.ccontext = addr result.ccontext
writer.ccontext = addr result.context writer.ccontext = addr result.ccontext
if TLSFlags.NoVerifyHost in flags: if TLSFlags.NoVerifyHost in flags:
sslClientInitFull(addr result.context, addr result.x509, nil, 0) sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0)
initNoAnchor(addr result.xwc, addr result.x509.vtable) initNoAnchor(addr result.xwc, addr result.x509.vtable)
sslEngineSetX509(addr result.context.eng, addr result.xwc.vtable) sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable)
else: else:
sslClientInitFull(addr result.context, addr result.x509, sslClientInitFull(addr result.ccontext, addr result.x509,
unsafeAddr MozillaTrustAnchors[0], unsafeAddr MozillaTrustAnchors[0],
len(MozillaTrustAnchors)) len(MozillaTrustAnchors))
let size = max(SSL_BUFSIZE_BIDI, bufferSize) let size = max(SSL_BUFSIZE_BIDI, bufferSize)
result.sbuffer = newSeq[byte](size) result.sbuffer = newSeq[byte](size)
sslEngineSetBuffer(addr result.context.eng, addr result.sbuffer[0], sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0],
uint(len(result.sbuffer)), 1) uint(len(result.sbuffer)), 1)
sslEngineSetVersions(addr result.context.eng, uint16(minVersion), sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion),
uint16(maxVersion)) uint16(maxVersion))
if TLSFlags.NoVerifySN in flags: if TLSFlags.NoVerifyServerName in flags:
let err = sslClientReset(addr result.context, "", 0) let err = sslClientReset(addr result.ccontext, "", 0)
if err == 0: if err == 0:
raise newException(TlsStreamError, "Could not initialize TLS layer") raise newException(TLSStreamError, "Could not initialize TLS layer")
else: else:
let err = sslClientReset(addr result.context, serverName, 0) let err = sslClientReset(addr result.ccontext, serverName, 0)
if err == 0: if err == 0:
raise newException(TlsStreamError, "Could not initialize TLS layer") raise newException(TLSStreamError, "Could not initialize TLS layer")
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop, init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
bufferSize) bufferSize)
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop, init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
bufferSize) bufferSize)
proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
wsource: AsyncStreamWriter,
privateKey: TLSPrivateKey,
certificate: TLSCertificate,
bufferSize = SSL_BUFSIZE_BIDI,
minVersion = TLSVersion.TLS11,
maxVersion = TLSVersion.TLS12,
cache: TLSSessionCache = nil,
flags: set[TLSFlags] = {}): TLSAsyncStream =
## Create new TLS asynchronous stream for inbound (server) connections
## using reading stream ``rsource`` and writing stream ``wsource``.
##
## You need to specify local private key ``privateKey`` and certificate
## ``certificate``.
##
## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding
## incoming data.
##
## ``minVersion`` and ``maxVersion`` are TLS versions which will be used
## for handshake with remote server. If server's version will be lower then
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
##
## ``flags`` - custom TLS connection flags.
if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}:
raiseTLSStreamProtoError("Incorrect private key")
if isNil(certificate) or len(certificate.certs) == 0:
raiseTLSStreamProtoError("Incorrect certificate")
result = new TLSAsyncStream
var reader = new TLSStreamReader
reader.kind = TLSStreamKind.Server
var writer = new TLSStreamWriter
writer.kind = TLSStreamKind.Server
var switchToWriter = newAsyncEvent()
var switchToReader = newAsyncEvent()
reader.stream = result
writer.stream = result
reader.switchToReader = switchToReader
reader.switchToWriter = switchToWriter
writer.switchToReader = switchToReader
writer.switchToWriter = switchToWriter
result.reader = reader
result.writer = writer
reader.scontext = addr result.scontext
writer.scontext = addr result.scontext
if privateKey.kind == TLSKeyType.EC:
let algo = getSignerAlgo(certificate.certs[0])
if algo == -1:
raiseTLSStreamProtoError("Could not decode certificate")
sslServerInitFullEc(addr result.scontext, addr certificate.certs[0],
len(certificate.certs), cuint(algo),
addr privateKey.eckey)
elif privateKey.kind == TLSKeyType.RSA:
sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0],
len(certificate.certs), addr privateKey.rsakey)
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
result.sbuffer = newSeq[byte](size)
sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0],
uint(len(result.sbuffer)), 1)
sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion),
uint16(maxVersion))
if not isNil(cache):
sslServerSetCache(addr result.scontext, addr cache.context.vtable)
if TLSFlags.EnforceServerPref in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
if TLSFlags.NoRenegotiation in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION)
if TLSFlags.TolerateNoClientAuth in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH)
if TLSFlags.FailOnAlpnMismatch in flags:
sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH)
let err = sslServerReset(addr result.scontext)
if err == 0:
raise newException(TLSStreamError, "Could not initialize TLS layer")
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
bufferSize)
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
bufferSize)
proc copyKey(src: RsaPrivateKey): TLSPrivateKey =
## Creates copy of RsaPrivateKey ``src``.
var offset = 0
let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen
result = TLSPrivateKey(kind: TLSKeyType.RSA)
result.storage = newSeq[byte](keySize)
copyMem(addr result.storage[offset], src.p, src.plen)
result.rsakey.p = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.plen = src.plen
offset = offset + src.plen
copyMem(addr result.storage[offset], src.q, src.qlen)
result.rsakey.q = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.qlen = src.qlen
offset = offset + src.qlen
copyMem(addr result.storage[offset], src.dp, src.dplen)
result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.dplen = src.dplen
offset = offset + src.dplen
copyMem(addr result.storage[offset], src.dq, src.dqlen)
result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.dqlen = src.dqlen
offset = offset + src.dqlen
copyMem(addr result.storage[offset], src.iq, src.iqlen)
result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset])
result.rsakey.iqlen = src.iqlen
result.rsakey.nBitlen = src.nBitlen
proc copyKey(src: EcPrivateKey): TLSPrivateKey =
## Creates copy of EcPrivateKey ``src``.
var offset = 0
let keySize = src.xlen
result = TLSPrivateKey(kind: TLSKeyType.EC)
result.storage = newSeq[byte](keySize)
copyMem(addr result.storage[offset], src.x, src.xlen)
result.eckey.x = cast[ptr cuchar](addr result.storage[offset])
result.eckey.xlen = src.xlen
result.eckey.curve = src.curve
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
## Initialize TLS private key from array of bytes ``data``.
##
## This procedure initializes private key using raw, DER-encoded format,
## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded).
var ctx: SkeyDecoderContext
if len(data) == 0:
raiseTLSStreamProtoError("Incorrect private key")
skeyDecoderInit(addr ctx)
skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data))
let err = skeyDecoderLastError(addr ctx)
if err != 0:
raiseTLSStreamProtoError(err)
let keyType = skeyDecoderKeyType(addr ctx)
if keyType == KEYTYPE_RSA:
result = copyKey(ctx.key.rsa)
elif keyType == KEYTYPE_EC:
result = copyKey(ctx.key.ec)
else:
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
proc pemDecode*(data: openarray[char]): seq[PEMElement] =
## Decode PEM encoded string and get array of binary blobs.
if len(data) == 0:
raiseTLSStreamProtoError("Empty PEM message")
var ctx: PemDecoderContext
var pctx = new PEMContext
result = newSeq[PEMElement]()
pemDecoderInit(addr ctx)
proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} =
var p = cast[PEMContext](ctx)
var o = len(p.data)
p.data.setLen(o + nbytes)
copyMem(addr p.data[o], pbytes, nbytes)
var length = len(data)
var offset = 0
var inobj = false
var elem: PEMElement
while length > 0:
var tlen = pemDecoderPush(addr ctx,
cast[pointer](unsafeAddr data[offset]), length)
offset = offset + tlen
length = length - tlen
let event = pemDecoderEvent(addr ctx)
if event == PEM_BEGIN_OBJ:
inobj = true
elem.name = $pemDecoderName(addr ctx)
pctx.data = newSeq[byte]()
pemDecoderSetdest(addr ctx, itemAppend, cast[pointer](pctx))
elif event == PEM_END_OBJ:
if inobj:
elem.data = pctx.data
result.add(elem)
inobj = false
else:
break
else:
raiseTLSStreamProtoError("Invalid PEM encoding")
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
## Initialize TLS private key from string ``data``.
##
## This procedure initializes private key using unencrypted PKCS#8 PEM
## encoded string.
##
## Note that PKCS#1 PEM encoded objects are not supported.
var items = pemDecode(data)
for item in items:
if item.name == "PRIVATE KEY":
result = TLSPrivateKey.init(item.data)
break
if isNil(result):
raiseTLSStreamProtoError("Could not find private key")
proc init*(tt: typedesc[TLSCertificate],
data: openarray[char]): TLSCertificate =
## Initialize TLS certificates from string ``data``.
##
## This procedure initializes array of certificates from PEM encoded string.
var items = pemDecode(data)
result = new TLSCertificate
for item in items:
if item.name == "CERTIFICATE" and len(item.data) > 0:
let offset = len(result.storage)
result.storage.add(item.data)
let cert = X509Certificate(
data: cast[ptr cuchar](addr result.storage[offset]),
dataLen: len(item.data)
)
let res = getSignerAlgo(cert)
if res == -1:
raiseTLSStreamProtoError("Could not decode certificate")
elif res != KEYTYPE_RSA and res != KEYTYPE_EC:
raiseTLSStreamProtoError("Unsupported signing key type in certificate")
result.certs.add(cert)
if len(result.storage) == 0:
raiseTLSStreamProtoError("Could not find any certificates")
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =
## Create new TLS session cache with size ``size``.
##
## One cached item is near 100 bytes size.
result = new TLSSessionCache
var rsize = min(size, 4096)
result.storage = newSeq[byte](rsize)
sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize)
proc handshake*(rws: SomeTLSStreamType): Future[void] =
## Wait until initial TLS handshake will be successfully performed.
var retFuture = newFuture[void]("tlsstream.handshake")
when rws is TLSStreamReader:
if rws.handshaked:
retFuture.complete()
else:
rws.handshakeFut = retFuture
rws.stream.writer.handshakeFut = retFuture
elif rws is TLSStreamWriter:
if rws.handshaked:
retFuture.complete()
else:
rws.handshakeFut = retFuture
rws.stream.reader.handshakeFut = retFuture
elif rws is TLSAsyncStream:
if rws.reader.handshaked:
retFuture.complete()
else:
rws.reader.handshakeFut = retFuture
rws.writer.handshakeFut = retFuture
return retFuture

View File

@ -8,6 +8,62 @@
import strutils, unittest, os import strutils, unittest, os
import ../chronos, ../chronos/streams/tlsstream import ../chronos, ../chronos/streams/tlsstream
const SelfSignedRsaKey = """
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDSXcKMR6zIIHSy
+UUjgyIlJWzEu3JFCEN+qBpjmwbuae3GNed0YgPels6AKe0AlWQNgpqoWfMQrWco
qNf9dcm9OIdUK5FGMWYC8mJu6OjwnexSXt0R/k07lc5ePPDUbGvkbvgDyHzl1Ynj
jYhe/2ujL0E99/KAuA7cvRILEl4rLpnngE+MMYNrFWmsSDzC+w0Hv5Fuyc8wZoSy
nzODvojvFXTva9Nx4LxPF1W79WidwJHwrJghVlGUUhSeRGLHQWRY+954rj6TavoJ
6gLYSHx6ELnRkFMbqxRrCJ0wAbrV8SwcPHjKSc6LQGRMzBRqxT45DoKmzH9Pwdnw
6P0PGJyPAgMBAAECggEAd5Ck39hpIwIXchXtrwZ8ZMKFtLeZdhUBT7656P0XDnEU
nQDMQcDn1B7A1eV+eENwr6EYyDD/zu3P4TM+OCg3dp3nhPaSRmQTR/995O3qX8BS
rmqOmgiA2yoFNljKxOGu3RIZUwUjv/oDulsaNGxWQFS+bzs7EOAMSngIBlT1QvLc
1etvGOW6hc4nzacrXpzhzWem6EzOabPZBmIy0DDz2ARlND1YSd2P5IMpOv0terNF
ZwYl5SZ6Rnbv6GJWKl5IIpJOOtRtwyIhKNU/bd835vOfW07aaHVAT6GNYlyEoGWT
36UjOyl3YxSLNvQOeIz4Y0n+vupBu/YL+mFtQdxJgQKBgQDtjEi0cNlt+Vz5sp4q
wAVBJ/6h7hu3KJkY+xDpdrLyFTcxttKM9q/dR8V2bEaqMYT/mTvADfmW2BBcfi2J
3VdR1lQ5pXeKAuxt1/Vc+Q4UCnX5OX7UXpP9aSetDdo5FXUC9X2H0hO0BqOcc+h2
khVyXjKt6TdBwV94dP9bmQp9QQKBgQDitPVqRHGepYBYasScyTUXPp7vL9T7PMSu
PGjqEkwvauhICpUbWpE/j8M0UXk64zwSmOYwQ37uiPpws88GL57oy1ZQZnhF/Hi3
tM00Mn4x0xbyONbWu/AcFIZwSeSL6QhHYfeyVj7Jb/lqUg8sMiGmO25JjlAQBfTb
vvBgEpcVzwKBgQCwgz87JWfLejIGMR2qcoj1A30IYmAh1377uwO0F0mc7PrYbBtE
N8IyUTR/bLGNocJME1b8vOWrmt19fRzlhp1t6C8prrSGzulULdbawQ4fAi7rhDek
Iqsg8FRVGSgAptsN2dDvbcDKUuycQtyHzsE0/J338IXozIHehkGBlNTggQKBgQCF
RDTj5BoaVVWuJA0x0UGJSYFqP2bmzWEcv1w5BMqOMT0cZEQkkUfC4oKwdZhbGosM
r57ZDkRGenUl3T08eK/kTuuNVb8r/O8Fpp3eKjRum5TojKsWDeJmz1X8GiPkbvcz
5w4RYouEJHOsoVJT+6A2NMdvK946nRXEO2jYQPVZlwKBgBro8qGm0+T0T2xNP21q
IzjP/EHT7iIkM5kiUCc2bPIrfzAGxXImakDzd6AgpgxhhficJOpp792Upe/b/Hwy
bwfmbdWlT7/hPCnlVVH2dgO/ysDyEfxPigBMd+MmucRm6fzGIU7XSQw4KJqH4vQN
9IASWlgzyQ1RytAduzRuepzB
-----END PRIVATE KEY-----
"""
const SelfSignedRsaCert = """
-----BEGIN CERTIFICATE-----
MIIDkzCCAnugAwIBAgIUEFovLJkPSn4T8BBZMYBrXPDajF0wDQYJKoZIhvcNAQEL
BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X
DTE5MTAxMjA1NDQ0N1oXDTIwMTAxMTA1NDQ0OFowWTELMAkGA1UEBhMCQVUxEzAR
BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5
IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEA0l3CjEesyCB0svlFI4MiJSVsxLtyRQhDfqgaY5sG7mntxjXndGID
3pbOgCntAJVkDYKaqFnzEK1nKKjX/XXJvTiHVCuRRjFmAvJibujo8J3sUl7dEf5N
O5XOXjzw1Gxr5G74A8h85dWJ442IXv9roy9BPffygLgO3L0SCxJeKy6Z54BPjDGD
axVprEg8wvsNB7+RbsnPMGaEsp8zg76I7xV072vTceC8TxdVu/VoncCR8KyYIVZR
lFIUnkRix0FkWPveeK4+k2r6CeoC2Eh8ehC50ZBTG6sUawidMAG61fEsHDx4yknO
i0BkTMwUasU+OQ6Cpsx/T8HZ8Oj9DxicjwIDAQABo1MwUTAdBgNVHQ4EFgQUMM+1
FZ6KmN2eCJfDxY+8xa1JKnYwHwYDVR0jBBgwFoAUMM+1FZ6KmN2eCJfDxY+8xa1J
KnYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEASglh98fXvPwA
KMaEezCUqTeE7DehLlhZ8n6ETKaBDcP3JR4+KTEh9y7gRGJ7DXGFAYfU3itgjyZo
kbgpZIhrTKyYCAsF96Q1mHf/cBQ96UXr0U0SbYXSSJFeeMthMvki556dJZajtxcA
9xR/U0PxPjhC9NIfpVSAv/7ocnXh73qOiFHoN9Cr2smzcGPxsifys2iv1qm5LwDr
Dx5h/RfyfuAjS8e1ZCAhS++PYjb8BX54NilW2lTYF3pwpXL8znc4eBmklBkw5L60
99jrK7LSQT9Nk8Mf9t4P/77N4hXCqsHIxZIqJlbdgdKfvBF3vRomxm3/aWtGlTVD
vvzZPnlYfQ==
-----END CERTIFICATE-----
"""
suite "AsyncStream test suite": suite "AsyncStream test suite":
test "AsyncStream(StreamTransport) readExactly() test": test "AsyncStream(StreamTransport) readExactly() test":
proc testReadExactly(address: TransportAddress): Future[bool] {.async.} = proc testReadExactly(address: TransportAddress): Future[bool] {.async.} =
@ -524,7 +580,7 @@ suite "TLSStream test suite":
var transp = await connect(address) var transp = await connect(address)
var reader = newAsyncStreamReader(transp) var reader = newAsyncStreamReader(transp)
var writer = newAsyncStreamWriter(transp) var writer = newAsyncStreamWriter(transp)
var tlsstream = newTlsClientAsyncStream(reader, writer, name) var tlsstream = newTLSClientAsyncStream(reader, writer, name)
await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name & await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name &
"\r\nConnection: close\r\n\r\n") "\r\nConnection: close\r\n\r\n")
@ -548,7 +604,52 @@ suite "TLSStream test suite":
"www.google.com")) "www.google.com"))
check res == true check res == true
test "TlsStream leaks test": proc checkSSLServer(address: TransportAddress,
pemkey, pemcert: string): Future[bool] {.async.} =
var key: TLSPrivateKey
var cert: TLSCertificate
let testMessage = "TEST MESSAGE"
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var reader = newAsyncStreamReader(transp)
var writer = newAsyncStreamWriter(transp)
var sstream = newTLSServerAsyncStream(reader, writer, key, cert)
await handshake(sstream)
await sstream.writer.write(testMessage & "\r\n")
await sstream.writer.closeWait()
await sstream.reader.closeWait()
await reader.closeWait()
await writer.closeWait()
await transp.closeWait()
server.stop()
server.close()
key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert)
var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start()
var conn = await connect(address)
var creader = newAsyncStreamReader(conn)
var cwriter = newAsyncStreamWriter(conn)
# We are using self-signed certificate
var cstream = newTLSClientAsyncStream(creader, cwriter,
flags = {NoVerifyHost})
let res = await cstream.reader.readLine()
await cstream.reader.closeWait()
await cstream.writer.closeWait()
await creader.closeWait()
await cwriter.closeWait()
await conn.closeWait()
await server.join()
result = res == testMessage
test "Simple server with RSA self-signed certificate":
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),
SelfSignedRsaKey, SelfSignedRsaCert))
check res == true
test "TLSStream leaks test":
check: check:
getTracker("async.stream.reader").isLeaked() == false getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false getTracker("async.stream.writer").isLeaked() == false