HttpServer now supports TLS.

Some TLSStream fixes to properly support EOF.
Some HttpServer to properly support TLS handshake problems.
HttpServer test suite for HTTPS.
This commit is contained in:
cheatfate 2021-02-03 12:47:03 +02:00 committed by zah
parent 1a3e9162a4
commit d43a9cb92d
4 changed files with 387 additions and 81 deletions

View File

@ -10,8 +10,6 @@ import stew/results, httputils, strutils, uri
export results, httputils, strutils export results, httputils, strutils
const const
useChroniclesLogging* {.booldefine.} = false
HeadersMark* = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] HeadersMark* = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete} PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete}
@ -19,9 +17,12 @@ type
HttpResult*[T] = Result[T, string] HttpResult*[T] = Result[T, string]
HttpResultCode*[T] = Result[T, HttpCode] HttpResultCode*[T] = Result[T, HttpCode]
HttpDefect* = object of Defect
HttpError* = object of CatchableError HttpError* = object of CatchableError
HttpCriticalError* = object of HttpError HttpCriticalError* = object of HttpError
code*: HttpCode
HttpRecoverableError* = object of HttpError HttpRecoverableError* = object of HttpError
code*: HttpCode
TransferEncodingFlags* {.pure.} = enum TransferEncodingFlags* {.pure.} = enum
Identity, Chunked, Compress, Deflate, Gzip Identity, Chunked, Compress, Deflate, Gzip
@ -29,15 +30,19 @@ type
ContentEncodingFlags* {.pure.} = enum ContentEncodingFlags* {.pure.} = enum
Identity, Br, Compress, Deflate, Gzip Identity, Br, Compress, Deflate, Gzip
template log*(body: untyped) = proc newHttpDefect*(msg: string): ref HttpDefect =
when defined(useChroniclesLogging): newException(HttpDefect, msg)
body
proc newHttpCriticalError*(msg: string): ref HttpCriticalError = proc newHttpCriticalError*(msg: string, code = Http400): ref HttpCriticalError =
newException(HttpCriticalError, msg) var tre = newException(HttpCriticalError, msg)
tre.code = code
tre
proc newHttpRecoverableError*(msg: string): ref HttpRecoverableError = proc newHttpRecoverableError*(msg: string,
newException(HttpRecoverableError, msg) code = Http400): ref HttpRecoverableError =
var tre = newException(HttpRecoverableError, msg)
tre.code = code
tre
iterator queryParams*(query: string): tuple[key: string, value: string] = iterator queryParams*(query: string): tuple[key: string, value: string] =
## Iterate over url-encoded query string. ## Iterate over url-encoded query string.

View File

