diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index f862e7f5..493c502b 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -3,6 +3,7 @@ package discv5 import ( "context" "crypto/ecdsa" + "errors" "math/rand" "net" "sync" @@ -26,6 +27,7 @@ type DiscoveryV5 struct { discovery.Discovery params *discV5Parameters + ctx context.Context host host.Host config discover.Config udpAddr *net.UDPAddr @@ -39,7 +41,9 @@ type DiscoveryV5 struct { wg *sync.WaitGroup - peerCache peerCache + peerCache peerCache + discoverCtx context.Context + discoverCancelFunc context.CancelFunc } type peerCache struct { @@ -95,7 +99,9 @@ func DefaultOptions() []DiscoveryV5Option { } } -func NewDiscoveryV5(host host.Host, priv *ecdsa.PrivateKey, localnode *enode.LocalNode, log *zap.Logger, opts ...DiscoveryV5Option) (*DiscoveryV5, error) { +const MaxPeersToDiscover = 600 + +func NewDiscoveryV5(ctx context.Context, host host.Host, priv *ecdsa.PrivateKey, localnode *enode.LocalNode, log *zap.Logger, opts ...DiscoveryV5Option) (*DiscoveryV5, error) { params := new(discV5Parameters) optList := DefaultOptions() optList = append(optList, opts...) @@ -111,6 +117,7 @@ func NewDiscoveryV5(host host.Host, priv *ecdsa.PrivateKey, localnode *enode.Loc } return &DiscoveryV5{ + ctx: ctx, host: host, params: params, NAT: NAT, @@ -194,6 +201,10 @@ func (d *DiscoveryV5) Start() error { return err } + // create cancellable + d.discoverCtx, d.discoverCancelFunc = context.WithCancel(d.ctx) + go d.runDiscoveryV5Loop() + return nil } @@ -206,6 +217,7 @@ func (d *DiscoveryV5) Stop() { } close(d.quit) + d.discoverCancelFunc() d.listener.Close() d.listener = nil @@ -279,7 +291,7 @@ func (d *DiscoveryV5) Advertise(ctx context.Context, ns string, opts ...discover return 20 * time.Minute, nil } -func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int, doneCh chan struct{}) { +func (d *DiscoveryV5) iterate(iterator enode.Iterator, limit int, doneCh chan struct{}) { defer d.wg.Done() for { @@ -287,7 +299,7 @@ func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limi break } - if ctx.Err() != nil { + if d.discoverCtx.Err() != nil { break } @@ -308,6 +320,7 @@ func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limi continue } + d.peerCache.Lock() for _, p := range peerAddrs { d.peerCache.recs[p.ID] = PeerRecord{ expire: time.Now().Unix() + 3600, // Expires in 1hr @@ -315,7 +328,7 @@ func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limi Node: *iterator.Node(), } } - + d.peerCache.Unlock() } close(doneCh) @@ -337,6 +350,25 @@ func (d *DiscoveryV5) removeExpiredPeers() int { return newCacheSize } +func (d *DiscoveryV5) runDiscoveryV5Loop() { + iterator := d.listener.RandomNodes() + iterator = enode.Filter(iterator, evaluateNode) + defer iterator.Close() + + doneCh := make(chan struct{}) + + d.wg.Add(1) + + go d.iterate(iterator, MaxPeersToDiscover, doneCh) + + select { + case <-d.discoverCtx.Done(): + case <-doneCh: + } + + d.log.Warn("Discv5 loop stopped") +} + func (d *DiscoveryV5) FindNodes(ctx context.Context, topic string, opts ...discovery.Option) ([]PeerRecord, error) { // Get options var options discovery.Options @@ -345,10 +377,13 @@ func (d *DiscoveryV5) FindNodes(ctx context.Context, topic string, opts ...disco return nil, err } - const maxLimit = 600 limit := options.Limit - if limit == 0 || limit > maxLimit { - limit = maxLimit + if limit == 0 || limit > MaxPeersToDiscover { + limit = MaxPeersToDiscover + } + + if limit > MaxPeersToDiscover { + return nil, errors.New("limit should be less than allowed maximum") } // We are ignoring the topic. Future versions might use a map[string]*peerCache instead where the string represents the pubsub topic @@ -356,28 +391,7 @@ func (d *DiscoveryV5) FindNodes(ctx context.Context, topic string, opts ...disco d.peerCache.Lock() defer d.peerCache.Unlock() - cacheSize := d.removeExpiredPeers() - - // Discover new records if we don't have enough - if cacheSize < limit && d.listener != nil { - d.Lock() - - iterator := d.listener.RandomNodes() - iterator = enode.Filter(iterator, evaluateNode) - defer iterator.Close() - - doneCh := make(chan struct{}) - - d.wg.Add(1) - go d.iterate(ctx, iterator, limit, doneCh) - - select { - case <-ctx.Done(): - case <-doneCh: - } - - d.Unlock() - } + d.removeExpiredPeers() // Randomize and fill channel with available records count := len(d.peerCache.recs) diff --git a/waku/v2/discv5/discover_test.go b/waku/v2/discv5/discover_test.go index cf20858b..d707da2a 100644 --- a/waku/v2/discv5/discover_test.go +++ b/waku/v2/discv5/discover_test.go @@ -106,7 +106,7 @@ func TestDiscV5(t *testing.T) { ip1, _ := extractIP(host1.Addrs()[0]) l1, err := newLocalnode(prvKey1, ip1, udpPort1, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) - d1, err := NewDiscoveryV5(host1, prvKey1, l1, utils.Logger(), WithUDPPort(udpPort1)) + d1, err := NewDiscoveryV5(context.Background(), host1, prvKey1, l1, utils.Logger(), WithUDPPort(udpPort1)) require.NoError(t, err) // H2 @@ -116,7 +116,7 @@ func TestDiscV5(t *testing.T) { require.NoError(t, err) l2, err := newLocalnode(prvKey2, ip2, udpPort2, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) - d2, err := NewDiscoveryV5(host2, prvKey2, l2, utils.Logger(), WithUDPPort(udpPort2), WithBootnodes([]*enode.Node{d1.localnode.Node()})) + d2, err := NewDiscoveryV5(context.Background(), host2, prvKey2, l2, utils.Logger(), WithUDPPort(udpPort2), WithBootnodes([]*enode.Node{d1.localnode.Node()})) require.NoError(t, err) // H3 @@ -126,7 +126,7 @@ func TestDiscV5(t *testing.T) { require.NoError(t, err) l3, err := newLocalnode(prvKey3, ip3, udpPort3, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) - d3, err := NewDiscoveryV5(host3, prvKey3, l3, utils.Logger(), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()})) + d3, err := NewDiscoveryV5(context.Background(), host3, prvKey3, l3, utils.Logger(), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()})) require.NoError(t, err) defer d1.Stop() @@ -142,6 +142,8 @@ func TestDiscV5(t *testing.T) { err = d3.Start() require.NoError(t, err) + time.Sleep(3 * time.Second) // Wait for nodes to be discovered + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -206,6 +208,8 @@ func TestDiscV5(t *testing.T) { err = d3.Start() require.NoError(t, err) + time.Sleep(3 * time.Second) // Wait for nodes to be discovered + foundHost1 = false foundHost2 = false diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index d3691c81..d958dd59 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -489,7 +489,7 @@ func (w *WakuNode) mountDiscV5() error { } var err error - w.discoveryV5, err = discv5.NewDiscoveryV5(w.Host(), w.opts.privKey, w.localNode, w.log, discV5Options...) + w.discoveryV5, err = discv5.NewDiscoveryV5(w.ctx, w.Host(), w.opts.privKey, w.localNode, w.log, discV5Options...) return err } diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go index fbb257bf..e5a5be26 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go @@ -105,7 +105,7 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { ip1, _ := extractIP(host1.Addrs()[0]) l1, err := newLocalnode(prvKey1, ip1, udpPort1, utils.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) require.NoError(t, err) - d1, err := discv5.NewDiscoveryV5(host1, prvKey1, l1, utils.Logger(), discv5.WithUDPPort(udpPort1)) + d1, err := discv5.NewDiscoveryV5(context.Background(), host1, prvKey1, l1, utils.Logger(), discv5.WithUDPPort(udpPort1)) require.NoError(t, err) // H2 @@ -115,7 +115,7 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { require.NoError(t, err) l2, err := newLocalnode(prvKey2, ip2, udpPort2, utils.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) require.NoError(t, err) - d2, err := discv5.NewDiscoveryV5(host2, prvKey2, l2, utils.Logger(), discv5.WithUDPPort(udpPort2), discv5.WithBootnodes([]*enode.Node{d1.Node()})) + d2, err := discv5.NewDiscoveryV5(context.Background(), host2, prvKey2, l2, utils.Logger(), discv5.WithUDPPort(udpPort2), discv5.WithBootnodes([]*enode.Node{d1.Node()})) require.NoError(t, err) // H3 @@ -133,6 +133,8 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { err = d2.Start() require.NoError(t, err) + time.Sleep(3 * time.Second) // Wait some time for peers to be discovered + // mount peer exchange px1 := NewWakuPeerExchange(context.Background(), host1, d1, utils.Logger()) px3 := NewWakuPeerExchange(context.Background(), host3, nil, utils.Logger()) @@ -143,8 +145,6 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { err = px3.Start() require.NoError(t, err) - time.Sleep(3 * time.Second) // Give the algorithm some time to work its magic - host3.Peerstore().AddAddrs(host1.ID(), host1.Addrs(), peerstore.PermanentAddrTTL) err = host3.Peerstore().AddProtocols(host1.ID(), string(PeerExchangeID_v20alpha1)) require.NoError(t, err)