Fix partial frame handling and allow extensions to hijack the flow (#56)

* moving files around

* wip

* wip

* move tls example into server example

* add tls functionality

* rename

* rename

* fix tests

* move extension related files to own folder

* use trace instead of debug

* export extensions

* rework partial frame handling and closing

* rework status codes as distincts

* logging

* re-enable extensions processing for frames

* enable all test for non-tls server

* remove tlsserver

* remove offset to mask - don't think we need it

* pass sessions extensions when calling send/recv

* adding encode/decode extensions flow test

* move server/client setup to helpers

* proper frame order execution on decode

* fix tls tests
This commit is contained in:
Dmitriy Ryajov 2021-06-11 14:04:09 -06:00 committed by GitHub
parent e632202037
commit 3e1599d790
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1159 additions and 535 deletions

View File

@ -232,8 +232,8 @@ jobs:
kill $pid
cd ..
nim c examples/tlsserver.nim
examples/tlsserver &
nim -d:tls c examples/server.nim
examples/server &
pid=$!
cd autobahn
wstest --mode fuzzingclient --spec fuzzingclient_tls.json

View File

@ -7,6 +7,6 @@
}
],
"cases": ["*"],
"exclude-cases": ["9.*", "12.*", "13.*"],
"exclude-cases": [],
"exclude-agent-cases": {}
}

View File

@ -6,12 +6,19 @@ import pkg/[
import ../ws/ws
proc main() {.async.} =
let ws = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws")
let ws = when defined tls:
await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/wss",
flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName})
else:
await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws")
debug "Websocket client: ", State = ws.readyState
trace "Websocket client: ", State = ws.readyState
let reqData = "Hello Server"
while true:
@ -22,7 +29,7 @@ proc main() {.async.} =
break
let dataStr = string.fromBytes(buff)
debug "Server Response: ", data = dataStr
trace "Server Response: ", data = dataStr
assert dataStr == reqData
break

View File

@ -5,13 +5,15 @@ import pkg/[chronos,
httputils]
import ../ws/ws
import ../tests/keys
proc handle(request: HttpRequest) {.async.} =
debug "Handling request:", uri = request.uri.path
if request.uri.path != "/ws":
trace "Handling request:", uri = request.uri.path
let path = when defined tls: "/wss" else: "/ws"
if request.uri.path != path:
return
debug "Initiating web socket connection."
trace "Initiating web socket connection."
try:
let server = WSServer.new()
let ws = await server.handleRequest(request)
@ -19,16 +21,14 @@ proc handle(request: HttpRequest) {.async.} =
error "Failed to open websocket connection"
return
debug "Websocket handshake completed"
while true:
trace "Websocket handshake completed"
while ws.readyState != ReadyState.Closed:
let recvData = await ws.recv()
if ws.readyState == ReadyState.Closed:
debug "Websocket closed"
break
debug "Client Response: ", size = recvData.len
trace "Client Response: ", size = recvData.len, binary = ws.binary
await ws.send(recvData,
if ws.binary: Opcode.Binary else: Opcode.Text)
except WebSocketError as exc:
error "WebSocket error:", exception = exc.msg
@ -37,10 +37,18 @@ when isMainModule:
let
address = initTAddress("127.0.0.1:8888")
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
server = HttpServer.create(address, handle, flags = socketFlags)
server = when defined tls:
TlsHttpServer.create(
address = address,
handler = handle,
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
tlsCertificate = TLSCertificate.init(SecureCert),
flags = socketFlags)
else:
HttpServer.create(address, handle, flags = socketFlags)
server.start()
info "Server listening at ", data = $server.localAddress()
trace "Server listening on ", data = $server.localAddress()
await server.join()
waitFor(main())

View File

@ -1,35 +0,0 @@
import pkg/[chronos,
chronos/streams/tlsstream,
chronicles,
stew/byteutils]
import ../ws/ws
proc main() {.async.} =
let ws = await WebSocket.tlsConnect(
"127.0.0.1",
Port(8888),
path = "/wss",
protocols = @["myfancyprotocol"],
flags = {NoVerifyHost, NoVerifyServerName})
debug "Websocket client: ", State = ws.readyState
let reqData = "Hello Server"
try:
debug "sending client "
await ws.send(reqData)
let buff = await ws.recv()
if buff.len <= 0:
break
let dataStr = string.fromBytes(buff)
debug "Server:", data = dataStr
assert dataStr == reqData
return # bail out
except WebSocketError as exc:
error "WebSocket error:", exception = exc.msg
# close the websocket
await ws.close()
waitFor(main())

View File

@ -1,54 +0,0 @@
import pkg/[chronos,
chronicles,
httputils,
stew/byteutils]
import pkg/[chronos/streams/tlsstream]
import ../ws/ws
import ../tests/keys
proc handle(request: HttpRequest) {.async.} =
debug "Handling request:", uri = request.uri.path
if request.uri.path != "/wss":
debug "Initiating web socket connection."
return
try:
let server = WSServer.new(protos = ["myfancyprotocol"])
var ws = await server.handleRequest(request)
if ws.readyState != Open:
error "Failed to open websocket connection."
return
debug "Websocket handshake completed."
# Only reads header for data frame.
echo "receiving server "
let recvData = await ws.recv()
if recvData.len <= 0:
debug "Empty messages"
break
if ws.readyState == ReadyState.Closed:
return
debug "Response: ", data = string.fromBytes(recvData)
await ws.send(recvData,
if ws.binary: Opcode.Binary else: Opcode.Text)
except WebSocketError:
error "WebSocket error:", exception = getCurrentExceptionMsg()
when isMainModule:
proc main() {.async.} =
let address = initTAddress("127.0.0.1:8888")
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let server = TlsHttpServer.create(
address = address,
handler = handle,
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
tlsCertificate = TLSCertificate.init(SecureCert),
flags = socketFlags)
server.start()
info "Server listening at ", data = $server.localAddress()
await server.join()
waitFor(main())

View File

