From dc08c448049afd26801b99e7827358554c20da36 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Thu, 27 Jul 2023 14:14:14 -0400 Subject: [PATCH] feat: add filters for discv5 --- waku/v2/discv5/discover.go | 69 +++++++--------------- waku/v2/discv5/filters.go | 52 ++++++++++++++++ waku/v2/protocol/peer_exchange/protocol.go | 2 +- 3 files changed, 75 insertions(+), 48 deletions(-) create mode 100644 waku/v2/discv5/filters.go diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index d81aa594..79faf8c7 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -18,7 +18,8 @@ import ( v2 "github.com/waku-org/go-waku/waku/v2" "github.com/waku-org/go-waku/waku/v2/metrics" "github.com/waku-org/go-waku/waku/v2/peers" - "github.com/waku-org/go-waku/waku/v2/protocol/enr" + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" @@ -276,7 +277,7 @@ func evaluateNode(node *enode.Node) bool { return false }*/ - _, err := enr.EnodeToPeerInfo(node) + _, err := wenr.EnodeToPeerInfo(node) if err != nil { metrics.RecordDiscV5Error(context.Background(), "peer_info_failure") @@ -287,53 +288,29 @@ func evaluateNode(node *enode.Node) bool { return true } -// get random nodes from DHT via discv5 listender -// used for caching enr address in peerExchange -// used for connecting to peers in discovery_connector -func (d *DiscoveryV5) Iterator() (enode.Iterator, error) { +// Predicate is a function that is applied to an iterator to filter the nodes to be retrieved according to some logic +type Predicate func(enode.Iterator) enode.Iterator + +// PeerIterator gets random nodes from DHT via discv5 listener. +// Used for caching enr address in peerExchange +// Used for connecting to peers in discovery_connector +func (d *DiscoveryV5) PeerIterator(predicate ...Predicate) (enode.Iterator, error) { if d.listener == nil { return nil, ErrNoDiscV5Listener } iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode) if d.params.loopPredicate != nil { - return enode.Filter(iterator, d.params.loopPredicate), nil - } else { - return iterator, nil - } -} - -func (d *DiscoveryV5) FindPeersWithPredicate(ctx context.Context, predicate func(*enode.Node) bool) (enode.Iterator, error) { - if d.listener == nil { - return nil, ErrNoDiscV5Listener + iterator = enode.Filter(iterator, d.params.loopPredicate) } - iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode) - if predicate != nil { - iterator = enode.Filter(iterator, predicate) + for _, p := range predicate { + iterator = p(iterator) } return iterator, nil } -func (d *DiscoveryV5) FindPeersWithShard(ctx context.Context, cluster, index uint16) (enode.Iterator, error) { - if d.listener == nil { - return nil, ErrNoDiscV5Listener - } - - iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode) - - predicate := func(node *enode.Node) bool { - rs, err := enr.RelaySharding(node.Record()) - if err != nil || rs == nil { - return false - } - return rs.Contains(cluster, index) - } - - return enode.Filter(iterator, predicate), nil -} - func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNode func(*enode.Node, peer.AddrInfo) error) { defer iterator.Close() @@ -356,7 +333,7 @@ func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNo } } - _, addresses, err := enr.Multiaddress(iterator.Node()) + _, addresses, err := wenr.Multiaddress(iterator.Node()) if err != nil { metrics.RecordDiscV5Error(context.Background(), "peer_info_failure") d.log.Error("extracting multiaddrs from enr", zap.Error(err)) @@ -409,14 +386,8 @@ func delayedHasNext(ctx context.Context, iterator enode.Iterator) bool { // Iterates over the nodes found via discv5 belonging to the node's current shard, and sends them to peerConnector func (d *DiscoveryV5) peerLoop(ctx context.Context) error { - iterator, err := d.Iterator() - if err != nil { - metrics.RecordDiscV5Error(context.Background(), "iterator_failure") - return fmt.Errorf("obtaining iterator: %w", err) - } - - iterator = enode.Filter(iterator, func(n *enode.Node) bool { - localRS, err := enr.RelaySharding(d.localnode.Node().Record()) + iterator, err := d.PeerIterator(FilterPredicate(func(n *enode.Node) bool { + localRS, err := wenr.RelaySharding(d.localnode.Node().Record()) if err != nil { return false } @@ -425,7 +396,7 @@ func (d *DiscoveryV5) peerLoop(ctx context.Context) error { return true } - nodeRS, err := enr.RelaySharding(n.Record()) + nodeRS, err := wenr.RelaySharding(n.Record()) if err != nil || nodeRS == nil { return false } @@ -442,7 +413,11 @@ func (d *DiscoveryV5) peerLoop(ctx context.Context) error { } return false - }) + })) + if err != nil { + metrics.RecordDiscV5Error(context.Background(), "iterator_failure") + return fmt.Errorf("obtaining iterator: %w", err) + } defer iterator.Close() diff --git a/waku/v2/discv5/filters.go b/waku/v2/discv5/filters.go new file mode 100644 index 00000000..9316b568 --- /dev/null +++ b/waku/v2/discv5/filters.go @@ -0,0 +1,52 @@ +package discv5 + +import ( + wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +// FilterPredicate is to create a Predicate using a custom function +func FilterPredicate(predicate func(*enode.Node) bool) Predicate { + return func(iterator enode.Iterator) enode.Iterator { + if predicate != nil { + iterator = enode.Filter(iterator, predicate) + } + + return iterator + } +} + +// FilterShard creates a Predicate that filters nodes that belong to a specific shard +func FilterShard(iterator enode.Iterator, cluster, index uint16) Predicate { + return func(iterator enode.Iterator) enode.Iterator { + predicate := func(node *enode.Node) bool { + rs, err := wenr.RelaySharding(node.Record()) + if err != nil || rs == nil { + return false + } + return rs.Contains(cluster, index) + } + return enode.Filter(iterator, predicate) + } +} + +// FilterCapabilities creates a Predicate to filter nodes that support specific protocols +func FilterCapabilities(iterator enode.Iterator, flags wenr.WakuEnrBitfield) Predicate { + return func(iterator enode.Iterator) enode.Iterator { + predicate := func(node *enode.Node) bool { + enrField := new(wenr.WakuEnrBitfield) + if err := node.Record().Load(enr.WithEntry(wenr.WakuENRField, &enrField)); err != nil { + return false + } + + if enrField == nil { + return false + } + + return *enrField&flags == flags + } + return enode.Filter(iterator, predicate) + } +} diff --git a/waku/v2/protocol/peer_exchange/protocol.go b/waku/v2/protocol/peer_exchange/protocol.go index 51f157aa..abb92602 100644 --- a/waku/v2/protocol/peer_exchange/protocol.go +++ b/waku/v2/protocol/peer_exchange/protocol.go @@ -136,7 +136,7 @@ func (wakuPX *WakuPeerExchange) Stop() { } func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error { - iterator, err := wakuPX.disc.Iterator() + iterator, err := wakuPX.disc.PeerIterator() if err != nil { return fmt.Errorf("obtaining iterator: %w", err) }