refactor(filter): unsubscribe waitgroup, execute async, and guard against calling functions while the protocol is not started (#692)

* refactor(filter): unsubscribe waitgroup and async
* refactor: verify started state for doing filter operations
This commit is contained in:
richΛrd 2023-09-04 09:53:51 -04:00 committed by GitHub
parent e8bd38a023
commit 8aa1c4a39b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 245 additions and 36 deletions

View File

@ -34,6 +34,9 @@ var (
) )
type WakuFilterLightNode struct { type WakuFilterLightNode struct {
sync.RWMutex
started bool
cancel context.CancelFunc cancel context.CancelFunc
ctx context.Context ctx context.Context
h host.Host h host.Host
@ -56,6 +59,9 @@ type WakuFilterPushResult struct {
PeerID peer.ID PeerID peer.ID
} }
var errNotStarted = errors.New("not started")
var errAlreadyStarted = errors.New("already started")
// NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options // NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options
// Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode. // Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode.
// If using libp2p host, then pass peermanager as nil // If using libp2p host, then pass peermanager as nil
@ -78,12 +84,20 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) {
} }
func (wf *WakuFilterLightNode) Start(ctx context.Context) error { func (wf *WakuFilterLightNode) Start(ctx context.Context) error {
wf.Lock()
defer wf.Unlock()
if wf.started {
return errAlreadyStarted
}
wf.wg.Wait() // Wait for any goroutines to stop wf.wg.Wait() // Wait for any goroutines to stop
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
wf.cancel = cancel wf.cancel = cancel
wf.ctx = ctx wf.ctx = ctx
wf.subscriptions = NewSubscriptionMap(wf.log) wf.subscriptions = NewSubscriptionMap(wf.log)
wf.started = true
wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx)) wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx))
@ -94,7 +108,10 @@ func (wf *WakuFilterLightNode) Start(ctx context.Context) error {
// Stop unmounts the filter protocol // Stop unmounts the filter protocol
func (wf *WakuFilterLightNode) Stop() { func (wf *WakuFilterLightNode) Stop() {
if wf.cancel == nil { wf.Lock()
defer wf.Unlock()
if !wf.started {
return return
} }
@ -102,10 +119,23 @@ func (wf *WakuFilterLightNode) Stop() {
wf.h.RemoveStreamHandler(FilterPushID_v20beta1) wf.h.RemoveStreamHandler(FilterPushID_v20beta1)
_, _ = wf.UnsubscribeAll(wf.ctx) res, err := wf.unsubscribeAll(wf.ctx)
if err != nil {
wf.log.Warn("unsubscribing from full nodes", zap.Error(err))
}
for r := range res {
if r.Err != nil {
wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID))
}
}
wf.subscriptions.Clear() wf.subscriptions.Clear()
wf.started = false
wf.cancel = nil
wf.wg.Wait() wf.wg.Wait()
} }
@ -206,6 +236,13 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr
// Subscribe setups a subscription to receive messages that match a specific content filter // Subscribe setups a subscription to receive messages that match a specific content filter
func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) { func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return nil, errNotStarted
}
if contentFilter.Topic == "" { if contentFilter.Topic == "" {
return nil, errors.New("topic is required") return nil, errors.New("topic is required")
} }
@ -244,6 +281,13 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont
// FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol // FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol
func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) { func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return nil, errNotStarted
}
if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) { if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) {
return nil, errors.New("subscription does not exist") return nil, errors.New("subscription does not exist")
} }
@ -263,6 +307,13 @@ func (wf *WakuFilterLightNode) getUnsubscribeParameters(opts ...FilterUnsubscrib
} }
func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return errNotStarted
}
return wf.request( return wf.request(
ctx, ctx,
&FilterSubscribeParameters{selectedPeer: peerID}, &FilterSubscribeParameters{selectedPeer: peerID},
@ -271,10 +322,24 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error {
} }
func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error { func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return errNotStarted
}
return wf.Ping(ctx, subscription.PeerID) return wf.Ping(ctx, subscription.PeerID)
} }
func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails { func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return nil
}
wf.subscriptions.RLock() wf.subscriptions.RLock()
defer wf.subscriptions.RUnlock() defer wf.subscriptions.RUnlock()
@ -324,6 +389,13 @@ func (wf *WakuFilterLightNode) cleanupSubscriptions(peerID peer.ID, contentFilte
// Unsubscribe is used to stop receiving messages from a peer that match a content filter // Unsubscribe is used to stop receiving messages from a peer that match a content filter
func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return nil, errNotStarted
}
if contentFilter.Topic == "" { if contentFilter.Topic == "" {
return nil, errors.New("topic is required") return nil, errors.New("topic is required")
} }
@ -341,17 +413,33 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co
return nil, err return nil, err
} }
localWg := sync.WaitGroup{}
resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items)) resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items))
var peersUnsubscribed []peer.ID
for peerID := range wf.subscriptions.items { for peerID := range wf.subscriptions.items {
if params.selectedPeer != "" && peerID != params.selectedPeer { if params.selectedPeer != "" && peerID != params.selectedPeer {
continue continue
} }
peersUnsubscribed = append(peersUnsubscribed, peerID)
localWg.Add(1) subscriptions, ok := wf.subscriptions.items[peerID]
if !ok || subscriptions == nil {
continue
}
wf.cleanupSubscriptions(peerID, contentFilter)
if len(subscriptions.subscriptionsPerTopic) == 0 {
delete(wf.subscriptions.items, peerID)
}
if params.wg != nil {
params.wg.Add(1)
}
go func(peerID peer.ID) { go func(peerID peer.ID) {
defer localWg.Done() defer func() {
if params.wg != nil {
params.wg.Done()
}
}()
err := wf.request( err := wf.request(
ctx, ctx,
&FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID}, &FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID},
@ -367,27 +455,33 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co
} }
} }
wf.cleanupSubscriptions(peerID, contentFilter) if params.wg != nil {
resultChan <- WakuFilterPushResult{ resultChan <- WakuFilterPushResult{
Err: err, Err: err,
PeerID: peerID, PeerID: peerID,
} }
}
}(peerID) }(peerID)
} }
localWg.Wait() if params.wg != nil {
params.wg.Wait()
}
close(resultChan) close(resultChan)
for _, peerID := range peersUnsubscribed {
if wf.subscriptions != nil && wf.subscriptions.items != nil && wf.subscriptions.items[peerID] != nil && len(wf.subscriptions.items[peerID].subscriptionsPerTopic) == 0 {
delete(wf.subscriptions.items, peerID)
}
}
return resultChan, nil return resultChan, nil
} }
// Unsubscribe is used to stop receiving messages from a peer that match a content filter // Unsubscribe is used to stop receiving messages from a peer that match a content filter
func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return nil, errNotStarted
}
var contentTopics []string var contentTopics []string
for k := range sub.ContentTopics { for k := range sub.ContentTopics {
contentTopics = append(contentTopics, k) contentTopics = append(contentTopics, k)
@ -398,8 +492,7 @@ func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context,
return wf.Unsubscribe(ctx, ContentFilter{Topic: sub.PubsubTopic, ContentTopics: contentTopics}, opts...) return wf.Unsubscribe(ctx, ContentFilter{Topic: sub.PubsubTopic, ContentTopics: contentTopics}, opts...)
} }
// UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions func (wf *WakuFilterLightNode) unsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
params, err := wf.getUnsubscribeParameters(opts...) params, err := wf.getUnsubscribeParameters(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -408,19 +501,26 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte
wf.subscriptions.Lock() wf.subscriptions.Lock()
defer wf.subscriptions.Unlock() defer wf.subscriptions.Unlock()
localWg := sync.WaitGroup{}
resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items)) resultChan := make(chan WakuFilterPushResult, len(wf.subscriptions.items))
var peersUnsubscribed []peer.ID
for peerID := range wf.subscriptions.items { for peerID := range wf.subscriptions.items {
if params.selectedPeer != "" && peerID != params.selectedPeer { if params.selectedPeer != "" && peerID != params.selectedPeer {
continue continue
} }
peersUnsubscribed = append(peersUnsubscribed, peerID)
localWg.Add(1) delete(wf.subscriptions.items, peerID)
if params.wg != nil {
params.wg.Add(1)
}
go func(peerID peer.ID) { go func(peerID peer.ID) {
defer localWg.Done() defer func() {
if params.wg != nil {
params.wg.Done()
}
}()
err := wf.request( err := wf.request(
ctx, ctx,
&FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID}, &FilterSubscribeParameters{selectedPeer: peerID, requestID: params.requestID},
@ -429,17 +529,32 @@ func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...Filte
if err != nil { if err != nil {
wf.log.Error("could not unsubscribe from peer", logging.HostID("peerID", peerID), zap.Error(err)) wf.log.Error("could not unsubscribe from peer", logging.HostID("peerID", peerID), zap.Error(err))
} }
if params.wg != nil {
resultChan <- WakuFilterPushResult{ resultChan <- WakuFilterPushResult{
Err: err, Err: err,
PeerID: peerID, PeerID: peerID,
} }
}
}(peerID) }(peerID)
} }
localWg.Wait() if params.wg != nil {
close(resultChan) params.wg.Wait()
for _, peerID := range peersUnsubscribed {
delete(wf.subscriptions.items, peerID)
} }
close(resultChan)
return resultChan, nil return resultChan, nil
} }
// UnsubscribeAll is used to stop receiving messages from peer(s). It does not close subscriptions
func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
wf.RLock()
defer wf.RUnlock()
if !wf.started {
return nil, errNotStarted
}
return wf.unsubscribeAll(ctx, opts...)
}

