mirror of https://github.com/status-im/go-waku.git
179 lines
4.7 KiB
Go
179 lines
4.7 KiB
Go
package peer_exchange
|
|
|
|
import (
|
|
"container/list"
|
|
"fmt"
|
|
"math/rand"
|
|
"sync"
|
|
|
|
"github.com/ethereum/go-ethereum/p2p/enode"
|
|
"github.com/waku-org/go-waku/waku/v2/protocol"
|
|
wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
|
|
"github.com/waku-org/go-waku/waku/v2/utils"
|
|
)
|
|
|
|
type ShardInfo struct {
|
|
clusterID uint16
|
|
shard uint16
|
|
}
|
|
type shardLRU struct {
|
|
size int // number of nodes allowed per shard
|
|
idToNode map[enode.ID][]*list.Element
|
|
shardNodes map[ShardInfo]*list.List
|
|
rng *rand.Rand
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func newShardLRU(size int) *shardLRU {
|
|
return &shardLRU{
|
|
idToNode: map[enode.ID][]*list.Element{},
|
|
shardNodes: map[ShardInfo]*list.List{},
|
|
size: size,
|
|
rng: rand.New(rand.NewSource(rand.Int63())),
|
|
}
|
|
}
|
|
|
|
type nodeWithShardInfo struct {
|
|
key ShardInfo
|
|
node *enode.Node
|
|
}
|
|
|
|
// time complexity: O(number of previous indexes present for node.ID)
|
|
func (l *shardLRU) remove(node *enode.Node) {
|
|
elements := l.idToNode[node.ID()]
|
|
for _, element := range elements {
|
|
key := element.Value.(nodeWithShardInfo).key
|
|
l.shardNodes[key].Remove(element)
|
|
}
|
|
delete(l.idToNode, node.ID())
|
|
}
|
|
|
|
// if a node is removed for a list, remove it from idToNode too
|
|
func (l *shardLRU) removeFromIdToNode(ele *list.Element) {
|
|
nodeID := ele.Value.(nodeWithShardInfo).node.ID()
|
|
for ind, entries := range l.idToNode[nodeID] {
|
|
if entries == ele {
|
|
l.idToNode[nodeID] = append(l.idToNode[nodeID][:ind], l.idToNode[nodeID][ind+1:]...)
|
|
break
|
|
}
|
|
}
|
|
if len(l.idToNode[nodeID]) == 0 {
|
|
delete(l.idToNode, nodeID)
|
|
}
|
|
}
|
|
|
|
func nodeToRelayShard(node *enode.Node) (*protocol.RelayShards, error) {
|
|
shard, err := wenr.RelaySharding(node.Record())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if shard == nil { // if no shard info, then add to node to Cluster 0, Index 0
|
|
shard = &protocol.RelayShards{
|
|
ClusterID: 0,
|
|
ShardIDs: []uint16{0},
|
|
}
|
|
}
|
|
|
|
return shard, nil
|
|
}
|
|
|
|
// time complexity: O(new number of indexes in node's shard)
|
|
func (l *shardLRU) add(node *enode.Node) error {
|
|
shard, err := nodeToRelayShard(node)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
elements := []*list.Element{}
|
|
for _, index := range shard.ShardIDs {
|
|
key := ShardInfo{
|
|
shard.ClusterID,
|
|
index,
|
|
}
|
|
if l.shardNodes[key] == nil {
|
|
l.shardNodes[key] = list.New()
|
|
}
|
|
if l.shardNodes[key].Len() >= l.size {
|
|
oldest := l.shardNodes[key].Back()
|
|
l.removeFromIdToNode(oldest)
|
|
l.shardNodes[key].Remove(oldest)
|
|
}
|
|
entry := l.shardNodes[key].PushFront(nodeWithShardInfo{
|
|
key: key,
|
|
node: node,
|
|
})
|
|
elements = append(elements, entry)
|
|
|
|
}
|
|
l.idToNode[node.ID()] = elements
|
|
|
|
return nil
|
|
}
|
|
|
|
// this will be called when the seq number of node is more than the one in cache
|
|
func (l *shardLRU) Add(node *enode.Node) error {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
// removing bcz previous node might be subscribed to different shards, we need to remove node from those shards
|
|
l.remove(node)
|
|
return l.add(node)
|
|
}
|
|
|
|
// clusterIndex is nil when peers for no specific shard are requested
|
|
func (l *shardLRU) GetRandomNodes(clusterIndex *ShardInfo, neededPeers int) (nodes []*enode.Node) {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
availablePeers := l.len(clusterIndex)
|
|
if availablePeers < neededPeers {
|
|
neededPeers = availablePeers
|
|
}
|
|
// if clusterIndex is nil, then return all nodes
|
|
var elements []*list.Element
|
|
if clusterIndex == nil {
|
|
elements = make([]*list.Element, 0, len(l.idToNode))
|
|
for _, entries := range l.idToNode {
|
|
elements = append(elements, entries[0])
|
|
}
|
|
} else if entries := l.shardNodes[*clusterIndex]; entries != nil && entries.Len() != 0 {
|
|
elements = make([]*list.Element, 0, entries.Len())
|
|
for ent := entries.Back(); ent != nil; ent = ent.Prev() {
|
|
elements = append(elements, ent)
|
|
}
|
|
}
|
|
utils.Logger().Info(fmt.Sprintf("%d", len(elements)))
|
|
indexes := l.rng.Perm(len(elements))[0:neededPeers]
|
|
for _, ind := range indexes {
|
|
node := elements[ind].Value.(nodeWithShardInfo).node
|
|
nodes = append(nodes, node)
|
|
// this removes the node from all list (all cluster/shard pair that the node has) and adds it to the front
|
|
l.remove(node)
|
|
_ = l.add(node)
|
|
}
|
|
return nodes
|
|
}
|
|
|
|
// if clusterIndex is not nil, return len of nodes maintained for a given shard
|
|
// if clusterIndex is nil, return count of all nodes maintained
|
|
func (l *shardLRU) len(clusterIndex *ShardInfo) int {
|
|
if clusterIndex == nil {
|
|
return len(l.idToNode)
|
|
}
|
|
if entries := l.shardNodes[*clusterIndex]; entries != nil {
|
|
return entries.Len()
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// get the node with the given id, if it is present in cache
|
|
func (l *shardLRU) Get(id enode.ID) *enode.Node {
|
|
l.mu.RLock()
|
|
defer l.mu.RUnlock()
|
|
|
|
if elements, ok := l.idToNode[id]; ok && len(elements) > 0 {
|
|
return elements[0].Value.(nodeWithShardInfo).node
|
|
}
|
|
return nil
|
|
}
|