@ -0,0 +1,276 @@
import std/strutils
import pkg/[chronos, stew/byteutils]
import ../../ws/ws
import ../asyncunit
type
ExtHandler = proc(ext: Ext, frame: Frame): Future[Frame] {.raises: [Defect].}
HelperExtension = ref object of Ext
handler*: ExtHandler
proc new*(
T: typedesc[HelperExtension],
handler: ExtHandler,
session: WSSession = nil): HelperExtension =
HelperExtension(
handler: handler,
name: "HelperExtension")
method decode*(
self: HelperExtension,
frame: Frame): Future[Frame] {.async.} =
return await self.handler(self, frame)
method encode*(
self: HelperExtension,
frame: Frame): Future[Frame] {.async.} =
return await self.handler(self, frame)
const TestString = "Hello"
suite "Encode frame extensions flow":
test "should call extension on encode":
var data = ""
proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "toUpper executed"
data = string.fromBytes(frame.data).toUpper()
check TestString.toUpper() == data
frame.data = data.toBytes()
return frame
var frame = Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: false,
data: TestString.toBytes())
discard await frame.encode(@[HelperExtension.new(toUpper).Ext])
check frame.data == TestString.toUpper().toBytes()
test "should call extensions in correct order on encode":
var count = 0
proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "first executed"
check count == 0
count.inc
return frame
proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "second executed"
check count == 1
count.inc
return frame
var frame = Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: false,
data: TestString.toBytes())
discard await frame.encode(@[
HelperExtension.new(first).Ext,
HelperExtension.new(second).Ext])
check count == 2
test "should allow modifying frame headers":
proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "changeHeader executed"
frame.rsv1 = true
frame.rsv2 = true
frame.rsv3 = true
frame.opcode = Opcode.Binary
return frame
var frame = Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text, # fragments have to be `Continuation` frames
mask: false,
data: TestString.toBytes())
discard await frame.encode(@[HelperExtension.new(changeHeader).Ext])
check:
frame.rsv1 == true
frame.rsv2 == true
frame.rsv2 == true
frame.opcode == Opcode.Binary
suite "Decode frame extensions flow":
var
address: TransportAddress
server: StreamServer
maskKey = genMaskKey(newRng())
transport: StreamTransport
reader: AsyncStreamReader
frame: Frame
setup:
server = createStreamServer(
initTAddress("127.0.0.1:0"),
flags = {ServerFlags.ReuseAddr})
address = server.localAddress()
teardown:
await transport.closeWait()
await server.closeWait()
server.stop()
test "should call extension on decode":
var data = ""
proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "toUpper executed"
try:
var buf = newSeq[byte](frame.length)
# read data
await reader.readExactly(addr buf[0], buf.len)
if frame.mask:
mask(buf, maskKey)
frame.mask = false # we can reset the mask key here
data = string.fromBytes(buf).toUpper()
check:
TestString.toUpper() == data
frame.data = data.toBytes()
return frame
except CatchableError as exc:
checkpoint exc.msg
check false
proc acceptHandler() {.async, gcsafe.} =
let transport = await server.accept()
reader = newAsyncStreamReader(transport)
frame = await Frame.decode(
reader,
false,
@[HelperExtension.new(toUpper).Ext])
await reader.closeWait()
await transport.closeWait()
let handlerWait = acceptHandler()
var encodedFrame = (await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
maskKey: maskKey,
data: TestString.toBytes())
.encode())
transport = await connect(address)
let wrote = await transport.write(encodedFrame)
await handlerWait
check:
wrote == encodedFrame.len
frame.data == TestString.toUpper().toBytes()
test "should call extensions in reverse order on decode":
var count = 0
proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "first executed"
check count == 1
count.inc
return frame
proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "second executed"
check count == 0
count.inc
return frame
proc acceptHandler() {.async, gcsafe.} =
let transport = await server.accept()
reader = newAsyncStreamReader(transport)
frame = await Frame.decode(
reader,
false,
@[HelperExtension.new(first).Ext,
HelperExtension.new(second).Ext])
await reader.closeWait()
await transport.closeWait()
let handlerWait = acceptHandler()
var encodedFrame = (await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
maskKey: maskKey,
data: TestString.toBytes())
.encode())
let transport = await connect(address)
let wrote = await transport.write(encodedFrame)
await handlerWait
check:
wrote == encodedFrame.len
count == 2
test "should allow modifying frame headers":
proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "changeHeader executed"
frame.rsv1 = false
frame.rsv2 = false
frame.rsv3 = false
frame.opcode = Opcode.Binary
return frame
proc acceptHandler() {.async, gcsafe.} =
let transport = await server.accept()
reader = newAsyncStreamReader(transport)
frame = await Frame.decode(
reader,
false,
@[HelperExtension.new(changeHeader).Ext])
check:
frame.rsv1 == false
frame.rsv2 == false
frame.rsv2 == false
frame.opcode == Opcode.Binary
await reader.closeWait()
await transport.closeWait()
let handlerWait = acceptHandler()
var encodedFrame = (await Frame(
fin: false,
rsv1: true,
rsv2: true,
rsv3: true,
opcode: Opcode.Text,
mask: true,
maskKey: maskKey,
data: TestString.toBytes())
.encode())
let transport = await connect(address)
let wrote = await transport.write(encodedFrame)
await handlerWait
check:
wrote == encodedFrame.len

81
tests/helpers.nim Normal file
View File

@ -0,0 +1,81 @@
import std/[strutils, random]
import pkg/[
chronos,
chronos/streams/tlsstream,
httputils,
chronicles,
stew/byteutils]
import ../ws/ws
import ./keys
let
WSSecureKey* = TLSPrivateKey.init(SecureKey)
WSSecureCert* = TLSCertificate.init(SecureCert)
const WSPath* = when defined secure: "/wss" else: "/ws"
proc rndStr*(size: int): string =
for _ in 0..<size:
add(result, char(rand(int('A') .. int('z'))))
proc rndBin*(size: int): seq[byte] =
for _ in 0..<size:
add(result, byte(rand(0 .. 255)))
proc waitForClose*(ws: WSSession) {.async.} =
try:
while ws.readystate != ReadyState.Closed:
discard await ws.recv()
except CatchableError:
trace "Closing websocket"
proc createServer*(
address = initTAddress("127.0.0.1:8888"),
tlsPrivateKey = WSSecureKey,
tlsCertificate = WSSecureCert,
handler: HttpAsyncCallback = nil,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr},
tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12): HttpServer =
when defined secure:
TlsHttpServer.create(
address = address,
tlsPrivateKey = tlsPrivateKey,
tlsCertificate = tlsCertificate,
handler = handler,
flags = flags,
tlsFlags = tlsFlags,
tlsMinVersion = tlsMinVersion,
tlsMaxVersion = tlsMaxVersion)
else:
HttpServer.create(
address = address,
handler = handler,
flags = flags)
proc connectClient*(
address = initTAddress("127.0.0.1:8888"),
path = WSPath,
protocols: seq[string] = @["proto"],
flags: set[TLSFlags] = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName},
version = WSDefaultVersion,
frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil,
onPong: ControlCb = nil,
onClose: CloseCb = nil,
rng: Rng = nil): Future[WSSession] {.async.} =
let secure = when defined secure: true else: false
return await WebSocket.connect(
address = address,
flags = flags,
path = path,
secure = secure,
protocols = protocols,
version = version,
frameSize = frameSize,
onPing = onPing,
onPong = onPong,
onClose = onClose,
rng = rng)

View File

@ -2,4 +2,4 @@
import ./testframes
import ./testutf8
import ./test_ext_utils
import ./testextutils

View File

@ -9,7 +9,7 @@
import
pkg/[asynctest, chronos],
../ws/ext_utils
../ws/extensions
suite "extension parser":
test "single extension":
@ -222,47 +222,47 @@ suite "extension parser":
var app: seq[AppExt]
let res = parseExt("filename=foo.html, filename=bar.html", app)
check res == false
test "emptydisposition":
var app: seq[AppExt]
let res = parseExt(" ; filename=foo.html", app)
check res == false
test "doublecolon":
var app: seq[AppExt]
let res = parseExt(": inline; attachment; filename=foo.html", app)
check res == false
test "attbrokenquotedfn":
var app: seq[AppExt]
let res = parseExt(" attachment; filename=\"foo.html\".txt", app)
check res == false
test "attbrokenquotedfn2":
var app: seq[AppExt]
let res = parseExt("attachment; filename=\"bar", app)
check res == false
test "attbrokenquotedfn3":
var app: seq[AppExt]
let res = parseExt("attachment; filename=foo\"bar;baz\"qux", app)
check res == false
test "attmissingdelim":
var app: seq[AppExt]
let res = parseExt("attachment; foo=foo filename=bar", app)
check res == false
test "attmissingdelim2":
var app: seq[AppExt]
let res = parseExt("attachment; filename=bar foo=foo", app)
check res == false
test "attmissingdelim3":
var app: seq[AppExt]
let res = parseExt("attachment filename=bar", app)
check res == false
test "attreversed":
var app: seq[AppExt]
let res = parseExt("filename=foo.html; attachment", app)

View File

