mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2025-01-27 17:05:59 +00:00
Improve memory efficiency of seen cache (#1073)
This commit is contained in:
parent
c4da9be32c
commit
02c96fc003
@ -16,6 +16,7 @@ import ./pubsub,
|
|||||||
./timedcache,
|
./timedcache,
|
||||||
./peertable,
|
./peertable,
|
||||||
./rpc/[message, messages, protobuf],
|
./rpc/[message, messages, protobuf],
|
||||||
|
nimcrypto/[hash, sha2],
|
||||||
../../crypto/crypto,
|
../../crypto/crypto,
|
||||||
../../stream/connection,
|
../../stream/connection,
|
||||||
../../peerid,
|
../../peerid,
|
||||||
@ -32,20 +33,29 @@ const FloodSubCodec* = "/floodsub/1.0.0"
|
|||||||
type
|
type
|
||||||
FloodSub* {.public.} = ref object of PubSub
|
FloodSub* {.public.} = ref object of PubSub
|
||||||
floodsub*: PeerTable # topic to remote peer map
|
floodsub*: PeerTable # topic to remote peer map
|
||||||
seen*: TimedCache[MessageId] # message id:s already seen on the network
|
seen*: TimedCache[SaltedId]
|
||||||
seenSalt*: seq[byte]
|
# Early filter for messages recently observed on the network
|
||||||
|
# We use a salted id because the messages in this cache have not yet
|
||||||
|
# been validated meaning that an attacker has greater control over the
|
||||||
|
# hash key and therefore could poison the table
|
||||||
|
seenSalt*: sha256
|
||||||
|
# The salt in this case is a partially updated SHA256 context pre-seeded
|
||||||
|
# with some random data
|
||||||
|
|
||||||
proc hasSeen*(f: FloodSub, msgId: MessageId): bool =
|
proc salt*(f: FloodSub, msgId: MessageId): SaltedId =
|
||||||
f.seenSalt & msgId in f.seen
|
var tmp = f.seenSalt
|
||||||
|
tmp.update(msgId)
|
||||||
|
SaltedId(data: tmp.finish())
|
||||||
|
|
||||||
proc addSeen*(f: FloodSub, msgId: MessageId): bool =
|
proc hasSeen*(f: FloodSub, saltedId: SaltedId): bool =
|
||||||
# Salting the seen hash helps avoid attacks against the hash function used
|
saltedId in f.seen
|
||||||
# in the nim hash table
|
|
||||||
|
proc addSeen*(f: FloodSub, saltedId: SaltedId): bool =
|
||||||
# Return true if the message has already been seen
|
# Return true if the message has already been seen
|
||||||
f.seen.put(f.seenSalt & msgId)
|
f.seen.put(saltedId)
|
||||||
|
|
||||||
proc firstSeen*(f: FloodSub, msgId: MessageId): Moment =
|
proc firstSeen*(f: FloodSub, saltedId: SaltedId): Moment =
|
||||||
f.seen.addedAt(f.seenSalt & msgId)
|
f.seen.addedAt(saltedId)
|
||||||
|
|
||||||
proc handleSubscribe*(f: FloodSub,
|
proc handleSubscribe*(f: FloodSub,
|
||||||
peer: PubSubPeer,
|
peer: PubSubPeer,
|
||||||
@ -117,9 +127,11 @@ method rpcHandler*(f: FloodSub,
|
|||||||
# TODO: descore peers due to error during message validation (malicious?)
|
# TODO: descore peers due to error during message validation (malicious?)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
let msgId = msgIdResult.get
|
let
|
||||||
|
msgId = msgIdResult.get
|
||||||
|
saltedId = f.salt(msgId)
|
||||||
|
|
||||||
if f.addSeen(msgId):
|
if f.addSeen(saltedId):
|
||||||
trace "Dropping already-seen message", msgId, peer
|
trace "Dropping already-seen message", msgId, peer
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -216,7 +228,7 @@ method publish*(f: FloodSub,
|
|||||||
trace "Created new message",
|
trace "Created new message",
|
||||||
msg = shortLog(msg), peers = peers.len, topic, msgId
|
msg = shortLog(msg), peers = peers.len, topic, msgId
|
||||||
|
|
||||||
if f.addSeen(msgId):
|
if f.addSeen(f.salt(msgId)):
|
||||||
# custom msgid providers might cause this
|
# custom msgid providers might cause this
|
||||||
trace "Dropping already-seen message", msgId, topic
|
trace "Dropping already-seen message", msgId, topic
|
||||||
return 0
|
return 0
|
||||||
@ -234,8 +246,11 @@ method publish*(f: FloodSub,
|
|||||||
method initPubSub*(f: FloodSub)
|
method initPubSub*(f: FloodSub)
|
||||||
{.raises: [InitializationError].} =
|
{.raises: [InitializationError].} =
|
||||||
procCall PubSub(f).initPubSub()
|
procCall PubSub(f).initPubSub()
|
||||||
f.seen = TimedCache[MessageId].init(2.minutes)
|
f.seen = TimedCache[SaltedId].init(2.minutes)
|
||||||
f.seenSalt = newSeqUninitialized[byte](sizeof(Hash))
|
f.seenSalt.init()
|
||||||
hmacDrbgGenerate(f.rng[], f.seenSalt)
|
|
||||||
|
var tmp: array[32, byte]
|
||||||
|
hmacDrbgGenerate(f.rng[], tmp)
|
||||||
|
f.seenSalt.update(tmp)
|
||||||
|
|
||||||
f.init()
|
f.init()
|
||||||
|
@ -360,7 +360,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
|||||||
|
|
||||||
proc validateAndRelay(g: GossipSub,
|
proc validateAndRelay(g: GossipSub,
|
||||||
msg: Message,
|
msg: Message,
|
||||||
msgId, msgIdSalted: MessageId,
|
msgId: MessageId, msgIdSalted: SaltedId,
|
||||||
peer: PubSubPeer) {.async.} =
|
peer: PubSubPeer) {.async.} =
|
||||||
try:
|
try:
|
||||||
let validation = await g.validate(msg)
|
let validation = await g.validate(msg)
|
||||||
@ -508,12 +508,12 @@ method rpcHandler*(g: GossipSub,
|
|||||||
|
|
||||||
let
|
let
|
||||||
msgId = msgIdResult.get
|
msgId = msgIdResult.get
|
||||||
msgIdSalted = msgId & g.seenSalt
|
msgIdSalted = g.salt(msgId)
|
||||||
topic = msg.topic
|
topic = msg.topic
|
||||||
|
|
||||||
# addSeen adds salt to msgId to avoid
|
# addSeen adds salt to msgId to avoid
|
||||||
# remote attacking the hash function
|
# remote attacking the hash function
|
||||||
if g.addSeen(msgId):
|
if g.addSeen(msgIdSalted):
|
||||||
trace "Dropping already-seen message", msgId = shortLog(msgId), peer
|
trace "Dropping already-seen message", msgId = shortLog(msgId), peer
|
||||||
|
|
||||||
var alreadyReceived = false
|
var alreadyReceived = false
|
||||||
@ -523,7 +523,7 @@ method rpcHandler*(g: GossipSub,
|
|||||||
alreadyReceived = true
|
alreadyReceived = true
|
||||||
|
|
||||||
if not alreadyReceived:
|
if not alreadyReceived:
|
||||||
let delay = Moment.now() - g.firstSeen(msgId)
|
let delay = Moment.now() - g.firstSeen(msgIdSalted)
|
||||||
g.rewardDelivered(peer, topic, false, delay)
|
g.rewardDelivered(peer, topic, false, delay)
|
||||||
|
|
||||||
libp2p_gossipsub_duplicate.inc()
|
libp2p_gossipsub_duplicate.inc()
|
||||||
@ -690,7 +690,7 @@ method publish*(g: GossipSub,
|
|||||||
|
|
||||||
trace "Created new message", msg = shortLog(msg), peers = peers.len
|
trace "Created new message", msg = shortLog(msg), peers = peers.len
|
||||||
|
|
||||||
if g.addSeen(msgId):
|
if g.addSeen(g.salt(msgId)):
|
||||||
# custom msgid providers might cause this
|
# custom msgid providers might cause this
|
||||||
trace "Dropping already-seen message"
|
trace "Dropping already-seen message"
|
||||||
return 0
|
return 0
|
||||||
@ -779,7 +779,7 @@ method initPubSub*(g: GossipSub)
|
|||||||
raise newException(InitializationError, $validationRes.error)
|
raise newException(InitializationError, $validationRes.error)
|
||||||
|
|
||||||
# init the floodsub stuff here, we customize timedcache in gossip!
|
# init the floodsub stuff here, we customize timedcache in gossip!
|
||||||
g.seen = TimedCache[MessageId].init(g.parameters.seenTTL)
|
g.seen = TimedCache[SaltedId].init(g.parameters.seenTTL)
|
||||||
|
|
||||||
# init gossip stuff
|
# init gossip stuff
|
||||||
g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength)
|
g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength)
|
||||||
|
@ -251,7 +251,7 @@ proc handleIHave*(g: GossipSub,
|
|||||||
peer, topicID = ihave.topicID, msgs = ihave.messageIDs
|
peer, topicID = ihave.topicID, msgs = ihave.messageIDs
|
||||||
if ihave.topicID in g.topics:
|
if ihave.topicID in g.topics:
|
||||||
for msgId in ihave.messageIDs:
|
for msgId in ihave.messageIDs:
|
||||||
if not g.hasSeen(msgId):
|
if not g.hasSeen(g.salt(msgId)):
|
||||||
if peer.iHaveBudget <= 0:
|
if peer.iHaveBudget <= 0:
|
||||||
break
|
break
|
||||||
elif msgId notin res.messageIDs:
|
elif msgId notin res.messageIDs:
|
||||||
|
@ -156,7 +156,7 @@ type
|
|||||||
maxNumElementsInNonPriorityQueue*: int
|
maxNumElementsInNonPriorityQueue*: int
|
||||||
|
|
||||||
BackoffTable* = Table[string, Table[PeerId, Moment]]
|
BackoffTable* = Table[string, Table[PeerId, Moment]]
|
||||||
ValidationSeenTable* = Table[MessageId, HashSet[PubSubPeer]]
|
ValidationSeenTable* = Table[SaltedId, HashSet[PubSubPeer]]
|
||||||
|
|
||||||
RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]]
|
RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]]
|
||||||
RoutingRecordsHandler* =
|
RoutingRecordsHandler* =
|
||||||
|
@ -37,6 +37,12 @@ type
|
|||||||
|
|
||||||
MessageId* = seq[byte]
|
MessageId* = seq[byte]
|
||||||
|
|
||||||
|
SaltedId* = object
|
||||||
|
# Salted hash of message ID - used instead of the ordinary message ID to
|
||||||
|
# avoid hash poisoning attacks and to make memory usage more predictable
|
||||||
|
# with respect to the variable-length message id
|
||||||
|
data*: MDigest[256]
|
||||||
|
|
||||||
Message* = object
|
Message* = object
|
||||||
fromPeer*: PeerId
|
fromPeer*: PeerId
|
||||||
data*: seq[byte]
|
data*: seq[byte]
|
||||||
|
@ -9,12 +9,13 @@
|
|||||||
|
|
||||||
{.push raises: [].}
|
{.push raises: [].}
|
||||||
|
|
||||||
import std/[tables]
|
import std/[hashes, sets]
|
||||||
|
|
||||||
import chronos/timer, stew/results
|
import chronos/timer, stew/results
|
||||||
|
|
||||||
import ../../utility
|
import ../../utility
|
||||||
|
|
||||||
|
export results
|
||||||
|
|
||||||
const Timeout* = 10.seconds # default timeout in ms
|
const Timeout* = 10.seconds # default timeout in ms
|
||||||
|
|
||||||
type
|
type
|
||||||
@ -26,20 +27,38 @@ type
|
|||||||
|
|
||||||
TimedCache*[K] = object of RootObj
|
TimedCache*[K] = object of RootObj
|
||||||
head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos
|
head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos
|
||||||
entries: Table[K, TimedEntry[K]]
|
entries: HashSet[TimedEntry[K]]
|
||||||
timeout: Duration
|
timeout: Duration
|
||||||
|
|
||||||
|
func `==`*[E](a, b: TimedEntry[E]): bool =
|
||||||
|
if isNil(a) == isNil(b):
|
||||||
|
isNil(a) or a.key == b.key
|
||||||
|
else:
|
||||||
|
false
|
||||||
|
|
||||||
|
func hash*(a: TimedEntry): Hash =
|
||||||
|
if isNil(a):
|
||||||
|
default(Hash)
|
||||||
|
else:
|
||||||
|
hash(a[].key)
|
||||||
|
|
||||||
func expire*(t: var TimedCache, now: Moment = Moment.now()) =
|
func expire*(t: var TimedCache, now: Moment = Moment.now()) =
|
||||||
while t.head != nil and t.head.expiresAt < now:
|
while t.head != nil and t.head.expiresAt < now:
|
||||||
t.entries.del(t.head.key)
|
t.entries.excl(t.head)
|
||||||
t.head.prev = nil
|
t.head.prev = nil
|
||||||
t.head = t.head.next
|
t.head = t.head.next
|
||||||
if t.head == nil: t.tail = nil
|
if t.head == nil: t.tail = nil
|
||||||
|
|
||||||
func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] =
|
func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] =
|
||||||
# Removes existing key from cache, returning the previous value if present
|
# Removes existing key from cache, returning the previous value if present
|
||||||
var item: TimedEntry[K]
|
let tmp = TimedEntry[K](key: key)
|
||||||
if t.entries.pop(key, item):
|
if tmp in t.entries:
|
||||||
|
let item = try:
|
||||||
|
t.entries[tmp] # use the shared instance in the set
|
||||||
|
except KeyError:
|
||||||
|
raiseAssert "just checked"
|
||||||
|
t.entries.excl(item)
|
||||||
|
|
||||||
if t.head == item: t.head = item.next
|
if t.head == item: t.head = item.next
|
||||||
if t.tail == item: t.tail = item.prev
|
if t.tail == item: t.tail = item.prev
|
||||||
|
|
||||||
@ -55,14 +74,14 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
|
|||||||
# refreshed.
|
# refreshed.
|
||||||
t.expire(now)
|
t.expire(now)
|
||||||
|
|
||||||
var previous = t.del(k) # Refresh existing item
|
let
|
||||||
|
previous = t.del(k) # Refresh existing item
|
||||||
var addedAt = now
|
addedAt = if previous.isSome():
|
||||||
previous.withValue(previous):
|
previous[].addedAt
|
||||||
addedAt = previous.addedAt
|
else:
|
||||||
|
now
|
||||||
|
|
||||||
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)
|
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)
|
||||||
|
|
||||||
if t.head == nil:
|
if t.head == nil:
|
||||||
t.tail = node
|
t.tail = node
|
||||||
t.head = t.tail
|
t.head = t.tail
|
||||||
@ -83,16 +102,24 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
|
|||||||
if cur == t.tail:
|
if cur == t.tail:
|
||||||
t.tail = node
|
t.tail = node
|
||||||
|
|
||||||
t.entries[k] = node
|
t.entries.incl(node)
|
||||||
|
|
||||||
previous.isSome()
|
previous.isSome()
|
||||||
|
|
||||||
func contains*[K](t: TimedCache[K], k: K): bool =
|
func contains*[K](t: TimedCache[K], k: K): bool =
|
||||||
k in t.entries
|
let tmp = TimedEntry[K](key: k)
|
||||||
|
tmp in t.entries
|
||||||
|
|
||||||
func addedAt*[K](t: TimedCache[K], k: K): Moment =
|
func addedAt*[K](t: var TimedCache[K], k: K): Moment =
|
||||||
t.entries.getOrDefault(k).addedAt
|
let tmp = TimedEntry[K](key: k)
|
||||||
|
try:
|
||||||
|
if tmp in t.entries: # raising is slow
|
||||||
|
# Use shared instance from entries
|
||||||
|
return t.entries[tmp][].addedAt
|
||||||
|
except KeyError:
|
||||||
|
raiseAssert "just checked"
|
||||||
|
|
||||||
|
default(Moment)
|
||||||
|
|
||||||
func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T =
|
func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T =
|
||||||
T(
|
T(
|
||||||
|
@ -569,8 +569,8 @@ suite "GossipSub":
|
|||||||
proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} =
|
proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} =
|
||||||
await cRelayed
|
await cRelayed
|
||||||
# Empty A & C caches to detect duplicates
|
# Empty A & C caches to detect duplicates
|
||||||
gossip1.seen = TimedCache[MessageId].init()
|
gossip1.seen = TimedCache[SaltedId].init()
|
||||||
gossip3.seen = TimedCache[MessageId].init()
|
gossip3.seen = TimedCache[SaltedId].init()
|
||||||
let msgId = toSeq(gossip2.validationSeen.keys)[0]
|
let msgId = toSeq(gossip2.validationSeen.keys)[0]
|
||||||
checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false)
|
checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false)
|
||||||
result = ValidationResult.Accept
|
result = ValidationResult.Accept
|
||||||
|
@ -24,6 +24,8 @@ suite "TimedCache":
|
|||||||
2 in cache
|
2 in cache
|
||||||
3 in cache
|
3 in cache
|
||||||
|
|
||||||
|
cache.addedAt(2) == now + 3.seconds
|
||||||
|
|
||||||
check:
|
check:
|
||||||
cache.put(2, now + 7.seconds) # refreshes 2
|
cache.put(2, now + 7.seconds) # refreshes 2
|
||||||
not cache.put(4, now + 12.seconds) # expires 3
|
not cache.put(4, now + 12.seconds) # expires 3
|
||||||
@ -33,6 +35,23 @@ suite "TimedCache":
|
|||||||
3 notin cache
|
3 notin cache
|
||||||
4 in cache
|
4 in cache
|
||||||
|
|
||||||
|
check:
|
||||||
|
cache.del(4).isSome()
|
||||||
|
4 notin cache
|
||||||
|
|
||||||
check:
|
check:
|
||||||
not cache.put(100, now + 100.seconds) # expires everything
|
not cache.put(100, now + 100.seconds) # expires everything
|
||||||
100 in cache
|
100 in cache
|
||||||
|
2 notin cache
|
||||||
|
|
||||||
|
test "enough items to force cache heap storage growth":
|
||||||
|
var cache = TimedCache[int].init(5.seconds)
|
||||||
|
|
||||||
|
let now = Moment.now()
|
||||||
|
for i in 101..100000:
|
||||||
|
check:
|
||||||
|
not cache.put(i, now)
|
||||||
|
|
||||||
|
for i in 101..100000:
|
||||||
|
check:
|
||||||
|
i in cache
|
||||||
|
Loading…
x
Reference in New Issue
Block a user