From e4f00698eadc94f09b828adc948eeecf27db72d7 Mon Sep 17 00:00:00 2001 From: Arijit Das Date: Tue, 6 Apr 2021 02:31:10 +0530 Subject: [PATCH] Update http to use chronos http. (#6) * Update http to use chronos http. * Add stream.nim file. * Address comments. * Fix CI failure. * Minor change. * Address comments. * Fix windows CI failing test. * minor cleanup * spacess * more idiomatic connect * use stew/base10 Co-authored-by: Dmitriy Ryajov --- examples/client.nim | 11 +- examples/server.nim | 77 +++++---- src/http.nim | 255 ----------------------------- src/random.nim | 1 - src/stream.nim | 53 ++++++ src/ws.nim | 190 ++++++++++++++-------- tests/helpers.nim | 54 ------- tests/testframes.nim | 1 - tests/testwebsockets.nim | 342 +++++++++++++++++++-------------------- 9 files changed, 395 insertions(+), 589 deletions(-) delete mode 100644 src/http.nim create mode 100644 src/stream.nim delete mode 100644 tests/helpers.nim diff --git a/examples/client.nim b/examples/client.nim index fb7c109..cb00025 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -1,8 +1,9 @@ -import ../src/ws, nativesockets, chronos, os, chronicles, stew/byteutils +import ../src/ws, nativesockets, chronos,chronicles, stew/byteutils proc main() {.async.} = - let ws = await connect( - "127.0.0.1", Port(8888), + let ws = await WebSocket.connect( + "127.0.0.1", + Port(8888), path = "/ws") debug "Websocket client: ", State = ws.readyState @@ -16,10 +17,10 @@ proc main() {.async.} = break let dataStr = string.fromBytes(buff) - debug "Server:", data = dataStr + debug "Server Response: ", data = dataStr assert dataStr == reqData - return # bail out + break except WebSocketError as exc: error "WebSocket error:", exception = exc.msg diff --git a/examples/server.nim b/examples/server.nim index b388881..379f047 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -1,37 +1,48 @@ -import ../src/ws, ../src/http, chronos, chronicles, httputils, stew/byteutils + import pkg/[chronos, + chronos/apps/http/httpserver, + chronicles, + httputils, + stew/byteutils] +import ../src/ws -proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = - debug "Handling request:", uri = header.uri() - if header.uri() == "/ws": - debug "Initiating web socket connection." - try: - var ws = await createServer(header, transp, "") - if ws.readyState != Open: - error "Failed to open websocket connection." - return +proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isOk(): + let request = r.get() + debug "Handling request:", uri = request.uri.path + if request.uri.path == "/ws": + debug "Initiating web socket connection." + try: + var ws = await createServer(request,"") + if ws.readyState != Open: + error "Failed to open websocket connection." + return + debug "Websocket handshake completed." + while ws.readyState != ReadyState.Closed: + # Only reads header for data frame. + var recvData = await ws.recv() + if recvData.len <= 0: + debug "Empty messages" + break - debug "Websocket handshake completed." - while ws.readyState != ReadyState.Closed: - # Only reads header for data frame. - var recvData = await ws.recv() - if recvData.len <= 0: - debug "Empty messages" - break - - # debug "Response: ", data = string.fromBytes(recvData), size = recvData.len - debug "Response: ", size = recvData.len - await ws.send(recvData) - # await ws.close() - - except WebSocketError as exc: - error "WebSocket error:", exception = exc.msg - - discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World") - await transp.closeWait() + # debug "Client Response: ", data = string.fromBytes(recvData), size = recvData.len + debug "Client Response: ", size = recvData.len + await ws.send(recvData) + # await ws.close() + + except WebSocketError as exc: + error "WebSocket error:", exception = exc.msg + discard await request.respond(Http200, "Hello World") + else: + return dumbResponse() when isMainModule: - let address = "127.0.0.1:8888" - var httpServer = newHttpServer(address, cb) - httpServer.start() - echo "Server started..." - waitFor httpServer.join() + let address = initTAddress("127.0.0.1:8888") + let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + let res = HttpServerRef.new( + address, process, + socketFlags = socketFlags) + + let server = res.get() + server.start() + info "Server listening at ", data = address + waitFor server.join() \ No newline at end of file diff --git a/src/http.nim b/src/http.nim deleted file mode 100644 index 0ccfc49..0000000 --- a/src/http.nim +++ /dev/null @@ -1,255 +0,0 @@ -import chronos, chronos/timer, httputils, chronicles, uri, tables, strutils - -const - MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets - MaxHttpRequestSize = 128 * 1024 # maximum size of HTTP body in octets - HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec) - CRLF* = "\r\n" - HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')] - -type - HttpClient* = ref object - connected: bool - currentURL: Uri ## Where we are currently connected. - headers: HttpHeaders ## Headers to send in requests. - transp*: StreamTransport - buf: seq[byte] - - HttpHeaders* = object - table*: TableRef[string, seq[string]] - - ReqStatus = enum - Success, Error, ErrorFailure - - AsyncCallback = proc (transp: StreamTransport, - header: HttpRequestHeader): Future[void] {.closure, gcsafe.} - - HttpServer* = ref object of StreamServer - callback: AsyncCallback - -proc recvData(transp: StreamTransport): Future[seq[byte]] {.async.} = - var buffer = newSeq[byte](MaxHttpHeadersSize) - var error = false - try: - let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize, - sep = HeaderSep) - let ores = await withTimeout(hlenfut, HttpHeadersTimeout) - if not ores: - # Timeout - debug "Timeout expired while receiving headers", - address = transp.remoteAddress() - error = true - else: - let hlen = hlenfut.read() - buffer.setLen(hlen) - except TransportLimitError: - # size of headers exceeds `MaxHttpHeadersSize` - debug "Maximum size of headers limit reached", - address = transp.remoteAddress() - error = true - except TransportIncompleteError: - # remote peer disconnected - debug "Remote peer disconnected", address = transp.remoteAddress() - error = true - except TransportOsError as exc: - debug "Problems with networking", address = transp.remoteAddress(), - error = exc.msg - error = true - - if error: - buffer.setLen(0) - return buffer - -proc connect(client: HttpClient, url: Uri) {.async.} = - if client.connected: - return - - let port = - if url.port == "": 80 - else: url.port.parseInt - - client.transp = await connect(initTAddress(url.hostname, port)) - - # May be connected through proxy but remember actual URL being accessed - client.currentURL = url - client.connected = true - -proc generateHeaders( - requestUrl: Uri, - httpMethod: string, - additionalHeaders: HttpHeaders): string = - # GET - var headers = httpMethod.toUpperAscii() - headers.add ' ' - - if not requestUrl.path.startsWith("/"): headers.add '/' - headers.add(requestUrl.path) - - # HTTP/1.1\c\l - headers.add(" HTTP/1.1" & CRLF) - - for key, val in additionalHeaders.table: - headers.add(key & ": " & val.join(", ") & CRLF) - headers.add(CRLF) - return headers - -# Send request to the client. Currently only supports HTTP get method. -proc request*( - client: HttpClient, - url, - httpMethod: string, - body = "", - headers: HttpHeaders): Future[seq[byte]] {.async.} = - # Helper that actually makes the request. Does not handle redirects. - - let requestUrl = parseUri(url) - if requestUrl.scheme == "": - raise newException(ValueError, "No uri scheme supplied.") - - await connect(client, requestUrl) - - let headerString = generateHeaders(requestUrl, httpMethod, headers) - let res = await client.transp.write(headerString) - if res != len(headerString): - raise newException(ValueError, "Error while send request to client") - - var value = await client.transp.recvData() - if value.len == 0: - raise newException(ValueError, "Empty response from server") - return value - -proc sendHTTPResponse*( - transp: StreamTransport, - version: HttpVersion, - code: HttpCode, - data: string = ""): Future[bool] {.async.} = - - var answer = $version - answer.add(" ") - answer.add($code) - answer.add(CRLF) - answer.add("Date: " & httpDate() & CRLF) - if len(data) > 0: - answer.add("Content-Type: application/json" & CRLF) - answer.add("Content-Length: " & $len(data) & CRLF) - answer.add(CRLF) - if len(data) > 0: - answer.add(data) - - let res = await transp.write(answer) - if res == len(answer): - return true - - raise newException(IOError, "Failed to send http request.") - -proc validateRequest( - transp: StreamTransport, - header: HttpRequestHeader): Future[ReqStatus] {.async.} = - - if header.meth notin {MethodGet}: - debug "GET method is only allowed", address = transp.remoteAddress() - if await transp.sendHTTPResponse(header.version, Http405): - return Error - else: - return ErrorFailure - - var hlen = header.contentLength() - if hlen < 0 or hlen > MaxHttpRequestSize: - debug "Invalid header length", address = transp.remoteAddress() - if await transp.sendHTTPResponse(header.version, Http413): - return Error - else: - return ErrorFailure - - return Success - -proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = - ## Process transport data to the HTTP server - ## - - var httpServer = cast[HttpServer](server) - var buffer = newSeq[byte](MaxHttpHeadersSize) - var header: HttpRequestHeader - - debug "Received connection", address = $transp.remoteAddress() - try: - let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize, - sep = HeaderSep) - let ores = await withTimeout(hlenfut, HttpHeadersTimeout) - if not ores: - # Timeout - debug "Timeout expired while receiving headers", - address = transp.remoteAddress() - discard await transp.sendHTTPResponse(HttpVersion11, Http408) - await transp.closeWait() - return - else: - let hlen = hlenfut.read() - buffer.setLen(hlen) - header = buffer.parseRequest() - if header.failed(): - # Header could not be parsed - debug "Malformed header received", - address = transp.remoteAddress() - discard await transp.sendHTTPResponse(HttpVersion11, Http400) - await transp.closeWait() - return - var vres = await validateRequest(transp, header) - if vres == Success: - debug "Received valid HTTP request", address = $transp.remoteAddress() - # Call the user's callback. - if httpServer.callback != nil: - await httpServer.callback(transp, header) - elif vres == ErrorFailure: - debug "Remote peer disconnected", address = transp.remoteAddress() - except TransportLimitError: - # size of headers exceeds `MaxHttpHeadersSize` - debug "Maximum size of headers limit reached", - address = transp.remoteAddress() - discard await transp.sendHTTPResponse(HttpVersion11, Http413) - except TransportIncompleteError: - # remote peer disconnected - debug "Remote peer disconnected", address = transp.remoteAddress() - except TransportOsError as exc: - debug "Problems with networking", address = transp.remoteAddress(), - error = exc.msg - except CatchableError as exc: - debug "Unknown exception", address = transp.remoteAddress(), - error = exc.msg - await transp.closeWait() - -proc newHttpServer*(address: string, handler: AsyncCallback, - flags: set[ServerFlags] = {ReuseAddr}): HttpServer = - let address = initTAddress(address) - var server = HttpServer(callback: handler) - server = cast[HttpServer](createStreamServer(address, serveClient, flags, - child = cast[StreamServer](server))) - return server - -func newHttpHeaders*(): HttpHeaders = - ## Returns a new ``HttpHeaders`` object. if ``titleCase`` is set to true, - ## headers are passed to the server in title case (e.g. "Content-Length") - return HttpHeaders(table: newTable[string, seq[string]]()) - -func newHttpHeaders*(keyValuePairs: - openArray[tuple[key: string, val: string]]): HttpHeaders = - ## Returns a new ``HttpHeaders`` object from an array. if ``titleCase`` is set to true, - ## headers are passed to the server in title case (e.g. "Content-Length") - var headers = newHttpHeaders() - - for pair in keyValuePairs: - let key = toUpperAscii(pair.key) - if key in headers.table: - headers.table[key].add(pair.val) - else: - headers.table[key] = @[pair.val] - return headers - -proc newHttpClient*(headers = newHttpHeaders()): HttpClient = - return HttpClient(headers: headers) - -proc close*(client: HttpClient) = - ## Closes any connections held by the HTTP client. - if client.connected: - client.transp.close() - client.connected = false diff --git a/src/random.nim b/src/random.nim index 8b6906e..e3ea2d6 100644 --- a/src/random.nim +++ b/src/random.nim @@ -3,7 +3,6 @@ import bearssl ## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng const randMax = 18_446_744_073_709_551_615'u64 - proc rand*(rng: var BrHmacDrbgContext, max: Natural): int = if max == 0: return 0 var x: uint64 diff --git a/src/stream.nim b/src/stream.nim new file mode 100644 index 0000000..1d9ccbe --- /dev/null +++ b/src/stream.nim @@ -0,0 +1,53 @@ + import pkg/[chronos, + chronos/apps/http/httpserver, + chronos/timer, + chronicles, + httputils] +import strutils + +const + HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec) + HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')] + MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets + +proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = + var buffer = newSeq[byte](MaxHttpHeadersSize) + var error = false + try: + let hlenfut = rstream.readUntil( + addr buffer[0], MaxHttpHeadersSize, + sep = HeaderSep) + let ores = await withTimeout(hlenfut, HttpHeadersTimeout) + if not ores: + # Timeout + debug "Timeout expired while receiving headers", + address = rstream.tsource.remoteAddress() + error = true + else: + let hlen = hlenfut.read() + buffer.setLen(hlen) + except AsyncStreamLimitError: + # size of headers exceeds `MaxHttpHeadersSize` + debug "Maximum size of headers limit reached", + address = rstream.tsource.remoteAddress() + error = true + except AsyncStreamIncompleteError: + # remote peer disconnected + debug "Remote peer disconnected", address = rstream.tsource.remoteAddress() + error = true + except AsyncStreamError as exc: + debug "Problems with networking", address = rstream.tsource.remoteAddress(), + error = exc.msg + error = true + + if error: + buffer.setLen(0) + return buffer + +proc closeWait*(wsStream : AsyncStream): Future[void] {.async.} = + if not wsStream.writer.tsource.closed(): + await wsStream.writer.tsource.closeWait() + if not wsStream.reader.tsource.closed(): + await wsStream.reader.tsource.closeWait() + +# TODO: Implement stream read and write wrapper. diff --git a/src/ws.nim b/src/ws.nim index 51e519f..120a797 100644 --- a/src/ws.nim +++ b/src/ws.nim @@ -4,16 +4,19 @@ import std/[tables, parseutils] import pkg/[chronos, + chronos/apps/http/httptable, + chronos/apps/http/httpserver, + chronos/streams/asyncstream, chronicles, httputils, stew/byteutils, stew/endians2, stew/base64, - eth/keys] + stew/base10, + eth/keys, + nimcrypto/sha] -import pkg/nimcrypto/sha - -import ./random, ./http +import ./random, ./stream #[ +---------------------------------------------------------------+ @@ -44,6 +47,7 @@ const WSDefaultFrameSize* = 1 shl 20 # 1mb WSMaxMessageSize* = 20 shl 20 # 20mb WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + CRLF* = "\r\n" type ReadyState* {.pure.} = enum @@ -131,8 +135,8 @@ type CloseResult {.gcsafe.} WebSocket* = ref object - tcpSocket*: StreamTransport - version*: int + stream*: AsyncStream + version*: uint key*: string protocol*: string readyState*: ReadyState @@ -147,6 +151,19 @@ type template remainder*(frame: Frame): uint64 = frame.length - frame.consumed +proc `$`(ht: HttpTables): string = + ## Returns string representation of HttpTable/Ref. + var res = "" + for key,value in ht.stringItems(true): + res.add(key.normalizeHeaderName()) + res.add(": ") + res.add(value) + res.add(CRLF) + + ## add for end of header mark + res.add(CRLF) + res + proc unmask*( data: var openArray[byte], maskKey: array[4, char], @@ -164,20 +181,26 @@ proc prepareCloseBody(code: Status, reason: string): seq[byte] = proc handshake*( ws: WebSocket, - header: HttpRequestHeader, - version = WSDefaultVersion) {.async.} = + request: HttpRequestRef, + version: uint = WSDefaultVersion) {.async.} = ## Handles the websocket handshake. ## + let + reqHeaders = request.headers + + ws.version = Base10.decode( + uint, + reqHeaders.getString("Sec-WebSocket-Version")) + .tryGet() # this method throws - discard parseSaturatedNatural(header["Sec-WebSocket-Version"], ws.version) if ws.version != version: raise newException(WSVersionError, "Websocket version not supported, Version: " & - header["Sec-WebSocket-Version"]) + reqHeaders.getString("Sec-WebSocket-Version")) - ws.key = header["Sec-WebSocket-Key"].strip() - if header.contains("Sec-WebSocket-Protocol"): - let wantProtocol = header["Sec-WebSocket-Protocol"].strip() + ws.key = reqHeaders.getString("Sec-WebSocket-Key").strip() + if reqHeaders.contains("Sec-WebSocket-Protocol"): + let wantProtocol = reqHeaders.getString("Sec-WebSocket-Protocol").strip() if ws.protocol != wantProtocol: raise newException(WSProtoMismatchError, "Protocol mismatch (expected: " & ws.protocol & ", got: " & @@ -186,23 +209,20 @@ proc handshake*( let cKey = ws.key & WSGuid let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data) - var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF - response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF) - response.add("Connection: Upgrade" & CRLF) - response.add("Upgrade: webSocket" & CRLF) - + var headerData = [("Connection", "Upgrade"),("Upgrade", "webSocket" ), + ("Sec-WebSocket-Accept", acceptKey)] + var headers = HttpTable.init(headerData) if ws.protocol != "": - response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF) - response.add CRLF + headers.add("Sec-WebSocket-Protocol", ws.protocol) - let res = await ws.tcpSocket.write(response) - if res != len(response): - raise newException(WSSendError, "Failed to send handshake response to client") + try: + discard await request.respond(httputils.Http101, "", headers) + except CatchableError as exc: + raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg) ws.readyState = ReadyState.Open proc createServer*( - header: HttpRequestHeader, - transp: StreamTransport, + request: HttpRequestRef, protocol: string = "", frameSize = WSDefaultFrameSize, onPing: ControlCb = nil, @@ -211,11 +231,15 @@ proc createServer*( ## Creates a new socket from a request. ## - if not header.contains("Sec-WebSocket-Version"): + if not request.headers.contains("Sec-WebSocket-Version"): raise newException(WSHandshakeError, "Missing version header") + let wsStream = AsyncStream( + reader: request.connection.reader, + writer: request.connection.writer) + var ws = WebSocket( - tcpSocket: transp, + stream: wsStream, protocol: protocol, masked: false, rng: newRng(), @@ -224,7 +248,7 @@ proc createServer*( onPong: onPong, onClose: onClose) - await ws.handshake(header) + await ws.handshake(request) return ws proc encodeFrame*(f: Frame): seq[byte] = @@ -302,7 +326,7 @@ proc send*( maskKey = genMaskKey(ws.rng) if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: - discard await ws.tcpSocket.write(encodeFrame(Frame( + await ws.stream.writer.write(encodeFrame(Frame( fin: true, rsv1: false, rsv2: false, @@ -328,7 +352,7 @@ proc send*( data: data[i ..< len], maskKey: maskKey) - discard await ws.tcpSocket.write(encodeFrame(inFrame)) + await ws.stream.writer.write(encodeFrame(inFrame)) i += len proc send*(ws: WebSocket, data: string): Future[void] = @@ -347,7 +371,7 @@ proc handleClose*(ws: WebSocket, frame: Frame) {.async.} = var data = newSeq[byte](frame.length) if frame.length > 0: # Read the data. - await ws.tcpSocket.readExactly(addr data[0], int frame.length) + await ws.stream.reader.readExactly(addr data[0], int frame.length) unmask(data.toOpenArray(0, data.high), frame.maskKey) var code: Status @@ -363,13 +387,13 @@ proc handleClose*(ws: WebSocket, frame: Frame) {.async.} = try: (rcode, reason) = ws.onClose(code, string.fromBytes(data)) except CatchableError as exc: - debug "Exception in Close callback, this is most likelly a bug", exc = exc.msg + debug "Exception in Close callback, this is most likely a bug", exc = exc.msg - # don't respong to a terminated connection + # don't respond to a terminated connection if ws.readyState != ReadyState.Closing: await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) - await ws.tcpSocket.closeWait() + await ws.stream.closeWait() ws.readyState = ReadyState.Closed else: raiseAssert("Invalid state during close!") @@ -405,9 +429,9 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = else: raiseAssert("Invalid control opcode") except CatchableError as exc: - debug "Exception handling control messages", exc = exc.msg + trace "Exception handling control messages", exc = exc.msg ws.readyState = ReadyState.Closed - await ws.tcpSocket.closeWait() + await ws.stream.closeWait() proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = ## Gets a frame from the WebSocket. @@ -418,8 +442,7 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = while ws.readyState != ReadyState.Closed: # read until a data frame arrives # Grab the header. var header = newSeq[byte](2) - await ws.tcpSocket.readExactly(addr header[0], 2) - + await ws.stream.reader.readExactly(addr header[0], 2) if header.len != 2: debug "Invalid websocket header length" raise newException(WSMalformedHeaderError, "Invalid websocket header length") @@ -453,12 +476,12 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = if headerLen == 0x7e: # Length must be 7+16 bits. var length = newSeq[byte](2) - await ws.tcpSocket.readExactly(addr length[0], 2) + await ws.stream.reader.readExactly(addr length[0], 2) finalLen = uint16.fromBytesBE(length) elif headerLen == 0x7f: # Length must be 7+64 bits. var length = newSeq[byte](8) - await ws.tcpSocket.readExactly(addr length[0], 8) + await ws.stream.reader.readExactly(addr length[0], 8) finalLen = uint64.fromBytesBE(length) else: # Length must be 7 bits. @@ -475,7 +498,7 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = var maskKey = newSeq[byte](4) if frame.mask: # Read the mask. - await ws.tcpSocket.readExactly(addr maskKey[0], 4) + await ws.stream.reader.readExactly(addr maskKey[0], 4) for i in 0..