From 02c96fc003fe34510a9447412808c534b6a79c87 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Wed, 1 May 2024 18:38:24 +0200 Subject: [PATCH] Improve memory efficiency of seen cache (#1073) --- libp2p/protocols/pubsub/floodsub.nim | 47 ++++++++++----- libp2p/protocols/pubsub/gossipsub.nim | 12 ++-- .../protocols/pubsub/gossipsub/behavior.nim | 2 +- libp2p/protocols/pubsub/gossipsub/types.nim | 2 +- libp2p/protocols/pubsub/rpc/messages.nim | 6 ++ libp2p/protocols/pubsub/timedcache.nim | 59 ++++++++++++++----- tests/pubsub/testgossipsub.nim | 4 +- tests/pubsub/testtimedcache.nim | 19 ++++++ 8 files changed, 109 insertions(+), 42 deletions(-) diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index 4a7d1a13b..0b4d886f4 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -16,6 +16,7 @@ import ./pubsub, ./timedcache, ./peertable, ./rpc/[message, messages, protobuf], + nimcrypto/[hash, sha2], ../../crypto/crypto, ../../stream/connection, ../../peerid, @@ -32,20 +33,29 @@ const FloodSubCodec* = "/floodsub/1.0.0" type FloodSub* {.public.} = ref object of PubSub floodsub*: PeerTable # topic to remote peer map - seen*: TimedCache[MessageId] # message id:s already seen on the network - seenSalt*: seq[byte] + seen*: TimedCache[SaltedId] + # Early filter for messages recently observed on the network + # We use a salted id because the messages in this cache have not yet + # been validated meaning that an attacker has greater control over the + # hash key and therefore could poison the table + seenSalt*: sha256 + # The salt in this case is a partially updated SHA256 context pre-seeded + # with some random data -proc hasSeen*(f: FloodSub, msgId: MessageId): bool = - f.seenSalt & msgId in f.seen +proc salt*(f: FloodSub, msgId: MessageId): SaltedId = + var tmp = f.seenSalt + tmp.update(msgId) + SaltedId(data: tmp.finish()) -proc addSeen*(f: FloodSub, msgId: MessageId): bool = - # Salting the seen hash helps avoid attacks against the hash function used - # in the nim hash table +proc hasSeen*(f: FloodSub, saltedId: SaltedId): bool = + saltedId in f.seen + +proc addSeen*(f: FloodSub, saltedId: SaltedId): bool = # Return true if the message has already been seen - f.seen.put(f.seenSalt & msgId) + f.seen.put(saltedId) -proc firstSeen*(f: FloodSub, msgId: MessageId): Moment = - f.seen.addedAt(f.seenSalt & msgId) +proc firstSeen*(f: FloodSub, saltedId: SaltedId): Moment = + f.seen.addedAt(saltedId) proc handleSubscribe*(f: FloodSub, peer: PubSubPeer, @@ -117,9 +127,11 @@ method rpcHandler*(f: FloodSub, # TODO: descore peers due to error during message validation (malicious?) continue - let msgId = msgIdResult.get + let + msgId = msgIdResult.get + saltedId = f.salt(msgId) - if f.addSeen(msgId): + if f.addSeen(saltedId): trace "Dropping already-seen message", msgId, peer continue @@ -216,7 +228,7 @@ method publish*(f: FloodSub, trace "Created new message", msg = shortLog(msg), peers = peers.len, topic, msgId - if f.addSeen(msgId): + if f.addSeen(f.salt(msgId)): # custom msgid providers might cause this trace "Dropping already-seen message", msgId, topic return 0 @@ -234,8 +246,11 @@ method publish*(f: FloodSub, method initPubSub*(f: FloodSub) {.raises: [InitializationError].} = procCall PubSub(f).initPubSub() - f.seen = TimedCache[MessageId].init(2.minutes) - f.seenSalt = newSeqUninitialized[byte](sizeof(Hash)) - hmacDrbgGenerate(f.rng[], f.seenSalt) + f.seen = TimedCache[SaltedId].init(2.minutes) + f.seenSalt.init() + + var tmp: array[32, byte] + hmacDrbgGenerate(f.rng[], tmp) + f.seenSalt.update(tmp) f.init() diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index c8deca4fb..948ca6382 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -360,7 +360,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) = proc validateAndRelay(g: GossipSub, msg: Message, - msgId, msgIdSalted: MessageId, + msgId: MessageId, msgIdSalted: SaltedId, peer: PubSubPeer) {.async.} = try: let validation = await g.validate(msg) @@ -508,12 +508,12 @@ method rpcHandler*(g: GossipSub, let msgId = msgIdResult.get - msgIdSalted = msgId & g.seenSalt + msgIdSalted = g.salt(msgId) topic = msg.topic # addSeen adds salt to msgId to avoid # remote attacking the hash function - if g.addSeen(msgId): + if g.addSeen(msgIdSalted): trace "Dropping already-seen message", msgId = shortLog(msgId), peer var alreadyReceived = false @@ -523,7 +523,7 @@ method rpcHandler*(g: GossipSub, alreadyReceived = true if not alreadyReceived: - let delay = Moment.now() - g.firstSeen(msgId) + let delay = Moment.now() - g.firstSeen(msgIdSalted) g.rewardDelivered(peer, topic, false, delay) libp2p_gossipsub_duplicate.inc() @@ -690,7 +690,7 @@ method publish*(g: GossipSub, trace "Created new message", msg = shortLog(msg), peers = peers.len - if g.addSeen(msgId): + if g.addSeen(g.salt(msgId)): # custom msgid providers might cause this trace "Dropping already-seen message" return 0 @@ -779,7 +779,7 @@ method initPubSub*(g: GossipSub) raise newException(InitializationError, $validationRes.error) # init the floodsub stuff here, we customize timedcache in gossip! - g.seen = TimedCache[MessageId].init(g.parameters.seenTTL) + g.seen = TimedCache[SaltedId].init(g.parameters.seenTTL) # init gossip stuff g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength) diff --git a/libp2p/protocols/pubsub/gossipsub/behavior.nim b/libp2p/protocols/pubsub/gossipsub/behavior.nim index 302695abf..372a87921 100644 --- a/libp2p/protocols/pubsub/gossipsub/behavior.nim +++ b/libp2p/protocols/pubsub/gossipsub/behavior.nim @@ -251,7 +251,7 @@ proc handleIHave*(g: GossipSub, peer, topicID = ihave.topicID, msgs = ihave.messageIDs if ihave.topicID in g.topics: for msgId in ihave.messageIDs: - if not g.hasSeen(msgId): + if not g.hasSeen(g.salt(msgId)): if peer.iHaveBudget <= 0: break elif msgId notin res.messageIDs: diff --git a/libp2p/protocols/pubsub/gossipsub/types.nim b/libp2p/protocols/pubsub/gossipsub/types.nim index e4efb8d7c..ef8a48745 100644 --- a/libp2p/protocols/pubsub/gossipsub/types.nim +++ b/libp2p/protocols/pubsub/gossipsub/types.nim @@ -156,7 +156,7 @@ type maxNumElementsInNonPriorityQueue*: int BackoffTable* = Table[string, Table[PeerId, Moment]] - ValidationSeenTable* = Table[MessageId, HashSet[PubSubPeer]] + ValidationSeenTable* = Table[SaltedId, HashSet[PubSubPeer]] RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]] RoutingRecordsHandler* = diff --git a/libp2p/protocols/pubsub/rpc/messages.nim b/libp2p/protocols/pubsub/rpc/messages.nim index 0dae3ed34..12423fab6 100644 --- a/libp2p/protocols/pubsub/rpc/messages.nim +++ b/libp2p/protocols/pubsub/rpc/messages.nim @@ -37,6 +37,12 @@ type MessageId* = seq[byte] + SaltedId* = object + # Salted hash of message ID - used instead of the ordinary message ID to + # avoid hash poisoning attacks and to make memory usage more predictable + # with respect to the variable-length message id + data*: MDigest[256] + Message* = object fromPeer*: PeerId data*: seq[byte] diff --git a/libp2p/protocols/pubsub/timedcache.nim b/libp2p/protocols/pubsub/timedcache.nim index fbac8db6b..ca08f0aef 100644 --- a/libp2p/protocols/pubsub/timedcache.nim +++ b/libp2p/protocols/pubsub/timedcache.nim @@ -9,12 +9,13 @@ {.push raises: [].} -import std/[tables] - +import std/[hashes, sets] import chronos/timer, stew/results import ../../utility +export results + const Timeout* = 10.seconds # default timeout in ms type @@ -26,20 +27,38 @@ type TimedCache*[K] = object of RootObj head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos - entries: Table[K, TimedEntry[K]] + entries: HashSet[TimedEntry[K]] timeout: Duration +func `==`*[E](a, b: TimedEntry[E]): bool = + if isNil(a) == isNil(b): + isNil(a) or a.key == b.key + else: + false + +func hash*(a: TimedEntry): Hash = + if isNil(a): + default(Hash) + else: + hash(a[].key) + func expire*(t: var TimedCache, now: Moment = Moment.now()) = while t.head != nil and t.head.expiresAt < now: - t.entries.del(t.head.key) + t.entries.excl(t.head) t.head.prev = nil t.head = t.head.next if t.head == nil: t.tail = nil func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] = # Removes existing key from cache, returning the previous value if present - var item: TimedEntry[K] - if t.entries.pop(key, item): + let tmp = TimedEntry[K](key: key) + if tmp in t.entries: + let item = try: + t.entries[tmp] # use the shared instance in the set + except KeyError: + raiseAssert "just checked" + t.entries.excl(item) + if t.head == item: t.head = item.next if t.tail == item: t.tail = item.prev @@ -55,14 +74,14 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool = # refreshed. t.expire(now) - var previous = t.del(k) # Refresh existing item - - var addedAt = now - previous.withValue(previous): - addedAt = previous.addedAt + let + previous = t.del(k) # Refresh existing item + addedAt = if previous.isSome(): + previous[].addedAt + else: + now let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout) - if t.head == nil: t.tail = node t.head = t.tail @@ -83,16 +102,24 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool = if cur == t.tail: t.tail = node - t.entries[k] = node + t.entries.incl(node) previous.isSome() func contains*[K](t: TimedCache[K], k: K): bool = - k in t.entries + let tmp = TimedEntry[K](key: k) + tmp in t.entries -func addedAt*[K](t: TimedCache[K], k: K): Moment = - t.entries.getOrDefault(k).addedAt +func addedAt*[K](t: var TimedCache[K], k: K): Moment = + let tmp = TimedEntry[K](key: k) + try: + if tmp in t.entries: # raising is slow + # Use shared instance from entries + return t.entries[tmp][].addedAt + except KeyError: + raiseAssert "just checked" + default(Moment) func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T = T( diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 514e64dab..fada9f81f 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -569,8 +569,8 @@ suite "GossipSub": proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} = await cRelayed # Empty A & C caches to detect duplicates - gossip1.seen = TimedCache[MessageId].init() - gossip3.seen = TimedCache[MessageId].init() + gossip1.seen = TimedCache[SaltedId].init() + gossip3.seen = TimedCache[SaltedId].init() let msgId = toSeq(gossip2.validationSeen.keys)[0] checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false) result = ValidationResult.Accept diff --git a/tests/pubsub/testtimedcache.nim b/tests/pubsub/testtimedcache.nim index 3fcbf28f2..917ddfe26 100644 --- a/tests/pubsub/testtimedcache.nim +++ b/tests/pubsub/testtimedcache.nim @@ -24,6 +24,8 @@ suite "TimedCache": 2 in cache 3 in cache + cache.addedAt(2) == now + 3.seconds + check: cache.put(2, now + 7.seconds) # refreshes 2 not cache.put(4, now + 12.seconds) # expires 3 @@ -33,6 +35,23 @@ suite "TimedCache": 3 notin cache 4 in cache + check: + cache.del(4).isSome() + 4 notin cache + check: not cache.put(100, now + 100.seconds) # expires everything 100 in cache + 2 notin cache + + test "enough items to force cache heap storage growth": + var cache = TimedCache[int].init(5.seconds) + + let now = Moment.now() + for i in 101..100000: + check: + not cache.put(i, now) + + for i in 101..100000: + check: + i in cache