Perform utf-8 validation at message boundaries (#90)

* validate utf8 at the message level

* move utf-8 validation to message

* rename `recv` to `recvMsg`

* add partial frame validation tests

* use `recvMsg` instead of `recv`
This commit is contained in:
Dmitriy Ryajov 2021-08-04 10:23:56 -06:00 committed by GitHub
parent 00440b6eff
commit 0ec755738c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 143 additions and 85 deletions

View File

@ -43,7 +43,7 @@ proc getCaseCount(): Future[int] {.async.} =
block:
try:
let ws = await connectServer("/getCaseCount")
let buff = await ws.recv()
let buff = await ws.recvMsg()
let dataStr = string.fromBytes(buff)
caseCount = parseInt(dataStr)
await ws.close()
@ -60,7 +60,7 @@ proc generateReport() {.async.} =
trace "request autobahn server to generate report"
let ws = await connectServer("/updateReports?agent=" & agent)
while true:
let buff = await ws.recv()
let buff = await ws.recvMsg()
if buff.len <= 0:
break
await ws.close()
@ -80,7 +80,7 @@ proc main() {.async.} =
while ws.readystate != ReadyState.Closed:
# echo back
let data = await ws.recv()
let data = await ws.recvMsg()
let opCode = if ws.binary:
Opcode.Binary
else:

View File

@ -28,7 +28,7 @@ proc handle(request: HttpRequest) {.async.} =
trace "Websocket handshake completed"
while ws.readyState != ReadyState.Closed:
let recvData = await ws.recv()
let recvData = await ws.recvMsg()
trace "Client Response: ", size = recvData.len, binary = ws.binary
if ws.readyState == ReadyState.Closed:

View File

@ -35,7 +35,7 @@ suite "permessage deflate compression":
let ws = await server.handleRequest(request)
while ws.readyState != ReadyState.Closed:
let recvData = await ws.recv()
let recvData = await ws.recvMsg()
if ws.readyState == ReadyState.Closed:
break
await ws.send(recvData,
@ -58,7 +58,7 @@ suite "permessage deflate compression":
var recvData: seq[byte]
while recvData.len < textData.len:
let res = await client.recv()
let res = await client.recvMsg()
recvData.add res
if client.readyState == ReadyState.Closed:
break
@ -75,7 +75,7 @@ suite "permessage deflate compression":
)
let ws = await server.handleRequest(request)
while ws.readyState != ReadyState.Closed:
let recvData = await ws.recv()
let recvData = await ws.recvMsg()
if ws.readyState == ReadyState.Closed:
break
await ws.send(recvData,
@ -98,7 +98,7 @@ suite "permessage deflate compression":
var recvData: seq[byte]
while recvData.len < binaryData.len:
let res = await client.recv()
let res = await client.recvMsg()
recvData.add res
if client.readyState == ReadyState.Closed:
break

View File

@ -30,7 +30,7 @@ suite "multiple extensions flow":
factories = [hexFactory, base64Factory],
)
let ws = await server.handleRequest(request)
let recvData = await ws.recv()
let recvData = await ws.recvMsg()
await ws.send(recvData,
if ws.binary: Opcode.Binary else: Opcode.Text)
@ -50,7 +50,7 @@ suite "multiple extensions flow":
)
await client.send(testData)
let res = await client.recv()
let res = await client.recvMsg()
check testData.toBytes() == res
await client.close()
@ -62,7 +62,7 @@ suite "multiple extensions flow":
factories = [hexFactory, base64Factory],
)
let ws = await server.handleRequest(request)
let recvData = await ws.recv()
let recvData = await ws.recvMsg()
await ws.send(recvData,
if ws.binary: Opcode.Binary else: Opcode.Text)
@ -82,6 +82,6 @@ suite "multiple extensions flow":
)
await client.send(testData)
let res = await client.recv()
let res = await client.recvMsg()
check testData.toBytes() == res
await client.close()

View File

