diff --git a/waku/v2/node/keepalive_test.go b/waku/v2/node/keepalive_test.go index f348a72e..f9300e9a 100644 --- a/waku/v2/node/keepalive_test.go +++ b/waku/v2/node/keepalive_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" "github.com/stretchr/testify/require" ) @@ -29,10 +30,18 @@ func TestKeepAlive(t *testing.T) { ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second) defer cancel2() - wg := &sync.WaitGroup{} - pingPeer(ctx2, wg, host1, host2.ID()) + w := &WakuNode{ + host: host1, + ctx: ctx2, + wg: wg, + keepAliveMutex: sync.Mutex{}, + keepAliveFails: make(map[peer.ID]int), + } + + w.wg.Add(1) + w.pingPeer(host2.ID()) require.NoError(t, ctx.Err()) } diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index d5ba5462..c8fb6df7 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -35,6 +35,8 @@ import ( var log = logging.Logger("wakunode") +const maxAllowedPingFailures = 2 + type Message []byte type Peer struct { @@ -65,6 +67,9 @@ type WakuNode struct { identificationEventSub event.Subscription addressChangesSub event.Subscription + keepAliveMutex sync.Mutex + keepAliveFails map[peer.ID]int + ctx context.Context cancel context.CancelFunc quit chan struct{} @@ -126,6 +131,7 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) { w.quit = make(chan struct{}) w.wg = &sync.WaitGroup{} w.addrChan = make(chan ma.Multiaddr, 1024) + w.keepAliveFails = make(map[peer.ID]int) if w.protocolEventSub, err = host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)); err != nil { return nil, err @@ -588,14 +594,11 @@ func (w *WakuNode) Peers() ([]*Peer, error) { // This is necessary because TCP connections are automatically closed due to inactivity, // and doing a ping will avoid this (with a small bandwidth cost) func (w *WakuNode) startKeepAlive(t time.Duration) { - defer w.wg.Done() - - log.Info("Setting up ping protocol with duration of ", t) - - ticker := time.NewTicker(t) - defer ticker.Stop() - go func() { + defer w.wg.Done() + log.Info("Setting up ping protocol with duration of ", t) + ticker := time.NewTicker(t) + defer ticker.Stop() for { select { case <-ticker.C: @@ -607,7 +610,8 @@ func (w *WakuNode) startKeepAlive(t time.Duration) { // through Network's peer collection, as it will be empty for _, p := range w.host.Peerstore().Peers() { if p != w.host.ID() { - go pingPeer(w.ctx, w.wg, w.host, p) + w.wg.Add(1) + go w.pingPeer(p) } } case <-w.quit: @@ -617,21 +621,34 @@ func (w *WakuNode) startKeepAlive(t time.Duration) { }() } -func pingPeer(ctx context.Context, wg *sync.WaitGroup, host host.Host, peer peer.ID) { - wg.Add(1) - defer wg.Done() +func (w *WakuNode) pingPeer(peer peer.ID) { + w.keepAliveMutex.Lock() + defer w.keepAliveMutex.Unlock() + defer w.wg.Done() - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(w.ctx, 3*time.Second) defer cancel() log.Debug("Pinging ", peer) - pr := ping.Ping(ctx, host, peer) + pr := ping.Ping(ctx, w.host, peer) select { case res := <-pr: if res.Error != nil { + w.keepAliveFails[peer]++ log.Debug(fmt.Sprintf("Could not ping %s: %s", peer, res.Error.Error())) + } else { + w.keepAliveFails[peer] = 0 } case <-ctx.Done(): + w.keepAliveFails[peer]++ log.Debug(fmt.Sprintf("Could not ping %s: %s", peer, ctx.Err())) } + + if w.keepAliveFails[peer] > maxAllowedPingFailures && w.host.Network().Connectedness(peer) == network.Connected { + log.Info("Disconnecting peer ", peer) + if err := w.host.Network().ClosePeer(peer); err != nil { + log.Debug(fmt.Sprintf("Could not close conn to peer %s: %s", peer, err)) + } + w.keepAliveFails[peer] = 0 + } }