diff --git a/waku/v2/node/keepalive.go b/waku/v2/node/keepalive.go index 0fccf964..e6693a62 100644 --- a/waku/v2/node/keepalive.go +++ b/waku/v2/node/keepalive.go @@ -2,6 +2,7 @@ package node import ( "context" + "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -50,12 +51,15 @@ func (w *WakuNode) startKeepAlive(ctx context.Context, t time.Duration) { // Network's peers collection, // contains only currently active peers - for _, p := range w.host.Network().Peers() { + pingWg := sync.WaitGroup{} + peersToPing := w.host.Network().Peers() + pingWg.Add(len(peersToPing)) + for _, p := range peersToPing { if p != w.host.ID() { - w.wg.Add(1) - go w.pingPeer(ctx, p) + go w.pingPeer(ctx, &pingWg, p) } } + pingWg.Wait() lastTimeExecuted = w.timesource.Now() case <-ctx.Done(): @@ -65,10 +69,8 @@ func (w *WakuNode) startKeepAlive(ctx context.Context, t time.Duration) { } } -func (w *WakuNode) pingPeer(ctx context.Context, peer peer.ID) { - w.keepAliveMutex.Lock() - defer w.keepAliveMutex.Unlock() - defer w.wg.Done() +func (w *WakuNode) pingPeer(ctx context.Context, wg *sync.WaitGroup, peer peer.ID) { + defer wg.Done() ctx, cancel := context.WithTimeout(ctx, 7*time.Second) defer cancel() @@ -79,16 +81,21 @@ func (w *WakuNode) pingPeer(ctx context.Context, peer peer.ID) { select { case res := <-pr: if res.Error != nil { + w.keepAliveMutex.Lock() w.keepAliveFails[peer]++ + w.keepAliveMutex.Unlock() logger.Debug("could not ping", zap.Error(res.Error)) } else { delete(w.keepAliveFails, peer) } case <-ctx.Done(): + w.keepAliveMutex.Lock() w.keepAliveFails[peer]++ + w.keepAliveMutex.Unlock() logger.Debug("could not ping (context done)", zap.Error(ctx.Err())) } + w.keepAliveMutex.Lock() if w.keepAliveFails[peer] > maxAllowedPingFailures && w.host.Network().Connectedness(peer) == network.Connected { logger.Info("disconnecting peer") if err := w.host.Network().ClosePeer(peer); err != nil { @@ -96,4 +103,5 @@ func (w *WakuNode) pingPeer(ctx context.Context, peer peer.ID) { } w.keepAliveFails[peer] = 0 } + w.keepAliveMutex.Unlock() } diff --git a/waku/v2/node/keepalive_test.go b/waku/v2/node/keepalive_test.go index 44ac952b..f8a56b30 100644 --- a/waku/v2/node/keepalive_test.go +++ b/waku/v2/node/keepalive_test.go @@ -42,7 +42,7 @@ func TestKeepAlive(t *testing.T) { } w.wg.Add(1) - w.pingPeer(ctx2, host2.ID()) + w.pingPeer(ctx2, w.wg, host2.ID()) require.NoError(t, ctx.Err()) }