@ -15,7 +15,7 @@ import
chronos,
chronicles
],
../ws/[ws, utf8_dfa]
../ws/[ws, utf8dfa]
suite "UTF-8 DFA validator":
test "single octet":
@ -76,7 +76,7 @@ proc waitForClose(ws: WSSession) {.async.} =
while ws.readystate != ReadyState.Closed:
discard await ws.recv()
except CatchableError:
debug "Closing websocket"
trace "Closing websocket"
# TODO: use new test framework from dryajov
# if it is ready.
@ -125,10 +125,10 @@ suite "UTF-8 validator in action":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == "/ws"
proc onClose(status: Status, reason: string):
proc onClose(status: StatusCodes, reason: string):
CloseResult {.gcsafe, raises: [Defect].} =
try:
check status == Status.Fulfilled
check status == StatusFulfilled
check reason == closeReason
return (status, reason)
except Exception as exc:
@ -137,6 +137,8 @@ suite "UTF-8 validator in action":
let server = WSServer.new(protos = ["proto"], onClose = onClose)
let ws = await server.handleRequest(request)
let res = await ws.recv()
await waitForClose(ws)
check:
string.fromBytes(res) == testData
ws.binary == false
@ -168,6 +170,7 @@ suite "UTF-8 validator in action":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
discard await ws.recv()
await waitForClose(ws)
server = HttpServer.create(
address,
@ -183,7 +186,7 @@ suite "UTF-8 validator in action":
)
await session.send(testData)
await waitForClose( session)
await session.close()
check session.readyState == ReadyState.Closed
test "invalid UTF-8 sequence close code":
@ -201,6 +204,8 @@ suite "UTF-8 validator in action":
string.fromBytes(res) == testData
ws.binary == false
await waitForClose(ws)
server = HttpServer.create(
address,
handle,
@ -216,5 +221,4 @@ suite "UTF-8 validator in action":
await session.send(testData)
await session.close(reason = closeReason)
await waitForClose( session)
check session.readyState == ReadyState.Closed

View File

@ -5,85 +5,16 @@ import pkg/[
chronicles,
stew/byteutils]
import ./asynctest
import ../ws/ws
import ./keys
var server: HttpServer
import ./asynctest
import ./helpers
let
address = initTAddress("127.0.0.1:8888")
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
clientFlags = {NoVerifyHost, NoVerifyServerName}
secureKey = TLSPrivateKey.init(SecureKey)
secureCert = TLSCertificate.init(SecureCert)
address* = initTAddress("127.0.0.1:8888")
const WSPath = when defined secure: "/wss" else: "/ws"
proc rndStr*(size: int): string =
for _ in .. size:
add(result, char(rand(int('A') .. int('z'))))
proc rndBin*(size: int): seq[byte] =
for _ in .. size:
add(result, byte(rand(0 .. 255)))
proc waitForClose(ws: WSSession) {.async.} =
try:
while ws.readystate != ReadyState.Closed:
discard await ws.recv()
except CatchableError:
debug "Closing websocket"
proc createServer(
address = initTAddress("127.0.0.1:8888"),
tlsPrivateKey = secureKey,
tlsCertificate = secureCert,
handler: HttpAsyncCallback = nil,
flags: set[ServerFlags] = socketFlags,
tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12): HttpServer =
when defined secure:
TlsHttpServer.create(
address = address,
tlsPrivateKey = tlsPrivateKey,
tlsCertificate = tlsCertificate,
handler = handler,
flags = flags,
tlsFlags = tlsFlags,
tlsMinVersion = tlsMinVersion,
tlsMaxVersion = tlsMaxVersion)
else:
HttpServer.create(
address = address,
handler = handler,
flags = flags)
proc connectClient*(
address = initTAddress("127.0.0.1:8888"),
path = WSPath,
protocols: seq[string] = @["proto"],
flags: set[TLSFlags] = clientFlags,
version = WSDefaultVersion,
frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil,
onPong: ControlCb = nil,
onClose: CloseCb = nil,
rng: Rng = nil): Future[WSSession] {.async.} =
let secure = when defined secure: true else: false
return await WebSocket.connect(
address = address,
flags = flags,
path = path,
secure = secure,
protocols = protocols,
version = version,
frameSize = frameSize,
onPing = onPing,
onPong = onPong,
onClose = onClose,
rng = rng)
var
server: HttpServer
suite "Test handshake":
teardown:
@ -175,25 +106,20 @@ suite "Test handshake":
parseUri(uri),
protocols = @["proto"])
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
suite "Test transmission":
teardown:
server.stop()
await server.closeWait()
test "Send text message message with payload of length 65535":
let testString = rndStr(65535)
test "Server - test reading simple frame":
let testString = "Hello!"
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
await ws.waitForClose()
@ -207,15 +133,13 @@ suite "Test transmission":
await session.send(testString)
await session.close()
test "Server - test reading simple frame":
let testString = "Hello!"
test "Send text message message with payload of length 65535":
let testString = rndStr(65535)
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
await ws.waitForClose()
@ -250,83 +174,11 @@ suite "Test transmission":
check string.fromBytes(clientRes) == testString
await waitForClose(session)
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
suite "Test ping-pong":
teardown:
server.stop()
await server.closeWait()
test "Send text Message fragmented into 2 fragments, one ping with payload in-between":
var ping, pong = false
let testString = "1234567890"
let msg = toBytes(testString)
let maxFrameSize = 5
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(
protos = ["proto"],
onPing = proc(data: openArray[byte]) =
ping = true
)
let ws = await server.handleRequest(request)
let respData = await ws.recv()
check string.fromBytes(respData) == testString
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let session = await connectClient(
address = initTAddress("127.0.0.1:8888"),
frameSize = maxFrameSize,
onPong = proc(data: openArray[byte]) =
pong = true
)
let maskKey = genMaskKey(newRng())
await session.stream.writer.write(
(await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
data: msg[0..4],
maskKey: maskKey)
.encode()))
await session.ping()
await session.stream.writer.write(
(await Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Cont,
mask: true,
data: msg[5..9],
maskKey: maskKey)
.encode()))
await session.close()
check:
ping
pong
test "Server - test ping-pong control messages":
var ping, pong = false
proc handle(request: HttpRequest) {.async.} =
@ -389,12 +241,59 @@ suite "Test ping-pong":
await session.ping()
await session.close()
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
test "Send ping with small text payload":
let testData = toBytes("Hello, world!")
var ping, pong = false
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(
protos = ["proto"],
onPing = proc(data: openArray[byte]) =
ping = data == testData)
let ws = await server.handleRequest(request)
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let session = await connectClient(
address = initTAddress("127.0.0.1:8888"),
onPong = proc(data: openArray[byte]) =
pong = true
)
await session.ping(testData)
await session.close()
check:
ping
pong
test "Test ping payload message length":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
expect WSPayloadTooLarge:
discard await ws.recv()
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let str = rndStr(126)
let session = await connectClient()
await session.ping(str.toBytes())
await session.close()
suite "Test framing":
teardown:
@ -408,13 +307,13 @@ suite "Test framing":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let frame1 = await ws.readFrame()
let frame1 = await ws.readFrame(@[])
check not isNil(frame1)
var data1 = newSeq[byte](frame1.remainder().int)
let read1 = await ws.stream.reader.readOnce(addr data1[0], data1.len)
check read1 == 5
let frame2 = await ws.readFrame()
let frame2 = await ws.readFrame(@[])
check not isNil(frame2)
var data2 = newSeq[byte](frame2.remainder().int)
let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len)
@ -454,16 +353,8 @@ suite "Test framing":
expect WSMaxMessageSizeError:
discard await session.recv(5)
await waitForClose(session)
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
suite "Test Closing":
teardown:
server.stop()
@ -490,15 +381,15 @@ suite "Test Closing":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
proc closeServer(status: StatusCodes, reason: string): CloseResult{.gcsafe,
raises: [Defect].} =
try:
check status == Status.TooLarge
check status == StatusTooLarge
check reason == "Message too big!"
except Exception as exc:
raise newException(Defect, exc.msg)
return (Status.Fulfilled, "")
return (StatusFulfilled, "")
let server = WSServer.new(
protos = ["proto"],
@ -514,11 +405,11 @@ suite "Test Closing":
flags = {ReuseAddr})
server.start()
proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
proc clientClose(status: StatusCodes, reason: string): CloseResult {.gcsafe,
raises: [Defect].} =
try:
check status == Status.Fulfilled
return (Status.TooLarge, "Message too big!")
check status == StatusFulfilled
return (StatusTooLarge, "Message too big!")
except Exception as exc:
raise newException(Defect, exc.msg)
@ -548,11 +439,11 @@ suite "Test Closing":
test "Client closing with status":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
proc closeServer(status: StatusCodes, reason: string): CloseResult{.gcsafe,
raises: [Defect].} =
try:
check status == Status.Fulfilled
return (Status.TooLarge, "Message too big!")
check status == StatusFulfilled
return (StatusTooLarge, "Message too big!")
except Exception as exc:
raise newException(Defect, exc.msg)
@ -570,12 +461,12 @@ suite "Test Closing":
flags = {ReuseAddr})
server.start()
proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
proc clientClose(status: StatusCodes, reason: string): CloseResult {.gcsafe,
raises: [Defect].} =
try:
check status == Status.TooLarge
check status == StatusTooLarge
check reason == "Message too big!"
return (Status.Fulfilled, "")
return (StatusFulfilled, "")
except Exception as exc:
raise newException(Defect, exc.msg)
@ -592,7 +483,7 @@ suite "Test Closing":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await ws.close(code = Status.ReservedCode)
await ws.close(code = StatusCodes(StatusLibsCodes.high))
server = createServer(
address = address,
@ -600,11 +491,11 @@ suite "Test Closing":
flags = {ReuseAddr})
server.start()
proc closeClient(status: Status, reason: string): CloseResult{.gcsafe,
raises: [Defect].} =
proc closeClient(status: StatusCodes, reason: string): CloseResult
{.gcsafe, raises: [Defect].} =
try:
check status == Status.ReservedCode
return (Status.ReservedCode, "Reserved Status")
check status == StatusCodes(StatusLibsCodes.high)
return (StatusCodes(StatusLibsCodes.high), "Reserved StatusCodes")
except Exception as exc:
raise newException(Defect, exc.msg)
@ -617,11 +508,11 @@ suite "Test Closing":
test "Client closing with valid close code 3999":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
proc closeServer(status: StatusCodes, reason: string): CloseResult{.gcsafe,
raises: [Defect].} =
try:
check status == Status.ReservedCode
return (Status.ReservedCode, "Reserved Status")
check status == StatusCodes(3999)
return (StatusCodes(3999), "Reserved StatusCodes")
except Exception as exc:
raise newException(Defect, exc.msg)
@ -640,7 +531,7 @@ suite "Test Closing":
server.start()
let session = await connectClient()
await session.close(code = Status.ReservedCode)
await session.close(code = StatusCodes(3999))
test "Server closing with Payload of length 2":
proc handle(request: HttpRequest) {.async.} =
@ -681,41 +572,12 @@ suite "Test Closing":
# Close with payload of length 2
await session.close(reason = "HH")
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
suite "Test Payload":
teardown:
server.stop()
await server.closeWait()
test "Test payload message length":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
expect WSPayloadTooLarge:
discard await ws.recv()
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let str = rndStr(126)
let session = await connectClient()
await session.ping(str.toBytes())
await session.close()
test "Test single empty payload":
test "Test payload of length 0":
let emptyStr = ""
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
@ -724,8 +586,12 @@ suite "Test Payload":
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr
await waitForClose(ws)
check:
servRes.len == 0
string.fromBytes(servRes) == emptyStr
await ws.send(emptyStr)
await ws.waitForClose()
server = createServer(
address = address,
@ -734,20 +600,32 @@ suite "Test Payload":
server.start()
let session = await connectClient()
await session.send(emptyStr)
let clientRes = await session.recv()
check:
clientRes.len == 0
string.fromBytes(clientRes) == emptyStr
await session.close()
test "Test multiple empty payload":
test "Test multiple payloads of length 0":
let emptyStr = ""
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
for _ in 0..<3:
let servRes = await ws.recv()
check:
servRes.len == 0
string.fromBytes(servRes) == emptyStr
for i in 0..3:
await ws.send(emptyStr)
check string.fromBytes(servRes) == emptyStr
await waitForClose(ws)
server = createServer(
@ -757,22 +635,35 @@ suite "Test Payload":
server.start()
let session = await connectClient()
for i in 0..3:
await session.send(emptyStr)
for _ in 0..<3:
let clientRes = await session.recv()
check:
clientRes.len == 0
string.fromBytes(clientRes) == emptyStr
await session.close()
test "Send ping with small text payload":
let testData = toBytes("Hello, world!")
test "Send two fragments":
var ping, pong = false
let testString = "1234567890"
let msg = toBytes(testString)
let maxFrameSize = 5
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(
protos = ["proto"],
onPing = proc(data: openArray[byte]) =
ping = data == testData)
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let respData = await ws.recv()
check:
string.fromBytes(respData) == testString
await waitForClose(ws)
server = createServer(
@ -783,22 +674,133 @@ suite "Test Payload":
let session = await connectClient(
address = initTAddress("127.0.0.1:8888"),
frameSize = maxFrameSize)
let maskKey = genMaskKey(newRng())
await session.stream.writer.write(
(await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
data: msg[0..4],
maskKey: maskKey)
.encode()))
await session.stream.writer.write(
(await Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Cont,
mask: true,
data: msg[5..9],
maskKey: maskKey)
.encode()))
await session.close()
test "Send two fragments with a ping with payload in-between":
var ping, pong = false
let testString = "1234567890"
let msg = toBytes(testString)
let maxFrameSize = 5
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(
protos = ["proto"],
onPing = proc(data: openArray[byte]) =
ping = true
)
let ws = await server.handleRequest(request)
let respData = await ws.recv()
check:
string.fromBytes(respData) == testString
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let session = await connectClient(
address = initTAddress("127.0.0.1:8888"),
frameSize = maxFrameSize,
onPong = proc(data: openArray[byte]) =
pong = true
)
await session.ping(testData)
let maskKey = genMaskKey(newRng())
await session.stream.writer.write(
(await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
data: msg[0..4],
maskKey: maskKey)
.encode()))
await session.ping()
await session.stream.writer.write(
(await Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Cont,
mask: true,
data: msg[5..9],
maskKey: maskKey)
.encode()))
await session.close()
check:
ping
pong
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
test "Send text message with multiple frames":
const FrameSize = 3000
let testData = rndStr(FrameSize * 3)
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let res = await ws.recv()
check ws.binary == false
await ws.send(res, Opcode.Text)
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let ws = await connectClient(
address = address,
frameSize = FrameSize
)
await ws.send(testData)
let echoed = await ws.recv()
await ws.close()
check:
string.fromBytes(echoed) == testData
ws.binary == false
suite "Test Binary message with Payload":
teardown:
@ -860,7 +862,7 @@ suite "Test Binary message with Payload":
test "Send binary data with small text payload":
let testData = rndBin(10)
debug "testData", testData = testData
trace "testData", testData = testData
var ping, pong = false
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
@ -929,9 +931,42 @@ suite "Test Binary message with Payload":
await session.send(testData, Opcode.Binary)
await session.close()
# test "AsyncStream leaks test":
# check:
# getTracker("async.stream.reader").isLeaked() == false
# getTracker("async.stream.writer").isLeaked() == false
# getTracker("stream.server").isLeaked() == false
# getTracker("stream.transport").isLeaked() == false
test "Send binary message with multiple frames":
const FrameSize = 3000
let testData = rndBin(FrameSize * 3)
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let res = await ws.recv()
check:
ws.binary == true
res == testData
await ws.send(res, Opcode.Binary)
await waitForClose(ws)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
server.start()
let ws = await connectClient(
address = address,
frameSize = FrameSize
)
await ws.send(testData, Opcode.Binary)
let echoed = await ws.recv()
check:
echoed == testData
await ws.close()
check:
echoed == testData
ws.binary == true

