diff --git a/src/ws.nim b/src/ws.nim index 199810fe..a47c065e 100644 --- a/src/ws.nim +++ b/src/ws.nim @@ -7,32 +7,6 @@ const HttpBodyTimeout = 12.seconds # timeout for receiving body (12 sec) HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] -type HeaderVerificationError* {.pure.} = enum - None - ## No error. - UnsupportedVersion - ## The Sec-Websocket-Version header gave an unsupported version. - ## The only currently supported version is 13. - NoKey - ## No Sec-Websocket-Key was provided. - ProtocolAdvertised - ## A protocol was advertised but the server gave no protocol. - NoProtocolsSupported - ## None of the advertised protocols match the server protocol. - NoProtocolAdvertised - ## Server asked for a protocol but no protocol was advertised. - -proc `$`*(error: HeaderVerificationError): string = - const errorTable: array[HeaderVerificationError, string] = [ - "no error", - "the only supported sec-websocket-version is 13", - "no sec-websocket-key provided", - "server does not support protocol negotation", - "no advertised protocol supported", - "no protocol advertised" - ] - result = errorTable[error] - type ReadyState* = enum Connecting = 0 # The connection is not yet open. @@ -41,23 +15,23 @@ type Closed = 3 # The connection is closed or couldn't be opened. WebSocket* = ref object + tcpSocket*: StreamTransport version*: int key*: string protocol*: string readyState*: ReadyState masked*: bool # send masked packets - HttpServer* = ref object - server*: StreamServer + AsyncCallback = proc (transp: StreamTransport, header: HttpRequestHeader): Future[void] {.closure, gcsafe.} + HttpServer* = ref object of StreamServer callback: AsyncCallback - maxBody: int ReqStatus = enum Success, Error, ErrorFailure WebSocketError* = object of IOError -proc sendAnswer(transp: StreamTransport, version: HttpVersion, code: HttpCode, +proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: HttpCode, data: string = ""): Future[bool] {.async.} = var answer = $version answer.add(" ") @@ -83,7 +57,7 @@ proc validateRequest(transp: StreamTransport, if header.meth notin {MethodGet}: # Request method is either PUT or DELETE. debug "GET method is only allowed", address = transp.remoteAddress() - if await transp.sendAnswer(header.version, Http405): + if await transp.sendHTTPResponse(header.version, Http405): result = Error else: result = ErrorFailure @@ -93,7 +67,7 @@ proc validateRequest(transp: StreamTransport, if length <= 0: # request length could not be calculated. debug "Content-Length is missing or 0", address = transp.remoteAddress() - if await transp.sendAnswer(header.version, Http411): + if await transp.sendHTTPResponse(header.version, Http411): result = Error else: result = ErrorFailure @@ -103,7 +77,7 @@ proc validateRequest(transp: StreamTransport, # request length is more then `MaxHttpRequestSize`. debug "Maximum size of request body reached", address = transp.remoteAddress() - if await transp.sendAnswer(header.version, Http413): + if await transp.sendHTTPResponse(header.version, Http413): result = Error else: result = ErrorFailure @@ -111,8 +85,9 @@ proc validateRequest(transp: StreamTransport, result = Success -proc serveClient(server: StreamServer, transp: StreamTransport) {.async, gcsafe.} = +proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = ## Process transport data to the RPC server + var httpServer = cast[HttpServer](server) var buffer = newSeq[byte](MaxHttpHeadersSize) var header: HttpRequestHeader var connection: string @@ -125,7 +100,7 @@ proc serveClient(server: StreamServer, transp: StreamTransport) {.async, gcsafe. # Timeout debug "Timeout expired while receiving headers", address = transp.remoteAddress() - let res = await transp.sendAnswer(HttpVersion11, Http408) + let res = await transp.sendHTTPResponse(HttpVersion11, Http408) await transp.closeWait() return else: @@ -136,14 +111,14 @@ proc serveClient(server: StreamServer, transp: StreamTransport) {.async, gcsafe. # Header could not be parsed debug "Malformed header received", address = transp.remoteAddress() - let res = await transp.sendAnswer(HttpVersion11, Http400) + let res = await transp.sendHTTPResponse(HttpVersion11, Http400) await transp.closeWait() return except TransportLimitError: # size of headers exceeds `MaxHttpHeadersSize` debug "Maximum size of headers limit reached", address = transp.remoteAddress() - let res = await transp.sendAnswer(HttpVersion11, Http413) + let res = await transp.sendHTTPResponse(HttpVersion11, Http413) await transp.closeWait() return except TransportIncompleteError: @@ -166,15 +141,18 @@ proc serveClient(server: StreamServer, transp: StreamTransport) {.async, gcsafe. if vres == Success: trace "Received valid RPC request", address = $transp.remoteAddress() - info "Header: ", header - debug "Disconnecting client", address = transp.remoteAddress() - await transp.closeWait() + + # Call the user's callback. + if httpServer.callback != nil: + await httpServer.callback(transp, header) + elif vres == ErrorFailure: debug "Remote peer disconnected", address = transp.remoteAddress() await transp.closeWait() -proc newHttpServer*(address: string, +proc newHttpServer*(address: string, handler:AsyncCallback, flags: set[ServerFlags] = {ReuseAddr}): HttpServer = - let address = initTAddress(address) new result - result.server = createStreamServer(address, serveClient, {ReuseAddr}) + let address = initTAddress(address) + result.callback = handler + result = cast[HttpServer](createStreamServer(address, serveClient, flags, child = cast[StreamServer](result))) diff --git a/test/server.nim b/test/server.nim index 9dba1cee..17716f87 100644 --- a/test/server.nim +++ b/test/server.nim @@ -1,7 +1,13 @@ -import ws, chronos +import ws, chronos, chronicles, httputils + +proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + info "Header: ", header + let res = await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World") + debug "Disconnecting client", address = transp.remoteAddress() + await transp.closeWait() when isMainModule: let address = "127.0.0.1:8888" - var httpServer = newHttpServer(address) - httpServer.server.start() - waitFor httpServer.server.join() + var httpServer = newHttpServer(address, cb) + httpServer.start() + waitFor httpServer.join()