mirror of https://github.com/status-im/go-waku.git
refactor: use WaitGroup for graceful shutdown of worker goroutines (#166)
This commit is contained in:
parent
c8caa46c99
commit
ce417a6486
|
@ -37,6 +37,8 @@ type DiscoveryV5 struct {
|
|||
NAT nat.Interface
|
||||
quit chan struct{}
|
||||
|
||||
wg *sync.WaitGroup
|
||||
|
||||
peerCache peerCache
|
||||
}
|
||||
|
||||
|
@ -142,6 +144,7 @@ func NewDiscoveryV5(host host.Host, ipAddr net.IP, tcpPort int, priv *ecdsa.Priv
|
|||
host: host,
|
||||
params: params,
|
||||
NAT: NAT,
|
||||
wg: &sync.WaitGroup{},
|
||||
peerCache: peerCache{
|
||||
rng: rand.New(rand.NewSource(rand.Int63())),
|
||||
recs: make(map[peer.ID]peerRecord),
|
||||
|
@ -197,7 +200,9 @@ func (d *DiscoveryV5) listen() error {
|
|||
d.udpAddr = conn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
if d.NAT != nil && !d.udpAddr.IP.IsLoopback() {
|
||||
d.wg.Add(1)
|
||||
go func() {
|
||||
defer d.wg.Done()
|
||||
nat.Map(d.NAT, d.quit, "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery")
|
||||
}()
|
||||
|
||||
|
@ -222,13 +227,15 @@ func (d *DiscoveryV5) Start() error {
|
|||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
d.wg.Wait() // Waiting for other go routines to stop
|
||||
|
||||
d.quit = make(chan struct{}, 1)
|
||||
|
||||
err := d.listen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.quit = make(chan struct{})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -236,12 +243,14 @@ func (d *DiscoveryV5) Stop() {
|
|||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
close(d.quit)
|
||||
|
||||
d.listener.Close()
|
||||
d.listener = nil
|
||||
|
||||
close(d.quit)
|
||||
|
||||
log.Info("Stopped Discovery V5")
|
||||
|
||||
d.wg.Wait()
|
||||
}
|
||||
|
||||
// IsPrivate reports whether ip is a private address, according to
|
||||
|
@ -354,6 +363,8 @@ func (c *DiscoveryV5) Advertise(ctx context.Context, ns string, opts ...discover
|
|||
}
|
||||
|
||||
func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int, doneCh chan struct{}) {
|
||||
defer d.wg.Done()
|
||||
|
||||
for {
|
||||
if len(d.peerCache.recs) >= limit {
|
||||
break
|
||||
|
@ -435,6 +446,8 @@ func (d *DiscoveryV5) FindPeers(ctx context.Context, topic string, opts ...disco
|
|||
defer iterator.Close()
|
||||
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
d.wg.Add(1)
|
||||
go d.iterate(ctx, iterator, limit, doneCh)
|
||||
|
||||
select {
|
||||
|
|
|
@ -63,6 +63,10 @@ func TestDiscV5(t *testing.T) {
|
|||
d3, err := NewDiscoveryV5(host3, net.IPv4(127, 0, 0, 1), tcpPort3, prvKey3, NewWakuEnrBitfield(true, true, true, true), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()}))
|
||||
require.NoError(t, err)
|
||||
|
||||
defer d1.Stop()
|
||||
defer d2.Stop()
|
||||
defer d3.Stop()
|
||||
|
||||
err = d1.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -84,6 +84,8 @@ func (w *WakuNode) sendConnStatus() {
|
|||
}
|
||||
|
||||
func (w *WakuNode) connectednessListener() {
|
||||
defer w.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.quit:
|
||||
|
|
|
@ -2,6 +2,7 @@ package node
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -28,7 +29,10 @@ func TestKeepAlive(t *testing.T) {
|
|||
|
||||
ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel2()
|
||||
pingPeer(ctx2, host1, host2.ID())
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
pingPeer(ctx2, wg, host1, host2.ID())
|
||||
|
||||
require.NoError(t, ctx.Err())
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
logging "github.com/ipfs/go-log"
|
||||
|
@ -67,6 +68,7 @@ type WakuNode struct {
|
|||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
quit chan struct{}
|
||||
wg *sync.WaitGroup
|
||||
|
||||
// Channel passed to WakuNode constructor
|
||||
// receiving connection status notifications
|
||||
|
@ -122,6 +124,7 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) {
|
|||
w.ctx = ctx
|
||||
w.opts = params
|
||||
w.quit = make(chan struct{})
|
||||
w.wg = &sync.WaitGroup{}
|
||||
w.addrChan = make(chan ma.Multiaddr, 1024)
|
||||
|
||||
if w.protocolEventSub, err = host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)); err != nil {
|
||||
|
@ -143,15 +146,16 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) {
|
|||
w.connectionNotif = NewConnectionNotifier(ctx, host)
|
||||
w.host.Network().Notify(w.connectionNotif)
|
||||
|
||||
w.wg.Add(2)
|
||||
go w.connectednessListener()
|
||||
|
||||
if w.opts.keepAliveInterval > time.Duration(0) {
|
||||
w.startKeepAlive(w.opts.keepAliveInterval)
|
||||
}
|
||||
|
||||
go w.checkForAddressChanges()
|
||||
go w.onAddrChange()
|
||||
|
||||
if w.opts.keepAliveInterval > time.Duration(0) {
|
||||
w.wg.Add(1)
|
||||
w.startKeepAlive(w.opts.keepAliveInterval)
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
|
@ -190,6 +194,8 @@ func (w *WakuNode) logAddress(addr ma.Multiaddr) {
|
|||
}
|
||||
|
||||
func (w *WakuNode) checkForAddressChanges() {
|
||||
defer w.wg.Done()
|
||||
|
||||
addrs := w.ListenAddresses()
|
||||
first := make(chan struct{}, 1)
|
||||
first <- struct{}{}
|
||||
|
@ -311,6 +317,8 @@ func (w *WakuNode) Stop() {
|
|||
w.store.Stop()
|
||||
|
||||
w.host.Close()
|
||||
|
||||
w.wg.Wait()
|
||||
}
|
||||
|
||||
func (w *WakuNode) Host() host.Host {
|
||||
|
@ -425,7 +433,10 @@ func (w *WakuNode) startStore() {
|
|||
if w.opts.shouldResume {
|
||||
// TODO: extract this to a function and run it when you go offline
|
||||
// TODO: determine if a store is listening to a topic
|
||||
w.wg.Add(1)
|
||||
go func() {
|
||||
defer w.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
@ -577,6 +588,8 @@ func (w *WakuNode) Peers() ([]*Peer, error) {
|
|||
// This is necessary because TCP connections are automatically closed due to inactivity,
|
||||
// and doing a ping will avoid this (with a small bandwidth cost)
|
||||
func (w *WakuNode) startKeepAlive(t time.Duration) {
|
||||
defer w.wg.Done()
|
||||
|
||||
log.Info("Setting up ping protocol with duration of ", t)
|
||||
|
||||
ticker := time.NewTicker(t)
|
||||
|
@ -594,7 +607,7 @@ func (w *WakuNode) startKeepAlive(t time.Duration) {
|
|||
// through Network's peer collection, as it will be empty
|
||||
for _, p := range w.host.Peerstore().Peers() {
|
||||
if p != w.host.ID() {
|
||||
go pingPeer(w.ctx, w.host, p)
|
||||
go pingPeer(w.ctx, w.wg, w.host, p)
|
||||
}
|
||||
}
|
||||
case <-w.quit:
|
||||
|
@ -604,7 +617,10 @@ func (w *WakuNode) startKeepAlive(t time.Duration) {
|
|||
}()
|
||||
}
|
||||
|
||||
func pingPeer(ctx context.Context, host host.Host, peer peer.ID) {
|
||||
func pingPeer(ctx context.Context, wg *sync.WaitGroup, host host.Host, peer peer.ID) {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
logging "github.com/ipfs/go-log"
|
||||
"github.com/libp2p/go-libp2p-core/host"
|
||||
|
@ -48,6 +49,7 @@ type (
|
|||
h host.Host
|
||||
isFullNode bool
|
||||
MsgC chan *protocol.Envelope
|
||||
wg *sync.WaitGroup
|
||||
|
||||
filters *FilterMap
|
||||
subscribers *Subscribers
|
||||
|
@ -67,13 +69,16 @@ func NewWakuFilter(ctx context.Context, host host.Host, isFullNode bool) *WakuFi
|
|||
|
||||
wf := new(WakuFilter)
|
||||
wf.ctx = ctx
|
||||
wf.MsgC = make(chan *protocol.Envelope)
|
||||
wf.wg = &sync.WaitGroup{}
|
||||
wf.MsgC = make(chan *protocol.Envelope, 1024)
|
||||
wf.h = host
|
||||
wf.isFullNode = isFullNode
|
||||
wf.filters = NewFilterMap()
|
||||
wf.subscribers = NewSubscribers()
|
||||
|
||||
wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest)
|
||||
|
||||
wf.wg.Add(1)
|
||||
go wf.FilterListener()
|
||||
|
||||
if wf.isFullNode {
|
||||
|
@ -155,6 +160,8 @@ func (wf *WakuFilter) pushMessage(subscriber Subscriber, msg *pb.WakuMessage) er
|
|||
}
|
||||
|
||||
func (wf *WakuFilter) FilterListener() {
|
||||
defer wf.wg.Done()
|
||||
|
||||
// This function is invoked for each message received
|
||||
// on the full node in context of Waku2-Filter
|
||||
handle := func(envelope *protocol.Envelope) error { // async
|
||||
|
@ -189,7 +196,6 @@ func (wf *WakuFilter) FilterListener() {
|
|||
log.Error("failed to handle message", err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Having a FilterRequest struct,
|
||||
|
@ -281,8 +287,11 @@ func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilt
|
|||
}
|
||||
|
||||
func (wf *WakuFilter) Stop() {
|
||||
close(wf.MsgC)
|
||||
|
||||
wf.h.RemoveStreamHandler(FilterID_v20beta1)
|
||||
wf.filters.RemoveAll()
|
||||
wf.wg.Wait()
|
||||
}
|
||||
|
||||
func (wf *WakuFilter) Subscribe(ctx context.Context, f ContentFilter, opts ...FilterSubscribeOption) (filterID string, theFilter Filter, err error) {
|
||||
|
|
|
@ -16,6 +16,7 @@ type MessageQueue struct {
|
|||
maxDuration time.Duration
|
||||
|
||||
quit chan struct{}
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (self *MessageQueue) Push(msg IndexedWakuMessage) {
|
||||
|
@ -73,6 +74,8 @@ func (self *MessageQueue) cleanOlderRecords() {
|
|||
}
|
||||
|
||||
func (self *MessageQueue) checkForOlderRecords(d time.Duration) {
|
||||
defer self.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(d)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
@ -98,9 +101,11 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
|
|||
maxDuration: maxDuration,
|
||||
seen: make(map[[32]byte]struct{}),
|
||||
quit: make(chan struct{}),
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
|
||||
if maxDuration != 0 {
|
||||
result.wg.Add(1)
|
||||
go result.checkForOlderRecords(10 * time.Second) // is 10s okay?
|
||||
}
|
||||
|
||||
|
@ -109,4 +114,5 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
|
|||
|
||||
func (self *MessageQueue) Stop() {
|
||||
close(self.quit)
|
||||
self.wg.Wait()
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
logging "github.com/ipfs/go-log"
|
||||
|
@ -227,6 +228,7 @@ type IndexedWakuMessage struct {
|
|||
type WakuStore struct {
|
||||
ctx context.Context
|
||||
MsgC chan *protocol.Envelope
|
||||
wg *sync.WaitGroup
|
||||
|
||||
started bool
|
||||
|
||||
|
@ -240,6 +242,7 @@ func NewWakuStore(host host.Host, p MessageProvider, maxNumberOfMessages int, ma
|
|||
wakuStore := new(WakuStore)
|
||||
wakuStore.msgProvider = p
|
||||
wakuStore.h = host
|
||||
wakuStore.wg = &sync.WaitGroup{}
|
||||
wakuStore.messageQueue = NewMessageQueue(maxNumberOfMessages, maxRetentionDuration)
|
||||
return wakuStore
|
||||
}
|
||||
|
@ -261,6 +264,7 @@ func (store *WakuStore) Start(ctx context.Context) {
|
|||
|
||||
store.h.SetStreamHandlerMatch(StoreID_v20beta3, protocol.PrefixTextMatch(string(StoreID_v20beta3)), store.onRequest)
|
||||
|
||||
store.wg.Add(1)
|
||||
go store.storeIncomingMessages(ctx)
|
||||
|
||||
if store.msgProvider == nil {
|
||||
|
@ -327,6 +331,7 @@ func (store *WakuStore) storeMessage(env *protocol.Envelope) {
|
|||
}
|
||||
|
||||
func (store *WakuStore) storeIncomingMessages(ctx context.Context) {
|
||||
defer store.wg.Done()
|
||||
for envelope := range store.MsgC {
|
||||
store.storeMessage(envelope)
|
||||
}
|
||||
|
@ -721,4 +726,6 @@ func (store *WakuStore) Stop() {
|
|||
if store.h != nil {
|
||||
store.h.RemoveStreamHandler(StoreID_v20beta3)
|
||||
}
|
||||
|
||||
store.wg.Wait()
|
||||
}
|
||||
|
|
|
@ -79,9 +79,10 @@ func SelectPeerWithLowestRTT(ctx context.Context, host host.Host, protocolId str
|
|||
waitCh := make(chan struct{})
|
||||
pingCh := make(chan pingResult, 1000)
|
||||
|
||||
wg.Add(len(peers))
|
||||
|
||||
go func() {
|
||||
for _, p := range peers {
|
||||
wg.Add(1)
|
||||
go func(p peer.ID) {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
|
|
Loading…
Reference in New Issue