Allow passing max message size (#800)
Co-authored-by: Tanguy <tanguy@status.im>
This commit is contained in:
parent
ce371f3bb4
commit
8c2eca18dc
|
@ -492,7 +492,7 @@ proc recvMessage(conn: StreamTransport): Future[seq[byte]] {.async.} =
|
||||||
res = PB.getUVarint(buffer.toOpenArray(0, i), length, size)
|
res = PB.getUVarint(buffer.toOpenArray(0, i), length, size)
|
||||||
if res.isOk():
|
if res.isOk():
|
||||||
break
|
break
|
||||||
if res.isErr() or size > MaxMessageSize:
|
if res.isErr() or size > 1'u shl 22:
|
||||||
buffer.setLen(0)
|
buffer.setLen(0)
|
||||||
result = buffer
|
result = buffer
|
||||||
return
|
return
|
||||||
|
|
|
@ -19,8 +19,7 @@ export results, utility
|
||||||
|
|
||||||
{.push public.}
|
{.push public.}
|
||||||
|
|
||||||
const
|
const MaxMessageSize = 1'u shl 22
|
||||||
MaxMessageSize* = 1'u shl 22
|
|
||||||
|
|
||||||
type
|
type
|
||||||
ProtoFieldKind* = enum
|
ProtoFieldKind* = enum
|
||||||
|
@ -37,6 +36,7 @@ type
|
||||||
buffer*: seq[byte]
|
buffer*: seq[byte]
|
||||||
offset*: int
|
offset*: int
|
||||||
length*: int
|
length*: int
|
||||||
|
maxSize*: uint
|
||||||
|
|
||||||
ProtoHeader* = object
|
ProtoHeader* = object
|
||||||
wire*: ProtoFieldKind
|
wire*: ProtoFieldKind
|
||||||
|
@ -122,23 +122,28 @@ proc vsizeof*(field: ProtoField): int {.inline.} =
|
||||||
0
|
0
|
||||||
|
|
||||||
proc initProtoBuffer*(data: seq[byte], offset = 0,
|
proc initProtoBuffer*(data: seq[byte], offset = 0,
|
||||||
options: set[ProtoFlags] = {}): ProtoBuffer =
|
options: set[ProtoFlags] = {},
|
||||||
|
maxSize = MaxMessageSize): ProtoBuffer =
|
||||||
## Initialize ProtoBuffer with shallow copy of ``data``.
|
## Initialize ProtoBuffer with shallow copy of ``data``.
|
||||||
result.buffer = data
|
result.buffer = data
|
||||||
result.offset = offset
|
result.offset = offset
|
||||||
result.options = options
|
result.options = options
|
||||||
|
result.maxSize = maxSize
|
||||||
|
|
||||||
proc initProtoBuffer*(data: openArray[byte], offset = 0,
|
proc initProtoBuffer*(data: openArray[byte], offset = 0,
|
||||||
options: set[ProtoFlags] = {}): ProtoBuffer =
|
options: set[ProtoFlags] = {},
|
||||||
|
maxSize = MaxMessageSize): ProtoBuffer =
|
||||||
## Initialize ProtoBuffer with copy of ``data``.
|
## Initialize ProtoBuffer with copy of ``data``.
|
||||||
result.buffer = @data
|
result.buffer = @data
|
||||||
result.offset = offset
|
result.offset = offset
|
||||||
result.options = options
|
result.options = options
|
||||||
|
result.maxSize = maxSize
|
||||||
|
|
||||||
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
|
proc initProtoBuffer*(options: set[ProtoFlags] = {}, maxSize = MaxMessageSize): ProtoBuffer =
|
||||||
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
|
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
|
||||||
result.buffer = newSeq[byte]()
|
result.buffer = newSeq[byte]()
|
||||||
result.options = options
|
result.options = options
|
||||||
|
result.maxSize = maxSize
|
||||||
if WithVarintLength in options:
|
if WithVarintLength in options:
|
||||||
# Our buffer will start from position 10, so we can store length of buffer
|
# Our buffer will start from position 10, so we can store length of buffer
|
||||||
# in [0, 9].
|
# in [0, 9].
|
||||||
|
@ -335,7 +340,7 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
|
||||||
var bsize = 0'u64
|
var bsize = 0'u64
|
||||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||||
data.offset += length
|
data.offset += length
|
||||||
if bsize <= uint64(MaxMessageSize):
|
if bsize <= uint64(data.maxSize):
|
||||||
if data.isEnough(int(bsize)):
|
if data.isEnough(int(bsize)):
|
||||||
data.offset += int(bsize)
|
data.offset += int(bsize)
|
||||||
ok()
|
ok()
|
||||||
|
@ -399,7 +404,7 @@ proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
|
||||||
outLength = 0
|
outLength = 0
|
||||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||||
data.offset += length
|
data.offset += length
|
||||||
if bsize <= uint64(MaxMessageSize):
|
if bsize <= uint64(data.maxSize):
|
||||||
if data.isEnough(int(bsize)):
|
if data.isEnough(int(bsize)):
|
||||||
outLength = int(bsize)
|
outLength = int(bsize)
|
||||||
if len(outBytes) >= int(bsize):
|
if len(outBytes) >= int(bsize):
|
||||||
|
@ -427,7 +432,7 @@ proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
|
||||||
|
|
||||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||||
data.offset += length
|
data.offset += length
|
||||||
if bsize <= uint64(MaxMessageSize):
|
if bsize <= uint64(data.maxSize):
|
||||||
if data.isEnough(int(bsize)):
|
if data.isEnough(int(bsize)):
|
||||||
outBytes.setLen(bsize)
|
outBytes.setLen(bsize)
|
||||||
if bsize > 0'u64:
|
if bsize > 0'u64:
|
||||||
|
|
|
@ -623,18 +623,27 @@ suite "MinProtobuf test suite":
|
||||||
|
|
||||||
test "[length] too big message test":
|
test "[length] too big message test":
|
||||||
var pb1 = initProtoBuffer()
|
var pb1 = initProtoBuffer()
|
||||||
var bigString = newString(MaxMessageSize + 1)
|
var bigString = newString(pb1.maxSize + 1)
|
||||||
|
|
||||||
for i in 0 ..< len(bigString):
|
for i in 0 ..< len(bigString):
|
||||||
bigString[i] = 'A'
|
bigString[i] = 'A'
|
||||||
pb1.write(1, bigString)
|
pb1.write(1, bigString)
|
||||||
pb1.finish()
|
pb1.finish()
|
||||||
var pb2 = initProtoBuffer(pb1.buffer)
|
block:
|
||||||
var value = newString(MaxMessageSize + 1)
|
var pb2 = initProtoBuffer(pb1.buffer)
|
||||||
var valueLen = 0
|
var value = newString(pb1.maxSize + 1)
|
||||||
let res = pb2.getField(1, value, valueLen)
|
var valueLen = 0
|
||||||
check:
|
let res = pb2.getField(1, value, valueLen)
|
||||||
res.isErr() == true
|
check:
|
||||||
|
res.isErr() == true
|
||||||
|
|
||||||
|
block:
|
||||||
|
var pb2 = initProtoBuffer(pb1.buffer, maxSize = uint.high)
|
||||||
|
var value = newString(pb1.maxSize + 1)
|
||||||
|
var valueLen = 0
|
||||||
|
let res = pb2.getField(1, value, valueLen)
|
||||||
|
check:
|
||||||
|
res.isErr() == false
|
||||||
|
|
||||||
test "[length] Repeated field test":
|
test "[length] Repeated field test":
|
||||||
var pb1 = initProtoBuffer()
|
var pb1 = initProtoBuffer()
|
||||||
|
|
Loading…
Reference in New Issue