View File

@ -3,6 +3,7 @@ package filter
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"errors"
"net/http" "net/http"
"sync" "sync"
"testing" "testing"
@ -67,7 +68,7 @@ func (s *FilterTestSuite) makeWakuRelay(topic string) (*relay.WakuRelay, *relay.
return relay, sub, host, broadcaster return relay, sub, host, broadcaster
} }
func (s *FilterTestSuite) makeWakuFilterLightNode() *WakuFilterLightNode { func (s *FilterTestSuite) makeWakuFilterLightNode(start bool) *WakuFilterLightNode {
port, err := tests.FindFreePort(s.T(), "", 5) port, err := tests.FindFreePort(s.T(), "", 5)
s.Require().NoError(err) s.Require().NoError(err)
@ -79,8 +80,10 @@ func (s *FilterTestSuite) makeWakuFilterLightNode() *WakuFilterLightNode {
filterPush := NewWakuFilterLightNode(b, nil, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, s.log) filterPush := NewWakuFilterLightNode(b, nil, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, s.log)
filterPush.SetHost(host) filterPush.SetHost(host)
s.lightNodeHost = host s.lightNodeHost = host
if start {
err = filterPush.Start(context.Background()) err = filterPush.Start(context.Background())
s.Require().NoError(err) s.Require().NoError(err)
}
return filterPush return filterPush
} }
@ -178,7 +181,7 @@ func (s *FilterTestSuite) SetupTest() {
s.testTopic = "/waku/2/go/filter/test" s.testTopic = "/waku/2/go/filter/test"
s.testContentTopic = "TopicA" s.testContentTopic = "TopicA"
s.lightNode = s.makeWakuFilterLightNode() s.lightNode = s.makeWakuFilterLightNode(true)
s.relayNode, s.fullNode = s.makeWakuFilterFullNode(s.testTopic) s.relayNode, s.fullNode = s.makeWakuFilterFullNode(s.testTopic)
@ -333,3 +336,75 @@ func (s *FilterTestSuite) TestMultipleMessages() {
}, s.subDetails.C) }, s.subDetails.C)
} }
func (s *FilterTestSuite) TestRunningGuard() {
s.lightNode.Stop()
contentFilter := ContentFilter{
Topic: "test",
ContentTopics: []string{"test"},
}
_, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
s.Require().ErrorIs(err, errNotStarted)
err = s.lightNode.Start(s.ctx)
s.Require().NoError(err)
_, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
s.Require().NoError(err)
}
func (s *FilterTestSuite) TestFireAndForgetAndCustomWg() {
contentFilter := ContentFilter{
Topic: "test",
ContentTopics: []string{"test"},
}
_, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
s.Require().NoError(err)
ch, err := s.lightNode.Unsubscribe(s.ctx, contentFilter, DontWait())
_, open := <-ch
s.Require().NoError(err)
s.Require().False(open)
_, err = s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
s.Require().NoError(err)
wg := sync.WaitGroup{}
_, err = s.lightNode.Unsubscribe(s.ctx, contentFilter, WithWaitGroup(&wg))
wg.Wait()
s.Require().NoError(err)
}
func (s *FilterTestSuite) TestStartStop() {
var wg sync.WaitGroup
wg.Add(2)
s.lightNode = s.makeWakuFilterLightNode(false)
stopNode := func() {
for i := 0; i < 100000; i++ {
s.lightNode.Stop()
}
wg.Done()
}
startNode := func() {
for i := 0; i < 100; i++ {
err := s.lightNode.Start(context.Background())
if errors.Is(err, errAlreadyStarted) {
continue
}
s.Require().NoError(err)
}
wg.Done()
}
go startNode()
go stopNode()
wg.Wait()
}

