2019-10-08 15:46:27 +00:00
|
|
|
#
|
|
|
|
# Chronos Asynchronous TLS Stream
|
|
|
|
# (c) Copyright 2019-Present
|
|
|
|
# Status Research & Development GmbH
|
|
|
|
#
|
|
|
|
# Licensed under either of
|
|
|
|
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
|
|
|
# MIT license (LICENSE-MIT)
|
|
|
|
|
|
|
|
## This module implements TLS stream reading and writing.
|
|
|
|
import bearssl, bearssl/cacert
|
|
|
|
import ../asyncloop, ../timer, ../asyncsync
|
|
|
|
import asyncstream, ../transports/stream, ../transports/common
|
2019-10-08 17:30:43 +00:00
|
|
|
import strutils
|
2019-10-08 15:46:27 +00:00
|
|
|
|
|
|
|
type
|
|
|
|
TLSStreamKind {.pure.} = enum
|
|
|
|
Client, Server
|
|
|
|
|
|
|
|
TLSVersion* {.pure.} = enum
|
|
|
|
TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303
|
|
|
|
|
|
|
|
TLSFlags* {.pure.} = enum
|
|
|
|
NoVerifyHost, # Client: Skip remote certificate check
|
|
|
|
NoVerifySN, # Client: Skip Server Name Indication (SNI) check
|
|
|
|
NoRenegotiation, # Server: Reject renegotiations requests
|
|
|
|
NoClientAuth, # Server: Disable strict client authentication
|
|
|
|
FailOnAlpnMismatch # Server: Fail on application protocol mismatch
|
|
|
|
|
|
|
|
type
|
|
|
|
TlsStreamWriter* = ref object of AsyncStreamWriter
|
2019-10-08 17:30:43 +00:00
|
|
|
stream*: TlsAsyncStream
|
2019-10-08 15:46:27 +00:00
|
|
|
|
|
|
|
TlsStreamReader* = ref object of AsyncStreamReader
|
|
|
|
case kind: TlsStreamKind
|
|
|
|
of TlsStreamKind.Client:
|
|
|
|
ccontext: ptr SslClientContext
|
|
|
|
of TlsStreamKind.Server:
|
|
|
|
scontext: ptr SslServerContext
|
2019-10-08 17:30:43 +00:00
|
|
|
stream*: TlsAsyncStream
|
2019-10-08 15:46:27 +00:00
|
|
|
|
|
|
|
TlsAsyncStream* = ref object of RootRef
|
|
|
|
xwc*: X509NoAnchorContext
|
|
|
|
context*: SslClientContext
|
|
|
|
sbuffer*: seq[byte]
|
|
|
|
x509*: X509MinimalContext
|
|
|
|
reader*: TlsStreamReader
|
|
|
|
writer*: TlsStreamWriter
|
|
|
|
|
|
|
|
TlsStreamError* = object of CatchableError
|
|
|
|
TlsStreamProtocolError* = object of TlsStreamError
|
|
|
|
errCode*: int
|
|
|
|
|
|
|
|
template newTlsStreamProtocolError[T](message: T): ref Exception =
|
|
|
|
var msg = ""
|
|
|
|
var code = 0
|
|
|
|
when T is string:
|
|
|
|
msg.add(message)
|
|
|
|
elif T is cint:
|
|
|
|
msg.add(sslErrorMsg(message) & " (code: " & $int(message) & ")")
|
|
|
|
code = int(message)
|
|
|
|
elif T is int:
|
|
|
|
msg.add(sslErrorMsg(message) & " (code: " & $message & ")")
|
|
|
|
code = message
|
|
|
|
else:
|
|
|
|
msg.add("Internal Error")
|
|
|
|
var err = newException(TlsStreamProtocolError, msg)
|
|
|
|
err.errCode = code
|
|
|
|
err
|
|
|
|
|
|
|
|
# proc raiseTlsStreamProtoError*[T](message: T) =
|
|
|
|
# raise newTlsStreamProtocolError(message)
|
|
|
|
|
|
|
|
# proc getStringState*(state: cuint): string =
|
|
|
|
# var n = newSeq[string]()
|
|
|
|
# if (state and SSL_CLOSED) == SSL_CLOSED:
|
|
|
|
# n.add("Closed")
|
|
|
|
# if (state and SSL_SENDREC) == SSL_SENDREC:
|
|
|
|
# n.add("SendRec")
|
|
|
|
# if (state and SSL_RECVREC) == SSL_RECVREC:
|
|
|
|
# n.add("RecvRec")
|
|
|
|
# if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
|
|
|
# n.add("SendApp")
|
|
|
|
# if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
|
|
|
# n.add("RecvApp")
|
|
|
|
# result = "{" & n.join(", ") & "} number (" & $state & ")"
|
|
|
|
|
|
|
|
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
|
|
|
var wstream = cast[TlsStreamWriter](stream)
|
|
|
|
try:
|
|
|
|
# We waiting for empty future which will never be completed, because all
|
|
|
|
# the logic are inside of tlsReadLoop(). This infinite wait can be stopped
|
|
|
|
# by closing stream (e.g. cancellation).
|
|
|
|
var future = newFuture[void]()
|
|
|
|
await future
|
|
|
|
except CancelledError:
|
|
|
|
discard
|
2019-10-08 17:38:39 +00:00
|
|
|
finally:
|
|
|
|
wstream.stream = nil
|
2019-10-08 15:46:27 +00:00
|
|
|
|
|
|
|
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
|
|
|
var rstream = cast[TlsStreamReader](stream)
|
2019-10-08 17:30:43 +00:00
|
|
|
var wstream = rstream.stream.writer
|
2019-10-08 15:46:27 +00:00
|
|
|
var engine: ptr SslEngineContext
|
|
|
|
if rstream.kind == TlsStreamKind.Server:
|
|
|
|
engine = addr rstream.scontext.eng
|
|
|
|
else:
|
|
|
|
engine = addr rstream.ccontext.eng
|
|
|
|
|
|
|
|
try:
|
|
|
|
var length: uint
|
|
|
|
while true:
|
|
|
|
var state = engine.sslEngineCurrentState()
|
|
|
|
if (state and SSL_CLOSED) == SSL_CLOSED:
|
|
|
|
let err = engine.sslEngineLastError()
|
|
|
|
if err != 0:
|
|
|
|
rstream.error = newTlsStreamProtocolError(err)
|
|
|
|
rstream.state = AsyncStreamState.Error
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
rstream.state = AsyncStreamState.Stopped
|
|
|
|
break
|
|
|
|
|
|
|
|
if (state and SSL_SENDREC) == SSL_SENDREC:
|
|
|
|
# TLS record needs to be sent over stream.
|
|
|
|
length = 0'u
|
|
|
|
var buf = sslEngineSendrecBuf(engine, length)
|
|
|
|
doAssert(length != 0 and not isNil(buf))
|
|
|
|
await wstream.wsource.write(buf, int(length))
|
|
|
|
sslEngineSendrecAck(engine, length)
|
|
|
|
continue
|
|
|
|
|
|
|
|
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
|
|
|
# Application data can be sent over stream.
|
|
|
|
if len(wstream.queue) > 0:
|
|
|
|
var item = await wstream.queue.get()
|
|
|
|
if item.size > 0:
|
|
|
|
length = 0'u
|
|
|
|
var buf = sslEngineSendappBuf(engine, length)
|
|
|
|
let toWrite = min(int(length), item.size)
|
|
|
|
|
|
|
|
if int(length) >= item.size:
|
|
|
|
if item.kind == Pointer:
|
|
|
|
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
|
|
|
copyMem(buf, p, item.size)
|
|
|
|
elif item.kind == Sequence:
|
|
|
|
copyMem(buf, addr item.data2[item.offset], item.size)
|
|
|
|
elif item.kind == String:
|
|
|
|
copyMem(buf, addr item.data3[item.offset], item.size)
|
|
|
|
sslEngineSendappAck(engine, uint(item.size))
|
|
|
|
sslEngineFlush(engine, 0)
|
|
|
|
item.future.complete()
|
|
|
|
else:
|
|
|
|
if item.kind == Pointer:
|
|
|
|
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
|
|
|
copyMem(buf, p, length)
|
|
|
|
elif item.kind == Sequence:
|
|
|
|
copyMem(buf, addr item.data2[item.offset], length)
|
|
|
|
elif item.kind == String:
|
|
|
|
copyMem(buf, addr item.data3[item.offset], length)
|
|
|
|
item.offset = item.offset + int(length)
|
|
|
|
item.size = item.size - int(length)
|
|
|
|
wstream.queue.addFirstNoWait(item)
|
|
|
|
sslEngineSendappAck(engine, length)
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
# Zero length item means finish
|
|
|
|
rstream.state = AsyncStreamState.Finished
|
|
|
|
break
|
|
|
|
|
|
|
|
if (state and SSL_RECVREC) == SSL_RECVREC:
|
|
|
|
# TLS records required for further processing
|
|
|
|
length = 0'u
|
|
|
|
var buf = sslEngineRecvrecBuf(engine, length)
|
|
|
|
let res = await rstream.rsource.readOnce(buf, int(length))
|
|
|
|
if res > 0:
|
|
|
|
sslEngineRecvrecAck(engine, uint(res))
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
rstream.state = AsyncStreamState.Finished
|
|
|
|
break
|
|
|
|
|
|
|
|
if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
|
|
|
# Application data can be recovered.
|
|
|
|
length = 0'u
|
|
|
|
var buf = sslEngineRecvappBuf(engine, length)
|
|
|
|
let toRead = min(int(length), rstream.buffer.bufferLen())
|
|
|
|
copyMem(rstream.buffer.getBuffer(), buf, toRead)
|
|
|
|
rstream.buffer.update(toRead)
|
|
|
|
sslEngineRecvappAck(engine, uint(toRead))
|
|
|
|
await rstream.buffer.transfer()
|
|
|
|
continue
|
|
|
|
|
|
|
|
except CancelledError:
|
|
|
|
rstream.state = AsyncStreamState.Stopped
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# Perform TLS cleanup procedure
|
|
|
|
sslEngineClose(engine)
|
|
|
|
# Becase tlsWriteLoop() is ephemeral, but we still need to keep stream state
|
|
|
|
# consistent.
|
|
|
|
wstream.state = rstream.state
|
|
|
|
if rstream.state == AsyncStreamState.Finished:
|
|
|
|
rstream.buffer.forget()
|
|
|
|
elif rstream.state == AsyncStreamState.Stopped:
|
|
|
|
rstream.buffer.forget()
|
|
|
|
while len(wstream.queue) > 0:
|
|
|
|
let item = wstream.queue.popFirstNoWait()
|
|
|
|
if not(item.future.finished()):
|
|
|
|
item.future.complete()
|
|
|
|
elif rstream.state == AsyncStreamState.Error:
|
|
|
|
rstream.buffer.forget()
|
|
|
|
while len(wstream.queue) > 0:
|
|
|
|
let item = wstream.queue.popFirstNoWait()
|
|
|
|
if not(item.future.finished()):
|
|
|
|
item.future.fail(rstream.error)
|
2019-10-08 17:38:39 +00:00
|
|
|
rstream.stream = nil
|
2019-10-08 15:46:27 +00:00
|
|
|
|
|
|
|
proc newTlsClientAsyncStream*(rsource: AsyncStreamReader,
|
|
|
|
wsource: AsyncStreamWriter,
|
|
|
|
serverName: string = "",
|
|
|
|
bufferSize = SSL_BUFSIZE_BIDI,
|
|
|
|
minVersion = TLSVersion.TLS11,
|
|
|
|
maxVersion = TLSVersion.TLS12,
|
|
|
|
flags: set[TlsFlags] = {}): TlsAsyncStream =
|
|
|
|
## Create new TLS asynchronous stream using reading stream ``rsource``,
|
|
|
|
## writing stream ``wsource``.
|
|
|
|
##
|
|
|
|
## You can specify remote server name using ``serverName``, if while
|
|
|
|
## handshake server reports different name you will get an error. If
|
|
|
|
## ``serverName`` is empty string, remote server name checking will be
|
|
|
|
## disabled.
|
|
|
|
##
|
|
|
|
## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding
|
|
|
|
## incoming data.
|
|
|
|
##
|
|
|
|
## ``minVersion`` and ``maxVersion`` are TLS versions which will be used
|
|
|
|
## for handshake with remote server. If server's version will be lower then
|
|
|
|
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
|
|
|
|
##
|
|
|
|
## ``flags`` - custom TLS connection flags.
|
|
|
|
result = new TlsAsyncStream
|
|
|
|
var reader = new TlsStreamReader
|
|
|
|
reader.kind = TlsStreamKind.Client
|
|
|
|
var writer = new TlsStreamWriter
|
2019-10-08 17:30:43 +00:00
|
|
|
reader.stream = result
|
|
|
|
writer.stream = result
|
2019-10-08 15:46:27 +00:00
|
|
|
result.reader = reader
|
|
|
|
result.writer = writer
|
|
|
|
reader.ccontext = addr result.context
|
|
|
|
|
|
|
|
if TLSFlags.NoVerifyHost in flags:
|
|
|
|
sslClientInitFull(addr result.context, addr result.x509, nil, 0)
|
|
|
|
initNoAnchor(addr result.xwc, addr result.x509.vtable)
|
|
|
|
sslEngineSetX509(addr result.context.eng, addr result.xwc.vtable)
|
|
|
|
else:
|
|
|
|
sslClientInitFull(addr result.context, addr result.x509,
|
|
|
|
unsafeAddr MozillaTrustAnchors[0],
|
|
|
|
len(MozillaTrustAnchors))
|
|
|
|
|
|
|
|
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
|
|
|
|
result.sbuffer = newSeq[byte](size)
|
|
|
|
sslEngineSetBuffer(addr result.context.eng, addr result.sbuffer[0],
|
|
|
|
uint(len(result.sbuffer)), 1)
|
|
|
|
sslEngineSetVersions(addr result.context.eng, uint16(minVersion),
|
|
|
|
uint16(maxVersion))
|
|
|
|
|
|
|
|
if TLSFlags.NoVerifySN in flags:
|
|
|
|
let err = sslClientReset(addr result.context, "", 0)
|
|
|
|
if err == 0:
|
|
|
|
raise newException(TlsStreamError, "Could not initialize TLS layer")
|
|
|
|
else:
|
|
|
|
let err = sslClientReset(addr result.context, serverName, 0)
|
|
|
|
if err == 0:
|
|
|
|
raise newException(TlsStreamError, "Could not initialize TLS layer")
|
|
|
|
|
|
|
|
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
|
|
|
|
bufferSize)
|
|
|
|
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
|
|
|
|
bufferSize)
|