Extract session (#31)

* extract websocket session

* fix tests

* fix frame tests
This commit is contained in:
Dmitriy Ryajov 2021-05-25 16:39:10 -06:00 committed by GitHub
parent 0f48b62eb9
commit eb62ec1725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 600 additions and 577 deletions

View File

@ -3,7 +3,7 @@ import pkg/[
chronicles,
stew/byteutils]
import ../ws/ws, ../ws/errors
import ../ws/ws
proc main() {.async.} =
let ws = await WebSocket.connect(

View File

@ -3,7 +3,7 @@ import pkg/[chronos,
chronicles,
httputils]
import ../ws/[ws, frame, errors]
import ../ws/ws
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isOk():

View File

@ -3,7 +3,7 @@ import pkg/[chronos,
chronicles,
stew/byteutils]
import ../ws/ws, ../ws/errors
import ../ws/ws
proc main() {.async.} =
let ws = await WebSocket.tlsConnect(
@ -31,4 +31,5 @@ proc main() {.async.} =
# close the websocket
await ws.close()
waitFor(main())

View File

@ -4,7 +4,7 @@ import pkg/[chronos,
httputils,
stew/byteutils]
import ../ws/[ws, frame, errors]
import ../ws/ws
import ../tests/keys
let secureKey = TLSPrivateKey.init(SecureKey)

View File

@ -1,4 +1,4 @@
import unittest, stew/byteutils
import pkg/[asynctest, stew/byteutils]
include ../ws/frame
include ../ws/utils
@ -9,7 +9,7 @@ var maskKey: array[4, char]
suite "Test data frames":
test "# 7bit length text":
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -17,10 +17,10 @@ suite "Test data frames":
opcode: Opcode.Text,
mask: false,
data: toBytes("hi there")
).encode() == toBytes("\1\8hi there")
).encode()) == toBytes("\1\8hi there")
test "# 7bit length text fin bit":
check Frame(
check (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -29,10 +29,10 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
).encode() == toBytes("\129\8hi there")
).encode()) == toBytes("\129\8hi there")
test "# 7bit length binary":
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -41,10 +41,10 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
).encode() == toBytes("\2\8hi there")
).encode()) == toBytes("\2\8hi there")
test "# 7bit length binary fin bit":
check Frame(
check (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -53,10 +53,10 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
).encode() == toBytes("\130\8hi there")
).encode()) == toBytes("\130\8hi there")
test "# 7bit length continuation":
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -65,14 +65,14 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
).encode() == toBytes("\0\8hi there")
).encode()) == toBytes("\0\8hi there")
test "# 7+16 length text":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -81,14 +81,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\1\126\4\98" & data)
).encode()) == toBytes("\1\126\4\98" & data)
test "# 7+16 length text fin bit":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -97,14 +97,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\1\126\4\98" & data)
).encode()) == toBytes("\1\126\4\98" & data)
test "# 7+16 length binary":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -113,14 +113,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\2\126\4\98" & data)
).encode()) == toBytes("\2\126\4\98" & data)
test "# 7+16 length binary fin bit":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -129,14 +129,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\130\126\4\98" & data)
).encode()) == toBytes("\130\126\4\98" & data)
test "# 7+16 length continuation":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -145,14 +145,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\0\126\4\98" & data)
).encode()) == toBytes("\0\126\4\98" & data)
test "# 7+64 length text":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -161,14 +161,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\1\127\0\0\0\0\0\1\169\34" & data)
).encode()) == toBytes("\1\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length fin bit":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -177,14 +177,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\129\127\0\0\0\0\0\1\169\34" & data)
).encode()) == toBytes("\129\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length binary":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -193,14 +193,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\2\127\0\0\0\0\0\1\169\34" & data)
).encode()) == toBytes("\2\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length binary fin bit":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -209,14 +209,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\130\127\0\0\0\0\0\1\169\34" & data)
).encode()) == toBytes("\130\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length binary":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -225,10 +225,10 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
).encode() == toBytes("\0\127\0\0\0\0\0\1\169\34" & data)
).encode()) == toBytes("\0\127\0\0\0\0\0\1\169\34" & data)
test "# masking":
let data = Frame(
let data = (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -237,14 +237,14 @@ suite "Test data frames":
mask: true,
data: toBytes("hi there"),
maskKey: ['\xCF', '\xD8', '\x05', 'e']
).encode()
).encode())
check data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0")
suite "Test control frames":
test "Close":
check Frame(
check (await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -253,10 +253,10 @@ suite "Test control frames":
mask: false,
data: @[3'u8, 232'u8] & toBytes("hi there"),
maskKey: maskKey
).encode() == toBytes("\136\10\3\232hi there")
).encode()) == toBytes("\136\10\3\232hi there")
test "Ping":
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -264,10 +264,10 @@ suite "Test control frames":
opcode: Opcode.Ping,
mask: false,
maskKey: maskKey
).encode() == toBytes("\9\0")
).encode()) == toBytes("\9\0")
test "Pong":
check Frame(
check (await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -275,4 +275,4 @@ suite "Test control frames":
opcode: Opcode.Pong,
mask: false,
maskKey: maskKey
).encode() == toBytes("\10\0")
).encode()) == toBytes("\10\0")

