diff --git a/examples/autobahn_client.nim b/examples/autobahn_client.nim index 4f015672ec..6c7a0b313e 100644 --- a/examples/autobahn_client.nim +++ b/examples/autobahn_client.nim @@ -32,8 +32,8 @@ proc connectServer(path: string, factories: seq[ExtFactory] = @[]): Future[WSSes let ws = await WebSocket.connect( host = "127.0.0.1:$1" % [$serverPort], path = path, - secure=secure, - flags=clientFlags, + secure = secure, + flags = clientFlags, factories = factories ) return ws diff --git a/examples/client.nim b/examples/client.nim index 133e416e19..308bb9a8cd 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -17,8 +17,9 @@ import ../websock/websock proc main() {.async.} = let ws = when defined tls: await WebSocket.connect( - "127.0.0.1:8888", + "127.0.0.1:8889", path = "/wss", + secure = true, flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}) else: await WebSocket.connect( @@ -38,7 +39,7 @@ proc main() {.async.} = let dataStr = string.fromBytes(buff) trace "Server Response: ", data = dataStr - assert dataStr == reqData + doAssert dataStr == reqData break except WebSocketError as exc: error "WebSocket error:", exception = exc.msg diff --git a/examples/server.nim b/examples/server.nim index 412e39a9a2..84a8b00c02 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -17,11 +17,7 @@ import ../tests/keys proc handle(request: HttpRequest) {.async.} = trace "Handling request:", uri = request.uri.path - let path = when defined tls: "/wss" else: "/ws" - if request.uri.path != path: - return - trace "Initiating web socket connection." try: let deflateFactory = deflateFactory() let server = WSServer.new(factories = [deflateFactory]) @@ -49,26 +45,32 @@ proc handle(request: HttpRequest) {.async.} = when isMainModule: # we want to run parallel tests in CI # so we are using different port - const serverAddr = when defined tls: - "127.0.0.1:8889" - else: - "127.0.0.1:8888" - proc main() {.async.} = let - address = initTAddress(serverAddr) socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} server = when defined tls: TlsHttpServer.create( - address = address, - handler = handle, + address = initTAddress("127.0.0.1:8889"), tlsPrivateKey = TLSPrivateKey.init(SecureKey), tlsCertificate = TLSCertificate.init(SecureCert), flags = socketFlags) else: - HttpServer.create(address, handle, flags = socketFlags) + HttpServer.create(initTAddress("127.0.0.1:8888"), flags = socketFlags) + + when defined accepts: + proc accepts() {.async, raises: [Defect].} = + while true: + try: + let req = await server.accept() + await req.handle() + except CatchableError as exc: + error "Transport error", exc = exc.msg + + asyncCheck accepts() + else: + server.handler = handle + server.start() - server.start() trace "Server listening on ", data = $server.localAddress() await server.join() diff --git a/websock/http/server.nim b/websock/http/server.nim index cb3a226c8b..530bbdc233 100644 --- a/websock/http/server.nim +++ b/websock/http/server.nim @@ -29,13 +29,17 @@ type HttpServer* = ref object of StreamServer handler*: HttpAsyncCallback + case secure*: bool: + of true: + tlsFlags*: set[TLSFlags] + tlsPrivateKey*: TLSPrivateKey + tlsCertificate*: TLSCertificate + minVersion*: TLSVersion + maxVersion*: TLSVersion + else: + discard - TlsHttpServer* = ref object of HttpServer - tlsFlags*: set[TLSFlags] - tlsPrivateKey*: TLSPrivateKey - tlsCertificate*: TLSCertificate - minVersion*: TLSVersion - maxVersion*: TLSVersion + TlsHttpServer* = HttpServer proc validateRequest( stream: AsyncStreamWriter, @@ -73,7 +77,7 @@ proc parseRequest( # Timeout trace "Timeout expired while receiving headers", address = $remoteAddr await stream.writer.sendError(Http408, version = HttpVersion11) - return + raise newException(HttpError, "Didn't read headers in time!") let hlen = hlenfut.read() buffer.setLen(hlen) @@ -82,7 +86,7 @@ proc parseRequest( # Header could not be parsed trace "Malformed header received", address = $remoteAddr await stream.writer.sendError(Http400, version = HttpVersion11) - return + raise newException(HttpError, "Malformed header received") var vres = await stream.writer.validateRequest(requestData) let hdrs = @@ -94,9 +98,9 @@ proc parseRequest( if vres == ReqStatus.ErrorFailure: trace "Remote peer disconnected", address = $remoteAddr - return + raise newException(HttpError, "Remote peer disconnected") - debug "Received valid HTTP request", address = $remoteAddr + trace "Received valid HTTP request", address = $remoteAddr return HttpRequest( headers: hdrs, stream: stream, @@ -110,8 +114,6 @@ proc parseRequest( trace "Remote peer disconnected", address = $remoteAddr except TransportOsError as exc: trace "Problems with networking", address = $remoteAddr, error = exc.msg - except CatchableError as exc: - debug "Unknown exception", address = $remoteAddr, error = exc.msg proc handleConnCb( server: StreamServer, @@ -160,16 +162,16 @@ proc handleTlsConnCb( finally: await stream.closeWait() -proc accept*(server: HttpServer | TlsHttpServer): Future[HttpRequest] +proc accept*(server: HttpServer): Future[HttpRequest] {.async, raises: [Defect, HttpError].} = if not isNil(server.handler): raise newException(HttpError, - "Callback already registered - cannot mix callback and accepts stypes!") + "Callback already registered - cannot mix callback and accepts styles!") + trace "Awaiting new request" let transp = await StreamServer(server).accept() - var stream: AsyncStream - when server is TlsHttpServer: + let stream = if server.secure: let tlsStream = newTLSServerAsyncStream( newAsyncStreamReader(transp), newAsyncStreamWriter(transp), @@ -179,14 +181,15 @@ proc accept*(server: HttpServer | TlsHttpServer): Future[HttpRequest] maxVersion = server.maxVersion, flags = server.tlsFlags) - stream = AsyncStream( + AsyncStream( reader: tlsStream.reader, writer: tlsStream.writer) else: - stream = AsyncStream( + AsyncStream( reader: newAsyncStreamReader(transp), writer: newAsyncStreamWriter(transp)) + trace "Got new request", isTls = server.secure return await server.parseRequest(stream) proc create*( @@ -206,21 +209,20 @@ proc create*( flags, child = StreamServer(server))) - trace "Created HTTP Server", host = $address + trace "Created HTTP Server", host = $server.localAddress() return server proc create*( _: typedesc[HttpServer], host: string, - port: Port, handler: HttpAsyncCallback = nil, flags: set[ServerFlags] = {}): HttpServer {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError ## Make a new HTTP Server ## - return HttpServer.create(initTAddress(host, port), handler, flags) + return HttpServer.create(initTAddress(host), handler, flags) proc create*( _: typedesc[TlsHttpServer], @@ -235,6 +237,7 @@ proc create*( {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError var server = TlsHttpServer( + secure: true, handler: handler, tlsPrivateKey: tlsPrivateKey, tlsCertificate: tlsCertificate, @@ -248,14 +251,13 @@ proc create*( flags, child = StreamServer(server))) - trace "Created TLS HTTP Server", host = $address + trace "Created TLS HTTP Server", host = $server.localAddress() return server proc create*( _: typedesc[TlsHttpServer], host: string, - port: Port, tlsPrivateKey: TLSPrivateKey, tlsCertificate: TLSCertificate, handler: HttpAsyncCallback = nil, @@ -265,7 +267,7 @@ proc create*( tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError TlsHttpServer.create( - address = initTAddress(host, port), + address = initTAddress(host), handler = handler, tlsPrivateKey = tlsPrivateKey, tlsCertificate = tlsCertificate,