From eba4aa43e5cb82ac6a6d315f615b3514554ad469 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Tue, 20 Jun 2023 16:39:20 -0400 Subject: [PATCH] feat: find discv5 peers with shards --- waku/v2/discv5/discover.go | 48 ++++++++++++++++++++++++++++++++-- waku/v2/protocol/enr/shards.go | 36 ++++++++++++++----------- 2 files changed, 67 insertions(+), 17 deletions(-) diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index 1ebda7e9..bd397af5 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -291,7 +291,7 @@ func (d *DiscoveryV5) Iterator() (enode.Iterator, error) { } } -func (d *DiscoveryV5) FindPeers(ctx context.Context, predicate func(*enode.Node) bool) (enode.Iterator, error) { +func (d *DiscoveryV5) FindPeersWithPredicate(ctx context.Context, predicate func(*enode.Node) bool) (enode.Iterator, error) { if d.listener == nil { return nil, ErrNoDiscV5Listener } @@ -304,6 +304,24 @@ func (d *DiscoveryV5) FindPeers(ctx context.Context, predicate func(*enode.Node) return iterator, nil } +func (d *DiscoveryV5) FindPeersWithShard(ctx context.Context, cluster, index uint16) (enode.Iterator, error) { + if d.listener == nil { + return nil, ErrNoDiscV5Listener + } + + iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode) + + predicate := func(node *enode.Node) bool { + rs, err := enr.RelaySharding(node.Record()) + if err != nil || rs == nil { + return false + } + return rs.Contains(cluster, index) + } + + return enode.Filter(iterator, predicate), nil +} + func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNode func(*enode.Node, peer.AddrInfo) error) { defer iterator.Close() @@ -337,7 +355,7 @@ func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNo } } -// Iterates over the nodes found via discv5, and sends them to peerConnector +// Iterates over the nodes found via discv5 belonging to the node's current shard, and sends them to peerConnector func (d *DiscoveryV5) peerLoop(ctx context.Context) error { iterator, err := d.Iterator() if err != nil { @@ -345,6 +363,32 @@ func (d *DiscoveryV5) peerLoop(ctx context.Context) error { return fmt.Errorf("obtaining iterator: %w", err) } + iterator = enode.Filter(iterator, func(n *enode.Node) bool { + // TODO: might make sense to extract the next line outside of the iterator + localRS, err := enr.RelaySharding(d.localnode.Node().Record()) + if err != nil || localRS == nil { + return false + } + + nodeRS, err := enr.RelaySharding(d.localnode.Node().Record()) + if err != nil || nodeRS == nil { + return false + } + + if nodeRS.Cluster != localRS.Cluster { + return false + } + + // Contains any + for _, idx := range nodeRS.Indices { + if nodeRS.Contains(localRS.Cluster, idx) { + return true + } + } + + return false + }) + defer iterator.Close() d.Iterate(ctx, iterator, func(n *enode.Node, p peer.AddrInfo) error { diff --git a/waku/v2/protocol/enr/shards.go b/waku/v2/protocol/enr/shards.go index d5f47b64..ede1fb97 100644 --- a/waku/v2/protocol/enr/shards.go +++ b/waku/v2/protocol/enr/shards.go @@ -53,9 +53,9 @@ func WithWakuRelayShardingTopics(topics ...string) ENROption { // ENR record accessors -func RelayShardingIndicesList(localnode *enode.LocalNode) (*protocol.RelayShards, error) { +func RelayShardingIndicesList(record *enr.Record) (*protocol.RelayShards, error) { var field []byte - if err := localnode.Node().Record().Load(enr.WithEntry(ShardingIndicesListEnrField, field)); err != nil { + if err := record.Load(enr.WithEntry(ShardingIndicesListEnrField, field)); err != nil { return nil, nil } @@ -67,10 +67,13 @@ func RelayShardingIndicesList(localnode *enode.LocalNode) (*protocol.RelayShards return &res, nil } -func RelayShardingBitVector(localnode *enode.LocalNode) (*protocol.RelayShards, error) { +func RelayShardingBitVector(record *enr.Record) (*protocol.RelayShards, error) { var field []byte - if err := localnode.Node().Record().Load(enr.WithEntry(ShardingBitVectorEnrField, field)); err != nil { - return nil, nil + if err := record.Load(enr.WithEntry(ShardingBitVectorEnrField, field)); err != nil { + if enr.IsNotFound(err) { + return nil, nil + } + return nil, err } res, err := protocol.FromBitVector(field) @@ -81,8 +84,8 @@ func RelayShardingBitVector(localnode *enode.LocalNode) (*protocol.RelayShards, return &res, nil } -func RelaySharding(localnode *enode.LocalNode) (*protocol.RelayShards, error) { - res, err := RelayShardingIndicesList(localnode) +func RelaySharding(record *enr.Record) (*protocol.RelayShards, error) { + res, err := RelayShardingIndicesList(record) if err != nil { return nil, err } @@ -91,17 +94,17 @@ func RelaySharding(localnode *enode.LocalNode) (*protocol.RelayShards, error) { return res, nil } - return RelayShardingBitVector(localnode) + return RelayShardingBitVector(record) } // Utils -func ContainsShard(localnode *enode.LocalNode, cluster uint16, index uint16) bool { +func ContainsShard(record *enr.Record, cluster uint16, index uint16) bool { if index > protocol.MaxShardIndex { return false } - rs, err := RelaySharding(localnode) + rs, err := RelaySharding(record) if err != nil { return false } @@ -109,19 +112,22 @@ func ContainsShard(localnode *enode.LocalNode, cluster uint16, index uint16) boo return rs.Contains(cluster, index) } -func ContainsShardWithNsTopic(localnode *enode.LocalNode, topic protocol.NamespacedPubsubTopic) bool { +func ContainsShardWithNsTopic(record *enr.Record, topic protocol.NamespacedPubsubTopic) bool { if topic.Kind() != protocol.StaticSharding { return false } shardTopic := topic.(protocol.StaticShardingPubsubTopic) - return ContainsShard(localnode, shardTopic.Cluster(), shardTopic.Shard()) - + return ContainsShard(record, shardTopic.Cluster(), shardTopic.Shard()) } -func ContainsShardTopic(localnode *enode.LocalNode, topic string) bool { +func ContainsRelayShard(record *enr.Record, topic protocol.StaticShardingPubsubTopic) bool { + return ContainsShardWithNsTopic(record, topic) +} + +func ContainsShardTopic(record *enr.Record, topic string) bool { shardTopic, err := protocol.ToShardedPubsubTopic(topic) if err != nil { return false } - return ContainsShardWithNsTopic(localnode, shardTopic) + return ContainsShardWithNsTopic(record, shardTopic) }