@ -37,7 +37,7 @@ proc rndBin*(size: int): seq[byte] =
proc waitForClose*(ws: WSSession) {.async.} =
try:
while ws.readystate != ReadyState.Closed:
discard await ws.recv()
discard await ws.recvMsg()
except CatchableError:
trace "Closing websocket"

View File

@ -74,7 +74,7 @@ suite "UTF-8 DFA validator":
proc waitForClose(ws: WSSession) {.async.} =
try:
while ws.readystate != ReadyState.Closed:
discard await ws.recv()
discard await ws.recvMsg()
except CatchableError:
trace "Closing websocket"
@ -96,7 +96,7 @@ suite "UTF-8 validator in action":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let res = await ws.recv()
let res = await ws.recvMsg()
check:
string.fromBytes(res) == testData
ws.binary == false
@ -135,7 +135,7 @@ suite "UTF-8 validator in action":
let server = WSServer.new(protos = ["proto"], onClose = onClose)
let ws = await server.handleRequest(request)
let res = await ws.recv()
let res = await ws.recvMsg()
await waitForClose(ws)
check:
@ -182,7 +182,7 @@ suite "UTF-8 validator in action":
)
expect WSInvalidUTF8:
let data = await session.recv()
let data = await session.recvMsg()
test "invalid UTF-8 sequence close code":
let closeReason = "i want to close\xc0\xaf"
@ -207,4 +207,4 @@ suite "UTF-8 validator in action":
)
expect WSInvalidUTF8:
let data = await session.recv()
let data = await session.recvMsg()

View File

