diff --git a/cmd/waku/node.go b/cmd/waku/node.go index 695f08b0..2327bbbe 100644 --- a/cmd/waku/node.go +++ b/cmd/waku/node.go @@ -134,7 +134,7 @@ func Execute(options NodeOptions) error { node.WithLogLevel(lvl), node.WithPrivateKey(prvKey), node.WithHostAddress(hostAddr), - node.WithKeepAlive(options.KeepAlive), + node.WithKeepAlive(10*time.Second, options.KeepAlive), node.WithMaxPeerConnections(options.MaxPeerConnections), node.WithPrometheusRegisterer(prometheus.DefaultRegisterer), node.WithPeerStoreCapacity(options.PeerStoreCapacity), diff --git a/library/node.go b/library/node.go index f1e49b2c..f084b85a 100644 --- a/library/node.go +++ b/library/node.go @@ -163,7 +163,7 @@ func NewNode(instance *WakuInstance, configJSON string) error { opts := []node.WakuNodeOption{ node.WithPrivateKey(prvKey), node.WithHostAddress(hostAddr), - node.WithKeepAlive(time.Duration(*config.KeepAliveInterval) * time.Second), + node.WithKeepAlive(10*time.Second, time.Duration(*config.KeepAliveInterval)*time.Second), } if *config.EnableRelay { diff --git a/waku/v2/node/keepalive.go b/waku/v2/node/keepalive.go index a2a8256e..2cb03317 100644 --- a/waku/v2/node/keepalive.go +++ b/waku/v2/node/keepalive.go @@ -2,6 +2,8 @@ package node import ( "context" + "errors" + "math/rand" "sync" "time" @@ -10,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/waku-org/go-waku/logging" "go.uber.org/zap" + "golang.org/x/exp/maps" ) const maxAllowedPingFailures = 2 @@ -19,86 +22,155 @@ const maxAllowedPingFailures = 2 // the peers if they don't reply back const sleepDetectionIntervalFactor = 3 +const maxPeersToPing = 10 + // startKeepAlive creates a go routine that periodically pings connected peers. // This is necessary because TCP connections are automatically closed due to inactivity, // and doing a ping will avoid this (with a small bandwidth cost) -func (w *WakuNode) startKeepAlive(ctx context.Context, t time.Duration) { +func (w *WakuNode) startKeepAlive(ctx context.Context, randomPeersPingDuration time.Duration, allPeersPingDuration time.Duration) { defer w.wg.Done() - w.log.Info("setting up ping protocol", zap.Duration("duration", t)) - ticker := time.NewTicker(t) - defer ticker.Stop() + + if !w.opts.enableRelay { + return + } + + w.log.Info("setting up ping protocol", zap.Duration("randomPeersPingDuration", randomPeersPingDuration), zap.Duration("allPeersPingDuration", allPeersPingDuration)) + + randomPeersTickerC := make(<-chan time.Time) + if randomPeersPingDuration != 0 { + randomPeersTicker := time.NewTicker(randomPeersPingDuration) + defer randomPeersTicker.Stop() + randomPeersTickerC = randomPeersTicker.C + } + + allPeersTickerC := make(<-chan time.Time) + if randomPeersPingDuration != 0 { + allPeersTicker := time.NewTicker(randomPeersPingDuration) + defer allPeersTicker.Stop() + randomPeersTickerC = allPeersTicker.C + } lastTimeExecuted := w.timesource.Now() - sleepDetectionInterval := int64(t) * sleepDetectionIntervalFactor + sleepDetectionInterval := int64(randomPeersPingDuration) * sleepDetectionIntervalFactor for { + peersToPing := []peer.ID{} + select { - case <-ticker.C: + case <-allPeersTickerC: + relayPeersSet := make(map[peer.ID]struct{}) + for _, t := range w.Relay().Topics() { + for _, p := range w.Relay().PubSub().ListPeers(t) { + relayPeersSet[p] = struct{}{} + } + } + peersToPing = maps.Keys(relayPeersSet) + + case <-randomPeersTickerC: difference := w.timesource.Now().UnixNano() - lastTimeExecuted.UnixNano() - forceDisconnectOnPingFailure := false if difference > sleepDetectionInterval { - forceDisconnectOnPingFailure = true lastTimeExecuted = w.timesource.Now() - w.log.Warn("keep alive hasnt been executed recently. Killing connections to peers if ping fails") + w.log.Warn("keep alive hasnt been executed recently. Killing all connections") + for _, p := range w.host.Network().Peers() { + err := w.host.Network().ClosePeer(p) + if err != nil { + w.log.Debug("closing conn to peer", zap.Error(err)) + } + } continue } - // Network's peers collection, - // contains only currently active peers - pingWg := sync.WaitGroup{} - peersToPing := w.host.Network().Peers() - pingWg.Add(len(peersToPing)) - for _, p := range peersToPing { - if p != w.host.ID() { - go w.pingPeer(ctx, &pingWg, p, forceDisconnectOnPingFailure) + // Priorize mesh peers + meshPeersSet := make(map[peer.ID]struct{}) + for _, t := range w.Relay().Topics() { + for _, p := range w.Relay().PubSub().MeshPeers(t) { + meshPeersSet[p] = struct{}{} } } - pingWg.Wait() + peersToPing = append(peersToPing, maps.Keys(meshPeersSet)...) + + // Ping also some random relay peers + if maxPeersToPing-len(peersToPing) > 0 { + relayPeersSet := make(map[peer.ID]struct{}) + for _, t := range w.Relay().Topics() { + for _, p := range w.Relay().PubSub().ListPeers(t) { + if _, ok := meshPeersSet[p]; !ok { + relayPeersSet[p] = struct{}{} + } + } + } + + relayPeers := maps.Keys(relayPeersSet) + rand.Shuffle(len(relayPeers), func(i, j int) { relayPeers[i], relayPeers[j] = relayPeers[j], relayPeers[i] }) + + peerLen := maxPeersToPing - len(peersToPing) + if peerLen > len(relayPeers) { + peerLen = len(relayPeers) + } + peersToPing = append(peersToPing, relayPeers[0:peerLen]...) + } - lastTimeExecuted = w.timesource.Now() case <-ctx.Done(): w.log.Info("stopping ping protocol") return } + + pingWg := sync.WaitGroup{} + pingWg.Add(len(peersToPing)) + for _, p := range peersToPing { + go w.pingPeer(ctx, &pingWg, p) + } + pingWg.Wait() + + lastTimeExecuted = w.timesource.Now() } } -func (w *WakuNode) pingPeer(ctx context.Context, wg *sync.WaitGroup, peerID peer.ID, forceDisconnectOnFail bool) { +func (w *WakuNode) pingPeer(ctx context.Context, wg *sync.WaitGroup, peerID peer.ID) { defer wg.Done() + logger := w.log.With(logging.HostID("peer", peerID)) + + for i := 0; i < maxAllowedPingFailures; i++ { + if w.host.Network().Connectedness(peerID) != network.Connected { + // Peer is no longer connected. No need to ping + return + } + + logger.Debug("pinging") + + if w.tryPing(ctx, peerID, logger) { + return + } + } + + if w.host.Network().Connectedness(peerID) != network.Connected { + return + } + + logger.Info("disconnecting dead peer") + if err := w.host.Network().ClosePeer(peerID); err != nil { + logger.Debug("closing conn to peer", zap.Error(err)) + } +} + +func (w *WakuNode) tryPing(ctx context.Context, peerID peer.ID, logger *zap.Logger) bool { ctx, cancel := context.WithTimeout(ctx, 7*time.Second) defer cancel() - logger := w.log.With(logging.HostID("peer", peerID)) - logger.Debug("pinging") pr := ping.Ping(ctx, w.host, peerID) select { case res := <-pr: if res.Error != nil { - w.keepAliveMutex.Lock() - w.keepAliveFails[peerID]++ - w.keepAliveMutex.Unlock() logger.Debug("could not ping", zap.Error(res.Error)) - } else { - w.keepAliveMutex.Lock() - delete(w.keepAliveFails, peerID) - w.keepAliveMutex.Unlock() + return false } case <-ctx.Done(): - w.keepAliveMutex.Lock() - w.keepAliveFails[peerID]++ - w.keepAliveMutex.Unlock() - logger.Debug("could not ping (context done)", zap.Error(ctx.Err())) - } - - w.keepAliveMutex.Lock() - if (forceDisconnectOnFail || w.keepAliveFails[peerID] > maxAllowedPingFailures) && w.host.Network().Connectedness(peerID) == network.Connected { - logger.Info("disconnecting peer") - if err := w.host.Network().ClosePeer(peerID); err != nil { - logger.Debug("closing conn to peer", zap.Error(err)) + if !errors.Is(ctx.Err(), context.Canceled) { + logger.Debug("could not ping (context)", zap.Error(ctx.Err())) } - w.keepAliveFails[peerID] = 0 + return false } - w.keepAliveMutex.Unlock() + return true } diff --git a/waku/v2/node/keepalive_test.go b/waku/v2/node/keepalive_test.go index a778f5c9..0508fd79 100644 --- a/waku/v2/node/keepalive_test.go +++ b/waku/v2/node/keepalive_test.go @@ -9,7 +9,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/libp2p/go-libp2p" - "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -40,15 +39,13 @@ func TestKeepAlive(t *testing.T) { wg := &sync.WaitGroup{} w := &WakuNode{ - host: host1, - wg: wg, - log: utils.Logger(), - keepAliveMutex: sync.Mutex{}, - keepAliveFails: make(map[peer.ID]int), + host: host1, + wg: wg, + log: utils.Logger(), } w.wg.Add(1) - w.pingPeer(ctx2, w.wg, peerID2, false) + w.pingPeer(ctx2, w.wg, peerID2) require.NoError(t, ctx.Err()) } @@ -70,7 +67,7 @@ func TestPeriodicKeepAlive(t *testing.T) { WithPrivateKey(prvKey), WithHostAddress(hostAddr), WithWakuRelay(), - WithKeepAlive(time.Second), + WithKeepAlive(time.Minute, time.Second), ) require.NoError(t, err) diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index cf7a51ad..5032c823 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -116,9 +116,6 @@ type WakuNode struct { addressChangesSub event.Subscription enrChangeCh chan struct{} - keepAliveMutex sync.Mutex - keepAliveFails map[peer.ID]int - cancel context.CancelFunc wg *sync.WaitGroup @@ -193,7 +190,6 @@ func New(opts ...WakuNodeOption) (*WakuNode, error) { w.opts = params w.log = params.logger.Named("node2") w.wg = &sync.WaitGroup{} - w.keepAliveFails = make(map[peer.ID]int) w.wakuFlag = enr.NewWakuEnrBitfield(w.opts.enableLightPush, w.opts.enableFilterFullNode, w.opts.enableStore, w.opts.enableRelay) w.circuitRelayNodes = make(chan peer.AddrInfo) w.metrics = newMetrics(params.prometheusReg) @@ -382,9 +378,9 @@ func (w *WakuNode) Start(ctx context.Context) error { return err } - if w.opts.keepAliveInterval > time.Duration(0) { + if w.opts.keepAliveRandomPeersInterval > time.Duration(0) || w.opts.keepAliveAllPeersInterval > time.Duration(0) { w.wg.Add(1) - go w.startKeepAlive(ctx, w.opts.keepAliveInterval) + go w.startKeepAlive(ctx, w.opts.keepAliveRandomPeersInterval, w.opts.keepAliveAllPeersInterval) } w.metadata.SetHost(host) diff --git a/waku/v2/node/wakuoptions.go b/waku/v2/node/wakuoptions.go index 26a82d0d..82d96461 100644 --- a/waku/v2/node/wakuoptions.go +++ b/waku/v2/node/wakuoptions.go @@ -114,7 +114,8 @@ type WakuNodeParameters struct { rlnTreePath string rlnMembershipContractAddress common.Address - keepAliveInterval time.Duration + keepAliveRandomPeersInterval time.Duration + keepAliveAllPeersInterval time.Duration enableLightPush bool @@ -476,10 +477,14 @@ func WithLightPush(lightpushOpts ...lightpush.Option) WakuNodeOption { } // WithKeepAlive is a WakuNodeOption used to set the interval of time when -// each peer will be ping to keep the TCP connection alive -func WithKeepAlive(t time.Duration) WakuNodeOption { +// each peer will be ping to keep the TCP connection alive. Option accepts two +// intervals, the `randomPeersInterval`, which will be used to ping full mesh +// peers (if using relay) and random connected peers, and `allPeersInterval` +// which is used to ping all connected peers +func WithKeepAlive(randomPeersInterval time.Duration, allPeersInterval time.Duration) WakuNodeOption { return func(params *WakuNodeParameters) error { - params.keepAliveInterval = t + params.keepAliveRandomPeersInterval = randomPeersInterval + params.keepAliveAllPeersInterval = allPeersInterval return nil } } diff --git a/waku/v2/node/wakuoptions_test.go b/waku/v2/node/wakuoptions_test.go index 751c7158..9d4ed4f9 100644 --- a/waku/v2/node/wakuoptions_test.go +++ b/waku/v2/node/wakuoptions_test.go @@ -58,7 +58,7 @@ func TestWakuOptions(t *testing.T) { WithWakuStore(), WithMessageProvider(&persistence.DBStore{}), WithLightPush(), - WithKeepAlive(time.Hour), + WithKeepAlive(time.Minute, time.Hour), WithTopicHealthStatusChannel(topicHealthStatusChan), WithWakuStoreFactory(storeFactory), } @@ -107,7 +107,7 @@ func TestWakuRLNOptions(t *testing.T) { WithWakuStore(), WithMessageProvider(&persistence.DBStore{}), WithLightPush(), - WithKeepAlive(time.Hour), + WithKeepAlive(time.Minute, time.Hour), WithTopicHealthStatusChannel(topicHealthStatusChan), WithWakuStoreFactory(storeFactory), WithStaticRLNRelay(&index, handleSpam), @@ -147,7 +147,7 @@ func TestWakuRLNOptions(t *testing.T) { WithWakuStore(), WithMessageProvider(&persistence.DBStore{}), WithLightPush(), - WithKeepAlive(time.Hour), + WithKeepAlive(time.Minute, time.Hour), WithTopicHealthStatusChannel(topicHealthStatusChan), WithWakuStoreFactory(storeFactory), WithDynamicRLNRelay(keystorePath, keystorePassword, rlnTreePath, common.HexToAddress(contractAddress), &index, handleSpam, ethClientAddress),