feat: find discv5 peers with shards

This commit is contained in:
Richard Ramos 2023-06-20 16:39:20 -04:00 committed by richΛrd
parent 0381b92531
commit eba4aa43e5
2 changed files with 67 additions and 17 deletions

View File

@ -291,7 +291,7 @@ func (d *DiscoveryV5) Iterator() (enode.Iterator, error) {
}
}
func (d *DiscoveryV5) FindPeers(ctx context.Context, predicate func(*enode.Node) bool) (enode.Iterator, error) {
func (d *DiscoveryV5) FindPeersWithPredicate(ctx context.Context, predicate func(*enode.Node) bool) (enode.Iterator, error) {
if d.listener == nil {
return nil, ErrNoDiscV5Listener
}
@ -304,6 +304,24 @@ func (d *DiscoveryV5) FindPeers(ctx context.Context, predicate func(*enode.Node)
return iterator, nil
}
func (d *DiscoveryV5) FindPeersWithShard(ctx context.Context, cluster, index uint16) (enode.Iterator, error) {
if d.listener == nil {
return nil, ErrNoDiscV5Listener
}
iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode)
predicate := func(node *enode.Node) bool {
rs, err := enr.RelaySharding(node.Record())
if err != nil || rs == nil {
return false
}
return rs.Contains(cluster, index)
}
return enode.Filter(iterator, predicate), nil
}
func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNode func(*enode.Node, peer.AddrInfo) error) {
defer iterator.Close()
@ -337,7 +355,7 @@ func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNo
}
}
// Iterates over the nodes found via discv5, and sends them to peerConnector
// Iterates over the nodes found via discv5 belonging to the node's current shard, and sends them to peerConnector
func (d *DiscoveryV5) peerLoop(ctx context.Context) error {
iterator, err := d.Iterator()
if err != nil {
@ -345,6 +363,32 @@ func (d *DiscoveryV5) peerLoop(ctx context.Context) error {
return fmt.Errorf("obtaining iterator: %w", err)
}
iterator = enode.Filter(iterator, func(n *enode.Node) bool {
// TODO: might make sense to extract the next line outside of the iterator
localRS, err := enr.RelaySharding(d.localnode.Node().Record())
if err != nil || localRS == nil {
return false
}
nodeRS, err := enr.RelaySharding(d.localnode.Node().Record())
if err != nil || nodeRS == nil {
return false
}
if nodeRS.Cluster != localRS.Cluster {
return false
}
// Contains any
for _, idx := range nodeRS.Indices {
if nodeRS.Contains(localRS.Cluster, idx) {
return true
}
}
return false
})
defer iterator.Close()
d.Iterate(ctx, iterator, func(n *enode.Node, p peer.AddrInfo) error {

View File

@ -53,9 +53,9 @@ func WithWakuRelayShardingTopics(topics ...string) ENROption {
// ENR record accessors
func RelayShardingIndicesList(localnode *enode.LocalNode) (*protocol.RelayShards, error) {
func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error) {
var field []byte
if err := localnode.Node().Record().Load(enr.WithEntry(ShardingIndicesListEnrField, field)); err != nil {
if err := record.Load(enr.WithEntry(ShardingIndicesListEnrField, field)); err != nil {
return nil, nil
}
@ -67,10 +67,13 @@ func RelayShardingIndicesList(localnode *enode.LocalNode) (*protocol.RelayShards
return &res, nil
}
func RelayShardingBitVector(localnode *enode.LocalNode) (*protocol.RelayShards, error) {
func RelayShardingBitVector(record *enr.Record) (*protocol.RelayShards, error) {
var field []byte
if err := localnode.Node().Record().Load(enr.WithEntry(ShardingBitVectorEnrField, field)); err != nil {
return nil, nil
if err := record.Load(enr.WithEntry(ShardingBitVectorEnrField, field)); err != nil {
if enr.IsNotFound(err) {
return nil, nil
}
return nil, err
}
res, err := protocol.FromBitVector(field)
@ -81,8 +84,8 @@ func RelayShardingBitVector(localnode *enode.LocalNode) (*protocol.RelayShards,
return &res, nil
}
func RelaySharding(localnode *enode.LocalNode) (*protocol.RelayShards, error) {
res, err := RelayShardingIndicesList(localnode)
func RelaySharding(record *enr.Record) (*protocol.RelayShards, error) {
res, err := RelayShardingIndicesList(record)
if err != nil {
return nil, err
}
@ -91,17 +94,17 @@ func RelaySharding(localnode *enode.LocalNode) (*protocol.RelayShards, error) {
return res, nil
}
return RelayShardingBitVector(localnode)
return RelayShardingBitVector(record)
}
// Utils
func ContainsShard(localnode *enode.LocalNode, cluster uint16, index uint16) bool {
func ContainsShard(record *enr.Record, cluster uint16, index uint16) bool {
if index > protocol.MaxShardIndex {
return false
}
rs, err := RelaySharding(localnode)
rs, err := RelaySharding(record)
if err != nil {
return false
}
@ -109,19 +112,22 @@ func ContainsShard(localnode *enode.LocalNode, cluster uint16, index uint16) boo
return rs.Contains(cluster, index)
}
func ContainsShardWithNsTopic(localnode *enode.LocalNode, topic protocol.NamespacedPubsubTopic) bool {
func ContainsShardWithNsTopic(record *enr.Record, topic protocol.NamespacedPubsubTopic) bool {
if topic.Kind() != protocol.StaticSharding {
return false
}
shardTopic := topic.(protocol.StaticShardingPubsubTopic)
return ContainsShard(localnode, shardTopic.Cluster(), shardTopic.Shard())
return ContainsShard(record, shardTopic.Cluster(), shardTopic.Shard())
}
func ContainsShardTopic(localnode *enode.LocalNode, topic string) bool {
func ContainsRelayShard(record *enr.Record, topic protocol.StaticShardingPubsubTopic) bool {
return ContainsShardWithNsTopic(record, topic)
}
func ContainsShardTopic(record *enr.Record, topic string) bool {
shardTopic, err := protocol.ToShardedPubsubTopic(topic)
if err != nil {
return false
}
return ContainsShardWithNsTopic(localnode, shardTopic)
return ContainsShardWithNsTopic(record, shardTopic)
}