From 4b1c188cf02ca65493f707946b8b6f131cc2f49d Mon Sep 17 00:00:00 2001 From: harsh jain Date: Wed, 13 Sep 2023 12:18:44 +0700 Subject: [PATCH] feat: add common protocol design (#724) * feat: add common protocol design * fix: remove redundant vars * fix: use AppDesign's ctx * refactor: relay, add AppDesign * feat: changes for suggestions * test: commonService start/stop execution * fix: lint error * nit: add comments --- waku/v2/protocol/common_service.go | 73 +++++++++++ waku/v2/protocol/common_service_test.go | 28 +++++ waku/v2/protocol/filter/client.go | 115 ++++++------------ waku/v2/protocol/filter/filter_test.go | 4 +- waku/v2/protocol/filter/server.go | 44 +++---- waku/v2/protocol/legacy_filter/waku_filter.go | 52 +++----- waku/v2/protocol/peer_exchange/client.go | 4 +- waku/v2/protocol/peer_exchange/protocol.go | 33 ++--- waku/v2/protocol/relay/waku_relay.go | 37 +++--- waku/v2/rendezvous/rendezvous.go | 37 ++---- 10 files changed, 218 insertions(+), 209 deletions(-) create mode 100644 waku/v2/protocol/common_service.go create mode 100644 waku/v2/protocol/common_service_test.go diff --git a/waku/v2/protocol/common_service.go b/waku/v2/protocol/common_service.go new file mode 100644 index 00000000..65746961 --- /dev/null +++ b/waku/v2/protocol/common_service.go @@ -0,0 +1,73 @@ +package protocol + +import ( + "context" + "errors" + "sync" +) + +// this is common layout for all the services that require mutex protection and a guarantee that all running goroutines will be finished before stop finishes execution. This guarantee comes from waitGroup all one has to use CommonService.WaitGroup() in the goroutines that should finish by the end of stop function. +type CommonService struct { + sync.RWMutex + cancel context.CancelFunc + ctx context.Context + wg sync.WaitGroup + started bool +} + +func NewCommonService() *CommonService { + return &CommonService{ + wg: sync.WaitGroup{}, + RWMutex: sync.RWMutex{}, + } +} + +// mutex protected start function +// creates internal context over provided context and runs fn safely +// fn is excerpt to be executed to start the protocol +func (sp *CommonService) Start(ctx context.Context, fn func() error) error { + sp.Lock() + defer sp.Unlock() + if sp.started { + return ErrAlreadyStarted + } + sp.started = true + sp.ctx, sp.cancel = context.WithCancel(ctx) + if err := fn(); err != nil { + sp.started = false + sp.cancel() + return err + } + return nil +} + +var ErrAlreadyStarted = errors.New("already started") +var ErrNotStarted = errors.New("not started") + +// mutex protected stop function +func (sp *CommonService) Stop(fn func()) { + sp.Lock() + defer sp.Unlock() + if !sp.started { + return + } + sp.cancel() + fn() + sp.wg.Wait() + sp.started = false +} + +// This is not a mutex protected function, it is up to the caller to use it in a mutex protected context +func (sp *CommonService) ErrOnNotRunning() error { + if !sp.started { + return ErrNotStarted + } + return nil +} + +func (sp *CommonService) Context() context.Context { + return sp.ctx +} +func (sp *CommonService) WaitGroup() *sync.WaitGroup { + return &sp.wg +} diff --git a/waku/v2/protocol/common_service_test.go b/waku/v2/protocol/common_service_test.go new file mode 100644 index 00000000..cd707e11 --- /dev/null +++ b/waku/v2/protocol/common_service_test.go @@ -0,0 +1,28 @@ +package protocol + +import ( + "context" + "sync" + "testing" +) + +// check if start and stop on common service works in random order +func TestCommonService(t *testing.T) { + s := NewCommonService() + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + if i%2 == 0 { + go func() { + wg.Done() + _ = s.Start(context.TODO(), func() error { return nil }) + }() + } else { + go func() { + wg.Done() + go s.Stop(func() {}) + }() + } + } + wg.Wait() +} diff --git a/waku/v2/protocol/filter/client.go b/waku/v2/protocol/filter/client.go index c8ed6620..200fa1fc 100644 --- a/waku/v2/protocol/filter/client.go +++ b/waku/v2/protocol/filter/client.go @@ -7,7 +7,6 @@ import ( "fmt" "math" "net/http" - "sync" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -34,16 +33,11 @@ var ( ) type WakuFilterLightNode struct { - sync.RWMutex - started bool - - cancel context.CancelFunc - ctx context.Context + *protocol.CommonService h host.Host broadcaster relay.Broadcaster //TODO: Move the broadcast functionality outside of relay client to a higher SDK layer.s timesource timesource.Timesource metrics Metrics - wg *sync.WaitGroup log *zap.Logger subscriptions *SubscriptionsMap pm *peermanager.PeerManager @@ -59,9 +53,6 @@ type WakuFilterPushResult struct { PeerID peer.ID } -var errNotStarted = errors.New("not started") -var errAlreadyStarted = errors.New("already started") - // NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options // Note that broadcaster is optional. // Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode. @@ -72,8 +63,8 @@ func NewWakuFilterLightNode(broadcaster relay.Broadcaster, pm *peermanager.PeerM wf.log = log.Named("filterv2-lightnode") wf.broadcaster = broadcaster wf.timesource = timesource - wf.wg = &sync.WaitGroup{} wf.pm = pm + wf.CommonService = protocol.NewCommonService() wf.metrics = newMetrics(reg) return wf @@ -85,59 +76,36 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) { } func (wf *WakuFilterLightNode) Start(ctx context.Context) error { - wf.Lock() - defer wf.Unlock() + return wf.CommonService.Start(ctx, wf.start) - if wf.started { - return errAlreadyStarted - } +} - wf.wg.Wait() // Wait for any goroutines to stop - - ctx, cancel := context.WithCancel(ctx) - wf.cancel = cancel - wf.ctx = ctx +func (wf *WakuFilterLightNode) start() error { wf.subscriptions = NewSubscriptionMap(wf.log) - wf.started = true - - wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx)) + wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(wf.Context())) wf.log.Info("filter-push protocol started") - return nil } // Stop unmounts the filter protocol func (wf *WakuFilterLightNode) Stop() { - wf.Lock() - defer wf.Unlock() - - if !wf.started { - return - } - - wf.cancel() - - wf.h.RemoveStreamHandler(FilterPushID_v20beta1) - - res, err := wf.unsubscribeAll(wf.ctx) - if err != nil { - wf.log.Warn("unsubscribing from full nodes", zap.Error(err)) - } - - for r := range res { - if r.Err != nil { - wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID)) + wf.CommonService.Stop(func() { + wf.h.RemoveStreamHandler(FilterPushID_v20beta1) + res, err := wf.unsubscribeAll(wf.Context()) + if err != nil { + wf.log.Warn("unsubscribing from full nodes", zap.Error(err)) } - } + for r := range res { + if r.Err != nil { + wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID)) + } - wf.subscriptions.Clear() - - wf.started = false - wf.cancel = nil - - wf.wg.Wait() + } + // + wf.subscriptions.Clear() + }) } func (wf *WakuFilterLightNode) onRequest(ctx context.Context) func(s network.Stream) { @@ -248,9 +216,8 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return nil, errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return nil, err } if contentFilter.Topic == "" { @@ -285,7 +252,6 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont if err != nil { return nil, err } - return wf.subscriptions.NewSubscription(params.selectedPeer, contentFilter.Topic, contentFilter.ContentTopics), nil } @@ -293,9 +259,8 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return nil, errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return nil, err } if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) { @@ -319,9 +284,8 @@ func (wf *WakuFilterLightNode) getUnsubscribeParameters(opts ...FilterUnsubscrib func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return err } return wf.request( @@ -334,9 +298,8 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error { func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return err } return wf.Ping(ctx, subscription.PeerID) @@ -345,8 +308,7 @@ func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscrip func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails { wf.RLock() defer wf.RUnlock() - - if !wf.started { + if err := wf.ErrOnNotRunning(); err != nil { return nil } @@ -398,13 +360,11 @@ func (wf *WakuFilterLightNode) cleanupSubscriptions(peerID peer.ID, contentFilte } // Unsubscribe is used to stop receiving messages from a peer that match a content filter -func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, - opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { +func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return nil, errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return nil, err } if contentFilter.Topic == "" { @@ -485,13 +445,11 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co } // Unsubscribe is used to stop receiving messages from a peer that match a content filter -func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, - opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { +func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return nil, errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return nil, err } var contentTopics []string @@ -563,9 +521,8 @@ func (wf *WakuFilterLightNode) unsubscribeAll(ctx context.Context, opts ...Filte func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) { wf.RLock() defer wf.RUnlock() - - if !wf.started { - return nil, errNotStarted + if err := wf.ErrOnNotRunning(); err != nil { + return nil, err } return wf.unsubscribeAll(ctx, opts...) diff --git a/waku/v2/protocol/filter/filter_test.go b/waku/v2/protocol/filter/filter_test.go index 9e046780..39cefd5a 100644 --- a/waku/v2/protocol/filter/filter_test.go +++ b/waku/v2/protocol/filter/filter_test.go @@ -350,7 +350,7 @@ func (s *FilterTestSuite) TestRunningGuard() { _, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID())) - s.Require().ErrorIs(err, errNotStarted) + s.Require().ErrorIs(err, protocol.ErrNotStarted) err = s.lightNode.Start(s.ctx) s.Require().NoError(err) @@ -398,7 +398,7 @@ func (s *FilterTestSuite) TestStartStop() { startNode := func() { for i := 0; i < 100; i++ { err := s.lightNode.Start(context.Background()) - if errors.Is(err, errAlreadyStarted) { + if errors.Is(err, protocol.ErrAlreadyStarted) { continue } s.Require().NoError(err) diff --git a/waku/v2/protocol/filter/server.go b/waku/v2/protocol/filter/server.go index 25803ae6..060ea3e3 100644 --- a/waku/v2/protocol/filter/server.go +++ b/waku/v2/protocol/filter/server.go @@ -6,7 +6,6 @@ import ( "fmt" "math" "net/http" - "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -31,13 +30,11 @@ const peerHasNoSubscription = "peer has no subscriptions" type ( WakuFilterFullNode struct { - cancel context.CancelFunc h host.Host msgSub relay.Subscription metrics Metrics - wg *sync.WaitGroup log *zap.Logger - + *protocol.CommonService subscriptions *SubscribersMap maxSubscriptions int @@ -56,7 +53,7 @@ func NewWakuFilterFullNode(timesource timesource.Timesource, reg prometheus.Regi opt(params) } - wf.wg = &sync.WaitGroup{} + wf.CommonService = protocol.NewCommonService() wf.metrics = newMetrics(reg) wf.subscriptions = NewSubscribersMap(params.Timeout) wf.maxSubscriptions = params.MaxSubscribers @@ -70,19 +67,19 @@ func (wf *WakuFilterFullNode) SetHost(h host.Host) { } func (wf *WakuFilterFullNode) Start(ctx context.Context, sub relay.Subscription) error { - wf.wg.Wait() // Wait for any goroutines to stop + return wf.CommonService.Start(ctx, func() error { + return wf.start(sub) + }) +} - ctx, cancel := context.WithCancel(ctx) +func (wf *WakuFilterFullNode) start(sub relay.Subscription) error { + wf.h.SetStreamHandlerMatch(FilterSubscribeID_v20beta1, protocol.PrefixTextMatch(string(FilterSubscribeID_v20beta1)), wf.onRequest(wf.Context())) - wf.h.SetStreamHandlerMatch(FilterSubscribeID_v20beta1, protocol.PrefixTextMatch(string(FilterSubscribeID_v20beta1)), wf.onRequest(ctx)) - - wf.cancel = cancel wf.msgSub = sub - wf.wg.Add(1) - go wf.filterListener(ctx) + wf.WaitGroup().Add(1) + go wf.filterListener(wf.Context()) wf.log.Info("filter-subscriber protocol started") - return nil } @@ -227,7 +224,7 @@ func (wf *WakuFilterFullNode) unsubscribeAll(ctx context.Context, s network.Stre } func (wf *WakuFilterFullNode) filterListener(ctx context.Context) { - defer wf.wg.Done() + defer wf.WaitGroup().Done() // This function is invoked for each message received // on the full node in context of Waku2-Filter @@ -243,9 +240,9 @@ func (wf *WakuFilterFullNode) filterListener(ctx context.Context) { subscriber := subscriber // https://golang.org/doc/faq#closures_and_goroutines // Do a message push to light node logger.Info("pushing message to light node") - wf.wg.Add(1) + wf.WaitGroup().Add(1) go func(subscriber peer.ID) { - defer wf.wg.Done() + defer wf.WaitGroup().Done() start := time.Now() err := wf.pushMessage(ctx, subscriber, envelope) if err != nil { @@ -317,15 +314,8 @@ func (wf *WakuFilterFullNode) pushMessage(ctx context.Context, peerID peer.ID, e // Stop unmounts the filter protocol func (wf *WakuFilterFullNode) Stop() { - if wf.cancel == nil { - return - } - - wf.h.RemoveStreamHandler(FilterSubscribeID_v20beta1) - - wf.cancel() - - wf.msgSub.Unsubscribe() - - wf.wg.Wait() + wf.CommonService.Stop(func() { + wf.h.RemoveStreamHandler(FilterSubscribeID_v20beta1) + wf.msgSub.Unsubscribe() + }) } diff --git a/waku/v2/protocol/legacy_filter/waku_filter.go b/waku/v2/protocol/legacy_filter/waku_filter.go index 19685876..e28a0729 100644 --- a/waku/v2/protocol/legacy_filter/waku_filter.go +++ b/waku/v2/protocol/legacy_filter/waku_filter.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "math" - "sync" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -47,12 +46,11 @@ type ( } WakuFilter struct { - cancel context.CancelFunc + *protocol.CommonService h host.Host isFullNode bool msgSub relay.Subscription metrics Metrics - wg *sync.WaitGroup log *zap.Logger filters *FilterMap @@ -75,8 +73,8 @@ func NewWakuFilter(broadcaster relay.Broadcaster, isFullNode bool, timesource ti opt(params) } - wf.wg = &sync.WaitGroup{} wf.isFullNode = isFullNode + wf.CommonService = protocol.NewCommonService() wf.filters = NewFilterMap(broadcaster, timesource) wf.subscribers = NewSubscribers(params.Timeout) wf.metrics = newMetrics(reg) @@ -90,23 +88,19 @@ func (wf *WakuFilter) SetHost(h host.Host) { } func (wf *WakuFilter) Start(ctx context.Context, sub relay.Subscription) error { - wf.wg.Wait() // Wait for any goroutines to stop - - ctx, cancel := context.WithCancel(ctx) - - wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest(ctx)) - - wf.cancel = cancel - wf.msgSub = sub - - wf.wg.Add(1) - go wf.filterListener(ctx) - - wf.log.Info("filter protocol started") - - return nil + return wf.CommonService.Start(ctx, func() error { + return wf.start(sub) + }) } +func (wf *WakuFilter) start(sub relay.Subscription) error { + wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest(wf.Context())) + wf.msgSub = sub + wf.WaitGroup().Add(1) + go wf.filterListener(wf.Context()) + wf.log.Info("filter protocol started") + return nil +} func (wf *WakuFilter) onRequest(ctx context.Context) func(s network.Stream) { return func(s network.Stream) { defer s.Close() @@ -188,7 +182,7 @@ func (wf *WakuFilter) pushMessage(ctx context.Context, subscriber Subscriber, ms } func (wf *WakuFilter) filterListener(ctx context.Context) { - defer wf.wg.Done() + defer wf.WaitGroup().Done() // This function is invoked for each message received // on the full node in context of Waku2-Filter @@ -327,19 +321,13 @@ func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilt // Stop unmounts the filter protocol func (wf *WakuFilter) Stop() { - if wf.cancel == nil { - return - } + wf.CommonService.Stop(func() { + wf.msgSub.Unsubscribe() - wf.cancel() - - wf.msgSub.Unsubscribe() - - wf.h.RemoveStreamHandler(FilterID_v20beta1) - wf.filters.RemoveAll() - wf.subscribers.Clear() - - wf.wg.Wait() + wf.h.RemoveStreamHandler(FilterID_v20beta1) + wf.filters.RemoveAll() + wf.subscribers.Clear() + }) } // Subscribe setups a subscription to receive messages that match a specific content filter diff --git a/waku/v2/protocol/peer_exchange/client.go b/waku/v2/protocol/peer_exchange/client.go index 0cdb5a5e..1466b845 100644 --- a/waku/v2/protocol/peer_exchange/client.go +++ b/waku/v2/protocol/peer_exchange/client.go @@ -100,9 +100,9 @@ func (wakuPX *WakuPeerExchange) handleResponse(ctx context.Context, response *pb if len(discoveredPeers) != 0 { wakuPX.log.Info("connecting to newly discovered peers", zap.Int("count", len(discoveredPeers))) - wakuPX.wg.Add(1) + wakuPX.WaitGroup().Add(1) go func() { - defer wakuPX.wg.Done() + defer wakuPX.WaitGroup().Done() peerCh := make(chan peermanager.PeerData) defer close(peerCh) diff --git a/waku/v2/protocol/peer_exchange/protocol.go b/waku/v2/protocol/peer_exchange/protocol.go index 1f905b30..af9f47e8 100644 --- a/waku/v2/protocol/peer_exchange/protocol.go +++ b/waku/v2/protocol/peer_exchange/protocol.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "math" - "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -43,9 +42,8 @@ type WakuPeerExchange struct { metrics Metrics log *zap.Logger - cancel context.CancelFunc + *protocol.CommonService - wg sync.WaitGroup peerConnector PeerConnector enrCache *enrCache } @@ -65,6 +63,7 @@ func NewWakuPeerExchange(disc *discv5.DiscoveryV5, peerConnector PeerConnector, wakuPX.enrCache = newEnrCache wakuPX.peerConnector = peerConnector wakuPX.pm = pm + wakuPX.CommonService = protocol.NewCommonService() return wakuPX, nil } @@ -76,20 +75,15 @@ func (wakuPX *WakuPeerExchange) SetHost(h host.Host) { // Start inits the peer exchange protocol func (wakuPX *WakuPeerExchange) Start(ctx context.Context) error { - if wakuPX.cancel != nil { - return errors.New("peer exchange already started") - } - - wakuPX.wg.Wait() // Waiting for any go routines to stop - - ctx, cancel := context.WithCancel(ctx) - wakuPX.cancel = cancel + return wakuPX.CommonService.Start(ctx, wakuPX.start) +} +func (wakuPX *WakuPeerExchange) start() error { wakuPX.h.SetStreamHandlerMatch(PeerExchangeID_v20alpha1, protocol.PrefixTextMatch(string(PeerExchangeID_v20alpha1)), wakuPX.onRequest()) - wakuPX.log.Info("Peer exchange protocol started") - wakuPX.wg.Add(1) - go wakuPX.runPeerExchangeDiscv5Loop(ctx) + wakuPX.WaitGroup().Add(1) + go wakuPX.runPeerExchangeDiscv5Loop(wakuPX.Context()) + wakuPX.log.Info("Peer exchange protocol started") return nil } @@ -133,12 +127,9 @@ func (wakuPX *WakuPeerExchange) onRequest() func(s network.Stream) { // Stop unmounts the peer exchange protocol func (wakuPX *WakuPeerExchange) Stop() { - if wakuPX.cancel == nil { - return - } - wakuPX.h.RemoveStreamHandler(PeerExchangeID_v20alpha1) - wakuPX.cancel() - wakuPX.wg.Wait() + wakuPX.CommonService.Stop(func() { + wakuPX.h.RemoveStreamHandler(PeerExchangeID_v20alpha1) + }) } func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error { @@ -173,7 +164,7 @@ func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error { } func (wakuPX *WakuPeerExchange) runPeerExchangeDiscv5Loop(ctx context.Context) { - defer wakuPX.wg.Done() + defer wakuPX.WaitGroup().Done() // Runs a discv5 loop adding new peers to the px peer cache if wakuPX.disc == nil { diff --git a/waku/v2/protocol/relay/waku_relay.go b/waku/v2/protocol/relay/waku_relay.go index 49f64f02..51efae51 100644 --- a/waku/v2/protocol/relay/waku_relay.go +++ b/waku/v2/protocol/relay/waku_relay.go @@ -60,9 +60,7 @@ type WakuRelay struct { EvtRelayUnsubscribed event.Emitter } - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + *waku_proto.CommonService } // EvtRelaySubscribed is an event emitted when a new subscription to a pubsub topic is created @@ -87,7 +85,7 @@ func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesou w.relaySubs = make(map[string]*pubsub.Subscription) w.bcaster = bcaster w.minPeersToPublish = minPeersToPublish - w.wg = sync.WaitGroup{} + w.CommonService = waku_proto.NewCommonService() w.log = log.Named("relay") w.events = eventbus.NewBus() w.metrics = newMetrics(reg, w.log) @@ -213,12 +211,11 @@ func (w *WakuRelay) SetHost(h host.Host) { // Start initiates the WakuRelay protocol func (w *WakuRelay) Start(ctx context.Context) error { - w.wg.Wait() - ctx, cancel := context.WithCancel(ctx) - w.ctx = ctx // TODO: create worker for creating subscriptions instead of storing context - w.cancel = cancel + return w.CommonService.Start(ctx, w.start) +} - ps, err := pubsub.NewGossipSub(ctx, w.host, w.opts...) +func (w *WakuRelay) start() error { + ps, err := pubsub.NewGossipSub(w.Context(), w.host, w.opts...) if err != nil { return err } @@ -310,7 +307,7 @@ func (w *WakuRelay) subscribe(topic string) (subs *pubsub.Subscription, err erro } if w.bcaster != nil { - w.wg.Add(1) + w.WaitGroup().Add(1) go w.subscribeToTopic(topic, sub) } w.log.Info("subscribing to topic", zap.String("topic", sub.Topic())) @@ -364,15 +361,11 @@ func (w *WakuRelay) Publish(ctx context.Context, message *pb.WakuMessage) ([]byt // Stop unmounts the relay protocol and stops all subscriptions func (w *WakuRelay) Stop() { - if w.cancel == nil { - return // Not started - } - - w.host.RemoveStreamHandler(WakuRelayID_v200) - w.emitters.EvtRelaySubscribed.Close() - w.emitters.EvtRelayUnsubscribed.Close() - w.cancel() - w.wg.Wait() + w.CommonService.Stop(func() { + w.host.RemoveStreamHandler(WakuRelayID_v200) + w.emitters.EvtRelaySubscribed.Close() + w.emitters.EvtRelayUnsubscribed.Close() + }) } // EnoughPeersToPublish returns whether there are enough peers connected in the default waku pubsub topic @@ -454,12 +447,12 @@ func (w *WakuRelay) nextMessage(ctx context.Context, sub *pubsub.Subscription) < } func (w *WakuRelay) subscribeToTopic(pubsubTopic string, sub *pubsub.Subscription) { - defer w.wg.Done() + defer w.WaitGroup().Done() - subChannel := w.nextMessage(w.ctx, sub) + subChannel := w.nextMessage(w.Context(), sub) for { select { - case <-w.ctx.Done(): + case <-w.Context().Done(): return // TODO: if there are no more relay subscriptions, close the pubsub subscription case msg, ok := <-subChannel: diff --git a/waku/v2/rendezvous/rendezvous.go b/waku/v2/rendezvous/rendezvous.go index 4d4550b3..eecab0d3 100644 --- a/waku/v2/rendezvous/rendezvous.go +++ b/waku/v2/rendezvous/rendezvous.go @@ -2,10 +2,8 @@ package rendezvous import ( "context" - "errors" "fmt" "math" - "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -32,9 +30,8 @@ type Rendezvous struct { peerConnector PeerConnector - log *zap.Logger - wg sync.WaitGroup - cancel context.CancelFunc + log *zap.Logger + *protocol.CommonService } // PeerConnector will subscribe to a channel containing the information for all peers found by this discovery protocol @@ -49,6 +46,7 @@ func NewRendezvous(db *DB, peerConnector PeerConnector, log *zap.Logger) *Rendez db: db, peerConnector: peerConnector, log: logger, + CommonService: protocol.NewCommonService(), } } @@ -58,19 +56,14 @@ func (r *Rendezvous) SetHost(h host.Host) { } func (r *Rendezvous) Start(ctx context.Context) error { - if r.cancel != nil { - return errors.New("already started") - } + return r.CommonService.Start(ctx, r.start) +} - ctx, cancel := context.WithCancel(ctx) - r.cancel = cancel - - err := r.db.Start(ctx) +func (r *Rendezvous) start() error { + err := r.db.Start(r.Context()) if err != nil { - cancel() return err } - r.rendezvousSvc = rvs.NewRendezvousService(r.host, r.db) r.log.Info("rendezvous protocol started") @@ -161,9 +154,9 @@ func (r *Rendezvous) RegisterRelayShards(ctx context.Context, rs protocol.RelayS // RegisterWithNamespace registers the node in the rendezvous point by using an specific namespace (usually a pubsub topic) func (r *Rendezvous) RegisterWithNamespace(ctx context.Context, namespace string, rendezvousPoints []*RendezvousPoint) { for _, m := range rendezvousPoints { - r.wg.Add(1) + r.WaitGroup().Add(1) go func(m *RendezvousPoint) { - r.wg.Done() + r.WaitGroup().Done() rendezvousClient := rvs.NewRendezvousClient(r.host, m.id) retries := 0 @@ -186,14 +179,10 @@ func (r *Rendezvous) RegisterWithNamespace(ctx context.Context, namespace string } func (r *Rendezvous) Stop() { - if r.cancel == nil { - return - } - - r.cancel() - r.wg.Wait() - r.host.RemoveStreamHandler(rvs.RendezvousProto) - r.rendezvousSvc = nil + r.CommonService.Stop(func() { + r.host.RemoveStreamHandler(rvs.RendezvousProto) + r.rendezvousSvc = nil + }) } // ShardToNamespace translates a cluster and shard index into a rendezvous namespace