diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index 4188814e..1ebda7e9 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -47,9 +47,11 @@ type DiscoveryV5 struct { type discV5Parameters struct { autoUpdate bool + autoFindPeers bool bootnodes []*enode.Node udpPort uint advertiseAddr []multiaddr.Multiaddr + loopPredicate func(*enode.Node) bool } type DiscoveryV5Option func(*discV5Parameters) @@ -80,9 +82,22 @@ func WithUDPPort(port uint) DiscoveryV5Option { } } +func WithPredicate(predicate func(*enode.Node) bool) DiscoveryV5Option { + return func(params *discV5Parameters) { + params.loopPredicate = predicate + } +} + +func WithAutoFindPeers(find bool) DiscoveryV5Option { + return func(params *discV5Parameters) { + params.autoFindPeers = find + } +} + func DefaultOptions() []DiscoveryV5Option { return []DiscoveryV5Option{ WithUDPPort(9000), + WithAutoFindPeers(true), } } @@ -185,11 +200,13 @@ func (d *DiscoveryV5) Start(ctx context.Context) error { return err } - d.wg.Add(1) - go func() { - defer d.wg.Done() - d.runDiscoveryV5Loop(ctx) - }() + if d.params.autoFindPeers { + d.wg.Add(1) + go func() { + defer d.wg.Done() + d.runDiscoveryV5Loop(ctx) + }() + } return nil } @@ -266,18 +283,28 @@ func (d *DiscoveryV5) Iterator() (enode.Iterator, error) { return nil, ErrNoDiscV5Listener } - iterator := d.listener.RandomNodes() - return enode.Filter(iterator, evaluateNode), nil + iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode) + if d.params.loopPredicate != nil { + return enode.Filter(iterator, d.params.loopPredicate), nil + } else { + return iterator, nil + } } -// iterate over all fecthed peer addresses and send them to peerConnector -func (d *DiscoveryV5) iterate(ctx context.Context) error { - iterator, err := d.Iterator() - if err != nil { - metrics.RecordDiscV5Error(context.Background(), "iterator_failure") - return fmt.Errorf("obtaining iterator: %w", err) +func (d *DiscoveryV5) FindPeers(ctx context.Context, predicate func(*enode.Node) bool) (enode.Iterator, error) { + if d.listener == nil { + return nil, ErrNoDiscV5Listener } + iterator := enode.Filter(d.listener.RandomNodes(), evaluateNode) + if predicate != nil { + iterator = enode.Filter(iterator, predicate) + } + + return iterator, nil +} + +func (d *DiscoveryV5) Iterate(ctx context.Context, iterator enode.Iterator, onNode func(*enode.Node, peer.AddrInfo) error) { defer iterator.Close() for iterator.Next() { // while next exists, run for loop @@ -296,25 +323,45 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error { } if len(peerAddrs) != 0 { - peer := v2.PeerData{ - Origin: peers.Discv5, - AddrInfo: peerAddrs[0], - ENR: iterator.Node(), - } - - select { - case d.peerConnector.PeerChannel() <- peer: - case <-ctx.Done(): - return nil + err := onNode(iterator.Node(), peerAddrs[0]) + if err != nil { + d.log.Error("processing node", zap.Error(err)) } } select { case <-ctx.Done(): - return nil + return default: } } +} + +// Iterates over the nodes found via discv5, and sends them to peerConnector +func (d *DiscoveryV5) peerLoop(ctx context.Context) error { + iterator, err := d.Iterator() + if err != nil { + metrics.RecordDiscV5Error(context.Background(), "iterator_failure") + return fmt.Errorf("obtaining iterator: %w", err) + } + + defer iterator.Close() + + d.Iterate(ctx, iterator, func(n *enode.Node, p peer.AddrInfo) error { + peer := v2.PeerData{ + Origin: peers.Discv5, + AddrInfo: p, + ENR: n, + } + + select { + case d.peerConnector.PeerChannel() <- peer: + case <-ctx.Done(): + return nil + } + + return nil + }) return nil } @@ -323,7 +370,7 @@ func (d *DiscoveryV5) runDiscoveryV5Loop(ctx context.Context) { restartLoop: for { - err := d.iterate(ctx) + err := d.peerLoop(ctx) if err != nil { d.log.Debug("iterating discv5", zap.Error(err)) }