Make HttpServer a case object (#84)

* cleanup examples

* more examples cleanup

* make HttServer a case object

* propagate errors when handling requests

* don't extend HttpServer

* remove port from create that takes a string host

make more consistent with client's `connect`
This commit is contained in:
Dmitriy Ryajov 2021-07-15 14:17:55 -06:00 committed by GitHub
parent 06ae75cf7f
commit 7756dd1e77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 42 deletions

View File

@ -17,8 +17,9 @@ import ../websock/websock
proc main() {.async.} = proc main() {.async.} =
let ws = when defined tls: let ws = when defined tls:
await WebSocket.connect( await WebSocket.connect(
"127.0.0.1:8888", "127.0.0.1:8889",
path = "/wss", path = "/wss",
secure = true,
flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}) flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName})
else: else:
await WebSocket.connect( await WebSocket.connect(
@ -38,7 +39,7 @@ proc main() {.async.} =
let dataStr = string.fromBytes(buff) let dataStr = string.fromBytes(buff)
trace "Server Response: ", data = dataStr trace "Server Response: ", data = dataStr
assert dataStr == reqData doAssert dataStr == reqData
break break
except WebSocketError as exc: except WebSocketError as exc:
error "WebSocket error:", exception = exc.msg error "WebSocket error:", exception = exc.msg

View File

@ -17,11 +17,7 @@ import ../tests/keys
proc handle(request: HttpRequest) {.async.} = proc handle(request: HttpRequest) {.async.} =
trace "Handling request:", uri = request.uri.path trace "Handling request:", uri = request.uri.path
let path = when defined tls: "/wss" else: "/ws"
if request.uri.path != path:
return
trace "Initiating web socket connection."
try: try:
let deflateFactory = deflateFactory() let deflateFactory = deflateFactory()
let server = WSServer.new(factories = [deflateFactory]) let server = WSServer.new(factories = [deflateFactory])
@ -49,26 +45,32 @@ proc handle(request: HttpRequest) {.async.} =
when isMainModule: when isMainModule:
# we want to run parallel tests in CI # we want to run parallel tests in CI
# so we are using different port # so we are using different port
const serverAddr = when defined tls:
"127.0.0.1:8889"
else:
"127.0.0.1:8888"
proc main() {.async.} = proc main() {.async.} =
let let
address = initTAddress(serverAddr)
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
server = when defined tls: server = when defined tls:
TlsHttpServer.create( TlsHttpServer.create(
address = address, address = initTAddress("127.0.0.1:8889"),
handler = handle,
tlsPrivateKey = TLSPrivateKey.init(SecureKey), tlsPrivateKey = TLSPrivateKey.init(SecureKey),
tlsCertificate = TLSCertificate.init(SecureCert), tlsCertificate = TLSCertificate.init(SecureCert),
flags = socketFlags) flags = socketFlags)
else: else:
HttpServer.create(address, handle, flags = socketFlags) HttpServer.create(initTAddress("127.0.0.1:8888"), flags = socketFlags)
when defined accepts:
proc accepts() {.async, raises: [Defect].} =
while true:
try:
let req = await server.accept()
await req.handle()
except CatchableError as exc:
error "Transport error", exc = exc.msg
asyncCheck accepts()
else:
server.handler = handle
server.start() server.start()
trace "Server listening on ", data = $server.localAddress() trace "Server listening on ", data = $server.localAddress()
await server.join() await server.join()

View File

@ -29,13 +29,17 @@ type
HttpServer* = ref object of StreamServer HttpServer* = ref object of StreamServer
handler*: HttpAsyncCallback handler*: HttpAsyncCallback
case secure*: bool:
TlsHttpServer* = ref object of HttpServer of true:
tlsFlags*: set[TLSFlags] tlsFlags*: set[TLSFlags]
tlsPrivateKey*: TLSPrivateKey tlsPrivateKey*: TLSPrivateKey
tlsCertificate*: TLSCertificate tlsCertificate*: TLSCertificate
minVersion*: TLSVersion minVersion*: TLSVersion
maxVersion*: TLSVersion maxVersion*: TLSVersion
else:
discard
TlsHttpServer* = HttpServer
proc validateRequest( proc validateRequest(
stream: AsyncStreamWriter, stream: AsyncStreamWriter,
@ -73,7 +77,7 @@ proc parseRequest(
# Timeout # Timeout
trace "Timeout expired while receiving headers", address = $remoteAddr trace "Timeout expired while receiving headers", address = $remoteAddr
await stream.writer.sendError(Http408, version = HttpVersion11) await stream.writer.sendError(Http408, version = HttpVersion11)
return raise newException(HttpError, "Didn't read headers in time!")
let hlen = hlenfut.read() let hlen = hlenfut.read()
buffer.setLen(hlen) buffer.setLen(hlen)
@ -82,7 +86,7 @@ proc parseRequest(
# Header could not be parsed # Header could not be parsed
trace "Malformed header received", address = $remoteAddr trace "Malformed header received", address = $remoteAddr
await stream.writer.sendError(Http400, version = HttpVersion11) await stream.writer.sendError(Http400, version = HttpVersion11)
return raise newException(HttpError, "Malformed header received")
var vres = await stream.writer.validateRequest(requestData) var vres = await stream.writer.validateRequest(requestData)
let hdrs = let hdrs =
@ -94,9 +98,9 @@ proc parseRequest(
if vres == ReqStatus.ErrorFailure: if vres == ReqStatus.ErrorFailure:
trace "Remote peer disconnected", address = $remoteAddr trace "Remote peer disconnected", address = $remoteAddr
return raise newException(HttpError, "Remote peer disconnected")
debug "Received valid HTTP request", address = $remoteAddr trace "Received valid HTTP request", address = $remoteAddr
return HttpRequest( return HttpRequest(
headers: hdrs, headers: hdrs,
stream: stream, stream: stream,
@ -110,8 +114,6 @@ proc parseRequest(
trace "Remote peer disconnected", address = $remoteAddr trace "Remote peer disconnected", address = $remoteAddr
except TransportOsError as exc: except TransportOsError as exc:
trace "Problems with networking", address = $remoteAddr, error = exc.msg trace "Problems with networking", address = $remoteAddr, error = exc.msg
except CatchableError as exc:
debug "Unknown exception", address = $remoteAddr, error = exc.msg
proc handleConnCb( proc handleConnCb(
server: StreamServer, server: StreamServer,
@ -160,16 +162,16 @@ proc handleTlsConnCb(
finally: finally:
await stream.closeWait() await stream.closeWait()
proc accept*(server: HttpServer | TlsHttpServer): Future[HttpRequest] proc accept*(server: HttpServer): Future[HttpRequest]
{.async, raises: [Defect, HttpError].} = {.async, raises: [Defect, HttpError].} =
if not isNil(server.handler): if not isNil(server.handler):
raise newException(HttpError, raise newException(HttpError,
"Callback already registered - cannot mix callback and accepts stypes!") "Callback already registered - cannot mix callback and accepts styles!")
trace "Awaiting new request"
let transp = await StreamServer(server).accept() let transp = await StreamServer(server).accept()
var stream: AsyncStream let stream = if server.secure:
when server is TlsHttpServer:
let tlsStream = newTLSServerAsyncStream( let tlsStream = newTLSServerAsyncStream(
newAsyncStreamReader(transp), newAsyncStreamReader(transp),
newAsyncStreamWriter(transp), newAsyncStreamWriter(transp),
@ -179,14 +181,15 @@ proc accept*(server: HttpServer | TlsHttpServer): Future[HttpRequest]
maxVersion = server.maxVersion, maxVersion = server.maxVersion,
flags = server.tlsFlags) flags = server.tlsFlags)
stream = AsyncStream( AsyncStream(
reader: tlsStream.reader, reader: tlsStream.reader,
writer: tlsStream.writer) writer: tlsStream.writer)
else: else:
stream = AsyncStream( AsyncStream(
reader: newAsyncStreamReader(transp), reader: newAsyncStreamReader(transp),
writer: newAsyncStreamWriter(transp)) writer: newAsyncStreamWriter(transp))
trace "Got new request", isTls = server.secure
return await server.parseRequest(stream) return await server.parseRequest(stream)
proc create*( proc create*(
@ -206,21 +209,20 @@ proc create*(
flags, flags,
child = StreamServer(server))) child = StreamServer(server)))
trace "Created HTTP Server", host = $address trace "Created HTTP Server", host = $server.localAddress()
return server return server
proc create*( proc create*(
_: typedesc[HttpServer], _: typedesc[HttpServer],
host: string, host: string,
port: Port,
handler: HttpAsyncCallback = nil, handler: HttpAsyncCallback = nil,
flags: set[ServerFlags] = {}): HttpServer flags: set[ServerFlags] = {}): HttpServer
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
## Make a new HTTP Server ## Make a new HTTP Server
## ##
return HttpServer.create(initTAddress(host, port), handler, flags) return HttpServer.create(initTAddress(host), handler, flags)
proc create*( proc create*(
_: typedesc[TlsHttpServer], _: typedesc[TlsHttpServer],
@ -235,6 +237,7 @@ proc create*(
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
var server = TlsHttpServer( var server = TlsHttpServer(
secure: true,
handler: handler, handler: handler,
tlsPrivateKey: tlsPrivateKey, tlsPrivateKey: tlsPrivateKey,
tlsCertificate: tlsCertificate, tlsCertificate: tlsCertificate,
@ -248,14 +251,13 @@ proc create*(
flags, flags,
child = StreamServer(server))) child = StreamServer(server)))
trace "Created TLS HTTP Server", host = $address trace "Created TLS HTTP Server", host = $server.localAddress()
return server return server
proc create*( proc create*(
_: typedesc[TlsHttpServer], _: typedesc[TlsHttpServer],
host: string, host: string,
port: Port,
tlsPrivateKey: TLSPrivateKey, tlsPrivateKey: TLSPrivateKey,
tlsCertificate: TLSCertificate, tlsCertificate: TLSCertificate,
handler: HttpAsyncCallback = nil, handler: HttpAsyncCallback = nil,
@ -265,7 +267,7 @@ proc create*(
tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
TlsHttpServer.create( TlsHttpServer.create(
address = initTAddress(host, port), address = initTAddress(host),
handler = handler, handler = handler,
tlsPrivateKey = tlsPrivateKey, tlsPrivateKey = tlsPrivateKey,
tlsCertificate = tlsCertificate, tlsCertificate = tlsCertificate,