@ -109,7 +109,7 @@ suite "Test transmission":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
let servRes = await ws.recvMsg()
check string.fromBytes(servRes) == testString
await ws.waitForClose()
@ -129,7 +129,7 @@ suite "Test transmission":
check request.uri.path == WSPath
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
let servRes = await ws.recvMsg()
check string.fromBytes(servRes) == testString
await ws.waitForClose()
@ -158,7 +158,7 @@ suite "Test transmission":
flags = {ReuseAddr})
let session = await connectClient()
var clientRes = await session.recv()
var clientRes = await session.recvMsg()
check string.fromBytes(clientRes) == testString
await waitForClose(session)
@ -265,7 +265,7 @@ suite "Test ping-pong":
let ws = await server.handleRequest(request)
expect WSPayloadTooLarge:
discard await ws.recv()
discard await ws.recvMsg()
await waitForClose(ws)
@ -334,7 +334,7 @@ suite "Test framing":
let session = await connectClient()
expect WSMaxMessageSizeError:
discard await session.recv(5)
discard await session.recvMsg(5)
await waitForClose(session)
suite "Test Closing":
@ -575,7 +575,7 @@ suite "Test Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
let servRes = await ws.recvMsg()
check:
servRes.len == 0
@ -591,7 +591,7 @@ suite "Test Payload":
let session = await connectClient()
await session.send(emptyStr)
let clientRes = await session.recv()
let clientRes = await session.recvMsg()
check:
clientRes.len == 0
@ -607,7 +607,7 @@ suite "Test Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
for _ in 0..<3:
let servRes = await ws.recv()
let servRes = await ws.recvMsg()
check:
servRes.len == 0
@ -629,7 +629,7 @@ suite "Test Payload":
await session.send(emptyStr)
for _ in 0..<3:
let clientRes = await session.recv()
let clientRes = await session.recvMsg()
check:
clientRes.len == 0
@ -648,7 +648,7 @@ suite "Test Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let respData = await ws.recv()
let respData = await ws.recvMsg()
check:
string.fromBytes(respData) == testString
@ -706,7 +706,7 @@ suite "Test Payload":
)
let ws = await server.handleRequest(request)
let respData = await ws.recv()
let respData = await ws.recvMsg()
check:
string.fromBytes(respData) == testString
@ -764,7 +764,7 @@ suite "Test Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let res = await ws.recv()
let res = await ws.recvMsg()
check ws.binary == false
await ws.send(res, Opcode.Text)
@ -781,7 +781,7 @@ suite "Test Payload":
)
await ws.send(testData)
let echoed = await ws.recv()
let echoed = await ws.recvMsg()
await ws.close()
check:
@ -800,7 +800,7 @@ suite "Test Binary message with Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
let servRes = await ws.recvMsg()
check:
servRes == emptyData
@ -825,7 +825,7 @@ suite "Test Binary message with Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv()
let servRes = await ws.recvMsg()
check:
servRes == emptyData
@ -858,7 +858,7 @@ suite "Test Binary message with Payload":
)
let ws = await server.handleRequest(request)
let res = await ws.recv()
let res = await ws.recvMsg()
check:
res == testData
ws.binary == true
@ -892,7 +892,7 @@ suite "Test Binary message with Payload":
)
let ws = await server.handleRequest(request)
let res = await ws.recv()
let res = await ws.recvMsg()
check:
res == testData
ws.binary == true
@ -921,7 +921,7 @@ suite "Test Binary message with Payload":
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let res = await ws.recv()
let res = await ws.recvMsg()
check:
ws.binary == true
@ -941,7 +941,7 @@ suite "Test Binary message with Payload":
)
await ws.send(testData, Opcode.Binary)
let echoed = await ws.recv()
let echoed = await ws.recvMsg()
check:
echoed == testData
@ -951,3 +951,61 @@ suite "Test Binary message with Payload":
check:
echoed == testData
ws.binary == true
suite "Partial frames":
teardown:
server.stop()
await server.closeWait()
proc lowLevelRecv(
senderFrameSize, receiverFrameSize, readChunkSize: int) {.async.} =
const
howMuchWood = "How much wood could a wood chuck chuck ..."
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let
server = WSServer.new(frameSize = receiverFrameSize)
ws = await server.handleRequest(request)
var
res = newSeq[byte](howMuchWood.len)
pos = 0
while ws.readyState != ReadyState.Closed:
let read = await ws.recv(addr res[pos], min(res.len - pos, readChunkSize))
pos += read
if pos >= res.len:
break
res.setlen(pos)
check res.len == howMuchWood.toBytes().len
check res == howMuchWood.toBytes()
await ws.waitForClose()
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
let session = await connectClient(
address = address,
frameSize = senderFrameSize)
await session.send(howMuchWood)
await session.close()
test "read in chunks less than sender frameSize":
await lowLevelRecv(7, 7, 5)
test "read in chunks greater than sender frameSize":
await lowLevelRecv(3, 7, 5)
test "sender frameSize greater than receiver":
await lowLevelRecv(7, 5, 5)
test "receiver frameSize greater than sender":
await lowLevelRecv(7, 10, 5)

View File

