diff --git a/waku/v2/protocol/peer_exchange/enr_cache.go b/waku/v2/protocol/peer_exchange/enr_cache.go index 06538d2b..90fb096a 100644 --- a/waku/v2/protocol/peer_exchange/enr_cache.go +++ b/waku/v2/protocol/peer_exchange/enr_cache.go @@ -3,69 +3,45 @@ package peer_exchange import ( "bufio" "bytes" - "math/rand" - "sync" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/hashicorp/golang-lru/simplelru" + "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange/pb" - "go.uber.org/zap" ) // simpleLRU internal uses container/list, which is ring buffer(double linked list) type enrCache struct { - // using lru, saves us from periodically cleaning the cache to maintain a certain size - data *simplelru.LRU - rng *rand.Rand - mu sync.RWMutex - log *zap.Logger + // using lru, saves us from periodically cleaning the cache to mauintain a certain size + data *shardLRU } // err on negative size -func newEnrCache(size int, log *zap.Logger) (*enrCache, error) { - inner, err := simplelru.NewLRU(size, nil) +func newEnrCache(size int) *enrCache { + inner := newShardLRU(int(size)) return &enrCache{ data: inner, - rng: rand.New(rand.NewSource(rand.Int63())), - log: log.Named("enr-cache"), - }, err + } } // updating cache -func (c *enrCache) updateCache(node *enode.Node) { - c.mu.Lock() - defer c.mu.Unlock() - currNode, ok := c.data.Get(node.ID()) - if !ok || node.Seq() > currNode.(*enode.Node).Seq() { - c.data.Add(node.ID(), node) - c.log.Debug("discovered px peer via discv5", zap.Stringer("enr", node)) +func (c *enrCache) updateCache(node *enode.Node) error { + currNode := c.data.Get(node.ID()) + if currNode == nil || node.Seq() > currNode.Seq() { + return c.data.Add(node) } + return nil } // get `numPeers` records of enr -func (c *enrCache) getENRs(neededPeers int) ([]*pb.PeerInfo, error) { - c.mu.RLock() - defer c.mu.RUnlock() +func (c *enrCache) getENRs(neededPeers int, clusterIndex *ShardInfo) ([]*pb.PeerInfo, error) { // - availablePeers := c.data.Len() - if availablePeers == 0 { - return nil, nil - } - if availablePeers < neededPeers { - neededPeers = availablePeers - } - - perm := c.rng.Perm(availablePeers)[0:neededPeers] - keys := c.data.Keys() + nodes := c.data.GetRandomNodes(clusterIndex, neededPeers) result := []*pb.PeerInfo{} - for _, ind := range perm { - node, ok := c.data.Get(keys[ind]) - if !ok { - continue - } + for _, node := range nodes { + // var b bytes.Buffer writer := bufio.NewWriter(&b) - err := node.(*enode.Node).Record().EncodeRLP(writer) + err := node.Record().EncodeRLP(writer) if err != nil { return nil, err } diff --git a/waku/v2/protocol/peer_exchange/protocol.go b/waku/v2/protocol/peer_exchange/protocol.go index e8af91c5..374cde3a 100644 --- a/waku/v2/protocol/peer_exchange/protocol.go +++ b/waku/v2/protocol/peer_exchange/protocol.go @@ -57,17 +57,11 @@ func NewWakuPeerExchange(disc *discv5.DiscoveryV5, peerConnector PeerConnector, wakuPX.disc = disc wakuPX.metrics = newMetrics(reg) wakuPX.log = log.Named("wakupx") + wakuPX.enrCache = newEnrCache(MaxCacheSize) wakuPX.peerConnector = peerConnector wakuPX.pm = pm wakuPX.CommonService = service.NewCommonService() - newEnrCache, err := newEnrCache(MaxCacheSize, wakuPX.log) - if err != nil { - return nil, err - } - - wakuPX.enrCache = newEnrCache - return wakuPX, nil } @@ -108,7 +102,7 @@ func (wakuPX *WakuPeerExchange) onRequest() func(network.Stream) { if requestRPC.Query != nil { logger.Info("request received") - records, err := wakuPX.enrCache.getENRs(int(requestRPC.Query.NumPeers)) + records, err := wakuPX.enrCache.getENRs(int(requestRPC.Query.NumPeers), nil) if err != nil { logger.Error("obtaining enrs from cache", zap.Error(err)) wakuPX.metrics.RecordError(pxFailure) @@ -161,7 +155,11 @@ func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error { continue } - wakuPX.enrCache.updateCache(iterator.Node()) + err = wakuPX.enrCache.updateCache(iterator.Node()) + if err != nil { + wakuPX.log.Error("adding peer to cache", zap.Error(err)) + continue + } select { case <-ctx.Done(): diff --git a/waku/v2/protocol/peer_exchange/shard_lru.go b/waku/v2/protocol/peer_exchange/shard_lru.go new file mode 100644 index 00000000..fade60a9 --- /dev/null +++ b/waku/v2/protocol/peer_exchange/shard_lru.go @@ -0,0 +1,178 @@ +package peer_exchange + +import ( + "container/list" + "fmt" + "math/rand" + "sync" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/waku-org/go-waku/waku/v2/protocol" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/waku-org/go-waku/waku/v2/utils" +) + +type ShardInfo struct { + clusterID uint16 + shard uint16 +} +type shardLRU struct { + size int // number of nodes allowed per shard + idToNode map[enode.ID][]*list.Element + shardNodes map[ShardInfo]*list.List + rng *rand.Rand + mu sync.RWMutex +} + +func newShardLRU(size int) *shardLRU { + return &shardLRU{ + idToNode: map[enode.ID][]*list.Element{}, + shardNodes: map[ShardInfo]*list.List{}, + size: size, + rng: rand.New(rand.NewSource(rand.Int63())), + } +} + +type nodeWithShardInfo struct { + key ShardInfo + node *enode.Node +} + +// time complexity: O(number of previous indexes present for node.ID) +func (l *shardLRU) remove(node *enode.Node) { + elements := l.idToNode[node.ID()] + for _, element := range elements { + key := element.Value.(nodeWithShardInfo).key + l.shardNodes[key].Remove(element) + } + delete(l.idToNode, node.ID()) +} + +// if a node is removed for a list, remove it from idToNode too +func (l *shardLRU) removeFromIdToNode(ele *list.Element) { + nodeID := ele.Value.(nodeWithShardInfo).node.ID() + for ind, entries := range l.idToNode[nodeID] { + if entries == ele { + l.idToNode[nodeID] = append(l.idToNode[nodeID][:ind], l.idToNode[nodeID][ind+1:]...) + break + } + } + if len(l.idToNode[nodeID]) == 0 { + delete(l.idToNode, nodeID) + } +} + +func nodeToRelayShard(node *enode.Node) (*protocol.RelayShards, error) { + shard, err := wenr.RelaySharding(node.Record()) + if err != nil { + return nil, err + } + + if shard == nil { // if no shard info, then add to node to Cluster 0, Index 0 + shard = &protocol.RelayShards{ + ClusterID: 0, + ShardIDs: []uint16{0}, + } + } + + return shard, nil +} + +// time complexity: O(new number of indexes in node's shard) +func (l *shardLRU) add(node *enode.Node) error { + shard, err := nodeToRelayShard(node) + if err != nil { + return err + } + + elements := []*list.Element{} + for _, index := range shard.ShardIDs { + key := ShardInfo{ + shard.ClusterID, + index, + } + if l.shardNodes[key] == nil { + l.shardNodes[key] = list.New() + } + if l.shardNodes[key].Len() >= l.size { + oldest := l.shardNodes[key].Back() + l.removeFromIdToNode(oldest) + l.shardNodes[key].Remove(oldest) + } + entry := l.shardNodes[key].PushFront(nodeWithShardInfo{ + key: key, + node: node, + }) + elements = append(elements, entry) + + } + l.idToNode[node.ID()] = elements + + return nil +} + +// this will be called when the seq number of node is more than the one in cache +func (l *shardLRU) Add(node *enode.Node) error { + l.mu.Lock() + defer l.mu.Unlock() + // removing bcz previous node might be subscribed to different shards, we need to remove node from those shards + l.remove(node) + return l.add(node) +} + +// clusterIndex is nil when peers for no specific shard are requested +func (l *shardLRU) GetRandomNodes(clusterIndex *ShardInfo, neededPeers int) (nodes []*enode.Node) { + l.mu.Lock() + defer l.mu.Unlock() + + availablePeers := l.len(clusterIndex) + if availablePeers < neededPeers { + neededPeers = availablePeers + } + // if clusterIndex is nil, then return all nodes + var elements []*list.Element + if clusterIndex == nil { + elements = make([]*list.Element, 0, len(l.idToNode)) + for _, entries := range l.idToNode { + elements = append(elements, entries[0]) + } + } else if entries := l.shardNodes[*clusterIndex]; entries != nil && entries.Len() != 0 { + elements = make([]*list.Element, 0, entries.Len()) + for ent := entries.Back(); ent != nil; ent = ent.Prev() { + elements = append(elements, ent) + } + } + utils.Logger().Info(fmt.Sprintf("%d", len(elements))) + indexes := l.rng.Perm(len(elements))[0:neededPeers] + for _, ind := range indexes { + node := elements[ind].Value.(nodeWithShardInfo).node + nodes = append(nodes, node) + // this removes the node from all list (all cluster/shard pair that the node has) and adds it to the front + l.remove(node) + _ = l.add(node) + } + return nodes +} + +// if clusterIndex is not nil, return len of nodes maintained for a given shard +// if clusterIndex is nil, return count of all nodes maintained +func (l *shardLRU) len(clusterIndex *ShardInfo) int { + if clusterIndex == nil { + return len(l.idToNode) + } + if entries := l.shardNodes[*clusterIndex]; entries != nil { + return entries.Len() + } + return 0 +} + +// get the node with the given id, if it is present in cache +func (l *shardLRU) Get(id enode.ID) *enode.Node { + l.mu.RLock() + defer l.mu.RUnlock() + + if elements, ok := l.idToNode[id]; ok && len(elements) > 0 { + return elements[0].Value.(nodeWithShardInfo).node + } + return nil +} diff --git a/waku/v2/protocol/peer_exchange/shard_lru_test.go b/waku/v2/protocol/peer_exchange/shard_lru_test.go new file mode 100644 index 00000000..3bbd7f46 --- /dev/null +++ b/waku/v2/protocol/peer_exchange/shard_lru_test.go @@ -0,0 +1,156 @@ +package peer_exchange + +import ( + "crypto/ecdsa" + "testing" + + gcrypto "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/stretchr/testify/require" + "github.com/waku-org/go-waku/waku/v2/protocol" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" +) + +func getEnode(t *testing.T, key *ecdsa.PrivateKey, cluster uint16, indexes ...uint16) *enode.Node { + if key == nil { + var err error + key, err = gcrypto.GenerateKey() + require.NoError(t, err) + } + myNode, err := wenr.NewLocalnode(key) + require.NoError(t, err) + if cluster != 0 { + shard, err := protocol.NewRelayShards(cluster, indexes...) + require.NoError(t, err) + myNode.Set(enr.WithEntry(wenr.ShardingBitVectorEnrField, shard.BitVector())) + } + return myNode.Node() +} + +func TestLruMoreThanSize(t *testing.T) { + node1 := getEnode(t, nil, 1, 1) + node2 := getEnode(t, nil, 1, 1, 2) + node3 := getEnode(t, nil, 1, 1) + + lru := newShardLRU(2) + + err := lru.Add(node1) + require.NoError(t, err) + + err = lru.Add(node2) + require.NoError(t, err) + + err = lru.Add(node3) + require.NoError(t, err) + + nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 1) + + require.Equal(t, 1, len(nodes)) + if nodes[0].ID() != node2.ID() && nodes[0].ID() != node3.ID() { + t.Fatalf("different node found %v", nodes) + } + // checks if removed node is deleted from lru or not + require.Nil(t, lru.Get(node1.ID())) + + // node 2 is removed from lru for cluster 1/shard 1 but it is still being maintained for cluster1,index2 + { + err = lru.Add(getEnode(t, nil, 1, 1)) // add two more nodes to 1/1 cluster shard + require.NoError(t, err) + + err = lru.Add(getEnode(t, nil, 1, 1)) + require.NoError(t, err) + + } + + // node2 still present in lru + require.Equal(t, node2, lru.Get(node2.ID())) + + // now node2 is removed from all shards' cache + { + err = lru.Add(getEnode(t, nil, 1, 2)) // add two more nodes to 1/2 cluster shard + require.NoError(t, err) + + err = lru.Add(getEnode(t, nil, 1, 2)) + require.NoError(t, err) + } + + // node2 still present in lru + require.Nil(t, lru.Get(node2.ID())) +} + +func TestLruNodeWithNewSeq(t *testing.T) { + lru := newShardLRU(2) + // + key, err := gcrypto.GenerateKey() + require.NoError(t, err) + + node1 := getEnode(t, key, 1, 1) + err = lru.Add(node1) + require.NoError(t, err) + + node1 = getEnode(t, key, 1, 2, 3) + err = lru.Add(node1) + require.NoError(t, err) + + // + nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 2) + require.Equal(t, 0, len(nodes)) + // + nodes = lru.GetRandomNodes(&ShardInfo{1, 2}, 2) + require.Equal(t, 1, len(nodes)) + // + nodes = lru.GetRandomNodes(&ShardInfo{1, 3}, 2) + require.Equal(t, 1, len(nodes)) +} + +func TestLruNoShard(t *testing.T) { + lru := newShardLRU(2) + + node1 := getEnode(t, nil, 0) + node2 := getEnode(t, nil, 0) + + err := lru.Add(node1) + require.NoError(t, err) + + err = lru.Add(node2) + require.NoError(t, err) + + // check returned nodes + require.Equal(t, 2, lru.len(nil)) + for _, node := range lru.GetRandomNodes(nil, 2) { + if node.ID() != node1.ID() && node.ID() != node2.ID() { + t.Fatalf("different node found %v", node) + } + } +} + +// checks if lru is able to handle nodes with/without shards together +func TestLruMixedNodes(t *testing.T) { + lru := newShardLRU(2) + + node1 := getEnode(t, nil, 0) + err := lru.Add(node1) + require.NoError(t, err) + + node2 := getEnode(t, nil, 1, 1) + err = lru.Add(node2) + require.NoError(t, err) + + node3 := getEnode(t, nil, 1, 2) + err = lru.Add(node3) + require.NoError(t, err) + + // check that default + require.Equal(t, 3, lru.len(nil)) + + // + nodes := lru.GetRandomNodes(&ShardInfo{1, 1}, 2) + require.Equal(t, 1, len(nodes)) + require.Equal(t, node2.ID(), nodes[0].ID()) + + // + nodes = lru.GetRandomNodes(&ShardInfo{1, 2}, 2) + require.Equal(t, 1, len(nodes)) + require.Equal(t, node3.ID(), nodes[0].ID()) +}