refactor(discv5): lock-free via atomics

This commit is contained in:
harsh-98 2023-05-04 11:09:51 +05:30 committed by RichΛrd
parent 46500b0de9
commit e391fe6a2f
1 changed files with 32 additions and 60 deletions

View File

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/host"
@ -26,8 +27,6 @@ import (
var ErrNoDiscV5Listener = errors.New("no discv5 listener")
type DiscoveryV5 struct {
sync.RWMutex
params *discV5Parameters
host host.Host
config discover.Config
@ -39,7 +38,7 @@ type DiscoveryV5 struct {
log *zap.Logger
started bool
started int32
cancel context.CancelFunc
wg *sync.WaitGroup
}
@ -136,6 +135,7 @@ func (d *DiscoveryV5) listen(ctx context.Context) error {
}
d.udpAddr = conn.LocalAddr().(*net.UDPAddr)
if d.NAT != nil && !d.udpAddr.IP.IsLoopback() {
d.wg.Add(1)
go func() {
@ -167,15 +167,16 @@ func (d *DiscoveryV5) SetHost(h host.Host) {
d.host = h
}
// only works if the discovery v5 hasn't been started yet.
func (d *DiscoveryV5) Start(ctx context.Context) error {
d.Lock()
defer d.Unlock()
// compare and swap sets the discovery v5 to `started` state
// and prevents multiple calls to the start method by being atomic.
if !atomic.CompareAndSwapInt32(&d.started, 0, 1) {
return nil
}
d.wg.Wait() // Waiting for any go routines to stop
ctx, cancel := context.WithCancel(ctx)
d.cancel = cancel
d.started = true
err := d.listen(ctx)
if err != nil {
@ -183,7 +184,10 @@ func (d *DiscoveryV5) Start(ctx context.Context) error {
}
d.wg.Add(1)
go d.runDiscoveryV5Loop(ctx)
go func() {
defer d.wg.Done()
d.runDiscoveryV5Loop(ctx)
}()
return nil
}
@ -196,16 +200,13 @@ func (d *DiscoveryV5) SetBootnodes(nodes []*enode.Node) error {
return d.listener.SetFallbackNodes(nodes)
}
// only works if the discovery v5 is in running state
// so we can assume that cancel method is set
func (d *DiscoveryV5) Stop() {
d.Lock()
defer d.Unlock()
if d.cancel == nil {
if !atomic.CompareAndSwapInt32(&d.started, 1, 0) { // if Discoveryv5 is running, set started to 0
return
}
d.cancel()
d.started = false
if d.listener != nil {
d.listener.Close()
@ -267,6 +268,7 @@ func (d *DiscoveryV5) Iterator() (enode.Iterator, error) {
return enode.Filter(iterator, evaluateNode), nil
}
// iterate over all fecthed peer addresses and send them to peerConnector
func (d *DiscoveryV5) iterate(ctx context.Context) error {
iterator, err := d.Iterator()
if err != nil {
@ -274,31 +276,9 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error {
return fmt.Errorf("obtaining iterator: %w", err)
}
closeCh := make(chan struct{}, 1)
defer close(closeCh)
// Closing iterator when context is cancelled or function is returning
d.wg.Add(1)
go func() {
defer d.wg.Done()
select {
case <-ctx.Done():
iterator.Close()
case <-closeCh:
iterator.Close()
}
}()
for {
if ctx.Err() != nil {
break
}
exists := iterator.Next()
if !exists {
break
}
defer iterator.Close()
for iterator.Next() { // while next exists, run for loop
_, addresses, err := enr.Multiaddress(iterator.Node())
if err != nil {
metrics.RecordDiscV5Error(context.Background(), "peer_info_failure")
@ -314,11 +294,12 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error {
}
if len(peerAddrs) != 0 {
select {
case <-ctx.Done():
return nil
case d.peerConnector.PeerChannel() <- peerAddrs[0]:
}
d.peerConnector.PeerChannel() <- peerAddrs[0]
}
select {
case <-ctx.Done():
return nil
default:
}
}
@ -326,32 +307,23 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error {
}
func (d *DiscoveryV5) runDiscoveryV5Loop(ctx context.Context) {
defer d.wg.Done()
ch := make(chan struct{}, 1)
ch <- struct{}{} // Initial execution
restartLoop:
for {
err := d.iterate(ctx)
if err != nil {
d.log.Debug("iterating discv5", zap.Error(err))
time.Sleep(2 * time.Second)
}
select {
case <-ch:
err := d.iterate(ctx)
if err != nil {
d.log.Debug("iterating discv5", zap.Error(err))
time.Sleep(2 * time.Second)
}
ch <- struct{}{}
case <-ctx.Done():
close(ch)
break restartLoop
default:
}
}
d.log.Warn("Discv5 loop stopped")
}
func (d *DiscoveryV5) IsStarted() bool {
d.RLock()
defer d.RUnlock()
return d.started
return atomic.LoadInt32(&d.started) == 1
}