@ -9,12 +9,9 @@
import std/[tables, options, uri, strutils] import std/[tables, options, uri, strutils]
import stew/results, httputils import stew/results, httputils
import ../../asyncloop, ../../asyncsync import ../../asyncloop, ../../asyncsync
import ../../streams/[asyncstream, boundstream, chunkstream] import ../../streams/[asyncstream, boundstream, chunkstream, tlsstream]
import httptable, httpcommon, multipart import httptable, httpcommon, multipart
export httptable, httpcommon, multipart export httptable, httpcommon, multipart, tlsstream, asyncstream
when defined(useChroniclesLogging):
import chronicles
type type
HttpServerFlags* {.pure.} = enum HttpServerFlags* {.pure.} = enum
@ -28,6 +25,7 @@ type
HttpProcessError* = object HttpProcessError* = object
error*: HTTPServerError error*: HTTPServerError
code*: HttpCode
exc*: ref CatchableError exc*: ref CatchableError
remote*: TransportAddress remote*: TransportAddress
@ -54,6 +52,7 @@ type
baseUri*: Uri baseUri*: Uri
flags*: set[HttpServerFlags] flags*: set[HttpServerFlags]
socketFlags*: set[ServerFlags] socketFlags*: set[ServerFlags]
secureFlags*: set[TLSFlags]
connections*: Table[string, Future[void]] connections*: Table[string, Future[void]]
acceptLoop*: Future[void] acceptLoop*: Future[void]
lifetime*: Future[void] lifetime*: Future[void]
@ -62,6 +61,8 @@ type
maxHeadersSize: int maxHeadersSize: int
maxRequestBodySize: int maxRequestBodySize: int
processCallback: HttpProcessCallback processCallback: HttpProcessCallback
tlsPrivateKey: TLSPrivateKey
tlsCertificate: TLSCertificate
HttpServerRef* = ref HttpServer HttpServerRef* = ref HttpServer
@ -99,16 +100,19 @@ type
HttpConnection* = object of RootObj HttpConnection* = object of RootObj
server*: HttpServerRef server*: HttpServerRef
transp: StreamTransport transp: StreamTransport
mainReader*: AsyncStreamReader mainReader: AsyncStreamReader
mainWriter*: AsyncStreamWriter mainWriter: AsyncStreamWriter
tlsStream: TLSAsyncStream
reader*: AsyncStreamReader
writer*: AsyncStreamWriter
buffer: seq[byte] buffer: seq[byte]
HttpConnectionRef* = ref HttpConnection HttpConnectionRef* = ref HttpConnection
proc init(htype: typedesc[HttpProcessError], error: HTTPServerError, proc init(htype: typedesc[HttpProcessError], error: HTTPServerError,
exc: ref CatchableError, exc: ref CatchableError, remote: TransportAddress,
remote: TransportAddress): HttpProcessError = code: HttpCode): HttpProcessError =
HttpProcessError(error: error, exc: exc, remote: remote) HttpProcessError(error: error, exc: exc, remote: remote, code: code)
proc new*(htype: typedesc[HttpServerRef], proc new*(htype: typedesc[HttpServerRef],
address: TransportAddress, address: TransportAddress,
@ -116,6 +120,9 @@ proc new*(htype: typedesc[HttpServerRef],
serverFlags: set[HttpServerFlags] = {}, serverFlags: set[HttpServerFlags] = {},
socketFlags: set[ServerFlags] = {ReuseAddr}, socketFlags: set[ServerFlags] = {ReuseAddr},
serverUri = Uri(), serverUri = Uri(),
tlsPrivateKey: TLSPrivateKey = nil,
tlsCertificate: TLSCertificate = nil,
secureFlags: set[TLSFlags] = {},
maxConnections: int = -1, maxConnections: int = -1,
bufferSize: int = 4096, bufferSize: int = 4096,
backlogSize: int = 100, backlogSize: int = 100,
@ -124,6 +131,10 @@ proc new*(htype: typedesc[HttpServerRef],
maxHeadersSize: int = 8192, maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576): HttpResult[HttpServerRef] = maxRequestBodySize: int = 1_048_576): HttpResult[HttpServerRef] =
if HttpServerFlags.Secure in serverFlags:
if isNil(tlsPrivateKey) or isNil(tlsCertificate):
return err("PrivateKey or Certificate is missing")
var res = HttpServerRef( var res = HttpServerRef(
maxConnections: maxConnections, maxConnections: maxConnections,
headersTimeout: httpHeadersTimeout, headersTimeout: httpHeadersTimeout,
@ -133,7 +144,9 @@ proc new*(htype: typedesc[HttpServerRef],
processCallback: processCallback, processCallback: processCallback,
backLogSize: backLogSize, backLogSize: backLogSize,
flags: serverFlags, flags: serverFlags,
socketFlags: socketFlags socketFlags: socketFlags,
tlsPrivateKey: tlsPrivateKey,
tlsCertificate: tlsCertificate
) )
res.baseUri = res.baseUri =
@ -311,10 +324,10 @@ proc getBodyStream*(request: HttpRequestRef): HttpResult[AsyncStreamReader] =
## Streams which was obtained using this procedure must be closed to avoid ## Streams which was obtained using this procedure must be closed to avoid
## leaks. ## leaks.
if HttpRequestFlags.BoundBody in request.requestFlags: if HttpRequestFlags.BoundBody in request.requestFlags:
ok(newBoundedStreamReader(request.connection.mainReader, ok(newBoundedStreamReader(request.connection.reader,
request.contentLength)) request.contentLength))
elif HttpRequestFlags.UnboundBody in request.requestFlags: elif HttpRequestFlags.UnboundBody in request.requestFlags:
ok(newChunkedStreamReader(request.connection.mainReader)) ok(newChunkedStreamReader(request.connection.reader))
else: else:
err("Request do not have body available") err("Request do not have body available")
@ -326,7 +339,7 @@ proc handleExpect*(request: HttpRequestRef) {.async.} =
if request.version == HttpVersion11: if request.version == HttpVersion11:
try: try:
let message = $request.version & " " & $Http100 & "\r\n\r\n" let message = $request.version & " " & $Http100 & "\r\n\r\n"
await request.connection.mainWriter.write(message) await request.connection.writer.write(message)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamWriteError, AsyncStreamIncompleteError: except AsyncStreamWriteError, AsyncStreamIncompleteError:
@ -379,7 +392,7 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion,
if len(databody) > 0: if len(databody) > 0:
answer.add(databody) answer.add(databody)
try: try:
await conn.mainWriter.write(answer) await conn.writer.write(answer)
return true return true
except CancelledError: except CancelledError:
return false return false
@ -388,43 +401,72 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion,
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
return false return false
proc getRequest*(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} = proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} =
try: try:
conn.buffer.setLen(conn.server.maxHeadersSize) conn.buffer.setLen(conn.server.maxHeadersSize)
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer), let res = await conn.reader.readUntil(addr conn.buffer[0], len(conn.buffer),
HeadersMark) HeadersMark)
conn.buffer.setLen(res) conn.buffer.setLen(res)
let header = parseRequest(conn.buffer) let header = parseRequest(conn.buffer)
if header.failed(): if header.failed():
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalError("Malformed request recieved") raise newHttpCriticalError("Malformed request recieved")
else: else:
let res = prepareRequest(conn, header) let res = prepareRequest(conn, header)
if res.isErr(): if res.isErr():
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalError("Invalid request received") raise newHttpCriticalError("Invalid request received")
else: else:
return res.get() return res.get()
except TransportOsError: except AsyncStreamIncompleteError:
raise newHttpCriticalError("Unexpected OS error")
except TransportIncompleteError:
raise newHttpCriticalError("Remote peer disconnected") raise newHttpCriticalError("Remote peer disconnected")
except TransportLimitError: except AsyncStreamReadError:
discard await conn.sendErrorResponse(HttpVersion11, Http413, false) raise newHttpCriticalError("Connection with remote peer has been lost")
raise newHttpCriticalError("Maximum size of request headers reached") except AsyncStreamLimitError:
raise newHttpCriticalError("Maximum size of request headers reached",
Http413)
proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef, proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef,
transp: StreamTransport): HttpConnectionRef = transp: StreamTransport): HttpConnectionRef =
let mainReader = newAsyncStreamReader(transp)
let mainWriter = newAsyncStreamWriter(transp)
let tlsStream =
if HttpServerFlags.Secure in server.flags:
newTLSServerAsyncStream(mainReader, mainWriter, server.tlsPrivateKey,
server.tlsCertificate,
minVersion = TLSVersion.TLS12,
flags = server.secureFlags)
else:
nil
let reader =
if isNil(tlsStream):
mainReader
else:
cast[AsyncStreamReader](tlsStream.reader)
let writer =
if isNil(tlsStream):
mainWriter
else:
cast[AsyncStreamWriter](tlsStream.writer)
HttpConnectionRef( HttpConnectionRef(
transp: transp, transp: transp,
server: server, server: server,
buffer: newSeq[byte](server.maxHeadersSize), buffer: newSeq[byte](server.maxHeadersSize),
mainReader: newAsyncStreamReader(transp), mainReader: mainReader,
mainWriter: newAsyncStreamWriter(transp) mainWriter: mainWriter,
tlsStream: tlsStream,
reader: reader,
writer: writer
) )
proc close(conn: HttpConnectionRef): Future[void] = proc close(conn: HttpConnectionRef) {.async.} =
allFutures(conn.mainReader.closeWait(), conn.mainWriter.closeWait(), if HttpServerFlags.Secure in conn.server.flags:
# First we will close TLS streams.
await allFutures(conn.reader.closeWait(), conn.writer.closeWait())
# After we going to close everything else.
await allFutures(conn.mainReader.closeWait(), conn.mainWriter.closeWait(),
conn.transp.closeWait()) conn.transp.closeWait())
proc close(req: HttpRequestRef) {.async.} = proc close(req: HttpRequestRef) {.async.} =
@ -434,10 +476,57 @@ proc close(req: HttpRequestRef) {.async.} =
not(isNil(resp.chunkedWriter)): not(isNil(resp.chunkedWriter)):
await resp.chunkedWriter.closeWait() await resp.chunkedWriter.closeWait()
proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = proc createConnection(server: HttpServerRef,
transp: StreamTransport): Future[HttpConnectionRef] {.
async.} =
var conn = HttpConnectionRef.new(server, transp) var conn = HttpConnectionRef.new(server, transp)
if HttpServerFlags.Secure notin server.flags:
# Non secure connection
return conn
try:
await handshake(conn.tlsStream)
return conn
except CancelledError as exc:
await conn.close()
raise exc
except TLSStreamError:
await conn.close()
raise newHttpCriticalError("Unable to establish secure connection")
proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
var
conn: HttpConnectionRef
connArg: RequestFence[HttpRequestRef]
runLoop = false
try:
conn = await createConnection(server, transp)
runLoop = true
except CancelledError:
# We could be cancelled only when we perform TLS handshake, connection
server.connections.del(transp.getId())
return
except HttpCriticalError as exc:
let error = HttpProcessError.init(HTTPServerError.CriticalError, exc,
transp.remoteAddress(), exc.code)
connArg = RequestFence[HttpRequestRef].err(error)
runLoop = false
if not(runLoop):
try:
# We still want to notify process callback about failure, but we ignore
# result and swallow all the exceptions.
discard await server.processCallback(connArg)
except CancelledError:
server.connections.del(transp.getId())
return
except CatchableError as exc:
# There should be no exceptions, so we will raise `Defect`.
raise newHttpDefect("Unexpected exception catched [" & $exc.name & "]")
var breakLoop = false var breakLoop = false
while true: while runLoop:
var var
arg: RequestFence[HttpRequestRef] arg: RequestFence[HttpRequestRef]
resp: HttpResponseRef resp: HttpResponseRef
@ -449,19 +538,19 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
breakLoop = true breakLoop = true
except AsyncTimeoutError as exc: except AsyncTimeoutError as exc:
let error = HttpProcessError.init(HTTPServerError.TimeoutError, exc, let error = HttpProcessError.init(HTTPServerError.TimeoutError, exc,
transp.remoteAddress()) transp.remoteAddress(), Http408)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence[HttpRequestRef].err(error)
except HttpRecoverableError as exc: except HttpRecoverableError as exc:
let error = HttpProcessError.init(HTTPServerError.RecoverableError, exc, let error = HttpProcessError.init(HTTPServerError.RecoverableError, exc,
transp.remoteAddress()) transp.remoteAddress(), exc.code)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence[HttpRequestRef].err(error)
except HttpCriticalError as exc: except HttpCriticalError as exc:
let error = HttpProcessError.init(HTTPServerError.CriticalError, exc, let error = HttpProcessError.init(HTTPServerError.CriticalError, exc,
transp.remoteAddress()) transp.remoteAddress(), exc.code)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence[HttpRequestRef].err(error)
except CatchableError as exc: except CatchableError as exc:
let error = HttpProcessError.init(HTTPServerError.CatchableError, exc, let error = HttpProcessError.init(HTTPServerError.CatchableError, exc,
transp.remoteAddress()) transp.remoteAddress(), Http500)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence[HttpRequestRef].err(error)
if breakLoop: if breakLoop:
@ -481,15 +570,16 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
break break
if arg.isErr(): if arg.isErr():
let code = arg.error().code
case arg.error().error case arg.error().error
of HTTPServerError.TimeoutError: of HTTPServerError.TimeoutError:
discard await conn.sendErrorResponse(HttpVersion11, Http408, false) discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HTTPServerError.RecoverableError: of HTTPServerError.RecoverableError:
discard await conn.sendErrorResponse(HttpVersion11, Http400, false) discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HTTPServerError.CriticalError: of HTTPServerError.CriticalError:
discard await conn.sendErrorResponse(HttpVersion11, Http400, false) discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HTTPServerError.CatchableError: of HTTPServerError.CatchableError:
discard await conn.sendErrorResponse(HttpVersion11, Http400, false) discard await conn.sendErrorResponse(HttpVersion11, code, false)
break break
else: else:
let request = arg.get() let request = arg.get()
@ -521,7 +611,10 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
if not(keepConn): if not(keepConn):
break break
# Connection could be `nil` only when secure handshake is failed.
if not(isNil(conn)):
await conn.close() await conn.close()
server.connections.del(transp.getId()) server.connections.del(transp.getId())
# if server.maxConnections > 0: # if server.maxConnections > 0:
# server.semaphore.release() # server.semaphore.release()
@ -784,9 +877,9 @@ proc sendBody*(resp: HttpResponseRef, pbytes: pointer, nbytes: int) {.async.} =
resp.state = HttpResponseState.Prepared resp.state = HttpResponseState.Prepared
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.connection.mainWriter.write(responseHeaders) await resp.connection.writer.write(responseHeaders)
if nbytes > 0: if nbytes > 0:
await resp.connection.mainWriter.write(pbytes, nbytes) await resp.connection.writer.write(pbytes, nbytes)
resp.state = HttpResponseState.Finished resp.state = HttpResponseState.Finished
except CancelledError as exc: except CancelledError as exc:
resp.state = HttpResponseState.Cancelled resp.state = HttpResponseState.Cancelled
@ -802,9 +895,9 @@ proc sendBody*[T: string|seq[byte]](resp: HttpResponseRef, data: T) {.async.} =
resp.state = HttpResponseState.Prepared resp.state = HttpResponseState.Prepared
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.connection.mainWriter.write(responseHeaders) await resp.connection.writer.write(responseHeaders)
if len(data) > 0: if len(data) > 0:
await resp.connection.mainWriter.write(data) await resp.connection.writer.write(data)
resp.state = HttpResponseState.Finished resp.state = HttpResponseState.Finished
except CancelledError as exc: except CancelledError as exc:
resp.state = HttpResponseState.Cancelled resp.state = HttpResponseState.Cancelled
@ -821,9 +914,9 @@ proc sendError*(resp: HttpResponseRef, code: HttpCode, body = "") {.async.} =
resp.state = HttpResponseState.Prepared resp.state = HttpResponseState.Prepared
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.connection.mainWriter.write(responseHeaders) await resp.connection.writer.write(responseHeaders)
if len(body) > 0: if len(body) > 0:
await resp.connection.mainWriter.write(body) await resp.connection.writer.write(body)
resp.state = HttpResponseState.Finished resp.state = HttpResponseState.Finished
except CancelledError as exc: except CancelledError as exc:
resp.state = HttpResponseState.Cancelled resp.state = HttpResponseState.Cancelled
@ -841,8 +934,8 @@ proc prepare*(resp: HttpResponseRef) {.async.} =
resp.state = HttpResponseState.Prepared resp.state = HttpResponseState.Prepared
try: try:
resp.state = HttpResponseState.Sending resp.state = HttpResponseState.Sending
await resp.connection.mainWriter.write(responseHeaders) await resp.connection.writer.write(responseHeaders)
resp.chunkedWriter = newChunkedStreamWriter(resp.connection.mainWriter) resp.chunkedWriter = newChunkedStreamWriter(resp.connection.writer)
resp.flags.incl(HttpResponseFlags.Chunked) resp.flags.incl(HttpResponseFlags.Chunked)
except CancelledError as exc: except CancelledError as exc:
resp.state = HttpResponseState.Cancelled resp.state = HttpResponseState.Cancelled

