mirror of
https://github.com/codex-storage/nim-websock.git
synced 2025-03-01 02:00:33 +00:00
[WIP] Web socket client implementation. (#2)
* Implement websocket server. * Implement websocket client. * Run nimpretty. * Remove commented code. * Address comments. * Address comments on websocket server. * Use seq[byte] to store data. * Working bytes conversion. * Remove result from return * Refactor the code. * Minor change. * Add test. * Add websocket test and fix closing handshake. * Add MsgReader to read data in external buffer.
This commit is contained in:
parent
8e34e0f138
commit
a1ae7d2c70
21
examples/client.nim
Normal file
21
examples/client.nim
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import ../src/ws, nativesockets, chronos, os, chronicles, stew/byteutils
|
||||||
|
|
||||||
|
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws",
|
||||||
|
protocols = @["myfancyprotocol"])
|
||||||
|
info "Websocket client: ", State = wsClient.readyState
|
||||||
|
|
||||||
|
let reqData = "Hello Server"
|
||||||
|
for idx in 1 .. 5:
|
||||||
|
try:
|
||||||
|
waitFor wsClient.sendStr(reqData)
|
||||||
|
let recvData = waitFor wsClient.receiveStrPacket()
|
||||||
|
let dataStr = string.fromBytes(recvData)
|
||||||
|
info "Server:", data = dataStr
|
||||||
|
assert dataStr == reqData
|
||||||
|
except WebSocketError:
|
||||||
|
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||||
|
os.sleep(1000)
|
||||||
|
|
||||||
|
# close the websocket
|
||||||
|
waitFor wsClient.close()
|
||||||
|
|
41
examples/server.nim
Normal file
41
examples/server.nim
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import ../src/ws, ../src/http, chronos, chronicles, httputils, stew/byteutils
|
||||||
|
|
||||||
|
proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||||
|
info "Handling request:", uri = header.uri()
|
||||||
|
if header.uri() == "/ws":
|
||||||
|
info "Initiating web socket connection."
|
||||||
|
try:
|
||||||
|
var ws = await newWebSocket(header, transp, "myfancyprotocol")
|
||||||
|
if ws.readyState == Open:
|
||||||
|
info "Websocket handshake completed."
|
||||||
|
else:
|
||||||
|
error "Failed to open websocket connection."
|
||||||
|
return
|
||||||
|
|
||||||
|
while true:
|
||||||
|
# Only reads header for data frame.
|
||||||
|
let msgReader = await ws.nextMessageReader()
|
||||||
|
|
||||||
|
# Read the frame payload in buffer.
|
||||||
|
let buffer = newSeq[byte](100)
|
||||||
|
var recvData :seq[byte]
|
||||||
|
while msgReader.error != EOFError:
|
||||||
|
msgReader.readMessage(buffer)
|
||||||
|
recvData.add buffer
|
||||||
|
if ws.readyState == ReadyState.Closed:
|
||||||
|
return
|
||||||
|
info "Response: ", data = recvData
|
||||||
|
await ws.send(recvData)
|
||||||
|
|
||||||
|
except WebSocketError:
|
||||||
|
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||||
|
|
||||||
|
discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World")
|
||||||
|
await transp.closeWait()
|
||||||
|
|
||||||
|
when isMainModule:
|
||||||
|
let address = "127.0.0.1:8888"
|
||||||
|
var httpServer = newHttpServer(address, cb)
|
||||||
|
httpServer.start()
|
||||||
|
echo "Server started..."
|
||||||
|
waitFor httpServer.join()
|
250
src/http.nim
Normal file
250
src/http.nim
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
import chronos, chronos/timer, httputils, chronicles, uri, tables, strutils
|
||||||
|
|
||||||
|
const
|
||||||
|
MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets
|
||||||
|
MaxHttpRequestSize = 128 * 1024 # maximum size of HTTP body in octets
|
||||||
|
HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec)
|
||||||
|
CRLF* = "\r\n"
|
||||||
|
HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')]
|
||||||
|
|
||||||
|
type
|
||||||
|
HttpClient* = ref object
|
||||||
|
connected: bool
|
||||||
|
currentURL: Uri ## Where we are currently connected.
|
||||||
|
headers: HttpHeaders ## Headers to send in requests.
|
||||||
|
transp*: StreamTransport
|
||||||
|
buf: seq[byte]
|
||||||
|
|
||||||
|
HttpHeaders* = object
|
||||||
|
table*: TableRef[string, seq[string]]
|
||||||
|
|
||||||
|
ReqStatus = enum
|
||||||
|
Success, Error, ErrorFailure
|
||||||
|
|
||||||
|
AsyncCallback = proc (transp: StreamTransport,
|
||||||
|
header: HttpRequestHeader): Future[void] {.closure, gcsafe.}
|
||||||
|
HttpServer* = ref object of StreamServer
|
||||||
|
callback: AsyncCallback
|
||||||
|
|
||||||
|
proc recvData(transp: StreamTransport): Future[seq[byte]] {.async.} =
|
||||||
|
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||||
|
var error = false
|
||||||
|
try:
|
||||||
|
let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize,
|
||||||
|
sep = HeaderSep)
|
||||||
|
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||||
|
if not ores:
|
||||||
|
# Timeout
|
||||||
|
debug "Timeout expired while receiving headers",
|
||||||
|
address = transp.remoteAddress()
|
||||||
|
error = true
|
||||||
|
else:
|
||||||
|
let hlen = hlenfut.read()
|
||||||
|
buffer.setLen(hlen)
|
||||||
|
except TransportLimitError:
|
||||||
|
# size of headers exceeds `MaxHttpHeadersSize`
|
||||||
|
debug "Maximum size of headers limit reached",
|
||||||
|
address = transp.remoteAddress()
|
||||||
|
error = true
|
||||||
|
except TransportIncompleteError:
|
||||||
|
# remote peer disconnected
|
||||||
|
debug "Remote peer disconnected", address = transp.remoteAddress()
|
||||||
|
error = true
|
||||||
|
except TransportOsError as exc:
|
||||||
|
debug "Problems with networking", address = transp.remoteAddress(),
|
||||||
|
error = exc.msg
|
||||||
|
error = true
|
||||||
|
|
||||||
|
if error:
|
||||||
|
buffer.setLen(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
proc newConnection(client: HttpClient, url: Uri) {.async.} =
|
||||||
|
if client.connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
let port =
|
||||||
|
if url.port == "": 80
|
||||||
|
else: url.port.parseInt
|
||||||
|
|
||||||
|
client.transp = await connect(initTAddress(url.hostname, port))
|
||||||
|
|
||||||
|
# May be connected through proxy but remember actual URL being accessed
|
||||||
|
client.currentURL = url
|
||||||
|
client.connected = true
|
||||||
|
|
||||||
|
proc generateHeaders(requestUrl: Uri, httpMethod: string,
|
||||||
|
additionalHeaders: HttpHeaders): string =
|
||||||
|
# GET
|
||||||
|
var headers = httpMethod.toUpperAscii()
|
||||||
|
headers.add ' '
|
||||||
|
|
||||||
|
if not requestUrl.path.startsWith("/"): headers.add '/'
|
||||||
|
headers.add(requestUrl.path)
|
||||||
|
|
||||||
|
# HTTP/1.1\c\l
|
||||||
|
headers.add(" HTTP/1.1" & CRLF)
|
||||||
|
|
||||||
|
for key, val in additionalHeaders.table:
|
||||||
|
headers.add(key & ": " & val.join(", ") & CRLF)
|
||||||
|
headers.add(CRLF)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
# Send request to the client. Currently only supports HTTP get method.
|
||||||
|
proc request*(client: HttpClient, url, httpMethod: string,
|
||||||
|
body = "", headers: HttpHeaders): Future[seq[byte]]
|
||||||
|
{.async.} =
|
||||||
|
# Helper that actually makes the request. Does not handle redirects.
|
||||||
|
let requestUrl = parseUri(url)
|
||||||
|
if requestUrl.scheme == "":
|
||||||
|
raise newException(ValueError, "No uri scheme supplied.")
|
||||||
|
|
||||||
|
await newConnection(client, requestUrl)
|
||||||
|
|
||||||
|
let headerString = generateHeaders(requestUrl, httpMethod, headers)
|
||||||
|
let res = await client.transp.write(headerString)
|
||||||
|
if res != len(headerString):
|
||||||
|
raise newException(ValueError, "Error while send request to client")
|
||||||
|
|
||||||
|
var value = await client.transp.recvData()
|
||||||
|
if value.len == 0:
|
||||||
|
raise newException(ValueError, "Empty response from server")
|
||||||
|
return value
|
||||||
|
|
||||||
|
proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: HttpCode,
|
||||||
|
data: string = ""): Future[bool] {.async.} =
|
||||||
|
var answer = $version
|
||||||
|
answer.add(" ")
|
||||||
|
answer.add($code)
|
||||||
|
answer.add(CRLF)
|
||||||
|
answer.add("Date: " & httpDate() & CRLF)
|
||||||
|
if len(data) > 0:
|
||||||
|
answer.add("Content-Type: application/json" & CRLF)
|
||||||
|
answer.add("Content-Length: " & $len(data) & CRLF)
|
||||||
|
answer.add(CRLF)
|
||||||
|
if len(data) > 0:
|
||||||
|
answer.add(data)
|
||||||
|
|
||||||
|
let res = await transp.write(answer)
|
||||||
|
if res == len(answer):
|
||||||
|
return true
|
||||||
|
raise newException(IOError, "Failed to send http request.")
|
||||||
|
|
||||||
|
proc validateRequest(transp: StreamTransport,
|
||||||
|
header: HttpRequestHeader): Future[ReqStatus] {.async.} =
|
||||||
|
if header.meth notin {MethodGet}:
|
||||||
|
debug "GET method is only allowed", address = transp.remoteAddress()
|
||||||
|
if await transp.sendHTTPResponse(header.version, Http405):
|
||||||
|
return Error
|
||||||
|
else:
|
||||||
|
return ErrorFailure
|
||||||
|
|
||||||
|
var hlen = header.contentLength()
|
||||||
|
if hlen < 0 or hlen > MaxHttpRequestSize:
|
||||||
|
debug "Invalid header length", address = transp.remoteAddress()
|
||||||
|
if await transp.sendHTTPResponse(header.version, Http413):
|
||||||
|
return Error
|
||||||
|
else:
|
||||||
|
return ErrorFailure
|
||||||
|
|
||||||
|
return Success
|
||||||
|
|
||||||
|
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||||
|
## Process transport data to the RPC server
|
||||||
|
var httpServer = cast[HttpServer](server)
|
||||||
|
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||||
|
var header: HttpRequestHeader
|
||||||
|
|
||||||
|
info "Received connection", address = $transp.remoteAddress()
|
||||||
|
try:
|
||||||
|
let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize,
|
||||||
|
sep = HeaderSep)
|
||||||
|
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||||
|
if not ores:
|
||||||
|
# Timeout
|
||||||
|
debug "Timeout expired while receiving headers",
|
||||||
|
address = transp.remoteAddress()
|
||||||
|
discard await transp.sendHTTPResponse(HttpVersion11, Http408)
|
||||||
|
await transp.closeWait()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
let hlen = hlenfut.read()
|
||||||
|
buffer.setLen(hlen)
|
||||||
|
header = buffer.parseRequest()
|
||||||
|
if header.failed():
|
||||||
|
# Header could not be parsed
|
||||||
|
debug "Malformed header received",
|
||||||
|
address = transp.remoteAddress()
|
||||||
|
discard await transp.sendHTTPResponse(HttpVersion11, Http400)
|
||||||
|
await transp.closeWait()
|
||||||
|
return
|
||||||
|
var vres = await validateRequest(transp, header)
|
||||||
|
if vres == Success:
|
||||||
|
info "Received valid RPC request", address = $transp.remoteAddress()
|
||||||
|
# Call the user's callback.
|
||||||
|
if httpServer.callback != nil:
|
||||||
|
await httpServer.callback(transp, header)
|
||||||
|
elif vres == ErrorFailure:
|
||||||
|
debug "Remote peer disconnected", address = transp.remoteAddress()
|
||||||
|
except TransportLimitError:
|
||||||
|
# size of headers exceeds `MaxHttpHeadersSize`
|
||||||
|
debug "Maximum size of headers limit reached",
|
||||||
|
address = transp.remoteAddress()
|
||||||
|
discard await transp.sendHTTPResponse(HttpVersion11, Http413)
|
||||||
|
except TransportIncompleteError:
|
||||||
|
# remote peer disconnected
|
||||||
|
debug "Remote peer disconnected", address = transp.remoteAddress()
|
||||||
|
except TransportOsError as exc:
|
||||||
|
debug "Problems with networking", address = transp.remoteAddress(),
|
||||||
|
error = exc.msg
|
||||||
|
except CatchableError as exc:
|
||||||
|
debug "Unknown exception", address = transp.remoteAddress(),
|
||||||
|
error = exc.msg
|
||||||
|
await transp.closeWait()
|
||||||
|
|
||||||
|
proc newHttpServer*(address: string, handler: AsyncCallback,
|
||||||
|
flags: set[ServerFlags] = {ReuseAddr}): HttpServer =
|
||||||
|
let address = initTAddress(address)
|
||||||
|
var server = HttpServer(callback: handler)
|
||||||
|
server = cast[HttpServer](createStreamServer(address, serveClient, flags,
|
||||||
|
child = cast[StreamServer](server)))
|
||||||
|
return server
|
||||||
|
|
||||||
|
func toTitleCase(s: string): string =
|
||||||
|
var tcstr = newString(len(s))
|
||||||
|
var upper = true
|
||||||
|
for i in 0..len(s) - 1:
|
||||||
|
tcstr[i] = if upper: toUpperAscii(s[i]) else: toLowerAscii(s[i])
|
||||||
|
upper = s[i] == '-'
|
||||||
|
return tcstr
|
||||||
|
|
||||||
|
func toCaseInsensitive*(headers: HttpHeaders, s: string): string {.inline.} =
|
||||||
|
return toTitleCase(s)
|
||||||
|
|
||||||
|
func newHttpHeaders*(): HttpHeaders =
|
||||||
|
## Returns a new ``HttpHeaders`` object. if ``titleCase`` is set to true,
|
||||||
|
## headers are passed to the server in title case (e.g. "Content-Length")
|
||||||
|
return HttpHeaders(table: newTable[string, seq[string]]())
|
||||||
|
|
||||||
|
func newHttpHeaders*(keyValuePairs:
|
||||||
|
openArray[tuple[key: string, val: string]]): HttpHeaders =
|
||||||
|
## Returns a new ``HttpHeaders`` object from an array. if ``titleCase`` is set to true,
|
||||||
|
## headers are passed to the server in title case (e.g. "Content-Length")
|
||||||
|
var headers = newHttpHeaders()
|
||||||
|
|
||||||
|
for pair in keyValuePairs:
|
||||||
|
let key = headers.toCaseInsensitive(pair.key)
|
||||||
|
if key in headers.table:
|
||||||
|
headers.table[key].add(pair.val)
|
||||||
|
else:
|
||||||
|
headers.table[key] = @[pair.val]
|
||||||
|
return headers
|
||||||
|
|
||||||
|
proc newHttpClient*(headers = newHttpHeaders()): HttpClient =
|
||||||
|
return HttpClient(headers: headers)
|
||||||
|
|
||||||
|
proc close*(client: HttpClient) =
|
||||||
|
## Closes any connections held by the HTTP client.
|
||||||
|
if client.connected:
|
||||||
|
client.transp.close()
|
||||||
|
client.connected = false
|
25
src/random.nim
Normal file
25
src/random.nim
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import bearssl
|
||||||
|
|
||||||
|
## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng
|
||||||
|
const randMax = 18_446_744_073_709_551_615'u64
|
||||||
|
|
||||||
|
proc rand*(rng: var BrHmacDrbgContext, max: Natural): int =
|
||||||
|
if max == 0: return 0
|
||||||
|
var x: uint64
|
||||||
|
while true:
|
||||||
|
brHmacDrbgGenerate(addr rng, addr x, csize_t(sizeof(x)))
|
||||||
|
if x < randMax - (randMax mod (uint64(max) + 1'u64)): # against modulo bias
|
||||||
|
return int(x mod (uint64(max) + 1'u64))
|
||||||
|
|
||||||
|
proc genMaskKey*(rng: ref BrHmacDrbgContext): array[4, char] =
|
||||||
|
## Generates a random key of 4 random chars.
|
||||||
|
proc r(): char = char(rand(rng[], 255))
|
||||||
|
return [r(), r(), r(), r()]
|
||||||
|
|
||||||
|
proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[char] =
|
||||||
|
var key = newSeq[char](16)
|
||||||
|
proc r(): char = char(rand(rng[], 255))
|
||||||
|
## Generates a random key of 16 random chars.
|
||||||
|
for i in 0..15:
|
||||||
|
key.add(r())
|
||||||
|
return key
|
451
src/ws.nim
Normal file
451
src/ws.nim
Normal file
@ -0,0 +1,451 @@
|
|||||||
|
import httputils, strutils, base64, std/sha1, ./random, http, uri,
|
||||||
|
chronos/timer, tables, stew/byteutils, eth/[keys], stew/endians2,
|
||||||
|
parseutils, stew/base64 as stewBase,chronos
|
||||||
|
|
||||||
|
const
|
||||||
|
SHA1DigestSize = 20
|
||||||
|
WSHeaderSize = 12
|
||||||
|
WSOpCode = {0x00, 0x01, 0x02, 0x08, 0x09, 0x0a}
|
||||||
|
|
||||||
|
type
|
||||||
|
ReadyState* = enum
|
||||||
|
Connecting = 0 # The connection is not yet open.
|
||||||
|
Open = 1 # The connection is open and ready to communicate.
|
||||||
|
Closing = 2 # The connection is in the process of closing.
|
||||||
|
Closed = 3 # The connection is closed or couldn't be opened.
|
||||||
|
|
||||||
|
WebSocket* = ref object
|
||||||
|
tcpSocket*: StreamTransport
|
||||||
|
version*: int
|
||||||
|
key*: string
|
||||||
|
protocol*: string
|
||||||
|
readyState*: ReadyState
|
||||||
|
masked*: bool # send masked packets
|
||||||
|
rng*: ref BrHmacDrbgContext
|
||||||
|
|
||||||
|
WebSocketError* = object of IOError
|
||||||
|
|
||||||
|
Base16Error* = object of CatchableError
|
||||||
|
## Base16 specific exception type
|
||||||
|
|
||||||
|
HeaderFlag* {.size: sizeof(uint8).} = enum
|
||||||
|
rsv3
|
||||||
|
rsv2
|
||||||
|
rsv1
|
||||||
|
fin
|
||||||
|
HeaderFlags = set[HeaderFlag]
|
||||||
|
|
||||||
|
HttpCode* = enum
|
||||||
|
Http101 = 101 # Switching Protocols
|
||||||
|
|
||||||
|
# Forward declare
|
||||||
|
proc close*(ws: WebSocket, initiator: bool = true) {.async.}
|
||||||
|
|
||||||
|
proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} =
|
||||||
|
## Handles the websocket handshake.
|
||||||
|
discard parseSaturatedNatural(header["Sec-WebSocket-Version"], ws.version)
|
||||||
|
if ws.version != 13:
|
||||||
|
raise newException(WebSocketError, "Websocket version not supported, Version: " &
|
||||||
|
header["Sec-WebSocket-Version"])
|
||||||
|
|
||||||
|
ws.key = header["Sec-WebSocket-Key"].strip()
|
||||||
|
if header.contains("Sec-WebSocket-Protocol"):
|
||||||
|
let wantProtocol = header["Sec-WebSocket-Protocol"].strip()
|
||||||
|
if ws.protocol != wantProtocol:
|
||||||
|
raise newException(WebSocketError,
|
||||||
|
"Protocol mismatch (expected: " & ws.protocol & ", got: " &
|
||||||
|
wantProtocol & ")")
|
||||||
|
|
||||||
|
var acceptKey: string
|
||||||
|
try:
|
||||||
|
let sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
acceptKey = stewBase.Base64.encode(hexToByteArray[SHA1DigestSize]($sh))
|
||||||
|
except ValueError:
|
||||||
|
raise newException(
|
||||||
|
WebSocketError, "Failed to generate accept key: " & getCurrentExceptionMsg())
|
||||||
|
|
||||||
|
var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF
|
||||||
|
response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF)
|
||||||
|
response.add("Connection: Upgrade" & CRLF)
|
||||||
|
response.add("Upgrade: webSocket" & CRLF)
|
||||||
|
|
||||||
|
if ws.protocol != "":
|
||||||
|
response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF)
|
||||||
|
response.add CRLF
|
||||||
|
|
||||||
|
let res = await ws.tcpSocket.write(response)
|
||||||
|
if res != len(response):
|
||||||
|
raise newException(WebSocketError, "Failed to send handshake response to client")
|
||||||
|
ws.readyState = Open
|
||||||
|
|
||||||
|
proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport,
|
||||||
|
protocol: string = ""): Future[WebSocket] {.async.} =
|
||||||
|
## Creates a new socket from a request.
|
||||||
|
try:
|
||||||
|
if not header.contains("Sec-WebSocket-Version"):
|
||||||
|
raise newException(WebSocketError, "Invalid WebSocket handshake")
|
||||||
|
var ws = WebSocket(tcpSocket: transp, protocol: protocol, masked: false,
|
||||||
|
rng: newRng())
|
||||||
|
await ws.handshake(header)
|
||||||
|
return ws
|
||||||
|
except ValueError, KeyError:
|
||||||
|
# Wrap all exceptions in a WebSocketError so its easy to catch.
|
||||||
|
raise newException(
|
||||||
|
WebSocketError,
|
||||||
|
"Failed to create WebSocket from request: " & getCurrentExceptionMsg()
|
||||||
|
)
|
||||||
|
|
||||||
|
type
|
||||||
|
Opcode* = enum
|
||||||
|
## 4 bits. Defines the interpretation of the "Payload data".
|
||||||
|
Cont = 0x0 ## Denotes a continuation frame.
|
||||||
|
Text = 0x1 ## Denotes a text frame.
|
||||||
|
Binary = 0x2 ## Denotes a binary frame.
|
||||||
|
# 3-7 are reserved for further non-control frames.
|
||||||
|
Close = 0x8 ## Denotes a connection close.
|
||||||
|
Ping = 0x9 ## Denotes a ping.
|
||||||
|
Pong = 0xa ## Denotes a pong.
|
||||||
|
# B-F are reserved for further control frames.
|
||||||
|
|
||||||
|
#[
|
||||||
|
+---------------------------------------------------------------+
|
||||||
|
|0 1 2 3 |
|
||||||
|
|0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1|
|
||||||
|
+-+-+-+-+-------+-+-------------+-------------------------------+
|
||||||
|
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||||
|
|I|S|S|S| (4) |A| (7) | (16/64) |
|
||||||
|
|N|V|V|V| |S| | (if payload len==126/127) |
|
||||||
|
| |1|2|3| |K| | |
|
||||||
|
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||||
|
| Extended payload length continued, if payload len == 127 |
|
||||||
|
+ - - - - - - - - - - - - - - - +-------------------------------+
|
||||||
|
| |Masking-key, if MASK set to 1 |
|
||||||
|
+-------------------------------+-------------------------------+
|
||||||
|
| Masking-key (continued) | Payload Data |
|
||||||
|
+-------------------------------- - - - - - - - - - - - - - - - +
|
||||||
|
: Payload Data continued ... :
|
||||||
|
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||||
|
| Payload Data continued ... |
|
||||||
|
+---------------------------------------------------------------+
|
||||||
|
]#
|
||||||
|
|
||||||
|
MsgReader = ref object
|
||||||
|
tcpSocket: StreamTransport
|
||||||
|
readErr: IOError
|
||||||
|
readLen: uint64
|
||||||
|
readRemaining: uint64
|
||||||
|
readFinal: bool ## true the current message has more frames.
|
||||||
|
opcode: Opcode ## Defines the interpretation of the "Payload data".
|
||||||
|
maskKey: array[4, char] ## Masking key
|
||||||
|
mask: bool ## Defines whether the "Payload data" is masked.
|
||||||
|
|
||||||
|
Frame = ref object
|
||||||
|
fin: bool ## Indicates that this is the final fragment in a message.
|
||||||
|
rsv1: bool ## MUST be 0 unless negotiated that defines meanings
|
||||||
|
rsv2: bool ## MUST be 0
|
||||||
|
rsv3: bool ## MUST be 0
|
||||||
|
opcode: Opcode ## Defines the interpretation of the "Payload data".
|
||||||
|
mask: bool ## Defines whether the "Payload data" is masked.
|
||||||
|
data: seq[byte] ## Payload data
|
||||||
|
maskKey: array[4, char] ## Masking key
|
||||||
|
length: uint64 ## Message size.
|
||||||
|
|
||||||
|
proc encodeFrame(f: Frame): seq[byte] =
|
||||||
|
## Encodes a frame into a string buffer.
|
||||||
|
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
|
||||||
|
var ret = newSeqOfCap[byte](f.data.len + WSHeaderSize)
|
||||||
|
|
||||||
|
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
|
||||||
|
if f.fin:
|
||||||
|
b0 = b0 or 128u8
|
||||||
|
|
||||||
|
ret.add(b0)
|
||||||
|
|
||||||
|
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
|
||||||
|
# 1st byte: payload len start and mask bit.
|
||||||
|
var b1 = 0u8
|
||||||
|
|
||||||
|
if f.data.len <= 125:
|
||||||
|
b1 = f.data.len.uint8
|
||||||
|
elif f.data.len > 125 and f.data.len <= 0xffff:
|
||||||
|
b1 = 126u8
|
||||||
|
else:
|
||||||
|
b1 = 127u8
|
||||||
|
|
||||||
|
if f.mask:
|
||||||
|
b1 = b1 or (1 shl 7)
|
||||||
|
|
||||||
|
ret.add(uint8 b1)
|
||||||
|
|
||||||
|
# Only need more bytes if data len is 7+16 bits, or 7+64 bits.
|
||||||
|
if f.data.len > 125 and f.data.len <= 0xffff:
|
||||||
|
# Data len is 7+16 bits.
|
||||||
|
var len = f.data.len.uint16
|
||||||
|
ret.add ((len shr 8) and 255).uint8
|
||||||
|
ret.add (len and 255).uint8
|
||||||
|
elif f.data.len > 0xffff:
|
||||||
|
# Data len is 7+64 bits.
|
||||||
|
var len = f.data.len
|
||||||
|
ret.add(f.data.len.uint64.toBE().toBytesBE())
|
||||||
|
|
||||||
|
var data = f.data
|
||||||
|
|
||||||
|
if f.mask:
|
||||||
|
# If we need to mask it generate random mask key and mask the data.
|
||||||
|
for i in 0..<data.len:
|
||||||
|
data[i] = (data[i].uint8 xor f.maskKey[i mod 4].uint8)
|
||||||
|
# Write mask key next.
|
||||||
|
ret.add(f.maskKey[0].uint8)
|
||||||
|
ret.add(f.maskKey[1].uint8)
|
||||||
|
ret.add(f.maskKey[2].uint8)
|
||||||
|
ret.add(f.maskKey[3].uint8)
|
||||||
|
|
||||||
|
# Write the data.
|
||||||
|
ret.add(data)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
proc send*(ws: WebSocket, data: seq[byte], opcode = Opcode.Text): Future[
|
||||||
|
void] {.async.} =
|
||||||
|
try:
|
||||||
|
var maskKey: array[4, char]
|
||||||
|
if ws.masked:
|
||||||
|
maskKey = genMaskKey(ws.rng)
|
||||||
|
|
||||||
|
var inFrame = Frame(
|
||||||
|
fin: true,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: opcode,
|
||||||
|
mask: ws.masked,
|
||||||
|
data: data,
|
||||||
|
maskKey: maskKey)
|
||||||
|
var frame = encodeFrame(inFrame)
|
||||||
|
const maxSize = 1024*1024
|
||||||
|
# Send stuff in 1 megabyte chunks to prevent IOErrors.
|
||||||
|
# This really large packets.
|
||||||
|
var i = 0
|
||||||
|
while i < frame.len:
|
||||||
|
let frameSize = min(frame.len, i + maxSize)
|
||||||
|
let res = await ws.tcpSocket.write(frame[i ..< frameSize])
|
||||||
|
if res != frameSize:
|
||||||
|
raise newException(ValueError, "Error while send websocket frame")
|
||||||
|
i += maxSize
|
||||||
|
except OSError, ValueError:
|
||||||
|
# Wrap all exceptions in a WebSocketError so its easy to catch
|
||||||
|
raise newException(WebSocketError, "Failed to send data: " &
|
||||||
|
getCurrentExceptionMsg())
|
||||||
|
|
||||||
|
proc sendStr*(ws: WebSocket, data: string, opcode = Opcode.Text): Future[void] =
|
||||||
|
send(ws, toBytes(data), opcode)
|
||||||
|
|
||||||
|
proc readFrame(ws: WebSocket): Future[Frame] {.async.} =
|
||||||
|
## Gets a frame from the WebSocket.
|
||||||
|
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
|
||||||
|
# Grab the header.
|
||||||
|
var header = newSeq[byte](2)
|
||||||
|
try:
|
||||||
|
await ws.tcpSocket.readExactly(addr header[0], 2)
|
||||||
|
except TransportUseClosedError:
|
||||||
|
ws.readyState = Closed
|
||||||
|
raise newException(WebSocketError, "Socket closed")
|
||||||
|
|
||||||
|
if header.len != 2:
|
||||||
|
ws.readyState = Closed
|
||||||
|
raise newException(WebSocketError, "Invalid websocket header length")
|
||||||
|
|
||||||
|
let b0 = header[0].uint8
|
||||||
|
let b1 = header[1].uint8
|
||||||
|
|
||||||
|
var frame: Frame
|
||||||
|
# Read the flags and fin from the header.
|
||||||
|
|
||||||
|
var hf = cast[HeaderFlags](b0 shr 4)
|
||||||
|
frame.fin = fin in hf
|
||||||
|
frame.rsv1 = rsv1 in hf
|
||||||
|
frame.rsv2 = rsv2 in hf
|
||||||
|
frame.rsv3 = rsv3 in hf
|
||||||
|
|
||||||
|
var opcode = b0 and 0x0f
|
||||||
|
if opcode notin WSOpCode:
|
||||||
|
raise newException(WebSocketError, "Unexpected websocket opcode")
|
||||||
|
frame.opcode = (opcode).Opcode
|
||||||
|
|
||||||
|
# If any of the rsv are set close the socket.
|
||||||
|
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||||
|
ws.readyState = Closed
|
||||||
|
raise newException(WebSocketError, "WebSocket rsv mismatch")
|
||||||
|
|
||||||
|
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
|
||||||
|
var finalLen: uint64 = 0
|
||||||
|
|
||||||
|
let headerLen = uint(b1 and 0x7f)
|
||||||
|
if headerLen == 0x7e:
|
||||||
|
# Length must be 7+16 bits.
|
||||||
|
var length = newSeq[byte](2)
|
||||||
|
await ws.tcpSocket.readExactly(addr length[0], 2)
|
||||||
|
finalLen = cast[ptr uint16](length[0].addr)[].toBE
|
||||||
|
elif headerLen == 0x7f:
|
||||||
|
# Length must be 7+64 bits.
|
||||||
|
var length = newSeq[byte](8)
|
||||||
|
await ws.tcpSocket.readExactly(addr length[0], 8)
|
||||||
|
finalLen = cast[ptr uint64](length[0].addr)[].toBE
|
||||||
|
else:
|
||||||
|
# Length must be 7 bits.
|
||||||
|
finalLen = headerLen
|
||||||
|
frame.length = finalLen
|
||||||
|
|
||||||
|
# Do we need to apply mask?
|
||||||
|
frame.mask = (b1 and 0x80) == 0x80
|
||||||
|
|
||||||
|
if ws.masked == frame.mask:
|
||||||
|
# Server sends unmasked but accepts only masked.
|
||||||
|
# Client sends masked but accepts only unmasked.
|
||||||
|
raise newException(WebSocketError, "Socket mask mismatch")
|
||||||
|
|
||||||
|
var maskKey = newSeq[byte](4)
|
||||||
|
if frame.mask:
|
||||||
|
# Read the mask.
|
||||||
|
await ws.tcpSocket.readExactly(addr maskKey[0], 4)
|
||||||
|
for i in 0..<maskKey.len:
|
||||||
|
frame.maskKey[i] = cast[char](maskKey[i])
|
||||||
|
|
||||||
|
if (frame.opcode == Text) or (frame.opcode == Opcode.Cont) or (frame.opcode == Opcode.Binary) :
|
||||||
|
return frame
|
||||||
|
|
||||||
|
# TODO: Add check for max size for control frames.
|
||||||
|
var data = newSeq[byte](finalLen)
|
||||||
|
|
||||||
|
# Read control frame payload.
|
||||||
|
if frame.length > 0 :
|
||||||
|
# Read the data.
|
||||||
|
await ws.tcpSocket.readExactly(addr data[0], int finalLen)
|
||||||
|
frame.data = data
|
||||||
|
|
||||||
|
# Process control frame payload.
|
||||||
|
if frame.opcode == Ping:
|
||||||
|
await ws.send(data, Pong)
|
||||||
|
elif frame.opcode == Pong:
|
||||||
|
discard
|
||||||
|
elif frame.opcode == Close:
|
||||||
|
await ws.close(false)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
proc close*(ws: WebSocket, initiator: bool = true) {.async.} =
|
||||||
|
## Close the Socket, sends close packet.
|
||||||
|
if ws.readyState == Closed:
|
||||||
|
discard ws.tcpSocket.closeWait()
|
||||||
|
return
|
||||||
|
ws.readyState = Closed
|
||||||
|
await ws.send(@[], Close)
|
||||||
|
if initiator == true:
|
||||||
|
let frame = await ws.readFrame()
|
||||||
|
if frame.opcode != Close:
|
||||||
|
echo "Different packet type"
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
proc readMessage*(msgReader: MsgReader,data: seq[byte]): MsgReader {.async.} =
|
||||||
|
while msgReader.readErr == nil:
|
||||||
|
if msgReader.readRemaining > 0 :
|
||||||
|
len = size(data)
|
||||||
|
if len > msgReader.readRemaining:
|
||||||
|
len = msgReader.readRemaining
|
||||||
|
|
||||||
|
await msgReader.tcpSocket.readExactly(addr data, len)
|
||||||
|
msgReader.readRemaining = msgReader.readRemaining - len
|
||||||
|
msgReader.readLen = len
|
||||||
|
|
||||||
|
if msgReader.mask:
|
||||||
|
# Apply mask, if we need too.
|
||||||
|
for i in 0 ..< len:
|
||||||
|
data[i] = (data[i].uint8 xor msgReader.maskKey[i mod 4].uint8)
|
||||||
|
|
||||||
|
if msgReader.readRemaining == 0:
|
||||||
|
msgReader.readErr = EOFError
|
||||||
|
|
||||||
|
return msgReader
|
||||||
|
|
||||||
|
if msgReader.readFinal:
|
||||||
|
msgReader.readLen = 0
|
||||||
|
msgReader.readErr = EOFError
|
||||||
|
return msgReader
|
||||||
|
|
||||||
|
var frame = await ws.readFrame()
|
||||||
|
if frame.fin:
|
||||||
|
msgReader.readFinal = true
|
||||||
|
msgReader.readRemaining = frame.length
|
||||||
|
|
||||||
|
# Non-control frames cannot occur in the middle of a fragmented non-control frame.
|
||||||
|
if frame.Opcode in Text || Binary:
|
||||||
|
raise newException("websocket: internal error, unexpected text or binary in Reader")
|
||||||
|
return msgReader
|
||||||
|
|
||||||
|
proc nextMessageReader*(ws: WebSocket): MsgReader =
|
||||||
|
while true:
|
||||||
|
# Handle control frames and return only on non control frames.
|
||||||
|
var frame = await ws.readFrame()
|
||||||
|
if frame.Opcode in Text || Binary:
|
||||||
|
var msgReader: MsgReader
|
||||||
|
msgReader.readFinal = frame.fin
|
||||||
|
msgReader.readRemaining = frame.readRemaining
|
||||||
|
msgReader.tcpSocket = ws.tcpSocket
|
||||||
|
msgReader.mask = frame.mask
|
||||||
|
msgReader.maskKey = frame.maskKey
|
||||||
|
return msgReader
|
||||||
|
|
||||||
|
proc receiveStrPacket*(ws: WebSocket): Future[seq[byte]] {.async.} =
|
||||||
|
# TODO: remove this once PR is approved.
|
||||||
|
return nil
|
||||||
|
|
||||||
|
proc validateWSClientHandshake*(transp: StreamTransport,
|
||||||
|
header: HttpResponseHeader): void =
|
||||||
|
if header.code != ord(Http101):
|
||||||
|
raise newException(WebSocketError, "Server did not reply with a websocket upgrade: " &
|
||||||
|
"Header code: " & $header.code &
|
||||||
|
"Header reason: " & header.reason() &
|
||||||
|
"Address: " & $transp.remoteAddress())
|
||||||
|
|
||||||
|
proc newWebsocketClient*(uri: Uri, protocols: seq[string] = @[]): Future[
|
||||||
|
WebSocket] {.async.} =
|
||||||
|
var key = encode(genWebSecKey(newRng()))
|
||||||
|
var uri = uri
|
||||||
|
case uri.scheme
|
||||||
|
of "ws":
|
||||||
|
uri.scheme = "http"
|
||||||
|
else:
|
||||||
|
raise newException(WebSocketError, "uri scheme has to be 'ws'")
|
||||||
|
|
||||||
|
var headers = newHttpHeaders({
|
||||||
|
"Connection": "Upgrade",
|
||||||
|
"Upgrade": "websocket",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Sec-WebSocket-Version": "13",
|
||||||
|
"Sec-WebSocket-Key": key
|
||||||
|
})
|
||||||
|
if protocols.len != 0:
|
||||||
|
headers.table["Sec-WebSocket-Protocol"] = @[protocols.join(", ")]
|
||||||
|
|
||||||
|
let client = newHttpClient(headers)
|
||||||
|
var response = await client.request($uri, "GET", headers = headers)
|
||||||
|
var header = response.parseResponse()
|
||||||
|
if header.failed():
|
||||||
|
# Header could not be parsed
|
||||||
|
raise newException(WebSocketError, "Malformed header received: " &
|
||||||
|
$client.transp.remoteAddress())
|
||||||
|
client.transp.validateWSClientHandshake(header)
|
||||||
|
|
||||||
|
# Client data should be masked.
|
||||||
|
return WebSocket(tcpSocket: client.transp, readyState: Open, masked: true,
|
||||||
|
rng: newRng())
|
||||||
|
|
||||||
|
proc newWebsocketClient*(host: string, port: Port, path: string,
|
||||||
|
protocols: seq[string] = @[]): Future[WebSocket] {.async.} =
|
||||||
|
var uri = "ws://" & host & ":" & $port
|
||||||
|
if path.startsWith("/"):
|
||||||
|
uri.add path
|
||||||
|
else:
|
||||||
|
uri.add "/" & path
|
||||||
|
return await newWebsocketClient(parseUri(uri), protocols)
|
@ -1,7 +0,0 @@
|
|||||||
import ws, nativesockets, chronos
|
|
||||||
|
|
||||||
discard waitFor newAsyncWebsocketClient("localhost", Port(8080), path = "/", protocols = @["myfancyprotocol"])
|
|
||||||
echo "connected"
|
|
||||||
|
|
||||||
runForever()
|
|
||||||
|
|
@ -1 +0,0 @@
|
|||||||
switch("path", "$projectDir/../src")
|
|
@ -1,20 +0,0 @@
|
|||||||
import ws, chronos, chronicles, httputils
|
|
||||||
|
|
||||||
proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
|
||||||
info "Header: ", uri = header.uri()
|
|
||||||
if header.uri() == "/ws":
|
|
||||||
info "Initiating web socket connection."
|
|
||||||
try:
|
|
||||||
var ws = await newWebSocket(header, transp)
|
|
||||||
echo await ws.receivePacket()
|
|
||||||
info "Websocket handshake completed."
|
|
||||||
except WebSocketError:
|
|
||||||
echo "socket closed:", getCurrentExceptionMsg()
|
|
||||||
|
|
||||||
let res = await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World")
|
|
||||||
|
|
||||||
when isMainModule:
|
|
||||||
let address = "127.0.0.1:8888"
|
|
||||||
var httpServer = newHttpServer(address, cb)
|
|
||||||
httpServer.start()
|
|
||||||
waitFor httpServer.join()
|
|
76
tests/frame.nim
Normal file
76
tests/frame.nim
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
include ../src/ws
|
||||||
|
include ../src/http
|
||||||
|
include ../src/random
|
||||||
|
#import chronos, chronicles, httputils, strutils, base64, std/sha1,
|
||||||
|
# streams, nativesockets, uri, times, chronos/timer, tables
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# TODO: Fix Test.
|
||||||
|
|
||||||
|
var maskKey: array[4, char]
|
||||||
|
|
||||||
|
suite "tests for encodeFrame()":
|
||||||
|
test "# 7bit length":
|
||||||
|
block: # 7bit length
|
||||||
|
assert encodeFrame((
|
||||||
|
fin: true,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: Opcode.Text,
|
||||||
|
mask: false,
|
||||||
|
data: toBytes("hi there"),
|
||||||
|
maskKey: maskKey
|
||||||
|
)) == toBytes("\129\8hi there")
|
||||||
|
test "# 7bit length":
|
||||||
|
block: # 7+16 bits length
|
||||||
|
var data = ""
|
||||||
|
for i in 0..32:
|
||||||
|
data.add "How are you this is the payload!!!"
|
||||||
|
assert encodeFrame((
|
||||||
|
fin: true,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: Opcode.Text,
|
||||||
|
mask: false,
|
||||||
|
data: toBytes(data),
|
||||||
|
maskKey: maskKey
|
||||||
|
))[0..32] == toBytes("\129~\4bHow are you this is the paylo")
|
||||||
|
test "# 7+64 bits length":
|
||||||
|
block: # 7+64 bits length
|
||||||
|
var data = ""
|
||||||
|
for i in 0..3200:
|
||||||
|
data.add "How are you this is the payload!!!"
|
||||||
|
assert encodeFrame((
|
||||||
|
fin: true,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: Opcode.Text,
|
||||||
|
mask: false,
|
||||||
|
data: toBytes(data),
|
||||||
|
maskKey: maskKey
|
||||||
|
))[0..32] == toBytes("\129\127\0\0\0\0\0\1\169\"How are you this is the")
|
||||||
|
test "# masking":
|
||||||
|
block: # masking
|
||||||
|
let data = encodeFrame((
|
||||||
|
fin: true,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: Opcode.Text,
|
||||||
|
mask: true,
|
||||||
|
data: toBytes("hi there"),
|
||||||
|
maskKey: ['\xCF', '\xD8', '\x05', 'e']
|
||||||
|
))
|
||||||
|
assert data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0")
|
||||||
|
|
||||||
|
suite "tests for toTitleCase()":
|
||||||
|
block:
|
||||||
|
let val = toTitleCase("webSocket")
|
||||||
|
assert val == "Websocket"
|
||||||
|
|
||||||
|
|
||||||
|
|
55
tests/helpers.nim
Normal file
55
tests/helpers.nim
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import ../src/ws, chronos, chronicles, httputils, stew/byteutils,
|
||||||
|
../src/http, unittest, strutils
|
||||||
|
|
||||||
|
proc cb*(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||||
|
info "Handling request:", uri = header.uri()
|
||||||
|
if header.uri() == "/ws":
|
||||||
|
info "Initiating web socket connection."
|
||||||
|
try:
|
||||||
|
var ws = await newWebSocket(header, transp, "myfancyprotocol")
|
||||||
|
if ws.readyState == Open:
|
||||||
|
info "Websocket handshake completed."
|
||||||
|
else:
|
||||||
|
error "Failed to open websocket connection."
|
||||||
|
return
|
||||||
|
|
||||||
|
while ws.readyState == Open:
|
||||||
|
let recvData = await ws.receiveStrPacket()
|
||||||
|
info "Server:", state = ws.readyState
|
||||||
|
await ws.send(recvData)
|
||||||
|
except WebSocketError:
|
||||||
|
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||||
|
discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established")
|
||||||
|
|
||||||
|
proc sendRecvClientData*(wsClient: WebSocket, msg: string) {.async.} =
|
||||||
|
try:
|
||||||
|
waitFor wsClient.sendStr(msg)
|
||||||
|
let recvData = waitFor wsClient.receiveStrPacket()
|
||||||
|
info "Websocket client state: ", state = wsClient.readyState
|
||||||
|
let dataStr = string.fromBytes(recvData)
|
||||||
|
require dataStr == msg
|
||||||
|
|
||||||
|
except WebSocketError:
|
||||||
|
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||||
|
|
||||||
|
proc incorrectProtocolCB*(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||||
|
info "Handling request:", uri = header.uri()
|
||||||
|
var isErr = false;
|
||||||
|
if header.uri() == "/ws":
|
||||||
|
info "Initiating web socket connection."
|
||||||
|
try:
|
||||||
|
var ws = await newWebSocket(header, transp, "myfancyprotocol")
|
||||||
|
require ws.readyState == ReadyState.Closed
|
||||||
|
except WebSocketError:
|
||||||
|
isErr = true;
|
||||||
|
require contains(getCurrentExceptionMsg(), "Protocol mismatch")
|
||||||
|
finally:
|
||||||
|
require isErr == true
|
||||||
|
discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established")
|
||||||
|
|
||||||
|
|
||||||
|
proc generateData*(num: int64): seq[byte] =
|
||||||
|
var str = newSeqOfCap[byte](num)
|
||||||
|
for i in 0 ..< num:
|
||||||
|
str.add(65)
|
||||||
|
return str
|
87
tests/websocket.nim
Normal file
87
tests/websocket.nim
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import helpers, unittest, ../src/http, chronos, ../src/ws,../src/random,
|
||||||
|
stew/byteutils, os, strutils
|
||||||
|
|
||||||
|
var httpServer: HttpServer
|
||||||
|
|
||||||
|
proc startServer() {.async, gcsafe.} =
|
||||||
|
httpServer = newHttpServer("127.0.0.1:8888", cb)
|
||||||
|
httpServer.start()
|
||||||
|
|
||||||
|
proc closeServer() {.async, gcsafe.} =
|
||||||
|
httpServer.stop()
|
||||||
|
waitFor httpServer.closeWait()
|
||||||
|
|
||||||
|
suite "Test websocket error cases":
|
||||||
|
teardown:
|
||||||
|
httpServer.stop()
|
||||||
|
waitFor httpServer.closeWait()
|
||||||
|
|
||||||
|
test "Test for incorrect protocol":
|
||||||
|
httpServer = newHttpServer("127.0.0.1:8888", incorrectProtocolCB)
|
||||||
|
httpServer.start()
|
||||||
|
try:
|
||||||
|
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888),
|
||||||
|
path = "/ws", protocols = @["mywrongprotocol"])
|
||||||
|
except WebSocketError:
|
||||||
|
require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade")
|
||||||
|
|
||||||
|
test "Test for incorrect port":
|
||||||
|
httpServer = newHttpServer("127.0.0.1:8888", cb)
|
||||||
|
httpServer.start()
|
||||||
|
try:
|
||||||
|
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8889),
|
||||||
|
path = "/ws", protocols = @["myfancyprotocol"])
|
||||||
|
except:
|
||||||
|
require contains(getCurrentExceptionMsg(), "Connection refused")
|
||||||
|
|
||||||
|
test "Test for incorrect path":
|
||||||
|
httpServer = newHttpServer("127.0.0.1:8888", cb)
|
||||||
|
httpServer.start()
|
||||||
|
try:
|
||||||
|
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888),
|
||||||
|
path = "/gg", protocols = @["myfancyprotocol"])
|
||||||
|
except:
|
||||||
|
require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade")
|
||||||
|
|
||||||
|
suite "Misc Test":
|
||||||
|
setup:
|
||||||
|
waitFor startServer()
|
||||||
|
teardown:
|
||||||
|
waitFor closeServer()
|
||||||
|
|
||||||
|
test "Test for maskKey":
|
||||||
|
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws",
|
||||||
|
protocols = @["myfancyprotocol"])
|
||||||
|
let maskKey = genMaskKey(wsClient.rng)
|
||||||
|
require maskKey.len == 4
|
||||||
|
|
||||||
|
test "Test for toCaseInsensitive":
|
||||||
|
let headers = newHttpHeaders()
|
||||||
|
require toCaseInsensitive(headers, "webSocket") == "Websocket"
|
||||||
|
|
||||||
|
|
||||||
|
suite "Test web socket communication":
|
||||||
|
|
||||||
|
setup:
|
||||||
|
waitFor startServer()
|
||||||
|
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888),
|
||||||
|
path = "/ws", protocols = @["myfancyprotocol"])
|
||||||
|
|
||||||
|
teardown:
|
||||||
|
waitFor closeServer()
|
||||||
|
|
||||||
|
test "Websocket conversation between client and server":
|
||||||
|
waitFor sendRecvClientData(wsClient, "Hello Server")
|
||||||
|
|
||||||
|
test "Test for small message ":
|
||||||
|
let msg = string.fromBytes(generateData(100))
|
||||||
|
waitFor sendRecvClientData(wsClient, msg)
|
||||||
|
|
||||||
|
test "Test for medium message ":
|
||||||
|
let msg = string.fromBytes(generateData(1000))
|
||||||
|
waitFor sendRecvClientData(wsClient, msg)
|
||||||
|
|
||||||
|
test "Test for large message ":
|
||||||
|
let msg = string.fromBytes(generateData(10000))
|
||||||
|
waitFor sendRecvClientData(wsClient, msg)
|
||||||
|
|
@ -4,12 +4,15 @@ author = "Status Research & Development GmbH"
|
|||||||
description = "WS protocol implementation"
|
description = "WS protocol implementation"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|
||||||
requires "nim >= 1.2.6"
|
requires "nim == 1.2.6"
|
||||||
requires "chronos >= 2.5.2 & < 3.0.0"
|
requires "chronos >= 2.5.2"
|
||||||
requires "httputils >= 0.2.0"
|
requires "httputils >= 0.2.0"
|
||||||
requires "chronicles >= 0.10.0"
|
requires "chronicles >= 0.10.0"
|
||||||
requires "urlly >= 0.2.0"
|
requires "urlly >= 0.2.0"
|
||||||
requires "uri"
|
requires "stew >= 0.1.0"
|
||||||
|
requires "eth"
|
||||||
|
requires "asynctest >= 0.2.0 & < 0.3.0"
|
||||||
|
requires "nimcrypto"
|
||||||
|
|
||||||
task lint, "format source files according to the official style guide":
|
task lint, "format source files according to the official style guide":
|
||||||
exec "./lint.nims"
|
exec "./lint.nims"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user