refactor: use WaitGroup for graceful shutdown of worker goroutines (#166)

This commit is contained in:
Richard Ramos 2021-11-23 11:03:12 -04:00 committed by GitHub
parent c8caa46c99
commit ce417a6486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 77 additions and 15 deletions

View File

@ -37,6 +37,8 @@ type DiscoveryV5 struct {
NAT nat.Interface NAT nat.Interface
quit chan struct{} quit chan struct{}
wg *sync.WaitGroup
peerCache peerCache peerCache peerCache
} }
@ -142,6 +144,7 @@ func NewDiscoveryV5(host host.Host, ipAddr net.IP, tcpPort int, priv *ecdsa.Priv
host: host, host: host,
params: params, params: params,
NAT: NAT, NAT: NAT,
wg: &sync.WaitGroup{},
peerCache: peerCache{ peerCache: peerCache{
rng: rand.New(rand.NewSource(rand.Int63())), rng: rand.New(rand.NewSource(rand.Int63())),
recs: make(map[peer.ID]peerRecord), recs: make(map[peer.ID]peerRecord),
@ -197,7 +200,9 @@ func (d *DiscoveryV5) listen() error {
d.udpAddr = conn.LocalAddr().(*net.UDPAddr) d.udpAddr = conn.LocalAddr().(*net.UDPAddr)
if d.NAT != nil && !d.udpAddr.IP.IsLoopback() { if d.NAT != nil && !d.udpAddr.IP.IsLoopback() {
d.wg.Add(1)
go func() { go func() {
defer d.wg.Done()
nat.Map(d.NAT, d.quit, "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery") nat.Map(d.NAT, d.quit, "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery")
}() }()
@ -222,13 +227,15 @@ func (d *DiscoveryV5) Start() error {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
d.wg.Wait() // Waiting for other go routines to stop
d.quit = make(chan struct{}, 1)
err := d.listen() err := d.listen()
if err != nil { if err != nil {
return err return err
} }
d.quit = make(chan struct{})
return nil return nil
} }
@ -236,12 +243,14 @@ func (d *DiscoveryV5) Stop() {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
close(d.quit)
d.listener.Close() d.listener.Close()
d.listener = nil d.listener = nil
close(d.quit)
log.Info("Stopped Discovery V5") log.Info("Stopped Discovery V5")
d.wg.Wait()
} }
// IsPrivate reports whether ip is a private address, according to // IsPrivate reports whether ip is a private address, according to
@ -354,6 +363,8 @@ func (c *DiscoveryV5) Advertise(ctx context.Context, ns string, opts ...discover
} }
func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int, doneCh chan struct{}) { func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int, doneCh chan struct{}) {
defer d.wg.Done()
for { for {
if len(d.peerCache.recs) >= limit { if len(d.peerCache.recs) >= limit {
break break
@ -435,6 +446,8 @@ func (d *DiscoveryV5) FindPeers(ctx context.Context, topic string, opts ...disco
defer iterator.Close() defer iterator.Close()
doneCh := make(chan struct{}) doneCh := make(chan struct{})
d.wg.Add(1)
go d.iterate(ctx, iterator, limit, doneCh) go d.iterate(ctx, iterator, limit, doneCh)
select { select {

View File

@ -63,6 +63,10 @@ func TestDiscV5(t *testing.T) {
d3, err := NewDiscoveryV5(host3, net.IPv4(127, 0, 0, 1), tcpPort3, prvKey3, NewWakuEnrBitfield(true, true, true, true), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()})) d3, err := NewDiscoveryV5(host3, net.IPv4(127, 0, 0, 1), tcpPort3, prvKey3, NewWakuEnrBitfield(true, true, true, true), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()}))
require.NoError(t, err) require.NoError(t, err)
defer d1.Stop()
defer d2.Stop()
defer d3.Stop()
err = d1.Start() err = d1.Start()
require.NoError(t, err) require.NoError(t, err)

View File

@ -84,6 +84,8 @@ func (w *WakuNode) sendConnStatus() {
} }
func (w *WakuNode) connectednessListener() { func (w *WakuNode) connectednessListener() {
defer w.wg.Done()
for { for {
select { select {
case <-w.quit: case <-w.quit:

View File

@ -2,6 +2,7 @@ package node
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
@ -28,7 +29,10 @@ func TestKeepAlive(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second) ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second)
defer cancel2() defer cancel2()
pingPeer(ctx2, host1, host2.ID())
wg := &sync.WaitGroup{}
pingPeer(ctx2, wg, host1, host2.ID())
require.NoError(t, ctx.Err()) require.NoError(t, ctx.Err())
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"sync"
"time" "time"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
@ -67,6 +68,7 @@ type WakuNode struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
quit chan struct{} quit chan struct{}
wg *sync.WaitGroup
// Channel passed to WakuNode constructor // Channel passed to WakuNode constructor
// receiving connection status notifications // receiving connection status notifications
@ -122,6 +124,7 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) {
w.ctx = ctx w.ctx = ctx
w.opts = params w.opts = params
w.quit = make(chan struct{}) w.quit = make(chan struct{})
w.wg = &sync.WaitGroup{}
w.addrChan = make(chan ma.Multiaddr, 1024) w.addrChan = make(chan ma.Multiaddr, 1024)
if w.protocolEventSub, err = host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)); err != nil { if w.protocolEventSub, err = host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)); err != nil {
@ -143,15 +146,16 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) {
w.connectionNotif = NewConnectionNotifier(ctx, host) w.connectionNotif = NewConnectionNotifier(ctx, host)
w.host.Network().Notify(w.connectionNotif) w.host.Network().Notify(w.connectionNotif)
w.wg.Add(2)
go w.connectednessListener() go w.connectednessListener()
if w.opts.keepAliveInterval > time.Duration(0) {
w.startKeepAlive(w.opts.keepAliveInterval)
}
go w.checkForAddressChanges() go w.checkForAddressChanges()
go w.onAddrChange() go w.onAddrChange()
if w.opts.keepAliveInterval > time.Duration(0) {
w.wg.Add(1)
w.startKeepAlive(w.opts.keepAliveInterval)
}
return w, nil return w, nil
} }
@ -190,6 +194,8 @@ func (w *WakuNode) logAddress(addr ma.Multiaddr) {
} }
func (w *WakuNode) checkForAddressChanges() { func (w *WakuNode) checkForAddressChanges() {
defer w.wg.Done()
addrs := w.ListenAddresses() addrs := w.ListenAddresses()
first := make(chan struct{}, 1) first := make(chan struct{}, 1)
first <- struct{}{} first <- struct{}{}
@ -311,6 +317,8 @@ func (w *WakuNode) Stop() {
w.store.Stop() w.store.Stop()
w.host.Close() w.host.Close()
w.wg.Wait()
} }
func (w *WakuNode) Host() host.Host { func (w *WakuNode) Host() host.Host {
@ -425,7 +433,10 @@ func (w *WakuNode) startStore() {
if w.opts.shouldResume { if w.opts.shouldResume {
// TODO: extract this to a function and run it when you go offline // TODO: extract this to a function and run it when you go offline
// TODO: determine if a store is listening to a topic // TODO: determine if a store is listening to a topic
w.wg.Add(1)
go func() { go func() {
defer w.wg.Done()
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop() defer ticker.Stop()
@ -577,6 +588,8 @@ func (w *WakuNode) Peers() ([]*Peer, error) {
// This is necessary because TCP connections are automatically closed due to inactivity, // This is necessary because TCP connections are automatically closed due to inactivity,
// and doing a ping will avoid this (with a small bandwidth cost) // and doing a ping will avoid this (with a small bandwidth cost)
func (w *WakuNode) startKeepAlive(t time.Duration) { func (w *WakuNode) startKeepAlive(t time.Duration) {
defer w.wg.Done()
log.Info("Setting up ping protocol with duration of ", t) log.Info("Setting up ping protocol with duration of ", t)
ticker := time.NewTicker(t) ticker := time.NewTicker(t)
@ -594,7 +607,7 @@ func (w *WakuNode) startKeepAlive(t time.Duration) {
// through Network's peer collection, as it will be empty // through Network's peer collection, as it will be empty
for _, p := range w.host.Peerstore().Peers() { for _, p := range w.host.Peerstore().Peers() {
if p != w.host.ID() { if p != w.host.ID() {
go pingPeer(w.ctx, w.host, p) go pingPeer(w.ctx, w.wg, w.host, p)
} }
} }
case <-w.quit: case <-w.quit:
@ -604,7 +617,10 @@ func (w *WakuNode) startKeepAlive(t time.Duration) {
}() }()
} }
func pingPeer(ctx context.Context, host host.Host, peer peer.ID) { func pingPeer(ctx context.Context, wg *sync.WaitGroup, host host.Host, peer peer.ID) {
wg.Add(1)
defer wg.Done()
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()

View File

@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"sync"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/host"
@ -48,6 +49,7 @@ type (
h host.Host h host.Host
isFullNode bool isFullNode bool
MsgC chan *protocol.Envelope MsgC chan *protocol.Envelope
wg *sync.WaitGroup
filters *FilterMap filters *FilterMap
subscribers *Subscribers subscribers *Subscribers
@ -67,13 +69,16 @@ func NewWakuFilter(ctx context.Context, host host.Host, isFullNode bool) *WakuFi
wf := new(WakuFilter) wf := new(WakuFilter)
wf.ctx = ctx wf.ctx = ctx
wf.MsgC = make(chan *protocol.Envelope) wf.wg = &sync.WaitGroup{}
wf.MsgC = make(chan *protocol.Envelope, 1024)
wf.h = host wf.h = host
wf.isFullNode = isFullNode wf.isFullNode = isFullNode
wf.filters = NewFilterMap() wf.filters = NewFilterMap()
wf.subscribers = NewSubscribers() wf.subscribers = NewSubscribers()
wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest) wf.h.SetStreamHandlerMatch(FilterID_v20beta1, protocol.PrefixTextMatch(string(FilterID_v20beta1)), wf.onRequest)
wf.wg.Add(1)
go wf.FilterListener() go wf.FilterListener()
if wf.isFullNode { if wf.isFullNode {
@ -155,6 +160,8 @@ func (wf *WakuFilter) pushMessage(subscriber Subscriber, msg *pb.WakuMessage) er
} }
func (wf *WakuFilter) FilterListener() { func (wf *WakuFilter) FilterListener() {
defer wf.wg.Done()
// This function is invoked for each message received // This function is invoked for each message received
// on the full node in context of Waku2-Filter // on the full node in context of Waku2-Filter
handle := func(envelope *protocol.Envelope) error { // async handle := func(envelope *protocol.Envelope) error { // async
@ -189,7 +196,6 @@ func (wf *WakuFilter) FilterListener() {
log.Error("failed to handle message", err) log.Error("failed to handle message", err)
} }
} }
} }
// Having a FilterRequest struct, // Having a FilterRequest struct,
@ -281,8 +287,11 @@ func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilt
} }
func (wf *WakuFilter) Stop() { func (wf *WakuFilter) Stop() {
close(wf.MsgC)
wf.h.RemoveStreamHandler(FilterID_v20beta1) wf.h.RemoveStreamHandler(FilterID_v20beta1)
wf.filters.RemoveAll() wf.filters.RemoveAll()
wf.wg.Wait()
} }
func (wf *WakuFilter) Subscribe(ctx context.Context, f ContentFilter, opts ...FilterSubscribeOption) (filterID string, theFilter Filter, err error) { func (wf *WakuFilter) Subscribe(ctx context.Context, f ContentFilter, opts ...FilterSubscribeOption) (filterID string, theFilter Filter, err error) {

View File

@ -16,6 +16,7 @@ type MessageQueue struct {
maxDuration time.Duration maxDuration time.Duration
quit chan struct{} quit chan struct{}
wg *sync.WaitGroup
} }
func (self *MessageQueue) Push(msg IndexedWakuMessage) { func (self *MessageQueue) Push(msg IndexedWakuMessage) {
@ -73,6 +74,8 @@ func (self *MessageQueue) cleanOlderRecords() {
} }
func (self *MessageQueue) checkForOlderRecords(d time.Duration) { func (self *MessageQueue) checkForOlderRecords(d time.Duration) {
defer self.wg.Done()
ticker := time.NewTicker(d) ticker := time.NewTicker(d)
defer ticker.Stop() defer ticker.Stop()
@ -98,9 +101,11 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
maxDuration: maxDuration, maxDuration: maxDuration,
seen: make(map[[32]byte]struct{}), seen: make(map[[32]byte]struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
wg: &sync.WaitGroup{},
} }
if maxDuration != 0 { if maxDuration != 0 {
result.wg.Add(1)
go result.checkForOlderRecords(10 * time.Second) // is 10s okay? go result.checkForOlderRecords(10 * time.Second) // is 10s okay?
} }
@ -109,4 +114,5 @@ func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
func (self *MessageQueue) Stop() { func (self *MessageQueue) Stop() {
close(self.quit) close(self.quit)
self.wg.Wait()
} }

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"math" "math"
"sort" "sort"
"sync"
"time" "time"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
@ -227,6 +228,7 @@ type IndexedWakuMessage struct {
type WakuStore struct { type WakuStore struct {
ctx context.Context ctx context.Context
MsgC chan *protocol.Envelope MsgC chan *protocol.Envelope
wg *sync.WaitGroup
started bool started bool
@ -240,6 +242,7 @@ func NewWakuStore(host host.Host, p MessageProvider, maxNumberOfMessages int, ma
wakuStore := new(WakuStore) wakuStore := new(WakuStore)
wakuStore.msgProvider = p wakuStore.msgProvider = p
wakuStore.h = host wakuStore.h = host
wakuStore.wg = &sync.WaitGroup{}
wakuStore.messageQueue = NewMessageQueue(maxNumberOfMessages, maxRetentionDuration) wakuStore.messageQueue = NewMessageQueue(maxNumberOfMessages, maxRetentionDuration)
return wakuStore return wakuStore
} }
@ -261,6 +264,7 @@ func (store *WakuStore) Start(ctx context.Context) {
store.h.SetStreamHandlerMatch(StoreID_v20beta3, protocol.PrefixTextMatch(string(StoreID_v20beta3)), store.onRequest) store.h.SetStreamHandlerMatch(StoreID_v20beta3, protocol.PrefixTextMatch(string(StoreID_v20beta3)), store.onRequest)
store.wg.Add(1)
go store.storeIncomingMessages(ctx) go store.storeIncomingMessages(ctx)
if store.msgProvider == nil { if store.msgProvider == nil {
@ -327,6 +331,7 @@ func (store *WakuStore) storeMessage(env *protocol.Envelope) {
} }
func (store *WakuStore) storeIncomingMessages(ctx context.Context) { func (store *WakuStore) storeIncomingMessages(ctx context.Context) {
defer store.wg.Done()
for envelope := range store.MsgC { for envelope := range store.MsgC {
store.storeMessage(envelope) store.storeMessage(envelope)
} }
@ -721,4 +726,6 @@ func (store *WakuStore) Stop() {
if store.h != nil { if store.h != nil {
store.h.RemoveStreamHandler(StoreID_v20beta3) store.h.RemoveStreamHandler(StoreID_v20beta3)
} }
store.wg.Wait()
} }

View File

@ -79,9 +79,10 @@ func SelectPeerWithLowestRTT(ctx context.Context, host host.Host, protocolId str
waitCh := make(chan struct{}) waitCh := make(chan struct{})
pingCh := make(chan pingResult, 1000) pingCh := make(chan pingResult, 1000)
wg.Add(len(peers))
go func() { go func() {
for _, p := range peers { for _, p := range peers {
wg.Add(1)
go func(p peer.ID) { go func(p peer.ID) {
defer wg.Done() defer wg.Done()
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 3*time.Second)