View File

@ -31,6 +31,9 @@ type
TLSKeyType {.pure.} = enum TLSKeyType {.pure.} = enum
RSA, EC RSA, EC
TLSResult {.pure.} = enum
Success, Error, EOF
TLSPrivateKey* = ref object TLSPrivateKey* = ref object
case kind: TLSKeyType case kind: TLSKeyType
of RSA: of RSA:
@ -127,24 +130,23 @@ template newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
err err
proc tlsWriteRec(engine: ptr SslEngineContext, proc tlsWriteRec(engine: ptr SslEngineContext,
writer: TLSStreamWriter): Future[bool] {.async.} = writer: TLSStreamWriter): Future[TLSResult] {.async.} =
try: try:
var length = 0'u var length = 0'u
var buf = sslEngineSendrecBuf(engine, length) var buf = sslEngineSendrecBuf(engine, length)
doAssert(length != 0 and not isNil(buf)) doAssert(length != 0 and not isNil(buf))
await writer.wsource.write(buf, int(length)) await writer.wsource.write(buf, int(length))
sslEngineSendrecAck(engine, length) sslEngineSendrecAck(engine, length)
return true return TLSResult.Success
except AsyncStreamError as exc: except AsyncStreamError as exc:
writer.state = AsyncStreamState.Error writer.state = AsyncStreamState.Error
writer.error = exc writer.error = exc
except CancelledError: except CancelledError:
writer.state = AsyncStreamState.Stopped writer.state = AsyncStreamState.Stopped
return TLSResult.Error
return false
proc tlsWriteApp(engine: ptr SslEngineContext, proc tlsWriteApp(engine: ptr SslEngineContext,
writer: TLSStreamWriter): Future[bool] {.async.} = writer: TLSStreamWriter): Future[TLSResult] {.async.} =
try: try:
var item = await writer.queue.get() var item = await writer.queue.get()
if item.size > 0: if item.size > 0:
@ -157,7 +159,7 @@ proc tlsWriteApp(engine: ptr SslEngineContext,
sslEngineSendappAck(engine, uint(item.size)) sslEngineSendappAck(engine, uint(item.size))
sslEngineFlush(engine, 0) sslEngineFlush(engine, 0)
item.future.complete() item.future.complete()
return true return TLSResult.Success
else: else:
# BearSSL is not ready to accept whole item, so we will send # BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset. # only part of item and adjust offset.
@ -165,58 +167,68 @@ proc tlsWriteApp(engine: ptr SslEngineContext,
item.size = item.size - int(length) item.size = item.size - int(length)
writer.queue.addFirstNoWait(item) writer.queue.addFirstNoWait(item)
sslEngineSendappAck(engine, length) sslEngineSendappAck(engine, length)
return true return TLSResult.Success
else: else:
sslEngineClose(engine) sslEngineClose(engine)
item.future.complete() item.future.complete()
return true return TLSResult.Success
except CancelledError: except CancelledError:
writer.state = AsyncStreamState.Stopped writer.state = AsyncStreamState.Stopped
return TLSResult.Error
return false
proc tlsReadRec(engine: ptr SslEngineContext, proc tlsReadRec(engine: ptr SslEngineContext,
reader: TLSStreamReader): Future[bool] {.async.} = reader: TLSStreamReader): Future[TLSResult] {.async.} =
try: try:
var length = 0'u var length = 0'u
var buf = sslEngineRecvrecBuf(engine, length) var buf = sslEngineRecvrecBuf(engine, length)
let res = await reader.rsource.readOnce(buf, int(length)) let res = await reader.rsource.readOnce(buf, int(length))
sslEngineRecvrecAck(engine, uint(res)) sslEngineRecvrecAck(engine, uint(res))
return true if res == 0:
sslEngineClose(engine)
return TLSResult.EOF
else:
return TLSResult.Success
except CancelledError: except CancelledError:
reader.state = AsyncStreamState.Stopped reader.state = AsyncStreamState.Stopped
except AsyncStreamError as exc: except AsyncStreamError as exc:
reader.state = AsyncStreamState.Error reader.state = AsyncStreamState.Error
reader.error = exc reader.error = exc
return false return TLSResult.Error
proc tlsReadApp(engine: ptr SslEngineContext, proc tlsReadApp(engine: ptr SslEngineContext,
reader: TLSStreamReader): Future[bool] {.async.} = reader: TLSStreamReader): Future[TLSResult] {.async.} =
try: try:
var length = 0'u var length = 0'u
var buf = sslEngineRecvappBuf(engine, length) var buf = sslEngineRecvappBuf(engine, length)
await upload(addr reader.buffer, buf, int(length)) await upload(addr reader.buffer, buf, int(length))
sslEngineRecvappAck(engine, length) sslEngineRecvappAck(engine, length)
return true return TLSResult.Success
except CancelledError: except CancelledError:
reader.state = AsyncStreamState.Stopped reader.state = AsyncStreamState.Stopped
return false return TLSResult.Error
template raiseTLSStreamProtoError*[T](message: T) = template raiseTLSStreamProtoError*[T](message: T) =
raise newTLSStreamProtocolError(message) raise newTLSStreamProtocolError(message)
template readAndReset(fut: untyped) = template readAndReset(fut: untyped) =
if fut.finished(): if fut.finished():
if fut.read(): let res = fut.read()
case res
of TLSREsult.Success:
fut = nil fut = nil
continue continue
else: of TLSResult.Error:
fut = nil fut = nil
loopState = AsyncStreamState.Error loopState = AsyncStreamState.Error
break break
of TLSResult.EOF:
fut = nil
loopState = AsyncStreamState.Finished
break
proc cancelAndWait*(a, b, c, d: Future[bool]): Future[void] = proc cancelAndWait*(a, b, c, d: Future[TLSResult]): Future[void] =
var waiting: seq[Future[bool]] var waiting: seq[Future[TLSResult]]
if not(isNil(a)) and not(a.finished()): if not(isNil(a)) and not(a.finished()):
a.cancel() a.cancel()
waiting.add(a) waiting.add(a)
@ -231,10 +243,29 @@ proc cancelAndWait*(a, b, c, d: Future[bool]): Future[void] =
waiting.add(d) waiting.add(d)
allFutures(waiting) allFutures(waiting)
proc dumpState*(state: cuint): string =
var res = ""
if (state and SSL_CLOSED) == SSL_CLOSED:
if len(res) > 0: res.add(", ")
res.add("SSL_CLOSED")
if (state and SSL_SENDREC) == SSL_SENDREC:
if len(res) > 0: res.add(", ")
res.add("SSL_SENDREC")
if (state and SSL_SENDAPP) == SSL_SENDAPP:
if len(res) > 0: res.add(", ")
res.add("SSL_SENDAPP")
if (state and SSL_RECVREC) == SSL_RECVREC:
if len(res) > 0: res.add(", ")
res.add("SSL_RECVREC")
if (state and SSL_RECVAPP) == SSL_RECVAPP:
if len(res) > 0: res.add(", ")
res.add("SSL_RECVAPP")
"{" & res & "}"
proc tlsLoop*(stream: TLSAsyncStream) {.async.} = proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
var var
sendRecFut, sendAppFut: Future[bool] sendRecFut, sendAppFut: Future[TLSResult]
recvRecFut, recvAppFut: Future[bool] recvRecFut, recvAppFut: Future[TLSResult]
let engine = let engine =
case stream.reader.kind case stream.reader.kind
@ -246,7 +277,7 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
var loopState = AsyncStreamState.Running var loopState = AsyncStreamState.Running
while true: while true:
var waiting: seq[Future[bool]] var waiting: seq[Future[TLSResult]]
var state = sslEngineCurrentState(engine) var state = sslEngineCurrentState(engine)
if (state and SSL_CLOSED) == SSL_CLOSED: if (state and SSL_CLOSED) == SSL_CLOSED:
@ -343,6 +374,14 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
if not(isNil(stream.writer.handshakeFut)): if not(isNil(stream.writer.handshakeFut)):
if not(stream.writer.handshakeFut.finished()): if not(stream.writer.handshakeFut.finished()):
stream.writer.handshakeFut.fail(error) stream.writer.handshakeFut.fail(error)
else:
if not(stream.writer.handshaked):
if not(isNil(stream.writer.handshakeFut)):
if not(stream.writer.handshakeFut.finished()):
stream.writer.handshakeFut.fail(
newTLSStreamProtocolError("Connection with remote peer lost")
)
# Completing readers # Completing readers
stream.reader.buffer.forget() stream.reader.buffer.forget()

View File

@ -8,6 +8,68 @@
import std/[strutils, unittest, algorithm, strutils] import std/[strutils, unittest, algorithm, strutils]
import ../chronos, ../chronos/apps import ../chronos, ../chronos/apps
# To create self-signed certificate and key you can use openssl
# openssl req -new -x509 -sha256 -newkey rsa:2048 -nodes \
# -keyout example-com.key.pem -days 3650 -out example-com.cert.pem
const HttpsSelfSignedRsaKey = """
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCn7tXGLKMIMzOG
tVzUixax1/ftlSLcpEAkZMORuiCCnYjtIJhGZdzRFZC8fBlfAJZpLIAOfX2L2f1J
ZuwpwDkOIvNqKMBrl5Mvkl5azPT0rtnjuwrcqN5NFtbmZPKFYvbjex2aXGqjl5MW
nQIs/ZA++DVEXmaN9oDxcZsvRMDKfrGQf9iLeoVL47Gx9KpqNqD/JLIn4LpieumV
yYidm6ukTOqHRvrWm36y6VvKW4TE97THacULmkeahtTf8zDJbbh4EO+gifgwgJ2W
BUS0+5hMcWu8111mXmanlOVlcoW8fH8RmPjL1eK1Z3j3SVHEf7oWZtIVW5gGA0jQ
nfA4K51RAgMBAAECggEANZ7/R13tWKrwouy6DWuz/WlWUtgx333atUQvZhKmWs5u
cDjeJmxUC7b1FhoSB9GqNT7uTLIpKkSaqZthgRtNnIPwcU890Zz+dEwqMJgNByvl
it+oYjjRco/+YmaNQaYN6yjelPE5Y678WlYb4b29Fz4t0/zIhj/VgEKkKH2tiXpS
TIicoM7pSOscEUfaW3yp5bS5QwNU6/AaF1wws0feBACd19ZkcdPvr52jopbhxlXw
h3XTV/vXIJd5zWGp0h/Jbd4xcD4MVo2GjfkeORKY6SjDaNzt8OGtePcKnnbUVu8b
2XlDxukhDQXqJ3g0sHz47mhvo4JeIM+FgymRm+3QmQKBgQDTawrEA3Zy9WvucaC7
Zah02oE9nuvpF12lZ7WJh7+tZ/1ss+Fm7YspEKaUiEk7nn1CAVFtem4X4YCXTBiC
Oqq/o+ipv1yTur0ae6m4pwLm5wcMWBh3H5zjfQTfrClNN8yjWv8u3/sq8KesHPnT
R92/sMAptAChPgTzQphWbxFiYwKBgQDLWFaBqXfZYVnTyUvKX8GorS6jGWc6Eh4l
lAFA+2EBWDICrUxsDPoZjEXrWCixdqLhyehaI3KEFIx2bcPv6X2c7yx3IG5lA/Gx
TZiKlY74c6jOTstkdLW9RJbg1VUHUVZMf/Owt802YmEfUI5S5v7jFmKW6VG+io+K
+5KYeHD1uwKBgQDMf53KPA82422jFwYCPjLT1QduM2q97HwIomhWv5gIg63+l4BP
rzYMYq6+vZUYthUy41OAMgyLzPQ1ZMXQMi83b7R9fTxvKRIBq9xfYCzObGnE5vHD
SDDZWvR75muM5Yxr9nkfPkgVIPMO6Hg+hiVYZf96V0LEtNjU9HWmJYkLQQKBgQCQ
ULGUdGHKtXy7AjH3/t3CiKaAupa4cANVSCVbqQy/l4hmvfdu+AbH+vXkgTzgNgKD
nHh7AI1Vj//gTSayLlQn/Nbh9PJkXtg5rYiFUn+VdQBo6yMOuIYDPZqXFtCx0Nge
kvCwisHpxwiG4PUhgS+Em259DDonsM8PJFx2OYRx4QKBgEQpGhg71Oi9MhPJshN7
dYTowaMS5eLTk2264ARaY+hAIV7fgvUa+5bgTVaWL+Cfs33hi4sMRqlEwsmfds2T
cnQiJ4cU20Euldfwa5FLnk6LaWdOyzYt/ICBJnKFRwfCUbS4Bu5rtMEM+3t0wxnJ
IgaD04WhoL9EX0Qo3DC1+0kG
-----END PRIVATE KEY-----
"""
# This SSL certificate will expire 13 October 2030.
const HttpsSelfSignedRsaCert = """
-----BEGIN CERTIFICATE-----
MIIDnzCCAoegAwIBAgIUUdcusjDd3XQi3FPM8urdFG3qI+8wDQYJKoZIhvcNAQEL
BQAwXzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEYMBYGA1UEAwwPMTI3LjAuMC4xOjQz
ODA4MB4XDTIwMTAxMjIxNDUwMVoXDTMwMTAxMDIxNDUwMVowXzELMAkGA1UEBhMC
QVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdp
dHMgUHR5IEx0ZDEYMBYGA1UEAwwPMTI3LjAuMC4xOjQzODA4MIIBIjANBgkqhkiG
9w0BAQEFAAOCAQ8AMIIBCgKCAQEAp+7VxiyjCDMzhrVc1IsWsdf37ZUi3KRAJGTD
kboggp2I7SCYRmXc0RWQvHwZXwCWaSyADn19i9n9SWbsKcA5DiLzaijAa5eTL5Je
Wsz09K7Z47sK3KjeTRbW5mTyhWL243sdmlxqo5eTFp0CLP2QPvg1RF5mjfaA8XGb
L0TAyn6xkH/Yi3qFS+OxsfSqajag/ySyJ+C6YnrplcmInZurpEzqh0b61pt+sulb
yluExPe0x2nFC5pHmobU3/MwyW24eBDvoIn4MICdlgVEtPuYTHFrvNddZl5mp5Tl
ZXKFvHx/EZj4y9XitWd490lRxH+6FmbSFVuYBgNI0J3wOCudUQIDAQABo1MwUTAd
BgNVHQ4EFgQUBKha84woY5WkFxKw7qx1cONg1H8wHwYDVR0jBBgwFoAUBKha84wo
Y5WkFxKw7qx1cONg1H8wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC
AQEAHZMYt9Ry+Xj3vTbzpGFQzYQVTJlfJWSN6eWNOivRFQE5io9kOBEe5noa8aLo
dLkw6ztxRP2QRJmlhGCO9/HwS17ckrkgZp3EC2LFnzxcBmoZu+owfxOT1KqpO52O
IKOl8eVohi1pEicE4dtTJVcpI7VCMovnXUhzx1Ci4Vibns4a6H+BQa19a1JSpifN
tO8U5jkjJ8Jprs/VPFhJj2O3di53oDHaYSE5eOrm2ZO14KFHSk9cGcOGmcYkUv8B
nV5vnGadH5Lvfxb/BCpuONabeRdOxMt9u9yQ89vNpxFtRdZDCpGKZBCfmUP+5m3m
N8r5CwGcIX/XPC3lKazzbZ8baA==
-----END CERTIFICATE-----
"""
suite "HTTP server testing suite": suite "HTTP server testing suite":
proc httpClient(address: TransportAddress, proc httpClient(address: TransportAddress,
data: string): Future[string] {.async.} = data: string): Future[string] {.async.} =
@ -27,6 +89,37 @@ suite "HTTP server testing suite":
if not(isNil(transp)): if not(isNil(transp)):
await closeWait(transp) await closeWait(transp)
proc httpsClient(address: TransportAddress,
data: string, flags = {NoVerifyHost, NoVerifyServerName}
): Future[string] {.async.} =
var
transp: StreamTransport
tlsstream: TlsAsyncStream
reader: AsyncStreamReader
writer: AsyncStreamWriter
try:
transp = await connect(address)
reader = newAsyncStreamReader(transp)
writer = newAsyncStreamWriter(transp)
tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
if len(data) > 0:
await tlsstream.writer.write(data)
var rres = await tlsstream.reader.read()
var sres = newString(len(rres))
if len(rres) > 0:
copyMem(addr sres[0], addr rres[0], len(rres))
return sres
except CatchableError:
return "EXCEPTION"
finally:
if not(isNil(tlsstream)):
await allFutures(tlsstream.reader.closeWait(),
tlsstream.writer.closeWait())
if not(isNil(reader)):
await allFutures(reader.closeWait(), writer.closeWait(),
transp.closeWait())
test "Request headers timeout test": test "Request headers timeout test":
proc testTimeout(address: TransportAddress): Future[bool] {.async.} = proc testTimeout(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
@ -359,6 +452,82 @@ suite "HTTP server testing suite":
check waitFor(testPostMultipart2(initTAddress("127.0.0.1:30080"))) == true check waitFor(testPostMultipart2(initTAddress("127.0.0.1:30080"))) == true
test "HTTPS server (successful handshake) test":
proc testHTTPS(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {.
async.} =
if r.isOk():
let request = r.get()
serverRes = true
return await request.respond(Http200, "TEST_OK:" & $request.meth,
HttpTable.init())
else:
serverRes = false
return dumbResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let serverFlags = {Secure}
let secureKey = TLSPrivateKey.init(HttpsSelfSignedRsaKey)
let secureCert = TLSCertificate.init(HttpsSelfSignedRsaCert)
let res = HttpServerRef.new(address, process,
socketFlags = socketFlags,
serverFlags = serverFlags,
tlsPrivateKey = secureKey,
tlsCertificate = secureCert)
if res.isErr():
return false
let server = res.get()
server.start()
let message = "GET / HTTP/1.0\r\nHost: https://127.0.0.1:80\r\n\r\n"
let data = await httpsClient(address, message)
await server.stop()
await server.close()
return serverRes and (data.find("TEST_OK:GET") >= 0)
check waitFor(testHTTPS(initTAddress("127.0.0.1:30080"))) == true
test "HTTPS server (failed handshake) test":
proc testHTTPS2(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false
var testFut = newFuture[void]()
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {.
async.} =
if r.isOk():
let request = r.get()
serverRes = false
return await request.respond(Http200, "TEST_OK:" & $request.meth,
HttpTable.init())
else:
serverRes = true
testFut.complete()
return dumbResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let serverFlags = {Secure}
let secureKey = TLSPrivateKey.init(HttpsSelfSignedRsaKey)
let secureCert = TLSCertificate.init(HttpsSelfSignedRsaCert)
let res = HttpServerRef.new(address, process,
socketFlags = socketFlags,
serverFlags = serverFlags,
tlsPrivateKey = secureKey,
tlsCertificate = secureCert)
if res.isErr():
return false
let server = res.get()
server.start()
let message = "GET / HTTP/1.0\r\nHost: https://127.0.0.1:80\r\n\r\n"
let data = await httpsClient(address, message, {NoVerifyServerName})
await testFut
await server.stop()
await server.close()
return serverRes and data == "EXCEPTION"
check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true
test "Leaks test": test "Leaks test":
check: check:
getTracker("async.stream.reader").isLeaked() == false getTracker("async.stream.reader").isLeaked() == false