View File

@ -6,8 +6,7 @@ import pkg/[asynctest,
chronos/apps/http/shttpserver,
stew/byteutils]
import ../ws/[ws, stream, errors],
../examples/tlsserver
import ../ws/ws, ../examples/tlsserver
import ./keys

View File

@ -1,12 +1,13 @@
import std/[strutils, random], httputils
import pkg/[asynctest,
chronos,
chronos/apps/http/httpserver,
chronicles,
stew/byteutils]
import pkg/[
asynctest,
chronos,
chronos/apps/http/httpserver,
chronicles,
stew/byteutils]
import ../ws/[ws, stream, utils, frame, errors]
import ../ws/ws
var server: HttpServerRef
let address = initTAddress("127.0.0.1:8888")
@ -138,7 +139,6 @@ suite "Test handshake":
suite "Test transmission":
teardown:
await server.stop()
await server.closeWait()
test "Send text message message with payload of length 65535":
@ -279,7 +279,7 @@ suite "Test ping-pong":
let maskKey = genMaskKey(newRng())
await wsClient.stream.writer.write(
Frame(
(await Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -288,12 +288,12 @@ suite "Test ping-pong":
mask: true,
data: msg[0..4],
maskKey: maskKey)
.encode())
.encode()))
await wsClient.ping()
await wsClient.stream.writer.write(
Frame(
(await Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -302,7 +302,7 @@ suite "Test ping-pong":
mask: true,
data: msg[5..9],
maskKey: maskKey)
.encode())
.encode()))
await wsClient.close()
check:

View File

@ -1,31 +0,0 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
type
WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError
WSFailedUpgradeError* = object of WebSocketError
WSVersionError* = object of WebSocketError
WSProtoMismatchError* = object of WebSocketError
WSMaskMismatchError* = object of WebSocketError
WSHandshakeError* = object of WebSocketError
WSOpcodeMismatchError* = object of WebSocketError
WSRsvMismatchError* = object of WebSocketError
WSWrongUriSchemeError* = object of WebSocketError
WSMaxMessageSizeError* = object of WebSocketError
WSClosedError* = object of WebSocketError
WSSendError* = object of WebSocketError
WSPayloadTooLarge* = object of WebSocketError
WSReserverdOpcodeError* = object of WebSocketError
WSFragmentedControlFrameError* = object of WebSocketError
WSInvalidCloseCodeError* = object of WebSocketError
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError

View File

@ -1,26 +0,0 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import pkg/[chronos, chronicles]
import ./frame
type
Extension* = ref object of RootObj
name*: string
proc `name=`*(self: Extension, name: string) =
raiseAssert "Can't change extensions name!"
method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"
method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"

View File