View File

@ -2,6 +2,7 @@ package filter
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/host"
@ -26,6 +27,7 @@ type (
selectedPeer peer.ID selectedPeer peer.ID
requestID []byte requestID []byte
log *zap.Logger log *zap.Logger
wg *sync.WaitGroup
} }
FilterParameters struct { FilterParameters struct {
@ -135,9 +137,26 @@ func AutomaticRequestId() FilterUnsubscribeOption {
} }
} }
// WithWaitGroup allos specigying a waitgroup to wait until all
// unsubscribe requests are complete before the function is complete
func WithWaitGroup(wg *sync.WaitGroup) FilterUnsubscribeOption {
return func(params *FilterUnsubscribeParameters) {
params.wg = wg
}
}
// DontWait is used to fire and forget an unsubscription, and don't
// care about the results of it
func DontWait() FilterUnsubscribeOption {
return func(params *FilterUnsubscribeParameters) {
params.wg = nil
}
}
func DefaultUnsubscribeOptions() []FilterUnsubscribeOption { func DefaultUnsubscribeOptions() []FilterUnsubscribeOption {
return []FilterUnsubscribeOption{ return []FilterUnsubscribeOption{
AutomaticRequestId(), AutomaticRequestId(),
WithWaitGroup(&sync.WaitGroup{}),
} }
} }