Allow passing max message size (#800)

Co-authored-by: Tanguy <tanguy@status.im>
This commit is contained in:
Dmitriy Ryajov 2022-11-15 07:01:14 -06:00 committed by GitHub
parent ce371f3bb4
commit 8c2eca18dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 16 deletions

View File

@ -492,7 +492,7 @@ proc recvMessage(conn: StreamTransport): Future[seq[byte]] {.async.} =
res = PB.getUVarint(buffer.toOpenArray(0, i), length, size) res = PB.getUVarint(buffer.toOpenArray(0, i), length, size)
if res.isOk(): if res.isOk():
break break
if res.isErr() or size > MaxMessageSize: if res.isErr() or size > 1'u shl 22:
buffer.setLen(0) buffer.setLen(0)
result = buffer result = buffer
return return

View File

@ -19,8 +19,7 @@ export results, utility
{.push public.} {.push public.}
const const MaxMessageSize = 1'u shl 22
MaxMessageSize* = 1'u shl 22
type type
ProtoFieldKind* = enum ProtoFieldKind* = enum
@ -37,6 +36,7 @@ type
buffer*: seq[byte] buffer*: seq[byte]
offset*: int offset*: int
length*: int length*: int
maxSize*: uint
ProtoHeader* = object ProtoHeader* = object
wire*: ProtoFieldKind wire*: ProtoFieldKind
@ -122,23 +122,28 @@ proc vsizeof*(field: ProtoField): int {.inline.} =
0 0
proc initProtoBuffer*(data: seq[byte], offset = 0, proc initProtoBuffer*(data: seq[byte], offset = 0,
options: set[ProtoFlags] = {}): ProtoBuffer = options: set[ProtoFlags] = {},
maxSize = MaxMessageSize): ProtoBuffer =
## Initialize ProtoBuffer with shallow copy of ``data``. ## Initialize ProtoBuffer with shallow copy of ``data``.
result.buffer = data result.buffer = data
result.offset = offset result.offset = offset
result.options = options result.options = options
result.maxSize = maxSize
proc initProtoBuffer*(data: openArray[byte], offset = 0, proc initProtoBuffer*(data: openArray[byte], offset = 0,
options: set[ProtoFlags] = {}): ProtoBuffer = options: set[ProtoFlags] = {},
maxSize = MaxMessageSize): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``. ## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data result.buffer = @data
result.offset = offset result.offset = offset
result.options = options result.options = options
result.maxSize = maxSize
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = proc initProtoBuffer*(options: set[ProtoFlags] = {}, maxSize = MaxMessageSize): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``. ## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeq[byte]() result.buffer = newSeq[byte]()
result.options = options result.options = options
result.maxSize = maxSize
if WithVarintLength in options: if WithVarintLength in options:
# Our buffer will start from position 10, so we can store length of buffer # Our buffer will start from position 10, so we can store length of buffer
# in [0, 9]. # in [0, 9].
@ -335,7 +340,7 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
var bsize = 0'u64 var bsize = 0'u64
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length data.offset += length
if bsize <= uint64(MaxMessageSize): if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)): if data.isEnough(int(bsize)):
data.offset += int(bsize) data.offset += int(bsize)
ok() ok()
@ -399,7 +404,7 @@ proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
outLength = 0 outLength = 0
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length data.offset += length
if bsize <= uint64(MaxMessageSize): if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)): if data.isEnough(int(bsize)):
outLength = int(bsize) outLength = int(bsize)
if len(outBytes) >= int(bsize): if len(outBytes) >= int(bsize):
@ -427,7 +432,7 @@ proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length data.offset += length
if bsize <= uint64(MaxMessageSize): if bsize <= uint64(data.maxSize):
if data.isEnough(int(bsize)): if data.isEnough(int(bsize)):
outBytes.setLen(bsize) outBytes.setLen(bsize)
if bsize > 0'u64: if bsize > 0'u64:

View File

@ -623,18 +623,27 @@ suite "MinProtobuf test suite":
test "[length] too big message test": test "[length] too big message test":
var pb1 = initProtoBuffer() var pb1 = initProtoBuffer()
var bigString = newString(MaxMessageSize + 1) var bigString = newString(pb1.maxSize + 1)
for i in 0 ..< len(bigString): for i in 0 ..< len(bigString):
bigString[i] = 'A' bigString[i] = 'A'
pb1.write(1, bigString) pb1.write(1, bigString)
pb1.finish() pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer) block:
var value = newString(MaxMessageSize + 1) var pb2 = initProtoBuffer(pb1.buffer)
var valueLen = 0 var value = newString(pb1.maxSize + 1)
let res = pb2.getField(1, value, valueLen) var valueLen = 0
check: let res = pb2.getField(1, value, valueLen)
res.isErr() == true check:
res.isErr() == true
block:
var pb2 = initProtoBuffer(pb1.buffer, maxSize = uint.high)
var value = newString(pb1.maxSize + 1)
var valueLen = 0
let res = pb2.getField(1, value, valueLen)
check:
res.isErr() == false
test "[length] Repeated field test": test "[length] Repeated field test":
var pb1 = initProtoBuffer() var pb1 = initProtoBuffer()