Improve memory efficiency of seen cache (#1073)
This commit is contained in:
parent
c4da9be32c
commit
02c96fc003
|
@ -16,6 +16,7 @@ import ./pubsub,
|
|||
./timedcache,
|
||||
./peertable,
|
||||
./rpc/[message, messages, protobuf],
|
||||
nimcrypto/[hash, sha2],
|
||||
../../crypto/crypto,
|
||||
../../stream/connection,
|
||||
../../peerid,
|
||||
|
@ -32,20 +33,29 @@ const FloodSubCodec* = "/floodsub/1.0.0"
|
|||
type
|
||||
FloodSub* {.public.} = ref object of PubSub
|
||||
floodsub*: PeerTable # topic to remote peer map
|
||||
seen*: TimedCache[MessageId] # message id:s already seen on the network
|
||||
seenSalt*: seq[byte]
|
||||
seen*: TimedCache[SaltedId]
|
||||
# Early filter for messages recently observed on the network
|
||||
# We use a salted id because the messages in this cache have not yet
|
||||
# been validated meaning that an attacker has greater control over the
|
||||
# hash key and therefore could poison the table
|
||||
seenSalt*: sha256
|
||||
# The salt in this case is a partially updated SHA256 context pre-seeded
|
||||
# with some random data
|
||||
|
||||
proc hasSeen*(f: FloodSub, msgId: MessageId): bool =
|
||||
f.seenSalt & msgId in f.seen
|
||||
proc salt*(f: FloodSub, msgId: MessageId): SaltedId =
|
||||
var tmp = f.seenSalt
|
||||
tmp.update(msgId)
|
||||
SaltedId(data: tmp.finish())
|
||||
|
||||
proc addSeen*(f: FloodSub, msgId: MessageId): bool =
|
||||
# Salting the seen hash helps avoid attacks against the hash function used
|
||||
# in the nim hash table
|
||||
proc hasSeen*(f: FloodSub, saltedId: SaltedId): bool =
|
||||
saltedId in f.seen
|
||||
|
||||
proc addSeen*(f: FloodSub, saltedId: SaltedId): bool =
|
||||
# Return true if the message has already been seen
|
||||
f.seen.put(f.seenSalt & msgId)
|
||||
f.seen.put(saltedId)
|
||||
|
||||
proc firstSeen*(f: FloodSub, msgId: MessageId): Moment =
|
||||
f.seen.addedAt(f.seenSalt & msgId)
|
||||
proc firstSeen*(f: FloodSub, saltedId: SaltedId): Moment =
|
||||
f.seen.addedAt(saltedId)
|
||||
|
||||
proc handleSubscribe*(f: FloodSub,
|
||||
peer: PubSubPeer,
|
||||
|
@ -117,9 +127,11 @@ method rpcHandler*(f: FloodSub,
|
|||
# TODO: descore peers due to error during message validation (malicious?)
|
||||
continue
|
||||
|
||||
let msgId = msgIdResult.get
|
||||
let
|
||||
msgId = msgIdResult.get
|
||||
saltedId = f.salt(msgId)
|
||||
|
||||
if f.addSeen(msgId):
|
||||
if f.addSeen(saltedId):
|
||||
trace "Dropping already-seen message", msgId, peer
|
||||
continue
|
||||
|
||||
|
@ -216,7 +228,7 @@ method publish*(f: FloodSub,
|
|||
trace "Created new message",
|
||||
msg = shortLog(msg), peers = peers.len, topic, msgId
|
||||
|
||||
if f.addSeen(msgId):
|
||||
if f.addSeen(f.salt(msgId)):
|
||||
# custom msgid providers might cause this
|
||||
trace "Dropping already-seen message", msgId, topic
|
||||
return 0
|
||||
|
@ -234,8 +246,11 @@ method publish*(f: FloodSub,
|
|||
method initPubSub*(f: FloodSub)
|
||||
{.raises: [InitializationError].} =
|
||||
procCall PubSub(f).initPubSub()
|
||||
f.seen = TimedCache[MessageId].init(2.minutes)
|
||||
f.seenSalt = newSeqUninitialized[byte](sizeof(Hash))
|
||||
hmacDrbgGenerate(f.rng[], f.seenSalt)
|
||||
f.seen = TimedCache[SaltedId].init(2.minutes)
|
||||
f.seenSalt.init()
|
||||
|
||||
var tmp: array[32, byte]
|
||||
hmacDrbgGenerate(f.rng[], tmp)
|
||||
f.seenSalt.update(tmp)
|
||||
|
||||
f.init()
|
||||
|
|
|
@ -360,7 +360,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
|||
|
||||
proc validateAndRelay(g: GossipSub,
|
||||
msg: Message,
|
||||
msgId, msgIdSalted: MessageId,
|
||||
msgId: MessageId, msgIdSalted: SaltedId,
|
||||
peer: PubSubPeer) {.async.} =
|
||||
try:
|
||||
let validation = await g.validate(msg)
|
||||
|
@ -508,12 +508,12 @@ method rpcHandler*(g: GossipSub,
|
|||
|
||||
let
|
||||
msgId = msgIdResult.get
|
||||
msgIdSalted = msgId & g.seenSalt
|
||||
msgIdSalted = g.salt(msgId)
|
||||
topic = msg.topic
|
||||
|
||||
# addSeen adds salt to msgId to avoid
|
||||
# remote attacking the hash function
|
||||
if g.addSeen(msgId):
|
||||
if g.addSeen(msgIdSalted):
|
||||
trace "Dropping already-seen message", msgId = shortLog(msgId), peer
|
||||
|
||||
var alreadyReceived = false
|
||||
|
@ -523,7 +523,7 @@ method rpcHandler*(g: GossipSub,
|
|||
alreadyReceived = true
|
||||
|
||||
if not alreadyReceived:
|
||||
let delay = Moment.now() - g.firstSeen(msgId)
|
||||
let delay = Moment.now() - g.firstSeen(msgIdSalted)
|
||||
g.rewardDelivered(peer, topic, false, delay)
|
||||
|
||||
libp2p_gossipsub_duplicate.inc()
|
||||
|
@ -690,7 +690,7 @@ method publish*(g: GossipSub,
|
|||
|
||||
trace "Created new message", msg = shortLog(msg), peers = peers.len
|
||||
|
||||
if g.addSeen(msgId):
|
||||
if g.addSeen(g.salt(msgId)):
|
||||
# custom msgid providers might cause this
|
||||
trace "Dropping already-seen message"
|
||||
return 0
|
||||
|
@ -779,7 +779,7 @@ method initPubSub*(g: GossipSub)
|
|||
raise newException(InitializationError, $validationRes.error)
|
||||
|
||||
# init the floodsub stuff here, we customize timedcache in gossip!
|
||||
g.seen = TimedCache[MessageId].init(g.parameters.seenTTL)
|
||||
g.seen = TimedCache[SaltedId].init(g.parameters.seenTTL)
|
||||
|
||||
# init gossip stuff
|
||||
g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength)
|
||||
|
|
|
@ -251,7 +251,7 @@ proc handleIHave*(g: GossipSub,
|
|||
peer, topicID = ihave.topicID, msgs = ihave.messageIDs
|
||||
if ihave.topicID in g.topics:
|
||||
for msgId in ihave.messageIDs:
|
||||
if not g.hasSeen(msgId):
|
||||
if not g.hasSeen(g.salt(msgId)):
|
||||
if peer.iHaveBudget <= 0:
|
||||
break
|
||||
elif msgId notin res.messageIDs:
|
||||
|
|
|
@ -156,7 +156,7 @@ type
|
|||
maxNumElementsInNonPriorityQueue*: int
|
||||
|
||||
BackoffTable* = Table[string, Table[PeerId, Moment]]
|
||||
ValidationSeenTable* = Table[MessageId, HashSet[PubSubPeer]]
|
||||
ValidationSeenTable* = Table[SaltedId, HashSet[PubSubPeer]]
|
||||
|
||||
RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]]
|
||||
RoutingRecordsHandler* =
|
||||
|
|
|
@ -37,6 +37,12 @@ type
|
|||
|
||||
MessageId* = seq[byte]
|
||||
|
||||
SaltedId* = object
|
||||
# Salted hash of message ID - used instead of the ordinary message ID to
|
||||
# avoid hash poisoning attacks and to make memory usage more predictable
|
||||
# with respect to the variable-length message id
|
||||
data*: MDigest[256]
|
||||
|
||||
Message* = object
|
||||
fromPeer*: PeerId
|
||||
data*: seq[byte]
|
||||
|
|
|
@ -9,12 +9,13 @@
|
|||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[tables]
|
||||
|
||||
import std/[hashes, sets]
|
||||
import chronos/timer, stew/results
|
||||
|
||||
import ../../utility
|
||||
|
||||
export results
|
||||
|
||||
const Timeout* = 10.seconds # default timeout in ms
|
||||
|
||||
type
|
||||
|
@ -26,20 +27,38 @@ type
|
|||
|
||||
TimedCache*[K] = object of RootObj
|
||||
head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos
|
||||
entries: Table[K, TimedEntry[K]]
|
||||
entries: HashSet[TimedEntry[K]]
|
||||
timeout: Duration
|
||||
|
||||
func `==`*[E](a, b: TimedEntry[E]): bool =
|
||||
if isNil(a) == isNil(b):
|
||||
isNil(a) or a.key == b.key
|
||||
else:
|
||||
false
|
||||
|
||||
func hash*(a: TimedEntry): Hash =
|
||||
if isNil(a):
|
||||
default(Hash)
|
||||
else:
|
||||
hash(a[].key)
|
||||
|
||||
func expire*(t: var TimedCache, now: Moment = Moment.now()) =
|
||||
while t.head != nil and t.head.expiresAt < now:
|
||||
t.entries.del(t.head.key)
|
||||
t.entries.excl(t.head)
|
||||
t.head.prev = nil
|
||||
t.head = t.head.next
|
||||
if t.head == nil: t.tail = nil
|
||||
|
||||
func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] =
|
||||
# Removes existing key from cache, returning the previous value if present
|
||||
var item: TimedEntry[K]
|
||||
if t.entries.pop(key, item):
|
||||
let tmp = TimedEntry[K](key: key)
|
||||
if tmp in t.entries:
|
||||
let item = try:
|
||||
t.entries[tmp] # use the shared instance in the set
|
||||
except KeyError:
|
||||
raiseAssert "just checked"
|
||||
t.entries.excl(item)
|
||||
|
||||
if t.head == item: t.head = item.next
|
||||
if t.tail == item: t.tail = item.prev
|
||||
|
||||
|
@ -55,14 +74,14 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
|
|||
# refreshed.
|
||||
t.expire(now)
|
||||
|
||||
var previous = t.del(k) # Refresh existing item
|
||||
|
||||
var addedAt = now
|
||||
previous.withValue(previous):
|
||||
addedAt = previous.addedAt
|
||||
let
|
||||
previous = t.del(k) # Refresh existing item
|
||||
addedAt = if previous.isSome():
|
||||
previous[].addedAt
|
||||
else:
|
||||
now
|
||||
|
||||
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)
|
||||
|
||||
if t.head == nil:
|
||||
t.tail = node
|
||||
t.head = t.tail
|
||||
|
@ -83,16 +102,24 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
|
|||
if cur == t.tail:
|
||||
t.tail = node
|
||||
|
||||
t.entries[k] = node
|
||||
t.entries.incl(node)
|
||||
|
||||
previous.isSome()
|
||||
|
||||
func contains*[K](t: TimedCache[K], k: K): bool =
|
||||
k in t.entries
|
||||
let tmp = TimedEntry[K](key: k)
|
||||
tmp in t.entries
|
||||
|
||||
func addedAt*[K](t: TimedCache[K], k: K): Moment =
|
||||
t.entries.getOrDefault(k).addedAt
|
||||
func addedAt*[K](t: var TimedCache[K], k: K): Moment =
|
||||
let tmp = TimedEntry[K](key: k)
|
||||
try:
|
||||
if tmp in t.entries: # raising is slow
|
||||
# Use shared instance from entries
|
||||
return t.entries[tmp][].addedAt
|
||||
except KeyError:
|
||||
raiseAssert "just checked"
|
||||
|
||||
default(Moment)
|
||||
|
||||
func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T =
|
||||
T(
|
||||
|
|
|
@ -569,8 +569,8 @@ suite "GossipSub":
|
|||
proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} =
|
||||
await cRelayed
|
||||
# Empty A & C caches to detect duplicates
|
||||
gossip1.seen = TimedCache[MessageId].init()
|
||||
gossip3.seen = TimedCache[MessageId].init()
|
||||
gossip1.seen = TimedCache[SaltedId].init()
|
||||
gossip3.seen = TimedCache[SaltedId].init()
|
||||
let msgId = toSeq(gossip2.validationSeen.keys)[0]
|
||||
checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false)
|
||||
result = ValidationResult.Accept
|
||||
|
|
|
@ -24,6 +24,8 @@ suite "TimedCache":
|
|||
2 in cache
|
||||
3 in cache
|
||||
|
||||
cache.addedAt(2) == now + 3.seconds
|
||||
|
||||
check:
|
||||
cache.put(2, now + 7.seconds) # refreshes 2
|
||||
not cache.put(4, now + 12.seconds) # expires 3
|
||||
|
@ -33,6 +35,23 @@ suite "TimedCache":
|
|||
3 notin cache
|
||||
4 in cache
|
||||
|
||||
check:
|
||||
cache.del(4).isSome()
|
||||
4 notin cache
|
||||
|
||||
check:
|
||||
not cache.put(100, now + 100.seconds) # expires everything
|
||||
100 in cache
|
||||
2 notin cache
|
||||
|
||||
test "enough items to force cache heap storage growth":
|
||||
var cache = TimedCache[int].init(5.seconds)
|
||||
|
||||
let now = Moment.now()
|
||||
for i in 101..100000:
|
||||
check:
|
||||
not cache.put(i, now)
|
||||
|
||||
for i in 101..100000:
|
||||
check:
|
||||
i in cache
|
||||
|
|
Loading…
Reference in New Issue