@ -1,5 +1,5 @@
## Nim-Libp2p
## Copyright (c) 2020 Status Research & Development GmbH
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
@ -10,7 +10,7 @@
{.push raises: [Defect].}
import pkg/[chronos, chronicles, stew/endians2, stew/results]
import ./errors
import ./types
#[
+---------------------------------------------------------------+
@ -34,40 +34,6 @@ import ./errors
+---------------------------------------------------------------+
]#
type
Opcode* {.pure.} = enum
## 4 bits. Defines the interpretation of the "Payload data".
Cont = 0x0 ## Denotes a continuation frame.
Text = 0x1 ## Denotes a text frame.
Binary = 0x2 ## Denotes a binary frame.
# 3-7 are reserved for further non-control frames.
Close = 0x8 ## Denotes a connection close.
Ping = 0x9 ## Denotes a ping.
Pong = 0xa ## Denotes a pong.
# B-F are reserved for further control frames.
HeaderFlag* {.pure, size: sizeof(uint8).} = enum
rsv3
rsv2
rsv1
fin
HeaderFlags = set[HeaderFlag]
MaskKey = array[4, char]
Frame* = ref object
fin*: bool ## Indicates that this is the final fragment in a message.
rsv1*: bool ## MUST be 0 unless negotiated that defines meanings
rsv2*: bool ## MUST be 0
rsv3*: bool ## MUST be 0
opcode*: Opcode ## Defines the interpretation of the "Payload data".
mask*: bool ## Defines whether the "Payload data" is masked.
data*: seq[byte] ## Payload data
maskKey*: MaskKey ## Masking key
length*: uint64 ## Message size.
consumed*: uint64 ## how much has been consumed from the frame
proc mask*(
data: var openArray[byte],
maskKey: MaskKey,
@ -78,10 +44,21 @@ proc mask*(
for i in 0 ..< data.len:
data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8)
proc encode*(f: Frame, offset = 0): seq[byte] =
template remainder*(frame: Frame): uint64 =
frame.length - frame.consumed
proc encode*(
frame: Frame,
offset = 0,
extensions: seq[Extension] = @[]): Future[seq[byte]] {.async.} =
## Encodes a frame into a string buffer.
## See https://tools.ietf.org/html/rfc6455#section-5.2
var f = frame
if extensions.len > 0:
for e in extensions:
f = await e.encode(f)
var ret: seq[byte]
var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags.
if f.fin:
@ -135,8 +112,8 @@ proc encode*(f: Frame, offset = 0): seq[byte] =
proc decode*(
_: typedesc[Frame],
reader: AsyncStreamReader,
masked: bool):
Future[Frame] {.async.} =
masked: bool,
extensions: seq[Extension] = @[]): Future[Frame] {.async.} =
## Read and Decode incoming header
##
@ -204,4 +181,8 @@ proc decode*(
for i in 0..<maskKey.len:
frame.maskKey[i] = cast[char](maskKey[i])
if extensions.len > 0:
for e in extensions[extensions.high..extensions.low]:
frame = await e.decode(frame)
return frame

380
ws/session.nim Normal file
View File

@ -0,0 +1,380 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import pkg/[chronos, chronicles, stew/byteutils, stew/endians2]
import ./types, ./frame, ./utils, ./stream
import pkg/chronos/[
streams/asyncstream,
streams/tlsstream]
type
WSSession* = ref object of WebSocket
stream*: AsyncStream
frame*: Frame
proc prepareCloseBody(code: Status, reason: string): seq[byte] =
result = reason.toBytes
if ord(code) > 999:
result = @(ord(code).uint16.toBytesBE()) & result
proc send*(
ws: WSSession,
data: seq[byte] = @[],
opcode: Opcode) {.async.} =
## Send a frame
##
if ws.readyState == ReadyState.Closed:
raise newException(WSClosedError, "Socket is closed!")
logScope:
opcode = opcode
dataSize = data.len
masked = ws.masked
debug "Sending data to remote"
var maskKey: array[4, char]
if ws.masked:
maskKey = genMaskKey(ws.rng)
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}:
return
await ws.stream.writer.write(
(await Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: opcode,
mask: ws.masked,
data: data, # allow sending data with close messages
maskKey: maskKey)
.encode()))
return
let maxSize = ws.frameSize
var i = 0
while ws.readyState notin {ReadyState.Closing}:
let len = min(data.len, (maxSize + i))
await ws.stream.writer.write(
(await Frame(
fin: if (i + len >= data.len): true else: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
mask: ws.masked,
data: data[i ..< len],
maskKey: maskKey)
.encode()))
i += len
if i >= data.len:
break
proc send*(ws: WSSession, data: string): Future[void] =
send(ws, data.toBytes(), Opcode.Text)
proc handleClose*(
ws: WSSession,
frame: Frame,
payLoad: seq[byte] = @[]) {.async.} =
## Handle close sequence
##
logScope:
fin = frame.fin
masked = frame.mask
opcode = frame.opcode
readyState = ws.readyState
debug "Handling close sequence"
if ws.readyState notin {ReadyState.Open}:
debug "Connection isn't open, abortig close sequence!"
return
var
code = Status.Fulfilled
reason = ""
if payLoad.len == 1:
raise newException(WSPayloadLengthError,
"Invalid close frame with payload length 1!")
if payLoad.len > 1:
# first two bytes are the status
let ccode = uint16.fromBytesBE(payLoad[0..<2])
if ccode <= 999 or ccode > 1015:
raise newException(WSInvalidCloseCodeError,
"Invalid code in close message!")
try:
code = Status(ccode)
except RangeError:
raise newException(WSInvalidCloseCodeError,
"Status code out of range!")
# remining payload bytes are reason for closing
reason = string.fromBytes(payLoad[2..payLoad.high])
var rcode: Status
if code in {Status.Fulfilled}:
rcode = Status.Fulfilled
if not isNil(ws.onClose):
try:
(rcode, reason) = ws.onClose(code, reason)
except CatchableError as exc:
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
# don't respond to a terminated connection
if ws.readyState != ReadyState.Closing:
ws.readyState = ReadyState.Closing
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
## Handle control frames
##
if not frame.fin:
raise newException(WSFragmentedControlFrameError,
"Control frame cannot be fragmented!")
if frame.length > 125:
raise newException(WSPayloadTooLarge,
"Control message payload is greater than 125 bytes!")
try:
var payLoad = newSeq[byte](frame.length.int)
if frame.length > 0:
payLoad.setLen(frame.length.int)
# Read control frame payload.
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
if frame.mask:
mask(
payLoad.toOpenArray(0, payLoad.high),
frame.maskKey)
# Process control frame payload.
case frame.opcode:
of Opcode.Ping:
if not isNil(ws.onPing):
try:
ws.onPing(payLoad)
except CatchableError as exc:
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
# send pong to remote
await ws.send(payLoad, Opcode.Pong)
of Opcode.Pong:
if not isNil(ws.onPong):
try:
ws.onPong(payLoad)
except CatchableError as exc:
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
of Opcode.Close:
await ws.handleClose(frame, payLoad)
else:
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
except WebSocketError as exc:
debug "Handled websocket exception", exc = exc.msg
raise exc
except CatchableError as exc:
trace "Exception handling control messages", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
proc readFrame*(ws: WSSession): Future[Frame] {.async.} =
## Gets a frame from the WebSocket.
## See https://tools.ietf.org/html/rfc6455#section-5.2
##
try:
while ws.readyState != ReadyState.Closed:
let frame = await Frame.decode(ws.stream.reader, ws.masked)
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
# return the current frame if it's not one of the control frames
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
await ws.handleControl(frame) # process control frames# process control frames
continue
return frame
except WebSocketError as exc:
trace "Websocket error", exc = exc.msg
raise exc
except CatchableError as exc:
debug "Exception reading frame, dropping socket", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
raise exc
proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] =
ws.send(data, opcode = Opcode.Ping)
proc recv*(
ws: WSSession,
data: pointer,
size: int): Future[int] {.async.} =
## Attempts to read up to `size` bytes
##
## Will read as many frames as necessary
## to fill the buffer until either
## the message ends (frame.fin) or
## the buffer is full. If no data is on
## the pipe will await until at least
## one byte is available
##
var consumed = 0
var pbuffer = cast[ptr UncheckedArray[byte]](data)
try:
while consumed < size:
# we might have to read more than
# one frame to fill the buffer
# TODO: Figure out a cleaner way to handle
# retrieving new frames
if isNil(ws.frame):
ws.frame = await ws.readFrame()
if isNil(ws.frame):
return consumed
if ws.frame.opcode == Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Text or Binary frame")
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed
if ws.frame.opcode != Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Continuation frame")
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
if ws.frame.fin and ws.frame.remainder() <= 0:
ws.frame = nil
break
let len = min(ws.frame.remainder().int, size - consumed)
if len == 0:
continue
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
if read <= 0:
continue
if ws.frame.mask:
# unmask data using offset
mask(
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
ws.frame.maskKey,
ws.frame.consumed.int)
consumed += read
ws.frame.consumed += read.uint64
return consumed.int
except WebSocketError as exc:
debug "Websocket error", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
raise exc
except CancelledError as exc:
debug "Cancelling reading", exc = exc.msg
raise exc
except CatchableError as exc:
debug "Exception reading frames", exc = exc.msg
proc recv*(
ws: WSSession,
size = WSMaxMessageSize): Future[seq[byte]] {.async.} =
## Attempt to read a full message up to max `size`
## bytes in `frameSize` chunks.
##
## If no `fin` flag arrives await until either
## cancelled or the `fin` flag arrives.
##
## If message is larger than `size` a `WSMaxMessageSizeError`
## exception is thrown.
##
## In all other cases it awaits a full message.
##
var res: seq[byte]
try:
while ws.readyState != ReadyState.Closed:
var buf = newSeq[byte](ws.frameSize)
let read = await ws.recv(addr buf[0], buf.len)
if read <= 0:
break
buf.setLen(read)
if res.len + buf.len > size:
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
res.add(buf)
# no more frames
if isNil(ws.frame):
break
# read the entire message, exit
if ws.frame.fin and ws.frame.remainder().int <= 0:
break
except WebSocketError as exc:
debug "Websocket error", exc = exc.msg
raise exc
except CancelledError as exc:
debug "Cancelling reading", exc = exc.msg
raise exc
except CatchableError as exc:
debug "Exception reading frames", exc = exc.msg
return res
proc close*(
ws: WSSession,
code: Status = Status.Fulfilled,
reason: string = "") {.async.} =
## Close the Socket, sends close packet.
##
if ws.readyState != ReadyState.Open:
return
try:
ws.readyState = ReadyState.Closing
await ws.send(
prepareCloseBody(code, reason),
opcode = Opcode.Close)
# read frames until closed
while ws.readyState != ReadyState.Closed:
discard await ws.recv()
except CatchableError as exc:
debug "Exception closing", exc = exc.msg

