Improve memory efficiency of seen cache (#1073)

This commit is contained in:
Jacek Sieka 2024-05-01 18:38:24 +02:00 committed by GitHub
parent c4da9be32c
commit 02c96fc003
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 109 additions and 42 deletions

View File

@ -16,6 +16,7 @@ import ./pubsub,
./timedcache, ./timedcache,
./peertable, ./peertable,
./rpc/[message, messages, protobuf], ./rpc/[message, messages, protobuf],
nimcrypto/[hash, sha2],
../../crypto/crypto, ../../crypto/crypto,
../../stream/connection, ../../stream/connection,
../../peerid, ../../peerid,
@ -32,20 +33,29 @@ const FloodSubCodec* = "/floodsub/1.0.0"
type type
FloodSub* {.public.} = ref object of PubSub FloodSub* {.public.} = ref object of PubSub
floodsub*: PeerTable # topic to remote peer map floodsub*: PeerTable # topic to remote peer map
seen*: TimedCache[MessageId] # message id:s already seen on the network seen*: TimedCache[SaltedId]
seenSalt*: seq[byte] # Early filter for messages recently observed on the network
# We use a salted id because the messages in this cache have not yet
# been validated meaning that an attacker has greater control over the
# hash key and therefore could poison the table
seenSalt*: sha256
# The salt in this case is a partially updated SHA256 context pre-seeded
# with some random data
proc hasSeen*(f: FloodSub, msgId: MessageId): bool = proc salt*(f: FloodSub, msgId: MessageId): SaltedId =
f.seenSalt & msgId in f.seen var tmp = f.seenSalt
tmp.update(msgId)
SaltedId(data: tmp.finish())
proc addSeen*(f: FloodSub, msgId: MessageId): bool = proc hasSeen*(f: FloodSub, saltedId: SaltedId): bool =
# Salting the seen hash helps avoid attacks against the hash function used saltedId in f.seen
# in the nim hash table
proc addSeen*(f: FloodSub, saltedId: SaltedId): bool =
# Return true if the message has already been seen # Return true if the message has already been seen
f.seen.put(f.seenSalt & msgId) f.seen.put(saltedId)
proc firstSeen*(f: FloodSub, msgId: MessageId): Moment = proc firstSeen*(f: FloodSub, saltedId: SaltedId): Moment =
f.seen.addedAt(f.seenSalt & msgId) f.seen.addedAt(saltedId)
proc handleSubscribe*(f: FloodSub, proc handleSubscribe*(f: FloodSub,
peer: PubSubPeer, peer: PubSubPeer,
@ -117,9 +127,11 @@ method rpcHandler*(f: FloodSub,
# TODO: descore peers due to error during message validation (malicious?) # TODO: descore peers due to error during message validation (malicious?)
continue continue
let msgId = msgIdResult.get let
msgId = msgIdResult.get
saltedId = f.salt(msgId)
if f.addSeen(msgId): if f.addSeen(saltedId):
trace "Dropping already-seen message", msgId, peer trace "Dropping already-seen message", msgId, peer
continue continue
@ -216,7 +228,7 @@ method publish*(f: FloodSub,
trace "Created new message", trace "Created new message",
msg = shortLog(msg), peers = peers.len, topic, msgId msg = shortLog(msg), peers = peers.len, topic, msgId
if f.addSeen(msgId): if f.addSeen(f.salt(msgId)):
# custom msgid providers might cause this # custom msgid providers might cause this
trace "Dropping already-seen message", msgId, topic trace "Dropping already-seen message", msgId, topic
return 0 return 0
@ -234,8 +246,11 @@ method publish*(f: FloodSub,
method initPubSub*(f: FloodSub) method initPubSub*(f: FloodSub)
{.raises: [InitializationError].} = {.raises: [InitializationError].} =
procCall PubSub(f).initPubSub() procCall PubSub(f).initPubSub()
f.seen = TimedCache[MessageId].init(2.minutes) f.seen = TimedCache[SaltedId].init(2.minutes)
f.seenSalt = newSeqUninitialized[byte](sizeof(Hash)) f.seenSalt.init()
hmacDrbgGenerate(f.rng[], f.seenSalt)
var tmp: array[32, byte]
hmacDrbgGenerate(f.rng[], tmp)
f.seenSalt.update(tmp)
f.init() f.init()

View File

@ -360,7 +360,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
proc validateAndRelay(g: GossipSub, proc validateAndRelay(g: GossipSub,
msg: Message, msg: Message,
msgId, msgIdSalted: MessageId, msgId: MessageId, msgIdSalted: SaltedId,
peer: PubSubPeer) {.async.} = peer: PubSubPeer) {.async.} =
try: try:
let validation = await g.validate(msg) let validation = await g.validate(msg)
@ -508,12 +508,12 @@ method rpcHandler*(g: GossipSub,
let let
msgId = msgIdResult.get msgId = msgIdResult.get
msgIdSalted = msgId & g.seenSalt msgIdSalted = g.salt(msgId)
topic = msg.topic topic = msg.topic
# addSeen adds salt to msgId to avoid # addSeen adds salt to msgId to avoid
# remote attacking the hash function # remote attacking the hash function
if g.addSeen(msgId): if g.addSeen(msgIdSalted):
trace "Dropping already-seen message", msgId = shortLog(msgId), peer trace "Dropping already-seen message", msgId = shortLog(msgId), peer
var alreadyReceived = false var alreadyReceived = false
@ -523,7 +523,7 @@ method rpcHandler*(g: GossipSub,
alreadyReceived = true alreadyReceived = true
if not alreadyReceived: if not alreadyReceived:
let delay = Moment.now() - g.firstSeen(msgId) let delay = Moment.now() - g.firstSeen(msgIdSalted)
g.rewardDelivered(peer, topic, false, delay) g.rewardDelivered(peer, topic, false, delay)
libp2p_gossipsub_duplicate.inc() libp2p_gossipsub_duplicate.inc()
@ -690,7 +690,7 @@ method publish*(g: GossipSub,
trace "Created new message", msg = shortLog(msg), peers = peers.len trace "Created new message", msg = shortLog(msg), peers = peers.len
if g.addSeen(msgId): if g.addSeen(g.salt(msgId)):
# custom msgid providers might cause this # custom msgid providers might cause this
trace "Dropping already-seen message" trace "Dropping already-seen message"
return 0 return 0
@ -779,7 +779,7 @@ method initPubSub*(g: GossipSub)
raise newException(InitializationError, $validationRes.error) raise newException(InitializationError, $validationRes.error)
# init the floodsub stuff here, we customize timedcache in gossip! # init the floodsub stuff here, we customize timedcache in gossip!
g.seen = TimedCache[MessageId].init(g.parameters.seenTTL) g.seen = TimedCache[SaltedId].init(g.parameters.seenTTL)
# init gossip stuff # init gossip stuff
g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength) g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength)

View File

@ -251,7 +251,7 @@ proc handleIHave*(g: GossipSub,
peer, topicID = ihave.topicID, msgs = ihave.messageIDs peer, topicID = ihave.topicID, msgs = ihave.messageIDs
if ihave.topicID in g.topics: if ihave.topicID in g.topics:
for msgId in ihave.messageIDs: for msgId in ihave.messageIDs:
if not g.hasSeen(msgId): if not g.hasSeen(g.salt(msgId)):
if peer.iHaveBudget <= 0: if peer.iHaveBudget <= 0:
break break
elif msgId notin res.messageIDs: elif msgId notin res.messageIDs:

View File

@ -156,7 +156,7 @@ type
maxNumElementsInNonPriorityQueue*: int maxNumElementsInNonPriorityQueue*: int
BackoffTable* = Table[string, Table[PeerId, Moment]] BackoffTable* = Table[string, Table[PeerId, Moment]]
ValidationSeenTable* = Table[MessageId, HashSet[PubSubPeer]] ValidationSeenTable* = Table[SaltedId, HashSet[PubSubPeer]]
RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]] RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]]
RoutingRecordsHandler* = RoutingRecordsHandler* =

View File

@ -37,6 +37,12 @@ type
MessageId* = seq[byte] MessageId* = seq[byte]
SaltedId* = object
# Salted hash of message ID - used instead of the ordinary message ID to
# avoid hash poisoning attacks and to make memory usage more predictable
# with respect to the variable-length message id
data*: MDigest[256]
Message* = object Message* = object
fromPeer*: PeerId fromPeer*: PeerId
data*: seq[byte] data*: seq[byte]

View File

@ -9,12 +9,13 @@
{.push raises: [].} {.push raises: [].}
import std/[tables] import std/[hashes, sets]
import chronos/timer, stew/results import chronos/timer, stew/results
import ../../utility import ../../utility
export results
const Timeout* = 10.seconds # default timeout in ms const Timeout* = 10.seconds # default timeout in ms
type type
@ -26,20 +27,38 @@ type
TimedCache*[K] = object of RootObj TimedCache*[K] = object of RootObj
head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos
entries: Table[K, TimedEntry[K]] entries: HashSet[TimedEntry[K]]
timeout: Duration timeout: Duration
func `==`*[E](a, b: TimedEntry[E]): bool =
if isNil(a) == isNil(b):
isNil(a) or a.key == b.key
else:
false
func hash*(a: TimedEntry): Hash =
if isNil(a):
default(Hash)
else:
hash(a[].key)
func expire*(t: var TimedCache, now: Moment = Moment.now()) = func expire*(t: var TimedCache, now: Moment = Moment.now()) =
while t.head != nil and t.head.expiresAt < now: while t.head != nil and t.head.expiresAt < now:
t.entries.del(t.head.key) t.entries.excl(t.head)
t.head.prev = nil t.head.prev = nil
t.head = t.head.next t.head = t.head.next
if t.head == nil: t.tail = nil if t.head == nil: t.tail = nil
func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] = func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] =
# Removes existing key from cache, returning the previous value if present # Removes existing key from cache, returning the previous value if present
var item: TimedEntry[K] let tmp = TimedEntry[K](key: key)
if t.entries.pop(key, item): if tmp in t.entries:
let item = try:
t.entries[tmp] # use the shared instance in the set
except KeyError:
raiseAssert "just checked"
t.entries.excl(item)
if t.head == item: t.head = item.next if t.head == item: t.head = item.next
if t.tail == item: t.tail = item.prev if t.tail == item: t.tail = item.prev
@ -55,14 +74,14 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
# refreshed. # refreshed.
t.expire(now) t.expire(now)
var previous = t.del(k) # Refresh existing item let
previous = t.del(k) # Refresh existing item
var addedAt = now addedAt = if previous.isSome():
previous.withValue(previous): previous[].addedAt
addedAt = previous.addedAt else:
now
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout) let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)
if t.head == nil: if t.head == nil:
t.tail = node t.tail = node
t.head = t.tail t.head = t.tail
@ -83,16 +102,24 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
if cur == t.tail: if cur == t.tail:
t.tail = node t.tail = node
t.entries[k] = node t.entries.incl(node)
previous.isSome() previous.isSome()
func contains*[K](t: TimedCache[K], k: K): bool = func contains*[K](t: TimedCache[K], k: K): bool =
k in t.entries let tmp = TimedEntry[K](key: k)
tmp in t.entries
func addedAt*[K](t: TimedCache[K], k: K): Moment = func addedAt*[K](t: var TimedCache[K], k: K): Moment =
t.entries.getOrDefault(k).addedAt let tmp = TimedEntry[K](key: k)
try:
if tmp in t.entries: # raising is slow
# Use shared instance from entries
return t.entries[tmp][].addedAt
except KeyError:
raiseAssert "just checked"
default(Moment)
func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T = func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T =
T( T(

View File

@ -569,8 +569,8 @@ suite "GossipSub":
proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} = proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} =
await cRelayed await cRelayed
# Empty A & C caches to detect duplicates # Empty A & C caches to detect duplicates
gossip1.seen = TimedCache[MessageId].init() gossip1.seen = TimedCache[SaltedId].init()
gossip3.seen = TimedCache[MessageId].init() gossip3.seen = TimedCache[SaltedId].init()
let msgId = toSeq(gossip2.validationSeen.keys)[0] let msgId = toSeq(gossip2.validationSeen.keys)[0]
checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false) checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false)
result = ValidationResult.Accept result = ValidationResult.Accept

View File

@ -24,6 +24,8 @@ suite "TimedCache":
2 in cache 2 in cache
3 in cache 3 in cache
cache.addedAt(2) == now + 3.seconds
check: check:
cache.put(2, now + 7.seconds) # refreshes 2 cache.put(2, now + 7.seconds) # refreshes 2
not cache.put(4, now + 12.seconds) # expires 3 not cache.put(4, now + 12.seconds) # expires 3
@ -33,6 +35,23 @@ suite "TimedCache":
3 notin cache 3 notin cache
4 in cache 4 in cache
check:
cache.del(4).isSome()
4 notin cache
check: check:
not cache.put(100, now + 100.seconds) # expires everything not cache.put(100, now + 100.seconds) # expires everything
100 in cache 100 in cache
2 notin cache
test "enough items to force cache heap storage growth":
var cache = TimedCache[int].init(5.seconds)
let now = Moment.now()
for i in 101..100000:
check:
not cache.put(i, now)
for i in 101..100000:
check:
i in cache