diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index f21267eb..239bb47b 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/host" @@ -26,8 +27,6 @@ import ( var ErrNoDiscV5Listener = errors.New("no discv5 listener") type DiscoveryV5 struct { - sync.RWMutex - params *discV5Parameters host host.Host config discover.Config @@ -39,7 +38,7 @@ type DiscoveryV5 struct { log *zap.Logger - started bool + started int32 cancel context.CancelFunc wg *sync.WaitGroup } @@ -136,6 +135,7 @@ func (d *DiscoveryV5) listen(ctx context.Context) error { } d.udpAddr = conn.LocalAddr().(*net.UDPAddr) + if d.NAT != nil && !d.udpAddr.IP.IsLoopback() { d.wg.Add(1) go func() { @@ -167,15 +167,16 @@ func (d *DiscoveryV5) SetHost(h host.Host) { d.host = h } +// only works if the discovery v5 hasn't been started yet. func (d *DiscoveryV5) Start(ctx context.Context) error { - d.Lock() - defer d.Unlock() + // compare and swap sets the discovery v5 to `started` state + // and prevents multiple calls to the start method by being atomic. + if !atomic.CompareAndSwapInt32(&d.started, 0, 1) { + return nil + } - d.wg.Wait() // Waiting for any go routines to stop ctx, cancel := context.WithCancel(ctx) - d.cancel = cancel - d.started = true err := d.listen(ctx) if err != nil { @@ -183,7 +184,10 @@ func (d *DiscoveryV5) Start(ctx context.Context) error { } d.wg.Add(1) - go d.runDiscoveryV5Loop(ctx) + go func() { + defer d.wg.Done() + d.runDiscoveryV5Loop(ctx) + }() return nil } @@ -196,16 +200,13 @@ func (d *DiscoveryV5) SetBootnodes(nodes []*enode.Node) error { return d.listener.SetFallbackNodes(nodes) } +// only works if the discovery v5 is in running state +// so we can assume that cancel method is set func (d *DiscoveryV5) Stop() { - d.Lock() - defer d.Unlock() - - if d.cancel == nil { + if !atomic.CompareAndSwapInt32(&d.started, 1, 0) { // if Discoveryv5 is running, set started to 0 return } - d.cancel() - d.started = false if d.listener != nil { d.listener.Close() @@ -267,6 +268,7 @@ func (d *DiscoveryV5) Iterator() (enode.Iterator, error) { return enode.Filter(iterator, evaluateNode), nil } +// iterate over all fecthed peer addresses and send them to peerConnector func (d *DiscoveryV5) iterate(ctx context.Context) error { iterator, err := d.Iterator() if err != nil { @@ -274,31 +276,9 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error { return fmt.Errorf("obtaining iterator: %w", err) } - closeCh := make(chan struct{}, 1) - defer close(closeCh) - - // Closing iterator when context is cancelled or function is returning - d.wg.Add(1) - go func() { - defer d.wg.Done() - select { - case <-ctx.Done(): - iterator.Close() - case <-closeCh: - iterator.Close() - } - }() - - for { - if ctx.Err() != nil { - break - } - - exists := iterator.Next() - if !exists { - break - } + defer iterator.Close() + for iterator.Next() { // while next exists, run for loop _, addresses, err := enr.Multiaddress(iterator.Node()) if err != nil { metrics.RecordDiscV5Error(context.Background(), "peer_info_failure") @@ -314,11 +294,12 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error { } if len(peerAddrs) != 0 { - select { - case <-ctx.Done(): - return nil - case d.peerConnector.PeerChannel() <- peerAddrs[0]: - } + d.peerConnector.PeerChannel() <- peerAddrs[0] + } + select { + case <-ctx.Done(): + return nil + default: } } @@ -326,32 +307,23 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error { } func (d *DiscoveryV5) runDiscoveryV5Loop(ctx context.Context) { - defer d.wg.Done() - - ch := make(chan struct{}, 1) - ch <- struct{}{} // Initial execution restartLoop: for { + err := d.iterate(ctx) + if err != nil { + d.log.Debug("iterating discv5", zap.Error(err)) + time.Sleep(2 * time.Second) + } select { - case <-ch: - err := d.iterate(ctx) - if err != nil { - d.log.Debug("iterating discv5", zap.Error(err)) - time.Sleep(2 * time.Second) - } - ch <- struct{}{} case <-ctx.Done(): - close(ch) break restartLoop + default: } } d.log.Warn("Discv5 loop stopped") } func (d *DiscoveryV5) IsStarted() bool { - d.RLock() - defer d.RUnlock() - - return d.started + return atomic.LoadInt32(&d.started) == 1 }