Make HttpServer a case object (#84)
* cleanup examples * more examples cleanup * make HttServer a case object * propagate errors when handling requests * don't extend HttpServer * remove port from create that takes a string host make more consistent with client's `connect`
This commit is contained in:
parent
06ae75cf7f
commit
7756dd1e77
|
@ -32,8 +32,8 @@ proc connectServer(path: string, factories: seq[ExtFactory] = @[]): Future[WSSes
|
||||||
let ws = await WebSocket.connect(
|
let ws = await WebSocket.connect(
|
||||||
host = "127.0.0.1:$1" % [$serverPort],
|
host = "127.0.0.1:$1" % [$serverPort],
|
||||||
path = path,
|
path = path,
|
||||||
secure=secure,
|
secure = secure,
|
||||||
flags=clientFlags,
|
flags = clientFlags,
|
||||||
factories = factories
|
factories = factories
|
||||||
)
|
)
|
||||||
return ws
|
return ws
|
||||||
|
|
|
@ -17,8 +17,9 @@ import ../websock/websock
|
||||||
proc main() {.async.} =
|
proc main() {.async.} =
|
||||||
let ws = when defined tls:
|
let ws = when defined tls:
|
||||||
await WebSocket.connect(
|
await WebSocket.connect(
|
||||||
"127.0.0.1:8888",
|
"127.0.0.1:8889",
|
||||||
path = "/wss",
|
path = "/wss",
|
||||||
|
secure = true,
|
||||||
flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName})
|
flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName})
|
||||||
else:
|
else:
|
||||||
await WebSocket.connect(
|
await WebSocket.connect(
|
||||||
|
@ -38,7 +39,7 @@ proc main() {.async.} =
|
||||||
let dataStr = string.fromBytes(buff)
|
let dataStr = string.fromBytes(buff)
|
||||||
trace "Server Response: ", data = dataStr
|
trace "Server Response: ", data = dataStr
|
||||||
|
|
||||||
assert dataStr == reqData
|
doAssert dataStr == reqData
|
||||||
break
|
break
|
||||||
except WebSocketError as exc:
|
except WebSocketError as exc:
|
||||||
error "WebSocket error:", exception = exc.msg
|
error "WebSocket error:", exception = exc.msg
|
||||||
|
|
|
@ -17,11 +17,7 @@ import ../tests/keys
|
||||||
|
|
||||||
proc handle(request: HttpRequest) {.async.} =
|
proc handle(request: HttpRequest) {.async.} =
|
||||||
trace "Handling request:", uri = request.uri.path
|
trace "Handling request:", uri = request.uri.path
|
||||||
let path = when defined tls: "/wss" else: "/ws"
|
|
||||||
if request.uri.path != path:
|
|
||||||
return
|
|
||||||
|
|
||||||
trace "Initiating web socket connection."
|
|
||||||
try:
|
try:
|
||||||
let deflateFactory = deflateFactory()
|
let deflateFactory = deflateFactory()
|
||||||
let server = WSServer.new(factories = [deflateFactory])
|
let server = WSServer.new(factories = [deflateFactory])
|
||||||
|
@ -49,26 +45,32 @@ proc handle(request: HttpRequest) {.async.} =
|
||||||
when isMainModule:
|
when isMainModule:
|
||||||
# we want to run parallel tests in CI
|
# we want to run parallel tests in CI
|
||||||
# so we are using different port
|
# so we are using different port
|
||||||
const serverAddr = when defined tls:
|
|
||||||
"127.0.0.1:8889"
|
|
||||||
else:
|
|
||||||
"127.0.0.1:8888"
|
|
||||||
|
|
||||||
proc main() {.async.} =
|
proc main() {.async.} =
|
||||||
let
|
let
|
||||||
address = initTAddress(serverAddr)
|
|
||||||
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||||
server = when defined tls:
|
server = when defined tls:
|
||||||
TlsHttpServer.create(
|
TlsHttpServer.create(
|
||||||
address = address,
|
address = initTAddress("127.0.0.1:8889"),
|
||||||
handler = handle,
|
|
||||||
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
|
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
|
||||||
tlsCertificate = TLSCertificate.init(SecureCert),
|
tlsCertificate = TLSCertificate.init(SecureCert),
|
||||||
flags = socketFlags)
|
flags = socketFlags)
|
||||||
else:
|
else:
|
||||||
HttpServer.create(address, handle, flags = socketFlags)
|
HttpServer.create(initTAddress("127.0.0.1:8888"), flags = socketFlags)
|
||||||
|
|
||||||
|
when defined accepts:
|
||||||
|
proc accepts() {.async, raises: [Defect].} =
|
||||||
|
while true:
|
||||||
|
try:
|
||||||
|
let req = await server.accept()
|
||||||
|
await req.handle()
|
||||||
|
except CatchableError as exc:
|
||||||
|
error "Transport error", exc = exc.msg
|
||||||
|
|
||||||
|
asyncCheck accepts()
|
||||||
|
else:
|
||||||
|
server.handler = handle
|
||||||
|
server.start()
|
||||||
|
|
||||||
server.start()
|
|
||||||
trace "Server listening on ", data = $server.localAddress()
|
trace "Server listening on ", data = $server.localAddress()
|
||||||
await server.join()
|
await server.join()
|
||||||
|
|
||||||
|
|
|
@ -29,13 +29,17 @@ type
|
||||||
|
|
||||||
HttpServer* = ref object of StreamServer
|
HttpServer* = ref object of StreamServer
|
||||||
handler*: HttpAsyncCallback
|
handler*: HttpAsyncCallback
|
||||||
|
case secure*: bool:
|
||||||
|
of true:
|
||||||
|
tlsFlags*: set[TLSFlags]
|
||||||
|
tlsPrivateKey*: TLSPrivateKey
|
||||||
|
tlsCertificate*: TLSCertificate
|
||||||
|
minVersion*: TLSVersion
|
||||||
|
maxVersion*: TLSVersion
|
||||||
|
else:
|
||||||
|
discard
|
||||||
|
|
||||||
TlsHttpServer* = ref object of HttpServer
|
TlsHttpServer* = HttpServer
|
||||||
tlsFlags*: set[TLSFlags]
|
|
||||||
tlsPrivateKey*: TLSPrivateKey
|
|
||||||
tlsCertificate*: TLSCertificate
|
|
||||||
minVersion*: TLSVersion
|
|
||||||
maxVersion*: TLSVersion
|
|
||||||
|
|
||||||
proc validateRequest(
|
proc validateRequest(
|
||||||
stream: AsyncStreamWriter,
|
stream: AsyncStreamWriter,
|
||||||
|
@ -73,7 +77,7 @@ proc parseRequest(
|
||||||
# Timeout
|
# Timeout
|
||||||
trace "Timeout expired while receiving headers", address = $remoteAddr
|
trace "Timeout expired while receiving headers", address = $remoteAddr
|
||||||
await stream.writer.sendError(Http408, version = HttpVersion11)
|
await stream.writer.sendError(Http408, version = HttpVersion11)
|
||||||
return
|
raise newException(HttpError, "Didn't read headers in time!")
|
||||||
|
|
||||||
let hlen = hlenfut.read()
|
let hlen = hlenfut.read()
|
||||||
buffer.setLen(hlen)
|
buffer.setLen(hlen)
|
||||||
|
@ -82,7 +86,7 @@ proc parseRequest(
|
||||||
# Header could not be parsed
|
# Header could not be parsed
|
||||||
trace "Malformed header received", address = $remoteAddr
|
trace "Malformed header received", address = $remoteAddr
|
||||||
await stream.writer.sendError(Http400, version = HttpVersion11)
|
await stream.writer.sendError(Http400, version = HttpVersion11)
|
||||||
return
|
raise newException(HttpError, "Malformed header received")
|
||||||
|
|
||||||
var vres = await stream.writer.validateRequest(requestData)
|
var vres = await stream.writer.validateRequest(requestData)
|
||||||
let hdrs =
|
let hdrs =
|
||||||
|
@ -94,9 +98,9 @@ proc parseRequest(
|
||||||
|
|
||||||
if vres == ReqStatus.ErrorFailure:
|
if vres == ReqStatus.ErrorFailure:
|
||||||
trace "Remote peer disconnected", address = $remoteAddr
|
trace "Remote peer disconnected", address = $remoteAddr
|
||||||
return
|
raise newException(HttpError, "Remote peer disconnected")
|
||||||
|
|
||||||
debug "Received valid HTTP request", address = $remoteAddr
|
trace "Received valid HTTP request", address = $remoteAddr
|
||||||
return HttpRequest(
|
return HttpRequest(
|
||||||
headers: hdrs,
|
headers: hdrs,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
|
@ -110,8 +114,6 @@ proc parseRequest(
|
||||||
trace "Remote peer disconnected", address = $remoteAddr
|
trace "Remote peer disconnected", address = $remoteAddr
|
||||||
except TransportOsError as exc:
|
except TransportOsError as exc:
|
||||||
trace "Problems with networking", address = $remoteAddr, error = exc.msg
|
trace "Problems with networking", address = $remoteAddr, error = exc.msg
|
||||||
except CatchableError as exc:
|
|
||||||
debug "Unknown exception", address = $remoteAddr, error = exc.msg
|
|
||||||
|
|
||||||
proc handleConnCb(
|
proc handleConnCb(
|
||||||
server: StreamServer,
|
server: StreamServer,
|
||||||
|
@ -160,16 +162,16 @@ proc handleTlsConnCb(
|
||||||
finally:
|
finally:
|
||||||
await stream.closeWait()
|
await stream.closeWait()
|
||||||
|
|
||||||
proc accept*(server: HttpServer | TlsHttpServer): Future[HttpRequest]
|
proc accept*(server: HttpServer): Future[HttpRequest]
|
||||||
{.async, raises: [Defect, HttpError].} =
|
{.async, raises: [Defect, HttpError].} =
|
||||||
|
|
||||||
if not isNil(server.handler):
|
if not isNil(server.handler):
|
||||||
raise newException(HttpError,
|
raise newException(HttpError,
|
||||||
"Callback already registered - cannot mix callback and accepts stypes!")
|
"Callback already registered - cannot mix callback and accepts styles!")
|
||||||
|
|
||||||
|
trace "Awaiting new request"
|
||||||
let transp = await StreamServer(server).accept()
|
let transp = await StreamServer(server).accept()
|
||||||
var stream: AsyncStream
|
let stream = if server.secure:
|
||||||
when server is TlsHttpServer:
|
|
||||||
let tlsStream = newTLSServerAsyncStream(
|
let tlsStream = newTLSServerAsyncStream(
|
||||||
newAsyncStreamReader(transp),
|
newAsyncStreamReader(transp),
|
||||||
newAsyncStreamWriter(transp),
|
newAsyncStreamWriter(transp),
|
||||||
|
@ -179,14 +181,15 @@ proc accept*(server: HttpServer | TlsHttpServer): Future[HttpRequest]
|
||||||
maxVersion = server.maxVersion,
|
maxVersion = server.maxVersion,
|
||||||
flags = server.tlsFlags)
|
flags = server.tlsFlags)
|
||||||
|
|
||||||
stream = AsyncStream(
|
AsyncStream(
|
||||||
reader: tlsStream.reader,
|
reader: tlsStream.reader,
|
||||||
writer: tlsStream.writer)
|
writer: tlsStream.writer)
|
||||||
else:
|
else:
|
||||||
stream = AsyncStream(
|
AsyncStream(
|
||||||
reader: newAsyncStreamReader(transp),
|
reader: newAsyncStreamReader(transp),
|
||||||
writer: newAsyncStreamWriter(transp))
|
writer: newAsyncStreamWriter(transp))
|
||||||
|
|
||||||
|
trace "Got new request", isTls = server.secure
|
||||||
return await server.parseRequest(stream)
|
return await server.parseRequest(stream)
|
||||||
|
|
||||||
proc create*(
|
proc create*(
|
||||||
|
@ -206,21 +209,20 @@ proc create*(
|
||||||
flags,
|
flags,
|
||||||
child = StreamServer(server)))
|
child = StreamServer(server)))
|
||||||
|
|
||||||
trace "Created HTTP Server", host = $address
|
trace "Created HTTP Server", host = $server.localAddress()
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
proc create*(
|
proc create*(
|
||||||
_: typedesc[HttpServer],
|
_: typedesc[HttpServer],
|
||||||
host: string,
|
host: string,
|
||||||
port: Port,
|
|
||||||
handler: HttpAsyncCallback = nil,
|
handler: HttpAsyncCallback = nil,
|
||||||
flags: set[ServerFlags] = {}): HttpServer
|
flags: set[ServerFlags] = {}): HttpServer
|
||||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||||
## Make a new HTTP Server
|
## Make a new HTTP Server
|
||||||
##
|
##
|
||||||
|
|
||||||
return HttpServer.create(initTAddress(host, port), handler, flags)
|
return HttpServer.create(initTAddress(host), handler, flags)
|
||||||
|
|
||||||
proc create*(
|
proc create*(
|
||||||
_: typedesc[TlsHttpServer],
|
_: typedesc[TlsHttpServer],
|
||||||
|
@ -235,6 +237,7 @@ proc create*(
|
||||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||||
|
|
||||||
var server = TlsHttpServer(
|
var server = TlsHttpServer(
|
||||||
|
secure: true,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
tlsPrivateKey: tlsPrivateKey,
|
tlsPrivateKey: tlsPrivateKey,
|
||||||
tlsCertificate: tlsCertificate,
|
tlsCertificate: tlsCertificate,
|
||||||
|
@ -248,14 +251,13 @@ proc create*(
|
||||||
flags,
|
flags,
|
||||||
child = StreamServer(server)))
|
child = StreamServer(server)))
|
||||||
|
|
||||||
trace "Created TLS HTTP Server", host = $address
|
trace "Created TLS HTTP Server", host = $server.localAddress()
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
proc create*(
|
proc create*(
|
||||||
_: typedesc[TlsHttpServer],
|
_: typedesc[TlsHttpServer],
|
||||||
host: string,
|
host: string,
|
||||||
port: Port,
|
|
||||||
tlsPrivateKey: TLSPrivateKey,
|
tlsPrivateKey: TLSPrivateKey,
|
||||||
tlsCertificate: TLSCertificate,
|
tlsCertificate: TLSCertificate,
|
||||||
handler: HttpAsyncCallback = nil,
|
handler: HttpAsyncCallback = nil,
|
||||||
|
@ -265,7 +267,7 @@ proc create*(
|
||||||
tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer
|
tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer
|
||||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||||
TlsHttpServer.create(
|
TlsHttpServer.create(
|
||||||
address = initTAddress(host, port),
|
address = initTAddress(host),
|
||||||
handler = handler,
|
handler = handler,
|
||||||
tlsPrivateKey = tlsPrivateKey,
|
tlsPrivateKey = tlsPrivateKey,
|
||||||
tlsCertificate = tlsCertificate,
|
tlsCertificate = tlsCertificate,
|
||||||
|
|
Loading…
Reference in New Issue