From e59b1dd2f0e08c626905365232f64a935999a3d7 Mon Sep 17 00:00:00 2001 From: Vitaly Vlasov Date: Mon, 12 Feb 2024 13:11:25 +0200 Subject: [PATCH] feat: previous pings are stored in a ttl map --- waku/v2/protocol/filter/client.go | 14 ++++++- waku/v2/utils/ttl_map.go | 63 +++++++++++++++++++++++++++++++ waku/v2/utils/ttl_map_test.go | 33 ++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 waku/v2/utils/ttl_map.go create mode 100644 waku/v2/utils/ttl_map_test.go diff --git a/waku/v2/protocol/filter/client.go b/waku/v2/protocol/filter/client.go index 451ef51c..41c7f213 100644 --- a/waku/v2/protocol/filter/client.go +++ b/waku/v2/protocol/filter/client.go @@ -26,6 +26,7 @@ import ( "github.com/waku-org/go-waku/waku/v2/protocol/subscription" "github.com/waku-org/go-waku/waku/v2/service" "github.com/waku-org/go-waku/waku/v2/timesource" + "github.com/waku-org/go-waku/waku/v2/utils" "go.uber.org/zap" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -50,6 +51,7 @@ type WakuFilterLightNode struct { log *zap.Logger subscriptions *subscription.SubscriptionsMap pm *peermanager.PeerManager + peerPings *utils.TtlMap[peer.ID, error] } type WakuFilterPushError struct { @@ -86,7 +88,6 @@ func NewWakuFilterLightNode(broadcaster relay.Broadcaster, pm *peermanager.PeerM wf.pm = pm wf.CommonService = service.NewCommonService() wf.metrics = newMetrics(reg) - return wf } @@ -96,6 +97,7 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) { } func (wf *WakuFilterLightNode) Start(ctx context.Context) error { + wf.peerPings = utils.NewTtlMap[peer.ID, error](ctx, 5) return wf.CommonService.Start(ctx, wf.start) } @@ -445,6 +447,11 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID, opts .. return err } + pingResult, found := wf.peerPings.Get(peerID) + if found { + return pingResult + } + params := &FilterPingParameters{} for _, opt := range opts { opt(params) @@ -453,12 +460,15 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID, opts .. params.requestID = protocol.GenerateRequestID() } - return wf.request( + result := wf.request( ctx, params.requestID, pb.FilterSubscribeRequest_SUBSCRIBER_PING, protocol.ContentFilter{}, peerID) + + wf.peerPings.Put(peerID, result) + return result } func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *subscription.SubscriptionDetails) error { diff --git a/waku/v2/utils/ttl_map.go b/waku/v2/utils/ttl_map.go new file mode 100644 index 00000000..7dc6cd41 --- /dev/null +++ b/waku/v2/utils/ttl_map.go @@ -0,0 +1,63 @@ +package utils + +import ( + "context" + "sync" + "time" +) + +type elem[V any] struct { + v V + lastAccess int64 +} + +type TtlMap[K comparable, V any] struct { + sync.RWMutex + + m map[K]elem[V] +} + +func NewTtlMap[K comparable, V any](ctx context.Context, maxTtl uint) *TtlMap[K, V] { + m := &TtlMap[K, V]{m: make(map[K]elem[V])} + go func() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + func() { + m.Lock() + defer m.Unlock() + now := time.Now().Unix() + for k, v := range m.m { + if now-v.lastAccess > int64(maxTtl) { + delete(m.m, k) + } + } + }() + } + } + }() + return m +} + +func (m *TtlMap[K, V]) Put(k K, v V) { + m.Lock() + defer m.Unlock() + m.m[k] = elem[V]{v, time.Now().Unix()} +} + +func (m *TtlMap[K, V]) Get(k K) (V, bool) { + m.RLock() + defer m.RUnlock() + v, ok := m.m[k] + return v.v, ok +} + +func (m *TtlMap[K, V]) Len() int { + m.RLock() + defer m.RUnlock() + return len(m.m) +} diff --git a/waku/v2/utils/ttl_map_test.go b/waku/v2/utils/ttl_map_test.go new file mode 100644 index 00000000..9123df2b --- /dev/null +++ b/waku/v2/utils/ttl_map_test.go @@ -0,0 +1,33 @@ +package utils + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTtlMap(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + ttlMap := NewTtlMap[string, bool](ctx, 3) + ttlMap.Put("a", true) + ttlMap.Put("b", false) + + v, ok := ttlMap.Get("a") + require.Equal(t, v, true) + require.Equal(t, ok, true) + + v, ok = ttlMap.Get("b") + require.Equal(t, v, false) + require.Equal(t, ok, true) + + time.Sleep(5 * time.Second) + + require.Equal(t, ttlMap.Len(), 0) + ttlMap.Put("c", true) + cancel() + time.Sleep(5 * time.Second) + require.Equal(t, ttlMap.Len(), 1) + +}