refactor: use WaitGroup for graceful shutdown of worker goroutines (#166)

This commit is contained in:
Richard Ramos 2021-11-23 11:03:12 -04:00 committed by GitHub
parent c8caa46c99
commit ce417a6486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 77 additions and 15 deletions

View File

@ -37,6 +37,8 @@ type DiscoveryV5 struct {
NAT nat.Interface
quit chan struct{}
wg *sync.WaitGroup
peerCache peerCache
}
@ -142,6 +144,7 @@ func NewDiscoveryV5(host host.Host, ipAddr net.IP, tcpPort int, priv *ecdsa.Priv
host: host,
params: params,
NAT: NAT,
wg: &sync.WaitGroup{},
peerCache: peerCache{
rng: rand.New(rand.NewSource(rand.Int63())),
recs: make(map[peer.ID]peerRecord),
@ -197,7 +200,9 @@ func (d *DiscoveryV5) listen() error {
d.udpAddr = conn.LocalAddr().(*net.UDPAddr)
if d.NAT != nil && !d.udpAddr.IP.IsLoopback() {
d.wg.Add(1)
go func() {
defer d.wg.Done()
nat.Map(d.NAT, d.quit, "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery")
}()
@ -222,13 +227,15 @@ func (d *DiscoveryV5) Start() error {
d.Lock()
defer d.Unlock()
d.wg.Wait() // Waiting for other go routines to stop
d.quit = make(chan struct{}, 1)
err := d.listen()
if err != nil {
return err
}
d.quit = make(chan struct{})
return nil
}
@ -236,12 +243,14 @@ func (d *DiscoveryV5) Stop() {
d.Lock()
defer d.Unlock()
close(d.quit)
d.listener.Close()
d.listener = nil
close(d.quit)
log.Info("Stopped Discovery V5")
d.wg.Wait()
}
// IsPrivate reports whether ip is a private address, according to
@ -354,6 +363,8 @@ func (c *DiscoveryV5) Advertise(ctx context.Context, ns string, opts ...discover
}
func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int, doneCh chan struct{}) {
defer d.wg.Done()
for {
if len(d.peerCache.recs) >= limit {
break
@ -435,6 +446,8 @@ func (d *DiscoveryV5) FindPeers(ctx context.Context, topic string, opts ...disco
defer iterator.Close()
doneCh := make(chan struct{})
d.wg.Add(1)
go d.iterate(ctx, iterator, limit, doneCh)
select {

View File

@ -63,6 +63,10 @@ func TestDiscV5(t *testing.T) {
d3, err := NewDiscoveryV5(host3, net.IPv4(127, 0, 0, 1), tcpPort3, prvKey3, NewWakuEnrBitfield(true, true, true, true), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()}))
require.NoError(t, err)
defer d1.Stop()
defer d2.Stop()
defer d3.Stop()
err = d1.Start()
require.NoError(t, err)

View File

@ -84,6 +84,8 @@ func (w *WakuNode) sendConnStatus() {
}
func (w *WakuNode) connectednessListener() {
defer w.wg.Done()
for {
select {
case <-w.quit:

View File

@ -2,6 +2,7 @@ package node
import (
"context"
"sync"
"testing"
"time"
@ -28,7 +29,10 @@ func TestKeepAlive(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second)
defer cancel2()
pingPeer(ctx2, host1, host2.ID())
wg := &sync.WaitGroup{}
pingPeer(ctx2, wg, host1, host2.ID())
require.NoError(t, ctx.Err())
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net"
"strconv"
"sync"
"time"
logging "github.com/ipfs/go-log"
@ -67,6 +68,7 @@ type WakuNode struct {
ctx context.Context
cancel context.CancelFunc
quit chan struct{}
wg *sync.WaitGroup
// Channel passed to WakuNode constructor
// receiving connection status notifications
@ -122,6 +124,7 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) {
w.ctx = ctx
w.opts = params
w.quit = make(chan struct{})
w.wg = &sync.WaitGroup{}
w.addrChan = make(chan ma.Multiaddr, 1024)
if w.protocolEventSub, err = host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)); err != nil {
@ -143,15 +146,16 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) {
w.connectionNotif = NewConnectionNotifier(ctx, host)
w.host.Network().Notify(w.connectionNotif)
w.wg.Add(2)
go w.connectednessListener()
if w.opts.keepAliveInterval > time.Duration(0) {
w.startKeepAlive(w.opts.keepAliveInterval)
}
go w.checkForAddressChanges()
go w.onAddrChange()
if w.opts.keepAliveInterval > time.Duration(0) {
w.wg.Add(1)
w.startKeepAlive(w.opts.keepAliveInterval)
}
return w, nil
}
@ -190,6 +194,8 @@ func (w *WakuNode) logAddress(addr ma.Multiaddr) {
}
func (w *WakuNode) checkForAddressChanges() {
defer w.wg.Done()
addrs := w.ListenAddresses()
first := make(chan struct{}, 1)
first <- struct{}{}
@ -311,6 +317,8 @@ func (w *WakuNode) Stop() {
w.store.Stop()
w.host.Close()
w.wg.Wait()
}
func (w *WakuNode) Host() host.Host {
@ -425,7 +433,10 @@ func (w *WakuNode) startStore() {
if w.opts.shouldResume {
// TODO: extract this to a function and run it when you go offline
// TODO: determine if a store is listening to a topic
w.wg.Add(1)
go func() {
defer w.wg.Done()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
@ -577,6 +588,8 @@ func (w *WakuNode) Peers() ([]*Peer, error) {
// This is necessary because TCP connections are automatically closed due to inactivity,
// and doing a ping will avoid this (with a small bandwidth cost)
func (w *WakuNode) startKeepAlive(t time.Duration) {
defer w.wg.Done()
log.Info("Setting up ping protocol with duration of ", t)
ticker := time.NewTicker(t)
@ -594,7 +607,7 @@ func (w *WakuNode) startKeepAlive(t time.Duration) {
// through Network's peer collection, as it will be empty
for _, p := range w.host.Peerstore().Peers() {
if p != w.host.ID() {
go pingPeer(w.ctx, w.host, p)
go pingPeer(w.ctx, w.wg, w.host, p)
}
}
case <-w.quit:
@ -604,7 +617,10 @@ func (w *WakuNode) startKeepAlive(t time.Duration) {
}()
}
func pingPeer(ctx context.Context, host host.Host, peer peer.ID) {
func pingPeer(ctx context.Context, wg *sync.WaitGroup, host host.Host, peer peer.ID) {
wg.Add(1)
defer wg.Done()
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()

View File

@ -5,6 +5,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"sync"
logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/host"
@ -48,6 +49,7 @@ type (
h host.Host
isFullNode bool
MsgC chan *protocol.Envelope
wg *sync.WaitGroup
filters *FilterMap
subscribers *Subscribers
@ -67,13 +69,16 @@ func NewWakuFilter(ctx context.Context, host host.Host, isFullNode bool) *WakuFi
wf := new(WakuFilter)
wf.ctx = ctx
wf.MsgC = make(chan *protocol.Envelope)
wf.wg = &sync.WaitGroup{}
wf.MsgC = make(chan *protocol.Envelope, 1024)
wf.h = host
wf.isFullNode = isFullNode
wf.filters = NewFilterMap()
wf.subscribers = NewSubscribers()
wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest)
wf.wg.Add(1)
go wf.FilterListener()
if wf.isFullNode {
@ -155,6 +160,8 @@ func (wf *WakuFilter) pushMessage(subscriber Subscriber, msg *pb.WakuMessage) er
}
func (wf *WakuFilter) FilterListener() {
defer wf.wg.Done()
// This function is invoked for each message received
// on the full node in context of Waku2-Filter
handle := func(envelope *protocol.Envelope) error { // async
@ -189,7 +196,6 @@ func (wf *WakuFilter) FilterListener() {
log.Error("failed to handle message", err)
}
}
}
// Having a FilterRequest struct,
@ -281,8 +287,11 @@ func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilt
}
func (wf *WakuFilter) Stop() {
close(wf.MsgC)
wf.h.RemoveStreamHandler(FilterID_v20beta1)
wf.filters.RemoveAll()
wf.wg.Wait()
}
func (wf *WakuFilter) Subscribe(ctx context.Context, f ContentFilter, opts ...FilterSubscribeOption) (filterID string, theFilter Filter, err error) {

View File

@ -16,6 +16,7 @@ type MessageQueue struct {
maxDuration time.Duration
quit chan struct{}
wg *sync.WaitGroup
}
func (self *MessageQueue) Push(msg IndexedWakuMessage) {
@ -73,6 +74,8 @@ func (self *MessageQueue) cleanOlderRecords() {
}
func (self *MessageQueue) checkForOlderRecords(d time.Duration) {
defer self.wg.Done()
ticker := time.NewTicker(d)
defer ticker.Stop()
@ -98,9 +101,11 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
maxDuration: maxDuration,
seen: make(map[[32]byte]struct{}),
quit: make(chan struct{}),
wg: &sync.WaitGroup{},
}
if maxDuration != 0 {
result.wg.Add(1)
go result.checkForOlderRecords(10 * time.Second) // is 10s okay?
}
@ -109,4 +114,5 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
func (self *MessageQueue) Stop() {
close(self.quit)
self.wg.Wait()
}

View File

@ -8,6 +8,7 @@ import (
"fmt"
"math"
"sort"
"sync"
"time"
logging "github.com/ipfs/go-log"
@ -227,6 +228,7 @@ type IndexedWakuMessage struct {
type WakuStore struct {
ctx context.Context
MsgC chan *protocol.Envelope
wg *sync.WaitGroup
started bool
@ -240,6 +242,7 @@ func NewWakuStore(host host.Host, p MessageProvider, maxNumberOfMessages int, ma
wakuStore := new(WakuStore)
wakuStore.msgProvider = p
wakuStore.h = host
wakuStore.wg = &sync.WaitGroup{}
wakuStore.messageQueue = NewMessageQueue(maxNumberOfMessages, maxRetentionDuration)
return wakuStore
}
@ -261,6 +264,7 @@ func (store *WakuStore) Start(ctx context.Context) {
store.h.SetStreamHandlerMatch(StoreID_v20beta3, protocol.PrefixTextMatch(string(StoreID_v20beta3)), store.onRequest)
store.wg.Add(1)
go store.storeIncomingMessages(ctx)
if store.msgProvider == nil {
@ -327,6 +331,7 @@ func (store *WakuStore) storeMessage(env *protocol.Envelope) {
}
func (store *WakuStore) storeIncomingMessages(ctx context.Context) {
defer store.wg.Done()
for envelope := range store.MsgC {
store.storeMessage(envelope)
}
@ -721,4 +726,6 @@ func (store *WakuStore) Stop() {
if store.h != nil {
store.h.RemoveStreamHandler(StoreID_v20beta3)
}
store.wg.Wait()
}

View File

@ -79,9 +79,10 @@ func SelectPeerWithLowestRTT(ctx context.Context, host host.Host, protocolId str
waitCh := make(chan struct{})
pingCh := make(chan pingResult, 1000)
wg.Add(len(peers))
go func() {
for _, p := range peers {
wg.Add(1)
go func(p peer.ID) {
defer wg.Done()
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)