6
ws/extensions.nim Normal file
View File

@ -0,0 +1,6 @@
import std/tables
import ./extensions/extutils
# import ./extensions/compression/compression
export extutils

View File

@ -0,0 +1,212 @@
## nim-ws
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import
std/[strutils],
pkg/[stew/results, chronos],
../../types, ../../frame, ./miniz/miniz_api
type
DeflateOpts = object
isServer: bool
serverNoContextTakeOver: bool
clientNoContextTakeOver: bool
serverMaxWindowBits: int
clientMaxWindowBits: int
DeflateExt = ref object of Ext
paramStr: string
opts: DeflateOpts
# 'messageCompressed' is a two way flag:
# 1. the original message is compressible or not
# 2. the frame we received need decompression or not
messageCompressed: bool
const
extID = "permessage-deflate"
proc concatParam(resp: var string, param: string) =
resp.add ';'
resp.add param
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
if arg.value.len == 0:
return ok("")
if arg.value.len > 2:
return err("window bits expect 2 bytes, got " & $arg.value.len)
for n in arg.value:
if n notin Digits:
return err("window bits value contains illegal char: " & $n)
var winbit = 0
for i in 0..<arg.value.len:
winbit = winbit * 10 + arg.value[i].int - '0'.int
if winbit < 8 or winbit > 15:
return err("window bits should between 8-15, got " & $winbit)
res = winbit
return ok("=" & arg.value)
proc validateParams(args: seq[ExtParam],
opts: var DeflateOpts): Result[string, string] =
# besides validating extensions params, this proc
# also constructing extension param for response
var resp = ""
for arg in args:
case arg.name
of "server_no_context_takeover":
if arg.value.len > 0:
return err("'server_no_context_takeover' should have no param")
if opts.isServer:
concatParam(resp, arg.name)
opts.serverNoContextTakeOver = true
of "client_no_context_takeover":
if arg.value.len > 0:
return err("'client_no_context_takeover' should have no param")
if opts.isServer:
concatParam(resp, arg.name)
opts.clientNoContextTakeOver = true
of "server_max_window_bits":
if opts.isServer:
concatParam(resp, arg.name)
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
if res.isErr:
return res
resp.add res.get()
of "client_max_window_bits":
if opts.isServer:
concatParam(resp, arg.name)
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
if res.isErr:
return res
resp.add res.get()
else:
return err("unrecognized param: " & arg.name)
ok(resp)
method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
if not ext.messageCompressed:
return data
# TODO: append trailing bytes
var mz = MzStream(
next_in: cast[ptr cuchar](data[0].unsafeAddr),
avail_in: data.len.cuint
)
let windowBits = if ext.opts.serverMaxWindowBits == 0:
MZ_DEFAULT_WINDOW_BITS
else:
MzWindowBits(ext.opts.serverMaxWindowBits)
doAssert(mz.inflateInit2(windowBits) == MZ_OK)
var res: seq[byte]
var buf: array[0xFFFF, byte]
while true:
mz.next_out = cast[ptr cuchar](buf[0].addr)
mz.avail_out = buf.len.cuint
let r = mz.inflate(MZ_SYNC_FLUSH)
let outSize = buf.len - mz.avail_out.int
res.add toOpenArray(buf, 0, outSize-1)
if r == MZ_STREAM_END:
break
elif r == MZ_OK:
continue
else:
doAssert(false, "decompression error")
doAssert(mz.inflateEnd() == MZ_OK)
return res
method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
var mz = MzStream(
next_in: cast[ptr cuchar](data[0].unsafeAddr),
avail_in: data.len.cuint
)
let windowBits = if ext.opts.serverMaxWindowBits == 0:
MZ_DEFAULT_WINDOW_BITS
else:
MzWindowBits(ext.opts.serverMaxWindowBits)
doAssert(mz.deflateInit2(
level = MZ_DEFAULT_LEVEL,
meth = MZ_DEFLATED,
windowBits,
1,
strategy = MZ_DEFAULT_STRATEGY) == MZ_OK
)
let maxSize = mz.deflateBound(data.len.culong).int
var res: seq[byte]
var buf: array[0xFFFF, byte]
while true:
mz.next_out = cast[ptr cuchar](buf[0].addr)
mz.avail_out = buf.len.cuint
let r = mz.deflate(MZ_FINISH)
let outSize = buf.len - mz.avail_out.int
res.add toOpenArray(buf, 0, outSize-1)
if r == MZ_STREAM_END:
break
elif r == MZ_OK:
continue
else:
doAssert(false, "compression error")
# TODO: cut trailing bytes
doAssert(mz.deflateEnd() == MZ_OK)
ext.messageCompressed = res.len < data.len
if ext.messageCompressed:
return res
else:
return data
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
if frame.opcode in {Opcode.Text, Opcode.Binary}:
# only data frame can be compressed
# and we want to know if this message is compressed or not
# if the frame opcode is text or binary, it should also the first frame
ext.messageCompressed = frame.rsv1
# clear rsv1 bit because we already done with it
frame.rsv1 = false
return frame
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
if frame.opcode in {Opcode.Text, Opcode.Binary}:
# only data frame can be compressed
# and we only set rsv1 bit to true if the message is compressible
# if the frame opcode is text or binary, it should also the first frame
frame.rsv1 = ext.messageCompressed
return frame
method toHttpOptions(ext: DeflateExt): string =
# using paramStr here is a bit clunky
extID & "; " & ext.paramStr
proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} =
var opts = DeflateOpts(isServer: isServer)
let resp = validateParams(args, opts)
if resp.isErr:
return err(resp.error)
let ext = DeflateExt(
name: extID,
paramStr: resp.get(),
opts: opts
)
ok(ext)
const
deflateFactory* = (extID, deflateExtFactory)

