Fix first frame (#28)

* split out frame

* use new api

* fix import
This commit is contained in:
Dmitriy Ryajov 2021-05-24 18:47:27 -06:00 committed by GitHub
parent cdd5224905
commit 0a4121c29d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 370 additions and 315 deletions

View File

@ -3,7 +3,7 @@ import pkg/[chronos,
chronicles,
httputils]
import ../ws/ws
import ../ws/[ws, frame, errors]
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isOk():
@ -12,7 +12,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if request.uri.path == "/ws":
debug "Initiating web socket connection."
try:
let ws = await createServer(request, "")
let ws = await WebSocket.createServer(request, "")
if ws.readyState != Open:
error "Failed to open websocket connection."
return

View File

@ -4,7 +4,7 @@ import pkg/[chronos,
httputils,
stew/byteutils]
import ../ws/ws
import ../ws/[ws, frame, errors]
import ../tests/keys
let secureKey = TLSPrivateKey.init(SecureKey)
@ -18,7 +18,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if request.uri.path == "/wss":
debug "Initiating web socket connection."
try:
var ws = await createServer(request, "myfancyprotocol")
var ws = await WebSocket.createServer(request, "myfancyprotocol")
if ws.readyState != Open:
error "Failed to open websocket connection."
return

View File

@ -1,6 +1,6 @@
import unittest
import unittest, stew/byteutils
include ../ws/ws
include ../ws/frame
include ../ws/utils
# TODO: Fix Test.
@ -9,19 +9,18 @@ var maskKey: array[4, char]
suite "Test data frames":
test "# 7bit length text":
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
)) == toBytes("\1\8hi there")
data: toBytes("hi there")
).encode() == toBytes("\1\8hi there")
test "# 7bit length text fin bit":
check encodeFrame(Frame(
check Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -30,10 +29,10 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
)) == toBytes("\129\8hi there")
).encode() == toBytes("\129\8hi there")
test "# 7bit length binary":
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -42,10 +41,10 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
)) == toBytes("\2\8hi there")
).encode() == toBytes("\2\8hi there")
test "# 7bit length binary fin bit":
check encodeFrame(Frame(
check Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -54,10 +53,10 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
)) == toBytes("\130\8hi there")
).encode() == toBytes("\130\8hi there")
test "# 7bit length continuation":
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -66,14 +65,14 @@ suite "Test data frames":
mask: false,
data: toBytes("hi there"),
maskKey: maskKey
)) == toBytes("\0\8hi there")
).encode() == toBytes("\0\8hi there")
test "# 7+16 length text":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -82,14 +81,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\1\126\4\98" & data)
).encode() == toBytes("\1\126\4\98" & data)
test "# 7+16 length text fin bit":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -98,14 +97,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\1\126\4\98" & data)
).encode() == toBytes("\1\126\4\98" & data)
test "# 7+16 length binary":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -114,14 +113,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\2\126\4\98" & data)
).encode() == toBytes("\2\126\4\98" & data)
test "# 7+16 length binary fin bit":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -130,14 +129,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\130\126\4\98" & data)
).encode() == toBytes("\130\126\4\98" & data)
test "# 7+16 length continuation":
var data = ""
for i in 0..32:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -146,14 +145,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\0\126\4\98" & data)
).encode() == toBytes("\0\126\4\98" & data)
test "# 7+64 length text":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -162,14 +161,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\1\127\0\0\0\0\0\1\169\34" & data)
).encode() == toBytes("\1\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length fin bit":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -178,14 +177,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\129\127\0\0\0\0\0\1\169\34" & data)
).encode() == toBytes("\129\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length binary":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -194,14 +193,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\2\127\0\0\0\0\0\1\169\34" & data)
).encode() == toBytes("\2\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length binary fin bit":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -210,14 +209,14 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\130\127\0\0\0\0\0\1\169\34" & data)
).encode() == toBytes("\130\127\0\0\0\0\0\1\169\34" & data)
test "# 7+64 length binary":
var data = ""
for i in 0..3200:
data.add "How are you this is the payload!!!"
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -226,10 +225,10 @@ suite "Test data frames":
mask: false,
data: toBytes(data),
maskKey: maskKey
)) == toBytes("\0\127\0\0\0\0\0\1\169\34" & data)
).encode() == toBytes("\0\127\0\0\0\0\0\1\169\34" & data)
test "# masking":
let data = encodeFrame(Frame(
let data = Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -238,14 +237,14 @@ suite "Test data frames":
mask: true,
data: toBytes("hi there"),
maskKey: ['\xCF', '\xD8', '\x05', 'e']
))
).encode()
check data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0")
suite "Test control frames":
test "Close":
check encodeFrame(Frame(
check Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -254,10 +253,10 @@ suite "Test control frames":
mask: false,
data: @[3'u8, 232'u8] & toBytes("hi there"),
maskKey: maskKey
)) == toBytes("\136\10\3\232hi there")
).encode() == toBytes("\136\10\3\232hi there")
test "Ping":
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -265,10 +264,10 @@ suite "Test control frames":
opcode: Opcode.Ping,
mask: false,
maskKey: maskKey
)) == toBytes("\9\0")
).encode() == toBytes("\9\0")
test "Pong":
check encodeFrame(Frame(
check Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -276,4 +275,4 @@ suite "Test control frames":
opcode: Opcode.Pong,
mask: false,
maskKey: maskKey
)) == toBytes("\10\0")
).encode() == toBytes("\10\0")

