From 8fb4e78353f5795f28f5b5dd697220efa14d3537 Mon Sep 17 00:00:00 2001 From: Arijit Das Date: Tue, 8 Dec 2020 18:19:22 +0530 Subject: [PATCH] Support Websocket handshake and update the readme. --- Readme.md | 13 +++++++ src/ws.nim | 96 +++++++++++++++++++++++++++++++++++++++++++++---- test/server.nim | 12 +++++-- 3 files changed, 112 insertions(+), 9 deletions(-) diff --git a/Readme.md b/Readme.md index 91429af..b835cd9 100644 --- a/Readme.md +++ b/Readme.md @@ -25,3 +25,16 @@ Testing Server Response: ```bash curl --location --request GET 'http://localhost:8888' ``` + +Testing Websocket Handshake: +```bash +curl --include \ + --no-buffer \ + --header "Connection: Upgrade" \ + --header "Upgrade: websocket" \ + --header "Host: example.com:80" \ + --header "Origin: http://example.com:80" \ + --header "Sec-WebSocket-Key: SGVsbG8sIHdvcmxkIQ==" \ + --header "Sec-WebSocket-Version: 13" \ + http://localhost:8888/ws +``` diff --git a/src/ws.nim b/src/ws.nim index 24e7da1..8baa4d8 100644 --- a/src/ws.nim +++ b/src/ws.nim @@ -1,4 +1,4 @@ -import chronos, chronicles, httputils +import chronos, chronicles, httputils, strutils, base64, std/sha1, random const MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets @@ -22,7 +22,8 @@ type readyState*: ReadyState masked*: bool # send masked packets - AsyncCallback = proc (transp: StreamTransport, header: HttpRequestHeader): Future[void] {.closure, gcsafe.} + AsyncCallback = proc (transp: StreamTransport, + header: HttpRequestHeader): Future[void] {.closure, gcsafe.} HttpServer* = ref object of StreamServer callback: AsyncCallback @@ -31,6 +32,89 @@ type WebSocketError* = object of IOError +template `[]`(value: uint8, index: int): bool = + ## Get bits from uint8, uint8[2] gets 2nd bit. + (value and (1 shl (7 - index))) != 0 + +proc nibbleFromChar(c: char): int = + ## Converts hex chars like `0` to 0 and `F` to 15. + case c: + of '0'..'9': (ord(c) - ord('0')) + of 'a'..'f': (ord(c) - ord('a') + 10) + of 'A'..'F': (ord(c) - ord('A') + 10) + else: 255 + +proc nibbleToChar(value: int): char = + ## Converts number like 0 to `0` and 15 to `fg`. + case value: + of 0..9: char(value + ord('0')) + else: char(value + ord('a') - 10) + +proc decodeBase16*(str: string): string = + ## Base16 decode a string. + result = newString(str.len div 2) + for i in 0 ..< result.len: + result[i] = chr( + (nibbleFromChar(str[2 * i]) shl 4) or + nibbleFromChar(str[2 * i + 1])) + +proc encodeBase16*(str: string): string = + ## Base61 encode a string. + result = newString(str.len * 2) + for i, c in str: + result[i * 2] = nibbleToChar(ord(c) shr 4) + result[i * 2 + 1] = nibbleToChar(ord(c) and 0x0f) + +proc genMaskKey(): array[4, char] = + ## Generates a random key of 4 random chars. + proc r(): char = char(rand(255)) + [r(), r(), r(), r()] + +proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} = + ## Handles the websocket handshake. + ws.version = parseInt(header["Sec-WebSocket-Version"]) + ws.key = header["Sec-WebSocket-Key"].strip() + if header.contains("Sec-WebSocket-Protocol"): + let wantProtocol = header["Sec-WebSocket-Protocol"].strip() + if ws.protocol != wantProtocol: + raise newException(WebSocketError, + "Protocol mismatch (expected: " & ws.protocol & ", got: " & wantProtocol & ")") + + let + sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + acceptKey = base64.encode(decodeBase16($sh)) + + var response = "HTTP/1.1 101 Web Socket Protocol Handshake\c\L" + response.add("Sec-WebSocket-Accept: " & acceptKey & "\c\L") + response.add("Connection: Upgrade\c\L") + response.add("Upgrade: webSocket\c\L") + + if ws.protocol != "": + response.add("Sec-WebSocket-Protocol: " & ws.protocol & "\c\L") + response.add "\c\L" + + discard await ws.tcpSocket.write(response) + ws.readyState = Open + +proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport, protocol: string = ""): Future[ + WebSocket] {.async.} = + ## Creates a new socket from a request. + try: + if not header.contains("Sec-WebSocket-Version"): + raise newException(WebSocketError, "Invalid WebSocket handshake") + var ws = WebSocket() + ws.masked = false + ws.protocol = protocol + ws.tcpSocket = transp + await ws.handshake(header) + return ws + except ValueError, KeyError: + # Wrap all exceptions in a WebSocketError so its easy to catch. + raise newException( + WebSocketError, + "Failed to create WebSocket from request: " & getCurrentExceptionMsg() + ) + proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: HttpCode, data: string = ""): Future[bool] {.async.} = var answer = $version @@ -128,21 +212,21 @@ proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = return let vres = await validateRequest(transp, header) - if vres == Success: trace "Received valid RPC request", address = $transp.remoteAddress() # Call the user's callback. if httpServer.callback != nil: await httpServer.callback(transp, header) - + await transp.closeWait() elif vres == ErrorFailure: debug "Remote peer disconnected", address = transp.remoteAddress() await transp.closeWait() -proc newHttpServer*(address: string, handler:AsyncCallback, +proc newHttpServer*(address: string, handler: AsyncCallback, flags: set[ServerFlags] = {ReuseAddr}): HttpServer = new result let address = initTAddress(address) result.callback = handler - result = cast[HttpServer](createStreamServer(address, serveClient, flags, child = cast[StreamServer](result))) + result = cast[HttpServer](createStreamServer(address, serveClient, flags, + child = cast[StreamServer](result))) diff --git a/test/server.nim b/test/server.nim index 17716f8..5f85881 100644 --- a/test/server.nim +++ b/test/server.nim @@ -1,10 +1,16 @@ import ws, chronos, chronicles, httputils proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} = - info "Header: ", header + info "Header: ", uri = header.uri() + if header.uri() == "/ws": + info "Initiating web socket connection." + try: + var ws = await newWebSocket(header, transp) + info "Websocket handshake completed." + except WebSocketError: + echo "socket closed:", getCurrentExceptionMsg() + let res = await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World") - debug "Disconnecting client", address = transp.remoteAddress() - await transp.closeWait() when isMainModule: let address = "127.0.0.1:8888"