Fix webscoket close and fix test cases.
This commit is contained in:
parent
0feac12a67
commit
51d834a0a1
|
@ -35,7 +35,7 @@ jobs:
|
||||||
- target:
|
- target:
|
||||||
os: windows
|
os: windows
|
||||||
builder: windows-2019
|
builder: windows-2019
|
||||||
name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})'
|
name: "${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})"
|
||||||
runs-on: ${{ matrix.builder }}
|
runs-on: ${{ matrix.builder }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout nim-ws
|
- name: Checkout nim-ws
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import pkg/[chronos,
|
import pkg/[chronos,
|
||||||
chronos/apps/http/httpserver,
|
chronos/apps/http/httpserver,
|
||||||
chronicles,
|
chronicles,
|
||||||
httputils,
|
httputils]
|
||||||
stew/byteutils]
|
|
||||||
|
|
||||||
import ../ws/ws
|
import ../ws/ws
|
||||||
|
|
||||||
|
@ -13,22 +12,18 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if request.uri.path == "/ws":
|
if request.uri.path == "/ws":
|
||||||
debug "Initiating web socket connection."
|
debug "Initiating web socket connection."
|
||||||
try:
|
try:
|
||||||
var ws = await createServer(request,"")
|
let ws = await createServer(request, "")
|
||||||
if ws.readyState != Open:
|
if ws.readyState != Open:
|
||||||
error "Failed to open websocket connection."
|
error "Failed to open websocket connection."
|
||||||
return
|
return
|
||||||
debug "Websocket handshake completed."
|
debug "Websocket handshake completed."
|
||||||
while ws.readyState != ReadyState.Closed:
|
while true:
|
||||||
# Only reads header for data frame.
|
let recvData = await ws.recv()
|
||||||
var recvData = await ws.recv()
|
if ws.readyState == ReadyState.Closed:
|
||||||
if recvData.len <= 0:
|
debug "Websocket closed."
|
||||||
debug "Empty messages"
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# debug "Client Response: ", data = string.fromBytes(recvData), size = recvData.len
|
|
||||||
debug "Client Response: ", size = recvData.len
|
debug "Client Response: ", size = recvData.len
|
||||||
await ws.send(recvData)
|
await ws.send(recvData)
|
||||||
# await ws.close()
|
|
||||||
|
|
||||||
except WebSocketError as exc:
|
except WebSocketError as exc:
|
||||||
error "WebSocket error:", exception = exc.msg
|
error "WebSocket error:", exception = exc.msg
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
include ../ws/ws
|
include ../ws/ws
|
||||||
include ../ws/random
|
include ../ws/utils
|
||||||
|
|
||||||
# TODO: Fix Test.
|
# TODO: Fix Test.
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import std/strutils, httputils
|
||||||
|
|
||||||
import pkg/[asynctest,
|
import pkg/[asynctest,
|
||||||
chronos,
|
chronos,
|
||||||
|
chronicles,
|
||||||
chronos/apps/http/shttpserver,
|
chronos/apps/http/shttpserver,
|
||||||
stew/byteutils]
|
stew/byteutils]
|
||||||
|
|
||||||
|
@ -10,14 +11,23 @@ import ../ws/[ws, stream],
|
||||||
|
|
||||||
import ./keys
|
import ./keys
|
||||||
|
|
||||||
var server: SecureHttpServerRef
|
proc waitForClose(ws: WebSocket) {.async.} =
|
||||||
let address = initTAddress("127.0.0.1:8888")
|
try:
|
||||||
let serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect}
|
while ws.readystate != ReadyState.Closed:
|
||||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
discard await ws.recv()
|
||||||
let clientFlags = {NoVerifyHost, NoVerifyServerName}
|
except CatchableError:
|
||||||
|
debug "Closing websocket"
|
||||||
|
|
||||||
|
var server: SecureHttpServerRef
|
||||||
|
|
||||||
|
let
|
||||||
|
address = initTAddress("127.0.0.1:8888")
|
||||||
|
serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect}
|
||||||
|
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||||
|
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||||
|
secureKey = TLSPrivateKey.init(SecureKey)
|
||||||
|
secureCert = TLSCertificate.init(SecureCert)
|
||||||
|
|
||||||
let secureKey = TLSPrivateKey.init(SecureKey)
|
|
||||||
let secureCert = TLSCertificate.init(SecureCert)
|
|
||||||
|
|
||||||
suite "Test websocket TLS handshake":
|
suite "Test websocket TLS handshake":
|
||||||
teardown:
|
teardown:
|
||||||
|
@ -31,10 +41,7 @@ suite "Test websocket TLS handshake":
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/wss"
|
check request.uri.path == "/wss"
|
||||||
expect WSProtoMismatchError:
|
expect WSProtoMismatchError:
|
||||||
var ws = await createServer(request, "proto")
|
discard await createServer(request, "proto")
|
||||||
check ws.readyState == ReadyState.Closed
|
|
||||||
|
|
||||||
return await request.respond(Http200, "Connection established")
|
|
||||||
|
|
||||||
let res = SecureHttpServerRef.new(
|
let res = SecureHttpServerRef.new(
|
||||||
address, cb,
|
address, cb,
|
||||||
|
@ -62,10 +69,7 @@ suite "Test websocket TLS handshake":
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/wss"
|
check request.uri.path == "/wss"
|
||||||
expect WSVersionError:
|
expect WSVersionError:
|
||||||
var ws = await createServer(request, "proto")
|
discard await createServer(request, "proto")
|
||||||
check ws.readyState == ReadyState.Closed
|
|
||||||
|
|
||||||
return await request.respond(Http200, "Connection established")
|
|
||||||
|
|
||||||
let res = SecureHttpServerRef.new(
|
let res = SecureHttpServerRef.new(
|
||||||
address, cb,
|
address, cb,
|
||||||
|
@ -91,9 +95,12 @@ suite "Test websocket TLS handshake":
|
||||||
check r.isOk()
|
check r.isOk()
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/wss"
|
check request.uri.path == "/wss"
|
||||||
check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii()
|
check request.headers.getString("Connection").toUpperAscii() ==
|
||||||
check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii()
|
"Upgrade".toUpperAscii()
|
||||||
check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii()
|
check request.headers.getString("Upgrade").toUpperAscii() ==
|
||||||
|
"websocket".toUpperAscii()
|
||||||
|
check request.headers.getString("Cache-Control").toUpperAscii() ==
|
||||||
|
"no-cache".toUpperAscii()
|
||||||
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
|
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
|
||||||
|
|
||||||
check request.headers.contains("Sec-WebSocket-Key")
|
check request.headers.contains("Sec-WebSocket-Key")
|
||||||
|
@ -133,41 +140,7 @@ suite "Test websocket TLS transmission":
|
||||||
let ws = await createServer(request, "proto")
|
let ws = await createServer(request, "proto")
|
||||||
let servRes = await ws.recv()
|
let servRes = await ws.recv()
|
||||||
check string.fromBytes(servRes) == testString
|
check string.fromBytes(servRes) == testString
|
||||||
await ws.close()
|
await waitForClose(ws)
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = SecureHttpServerRef.new(
|
|
||||||
address, cb,
|
|
||||||
serverFlags = serverFlags,
|
|
||||||
socketFlags = socketFlags,
|
|
||||||
tlsPrivateKey = secureKey,
|
|
||||||
tlsCertificate = secureCert)
|
|
||||||
|
|
||||||
server = res.get()
|
|
||||||
server.start()
|
|
||||||
|
|
||||||
let wsClient = await WebSocket.tlsConnect(
|
|
||||||
"127.0.0.1",
|
|
||||||
Port(8888),
|
|
||||||
path = "/wss",
|
|
||||||
protocols = @["proto"],
|
|
||||||
clientFlags)
|
|
||||||
|
|
||||||
await wsClient.send(testString)
|
|
||||||
await wsClient.close()
|
|
||||||
|
|
||||||
test "Client - test reading simple frame":
|
|
||||||
let testString = "Hello!"
|
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
|
||||||
if r.isErr():
|
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let request = r.get()
|
|
||||||
check request.uri.path == "/wss"
|
|
||||||
let ws = await createServer(request, "proto")
|
|
||||||
let servRes = await ws.recv()
|
|
||||||
check string.fromBytes(servRes) == testString
|
|
||||||
await ws.close()
|
|
||||||
return dumbResponse()
|
return dumbResponse()
|
||||||
|
|
||||||
let res = SecureHttpServerRef.new(
|
let res = SecureHttpServerRef.new(
|
||||||
|
@ -222,4 +195,4 @@ suite "Test websocket TLS transmission":
|
||||||
|
|
||||||
var clientRes = await wsClient.recv()
|
var clientRes = await wsClient.recv()
|
||||||
check string.fromBytes(clientRes) == testString
|
check string.fromBytes(clientRes) == testString
|
||||||
await wsClient.close()
|
await waitForClose(wsClient)
|
||||||
|
|
|
@ -1,15 +1,29 @@
|
||||||
import std/strutils,httputils
|
import std/[strutils, random], httputils
|
||||||
|
|
||||||
import pkg/[asynctest,
|
import pkg/[asynctest,
|
||||||
chronos,
|
chronos,
|
||||||
chronos/apps/http/httpserver,
|
chronos/apps/http/httpserver,
|
||||||
|
chronicles,
|
||||||
stew/byteutils]
|
stew/byteutils]
|
||||||
|
|
||||||
import ../ws/[ws, stream]
|
import ../ws/[ws, stream]
|
||||||
|
|
||||||
|
include ../ws/ws
|
||||||
|
|
||||||
var server: HttpServerRef
|
var server: HttpServerRef
|
||||||
let address = initTAddress("127.0.0.1:8888")
|
let address = initTAddress("127.0.0.1:8888")
|
||||||
|
|
||||||
|
proc rndStr*(size: int): string =
|
||||||
|
for _ in .. size:
|
||||||
|
add(result, char(rand(int('A') .. int('z'))))
|
||||||
|
|
||||||
|
proc waitForClose(ws: WebSocket) {.async.} =
|
||||||
|
try:
|
||||||
|
while ws.readystate != ReadyState.Closed:
|
||||||
|
discard await ws.recv()
|
||||||
|
except CatchableError:
|
||||||
|
debug "Closing websocket"
|
||||||
|
|
||||||
suite "Test handshake":
|
suite "Test handshake":
|
||||||
teardown:
|
teardown:
|
||||||
await server.stop()
|
await server.stop()
|
||||||
|
@ -18,13 +32,12 @@ suite "Test handshake":
|
||||||
test "Test for incorrect protocol":
|
test "Test for incorrect protocol":
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
expect WSProtoMismatchError:
|
expect WSProtoMismatchError:
|
||||||
var ws = await createServer(request, "proto")
|
discard await createServer(request, "proto")
|
||||||
check ws.readyState == ReadyState.Closed
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -45,8 +58,7 @@ suite "Test handshake":
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
expect WSVersionError:
|
expect WSVersionError:
|
||||||
var ws = await createServer(request, "proto")
|
discard await createServer(request, "proto")
|
||||||
check ws.readyState == ReadyState.Closed
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -63,17 +75,22 @@ suite "Test handshake":
|
||||||
test "Test for client headers":
|
test "Test for client headers":
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii()
|
check request.headers.getString("Connection").toUpperAscii() ==
|
||||||
check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii()
|
"Upgrade".toUpperAscii()
|
||||||
check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii()
|
check request.headers.getString("Upgrade").toUpperAscii() ==
|
||||||
|
"websocket".toUpperAscii()
|
||||||
|
check request.headers.getString("Cache-Control").toUpperAscii() ==
|
||||||
|
"no-cache".toUpperAscii()
|
||||||
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
|
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
|
||||||
|
|
||||||
check request.headers.contains("Sec-WebSocket-Key")
|
check request.headers.contains("Sec-WebSocket-Key")
|
||||||
|
|
||||||
|
discard await request.respond(Http200, "Connection established")
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -88,7 +105,7 @@ suite "Test handshake":
|
||||||
test "Test for incorrect scheme":
|
test "Test for incorrect scheme":
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
|
@ -109,23 +126,52 @@ suite "Test handshake":
|
||||||
parseUri(uri),
|
parseUri(uri),
|
||||||
protocols = @["proto"])
|
protocols = @["proto"])
|
||||||
|
|
||||||
|
test "AsyncStream leaks test":
|
||||||
|
check:
|
||||||
|
getTracker("async.stream.reader").isLeaked() == false
|
||||||
|
getTracker("async.stream.writer").isLeaked() == false
|
||||||
|
getTracker("stream.server").isLeaked() == false
|
||||||
|
getTracker("stream.transport").isLeaked() == false
|
||||||
|
|
||||||
suite "Test transmission":
|
suite "Test transmission":
|
||||||
teardown:
|
teardown:
|
||||||
await server.closeWait()
|
await server.closeWait()
|
||||||
|
|
||||||
|
test "Send text message message with payload of length 65535":
|
||||||
|
let testString = rndStr(65535)
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(request, "proto")
|
||||||
|
let servRes = await ws.recv()
|
||||||
|
check string.fromBytes(servRes) == testString
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
await wsClient.send(testString)
|
||||||
|
await wsClient.close()
|
||||||
|
|
||||||
test "Server - test reading simple frame":
|
test "Server - test reading simple frame":
|
||||||
let testString = "Hello!"
|
let testString = "Hello!"
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
let ws = await createServer(request, "proto")
|
let ws = await createServer(request, "proto")
|
||||||
let servRes = await ws.recv()
|
let servRes = await ws.recv()
|
||||||
|
|
||||||
check string.fromBytes(servRes) == testString
|
check string.fromBytes(servRes) == testString
|
||||||
await ws.close()
|
await waitForClose(ws)
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -143,7 +189,7 @@ suite "Test transmission":
|
||||||
let testString = "Hello!"
|
let testString = "Hello!"
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
|
@ -162,18 +208,89 @@ suite "Test transmission":
|
||||||
protocols = @["proto"])
|
protocols = @["proto"])
|
||||||
|
|
||||||
var clientRes = await wsClient.recv()
|
var clientRes = await wsClient.recv()
|
||||||
await wsClient.close()
|
|
||||||
check string.fromBytes(clientRes) == testString
|
check string.fromBytes(clientRes) == testString
|
||||||
|
await waitForClose(wsClient)
|
||||||
|
|
||||||
|
test "AsyncStream leaks test":
|
||||||
|
check:
|
||||||
|
getTracker("async.stream.reader").isLeaked() == false
|
||||||
|
getTracker("async.stream.writer").isLeaked() == false
|
||||||
|
getTracker("stream.server").isLeaked() == false
|
||||||
|
getTracker("stream.transport").isLeaked() == false
|
||||||
|
|
||||||
suite "Test ping-pong":
|
suite "Test ping-pong":
|
||||||
teardown:
|
teardown:
|
||||||
await server.closeWait()
|
await server.closeWait()
|
||||||
|
test "Send text Message fragmented into 2 fragments, one ping with payload in-between":
|
||||||
|
var ping, pong = false
|
||||||
|
let testString = "1234567890"
|
||||||
|
let msg = toBytes(testString)
|
||||||
|
let maxFrameSize = 5
|
||||||
|
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(
|
||||||
|
request,
|
||||||
|
"proto",
|
||||||
|
onPing = proc() =
|
||||||
|
ping = true
|
||||||
|
)
|
||||||
|
|
||||||
|
let respData = await ws.recv()
|
||||||
|
check string.fromBytes(respData) == testString
|
||||||
|
await waitForClose(ws)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"],
|
||||||
|
frameSize = maxFrameSize,
|
||||||
|
onPong = proc() =
|
||||||
|
pong = true
|
||||||
|
)
|
||||||
|
|
||||||
|
let maskKey = genMaskKey(newRng())
|
||||||
|
let encframe = encodeFrame(Frame(
|
||||||
|
fin: false,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: Opcode.Text,
|
||||||
|
mask: true,
|
||||||
|
data: msg[0..4],
|
||||||
|
maskKey: maskKey))
|
||||||
|
|
||||||
|
await wsClient.stream.writer.write(encframe)
|
||||||
|
await wsClient.ping()
|
||||||
|
let encframe1 = encodeFrame(Frame(
|
||||||
|
fin: true,
|
||||||
|
rsv1: false,
|
||||||
|
rsv2: false,
|
||||||
|
rsv3: false,
|
||||||
|
opcode: Opcode.Cont,
|
||||||
|
mask: true,
|
||||||
|
data: msg[5..9],
|
||||||
|
maskKey: maskKey))
|
||||||
|
|
||||||
|
await wsClient.stream.writer.write(encframe1)
|
||||||
|
await wsClient.close()
|
||||||
|
check:
|
||||||
|
ping
|
||||||
|
pong
|
||||||
test "Server - test ping-pong control messages":
|
test "Server - test ping-pong control messages":
|
||||||
var ping, pong = false
|
var ping, pong = false
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
|
@ -200,7 +317,7 @@ suite "Test ping-pong":
|
||||||
ping = true
|
ping = true
|
||||||
)
|
)
|
||||||
|
|
||||||
discard await wsClient.recv()
|
await waitForClose(wsClient)
|
||||||
check:
|
check:
|
||||||
ping
|
ping
|
||||||
pong
|
pong
|
||||||
|
@ -209,7 +326,7 @@ suite "Test ping-pong":
|
||||||
var ping, pong = false
|
var ping, pong = false
|
||||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
if r.isErr():
|
if r.isErr():
|
||||||
return
|
return dumbResponse()
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
|
@ -219,10 +336,10 @@ suite "Test ping-pong":
|
||||||
onPing = proc() =
|
onPing = proc() =
|
||||||
ping = true
|
ping = true
|
||||||
)
|
)
|
||||||
|
await waitForClose(ws)
|
||||||
discard await ws.recv()
|
check:
|
||||||
await ws.close()
|
ping
|
||||||
|
pong
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -238,9 +355,13 @@ suite "Test ping-pong":
|
||||||
|
|
||||||
await wsClient.ping()
|
await wsClient.ping()
|
||||||
await wsClient.close()
|
await wsClient.close()
|
||||||
|
|
||||||
|
test "AsyncStream leaks test":
|
||||||
check:
|
check:
|
||||||
ping
|
getTracker("async.stream.reader").isLeaked() == false
|
||||||
pong
|
getTracker("async.stream.writer").isLeaked() == false
|
||||||
|
getTracker("stream.server").isLeaked() == false
|
||||||
|
getTracker("stream.transport").isLeaked() == false
|
||||||
|
|
||||||
suite "Test framing":
|
suite "Test framing":
|
||||||
teardown:
|
teardown:
|
||||||
|
@ -268,8 +389,7 @@ suite "Test framing":
|
||||||
let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len)
|
let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len)
|
||||||
check read2 == 5
|
check read2 == 5
|
||||||
|
|
||||||
await ws.close()
|
await waitForClose(ws)
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -296,7 +416,6 @@ suite "Test framing":
|
||||||
let ws = await createServer(request, "proto")
|
let ws = await createServer(request, "proto")
|
||||||
await ws.send(testString)
|
await ws.send(testString)
|
||||||
await ws.close()
|
await ws.close()
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -310,8 +429,14 @@ suite "Test framing":
|
||||||
|
|
||||||
expect WSMaxMessageSizeError:
|
expect WSMaxMessageSizeError:
|
||||||
discard await wsClient.recv(5)
|
discard await wsClient.recv(5)
|
||||||
|
await waitForClose(wsClient)
|
||||||
|
|
||||||
await wsClient.close()
|
test "AsyncStream leaks test":
|
||||||
|
check:
|
||||||
|
getTracker("async.stream.reader").isLeaked() == false
|
||||||
|
getTracker("async.stream.writer").isLeaked() == false
|
||||||
|
getTracker("stream.server").isLeaked() == false
|
||||||
|
getTracker("stream.transport").isLeaked() == false
|
||||||
|
|
||||||
suite "Test Closing":
|
suite "Test Closing":
|
||||||
teardown:
|
teardown:
|
||||||
|
@ -326,7 +451,6 @@ suite "Test Closing":
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
let ws = await createServer(request, "proto")
|
let ws = await createServer(request, "proto")
|
||||||
await ws.close()
|
await ws.close()
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -338,7 +462,7 @@ suite "Test Closing":
|
||||||
path = "/ws",
|
path = "/ws",
|
||||||
protocols = @["proto"])
|
protocols = @["proto"])
|
||||||
|
|
||||||
discard await wsClient.recv()
|
await waitForClose(wsClient)
|
||||||
check wsClient.readyState == ReadyState.Closed
|
check wsClient.readyState == ReadyState.Closed
|
||||||
|
|
||||||
test "Server closing with status":
|
test "Server closing with status":
|
||||||
|
@ -348,8 +472,8 @@ suite "Test Closing":
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
proc closeServer(status: Status, reason: string): CloseResult
|
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
|
||||||
{.gcsafe, raises: [Defect].} =
|
raises: [Defect].} =
|
||||||
try:
|
try:
|
||||||
check status == Status.TooLarge
|
check status == Status.TooLarge
|
||||||
check reason == "Message too big!"
|
check reason == "Message too big!"
|
||||||
|
@ -364,14 +488,13 @@ suite "Test Closing":
|
||||||
onClose = closeServer)
|
onClose = closeServer)
|
||||||
|
|
||||||
await ws.close()
|
await ws.close()
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
proc clientClose(status: Status, reason: string): CloseResult
|
proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
|
||||||
{.gcsafe, raises: [Defect].} =
|
raises: [Defect].} =
|
||||||
try:
|
try:
|
||||||
check status == Status.Fulfilled
|
check status == Status.Fulfilled
|
||||||
return (Status.TooLarge, "Message too big!")
|
return (Status.TooLarge, "Message too big!")
|
||||||
|
@ -385,7 +508,7 @@ suite "Test Closing":
|
||||||
protocols = @["proto"],
|
protocols = @["proto"],
|
||||||
onClose = clientClose)
|
onClose = clientClose)
|
||||||
|
|
||||||
discard await wsClient.recv()
|
await waitForClose(wsClient)
|
||||||
check wsClient.readyState == ReadyState.Closed
|
check wsClient.readyState == ReadyState.Closed
|
||||||
|
|
||||||
test "Client closing":
|
test "Client closing":
|
||||||
|
@ -396,9 +519,7 @@ suite "Test Closing":
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
let ws = await createServer(request, "proto")
|
let ws = await createServer(request, "proto")
|
||||||
discard await ws.recv()
|
await waitForClose(ws)
|
||||||
await ws.close()
|
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
|
@ -418,8 +539,8 @@ suite "Test Closing":
|
||||||
|
|
||||||
let request = r.get()
|
let request = r.get()
|
||||||
check request.uri.path == "/ws"
|
check request.uri.path == "/ws"
|
||||||
proc closeServer(status: Status, reason: string): CloseResult
|
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
|
||||||
{.gcsafe, raises: [Defect].} =
|
raises: [Defect].} =
|
||||||
try:
|
try:
|
||||||
check status == Status.Fulfilled
|
check status == Status.Fulfilled
|
||||||
return (Status.TooLarge, "Message too big!")
|
return (Status.TooLarge, "Message too big!")
|
||||||
|
@ -430,16 +551,14 @@ suite "Test Closing":
|
||||||
request,
|
request,
|
||||||
"proto",
|
"proto",
|
||||||
onClose = closeServer)
|
onClose = closeServer)
|
||||||
discard await ws.recv()
|
await waitForClose(ws)
|
||||||
await ws.close()
|
|
||||||
return dumbResponse()
|
|
||||||
|
|
||||||
let res = HttpServerRef.new(address, cb)
|
let res = HttpServerRef.new(address, cb)
|
||||||
server = res.get()
|
server = res.get()
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
proc clientClose(status: Status, reason: string): CloseResult
|
proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
|
||||||
{.gcsafe, raises: [Defect].} =
|
raises: [Defect].} =
|
||||||
try:
|
try:
|
||||||
check status == Status.TooLarge
|
check status == Status.TooLarge
|
||||||
check reason == "Message too big!"
|
check reason == "Message too big!"
|
||||||
|
@ -456,3 +575,253 @@ suite "Test Closing":
|
||||||
|
|
||||||
await wsClient.close()
|
await wsClient.close()
|
||||||
check wsClient.readyState == ReadyState.Closed
|
check wsClient.readyState == ReadyState.Closed
|
||||||
|
|
||||||
|
test "Server closing with valid close code 3999":
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(
|
||||||
|
request,
|
||||||
|
"proto")
|
||||||
|
|
||||||
|
await ws.close(code = Status.ReservedCode)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
proc closeClient(status: Status, reason: string): CloseResult{.gcsafe,
|
||||||
|
raises: [Defect].} =
|
||||||
|
try:
|
||||||
|
check status == Status.ReservedCode
|
||||||
|
return (Status.ReservedCode, "Reserved Status")
|
||||||
|
except Exception as exc:
|
||||||
|
raise newException(Defect, exc.msg)
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"],
|
||||||
|
onClose = closeClient)
|
||||||
|
|
||||||
|
await waitForClose(wsClient)
|
||||||
|
|
||||||
|
test "Client closing with valid close code 3999":
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
|
||||||
|
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
|
||||||
|
raises: [Defect].} =
|
||||||
|
try:
|
||||||
|
check status == Status.ReservedCode
|
||||||
|
return (Status.ReservedCode, "Reserved Status")
|
||||||
|
except Exception as exc:
|
||||||
|
raise newException(Defect, exc.msg)
|
||||||
|
|
||||||
|
let ws = await createServer(
|
||||||
|
request,
|
||||||
|
"proto",
|
||||||
|
onClose = closeServer)
|
||||||
|
|
||||||
|
await waitForClose(ws)
|
||||||
|
return dumbResponse()
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
await wsClient.close(code = Status.ReservedCode)
|
||||||
|
|
||||||
|
test "Server closing with Payload of length 2":
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(request, "proto")
|
||||||
|
# Close with payload of length 2
|
||||||
|
await ws.close(reason = "HH")
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
await waitForClose(wsClient)
|
||||||
|
|
||||||
|
test "Client closing with Payload of length 2":
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(request, "proto")
|
||||||
|
await waitForClose(ws)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
# Close with payload of length 2
|
||||||
|
await wsClient.close(reason = "HH")
|
||||||
|
|
||||||
|
|
||||||
|
test "AsyncStream leaks test":
|
||||||
|
check:
|
||||||
|
getTracker("async.stream.reader").isLeaked() == false
|
||||||
|
getTracker("async.stream.writer").isLeaked() == false
|
||||||
|
getTracker("stream.server").isLeaked() == false
|
||||||
|
getTracker("stream.transport").isLeaked() == false
|
||||||
|
|
||||||
|
suite "Test Payload":
|
||||||
|
teardown:
|
||||||
|
await server.closeWait()
|
||||||
|
|
||||||
|
test "Test payload message length":
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(
|
||||||
|
request,
|
||||||
|
"proto")
|
||||||
|
|
||||||
|
expect WSPayloadTooLarge:
|
||||||
|
discard await ws.recv()
|
||||||
|
await waitForClose(ws)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let str = rndStr(126)
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
|
||||||
|
await wsClient.send(toBytes(str), Opcode.Ping)
|
||||||
|
await wsClient.close()
|
||||||
|
|
||||||
|
test "Test single empty payload":
|
||||||
|
let emptyStr = ""
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(request, "proto")
|
||||||
|
let servRes = await ws.recv()
|
||||||
|
check string.fromBytes(servRes) == emptyStr
|
||||||
|
await waitForClose(ws)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
|
||||||
|
await wsClient.send(emptyStr)
|
||||||
|
await wsClient.close()
|
||||||
|
|
||||||
|
test "Test multiple empty payload":
|
||||||
|
let emptyStr = ""
|
||||||
|
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(request, "proto")
|
||||||
|
let servRes = await ws.recv()
|
||||||
|
check string.fromBytes(servRes) == emptyStr
|
||||||
|
await waitForClose(ws)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, cb)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"])
|
||||||
|
|
||||||
|
for i in 0..3:
|
||||||
|
await wsClient.send(emptyStr)
|
||||||
|
await wsClient.close()
|
||||||
|
|
||||||
|
test "Send ping with small text payload":
|
||||||
|
let testData = toBytes("Hello, world!")
|
||||||
|
var ping, pong = false
|
||||||
|
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||||
|
if r.isErr():
|
||||||
|
return dumbResponse()
|
||||||
|
let request = r.get()
|
||||||
|
check request.uri.path == "/ws"
|
||||||
|
let ws = await createServer(
|
||||||
|
request,
|
||||||
|
"proto",
|
||||||
|
onPing = proc() =
|
||||||
|
ping = true
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitForClose(ws)
|
||||||
|
|
||||||
|
let res = HttpServerRef.new(
|
||||||
|
address, process)
|
||||||
|
server = res.get()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let wsClient = await WebSocket.connect(
|
||||||
|
"127.0.0.1",
|
||||||
|
Port(8888),
|
||||||
|
path = "/ws",
|
||||||
|
protocols = @["proto"],
|
||||||
|
onPong = proc() =
|
||||||
|
pong = true
|
||||||
|
)
|
||||||
|
|
||||||
|
await wsClient.send(testData, Opcode.Ping)
|
||||||
|
await wsClient.close()
|
||||||
|
check:
|
||||||
|
ping
|
||||||
|
pong
|
||||||
|
test "AsyncStream leaks test":
|
||||||
|
check:
|
||||||
|
getTracker("async.stream.reader").isLeaked() == false
|
||||||
|
getTracker("async.stream.writer").isLeaked() == false
|
||||||
|
getTracker("stream.server").isLeaked() == false
|
||||||
|
getTracker("stream.transport").isLeaked() == false
|
||||||
|
|
|
@ -46,6 +46,10 @@ proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} =
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
proc closeWait*(wsStream: AsyncStream) {.async.} =
|
proc closeWait*(wsStream: AsyncStream) {.async.} =
|
||||||
|
|
||||||
|
await allFutures(
|
||||||
|
wsStream.writer.closeWait(),
|
||||||
|
wsStream.reader.closeWait())
|
||||||
await allFutures(
|
await allFutures(
|
||||||
wsStream.writer.tsource.closeWait(),
|
wsStream.writer.tsource.closeWait(),
|
||||||
wsStream.reader.tsource.closeWait())
|
wsStream.reader.tsource.closeWait())
|
||||||
|
|
175
ws/ws.nim
175
ws/ws.nim
|
@ -18,7 +18,7 @@ import pkg/[chronos,
|
||||||
stew/base10,
|
stew/base10,
|
||||||
nimcrypto/sha]
|
nimcrypto/sha]
|
||||||
|
|
||||||
import ./random, ./stream
|
import ./utils, ./stream
|
||||||
|
|
||||||
#[
|
#[
|
||||||
+---------------------------------------------------------------+
|
+---------------------------------------------------------------+
|
||||||
|
@ -71,7 +71,13 @@ type
|
||||||
WSMaxMessageSizeError* = object of WebSocketError
|
WSMaxMessageSizeError* = object of WebSocketError
|
||||||
WSClosedError* = object of WebSocketError
|
WSClosedError* = object of WebSocketError
|
||||||
WSSendError* = object of WebSocketError
|
WSSendError* = object of WebSocketError
|
||||||
WSPayloadTooLarge = object of WebSocketError
|
WSPayloadTooLarge* = object of WebSocketError
|
||||||
|
WSReserverdOpcodeError* = object of WebSocketError
|
||||||
|
WSFragmentedControlFrameError* = object of WebSocketError
|
||||||
|
WSInvalidCloseCodeError* = object of WebSocketError
|
||||||
|
WSPayloadLengthError* = object of WebSocketError
|
||||||
|
WSInvalidOpcodeError* = object of WebSocketError
|
||||||
|
|
||||||
|
|
||||||
Base16Error* = object of CatchableError
|
Base16Error* = object of CatchableError
|
||||||
## Base16 specific exception type
|
## Base16 specific exception type
|
||||||
|
@ -111,7 +117,7 @@ type
|
||||||
TooLarge = 1009
|
TooLarge = 1009
|
||||||
NoExtensions = 1010
|
NoExtensions = 1010
|
||||||
UnexpectedError = 1011
|
UnexpectedError = 1011
|
||||||
TlsError # use by clients
|
ReservedCode = 3999 # use by clients
|
||||||
# 3000-3999 reserved for libs
|
# 3000-3999 reserved for libs
|
||||||
# 4000-4999 reserved for applications
|
# 4000-4999 reserved for applications
|
||||||
|
|
||||||
|
@ -209,7 +215,8 @@ proc handshake*(
|
||||||
wantProtocol & ")")
|
wantProtocol & ")")
|
||||||
|
|
||||||
let cKey = ws.key & WSGuid
|
let cKey = ws.key & WSGuid
|
||||||
let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data)
|
let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0,
|
||||||
|
cKey.high)).data)
|
||||||
var headerData = [
|
var headerData = [
|
||||||
("Connection", "Upgrade"),
|
("Connection", "Upgrade"),
|
||||||
("Upgrade", "webSocket"),
|
("Upgrade", "webSocket"),
|
||||||
|
@ -222,7 +229,8 @@ proc handshake*(
|
||||||
try:
|
try:
|
||||||
discard await request.respond(httputils.Http101, "", headers)
|
discard await request.respond(httputils.Http101, "", headers)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg)
|
raise newException(WSHandshakeError,
|
||||||
|
"Failed to sent handshake response. Error: " & exc.msg)
|
||||||
ws.readyState = ReadyState.Open
|
ws.readyState = ReadyState.Open
|
||||||
|
|
||||||
proc createServer*(
|
proc createServer*(
|
||||||
|
@ -330,6 +338,8 @@ proc send*(
|
||||||
maskKey = genMaskKey(ws.rng)
|
maskKey = genMaskKey(ws.rng)
|
||||||
|
|
||||||
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||||
|
if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}:
|
||||||
|
return
|
||||||
await ws.stream.writer.write(encodeFrame(Frame(
|
await ws.stream.writer.write(encodeFrame(Frame(
|
||||||
fin: true,
|
fin: true,
|
||||||
rsv1: false,
|
rsv1: false,
|
||||||
|
@ -344,9 +354,9 @@ proc send*(
|
||||||
|
|
||||||
let maxSize = ws.frameSize
|
let maxSize = ws.frameSize
|
||||||
var i = 0
|
var i = 0
|
||||||
while i < data.len:
|
while ws.readyState notin {ReadyState.Closing}:
|
||||||
let len = min(data.len, (maxSize + i))
|
let len = min(data.len, (maxSize + i))
|
||||||
let inFrame = Frame(
|
let encFrame = encodeFrame(Frame(
|
||||||
fin: if (i + len >= data.len): true else: false,
|
fin: if (i + len >= data.len): true else: false,
|
||||||
rsv1: false,
|
rsv1: false,
|
||||||
rsv2: false,
|
rsv2: false,
|
||||||
|
@ -354,15 +364,21 @@ proc send*(
|
||||||
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
|
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
|
||||||
mask: ws.masked,
|
mask: ws.masked,
|
||||||
data: data[i ..< len],
|
data: data[i ..< len],
|
||||||
maskKey: maskKey)
|
maskKey: maskKey))
|
||||||
|
|
||||||
await ws.stream.writer.write(encodeFrame(inFrame))
|
await ws.stream.writer.write(encFrame)
|
||||||
i += len
|
i += len
|
||||||
|
|
||||||
|
if i >= data.len:
|
||||||
|
break
|
||||||
|
|
||||||
proc send*(ws: WebSocket, data: string): Future[void] =
|
proc send*(ws: WebSocket, data: string): Future[void] =
|
||||||
send(ws, toBytes(data), Opcode.Text)
|
send(ws, toBytes(data), Opcode.Text)
|
||||||
|
|
||||||
proc handleClose*(ws: WebSocket, frame: Frame) {.async.} =
|
proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
|
||||||
|
|
||||||
|
if ws.readyState notin {ReadyState.Open}:
|
||||||
|
return
|
||||||
logScope:
|
logScope:
|
||||||
fin = frame.fin
|
fin = frame.fin
|
||||||
masked = frame.mask
|
masked = frame.mask
|
||||||
|
@ -370,46 +386,46 @@ proc handleClose*(ws: WebSocket, frame: Frame) {.async.} =
|
||||||
serverState = ws.readyState
|
serverState = ws.readyState
|
||||||
|
|
||||||
debug "Handling close sequence"
|
debug "Handling close sequence"
|
||||||
if ws.readyState == ReadyState.Open or ws.readyState == ReadyState.Closing:
|
var
|
||||||
# Read control frame payload.
|
code = Status.Fulfilled
|
||||||
var data = newSeq[byte](frame.length)
|
reason = ""
|
||||||
if frame.length > 0:
|
|
||||||
# Read the data.
|
|
||||||
await ws.stream.reader.readExactly(addr data[0], int frame.length)
|
|
||||||
unmask(data.toOpenArray(0, data.high), frame.maskKey)
|
|
||||||
|
|
||||||
var code: Status
|
if payLoad.len == 1:
|
||||||
if data.len > 0:
|
raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!")
|
||||||
let ccode = uint16.fromBytesBE(data[0..<2]) # first two bytes are the status
|
elif payLoad.len > 1:
|
||||||
doAssert(ccode > 999, "No valid code in close message!")
|
# first two bytes are the status
|
||||||
|
let ccode = uint16.fromBytesBE(payLoad[0..<2])
|
||||||
|
if ccode <= 999 or ccode > 1015:
|
||||||
|
raise newException(WSInvalidCloseCodeError, "Invalid code in close message!")
|
||||||
|
try:
|
||||||
code = Status(ccode)
|
code = Status(ccode)
|
||||||
data = data[2..data.high]
|
except RangeError:
|
||||||
|
code = Status.Fulfilled
|
||||||
|
# remining payload bytes are reason for closing
|
||||||
|
reason = string.fromBytes(payLoad[2..payLoad.high])
|
||||||
|
|
||||||
|
var rcode: Status
|
||||||
|
if code in {Status.Fulfilled}:
|
||||||
|
rcode = Status.Fulfilled
|
||||||
|
|
||||||
var rcode = Status.Fulfilled
|
|
||||||
var reason = ""
|
|
||||||
if not isNil(ws.onClose):
|
if not isNil(ws.onClose):
|
||||||
try:
|
try:
|
||||||
(rcode, reason) = ws.onClose(code, string.fromBytes(data))
|
(rcode, reason) = ws.onClose(code, reason)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
|
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
|
||||||
|
|
||||||
# don't respond to a terminated connection
|
# don't respond to a terminated connection
|
||||||
if ws.readyState != ReadyState.Closing:
|
if ws.readyState != ReadyState.Closing:
|
||||||
|
ws.readyState = ReadyState.Closing
|
||||||
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
|
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
|
||||||
|
|
||||||
await ws.stream.closeWait()
|
|
||||||
ws.readyState = ReadyState.Closed
|
ws.readyState = ReadyState.Closed
|
||||||
else:
|
await ws.stream.closeWait()
|
||||||
raiseAssert("Invalid state during close!")
|
|
||||||
|
|
||||||
proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
|
proc handleControl*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
|
||||||
## handle control frames
|
## handle control frames
|
||||||
##
|
##
|
||||||
|
|
||||||
if frame.length > 125:
|
|
||||||
raise newException(WSPayloadTooLarge,
|
|
||||||
"Control message payload is greater than 125 bytes!")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Process control frame payload.
|
# Process control frame payload.
|
||||||
case frame.opcode:
|
case frame.opcode:
|
||||||
|
@ -421,7 +437,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
|
||||||
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
|
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
|
||||||
|
|
||||||
# send pong to remote
|
# send pong to remote
|
||||||
await ws.send(@[], Opcode.Pong)
|
await ws.send(payLoad, Opcode.Pong)
|
||||||
of Opcode.Pong:
|
of Opcode.Pong:
|
||||||
if not isNil(ws.onPong):
|
if not isNil(ws.onPong):
|
||||||
try:
|
try:
|
||||||
|
@ -429,9 +445,12 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
|
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
|
||||||
of Opcode.Close:
|
of Opcode.Close:
|
||||||
await ws.handleClose(frame)
|
await ws.handleClose(frame, payLoad)
|
||||||
else:
|
else:
|
||||||
raiseAssert("Invalid control opcode")
|
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
|
||||||
|
except WebSocketError as exc:
|
||||||
|
debug "Handled websocket exception", exc = exc.msg
|
||||||
|
raise exc
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Exception handling control messages", exc = exc.msg
|
trace "Exception handling control messages", exc = exc.msg
|
||||||
ws.readyState = ReadyState.Closed
|
ws.readyState = ReadyState.Closed
|
||||||
|
@ -467,8 +486,12 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
|
||||||
if opcode > ord(Opcode.high):
|
if opcode > ord(Opcode.high):
|
||||||
raise newException(WSOpcodeMismatchError, "Wrong opcode!")
|
raise newException(WSOpcodeMismatchError, "Wrong opcode!")
|
||||||
|
|
||||||
frame.opcode = (opcode).Opcode
|
let frameOpcode = (opcode).Opcode
|
||||||
|
if frameOpcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary,
|
||||||
|
Opcode.Ping, Opcode.Pong, Opcode.Close}:
|
||||||
|
raise newException(WSReserverdOpcodeError, "Unknown opcode received!")
|
||||||
|
|
||||||
|
frame.opcode = frameOpcode
|
||||||
# If any of the rsv are set close the socket.
|
# If any of the rsv are set close the socket.
|
||||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||||
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
||||||
|
@ -508,12 +531,33 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
|
||||||
|
|
||||||
# return the current frame if it's not one of the control frames
|
# return the current frame if it's not one of the control frames
|
||||||
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||||
asyncSpawn ws.handleControl(frame) # process control frames
|
if not frame.fin:
|
||||||
|
raise newException(WSFragmentedControlFrameError, "Control frame cannot be fragmented!")
|
||||||
|
if frame.length > 125:
|
||||||
|
raise newException(WSPayloadTooLarge,
|
||||||
|
"Control message payload is greater than 125 bytes!")
|
||||||
|
var payLoad = newSeq[byte](frame.length)
|
||||||
|
if frame.length > 0:
|
||||||
|
# Read control frame payload.
|
||||||
|
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
|
||||||
|
unmask(payLoad.toOpenArray(0, payLoad.high), frame.maskKey)
|
||||||
|
await ws.handleControl(frame, payLoad) # process control frames# process control frames
|
||||||
continue
|
continue
|
||||||
|
|
||||||
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
|
debug "Decoded new frame", opcode = frame.opcode, len = frame.length,
|
||||||
|
mask = frame.mask
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
except WSReserverdOpcodeError as exc:
|
||||||
|
trace "Handled websocket opcode exception", exc = exc.msg
|
||||||
|
raise exc
|
||||||
|
except WSPayloadTooLarge as exc:
|
||||||
|
debug "Handled payload too large exception", exc = exc.msg
|
||||||
|
raise exc
|
||||||
|
except WebSocketError as exc:
|
||||||
|
debug "Handled websocket exception", exc = exc.msg
|
||||||
|
raise exc
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
debug "Exception reading frame, dropping socket", exc = exc.msg
|
debug "Exception reading frame, dropping socket", exc = exc.msg
|
||||||
ws.readyState = ReadyState.Closed
|
ws.readyState = ReadyState.Closed
|
||||||
|
@ -543,17 +587,32 @@ proc recv*(
|
||||||
while consumed < size:
|
while consumed < size:
|
||||||
# we might have to read more than
|
# we might have to read more than
|
||||||
# one frame to fill the buffer
|
# one frame to fill the buffer
|
||||||
if isNil(ws.frame):
|
|
||||||
ws.frame = await ws.readFrame()
|
|
||||||
|
|
||||||
# all has been consumed from the frame
|
# all has been consumed from the frame
|
||||||
# read the next frame
|
# read the next frame
|
||||||
if ws.frame.remainder() <= 0:
|
if isNil(ws.frame):
|
||||||
ws.frame = await ws.readFrame()
|
ws.frame = await ws.readFrame()
|
||||||
|
# This could happen if the connection is closed.
|
||||||
|
if isNil(ws.frame):
|
||||||
|
return consumed.int
|
||||||
|
if ws.frame.opcode == Opcode.Cont:
|
||||||
|
raise newException(WSOpcodeMismatchError, "First frame cannot be continue frame")
|
||||||
|
|
||||||
|
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
|
||||||
|
ws.frame = await ws.readFrame()
|
||||||
|
# This could happen if the connection is closed.
|
||||||
|
if isNil(ws.frame):
|
||||||
|
return consumed.int
|
||||||
|
|
||||||
|
if ws.frame.opcode != Opcode.Cont:
|
||||||
|
raise newException(WSOpcodeMismatchError, "expected continue frame")
|
||||||
|
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||||
|
ws.frame = nil
|
||||||
|
break
|
||||||
|
|
||||||
let len = min(ws.frame.remainder().int, size - consumed)
|
let len = min(ws.frame.remainder().int, size - consumed)
|
||||||
|
if len == 0:
|
||||||
|
continue
|
||||||
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
|
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
|
||||||
|
|
||||||
if read <= 0:
|
if read <= 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -562,14 +621,18 @@ proc recv*(
|
||||||
unmask(
|
unmask(
|
||||||
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
|
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
|
||||||
ws.frame.maskKey,
|
ws.frame.maskKey,
|
||||||
consumed)
|
ws.frame.consumed.int)
|
||||||
|
|
||||||
consumed += read
|
consumed += read
|
||||||
ws.frame.consumed += read.uint64
|
ws.frame.consumed += read.uint64
|
||||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
return consumed.int
|
return consumed.int
|
||||||
|
|
||||||
|
except WebSocketError as exc:
|
||||||
|
debug "Websocket error", exc = exc.msg
|
||||||
|
ws.readyState = ReadyState.Closed
|
||||||
|
await ws.stream.closeWait()
|
||||||
|
raise exc
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
debug "Cancelling reading", exc = exc.msg
|
debug "Cancelling reading", exc = exc.msg
|
||||||
raise exc
|
raise exc
|
||||||
|
@ -611,7 +674,8 @@ proc recv*(
|
||||||
# read the entire message, exit
|
# read the entire message, exit
|
||||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||||
break
|
break
|
||||||
except WSMaxMessageSizeError as exc:
|
except WebSocketError as exc:
|
||||||
|
debug "Websocket error", exc = exc.msg
|
||||||
raise exc
|
raise exc
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
debug "Cancelling reading", exc = exc.msg
|
debug "Cancelling reading", exc = exc.msg
|
||||||
|
@ -659,11 +723,14 @@ proc initiateHandshake(
|
||||||
TransportError,
|
TransportError,
|
||||||
"Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg)
|
"Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg)
|
||||||
|
|
||||||
let requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers
|
let
|
||||||
let reader = newAsyncStreamReader(transp)
|
requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers
|
||||||
let writer = newAsyncStreamWriter(transp)
|
reader = newAsyncStreamReader(transp)
|
||||||
|
writer = newAsyncStreamWriter(transp)
|
||||||
|
|
||||||
var stream: AsyncStream
|
var stream: AsyncStream
|
||||||
|
|
||||||
|
try:
|
||||||
var res: seq[byte]
|
var res: seq[byte]
|
||||||
if uri.scheme == "https":
|
if uri.scheme == "https":
|
||||||
let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
|
let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
|
||||||
|
@ -694,6 +761,10 @@ proc initiateHandshake(
|
||||||
" Header code: " & $resHeader.code &
|
" Header code: " & $resHeader.code &
|
||||||
" Header reason: " & resHeader.reason() &
|
" Header reason: " & resHeader.reason() &
|
||||||
" Address: " & $transp.remoteAddress())
|
" Address: " & $transp.remoteAddress())
|
||||||
|
except CatchableError as exc:
|
||||||
|
debug "Websocket failed during handshake", exc = exc.msg
|
||||||
|
await stream.closeWait()
|
||||||
|
raise exc
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
|
@ -738,7 +809,7 @@ proc connect*(
|
||||||
# Client data should be masked.
|
# Client data should be masked.
|
||||||
return WebSocket(
|
return WebSocket(
|
||||||
stream: stream,
|
stream: stream,
|
||||||
readyState: Open,
|
readyState: ReadyState.Open,
|
||||||
masked: true,
|
masked: true,
|
||||||
rng: newRng(),
|
rng: newRng(),
|
||||||
frameSize: frameSize,
|
frameSize: frameSize,
|
||||||
|
|
Loading…
Reference in New Issue