refactor: shard LRU for storing peers by shard (#840)

Co-authored-by: Richard Ramos <info@richardramos.me>
This commit is contained in:
harsh jain 2023-12-07 01:34:58 +07:00 committed by GitHub
parent 48ab9e6ce7
commit d7c7255aa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 357 additions and 49 deletions

View File

@ -3,69 +3,45 @@ package peer_exchange
import (
"bufio"
"bytes"
"math/rand"
"sync"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/hashicorp/golang-lru/simplelru"
"github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange/pb"
"go.uber.org/zap"
)
// simpleLRU internal uses container/list, which is ring buffer(double linked list)
type enrCache struct {
// using lru, saves us from periodically cleaning the cache to maintain a certain size
data *simplelru.LRU
rng *rand.Rand
mu sync.RWMutex
log *zap.Logger
// using lru, saves us from periodically cleaning the cache to mauintain a certain size
data *shardLRU
}
// err on negative size
func newEnrCache(size int, log *zap.Logger) (*enrCache, error) {
inner, err := simplelru.NewLRU(size, nil)
func newEnrCache(size int) *enrCache {
inner := newShardLRU(int(size))
return &enrCache{
data: inner,
rng: rand.New(rand.NewSource(rand.Int63())),
log: log.Named("enr-cache"),
}, err
}
}
// updating cache
func (c *enrCache) updateCache(node *enode.Node) {
c.mu.Lock()
defer c.mu.Unlock()
currNode, ok := c.data.Get(node.ID())
if !ok || node.Seq() > currNode.(*enode.Node).Seq() {
c.data.Add(node.ID(), node)
c.log.Debug("discovered px peer via discv5", zap.Stringer("enr", node))
func (c *enrCache) updateCache(node *enode.Node) error {
currNode := c.data.Get(node.ID())
if currNode == nil || node.Seq() > currNode.Seq() {
return c.data.Add(node)
}
return nil
}
// get `numPeers` records of enr
func (c *enrCache) getENRs(neededPeers int) ([]*pb.PeerInfo, error) {
c.mu.RLock()
defer c.mu.RUnlock()
func (c *enrCache) getENRs(neededPeers int, clusterIndex *ShardInfo) ([]*pb.PeerInfo, error) {
//
availablePeers := c.data.Len()
if availablePeers == 0 {
return nil, nil
}
if availablePeers < neededPeers {
neededPeers = availablePeers
}
perm := c.rng.Perm(availablePeers)[0:neededPeers]
keys := c.data.Keys()
nodes := c.data.GetRandomNodes(clusterIndex, neededPeers)
result := []*pb.PeerInfo{}
for _, ind := range perm {
node, ok := c.data.Get(keys[ind])
if !ok {
continue
}
for _, node := range nodes {
//
var b bytes.Buffer
writer := bufio.NewWriter(&b)
err := node.(*enode.Node).Record().EncodeRLP(writer)
err := node.Record().EncodeRLP(writer)
if err != nil {
return nil, err
}

View File

@ -57,17 +57,11 @@ func NewWakuPeerExchange(disc *discv5.DiscoveryV5, peerConnector PeerConnector,
wakuPX.disc = disc
wakuPX.metrics = newMetrics(reg)
wakuPX.log = log.Named("wakupx")
wakuPX.enrCache = newEnrCache(MaxCacheSize)
wakuPX.peerConnector = peerConnector
wakuPX.pm = pm
wakuPX.CommonService = service.NewCommonService()
newEnrCache, err := newEnrCache(MaxCacheSize, wakuPX.log)
if err != nil {
return nil, err
}
wakuPX.enrCache = newEnrCache
return wakuPX, nil
}
@ -108,7 +102,7 @@ func (wakuPX *WakuPeerExchange) onRequest() func(network.Stream) {
if requestRPC.Query != nil {
logger.Info("request received")
records, err := wakuPX.enrCache.getENRs(int(requestRPC.Query.NumPeers))
records, err := wakuPX.enrCache.getENRs(int(requestRPC.Query.NumPeers), nil)
if err != nil {
logger.Error("obtaining enrs from cache", zap.Error(err))
wakuPX.metrics.RecordError(pxFailure)
@ -161,7 +155,11 @@ func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error {
continue
}
wakuPX.enrCache.updateCache(iterator.Node())
err = wakuPX.enrCache.updateCache(iterator.Node())
if err != nil {
wakuPX.log.Error("adding peer to cache", zap.Error(err))
continue
}
select {
case <-ctx.Done():

View File

@ -0,0 +1,178 @@
package peer_exchange
import (
"container/list"
"fmt"
"math/rand"
"sync"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/waku-org/go-waku/waku/v2/protocol"
wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
"github.com/waku-org/go-waku/waku/v2/utils"
)
type ShardInfo struct {
clusterID uint16
shard uint16
}
type shardLRU struct {
size int // number of nodes allowed per shard
idToNode map[enode.ID][]*list.Element
shardNodes map[ShardInfo]*list.List
rng *rand.Rand
mu sync.RWMutex
}
func newShardLRU(size int) *shardLRU {
return &shardLRU{
idToNode: map[enode.ID][]*list.Element{},
shardNodes: map[ShardInfo]*list.List{},
size: size,
rng: rand.New(rand.NewSource(rand.Int63())),
}
}
type nodeWithShardInfo struct {
key ShardInfo
node *enode.Node
}
// time complexity: O(number of previous indexes present for node.ID)
func (l *shardLRU) remove(node *enode.Node) {
elements := l.idToNode[node.ID()]
for _, element := range elements {
key := element.Value.(nodeWithShardInfo).key
l.shardNodes[key].Remove(element)
}
delete(l.idToNode, node.ID())
}
// if a node is removed for a list, remove it from idToNode too
func (l *shardLRU) removeFromIdToNode(ele *list.Element) {
nodeID := ele.Value.(nodeWithShardInfo).node.ID()
for ind, entries := range l.idToNode[nodeID] {
if entries == ele {
l.idToNode[nodeID] = append(l.idToNode[nodeID][:ind], l.idToNode[nodeID][ind+1:]...)
break
}
}
if len(l.idToNode[nodeID]) == 0 {
delete(l.idToNode, nodeID)
}
}
func nodeToRelayShard(node *enode.Node) (*protocol.RelayShards, error) {
shard, err := wenr.RelaySharding(node.Record())
if err != nil {
return nil, err
}
if shard == nil { // if no shard info, then add to node to Cluster 0, Index 0
shard = &protocol.RelayShards{
ClusterID: 0,
ShardIDs: []uint16{0},
}
}
return shard, nil
}
// time complexity: O(new number of indexes in node's shard)
func (l *shardLRU) add(node *enode.Node) error {
shard, err := nodeToRelayShard(node)
if err != nil {
return err
}
elements := []*list.Element{}
for _, index := range shard.ShardIDs {
key := ShardInfo{
shard.ClusterID,
index,
}
if l.shardNodes[key] == nil {
l.shardNodes[key] = list.New()
}
if l.shardNodes[key].Len() >= l.size {
oldest := l.shardNodes[key].Back()
l.removeFromIdToNode(oldest)
l.shardNodes[key].Remove(oldest)
}
entry := l.shardNodes[key].PushFront(nodeWithShardInfo{
key: key,
node: node,
})
elements = append(elements, entry)
}
l.idToNode[node.ID()] = elements
return nil
}
// this will be called when the seq number of node is more than the one in cache
func (l *shardLRU) Add(node *enode.Node) error {
l.mu.Lock()
defer l.mu.Unlock()
// removing bcz previous node might be subscribed to different shards, we need to remove node from those shards
l.remove(node)
return l.add(node)
}
// clusterIndex is nil when peers for no specific shard are requested
func (l *shardLRU) GetRandomNodes(clusterIndex *ShardInfo, neededPeers int) (nodes []*enode.Node) {
l.mu.Lock()
defer l.mu.Unlock()
availablePeers := l.len(clusterIndex)
if availablePeers < neededPeers {
neededPeers = availablePeers
}
// if clusterIndex is nil, then return all nodes
var elements []*list.Element
if clusterIndex == nil {
elements = make([]*list.Element, 0, len(l.idToNode))
for _, entries := range l.idToNode {
elements = append(elements, entries[0])
}
} else if entries := l.shardNodes[*clusterIndex]; entries != nil && entries.Len() != 0 {
elements = make([]*list.Element, 0, entries.Len())
for ent := entries.Back(); ent != nil; ent = ent.Prev() {
elements = append(elements, ent)
}
}
utils.Logger().Info(fmt.Sprintf("%d", len(elements)))
indexes := l.rng.Perm(len(elements))[0:neededPeers]
for _, ind := range indexes {
node := elements[ind].Value.(nodeWithShardInfo).node
nodes = append(nodes, node)
// this removes the node from all list (all cluster/shard pair that the node has) and adds it to the front
l.remove(node)
_ = l.add(node)
}
return nodes
}
// if clusterIndex is not nil, return len of nodes maintained for a given shard
// if clusterIndex is nil, return count of all nodes maintained
func (l *shardLRU) len(clusterIndex *ShardInfo) int {
if clusterIndex == nil {
return len(l.idToNode)
}
if entries := l.shardNodes[*clusterIndex]; entries != nil {
return entries.Len()
}
return 0
}
// get the node with the given id, if it is present in cache
func (l *shardLRU) Get(id enode.ID) *enode.Node {
l.mu.RLock()
defer l.mu.RUnlock()
if elements, ok := l.idToNode[id]; ok && len(elements) > 0 {
return elements[0].Value.(nodeWithShardInfo).node
}
return nil
}

View File

@ -0,0 +1,156 @@
package peer_exchange
import (
"crypto/ecdsa"
"testing"
gcrypto "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/waku/v2/protocol"
wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
)
func getEnode(t *testing.T, key *ecdsa.PrivateKey, cluster uint16, indexes ...uint16) *enode.Node {
if key == nil {
var err error
key, err = gcrypto.GenerateKey()
require.NoError(t, err)
}
myNode, err := wenr.NewLocalnode(key)
require.NoError(t, err)
if cluster != 0 {
shard, err := protocol.NewRelayShards(cluster, indexes...)
require.NoError(t, err)
myNode.Set(enr.WithEntry(wenr.ShardingBitVectorEnrField, shard.BitVector()))
}
return myNode.Node()
}
func TestLruMoreThanSize(t *testing.T) {
node1 := getEnode(t, nil, 1, 1)
node2 := getEnode(t, nil, 1, 1, 2)
node3 := getEnode(t, nil, 1, 1)
lru := newShardLRU(2)
err := lru.Add(node1)
require.NoError(t, err)
err = lru.Add(node2)
require.NoError(t, err)
err = lru.Add(node3)
require.NoError(t, err)
nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 1)
require.Equal(t, 1, len(nodes))
if nodes[0].ID() != node2.ID() && nodes[0].ID() != node3.ID() {
t.Fatalf("different node found %v", nodes)
}
// checks if removed node is deleted from lru or not
require.Nil(t, lru.Get(node1.ID()))
// node 2 is removed from lru for cluster 1/shard 1 but it is still being maintained for cluster1,index2
{
err = lru.Add(getEnode(t, nil, 1, 1)) // add two more nodes to 1/1 cluster shard
require.NoError(t, err)
err = lru.Add(getEnode(t, nil, 1, 1))
require.NoError(t, err)
}
// node2 still present in lru
require.Equal(t, node2, lru.Get(node2.ID()))
// now node2 is removed from all shards' cache
{
err = lru.Add(getEnode(t, nil, 1, 2)) // add two more nodes to 1/2 cluster shard
require.NoError(t, err)
err = lru.Add(getEnode(t, nil, 1, 2))
require.NoError(t, err)
}
// node2 still present in lru
require.Nil(t, lru.Get(node2.ID()))
}
func TestLruNodeWithNewSeq(t *testing.T) {
lru := newShardLRU(2)
//
key, err := gcrypto.GenerateKey()
require.NoError(t, err)
node1 := getEnode(t, key, 1, 1)
err = lru.Add(node1)
require.NoError(t, err)
node1 = getEnode(t, key, 1, 2, 3)
err = lru.Add(node1)
require.NoError(t, err)
//
nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 2)
require.Equal(t, 0, len(nodes))
//
nodes = lru.GetRandomNodes(&ShardInfo{1, 2}, 2)
require.Equal(t, 1, len(nodes))
//
nodes = lru.GetRandomNodes(&ShardInfo{1, 3}, 2)
require.Equal(t, 1, len(nodes))
}
func TestLruNoShard(t *testing.T) {
lru := newShardLRU(2)
node1 := getEnode(t, nil, 0)
node2 := getEnode(t, nil, 0)
err := lru.Add(node1)
require.NoError(t, err)
err = lru.Add(node2)
require.NoError(t, err)
// check returned nodes
require.Equal(t, 2, lru.len(nil))
for _, node := range lru.GetRandomNodes(nil, 2) {
if node.ID() != node1.ID() && node.ID() != node2.ID() {
t.Fatalf("different node found %v", node)
}
}
}
// checks if lru is able to handle nodes with/without shards together
func TestLruMixedNodes(t *testing.T) {
lru := newShardLRU(2)
node1 := getEnode(t, nil, 0)
err := lru.Add(node1)
require.NoError(t, err)
node2 := getEnode(t, nil, 1, 1)
err = lru.Add(node2)
require.NoError(t, err)
node3 := getEnode(t, nil, 1, 2)
err = lru.Add(node3)
require.NoError(t, err)
// check that default
require.Equal(t, 3, lru.len(nil))
//
nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 2)
require.Equal(t, 1, len(nodes))
require.Equal(t, node2.ID(), nodes[0].ID())
//
nodes = lru.GetRandomNodes(&ShardInfo{1, 2}, 2)
require.Equal(t, 1, len(nodes))
require.Equal(t, node3.ID(), nodes[0].ID())
}