diff --git a/storage/blockexchange/engine/downloadcontext.nim b/storage/blockexchange/engine/downloadcontext.nim index 67a8452c..346b530d 100644 --- a/storage/blockexchange/engine/downloadcontext.nim +++ b/storage/blockexchange/engine/downloadcontext.nim @@ -28,6 +28,8 @@ const PresenceWindowBytes*: uint64 = 1024 * 1024 * 1024 PresenceWindowBlocks*: uint64 = PresenceWindowBytes div DefaultBlockSize.uint64 MaxPresenceWindowBlocks*: uint64 = PresenceWindowBytes div MinBlockSize + MaxRangeIterationsPerMessage*: uint64 = + PresenceWindowBytes div DefaultBlockSize.uint64 PresenceWindowThreshold*: float = 0.75 PresenceBroadcastIntervalMin*: Duration = 5.seconds PresenceBroadcastIntervalMax*: Duration = 10.seconds diff --git a/storage/blockexchange/engine/engine.nim b/storage/blockexchange/engine/engine.nim index cc190844..4b48e420 100644 --- a/storage/blockexchange/engine/engine.nim +++ b/storage/blockexchange/engine/engine.nim @@ -941,7 +941,17 @@ proc wantListHandler*( if peerCtx.isNil: return - var presence: seq[BlockPresence] + if peerCtx.wantListBusy: + debug "Dropping want list, handler already in flight for peer", peer + return + + peerCtx.wantListBusy = true + defer: + peerCtx.wantListBusy = false + + var + presence: seq[BlockPresence] + iterBudget: uint64 = MaxRangeIterationsPerMessage try: for e in wantList.entries: @@ -958,8 +968,10 @@ proc wantListHandler*( peer = peer, treeCid = treeCid, count = count, max = MaxPresenceWindowBlocks continue + let effectiveCount = min(count, iterBudget) + trace "Processing range query", - treeCid = treeCid, start = startIdx, count = count + treeCid = treeCid, start = startIdx, count = effectiveCount let runtimeQuota = 100.milliseconds var @@ -968,7 +980,7 @@ proc wantListHandler*( inRange = false lastIdle = Moment.now() - for i in 0'u64 ..< count: + for i in 0'u64 ..< effectiveCount: if (Moment.now() - lastIdle) >= runtimeQuota: await idleAsync() lastIdle = Moment.now() @@ -990,7 +1002,9 @@ proc wantListHandler*( inRange = false if inRange: - ranges.add((rangeStart, (startIdx + count) - rangeStart)) + ranges.add((rangeStart, (startIdx + effectiveCount) - rangeStart)) + + iterBudget -= effectiveCount if ranges.len > 0: trace "Have blocks in range", treeCid = treeCid, ranges = ranges @@ -1003,7 +1017,8 @@ proc wantListHandler*( ) ) else: - trace "Don't have range", treeCid = treeCid, start = startIdx, count = count + trace "Don't have range", + treeCid = treeCid, start = startIdx, count = effectiveCount if e.sendDontHave: presence.add( BlockPresence( diff --git a/storage/blockexchange/peers/peercontext.nim b/storage/blockexchange/peers/peercontext.nim index d81a4161..fcfaf054 100644 --- a/storage/blockexchange/peers/peercontext.nim +++ b/storage/blockexchange/peers/peercontext.nim @@ -48,6 +48,7 @@ static: type PeerContext* = ref object of RootObj id*: PeerId stats*: PeerPerfStats + wantListBusy*: bool proc new*(T: type PeerContext, id: PeerId): PeerContext = PeerContext(id: id, stats: PeerPerfStats.new()) diff --git a/storage/blockexchange/protocol/constants.nim b/storage/blockexchange/protocol/constants.nim index d0448bbb..6e6ebf50 100644 --- a/storage/blockexchange/protocol/constants.nim +++ b/storage/blockexchange/protocol/constants.nim @@ -22,6 +22,11 @@ const TargetBatchBytes*: uint32 = 1024 * 1024 MinBatchSize*: uint32 = 1 + # caps the number of entries decoded from a single WantList/blockPresences + # repeated field, independent of MaxMessageSize, to bound per-message CPU/disk work + MaxWantListEntries*: int = 1024 + MaxBlockPresenceEntries*: int = 1024 + MaxMetadataSize*: uint32 = 4 * 1024 * 1024 MaxWantBlocksResponseBytes*: uint32 = 4 + MaxMetadataSize + TargetBatchBytes MaxBlocksPerBatch*: uint32 = TargetBatchBytes div MinBlockSize.uint32 diff --git a/storage/blockexchange/protocol/message.nim b/storage/blockexchange/protocol/message.nim index 0c49cd4b..ffd2a2b6 100644 --- a/storage/blockexchange/protocol/message.nim +++ b/storage/blockexchange/protocol/message.nim @@ -11,6 +11,7 @@ import pkg/questionable import ../../merkletree import ../../blocktype +import ./constants type WantType* = enum @@ -147,6 +148,8 @@ proc decode*(_: type WantList, pb: ProtoBuffer): ProtoResult[WantList] = field: uint64 sublist: seq[seq[byte]] if ?pb.getRepeatedField(1, sublist): + if sublist.len > MaxWantListEntries: + return err(ProtoError.BufferOverflow) for item in sublist: value.entries.add(?WantListEntry.decode(initProtoBuffer(item))) if ?pb.getField(2, field): @@ -183,6 +186,8 @@ proc protobufDecode*(_: type Message, msg: seq[byte]): ProtoResult[Message] = if ?pb.getField(1, ipb): value.wantList = ?WantList.decode(ipb) if ?pb.getRepeatedField(4, sublist): + if sublist.len > MaxBlockPresenceEntries: + return err(ProtoError.BufferOverflow) for item in sublist: value.blockPresences.add(?BlockPresence.decode(initProtoBuffer(item))) ok(value) diff --git a/tests/storage/blockexchange/protocol/testmessage.nim b/tests/storage/blockexchange/protocol/testmessage.nim index 98e0a27d..942e2fa0 100644 --- a/tests/storage/blockexchange/protocol/testmessage.nim +++ b/tests/storage/blockexchange/protocol/testmessage.nim @@ -1,5 +1,8 @@ +import std/sequtils + import pkg/unittest2 +import pkg/storage/blockexchange/protocol/constants import pkg/storage/blockexchange/protocol/message import ../../examples @@ -148,6 +151,28 @@ suite "WantList protobuf encoding": check res.get.entries[1].sendDontHave == true check res.get.full == true + test "Should reject WantList with too many entries": + let + treeCid = Cid.example + wantList = WantList( + entries: newSeqWith( + MaxWantListEntries + 1, + WantListEntry(address: BlockAddress(treeCid: treeCid, index: 0)), + ), + full: false, + ) + + var buffer = initProtoBuffer() + buffer.write(1, wantList) + buffer.finish() + + var decoded: ProtoBuffer + check buffer.getField(1, decoded).isOk + + let res = WantList.decode(decoded) + check res.isErr + check res.error == ProtoError.BufferOverflow + suite "BlockPresence protobuf encoding": test "Should encode and decode BlockPresence with DontHave": let @@ -273,3 +298,23 @@ suite "Full Message protobuf encoding": check decoded.get.blockPresences[0].kind == BlockPresenceType.HaveRange check decoded.get.blockPresences[0].ranges.len == 1 check decoded.get.blockPresences[0].ranges[0].count == 500 + + test "Should reject Message with too many blockPresences": + let + treeCid = Cid.example + msg = Message( + wantList: WantList(entries: @[], full: false), + blockPresences: newSeqWith( + MaxBlockPresenceEntries + 1, + BlockPresence( + address: BlockAddress(treeCid: treeCid, index: 0), + kind: BlockPresenceType.DontHave, + ranges: @[], + ), + ), + ) + encoded = msg.protobufEncode() + decoded = Message.protobufDecode(encoded) + + check decoded.isErr + check decoded.error == ProtoError.BufferOverflow