From ac1a699171cfb97abf617ab0041fd4d9177260f1 Mon Sep 17 00:00:00 2001 From: Prem Chaitanya Prathi Date: Fri, 1 Dec 2023 06:27:13 +0530 Subject: [PATCH] fix: return appropriate errors in filter unsubscribe (#941) --- waku/v2/protocol/filter/client.go | 41 +++++++++++++------ waku/v2/protocol/filter/filter_test.go | 2 +- .../subscription/subscriptions_map.go | 6 +++ 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/waku/v2/protocol/filter/client.go b/waku/v2/protocol/filter/client.go index 08f06b71..38449ae1 100644 --- a/waku/v2/protocol/filter/client.go +++ b/waku/v2/protocol/filter/client.go @@ -36,7 +36,8 @@ import ( const FilterPushID_v20beta1 = libp2pProtocol.ID("/vac/waku/filter-push/2.0.0-beta1") var ( - ErrNoPeersAvailable = errors.New("no suitable remote peers") + ErrNoPeersAvailable = errors.New("no suitable remote peers") + ErrSubscriptionNotFound = errors.New("subscription not found") ) type WakuFilterLightNode struct { @@ -110,19 +111,21 @@ func (wf *WakuFilterLightNode) start() error { func (wf *WakuFilterLightNode) Stop() { wf.CommonService.Stop(func() { wf.h.RemoveStreamHandler(FilterPushID_v20beta1) - res, err := wf.unsubscribeAll(wf.Context()) - if err != nil { - wf.log.Warn("unsubscribing from full nodes", zap.Error(err)) - } - - for _, r := range res.Errors() { - if r.Err != nil { - wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID)) + if wf.subscriptions.Count() > 0 { + res, err := wf.unsubscribeAll(wf.Context()) + if err != nil { + wf.log.Warn("unsubscribing from full nodes", zap.Error(err)) } + for _, r := range res.Errors() { + if r.Err != nil { + wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID)) + } + + } + // + wf.subscriptions.Clear() } - // - wf.subscriptions.Clear() }) } @@ -485,6 +488,13 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter pr peers := make(map[peer.ID]struct{}) subs := wf.subscriptions.GetSubscription(params.selectedPeer, cFilter) + if len(subs) == 0 { + result.Add(WakuFilterPushError{ + Err: ErrSubscriptionNotFound, + PeerID: params.selectedPeer, + }) + continue + } for _, sub := range subs { sub.Remove(cTopics...) peers[sub.PeerID] = struct{}{} @@ -583,14 +593,21 @@ func (wf *WakuFilterLightNode) unsubscribeAll(ctx context.Context, opts ...Filte if err != nil { return nil, err } + result := &WakuFilterPushResult{} peers := make(map[peer.ID]struct{}) subs := wf.subscriptions.GetSubscription(params.selectedPeer, protocol.ContentFilter{}) + if len(subs) == 0 && params.selectedPeer != "" { + result.Add(WakuFilterPushError{ + Err: err, + PeerID: params.selectedPeer, + }) + return result, ErrSubscriptionNotFound + } for _, sub := range subs { sub.Close() peers[sub.PeerID] = struct{}{} } - result := &WakuFilterPushResult{} if params.wg != nil { params.wg.Add(len(peers)) } diff --git a/waku/v2/protocol/filter/filter_test.go b/waku/v2/protocol/filter/filter_test.go index 5143a773..01664844 100644 --- a/waku/v2/protocol/filter/filter_test.go +++ b/waku/v2/protocol/filter/filter_test.go @@ -190,7 +190,7 @@ func (s *FilterTestSuite) waitForMessages(fn func(), subs []*subscription.Subscr contentTopic: env.Message().GetContentTopic(), payload: string(env.Message().GetPayload()), } - s.log.Info("received message ", zap.String("pubSubTopic", received.pubSubTopic), zap.String("contentTopic", received.contentTopic), zap.String("payload", received.payload)) + s.log.Debug("received message ", zap.String("pubSubTopic", received.pubSubTopic), zap.String("contentTopic", received.contentTopic), zap.String("payload", received.payload)) if matchOneOfManyMsg(received, expected) { found++ } diff --git a/waku/v2/protocol/subscription/subscriptions_map.go b/waku/v2/protocol/subscription/subscriptions_map.go index c007f623..540fe807 100644 --- a/waku/v2/protocol/subscription/subscriptions_map.go +++ b/waku/v2/protocol/subscription/subscriptions_map.go @@ -28,6 +28,12 @@ func NewSubscriptionMap(logger *zap.Logger) *SubscriptionsMap { } } +func (m *SubscriptionsMap) Count() int { + m.RLock() + defer m.RUnlock() + return len(m.items) +} + func (m *SubscriptionsMap) IsListening(pubsubTopic, contentTopic string) bool { m.RLock() defer m.RUnlock()