From a1ae7d2c708daa486d4debd4e8457e05ad8b5d8b Mon Sep 17 00:00:00 2001 From: Arijit Das Date: Thu, 11 Mar 2021 09:04:14 +0530 Subject: [PATCH] [WIP] Web socket client implementation. (#2) * Implement websocket server. * Implement websocket client. * Run nimpretty. * Remove commented code. * Address comments. * Address comments on websocket server. * Use seq[byte] to store data. * Working bytes conversion. * Remove result from return * Refactor the code. * Minor change. * Add test. * Add websocket test and fix closing handshake. * Add MsgReader to read data in external buffer. --- examples/client.nim | 21 +++ examples/server.nim | 41 ++++ src/http.nim | 250 ++++++++++++++++++++++++ src/random.nim | 25 +++ src/ws.nim | 451 ++++++++++++++++++++++++++++++++++++++++++++ test/client.nim | 7 - test/config.nims | 1 - test/server.nim | 20 -- tests/frame.nim | 76 ++++++++ tests/helpers.nim | 55 ++++++ tests/websocket.nim | 87 +++++++++ ws.nimble | 9 +- 12 files changed, 1012 insertions(+), 31 deletions(-) create mode 100644 examples/client.nim create mode 100644 examples/server.nim create mode 100644 src/http.nim create mode 100644 src/random.nim create mode 100644 src/ws.nim delete mode 100644 test/client.nim delete mode 100644 test/config.nims delete mode 100644 test/server.nim create mode 100644 tests/frame.nim create mode 100644 tests/helpers.nim create mode 100644 tests/websocket.nim diff --git a/examples/client.nim b/examples/client.nim new file mode 100644 index 0000000000..730fee558f --- /dev/null +++ b/examples/client.nim @@ -0,0 +1,21 @@ +import ../src/ws, nativesockets, chronos, os, chronicles, stew/byteutils + +let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws", + protocols = @["myfancyprotocol"]) +info "Websocket client: ", State = wsClient.readyState + +let reqData = "Hello Server" +for idx in 1 .. 5: + try: + waitFor wsClient.sendStr(reqData) + let recvData = waitFor wsClient.receiveStrPacket() + let dataStr = string.fromBytes(recvData) + info "Server:", data = dataStr + assert dataStr == reqData + except WebSocketError: + error "WebSocket error:", exception = getCurrentExceptionMsg() + os.sleep(1000) + +# close the websocket +waitFor wsClient.close() + diff --git a/examples/server.nim b/examples/server.nim new file mode 100644 index 0000000000..d9ba10a652 --- /dev/null +++ b/examples/server.nim @@ -0,0 +1,41 @@ +import ../src/ws, ../src/http, chronos, chronicles, httputils, stew/byteutils + +proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + info "Handling request:", uri = header.uri() + if header.uri() == "/ws": + info "Initiating web socket connection." + try: + var ws = await newWebSocket(header, transp, "myfancyprotocol") + if ws.readyState == Open: + info "Websocket handshake completed." + else: + error "Failed to open websocket connection." + return + + while true: + # Only reads header for data frame. + let msgReader = await ws.nextMessageReader() + + # Read the frame payload in buffer. + let buffer = newSeq[byte](100) + var recvData :seq[byte] + while msgReader.error != EOFError: + msgReader.readMessage(buffer) + recvData.add buffer + if ws.readyState == ReadyState.Closed: + return + info "Response: ", data = recvData + await ws.send(recvData) + + except WebSocketError: + error "WebSocket error:", exception = getCurrentExceptionMsg() + + discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World") + await transp.closeWait() + +when isMainModule: + let address = "127.0.0.1:8888" + var httpServer = newHttpServer(address, cb) + httpServer.start() + echo "Server started..." + waitFor httpServer.join() diff --git a/src/http.nim b/src/http.nim new file mode 100644 index 0000000000..f16cd29f57 --- /dev/null +++ b/src/http.nim @@ -0,0 +1,250 @@ +import chronos, chronos/timer, httputils, chronicles, uri, tables, strutils + +const + MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets + MaxHttpRequestSize = 128 * 1024 # maximum size of HTTP body in octets + HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec) + CRLF* = "\r\n" + HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')] + +type + HttpClient* = ref object + connected: bool + currentURL: Uri ## Where we are currently connected. + headers: HttpHeaders ## Headers to send in requests. + transp*: StreamTransport + buf: seq[byte] + + HttpHeaders* = object + table*: TableRef[string, seq[string]] + + ReqStatus = enum + Success, Error, ErrorFailure + + AsyncCallback = proc (transp: StreamTransport, + header: HttpRequestHeader): Future[void] {.closure, gcsafe.} + HttpServer* = ref object of StreamServer + callback: AsyncCallback + +proc recvData(transp: StreamTransport): Future[seq[byte]] {.async.} = + var buffer = newSeq[byte](MaxHttpHeadersSize) + var error = false + try: + let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize, + sep = HeaderSep) + let ores = await withTimeout(hlenfut, HttpHeadersTimeout) + if not ores: + # Timeout + debug "Timeout expired while receiving headers", + address = transp.remoteAddress() + error = true + else: + let hlen = hlenfut.read() + buffer.setLen(hlen) + except TransportLimitError: + # size of headers exceeds `MaxHttpHeadersSize` + debug "Maximum size of headers limit reached", + address = transp.remoteAddress() + error = true + except TransportIncompleteError: + # remote peer disconnected + debug "Remote peer disconnected", address = transp.remoteAddress() + error = true + except TransportOsError as exc: + debug "Problems with networking", address = transp.remoteAddress(), + error = exc.msg + error = true + + if error: + buffer.setLen(0) + return buffer + +proc newConnection(client: HttpClient, url: Uri) {.async.} = + if client.connected: + return + + let port = + if url.port == "": 80 + else: url.port.parseInt + + client.transp = await connect(initTAddress(url.hostname, port)) + + # May be connected through proxy but remember actual URL being accessed + client.currentURL = url + client.connected = true + +proc generateHeaders(requestUrl: Uri, httpMethod: string, + additionalHeaders: HttpHeaders): string = + # GET + var headers = httpMethod.toUpperAscii() + headers.add ' ' + + if not requestUrl.path.startsWith("/"): headers.add '/' + headers.add(requestUrl.path) + + # HTTP/1.1\c\l + headers.add(" HTTP/1.1" & CRLF) + + for key, val in additionalHeaders.table: + headers.add(key & ": " & val.join(", ") & CRLF) + headers.add(CRLF) + return headers + +# Send request to the client. Currently only supports HTTP get method. +proc request*(client: HttpClient, url, httpMethod: string, + body = "", headers: HttpHeaders): Future[seq[byte]] + {.async.} = + # Helper that actually makes the request. Does not handle redirects. + let requestUrl = parseUri(url) + if requestUrl.scheme == "": + raise newException(ValueError, "No uri scheme supplied.") + + await newConnection(client, requestUrl) + + let headerString = generateHeaders(requestUrl, httpMethod, headers) + let res = await client.transp.write(headerString) + if res != len(headerString): + raise newException(ValueError, "Error while send request to client") + + var value = await client.transp.recvData() + if value.len == 0: + raise newException(ValueError, "Empty response from server") + return value + +proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: HttpCode, + data: string = ""): Future[bool] {.async.} = + var answer = $version + answer.add(" ") + answer.add($code) + answer.add(CRLF) + answer.add("Date: " & httpDate() & CRLF) + if len(data) > 0: + answer.add("Content-Type: application/json" & CRLF) + answer.add("Content-Length: " & $len(data) & CRLF) + answer.add(CRLF) + if len(data) > 0: + answer.add(data) + + let res = await transp.write(answer) + if res == len(answer): + return true + raise newException(IOError, "Failed to send http request.") + +proc validateRequest(transp: StreamTransport, + header: HttpRequestHeader): Future[ReqStatus] {.async.} = + if header.meth notin {MethodGet}: + debug "GET method is only allowed", address = transp.remoteAddress() + if await transp.sendHTTPResponse(header.version, Http405): + return Error + else: + return ErrorFailure + + var hlen = header.contentLength() + if hlen < 0 or hlen > MaxHttpRequestSize: + debug "Invalid header length", address = transp.remoteAddress() + if await transp.sendHTTPResponse(header.version, Http413): + return Error + else: + return ErrorFailure + + return Success + +proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = + ## Process transport data to the RPC server + var httpServer = cast[HttpServer](server) + var buffer = newSeq[byte](MaxHttpHeadersSize) + var header: HttpRequestHeader + + info "Received connection", address = $transp.remoteAddress() + try: + let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize, + sep = HeaderSep) + let ores = await withTimeout(hlenfut, HttpHeadersTimeout) + if not ores: + # Timeout + debug "Timeout expired while receiving headers", + address = transp.remoteAddress() + discard await transp.sendHTTPResponse(HttpVersion11, Http408) + await transp.closeWait() + return + else: + let hlen = hlenfut.read() + buffer.setLen(hlen) + header = buffer.parseRequest() + if header.failed(): + # Header could not be parsed + debug "Malformed header received", + address = transp.remoteAddress() + discard await transp.sendHTTPResponse(HttpVersion11, Http400) + await transp.closeWait() + return + var vres = await validateRequest(transp, header) + if vres == Success: + info "Received valid RPC request", address = $transp.remoteAddress() + # Call the user's callback. + if httpServer.callback != nil: + await httpServer.callback(transp, header) + elif vres == ErrorFailure: + debug "Remote peer disconnected", address = transp.remoteAddress() + except TransportLimitError: + # size of headers exceeds `MaxHttpHeadersSize` + debug "Maximum size of headers limit reached", + address = transp.remoteAddress() + discard await transp.sendHTTPResponse(HttpVersion11, Http413) + except TransportIncompleteError: + # remote peer disconnected + debug "Remote peer disconnected", address = transp.remoteAddress() + except TransportOsError as exc: + debug "Problems with networking", address = transp.remoteAddress(), + error = exc.msg + except CatchableError as exc: + debug "Unknown exception", address = transp.remoteAddress(), + error = exc.msg + await transp.closeWait() + +proc newHttpServer*(address: string, handler: AsyncCallback, + flags: set[ServerFlags] = {ReuseAddr}): HttpServer = + let address = initTAddress(address) + var server = HttpServer(callback: handler) + server = cast[HttpServer](createStreamServer(address, serveClient, flags, + child = cast[StreamServer](server))) + return server + +func toTitleCase(s: string): string = + var tcstr = newString(len(s)) + var upper = true + for i in 0..len(s) - 1: + tcstr[i] = if upper: toUpperAscii(s[i]) else: toLowerAscii(s[i]) + upper = s[i] == '-' + return tcstr + +func toCaseInsensitive*(headers: HttpHeaders, s: string): string {.inline.} = + return toTitleCase(s) + +func newHttpHeaders*(): HttpHeaders = + ## Returns a new ``HttpHeaders`` object. if ``titleCase`` is set to true, + ## headers are passed to the server in title case (e.g. "Content-Length") + return HttpHeaders(table: newTable[string, seq[string]]()) + +func newHttpHeaders*(keyValuePairs: + openArray[tuple[key: string, val: string]]): HttpHeaders = + ## Returns a new ``HttpHeaders`` object from an array. if ``titleCase`` is set to true, + ## headers are passed to the server in title case (e.g. "Content-Length") + var headers = newHttpHeaders() + + for pair in keyValuePairs: + let key = headers.toCaseInsensitive(pair.key) + if key in headers.table: + headers.table[key].add(pair.val) + else: + headers.table[key] = @[pair.val] + return headers + +proc newHttpClient*(headers = newHttpHeaders()): HttpClient = + return HttpClient(headers: headers) + +proc close*(client: HttpClient) = + ## Closes any connections held by the HTTP client. + if client.connected: + client.transp.close() + client.connected = false diff --git a/src/random.nim b/src/random.nim new file mode 100644 index 0000000000..f61ec97caf --- /dev/null +++ b/src/random.nim @@ -0,0 +1,25 @@ +import bearssl + +## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng +const randMax = 18_446_744_073_709_551_615'u64 + +proc rand*(rng: var BrHmacDrbgContext, max: Natural): int = + if max == 0: return 0 + var x: uint64 + while true: + brHmacDrbgGenerate(addr rng, addr x, csize_t(sizeof(x))) + if x < randMax - (randMax mod (uint64(max) + 1'u64)): # against modulo bias + return int(x mod (uint64(max) + 1'u64)) + +proc genMaskKey*(rng: ref BrHmacDrbgContext): array[4, char] = + ## Generates a random key of 4 random chars. + proc r(): char = char(rand(rng[], 255)) + return [r(), r(), r(), r()] + +proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[char] = + var key = newSeq[char](16) + proc r(): char = char(rand(rng[], 255)) + ## Generates a random key of 16 random chars. + for i in 0..15: + key.add(r()) + return key diff --git a/src/ws.nim b/src/ws.nim new file mode 100644 index 0000000000..08a856bb9c --- /dev/null +++ b/src/ws.nim @@ -0,0 +1,451 @@ +import httputils, strutils, base64, std/sha1, ./random, http, uri, + chronos/timer, tables, stew/byteutils, eth/[keys], stew/endians2, + parseutils, stew/base64 as stewBase,chronos + +const + SHA1DigestSize = 20 + WSHeaderSize = 12 + WSOpCode = {0x00, 0x01, 0x02, 0x08, 0x09, 0x0a} + +type + ReadyState* = enum + Connecting = 0 # The connection is not yet open. + Open = 1 # The connection is open and ready to communicate. + Closing = 2 # The connection is in the process of closing. + Closed = 3 # The connection is closed or couldn't be opened. + + WebSocket* = ref object + tcpSocket*: StreamTransport + version*: int + key*: string + protocol*: string + readyState*: ReadyState + masked*: bool # send masked packets + rng*: ref BrHmacDrbgContext + + WebSocketError* = object of IOError + + Base16Error* = object of CatchableError + ## Base16 specific exception type + + HeaderFlag* {.size: sizeof(uint8).} = enum + rsv3 + rsv2 + rsv1 + fin + HeaderFlags = set[HeaderFlag] + + HttpCode* = enum + Http101 = 101 # Switching Protocols + +# Forward declare +proc close*(ws: WebSocket, initiator: bool = true) {.async.} + +proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} = + ## Handles the websocket handshake. + discard parseSaturatedNatural(header["Sec-WebSocket-Version"], ws.version) + if ws.version != 13: + raise newException(WebSocketError, "Websocket version not supported, Version: " & + header["Sec-WebSocket-Version"]) + + ws.key = header["Sec-WebSocket-Key"].strip() + if header.contains("Sec-WebSocket-Protocol"): + let wantProtocol = header["Sec-WebSocket-Protocol"].strip() + if ws.protocol != wantProtocol: + raise newException(WebSocketError, + "Protocol mismatch (expected: " & ws.protocol & ", got: " & + wantProtocol & ")") + + var acceptKey: string + try: + let sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + acceptKey = stewBase.Base64.encode(hexToByteArray[SHA1DigestSize]($sh)) + except ValueError: + raise newException( + WebSocketError, "Failed to generate accept key: " & getCurrentExceptionMsg()) + + var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF + response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF) + response.add("Connection: Upgrade" & CRLF) + response.add("Upgrade: webSocket" & CRLF) + + if ws.protocol != "": + response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF) + response.add CRLF + + let res = await ws.tcpSocket.write(response) + if res != len(response): + raise newException(WebSocketError, "Failed to send handshake response to client") + ws.readyState = Open + +proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport, + protocol: string = ""): Future[WebSocket] {.async.} = + ## Creates a new socket from a request. + try: + if not header.contains("Sec-WebSocket-Version"): + raise newException(WebSocketError, "Invalid WebSocket handshake") + var ws = WebSocket(tcpSocket: transp, protocol: protocol, masked: false, + rng: newRng()) + await ws.handshake(header) + return ws + except ValueError, KeyError: + # Wrap all exceptions in a WebSocketError so its easy to catch. + raise newException( + WebSocketError, + "Failed to create WebSocket from request: " & getCurrentExceptionMsg() + ) + +type + Opcode* = enum + ## 4 bits. Defines the interpretation of the "Payload data". + Cont = 0x0 ## Denotes a continuation frame. + Text = 0x1 ## Denotes a text frame. + Binary = 0x2 ## Denotes a binary frame. + # 3-7 are reserved for further non-control frames. + Close = 0x8 ## Denotes a connection close. + Ping = 0x9 ## Denotes a ping. + Pong = 0xa ## Denotes a pong. + # B-F are reserved for further control frames. + + #[ + +---------------------------------------------------------------+ + |0 1 2 3 | + |0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1| + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + ]# + + MsgReader = ref object + tcpSocket: StreamTransport + readErr: IOError + readLen: uint64 + readRemaining: uint64 + readFinal: bool ## true the current message has more frames. + opcode: Opcode ## Defines the interpretation of the "Payload data". + maskKey: array[4, char] ## Masking key + mask: bool ## Defines whether the "Payload data" is masked. + + Frame = ref object + fin: bool ## Indicates that this is the final fragment in a message. + rsv1: bool ## MUST be 0 unless negotiated that defines meanings + rsv2: bool ## MUST be 0 + rsv3: bool ## MUST be 0 + opcode: Opcode ## Defines the interpretation of the "Payload data". + mask: bool ## Defines whether the "Payload data" is masked. + data: seq[byte] ## Payload data + maskKey: array[4, char] ## Masking key + length: uint64 ## Message size. + +proc encodeFrame(f: Frame): seq[byte] = + ## Encodes a frame into a string buffer. + ## See https://tools.ietf.org/html/rfc6455#section-5.2 + + var ret = newSeqOfCap[byte](f.data.len + WSHeaderSize) + + var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. + if f.fin: + b0 = b0 or 128u8 + + ret.add(b0) + + # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. + # 1st byte: payload len start and mask bit. + var b1 = 0u8 + + if f.data.len <= 125: + b1 = f.data.len.uint8 + elif f.data.len > 125 and f.data.len <= 0xffff: + b1 = 126u8 + else: + b1 = 127u8 + + if f.mask: + b1 = b1 or (1 shl 7) + + ret.add(uint8 b1) + + # Only need more bytes if data len is 7+16 bits, or 7+64 bits. + if f.data.len > 125 and f.data.len <= 0xffff: + # Data len is 7+16 bits. + var len = f.data.len.uint16 + ret.add ((len shr 8) and 255).uint8 + ret.add (len and 255).uint8 + elif f.data.len > 0xffff: + # Data len is 7+64 bits. + var len = f.data.len + ret.add(f.data.len.uint64.toBE().toBytesBE()) + + var data = f.data + + if f.mask: + # If we need to mask it generate random mask key and mask the data. + for i in 0.. 0 : + # Read the data. + await ws.tcpSocket.readExactly(addr data[0], int finalLen) + frame.data = data + + # Process control frame payload. + if frame.opcode == Ping: + await ws.send(data, Pong) + elif frame.opcode == Pong: + discard + elif frame.opcode == Close: + await ws.close(false) + + return frame + +proc close*(ws: WebSocket, initiator: bool = true) {.async.} = + ## Close the Socket, sends close packet. + if ws.readyState == Closed: + discard ws.tcpSocket.closeWait() + return + ws.readyState = Closed + await ws.send(@[], Close) + if initiator == true: + let frame = await ws.readFrame() + if frame.opcode != Close: + echo "Different packet type" + await ws.close() + +proc readMessage*(msgReader: MsgReader,data: seq[byte]): MsgReader {.async.} = + while msgReader.readErr == nil: + if msgReader.readRemaining > 0 : + len = size(data) + if len > msgReader.readRemaining: + len = msgReader.readRemaining + + await msgReader.tcpSocket.readExactly(addr data, len) + msgReader.readRemaining = msgReader.readRemaining - len + msgReader.readLen = len + + if msgReader.mask: + # Apply mask, if we need too. + for i in 0 ..< len: + data[i] = (data[i].uint8 xor msgReader.maskKey[i mod 4].uint8) + + if msgReader.readRemaining == 0: + msgReader.readErr = EOFError + + return msgReader + + if msgReader.readFinal: + msgReader.readLen = 0 + msgReader.readErr = EOFError + return msgReader + + var frame = await ws.readFrame() + if frame.fin: + msgReader.readFinal = true + msgReader.readRemaining = frame.length + + # Non-control frames cannot occur in the middle of a fragmented non-control frame. + if frame.Opcode in Text || Binary: + raise newException("websocket: internal error, unexpected text or binary in Reader") + return msgReader + +proc nextMessageReader*(ws: WebSocket): MsgReader = + while true: + # Handle control frames and return only on non control frames. + var frame = await ws.readFrame() + if frame.Opcode in Text || Binary: + var msgReader: MsgReader + msgReader.readFinal = frame.fin + msgReader.readRemaining = frame.readRemaining + msgReader.tcpSocket = ws.tcpSocket + msgReader.mask = frame.mask + msgReader.maskKey = frame.maskKey + return msgReader + +proc receiveStrPacket*(ws: WebSocket): Future[seq[byte]] {.async.} = + # TODO: remove this once PR is approved. + return nil + +proc validateWSClientHandshake*(transp: StreamTransport, + header: HttpResponseHeader): void = + if header.code != ord(Http101): + raise newException(WebSocketError, "Server did not reply with a websocket upgrade: " & + "Header code: " & $header.code & + "Header reason: " & header.reason() & + "Address: " & $transp.remoteAddress()) + +proc newWebsocketClient*(uri: Uri, protocols: seq[string] = @[]): Future[ + WebSocket] {.async.} = + var key = encode(genWebSecKey(newRng())) + var uri = uri + case uri.scheme + of "ws": + uri.scheme = "http" + else: + raise newException(WebSocketError, "uri scheme has to be 'ws'") + + var headers = newHttpHeaders({ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Cache-Control": "no-cache", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": key + }) + if protocols.len != 0: + headers.table["Sec-WebSocket-Protocol"] = @[protocols.join(", ")] + + let client = newHttpClient(headers) + var response = await client.request($uri, "GET", headers = headers) + var header = response.parseResponse() + if header.failed(): + # Header could not be parsed + raise newException(WebSocketError, "Malformed header received: " & + $client.transp.remoteAddress()) + client.transp.validateWSClientHandshake(header) + + # Client data should be masked. + return WebSocket(tcpSocket: client.transp, readyState: Open, masked: true, + rng: newRng()) + +proc newWebsocketClient*(host: string, port: Port, path: string, + protocols: seq[string] = @[]): Future[WebSocket] {.async.} = + var uri = "ws://" & host & ":" & $port + if path.startsWith("/"): + uri.add path + else: + uri.add "/" & path + return await newWebsocketClient(parseUri(uri), protocols) diff --git a/test/client.nim b/test/client.nim deleted file mode 100644 index 358b63370b..0000000000 --- a/test/client.nim +++ /dev/null @@ -1,7 +0,0 @@ -import ws, nativesockets, chronos - -discard waitFor newAsyncWebsocketClient("localhost", Port(8080), path = "/", protocols = @["myfancyprotocol"]) -echo "connected" - -runForever() - diff --git a/test/config.nims b/test/config.nims deleted file mode 100644 index 80091ff6c7..0000000000 --- a/test/config.nims +++ /dev/null @@ -1 +0,0 @@ -switch("path", "$projectDir/../src") diff --git a/test/server.nim b/test/server.nim deleted file mode 100644 index f89c23e5f0..0000000000 --- a/test/server.nim +++ /dev/null @@ -1,20 +0,0 @@ -import ws, chronos, chronicles, httputils - -proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = - info "Header: ", uri = header.uri() - if header.uri() == "/ws": - info "Initiating web socket connection." - try: - var ws = await newWebSocket(header, transp) - echo await ws.receivePacket() - info "Websocket handshake completed." - except WebSocketError: - echo "socket closed:", getCurrentExceptionMsg() - - let res = await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World") - -when isMainModule: - let address = "127.0.0.1:8888" - var httpServer = newHttpServer(address, cb) - httpServer.start() - waitFor httpServer.join() diff --git a/tests/frame.nim b/tests/frame.nim new file mode 100644 index 0000000000..5f9f1f814f --- /dev/null +++ b/tests/frame.nim @@ -0,0 +1,76 @@ +include ../src/ws +include ../src/http +include ../src/random +#import chronos, chronicles, httputils, strutils, base64, std/sha1, +# streams, nativesockets, uri, times, chronos/timer, tables + +import unittest + +# TODO: Fix Test. + +var maskKey: array[4, char] + +suite "tests for encodeFrame()": + test "# 7bit length": + block: # 7bit length + assert encodeFrame(( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\129\8hi there") + test "# 7bit length": + block: # 7+16 bits length + var data = "" + for i in 0..32: + data.add "How are you this is the payload!!!" + assert encodeFrame(( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes(data), + maskKey: maskKey + ))[0..32] == toBytes("\129~\4bHow are you this is the paylo") + test "# 7+64 bits length": + block: # 7+64 bits length + var data = "" + for i in 0..3200: + data.add "How are you this is the payload!!!" + assert encodeFrame(( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes(data), + maskKey: maskKey + ))[0..32] == toBytes("\129\127\0\0\0\0\0\1\169\"How are you this is the") + test "# masking": + block: # masking + let data = encodeFrame(( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: true, + data: toBytes("hi there"), + maskKey: ['\xCF', '\xD8', '\x05', 'e'] + )) + assert data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0") + +suite "tests for toTitleCase()": + block: + let val = toTitleCase("webSocket") + assert val == "Websocket" + + + diff --git a/tests/helpers.nim b/tests/helpers.nim new file mode 100644 index 0000000000..a43a4df0e0 --- /dev/null +++ b/tests/helpers.nim @@ -0,0 +1,55 @@ +import ../src/ws, chronos, chronicles, httputils, stew/byteutils, + ../src/http, unittest, strutils + +proc cb*(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + info "Handling request:", uri = header.uri() + if header.uri() == "/ws": + info "Initiating web socket connection." + try: + var ws = await newWebSocket(header, transp, "myfancyprotocol") + if ws.readyState == Open: + info "Websocket handshake completed." + else: + error "Failed to open websocket connection." + return + + while ws.readyState == Open: + let recvData = await ws.receiveStrPacket() + info "Server:", state = ws.readyState + await ws.send(recvData) + except WebSocketError: + error "WebSocket error:", exception = getCurrentExceptionMsg() + discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established") + +proc sendRecvClientData*(wsClient: WebSocket, msg: string) {.async.} = + try: + waitFor wsClient.sendStr(msg) + let recvData = waitFor wsClient.receiveStrPacket() + info "Websocket client state: ", state = wsClient.readyState + let dataStr = string.fromBytes(recvData) + require dataStr == msg + + except WebSocketError: + error "WebSocket error:", exception = getCurrentExceptionMsg() + +proc incorrectProtocolCB*(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + info "Handling request:", uri = header.uri() + var isErr = false; + if header.uri() == "/ws": + info "Initiating web socket connection." + try: + var ws = await newWebSocket(header, transp, "myfancyprotocol") + require ws.readyState == ReadyState.Closed + except WebSocketError: + isErr = true; + require contains(getCurrentExceptionMsg(), "Protocol mismatch") + finally: + require isErr == true + discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established") + + +proc generateData*(num: int64): seq[byte] = + var str = newSeqOfCap[byte](num) + for i in 0 ..< num: + str.add(65) + return str diff --git a/tests/websocket.nim b/tests/websocket.nim new file mode 100644 index 0000000000..6abdbfcf54 --- /dev/null +++ b/tests/websocket.nim @@ -0,0 +1,87 @@ +import helpers, unittest, ../src/http, chronos, ../src/ws,../src/random, + stew/byteutils, os, strutils + +var httpServer: HttpServer + +proc startServer() {.async, gcsafe.} = + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + +proc closeServer() {.async, gcsafe.} = + httpServer.stop() + waitFor httpServer.closeWait() + +suite "Test websocket error cases": + teardown: + httpServer.stop() + waitFor httpServer.closeWait() + + test "Test for incorrect protocol": + httpServer = newHttpServer("127.0.0.1:8888", incorrectProtocolCB) + httpServer.start() + try: + let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), + path = "/ws", protocols = @["mywrongprotocol"]) + except WebSocketError: + require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade") + + test "Test for incorrect port": + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + try: + let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8889), + path = "/ws", protocols = @["myfancyprotocol"]) + except: + require contains(getCurrentExceptionMsg(), "Connection refused") + + test "Test for incorrect path": + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + try: + let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), + path = "/gg", protocols = @["myfancyprotocol"]) + except: + require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade") + +suite "Misc Test": + setup: + waitFor startServer() + teardown: + waitFor closeServer() + + test "Test for maskKey": + let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws", + protocols = @["myfancyprotocol"]) + let maskKey = genMaskKey(wsClient.rng) + require maskKey.len == 4 + + test "Test for toCaseInsensitive": + let headers = newHttpHeaders() + require toCaseInsensitive(headers, "webSocket") == "Websocket" + + +suite "Test web socket communication": + + setup: + waitFor startServer() + let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), + path = "/ws", protocols = @["myfancyprotocol"]) + + teardown: + waitFor closeServer() + + test "Websocket conversation between client and server": + waitFor sendRecvClientData(wsClient, "Hello Server") + + test "Test for small message ": + let msg = string.fromBytes(generateData(100)) + waitFor sendRecvClientData(wsClient, msg) + + test "Test for medium message ": + let msg = string.fromBytes(generateData(1000)) + waitFor sendRecvClientData(wsClient, msg) + + test "Test for large message ": + let msg = string.fromBytes(generateData(10000)) + waitFor sendRecvClientData(wsClient, msg) + diff --git a/ws.nimble b/ws.nimble index 0777e4658c..86f8ad7448 100644 --- a/ws.nimble +++ b/ws.nimble @@ -4,12 +4,15 @@ author = "Status Research & Development GmbH" description = "WS protocol implementation" license = "MIT" -requires "nim >= 1.2.6" -requires "chronos >= 2.5.2 & < 3.0.0" +requires "nim == 1.2.6" +requires "chronos >= 2.5.2" requires "httputils >= 0.2.0" requires "chronicles >= 0.10.0" requires "urlly >= 0.2.0" -requires "uri" +requires "stew >= 0.1.0" +requires "eth" +requires "asynctest >= 0.2.0 & < 0.3.0" +requires "nimcrypto" task lint, "format source files according to the official style guide": exec "./lint.nims"