mirror of
https://github.com/status-im/nim-chronos.git
synced 2025-01-31 05:25:09 +00:00
commit
a291f26c82
@ -1,5 +1,5 @@
|
||||
packageName = "chronos"
|
||||
version = "2.3.0"
|
||||
version = "2.3.1"
|
||||
author = "Status Research & Development GmbH"
|
||||
description = "Chronos"
|
||||
license = "Apache License 2.0 or MIT"
|
||||
@ -7,7 +7,8 @@ skipDirs = @["tests"]
|
||||
|
||||
### Dependencies
|
||||
|
||||
requires "nim > 0.18.0"
|
||||
requires "nim > 0.19.4",
|
||||
"bearssl"
|
||||
|
||||
task test, "Run all tests":
|
||||
var commands = [
|
||||
|
@ -254,15 +254,19 @@ proc cancel[T](future: Future[T], loc: ptr SrcLoc) =
|
||||
else:
|
||||
var first = FutureBase(future)
|
||||
var last = first
|
||||
while (not isNil(last.child)) and (not(last.child.finished())):
|
||||
while not(isNil(last.child)) and not(last.child.cancelled()):
|
||||
last = last.child
|
||||
if last == first:
|
||||
checkFinished(future, loc)
|
||||
let isPending = (last.state == FutureState.Pending)
|
||||
last.state = FutureState.Cancelled
|
||||
last.error = newException(CancelledError, "")
|
||||
if not(isNil(last.cancelcb)):
|
||||
last.cancelcb(cast[pointer](last))
|
||||
last.callbacks.call()
|
||||
if isPending:
|
||||
# If Future's state was `Finished` or `Failed` callbacks are already
|
||||
# scheduled.
|
||||
last.callbacks.call()
|
||||
|
||||
template cancel*[T](future: Future[T]) =
|
||||
## Cancel ``future``.
|
||||
|
@ -141,12 +141,37 @@ proc copyData*(sb: AsyncBuffer, dest: pointer, offset, length: int) {.inline.} =
|
||||
copyMem(cast[pointer](cast[uint](dest) + cast[uint](offset)),
|
||||
unsafeAddr sb.buffer[0], length)
|
||||
|
||||
proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte,
|
||||
nbytes: int): Future[void] {.async.} =
|
||||
var length = nbytes
|
||||
while length > 0:
|
||||
let size = min(length, sb[].bufferLen())
|
||||
if size == 0:
|
||||
# Internal buffer is full, we need to transfer data to consumer.
|
||||
await sb[].transfer()
|
||||
continue
|
||||
else:
|
||||
copyMem(addr sb[].buffer[sb.offset], pbytes, size)
|
||||
sb[].offset = sb[].offset + size
|
||||
length = length - size
|
||||
# We notify consumers that new data is available.
|
||||
sb[].forget()
|
||||
|
||||
template toDataOpenArray*(sb: AsyncBuffer): auto =
|
||||
toOpenArray(sb.buffer, 0, sb.offset - 1)
|
||||
|
||||
template toBufferOpenArray*(sb: AsyncBuffer): auto =
|
||||
toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1)
|
||||
|
||||
template copyOut*(dest: pointer, item: WriteItem, length: int) =
|
||||
if item.kind == Pointer:
|
||||
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset))
|
||||
copyMem(dest, p, length)
|
||||
elif item.kind == Sequence:
|
||||
copyMem(dest, unsafeAddr item.data2[item.offset], length)
|
||||
elif item.kind == String:
|
||||
copyMem(dest, unsafeAddr item.data3[item.offset], length)
|
||||
|
||||
proc newAsyncStreamReadError(p: ref Exception): ref Exception {.inline.} =
|
||||
var w = newException(AsyncStreamReadError, "Read stream failed")
|
||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||
@ -270,6 +295,8 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer,
|
||||
if isNil(rstream.rsource):
|
||||
try:
|
||||
await readExactly(rstream.tsource, pbytes, nbytes)
|
||||
except CancelledError:
|
||||
raise
|
||||
except TransportIncompleteError:
|
||||
raise newAsyncStreamIncompleteError()
|
||||
except:
|
||||
@ -308,10 +335,12 @@ proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer,
|
||||
if isNil(rstream.rsource):
|
||||
try:
|
||||
result = await readOnce(rstream.tsource, pbytes, nbytes)
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamReadError(getCurrentException())
|
||||
else:
|
||||
if isNil(rstream.rsource):
|
||||
if isNil(rstream.readerLoop):
|
||||
result = await readOnce(rstream.rsource, pbytes, nbytes)
|
||||
else:
|
||||
while true:
|
||||
@ -351,6 +380,8 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int,
|
||||
if isNil(rstream.rsource):
|
||||
try:
|
||||
result = await readUntil(rstream.tsource, pbytes, nbytes, sep)
|
||||
except CancelledError:
|
||||
raise
|
||||
except TransportIncompleteError:
|
||||
raise newAsyncStreamIncompleteError()
|
||||
except TransportLimitError:
|
||||
@ -416,6 +447,8 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0,
|
||||
if isNil(rstream.rsource):
|
||||
try:
|
||||
result = await readLine(rstream.tsource, limit, sep)
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamReadError(getCurrentException())
|
||||
else:
|
||||
@ -469,6 +502,8 @@ proc read*(rstream: AsyncStreamReader, n = 0): Future[seq[byte]] {.async.} =
|
||||
if isNil(rstream.rsource):
|
||||
try:
|
||||
result = await read(rstream.tsource, n)
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamReadError(getCurrentException())
|
||||
else:
|
||||
@ -517,6 +552,8 @@ proc consume*(rstream: AsyncStreamReader, n = -1): Future[int] {.async.} =
|
||||
if isNil(rstream.rsource):
|
||||
try:
|
||||
result = await consume(rstream.tsource, n)
|
||||
except CancelledError:
|
||||
raise
|
||||
except TransportLimitError:
|
||||
raise newAsyncStreamLimitError()
|
||||
except:
|
||||
@ -569,6 +606,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
|
||||
var res: int
|
||||
try:
|
||||
res = await write(wstream.tsource, pbytes, nbytes)
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(getCurrentException())
|
||||
if res != nbytes:
|
||||
@ -584,6 +623,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
|
||||
await wstream.queue.put(item)
|
||||
try:
|
||||
await item.future
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(item.future.error)
|
||||
|
||||
@ -608,6 +649,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte],
|
||||
var res: int
|
||||
try:
|
||||
res = await write(wstream.tsource, sbytes, msglen)
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(getCurrentException())
|
||||
if res != length:
|
||||
@ -626,6 +669,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte],
|
||||
await wstream.queue.put(item)
|
||||
try:
|
||||
await item.future
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(item.future.error)
|
||||
|
||||
@ -649,6 +694,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string,
|
||||
var res: int
|
||||
try:
|
||||
res = await write(wstream.tsource, sbytes, msglen)
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(getCurrentException())
|
||||
if res != length:
|
||||
@ -667,6 +714,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string,
|
||||
await wstream.queue.put(item)
|
||||
try:
|
||||
await item.future
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(item.future.error)
|
||||
|
||||
@ -685,6 +734,8 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} =
|
||||
await wstream.queue.put(item)
|
||||
try:
|
||||
await item.future
|
||||
except CancelledError:
|
||||
raise
|
||||
except:
|
||||
raise newAsyncStreamWriteError(item.future.error)
|
||||
|
||||
@ -735,8 +786,8 @@ proc close*(rw: AsyncStreamRW) =
|
||||
if rw.future.finished():
|
||||
callSoon(continuation)
|
||||
else:
|
||||
rw.future.cancel()
|
||||
rw.future.addCallback(continuation)
|
||||
rw.future.cancel()
|
||||
elif rw is AsyncStreamWriter:
|
||||
if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future):
|
||||
callSoon(continuation)
|
||||
@ -744,8 +795,8 @@ proc close*(rw: AsyncStreamRW) =
|
||||
if rw.future.finished():
|
||||
callSoon(continuation)
|
||||
else:
|
||||
rw.future.cancel()
|
||||
rw.future.addCallback(continuation)
|
||||
rw.future.cancel()
|
||||
|
||||
proc closeWait*(rw: AsyncStreamRW): Future[void] =
|
||||
## Close and frees resources of stream ``rw``.
|
||||
|
618
chronos/streams/tlsstream.nim
Normal file
618
chronos/streams/tlsstream.nim
Normal file
@ -0,0 +1,618 @@
|
||||
#
|
||||
# Chronos Asynchronous TLS Stream
|
||||
# (c) Copyright 2019-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
|
||||
## This module implements Transport Layer Security (TLS) stream. This module
|
||||
## uses sources of BearSSL <https://www.bearssl.org> by Thomas Pornin.
|
||||
import bearssl, bearssl/cacert
|
||||
import ../asyncloop, ../timer, ../asyncsync
|
||||
import asyncstream, ../transports/stream, ../transports/common
|
||||
|
||||
type
|
||||
TLSStreamKind {.pure.} = enum
|
||||
Client, Server
|
||||
|
||||
TLSVersion* {.pure.} = enum
|
||||
TLS10 = 0x0301, TLS11 = 0x0302, TLS12 = 0x0303
|
||||
|
||||
TLSFlags* {.pure.} = enum
|
||||
NoVerifyHost, # Client: Skip remote certificate check
|
||||
NoVerifyServerName, # Client: Skip Server Name Indication (SNI) check
|
||||
EnforceServerPref, # Server: Enforce server preferences
|
||||
NoRenegotiation, # Server: Reject renegotiations requests
|
||||
TolerateNoClientAuth, # Server: Disable strict client authentication
|
||||
FailOnAlpnMismatch # Server: Fail on application protocol mismatch
|
||||
|
||||
TLSKeyType {.pure.} = enum
|
||||
RSA, EC
|
||||
|
||||
TLSPrivateKey* = ref object
|
||||
case kind: TLSKeyType
|
||||
of RSA:
|
||||
rsakey: RsaPrivateKey
|
||||
of EC:
|
||||
eckey: EcPrivateKey
|
||||
storage: seq[byte]
|
||||
|
||||
TLSCertificate* = ref object
|
||||
certs: seq[X509Certificate]
|
||||
storage: seq[byte]
|
||||
|
||||
TLSSessionCache* = ref object
|
||||
storage: seq[byte]
|
||||
context: SslSessionCacheLru
|
||||
|
||||
PEMElement* = object
|
||||
name*: string
|
||||
data*: seq[byte]
|
||||
|
||||
PEMContext = ref object
|
||||
data: seq[byte]
|
||||
|
||||
TLSStreamWriter* = ref object of AsyncStreamWriter
|
||||
case kind: TLSStreamKind
|
||||
of TLSStreamKind.Client:
|
||||
ccontext: ptr SslClientContext
|
||||
of TLSStreamKind.Server:
|
||||
scontext: ptr SslServerContext
|
||||
stream*: TLSAsyncStream
|
||||
switchToReader*: AsyncEvent
|
||||
switchToWriter*: AsyncEvent
|
||||
handshaked*: bool
|
||||
handshakeFut*: Future[void]
|
||||
|
||||
TLSStreamReader* = ref object of AsyncStreamReader
|
||||
case kind: TLSStreamKind
|
||||
of TLSStreamKind.Client:
|
||||
ccontext: ptr SslClientContext
|
||||
of TLSStreamKind.Server:
|
||||
scontext: ptr SslServerContext
|
||||
stream*: TLSAsyncStream
|
||||
switchToReader*: AsyncEvent
|
||||
switchToWriter*: AsyncEvent
|
||||
handshaked*: bool
|
||||
handshakeFut*: Future[void]
|
||||
|
||||
TLSAsyncStream* = ref object of RootRef
|
||||
xwc*: X509NoAnchorContext
|
||||
ccontext*: SslClientContext
|
||||
scontext*: SslServerContext
|
||||
sbuffer*: seq[byte]
|
||||
x509*: X509MinimalContext
|
||||
reader*: TLSStreamReader
|
||||
writer*: TLSStreamWriter
|
||||
|
||||
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
|
||||
|
||||
TLSStreamError* = object of CatchableError
|
||||
TLSStreamProtocolError* = object of TLSStreamError
|
||||
errCode*: int
|
||||
|
||||
template newTLSStreamProtocolError[T](message: T): ref Exception =
|
||||
var msg = ""
|
||||
var code = 0
|
||||
when T is string:
|
||||
msg.add(message)
|
||||
elif T is cint:
|
||||
msg.add(sslErrorMsg(message) & " (code: " & $int(message) & ")")
|
||||
code = int(message)
|
||||
elif T is int:
|
||||
msg.add(sslErrorMsg(message) & " (code: " & $message & ")")
|
||||
code = message
|
||||
else:
|
||||
msg.add("Internal Error")
|
||||
var err = newException(TLSStreamProtocolError, msg)
|
||||
err.errCode = code
|
||||
err
|
||||
|
||||
proc raiseTLSStreamProtoError*[T](message: T) =
|
||||
raise newTLSStreamProtocolError(message)
|
||||
|
||||
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} =
|
||||
var wstream = cast[TLSStreamWriter](stream)
|
||||
var engine: ptr SslEngineContext
|
||||
var error: ref Exception
|
||||
|
||||
if wstream.kind == TLSStreamKind.Server:
|
||||
engine = addr wstream.scontext.eng
|
||||
else:
|
||||
engine = addr wstream.ccontext.eng
|
||||
|
||||
wstream.state = AsyncStreamState.Running
|
||||
|
||||
try:
|
||||
var length: uint
|
||||
while true:
|
||||
var state = engine.sslEngineCurrentState()
|
||||
|
||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||
wstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
|
||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) != 0:
|
||||
if not(wstream.switchToReader.isSet()):
|
||||
wstream.switchToReader.fire()
|
||||
|
||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) == 0:
|
||||
await wstream.switchToWriter.wait()
|
||||
wstream.switchToWriter.clear()
|
||||
# We need to refresh `state` because we just returned from readerLoop.
|
||||
continue
|
||||
|
||||
if (state and SSL_SENDREC) == SSL_SENDREC:
|
||||
# TLS record needs to be sent over stream.
|
||||
length = 0'u
|
||||
var buf = sslEngineSendrecBuf(engine, length)
|
||||
doAssert(length != 0 and not isNil(buf))
|
||||
var fut = awaitne wstream.wsource.write(buf, int(length))
|
||||
if fut.cancelled():
|
||||
raise fut.error
|
||||
elif fut.failed():
|
||||
error = fut.error
|
||||
break
|
||||
sslEngineSendrecAck(engine, length)
|
||||
continue
|
||||
|
||||
if (state and SSL_SENDAPP) == SSL_SENDAPP:
|
||||
# Application data can be sent over stream.
|
||||
if not(wstream.handshaked):
|
||||
wstream.stream.reader.handshaked = true
|
||||
wstream.handshaked = true
|
||||
if not(isNil(wstream.handshakeFut)):
|
||||
wstream.handshakeFut.complete()
|
||||
|
||||
var item = await wstream.queue.get()
|
||||
if item.size > 0:
|
||||
length = 0'u
|
||||
var buf = sslEngineSendappBuf(engine, length)
|
||||
let toWrite = min(int(length), item.size)
|
||||
copyOut(buf, item, toWrite)
|
||||
if int(length) >= item.size:
|
||||
# BearSSL is ready to accept whole item size.
|
||||
sslEngineSendappAck(engine, uint(item.size))
|
||||
sslEngineFlush(engine, 0)
|
||||
item.future.complete()
|
||||
else:
|
||||
# BearSSL is not ready to accept whole item, so we will send only
|
||||
# part of item and adjust offset.
|
||||
item.offset = item.offset + int(length)
|
||||
item.size = item.size - int(length)
|
||||
wstream.queue.addFirstNoWait(item)
|
||||
sslEngineSendappAck(engine, length)
|
||||
continue
|
||||
else:
|
||||
# Zero length item means finish
|
||||
wstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
|
||||
except CancelledError:
|
||||
wstream.state = AsyncStreamState.Stopped
|
||||
|
||||
finally:
|
||||
if wstream.state == AsyncStreamState.Stopped:
|
||||
while len(wstream.queue) > 0:
|
||||
let item = wstream.queue.popFirstNoWait()
|
||||
if not(item.future.finished()):
|
||||
item.future.complete()
|
||||
elif wstream.state == AsyncStreamState.Error:
|
||||
while len(wstream.queue) > 0:
|
||||
let item = wstream.queue.popFirstNoWait()
|
||||
if not(item.future.finished()):
|
||||
item.future.fail(error)
|
||||
wstream.stream = nil
|
||||
|
||||
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} =
|
||||
var rstream = cast[TLSStreamReader](stream)
|
||||
var engine: ptr SslEngineContext
|
||||
|
||||
if rstream.kind == TLSStreamKind.Server:
|
||||
engine = addr rstream.scontext.eng
|
||||
else:
|
||||
engine = addr rstream.ccontext.eng
|
||||
|
||||
rstream.state = AsyncStreamState.Running
|
||||
|
||||
try:
|
||||
var length: uint
|
||||
while true:
|
||||
var state = engine.sslEngineCurrentState()
|
||||
if (state and SSL_CLOSED) == SSL_CLOSED:
|
||||
let err = engine.sslEngineLastError()
|
||||
if err != 0:
|
||||
raise newTLSStreamProtocolError(err)
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
break
|
||||
|
||||
if (state and (SSL_SENDREC or SSL_SENDAPP)) != 0:
|
||||
if not(rstream.switchToWriter.isSet()):
|
||||
rstream.switchToWriter.fire()
|
||||
|
||||
if (state and (SSL_RECVREC or SSL_RECVAPP)) == 0:
|
||||
await rstream.switchToReader.wait()
|
||||
rstream.switchToReader.clear()
|
||||
# We need to refresh `state` because we just returned from writerLoop.
|
||||
continue
|
||||
|
||||
if (state and SSL_RECVREC) == SSL_RECVREC:
|
||||
# TLS records required for further processing
|
||||
length = 0'u
|
||||
var buf = sslEngineRecvrecBuf(engine, length)
|
||||
let res = await rstream.rsource.readOnce(buf, int(length))
|
||||
if res > 0:
|
||||
sslEngineRecvrecAck(engine, uint(res))
|
||||
continue
|
||||
else:
|
||||
rstream.state = AsyncStreamState.Finished
|
||||
break
|
||||
|
||||
if (state and SSL_RECVAPP) == SSL_RECVAPP:
|
||||
# Application data can be recovered.
|
||||
length = 0'u
|
||||
var buf = sslEngineRecvappBuf(engine, length)
|
||||
await upload(addr rstream.buffer, buf, int(length))
|
||||
sslEngineRecvappAck(engine, length)
|
||||
continue
|
||||
|
||||
except CancelledError:
|
||||
rstream.state = AsyncStreamState.Stopped
|
||||
except AsyncStreamReadError:
|
||||
rstream.error = getCurrentException()
|
||||
rstream.state = AsyncStreamState.Error
|
||||
if not(rstream.handshaked):
|
||||
rstream.handshaked = true
|
||||
rstream.stream.writer.handshaked = true
|
||||
if not(isNil(rstream.handshakeFut)):
|
||||
rstream.handshakeFut.fail(rstream.error)
|
||||
rstream.switchToWriter.fire()
|
||||
finally:
|
||||
# Perform TLS cleanup procedure
|
||||
sslEngineClose(engine)
|
||||
rstream.buffer.forget()
|
||||
rstream.stream = nil
|
||||
|
||||
proc getSignerAlgo(xc: X509Certificate): int =
|
||||
## Get certificate's signing algorithm.
|
||||
var dc: X509DecoderContext
|
||||
x509DecoderInit(addr dc, nil, nil)
|
||||
x509DecoderPush(addr dc, xc.data, xc.dataLen)
|
||||
let err = x509DecoderLastError(addr dc)
|
||||
if err != 0:
|
||||
result = -1
|
||||
else:
|
||||
result = int(x509DecoderGetSignerKeyType(addr dc))
|
||||
|
||||
proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
|
||||
wsource: AsyncStreamWriter,
|
||||
serverName: string,
|
||||
bufferSize = SSL_BUFSIZE_BIDI,
|
||||
minVersion = TLSVersion.TLS11,
|
||||
maxVersion = TLSVersion.TLS12,
|
||||
flags: set[TLSFlags] = {}): TLSAsyncStream =
|
||||
## Create new TLS asynchronous stream for outbound (client) connections
|
||||
## using reading stream ``rsource`` and writing stream ``wsource``.
|
||||
##
|
||||
## You can specify remote server name using ``serverName``, if while
|
||||
## handshake server reports different name you will get an error. If
|
||||
## ``serverName`` is empty string, remote server name checking will be
|
||||
## disabled.
|
||||
##
|
||||
## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding
|
||||
## incoming data.
|
||||
##
|
||||
## ``minVersion`` and ``maxVersion`` are TLS versions which will be used
|
||||
## for handshake with remote server. If server's version will be lower then
|
||||
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
|
||||
##
|
||||
## ``flags`` - custom TLS connection flags.
|
||||
result = new TLSAsyncStream
|
||||
var reader = new TLSStreamReader
|
||||
reader.kind = TLSStreamKind.Client
|
||||
var writer = new TLSStreamWriter
|
||||
writer.kind = TLSStreamKind.Client
|
||||
var switchToWriter = newAsyncEvent()
|
||||
var switchToReader = newAsyncEvent()
|
||||
reader.stream = result
|
||||
writer.stream = result
|
||||
reader.switchToReader = switchToReader
|
||||
reader.switchToWriter = switchToWriter
|
||||
writer.switchToReader = switchToReader
|
||||
writer.switchToWriter = switchToWriter
|
||||
result.reader = reader
|
||||
result.writer = writer
|
||||
reader.ccontext = addr result.ccontext
|
||||
writer.ccontext = addr result.ccontext
|
||||
|
||||
if TLSFlags.NoVerifyHost in flags:
|
||||
sslClientInitFull(addr result.ccontext, addr result.x509, nil, 0)
|
||||
initNoAnchor(addr result.xwc, addr result.x509.vtable)
|
||||
sslEngineSetX509(addr result.ccontext.eng, addr result.xwc.vtable)
|
||||
else:
|
||||
sslClientInitFull(addr result.ccontext, addr result.x509,
|
||||
unsafeAddr MozillaTrustAnchors[0],
|
||||
len(MozillaTrustAnchors))
|
||||
|
||||
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
|
||||
result.sbuffer = newSeq[byte](size)
|
||||
sslEngineSetBuffer(addr result.ccontext.eng, addr result.sbuffer[0],
|
||||
uint(len(result.sbuffer)), 1)
|
||||
sslEngineSetVersions(addr result.ccontext.eng, uint16(minVersion),
|
||||
uint16(maxVersion))
|
||||
|
||||
if TLSFlags.NoVerifyServerName in flags:
|
||||
let err = sslClientReset(addr result.ccontext, "", 0)
|
||||
if err == 0:
|
||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
||||
else:
|
||||
if len(serverName) == 0:
|
||||
raise newException(TLSStreamError, "serverName must not be empty string")
|
||||
|
||||
let err = sslClientReset(addr result.ccontext, serverName, 0)
|
||||
if err == 0:
|
||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
||||
|
||||
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
|
||||
bufferSize)
|
||||
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
|
||||
bufferSize)
|
||||
|
||||
proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
|
||||
wsource: AsyncStreamWriter,
|
||||
privateKey: TLSPrivateKey,
|
||||
certificate: TLSCertificate,
|
||||
bufferSize = SSL_BUFSIZE_BIDI,
|
||||
minVersion = TLSVersion.TLS11,
|
||||
maxVersion = TLSVersion.TLS12,
|
||||
cache: TLSSessionCache = nil,
|
||||
flags: set[TLSFlags] = {}): TLSAsyncStream =
|
||||
## Create new TLS asynchronous stream for inbound (server) connections
|
||||
## using reading stream ``rsource`` and writing stream ``wsource``.
|
||||
##
|
||||
## You need to specify local private key ``privateKey`` and certificate
|
||||
## ``certificate``.
|
||||
##
|
||||
## ``bufferSize`` - is SSL/TLS buffer which is used for encoding/decoding
|
||||
## incoming data.
|
||||
##
|
||||
## ``minVersion`` and ``maxVersion`` are TLS versions which will be used
|
||||
## for handshake with remote server. If server's version will be lower then
|
||||
## ``minVersion`` of bigger then ``maxVersion`` you will get an error.
|
||||
##
|
||||
## ``flags`` - custom TLS connection flags.
|
||||
if isNil(privateKey) or privateKey.kind notin {TLSKeyType.RSA, TLSKeyType.EC}:
|
||||
raiseTLSStreamProtoError("Incorrect private key")
|
||||
if isNil(certificate) or len(certificate.certs) == 0:
|
||||
raiseTLSStreamProtoError("Incorrect certificate")
|
||||
|
||||
result = new TLSAsyncStream
|
||||
var reader = new TLSStreamReader
|
||||
reader.kind = TLSStreamKind.Server
|
||||
var writer = new TLSStreamWriter
|
||||
writer.kind = TLSStreamKind.Server
|
||||
var switchToWriter = newAsyncEvent()
|
||||
var switchToReader = newAsyncEvent()
|
||||
reader.stream = result
|
||||
writer.stream = result
|
||||
reader.switchToReader = switchToReader
|
||||
reader.switchToWriter = switchToWriter
|
||||
writer.switchToReader = switchToReader
|
||||
writer.switchToWriter = switchToWriter
|
||||
result.reader = reader
|
||||
result.writer = writer
|
||||
reader.scontext = addr result.scontext
|
||||
writer.scontext = addr result.scontext
|
||||
|
||||
if privateKey.kind == TLSKeyType.EC:
|
||||
let algo = getSignerAlgo(certificate.certs[0])
|
||||
if algo == -1:
|
||||
raiseTLSStreamProtoError("Could not decode certificate")
|
||||
sslServerInitFullEc(addr result.scontext, addr certificate.certs[0],
|
||||
len(certificate.certs), cuint(algo),
|
||||
addr privateKey.eckey)
|
||||
elif privateKey.kind == TLSKeyType.RSA:
|
||||
sslServerInitFullRsa(addr result.scontext, addr certificate.certs[0],
|
||||
len(certificate.certs), addr privateKey.rsakey)
|
||||
|
||||
let size = max(SSL_BUFSIZE_BIDI, bufferSize)
|
||||
result.sbuffer = newSeq[byte](size)
|
||||
sslEngineSetBuffer(addr result.scontext.eng, addr result.sbuffer[0],
|
||||
uint(len(result.sbuffer)), 1)
|
||||
sslEngineSetVersions(addr result.scontext.eng, uint16(minVersion),
|
||||
uint16(maxVersion))
|
||||
|
||||
if not isNil(cache):
|
||||
sslServerSetCache(addr result.scontext, addr cache.context.vtable)
|
||||
|
||||
if TLSFlags.EnforceServerPref in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
|
||||
if TLSFlags.NoRenegotiation in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_NO_RENEGOTIATION)
|
||||
if TLSFlags.TolerateNoClientAuth in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_TOLERATE_NO_CLIENT_AUTH)
|
||||
if TLSFlags.FailOnAlpnMismatch in flags:
|
||||
sslEngineAddFlags(addr result.scontext.eng, OPT_FAIL_ON_ALPN_MISMATCH)
|
||||
|
||||
let err = sslServerReset(addr result.scontext)
|
||||
if err == 0:
|
||||
raise newException(TLSStreamError, "Could not initialize TLS layer")
|
||||
|
||||
init(cast[AsyncStreamWriter](result.writer), wsource, tlsWriteLoop,
|
||||
bufferSize)
|
||||
init(cast[AsyncStreamReader](result.reader), rsource, tlsReadLoop,
|
||||
bufferSize)
|
||||
|
||||
proc copyKey(src: RsaPrivateKey): TLSPrivateKey =
|
||||
## Creates copy of RsaPrivateKey ``src``.
|
||||
var offset = 0
|
||||
let keySize = src.plen + src.qlen + src.dplen + src.dqlen + src.iqlen
|
||||
result = TLSPrivateKey(kind: TLSKeyType.RSA)
|
||||
result.storage = newSeq[byte](keySize)
|
||||
copyMem(addr result.storage[offset], src.p, src.plen)
|
||||
result.rsakey.p = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.plen = src.plen
|
||||
offset = offset + src.plen
|
||||
copyMem(addr result.storage[offset], src.q, src.qlen)
|
||||
result.rsakey.q = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.qlen = src.qlen
|
||||
offset = offset + src.qlen
|
||||
copyMem(addr result.storage[offset], src.dp, src.dplen)
|
||||
result.rsakey.dp = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.dplen = src.dplen
|
||||
offset = offset + src.dplen
|
||||
copyMem(addr result.storage[offset], src.dq, src.dqlen)
|
||||
result.rsakey.dq = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.dqlen = src.dqlen
|
||||
offset = offset + src.dqlen
|
||||
copyMem(addr result.storage[offset], src.iq, src.iqlen)
|
||||
result.rsakey.iq = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.rsakey.iqlen = src.iqlen
|
||||
result.rsakey.nBitlen = src.nBitlen
|
||||
|
||||
proc copyKey(src: EcPrivateKey): TLSPrivateKey =
|
||||
## Creates copy of EcPrivateKey ``src``.
|
||||
var offset = 0
|
||||
let keySize = src.xlen
|
||||
result = TLSPrivateKey(kind: TLSKeyType.EC)
|
||||
result.storage = newSeq[byte](keySize)
|
||||
copyMem(addr result.storage[offset], src.x, src.xlen)
|
||||
result.eckey.x = cast[ptr cuchar](addr result.storage[offset])
|
||||
result.eckey.xlen = src.xlen
|
||||
result.eckey.curve = src.curve
|
||||
|
||||
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[byte]): TLSPrivateKey =
|
||||
## Initialize TLS private key from array of bytes ``data``.
|
||||
##
|
||||
## This procedure initializes private key using raw, DER-encoded format,
|
||||
## or wrapped in an unencrypted PKCS#8 archive (again DER-encoded).
|
||||
var ctx: SkeyDecoderContext
|
||||
if len(data) == 0:
|
||||
raiseTLSStreamProtoError("Incorrect private key")
|
||||
skeyDecoderInit(addr ctx)
|
||||
skeyDecoderPush(addr ctx, cast[pointer](unsafeAddr data[0]), len(data))
|
||||
let err = skeyDecoderLastError(addr ctx)
|
||||
if err != 0:
|
||||
raiseTLSStreamProtoError(err)
|
||||
let keyType = skeyDecoderKeyType(addr ctx)
|
||||
if keyType == KEYTYPE_RSA:
|
||||
result = copyKey(ctx.key.rsa)
|
||||
elif keyType == KEYTYPE_EC:
|
||||
result = copyKey(ctx.key.ec)
|
||||
else:
|
||||
raiseTLSStreamProtoError("Unknown key type (" & $keyType & ")")
|
||||
|
||||
proc pemDecode*(data: openarray[char]): seq[PEMElement] =
|
||||
## Decode PEM encoded string and get array of binary blobs.
|
||||
if len(data) == 0:
|
||||
raiseTLSStreamProtoError("Empty PEM message")
|
||||
var ctx: PemDecoderContext
|
||||
var pctx = new PEMContext
|
||||
result = newSeq[PEMElement]()
|
||||
pemDecoderInit(addr ctx)
|
||||
|
||||
proc itemAppend(ctx: pointer, pbytes: pointer, nbytes: int) {.cdecl.} =
|
||||
var p = cast[PEMContext](ctx)
|
||||
var o = len(p.data)
|
||||
p.data.setLen(o + nbytes)
|
||||
copyMem(addr p.data[o], pbytes, nbytes)
|
||||
|
||||
var length = len(data)
|
||||
var offset = 0
|
||||
var inobj = false
|
||||
var elem: PEMElement
|
||||
|
||||
while length > 0:
|
||||
var tlen = pemDecoderPush(addr ctx,
|
||||
cast[pointer](unsafeAddr data[offset]), length)
|
||||
offset = offset + tlen
|
||||
length = length - tlen
|
||||
|
||||
let event = pemDecoderEvent(addr ctx)
|
||||
if event == PEM_BEGIN_OBJ:
|
||||
inobj = true
|
||||
elem.name = $pemDecoderName(addr ctx)
|
||||
pctx.data = newSeq[byte]()
|
||||
pemDecoderSetdest(addr ctx, itemAppend, cast[pointer](pctx))
|
||||
elif event == PEM_END_OBJ:
|
||||
if inobj:
|
||||
elem.data = pctx.data
|
||||
result.add(elem)
|
||||
inobj = false
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raiseTLSStreamProtoError("Invalid PEM encoding")
|
||||
|
||||
proc init*(tt: typedesc[TLSPrivateKey], data: openarray[char]): TLSPrivateKey =
|
||||
## Initialize TLS private key from string ``data``.
|
||||
##
|
||||
## This procedure initializes private key using unencrypted PKCS#8 PEM
|
||||
## encoded string.
|
||||
##
|
||||
## Note that PKCS#1 PEM encoded objects are not supported.
|
||||
var items = pemDecode(data)
|
||||
for item in items:
|
||||
if item.name == "PRIVATE KEY":
|
||||
result = TLSPrivateKey.init(item.data)
|
||||
break
|
||||
if isNil(result):
|
||||
raiseTLSStreamProtoError("Could not find private key")
|
||||
|
||||
proc init*(tt: typedesc[TLSCertificate],
|
||||
data: openarray[char]): TLSCertificate =
|
||||
## Initialize TLS certificates from string ``data``.
|
||||
##
|
||||
## This procedure initializes array of certificates from PEM encoded string.
|
||||
var items = pemDecode(data)
|
||||
result = new TLSCertificate
|
||||
for item in items:
|
||||
if item.name == "CERTIFICATE" and len(item.data) > 0:
|
||||
let offset = len(result.storage)
|
||||
result.storage.add(item.data)
|
||||
let cert = X509Certificate(
|
||||
data: cast[ptr cuchar](addr result.storage[offset]),
|
||||
dataLen: len(item.data)
|
||||
)
|
||||
let res = getSignerAlgo(cert)
|
||||
if res == -1:
|
||||
raiseTLSStreamProtoError("Could not decode certificate")
|
||||
elif res != KEYTYPE_RSA and res != KEYTYPE_EC:
|
||||
raiseTLSStreamProtoError("Unsupported signing key type in certificate")
|
||||
result.certs.add(cert)
|
||||
if len(result.storage) == 0:
|
||||
raiseTLSStreamProtoError("Could not find any certificates")
|
||||
|
||||
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache =
|
||||
## Create new TLS session cache with size ``size``.
|
||||
##
|
||||
## One cached item is near 100 bytes size.
|
||||
result = new TLSSessionCache
|
||||
var rsize = min(size, 4096)
|
||||
result.storage = newSeq[byte](rsize)
|
||||
sslSessionCacheLruInit(addr result.context, addr result.storage[0], rsize)
|
||||
|
||||
proc handshake*(rws: SomeTLSStreamType): Future[void] =
|
||||
## Wait until initial TLS handshake will be successfully performed.
|
||||
var retFuture = newFuture[void]("tlsstream.handshake")
|
||||
when rws is TLSStreamReader:
|
||||
if rws.handshaked:
|
||||
retFuture.complete()
|
||||
else:
|
||||
rws.handshakeFut = retFuture
|
||||
rws.stream.writer.handshakeFut = retFuture
|
||||
elif rws is TLSStreamWriter:
|
||||
if rws.handshaked:
|
||||
retFuture.complete()
|
||||
else:
|
||||
rws.handshakeFut = retFuture
|
||||
rws.stream.reader.handshakeFut = retFuture
|
||||
elif rws is TLSAsyncStream:
|
||||
if rws.reader.handshaked:
|
||||
retFuture.complete()
|
||||
else:
|
||||
rws.reader.handshakeFut = retFuture
|
||||
rws.writer.handshakeFut = retFuture
|
||||
return retFuture
|
@ -6,7 +6,63 @@
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import strutils, unittest, os
|
||||
import ../chronos
|
||||
import ../chronos, ../chronos/streams/tlsstream
|
||||
|
||||
const SelfSignedRsaKey = """
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDSXcKMR6zIIHSy
|
||||
+UUjgyIlJWzEu3JFCEN+qBpjmwbuae3GNed0YgPels6AKe0AlWQNgpqoWfMQrWco
|
||||
qNf9dcm9OIdUK5FGMWYC8mJu6OjwnexSXt0R/k07lc5ePPDUbGvkbvgDyHzl1Ynj
|
||||
jYhe/2ujL0E99/KAuA7cvRILEl4rLpnngE+MMYNrFWmsSDzC+w0Hv5Fuyc8wZoSy
|
||||
nzODvojvFXTva9Nx4LxPF1W79WidwJHwrJghVlGUUhSeRGLHQWRY+954rj6TavoJ
|
||||
6gLYSHx6ELnRkFMbqxRrCJ0wAbrV8SwcPHjKSc6LQGRMzBRqxT45DoKmzH9Pwdnw
|
||||
6P0PGJyPAgMBAAECggEAd5Ck39hpIwIXchXtrwZ8ZMKFtLeZdhUBT7656P0XDnEU
|
||||
nQDMQcDn1B7A1eV+eENwr6EYyDD/zu3P4TM+OCg3dp3nhPaSRmQTR/995O3qX8BS
|
||||
rmqOmgiA2yoFNljKxOGu3RIZUwUjv/oDulsaNGxWQFS+bzs7EOAMSngIBlT1QvLc
|
||||
1etvGOW6hc4nzacrXpzhzWem6EzOabPZBmIy0DDz2ARlND1YSd2P5IMpOv0terNF
|
||||
ZwYl5SZ6Rnbv6GJWKl5IIpJOOtRtwyIhKNU/bd835vOfW07aaHVAT6GNYlyEoGWT
|
||||
36UjOyl3YxSLNvQOeIz4Y0n+vupBu/YL+mFtQdxJgQKBgQDtjEi0cNlt+Vz5sp4q
|
||||
wAVBJ/6h7hu3KJkY+xDpdrLyFTcxttKM9q/dR8V2bEaqMYT/mTvADfmW2BBcfi2J
|
||||
3VdR1lQ5pXeKAuxt1/Vc+Q4UCnX5OX7UXpP9aSetDdo5FXUC9X2H0hO0BqOcc+h2
|
||||
khVyXjKt6TdBwV94dP9bmQp9QQKBgQDitPVqRHGepYBYasScyTUXPp7vL9T7PMSu
|
||||
PGjqEkwvauhICpUbWpE/j8M0UXk64zwSmOYwQ37uiPpws88GL57oy1ZQZnhF/Hi3
|
||||
tM00Mn4x0xbyONbWu/AcFIZwSeSL6QhHYfeyVj7Jb/lqUg8sMiGmO25JjlAQBfTb
|
||||
vvBgEpcVzwKBgQCwgz87JWfLejIGMR2qcoj1A30IYmAh1377uwO0F0mc7PrYbBtE
|
||||
N8IyUTR/bLGNocJME1b8vOWrmt19fRzlhp1t6C8prrSGzulULdbawQ4fAi7rhDek
|
||||
Iqsg8FRVGSgAptsN2dDvbcDKUuycQtyHzsE0/J338IXozIHehkGBlNTggQKBgQCF
|
||||
RDTj5BoaVVWuJA0x0UGJSYFqP2bmzWEcv1w5BMqOMT0cZEQkkUfC4oKwdZhbGosM
|
||||
r57ZDkRGenUl3T08eK/kTuuNVb8r/O8Fpp3eKjRum5TojKsWDeJmz1X8GiPkbvcz
|
||||
5w4RYouEJHOsoVJT+6A2NMdvK946nRXEO2jYQPVZlwKBgBro8qGm0+T0T2xNP21q
|
||||
IzjP/EHT7iIkM5kiUCc2bPIrfzAGxXImakDzd6AgpgxhhficJOpp792Upe/b/Hwy
|
||||
bwfmbdWlT7/hPCnlVVH2dgO/ysDyEfxPigBMd+MmucRm6fzGIU7XSQw4KJqH4vQN
|
||||
9IASWlgzyQ1RytAduzRuepzB
|
||||
-----END PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
const SelfSignedRsaCert = """
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDkzCCAnugAwIBAgIUEFovLJkPSn4T8BBZMYBrXPDajF0wDQYJKoZIhvcNAQEL
|
||||
BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X
|
||||
DTE5MTAxMjA1NDQ0N1oXDTIwMTAxMTA1NDQ0OFowWTELMAkGA1UEBhMCQVUxEzAR
|
||||
BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5
|
||||
IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
|
||||
MIIBCgKCAQEA0l3CjEesyCB0svlFI4MiJSVsxLtyRQhDfqgaY5sG7mntxjXndGID
|
||||
3pbOgCntAJVkDYKaqFnzEK1nKKjX/XXJvTiHVCuRRjFmAvJibujo8J3sUl7dEf5N
|
||||
O5XOXjzw1Gxr5G74A8h85dWJ442IXv9roy9BPffygLgO3L0SCxJeKy6Z54BPjDGD
|
||||
axVprEg8wvsNB7+RbsnPMGaEsp8zg76I7xV072vTceC8TxdVu/VoncCR8KyYIVZR
|
||||
lFIUnkRix0FkWPveeK4+k2r6CeoC2Eh8ehC50ZBTG6sUawidMAG61fEsHDx4yknO
|
||||
i0BkTMwUasU+OQ6Cpsx/T8HZ8Oj9DxicjwIDAQABo1MwUTAdBgNVHQ4EFgQUMM+1
|
||||
FZ6KmN2eCJfDxY+8xa1JKnYwHwYDVR0jBBgwFoAUMM+1FZ6KmN2eCJfDxY+8xa1J
|
||||
KnYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEASglh98fXvPwA
|
||||
KMaEezCUqTeE7DehLlhZ8n6ETKaBDcP3JR4+KTEh9y7gRGJ7DXGFAYfU3itgjyZo
|
||||
kbgpZIhrTKyYCAsF96Q1mHf/cBQ96UXr0U0SbYXSSJFeeMthMvki556dJZajtxcA
|
||||
9xR/U0PxPjhC9NIfpVSAv/7ocnXh73qOiFHoN9Cr2smzcGPxsifys2iv1qm5LwDr
|
||||
Dx5h/RfyfuAjS8e1ZCAhS++PYjb8BX54NilW2lTYF3pwpXL8znc4eBmklBkw5L60
|
||||
99jrK7LSQT9Nk8Mf9t4P/77N4hXCqsHIxZIqJlbdgdKfvBF3vRomxm3/aWtGlTVD
|
||||
vvzZPnlYfQ==
|
||||
-----END CERTIFICATE-----
|
||||
"""
|
||||
|
||||
suite "AsyncStream test suite":
|
||||
test "AsyncStream(StreamTransport) readExactly() test":
|
||||
@ -506,3 +562,96 @@ suite "ChunkedStream test suite":
|
||||
break
|
||||
result = res
|
||||
check waitFor(testVectors2(initTAddress("127.0.0.1:46001"))) == true
|
||||
|
||||
test "ChunkedStream leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
|
||||
suite "TLSStream test suite":
|
||||
const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
|
||||
test "Simple HTTPS connection":
|
||||
proc headerClient(address: TransportAddress,
|
||||
name: string): Future[bool] {.async.} =
|
||||
var mark = "HTTP/1.1 "
|
||||
var buffer = newSeq[byte](8192)
|
||||
var transp = await connect(address)
|
||||
var reader = newAsyncStreamReader(transp)
|
||||
var writer = newAsyncStreamWriter(transp)
|
||||
var tlsstream = newTLSClientAsyncStream(reader, writer, name)
|
||||
|
||||
await tlsstream.writer.write("GET / HTTP/1.1\r\nHost: " & name &
|
||||
"\r\nConnection: close\r\n\r\n")
|
||||
var readFut = tlsstream.reader.readUntil(addr buffer[0], len(buffer),
|
||||
HttpHeadersMark)
|
||||
let res = await withTimeout(readFut, 5.seconds)
|
||||
if res:
|
||||
var length = readFut.read()
|
||||
buffer.setLen(length)
|
||||
if len(buffer) > len(mark):
|
||||
if equalMem(addr buffer[0], addr mark[0], len(mark)):
|
||||
result = true
|
||||
|
||||
await tlsstream.reader.closeWait()
|
||||
await tlsstream.writer.closeWait()
|
||||
await reader.closeWait()
|
||||
await writer.closeWait()
|
||||
await transp.closeWait()
|
||||
|
||||
let res = waitFor(headerClient(resolveTAddress("www.google.com:443")[0],
|
||||
"www.google.com"))
|
||||
check res == true
|
||||
|
||||
proc checkSSLServer(address: TransportAddress,
|
||||
pemkey, pemcert: string): Future[bool] {.async.} =
|
||||
var key: TLSPrivateKey
|
||||
var cert: TLSCertificate
|
||||
let testMessage = "TEST MESSAGE"
|
||||
|
||||
proc serveClient(server: StreamServer,
|
||||
transp: StreamTransport) {.async.} =
|
||||
var reader = newAsyncStreamReader(transp)
|
||||
var writer = newAsyncStreamWriter(transp)
|
||||
var sstream = newTLSServerAsyncStream(reader, writer, key, cert)
|
||||
await handshake(sstream)
|
||||
await sstream.writer.write(testMessage & "\r\n")
|
||||
await sstream.writer.closeWait()
|
||||
await sstream.reader.closeWait()
|
||||
await reader.closeWait()
|
||||
await writer.closeWait()
|
||||
await transp.closeWait()
|
||||
server.stop()
|
||||
server.close()
|
||||
|
||||
key = TLSPrivateKey.init(pemkey)
|
||||
cert = TLSCertificate.init(pemcert)
|
||||
|
||||
var server = createStreamServer(address, serveClient, {ReuseAddr})
|
||||
server.start()
|
||||
var conn = await connect(address)
|
||||
var creader = newAsyncStreamReader(conn)
|
||||
var cwriter = newAsyncStreamWriter(conn)
|
||||
# We are using self-signed certificate
|
||||
let flags = {NoVerifyHost, NoVerifyServerName}
|
||||
var cstream = newTLSClientAsyncStream(creader, cwriter, "", flags = flags)
|
||||
let res = await cstream.reader.readLine()
|
||||
await cstream.reader.closeWait()
|
||||
await cstream.writer.closeWait()
|
||||
await creader.closeWait()
|
||||
await cwriter.closeWait()
|
||||
await conn.closeWait()
|
||||
await server.join()
|
||||
result = res == testMessage
|
||||
|
||||
test "Simple server with RSA self-signed certificate":
|
||||
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),
|
||||
SelfSignedRsaKey, SelfSignedRsaCert))
|
||||
check res == true
|
||||
test "TLSStream leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
|
Loading…
x
Reference in New Issue
Block a user