@ -315,23 +315,9 @@ proc recv*(
var consumed = 0
var pbuffer = cast[ptr UncheckedArray[byte]](data)
try:
var first = true
if not isNil(ws.frame):
if ws.frame.fin and ws.frame.remainder > 0:
trace "Continue reading from the same frame"
first = true
elif not ws.frame.fin and ws.frame.remainder > 0:
trace "Restarting reads in the middle of a frame in a multiframe message"
first = false
elif ws.frame.fin and ws.frame.remainder <= 0:
trace "Resetting an already consumed frame"
ws.frame = nil
elif not ws.frame.fin and ws.frame.remainder <= 0:
trace "No more bytes left and message EOF, resetting frame"
ws.frame = nil
if isNil(ws.frame):
ws.frame = await ws.readFrame(ws.extensions)
ws.first = true
while consumed < size:
if isNil(ws.frame):
@ -339,7 +325,7 @@ proc recv*(
break
logScope:
first = first
first = ws.first
fin = ws.frame.fin
len = ws.frame.length
consumed = ws.frame.consumed
@ -347,12 +333,12 @@ proc recv*(
opcode = ws.frame.opcode
masked = ws.frame.mask
if first == (ws.frame.opcode == Opcode.Cont):
if ws.first == (ws.frame.opcode == Opcode.Cont):
error "Opcode mismatch!"
raise newException(WSOpcodeMismatchError,
&"Opcode mismatch: first: {first}, opcode: {ws.frame.opcode}")
&"Opcode mismatch: first: {ws.first}, opcode: {ws.frame.opcode}")
if first:
if ws.first:
ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag
trace "Setting binary flag"
@ -370,18 +356,15 @@ proc recv*(
# all has been consumed from the frame
# read the next frame
if ws.frame.remainder <= 0:
first = false
ws.first = false
if ws.frame.fin: # we're at the end of the message, break
trace "Read all frames, breaking"
ws.frame = nil
break
# read next frame
ws.frame = await ws.readFrame(ws.extensions)
if not ws.binary and validateUTF8(pbuffer.toOpenArray(0, consumed - 1)) == false:
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
except CatchableError as exc:
trace "Exception reading frames", exc = exc.msg
ws.readyState = ReadyState.Closed
@ -396,42 +379,58 @@ proc recv*(
return consumed
proc recv*(
proc recvMsg*(
ws: WSSession,
size = WSMaxMessageSize): Future[seq[byte]] {.async.} =
## Attempt to read a full message up to max `size`
## bytes in `frameSize` chunks.
##
## If no `fin` flag arrives await until either
## cancelled or the `fin` flag arrives.
## If no `fin` flag arrives await until cancelled or
## closed.
##
## If message is larger than `size` a `WSMaxMessageSizeError`
## exception is thrown.
##
## In all other cases it awaits a full message.
##
var res: seq[byte]
while ws.readyState != ReadyState.Closed:
var buf = newSeq[byte](min(size, ws.frameSize))
let read = await ws.recv(addr buf[0], buf.len)
try:
var res: seq[byte]
while ws.readyState != ReadyState.Closed:
var buf = newSeq[byte](min(size, ws.frameSize))
let read = await ws.recv(addr buf[0], buf.len)
buf.setLen(read)
if res.len + buf.len > size:
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
buf.setLen(read)
if res.len + buf.len > size:
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
trace "Read message", size = read
res.add(buf)
trace "Read message", size = read
res.add(buf)
# no more frames
if isNil(ws.frame):
break
# no more frames
if isNil(ws.frame):
break
# read the entire message, exit
if ws.frame.fin and ws.frame.remainder <= 0:
trace "Read full message, breaking!"
break
# read the entire message, exit
if ws.frame.fin and ws.frame.remainder <= 0:
trace "Read full message, breaking!"
break
return res
if not ws.binary and validateUTF8(res.toOpenArray(0, res.high)) == false:
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
return res
except CatchableError as exc:
trace "Exception reading message", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
raise exc
proc recv*(
ws: WSSession,
size = WSMaxMessageSize): Future[seq[byte]]
{.deprecated: "deprecated in favor of recvMsg()".} =
ws.recvMsg(size)
proc close*(
ws: WSSession,
@ -451,6 +450,6 @@ proc close*(
# read frames until closed
while ws.readyState != ReadyState.Closed:
discard await ws.recv()
discard await ws.recvMsg()
except CatchableError as exc:
trace "Exception closing", exc = exc.msg

View File

@ -79,11 +79,11 @@ type
version*: uint
key*: string
readyState*: ReadyState
masked*: bool # send masked packets
binary*: bool # is payload binary?
masked*: bool # send masked packets
binary*: bool # is payload binary?
flags*: set[TLSFlags]
rng*: Rng
frameSize*: int
frameSize*: int # max frame buffer size
onPing*: ControlCb
onPong*: ControlCb
onClose*: CloseCb
@ -91,6 +91,7 @@ type
WSSession* = ref object of WebSocket
stream*: AsyncStream
frame*: Frame
first*: bool
proto*: string
Ext* = ref object of RootObj