136
ws/types.nim Normal file
View File

@ -0,0 +1,136 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import chronos
import ./utils
const
SHA1DigestSize* = 20
WSHeaderSize* = 12
WSDefaultVersion* = 13
WSDefaultFrameSize* = 1 shl 20 # 1mb
WSMaxMessageSize* = 20 shl 20 # 20mb
WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
CRLF* = "\r\n"
type
ReadyState* {.pure.} = enum
Connecting = 0 # The connection is not yet open.
Open = 1 # The connection is open and ready to communicate.
Closing = 2 # The connection is in the process of closing.
Closed = 3 # The connection is closed or couldn't be opened.
Opcode* {.pure.} = enum
## 4 bits. Defines the interpretation of the "Payload data".
Cont = 0x0 ## Denotes a continuation frame.
Text = 0x1 ## Denotes a text frame.
Binary = 0x2 ## Denotes a binary frame.
# 3-7 are reserved for further non-control frames.
Close = 0x8 ## Denotes a connection close.
Ping = 0x9 ## Denotes a ping.
Pong = 0xa ## Denotes a pong.
# B-F are reserved for further control frames.
HeaderFlag* {.pure, size: sizeof(uint8).} = enum
rsv3
rsv2
rsv1
fin
HeaderFlags* = set[HeaderFlag]
MaskKey* = array[4, char]
Frame* = ref object
fin*: bool ## Indicates that this is the final fragment in a message.
rsv1*: bool ## MUST be 0 unless negotiated that defines meanings
rsv2*: bool ## MUST be 0
rsv3*: bool ## MUST be 0
opcode*: Opcode ## Defines the interpretation of the "Payload data".
mask*: bool ## Defines whether the "Payload data" is masked.
data*: seq[byte] ## Payload data
maskKey*: MaskKey ## Masking key
length*: uint64 ## Message size.
consumed*: uint64 ## how much has been consumed from the frame
Status* {.pure.} = enum
# 0-999 not used
Fulfilled = 1000
GoingAway = 1001
ProtocolError = 1002
CannotAccept = 1003
# 1004 reserved
NoStatus = 1005 # use by clients
ClosedAbnormally = 1006 # use by clients
Inconsistent = 1007
PolicyError = 1008
TooLarge = 1009
NoExtensions = 1010
UnexpectedError = 1011
ReservedCode = 3999 # use by clients
# 3000-3999 reserved for libs
# 4000-4999 reserved for applications
ControlCb* = proc(data: openArray[byte] = [])
{.gcsafe, raises: [Defect].}
CloseResult* = tuple
code: Status
reason: string
CloseCb* = proc(code: Status, reason: string):
CloseResult {.gcsafe, raises: [Defect].}
Extension* = ref object of RootObj
name*: string
WebSocket* = ref object of RootObj
extensions: seq[Extension] # extension active for this session
version*: uint
key*: string
proto*: string
readyState*: ReadyState
masked*: bool # send masked packets
binary*: bool # is payload binary?
rng*: Rng
frameSize*: int
onPing*: ControlCb
onPong*: ControlCb
onClose*: CloseCb
WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError
WSFailedUpgradeError* = object of WebSocketError
WSVersionError* = object of WebSocketError
WSProtoMismatchError* = object of WebSocketError
WSMaskMismatchError* = object of WebSocketError
WSHandshakeError* = object of WebSocketError
WSOpcodeMismatchError* = object of WebSocketError
WSRsvMismatchError* = object of WebSocketError
WSWrongUriSchemeError* = object of WebSocketError
WSMaxMessageSizeError* = object of WebSocketError
WSClosedError* = object of WebSocketError
WSSendError* = object of WebSocketError
WSPayloadTooLarge* = object of WebSocketError
WSReserverdOpcodeError* = object of WebSocketError
WSFragmentedControlFrameError* = object of WebSocketError
WSInvalidCloseCodeError* = object of WebSocketError
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError
proc `name=`*(self: Extension, name: string) =
raiseAssert "Can't change extensions name!"
method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"
method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"

