diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..899c1ee --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,202 @@ +name: nim-ws CI +on: [push, pull_request] + +jobs: + build: + strategy: + fail-fast: false + max-parallel: 20 + matrix: + branch: [v1.2.6] + target: + # Unit tests + - os: linux + cpu: amd64 + TEST_KIND: unit-tests + - os: linux + cpu: i386 + TEST_KIND: unit-tests + - os: macos + cpu: amd64 + TEST_KIND: unit-tests + - os: windows + cpu: i386 + TEST_KIND: unit-tests + - os: windows + cpu: amd64 + TEST_KIND: unit-tests + include: + - target: + os: linux + builder: ubuntu-20.04 + - target: + os: macos + builder: macos-10.15 + - target: + os: windows + builder: windows-2019 + name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})' + runs-on: ${{ matrix.builder }} + steps: + - name: Checkout nim-ws + uses: actions/checkout@v2 + with: + path: nim-ws + submodules: true + + - name: Derive environment variables + shell: bash + run: | + if [[ '${{ matrix.target.cpu }}' == 'amd64' ]]; then + ARCH=64 + PLATFORM=x64 + else + ARCH=32 + PLATFORM=x86 + fi + echo "ARCH=$ARCH" >> $GITHUB_ENV + echo "PLATFORM=$PLATFORM" >> $GITHUB_ENV + + ncpu= + case '${{ runner.os }}' in + 'Linux') + ncpu=$(nproc) + ;; + 'macOS') + ncpu=$(sysctl -n hw.ncpu) + ;; + 'Windows') + ncpu=$NUMBER_OF_PROCESSORS + ;; + esac + [[ -z "$ncpu" || $ncpu -le 0 ]] && ncpu=1 + echo "ncpu=$ncpu" >> $GITHUB_ENV + + - name: Install build dependencies (Linux i386) + if: runner.os == 'Linux' && matrix.target.cpu == 'i386' + run: | + sudo dpkg --add-architecture i386 + sudo rm /etc/apt/sources.list.d/devel:kubic:libcontainers:stable.list + sudo apt-get update -qq + sudo DEBIAN_FRONTEND='noninteractive' apt-get install \ + --no-install-recommends -yq gcc-multilib g++-multilib + mkdir -p external/bin + cat << EOF > external/bin/gcc + #!/bin/bash + exec $(which gcc) -m32 -mno-adx "\$@" + EOF + cat << EOF > external/bin/g++ + #!/bin/bash + exec $(which g++) -m32 -mno-adx "\$@" + EOF + chmod 755 external/bin/gcc external/bin/g++ + echo "${{ github.workspace }}/external/bin" >> $GITHUB_PATH + + - name: Install build dependencies (Windows) + if: runner.os == 'Windows' + shell: bash + run: | + mkdir external + if [[ '${{ matrix.target.cpu }}' == 'amd64' ]]; then + arch=64 + else + arch=32 + fi + curl -L "https://nim-lang.org/download/mingw$arch-6.3.0.7z" -o "external/mingw$arch.7z" + curl -L "https://nim-lang.org/download/windeps.zip" -o external/windeps.zip + 7z x "external/mingw$arch.7z" -oexternal/ + 7z x external/windeps.zip -oexternal/dlls + echo '${{ github.workspace }}'"/external/mingw$arch/bin" >> $GITHUB_PATH + echo '${{ github.workspace }}'"/external/dlls" >> $GITHUB_PATH + + - name: Setup environment + shell: bash + run: echo '${{ github.workspace }}/nim/bin' >> $GITHUB_PATH + + - name: Get latest Nim commit hash + id: versions + shell: bash + run: | + getHash() { + git ls-remote "https://github.com/$1" "${2:-HEAD}" | cut -f 1 + } + nimHash=$(getHash nim-lang/Nim '${{ matrix.branch }}') + csourcesHash=$(getHash nim-lang/csources) + echo "::set-output name=nim::$nimHash" + echo "::set-output name=csources::$csourcesHash" + + - name: Restore prebuilt Nim from cache + id: nim-cache + uses: actions/cache@v1 + with: + path: nim + key: "nim-${{ matrix.target.os }}-${{ matrix.target.cpu }}-${{ steps.versions.outputs.nim }}" + + - name: Restore prebuilt csources from cache + if: steps.nim-cache.outputs.cache-hit != 'true' + id: csources-cache + uses: actions/cache@v1 + with: + path: csources/bin + key: "csources-${{ matrix.target.os }}-${{ matrix.target.cpu }}-${{ steps.versions.outputs.csources }}" + + - name: Checkout Nim csources + if: > + steps.csources-cache.outputs.cache-hit != 'true' && + steps.nim-cache.outputs.cache-hit != 'true' + uses: actions/checkout@v2 + with: + repository: nim-lang/csources + path: csources + ref: ${{ steps.versions.outputs.csources }} + + - name: Checkout Nim + if: steps.nim-cache.outputs.cache-hit != 'true' + uses: actions/checkout@v2 + with: + repository: nim-lang/Nim + path: nim + ref: ${{ steps.versions.outputs.nim }} + + - name: Build Nim and associated tools + if: steps.nim-cache.outputs.cache-hit != 'true' + shell: bash + run: | + ncpu= + ext= + case '${{ runner.os }}' in + 'Linux') + ncpu=$(nproc) + ;; + 'macOS') + ncpu=$(sysctl -n hw.ncpu) + ;; + 'Windows') + ncpu=$NUMBER_OF_PROCESSORS + ext=.exe + ;; + esac + [[ -z "$ncpu" || $ncpu -le 0 ]] && ncpu=1 + if [[ ! -e csources/bin/nim$ext ]]; then + make -C csources -j $ncpu CC=gcc ucpu='${{ matrix.target.cpu }}' + else + echo 'Using prebuilt csources' + fi + cp -v csources/bin/nim$ext nim/bin + cd nim + nim c koch + ./koch boot -d:release + ./koch tools -d:release + # clean up to save cache space + rm koch + rm -rf nimcache + rm -rf dist + rm -rf .git + + - name: Run nim-ws tests + shell: bash + run: | + export UCPU="$cpu" + cd nim-ws + nimble install -y --depsOnly + nimble test diff --git a/examples/client.nim b/examples/client.nim index 730fee5..fb7c109 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -1,21 +1,31 @@ import ../src/ws, nativesockets, chronos, os, chronicles, stew/byteutils -let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws", - protocols = @["myfancyprotocol"]) -info "Websocket client: ", State = wsClient.readyState +proc main() {.async.} = + let ws = await connect( + "127.0.0.1", Port(8888), + path = "/ws") -let reqData = "Hello Server" -for idx in 1 .. 5: - try: - waitFor wsClient.sendStr(reqData) - let recvData = waitFor wsClient.receiveStrPacket() - let dataStr = string.fromBytes(recvData) - info "Server:", data = dataStr - assert dataStr == reqData - except WebSocketError: - error "WebSocket error:", exception = getCurrentExceptionMsg() - os.sleep(1000) + debug "Websocket client: ", State = ws.readyState -# close the websocket -waitFor wsClient.close() + let reqData = "Hello Server" + while true: + try: + await ws.send(reqData) + let buff = await ws.recv() + if buff.len <= 0: + break + let dataStr = string.fromBytes(buff) + debug "Server:", data = dataStr + + assert dataStr == reqData + return # bail out + except WebSocketError as exc: + error "WebSocket error:", exception = exc.msg + + await sleepAsync(100.millis) + + # close the websocket + await ws.close() + +waitFor(main()) diff --git a/examples/server.nim b/examples/server.nim index d9ba10a..b388881 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -1,34 +1,30 @@ import ../src/ws, ../src/http, chronos, chronicles, httputils, stew/byteutils proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = - info "Handling request:", uri = header.uri() + debug "Handling request:", uri = header.uri() if header.uri() == "/ws": - info "Initiating web socket connection." + debug "Initiating web socket connection." try: - var ws = await newWebSocket(header, transp, "myfancyprotocol") - if ws.readyState == Open: - info "Websocket handshake completed." - else: + var ws = await createServer(header, transp, "") + if ws.readyState != Open: error "Failed to open websocket connection." return - while true: + debug "Websocket handshake completed." + while ws.readyState != ReadyState.Closed: # Only reads header for data frame. - let msgReader = await ws.nextMessageReader() + var recvData = await ws.recv() + if recvData.len <= 0: + debug "Empty messages" + break - # Read the frame payload in buffer. - let buffer = newSeq[byte](100) - var recvData :seq[byte] - while msgReader.error != EOFError: - msgReader.readMessage(buffer) - recvData.add buffer - if ws.readyState == ReadyState.Closed: - return - info "Response: ", data = recvData + # debug "Response: ", data = string.fromBytes(recvData), size = recvData.len + debug "Response: ", size = recvData.len await ws.send(recvData) + # await ws.close() - except WebSocketError: - error "WebSocket error:", exception = getCurrentExceptionMsg() + except WebSocketError as exc: + error "WebSocket error:", exception = exc.msg discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World") await transp.closeWait() diff --git a/lint.nims b/lint.nims deleted file mode 100644 index 08b8c83..0000000 --- a/lint.nims +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env nim -import std/strutils - -proc lintFile*(file: string) = - if file.endsWith(".nim"): - exec "nimpretty " & file - -proc lintDir*(dir: string) = - for file in listFiles(dir): - lintFile(file) - for subdir in listDirs(dir): - lintDir(subdir) - -lintDir(projectDir()) \ No newline at end of file diff --git a/src/http.nim b/src/http.nim index f16cd29..0ccfc49 100644 --- a/src/http.nim +++ b/src/http.nim @@ -23,6 +23,7 @@ type AsyncCallback = proc (transp: StreamTransport, header: HttpRequestHeader): Future[void] {.closure, gcsafe.} + HttpServer* = ref object of StreamServer callback: AsyncCallback @@ -59,7 +60,7 @@ proc recvData(transp: StreamTransport): Future[seq[byte]] {.async.} = buffer.setLen(0) return buffer -proc newConnection(client: HttpClient, url: Uri) {.async.} = +proc connect(client: HttpClient, url: Uri) {.async.} = if client.connected: return @@ -73,8 +74,10 @@ proc newConnection(client: HttpClient, url: Uri) {.async.} = client.currentURL = url client.connected = true -proc generateHeaders(requestUrl: Uri, httpMethod: string, - additionalHeaders: HttpHeaders): string = +proc generateHeaders( + requestUrl: Uri, + httpMethod: string, + additionalHeaders: HttpHeaders): string = # GET var headers = httpMethod.toUpperAscii() headers.add ' ' @@ -91,15 +94,19 @@ proc generateHeaders(requestUrl: Uri, httpMethod: string, return headers # Send request to the client. Currently only supports HTTP get method. -proc request*(client: HttpClient, url, httpMethod: string, - body = "", headers: HttpHeaders): Future[seq[byte]] - {.async.} = +proc request*( + client: HttpClient, + url, + httpMethod: string, + body = "", + headers: HttpHeaders): Future[seq[byte]] {.async.} = # Helper that actually makes the request. Does not handle redirects. + let requestUrl = parseUri(url) if requestUrl.scheme == "": raise newException(ValueError, "No uri scheme supplied.") - await newConnection(client, requestUrl) + await connect(client, requestUrl) let headerString = generateHeaders(requestUrl, httpMethod, headers) let res = await client.transp.write(headerString) @@ -111,8 +118,12 @@ proc request*(client: HttpClient, url, httpMethod: string, raise newException(ValueError, "Empty response from server") return value -proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: HttpCode, - data: string = ""): Future[bool] {.async.} = +proc sendHTTPResponse*( + transp: StreamTransport, + version: HttpVersion, + code: HttpCode, + data: string = ""): Future[bool] {.async.} = + var answer = $version answer.add(" ") answer.add($code) @@ -128,10 +139,13 @@ proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: Http let res = await transp.write(answer) if res == len(answer): return true + raise newException(IOError, "Failed to send http request.") -proc validateRequest(transp: StreamTransport, - header: HttpRequestHeader): Future[ReqStatus] {.async.} = +proc validateRequest( + transp: StreamTransport, + header: HttpRequestHeader): Future[ReqStatus] {.async.} = + if header.meth notin {MethodGet}: debug "GET method is only allowed", address = transp.remoteAddress() if await transp.sendHTTPResponse(header.version, Http405): @@ -150,12 +164,14 @@ proc validateRequest(transp: StreamTransport, return Success proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = - ## Process transport data to the RPC server + ## Process transport data to the HTTP server + ## + var httpServer = cast[HttpServer](server) var buffer = newSeq[byte](MaxHttpHeadersSize) var header: HttpRequestHeader - info "Received connection", address = $transp.remoteAddress() + debug "Received connection", address = $transp.remoteAddress() try: let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep) @@ -180,7 +196,7 @@ proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = return var vres = await validateRequest(transp, header) if vres == Success: - info "Received valid RPC request", address = $transp.remoteAddress() + debug "Received valid HTTP request", address = $transp.remoteAddress() # Call the user's callback. if httpServer.callback != nil: await httpServer.callback(transp, header) @@ -210,17 +226,6 @@ proc newHttpServer*(address: string, handler: AsyncCallback, child = cast[StreamServer](server))) return server -func toTitleCase(s: string): string = - var tcstr = newString(len(s)) - var upper = true - for i in 0..len(s) - 1: - tcstr[i] = if upper: toUpperAscii(s[i]) else: toLowerAscii(s[i]) - upper = s[i] == '-' - return tcstr - -func toCaseInsensitive*(headers: HttpHeaders, s: string): string {.inline.} = - return toTitleCase(s) - func newHttpHeaders*(): HttpHeaders = ## Returns a new ``HttpHeaders`` object. if ``titleCase`` is set to true, ## headers are passed to the server in title case (e.g. "Content-Length") @@ -233,7 +238,7 @@ func newHttpHeaders*(keyValuePairs: var headers = newHttpHeaders() for pair in keyValuePairs: - let key = headers.toCaseInsensitive(pair.key) + let key = toUpperAscii(pair.key) if key in headers.table: headers.table[key].add(pair.val) else: diff --git a/src/random.nim b/src/random.nim index f61ec97..8b6906e 100644 --- a/src/random.nim +++ b/src/random.nim @@ -3,6 +3,7 @@ import bearssl ## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng const randMax = 18_446_744_073_709_551_615'u64 + proc rand*(rng: var BrHmacDrbgContext, max: Natural): int = if max == 0: return 0 var x: uint64 @@ -16,9 +17,9 @@ proc genMaskKey*(rng: ref BrHmacDrbgContext): array[4, char] = proc r(): char = char(rand(rng[], 255)) return [r(), r(), r(), r()] -proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[char] = - var key = newSeq[char](16) - proc r(): char = char(rand(rng[], 255)) +proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[byte] = + var key = newSeq[byte](16) + proc r(): byte = byte(rand(rng[], 255)) ## Generates a random key of 16 random chars. for i in 0..15: key.add(r()) diff --git a/src/ws.nim b/src/ws.nim index 08a856b..51e519f 100644 --- a/src/ws.nim +++ b/src/ws.nim @@ -1,111 +1,19 @@ -import httputils, strutils, base64, std/sha1, ./random, http, uri, - chronos/timer, tables, stew/byteutils, eth/[keys], stew/endians2, - parseutils, stew/base64 as stewBase,chronos +import std/[tables, + strutils, + uri, + parseutils] -const - SHA1DigestSize = 20 - WSHeaderSize = 12 - WSOpCode = {0x00, 0x01, 0x02, 0x08, 0x09, 0x0a} +import pkg/[chronos, + chronicles, + httputils, + stew/byteutils, + stew/endians2, + stew/base64, + eth/keys] -type - ReadyState* = enum - Connecting = 0 # The connection is not yet open. - Open = 1 # The connection is open and ready to communicate. - Closing = 2 # The connection is in the process of closing. - Closed = 3 # The connection is closed or couldn't be opened. +import pkg/nimcrypto/sha - WebSocket* = ref object - tcpSocket*: StreamTransport - version*: int - key*: string - protocol*: string - readyState*: ReadyState - masked*: bool # send masked packets - rng*: ref BrHmacDrbgContext - - WebSocketError* = object of IOError - - Base16Error* = object of CatchableError - ## Base16 specific exception type - - HeaderFlag* {.size: sizeof(uint8).} = enum - rsv3 - rsv2 - rsv1 - fin - HeaderFlags = set[HeaderFlag] - - HttpCode* = enum - Http101 = 101 # Switching Protocols - -# Forward declare -proc close*(ws: WebSocket, initiator: bool = true) {.async.} - -proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} = - ## Handles the websocket handshake. - discard parseSaturatedNatural(header["Sec-WebSocket-Version"], ws.version) - if ws.version != 13: - raise newException(WebSocketError, "Websocket version not supported, Version: " & - header["Sec-WebSocket-Version"]) - - ws.key = header["Sec-WebSocket-Key"].strip() - if header.contains("Sec-WebSocket-Protocol"): - let wantProtocol = header["Sec-WebSocket-Protocol"].strip() - if ws.protocol != wantProtocol: - raise newException(WebSocketError, - "Protocol mismatch (expected: " & ws.protocol & ", got: " & - wantProtocol & ")") - - var acceptKey: string - try: - let sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - acceptKey = stewBase.Base64.encode(hexToByteArray[SHA1DigestSize]($sh)) - except ValueError: - raise newException( - WebSocketError, "Failed to generate accept key: " & getCurrentExceptionMsg()) - - var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF - response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF) - response.add("Connection: Upgrade" & CRLF) - response.add("Upgrade: webSocket" & CRLF) - - if ws.protocol != "": - response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF) - response.add CRLF - - let res = await ws.tcpSocket.write(response) - if res != len(response): - raise newException(WebSocketError, "Failed to send handshake response to client") - ws.readyState = Open - -proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport, - protocol: string = ""): Future[WebSocket] {.async.} = - ## Creates a new socket from a request. - try: - if not header.contains("Sec-WebSocket-Version"): - raise newException(WebSocketError, "Invalid WebSocket handshake") - var ws = WebSocket(tcpSocket: transp, protocol: protocol, masked: false, - rng: newRng()) - await ws.handshake(header) - return ws - except ValueError, KeyError: - # Wrap all exceptions in a WebSocketError so its easy to catch. - raise newException( - WebSocketError, - "Failed to create WebSocket from request: " & getCurrentExceptionMsg() - ) - -type - Opcode* = enum - ## 4 bits. Defines the interpretation of the "Payload data". - Cont = 0x0 ## Denotes a continuation frame. - Text = 0x1 ## Denotes a text frame. - Binary = 0x2 ## Denotes a binary frame. - # 3-7 are reserved for further non-control frames. - Close = 0x8 ## Denotes a connection close. - Ping = 0x9 ## Denotes a ping. - Pong = 0xa ## Denotes a pong. - # B-F are reserved for further control frames. +import ./random, ./http #[ +---------------------------------------------------------------+ @@ -129,49 +37,217 @@ type +---------------------------------------------------------------+ ]# - MsgReader = ref object - tcpSocket: StreamTransport - readErr: IOError - readLen: uint64 - readRemaining: uint64 - readFinal: bool ## true the current message has more frames. - opcode: Opcode ## Defines the interpretation of the "Payload data". - maskKey: array[4, char] ## Masking key - mask: bool ## Defines whether the "Payload data" is masked. +const + SHA1DigestSize* = 20 + WSHeaderSize* = 12 + WSDefaultVersion* = 13 + WSDefaultFrameSize* = 1 shl 20 # 1mb + WSMaxMessageSize* = 20 shl 20 # 20mb + WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +type + ReadyState* {.pure.} = enum + Connecting = 0 # The connection is not yet open. + Open = 1 # The connection is open and ready to communicate. + Closing = 2 # The connection is in the process of closing. + Closed = 3 # The connection is closed or couldn't be opened. + + WebSocketError* = object of CatchableError + WSMalformedHeaderError* = object of WebSocketError + WSFailedUpgradeError* = object of WebSocketError + WSVersionError* = object of WebSocketError + WSProtoMismatchError* = object of WebSocketError + WSMaskMismatchError* = object of WebSocketError + WSHandshakeError* = object of WebSocketError + WSOpcodeMismatchError* = object of WebSocketError + WSRsvMismatchError* = object of WebSocketError + WSWrongUriSchemeError* = object of WebSocketError + WSMaxMessageSizeError* = object of WebSocketError + WSClosedError* = object of WebSocketError + WSSendError* = object of WebSocketError + WSPayloadTooLarge = object of WebSocketError + + Base16Error* = object of CatchableError + ## Base16 specific exception type + + HeaderFlag* {.size: sizeof(uint8).} = enum + rsv3 + rsv2 + rsv1 + fin + HeaderFlags = set[HeaderFlag] + + HttpCode* = enum + Http101 = 101 # Switching Protocols + + Opcode* {.pure.} = enum + ## 4 bits. Defines the interpretation of the "Payload data". + Cont = 0x0 ## Denotes a continuation frame. + Text = 0x1 ## Denotes a text frame. + Binary = 0x2 ## Denotes a binary frame. + # 3-7 are reserved for further non-control frames. + Close = 0x8 ## Denotes a connection close. + Ping = 0x9 ## Denotes a ping. + Pong = 0xa ## Denotes a pong. + # B-F are reserved for further control frames. + + Status* {.pure.} = enum + # 0-999 not used + Fulfilled = 1000 + GoingAway = 1001 + ProtocolError = 1002 + CannotAccept = 1003 + # 1004 reserved + NoStatus = 1005 # use by clients + ClosedAbnormally = 1006 # use by clients + Inconsistent = 1007 + PolicyError = 1008 + TooLarge = 1009 + NoExtensions = 1010 + UnexpectedError = 1011 + TlsError # use by clients + # 3000-3999 reserved for libs + # 4000-4999 reserved for applications Frame = ref object - fin: bool ## Indicates that this is the final fragment in a message. - rsv1: bool ## MUST be 0 unless negotiated that defines meanings - rsv2: bool ## MUST be 0 - rsv3: bool ## MUST be 0 - opcode: Opcode ## Defines the interpretation of the "Payload data". - mask: bool ## Defines whether the "Payload data" is masked. - data: seq[byte] ## Payload data - maskKey: array[4, char] ## Masking key - length: uint64 ## Message size. + fin: bool ## Indicates that this is the final fragment in a message. + rsv1: bool ## MUST be 0 unless negotiated that defines meanings + rsv2: bool ## MUST be 0 + rsv3: bool ## MUST be 0 + opcode: Opcode ## Defines the interpretation of the "Payload data". + mask: bool ## Defines whether the "Payload data" is masked. + data: seq[byte] ## Payload data + maskKey: array[4, char] ## Masking key + length: uint64 ## Message size. + consumed: uint64 ## how much has been consumed from the frame -proc encodeFrame(f: Frame): seq[byte] = + ControlCb* = proc() {.gcsafe.} + + CloseResult* = tuple + code: Status + reason: string + + CloseCb* = proc(code: Status, reason: string): + CloseResult {.gcsafe.} + + WebSocket* = ref object + tcpSocket*: StreamTransport + version*: int + key*: string + protocol*: string + readyState*: ReadyState + masked*: bool # send masked packets + rng*: ref BrHmacDrbgContext + frameSize: int + frame: Frame + onPing: ControlCb + onPong: ControlCb + onClose: CloseCb + +template remainder*(frame: Frame): uint64 = + frame.length - frame.consumed + +proc unmask*( + data: var openArray[byte], + maskKey: array[4, char], + offset = 0) = + ## Unmask a data payload using key + ## + + for i in 0 ..< data.len: + data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8) + +proc prepareCloseBody(code: Status, reason: string): seq[byte] = + result = reason.toBytes + if ord(code) > 999: + result = @(ord(code).uint16.toBytesBE()) & result + +proc handshake*( + ws: WebSocket, + header: HttpRequestHeader, + version = WSDefaultVersion) {.async.} = + ## Handles the websocket handshake. + ## + + discard parseSaturatedNatural(header["Sec-WebSocket-Version"], ws.version) + if ws.version != version: + raise newException(WSVersionError, + "Websocket version not supported, Version: " & + header["Sec-WebSocket-Version"]) + + ws.key = header["Sec-WebSocket-Key"].strip() + if header.contains("Sec-WebSocket-Protocol"): + let wantProtocol = header["Sec-WebSocket-Protocol"].strip() + if ws.protocol != wantProtocol: + raise newException(WSProtoMismatchError, + "Protocol mismatch (expected: " & ws.protocol & ", got: " & + wantProtocol & ")") + + let cKey = ws.key & WSGuid + let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data) + + var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF + response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF) + response.add("Connection: Upgrade" & CRLF) + response.add("Upgrade: webSocket" & CRLF) + + if ws.protocol != "": + response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF) + response.add CRLF + + let res = await ws.tcpSocket.write(response) + if res != len(response): + raise newException(WSSendError, "Failed to send handshake response to client") + ws.readyState = ReadyState.Open + +proc createServer*( + header: HttpRequestHeader, + transp: StreamTransport, + protocol: string = "", + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil): Future[WebSocket] {.async.} = + ## Creates a new socket from a request. + ## + + if not header.contains("Sec-WebSocket-Version"): + raise newException(WSHandshakeError, "Missing version header") + + var ws = WebSocket( + tcpSocket: transp, + protocol: protocol, + masked: false, + rng: newRng(), + frameSize: frameSize, + onPing: onPing, + onPong: onPong, + onClose: onClose) + + await ws.handshake(header) + return ws + +proc encodeFrame*(f: Frame): seq[byte] = ## Encodes a frame into a string buffer. ## See https://tools.ietf.org/html/rfc6455#section-5.2 - var ret = newSeqOfCap[byte](f.data.len + WSHeaderSize) - + var ret: seq[byte] var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. if f.fin: - b0 = b0 or 128u8 + b0 = b0 or 128'u8 ret.add(b0) # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. # 1st byte: payload len start and mask bit. - var b1 = 0u8 + var b1 = 0'u8 if f.data.len <= 125: b1 = f.data.len.uint8 elif f.data.len > 125 and f.data.len <= 0xffff: - b1 = 126u8 + b1 = 126'u8 else: - b1 = 127u8 + b1 = 127'u8 if f.mask: b1 = b1 or (1 shl 7) @@ -186,8 +262,8 @@ proc encodeFrame(f: Frame): seq[byte] = ret.add (len and 255).uint8 elif f.data.len > 0xffff: # Data len is 7+64 bits. - var len = f.data.len - ret.add(f.data.len.uint64.toBE().toBytesBE()) + var len = f.data.len.uint64 + ret.add(len.toBytesBE()) var data = f.data @@ -205,226 +281,369 @@ proc encodeFrame(f: Frame): seq[byte] = ret.add(data) return ret -proc send*(ws: WebSocket, data: seq[byte], opcode = Opcode.Text): Future[ - void] {.async.} = +proc send*( + ws: WebSocket, + data: seq[byte] = @[], + opcode = Opcode.Text): Future[void] {.async.} = + ## Send a frame + ## + + if ws.readyState == ReadyState.Closed: + raise newException(WSClosedError, "Socket is closed!") + + logScope: + opcode = opcode + dataSize = data.len + + debug "Sending data to remote" + + var maskKey: array[4, char] + if ws.masked: + maskKey = genMaskKey(ws.rng) + + if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: + discard await ws.tcpSocket.write(encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: opcode, + mask: ws.masked, + data: data, # allow sending data with close messages + maskKey: maskKey))) + + return + + let maxSize = ws.frameSize + var i = 0 + while i < data.len: + let len = min(data.len, (maxSize + i)) + let inFrame = Frame( + fin: if (i + len >= data.len): true else: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames + mask: ws.masked, + data: data[i ..< len], + maskKey: maskKey) + + discard await ws.tcpSocket.write(encodeFrame(inFrame)) + i += len + +proc send*(ws: WebSocket, data: string): Future[void] = + send(ws, toBytes(data), Opcode.Text) + +proc handleClose*(ws: WebSocket, frame: Frame) {.async.} = + logScope: + fin = frame.fin + masked = frame.mask + opcode = frame.opcode + serverState = ws.readyState + + debug "Handling close sequence" + if ws.readyState == ReadyState.Open or ws.readyState == ReadyState.Closing: + # Read control frame payload. + var data = newSeq[byte](frame.length) + if frame.length > 0: + # Read the data. + await ws.tcpSocket.readExactly(addr data[0], int frame.length) + unmask(data.toOpenArray(0, data.high), frame.maskKey) + + var code: Status + if data.len > 0: + let ccode = uint16.fromBytesBE(data[0..<2]) # first two bytes are the status + doAssert(ccode > 999, "No valid code in close message!") + code = Status(ccode) + data = data[2..data.high] + + var rcode = Status.Fulfilled + var reason = "" + if not isNil(ws.onClose): + try: + (rcode, reason) = ws.onClose(code, string.fromBytes(data)) + except CatchableError as exc: + debug "Exception in Close callback, this is most likelly a bug", exc = exc.msg + + # don't respong to a terminated connection + if ws.readyState != ReadyState.Closing: + await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) + + await ws.tcpSocket.closeWait() + ws.readyState = ReadyState.Closed + else: + raiseAssert("Invalid state during close!") + +proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = + ## handle control frames + ## + + if frame.length > 125: + raise newException(WSPayloadTooLarge, + "Control message payload is freater than 125 bytes!") + try: - var maskKey: array[4, char] - if ws.masked: - maskKey = genMaskKey(ws.rng) + # Process control frame payload. + case frame.opcode: + of Opcode.Ping: + if not isNil(ws.onPing): + try: + ws.onPing() + except CatchableError as exc: + debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg - var inFrame = Frame( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: opcode, - mask: ws.masked, - data: data, - maskKey: maskKey) - var frame = encodeFrame(inFrame) - const maxSize = 1024*1024 - # Send stuff in 1 megabyte chunks to prevent IOErrors. - # This really large packets. - var i = 0 - while i < frame.len: - let frameSize = min(frame.len, i + maxSize) - let res = await ws.tcpSocket.write(frame[i ..< frameSize]) - if res != frameSize: - raise newException(ValueError, "Error while send websocket frame") - i += maxSize - except OSError, ValueError: - # Wrap all exceptions in a WebSocketError so its easy to catch - raise newException(WebSocketError, "Failed to send data: " & - getCurrentExceptionMsg()) + # send pong to remote + await ws.send(@[], Opcode.Pong) + of Opcode.Pong: + if not isNil(ws.onPong): + try: + ws.onPong() + except CatchableError as exc: + debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg + of Opcode.Close: + await ws.handleClose(frame) + else: + raiseAssert("Invalid control opcode") + except CatchableError as exc: + debug "Exception handling control messages", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.tcpSocket.closeWait() -proc sendStr*(ws: WebSocket, data: string, opcode = Opcode.Text): Future[void] = - send(ws, toBytes(data), opcode) - -proc readFrame(ws: WebSocket): Future[Frame] {.async.} = +proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = ## Gets a frame from the WebSocket. ## See https://tools.ietf.org/html/rfc6455#section-5.2 + ## - # Grab the header. - var header = newSeq[byte](2) try: - await ws.tcpSocket.readExactly(addr header[0], 2) - except TransportUseClosedError: - ws.readyState = Closed - raise newException(WebSocketError, "Socket closed") + while ws.readyState != ReadyState.Closed: # read until a data frame arrives + # Grab the header. + var header = newSeq[byte](2) + await ws.tcpSocket.readExactly(addr header[0], 2) - if header.len != 2: - ws.readyState = Closed - raise newException(WebSocketError, "Invalid websocket header length") + if header.len != 2: + debug "Invalid websocket header length" + raise newException(WSMalformedHeaderError, "Invalid websocket header length") - let b0 = header[0].uint8 - let b1 = header[1].uint8 + let b0 = header[0].uint8 + let b1 = header[1].uint8 - var frame: Frame - # Read the flags and fin from the header. + var frame = Frame() + # Read the flags and fin from the header. - var hf = cast[HeaderFlags](b0 shr 4) - frame.fin = fin in hf - frame.rsv1 = rsv1 in hf - frame.rsv2 = rsv2 in hf - frame.rsv3 = rsv3 in hf + var hf = cast[HeaderFlags](b0 shr 4) + frame.fin = fin in hf + frame.rsv1 = rsv1 in hf + frame.rsv2 = rsv2 in hf + frame.rsv3 = rsv3 in hf - var opcode = b0 and 0x0f - if opcode notin WSOpCode: - raise newException(WebSocketError, "Unexpected websocket opcode") - frame.opcode = (opcode).Opcode + let opcode = (b0 and 0x0f) + if opcode > ord(Opcode.high): + raise newException(WSOpcodeMismatchError, "Wrong opcode!") - # If any of the rsv are set close the socket. - if frame.rsv1 or frame.rsv2 or frame.rsv3: - ws.readyState = Closed - raise newException(WebSocketError, "WebSocket rsv mismatch") + frame.opcode = (opcode).Opcode - # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. - var finalLen: uint64 = 0 + # If any of the rsv are set close the socket. + if frame.rsv1 or frame.rsv2 or frame.rsv3: + raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") - let headerLen = uint(b1 and 0x7f) - if headerLen == 0x7e: - # Length must be 7+16 bits. - var length = newSeq[byte](2) - await ws.tcpSocket.readExactly(addr length[0], 2) - finalLen = cast[ptr uint16](length[0].addr)[].toBE - elif headerLen == 0x7f: - # Length must be 7+64 bits. - var length = newSeq[byte](8) - await ws.tcpSocket.readExactly(addr length[0], 8) - finalLen = cast[ptr uint64](length[0].addr)[].toBE - else: - # Length must be 7 bits. - finalLen = headerLen - frame.length = finalLen + # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. + var finalLen: uint64 = 0 - # Do we need to apply mask? - frame.mask = (b1 and 0x80) == 0x80 + let headerLen = uint(b1 and 0x7f) + if headerLen == 0x7e: + # Length must be 7+16 bits. + var length = newSeq[byte](2) + await ws.tcpSocket.readExactly(addr length[0], 2) + finalLen = uint16.fromBytesBE(length) + elif headerLen == 0x7f: + # Length must be 7+64 bits. + var length = newSeq[byte](8) + await ws.tcpSocket.readExactly(addr length[0], 8) + finalLen = uint64.fromBytesBE(length) + else: + # Length must be 7 bits. + finalLen = headerLen + frame.length = finalLen - if ws.masked == frame.mask: - # Server sends unmasked but accepts only masked. - # Client sends masked but accepts only unmasked. - raise newException(WebSocketError, "Socket mask mismatch") + # Do we need to apply mask? + frame.mask = (b1 and 0x80) == 0x80 + if ws.masked == frame.mask: + # Server sends unmasked but accepts only masked. + # Client sends masked but accepts only unmasked. + raise newException(WSMaskMismatchError, "Socket mask mismatch") - var maskKey = newSeq[byte](4) - if frame.mask: - # Read the mask. - await ws.tcpSocket.readExactly(addr maskKey[0], 4) - for i in 0.. 0 : - # Read the data. - await ws.tcpSocket.readExactly(addr data[0], int finalLen) - frame.data = data + return frame + except CatchableError as exc: + debug "Exception reading frame, dropping socket", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.tcpSocket.closeWait() + raise exc - # Process control frame payload. - if frame.opcode == Ping: - await ws.send(data, Pong) - elif frame.opcode == Pong: - discard - elif frame.opcode == Close: - await ws.close(false) +proc ping*(ws: WebSocket): Future[void] = + ws.send(opcode = Opcode.Ping) - return frame +proc recv*( + ws: WebSocket, + data: pointer, + size: int): Future[int] {.async.} = + ## Attempts to read up to `size` bytes + ## + ## Will read as many frames as necesary + ## to fill the buffer until either + ## the message ends (frame.fin) or + ## the buffer is full. If no data is on + ## the pipe will await until at least + ## one byte is available + ## -proc close*(ws: WebSocket, initiator: bool = true) {.async.} = + var consumed = 0 + var pbuffer = cast[ptr UncheckedArray[byte]](data) + try: + while consumed < size: + # we might have to read more than + # one frame to fill the buffer + if isNil(ws.frame): + ws.frame = await ws.readFrame() + + # all has been consumed from the frame + # read the next frame + if ws.frame.remainder() <= 0: + ws.frame = await ws.readFrame() + + let len = min(ws.frame.remainder().int, size - consumed) + let read = await ws.tcpSocket.readOnce(addr pbuffer[consumed], len) + + if read <= 0: + continue + + if ws.frame.mask: + # unmask data using offset + unmask( + pbuffer.toOpenArray(consumed, (consumed + read) - 1), + ws.frame.maskKey, + consumed) + + consumed += read + ws.frame.consumed += read.uint64 + if ws.frame.fin and ws.frame.remainder().int <= 0: + break + + return consumed.int + except CancelledError as exc: + debug "Cancelling reading", exc = exc.msg + raise exc + except CatchableError as exc: + debug "Exception reading frames", exc = exc.msg + +proc recv*( + ws: WebSocket, + size = WSMaxMessageSize): Future[seq[byte]] {.async.} = + ## Attempt to read a full message up to max `size` + ## bytes in `frameSize` chunks. + ## + ## If no `fin` flag ever arrives it will await until + ## either cancelled or the `fin` flag arrives. + ## + ## If message is larger than `size` a `WSMaxMessageSizeError` + ## exception is thrown. + ## + ## In all other cases it awaits a full message. + ## + var res: seq[byte] + try: + while ws.readyState != ReadyState.Closed: + var buf = newSeq[byte](ws.frameSize) + let read = await ws.recv(addr buf[0], buf.len) + if read <= 0: + break + + buf.setLen(read) + if res.len + buf.len > size: + raise newException(WSMaxMessageSizeError, "Max message size exceeded") + + res.add(buf) + + # no more frames + if isNil(ws.frame): + break + + # read the entire message, exit + if ws.frame.fin and ws.frame.remainder().int <= 0: + break + except WSMaxMessageSizeError as exc: + raise exc + except CancelledError as exc: + debug "Cancelling reading", exc = exc.msg + raise exc + except CatchableError as exc: + debug "Exception reading frames", exc = exc.msg + + return res + +proc close*( + ws: WebSocket, + code: Status = Status.Fulfilled, + reason: string = "") {.async.} = ## Close the Socket, sends close packet. - if ws.readyState == Closed: - discard ws.tcpSocket.closeWait() + ## + + if ws.readyState != ReadyState.Open: return - ws.readyState = Closed - await ws.send(@[], Close) - if initiator == true: - let frame = await ws.readFrame() - if frame.opcode != Close: - echo "Different packet type" - await ws.close() -proc readMessage*(msgReader: MsgReader,data: seq[byte]): MsgReader {.async.} = - while msgReader.readErr == nil: - if msgReader.readRemaining > 0 : - len = size(data) - if len > msgReader.readRemaining: - len = msgReader.readRemaining + try: + ws.readyState = ReadyState.Closing + await ws.send( + prepareCloseBody(code, reason), + opcode = Opcode.Close) - await msgReader.tcpSocket.readExactly(addr data, len) - msgReader.readRemaining = msgReader.readRemaining - len - msgReader.readLen = len + # read frames until closed + while ws.readyState != ReadyState.Closed: + discard await ws.recv() - if msgReader.mask: - # Apply mask, if we need too. - for i in 0 ..< len: - data[i] = (data[i].uint8 xor msgReader.maskKey[i mod 4].uint8) + except CatchableError as exc: + debug "Exception closing", exc = exc.msg - if msgReader.readRemaining == 0: - msgReader.readErr = EOFError +proc connect*( + uri: Uri, + protocols: seq[string] = @[], + version = WSDefaultVersion, + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil): Future[WebSocket] {.async.} = + ## create a new websockets client + ## - return msgReader - - if msgReader.readFinal: - msgReader.readLen = 0 - msgReader.readErr = EOFError - return msgReader - - var frame = await ws.readFrame() - if frame.fin: - msgReader.readFinal = true - msgReader.readRemaining = frame.length - - # Non-control frames cannot occur in the middle of a fragmented non-control frame. - if frame.Opcode in Text || Binary: - raise newException("websocket: internal error, unexpected text or binary in Reader") - return msgReader - -proc nextMessageReader*(ws: WebSocket): MsgReader = - while true: - # Handle control frames and return only on non control frames. - var frame = await ws.readFrame() - if frame.Opcode in Text || Binary: - var msgReader: MsgReader - msgReader.readFinal = frame.fin - msgReader.readRemaining = frame.readRemaining - msgReader.tcpSocket = ws.tcpSocket - msgReader.mask = frame.mask - msgReader.maskKey = frame.maskKey - return msgReader - -proc receiveStrPacket*(ws: WebSocket): Future[seq[byte]] {.async.} = - # TODO: remove this once PR is approved. - return nil - -proc validateWSClientHandshake*(transp: StreamTransport, - header: HttpResponseHeader): void = - if header.code != ord(Http101): - raise newException(WebSocketError, "Server did not reply with a websocket upgrade: " & - "Header code: " & $header.code & - "Header reason: " & header.reason() & - "Address: " & $transp.remoteAddress()) - -proc newWebsocketClient*(uri: Uri, protocols: seq[string] = @[]): Future[ - WebSocket] {.async.} = - var key = encode(genWebSecKey(newRng())) + var key = Base64.encode(genWebSecKey(newRng())) var uri = uri case uri.scheme of "ws": uri.scheme = "http" else: - raise newException(WebSocketError, "uri scheme has to be 'ws'") + raise newException(WSWrongUriSchemeError, "uri scheme has to be 'ws'") var headers = newHttpHeaders({ "Connection": "Upgrade", "Upgrade": "websocket", "Cache-Control": "no-cache", - "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Version": $version, "Sec-WebSocket-Key": key }) + if protocols.len != 0: headers.table["Sec-WebSocket-Protocol"] = @[protocols.join(", ")] @@ -433,19 +652,52 @@ proc newWebsocketClient*(uri: Uri, protocols: seq[string] = @[]): Future[ var header = response.parseResponse() if header.failed(): # Header could not be parsed - raise newException(WebSocketError, "Malformed header received: " & + raise newException(WSMalformedHeaderError, "Malformed header received: " & $client.transp.remoteAddress()) - client.transp.validateWSClientHandshake(header) + + if header.code != ord(Http101): + raise newException(WSFailedUpgradeError, + "Server did not reply with a websocket upgrade: " & + "Header code: " & $header.code & + "Header reason: " & header.reason() & + "Address: " & $client.transp.remoteAddress()) # Client data should be masked. - return WebSocket(tcpSocket: client.transp, readyState: Open, masked: true, - rng: newRng()) + return WebSocket( + tcpSocket: client.transp, + readyState: Open, + masked: true, + rng: newRng(), + frameSize: frameSize, + onPing: onPing, + onPong: onPong, + onClose: onClose) + +proc connect*( + host: string, + port: Port, + path: string, + protocols: seq[string] = @[], + version = WSDefaultVersion, + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil): Future[WebSocket] {.async.} = + ## Create a new websockets client + ## using a string path + ## -proc newWebsocketClient*(host: string, port: Port, path: string, - protocols: seq[string] = @[]): Future[WebSocket] {.async.} = var uri = "ws://" & host & ":" & $port if path.startsWith("/"): uri.add path else: uri.add "/" & path - return await newWebsocketClient(parseUri(uri), protocols) + + return await connect( + parseUri(uri), + protocols, + version, + frameSize, + onPing, + onPong, + onClose) diff --git a/tests/frame.nim b/tests/frame.nim deleted file mode 100644 index 5f9f1f8..0000000 --- a/tests/frame.nim +++ /dev/null @@ -1,76 +0,0 @@ -include ../src/ws -include ../src/http -include ../src/random -#import chronos, chronicles, httputils, strutils, base64, std/sha1, -# streams, nativesockets, uri, times, chronos/timer, tables - -import unittest - -# TODO: Fix Test. - -var maskKey: array[4, char] - -suite "tests for encodeFrame()": - test "# 7bit length": - block: # 7bit length - assert encodeFrame(( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: Opcode.Text, - mask: false, - data: toBytes("hi there"), - maskKey: maskKey - )) == toBytes("\129\8hi there") - test "# 7bit length": - block: # 7+16 bits length - var data = "" - for i in 0..32: - data.add "How are you this is the payload!!!" - assert encodeFrame(( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: Opcode.Text, - mask: false, - data: toBytes(data), - maskKey: maskKey - ))[0..32] == toBytes("\129~\4bHow are you this is the paylo") - test "# 7+64 bits length": - block: # 7+64 bits length - var data = "" - for i in 0..3200: - data.add "How are you this is the payload!!!" - assert encodeFrame(( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: Opcode.Text, - mask: false, - data: toBytes(data), - maskKey: maskKey - ))[0..32] == toBytes("\129\127\0\0\0\0\0\1\169\"How are you this is the") - test "# masking": - block: # masking - let data = encodeFrame(( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: Opcode.Text, - mask: true, - data: toBytes("hi there"), - maskKey: ['\xCF', '\xD8', '\x05', 'e'] - )) - assert data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0") - -suite "tests for toTitleCase()": - block: - let val = toTitleCase("webSocket") - assert val == "Websocket" - - - diff --git a/tests/helpers.nim b/tests/helpers.nim index a43a4df..caf7eb4 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -1,31 +1,30 @@ import ../src/ws, chronos, chronicles, httputils, stew/byteutils, ../src/http, unittest, strutils -proc cb*(transp: StreamTransport, header: HttpRequestHeader) {.async.} = - info "Handling request:", uri = header.uri() +proc echoCb*(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + debug "Handling request:", uri = header.uri() if header.uri() == "/ws": - info "Initiating web socket connection." + debug "Initiating web socket connection." try: - var ws = await newWebSocket(header, transp, "myfancyprotocol") + var ws = await createServer(header, transp, "myfancyprotocol") if ws.readyState == Open: - info "Websocket handshake completed." + debug "Websocket handshake completed." else: error "Failed to open websocket connection." return - while ws.readyState == Open: - let recvData = await ws.receiveStrPacket() - info "Server:", state = ws.readyState - await ws.send(recvData) + let recvData = await ws.recv() + debug "Server:", state = ws.readyState + await ws.send(recvData) except WebSocketError: error "WebSocket error:", exception = getCurrentExceptionMsg() discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established") proc sendRecvClientData*(wsClient: WebSocket, msg: string) {.async.} = try: - waitFor wsClient.sendStr(msg) - let recvData = waitFor wsClient.receiveStrPacket() - info "Websocket client state: ", state = wsClient.readyState + await wsClient.send(msg) + let recvData = await wsClient.recv() + debug "Websocket client state: ", state = wsClient.readyState let dataStr = string.fromBytes(recvData) require dataStr == msg @@ -33,12 +32,12 @@ proc sendRecvClientData*(wsClient: WebSocket, msg: string) {.async.} = error "WebSocket error:", exception = getCurrentExceptionMsg() proc incorrectProtocolCB*(transp: StreamTransport, header: HttpRequestHeader) {.async.} = - info "Handling request:", uri = header.uri() + debug "Handling request:", uri = header.uri() var isErr = false; if header.uri() == "/ws": - info "Initiating web socket connection." + debug "Initiating web socket connection." try: - var ws = await newWebSocket(header, transp, "myfancyprotocol") + var ws = await createServer(header, transp, "myfancyprotocol") require ws.readyState == ReadyState.Closed except WebSocketError: isErr = true; diff --git a/tests/testall.nim b/tests/testall.nim new file mode 100644 index 0000000..51d4c49 --- /dev/null +++ b/tests/testall.nim @@ -0,0 +1,2 @@ +import ./testframes +import ./testwebsockets diff --git a/tests/testframes.nim b/tests/testframes.nim new file mode 100644 index 0000000..38e2ec0 --- /dev/null +++ b/tests/testframes.nim @@ -0,0 +1,280 @@ +import unittest + +include ../src/ws +include ../src/http +include ../src/random + +# TODO: Fix Test. + +var maskKey: array[4, char] + +suite "Test data frames": + test "# 7bit length text": + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\1\8hi there") + + test "# 7bit length text fin bit": + check encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\129\8hi there") + + test "# 7bit length binary": + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Binary, + mask: false, + data: toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\2\8hi there") + + test "# 7bit length binary fin bit": + check encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Binary, + mask: false, + data: toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\130\8hi there") + + test "# 7bit length continuation": + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Cont, + mask: false, + data: toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\0\8hi there") + + test "# 7+16 length text": + var data = "" + for i in 0..32: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\1\126\4\98" & data) + + test "# 7+16 length text fin bit": + var data = "" + for i in 0..32: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\1\126\4\98" & data) + + test "# 7+16 length binary": + var data = "" + for i in 0..32: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Binary, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\2\126\4\98" & data) + + test "# 7+16 length binary fin bit": + var data = "" + for i in 0..32: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Binary, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\130\126\4\98" & data) + + test "# 7+16 length continuation": + var data = "" + for i in 0..32: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Cont, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\0\126\4\98" & data) + + test "# 7+64 length text": + var data = "" + for i in 0..3200: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\1\127\0\0\0\0\0\1\169\34" & data) + + test "# 7+64 length fin bit": + var data = "" + for i in 0..3200: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\129\127\0\0\0\0\0\1\169\34" & data) + + test "# 7+64 length binary": + var data = "" + for i in 0..3200: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Binary, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\2\127\0\0\0\0\0\1\169\34" & data) + + test "# 7+64 length binary fin bit": + var data = "" + for i in 0..3200: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Binary, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\130\127\0\0\0\0\0\1\169\34" & data) + + test "# 7+64 length binary": + var data = "" + for i in 0..3200: + data.add "How are you this is the payload!!!" + + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Cont, + mask: false, + data: toBytes(data), + maskKey: maskKey + )) == toBytes("\0\127\0\0\0\0\0\1\169\34" & data) + + test "# masking": + let data = encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: true, + data: toBytes("hi there"), + maskKey: ['\xCF', '\xD8', '\x05', 'e'] + )) + + check data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0") + +suite "Test control frames": + + test "Close": + check encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Close, + mask: false, + data: @[3'u8, 232'u8] & toBytes("hi there"), + maskKey: maskKey + )) == toBytes("\136\10\3\232hi there") + + test "Ping": + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Ping, + mask: false, + maskKey: maskKey + )) == toBytes("\9\0") + + test "Pong": + check encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Pong, + mask: false, + maskKey: maskKey + )) == toBytes("\10\0") diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim new file mode 100644 index 0000000..cbf64f3 --- /dev/null +++ b/tests/testwebsockets.nim @@ -0,0 +1,387 @@ +import std/strutils +import pkg/[asynctest, chronos, httputils] +import pkg/stew/byteutils + +import ../src/http, + ../src/ws, + ../src/random + +import ./helpers + +var httpServer: HttpServer + +suite "Test handshake": + teardown: + httpServer.stop() + await httpServer.closeWait() + + test "Test for incorrect protocol": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + expect WSProtoMismatchError: + var ws = await createServer(header, transp, "proto") + check ws.readyState == ReadyState.Closed + + check await transp.sendHTTPResponse( + HttpVersion11, + Http200, + "Connection established") + + await transp.closeWait() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + expect WSFailedUpgradeError: + discard await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["wrongproto"]) + + test "Test for incorrect version": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + expect WSVersionError: + var ws = await createServer(header, transp, "proto") + check ws.readyState == ReadyState.Closed + + check await transp.sendHTTPResponse( + HttpVersion11, + Http200, + "Connection established") + + await transp.closeWait() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + expect WSFailedUpgradeError: + discard await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["wrongproto"], + version = 14) + + test "Test for client headers": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + check header["Connection"].toUpperAscii() == "Upgrade".toUpperAscii() + check header["Upgrade"].toUpperAscii() == "websocket".toUpperAscii() + check header["Cache-Control"].toUpperAscii() == "no-cache".toUpperAscii() + check header["Sec-WebSocket-Version"] == $WSDefaultVersion + + check "Sec-WebSocket-Key" in header + + await transp.closeWait() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + expect ValueError: + discard await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + +suite "Test transmission": + teardown: + httpServer.stop() + await httpServer.closeWait() + + test "Server - test reading simple frame": + let testString = "Hello!" + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer(header, transp, "proto") + let res = await ws.recv() + + check string.fromBytes(res) == testString + await transp.closeWait() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + await ws.send(testString) + + test "Client - test reading simple frame": + let testString = "Hello!" + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer(header, transp, "proto") + await ws.send(testString) + await transp.closeWait() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + let res = await ws.recv() + check string.fromBytes(res) == testString + +suite "Test ping-pong": + teardown: + httpServer.stop() + await httpServer.closeWait() + + test "Server - test ping-pong control messages": + var ping, pong = false + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer( + header, + transp, + "proto", + onPong = proc() = + pong = true + ) + + await ws.ping() + await ws.close() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onPing = proc() = + ping = true + ) + + discard await ws.recv() + + check: + ping + pong + + test "Client - test ping-pong control messages": + var ping, pong = false + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer( + header, + transp, + "proto", + onPing = proc() = + ping = true + ) + + discard await ws.recv() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onPong = proc() = + pong = true + ) + + await ws.ping() + await ws.close() + + check: + ping + pong + +suite "Test framing": + teardown: + httpServer.stop() + await httpServer.closeWait() + + test "should split message into frames": + let testString = "1234567890" + var done = newFuture[void]() + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer(header, transp, "proto") + + let frame1 = await ws.readFrame() + check not isNil(frame1) + var data1 = newSeq[byte](frame1.remainder().int) + let read1 = await ws.tcpSocket.readOnce(addr data1[0], data1.len) + check read1 == 5 + + let frame2 = await ws.readFrame() + check not isNil(frame2) + var data2 = newSeq[byte](frame2.remainder().int) + let read2 = await ws.tcpSocket.readOnce(addr data2[0], data2.len) + check read2 == 5 + + await transp.closeWait() + done.complete() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + frameSize = 5) + + await ws.send(testString) + await done + + test "should fail to read past max message size": + let testString = "1234567890" + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer(header, transp, "proto") + await ws.send(testString) + await transp.closeWait() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + expect WSMaxMessageSizeError: + discard await ws.recv(5) + +suite "Test Closing": + teardown: + httpServer.stop() + await httpServer.closeWait() + + test "Server closing": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer(header, transp, "proto") + await ws.close() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + discard await ws.recv() + check ws.readyState == ReadyState.Closed + + test "Server closing with status": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + proc closeServer(status: Status, reason: string): CloseResult {.gcsafe.} = + check status == Status.TooLarge + check reason == "Message too big!" + + return (Status.Fulfilled, "") + + let ws = await createServer( + header, + transp, + "proto", + onClose = closeServer) + + await ws.close() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + proc clientClose(status: Status, reason: string): CloseResult {.gcsafe.} = + check status == Status.Fulfilled + + return (Status.TooLarge, "Message too big!") + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onClose = clientClose) + + discard await ws.recv() + check ws.readyState == ReadyState.Closed + + test "Client closing": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + let ws = await createServer(header, transp, "proto") + discard await ws.recv() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + await ws.close() + + test "Client closing with status": + proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = + check header.uri() == "/ws" + + proc closeServer(status: Status, reason: string): CloseResult {.gcsafe.} = + check status == Status.Fulfilled + + return (Status.TooLarge, "Message too big!") + + let ws = await createServer( + header, + transp, + "proto", + onClose = closeServer) + + discard await ws.recv() + + httpServer = newHttpServer("127.0.0.1:8888", cb) + httpServer.start() + + proc clientClose(status: Status, reason: string): CloseResult {.gcsafe.} = + check status == Status.TooLarge + check reason == "Message too big!" + + return (Status.Fulfilled, "") + + let ws = await connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onClose = clientClose) + + await ws.close() + check ws.readyState == ReadyState.Closed diff --git a/tests/websocket.nim b/tests/websocket.nim deleted file mode 100644 index 6abdbfc..0000000 --- a/tests/websocket.nim +++ /dev/null @@ -1,87 +0,0 @@ -import helpers, unittest, ../src/http, chronos, ../src/ws,../src/random, - stew/byteutils, os, strutils - -var httpServer: HttpServer - -proc startServer() {.async, gcsafe.} = - httpServer = newHttpServer("127.0.0.1:8888", cb) - httpServer.start() - -proc closeServer() {.async, gcsafe.} = - httpServer.stop() - waitFor httpServer.closeWait() - -suite "Test websocket error cases": - teardown: - httpServer.stop() - waitFor httpServer.closeWait() - - test "Test for incorrect protocol": - httpServer = newHttpServer("127.0.0.1:8888", incorrectProtocolCB) - httpServer.start() - try: - let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), - path = "/ws", protocols = @["mywrongprotocol"]) - except WebSocketError: - require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade") - - test "Test for incorrect port": - httpServer = newHttpServer("127.0.0.1:8888", cb) - httpServer.start() - try: - let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8889), - path = "/ws", protocols = @["myfancyprotocol"]) - except: - require contains(getCurrentExceptionMsg(), "Connection refused") - - test "Test for incorrect path": - httpServer = newHttpServer("127.0.0.1:8888", cb) - httpServer.start() - try: - let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), - path = "/gg", protocols = @["myfancyprotocol"]) - except: - require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade") - -suite "Misc Test": - setup: - waitFor startServer() - teardown: - waitFor closeServer() - - test "Test for maskKey": - let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws", - protocols = @["myfancyprotocol"]) - let maskKey = genMaskKey(wsClient.rng) - require maskKey.len == 4 - - test "Test for toCaseInsensitive": - let headers = newHttpHeaders() - require toCaseInsensitive(headers, "webSocket") == "Websocket" - - -suite "Test web socket communication": - - setup: - waitFor startServer() - let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), - path = "/ws", protocols = @["myfancyprotocol"]) - - teardown: - waitFor closeServer() - - test "Websocket conversation between client and server": - waitFor sendRecvClientData(wsClient, "Hello Server") - - test "Test for small message ": - let msg = string.fromBytes(generateData(100)) - waitFor sendRecvClientData(wsClient, msg) - - test "Test for medium message ": - let msg = string.fromBytes(generateData(1000)) - waitFor sendRecvClientData(wsClient, msg) - - test "Test for large message ": - let msg = string.fromBytes(generateData(10000)) - waitFor sendRecvClientData(wsClient, msg) - diff --git a/ws.nimble b/ws.nimble index 86f8ad7..b86b81f 100644 --- a/ws.nimble +++ b/ws.nimble @@ -14,5 +14,5 @@ requires "eth" requires "asynctest >= 0.2.0 & < 0.3.0" requires "nimcrypto" -task lint, "format source files according to the official style guide": - exec "./lint.nims" +task test, "run tests": + exec "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off ./tests/testall.nim"