mirror of https://github.com/status-im/go-waku.git
refactor(filter): unsubscribe waitgroup, execute async, and guard against calling functions while the protocol is not started (#692)
* refactor(filter): unsubscribe waitgroup and async * refactor: verify started state for doing filter operations
This commit is contained in:
parent
e8bd38a023
commit
8aa1c4a39b
|
@ -34,6 +34,9 @@ var (
|
|||
)
|
||||
|
||||
type WakuFilterLightNode struct {
|
||||
sync.RWMutex
|
||||
started bool
|
||||
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
h host.Host
|
||||
|
@ -56,6 +59,9 @@ type WakuFilterPushResult struct {
|
|||
PeerID peer.ID
|
||||
}
|
||||
|
||||
var errNotStarted = errors.New("not started")
|
||||
var errAlreadyStarted = errors.New("already started")
|
||||
|
||||
// NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options
|
||||
// Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode.
|
||||
// If using libp2p host, then pass peermanager as nil
|
||||
|
@ -78,12 +84,20 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) {
|
|||
}
|
||||
|
||||
func (wf *WakuFilterLightNode) Start(ctx context.Context) error {
|
||||
wf.Lock()
|
||||
defer wf.Unlock()
|
||||
|
||||
if wf.started {
|
||||
return errAlreadyStarted
|
||||
}
|
||||
|
||||
wf.wg.Wait() // Wait for any goroutines to stop
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
wf.cancel = cancel
|
||||
wf.ctx = ctx
|
||||
wf.subscriptions = NewSubscriptionMap(wf.log)
|
||||
wf.started = true
|
||||
|
||||
wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx))
|
||||
|
||||
|
@ -94,7 +108,10 @@ func (wf *WakuFilterLightNode) Start(ctx context.Context) error {
|
|||
|
||||
// Stop unmounts the filter protocol
|
||||
func (wf *WakuFilterLightNode) Stop() {
|
||||
if wf.cancel == nil {
|
||||
wf.Lock()
|
||||
defer wf.Unlock()
|
||||
|
||||
if !wf.started {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -102,10 +119,23 @@ func (wf *WakuFilterLightNode) Stop() {
|
|||
|
||||
wf.h.RemoveStreamHandler(FilterPushID_v20beta1)
|
||||
|
||||
_, _ = wf.UnsubscribeAll(wf.ctx)
|
||||
res, err := wf.unsubscribeAll(wf.ctx)
|
||||
if err != nil {
|
||||
wf.log.Warn("unsubscribing from full nodes", zap.Error(err))
|
||||
}
|
||||
|
||||
for r := range res {
|
||||
if r.Err != nil {
|
||||
wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
wf.subscriptions.Clear()
|
||||
|
||||
wf.started = false
|
||||
wf.cancel = nil
|
||||
|
||||
wf.wg.Wait()
|
||||
}
|
||||
|
||||
|
@ -206,6 +236,13 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr
|
|||
|
||||
// Subscribe setups a subscription to receive messages that match a specific content filter
|
||||
func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return nil, errNotStarted
|
||||
}
|
||||
|
||||
if contentFilter.Topic == "" {
|
||||
return nil, errors.New("topic is required")
|
||||
}
|
||||
|
@ -244,6 +281,13 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont
|
|||
|
||||
// FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol
|
||||
func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return nil, errNotStarted
|
||||
}
|
||||
|
||||
if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) {
|
||||
return nil, errors.New("subscription does not exist")
|
||||
}
|
||||
|
@ -263,6 +307,13 @@ func (wf *WakuFilterLightNode) getUnsubscribeParameters(opts ...FilterUnsubscrib
|
|||
}
|
||||
|
||||
func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return errNotStarted
|
||||
}
|
||||
|
||||
return wf.request(
|
||||
ctx,
|
||||
&FilterSubscribeParameters{selectedPeer: peerID},
|
||||
|
@ -271,10 +322,24 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error {
|
|||
}
|
||||
|
||||
func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return errNotStarted
|
||||
}
|
||||
|
||||
return wf.Ping(ctx, subscription.PeerID)
|
||||
}
|
||||
|
||||
func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
wf.subscriptions.RLock()
|
||||
defer wf.subscriptions.RUnlock()
|
||||
|
||||
|
@ -324,6 +389,13 @@ func (wf *WakuFilterLightNode) cleanupSubscriptions(peerID peer.ID, contentFilte
|
|||
|
||||
// Unsubscribe is used to stop receiving messages from a peer that match a content filter
|
||||
func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return nil, errNotStarted
|
||||
}
|
||||
|
||||
if contentFilter.Topic == "" {
|
||||
return nil, errors.New("topic is required")
|
||||
}
|
||||
|
@ -341,17 +413,33 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co
|
|||
return nil, err
|
||||
}
|
||||
|
||||
localWg := sync.WaitGroup{}
|
||||
resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items))
|
||||
var peersUnsubscribed []peer.ID
|
||||
for peerID := range wf.subscriptions.items {
|
||||
if params.selectedPeer != "" && peerID != params.selectedPeer {
|
||||
continue
|
||||
}
|
||||
peersUnsubscribed = append(peersUnsubscribed, peerID)
|
||||
localWg.Add(1)
|
||||
|
||||
subscriptions, ok := wf.subscriptions.items[peerID]
|
||||
if !ok || subscriptions == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
wf.cleanupSubscriptions(peerID, contentFilter)
|
||||
if len(subscriptions.subscriptionsPerTopic) == 0 {
|
||||
delete(wf.subscriptions.items, peerID)
|
||||
}
|
||||
|
||||
if params.wg != nil {
|
||||
params.wg.Add(1)
|
||||
}
|
||||
|
||||
go func(peerID peer.ID) {
|
||||
defer localWg.Done()
|
||||
defer func() {
|
||||
if params.wg != nil {
|
||||
params.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
err := wf.request(
|
||||
ctx,
|
||||
&FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID},
|
||||
|
@ -367,27 +455,33 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co
|
|||
}
|
||||
}
|
||||
|
||||
wf.cleanupSubscriptions(peerID, contentFilter)
|
||||
|
||||
if params.wg != nil {
|
||||
resultChan <- WakuFilterPushResult{
|
||||
Err: err,
|
||||
PeerID: peerID,
|
||||
}
|
||||
}
|
||||
}(peerID)
|
||||
}
|
||||
|
||||
localWg.Wait()
|
||||
if params.wg != nil {
|
||||
params.wg.Wait()
|
||||
}
|
||||
|
||||
close(resultChan)
|
||||
for _, peerID := range peersUnsubscribed {
|
||||
if wf.subscriptions != nil && wf.subscriptions.items != nil && wf.subscriptions.items[peerID] != nil && len(wf.subscriptions.items[peerID].subscriptionsPerTopic) == 0 {
|
||||
delete(wf.subscriptions.items, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return resultChan, nil
|
||||
}
|
||||
|
||||
// Unsubscribe is used to stop receiving messages from a peer that match a content filter
|
||||
func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return nil, errNotStarted
|
||||
}
|
||||
|
||||
var contentTopics []string
|
||||
for k := range sub.ContentTopics {
|
||||
contentTopics = append(contentTopics, k)
|
||||
|
@ -398,8 +492,7 @@ func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context,
|
|||
return wf.Unsubscribe(ctx, ContentFilter{Topic: sub.PubsubTopic, ContentTopics: contentTopics}, opts...)
|
||||
}
|
||||
|
||||
// UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions
|
||||
func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
|
||||
func (wf *WakuFilterLightNode) unsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
|
||||
params, err := wf.getUnsubscribeParameters(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -408,19 +501,26 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte
|
|||
wf.subscriptions.Lock()
|
||||
defer wf.subscriptions.Unlock()
|
||||
|
||||
localWg := sync.WaitGroup{}
|
||||
resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items))
|
||||
var peersUnsubscribed []peer.ID
|
||||
|
||||
for peerID := range wf.subscriptions.items {
|
||||
if params.selectedPeer != "" && peerID != params.selectedPeer {
|
||||
continue
|
||||
}
|
||||
peersUnsubscribed = append(peersUnsubscribed, peerID)
|
||||
|
||||
localWg.Add(1)
|
||||
delete(wf.subscriptions.items, peerID)
|
||||
|
||||
if params.wg != nil {
|
||||
params.wg.Add(1)
|
||||
}
|
||||
|
||||
go func(peerID peer.ID) {
|
||||
defer localWg.Done()
|
||||
defer func() {
|
||||
if params.wg != nil {
|
||||
params.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
err := wf.request(
|
||||
ctx,
|
||||
&FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID},
|
||||
|
@ -429,17 +529,32 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte
|
|||
if err != nil {
|
||||
wf.log.Error("could not unsubscribe from peer", logging.HostID("peerID", peerID), zap.Error(err))
|
||||
}
|
||||
if params.wg != nil {
|
||||
resultChan <- WakuFilterPushResult{
|
||||
Err: err,
|
||||
PeerID: peerID,
|
||||
}
|
||||
}
|
||||
}(peerID)
|
||||
}
|
||||
|
||||
localWg.Wait()
|
||||
close(resultChan)
|
||||
for _, peerID := range peersUnsubscribed {
|
||||
delete(wf.subscriptions.items, peerID)
|
||||
if params.wg != nil {
|
||||
params.wg.Wait()
|
||||
}
|
||||
|
||||
close(resultChan)
|
||||
|
||||
return resultChan, nil
|
||||
}
|
||||
|
||||
// UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions
|
||||
func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
|
||||
wf.RLock()
|
||||
defer wf.RUnlock()
|
||||
|
||||
if !wf.started {
|
||||
return nil, errNotStarted
|
||||
}
|
||||
|
||||
return wf.unsubscribeAll(ctx, opts...)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package filter
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -67,7 +68,7 @@ func (s *FilterTestSuite) makeWakuRelay(topic string) (*relay.WakuRelay, *relay.
|
|||
return relay, sub, host, broadcaster
|
||||
}
|
||||
|
||||
func (s *FilterTestSuite) makeWakuFilterLightNode() *WakuFilterLightNode {
|
||||
func (s *FilterTestSuite) makeWakuFilterLightNode(start bool) *WakuFilterLightNode {
|
||||
port, err := tests.FindFreePort(s.T(), "", 5)
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
@ -79,8 +80,10 @@ func (s *FilterTestSuite) makeWakuFilterLightNode() *WakuFilterLightNode {
|
|||
filterPush := NewWakuFilterLightNode(b, nil, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, s.log)
|
||||
filterPush.SetHost(host)
|
||||
s.lightNodeHost = host
|
||||
if start {
|
||||
err = filterPush.Start(context.Background())
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
return filterPush
|
||||
}
|
||||
|
@ -178,7 +181,7 @@ func (s *FilterTestSuite) SetupTest() {
|
|||
s.testTopic = "/waku/2/go/filter/test"
|
||||
s.testContentTopic = "TopicA"
|
||||
|
||||
s.lightNode = s.makeWakuFilterLightNode()
|
||||
s.lightNode = s.makeWakuFilterLightNode(true)
|
||||
|
||||
s.relayNode, s.fullNode = s.makeWakuFilterFullNode(s.testTopic)
|
||||
|
||||
|
@ -333,3 +336,75 @@ func (s *FilterTestSuite) TestMultipleMessages() {
|
|||
|
||||
}, s.subDetails.C)
|
||||
}
|
||||
|
||||
func (s *FilterTestSuite) TestRunningGuard() {
|
||||
s.lightNode.Stop()
|
||||
|
||||
contentFilter := ContentFilter{
|
||||
Topic: "test",
|
||||
ContentTopics: []string{"test"},
|
||||
}
|
||||
|
||||
_, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
|
||||
|
||||
s.Require().ErrorIs(err, errNotStarted)
|
||||
|
||||
err = s.lightNode.Start(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
|
||||
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *FilterTestSuite) TestFireAndForgetAndCustomWg() {
|
||||
contentFilter := ContentFilter{
|
||||
Topic: "test",
|
||||
ContentTopics: []string{"test"},
|
||||
}
|
||||
|
||||
_, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
|
||||
s.Require().NoError(err)
|
||||
|
||||
ch, err := s.lightNode.Unsubscribe(s.ctx, contentFilter, DontWait())
|
||||
_, open := <-ch
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(open)
|
||||
|
||||
_, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
|
||||
s.Require().NoError(err)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
_, err = s.lightNode.Unsubscribe(s.ctx, contentFilter, WithWaitGroup(&wg))
|
||||
wg.Wait()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *FilterTestSuite) TestStartStop() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
s.lightNode = s.makeWakuFilterLightNode(false)
|
||||
|
||||
stopNode := func() {
|
||||
for i := 0; i < 100000; i++ {
|
||||
s.lightNode.Stop()
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
startNode := func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
err := s.lightNode.Start(context.Background())
|
||||
if errors.Is(err, errAlreadyStarted) {
|
||||
continue
|
||||
}
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
go startNode()
|
||||
go stopNode()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package filter
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
|
@ -26,6 +27,7 @@ type (
|
|||
selectedPeer peer.ID
|
||||
requestID []byte
|
||||
log *zap.Logger
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
FilterParameters struct {
|
||||
|
@ -135,9 +137,26 @@ func AutomaticRequestId() FilterUnsubscribeOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithWaitGroup allos specigying a waitgroup to wait until all
|
||||
// unsubscribe requests are complete before the function is complete
|
||||
func WithWaitGroup(wg *sync.WaitGroup) FilterUnsubscribeOption {
|
||||
return func(params *FilterUnsubscribeParameters) {
|
||||
params.wg = wg
|
||||
}
|
||||
}
|
||||
|
||||
// DontWait is used to fire and forget an unsubscription, and don't
|
||||
// care about the results of it
|
||||
func DontWait() FilterUnsubscribeOption {
|
||||
return func(params *FilterUnsubscribeParameters) {
|
||||
params.wg = nil
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultUnsubscribeOptions() []FilterUnsubscribeOption {
|
||||
return []FilterUnsubscribeOption{
|
||||
AutomaticRequestId(),
|
||||
WithWaitGroup(&sync.WaitGroup{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue