diff --git a/libp2p/daemon/daemonapi.nim b/libp2p/daemon/daemonapi.nim index 4e802c434..05f764756 100644 --- a/libp2p/daemon/daemonapi.nim +++ b/libp2p/daemon/daemonapi.nim @@ -492,7 +492,7 @@ proc recvMessage(conn: StreamTransport): Future[seq[byte]] {.async.} = res = PB.getUVarint(buffer.toOpenArray(0, i), length, size) if res.isOk(): break - if res.isErr() or size > MaxMessageSize: + if res.isErr() or size > 1'u shl 22: buffer.setLen(0) result = buffer return diff --git a/libp2p/protobuf/minprotobuf.nim b/libp2p/protobuf/minprotobuf.nim index 57223b0e8..5305905f6 100644 --- a/libp2p/protobuf/minprotobuf.nim +++ b/libp2p/protobuf/minprotobuf.nim @@ -19,8 +19,7 @@ export results, utility {.push public.} -const - MaxMessageSize* = 1'u shl 22 +const MaxMessageSize = 1'u shl 22 type ProtoFieldKind* = enum @@ -37,6 +36,7 @@ type buffer*: seq[byte] offset*: int length*: int + maxSize*: uint ProtoHeader* = object wire*: ProtoFieldKind @@ -122,23 +122,28 @@ proc vsizeof*(field: ProtoField): int {.inline.} = 0 proc initProtoBuffer*(data: seq[byte], offset = 0, - options: set[ProtoFlags] = {}): ProtoBuffer = + options: set[ProtoFlags] = {}, + maxSize = MaxMessageSize): ProtoBuffer = ## Initialize ProtoBuffer with shallow copy of ``data``. result.buffer = data result.offset = offset result.options = options + result.maxSize = maxSize proc initProtoBuffer*(data: openArray[byte], offset = 0, - options: set[ProtoFlags] = {}): ProtoBuffer = + options: set[ProtoFlags] = {}, + maxSize = MaxMessageSize): ProtoBuffer = ## Initialize ProtoBuffer with copy of ``data``. result.buffer = @data result.offset = offset result.options = options + result.maxSize = maxSize -proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = +proc initProtoBuffer*(options: set[ProtoFlags] = {}, maxSize = MaxMessageSize): ProtoBuffer = ## Initialize ProtoBuffer with new sequence of capacity ``cap``. result.buffer = newSeq[byte]() result.options = options + result.maxSize = maxSize if WithVarintLength in options: # Our buffer will start from position 10, so we can store length of buffer # in [0, 9]. @@ -335,7 +340,7 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] = var bsize = 0'u64 if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): data.offset += length - if bsize <= uint64(MaxMessageSize): + if bsize <= uint64(data.maxSize): if data.isEnough(int(bsize)): data.offset += int(bsize) ok() @@ -399,7 +404,7 @@ proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader, outLength = 0 if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): data.offset += length - if bsize <= uint64(MaxMessageSize): + if bsize <= uint64(data.maxSize): if data.isEnough(int(bsize)): outLength = int(bsize) if len(outBytes) >= int(bsize): @@ -427,7 +432,7 @@ proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader, if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): data.offset += length - if bsize <= uint64(MaxMessageSize): + if bsize <= uint64(data.maxSize): if data.isEnough(int(bsize)): outBytes.setLen(bsize) if bsize > 0'u64: diff --git a/tests/testminprotobuf.nim b/tests/testminprotobuf.nim index fae8fd742..2ddfa4c61 100644 --- a/tests/testminprotobuf.nim +++ b/tests/testminprotobuf.nim @@ -623,18 +623,27 @@ suite "MinProtobuf test suite": test "[length] too big message test": var pb1 = initProtoBuffer() - var bigString = newString(MaxMessageSize + 1) + var bigString = newString(pb1.maxSize + 1) for i in 0 ..< len(bigString): bigString[i] = 'A' pb1.write(1, bigString) pb1.finish() - var pb2 = initProtoBuffer(pb1.buffer) - var value = newString(MaxMessageSize + 1) - var valueLen = 0 - let res = pb2.getField(1, value, valueLen) - check: - res.isErr() == true + block: + var pb2 = initProtoBuffer(pb1.buffer) + var value = newString(pb1.maxSize + 1) + var valueLen = 0 + let res = pb2.getField(1, value, valueLen) + check: + res.isErr() == true + + block: + var pb2 = initProtoBuffer(pb1.buffer, maxSize = uint.high) + var value = newString(pb1.maxSize + 1) + var valueLen = 0 + let res = pb2.getField(1, value, valueLen) + check: + res.isErr() == false test "[length] Repeated field test": var pb1 = initProtoBuffer()