View File

@ -21,22 +21,22 @@ proc newRng*(): ref BrHmacDrbgContext =
return nil
rng
proc rand*(rng: var BrHmacDrbgContext, max: Natural): int =
proc rand*(rng: Rng, max: Natural): int =
if max == 0: return 0
var x: uint64
while true:
brHmacDrbgGenerate(addr rng, addr x, csize_t(sizeof(x)))
brHmacDrbgGenerate(addr rng[], addr x, csize_t(sizeof(x)))
if x < randMax - (randMax mod (uint64(max) + 1'u64)): # against modulo bias
return int(x mod (uint64(max) + 1'u64))
proc genMaskKey*(rng: ref BrHmacDrbgContext): array[4, char] =
proc genMaskKey*(rng: Rng): array[4, char] =
## Generates a random key of 4 random chars.
proc r(): char = char(rand(rng[], 255))
proc r(): char = char(rand(rng, 255))
return [r(), r(), r(), r()]
proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[byte] =
proc genWebSecKey*(rng: Rng): seq[byte] =
var key = newSeq[byte](16)
proc r(): byte = byte(rand(rng[], 255))
proc r(): byte = byte(rand(rng, 255))
## Generates a random key of 16 random chars.
for i in 0..15:
key.add(r())

421
ws/ws.nim
View File

@ -7,7 +7,6 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import std/[tables,
@ -29,79 +28,17 @@ import pkg/[chronos,
stew/base10,
nimcrypto/sha]
import ./utils, ./stream, ./frame, ./errors, ./extension
import ./utils, ./stream, ./frame, ./session, /types
const
SHA1DigestSize* = 20
WSHeaderSize* = 12
WSDefaultVersion* = 13
WSDefaultFrameSize* = 1 shl 20 # 1mb
WSMaxMessageSize* = 20 shl 20 # 20mb
WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
CRLF* = "\r\n"
export utils, session, frame, stream, types
type
ReadyState* {.pure.} = enum
Connecting = 0 # The connection is not yet open.
Open = 1 # The connection is open and ready to communicate.
Closing = 2 # The connection is in the process of closing.
Closed = 3 # The connection is closed or couldn't be opened.
HttpCode* = enum
Http101 = 101 # Switching Protocols
Status* {.pure.} = enum
# 0-999 not used
Fulfilled = 1000
GoingAway = 1001
ProtocolError = 1002
CannotAccept = 1003
# 1004 reserved
NoStatus = 1005 # use by clients
ClosedAbnormally = 1006 # use by clients
Inconsistent = 1007
PolicyError = 1008
TooLarge = 1009
NoExtensions = 1010
UnexpectedError = 1011
ReservedCode = 3999 # use by clients
# 3000-3999 reserved for libs
# 4000-4999 reserved for applications
ControlCb* = proc(data: openArray[byte] = [])
{.gcsafe, raises: [Defect].}
CloseResult* = tuple
code: Status
reason: string
CloseCb* = proc(code: Status, reason: string):
CloseResult {.gcsafe, raises: [Defect].}
WebSocket* = ref object of RootObj
extensions: seq[Extension] # extension active for this session
version*: uint
key*: string
proto*: string
readyState*: ReadyState
masked*: bool # send masked packets
binary*: bool # is payload binary?
rng*: ref BrHmacDrbgContext
frameSize: int
onPing: ControlCb
onPong: ControlCb
onClose: CloseCb
WSServer* = ref object of WebSocket
protocols: seq[string]
WSSession* = ref object of WebSocket
stream*: AsyncStream
frame*: Frame
template remainder*(frame: Frame): uint64 =
frame.length - frame.consumed
proc `$`(ht: HttpTables): string =
## Returns string representation of HttpTable/Ref.
var res = ""
@ -115,11 +52,6 @@ proc `$`(ht: HttpTables): string =
res.add(CRLF)
res
proc prepareCloseBody(code: Status, reason: string): seq[byte] =
result = reason.toBytes
if ord(code) > 999:
result = @(ord(code).uint16.toBytesBE()) & result
proc handshake*(
ws: WSServer,
request: HttpRequestRef,
@ -187,355 +119,6 @@ proc handshake*(
onPong: ws.onPong,
onClose: ws.onClose)
proc send*(
ws: WSSession,
data: seq[byte] = @[],
opcode: Opcode) {.async.} =
## Send a frame
##
if ws.readyState == ReadyState.Closed:
raise newException(WSClosedError, "Socket is closed!")
logScope:
opcode = opcode
dataSize = data.len
masked = ws.masked
debug "Sending data to remote"
var maskKey: array[4, char]
if ws.masked:
maskKey = genMaskKey(ws.rng)
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}:
return
await ws.stream.writer.write(
Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: opcode,
mask: ws.masked,
data: data, # allow sending data with close messages
maskKey: maskKey)
.encode())
return
let maxSize = ws.frameSize
var i = 0
while ws.readyState notin {ReadyState.Closing}:
let len = min(data.len, (maxSize + i))
await ws.stream.writer.write(
Frame(
fin: if (i + len >= data.len): true else: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
mask: ws.masked,
data: data[i ..< len],
maskKey: maskKey)
.encode())
i += len
if i >= data.len:
break
proc send*(ws: WSSession, data: string): Future[void] =
send(ws, toBytes(data), Opcode.Text)
proc handleClose*(ws: WSSession, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
## Handle close sequence
##
logScope:
fin = frame.fin
masked = frame.mask
opcode = frame.opcode
readyState = ws.readyState
debug "Handling close sequence"
if ws.readyState notin {ReadyState.Open}:
debug "Connection isn't open, abortig close sequence!"
return
var
code = Status.Fulfilled
reason = ""
if payLoad.len == 1:
raise newException(WSPayloadLengthError,
"Invalid close frame with payload length 1!")
if payLoad.len > 1:
# first two bytes are the status
let ccode = uint16.fromBytesBE(payLoad[0..<2])
if ccode <= 999 or ccode > 1015:
raise newException(WSInvalidCloseCodeError,
"Invalid code in close message!")
try:
code = Status(ccode)
except RangeError:
raise newException(WSInvalidCloseCodeError,
"Status code out of range!")
# remining payload bytes are reason for closing
reason = string.fromBytes(payLoad[2..payLoad.high])
var rcode: Status
if code in {Status.Fulfilled}:
rcode = Status.Fulfilled
if not isNil(ws.onClose):
try:
(rcode, reason) = ws.onClose(code, reason)
except CatchableError as exc:
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
# don't respond to a terminated connection
if ws.readyState != ReadyState.Closing:
ws.readyState = ReadyState.Closing
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
## Handle control frames
##
if not frame.fin:
raise newException(WSFragmentedControlFrameError,
"Control frame cannot be fragmented!")
if frame.length > 125:
raise newException(WSPayloadTooLarge,
"Control message payload is greater than 125 bytes!")
try:
var payLoad = newSeq[byte](frame.length.int)
if frame.length > 0:
payLoad.setLen(frame.length.int)
# Read control frame payload.
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
if frame.mask:
mask(
payLoad.toOpenArray(0, payLoad.high),
frame.maskKey)
# Process control frame payload.
case frame.opcode:
of Opcode.Ping:
if not isNil(ws.onPing):
try:
ws.onPing(payLoad)
except CatchableError as exc:
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
# send pong to remote
await ws.send(payLoad, Opcode.Pong)
of Opcode.Pong:
if not isNil(ws.onPong):
try:
ws.onPong(payLoad)
except CatchableError as exc:
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
of Opcode.Close:
await ws.handleClose(frame, payLoad)
else:
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
except WebSocketError as exc:
debug "Handled websocket exception", exc = exc.msg
raise exc
except CatchableError as exc:
trace "Exception handling control messages", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
proc readFrame*(ws: WSSession): Future[Frame] {.async.} =
## Gets a frame from the WebSocket.
## See https://tools.ietf.org/html/rfc6455#section-5.2
##
try:
while ws.readyState != ReadyState.Closed:
let frame = await Frame.decode(ws.stream.reader, ws.masked)
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
# return the current frame if it's not one of the control frames
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
await ws.handleControl(frame) # process control frames# process control frames
continue
return frame
except WebSocketError as exc:
trace "Websocket error", exc = exc.msg
raise exc
except CatchableError as exc:
debug "Exception reading frame, dropping socket", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
raise exc
proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] =
ws.send(data, opcode = Opcode.Ping)
proc recv*(
ws: WSSession,
data: pointer,
size: int): Future[int] {.async.} =
## Attempts to read up to `size` bytes
##
## Will read as many frames as necessary
## to fill the buffer until either
## the message ends (frame.fin) or
## the buffer is full. If no data is on
## the pipe will await until at least
## one byte is available
##
var consumed = 0
var pbuffer = cast[ptr UncheckedArray[byte]](data)
try:
while consumed < size:
# we might have to read more than
# one frame to fill the buffer
# TODO: Figure out a cleaner way to handle
# retrieving new frames
if isNil(ws.frame):
ws.frame = await ws.readFrame()
if isNil(ws.frame):
return consumed
if ws.frame.opcode == Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Text or Binary frame")
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed
if ws.frame.opcode != Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Continuation frame")
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
if ws.frame.fin and ws.frame.remainder() <= 0:
ws.frame = nil
break
let len = min(ws.frame.remainder().int, size - consumed)
if len == 0:
continue
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
if read <= 0:
continue
if ws.frame.mask:
# unmask data using offset
mask(
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
ws.frame.maskKey,
ws.frame.consumed.int)
consumed += read
ws.frame.consumed += read.uint64
return consumed.int
except WebSocketError as exc:
debug "Websocket error", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
raise exc
except CancelledError as exc:
debug "Cancelling reading", exc = exc.msg
raise exc
except CatchableError as exc:
debug "Exception reading frames", exc = exc.msg
proc recv*(
ws: WSSession,
size = WSMaxMessageSize): Future[seq[byte]] {.async.} =
## Attempt to read a full message up to max `size`
## bytes in `frameSize` chunks.
##
## If no `fin` flag arrives await until either
## cancelled or the `fin` flag arrives.
##
## If message is larger than `size` a `WSMaxMessageSizeError`
## exception is thrown.
##
## In all other cases it awaits a full message.
##
var res: seq[byte]
try:
while ws.readyState != ReadyState.Closed:
var buf = newSeq[byte](ws.frameSize)
let read = await ws.recv(addr buf[0], buf.len)
if read <= 0:
break
buf.setLen(read)
if res.len + buf.len > size:
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
res.add(buf)
# no more frames
if isNil(ws.frame):
break
# read the entire message, exit
if ws.frame.fin and ws.frame.remainder().int <= 0:
break
except WebSocketError as exc:
debug "Websocket error", exc = exc.msg
raise exc
except CancelledError as exc:
debug "Cancelling reading", exc = exc.msg
raise exc
except CatchableError as exc:
debug "Exception reading frames", exc = exc.msg
return res
proc close*(
ws: WSSession,
code: Status = Status.Fulfilled,
reason: string = "") {.async.} =
## Close the Socket, sends close packet.
##
if ws.readyState != ReadyState.Open:
return
try:
ws.readyState = ReadyState.Closing
await ws.send(
prepareCloseBody(code, reason),
opcode = Opcode.Close)
# read frames until closed
while ws.readyState != ReadyState.Closed:
discard await ws.recv()
except CatchableError as exc:
debug "Exception closing", exc = exc.msg
proc initiateHandshake(
uri: Uri,
address: TransportAddress,