Fix partial frame handling and allow extensions to hijack the flow (#56)
* moving files around * wip * wip * move tls example into server example * add tls functionality * rename * rename * fix tests * move extension related files to own folder * use trace instead of debug * export extensions * rework partial frame handling and closing * rework status codes as distincts * logging * re-enable extensions processing for frames * enable all test for non-tls server * remove tlsserver * remove offset to mask - don't think we need it * pass sessions extensions when calling send/recv * adding encode/decode extensions flow test * move server/client setup to helpers * proper frame order execution on decode * fix tls tests
This commit is contained in:
parent
e632202037
commit
3e1599d790
|
@ -232,8 +232,8 @@ jobs:
|
|||
kill $pid
|
||||
cd ..
|
||||
|
||||
nim c examples/tlsserver.nim
|
||||
examples/tlsserver &
|
||||
nim -d:tls c examples/server.nim
|
||||
examples/server &
|
||||
pid=$!
|
||||
cd autobahn
|
||||
wstest --mode fuzzingclient --spec fuzzingclient_tls.json
|
||||
|
|
|
@ -7,6 +7,6 @@
|
|||
}
|
||||
],
|
||||
"cases": ["*"],
|
||||
"exclude-cases": ["9.*", "12.*", "13.*"],
|
||||
"exclude-cases": [],
|
||||
"exclude-agent-cases": {}
|
||||
}
|
||||
|
|
|
@ -6,12 +6,19 @@ import pkg/[
|
|||
import ../ws/ws
|
||||
|
||||
proc main() {.async.} =
|
||||
let ws = await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/ws")
|
||||
let ws = when defined tls:
|
||||
await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName})
|
||||
else:
|
||||
await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/ws")
|
||||
|
||||
debug "Websocket client: ", State = ws.readyState
|
||||
trace "Websocket client: ", State = ws.readyState
|
||||
|
||||
let reqData = "Hello Server"
|
||||
while true:
|
||||
|
@ -22,7 +29,7 @@ proc main() {.async.} =
|
|||
break
|
||||
|
||||
let dataStr = string.fromBytes(buff)
|
||||
debug "Server Response: ", data = dataStr
|
||||
trace "Server Response: ", data = dataStr
|
||||
|
||||
assert dataStr == reqData
|
||||
break
|
||||
|
|
|
@ -5,13 +5,15 @@ import pkg/[chronos,
|
|||
httputils]
|
||||
|
||||
import ../ws/ws
|
||||
import ../tests/keys
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
debug "Handling request:", uri = request.uri.path
|
||||
if request.uri.path != "/ws":
|
||||
trace "Handling request:", uri = request.uri.path
|
||||
let path = when defined tls: "/wss" else: "/ws"
|
||||
if request.uri.path != path:
|
||||
return
|
||||
|
||||
debug "Initiating web socket connection."
|
||||
trace "Initiating web socket connection."
|
||||
try:
|
||||
let server = WSServer.new()
|
||||
let ws = await server.handleRequest(request)
|
||||
|
@ -19,16 +21,14 @@ proc handle(request: HttpRequest) {.async.} =
|
|||
error "Failed to open websocket connection"
|
||||
return
|
||||
|
||||
debug "Websocket handshake completed"
|
||||
while true:
|
||||
trace "Websocket handshake completed"
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let recvData = await ws.recv()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
debug "Websocket closed"
|
||||
break
|
||||
|
||||
debug "Client Response: ", size = recvData.len
|
||||
trace "Client Response: ", size = recvData.len, binary = ws.binary
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
|
||||
except WebSocketError as exc:
|
||||
error "WebSocket error:", exception = exc.msg
|
||||
|
||||
|
@ -37,10 +37,18 @@ when isMainModule:
|
|||
let
|
||||
address = initTAddress("127.0.0.1:8888")
|
||||
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
server = HttpServer.create(address, handle, flags = socketFlags)
|
||||
server = when defined tls:
|
||||
TlsHttpServer.create(
|
||||
address = address,
|
||||
handler = handle,
|
||||
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
|
||||
tlsCertificate = TLSCertificate.init(SecureCert),
|
||||
flags = socketFlags)
|
||||
else:
|
||||
HttpServer.create(address, handle, flags = socketFlags)
|
||||
|
||||
server.start()
|
||||
info "Server listening at ", data = $server.localAddress()
|
||||
trace "Server listening on ", data = $server.localAddress()
|
||||
await server.join()
|
||||
|
||||
waitFor(main())
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
import pkg/[chronos,
|
||||
chronos/streams/tlsstream,
|
||||
chronicles,
|
||||
stew/byteutils]
|
||||
|
||||
import ../ws/ws
|
||||
|
||||
proc main() {.async.} =
|
||||
let ws = await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["myfancyprotocol"],
|
||||
flags = {NoVerifyHost, NoVerifyServerName})
|
||||
debug "Websocket client: ", State = ws.readyState
|
||||
|
||||
let reqData = "Hello Server"
|
||||
try:
|
||||
debug "sending client "
|
||||
await ws.send(reqData)
|
||||
let buff = await ws.recv()
|
||||
if buff.len <= 0:
|
||||
break
|
||||
let dataStr = string.fromBytes(buff)
|
||||
debug "Server:", data = dataStr
|
||||
|
||||
assert dataStr == reqData
|
||||
return # bail out
|
||||
except WebSocketError as exc:
|
||||
error "WebSocket error:", exception = exc.msg
|
||||
|
||||
# close the websocket
|
||||
await ws.close()
|
||||
|
||||
waitFor(main())
|
|
@ -1,54 +0,0 @@
|
|||
import pkg/[chronos,
|
||||
chronicles,
|
||||
httputils,
|
||||
stew/byteutils]
|
||||
|
||||
import pkg/[chronos/streams/tlsstream]
|
||||
|
||||
import ../ws/ws
|
||||
import ../tests/keys
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
debug "Handling request:", uri = request.uri.path
|
||||
if request.uri.path != "/wss":
|
||||
debug "Initiating web socket connection."
|
||||
return
|
||||
|
||||
try:
|
||||
let server = WSServer.new(protos = ["myfancyprotocol"])
|
||||
var ws = await server.handleRequest(request)
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
debug "Websocket handshake completed."
|
||||
# Only reads header for data frame.
|
||||
echo "receiving server "
|
||||
let recvData = await ws.recv()
|
||||
if recvData.len <= 0:
|
||||
debug "Empty messages"
|
||||
break
|
||||
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
return
|
||||
debug "Response: ", data = string.fromBytes(recvData)
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
|
||||
when isMainModule:
|
||||
proc main() {.async.} =
|
||||
let address = initTAddress("127.0.0.1:8888")
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let server = TlsHttpServer.create(
|
||||
address = address,
|
||||
handler = handle,
|
||||
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
|
||||
tlsCertificate = TLSCertificate.init(SecureCert),
|
||||
flags = socketFlags)
|
||||
|
||||
server.start()
|
||||
info "Server listening at ", data = $server.localAddress()
|
||||
await server.join()
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,276 @@
|
|||
import std/strutils
|
||||
import pkg/[chronos, stew/byteutils]
|
||||
|
||||
import ../../ws/ws
|
||||
import ../asyncunit
|
||||
|
||||
type
|
||||
ExtHandler = proc(ext: Ext, frame: Frame): Future[Frame] {.raises: [Defect].}
|
||||
|
||||
HelperExtension = ref object of Ext
|
||||
handler*: ExtHandler
|
||||
|
||||
proc new*(
|
||||
T: typedesc[HelperExtension],
|
||||
handler: ExtHandler,
|
||||
session: WSSession = nil): HelperExtension =
|
||||
HelperExtension(
|
||||
handler: handler,
|
||||
name: "HelperExtension")
|
||||
|
||||
method decode*(
|
||||
self: HelperExtension,
|
||||
frame: Frame): Future[Frame] {.async.} =
|
||||
return await self.handler(self, frame)
|
||||
|
||||
method encode*(
|
||||
self: HelperExtension,
|
||||
frame: Frame): Future[Frame] {.async.} =
|
||||
return await self.handler(self, frame)
|
||||
|
||||
const TestString = "Hello"
|
||||
|
||||
suite "Encode frame extensions flow":
|
||||
test "should call extension on encode":
|
||||
var data = ""
|
||||
proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "toUpper executed"
|
||||
data = string.fromBytes(frame.data).toUpper()
|
||||
check TestString.toUpper() == data
|
||||
frame.data = data.toBytes()
|
||||
return frame
|
||||
|
||||
var frame = Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: false,
|
||||
data: TestString.toBytes())
|
||||
|
||||
discard await frame.encode(@[HelperExtension.new(toUpper).Ext])
|
||||
check frame.data == TestString.toUpper().toBytes()
|
||||
|
||||
test "should call extensions in correct order on encode":
|
||||
var count = 0
|
||||
proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "first executed"
|
||||
check count == 0
|
||||
count.inc
|
||||
|
||||
return frame
|
||||
|
||||
proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "second executed"
|
||||
check count == 1
|
||||
count.inc
|
||||
|
||||
return frame
|
||||
|
||||
var frame = Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: false,
|
||||
data: TestString.toBytes())
|
||||
|
||||
discard await frame.encode(@[
|
||||
HelperExtension.new(first).Ext,
|
||||
HelperExtension.new(second).Ext])
|
||||
|
||||
check count == 2
|
||||
|
||||
test "should allow modifying frame headers":
|
||||
proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "changeHeader executed"
|
||||
frame.rsv1 = true
|
||||
frame.rsv2 = true
|
||||
frame.rsv3 = true
|
||||
frame.opcode = Opcode.Binary
|
||||
return frame
|
||||
|
||||
var frame = Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text, # fragments have to be `Continuation` frames
|
||||
mask: false,
|
||||
data: TestString.toBytes())
|
||||
|
||||
discard await frame.encode(@[HelperExtension.new(changeHeader).Ext])
|
||||
check:
|
||||
frame.rsv1 == true
|
||||
frame.rsv2 == true
|
||||
frame.rsv2 == true
|
||||
frame.opcode == Opcode.Binary
|
||||
|
||||
suite "Decode frame extensions flow":
|
||||
var
|
||||
address: TransportAddress
|
||||
server: StreamServer
|
||||
maskKey = genMaskKey(newRng())
|
||||
transport: StreamTransport
|
||||
reader: AsyncStreamReader
|
||||
frame: Frame
|
||||
|
||||
setup:
|
||||
server = createStreamServer(
|
||||
initTAddress("127.0.0.1:0"),
|
||||
flags = {ServerFlags.ReuseAddr})
|
||||
address = server.localAddress()
|
||||
|
||||
teardown:
|
||||
await transport.closeWait()
|
||||
await server.closeWait()
|
||||
server.stop()
|
||||
|
||||
test "should call extension on decode":
|
||||
var data = ""
|
||||
proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "toUpper executed"
|
||||
try:
|
||||
var buf = newSeq[byte](frame.length)
|
||||
# read data
|
||||
await reader.readExactly(addr buf[0], buf.len)
|
||||
if frame.mask:
|
||||
mask(buf, maskKey)
|
||||
frame.mask = false # we can reset the mask key here
|
||||
|
||||
data = string.fromBytes(buf).toUpper()
|
||||
check:
|
||||
TestString.toUpper() == data
|
||||
|
||||
frame.data = data.toBytes()
|
||||
return frame
|
||||
except CatchableError as exc:
|
||||
checkpoint exc.msg
|
||||
check false
|
||||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let transport = await server.accept()
|
||||
reader = newAsyncStreamReader(transport)
|
||||
frame = await Frame.decode(
|
||||
reader,
|
||||
false,
|
||||
@[HelperExtension.new(toUpper).Ext])
|
||||
|
||||
await reader.closeWait()
|
||||
await transport.closeWait()
|
||||
|
||||
let handlerWait = acceptHandler()
|
||||
var encodedFrame = (await Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
maskKey: maskKey,
|
||||
data: TestString.toBytes())
|
||||
.encode())
|
||||
|
||||
transport = await connect(address)
|
||||
let wrote = await transport.write(encodedFrame)
|
||||
|
||||
await handlerWait
|
||||
check:
|
||||
wrote == encodedFrame.len
|
||||
frame.data == TestString.toUpper().toBytes()
|
||||
|
||||
test "should call extensions in reverse order on decode":
|
||||
var count = 0
|
||||
proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "first executed"
|
||||
check count == 1
|
||||
count.inc
|
||||
|
||||
return frame
|
||||
|
||||
proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "second executed"
|
||||
check count == 0
|
||||
count.inc
|
||||
|
||||
return frame
|
||||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let transport = await server.accept()
|
||||
reader = newAsyncStreamReader(transport)
|
||||
frame = await Frame.decode(
|
||||
reader,
|
||||
false,
|
||||
@[HelperExtension.new(first).Ext,
|
||||
HelperExtension.new(second).Ext])
|
||||
|
||||
await reader.closeWait()
|
||||
await transport.closeWait()
|
||||
|
||||
let handlerWait = acceptHandler()
|
||||
var encodedFrame = (await Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
maskKey: maskKey,
|
||||
data: TestString.toBytes())
|
||||
.encode())
|
||||
|
||||
let transport = await connect(address)
|
||||
let wrote = await transport.write(encodedFrame)
|
||||
|
||||
await handlerWait
|
||||
check:
|
||||
wrote == encodedFrame.len
|
||||
count == 2
|
||||
|
||||
test "should allow modifying frame headers":
|
||||
proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} =
|
||||
checkpoint "changeHeader executed"
|
||||
frame.rsv1 = false
|
||||
frame.rsv2 = false
|
||||
frame.rsv3 = false
|
||||
frame.opcode = Opcode.Binary
|
||||
|
||||
return frame
|
||||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let transport = await server.accept()
|
||||
reader = newAsyncStreamReader(transport)
|
||||
frame = await Frame.decode(
|
||||
reader,
|
||||
false,
|
||||
@[HelperExtension.new(changeHeader).Ext])
|
||||
|
||||
check:
|
||||
frame.rsv1 == false
|
||||
frame.rsv2 == false
|
||||
frame.rsv2 == false
|
||||
frame.opcode == Opcode.Binary
|
||||
|
||||
await reader.closeWait()
|
||||
await transport.closeWait()
|
||||
|
||||
let handlerWait = acceptHandler()
|
||||
var encodedFrame = (await Frame(
|
||||
fin: false,
|
||||
rsv1: true,
|
||||
rsv2: true,
|
||||
rsv3: true,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
maskKey: maskKey,
|
||||
data: TestString.toBytes())
|
||||
.encode())
|
||||
|
||||
let transport = await connect(address)
|
||||
let wrote = await transport.write(encodedFrame)
|
||||
|
||||
await handlerWait
|
||||
check:
|
||||
wrote == encodedFrame.len
|
|
@ -0,0 +1,81 @@
|
|||
import std/[strutils, random]
|
||||
import pkg/[
|
||||
chronos,
|
||||
chronos/streams/tlsstream,
|
||||
httputils,
|
||||
chronicles,
|
||||
stew/byteutils]
|
||||
|
||||
import ../ws/ws
|
||||
import ./keys
|
||||
|
||||
let
|
||||
WSSecureKey* = TLSPrivateKey.init(SecureKey)
|
||||
WSSecureCert* = TLSCertificate.init(SecureCert)
|
||||
|
||||
const WSPath* = when defined secure: "/wss" else: "/ws"
|
||||
|
||||
proc rndStr*(size: int): string =
|
||||
for _ in 0..<size:
|
||||
add(result, char(rand(int('A') .. int('z'))))
|
||||
|
||||
proc rndBin*(size: int): seq[byte] =
|
||||
for _ in 0..<size:
|
||||
add(result, byte(rand(0 .. 255)))
|
||||
|
||||
proc waitForClose*(ws: WSSession) {.async.} =
|
||||
try:
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
except CatchableError:
|
||||
trace "Closing websocket"
|
||||
|
||||
proc createServer*(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
tlsPrivateKey = WSSecureKey,
|
||||
tlsCertificate = WSSecureCert,
|
||||
handler: HttpAsyncCallback = nil,
|
||||
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr},
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
tlsMinVersion = TLSVersion.TLS12,
|
||||
tlsMaxVersion = TLSVersion.TLS12): HttpServer =
|
||||
when defined secure:
|
||||
TlsHttpServer.create(
|
||||
address = address,
|
||||
tlsPrivateKey = tlsPrivateKey,
|
||||
tlsCertificate = tlsCertificate,
|
||||
handler = handler,
|
||||
flags = flags,
|
||||
tlsFlags = tlsFlags,
|
||||
tlsMinVersion = tlsMinVersion,
|
||||
tlsMaxVersion = tlsMaxVersion)
|
||||
else:
|
||||
HttpServer.create(
|
||||
address = address,
|
||||
handler = handler,
|
||||
flags = flags)
|
||||
|
||||
proc connectClient*(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
path = WSPath,
|
||||
protocols: seq[string] = @["proto"],
|
||||
flags: set[TLSFlags] = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName},
|
||||
version = WSDefaultVersion,
|
||||
frameSize = WSDefaultFrameSize,
|
||||
onPing: ControlCb = nil,
|
||||
onPong: ControlCb = nil,
|
||||
onClose: CloseCb = nil,
|
||||
rng: Rng = nil): Future[WSSession] {.async.} =
|
||||
let secure = when defined secure: true else: false
|
||||
return await WebSocket.connect(
|
||||
address = address,
|
||||
flags = flags,
|
||||
path = path,
|
||||
secure = secure,
|
||||
protocols = protocols,
|
||||
version = version,
|
||||
frameSize = frameSize,
|
||||
onPing = onPing,
|
||||
onPong = onPong,
|
||||
onClose = onClose,
|
||||
rng = rng)
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
import ./testframes
|
||||
import ./testutf8
|
||||
import ./test_ext_utils
|
||||
import ./testextutils
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import
|
||||
pkg/[asynctest, chronos],
|
||||
../ws/ext_utils
|
||||
../ws/extensions
|
||||
|
||||
suite "extension parser":
|
||||
test "single extension":
|
||||
|
@ -222,47 +222,47 @@ suite "extension parser":
|
|||
var app: seq[AppExt]
|
||||
let res = parseExt("filename=foo.html, filename=bar.html", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "emptydisposition":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt(" ; filename=foo.html", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "doublecolon":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt(": inline; attachment; filename=foo.html", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attbrokenquotedfn":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt(" attachment; filename=\"foo.html\".txt", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attbrokenquotedfn2":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt("attachment; filename=\"bar", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attbrokenquotedfn3":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt("attachment; filename=foo\"bar;baz\"qux", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attmissingdelim":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt("attachment; foo=foo filename=bar", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attmissingdelim2":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt("attachment; filename=bar foo=foo", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attmissingdelim3":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt("attachment filename=bar", app)
|
||||
check res == false
|
||||
|
||||
|
||||
test "attreversed":
|
||||
var app: seq[AppExt]
|
||||
let res = parseExt("filename=foo.html; attachment", app)
|
|
@ -15,7 +15,7 @@ import
|
|||
chronos,
|
||||
chronicles
|
||||
],
|
||||
../ws/[ws, utf8_dfa]
|
||||
../ws/[ws, utf8dfa]
|
||||
|
||||
suite "UTF-8 DFA validator":
|
||||
test "single octet":
|
||||
|
@ -76,7 +76,7 @@ proc waitForClose(ws: WSSession) {.async.} =
|
|||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
except CatchableError:
|
||||
debug "Closing websocket"
|
||||
trace "Closing websocket"
|
||||
|
||||
# TODO: use new test framework from dryajov
|
||||
# if it is ready.
|
||||
|
@ -125,10 +125,10 @@ suite "UTF-8 validator in action":
|
|||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == "/ws"
|
||||
|
||||
proc onClose(status: Status, reason: string):
|
||||
proc onClose(status: StatusCodes, reason: string):
|
||||
CloseResult {.gcsafe, raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.Fulfilled
|
||||
check status == StatusFulfilled
|
||||
check reason == closeReason
|
||||
return (status, reason)
|
||||
except Exception as exc:
|
||||
|
@ -137,6 +137,8 @@ suite "UTF-8 validator in action":
|
|||
let server = WSServer.new(protos = ["proto"], onClose = onClose)
|
||||
let ws = await server.handleRequest(request)
|
||||
let res = await ws.recv()
|
||||
await waitForClose(ws)
|
||||
|
||||
check:
|
||||
string.fromBytes(res) == testData
|
||||
ws.binary == false
|
||||
|
@ -168,6 +170,7 @@ suite "UTF-8 validator in action":
|
|||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
discard await ws.recv()
|
||||
await waitForClose(ws)
|
||||
|
||||
server = HttpServer.create(
|
||||
address,
|
||||
|
@ -183,7 +186,7 @@ suite "UTF-8 validator in action":
|
|||
)
|
||||
|
||||
await session.send(testData)
|
||||
await waitForClose( session)
|
||||
await session.close()
|
||||
check session.readyState == ReadyState.Closed
|
||||
|
||||
test "invalid UTF-8 sequence close code":
|
||||
|
@ -201,6 +204,8 @@ suite "UTF-8 validator in action":
|
|||
string.fromBytes(res) == testData
|
||||
ws.binary == false
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
server = HttpServer.create(
|
||||
address,
|
||||
handle,
|
||||
|
@ -216,5 +221,4 @@ suite "UTF-8 validator in action":
|
|||
|
||||
await session.send(testData)
|
||||
await session.close(reason = closeReason)
|
||||
await waitForClose( session)
|
||||
check session.readyState == ReadyState.Closed
|
||||
|
|
|
@ -5,85 +5,16 @@ import pkg/[
|
|||
chronicles,
|
||||
stew/byteutils]
|
||||
|
||||
import ./asynctest
|
||||
import ../ws/ws
|
||||
import ./keys
|
||||
|
||||
var server: HttpServer
|
||||
import ./asynctest
|
||||
import ./helpers
|
||||
|
||||
let
|
||||
address = initTAddress("127.0.0.1:8888")
|
||||
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||
secureKey = TLSPrivateKey.init(SecureKey)
|
||||
secureCert = TLSCertificate.init(SecureCert)
|
||||
address* = initTAddress("127.0.0.1:8888")
|
||||
|
||||
const WSPath = when defined secure: "/wss" else: "/ws"
|
||||
|
||||
proc rndStr*(size: int): string =
|
||||
for _ in .. size:
|
||||
add(result, char(rand(int('A') .. int('z'))))
|
||||
|
||||
proc rndBin*(size: int): seq[byte] =
|
||||
for _ in .. size:
|
||||
add(result, byte(rand(0 .. 255)))
|
||||
|
||||
proc waitForClose(ws: WSSession) {.async.} =
|
||||
try:
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
except CatchableError:
|
||||
debug "Closing websocket"
|
||||
|
||||
proc createServer(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert,
|
||||
handler: HttpAsyncCallback = nil,
|
||||
flags: set[ServerFlags] = socketFlags,
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
tlsMinVersion = TLSVersion.TLS12,
|
||||
tlsMaxVersion = TLSVersion.TLS12): HttpServer =
|
||||
when defined secure:
|
||||
TlsHttpServer.create(
|
||||
address = address,
|
||||
tlsPrivateKey = tlsPrivateKey,
|
||||
tlsCertificate = tlsCertificate,
|
||||
handler = handler,
|
||||
flags = flags,
|
||||
tlsFlags = tlsFlags,
|
||||
tlsMinVersion = tlsMinVersion,
|
||||
tlsMaxVersion = tlsMaxVersion)
|
||||
else:
|
||||
HttpServer.create(
|
||||
address = address,
|
||||
handler = handler,
|
||||
flags = flags)
|
||||
|
||||
proc connectClient*(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
path = WSPath,
|
||||
protocols: seq[string] = @["proto"],
|
||||
flags: set[TLSFlags] = clientFlags,
|
||||
version = WSDefaultVersion,
|
||||
frameSize = WSDefaultFrameSize,
|
||||
onPing: ControlCb = nil,
|
||||
onPong: ControlCb = nil,
|
||||
onClose: CloseCb = nil,
|
||||
rng: Rng = nil): Future[WSSession] {.async.} =
|
||||
let secure = when defined secure: true else: false
|
||||
return await WebSocket.connect(
|
||||
address = address,
|
||||
flags = flags,
|
||||
path = path,
|
||||
secure = secure,
|
||||
protocols = protocols,
|
||||
version = version,
|
||||
frameSize = frameSize,
|
||||
onPing = onPing,
|
||||
onPong = onPong,
|
||||
onClose = onClose,
|
||||
rng = rng)
|
||||
var
|
||||
server: HttpServer
|
||||
|
||||
suite "Test handshake":
|
||||
teardown:
|
||||
|
@ -175,25 +106,20 @@ suite "Test handshake":
|
|||
parseUri(uri),
|
||||
protocols = @["proto"])
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
|
||||
suite "Test transmission":
|
||||
teardown:
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "Send text message message with payload of length 65535":
|
||||
let testString = rndStr(65535)
|
||||
test "Server - test reading simple frame":
|
||||
let testString = "Hello!"
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
|
||||
check string.fromBytes(servRes) == testString
|
||||
await ws.waitForClose()
|
||||
|
||||
|
@ -207,15 +133,13 @@ suite "Test transmission":
|
|||
await session.send(testString)
|
||||
await session.close()
|
||||
|
||||
test "Server - test reading simple frame":
|
||||
let testString = "Hello!"
|
||||
test "Send text message message with payload of length 65535":
|
||||
let testString = rndStr(65535)
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
|
||||
check string.fromBytes(servRes) == testString
|
||||
await ws.waitForClose()
|
||||
|
||||
|
@ -250,83 +174,11 @@ suite "Test transmission":
|
|||
check string.fromBytes(clientRes) == testString
|
||||
await waitForClose(session)
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
|
||||
suite "Test ping-pong":
|
||||
teardown:
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "Send text Message fragmented into 2 fragments, one ping with payload in-between":
|
||||
var ping, pong = false
|
||||
let testString = "1234567890"
|
||||
let msg = toBytes(testString)
|
||||
let maxFrameSize = 5
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let server = WSServer.new(
|
||||
protos = ["proto"],
|
||||
onPing = proc(data: openArray[byte]) =
|
||||
ping = true
|
||||
)
|
||||
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let respData = await ws.recv()
|
||||
check string.fromBytes(respData) == testString
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let session = await connectClient(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
frameSize = maxFrameSize,
|
||||
onPong = proc(data: openArray[byte]) =
|
||||
pong = true
|
||||
)
|
||||
|
||||
let maskKey = genMaskKey(newRng())
|
||||
await session.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
data: msg[0..4],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
await session.ping()
|
||||
|
||||
await session.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Cont,
|
||||
mask: true,
|
||||
data: msg[5..9],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
await session.close()
|
||||
check:
|
||||
ping
|
||||
pong
|
||||
|
||||
test "Server - test ping-pong control messages":
|
||||
var ping, pong = false
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
|
@ -389,12 +241,59 @@ suite "Test ping-pong":
|
|||
await session.ping()
|
||||
await session.close()
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
test "Send ping with small text payload":
|
||||
let testData = toBytes("Hello, world!")
|
||||
var ping, pong = false
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(
|
||||
protos = ["proto"],
|
||||
onPing = proc(data: openArray[byte]) =
|
||||
ping = data == testData)
|
||||
|
||||
let ws = await server.handleRequest(request)
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let session = await connectClient(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
onPong = proc(data: openArray[byte]) =
|
||||
pong = true
|
||||
)
|
||||
|
||||
await session.ping(testData)
|
||||
await session.close()
|
||||
check:
|
||||
ping
|
||||
pong
|
||||
|
||||
test "Test ping payload message length":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
expect WSPayloadTooLarge:
|
||||
discard await ws.recv()
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let str = rndStr(126)
|
||||
let session = await connectClient()
|
||||
await session.ping(str.toBytes())
|
||||
await session.close()
|
||||
|
||||
suite "Test framing":
|
||||
teardown:
|
||||
|
@ -408,13 +307,13 @@ suite "Test framing":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let frame1 = await ws.readFrame()
|
||||
let frame1 = await ws.readFrame(@[])
|
||||
check not isNil(frame1)
|
||||
var data1 = newSeq[byte](frame1.remainder().int)
|
||||
let read1 = await ws.stream.reader.readOnce(addr data1[0], data1.len)
|
||||
check read1 == 5
|
||||
|
||||
let frame2 = await ws.readFrame()
|
||||
let frame2 = await ws.readFrame(@[])
|
||||
check not isNil(frame2)
|
||||
var data2 = newSeq[byte](frame2.remainder().int)
|
||||
let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len)
|
||||
|
@ -454,16 +353,8 @@ suite "Test framing":
|
|||
|
||||
expect WSMaxMessageSizeError:
|
||||
discard await session.recv(5)
|
||||
|
||||
await waitForClose(session)
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
|
||||
suite "Test Closing":
|
||||
teardown:
|
||||
server.stop()
|
||||
|
@ -490,15 +381,15 @@ suite "Test Closing":
|
|||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
|
||||
proc closeServer(status: StatusCodes, reason: string): CloseResult{.gcsafe,
|
||||
raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.TooLarge
|
||||
check status == StatusTooLarge
|
||||
check reason == "Message too big!"
|
||||
except Exception as exc:
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
return (Status.Fulfilled, "")
|
||||
return (StatusFulfilled, "")
|
||||
|
||||
let server = WSServer.new(
|
||||
protos = ["proto"],
|
||||
|
@ -514,11 +405,11 @@ suite "Test Closing":
|
|||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
|
||||
proc clientClose(status: StatusCodes, reason: string): CloseResult {.gcsafe,
|
||||
raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.Fulfilled
|
||||
return (Status.TooLarge, "Message too big!")
|
||||
check status == StatusFulfilled
|
||||
return (StatusTooLarge, "Message too big!")
|
||||
except Exception as exc:
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
|
@ -548,11 +439,11 @@ suite "Test Closing":
|
|||
test "Client closing with status":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
|
||||
proc closeServer(status: StatusCodes, reason: string): CloseResult{.gcsafe,
|
||||
raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.Fulfilled
|
||||
return (Status.TooLarge, "Message too big!")
|
||||
check status == StatusFulfilled
|
||||
return (StatusTooLarge, "Message too big!")
|
||||
except Exception as exc:
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
|
@ -570,12 +461,12 @@ suite "Test Closing":
|
|||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
|
||||
proc clientClose(status: StatusCodes, reason: string): CloseResult {.gcsafe,
|
||||
raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.TooLarge
|
||||
check status == StatusTooLarge
|
||||
check reason == "Message too big!"
|
||||
return (Status.Fulfilled, "")
|
||||
return (StatusFulfilled, "")
|
||||
except Exception as exc:
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
|
@ -592,7 +483,7 @@ suite "Test Closing":
|
|||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
await ws.close(code = Status.ReservedCode)
|
||||
await ws.close(code = StatusCodes(StatusLibsCodes.high))
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
|
@ -600,11 +491,11 @@ suite "Test Closing":
|
|||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
proc closeClient(status: Status, reason: string): CloseResult{.gcsafe,
|
||||
raises: [Defect].} =
|
||||
proc closeClient(status: StatusCodes, reason: string): CloseResult
|
||||
{.gcsafe, raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.ReservedCode
|
||||
return (Status.ReservedCode, "Reserved Status")
|
||||
check status == StatusCodes(StatusLibsCodes.high)
|
||||
return (StatusCodes(StatusLibsCodes.high), "Reserved StatusCodes")
|
||||
except Exception as exc:
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
|
@ -617,11 +508,11 @@ suite "Test Closing":
|
|||
test "Client closing with valid close code 3999":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
|
||||
proc closeServer(status: StatusCodes, reason: string): CloseResult{.gcsafe,
|
||||
raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.ReservedCode
|
||||
return (Status.ReservedCode, "Reserved Status")
|
||||
check status == StatusCodes(3999)
|
||||
return (StatusCodes(3999), "Reserved StatusCodes")
|
||||
except Exception as exc:
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
|
@ -640,7 +531,7 @@ suite "Test Closing":
|
|||
server.start()
|
||||
|
||||
let session = await connectClient()
|
||||
await session.close(code = Status.ReservedCode)
|
||||
await session.close(code = StatusCodes(3999))
|
||||
|
||||
test "Server closing with Payload of length 2":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
|
@ -681,41 +572,12 @@ suite "Test Closing":
|
|||
# Close with payload of length 2
|
||||
await session.close(reason = "HH")
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
|
||||
suite "Test Payload":
|
||||
teardown:
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "Test payload message length":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
expect WSPayloadTooLarge:
|
||||
discard await ws.recv()
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let str = rndStr(126)
|
||||
let session = await connectClient()
|
||||
await session.ping(str.toBytes())
|
||||
await session.close()
|
||||
|
||||
test "Test single empty payload":
|
||||
test "Test payload of length 0":
|
||||
let emptyStr = ""
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
@ -724,8 +586,12 @@ suite "Test Payload":
|
|||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
|
||||
check string.fromBytes(servRes) == emptyStr
|
||||
await waitForClose(ws)
|
||||
check:
|
||||
servRes.len == 0
|
||||
string.fromBytes(servRes) == emptyStr
|
||||
|
||||
await ws.send(emptyStr)
|
||||
await ws.waitForClose()
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
|
@ -734,20 +600,32 @@ suite "Test Payload":
|
|||
server.start()
|
||||
|
||||
let session = await connectClient()
|
||||
|
||||
await session.send(emptyStr)
|
||||
let clientRes = await session.recv()
|
||||
|
||||
check:
|
||||
clientRes.len == 0
|
||||
string.fromBytes(clientRes) == emptyStr
|
||||
|
||||
await session.close()
|
||||
|
||||
test "Test multiple empty payload":
|
||||
test "Test multiple payloads of length 0":
|
||||
let emptyStr = ""
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
for _ in 0..<3:
|
||||
let servRes = await ws.recv()
|
||||
|
||||
check:
|
||||
servRes.len == 0
|
||||
string.fromBytes(servRes) == emptyStr
|
||||
|
||||
for i in 0..3:
|
||||
await ws.send(emptyStr)
|
||||
|
||||
check string.fromBytes(servRes) == emptyStr
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
|
@ -757,22 +635,35 @@ suite "Test Payload":
|
|||
server.start()
|
||||
|
||||
let session = await connectClient()
|
||||
|
||||
for i in 0..3:
|
||||
await session.send(emptyStr)
|
||||
|
||||
for _ in 0..<3:
|
||||
let clientRes = await session.recv()
|
||||
|
||||
check:
|
||||
clientRes.len == 0
|
||||
string.fromBytes(clientRes) == emptyStr
|
||||
|
||||
await session.close()
|
||||
|
||||
test "Send ping with small text payload":
|
||||
let testData = toBytes("Hello, world!")
|
||||
test "Send two fragments":
|
||||
var ping, pong = false
|
||||
let testString = "1234567890"
|
||||
let msg = toBytes(testString)
|
||||
let maxFrameSize = 5
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(
|
||||
protos = ["proto"],
|
||||
onPing = proc(data: openArray[byte]) =
|
||||
ping = data == testData)
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
|
||||
let ws = await server.handleRequest(request)
|
||||
let respData = await ws.recv()
|
||||
|
||||
check:
|
||||
string.fromBytes(respData) == testString
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
|
@ -783,22 +674,133 @@ suite "Test Payload":
|
|||
|
||||
let session = await connectClient(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
frameSize = maxFrameSize)
|
||||
|
||||
let maskKey = genMaskKey(newRng())
|
||||
await session.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
data: msg[0..4],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
await session.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Cont,
|
||||
mask: true,
|
||||
data: msg[5..9],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
await session.close()
|
||||
|
||||
test "Send two fragments with a ping with payload in-between":
|
||||
var ping, pong = false
|
||||
let testString = "1234567890"
|
||||
let msg = toBytes(testString)
|
||||
let maxFrameSize = 5
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let server = WSServer.new(
|
||||
protos = ["proto"],
|
||||
onPing = proc(data: openArray[byte]) =
|
||||
ping = true
|
||||
)
|
||||
|
||||
let ws = await server.handleRequest(request)
|
||||
let respData = await ws.recv()
|
||||
check:
|
||||
string.fromBytes(respData) == testString
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let session = await connectClient(
|
||||
address = initTAddress("127.0.0.1:8888"),
|
||||
frameSize = maxFrameSize,
|
||||
onPong = proc(data: openArray[byte]) =
|
||||
pong = true
|
||||
)
|
||||
|
||||
await session.ping(testData)
|
||||
let maskKey = genMaskKey(newRng())
|
||||
await session.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Text,
|
||||
mask: true,
|
||||
data: msg[0..4],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
await session.ping()
|
||||
|
||||
await session.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode.Cont,
|
||||
mask: true,
|
||||
data: msg[5..9],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
await session.close()
|
||||
check:
|
||||
ping
|
||||
pong
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
test "Send text message with multiple frames":
|
||||
const FrameSize = 3000
|
||||
let testData = rndStr(FrameSize * 3)
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let res = await ws.recv()
|
||||
|
||||
check ws.binary == false
|
||||
await ws.send(res, Opcode.Text)
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let ws = await connectClient(
|
||||
address = address,
|
||||
frameSize = FrameSize
|
||||
)
|
||||
|
||||
await ws.send(testData)
|
||||
let echoed = await ws.recv()
|
||||
await ws.close()
|
||||
|
||||
check:
|
||||
string.fromBytes(echoed) == testData
|
||||
ws.binary == false
|
||||
|
||||
suite "Test Binary message with Payload":
|
||||
teardown:
|
||||
|
@ -860,7 +862,7 @@ suite "Test Binary message with Payload":
|
|||
|
||||
test "Send binary data with small text payload":
|
||||
let testData = rndBin(10)
|
||||
debug "testData", testData = testData
|
||||
trace "testData", testData = testData
|
||||
var ping, pong = false
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
@ -929,9 +931,42 @@ suite "Test Binary message with Payload":
|
|||
await session.send(testData, Opcode.Binary)
|
||||
await session.close()
|
||||
|
||||
# test "AsyncStream leaks test":
|
||||
# check:
|
||||
# getTracker("async.stream.reader").isLeaked() == false
|
||||
# getTracker("async.stream.writer").isLeaked() == false
|
||||
# getTracker("stream.server").isLeaked() == false
|
||||
# getTracker("stream.transport").isLeaked() == false
|
||||
test "Send binary message with multiple frames":
|
||||
const FrameSize = 3000
|
||||
let testData = rndBin(FrameSize * 3)
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let res = await ws.recv()
|
||||
|
||||
check:
|
||||
ws.binary == true
|
||||
res == testData
|
||||
|
||||
await ws.send(res, Opcode.Binary)
|
||||
await waitForClose(ws)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let ws = await connectClient(
|
||||
address = address,
|
||||
frameSize = FrameSize
|
||||
)
|
||||
|
||||
await ws.send(testData, Opcode.Binary)
|
||||
let echoed = await ws.recv()
|
||||
|
||||
check:
|
||||
echoed == testData
|
||||
|
||||
await ws.close()
|
||||
|
||||
check:
|
||||
echoed == testData
|
||||
ws.binary == true
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
|
||||
import std/tables
|
||||
import ./extensions/extutils
|
||||
# import ./extensions/compression/compression
|
||||
|
||||
export extutils
|
|
@ -0,0 +1,212 @@
|
|||
## nim-ws
|
||||
## Copyright (c) 2021 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
|
||||
import
|
||||
std/[strutils],
|
||||
pkg/[stew/results, chronos],
|
||||
../../types, ../../frame, ./miniz/miniz_api
|
||||
|
||||
type
|
||||
DeflateOpts = object
|
||||
isServer: bool
|
||||
serverNoContextTakeOver: bool
|
||||
clientNoContextTakeOver: bool
|
||||
serverMaxWindowBits: int
|
||||
clientMaxWindowBits: int
|
||||
|
||||
DeflateExt = ref object of Ext
|
||||
paramStr: string
|
||||
opts: DeflateOpts
|
||||
# 'messageCompressed' is a two way flag:
|
||||
# 1. the original message is compressible or not
|
||||
# 2. the frame we received need decompression or not
|
||||
messageCompressed: bool
|
||||
|
||||
const
|
||||
extID = "permessage-deflate"
|
||||
|
||||
proc concatParam(resp: var string, param: string) =
|
||||
resp.add ';'
|
||||
resp.add param
|
||||
|
||||
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
|
||||
if arg.value.len == 0:
|
||||
return ok("")
|
||||
|
||||
if arg.value.len > 2:
|
||||
return err("window bits expect 2 bytes, got " & $arg.value.len)
|
||||
|
||||
for n in arg.value:
|
||||
if n notin Digits:
|
||||
return err("window bits value contains illegal char: " & $n)
|
||||
|
||||
var winbit = 0
|
||||
for i in 0..<arg.value.len:
|
||||
winbit = winbit * 10 + arg.value[i].int - '0'.int
|
||||
|
||||
if winbit < 8 or winbit > 15:
|
||||
return err("window bits should between 8-15, got " & $winbit)
|
||||
|
||||
res = winbit
|
||||
return ok("=" & arg.value)
|
||||
|
||||
proc validateParams(args: seq[ExtParam],
|
||||
opts: var DeflateOpts): Result[string, string] =
|
||||
# besides validating extensions params, this proc
|
||||
# also constructing extension param for response
|
||||
var resp = ""
|
||||
for arg in args:
|
||||
case arg.name
|
||||
of "server_no_context_takeover":
|
||||
if arg.value.len > 0:
|
||||
return err("'server_no_context_takeover' should have no param")
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
opts.serverNoContextTakeOver = true
|
||||
of "client_no_context_takeover":
|
||||
if arg.value.len > 0:
|
||||
return err("'client_no_context_takeover' should have no param")
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
opts.clientNoContextTakeOver = true
|
||||
of "server_max_window_bits":
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
|
||||
if res.isErr:
|
||||
return res
|
||||
resp.add res.get()
|
||||
of "client_max_window_bits":
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
|
||||
if res.isErr:
|
||||
return res
|
||||
resp.add res.get()
|
||||
else:
|
||||
return err("unrecognized param: " & arg.name)
|
||||
|
||||
ok(resp)
|
||||
|
||||
method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
|
||||
if not ext.messageCompressed:
|
||||
return data
|
||||
|
||||
# TODO: append trailing bytes
|
||||
var mz = MzStream(
|
||||
next_in: cast[ptr cuchar](data[0].unsafeAddr),
|
||||
avail_in: data.len.cuint
|
||||
)
|
||||
|
||||
let windowBits = if ext.opts.serverMaxWindowBits == 0:
|
||||
MZ_DEFAULT_WINDOW_BITS
|
||||
else:
|
||||
MzWindowBits(ext.opts.serverMaxWindowBits)
|
||||
|
||||
doAssert(mz.inflateInit2(windowBits) == MZ_OK)
|
||||
var res: seq[byte]
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
while true:
|
||||
mz.next_out = cast[ptr cuchar](buf[0].addr)
|
||||
mz.avail_out = buf.len.cuint
|
||||
let r = mz.inflate(MZ_SYNC_FLUSH)
|
||||
let outSize = buf.len - mz.avail_out.int
|
||||
res.add toOpenArray(buf, 0, outSize-1)
|
||||
if r == MZ_STREAM_END:
|
||||
break
|
||||
elif r == MZ_OK:
|
||||
continue
|
||||
else:
|
||||
doAssert(false, "decompression error")
|
||||
|
||||
doAssert(mz.inflateEnd() == MZ_OK)
|
||||
return res
|
||||
|
||||
method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
|
||||
var mz = MzStream(
|
||||
next_in: cast[ptr cuchar](data[0].unsafeAddr),
|
||||
avail_in: data.len.cuint
|
||||
)
|
||||
|
||||
let windowBits = if ext.opts.serverMaxWindowBits == 0:
|
||||
MZ_DEFAULT_WINDOW_BITS
|
||||
else:
|
||||
MzWindowBits(ext.opts.serverMaxWindowBits)
|
||||
|
||||
doAssert(mz.deflateInit2(
|
||||
level = MZ_DEFAULT_LEVEL,
|
||||
meth = MZ_DEFLATED,
|
||||
windowBits,
|
||||
1,
|
||||
strategy = MZ_DEFAULT_STRATEGY) == MZ_OK
|
||||
)
|
||||
|
||||
let maxSize = mz.deflateBound(data.len.culong).int
|
||||
var res: seq[byte]
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
while true:
|
||||
mz.next_out = cast[ptr cuchar](buf[0].addr)
|
||||
mz.avail_out = buf.len.cuint
|
||||
let r = mz.deflate(MZ_FINISH)
|
||||
let outSize = buf.len - mz.avail_out.int
|
||||
res.add toOpenArray(buf, 0, outSize-1)
|
||||
if r == MZ_STREAM_END:
|
||||
break
|
||||
elif r == MZ_OK:
|
||||
continue
|
||||
else:
|
||||
doAssert(false, "compression error")
|
||||
|
||||
# TODO: cut trailing bytes
|
||||
doAssert(mz.deflateEnd() == MZ_OK)
|
||||
ext.messageCompressed = res.len < data.len
|
||||
if ext.messageCompressed:
|
||||
return res
|
||||
else:
|
||||
return data
|
||||
|
||||
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||
# only data frame can be compressed
|
||||
# and we want to know if this message is compressed or not
|
||||
# if the frame opcode is text or binary, it should also the first frame
|
||||
ext.messageCompressed = frame.rsv1
|
||||
# clear rsv1 bit because we already done with it
|
||||
frame.rsv1 = false
|
||||
return frame
|
||||
|
||||
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||
# only data frame can be compressed
|
||||
# and we only set rsv1 bit to true if the message is compressible
|
||||
# if the frame opcode is text or binary, it should also the first frame
|
||||
frame.rsv1 = ext.messageCompressed
|
||||
return frame
|
||||
|
||||
method toHttpOptions(ext: DeflateExt): string =
|
||||
# using paramStr here is a bit clunky
|
||||
extID & "; " & ext.paramStr
|
||||
|
||||
proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} =
|
||||
var opts = DeflateOpts(isServer: isServer)
|
||||
let resp = validateParams(args, opts)
|
||||
if resp.isErr:
|
||||
return err(resp.error)
|
||||
let ext = DeflateExt(
|
||||
name: extID,
|
||||
paramStr: resp.get(),
|
||||
opts: opts
|
||||
)
|
||||
ok(ext)
|
||||
|
||||
const
|
||||
deflateFactory* = (extID, deflateExtFactory)
|
35
ws/frame.nim
35
ws/frame.nim
|
@ -15,8 +15,12 @@ import pkg/[
|
|||
stew/byteutils,
|
||||
stew/endians2,
|
||||
stew/results]
|
||||
|
||||
import ./types
|
||||
|
||||
logScope:
|
||||
topics = "ws-frame"
|
||||
|
||||
#[
|
||||
+---------------------------------------------------------------+
|
||||
|0 1 2 3 |
|
||||
|
@ -54,7 +58,6 @@ template remainder*(frame: Frame): uint64 =
|
|||
|
||||
proc encode*(
|
||||
frame: Frame,
|
||||
offset = 0,
|
||||
extensions: seq[Ext] = @[]): Future[seq[byte]] {.async.} =
|
||||
## Encodes a frame into a string buffer.
|
||||
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
|
@ -65,7 +68,7 @@ proc encode*(
|
|||
f = await e.encode(f)
|
||||
|
||||
var ret: seq[byte]
|
||||
var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags.
|
||||
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
|
||||
if f.fin:
|
||||
b0 = b0 or 128'u8
|
||||
|
||||
|
@ -77,7 +80,7 @@ proc encode*(
|
|||
|
||||
if f.data.len <= 125:
|
||||
b1 = f.data.len.uint8
|
||||
elif f.data.len > 125 and f.data.len <= 0xFFFF:
|
||||
elif f.data.len > 125 and f.data.len <= 0xffff:
|
||||
b1 = 126'u8
|
||||
else:
|
||||
b1 = 127'u8
|
||||
|
@ -88,12 +91,12 @@ proc encode*(
|
|||
ret.add(uint8 b1)
|
||||
|
||||
# Only need more bytes if data len is 7+16 bits, or 7+64 bits.
|
||||
if f.data.len > 125 and f.data.len <= 0xFFFF:
|
||||
if f.data.len > 125 and f.data.len <= 0xffff:
|
||||
# Data len is 7+16 bits.
|
||||
var len = f.data.len.uint16
|
||||
ret.add ((len shr 8) and 0xFF).uint8
|
||||
ret.add (len and 0xFF).uint8
|
||||
elif f.data.len > 0xFFFF:
|
||||
ret.add ((len shr 8) and 0xff).uint8
|
||||
ret.add (len and 0xff).uint8
|
||||
elif f.data.len > 0xffff:
|
||||
# Data len is 7+64 bits.
|
||||
var len = f.data.len.uint64
|
||||
ret.add(len.toBytesBE())
|
||||
|
@ -101,7 +104,7 @@ proc encode*(
|
|||
var data = f.data
|
||||
if f.mask:
|
||||
# If we need to mask it generate random mask key and mask the data.
|
||||
mask(data, f.maskKey, offset)
|
||||
mask(data, f.maskKey)
|
||||
|
||||
# Write mask key next.
|
||||
ret.add(f.maskKey[0].uint8)
|
||||
|
@ -122,10 +125,10 @@ proc decode*(
|
|||
##
|
||||
|
||||
var header = newSeq[byte](2)
|
||||
debug "Reading new frame"
|
||||
trace "Reading new frame"
|
||||
await reader.readExactly(addr header[0], 2)
|
||||
if header.len != 2:
|
||||
debug "Invalid websocket header length"
|
||||
trace "Invalid websocket header length"
|
||||
raise newException(WSMalformedHeaderError,
|
||||
"Invalid websocket header length")
|
||||
|
||||
|
@ -147,10 +150,6 @@ proc decode*(
|
|||
|
||||
frame.opcode = (opcode).Opcode
|
||||
|
||||
# If any of the rsv are set close the socket.
|
||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
||||
|
||||
# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
|
||||
var finalLen: uint64 = 0
|
||||
|
||||
|
@ -187,7 +186,11 @@ proc decode*(
|
|||
frame.maskKey[i] = cast[char](maskKey[i])
|
||||
|
||||
if extensions.len > 0:
|
||||
for e in extensions[extensions.high..extensions.low]:
|
||||
frame = await e.decode(frame)
|
||||
for i in countdown(extensions.high, extensions.low):
|
||||
frame = await extensions[i].decode(frame)
|
||||
|
||||
# If any of the rsv are set close the socket.
|
||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
||||
|
||||
return frame
|
||||
|
|
|
@ -9,6 +9,9 @@ import pkg/[
|
|||
|
||||
import ./common
|
||||
|
||||
logScope:
|
||||
topics = "http-client"
|
||||
|
||||
type
|
||||
HttpClient* = ref object of RootObj
|
||||
connected*: bool
|
||||
|
@ -44,7 +47,7 @@ proc readResponse(stream: AsyncStreamReader): Future[HttpResponseHeader] {.async
|
|||
|
||||
return buffer.parseResponse()
|
||||
except CatchableError as exc:
|
||||
debug "Exception reading headers", exc = exc.msg
|
||||
trace "Exception reading headers", exc = exc.msg
|
||||
buffer.setLen(0)
|
||||
raise exc
|
||||
|
||||
|
|
|
@ -53,8 +53,7 @@ proc closeWait*(stream: AsyncStream) {.async.} =
|
|||
await allFutures(
|
||||
stream.reader.tsource.closeTransp(),
|
||||
stream.reader.closeStream(),
|
||||
stream.writer.closeStream()
|
||||
)
|
||||
stream.writer.closeStream())
|
||||
|
||||
proc sendResponse*(
|
||||
request: HttpRequest,
|
||||
|
@ -112,8 +111,7 @@ proc sendError*(
|
|||
response.add(CRLF)
|
||||
|
||||
await stream.write(
|
||||
response.toBytes() &
|
||||
content.toBytes())
|
||||
response.toBytes() & content.toBytes())
|
||||
|
||||
proc sendError*(
|
||||
request: HttpRequest,
|
||||
|
|
|
@ -30,13 +30,13 @@ proc validateRequest(
|
|||
##
|
||||
|
||||
if header.meth notin {MethodGet}:
|
||||
debug "GET method is only allowed", address = stream.tsource.remoteAddress()
|
||||
trace "GET method is only allowed", address = stream.tsource.remoteAddress()
|
||||
await stream.sendError(Http405, version = header.version)
|
||||
return ReqStatus.Error
|
||||
|
||||
var hlen = header.contentLength()
|
||||
if hlen < 0 or hlen > MaxHttpRequestSize:
|
||||
debug "Invalid header length", address = stream.tsource.remoteAddress()
|
||||
trace "Invalid header length", address = stream.tsource.remoteAddress()
|
||||
await stream.sendError(Http413, version = header.version)
|
||||
return ReqStatus.Error
|
||||
|
||||
|
@ -50,14 +50,14 @@ proc handleRequest(
|
|||
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
let remoteAddr = stream.reader.tsource.remoteAddress()
|
||||
debug "Received connection", address = $remoteAddr
|
||||
trace "Received connection", address = $remoteAddr
|
||||
try:
|
||||
let hlenfut = stream.reader.readUntil(
|
||||
addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep)
|
||||
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||
if not ores:
|
||||
# Timeout
|
||||
debug "Timeout expired while receiving headers", address = $remoteAddr
|
||||
trace "Timeout expired while receiving headers", address = $remoteAddr
|
||||
await stream.writer.sendError(Http408, version = HttpVersion11)
|
||||
return
|
||||
|
||||
|
@ -66,7 +66,7 @@ proc handleRequest(
|
|||
let requestData = buffer.parseRequest()
|
||||
if requestData.failed():
|
||||
# Header could not be parsed
|
||||
debug "Malformed header received", address = $remoteAddr
|
||||
trace "Malformed header received", address = $remoteAddr
|
||||
await stream.writer.sendError(Http400, version = HttpVersion11)
|
||||
return
|
||||
|
||||
|
@ -79,10 +79,10 @@ proc handleRequest(
|
|||
res
|
||||
|
||||
if vres == ReqStatus.ErrorFailure:
|
||||
debug "Remote peer disconnected", address = $remoteAddr
|
||||
trace "Remote peer disconnected", address = $remoteAddr
|
||||
return
|
||||
|
||||
debug "Received valid HTTP request", address = $remoteAddr
|
||||
trace "Received valid HTTP request", address = $remoteAddr
|
||||
# Call the user's handler.
|
||||
if server.handler != nil:
|
||||
await server.handler(
|
||||
|
@ -92,15 +92,15 @@ proc handleRequest(
|
|||
uri: requestData.uri().parseUri()))
|
||||
except TransportLimitError:
|
||||
# size of headers exceeds `MaxHttpHeadersSize`
|
||||
debug "Maximum size of headers limit reached", address = $remoteAddr
|
||||
trace "maximum size of headers limit reached", address = $remoteAddr
|
||||
await stream.writer.sendError(Http413, version = HttpVersion11)
|
||||
except TransportIncompleteError:
|
||||
# remote peer disconnected
|
||||
debug "Remote peer disconnected", address = $remoteAddr
|
||||
trace "Remote peer disconnected", address = $remoteAddr
|
||||
except TransportOsError as exc:
|
||||
debug "Problems with networking", address = $remoteAddr, error = exc.msg
|
||||
trace "Problems with networking", address = $remoteAddr, error = exc.msg
|
||||
except CatchableError as exc:
|
||||
debug "Unknown exception", address = $remoteAddr, error = exc.msg
|
||||
trace "Unknown exception", address = $remoteAddr, error = exc.msg
|
||||
finally:
|
||||
await stream.closeWait()
|
||||
|
||||
|
@ -151,6 +151,8 @@ proc create*(
|
|||
flags,
|
||||
child = StreamServer(server)))
|
||||
|
||||
trace "Created HTTP Server", host = $address
|
||||
|
||||
return server
|
||||
|
||||
proc create*(
|
||||
|
@ -191,6 +193,8 @@ proc create*(
|
|||
flags,
|
||||
child = StreamServer(server)))
|
||||
|
||||
trace "Created TLS HTTP Server", host = $address
|
||||
|
||||
return server
|
||||
|
||||
proc create*(
|
||||
|
|
236
ws/session.nim
236
ws/session.nim
|
@ -9,27 +9,27 @@
|
|||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import std/strformat
|
||||
import pkg/[chronos, chronicles, stew/byteutils, stew/endians2]
|
||||
import ./types, ./frame, ./utils, ./utf8_dfa, ./http
|
||||
import ./types, ./frame, ./utils, ./utf8dfa, ./http
|
||||
|
||||
import pkg/chronos/[streams/asyncstream]
|
||||
import pkg/chronos/streams/asyncstream
|
||||
|
||||
type
|
||||
WSSession* = ref object of WebSocket
|
||||
stream*: AsyncStream
|
||||
frame*: Frame
|
||||
proto*: string
|
||||
logScope:
|
||||
topics = "ws-session"
|
||||
|
||||
proc prepareCloseBody(code: Status, reason: string): seq[byte] =
|
||||
proc prepareCloseBody(code: StatusCodes, reason: string): seq[byte] =
|
||||
result = reason.toBytes
|
||||
if ord(code) > 999:
|
||||
result = @(ord(code).uint16.toBytesBE()) & result
|
||||
|
||||
proc send*(
|
||||
proc writeMessage*(
|
||||
ws: WSSession,
|
||||
data: seq[byte] = @[],
|
||||
opcode: Opcode) {.async.} =
|
||||
## Send a frame
|
||||
opcode: Opcode,
|
||||
extensions: seq[Ext]) {.async.} =
|
||||
## Send a frame applying the supplied
|
||||
## extensions
|
||||
##
|
||||
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
|
@ -40,7 +40,7 @@ proc send*(
|
|||
dataSize = data.len
|
||||
masked = ws.masked
|
||||
|
||||
debug "Sending data to remote"
|
||||
trace "Sending data to remote"
|
||||
|
||||
var maskKey: array[4, char]
|
||||
if ws.masked:
|
||||
|
@ -61,30 +61,40 @@ proc send*(
|
|||
mask: ws.masked,
|
||||
data: data, # allow sending data with close messages
|
||||
maskKey: maskKey)
|
||||
.encode(extensions = ws.extensions)))
|
||||
.encode()))
|
||||
|
||||
return
|
||||
|
||||
let maxSize = ws.frameSize
|
||||
var i = 0
|
||||
while ws.readyState notin {ReadyState.Closing}:
|
||||
let len = min(data.len, (maxSize + i))
|
||||
await ws.stream.writer.write(
|
||||
(await Frame(
|
||||
fin: if (i + len >= data.len): true else: false,
|
||||
let len = min(data.len, maxSize)
|
||||
let frame = Frame(
|
||||
fin: if (len + i >= data.len): true else: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
|
||||
mask: ws.masked,
|
||||
data: data[i ..< len],
|
||||
data: data[i ..< len + i],
|
||||
maskKey: maskKey)
|
||||
.encode()))
|
||||
|
||||
let encoded = await frame.encode(extensions)
|
||||
await ws.stream.writer.write(encoded)
|
||||
|
||||
i += len
|
||||
if i >= data.len:
|
||||
break
|
||||
|
||||
proc send*(
|
||||
ws: WSSession,
|
||||
data: seq[byte] = @[],
|
||||
opcode: Opcode): Future[void] =
|
||||
## Send a frame
|
||||
##
|
||||
|
||||
return ws.writeMessage(data, opcode, ws.extensions)
|
||||
|
||||
proc send*(ws: WSSession, data: string): Future[void] =
|
||||
send(ws, data.toBytes(), Opcode.Text)
|
||||
|
||||
|
@ -101,54 +111,62 @@ proc handleClose*(
|
|||
opcode = frame.opcode
|
||||
readyState = ws.readyState
|
||||
|
||||
debug "Handling close"
|
||||
trace "Handling close"
|
||||
|
||||
if ws.readyState notin {ReadyState.Open}:
|
||||
debug "Connection isn't open, abortig close sequence!"
|
||||
if ws.readyState != ReadyState.Open:
|
||||
trace "Connection isn't open, aborting close sequence!"
|
||||
return
|
||||
|
||||
var
|
||||
code = Status.Fulfilled
|
||||
code = StatusFulfilled
|
||||
reason = ""
|
||||
|
||||
if payLoad.len == 1:
|
||||
case payload.len:
|
||||
of 0:
|
||||
code = StatusNoStatus
|
||||
of 1:
|
||||
raise newException(WSPayloadLengthError,
|
||||
"Invalid close frame with payload length 1!")
|
||||
|
||||
if payLoad.len > 1:
|
||||
# first two bytes are the status
|
||||
let ccode = uint16.fromBytesBE(payLoad[0..<2])
|
||||
if ccode <= 999 or ccode > 1015:
|
||||
raise newException(WSInvalidCloseCodeError,
|
||||
"Invalid code in close message!")
|
||||
|
||||
else:
|
||||
try:
|
||||
code = Status(ccode)
|
||||
code = StatusCodes(uint16.fromBytesBE(payLoad[0..<2]))
|
||||
except RangeError:
|
||||
raise newException(WSInvalidCloseCodeError,
|
||||
"Status code out of range!")
|
||||
|
||||
# remining payload bytes are reason for closing
|
||||
if code in StatusNotUsed or
|
||||
code in StatusReservedProtocol:
|
||||
raise newException(WSInvalidCloseCodeError,
|
||||
&"Can't use reserved status code: {code}")
|
||||
|
||||
if code == StatusReserved or
|
||||
code == StatusNoStatus or
|
||||
code == StatusClosedAbnormally:
|
||||
raise newException(WSInvalidCloseCodeError,
|
||||
&"Can't use reserved status code: {code}")
|
||||
|
||||
# remaining payload bytes are reason for closing
|
||||
reason = string.fromBytes(payLoad[2..payLoad.high])
|
||||
|
||||
if not ws.binary and validateUTF8(reason) == false:
|
||||
raise newException(WSInvalidUTF8,
|
||||
"Invalid UTF8 sequence detected in close reason")
|
||||
|
||||
var rcode: Status
|
||||
if code in {Status.Fulfilled}:
|
||||
rcode = Status.Fulfilled
|
||||
|
||||
trace "Handling close message", code, reason
|
||||
if not isNil(ws.onClose):
|
||||
try:
|
||||
(rcode, reason) = ws.onClose(code, reason)
|
||||
(code, reason) = ws.onClose(code, reason)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
|
||||
trace "Exception in Close callback, this is most likely a bug", exc = exc.msg
|
||||
else:
|
||||
code = StatusFulfilled
|
||||
reason = ""
|
||||
|
||||
# don't respond to a terminated connection
|
||||
if ws.readyState != ReadyState.Closing:
|
||||
ws.readyState = ReadyState.Closing
|
||||
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
|
||||
trace "Sending close", code, reason
|
||||
await ws.send(prepareCloseBody(code, reason), Opcode.Close)
|
||||
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
|
@ -164,7 +182,7 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
|
|||
readyState = ws.readyState
|
||||
len = frame.length
|
||||
|
||||
debug "Handling control frame"
|
||||
trace "Handling control frame"
|
||||
|
||||
if not frame.fin:
|
||||
raise newException(WSFragmentedControlFrameError,
|
||||
|
@ -191,7 +209,7 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
|
|||
try:
|
||||
ws.onPing(payLoad)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
|
||||
trace "Exception in Ping callback, this is most likely a bug", exc = exc.msg
|
||||
|
||||
# send pong to remote
|
||||
await ws.send(payLoad, Opcode.Pong)
|
||||
|
@ -200,21 +218,28 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
|
|||
try:
|
||||
ws.onPong(payLoad)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
|
||||
trace "Exception in Pong callback, this is most likely a bug", exc = exc.msg
|
||||
of Opcode.Close:
|
||||
await ws.handleClose(frame, payLoad)
|
||||
else:
|
||||
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
|
||||
|
||||
proc readFrame*(ws: WSSession): Future[Frame] {.async.} =
|
||||
proc readFrame*(ws: WSSession, extensions: seq[Ext] = @[]): Future[Frame] {.async.} =
|
||||
## Gets a frame from the WebSocket.
|
||||
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
##
|
||||
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let frame = await Frame.decode(
|
||||
ws.stream.reader, ws.masked, ws.extensions)
|
||||
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
|
||||
ws.stream.reader, ws.masked, extensions)
|
||||
|
||||
logScope:
|
||||
opcode = frame.opcode
|
||||
len = frame.length
|
||||
mask = frame.mask
|
||||
fin = frame.fin
|
||||
|
||||
trace "Decoded new frame"
|
||||
|
||||
# return the current frame if it's not one of the control frames
|
||||
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||
|
@ -230,59 +255,70 @@ proc recv*(
|
|||
ws: WSSession,
|
||||
data: pointer,
|
||||
size: int): Future[int] {.async.} =
|
||||
## Attempts to read up to `size` bytes
|
||||
## Attempts to read up to ``size`` bytes
|
||||
##
|
||||
## Will read as many frames as necessary
|
||||
## to fill the buffer until either
|
||||
## the message ends (frame.fin) or
|
||||
## the buffer is full. If no data is on
|
||||
## the pipe will await until at least
|
||||
## one byte is available
|
||||
## If ``size`` is less than the data in
|
||||
## the frame, allow reading partial frames
|
||||
##
|
||||
## If no data is left in the pipe await
|
||||
## until at least one byte is available
|
||||
##
|
||||
## Otherwise, read as many frames as needed
|
||||
## up to ``size`` bytes, note that we do break
|
||||
## at message boundaries (``fin`` flag set).
|
||||
##
|
||||
## Use this to stream data from frames
|
||||
##
|
||||
|
||||
var consumed = 0
|
||||
var pbuffer = cast[ptr UncheckedArray[byte]](data)
|
||||
try:
|
||||
var first = true
|
||||
|
||||
# reset previous frame if nothing is left in it
|
||||
if not isNil(ws.frame) and ws.frame.remainder <= 0:
|
||||
trace "Resetting previous frame"
|
||||
first = ws.frame.fin # set as first frame if last frame was final
|
||||
ws.frame = nil
|
||||
|
||||
if isNil(ws.frame):
|
||||
ws.frame = await ws.readFrame(ws.extensions)
|
||||
|
||||
while consumed < size:
|
||||
# we might have to read more than
|
||||
# one frame to fill the buffer
|
||||
|
||||
# TODO: Figure out a cleaner way to handle
|
||||
# retrieving new frames
|
||||
if isNil(ws.frame):
|
||||
ws.frame = await ws.readFrame()
|
||||
|
||||
if isNil(ws.frame):
|
||||
return consumed
|
||||
|
||||
if ws.frame.opcode == Opcode.Cont:
|
||||
raise newException(WSOpcodeMismatchError,
|
||||
"Expected Text or Binary frame")
|
||||
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
|
||||
ws.frame = await ws.readFrame()
|
||||
# This could happen if the connection is closed.
|
||||
|
||||
if isNil(ws.frame):
|
||||
return consumed
|
||||
|
||||
if ws.frame.opcode != Opcode.Cont:
|
||||
raise newException(WSOpcodeMismatchError,
|
||||
"Expected Continuation frame")
|
||||
|
||||
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
|
||||
if ws.frame.fin and ws.frame.remainder() <= 0:
|
||||
ws.frame = nil
|
||||
trace "Empty frame, breaking"
|
||||
break
|
||||
|
||||
let len = min(ws.frame.remainder().int, size - consumed)
|
||||
if len == 0:
|
||||
continue
|
||||
logScope:
|
||||
first = first
|
||||
fin = ws.frame.fin
|
||||
len = ws.frame.length
|
||||
consumed = ws.frame.consumed
|
||||
remainder = ws.frame.remainder
|
||||
opcode = ws.frame.opcode
|
||||
masked = ws.frame.mask
|
||||
|
||||
if first == (ws.frame.opcode == Opcode.Cont):
|
||||
error "Opcode mismatch!"
|
||||
raise newException(WSOpcodeMismatchError,
|
||||
&"Opcode mismatch: first: {first}, opcode: {ws.frame.opcode}")
|
||||
|
||||
if first:
|
||||
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
|
||||
trace "Setting binary flag"
|
||||
|
||||
let len = min(ws.frame.remainder.int, size - consumed)
|
||||
if len <= 0:
|
||||
trace "Nothing left to read, breaking!"
|
||||
break
|
||||
|
||||
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
|
||||
if read <= 0:
|
||||
continue
|
||||
trace "Didn't read any bytes, breaking"
|
||||
break
|
||||
|
||||
if ws.frame.mask:
|
||||
trace "Unmasking frame"
|
||||
# unmask data using offset
|
||||
mask(
|
||||
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
|
||||
|
@ -292,15 +328,31 @@ proc recv*(
|
|||
consumed += read
|
||||
ws.frame.consumed += read.uint64
|
||||
|
||||
trace "Read data from frame", read
|
||||
# all has been consumed from the frame
|
||||
# read the next frame
|
||||
if ws.frame.remainder <= 0:
|
||||
first = false
|
||||
|
||||
if ws.frame.fin: # we're at the end of the message, break
|
||||
trace "Read all frames, breaking"
|
||||
ws.frame = nil
|
||||
break
|
||||
|
||||
ws.frame = await ws.readFrame(ws.extensions)
|
||||
|
||||
if not ws.binary and validateUTF8(pbuffer.toOpenArray(0, consumed - 1)) == false:
|
||||
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
|
||||
|
||||
return consumed.int
|
||||
return consumed
|
||||
except CatchableError as exc:
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
debug "Exception reading frames", exc = exc.msg
|
||||
trace "Exception reading frames", exc = exc.msg
|
||||
raise exc
|
||||
finally:
|
||||
if not isNil(ws.frame) and (ws.frame.fin and ws.frame.remainder <= 0):
|
||||
ws.frame = nil
|
||||
|
||||
proc recv*(
|
||||
ws: WSSession,
|
||||
|
@ -318,15 +370,14 @@ proc recv*(
|
|||
##
|
||||
var res: seq[byte]
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
var buf = newSeq[byte](ws.frameSize)
|
||||
var buf = newSeq[byte](min(size, ws.frameSize))
|
||||
let read = await ws.recv(addr buf[0], buf.len)
|
||||
if read <= 0:
|
||||
break
|
||||
|
||||
buf.setLen(read)
|
||||
if res.len + buf.len > size:
|
||||
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
|
||||
|
||||
trace "Read message", size = read
|
||||
res.add(buf)
|
||||
|
||||
# no more frames
|
||||
|
@ -335,13 +386,14 @@ proc recv*(
|
|||
|
||||
# read the entire message, exit
|
||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||
trace "Read full message, breaking!"
|
||||
break
|
||||
|
||||
return res
|
||||
|
||||
proc close*(
|
||||
ws: WSSession,
|
||||
code: Status = Status.Fulfilled,
|
||||
code = StatusFulfilled,
|
||||
reason: string = "") {.async.} =
|
||||
## Close the Socket, sends close packet.
|
||||
##
|
||||
|
@ -359,4 +411,4 @@ proc close*(
|
|||
while ws.readyState != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
except CatchableError as exc:
|
||||
debug "Exception closing", exc = exc.msg
|
||||
trace "Exception closing", exc = exc.msg
|
||||
|
|
78
ws/types.nim
78
ws/types.nim
|
@ -62,41 +62,18 @@ type
|
|||
length*: uint64 ## Message size.
|
||||
consumed*: uint64 ## how much has been consumed from the frame
|
||||
|
||||
Status* {.pure.} = enum
|
||||
# 0-999 not used
|
||||
Fulfilled = 1000
|
||||
GoingAway = 1001
|
||||
ProtocolError = 1002
|
||||
CannotAccept = 1003
|
||||
# 1004 reserved
|
||||
NoStatus = 1005 # use by clients
|
||||
ClosedAbnormally = 1006 # use by clients
|
||||
Inconsistent = 1007
|
||||
PolicyError = 1008
|
||||
TooLarge = 1009
|
||||
NoExtensions = 1010
|
||||
UnexpectedError = 1011
|
||||
ReservedCode = 3999 # use by clients
|
||||
# 3000-3999 reserved for libs
|
||||
# 4000-4999 reserved for applications
|
||||
StatusCodes* = distinct range[0..4999]
|
||||
|
||||
ControlCb* = proc(data: openArray[byte] = [])
|
||||
{.gcsafe, raises: [Defect].}
|
||||
|
||||
CloseResult* = tuple
|
||||
code: Status
|
||||
code: StatusCodes
|
||||
reason: string
|
||||
|
||||
CloseCb* = proc(code: Status, reason: string):
|
||||
CloseCb* = proc(code: StatusCodes, reason: string):
|
||||
CloseResult {.gcsafe, raises: [Defect].}
|
||||
|
||||
Ext* = ref object of RootObj
|
||||
name*: string
|
||||
options*: Table[string, string]
|
||||
|
||||
ExtFactory* = proc(name: string, options: Table[string, string]):
|
||||
Ext {.raises: [Defect].}
|
||||
|
||||
WebSocket* = ref object of RootObj
|
||||
extensions*: seq[Ext]
|
||||
version*: uint
|
||||
|
@ -111,6 +88,21 @@ type
|
|||
onPong*: ControlCb
|
||||
onClose*: CloseCb
|
||||
|
||||
WSSession* = ref object of WebSocket
|
||||
stream*: AsyncStream
|
||||
frame*: Frame
|
||||
proto*: string
|
||||
|
||||
Ext* = ref object of RootObj
|
||||
name*: string
|
||||
options*: Table[string, string]
|
||||
session*: WSSession
|
||||
|
||||
ExtFactory* = proc(
|
||||
name: string,
|
||||
session: WSSession,
|
||||
options: Table[string, string]): Ext {.raises: [Defect].}
|
||||
|
||||
WebSocketError* = object of CatchableError
|
||||
WSMalformedHeaderError* = object of WebSocketError
|
||||
WSFailedUpgradeError* = object of WebSocketError
|
||||
|
@ -125,13 +117,43 @@ type
|
|||
WSClosedError* = object of WebSocketError
|
||||
WSSendError* = object of WebSocketError
|
||||
WSPayloadTooLarge* = object of WebSocketError
|
||||
WSReserverdOpcodeError* = object of WebSocketError
|
||||
WSReservedOpcodeError* = object of WebSocketError
|
||||
WSFragmentedControlFrameError* = object of WebSocketError
|
||||
WSInvalidCloseCodeError* = object of WebSocketError
|
||||
WSPayloadLengthError* = object of WebSocketError
|
||||
WSInvalidOpcodeError* = object of WebSocketError
|
||||
WSInvalidUTF8* = object of WebSocketError
|
||||
|
||||
const
|
||||
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
||||
StatusFulfilled* = StatusCodes(1000)
|
||||
StatusGoingAway* = StatusCodes(1001)
|
||||
StatusProtocolError* = StatusCodes(1002)
|
||||
StatusCannotAccept* = StatusCodes(1003)
|
||||
StatusReserved* = StatusCodes(1004) # 1004 reserved
|
||||
StatusNoStatus* = StatusCodes(1005) # use by clients
|
||||
StatusClosedAbnormally* = StatusCodes(1006) # use by clients
|
||||
StatusInconsistent* = StatusCodes(1007)
|
||||
StatusPolicyError* = StatusCodes(1008)
|
||||
StatusTooLarge* = StatusCodes(1009)
|
||||
StatusNoExtensions* = StatusCodes(1010)
|
||||
StatusUnexpectedError* = StatusCodes(1011)
|
||||
StatusFailedTls* = StatusCodes(1015) # passed to applications to indicate TLS errors
|
||||
StatusReservedProtocol* = StatusCodes(1016)..StatusCodes(2999) # reserved for this protocol
|
||||
StatusLibsCodes* = (StatusCodes(3000)..StatusCodes(3999)) # 3000-3999 reserved for libs
|
||||
StatusAppsCodes* = (StatusCodes(4000)..StatusCodes(4999)) # 4000-4999 reserved for apps
|
||||
|
||||
proc `<=`*(a, b: StatusCodes): bool = a.uint16 <= b.uint16
|
||||
proc `>=`*(a, b: StatusCodes): bool = a.uint16 >= b.uint16
|
||||
proc `<`*(a, b: StatusCodes): bool = a.uint16 < b.uint16
|
||||
proc `>`*(a, b: StatusCodes): bool = a.uint16 > b.uint16
|
||||
proc `==`*(a, b: StatusCodes): bool = a.uint16 == b.uint16
|
||||
|
||||
proc high*(a: HSlice[StatusCodes, StatusCodes]): uint16 = a.b.uint16
|
||||
proc low*(a: HSlice[StatusCodes, StatusCodes]): uint16 = a.a.uint16
|
||||
|
||||
proc `$`*(a: StatusCodes): string = $(a.int)
|
||||
|
||||
proc `name=`*(self: Ext, name: string) =
|
||||
raiseAssert "Can't change extensions name!"
|
||||
|
||||
|
@ -141,5 +163,5 @@ method decode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} =
|
|||
method encode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} =
|
||||
raiseAssert "Not implemented!"
|
||||
|
||||
method toHttpOptions*(self: Ext): string =
|
||||
method toHttpOptions*(self: Ext): string {.base.} =
|
||||
raiseAssert "Not implemented!"
|
||||
|
|
10
ws/ws.nim
10
ws/ws.nim
|
@ -23,7 +23,6 @@ import pkg/[chronos,
|
|||
chronicles,
|
||||
httputils,
|
||||
stew/byteutils,
|
||||
stew/endians2,
|
||||
stew/base64,
|
||||
stew/base10,
|
||||
nimcrypto/sha]
|
||||
|
@ -32,6 +31,9 @@ import ./utils, ./frame, ./session, /types, ./http
|
|||
|
||||
export utils, session, frame, types, http
|
||||
|
||||
logScope:
|
||||
topics = "ws-server"
|
||||
|
||||
type
|
||||
WSServer* = ref object of WebSocket
|
||||
protocols: seq[string]
|
||||
|
@ -86,7 +88,7 @@ proc connect*(
|
|||
let response = try:
|
||||
await client.request(uri, headers = headers)
|
||||
except CatchableError as exc:
|
||||
debug "Websocket failed during handshake", exc = exc.msg
|
||||
trace "Websocket failed during handshake", exc = exc.msg
|
||||
await client.close()
|
||||
raise exc
|
||||
|
||||
|
@ -207,7 +209,7 @@ proc handleRequest*(
|
|||
|
||||
if ws.version != version:
|
||||
await request.stream.writer.sendError(Http426)
|
||||
debug "Websocket version not supported", version = ws.version
|
||||
trace "Websocket version not supported", version = ws.version
|
||||
|
||||
raise newException(WSVersionError,
|
||||
&"Websocket version not supported, Version: {version}")
|
||||
|
@ -236,7 +238,7 @@ proc handleRequest*(
|
|||
if protocol.len > 0:
|
||||
headers.add("Sec-WebSocket-Protocol", protocol) # send back the first matching proto
|
||||
else:
|
||||
debug "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
||||
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
||||
|
||||
try:
|
||||
await request.sendResponse(Http101, headers = headers)
|
||||
|
|
Loading…
Reference in New Issue