mirror of
https://github.com/logos-messaging/logos-messaging-go.git
synced 2026-01-03 06:23:06 +00:00
refactor: shard LRU for storing peers by shard (#840)
Co-authored-by: Richard Ramos <info@richardramos.me>
This commit is contained in:
parent
48ab9e6ce7
commit
d7c7255aa4
@ -3,69 +3,45 @@ package peer_exchange
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/hashicorp/golang-lru/simplelru"
|
||||
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange/pb"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// simpleLRU internal uses container/list, which is ring buffer(double linked list)
|
||||
type enrCache struct {
|
||||
// using lru, saves us from periodically cleaning the cache to maintain a certain size
|
||||
data *simplelru.LRU
|
||||
rng *rand.Rand
|
||||
mu sync.RWMutex
|
||||
log *zap.Logger
|
||||
// using lru, saves us from periodically cleaning the cache to mauintain a certain size
|
||||
data *shardLRU
|
||||
}
|
||||
|
||||
// err on negative size
|
||||
func newEnrCache(size int, log *zap.Logger) (*enrCache, error) {
|
||||
inner, err := simplelru.NewLRU(size, nil)
|
||||
func newEnrCache(size int) *enrCache {
|
||||
inner := newShardLRU(int(size))
|
||||
return &enrCache{
|
||||
data: inner,
|
||||
rng: rand.New(rand.NewSource(rand.Int63())),
|
||||
log: log.Named("enr-cache"),
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
// updating cache
|
||||
func (c *enrCache) updateCache(node *enode.Node) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
currNode, ok := c.data.Get(node.ID())
|
||||
if !ok || node.Seq() > currNode.(*enode.Node).Seq() {
|
||||
c.data.Add(node.ID(), node)
|
||||
c.log.Debug("discovered px peer via discv5", zap.Stringer("enr", node))
|
||||
func (c *enrCache) updateCache(node *enode.Node) error {
|
||||
currNode := c.data.Get(node.ID())
|
||||
if currNode == nil || node.Seq() > currNode.Seq() {
|
||||
return c.data.Add(node)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get `numPeers` records of enr
|
||||
func (c *enrCache) getENRs(neededPeers int) ([]*pb.PeerInfo, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
func (c *enrCache) getENRs(neededPeers int, clusterIndex *ShardInfo) ([]*pb.PeerInfo, error) {
|
||||
//
|
||||
availablePeers := c.data.Len()
|
||||
if availablePeers == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if availablePeers < neededPeers {
|
||||
neededPeers = availablePeers
|
||||
}
|
||||
|
||||
perm := c.rng.Perm(availablePeers)[0:neededPeers]
|
||||
keys := c.data.Keys()
|
||||
nodes := c.data.GetRandomNodes(clusterIndex, neededPeers)
|
||||
result := []*pb.PeerInfo{}
|
||||
for _, ind := range perm {
|
||||
node, ok := c.data.Get(keys[ind])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, node := range nodes {
|
||||
//
|
||||
var b bytes.Buffer
|
||||
writer := bufio.NewWriter(&b)
|
||||
err := node.(*enode.Node).Record().EncodeRLP(writer)
|
||||
err := node.Record().EncodeRLP(writer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -57,17 +57,11 @@ func NewWakuPeerExchange(disc *discv5.DiscoveryV5, peerConnector PeerConnector,
|
||||
wakuPX.disc = disc
|
||||
wakuPX.metrics = newMetrics(reg)
|
||||
wakuPX.log = log.Named("wakupx")
|
||||
wakuPX.enrCache = newEnrCache(MaxCacheSize)
|
||||
wakuPX.peerConnector = peerConnector
|
||||
wakuPX.pm = pm
|
||||
wakuPX.CommonService = service.NewCommonService()
|
||||
|
||||
newEnrCache, err := newEnrCache(MaxCacheSize, wakuPX.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wakuPX.enrCache = newEnrCache
|
||||
|
||||
return wakuPX, nil
|
||||
}
|
||||
|
||||
@ -108,7 +102,7 @@ func (wakuPX *WakuPeerExchange) onRequest() func(network.Stream) {
|
||||
if requestRPC.Query != nil {
|
||||
logger.Info("request received")
|
||||
|
||||
records, err := wakuPX.enrCache.getENRs(int(requestRPC.Query.NumPeers))
|
||||
records, err := wakuPX.enrCache.getENRs(int(requestRPC.Query.NumPeers), nil)
|
||||
if err != nil {
|
||||
logger.Error("obtaining enrs from cache", zap.Error(err))
|
||||
wakuPX.metrics.RecordError(pxFailure)
|
||||
@ -161,7 +155,11 @@ func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
wakuPX.enrCache.updateCache(iterator.Node())
|
||||
err = wakuPX.enrCache.updateCache(iterator.Node())
|
||||
if err != nil {
|
||||
wakuPX.log.Error("adding peer to cache", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
||||
178
waku/v2/protocol/peer_exchange/shard_lru.go
Normal file
178
waku/v2/protocol/peer_exchange/shard_lru.go
Normal file
@ -0,0 +1,178 @@
|
||||
package peer_exchange
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol"
|
||||
wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
|
||||
"github.com/waku-org/go-waku/waku/v2/utils"
|
||||
)
|
||||
|
||||
type ShardInfo struct {
|
||||
clusterID uint16
|
||||
shard uint16
|
||||
}
|
||||
type shardLRU struct {
|
||||
size int // number of nodes allowed per shard
|
||||
idToNode map[enode.ID][]*list.Element
|
||||
shardNodes map[ShardInfo]*list.List
|
||||
rng *rand.Rand
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newShardLRU(size int) *shardLRU {
|
||||
return &shardLRU{
|
||||
idToNode: map[enode.ID][]*list.Element{},
|
||||
shardNodes: map[ShardInfo]*list.List{},
|
||||
size: size,
|
||||
rng: rand.New(rand.NewSource(rand.Int63())),
|
||||
}
|
||||
}
|
||||
|
||||
type nodeWithShardInfo struct {
|
||||
key ShardInfo
|
||||
node *enode.Node
|
||||
}
|
||||
|
||||
// time complexity: O(number of previous indexes present for node.ID)
|
||||
func (l *shardLRU) remove(node *enode.Node) {
|
||||
elements := l.idToNode[node.ID()]
|
||||
for _, element := range elements {
|
||||
key := element.Value.(nodeWithShardInfo).key
|
||||
l.shardNodes[key].Remove(element)
|
||||
}
|
||||
delete(l.idToNode, node.ID())
|
||||
}
|
||||
|
||||
// if a node is removed for a list, remove it from idToNode too
|
||||
func (l *shardLRU) removeFromIdToNode(ele *list.Element) {
|
||||
nodeID := ele.Value.(nodeWithShardInfo).node.ID()
|
||||
for ind, entries := range l.idToNode[nodeID] {
|
||||
if entries == ele {
|
||||
l.idToNode[nodeID] = append(l.idToNode[nodeID][:ind], l.idToNode[nodeID][ind+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(l.idToNode[nodeID]) == 0 {
|
||||
delete(l.idToNode, nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
func nodeToRelayShard(node *enode.Node) (*protocol.RelayShards, error) {
|
||||
shard, err := wenr.RelaySharding(node.Record())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if shard == nil { // if no shard info, then add to node to Cluster 0, Index 0
|
||||
shard = &protocol.RelayShards{
|
||||
ClusterID: 0,
|
||||
ShardIDs: []uint16{0},
|
||||
}
|
||||
}
|
||||
|
||||
return shard, nil
|
||||
}
|
||||
|
||||
// time complexity: O(new number of indexes in node's shard)
|
||||
func (l *shardLRU) add(node *enode.Node) error {
|
||||
shard, err := nodeToRelayShard(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elements := []*list.Element{}
|
||||
for _, index := range shard.ShardIDs {
|
||||
key := ShardInfo{
|
||||
shard.ClusterID,
|
||||
index,
|
||||
}
|
||||
if l.shardNodes[key] == nil {
|
||||
l.shardNodes[key] = list.New()
|
||||
}
|
||||
if l.shardNodes[key].Len() >= l.size {
|
||||
oldest := l.shardNodes[key].Back()
|
||||
l.removeFromIdToNode(oldest)
|
||||
l.shardNodes[key].Remove(oldest)
|
||||
}
|
||||
entry := l.shardNodes[key].PushFront(nodeWithShardInfo{
|
||||
key: key,
|
||||
node: node,
|
||||
})
|
||||
elements = append(elements, entry)
|
||||
|
||||
}
|
||||
l.idToNode[node.ID()] = elements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// this will be called when the seq number of node is more than the one in cache
|
||||
func (l *shardLRU) Add(node *enode.Node) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
// removing bcz previous node might be subscribed to different shards, we need to remove node from those shards
|
||||
l.remove(node)
|
||||
return l.add(node)
|
||||
}
|
||||
|
||||
// clusterIndex is nil when peers for no specific shard are requested
|
||||
func (l *shardLRU) GetRandomNodes(clusterIndex *ShardInfo, neededPeers int) (nodes []*enode.Node) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
availablePeers := l.len(clusterIndex)
|
||||
if availablePeers < neededPeers {
|
||||
neededPeers = availablePeers
|
||||
}
|
||||
// if clusterIndex is nil, then return all nodes
|
||||
var elements []*list.Element
|
||||
if clusterIndex == nil {
|
||||
elements = make([]*list.Element, 0, len(l.idToNode))
|
||||
for _, entries := range l.idToNode {
|
||||
elements = append(elements, entries[0])
|
||||
}
|
||||
} else if entries := l.shardNodes[*clusterIndex]; entries != nil && entries.Len() != 0 {
|
||||
elements = make([]*list.Element, 0, entries.Len())
|
||||
for ent := entries.Back(); ent != nil; ent = ent.Prev() {
|
||||
elements = append(elements, ent)
|
||||
}
|
||||
}
|
||||
utils.Logger().Info(fmt.Sprintf("%d", len(elements)))
|
||||
indexes := l.rng.Perm(len(elements))[0:neededPeers]
|
||||
for _, ind := range indexes {
|
||||
node := elements[ind].Value.(nodeWithShardInfo).node
|
||||
nodes = append(nodes, node)
|
||||
// this removes the node from all list (all cluster/shard pair that the node has) and adds it to the front
|
||||
l.remove(node)
|
||||
_ = l.add(node)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// if clusterIndex is not nil, return len of nodes maintained for a given shard
|
||||
// if clusterIndex is nil, return count of all nodes maintained
|
||||
func (l *shardLRU) len(clusterIndex *ShardInfo) int {
|
||||
if clusterIndex == nil {
|
||||
return len(l.idToNode)
|
||||
}
|
||||
if entries := l.shardNodes[*clusterIndex]; entries != nil {
|
||||
return entries.Len()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// get the node with the given id, if it is present in cache
|
||||
func (l *shardLRU) Get(id enode.ID) *enode.Node {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
if elements, ok := l.idToNode[id]; ok && len(elements) > 0 {
|
||||
return elements[0].Value.(nodeWithShardInfo).node
|
||||
}
|
||||
return nil
|
||||
}
|
||||
156
waku/v2/protocol/peer_exchange/shard_lru_test.go
Normal file
156
waku/v2/protocol/peer_exchange/shard_lru_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
package peer_exchange
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"testing"
|
||||
|
||||
gcrypto "github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol"
|
||||
wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
|
||||
)
|
||||
|
||||
func getEnode(t *testing.T, key *ecdsa.PrivateKey, cluster uint16, indexes ...uint16) *enode.Node {
|
||||
if key == nil {
|
||||
var err error
|
||||
key, err = gcrypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
myNode, err := wenr.NewLocalnode(key)
|
||||
require.NoError(t, err)
|
||||
if cluster != 0 {
|
||||
shard, err := protocol.NewRelayShards(cluster, indexes...)
|
||||
require.NoError(t, err)
|
||||
myNode.Set(enr.WithEntry(wenr.ShardingBitVectorEnrField, shard.BitVector()))
|
||||
}
|
||||
return myNode.Node()
|
||||
}
|
||||
|
||||
func TestLruMoreThanSize(t *testing.T) {
|
||||
node1 := getEnode(t, nil, 1, 1)
|
||||
node2 := getEnode(t, nil, 1, 1, 2)
|
||||
node3 := getEnode(t, nil, 1, 1)
|
||||
|
||||
lru := newShardLRU(2)
|
||||
|
||||
err := lru.Add(node1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = lru.Add(node2)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = lru.Add(node3)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 1)
|
||||
|
||||
require.Equal(t, 1, len(nodes))
|
||||
if nodes[0].ID() != node2.ID() && nodes[0].ID() != node3.ID() {
|
||||
t.Fatalf("different node found %v", nodes)
|
||||
}
|
||||
// checks if removed node is deleted from lru or not
|
||||
require.Nil(t, lru.Get(node1.ID()))
|
||||
|
||||
// node 2 is removed from lru for cluster 1/shard 1 but it is still being maintained for cluster1,index2
|
||||
{
|
||||
err = lru.Add(getEnode(t, nil, 1, 1)) // add two more nodes to 1/1 cluster shard
|
||||
require.NoError(t, err)
|
||||
|
||||
err = lru.Add(getEnode(t, nil, 1, 1))
|
||||
require.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
// node2 still present in lru
|
||||
require.Equal(t, node2, lru.Get(node2.ID()))
|
||||
|
||||
// now node2 is removed from all shards' cache
|
||||
{
|
||||
err = lru.Add(getEnode(t, nil, 1, 2)) // add two more nodes to 1/2 cluster shard
|
||||
require.NoError(t, err)
|
||||
|
||||
err = lru.Add(getEnode(t, nil, 1, 2))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// node2 still present in lru
|
||||
require.Nil(t, lru.Get(node2.ID()))
|
||||
}
|
||||
|
||||
func TestLruNodeWithNewSeq(t *testing.T) {
|
||||
lru := newShardLRU(2)
|
||||
//
|
||||
key, err := gcrypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
node1 := getEnode(t, key, 1, 1)
|
||||
err = lru.Add(node1)
|
||||
require.NoError(t, err)
|
||||
|
||||
node1 = getEnode(t, key, 1, 2, 3)
|
||||
err = lru.Add(node1)
|
||||
require.NoError(t, err)
|
||||
|
||||
//
|
||||
nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 2)
|
||||
require.Equal(t, 0, len(nodes))
|
||||
//
|
||||
nodes = lru.GetRandomNodes(&ShardInfo{1, 2}, 2)
|
||||
require.Equal(t, 1, len(nodes))
|
||||
//
|
||||
nodes = lru.GetRandomNodes(&ShardInfo{1, 3}, 2)
|
||||
require.Equal(t, 1, len(nodes))
|
||||
}
|
||||
|
||||
func TestLruNoShard(t *testing.T) {
|
||||
lru := newShardLRU(2)
|
||||
|
||||
node1 := getEnode(t, nil, 0)
|
||||
node2 := getEnode(t, nil, 0)
|
||||
|
||||
err := lru.Add(node1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = lru.Add(node2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// check returned nodes
|
||||
require.Equal(t, 2, lru.len(nil))
|
||||
for _, node := range lru.GetRandomNodes(nil, 2) {
|
||||
if node.ID() != node1.ID() && node.ID() != node2.ID() {
|
||||
t.Fatalf("different node found %v", node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checks if lru is able to handle nodes with/without shards together
|
||||
func TestLruMixedNodes(t *testing.T) {
|
||||
lru := newShardLRU(2)
|
||||
|
||||
node1 := getEnode(t, nil, 0)
|
||||
err := lru.Add(node1)
|
||||
require.NoError(t, err)
|
||||
|
||||
node2 := getEnode(t, nil, 1, 1)
|
||||
err = lru.Add(node2)
|
||||
require.NoError(t, err)
|
||||
|
||||
node3 := getEnode(t, nil, 1, 2)
|
||||
err = lru.Add(node3)
|
||||
require.NoError(t, err)
|
||||
|
||||
// check that default
|
||||
require.Equal(t, 3, lru.len(nil))
|
||||
|
||||
//
|
||||
nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 2)
|
||||
require.Equal(t, 1, len(nodes))
|
||||
require.Equal(t, node2.ID(), nodes[0].ID())
|
||||
|
||||
//
|
||||
nodes = lru.GetRandomNodes(&ShardInfo{1, 2}, 2)
|
||||
require.Equal(t, 1, len(nodes))
|
||||
require.Equal(t, node3.ID(), nodes[0].ID())
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user