diff --git a/waku/v2/protocol/filter/client.go b/waku/v2/protocol/filter/client.go index f2787269..d9448965 100644 --- a/waku/v2/protocol/filter/client.go +++ b/waku/v2/protocol/filter/client.go @@ -34,6 +34,9 @@ var ( ) type WakuFilterLightNode struct { + sync.RWMutex + started bool + cancel context.CancelFunc ctx context.Context h host.Host @@ -56,6 +59,9 @@ type WakuFilterPushResult struct { PeerID peer.ID } +var errNotStarted = errors.New("not started") +var errAlreadyStarted = errors.New("already started") + // NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options // Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode. // If using libp2p host, then pass peermanager as nil @@ -78,12 +84,20 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) { } func (wf *WakuFilterLightNode) Start(ctx context.Context) error { + wf.Lock() + defer wf.Unlock() + + if wf.started { + return errAlreadyStarted + } + wf.wg.Wait() // Wait for any goroutines to stop ctx, cancel := context.WithCancel(ctx) wf.cancel = cancel wf.ctx = ctx wf.subscriptions = NewSubscriptionMap(wf.log) + wf.started = true wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx)) @@ -94,7 +108,10 @@ func (wf *WakuFilterLightNode) Start(ctx context.Context) error { // Stop unmounts the filter protocol func (wf *WakuFilterLightNode) Stop() { - if wf.cancel == nil { + wf.Lock() + defer wf.Unlock() + + if !wf.started { return } @@ -102,10 +119,23 @@ func (wf *WakuFilterLightNode) Stop() { wf.h.RemoveStreamHandler(FilterPushID_v20beta1) - _, _ = wf.UnsubscribeAll(wf.ctx) + res, err := wf.unsubscribeAll(wf.ctx) + if err != nil { + wf.log.Warn("unsubscribing from full nodes", zap.Error(err)) + } + + for r := range res { + if r.Err != nil { + wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID)) + } + + } wf.subscriptions.Clear() + wf.started = false + wf.cancel = nil + wf.wg.Wait() } @@ -206,6 +236,13 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr // Subscribe setups a subscription to receive messages that match a specific content filter func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return nil, errNotStarted + } + if contentFilter.Topic == "" { return nil, errors.New("topic is required") } @@ -244,6 +281,13 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont // FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return nil, errNotStarted + } + if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) { return nil, errors.New("subscription does not exist") } @@ -263,6 +307,13 @@ func (wf *WakuFilterLightNode) getUnsubscribeParameters(opts ...FilterUnsubscrib } func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return errNotStarted + } + return wf.request( ctx, &FilterSubscribeParameters{selectedPeer: peerID}, @@ -271,10 +322,24 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { } func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return errNotStarted + } + return wf.Ping(ctx, subscription.PeerID) } func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return nil + } + wf.subscriptions.RLock() defer wf.subscriptions.RUnlock() @@ -324,6 +389,13 @@ func (wf *WakuFilterLightNode) cleanupSubscriptions(peerID peer.ID, contentFilte // Unsubscribe is used to stop receiving messages from a peer that match a content filter func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return nil, errNotStarted + } + if contentFilter.Topic == "" { return nil, errors.New("topic is required") } @@ -341,17 +413,33 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co return nil, err } - localWg := sync.WaitGroup{} resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items)) - var peersUnsubscribed []peer.ID for peerID := range wf.subscriptions.items { if params.selectedPeer != "" && peerID != params.selectedPeer { continue } - peersUnsubscribed = append(peersUnsubscribed, peerID) - localWg.Add(1) + + subscriptions, ok := wf.subscriptions.items[peerID] + if !ok || subscriptions == nil { + continue + } + + wf.cleanupSubscriptions(peerID, contentFilter) + if len(subscriptions.subscriptionsPerTopic) == 0 { + delete(wf.subscriptions.items, peerID) + } + + if params.wg != nil { + params.wg.Add(1) + } + go func(peerID peer.ID) { - defer localWg.Done() + defer func() { + if params.wg != nil { + params.wg.Done() + } + }() + err := wf.request( ctx, &FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID}, @@ -367,27 +455,33 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co } } - wf.cleanupSubscriptions(peerID, contentFilter) - - resultChan <- WakuFilterPushResult{ - Err: err, - PeerID: peerID, + if params.wg != nil { + resultChan <- WakuFilterPushResult{ + Err: err, + PeerID: peerID, + } } }(peerID) } - localWg.Wait() - close(resultChan) - for _, peerID := range peersUnsubscribed { - if wf.subscriptions != nil && wf.subscriptions.items != nil && wf.subscriptions.items[peerID] != nil && len(wf.subscriptions.items[peerID].subscriptionsPerTopic) == 0 { - delete(wf.subscriptions.items, peerID) - } + if params.wg != nil { + params.wg.Wait() } + + close(resultChan) + return resultChan, nil } // Unsubscribe is used to stop receiving messages from a peer that match a content filter func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return nil, errNotStarted + } + var contentTopics []string for k := range sub.ContentTopics { contentTopics = append(contentTopics, k) @@ -398,8 +492,7 @@ func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, return wf.Unsubscribe(ctx, ContentFilter{Topic: sub.PubsubTopic, ContentTopics: contentTopics}, opts...) } -// UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions -func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { +func (wf *WakuFilterLightNode) unsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { params, err := wf.getUnsubscribeParameters(opts...) if err != nil { return nil, err @@ -408,19 +501,26 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte wf.subscriptions.Lock() defer wf.subscriptions.Unlock() - localWg := sync.WaitGroup{} resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items)) - var peersUnsubscribed []peer.ID for peerID := range wf.subscriptions.items { if params.selectedPeer != "" && peerID != params.selectedPeer { continue } - peersUnsubscribed = append(peersUnsubscribed, peerID) - localWg.Add(1) + delete(wf.subscriptions.items, peerID) + + if params.wg != nil { + params.wg.Add(1) + } + go func(peerID peer.ID) { - defer localWg.Done() + defer func() { + if params.wg != nil { + params.wg.Done() + } + }() + err := wf.request( ctx, &FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID}, @@ -429,17 +529,32 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte if err != nil { wf.log.Error("could not unsubscribe from peer", logging.HostID("peerID", peerID), zap.Error(err)) } - resultChan <- WakuFilterPushResult{ - Err: err, - PeerID: peerID, + if params.wg != nil { + resultChan <- WakuFilterPushResult{ + Err: err, + PeerID: peerID, + } } }(peerID) } - localWg.Wait() - close(resultChan) - for _, peerID := range peersUnsubscribed { - delete(wf.subscriptions.items, peerID) + if params.wg != nil { + params.wg.Wait() } + + close(resultChan) + return resultChan, nil } + +// UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions +func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { + wf.RLock() + defer wf.RUnlock() + + if !wf.started { + return nil, errNotStarted + } + + return wf.unsubscribeAll(ctx, opts...) +} diff --git a/waku/v2/protocol/filter/filter_test.go b/waku/v2/protocol/filter/filter_test.go index 5e676bad..2d685371 100644 --- a/waku/v2/protocol/filter/filter_test.go +++ b/waku/v2/protocol/filter/filter_test.go @@ -3,6 +3,7 @@ package filter import ( "context" "crypto/rand" + "errors" "net/http" "sync" "testing" @@ -67,7 +68,7 @@ func (s *FilterTestSuite) makeWakuRelay(topic string) (*relay.WakuRelay, *relay. return relay, sub, host, broadcaster } -func (s *FilterTestSuite) makeWakuFilterLightNode() *WakuFilterLightNode { +func (s *FilterTestSuite) makeWakuFilterLightNode(start bool) *WakuFilterLightNode { port, err := tests.FindFreePort(s.T(), "", 5) s.Require().NoError(err) @@ -79,8 +80,10 @@ func (s *FilterTestSuite) makeWakuFilterLightNode() *WakuFilterLightNode { filterPush := NewWakuFilterLightNode(b, nil, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, s.log) filterPush.SetHost(host) s.lightNodeHost = host - err = filterPush.Start(context.Background()) - s.Require().NoError(err) + if start { + err = filterPush.Start(context.Background()) + s.Require().NoError(err) + } return filterPush } @@ -178,7 +181,7 @@ func (s *FilterTestSuite) SetupTest() { s.testTopic = "/waku/2/go/filter/test" s.testContentTopic = "TopicA" - s.lightNode = s.makeWakuFilterLightNode() + s.lightNode = s.makeWakuFilterLightNode(true) s.relayNode, s.fullNode = s.makeWakuFilterFullNode(s.testTopic) @@ -333,3 +336,75 @@ func (s *FilterTestSuite) TestMultipleMessages() { }, s.subDetails.C) } + +func (s *FilterTestSuite) TestRunningGuard() { + s.lightNode.Stop() + + contentFilter := ContentFilter{ + Topic: "test", + ContentTopics: []string{"test"}, + } + + _, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + + s.Require().ErrorIs(err, errNotStarted) + + err = s.lightNode.Start(s.ctx) + s.Require().NoError(err) + + _, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + + s.Require().NoError(err) +} + +func (s *FilterTestSuite) TestFireAndForgetAndCustomWg() { + contentFilter := ContentFilter{ + Topic: "test", + ContentTopics: []string{"test"}, + } + + _, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + s.Require().NoError(err) + + ch, err := s.lightNode.Unsubscribe(s.ctx, contentFilter, DontWait()) + _, open := <-ch + s.Require().NoError(err) + s.Require().False(open) + + _, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) + s.Require().NoError(err) + + wg := sync.WaitGroup{} + _, err = s.lightNode.Unsubscribe(s.ctx, contentFilter, WithWaitGroup(&wg)) + wg.Wait() + s.Require().NoError(err) +} + +func (s *FilterTestSuite) TestStartStop() { + var wg sync.WaitGroup + wg.Add(2) + s.lightNode = s.makeWakuFilterLightNode(false) + + stopNode := func() { + for i := 0; i < 100000; i++ { + s.lightNode.Stop() + } + wg.Done() + } + + startNode := func() { + for i := 0; i < 100; i++ { + err := s.lightNode.Start(context.Background()) + if errors.Is(err, errAlreadyStarted) { + continue + } + s.Require().NoError(err) + } + wg.Done() + } + + go startNode() + go stopNode() + + wg.Wait() +} diff --git a/waku/v2/protocol/filter/options.go b/waku/v2/protocol/filter/options.go index 188638b9..249db284 100644 --- a/waku/v2/protocol/filter/options.go +++ b/waku/v2/protocol/filter/options.go @@ -2,6 +2,7 @@ package filter import ( "context" + "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -26,6 +27,7 @@ type ( selectedPeer peer.ID requestID []byte log *zap.Logger + wg *sync.WaitGroup } FilterParameters struct { @@ -135,9 +137,26 @@ func AutomaticRequestId() FilterUnsubscribeOption { } } +// WithWaitGroup allos specigying a waitgroup to wait until all +// unsubscribe requests are complete before the function is complete +func WithWaitGroup(wg *sync.WaitGroup) FilterUnsubscribeOption { + return func(params *FilterUnsubscribeParameters) { + params.wg = wg + } +} + +// DontWait is used to fire and forget an unsubscription, and don't +// care about the results of it +func DontWait() FilterUnsubscribeOption { + return func(params *FilterUnsubscribeParameters) { + params.wg = nil + } +} + func DefaultUnsubscribeOptions() []FilterUnsubscribeOption { return []FilterUnsubscribeOption{ AutomaticRequestId(), + WithWaitGroup(&sync.WaitGroup{}), } }