mirror of
https://github.com/status-im/nim-websock.git
synced 2025-01-20 22:38:57 +00:00
[WIP] Web socket client implementation. (#2)
* Implement websocket server. * Implement websocket client. * Run nimpretty. * Remove commented code. * Address comments. * Address comments on websocket server. * Use seq[byte] to store data. * Working bytes conversion. * Remove result from return * Refactor the code. * Minor change. * Add test. * Add websocket test and fix closing handshake. * Add MsgReader to read data in external buffer.
This commit is contained in:
parent
8e34e0f138
commit
a1ae7d2c70
21
examples/client.nim
Normal file
21
examples/client.nim
Normal file
@ -0,0 +1,21 @@
|
||||
import ../src/ws, nativesockets, chronos, os, chronicles, stew/byteutils
|
||||
|
||||
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws",
|
||||
protocols = @["myfancyprotocol"])
|
||||
info "Websocket client: ", State = wsClient.readyState
|
||||
|
||||
let reqData = "Hello Server"
|
||||
for idx in 1 .. 5:
|
||||
try:
|
||||
waitFor wsClient.sendStr(reqData)
|
||||
let recvData = waitFor wsClient.receiveStrPacket()
|
||||
let dataStr = string.fromBytes(recvData)
|
||||
info "Server:", data = dataStr
|
||||
assert dataStr == reqData
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
os.sleep(1000)
|
||||
|
||||
# close the websocket
|
||||
waitFor wsClient.close()
|
||||
|
41
examples/server.nim
Normal file
41
examples/server.nim
Normal file
@ -0,0 +1,41 @@
|
||||
import ../src/ws, ../src/http, chronos, chronicles, httputils, stew/byteutils
|
||||
|
||||
proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||
info "Handling request:", uri = header.uri()
|
||||
if header.uri() == "/ws":
|
||||
info "Initiating web socket connection."
|
||||
try:
|
||||
var ws = await newWebSocket(header, transp, "myfancyprotocol")
|
||||
if ws.readyState == Open:
|
||||
info "Websocket handshake completed."
|
||||
else:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
|
||||
while true:
|
||||
# Only reads header for data frame.
|
||||
let msgReader = await ws.nextMessageReader()
|
||||
|
||||
# Read the frame payload in buffer.
|
||||
let buffer = newSeq[byte](100)
|
||||
var recvData :seq[byte]
|
||||
while msgReader.error != EOFError:
|
||||
msgReader.readMessage(buffer)
|
||||
recvData.add buffer
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
return
|
||||
info "Response: ", data = recvData
|
||||
await ws.send(recvData)
|
||||
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
|
||||
discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World")
|
||||
await transp.closeWait()
|
||||
|
||||
when isMainModule:
|
||||
let address = "127.0.0.1:8888"
|
||||
var httpServer = newHttpServer(address, cb)
|
||||
httpServer.start()
|
||||
echo "Server started..."
|
||||
waitFor httpServer.join()
|
250
src/http.nim
Normal file
250
src/http.nim
Normal file
@ -0,0 +1,250 @@
|
||||
import chronos, chronos/timer, httputils, chronicles, uri, tables, strutils
|
||||
|
||||
const
|
||||
MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets
|
||||
MaxHttpRequestSize = 128 * 1024 # maximum size of HTTP body in octets
|
||||
HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec)
|
||||
CRLF* = "\r\n"
|
||||
HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')]
|
||||
|
||||
type
|
||||
HttpClient* = ref object
|
||||
connected: bool
|
||||
currentURL: Uri ## Where we are currently connected.
|
||||
headers: HttpHeaders ## Headers to send in requests.
|
||||
transp*: StreamTransport
|
||||
buf: seq[byte]
|
||||
|
||||
HttpHeaders* = object
|
||||
table*: TableRef[string, seq[string]]
|
||||
|
||||
ReqStatus = enum
|
||||
Success, Error, ErrorFailure
|
||||
|
||||
AsyncCallback = proc (transp: StreamTransport,
|
||||
header: HttpRequestHeader): Future[void] {.closure, gcsafe.}
|
||||
HttpServer* = ref object of StreamServer
|
||||
callback: AsyncCallback
|
||||
|
||||
proc recvData(transp: StreamTransport): Future[seq[byte]] {.async.} =
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
var error = false
|
||||
try:
|
||||
let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize,
|
||||
sep = HeaderSep)
|
||||
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||
if not ores:
|
||||
# Timeout
|
||||
debug "Timeout expired while receiving headers",
|
||||
address = transp.remoteAddress()
|
||||
error = true
|
||||
else:
|
||||
let hlen = hlenfut.read()
|
||||
buffer.setLen(hlen)
|
||||
except TransportLimitError:
|
||||
# size of headers exceeds `MaxHttpHeadersSize`
|
||||
debug "Maximum size of headers limit reached",
|
||||
address = transp.remoteAddress()
|
||||
error = true
|
||||
except TransportIncompleteError:
|
||||
# remote peer disconnected
|
||||
debug "Remote peer disconnected", address = transp.remoteAddress()
|
||||
error = true
|
||||
except TransportOsError as exc:
|
||||
debug "Problems with networking", address = transp.remoteAddress(),
|
||||
error = exc.msg
|
||||
error = true
|
||||
|
||||
if error:
|
||||
buffer.setLen(0)
|
||||
return buffer
|
||||
|
||||
proc newConnection(client: HttpClient, url: Uri) {.async.} =
|
||||
if client.connected:
|
||||
return
|
||||
|
||||
let port =
|
||||
if url.port == "": 80
|
||||
else: url.port.parseInt
|
||||
|
||||
client.transp = await connect(initTAddress(url.hostname, port))
|
||||
|
||||
# May be connected through proxy but remember actual URL being accessed
|
||||
client.currentURL = url
|
||||
client.connected = true
|
||||
|
||||
proc generateHeaders(requestUrl: Uri, httpMethod: string,
|
||||
additionalHeaders: HttpHeaders): string =
|
||||
# GET
|
||||
var headers = httpMethod.toUpperAscii()
|
||||
headers.add ' '
|
||||
|
||||
if not requestUrl.path.startsWith("/"): headers.add '/'
|
||||
headers.add(requestUrl.path)
|
||||
|
||||
# HTTP/1.1\c\l
|
||||
headers.add(" HTTP/1.1" & CRLF)
|
||||
|
||||
for key, val in additionalHeaders.table:
|
||||
headers.add(key & ": " & val.join(", ") & CRLF)
|
||||
headers.add(CRLF)
|
||||
return headers
|
||||
|
||||
# Send request to the client. Currently only supports HTTP get method.
|
||||
proc request*(client: HttpClient, url, httpMethod: string,
|
||||
body = "", headers: HttpHeaders): Future[seq[byte]]
|
||||
{.async.} =
|
||||
# Helper that actually makes the request. Does not handle redirects.
|
||||
let requestUrl = parseUri(url)
|
||||
if requestUrl.scheme == "":
|
||||
raise newException(ValueError, "No uri scheme supplied.")
|
||||
|
||||
await newConnection(client, requestUrl)
|
||||
|
||||
let headerString = generateHeaders(requestUrl, httpMethod, headers)
|
||||
let res = await client.transp.write(headerString)
|
||||
if res != len(headerString):
|
||||
raise newException(ValueError, "Error while send request to client")
|
||||
|
||||
var value = await client.transp.recvData()
|
||||
if value.len == 0:
|
||||
raise newException(ValueError, "Empty response from server")
|
||||
return value
|
||||
|
||||
proc sendHTTPResponse*(transp: StreamTransport, version: HttpVersion, code: HttpCode,
|
||||
data: string = ""): Future[bool] {.async.} =
|
||||
var answer = $version
|
||||
answer.add(" ")
|
||||
answer.add($code)
|
||||
answer.add(CRLF)
|
||||
answer.add("Date: " & httpDate() & CRLF)
|
||||
if len(data) > 0:
|
||||
answer.add("Content-Type: application/json" & CRLF)
|
||||
answer.add("Content-Length: " & $len(data) & CRLF)
|
||||
answer.add(CRLF)
|
||||
if len(data) > 0:
|
||||
answer.add(data)
|
||||
|
||||
let res = await transp.write(answer)
|
||||
if res == len(answer):
|
||||
return true
|
||||
raise newException(IOError, "Failed to send http request.")
|
||||
|
||||
proc validateRequest(transp: StreamTransport,
|
||||
header: HttpRequestHeader): Future[ReqStatus] {.async.} =
|
||||
if header.meth notin {MethodGet}:
|
||||
debug "GET method is only allowed", address = transp.remoteAddress()
|
||||
if await transp.sendHTTPResponse(header.version, Http405):
|
||||
return Error
|
||||
else:
|
||||
return ErrorFailure
|
||||
|
||||
var hlen = header.contentLength()
|
||||
if hlen < 0 or hlen > MaxHttpRequestSize:
|
||||
debug "Invalid header length", address = transp.remoteAddress()
|
||||
if await transp.sendHTTPResponse(header.version, Http413):
|
||||
return Error
|
||||
else:
|
||||
return ErrorFailure
|
||||
|
||||
return Success
|
||||
|
||||
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||
## Process transport data to the RPC server
|
||||
var httpServer = cast[HttpServer](server)
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
var header: HttpRequestHeader
|
||||
|
||||
info "Received connection", address = $transp.remoteAddress()
|
||||
try:
|
||||
let hlenfut = transp.readUntil(addr buffer[0], MaxHttpHeadersSize,
|
||||
sep = HeaderSep)
|
||||
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||
if not ores:
|
||||
# Timeout
|
||||
debug "Timeout expired while receiving headers",
|
||||
address = transp.remoteAddress()
|
||||
discard await transp.sendHTTPResponse(HttpVersion11, Http408)
|
||||
await transp.closeWait()
|
||||
return
|
||||
else:
|
||||
let hlen = hlenfut.read()
|
||||
buffer.setLen(hlen)
|
||||
header = buffer.parseRequest()
|
||||
if header.failed():
|
||||
# Header could not be parsed
|
||||
debug "Malformed header received",
|
||||
address = transp.remoteAddress()
|
||||
discard await transp.sendHTTPResponse(HttpVersion11, Http400)
|
||||
await transp.closeWait()
|
||||
return
|
||||
var vres = await validateRequest(transp, header)
|
||||
if vres == Success:
|
||||
info "Received valid RPC request", address = $transp.remoteAddress()
|
||||
# Call the user's callback.
|
||||
if httpServer.callback != nil:
|
||||
await httpServer.callback(transp, header)
|
||||
elif vres == ErrorFailure:
|
||||
debug "Remote peer disconnected", address = transp.remoteAddress()
|
||||
except TransportLimitError:
|
||||
# size of headers exceeds `MaxHttpHeadersSize`
|
||||
debug "Maximum size of headers limit reached",
|
||||
address = transp.remoteAddress()
|
||||
discard await transp.sendHTTPResponse(HttpVersion11, Http413)
|
||||
except TransportIncompleteError:
|
||||
# remote peer disconnected
|
||||
debug "Remote peer disconnected", address = transp.remoteAddress()
|
||||
except TransportOsError as exc:
|
||||
debug "Problems with networking", address = transp.remoteAddress(),
|
||||
error = exc.msg
|
||||
except CatchableError as exc:
|
||||
debug "Unknown exception", address = transp.remoteAddress(),
|
||||
error = exc.msg
|
||||
await transp.closeWait()
|
||||
|
||||
proc newHttpServer*(address: string, handler: AsyncCallback,
|
||||
flags: set[ServerFlags] = {ReuseAddr}): HttpServer =
|
||||
let address = initTAddress(address)
|
||||
var server = HttpServer(callback: handler)
|
||||
server = cast[HttpServer](createStreamServer(address, serveClient, flags,
|
||||
child = cast[StreamServer](server)))
|
||||
return server
|
||||
|
||||
func toTitleCase(s: string): string =
|
||||
var tcstr = newString(len(s))
|
||||
var upper = true
|
||||
for i in 0..len(s) - 1:
|
||||
tcstr[i] = if upper: toUpperAscii(s[i]) else: toLowerAscii(s[i])
|
||||
upper = s[i] == '-'
|
||||
return tcstr
|
||||
|
||||
func toCaseInsensitive*(headers: HttpHeaders, s: string): string {.inline.} =
|
||||
return toTitleCase(s)
|
||||
|
||||
func newHttpHeaders*(): HttpHeaders =
|
||||
## Returns a new ``HttpHeaders`` object. if ``titleCase`` is set to true,
|
||||
## headers are passed to the server in title case (e.g. "Content-Length")
|
||||
return HttpHeaders(table: newTable[string, seq[string]]())
|
||||
|
||||
func newHttpHeaders*(keyValuePairs:
|
||||
openArray[tuple[key: string, val: string]]): HttpHeaders =
|
||||
## Returns a new ``HttpHeaders`` object from an array. if ``titleCase`` is set to true,
|
||||
## headers are passed to the server in title case (e.g. "Content-Length")
|
||||
var headers = newHttpHeaders()
|
||||
|
||||
for pair in keyValuePairs:
|
||||
let key = headers.toCaseInsensitive(pair.key)
|
||||
if key in headers.table:
|
||||
headers.table[key].add(pair.val)
|
||||
else:
|
||||
headers.table[key] = @[pair.val]
|
||||
return headers
|
||||
|
||||
proc newHttpClient*(headers = newHttpHeaders()): HttpClient =
|
||||
return HttpClient(headers: headers)
|
||||
|
||||
proc close*(client: HttpClient) =
|
||||
## Closes any connections held by the HTTP client.
|
||||
if client.connected:
|
||||
client.transp.close()
|
||||
client.connected = false
|
25
src/random.nim
Normal file
25
src/random.nim
Normal file
@ -0,0 +1,25 @@
|
||||
import bearssl
|
||||
|
||||
## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng
|
||||
const randMax = 18_446_744_073_709_551_615'u64
|
||||
|
||||
proc rand*(rng: var BrHmacDrbgContext, max: Natural): int =
|
||||
if max == 0: return 0
|
||||
var x: uint64
|
||||
while true:
|
||||
brHmacDrbgGenerate(addr rng, addr x, csize_t(sizeof(x)))
|
||||
if x < randMax - (randMax mod (uint64(max) + 1'u64)): # against modulo bias
|
||||
return int(x mod (uint64(max) + 1'u64))
|
||||
|
||||
proc genMaskKey*(rng: ref BrHmacDrbgContext): array[4, char] =
|
||||
## Generates a random key of 4 random chars.
|
||||
proc r(): char = char(rand(rng[], 255))
|
||||
return [r(), r(), r(), r()]
|
||||
|
||||
proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[char] =
|
||||
var key = newSeq[char](16)
|
||||
proc r(): char = char(rand(rng[], 255))
|
||||
## Generates a random key of 16 random chars.
|
||||
for i in 0..15:
|
||||
key.add(r())
|
||||
return key
|
451
src/ws.nim
Normal file
451
src/ws.nim
Normal file
@ -0,0 +1,451 @@
|
||||
import httputils, strutils, base64, std/sha1, ./random, http, uri,
|
||||
chronos/timer, tables, stew/byteutils, eth/[keys], stew/endians2,
|
||||
parseutils, stew/base64 as stewBase,chronos
|
||||
|
||||
const
|
||||
SHA1DigestSize = 20
|
||||
WSHeaderSize = 12
|
||||
WSOpCode = {0x00, 0x01, 0x02, 0x08, 0x09, 0x0a}
|
||||
|
||||
type
|
||||
ReadyState* = enum
|
||||
Connecting = 0 # The connection is not yet open.
|
||||
Open = 1 # The connection is open and ready to communicate.
|
||||
Closing = 2 # The connection is in the process of closing.
|
||||
Closed = 3 # The connection is closed or couldn't be opened.
|
||||
|
||||
WebSocket* = ref object
|
||||
tcpSocket*: StreamTransport
|
||||
version*: int
|
||||
key*: string
|
||||
protocol*: string
|
||||
readyState*: ReadyState
|
||||
masked*: bool # send masked packets
|
||||
rng*: ref BrHmacDrbgContext
|
||||
|
||||
WebSocketError* = object of IOError
|
||||
|
||||
Base16Error* = object of CatchableError
|
||||
## Base16 specific exception type
|
||||
|
||||
HeaderFlag* {.size: sizeof(uint8).} = enum
|
||||
rsv3
|
||||
rsv2
|
||||
rsv1
|
||||
fin
|
||||
HeaderFlags = set[HeaderFlag]
|
||||
|
||||
HttpCode* = enum
|
||||
Http101 = 101 # Switching Protocols
|
||||
|
||||
# Forward declare
|
||||
proc close*(ws: WebSocket, initiator: bool = true) {.async.}
|
||||
|
||||
proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} =
|
||||
## Handles the websocket handshake.
|
||||
discard parseSaturatedNatural(header["Sec-WebSocket-Version"], ws.version)
|
||||
if ws.version != 13:
|
||||
raise newException(WebSocketError, "Websocket version not supported, Version: " &
|
||||
header["Sec-WebSocket-Version"])
|
||||
|
||||
ws.key = header["Sec-WebSocket-Key"].strip()
|
||||
if header.contains("Sec-WebSocket-Protocol"):
|
||||
let wantProtocol = header["Sec-WebSocket-Protocol"].strip()
|
||||
if ws.protocol != wantProtocol:
|
||||
raise newException(WebSocketError,
|
||||
"Protocol mismatch (expected: " & ws.protocol & ", got: " &
|
||||
wantProtocol & ")")
|
||||
|
||||
var acceptKey: string
|
||||
try:
|
||||
let sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
acceptKey = stewBase.Base64.encode(hexToByteArray[SHA1DigestSize]($sh))
|
||||
except ValueError:
|
||||
raise newException(
|
||||
WebSocketError, "Failed to generate accept key: " & getCurrentExceptionMsg())
|
||||
|
||||
var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF
|
||||
response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF)
|
||||
response.add("Connection: Upgrade" & CRLF)
|
||||
response.add("Upgrade: webSocket" & CRLF)
|
||||
|
||||
if ws.protocol != "":
|
||||
response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF)
|
||||
response.add CRLF
|
||||
|
||||
let res = await ws.tcpSocket.write(response)
|
||||
if res != len(response):
|
||||
raise newException(WebSocketError, "Failed to send handshake response to client")
|
||||
ws.readyState = Open
|
||||
|
||||
proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport,
|
||||
protocol: string = ""): Future[WebSocket] {.async.} =
|
||||
## Creates a new socket from a request.
|
||||
try:
|
||||
if not header.contains("Sec-WebSocket-Version"):
|
||||
raise newException(WebSocketError, "Invalid WebSocket handshake")
|
||||
var ws = WebSocket(tcpSocket: transp, protocol: protocol, masked: false,
|
||||
rng: newRng())
|
||||
await ws.handshake(header)
|
||||
return ws
|
||||
except ValueError, KeyError:
|
||||
# Wrap all exceptions in a WebSocketError so its easy to catch.
|
||||
raise newException(
|
||||
WebSocketError,
|
||||
"Failed to create WebSocket from request: " & getCurrentExceptionMsg()
|
||||
)
|
||||
|
||||
type
|
||||
Opcode* = enum
|
||||
## 4 bits. Defines the interpretation of the "Payload data".
|
||||
Cont = 0x0 ## Denotes a continuation frame.
|
||||
Text = 0x1 ## Denotes a text frame.
|
||||
Binary = 0x2 ## Denotes a binary frame.
|
||||
# 3-7 are reserved for further non-control frames.
|
||||
Close = 0x8 ## Denotes a connection close.
|
||||
Ping = 0x9 ## Denotes a ping.
|
||||
Pong = 0xa ## Denotes a pong.
|
||||
# B-F are reserved for further control frames.
|
||||
|
||||
#[
|
||||
+---------------------------------------------------------------+
|
||||
|0 1 2 3 |
|
||||
|0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1|
|
||||
+-+-+-+-+-------+-+-------------+-------------------------------+
|
||||
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||
|I|S|S|S| (4) |A| (7) | (16/64) |
|
||||
|N|V|V|V| |S| | (if payload len==126/127) |
|
||||
| |1|2|3| |K| | |
|
||||
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||
| Extended payload length continued, if payload len == 127 |
|
||||
+ - - - - - - - - - - - - - - - +-------------------------------+
|
||||
| |Masking-key, if MASK set to 1 |
|
||||
+-------------------------------+-------------------------------+
|
||||
| Masking-key (continued) | Payload Data |
|
||||
+-------------------------------- - - - - - - - - - - - - - - - +
|
||||
: Payload Data continued ... :
|
||||
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||
| Payload Data continued ... |
|
||||
+---------------------------------------------------------------+
|
||||
]#
|
||||
|
||||
MsgReader = ref object
|
||||
tcpSocket: StreamTransport
|
||||
readErr: IOError
|
||||
readLen: uint64
|
||||
readRemaining: uint64
|
||||
readFinal: bool ## true the current message has more frames.
|
||||
opcode: Opcode ## Defines the interpretation of the "Payload data".
|
||||
maskKey: array[4, char] ## Masking key
|
||||
mask: bool ## Defines whether the "Payload data" is masked.
|
||||
|
||||
Frame = ref object
|
||||
fin: bool ## Indicates that this is the final fragment in a message.
|
||||
rsv1: bool ## MUST be 0 unless negotiated that defines meanings
|
||||
rsv2: bool ## MUST be 0
|
||||
rsv3: bool ## MUST be 0
|
||||
opcode: Opcode ## Defines the interpretation of the "Payload data".
|
||||
mask: bool ## Defines whether the "Payload data" is masked.
|
||||
data: seq[byte] ## Payload data
|
||||
maskKey: array[4, char] ## Masking key
|
||||
length: uint64 ## Message size.
|
||||
|
||||
proc encodeFrame(f: Frame): seq[byte] =
|
||||
## Encodes a frame into a string buffer.
|
||||
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
|
||||
var ret = newSeqOfCap[byte](f.data.len + WSHeaderSize)
|
||||
|
||||
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
|
||||
if f.fin:
|
||||
b0 = b0 or 128u8
|
||||
|
||||
ret.add(b0)
|
||||
|
||||
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
|
||||
# 1st byte: payload len start and mask bit.
|
||||
var b1 = 0u8
|
||||
|
||||
if f.data.len <= 125:
|
||||
b1 = f.data.len.uint8
|
||||
elif f.data.len > 125 and f.data.len <= 0xffff:
|
||||
b1 = 126u8
|
||||
else:
|
||||
b1 = 127u8
|
||||
|
||||
if f.mask:
|
||||
b1 = b1 or (1 shl 7)
|
||||
|
||||
ret.add(uint8 b1)
|
||||
|
||||
# Only need more bytes if data len is 7+16 bits, or 7+64 bits.
|
||||
if f.data.len > 125 and f.data.len <= 0xffff:
|
||||
# Data len is 7+16 bits.
|
||||
var len = f.data.len.uint16
|
||||
ret.add ((len shr 8) and 255).uint8
|
||||
ret.add (len and 255).uint8
|
||||
elif f.data.len > 0xffff:
|
||||
# Data len is 7+64 bits.
|
||||
var len = f.data.len
|
||||
ret.add(f.data.len.uint64.toBE().toBytesBE())
|
||||
|
||||
var data = f.data
|
||||
|
||||
if f.mask:
|
||||
# If we need to mask it generate random mask key and mask the data.
|
||||
for i in 0..<data.len:
|
||||
data[i] = (data[i].uint8 xor f.maskKey[i mod 4].uint8)
|
||||
# Write mask key next.
|
||||
ret.add(f.maskKey[0].uint8)
|
||||
ret.add(f.maskKey[1].uint8)
|
||||
ret.add(f.maskKey[2].uint8)
|
||||
ret.add(f.maskKey[3].uint8)
|
||||
|
||||
# Write the data.
|
||||
ret.add(data)
|
||||
return ret
|
||||
|
||||
proc send*(ws: WebSocket, data: seq[byte], opcode = Opcode.Text): Future[
|
||||
void] {.async.} =
|
||||
try:
|
||||
var maskKey: array[4, char]
|
||||
if ws.masked:
|
||||
maskKey = genMaskKey(ws.rng)
|
||||
|
||||
var inFrame = Frame(
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: opcode,
|
||||
mask: ws.masked,
|
||||
data: data,
|
||||
maskKey: maskKey)
|
||||
var frame = encodeFrame(inFrame)
|
||||
const maxSize = 1024*1024
|
||||
# Send stuff in 1 megabyte chunks to prevent IOErrors.
|
||||
# This really large packets.
|
||||
var i = 0
|
||||
while i < frame.len:
|
||||
let frameSize = min(frame.len, i + maxSize)
|
||||
let res = await ws.tcpSocket.write(frame[i ..< frameSize])
|
||||
if res != frameSize:
|
||||
raise newException(ValueError, "Error while send websocket frame")
|
||||
i += maxSize
|
||||
except OSError, ValueError:
|
||||
# Wrap all exceptions in a WebSocketError so its easy to catch
|
||||
raise newException(WebSocketError, "Failed to send data: " &
|
||||
getCurrentExceptionMsg())
|
||||
|
||||
proc sendStr*(ws: WebSocket, data: string, opcode = Opcode.Text): Future[void] =
|
||||
send(ws, toBytes(data), opcode)
|
||||
|
||||
proc readFrame(ws: WebSocket): Future[Frame] {.async.} =
|
||||
## Gets a frame from the WebSocket.
|
||||
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
|
||||
# Grab the header.
|
||||
var header = newSeq[byte](2)
|
||||
try:
|
||||
await ws.tcpSocket.readExactly(addr header[0], 2)
|
||||
except TransportUseClosedError:
|
||||
ws.readyState = Closed
|
||||
raise newException(WebSocketError, "Socket closed")
|
||||
|
||||
if header.len != 2:
|
||||
ws.readyState = Closed
|
||||
raise newException(WebSocketError, "Invalid websocket header length")
|
||||
|
||||
let b0 = header[0].uint8
|
||||
let b1 = header[1].uint8
|
||||
|
||||
var frame: Frame
|
||||
# Read the flags and fin from the header.
|
||||
|
||||
var hf = cast[HeaderFlags](b0 shr 4)
|
||||
frame.fin = fin in hf
|
||||
frame.rsv1 = rsv1 in hf
|
||||
frame.rsv2 = rsv2 in hf
|
||||
frame.rsv3 = rsv3 in hf
|
||||
|
||||
var opcode = b0 and 0x0f
|
||||
if opcode notin WSOpCode:
|
||||
raise newException(WebSocketError, "Unexpected websocket opcode")
|
||||
frame.opcode = (opcode).Opcode
|
||||
|
||||
# If any of the rsv are set close the socket.
|
||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||
ws.readyState = Closed
|
||||
raise newException(WebSocketError, "WebSocket rsv mismatch")
|
||||
|
||||
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
|
||||
var finalLen: uint64 = 0
|
||||
|
||||
let headerLen = uint(b1 and 0x7f)
|
||||
if headerLen == 0x7e:
|
||||
# Length must be 7+16 bits.
|
||||
var length = newSeq[byte](2)
|
||||
await ws.tcpSocket.readExactly(addr length[0], 2)
|
||||
finalLen = cast[ptr uint16](length[0].addr)[].toBE
|
||||
elif headerLen == 0x7f:
|
||||
# Length must be 7+64 bits.
|
||||
var length = newSeq[byte](8)
|
||||
await ws.tcpSocket.readExactly(addr length[0], 8)
|
||||
finalLen = cast[ptr uint64](length[0].addr)[].toBE
|
||||
else:
|
||||
# Length must be 7 bits.
|
||||
finalLen = headerLen
|
||||
frame.length = finalLen
|
||||
|
||||
# Do we need to apply mask?
|
||||
frame.mask = (b1 and 0x80) == 0x80
|
||||
|
||||
if ws.masked == frame.mask:
|
||||
# Server sends unmasked but accepts only masked.
|
||||
# Client sends masked but accepts only unmasked.
|
||||
raise newException(WebSocketError, "Socket mask mismatch")
|
||||
|
||||
var maskKey = newSeq[byte](4)
|
||||
if frame.mask:
|
||||
# Read the mask.
|
||||
await ws.tcpSocket.readExactly(addr maskKey[0], 4)
|
||||
for i in 0..<maskKey.len:
|
||||
frame.maskKey[i] = cast[char](maskKey[i])
|
||||
|
||||
if (frame.opcode == Text) or (frame.opcode == Opcode.Cont) or (frame.opcode == Opcode.Binary) :
|
||||
return frame
|
||||
|
||||
# TODO: Add check for max size for control frames.
|
||||
var data = newSeq[byte](finalLen)
|
||||
|
||||
# Read control frame payload.
|
||||
if frame.length > 0 :
|
||||
# Read the data.
|
||||
await ws.tcpSocket.readExactly(addr data[0], int finalLen)
|
||||
frame.data = data
|
||||
|
||||
# Process control frame payload.
|
||||
if frame.opcode == Ping:
|
||||
await ws.send(data, Pong)
|
||||
elif frame.opcode == Pong:
|
||||
discard
|
||||
elif frame.opcode == Close:
|
||||
await ws.close(false)
|
||||
|
||||
return frame
|
||||
|
||||
proc close*(ws: WebSocket, initiator: bool = true) {.async.} =
|
||||
## Close the Socket, sends close packet.
|
||||
if ws.readyState == Closed:
|
||||
discard ws.tcpSocket.closeWait()
|
||||
return
|
||||
ws.readyState = Closed
|
||||
await ws.send(@[], Close)
|
||||
if initiator == true:
|
||||
let frame = await ws.readFrame()
|
||||
if frame.opcode != Close:
|
||||
echo "Different packet type"
|
||||
await ws.close()
|
||||
|
||||
proc readMessage*(msgReader: MsgReader,data: seq[byte]): MsgReader {.async.} =
|
||||
while msgReader.readErr == nil:
|
||||
if msgReader.readRemaining > 0 :
|
||||
len = size(data)
|
||||
if len > msgReader.readRemaining:
|
||||
len = msgReader.readRemaining
|
||||
|
||||
await msgReader.tcpSocket.readExactly(addr data, len)
|
||||
msgReader.readRemaining = msgReader.readRemaining - len
|
||||
msgReader.readLen = len
|
||||
|
||||
if msgReader.mask:
|
||||
# Apply mask, if we need too.
|
||||
for i in 0 ..< len:
|
||||
data[i] = (data[i].uint8 xor msgReader.maskKey[i mod 4].uint8)
|
||||
|
||||
if msgReader.readRemaining == 0:
|
||||
msgReader.readErr = EOFError
|
||||
|
||||
return msgReader
|
||||
|
||||
if msgReader.readFinal:
|
||||
msgReader.readLen = 0
|
||||
msgReader.readErr = EOFError
|
||||
return msgReader
|
||||
|
||||
var frame = await ws.readFrame()
|
||||
if frame.fin:
|
||||
msgReader.readFinal = true
|
||||
msgReader.readRemaining = frame.length
|
||||
|
||||
# Non-control frames cannot occur in the middle of a fragmented non-control frame.
|
||||
if frame.Opcode in Text || Binary:
|
||||
raise newException("websocket: internal error, unexpected text or binary in Reader")
|
||||
return msgReader
|
||||
|
||||
proc nextMessageReader*(ws: WebSocket): MsgReader =
|
||||
while true:
|
||||
# Handle control frames and return only on non control frames.
|
||||
var frame = await ws.readFrame()
|
||||
if frame.Opcode in Text || Binary:
|
||||
var msgReader: MsgReader
|
||||
msgReader.readFinal = frame.fin
|
||||
msgReader.readRemaining = frame.readRemaining
|
||||
msgReader.tcpSocket = ws.tcpSocket
|
||||
msgReader.mask = frame.mask
|
||||
msgReader.maskKey = frame.maskKey
|
||||
return msgReader
|
||||
|
||||
proc receiveStrPacket*(ws: WebSocket): Future[seq[byte]] {.async.} =
|
||||
# TODO: remove this once PR is approved.
|
||||
return nil
|
||||
|
||||
proc validateWSClientHandshake*(transp: StreamTransport,
|
||||
header: HttpResponseHeader): void =
|
||||
if header.code != ord(Http101):
|
||||
raise newException(WebSocketError, "Server did not reply with a websocket upgrade: " &
|
||||
"Header code: " & $header.code &
|
||||
"Header reason: " & header.reason() &
|
||||
"Address: " & $transp.remoteAddress())
|
||||
|
||||
proc newWebsocketClient*(uri: Uri, protocols: seq[string] = @[]): Future[
|
||||
WebSocket] {.async.} =
|
||||
var key = encode(genWebSecKey(newRng()))
|
||||
var uri = uri
|
||||
case uri.scheme
|
||||
of "ws":
|
||||
uri.scheme = "http"
|
||||
else:
|
||||
raise newException(WebSocketError, "uri scheme has to be 'ws'")
|
||||
|
||||
var headers = newHttpHeaders({
|
||||
"Connection": "Upgrade",
|
||||
"Upgrade": "websocket",
|
||||
"Cache-Control": "no-cache",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": key
|
||||
})
|
||||
if protocols.len != 0:
|
||||
headers.table["Sec-WebSocket-Protocol"] = @[protocols.join(", ")]
|
||||
|
||||
let client = newHttpClient(headers)
|
||||
var response = await client.request($uri, "GET", headers = headers)
|
||||
var header = response.parseResponse()
|
||||
if header.failed():
|
||||
# Header could not be parsed
|
||||
raise newException(WebSocketError, "Malformed header received: " &
|
||||
$client.transp.remoteAddress())
|
||||
client.transp.validateWSClientHandshake(header)
|
||||
|
||||
# Client data should be masked.
|
||||
return WebSocket(tcpSocket: client.transp, readyState: Open, masked: true,
|
||||
rng: newRng())
|
||||
|
||||
proc newWebsocketClient*(host: string, port: Port, path: string,
|
||||
protocols: seq[string] = @[]): Future[WebSocket] {.async.} =
|
||||
var uri = "ws://" & host & ":" & $port
|
||||
if path.startsWith("/"):
|
||||
uri.add path
|
||||
else:
|
||||
uri.add "/" & path
|
||||
return await newWebsocketClient(parseUri(uri), protocols)
|
@ -1,7 +0,0 @@
|
||||
import ws, nativesockets, chronos
|
||||
|
||||
discard waitFor newAsyncWebsocketClient("localhost", Port(8080), path = "/", protocols = @["myfancyprotocol"])
|
||||
echo "connected"
|
||||
|
||||
runForever()
|
||||
|
@ -1 +0,0 @@
|
||||
switch("path", "$projectDir/../src")
|
@ -1,20 +0,0 @@
|
||||
import ws, chronos, chronicles, httputils
|
||||
|
||||
proc cb(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||
info "Header: ", uri = header.uri()
|
||||
if header.uri() == "/ws":
|
||||
info "Initiating web socket connection."
|
||||
try:
|
||||
var ws = await newWebSocket(header, transp)
|
||||
echo await ws.receivePacket()
|
||||
info "Websocket handshake completed."
|
||||
except WebSocketError:
|
||||
echo "socket closed:", getCurrentExceptionMsg()
|
||||
|
||||
let res = await transp.sendHTTPResponse(HttpVersion11, Http200, "Hello World")
|
||||
|
||||
when isMainModule:
|
||||
let address = "127.0.0.1:8888"
|
||||
var httpServer = newHttpServer(address, cb)
|
||||
httpServer.start()
|
||||
waitFor httpServer.join()
|
76
tests/frame.nim
Normal file
76
tests/frame.nim
Normal file
@ -0,0 +1,76 @@
|
||||
include ../src/ws
|
||||
include ../src/http
|
||||
include ../src/random
|
||||
#import chronos, chronicles, httputils, strutils, base64, std/sha1,
|
||||
# streams, nativesockets, uri, times, chronos/timer, tables
|
||||
|
||||
import unittest
|
||||
|
||||
# TODO: Fix Test.
|
||||
|
||||
var maskKey: array[4, char]
|
||||
|
||||
suite "tests for encodeFrame()":
|
||||
test "# 7bit length":
|
||||
block: # 7bit length
|
||||
assert encodeFrame((
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: false,
|
||||
data: toBytes("hi there"),
|
||||
maskKey: maskKey
|
||||
)) == toBytes("\129\8hi there")
|
||||
test "# 7bit length":
|
||||
block: # 7+16 bits length
|
||||
var data = ""
|
||||
for i in 0..32:
|
||||
data.add "How are you this is the payload!!!"
|
||||
assert encodeFrame((
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: false,
|
||||
data: toBytes(data),
|
||||
maskKey: maskKey
|
||||
))[0..32] == toBytes("\129~\4bHow are you this is the paylo")
|
||||
test "# 7+64 bits length":
|
||||
block: # 7+64 bits length
|
||||
var data = ""
|
||||
for i in 0..3200:
|
||||
data.add "How are you this is the payload!!!"
|
||||
assert encodeFrame((
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: false,
|
||||
data: toBytes(data),
|
||||
maskKey: maskKey
|
||||
))[0..32] == toBytes("\129\127\0\0\0\0\0\1\169\"How are you this is the")
|
||||
test "# masking":
|
||||
block: # masking
|
||||
let data = encodeFrame((
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
data: toBytes("hi there"),
|
||||
maskKey: ['\xCF', '\xD8', '\x05', 'e']
|
||||
))
|
||||
assert data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0")
|
||||
|
||||
suite "tests for toTitleCase()":
|
||||
block:
|
||||
let val = toTitleCase("webSocket")
|
||||
assert val == "Websocket"
|
||||
|
||||
|
||||
|
55
tests/helpers.nim
Normal file
55
tests/helpers.nim
Normal file
@ -0,0 +1,55 @@
|
||||
import ../src/ws, chronos, chronicles, httputils, stew/byteutils,
|
||||
../src/http, unittest, strutils
|
||||
|
||||
proc cb*(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||
info "Handling request:", uri = header.uri()
|
||||
if header.uri() == "/ws":
|
||||
info "Initiating web socket connection."
|
||||
try:
|
||||
var ws = await newWebSocket(header, transp, "myfancyprotocol")
|
||||
if ws.readyState == Open:
|
||||
info "Websocket handshake completed."
|
||||
else:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
|
||||
while ws.readyState == Open:
|
||||
let recvData = await ws.receiveStrPacket()
|
||||
info "Server:", state = ws.readyState
|
||||
await ws.send(recvData)
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established")
|
||||
|
||||
proc sendRecvClientData*(wsClient: WebSocket, msg: string) {.async.} =
|
||||
try:
|
||||
waitFor wsClient.sendStr(msg)
|
||||
let recvData = waitFor wsClient.receiveStrPacket()
|
||||
info "Websocket client state: ", state = wsClient.readyState
|
||||
let dataStr = string.fromBytes(recvData)
|
||||
require dataStr == msg
|
||||
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
|
||||
proc incorrectProtocolCB*(transp: StreamTransport, header: HttpRequestHeader) {.async.} =
|
||||
info "Handling request:", uri = header.uri()
|
||||
var isErr = false;
|
||||
if header.uri() == "/ws":
|
||||
info "Initiating web socket connection."
|
||||
try:
|
||||
var ws = await newWebSocket(header, transp, "myfancyprotocol")
|
||||
require ws.readyState == ReadyState.Closed
|
||||
except WebSocketError:
|
||||
isErr = true;
|
||||
require contains(getCurrentExceptionMsg(), "Protocol mismatch")
|
||||
finally:
|
||||
require isErr == true
|
||||
discard await transp.sendHTTPResponse(HttpVersion11, Http200, "Connection established")
|
||||
|
||||
|
||||
proc generateData*(num: int64): seq[byte] =
|
||||
var str = newSeqOfCap[byte](num)
|
||||
for i in 0 ..< num:
|
||||
str.add(65)
|
||||
return str
|
87
tests/websocket.nim
Normal file
87
tests/websocket.nim
Normal file
@ -0,0 +1,87 @@
|
||||
import helpers, unittest, ../src/http, chronos, ../src/ws,../src/random,
|
||||
stew/byteutils, os, strutils
|
||||
|
||||
var httpServer: HttpServer
|
||||
|
||||
proc startServer() {.async, gcsafe.} =
|
||||
httpServer = newHttpServer("127.0.0.1:8888", cb)
|
||||
httpServer.start()
|
||||
|
||||
proc closeServer() {.async, gcsafe.} =
|
||||
httpServer.stop()
|
||||
waitFor httpServer.closeWait()
|
||||
|
||||
suite "Test websocket error cases":
|
||||
teardown:
|
||||
httpServer.stop()
|
||||
waitFor httpServer.closeWait()
|
||||
|
||||
test "Test for incorrect protocol":
|
||||
httpServer = newHttpServer("127.0.0.1:8888", incorrectProtocolCB)
|
||||
httpServer.start()
|
||||
try:
|
||||
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888),
|
||||
path = "/ws", protocols = @["mywrongprotocol"])
|
||||
except WebSocketError:
|
||||
require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade")
|
||||
|
||||
test "Test for incorrect port":
|
||||
httpServer = newHttpServer("127.0.0.1:8888", cb)
|
||||
httpServer.start()
|
||||
try:
|
||||
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8889),
|
||||
path = "/ws", protocols = @["myfancyprotocol"])
|
||||
except:
|
||||
require contains(getCurrentExceptionMsg(), "Connection refused")
|
||||
|
||||
test "Test for incorrect path":
|
||||
httpServer = newHttpServer("127.0.0.1:8888", cb)
|
||||
httpServer.start()
|
||||
try:
|
||||
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888),
|
||||
path = "/gg", protocols = @["myfancyprotocol"])
|
||||
except:
|
||||
require contains(getCurrentExceptionMsg(), "Server did not reply with a websocket upgrade")
|
||||
|
||||
suite "Misc Test":
|
||||
setup:
|
||||
waitFor startServer()
|
||||
teardown:
|
||||
waitFor closeServer()
|
||||
|
||||
test "Test for maskKey":
|
||||
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888), path = "/ws",
|
||||
protocols = @["myfancyprotocol"])
|
||||
let maskKey = genMaskKey(wsClient.rng)
|
||||
require maskKey.len == 4
|
||||
|
||||
test "Test for toCaseInsensitive":
|
||||
let headers = newHttpHeaders()
|
||||
require toCaseInsensitive(headers, "webSocket") == "Websocket"
|
||||
|
||||
|
||||
suite "Test web socket communication":
|
||||
|
||||
setup:
|
||||
waitFor startServer()
|
||||
let wsClient = waitFor newWebsocketClient("127.0.0.1", Port(8888),
|
||||
path = "/ws", protocols = @["myfancyprotocol"])
|
||||
|
||||
teardown:
|
||||
waitFor closeServer()
|
||||
|
||||
test "Websocket conversation between client and server":
|
||||
waitFor sendRecvClientData(wsClient, "Hello Server")
|
||||
|
||||
test "Test for small message ":
|
||||
let msg = string.fromBytes(generateData(100))
|
||||
waitFor sendRecvClientData(wsClient, msg)
|
||||
|
||||
test "Test for medium message ":
|
||||
let msg = string.fromBytes(generateData(1000))
|
||||
waitFor sendRecvClientData(wsClient, msg)
|
||||
|
||||
test "Test for large message ":
|
||||
let msg = string.fromBytes(generateData(10000))
|
||||
waitFor sendRecvClientData(wsClient, msg)
|
||||
|
@ -4,12 +4,15 @@ author = "Status Research & Development GmbH"
|
||||
description = "WS protocol implementation"
|
||||
license = "MIT"
|
||||
|
||||
requires "nim >= 1.2.6"
|
||||
requires "chronos >= 2.5.2 & < 3.0.0"
|
||||
requires "nim == 1.2.6"
|
||||
requires "chronos >= 2.5.2"
|
||||
requires "httputils >= 0.2.0"
|
||||
requires "chronicles >= 0.10.0"
|
||||
requires "urlly >= 0.2.0"
|
||||
requires "uri"
|
||||
requires "stew >= 0.1.0"
|
||||
requires "eth"
|
||||
requires "asynctest >= 0.2.0 & < 0.3.0"
|
||||
requires "nimcrypto"
|
||||
|
||||
task lint, "format source files according to the official style guide":
|
||||
exec "./lint.nims"
|
||||
|
Loading…
x
Reference in New Issue
Block a user