View File

@ -6,7 +6,7 @@ import pkg/[asynctest,
chronos/apps/http/shttpserver,
stew/byteutils]
import ../ws/[ws, stream],
import ../ws/[ws, stream, errors],
../examples/tlsserver
import ./keys
@ -40,7 +40,7 @@ suite "Test websocket TLS handshake":
let request = r.get()
check request.uri.path == "/wss"
expect WSProtoMismatchError:
discard await createServer(request, "proto")
discard await WebSocket.createServer(request, "proto")
let res = SecureHttpServerRef.new(
address, cb,
@ -68,7 +68,7 @@ suite "Test websocket TLS handshake":
let request = r.get()
check request.uri.path == "/wss"
expect WSVersionError:
discard await createServer(request, "proto")
discard await WebSocket.createServer(request, "proto")
let res = SecureHttpServerRef.new(
address, cb,
@ -135,7 +135,7 @@ suite "Test websocket TLS transmission":
let request = r.get()
check request.uri.path == "/wss"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
await waitForClose(ws)
@ -169,7 +169,7 @@ suite "Test websocket TLS transmission":
let request = r.get()
check request.uri.path == "/wss"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
await ws.send(testString)
await ws.close()
return dumbResponse()

View File

@ -6,7 +6,7 @@ import pkg/[asynctest,
chronicles,
stew/byteutils]
import ../ws/[ws, stream, utils]
import ../ws/[ws, stream, utils, frame, errors]
var server: HttpServerRef
let address = initTAddress("127.0.0.1:8888")
@ -39,7 +39,7 @@ suite "Test handshake":
let request = r.get()
check request.uri.path == "/ws"
expect WSProtoMismatchError:
discard await createServer(request, "proto")
discard await WebSocket.createServer(request, "proto")
let res = HttpServerRef.new(address, cb)
server = res.get()
@ -59,7 +59,7 @@ suite "Test handshake":
let request = r.get()
check request.uri.path == "/ws"
expect WSVersionError:
discard await createServer(request, "proto")
discard await WebSocket.createServer(request, "proto")
let res = HttpServerRef.new(address, cb)
server = res.get()
@ -111,7 +111,7 @@ suite "Test handshake":
check request.uri.path == "/ws"
expect WSProtoMismatchError:
var ws = await createServer(request, "proto")
var ws = await WebSocket.createServer(request, "proto")
check ws.readyState == ReadyState.Closed
return await request.respond(Http200, "Connection established")
@ -142,14 +142,14 @@ suite "Test transmission":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
let res = HttpServerRef.new(
address, cb)
let res = HttpServerRef.new(address, cb)
server = res.get()
server.start()
@ -158,6 +158,7 @@ suite "Test transmission":
Port(8888),
path = "/ws",
protocols = @["proto"])
await wsClient.send(testString)
await wsClient.close()
@ -168,7 +169,7 @@ suite "Test transmission":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
await waitForClose(ws)
@ -198,7 +199,7 @@ suite "Test transmission":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
await ws.send(testString)
await ws.close()
@ -238,7 +239,7 @@ suite "Test ping-pong":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onPing = proc(data: openArray[byte]) =
@ -266,7 +267,7 @@ suite "Test ping-pong":
let maskKey = genMaskKey(newRng())
await wsClient.stream.writer.write(
encodeFrame(Frame(
Frame(
fin: false,
rsv1: false,
rsv2: false,
@ -274,12 +275,13 @@ suite "Test ping-pong":
opcode: Opcode.Text,
mask: true,
data: msg[0..4],
maskKey: maskKey)))
maskKey: maskKey)
.encode())
await wsClient.ping()
await wsClient.stream.writer.write(
encodeFrame(Frame(
Frame(
fin: true,
rsv1: false,
rsv2: false,
@ -287,7 +289,8 @@ suite "Test ping-pong":
opcode: Opcode.Cont,
mask: true,
data: msg[5..9],
maskKey: maskKey)))
maskKey: maskKey)
.encode())
await wsClient.close()
check:
@ -307,7 +310,7 @@ suite "Test ping-pong":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onPong = proc(data: openArray[byte]) =
@ -343,7 +346,7 @@ suite "Test ping-pong":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onPing = proc(data: openArray[byte]) =
@ -390,7 +393,7 @@ suite "Test framing":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let frame1 = await ws.readFrame()
check not isNil(frame1)
var data1 = newSeq[byte](frame1.remainder().int)
@ -432,7 +435,7 @@ suite "Test framing":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
await ws.send(testString)
await ws.close()
@ -473,7 +476,7 @@ suite "Test Closing":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
await ws.close()
let res = HttpServerRef.new(address, cb)
@ -512,7 +515,7 @@ suite "Test Closing":
return (Status.Fulfilled, "")
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onClose = closeServer)
@ -548,7 +551,7 @@ suite "Test Closing":
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
await waitForClose(ws)
let res = HttpServerRef.new(address, cb)
@ -577,7 +580,7 @@ suite "Test Closing":
except Exception as exc:
raise newException(Defect, exc.msg)
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onClose = closeServer)
@ -612,7 +615,7 @@ suite "Test Closing":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto")
@ -655,7 +658,7 @@ suite "Test Closing":
except Exception as exc:
raise newException(Defect, exc.msg)
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onClose = closeServer)
@ -686,7 +689,7 @@ suite "Test Closing":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
# Close with payload of length 2
await ws.close(reason = "HH")
@ -708,7 +711,7 @@ suite "Test Closing":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
await waitForClose(ws)
let res = HttpServerRef.new(
@ -731,7 +734,6 @@ suite "Test Closing":
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "Test Payload":
teardown:
await server.closeWait()
@ -742,7 +744,7 @@ suite "Test Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto")
@ -773,7 +775,7 @@ suite "Test Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr
await waitForClose(ws)
@ -799,7 +801,7 @@ suite "Test Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr
await waitForClose(ws)
@ -827,7 +829,7 @@ suite "Test Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onPing = proc(data: openArray[byte]) =
@ -874,7 +876,7 @@ suite "Test Binary message with Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check:
@ -904,7 +906,7 @@ suite "Test Binary message with Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let ws = await WebSocket.createServer(request, "proto")
let servRes = await ws.recv()
check:
@ -937,7 +939,7 @@ suite "Test Binary message with Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onPing = proc(data: openArray[byte]) =
@ -976,7 +978,7 @@ suite "Test Binary message with Payload":
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
let ws = await WebSocket.createServer(
request,
"proto",
onPing = proc(data: openArray[byte]) =

31
ws/errors.nim Normal file
View File

@ -0,0 +1,31 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
type
WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError
WSFailedUpgradeError* = object of WebSocketError
WSVersionError* = object of WebSocketError
WSProtoMismatchError* = object of WebSocketError
WSMaskMismatchError* = object of WebSocketError
WSHandshakeError* = object of WebSocketError
WSOpcodeMismatchError* = object of WebSocketError
WSRsvMismatchError* = object of WebSocketError
WSWrongUriSchemeError* = object of WebSocketError
WSMaxMessageSizeError* = object of WebSocketError
WSClosedError* = object of WebSocketError
WSSendError* = object of WebSocketError
WSPayloadTooLarge* = object of WebSocketError
WSReserverdOpcodeError* = object of WebSocketError
WSFragmentedControlFrameError* = object of WebSocketError
WSInvalidCloseCodeError* = object of WebSocketError
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError

207
ws/frame.nim Normal file
View File

@ -0,0 +1,207 @@
## Nim-Libp2p
## Copyright (c) 2020 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import pkg/[chronos, chronicles, stew/endians2, stew/results]
import ./errors
#[
+---------------------------------------------------------------+
|0 1 2 3 |
|0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1|
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
]#
type
Opcode* {.pure.} = enum
## 4 bits. Defines the interpretation of the "Payload data".
Cont = 0x0 ## Denotes a continuation frame.
Text = 0x1 ## Denotes a text frame.
Binary = 0x2 ## Denotes a binary frame.
# 3-7 are reserved for further non-control frames.
Close = 0x8 ## Denotes a connection close.
Ping = 0x9 ## Denotes a ping.
Pong = 0xa ## Denotes a pong.
# B-F are reserved for further control frames.
HeaderFlag* {.pure, size: sizeof(uint8).} = enum
rsv3
rsv2
rsv1
fin
HeaderFlags = set[HeaderFlag]
MaskKey = array[4, char]
Frame* = ref object
fin*: bool ## Indicates that this is the final fragment in a message.
rsv1*: bool ## MUST be 0 unless negotiated that defines meanings
rsv2*: bool ## MUST be 0
rsv3*: bool ## MUST be 0
opcode*: Opcode ## Defines the interpretation of the "Payload data".
mask*: bool ## Defines whether the "Payload data" is masked.
data*: seq[byte] ## Payload data
maskKey*: MaskKey ## Masking key
length*: uint64 ## Message size.
consumed*: uint64 ## how much has been consumed from the frame
proc mask*(
data: var openArray[byte],
maskKey: MaskKey,
offset = 0) =
## Unmask a data payload using key
##
for i in 0 ..< data.len:
data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8)
proc encode*(f: Frame, offset = 0): seq[byte] =
## Encodes a frame into a string buffer.
## See https://tools.ietf.org/html/rfc6455#section-5.2
var ret: seq[byte]
var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags.
if f.fin:
b0 = b0 or 128'u8
ret.add(b0)
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
# 1st byte: payload len start and mask bit.
var b1 = 0'u8
if f.data.len <= 125:
b1 = f.data.len.uint8
elif f.data.len > 125 and f.data.len <= 0xFFFF:
b1 = 126'u8
else:
b1 = 127'u8
if f.mask:
b1 = b1 or (1 shl 7)
ret.add(uint8 b1)
# Only need more bytes if data len is 7+16 bits, or 7+64 bits.
if f.data.len > 125 and f.data.len <= 0xFFFF:
# Data len is 7+16 bits.
var len = f.data.len.uint16
ret.add ((len shr 8) and 0xFF).uint8
ret.add (len and 0xFF).uint8
elif f.data.len > 0xFFFF:
# Data len is 7+64 bits.
var len = f.data.len.uint64
ret.add(len.toBytesBE())
var data = f.data
if f.mask:
# If we need to mask it generate random mask key and mask the data.
mask(data, f.maskKey, offset)
# Write mask key next.
ret.add(f.maskKey[0].uint8)
ret.add(f.maskKey[1].uint8)
ret.add(f.maskKey[2].uint8)
ret.add(f.maskKey[3].uint8)
# Write the data.
ret.add(data)
return ret
proc decode*(
_: typedesc[Frame],
reader: AsyncStreamReader,
masked: bool):
Future[Frame] {.async.} =
## Read and Decode incoming header
##
var header = newSeq[byte](2)
await reader.readExactly(addr header[0], 2)
if header.len != 2:
debug "Invalid websocket header length"
raise newException(WSMalformedHeaderError,
"Invalid websocket header length")
let b0 = header[0].uint8
let b1 = header[1].uint8
var frame = Frame()
# Read the flags and fin from the header.
var hf = cast[HeaderFlags](b0 shr 4)
frame.fin = HeaderFlag.fin in hf
frame.rsv1 = HeaderFlag.rsv1 in hf
frame.rsv2 = HeaderFlag.rsv2 in hf
frame.rsv3 = HeaderFlag.rsv3 in hf
let opcode = (b0 and 0x0f)
if opcode > ord(Opcode.high):
raise newException(WSOpcodeMismatchError, "Wrong opcode!")
frame.opcode = (opcode).Opcode
# If any of the rsv are set close the socket.
if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
var finalLen: uint64 = 0
let headerLen = uint(b1 and 0x7f)
if headerLen == 0x7e:
# Length must be 7+16 bits.
var length = newSeq[byte](2)
await reader.readExactly(addr length[0], 2)
finalLen = uint16.fromBytesBE(length)
elif headerLen == 0x7f:
# Length must be 7+64 bits.
var length = newSeq[byte](8)
await reader.readExactly(addr length[0], 8)
finalLen = uint64.fromBytesBE(length)
else:
# Length must be 7 bits.
finalLen = headerLen
frame.length = finalLen
# Do we need to apply mask?
frame.mask = (b1 and 0x80) == 0x80
if masked == frame.mask:
# Server sends unmasked but accepts only masked.
# Client sends masked but accepts only unmasked.
raise newException(WSMaskMismatchError,
"Socket mask mismatch")
var maskKey = newSeq[byte](4)
if frame.mask:
# Read the mask.
await reader.readExactly(addr maskKey[0], 4)
for i in 0..<maskKey.len:
frame.maskKey[i] = cast[char](maskKey[i])
return frame

