Fix webscoket close and fix test cases.
This commit is contained in:
parent
0feac12a67
commit
51d834a0a1
|
@ -35,7 +35,7 @@ jobs:
|
|||
- target:
|
||||
os: windows
|
||||
builder: windows-2019
|
||||
name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})'
|
||||
name: "${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})"
|
||||
runs-on: ${{ matrix.builder }}
|
||||
steps:
|
||||
- name: Checkout nim-ws
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import pkg/[chronos,
|
||||
import pkg/[chronos,
|
||||
chronos/apps/http/httpserver,
|
||||
chronicles,
|
||||
httputils,
|
||||
stew/byteutils]
|
||||
httputils]
|
||||
|
||||
import ../ws/ws
|
||||
|
||||
|
@ -13,22 +12,18 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
|||
if request.uri.path == "/ws":
|
||||
debug "Initiating web socket connection."
|
||||
try:
|
||||
var ws = await createServer(request,"")
|
||||
let ws = await createServer(request, "")
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
debug "Websocket handshake completed."
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
# Only reads header for data frame.
|
||||
var recvData = await ws.recv()
|
||||
if recvData.len <= 0:
|
||||
debug "Empty messages"
|
||||
while true:
|
||||
let recvData = await ws.recv()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
debug "Websocket closed."
|
||||
break
|
||||
|
||||
# debug "Client Response: ", data = string.fromBytes(recvData), size = recvData.len
|
||||
debug "Client Response: ", size = recvData.len
|
||||
await ws.send(recvData)
|
||||
# await ws.close()
|
||||
|
||||
except WebSocketError as exc:
|
||||
error "WebSocket error:", exception = exc.msg
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
include ../ws/ws
|
||||
include ../ws/random
|
||||
include ../ws/utils
|
||||
|
||||
# TODO: Fix Test.
|
||||
|
||||
|
|
|
@ -2,22 +2,32 @@ import std/strutils, httputils
|
|||
|
||||
import pkg/[asynctest,
|
||||
chronos,
|
||||
chronicles,
|
||||
chronos/apps/http/shttpserver,
|
||||
stew/byteutils]
|
||||
|
||||
import ../ws/[ws, stream],
|
||||
import ../ws/[ws, stream],
|
||||
../examples/tlsserver
|
||||
|
||||
import ./keys
|
||||
|
||||
var server: SecureHttpServerRef
|
||||
let address = initTAddress("127.0.0.1:8888")
|
||||
let serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect}
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||
proc waitForClose(ws: WebSocket) {.async.} =
|
||||
try:
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
except CatchableError:
|
||||
debug "Closing websocket"
|
||||
|
||||
var server: SecureHttpServerRef
|
||||
|
||||
let
|
||||
address = initTAddress("127.0.0.1:8888")
|
||||
serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect}
|
||||
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||
secureKey = TLSPrivateKey.init(SecureKey)
|
||||
secureCert = TLSCertificate.init(SecureCert)
|
||||
|
||||
let secureKey = TLSPrivateKey.init(SecureKey)
|
||||
let secureCert = TLSCertificate.init(SecureCert)
|
||||
|
||||
suite "Test websocket TLS handshake":
|
||||
teardown:
|
||||
|
@ -31,10 +41,7 @@ suite "Test websocket TLS handshake":
|
|||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
expect WSProtoMismatchError:
|
||||
var ws = await createServer(request, "proto")
|
||||
check ws.readyState == ReadyState.Closed
|
||||
|
||||
return await request.respond(Http200, "Connection established")
|
||||
discard await createServer(request, "proto")
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
|
@ -62,10 +69,7 @@ suite "Test websocket TLS handshake":
|
|||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
expect WSVersionError:
|
||||
var ws = await createServer(request, "proto")
|
||||
check ws.readyState == ReadyState.Closed
|
||||
|
||||
return await request.respond(Http200, "Connection established")
|
||||
discard await createServer(request, "proto")
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
|
@ -91,14 +95,17 @@ suite "Test websocket TLS handshake":
|
|||
check r.isOk()
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii()
|
||||
check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii()
|
||||
check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii()
|
||||
check request.headers.getString("Connection").toUpperAscii() ==
|
||||
"Upgrade".toUpperAscii()
|
||||
check request.headers.getString("Upgrade").toUpperAscii() ==
|
||||
"websocket".toUpperAscii()
|
||||
check request.headers.getString("Cache-Control").toUpperAscii() ==
|
||||
"no-cache".toUpperAscii()
|
||||
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
|
||||
|
||||
check request.headers.contains("Sec-WebSocket-Key")
|
||||
|
||||
discard await request.respond( Http200,"Connection established")
|
||||
discard await request.respond(Http200, "Connection established")
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
|
@ -133,7 +140,7 @@ suite "Test websocket TLS transmission":
|
|||
let ws = await createServer(request, "proto")
|
||||
let servRes = await ws.recv()
|
||||
check string.fromBytes(servRes) == testString
|
||||
await ws.close()
|
||||
await waitForClose(ws)
|
||||
return dumbResponse()
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
|
@ -158,41 +165,7 @@ suite "Test websocket TLS transmission":
|
|||
|
||||
test "Client - test reading simple frame":
|
||||
let testString = "Hello!"
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
let ws = await createServer(request, "proto")
|
||||
let servRes = await ws.recv()
|
||||
check string.fromBytes(servRes) == testString
|
||||
await ws.close()
|
||||
return dumbResponse()
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
|
||||
server = res.get()
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["proto"],
|
||||
clientFlags)
|
||||
|
||||
await wsClient.send(testString)
|
||||
await wsClient.close()
|
||||
|
||||
test "Client - test reading simple frame":
|
||||
let testString = "Hello!"
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
|
@ -222,4 +195,4 @@ suite "Test websocket TLS transmission":
|
|||
|
||||
var clientRes = await wsClient.recv()
|
||||
check string.fromBytes(clientRes) == testString
|
||||
await wsClient.close()
|
||||
await waitForClose(wsClient)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,4 @@
|
|||
import pkg/[chronos,
|
||||
import pkg/[chronos,
|
||||
chronos/apps/http/httpserver,
|
||||
chronos/timer,
|
||||
chronicles,
|
||||
|
@ -8,7 +8,7 @@ import strutils
|
|||
const
|
||||
HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')]
|
||||
HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec)
|
||||
MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets
|
||||
MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets
|
||||
|
||||
proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} =
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
|
@ -45,7 +45,11 @@ proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} =
|
|||
|
||||
return buffer
|
||||
|
||||
proc closeWait*(wsStream : AsyncStream) {.async.} =
|
||||
proc closeWait*(wsStream: AsyncStream) {.async.} =
|
||||
|
||||
await allFutures(
|
||||
wsStream.writer.closeWait(),
|
||||
wsStream.reader.closeWait())
|
||||
await allFutures(
|
||||
wsStream.writer.tsource.closeWait(),
|
||||
wsStream.reader.tsource.closeWait())
|
||||
|
|
309
ws/ws.nim
309
ws/ws.nim
|
@ -18,7 +18,7 @@ import pkg/[chronos,
|
|||
stew/base10,
|
||||
nimcrypto/sha]
|
||||
|
||||
import ./random, ./stream
|
||||
import ./utils, ./stream
|
||||
|
||||
#[
|
||||
+---------------------------------------------------------------+
|
||||
|
@ -47,7 +47,7 @@ const
|
|||
WSHeaderSize* = 12
|
||||
WSDefaultVersion* = 13
|
||||
WSDefaultFrameSize* = 1 shl 20 # 1mb
|
||||
WSMaxMessageSize* = 20 shl 20 # 20mb
|
||||
WSMaxMessageSize* = 20 shl 20 # 20mb
|
||||
WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
CRLF* = "\r\n"
|
||||
|
||||
|
@ -71,7 +71,13 @@ type
|
|||
WSMaxMessageSizeError* = object of WebSocketError
|
||||
WSClosedError* = object of WebSocketError
|
||||
WSSendError* = object of WebSocketError
|
||||
WSPayloadTooLarge = object of WebSocketError
|
||||
WSPayloadTooLarge* = object of WebSocketError
|
||||
WSReserverdOpcodeError* = object of WebSocketError
|
||||
WSFragmentedControlFrameError* = object of WebSocketError
|
||||
WSInvalidCloseCodeError* = object of WebSocketError
|
||||
WSPayloadLengthError* = object of WebSocketError
|
||||
WSInvalidOpcodeError* = object of WebSocketError
|
||||
|
||||
|
||||
Base16Error* = object of CatchableError
|
||||
## Base16 specific exception type
|
||||
|
@ -111,21 +117,21 @@ type
|
|||
TooLarge = 1009
|
||||
NoExtensions = 1010
|
||||
UnexpectedError = 1011
|
||||
TlsError # use by clients
|
||||
# 3000-3999 reserved for libs
|
||||
# 4000-4999 reserved for applications
|
||||
ReservedCode = 3999 # use by clients
|
||||
# 3000-3999 reserved for libs
|
||||
# 4000-4999 reserved for applications
|
||||
|
||||
Frame = ref object
|
||||
fin: bool ## Indicates that this is the final fragment in a message.
|
||||
rsv1: bool ## MUST be 0 unless negotiated that defines meanings
|
||||
rsv2: bool ## MUST be 0
|
||||
rsv3: bool ## MUST be 0
|
||||
opcode: Opcode ## Defines the interpretation of the "Payload data".
|
||||
mask: bool ## Defines whether the "Payload data" is masked.
|
||||
data: seq[byte] ## Payload data
|
||||
maskKey: array[4, char] ## Masking key
|
||||
length: uint64 ## Message size.
|
||||
consumed: uint64 ## how much has been consumed from the frame
|
||||
fin: bool ## Indicates that this is the final fragment in a message.
|
||||
rsv1: bool ## MUST be 0 unless negotiated that defines meanings
|
||||
rsv2: bool ## MUST be 0
|
||||
rsv3: bool ## MUST be 0
|
||||
opcode: Opcode ## Defines the interpretation of the "Payload data".
|
||||
mask: bool ## Defines whether the "Payload data" is masked.
|
||||
data: seq[byte] ## Payload data
|
||||
maskKey: array[4, char] ## Masking key
|
||||
length: uint64 ## Message size.
|
||||
consumed: uint64 ## how much has been consumed from the frame
|
||||
|
||||
ControlCb* = proc() {.gcsafe, raises: [Defect].}
|
||||
|
||||
|
@ -156,11 +162,11 @@ template remainder*(frame: Frame): uint64 =
|
|||
proc `$`(ht: HttpTables): string =
|
||||
## Returns string representation of HttpTable/Ref.
|
||||
var res = ""
|
||||
for key,value in ht.stringItems(true):
|
||||
res.add(key.normalizeHeaderName())
|
||||
res.add(": ")
|
||||
res.add(value)
|
||||
res.add(CRLF)
|
||||
for key, value in ht.stringItems(true):
|
||||
res.add(key.normalizeHeaderName())
|
||||
res.add(": ")
|
||||
res.add(value)
|
||||
res.add(CRLF)
|
||||
|
||||
## add for end of header mark
|
||||
res.add(CRLF)
|
||||
|
@ -209,10 +215,11 @@ proc handshake*(
|
|||
wantProtocol & ")")
|
||||
|
||||
let cKey = ws.key & WSGuid
|
||||
let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data)
|
||||
let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0,
|
||||
cKey.high)).data)
|
||||
var headerData = [
|
||||
("Connection", "Upgrade"),
|
||||
("Upgrade", "webSocket" ),
|
||||
("Upgrade", "webSocket"),
|
||||
("Sec-WebSocket-Accept", acceptKey)]
|
||||
|
||||
var headers = HttpTable.init(headerData)
|
||||
|
@ -222,7 +229,8 @@ proc handshake*(
|
|||
try:
|
||||
discard await request.respond(httputils.Http101, "", headers)
|
||||
except CatchableError as exc:
|
||||
raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg)
|
||||
raise newException(WSHandshakeError,
|
||||
"Failed to sent handshake response. Error: " & exc.msg)
|
||||
ws.readyState = ReadyState.Open
|
||||
|
||||
proc createServer*(
|
||||
|
@ -330,39 +338,47 @@ proc send*(
|
|||
maskKey = genMaskKey(ws.rng)
|
||||
|
||||
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||
if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}:
|
||||
return
|
||||
await ws.stream.writer.write(encodeFrame(Frame(
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: opcode,
|
||||
mask: ws.masked,
|
||||
data: data, # allow sending data with close messages
|
||||
maskKey: maskKey)))
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: opcode,
|
||||
mask: ws.masked,
|
||||
data: data, # allow sending data with close messages
|
||||
maskKey: maskKey)))
|
||||
|
||||
return
|
||||
|
||||
let maxSize = ws.frameSize
|
||||
var i = 0
|
||||
while i < data.len:
|
||||
while ws.readyState notin {ReadyState.Closing}:
|
||||
let len = min(data.len, (maxSize + i))
|
||||
let inFrame = Frame(
|
||||
fin: if (i + len >= data.len): true else: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
|
||||
mask: ws.masked,
|
||||
data: data[i ..< len],
|
||||
maskKey: maskKey)
|
||||
let encFrame = encodeFrame(Frame(
|
||||
fin: if (i + len >= data.len): true else: false,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
|
||||
mask: ws.masked,
|
||||
data: data[i ..< len],
|
||||
maskKey: maskKey))
|
||||
|
||||
await ws.stream.writer.write(encodeFrame(inFrame))
|
||||
await ws.stream.writer.write(encFrame)
|
||||
i += len
|
||||
|
||||
if i >= data.len:
|
||||
break
|
||||
|
||||
proc send*(ws: WebSocket, data: string): Future[void] =
|
||||
send(ws, toBytes(data), Opcode.Text)
|
||||
|
||||
proc handleClose*(ws: WebSocket, frame: Frame) {.async.} =
|
||||
proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
|
||||
|
||||
if ws.readyState notin {ReadyState.Open}:
|
||||
return
|
||||
logScope:
|
||||
fin = frame.fin
|
||||
masked = frame.mask
|
||||
|
@ -370,46 +386,46 @@ proc handleClose*(ws: WebSocket, frame: Frame) {.async.} =
|
|||
serverState = ws.readyState
|
||||
|
||||
debug "Handling close sequence"
|
||||
if ws.readyState == ReadyState.Open or ws.readyState == ReadyState.Closing:
|
||||
# Read control frame payload.
|
||||
var data = newSeq[byte](frame.length)
|
||||
if frame.length > 0:
|
||||
# Read the data.
|
||||
await ws.stream.reader.readExactly(addr data[0], int frame.length)
|
||||
unmask(data.toOpenArray(0, data.high), frame.maskKey)
|
||||
var
|
||||
code = Status.Fulfilled
|
||||
reason = ""
|
||||
|
||||
var code: Status
|
||||
if data.len > 0:
|
||||
let ccode = uint16.fromBytesBE(data[0..<2]) # first two bytes are the status
|
||||
doAssert(ccode > 999, "No valid code in close message!")
|
||||
if payLoad.len == 1:
|
||||
raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!")
|
||||
elif payLoad.len > 1:
|
||||
# first two bytes are the status
|
||||
let ccode = uint16.fromBytesBE(payLoad[0..<2])
|
||||
if ccode <= 999 or ccode > 1015:
|
||||
raise newException(WSInvalidCloseCodeError, "Invalid code in close message!")
|
||||
try:
|
||||
code = Status(ccode)
|
||||
data = data[2..data.high]
|
||||
except RangeError:
|
||||
code = Status.Fulfilled
|
||||
# remining payload bytes are reason for closing
|
||||
reason = string.fromBytes(payLoad[2..payLoad.high])
|
||||
|
||||
var rcode = Status.Fulfilled
|
||||
var reason = ""
|
||||
if not isNil(ws.onClose):
|
||||
try:
|
||||
(rcode, reason) = ws.onClose(code, string.fromBytes(data))
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
|
||||
var rcode: Status
|
||||
if code in {Status.Fulfilled}:
|
||||
rcode = Status.Fulfilled
|
||||
|
||||
# don't respond to a terminated connection
|
||||
if ws.readyState != ReadyState.Closing:
|
||||
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
|
||||
if not isNil(ws.onClose):
|
||||
try:
|
||||
(rcode, reason) = ws.onClose(code, reason)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
|
||||
|
||||
# don't respond to a terminated connection
|
||||
if ws.readyState != ReadyState.Closing:
|
||||
ws.readyState = ReadyState.Closing
|
||||
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
|
||||
|
||||
await ws.stream.closeWait()
|
||||
ws.readyState = ReadyState.Closed
|
||||
else:
|
||||
raiseAssert("Invalid state during close!")
|
||||
await ws.stream.closeWait()
|
||||
|
||||
proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
|
||||
proc handleControl*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
|
||||
## handle control frames
|
||||
##
|
||||
|
||||
if frame.length > 125:
|
||||
raise newException(WSPayloadTooLarge,
|
||||
"Control message payload is greater than 125 bytes!")
|
||||
|
||||
try:
|
||||
# Process control frame payload.
|
||||
case frame.opcode:
|
||||
|
@ -421,7 +437,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
|
|||
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
|
||||
|
||||
# send pong to remote
|
||||
await ws.send(@[], Opcode.Pong)
|
||||
await ws.send(payLoad, Opcode.Pong)
|
||||
of Opcode.Pong:
|
||||
if not isNil(ws.onPong):
|
||||
try:
|
||||
|
@ -429,9 +445,12 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
|
|||
except CatchableError as exc:
|
||||
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
|
||||
of Opcode.Close:
|
||||
await ws.handleClose(frame)
|
||||
await ws.handleClose(frame, payLoad)
|
||||
else:
|
||||
raiseAssert("Invalid control opcode")
|
||||
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
|
||||
except WebSocketError as exc:
|
||||
debug "Handled websocket exception", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "Exception handling control messages", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
|
@ -467,8 +486,12 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
|
|||
if opcode > ord(Opcode.high):
|
||||
raise newException(WSOpcodeMismatchError, "Wrong opcode!")
|
||||
|
||||
frame.opcode = (opcode).Opcode
|
||||
let frameOpcode = (opcode).Opcode
|
||||
if frameOpcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary,
|
||||
Opcode.Ping, Opcode.Pong, Opcode.Close}:
|
||||
raise newException(WSReserverdOpcodeError, "Unknown opcode received!")
|
||||
|
||||
frame.opcode = frameOpcode
|
||||
# If any of the rsv are set close the socket.
|
||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
||||
|
@ -508,12 +531,33 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
|
|||
|
||||
# return the current frame if it's not one of the control frames
|
||||
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||
asyncSpawn ws.handleControl(frame) # process control frames
|
||||
if not frame.fin:
|
||||
raise newException(WSFragmentedControlFrameError, "Control frame cannot be fragmented!")
|
||||
if frame.length > 125:
|
||||
raise newException(WSPayloadTooLarge,
|
||||
"Control message payload is greater than 125 bytes!")
|
||||
var payLoad = newSeq[byte](frame.length)
|
||||
if frame.length > 0:
|
||||
# Read control frame payload.
|
||||
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
|
||||
unmask(payLoad.toOpenArray(0, payLoad.high), frame.maskKey)
|
||||
await ws.handleControl(frame, payLoad) # process control frames# process control frames
|
||||
continue
|
||||
|
||||
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
|
||||
debug "Decoded new frame", opcode = frame.opcode, len = frame.length,
|
||||
mask = frame.mask
|
||||
|
||||
return frame
|
||||
|
||||
except WSReserverdOpcodeError as exc:
|
||||
trace "Handled websocket opcode exception", exc = exc.msg
|
||||
raise exc
|
||||
except WSPayloadTooLarge as exc:
|
||||
debug "Handled payload too large exception", exc = exc.msg
|
||||
raise exc
|
||||
except WebSocketError as exc:
|
||||
debug "Handled websocket exception", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Exception reading frame, dropping socket", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
|
@ -543,17 +587,32 @@ proc recv*(
|
|||
while consumed < size:
|
||||
# we might have to read more than
|
||||
# one frame to fill the buffer
|
||||
if isNil(ws.frame):
|
||||
ws.frame = await ws.readFrame()
|
||||
|
||||
# all has been consumed from the frame
|
||||
# read the next frame
|
||||
if ws.frame.remainder() <= 0:
|
||||
if isNil(ws.frame):
|
||||
ws.frame = await ws.readFrame()
|
||||
# This could happen if the connection is closed.
|
||||
if isNil(ws.frame):
|
||||
return consumed.int
|
||||
if ws.frame.opcode == Opcode.Cont:
|
||||
raise newException(WSOpcodeMismatchError, "First frame cannot be continue frame")
|
||||
|
||||
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
|
||||
ws.frame = await ws.readFrame()
|
||||
# This could happen if the connection is closed.
|
||||
if isNil(ws.frame):
|
||||
return consumed.int
|
||||
|
||||
if ws.frame.opcode != Opcode.Cont:
|
||||
raise newException(WSOpcodeMismatchError, "expected continue frame")
|
||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||
ws.frame = nil
|
||||
break
|
||||
|
||||
let len = min(ws.frame.remainder().int, size - consumed)
|
||||
if len == 0:
|
||||
continue
|
||||
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
|
||||
|
||||
if read <= 0:
|
||||
continue
|
||||
|
||||
|
@ -562,14 +621,18 @@ proc recv*(
|
|||
unmask(
|
||||
pbuffer.toOpenArray(consumed, (consumed + read) - 1),
|
||||
ws.frame.maskKey,
|
||||
consumed)
|
||||
ws.frame.consumed.int)
|
||||
|
||||
consumed += read
|
||||
ws.frame.consumed += read.uint64
|
||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||
break
|
||||
|
||||
return consumed.int
|
||||
|
||||
except WebSocketError as exc:
|
||||
debug "Websocket error", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
raise exc
|
||||
except CancelledError as exc:
|
||||
debug "Cancelling reading", exc = exc.msg
|
||||
raise exc
|
||||
|
@ -611,7 +674,8 @@ proc recv*(
|
|||
# read the entire message, exit
|
||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||
break
|
||||
except WSMaxMessageSizeError as exc:
|
||||
except WebSocketError as exc:
|
||||
debug "Websocket error", exc = exc.msg
|
||||
raise exc
|
||||
except CancelledError as exc:
|
||||
debug "Cancelling reading", exc = exc.msg
|
||||
|
@ -659,41 +723,48 @@ proc initiateHandshake(
|
|||
TransportError,
|
||||
"Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg)
|
||||
|
||||
let requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers
|
||||
let reader = newAsyncStreamReader(transp)
|
||||
let writer = newAsyncStreamWriter(transp)
|
||||
let
|
||||
requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers
|
||||
reader = newAsyncStreamReader(transp)
|
||||
writer = newAsyncStreamWriter(transp)
|
||||
|
||||
var stream: AsyncStream
|
||||
|
||||
var res: seq[byte]
|
||||
if uri.scheme == "https":
|
||||
let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
|
||||
stream = AsyncStream(
|
||||
reader: tlsstream.reader,
|
||||
writer: tlsstream.writer)
|
||||
try:
|
||||
var res: seq[byte]
|
||||
if uri.scheme == "https":
|
||||
let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
|
||||
stream = AsyncStream(
|
||||
reader: tlsstream.reader,
|
||||
writer: tlsstream.writer)
|
||||
|
||||
await tlsstream.writer.write(requestHeader)
|
||||
res = await tlsstream.reader.readHeaders()
|
||||
else:
|
||||
stream = AsyncStream(
|
||||
reader: reader,
|
||||
writer: writer)
|
||||
await stream.writer.write(requestHeader)
|
||||
res = await stream.reader.readHeaders()
|
||||
await tlsstream.writer.write(requestHeader)
|
||||
res = await tlsstream.reader.readHeaders()
|
||||
else:
|
||||
stream = AsyncStream(
|
||||
reader: reader,
|
||||
writer: writer)
|
||||
await stream.writer.write(requestHeader)
|
||||
res = await stream.reader.readHeaders()
|
||||
|
||||
if res.len == 0:
|
||||
raise newException(ValueError, "Empty response from server")
|
||||
if res.len == 0:
|
||||
raise newException(ValueError, "Empty response from server")
|
||||
|
||||
let resHeader = res.parseResponse()
|
||||
if resHeader.failed():
|
||||
# Header could not be parsed
|
||||
raise newException(WSMalformedHeaderError, "Malformed header received.")
|
||||
let resHeader = res.parseResponse()
|
||||
if resHeader.failed():
|
||||
# Header could not be parsed
|
||||
raise newException(WSMalformedHeaderError, "Malformed header received.")
|
||||
|
||||
if resHeader.code != ord(Http101):
|
||||
raise newException(WSFailedUpgradeError,
|
||||
"Server did not reply with a websocket upgrade:" &
|
||||
" Header code: " & $resHeader.code &
|
||||
" Header reason: " & resHeader.reason() &
|
||||
" Address: " & $transp.remoteAddress())
|
||||
if resHeader.code != ord(Http101):
|
||||
raise newException(WSFailedUpgradeError,
|
||||
"Server did not reply with a websocket upgrade:" &
|
||||
" Header code: " & $resHeader.code &
|
||||
" Header reason: " & resHeader.reason() &
|
||||
" Address: " & $transp.remoteAddress())
|
||||
except CatchableError as exc:
|
||||
debug "Websocket failed during handshake", exc = exc.msg
|
||||
await stream.closeWait()
|
||||
raise exc
|
||||
|
||||
return stream
|
||||
|
||||
|
@ -738,7 +809,7 @@ proc connect*(
|
|||
# Client data should be masked.
|
||||
return WebSocket(
|
||||
stream: stream,
|
||||
readyState: Open,
|
||||
readyState: ReadyState.Open,
|
||||
masked: true,
|
||||
rng: newRng(),
|
||||
frameSize: frameSize,
|
||||
|
|
Loading…
Reference in New Issue