diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index 3db89479..d4a7d28b 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -37,6 +37,8 @@ type DiscoveryV5 struct { NAT nat.Interface quit chan struct{} + wg *sync.WaitGroup + peerCache peerCache } @@ -142,6 +144,7 @@ func NewDiscoveryV5(host host.Host, ipAddr net.IP, tcpPort int, priv *ecdsa.Priv host: host, params: params, NAT: NAT, + wg: &sync.WaitGroup{}, peerCache: peerCache{ rng: rand.New(rand.NewSource(rand.Int63())), recs: make(map[peer.ID]peerRecord), @@ -197,7 +200,9 @@ func (d *DiscoveryV5) listen() error { d.udpAddr = conn.LocalAddr().(*net.UDPAddr) if d.NAT != nil && !d.udpAddr.IP.IsLoopback() { + d.wg.Add(1) go func() { + defer d.wg.Done() nat.Map(d.NAT, d.quit, "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery") }() @@ -222,13 +227,15 @@ func (d *DiscoveryV5) Start() error { d.Lock() defer d.Unlock() + d.wg.Wait() // Waiting for other go routines to stop + + d.quit = make(chan struct{}, 1) + err := d.listen() if err != nil { return err } - d.quit = make(chan struct{}) - return nil } @@ -236,12 +243,14 @@ func (d *DiscoveryV5) Stop() { d.Lock() defer d.Unlock() + close(d.quit) + d.listener.Close() d.listener = nil - close(d.quit) - log.Info("Stopped Discovery V5") + + d.wg.Wait() } // IsPrivate reports whether ip is a private address, according to @@ -354,6 +363,8 @@ func (c *DiscoveryV5) Advertise(ctx context.Context, ns string, opts ...discover } func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int, doneCh chan struct{}) { + defer d.wg.Done() + for { if len(d.peerCache.recs) >= limit { break @@ -435,6 +446,8 @@ func (d *DiscoveryV5) FindPeers(ctx context.Context, topic string, opts ...disco defer iterator.Close() doneCh := make(chan struct{}) + + d.wg.Add(1) go d.iterate(ctx, iterator, limit, doneCh) select { diff --git a/waku/v2/discv5/discover_test.go b/waku/v2/discv5/discover_test.go index f9d6d448..70a72902 100644 --- a/waku/v2/discv5/discover_test.go +++ b/waku/v2/discv5/discover_test.go @@ -63,6 +63,10 @@ func TestDiscV5(t *testing.T) { d3, err := NewDiscoveryV5(host3, net.IPv4(127, 0, 0, 1), tcpPort3, prvKey3, NewWakuEnrBitfield(true, true, true, true), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()})) require.NoError(t, err) + defer d1.Stop() + defer d2.Stop() + defer d3.Stop() + err = d1.Start() require.NoError(t, err) diff --git a/waku/v2/node/connectedness.go b/waku/v2/node/connectedness.go index 0a20d0e8..503024f0 100644 --- a/waku/v2/node/connectedness.go +++ b/waku/v2/node/connectedness.go @@ -84,6 +84,8 @@ func (w *WakuNode) sendConnStatus() { } func (w *WakuNode) connectednessListener() { + defer w.wg.Done() + for { select { case <-w.quit: diff --git a/waku/v2/node/keepalive_test.go b/waku/v2/node/keepalive_test.go index 8d72e193..f348a72e 100644 --- a/waku/v2/node/keepalive_test.go +++ b/waku/v2/node/keepalive_test.go @@ -2,6 +2,7 @@ package node import ( "context" + "sync" "testing" "time" @@ -28,7 +29,10 @@ func TestKeepAlive(t *testing.T) { ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second) defer cancel2() - pingPeer(ctx2, host1, host2.ID()) + + wg := &sync.WaitGroup{} + + pingPeer(ctx2, wg, host1, host2.ID()) require.NoError(t, ctx.Err()) } diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index 614861c4..d5ba5462 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strconv" + "sync" "time" logging "github.com/ipfs/go-log" @@ -67,6 +68,7 @@ type WakuNode struct { ctx context.Context cancel context.CancelFunc quit chan struct{} + wg *sync.WaitGroup // Channel passed to WakuNode constructor // receiving connection status notifications @@ -122,6 +124,7 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) { w.ctx = ctx w.opts = params w.quit = make(chan struct{}) + w.wg = &sync.WaitGroup{} w.addrChan = make(chan ma.Multiaddr, 1024) if w.protocolEventSub, err = host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)); err != nil { @@ -143,15 +146,16 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) { w.connectionNotif = NewConnectionNotifier(ctx, host) w.host.Network().Notify(w.connectionNotif) + w.wg.Add(2) go w.connectednessListener() - - if w.opts.keepAliveInterval > time.Duration(0) { - w.startKeepAlive(w.opts.keepAliveInterval) - } - go w.checkForAddressChanges() go w.onAddrChange() + if w.opts.keepAliveInterval > time.Duration(0) { + w.wg.Add(1) + w.startKeepAlive(w.opts.keepAliveInterval) + } + return w, nil } @@ -190,6 +194,8 @@ func (w *WakuNode) logAddress(addr ma.Multiaddr) { } func (w *WakuNode) checkForAddressChanges() { + defer w.wg.Done() + addrs := w.ListenAddresses() first := make(chan struct{}, 1) first <- struct{}{} @@ -311,6 +317,8 @@ func (w *WakuNode) Stop() { w.store.Stop() w.host.Close() + + w.wg.Wait() } func (w *WakuNode) Host() host.Host { @@ -425,7 +433,10 @@ func (w *WakuNode) startStore() { if w.opts.shouldResume { // TODO: extract this to a function and run it when you go offline // TODO: determine if a store is listening to a topic + w.wg.Add(1) go func() { + defer w.wg.Done() + ticker := time.NewTicker(time.Second) defer ticker.Stop() @@ -577,6 +588,8 @@ func (w *WakuNode) Peers() ([]*Peer, error) { // This is necessary because TCP connections are automatically closed due to inactivity, // and doing a ping will avoid this (with a small bandwidth cost) func (w *WakuNode) startKeepAlive(t time.Duration) { + defer w.wg.Done() + log.Info("Setting up ping protocol with duration of ", t) ticker := time.NewTicker(t) @@ -594,7 +607,7 @@ func (w *WakuNode) startKeepAlive(t time.Duration) { // through Network's peer collection, as it will be empty for _, p := range w.host.Peerstore().Peers() { if p != w.host.ID() { - go pingPeer(w.ctx, w.host, p) + go pingPeer(w.ctx, w.wg, w.host, p) } } case <-w.quit: @@ -604,7 +617,10 @@ func (w *WakuNode) startKeepAlive(t time.Duration) { }() } -func pingPeer(ctx context.Context, host host.Host, peer peer.ID) { +func pingPeer(ctx context.Context, wg *sync.WaitGroup, host host.Host, peer peer.ID) { + wg.Add(1) + defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/waku/v2/protocol/filter/waku_filter.go b/waku/v2/protocol/filter/waku_filter.go index efafb5ed..c98a180a 100644 --- a/waku/v2/protocol/filter/waku_filter.go +++ b/waku/v2/protocol/filter/waku_filter.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "sync" logging "github.com/ipfs/go-log" "github.com/libp2p/go-libp2p-core/host" @@ -48,6 +49,7 @@ type ( h host.Host isFullNode bool MsgC chan *protocol.Envelope + wg *sync.WaitGroup filters *FilterMap subscribers *Subscribers @@ -67,13 +69,16 @@ func NewWakuFilter(ctx context.Context, host host.Host, isFullNode bool) *WakuFi wf := new(WakuFilter) wf.ctx = ctx - wf.MsgC = make(chan *protocol.Envelope) + wf.wg = &sync.WaitGroup{} + wf.MsgC = make(chan *protocol.Envelope, 1024) wf.h = host wf.isFullNode = isFullNode wf.filters = NewFilterMap() wf.subscribers = NewSubscribers() wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest) + + wf.wg.Add(1) go wf.FilterListener() if wf.isFullNode { @@ -155,6 +160,8 @@ func (wf *WakuFilter) pushMessage(subscriber Subscriber, msg *pb.WakuMessage) er } func (wf *WakuFilter) FilterListener() { + defer wf.wg.Done() + // This function is invoked for each message received // on the full node in context of Waku2-Filter handle := func(envelope *protocol.Envelope) error { // async @@ -189,7 +196,6 @@ func (wf *WakuFilter) FilterListener() { log.Error("failed to handle message", err) } } - } // Having a FilterRequest struct, @@ -281,8 +287,11 @@ func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilt } func (wf *WakuFilter) Stop() { + close(wf.MsgC) + wf.h.RemoveStreamHandler(FilterID_v20beta1) wf.filters.RemoveAll() + wf.wg.Wait() } func (wf *WakuFilter) Subscribe(ctx context.Context, f ContentFilter, opts ...FilterSubscribeOption) (filterID string, theFilter Filter, err error) { diff --git a/waku/v2/protocol/store/message_queue.go b/waku/v2/protocol/store/message_queue.go index 300be1cf..fa7a52d6 100644 --- a/waku/v2/protocol/store/message_queue.go +++ b/waku/v2/protocol/store/message_queue.go @@ -16,6 +16,7 @@ type MessageQueue struct { maxDuration time.Duration quit chan struct{} + wg *sync.WaitGroup } func (self *MessageQueue) Push(msg IndexedWakuMessage) { @@ -73,6 +74,8 @@ func (self *MessageQueue) cleanOlderRecords() { } func (self *MessageQueue) checkForOlderRecords(d time.Duration) { + defer self.wg.Done() + ticker := time.NewTicker(d) defer ticker.Stop() @@ -98,9 +101,11 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue { maxDuration: maxDuration, seen: make(map[[32]byte]struct{}), quit: make(chan struct{}), + wg: &sync.WaitGroup{}, } if maxDuration != 0 { + result.wg.Add(1) go result.checkForOlderRecords(10 * time.Second) // is 10s okay? } @@ -109,4 +114,5 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue { func (self *MessageQueue) Stop() { close(self.quit) + self.wg.Wait() } diff --git a/waku/v2/protocol/store/waku_store.go b/waku/v2/protocol/store/waku_store.go index 90ec1a5a..546f3172 100644 --- a/waku/v2/protocol/store/waku_store.go +++ b/waku/v2/protocol/store/waku_store.go @@ -8,6 +8,7 @@ import ( "fmt" "math" "sort" + "sync" "time" logging "github.com/ipfs/go-log" @@ -227,6 +228,7 @@ type IndexedWakuMessage struct { type WakuStore struct { ctx context.Context MsgC chan *protocol.Envelope + wg *sync.WaitGroup started bool @@ -240,6 +242,7 @@ func NewWakuStore(host host.Host, p MessageProvider, maxNumberOfMessages int, ma wakuStore := new(WakuStore) wakuStore.msgProvider = p wakuStore.h = host + wakuStore.wg = &sync.WaitGroup{} wakuStore.messageQueue = NewMessageQueue(maxNumberOfMessages, maxRetentionDuration) return wakuStore } @@ -261,6 +264,7 @@ func (store *WakuStore) Start(ctx context.Context) { store.h.SetStreamHandlerMatch(StoreID_v20beta3, protocol.PrefixTextMatch(string(StoreID_v20beta3)), store.onRequest) + store.wg.Add(1) go store.storeIncomingMessages(ctx) if store.msgProvider == nil { @@ -327,6 +331,7 @@ func (store *WakuStore) storeMessage(env *protocol.Envelope) { } func (store *WakuStore) storeIncomingMessages(ctx context.Context) { + defer store.wg.Done() for envelope := range store.MsgC { store.storeMessage(envelope) } @@ -721,4 +726,6 @@ func (store *WakuStore) Stop() { if store.h != nil { store.h.RemoveStreamHandler(StoreID_v20beta3) } + + store.wg.Wait() } diff --git a/waku/v2/utils/peer.go b/waku/v2/utils/peer.go index 28064b3a..3700f5b1 100644 --- a/waku/v2/utils/peer.go +++ b/waku/v2/utils/peer.go @@ -79,9 +79,10 @@ func SelectPeerWithLowestRTT(ctx context.Context, host host.Host, protocolId str waitCh := make(chan struct{}) pingCh := make(chan pingResult, 1000) + wg.Add(len(peers)) + go func() { for _, p := range peers { - wg.Add(1) go func(p peer.ID) { defer wg.Done() ctx, cancel := context.WithTimeout(ctx, 3*time.Second)