Perform utf-8 validation at message boundaries (#90)
* validate utf8 at the message level * move utf-8 validation to message * rename `recv` to `recvMsg` * add partial frame validation tests * use `recvMsg` instead of `recv`
This commit is contained in:
parent
00440b6eff
commit
0ec755738c
|
@ -43,7 +43,7 @@ proc getCaseCount(): Future[int] {.async.} =
|
|||
block:
|
||||
try:
|
||||
let ws = await connectServer("/getCaseCount")
|
||||
let buff = await ws.recv()
|
||||
let buff = await ws.recvMsg()
|
||||
let dataStr = string.fromBytes(buff)
|
||||
caseCount = parseInt(dataStr)
|
||||
await ws.close()
|
||||
|
@ -60,7 +60,7 @@ proc generateReport() {.async.} =
|
|||
trace "request autobahn server to generate report"
|
||||
let ws = await connectServer("/updateReports?agent=" & agent)
|
||||
while true:
|
||||
let buff = await ws.recv()
|
||||
let buff = await ws.recvMsg()
|
||||
if buff.len <= 0:
|
||||
break
|
||||
await ws.close()
|
||||
|
@ -80,7 +80,7 @@ proc main() {.async.} =
|
|||
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
# echo back
|
||||
let data = await ws.recv()
|
||||
let data = await ws.recvMsg()
|
||||
let opCode = if ws.binary:
|
||||
Opcode.Binary
|
||||
else:
|
||||
|
|
|
@ -28,7 +28,7 @@ proc handle(request: HttpRequest) {.async.} =
|
|||
|
||||
trace "Websocket handshake completed"
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let recvData = await ws.recv()
|
||||
let recvData = await ws.recvMsg()
|
||||
trace "Client Response: ", size = recvData.len, binary = ws.binary
|
||||
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
|
|
|
@ -35,7 +35,7 @@ suite "permessage deflate compression":
|
|||
let ws = await server.handleRequest(request)
|
||||
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let recvData = await ws.recv()
|
||||
let recvData = await ws.recvMsg()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
break
|
||||
await ws.send(recvData,
|
||||
|
@ -58,7 +58,7 @@ suite "permessage deflate compression":
|
|||
|
||||
var recvData: seq[byte]
|
||||
while recvData.len < textData.len:
|
||||
let res = await client.recv()
|
||||
let res = await client.recvMsg()
|
||||
recvData.add res
|
||||
if client.readyState == ReadyState.Closed:
|
||||
break
|
||||
|
@ -75,7 +75,7 @@ suite "permessage deflate compression":
|
|||
)
|
||||
let ws = await server.handleRequest(request)
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let recvData = await ws.recv()
|
||||
let recvData = await ws.recvMsg()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
break
|
||||
await ws.send(recvData,
|
||||
|
@ -98,7 +98,7 @@ suite "permessage deflate compression":
|
|||
|
||||
var recvData: seq[byte]
|
||||
while recvData.len < binaryData.len:
|
||||
let res = await client.recv()
|
||||
let res = await client.recvMsg()
|
||||
recvData.add res
|
||||
if client.readyState == ReadyState.Closed:
|
||||
break
|
||||
|
|
|
@ -30,7 +30,7 @@ suite "multiple extensions flow":
|
|||
factories = [hexFactory, base64Factory],
|
||||
)
|
||||
let ws = await server.handleRequest(request)
|
||||
let recvData = await ws.recv()
|
||||
let recvData = await ws.recvMsg()
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
|
||||
|
@ -50,7 +50,7 @@ suite "multiple extensions flow":
|
|||
)
|
||||
|
||||
await client.send(testData)
|
||||
let res = await client.recv()
|
||||
let res = await client.recvMsg()
|
||||
check testData.toBytes() == res
|
||||
await client.close()
|
||||
|
||||
|
@ -62,7 +62,7 @@ suite "multiple extensions flow":
|
|||
factories = [hexFactory, base64Factory],
|
||||
)
|
||||
let ws = await server.handleRequest(request)
|
||||
let recvData = await ws.recv()
|
||||
let recvData = await ws.recvMsg()
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
|
||||
|
@ -82,6 +82,6 @@ suite "multiple extensions flow":
|
|||
)
|
||||
|
||||
await client.send(testData)
|
||||
let res = await client.recv()
|
||||
let res = await client.recvMsg()
|
||||
check testData.toBytes() == res
|
||||
await client.close()
|
||||
|
|
|
@ -37,7 +37,7 @@ proc rndBin*(size: int): seq[byte] =
|
|||
proc waitForClose*(ws: WSSession) {.async.} =
|
||||
try:
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
discard await ws.recvMsg()
|
||||
except CatchableError:
|
||||
trace "Closing websocket"
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ suite "UTF-8 DFA validator":
|
|||
proc waitForClose(ws: WSSession) {.async.} =
|
||||
try:
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
discard await ws.recvMsg()
|
||||
except CatchableError:
|
||||
trace "Closing websocket"
|
||||
|
||||
|
@ -96,7 +96,7 @@ suite "UTF-8 validator in action":
|
|||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let res = await ws.recv()
|
||||
let res = await ws.recvMsg()
|
||||
check:
|
||||
string.fromBytes(res) == testData
|
||||
ws.binary == false
|
||||
|
@ -135,7 +135,7 @@ suite "UTF-8 validator in action":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"], onClose = onClose)
|
||||
let ws = await server.handleRequest(request)
|
||||
let res = await ws.recv()
|
||||
let res = await ws.recvMsg()
|
||||
await waitForClose(ws)
|
||||
|
||||
check:
|
||||
|
@ -182,7 +182,7 @@ suite "UTF-8 validator in action":
|
|||
)
|
||||
|
||||
expect WSInvalidUTF8:
|
||||
let data = await session.recv()
|
||||
let data = await session.recvMsg()
|
||||
|
||||
test "invalid UTF-8 sequence close code":
|
||||
let closeReason = "i want to close\xc0\xaf"
|
||||
|
@ -207,4 +207,4 @@ suite "UTF-8 validator in action":
|
|||
)
|
||||
|
||||
expect WSInvalidUTF8:
|
||||
let data = await session.recv()
|
||||
let data = await session.recvMsg()
|
||||
|
|
|
@ -109,7 +109,7 @@ suite "Test transmission":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
let servRes = await ws.recvMsg()
|
||||
|
||||
check string.fromBytes(servRes) == testString
|
||||
await ws.waitForClose()
|
||||
|
@ -129,7 +129,7 @@ suite "Test transmission":
|
|||
check request.uri.path == WSPath
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
let servRes = await ws.recvMsg()
|
||||
check string.fromBytes(servRes) == testString
|
||||
await ws.waitForClose()
|
||||
|
||||
|
@ -158,7 +158,7 @@ suite "Test transmission":
|
|||
flags = {ReuseAddr})
|
||||
|
||||
let session = await connectClient()
|
||||
var clientRes = await session.recv()
|
||||
var clientRes = await session.recvMsg()
|
||||
check string.fromBytes(clientRes) == testString
|
||||
await waitForClose(session)
|
||||
|
||||
|
@ -265,7 +265,7 @@ suite "Test ping-pong":
|
|||
let ws = await server.handleRequest(request)
|
||||
|
||||
expect WSPayloadTooLarge:
|
||||
discard await ws.recv()
|
||||
discard await ws.recvMsg()
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
|
@ -334,7 +334,7 @@ suite "Test framing":
|
|||
let session = await connectClient()
|
||||
|
||||
expect WSMaxMessageSizeError:
|
||||
discard await session.recv(5)
|
||||
discard await session.recvMsg(5)
|
||||
await waitForClose(session)
|
||||
|
||||
suite "Test Closing":
|
||||
|
@ -575,7 +575,7 @@ suite "Test Payload":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
let servRes = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
servRes.len == 0
|
||||
|
@ -591,7 +591,7 @@ suite "Test Payload":
|
|||
|
||||
let session = await connectClient()
|
||||
await session.send(emptyStr)
|
||||
let clientRes = await session.recv()
|
||||
let clientRes = await session.recvMsg()
|
||||
|
||||
check:
|
||||
clientRes.len == 0
|
||||
|
@ -607,7 +607,7 @@ suite "Test Payload":
|
|||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
for _ in 0..<3:
|
||||
let servRes = await ws.recv()
|
||||
let servRes = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
servRes.len == 0
|
||||
|
@ -629,7 +629,7 @@ suite "Test Payload":
|
|||
await session.send(emptyStr)
|
||||
|
||||
for _ in 0..<3:
|
||||
let clientRes = await session.recv()
|
||||
let clientRes = await session.recvMsg()
|
||||
|
||||
check:
|
||||
clientRes.len == 0
|
||||
|
@ -648,7 +648,7 @@ suite "Test Payload":
|
|||
let server = WSServer.new(protos = ["proto"])
|
||||
|
||||
let ws = await server.handleRequest(request)
|
||||
let respData = await ws.recv()
|
||||
let respData = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
string.fromBytes(respData) == testString
|
||||
|
@ -706,7 +706,7 @@ suite "Test Payload":
|
|||
)
|
||||
|
||||
let ws = await server.handleRequest(request)
|
||||
let respData = await ws.recv()
|
||||
let respData = await ws.recvMsg()
|
||||
check:
|
||||
string.fromBytes(respData) == testString
|
||||
|
||||
|
@ -764,7 +764,7 @@ suite "Test Payload":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let res = await ws.recv()
|
||||
let res = await ws.recvMsg()
|
||||
|
||||
check ws.binary == false
|
||||
await ws.send(res, Opcode.Text)
|
||||
|
@ -781,7 +781,7 @@ suite "Test Payload":
|
|||
)
|
||||
|
||||
await ws.send(testData)
|
||||
let echoed = await ws.recv()
|
||||
let echoed = await ws.recvMsg()
|
||||
await ws.close()
|
||||
|
||||
check:
|
||||
|
@ -800,7 +800,7 @@ suite "Test Binary message with Payload":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let servRes = await ws.recv()
|
||||
let servRes = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
servRes == emptyData
|
||||
|
@ -825,7 +825,7 @@ suite "Test Binary message with Payload":
|
|||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let servRes = await ws.recv()
|
||||
let servRes = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
servRes == emptyData
|
||||
|
@ -858,7 +858,7 @@ suite "Test Binary message with Payload":
|
|||
)
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let res = await ws.recv()
|
||||
let res = await ws.recvMsg()
|
||||
check:
|
||||
res == testData
|
||||
ws.binary == true
|
||||
|
@ -892,7 +892,7 @@ suite "Test Binary message with Payload":
|
|||
)
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let res = await ws.recv()
|
||||
let res = await ws.recvMsg()
|
||||
check:
|
||||
res == testData
|
||||
ws.binary == true
|
||||
|
@ -921,7 +921,7 @@ suite "Test Binary message with Payload":
|
|||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
let res = await ws.recv()
|
||||
let res = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
ws.binary == true
|
||||
|
@ -941,7 +941,7 @@ suite "Test Binary message with Payload":
|
|||
)
|
||||
|
||||
await ws.send(testData, Opcode.Binary)
|
||||
let echoed = await ws.recv()
|
||||
let echoed = await ws.recvMsg()
|
||||
|
||||
check:
|
||||
echoed == testData
|
||||
|
@ -951,3 +951,61 @@ suite "Test Binary message with Payload":
|
|||
check:
|
||||
echoed == testData
|
||||
ws.binary == true
|
||||
|
||||
suite "Partial frames":
|
||||
teardown:
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
proc lowLevelRecv(
|
||||
senderFrameSize, receiverFrameSize, readChunkSize: int) {.async.} =
|
||||
|
||||
const
|
||||
howMuchWood = "How much wood could a wood chuck chuck ..."
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
|
||||
let
|
||||
server = WSServer.new(frameSize = receiverFrameSize)
|
||||
ws = await server.handleRequest(request)
|
||||
|
||||
var
|
||||
res = newSeq[byte](howMuchWood.len)
|
||||
pos = 0
|
||||
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let read = await ws.recv(addr res[pos], min(res.len - pos, readChunkSize))
|
||||
pos += read
|
||||
|
||||
if pos >= res.len:
|
||||
break
|
||||
|
||||
res.setlen(pos)
|
||||
check res.len == howMuchWood.toBytes().len
|
||||
check res == howMuchWood.toBytes()
|
||||
await ws.waitForClose()
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
|
||||
let session = await connectClient(
|
||||
address = address,
|
||||
frameSize = senderFrameSize)
|
||||
|
||||
await session.send(howMuchWood)
|
||||
await session.close()
|
||||
|
||||
test "read in chunks less than sender frameSize":
|
||||
await lowLevelRecv(7, 7, 5)
|
||||
|
||||
test "read in chunks greater than sender frameSize":
|
||||
await lowLevelRecv(3, 7, 5)
|
||||
|
||||
test "sender frameSize greater than receiver":
|
||||
await lowLevelRecv(7, 5, 5)
|
||||
|
||||
test "receiver frameSize greater than sender":
|
||||
await lowLevelRecv(7, 10, 5)
|
||||
|
|
|
@ -315,23 +315,9 @@ proc recv*(
|
|||
var consumed = 0
|
||||
var pbuffer = cast[ptr UncheckedArray[byte]](data)
|
||||
try:
|
||||
var first = true
|
||||
if not isNil(ws.frame):
|
||||
if ws.frame.fin and ws.frame.remainder > 0:
|
||||
trace "Continue reading from the same frame"
|
||||
first = true
|
||||
elif not ws.frame.fin and ws.frame.remainder > 0:
|
||||
trace "Restarting reads in the middle of a frame in a multiframe message"
|
||||
first = false
|
||||
elif ws.frame.fin and ws.frame.remainder <= 0:
|
||||
trace "Resetting an already consumed frame"
|
||||
ws.frame = nil
|
||||
elif not ws.frame.fin and ws.frame.remainder <= 0:
|
||||
trace "No more bytes left and message EOF, resetting frame"
|
||||
ws.frame = nil
|
||||
|
||||
if isNil(ws.frame):
|
||||
ws.frame = await ws.readFrame(ws.extensions)
|
||||
ws.first = true
|
||||
|
||||
while consumed < size:
|
||||
if isNil(ws.frame):
|
||||
|
@ -339,7 +325,7 @@ proc recv*(
|
|||
break
|
||||
|
||||
logScope:
|
||||
first = first
|
||||
first = ws.first
|
||||
fin = ws.frame.fin
|
||||
len = ws.frame.length
|
||||
consumed = ws.frame.consumed
|
||||
|
@ -347,12 +333,12 @@ proc recv*(
|
|||
opcode = ws.frame.opcode
|
||||
masked = ws.frame.mask
|
||||
|
||||
if first == (ws.frame.opcode == Opcode.Cont):
|
||||
if ws.first == (ws.frame.opcode == Opcode.Cont):
|
||||
error "Opcode mismatch!"
|
||||
raise newException(WSOpcodeMismatchError,
|
||||
&"Opcode mismatch: first: {first}, opcode: {ws.frame.opcode}")
|
||||
&"Opcode mismatch: first: {ws.first}, opcode: {ws.frame.opcode}")
|
||||
|
||||
if first:
|
||||
if ws.first:
|
||||
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
|
||||
trace "Setting binary flag"
|
||||
|
||||
|
@ -370,18 +356,15 @@ proc recv*(
|
|||
# all has been consumed from the frame
|
||||
# read the next frame
|
||||
if ws.frame.remainder <= 0:
|
||||
first = false
|
||||
ws.first = false
|
||||
|
||||
if ws.frame.fin: # we're at the end of the message, break
|
||||
trace "Read all frames, breaking"
|
||||
ws.frame = nil
|
||||
break
|
||||
|
||||
# read next frame
|
||||
ws.frame = await ws.readFrame(ws.extensions)
|
||||
|
||||
if not ws.binary and validateUTF8(pbuffer.toOpenArray(0, consumed - 1)) == false:
|
||||
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
|
||||
|
||||
except CatchableError as exc:
|
||||
trace "Exception reading frames", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
|
@ -396,42 +379,58 @@ proc recv*(
|
|||
|
||||
return consumed
|
||||
|
||||
proc recv*(
|
||||
proc recvMsg*(
|
||||
ws: WSSession,
|
||||
size = WSMaxMessageSize): Future[seq[byte]] {.async.} =
|
||||
## Attempt to read a full message up to max `size`
|
||||
## bytes in `frameSize` chunks.
|
||||
##
|
||||
## If no `fin` flag arrives await until either
|
||||
## cancelled or the `fin` flag arrives.
|
||||
## If no `fin` flag arrives await until cancelled or
|
||||
## closed.
|
||||
##
|
||||
## If message is larger than `size` a `WSMaxMessageSizeError`
|
||||
## exception is thrown.
|
||||
##
|
||||
## In all other cases it awaits a full message.
|
||||
##
|
||||
var res: seq[byte]
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
var buf = newSeq[byte](min(size, ws.frameSize))
|
||||
let read = await ws.recv(addr buf[0], buf.len)
|
||||
try:
|
||||
var res: seq[byte]
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
var buf = newSeq[byte](min(size, ws.frameSize))
|
||||
let read = await ws.recv(addr buf[0], buf.len)
|
||||
|
||||
buf.setLen(read)
|
||||
if res.len + buf.len > size:
|
||||
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
|
||||
buf.setLen(read)
|
||||
if res.len + buf.len > size:
|
||||
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
|
||||
|
||||
trace "Read message", size = read
|
||||
res.add(buf)
|
||||
trace "Read message", size = read
|
||||
res.add(buf)
|
||||
|
||||
# no more frames
|
||||
if isNil(ws.frame):
|
||||
break
|
||||
# no more frames
|
||||
if isNil(ws.frame):
|
||||
break
|
||||
|
||||
# read the entire message, exit
|
||||
if ws.frame.fin and ws.frame.remainder <= 0:
|
||||
trace "Read full message, breaking!"
|
||||
break
|
||||
# read the entire message, exit
|
||||
if ws.frame.fin and ws.frame.remainder <= 0:
|
||||
trace "Read full message, breaking!"
|
||||
break
|
||||
|
||||
return res
|
||||
if not ws.binary and validateUTF8(res.toOpenArray(0, res.high)) == false:
|
||||
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
|
||||
|
||||
return res
|
||||
except CatchableError as exc:
|
||||
trace "Exception reading message", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
|
||||
raise exc
|
||||
|
||||
proc recv*(
|
||||
ws: WSSession,
|
||||
size = WSMaxMessageSize): Future[seq[byte]]
|
||||
{.deprecated: "deprecated in favor of recvMsg()".} =
|
||||
ws.recvMsg(size)
|
||||
|
||||
proc close*(
|
||||
ws: WSSession,
|
||||
|
@ -451,6 +450,6 @@ proc close*(
|
|||
|
||||
# read frames until closed
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
discard await ws.recvMsg()
|
||||
except CatchableError as exc:
|
||||
trace "Exception closing", exc = exc.msg
|
||||
|
|
|
@ -79,11 +79,11 @@ type
|
|||
version*: uint
|
||||
key*: string
|
||||
readyState*: ReadyState
|
||||
masked*: bool # send masked packets
|
||||
binary*: bool # is payload binary?
|
||||
masked*: bool # send masked packets
|
||||
binary*: bool # is payload binary?
|
||||
flags*: set[TLSFlags]
|
||||
rng*: Rng
|
||||
frameSize*: int
|
||||
frameSize*: int # max frame buffer size
|
||||
onPing*: ControlCb
|
||||
onPong*: ControlCb
|
||||
onClose*: CloseCb
|
||||
|
@ -91,6 +91,7 @@ type
|
|||
WSSession* = ref object of WebSocket
|
||||
stream*: AsyncStream
|
||||
frame*: Frame
|
||||
first*: bool
|
||||
proto*: string
|
||||
|
||||
Ext* = ref object of RootObj
|
||||
|
|
Loading…
Reference in New Issue