274
ws/ws.nim
View File

@ -1,3 +1,13 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import std/[tables,
@ -18,29 +28,7 @@ import pkg/[chronos,
stew/base10,
nimcrypto/sha]
import ./utils, ./stream
#[
+---------------------------------------------------------------+
|0 1 2 3 |
|0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1|
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
]#
import ./utils, ./stream, ./frame, ./errors
const
SHA1DigestSize* = 20
@ -58,48 +46,9 @@ type
Closing = 2 # The connection is in the process of closing.
Closed = 3 # The connection is closed or couldn't be opened.
WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError
WSFailedUpgradeError* = object of WebSocketError
WSVersionError* = object of WebSocketError
WSProtoMismatchError* = object of WebSocketError
WSMaskMismatchError* = object of WebSocketError
WSHandshakeError* = object of WebSocketError
WSOpcodeMismatchError* = object of WebSocketError
WSRsvMismatchError* = object of WebSocketError
WSWrongUriSchemeError* = object of WebSocketError
WSMaxMessageSizeError* = object of WebSocketError
WSClosedError* = object of WebSocketError
WSSendError* = object of WebSocketError
WSPayloadTooLarge* = object of WebSocketError
WSReserverdOpcodeError* = object of WebSocketError
WSFragmentedControlFrameError* = object of WebSocketError
WSInvalidCloseCodeError* = object of WebSocketError
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError
HeaderFlag* {.size: sizeof(uint8).} = enum
rsv3
rsv2
rsv1
fin
HeaderFlags = set[HeaderFlag]
HttpCode* = enum
Http101 = 101 # Switching Protocols
Opcode* {.pure.} = enum
## 4 bits. Defines the interpretation of the "Payload data".
Cont = 0x0 ## Denotes a continuation frame.
Text = 0x1 ## Denotes a text frame.
Binary = 0x2 ## Denotes a binary frame.
# 3-7 are reserved for further non-control frames.
Close = 0x8 ## Denotes a connection close.
Ping = 0x9 ## Denotes a ping.
Pong = 0xa ## Denotes a pong.
# B-F are reserved for further control frames.
Status* {.pure.} = enum
# 0-999 not used
Fulfilled = 1000
@ -118,19 +67,8 @@ type
# 3000-3999 reserved for libs
# 4000-4999 reserved for applications
Frame* = ref object
fin*: bool ## Indicates that this is the final fragment in a message.
rsv1*: bool ## MUST be 0 unless negotiated that defines meanings
rsv2*: bool ## MUST be 0
rsv3*: bool ## MUST be 0
opcode*: Opcode ## Defines the interpretation of the "Payload data".
mask*: bool ## Defines whether the "Payload data" is masked.
data*: seq[byte] ## Payload data
maskKey*: array[4, char] ## Masking key
length*: uint64 ## Message size.
consumed*: uint64 ## how much has been consumed from the frame
ControlCb* = proc(data: openArray[byte] = []) {.gcsafe, raises: [Defect].}
ControlCb* = proc(data: openArray[byte] = [])
{.gcsafe, raises: [Defect].}
CloseResult* = tuple
code: Status
@ -170,16 +108,6 @@ proc `$`(ht: HttpTables): string =
res.add(CRLF)
res
proc mask*(
data: var openArray[byte],
maskKey: array[4, char],
offset = 0) =
## Unmask a data payload using key
##
for i in 0 ..< data.len:
data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8)
proc prepareCloseBody(code: Status, reason: string): seq[byte] =
result = reason.toBytes
if ord(code) > 999:
@ -234,6 +162,7 @@ proc handshake*(
ws.readyState = ReadyState.Open
proc createServer*(
_: typedesc[WebSocket],
request: HttpRequestRef,
protocol: string = "",
frameSize = WSDefaultFrameSize,
@ -263,60 +192,6 @@ proc createServer*(
await ws.handshake(request)
return ws
proc encodeFrame*(f: Frame, offset = 0): seq[byte] =
## Encodes a frame into a string buffer.
## See https://tools.ietf.org/html/rfc6455#section-5.2
var ret: seq[byte]
var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags.
if f.fin:
b0 = b0 or 128'u8
ret.add(b0)
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
# 1st byte: payload len start and mask bit.
var b1 = 0'u8
if f.data.len <= 125:
b1 = f.data.len.uint8
elif f.data.len > 125 and f.data.len <= 0xFFFF:
b1 = 126'u8
else:
b1 = 127'u8
if f.mask:
b1 = b1 or (1 shl 7)
ret.add(uint8 b1)
# Only need more bytes if data len is 7+16 bits, or 7+64 bits.
if f.data.len > 125 and f.data.len <= 0xFFFF:
# Data len is 7+16 bits.
var len = f.data.len.uint16
ret.add ((len shr 8) and 0xFF).uint8
ret.add (len and 0xFF).uint8
elif f.data.len > 0xFFFF:
# Data len is 7+64 bits.
var len = f.data.len.uint64
ret.add(len.toBytesBE())
var data = f.data
if f.mask:
# If we need to mask it generate random mask key and mask the data.
mask(data, f.maskKey, offset)
# Write mask key next.
ret.add(f.maskKey[0].uint8)
ret.add(f.maskKey[1].uint8)
ret.add(f.maskKey[2].uint8)
ret.add(f.maskKey[3].uint8)
# Write the data.
ret.add(data)
return ret
proc send*(
ws: WebSocket,
data: seq[byte] = @[],
@ -339,19 +214,21 @@ proc send*(
maskKey = genMaskKey(ws.rng)
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}:
return
await ws.stream.writer.write(
encodeFrame(Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: opcode,
mask: ws.masked,
data: data, # allow sending data with close messages
maskKey: maskKey)))
Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: opcode,
mask: ws.masked,
data: data, # allow sending data with close messages
maskKey: maskKey)
.encode())
return
@ -360,7 +237,7 @@ proc send*(
while ws.readyState notin {ReadyState.Closing}:
let len = min(data.len, (maxSize + i))
await ws.stream.writer.write(
encodeFrame(Frame(
Frame(
fin: if (i + len >= data.len): true else: false,
rsv1: false,
rsv2: false,
@ -368,7 +245,8 @@ proc send*(
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
mask: ws.masked,
data: data[i ..< len],
maskKey: maskKey)))
maskKey: maskKey)
.encode())
i += len
if i >= data.len:
@ -491,69 +369,7 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
try:
while ws.readyState != ReadyState.Closed:
# Grab the header.
var header = newSeq[byte](2)
await ws.stream.reader.readExactly(addr header[0], 2)
if header.len != 2:
debug "Invalid websocket header length"
raise newException(WSMalformedHeaderError,
"Invalid websocket header length")
let b0 = header[0].uint8
let b1 = header[1].uint8
var frame = Frame()
# Read the flags and fin from the header.
var hf = cast[HeaderFlags](b0 shr 4)
frame.fin = fin in hf
frame.rsv1 = rsv1 in hf
frame.rsv2 = rsv2 in hf
frame.rsv3 = rsv3 in hf
let opcode = (b0 and 0x0f)
if opcode > ord(Opcode.high):
raise newException(WSOpcodeMismatchError, "Wrong opcode!")
frame.opcode = (opcode).Opcode
# If any of the rsv are set close the socket.
if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
var finalLen: uint64 = 0
let headerLen = uint(b1 and 0x7f)
if headerLen == 0x7e:
# Length must be 7+16 bits.
var length = newSeq[byte](2)
await ws.stream.reader.readExactly(addr length[0], 2)
finalLen = uint16.fromBytesBE(length)
elif headerLen == 0x7f:
# Length must be 7+64 bits.
var length = newSeq[byte](8)
await ws.stream.reader.readExactly(addr length[0], 8)
finalLen = uint64.fromBytesBE(length)
else:
# Length must be 7 bits.
finalLen = headerLen
frame.length = finalLen
# Do we need to apply mask?
frame.mask = (b1 and 0x80) == 0x80
if ws.masked == frame.mask:
# Server sends unmasked but accepts only masked.
# Client sends masked but accepts only unmasked.
raise newException(WSMaskMismatchError, "Socket mask mismatch")
var maskKey = newSeq[byte](4)
if frame.mask:
# Read the mask.
await ws.stream.reader.readExactly(addr maskKey[0], 4)
for i in 0..<maskKey.len:
frame.maskKey[i] = cast[char](maskKey[i])
let frame = await Frame.decode(ws.stream.reader, ws.masked)
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
# return the current frame if it's not one of the control frames
@ -591,34 +407,34 @@ proc recv*(
var consumed = 0
var pbuffer = cast[ptr UncheckedArray[byte]](data)
try:
# read the first frame
ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed
if ws.frame.opcode == Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"First frame cannot be continue frame")
while consumed < size:
# we might have to read more than
# one frame to fill the buffer
if (not ws.frame.fin and ws.frame.remainder() <= 0):
# TODO: Figure out a cleaner way to handle
# retrieving new frames
if isNil(ws.frame):
ws.frame = await ws.readFrame()
if isNil(ws.frame):
return consumed
if ws.frame.opcode == Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Text or Binary frame")
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed
if ws.frame.opcode != Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"expected continue frame")
"Expected Continuation frame")
ws.binary = ws.frame.opcode == Opcode.Binary
if ws.frame.fin and ws.frame.remainder().int <= 0:
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
if ws.frame.fin and ws.frame.remainder() <= 0:
ws.frame = nil
break