View File

@ -15,8 +15,12 @@ import pkg/[
stew/byteutils,
stew/endians2,
stew/results]
import ./types
logScope:
topics = "ws-frame"
#[
+---------------------------------------------------------------+
|0 1 2 3 |
@ -54,7 +58,6 @@ template remainder*(frame: Frame): uint64 =
proc encode*(
frame: Frame,
offset = 0,
extensions: seq[Ext] = @[]): Future[seq[byte]] {.async.} =
## Encodes a frame into a string buffer.
## See https://tools.ietf.org/html/rfc6455#section-5.2
@ -65,7 +68,7 @@ proc encode*(
f = await e.encode(f)
var ret: seq[byte]
var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags.
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
if f.fin:
b0 = b0 or 128'u8
@ -77,7 +80,7 @@ proc encode*(
if f.data.len <= 125:
b1 = f.data.len.uint8
elif f.data.len > 125 and f.data.len <= 0xFFFF:
elif f.data.len > 125 and f.data.len <= 0xffff:
b1 = 126'u8
else:
b1 = 127'u8
@ -88,12 +91,12 @@ proc encode*(
ret.add(uint8 b1)
# Only need more bytes if data len is 7+16 bits, or 7+64 bits.
if f.data.len > 125 and f.data.len <= 0xFFFF:
if f.data.len > 125 and f.data.len <= 0xffff:
# Data len is 7+16 bits.
var len = f.data.len.uint16
ret.add ((len shr 8) and 0xFF).uint8
ret.add (len and 0xFF).uint8
elif f.data.len > 0xFFFF:
ret.add ((len shr 8) and 0xff).uint8
ret.add (len and 0xff).uint8
elif f.data.len > 0xffff:
# Data len is 7+64 bits.
var len = f.data.len.uint64
ret.add(len.toBytesBE())
@ -101,7 +104,7 @@ proc encode*(
var data = f.data
if f.mask:
# If we need to mask it generate random mask key and mask the data.
mask(data, f.maskKey, offset)
mask(data, f.maskKey)
# Write mask key next.
ret.add(f.maskKey[0].uint8)
@ -122,10 +125,10 @@ proc decode*(
##
var header = newSeq[byte](2)
debug "Reading new frame"
trace "Reading new frame"
await reader.readExactly(addr header[0], 2)
if header.len != 2:
debug "Invalid websocket header length"
trace "Invalid websocket header length"
raise newException(WSMalformedHeaderError,
"Invalid websocket header length")
@ -147,10 +150,6 @@ proc decode*(
frame.opcode = (opcode).Opcode
# If any of the rsv are set close the socket.
if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
var finalLen: uint64 = 0
@ -187,7 +186,11 @@ proc decode*(
frame.maskKey[i] = cast[char](maskKey[i])
if extensions.len > 0:
for e in extensions[extensions.high..extensions.low]:
frame = await e.decode(frame)
for i in countdown(extensions.high, extensions.low):
frame = await extensions[i].decode(frame)
# If any of the rsv are set close the socket.
if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
return frame

View File

@ -9,6 +9,9 @@ import pkg/[
import ./common
logScope:
topics = "http-client"
type
HttpClient* = ref object of RootObj
connected*: bool
@ -44,7 +47,7 @@ proc readResponse(stream: AsyncStreamReader): Future[HttpResponseHeader] {.async
return buffer.parseResponse()
except CatchableError as exc:
debug "Exception reading headers", exc = exc.msg
trace "Exception reading headers", exc = exc.msg
buffer.setLen(0)
raise exc

View File

@ -53,8 +53,7 @@ proc closeWait*(stream: AsyncStream) {.async.} =
await allFutures(
stream.reader.tsource.closeTransp(),
stream.reader.closeStream(),
stream.writer.closeStream()
)
stream.writer.closeStream())
proc sendResponse*(
request: HttpRequest,
@ -112,8 +111,7 @@ proc sendError*(
response.add(CRLF)
await stream.write(
response.toBytes() &
content.toBytes())
response.toBytes() & content.toBytes())
proc sendError*(
request: HttpRequest,

View File

@ -30,13 +30,13 @@ proc validateRequest(
##
if header.meth notin {MethodGet}:
debug "GET method is only allowed", address = stream.tsource.remoteAddress()
trace "GET method is only allowed", address = stream.tsource.remoteAddress()
await stream.sendError(Http405, version = header.version)
return ReqStatus.Error
var hlen = header.contentLength()
if hlen < 0 or hlen > MaxHttpRequestSize:
debug "Invalid header length", address = stream.tsource.remoteAddress()
trace "Invalid header length", address = stream.tsource.remoteAddress()
await stream.sendError(Http413, version = header.version)
return ReqStatus.Error
@ -50,14 +50,14 @@ proc handleRequest(
var buffer = newSeq[byte](MaxHttpHeadersSize)
let remoteAddr = stream.reader.tsource.remoteAddress()
debug "Received connection", address = $remoteAddr
trace "Received connection", address = $remoteAddr
try:
let hlenfut = stream.reader.readUntil(
addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep)
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
if not ores:
# Timeout
debug "Timeout expired while receiving headers", address = $remoteAddr
trace "Timeout expired while receiving headers", address = $remoteAddr
await stream.writer.sendError(Http408, version = HttpVersion11)
return
@ -66,7 +66,7 @@ proc handleRequest(
let requestData = buffer.parseRequest()
if requestData.failed():
# Header could not be parsed
debug "Malformed header received", address = $remoteAddr
trace "Malformed header received", address = $remoteAddr
await stream.writer.sendError(Http400, version = HttpVersion11)
return
@ -79,10 +79,10 @@ proc handleRequest(
res
if vres == ReqStatus.ErrorFailure:
debug "Remote peer disconnected", address = $remoteAddr
trace "Remote peer disconnected", address = $remoteAddr
return
debug "Received valid HTTP request", address = $remoteAddr
trace "Received valid HTTP request", address = $remoteAddr
# Call the user's handler.
if server.handler != nil:
await server.handler(
@ -92,15 +92,15 @@ proc handleRequest(
uri: requestData.uri().parseUri()))
except TransportLimitError:
# size of headers exceeds `MaxHttpHeadersSize`
debug "Maximum size of headers limit reached", address = $remoteAddr
trace "maximum size of headers limit reached", address = $remoteAddr
await stream.writer.sendError(Http413, version = HttpVersion11)
except TransportIncompleteError:
# remote peer disconnected
debug "Remote peer disconnected", address = $remoteAddr
trace "Remote peer disconnected", address = $remoteAddr
except TransportOsError as exc:
debug "Problems with networking", address = $remoteAddr, error = exc.msg
trace "Problems with networking", address = $remoteAddr, error = exc.msg
except CatchableError as exc:
debug "Unknown exception", address = $remoteAddr, error = exc.msg
trace "Unknown exception", address = $remoteAddr, error = exc.msg
finally:
await stream.closeWait()
@ -151,6 +151,8 @@ proc create*(
flags,
child = StreamServer(server)))
trace "Created HTTP Server", host = $address
return server
proc create*(
@ -191,6 +193,8 @@ proc create*(
flags,
child = StreamServer(server)))
trace "Created TLS HTTP Server", host = $address
return server
proc create*(

View File

@ -9,27 +9,27 @@
{.push raises: [Defect].}
import std/strformat
import pkg/[chronos, chronicles, stew/byteutils, stew/endians2]
import ./types, ./frame, ./utils, ./utf8_dfa, ./http
import ./types, ./frame, ./utils, ./utf8dfa, ./http
import pkg/chronos/[streams/asyncstream]
import pkg/chronos/streams/asyncstream
type
WSSession* = ref object of WebSocket
stream*: AsyncStream
frame*: Frame
proto*: string
logScope:
topics = "ws-session"
proc prepareCloseBody(code: Status, reason: string): seq[byte] =
proc prepareCloseBody(code: StatusCodes, reason: string): seq[byte] =
result = reason.toBytes
if ord(code) > 999:
result = @(ord(code).uint16.toBytesBE()) & result
proc send*(
proc writeMessage*(
ws: WSSession,
data: seq[byte] = @[],
opcode: Opcode) {.async.} =
## Send a frame
opcode: Opcode,
extensions: seq[Ext]) {.async.} =
## Send a frame applying the supplied
## extensions
##
if ws.readyState == ReadyState.Closed:
@ -40,7 +40,7 @@ proc send*(
dataSize = data.len
masked = ws.masked
debug "Sending data to remote"
trace "Sending data to remote"
var maskKey: array[4, char]
if ws.masked:
@ -61,30 +61,40 @@ proc send*(
mask: ws.masked,
data: data, # allow sending data with close messages
maskKey: maskKey)
.encode(extensions = ws.extensions)))
.encode()))
return
let maxSize = ws.frameSize
var i = 0
while ws.readyState notin {ReadyState.Closing}:
let len = min(data.len, (maxSize + i))
await ws.stream.writer.write(
(await Frame(
fin: if (i + len >= data.len): true else: false,
let len = min(data.len, maxSize)
let frame = Frame(
fin: if (len + i >= data.len): true else: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
mask: ws.masked,
data: data[i ..< len],
data: data[i ..< len + i],
maskKey: maskKey)
.encode()))
let encoded = await frame.encode(extensions)
await ws.stream.writer.write(encoded)
i += len
if i >= data.len:
break
proc send*(
ws: WSSession,
data: seq[byte] = @[],
opcode: Opcode): Future[void] =
## Send a frame
##
return ws.writeMessage(data, opcode, ws.extensions)
proc send*(ws: WSSession, data: string): Future[void] =
send(ws, data.toBytes(), Opcode.Text)
@ -101,54 +111,62 @@ proc handleClose*(
opcode = frame.opcode
readyState = ws.readyState
debug "Handling close"
trace "Handling close"
if ws.readyState notin {ReadyState.Open}:
debug "Connection isn't open, abortig close sequence!"
if ws.readyState != ReadyState.Open:
trace "Connection isn't open, aborting close sequence!"
return
var
code = Status.Fulfilled
code = StatusFulfilled
reason = ""
if payLoad.len == 1:
case payload.len:
of 0:
code = StatusNoStatus
of 1:
raise newException(WSPayloadLengthError,
"Invalid close frame with payload length 1!")
if payLoad.len > 1:
# first two bytes are the status
let ccode = uint16.fromBytesBE(payLoad[0..<2])
if ccode <= 999 or ccode > 1015:
raise newException(WSInvalidCloseCodeError,
"Invalid code in close message!")
else:
try:
code = Status(ccode)
code = StatusCodes(uint16.fromBytesBE(payLoad[0..<2]))
except RangeError:
raise newException(WSInvalidCloseCodeError,
"Status code out of range!")
# remining payload bytes are reason for closing
if code in StatusNotUsed or
code in StatusReservedProtocol:
raise newException(WSInvalidCloseCodeError,
&"Can't use reserved status code: {code}")
if code == StatusReserved or
code == StatusNoStatus or
code == StatusClosedAbnormally:
raise newException(WSInvalidCloseCodeError,
&"Can't use reserved status code: {code}")
# remaining payload bytes are reason for closing
reason = string.fromBytes(payLoad[2..payLoad.high])
if not ws.binary and validateUTF8(reason) == false:
raise newException(WSInvalidUTF8,
"Invalid UTF8 sequence detected in close reason")
var rcode: Status
if code in {Status.Fulfilled}:
rcode = Status.Fulfilled
trace "Handling close message", code, reason
if not isNil(ws.onClose):
try:
(rcode, reason) = ws.onClose(code, reason)
(code, reason) = ws.onClose(code, reason)
except CatchableError as exc:
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
trace "Exception in Close callback, this is most likely a bug", exc = exc.msg
else:
code = StatusFulfilled
reason = ""
# don't respond to a terminated connection
if ws.readyState != ReadyState.Closing:
ws.readyState = ReadyState.Closing
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
trace "Sending close", code, reason
await ws.send(prepareCloseBody(code, reason), Opcode.Close)
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
@ -164,7 +182,7 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
readyState = ws.readyState
len = frame.length
debug "Handling control frame"
trace "Handling control frame"
if not frame.fin:
raise newException(WSFragmentedControlFrameError,
@ -191,7 +209,7 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
try:
ws.onPing(payLoad)
except CatchableError as exc:
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
trace "Exception in Ping callback, this is most likely a bug", exc = exc.msg
# send pong to remote
await ws.send(payLoad, Opcode.Pong)
@ -200,21 +218,28 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
try:
ws.onPong(payLoad)
except CatchableError as exc:
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
trace "Exception in Pong callback, this is most likely a bug", exc = exc.msg
of Opcode.Close:
await ws.handleClose(frame, payLoad)
else:
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
proc readFrame*(ws: WSSession): Future[Frame] {.async.} =
proc readFrame*(ws: WSSession, extensions: seq[Ext] = @[]): Future[Frame] {.async.} =
## Gets a frame from the WebSocket.
## See https://tools.ietf.org/html/rfc6455#section-5.2
##
while ws.readyState != ReadyState.Closed:
let frame = await Frame.decode(
ws.stream.reader, ws.masked, ws.extensions)
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
ws.stream.reader, ws.masked, extensions)
logScope:
opcode = frame.opcode
len = frame.length
mask = frame.mask
fin = frame.fin
trace "Decoded new frame"
# return the current frame if it's not one of the control frames
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
@ -230,59 +255,70 @@ proc recv*(
ws: WSSession,
data: pointer,
size: int): Future[int] {.async.} =
## Attempts to read up to `size` bytes
## Attempts to read up to ``size`` bytes
##
## Will read as many frames as necessary
## to fill the buffer until either
## the message ends (frame.fin) or
## the buffer is full. If no data is on
## the pipe will await until at least
## one byte is available
## If ``size`` is less than the data in
## the frame, allow reading partial frames
##
## If no data is left in the pipe await
## until at least one byte is available
##
## Otherwise, read as many frames as needed
## up to ``size`` bytes, note that we do break
## at message boundaries (``fin`` flag set).
##
## Use this to stream data from frames
##
var consumed = 0
var pbuffer = cast[ptr UncheckedArray[byte]](data)
try:
var first = true
# reset previous frame if nothing is left in it
if not isNil(ws.frame) and ws.frame.remainder <= 0:
trace "Resetting previous frame"
first = ws.frame.fin # set as first frame if last frame was final
ws.frame = nil
if isNil(ws.frame):
ws.frame = await ws.readFrame(ws.extensions)
while consumed < size:
# we might have to read more than
# one frame to fill the buffer
# TODO: Figure out a cleaner way to handle
# retrieving new frames
if isNil(ws.frame):
ws.frame = await ws.readFrame()
if isNil(ws.frame):
return consumed
if ws.frame.opcode == Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Text or Binary frame")
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed
if ws.frame.opcode != Opcode.Cont:
raise newException(WSOpcodeMismatchError,
"Expected Continuation frame")
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
if ws.frame.fin and ws.frame.remainder() <= 0:
ws.frame = nil
trace "Empty frame, breaking"
break
let len = min(ws.frame.remainder().int, size - consumed)
if len == 0:
continue
logScope:
first = first
fin = ws.frame.fin
len = ws.frame.length
consumed = ws.frame.consumed
remainder = ws.frame.remainder
opcode = ws.frame.opcode
masked = ws.frame.mask
if first == (ws.frame.opcode == Opcode.Cont):
error "Opcode mismatch!"
raise newException(WSOpcodeMismatchError,
&"Opcode mismatch: first: {first}, opcode: {ws.frame.opcode}")
if first:
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
trace "Setting binary flag"
let len = min(ws.frame.remainder.int, size - consumed)
if len <= 0:
trace "Nothing left to read, breaking!"
break
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
if read <= 0:
continue
trace "Didn't read any bytes, breaking"
break
if ws.frame.mask:
trace "Unmasking frame"
# unmask data using offset
mask(
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
@ -292,15 +328,31 @@ proc recv*(
consumed += read
ws.frame.consumed += read.uint64
trace "Read data from frame", read
# all has been consumed from the frame
# read the next frame
if ws.frame.remainder <= 0:
first = false
if ws.frame.fin: # we're at the end of the message, break
trace "Read all frames, breaking"
ws.frame = nil
break
ws.frame = await ws.readFrame(ws.extensions)
if not ws.binary and validateUTF8(pbuffer.toOpenArray(0, consumed - 1)) == false:
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
return consumed.int
return consumed
except CatchableError as exc:
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
debug "Exception reading frames", exc = exc.msg
trace "Exception reading frames", exc = exc.msg
raise exc
finally:
if not isNil(ws.frame) and (ws.frame.fin and ws.frame.remainder <= 0):
ws.frame = nil
proc recv*(
ws: WSSession,
@ -318,15 +370,14 @@ proc recv*(
##
var res: seq[byte]
while ws.readyState != ReadyState.Closed:
var buf = newSeq[byte](ws.frameSize)
var buf = newSeq[byte](min(size, ws.frameSize))
let read = await ws.recv(addr buf[0], buf.len)
if read <= 0:
break
buf.setLen(read)
if res.len + buf.len > size:
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
trace "Read message", size = read
res.add(buf)
# no more frames
@ -335,13 +386,14 @@ proc recv*(
# read the entire message, exit
if ws.frame.fin and ws.frame.remainder().int <= 0:
trace "Read full message, breaking!"
break
return res
proc close*(
ws: WSSession,
code: Status = Status.Fulfilled,
code = StatusFulfilled,
reason: string = "") {.async.} =
## Close the Socket, sends close packet.
##
@ -359,4 +411,4 @@ proc close*(
while ws.readyState != ReadyState.Closed:
discard await ws.recv()
except CatchableError as exc:
debug "Exception closing", exc = exc.msg
trace "Exception closing", exc = exc.msg

View File

@ -62,41 +62,18 @@ type
length*: uint64 ## Message size.
consumed*: uint64 ## how much has been consumed from the frame
Status* {.pure.} = enum
# 0-999 not used
Fulfilled = 1000
GoingAway = 1001
ProtocolError = 1002
CannotAccept = 1003
# 1004 reserved
NoStatus = 1005 # use by clients
ClosedAbnormally = 1006 # use by clients
Inconsistent = 1007
PolicyError = 1008
TooLarge = 1009
NoExtensions = 1010
UnexpectedError = 1011
ReservedCode = 3999 # use by clients
# 3000-3999 reserved for libs
# 4000-4999 reserved for applications
StatusCodes* = distinct range[0..4999]
ControlCb* = proc(data: openArray[byte] = [])
{.gcsafe, raises: [Defect].}
CloseResult* = tuple
code: Status
code: StatusCodes
reason: string
CloseCb* = proc(code: Status, reason: string):
CloseCb* = proc(code: StatusCodes, reason: string):
CloseResult {.gcsafe, raises: [Defect].}
Ext* = ref object of RootObj
name*: string
options*: Table[string, string]
ExtFactory* = proc(name: string, options: Table[string, string]):
Ext {.raises: [Defect].}
WebSocket* = ref object of RootObj
extensions*: seq[Ext]
version*: uint
@ -111,6 +88,21 @@ type
onPong*: ControlCb
onClose*: CloseCb
WSSession* = ref object of WebSocket
stream*: AsyncStream
frame*: Frame
proto*: string
Ext* = ref object of RootObj
name*: string
options*: Table[string, string]
session*: WSSession
ExtFactory* = proc(
name: string,
session: WSSession,
options: Table[string, string]): Ext {.raises: [Defect].}
WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError
WSFailedUpgradeError* = object of WebSocketError
@ -125,13 +117,43 @@ type
WSClosedError* = object of WebSocketError
WSSendError* = object of WebSocketError
WSPayloadTooLarge* = object of WebSocketError
WSReserverdOpcodeError* = object of WebSocketError
WSReservedOpcodeError* = object of WebSocketError
WSFragmentedControlFrameError* = object of WebSocketError
WSInvalidCloseCodeError* = object of WebSocketError
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError
WSInvalidUTF8* = object of WebSocketError
const
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
StatusFulfilled* = StatusCodes(1000)
StatusGoingAway* = StatusCodes(1001)
StatusProtocolError* = StatusCodes(1002)
StatusCannotAccept* = StatusCodes(1003)
StatusReserved* = StatusCodes(1004) # 1004 reserved
StatusNoStatus* = StatusCodes(1005) # use by clients
StatusClosedAbnormally* = StatusCodes(1006) # use by clients
StatusInconsistent* = StatusCodes(1007)
StatusPolicyError* = StatusCodes(1008)
StatusTooLarge* = StatusCodes(1009)
StatusNoExtensions* = StatusCodes(1010)
StatusUnexpectedError* = StatusCodes(1011)
StatusFailedTls* = StatusCodes(1015) # passed to applications to indicate TLS errors
StatusReservedProtocol* = StatusCodes(1016)..StatusCodes(2999) # reserved for this protocol
StatusLibsCodes* = (StatusCodes(3000)..StatusCodes(3999)) # 3000-3999 reserved for libs
StatusAppsCodes* = (StatusCodes(4000)..StatusCodes(4999)) # 4000-4999 reserved for apps
proc `<=`*(a, b: StatusCodes): bool = a.uint16 <= b.uint16
proc `>=`*(a, b: StatusCodes): bool = a.uint16 >= b.uint16
proc `<`*(a, b: StatusCodes): bool = a.uint16 < b.uint16
proc `>`*(a, b: StatusCodes): bool = a.uint16 > b.uint16
proc `==`*(a, b: StatusCodes): bool = a.uint16 == b.uint16
proc high*(a: HSlice[StatusCodes, StatusCodes]): uint16 = a.b.uint16
proc low*(a: HSlice[StatusCodes, StatusCodes]): uint16 = a.a.uint16
proc `$`*(a: StatusCodes): string = $(a.int)
proc `name=`*(self: Ext, name: string) =
raiseAssert "Can't change extensions name!"
@ -141,5 +163,5 @@ method decode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} =
method encode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"
method toHttpOptions*(self: Ext): string =
method toHttpOptions*(self: Ext): string {.base.} =
raiseAssert "Not implemented!"

View File

@ -23,7 +23,6 @@ import pkg/[chronos,
chronicles,
httputils,
stew/byteutils,
stew/endians2,
stew/base64,
stew/base10,
nimcrypto/sha]
@ -32,6 +31,9 @@ import ./utils, ./frame, ./session, /types, ./http
export utils, session, frame, types, http
logScope:
topics = "ws-server"
type
WSServer* = ref object of WebSocket
protocols: seq[string]
@ -86,7 +88,7 @@ proc connect*(
let response = try:
await client.request(uri, headers = headers)
except CatchableError as exc:
debug "Websocket failed during handshake", exc = exc.msg
trace "Websocket failed during handshake", exc = exc.msg
await client.close()
raise exc
@ -207,7 +209,7 @@ proc handleRequest*(
if ws.version != version:
await request.stream.writer.sendError(Http426)
debug "Websocket version not supported", version = ws.version
trace "Websocket version not supported", version = ws.version
raise newException(WSVersionError,
&"Websocket version not supported, Version: {version}")
@ -236,7 +238,7 @@ proc handleRequest*(
if protocol.len > 0:
headers.add("Sec-WebSocket-Protocol", protocol) # send back the first matching proto
else:
debug "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
try:
await request.sendResponse(Http101, headers = headers)