diff --git a/waku/v2/protocol/filter/client.go b/waku/v2/protocol/filter/client.go index 28721b18..5b1fdd74 100644 --- a/waku/v2/protocol/filter/client.go +++ b/waku/v2/protocol/filter/client.go @@ -8,6 +8,7 @@ import ( "math" "net/http" "sync" + "sync/atomic" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -44,6 +45,7 @@ type WakuFilterLightNode struct { log *zap.Logger subscriptions *SubscriptionsMap pm *peermanager.PeerManager + started atomic.Bool } type ContentFilter struct { @@ -56,6 +58,8 @@ type WakuFilterPushResult struct { PeerID peer.ID } +var errNotStarted = errors.New("filter is not started") + // NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options // Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode. // If using libp2p host, then pass peermanager as nil @@ -78,6 +82,10 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) { } func (wf *WakuFilterLightNode) Start(ctx context.Context) error { + if !wf.started.CompareAndSwap(false, true) { + return nil // Already started + } + wf.wg.Wait() // Wait for any goroutines to stop ctx, cancel := context.WithCancel(ctx) @@ -94,7 +102,7 @@ func (wf *WakuFilterLightNode) Start(ctx context.Context) error { // Stop unmounts the filter protocol func (wf *WakuFilterLightNode) Stop() { - if wf.cancel == nil { + if !wf.started.CompareAndSwap(true, false) { return } @@ -206,6 +214,10 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr // Subscribe setups a subscription to receive messages that match a specific content filter func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) { + if !wf.isStarted() { + return nil, errNotStarted + } + if contentFilter.Topic == "" { return nil, errors.New("topic is required") } @@ -244,6 +256,10 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont // FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) { + if !wf.isStarted() { + return nil, errNotStarted + } + if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) { return nil, errors.New("subscription does not exist") } @@ -263,6 +279,10 @@ func (wf *WakuFilterLightNode) getUnsubscribeParameters(opts ...FilterUnsubscrib } func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { + if !wf.isStarted() { + return errNotStarted + } + return wf.request( ctx, &FilterSubscribeParameters{selectedPeer: peerID}, @@ -271,10 +291,18 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { } func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error { + if !wf.isStarted() { + return errNotStarted + } + return wf.Ping(ctx, subscription.PeerID) } func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails { + if !wf.isStarted() { + return nil + } + wf.subscriptions.RLock() defer wf.subscriptions.RUnlock() @@ -324,6 +352,10 @@ func (wf *WakuFilterLightNode) cleanupSubscriptions(peerID peer.ID, contentFilte // Unsubscribe is used to stop receiving messages from a peer that match a content filter func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { + if !wf.isStarted() { + return nil, errNotStarted + } + if contentFilter.Topic == "" { return nil, errors.New("topic is required") } @@ -396,11 +428,17 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co params.wg.Wait() } + close(resultChan) + return resultChan, nil } // Unsubscribe is used to stop receiving messages from a peer that match a content filter func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { + if !wf.isStarted() { + return nil, errNotStarted + } + var contentTopics []string for k := range sub.ContentTopics { contentTopics = append(contentTopics, k) @@ -413,6 +451,10 @@ func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, // UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { + if !wf.isStarted() { + return nil, errNotStarted + } + params, err := wf.getUnsubscribeParameters(opts...) if err != nil { return nil, err @@ -466,3 +508,7 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte return resultChan, nil } + +func (wf *WakuFilterLightNode) isStarted() bool { + return wf.started.Load() +} diff --git a/waku/v2/protocol/filter/filter_test.go b/waku/v2/protocol/filter/filter_test.go index 5e676bad..e0b5dc67 100644 --- a/waku/v2/protocol/filter/filter_test.go +++ b/waku/v2/protocol/filter/filter_test.go @@ -333,3 +333,46 @@ func (s *FilterTestSuite) TestMultipleMessages() { }, s.subDetails.C) } + +func (s *FilterTestSuite) TestRunningGuard() { + s.lightNode.Stop() + + contentFilter := ContentFilter{ + Topic: "test", + ContentTopics: []string{"test"}, + } + + _, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + + s.Require().ErrorIs(err, errNotStarted) + + err = s.lightNode.Start(s.ctx) + s.Require().NoError(err) + + _, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + + s.Require().NoError(err) +} + +func (s *FilterTestSuite) TestFireAndForgetAndCustomWg() { + contentFilter := ContentFilter{ + Topic: "test", + ContentTopics: []string{"test"}, + } + + _, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + s.Require().NoError(err) + + ch, err := s.lightNode.Unsubscribe(s.ctx, contentFilter, Async()) + _, open := <-ch + s.Require().NoError(err) + s.Require().False(open) + + _, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + s.Require().NoError(err) + + wg := sync.WaitGroup{} + _, err = s.lightNode.Unsubscribe(s.ctx, contentFilter, WithWaitGroup(&wg)) + wg.Wait() + s.Require().NoError(err) +}