From 83fd53d432ed1277536c8f47475b7e1ee54b6c38 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Sat, 10 Dec 2022 11:38:18 -0400 Subject: [PATCH] refactor: use context instead of quit channel --- waku/node.go | 2 +- waku/v2/discv5/discover.go | 59 +++------ waku/v2/discv5/discover_test.go | 14 +- waku/v2/node/connectedness.go | 2 +- waku/v2/node/keepalive.go | 3 +- waku/v2/node/localnode.go | 2 +- waku/v2/node/wakunode2.go | 18 +-- waku/v2/node/wakunode2_rln.go | 3 +- .../peer_exchange/waku_peer_exchange.go | 125 ++++++++---------- .../peer_exchange/waku_peer_exchange_test.go | 16 +-- waku/v2/protocol/rln/web3.go | 2 +- waku/v2/timesource/ntp.go | 26 ++-- waku/v2/timesource/ntp_test.go | 8 +- waku/v2/timesource/timesource.go | 7 +- waku/v2/timesource/wall.go | 7 +- 15 files changed, 137 insertions(+), 157 deletions(-) diff --git a/waku/node.go b/waku/node.go index fcf5a0e5..031e27b6 100644 --- a/waku/node.go +++ b/waku/node.go @@ -297,7 +297,7 @@ func Execute(options Options) { } if options.DiscV5.Enable { - if err = wakuNode.DiscV5().Start(); err != nil { + if err = wakuNode.DiscV5().Start(ctx); err != nil { logger.Fatal("starting discovery v5", zap.Error(err)) } } diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index f29de566..a0b19b1a 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -27,23 +27,20 @@ type DiscoveryV5 struct { discovery.Discovery params *discV5Parameters - ctx context.Context host host.Host config discover.Config udpAddr *net.UDPAddr listener *discover.UDPv5 localnode *enode.LocalNode NAT nat.Interface - quit chan struct{} - started bool log *zap.Logger - wg *sync.WaitGroup + started bool + cancel context.CancelFunc + wg *sync.WaitGroup - peerCache peerCache - discoverCtx context.Context - discoverCancelFunc context.CancelFunc + peerCache peerCache } type peerCache struct { @@ -101,7 +98,7 @@ func DefaultOptions() []DiscoveryV5Option { const MaxPeersToDiscover = 600 -func NewDiscoveryV5(ctx context.Context, host host.Host, priv *ecdsa.PrivateKey, localnode *enode.LocalNode, log *zap.Logger, opts ...DiscoveryV5Option) (*DiscoveryV5, error) { +func NewDiscoveryV5(host host.Host, priv *ecdsa.PrivateKey, localnode *enode.LocalNode, log *zap.Logger, opts ...DiscoveryV5Option) (*DiscoveryV5, error) { params := new(discV5Parameters) optList := DefaultOptions() optList = append(optList, opts...) @@ -117,7 +114,6 @@ func NewDiscoveryV5(ctx context.Context, host host.Host, priv *ecdsa.PrivateKey, } return &DiscoveryV5{ - ctx: ctx, host: host, params: params, NAT: NAT, @@ -150,7 +146,7 @@ func (d *DiscoveryV5) Node() *enode.Node { return d.localnode.Node() } -func (d *DiscoveryV5) listen() error { +func (d *DiscoveryV5) listen(ctx context.Context) error { conn, err := net.ListenUDP("udp", d.udpAddr) if err != nil { return err @@ -161,7 +157,7 @@ func (d *DiscoveryV5) listen() error { d.wg.Add(1) go func() { defer d.wg.Done() - nat.Map(d.NAT, d.quit, "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery") + nat.Map(d.NAT, ctx.Done(), "udp", d.udpAddr.Port, d.udpAddr.Port, "go-waku discv5 discovery") }() } @@ -183,27 +179,22 @@ func (d *DiscoveryV5) listen() error { return nil } -func (d *DiscoveryV5) Start() error { +func (d *DiscoveryV5) Start(ctx context.Context) error { d.Lock() defer d.Unlock() - if d.started { - return nil - } + d.wg.Wait() // Waiting for any go routines to stop + ctx, cancel := context.WithCancel(ctx) - d.wg.Wait() // Waiting for other go routines to stop - - d.quit = make(chan struct{}, 1) + d.cancel = cancel d.started = true - err := d.listen() + err := d.listen(ctx) if err != nil { return err } - // create cancellable - d.discoverCtx, d.discoverCancelFunc = context.WithCancel(d.ctx) - go d.runDiscoveryV5Loop() + go d.runDiscoveryV5Loop(ctx) return nil } @@ -216,12 +207,7 @@ func (d *DiscoveryV5) Stop() { d.Lock() defer d.Unlock() - if !d.started { - return - } - - close(d.quit) - d.discoverCancelFunc() + d.cancel() d.listener.Close() d.listener = nil @@ -295,7 +281,7 @@ func (d *DiscoveryV5) Advertise(ctx context.Context, ns string, opts ...discover return 20 * time.Minute, nil } -func (d *DiscoveryV5) iterate(iterator enode.Iterator, limit int, doneCh chan struct{}) { +func (d *DiscoveryV5) iterate(ctx context.Context, iterator enode.Iterator, limit int) { defer d.wg.Done() for { @@ -303,7 +289,7 @@ func (d *DiscoveryV5) iterate(iterator enode.Iterator, limit int, doneCh chan st break } - if d.discoverCtx.Err() != nil { + if ctx.Err() != nil { break } @@ -334,8 +320,6 @@ func (d *DiscoveryV5) iterate(iterator enode.Iterator, limit int, doneCh chan st } d.peerCache.Unlock() } - - close(doneCh) } func (d *DiscoveryV5) removeExpiredPeers() int { @@ -354,21 +338,16 @@ func (d *DiscoveryV5) removeExpiredPeers() int { return newCacheSize } -func (d *DiscoveryV5) runDiscoveryV5Loop() { +func (d *DiscoveryV5) runDiscoveryV5Loop(ctx context.Context) { iterator := d.listener.RandomNodes() iterator = enode.Filter(iterator, evaluateNode) defer iterator.Close() - doneCh := make(chan struct{}) - d.wg.Add(1) - go d.iterate(iterator, MaxPeersToDiscover, doneCh) + go d.iterate(ctx, iterator, MaxPeersToDiscover) - select { - case <-d.discoverCtx.Done(): - case <-doneCh: - } + <-ctx.Done() d.log.Warn("Discv5 loop stopped") } diff --git a/waku/v2/discv5/discover_test.go b/waku/v2/discv5/discover_test.go index d707da2a..504e86c5 100644 --- a/waku/v2/discv5/discover_test.go +++ b/waku/v2/discv5/discover_test.go @@ -106,7 +106,7 @@ func TestDiscV5(t *testing.T) { ip1, _ := extractIP(host1.Addrs()[0]) l1, err := newLocalnode(prvKey1, ip1, udpPort1, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) - d1, err := NewDiscoveryV5(context.Background(), host1, prvKey1, l1, utils.Logger(), WithUDPPort(udpPort1)) + d1, err := NewDiscoveryV5(host1, prvKey1, l1, utils.Logger(), WithUDPPort(udpPort1)) require.NoError(t, err) // H2 @@ -116,7 +116,7 @@ func TestDiscV5(t *testing.T) { require.NoError(t, err) l2, err := newLocalnode(prvKey2, ip2, udpPort2, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) - d2, err := NewDiscoveryV5(context.Background(), host2, prvKey2, l2, utils.Logger(), WithUDPPort(udpPort2), WithBootnodes([]*enode.Node{d1.localnode.Node()})) + d2, err := NewDiscoveryV5(host2, prvKey2, l2, utils.Logger(), WithUDPPort(udpPort2), WithBootnodes([]*enode.Node{d1.localnode.Node()})) require.NoError(t, err) // H3 @@ -126,20 +126,20 @@ func TestDiscV5(t *testing.T) { require.NoError(t, err) l3, err := newLocalnode(prvKey3, ip3, udpPort3, utils.NewWakuEnrBitfield(true, true, true, true), nil, utils.Logger()) require.NoError(t, err) - d3, err := NewDiscoveryV5(context.Background(), host3, prvKey3, l3, utils.Logger(), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()})) + d3, err := NewDiscoveryV5(host3, prvKey3, l3, utils.Logger(), WithUDPPort(udpPort3), WithBootnodes([]*enode.Node{d2.localnode.Node()})) require.NoError(t, err) defer d1.Stop() defer d2.Stop() defer d3.Stop() - err = d1.Start() + err = d1.Start(context.Background()) require.NoError(t, err) - err = d2.Start() + err = d2.Start(context.Background()) require.NoError(t, err) - err = d3.Start() + err = d3.Start(context.Background()) require.NoError(t, err) time.Sleep(3 * time.Second) // Wait for nodes to be discovered @@ -205,7 +205,7 @@ func TestDiscV5(t *testing.T) { } // Restart peer search - err = d3.Start() + err = d3.Start(context.Background()) require.NoError(t, err) time.Sleep(3 * time.Second) // Wait for nodes to be discovered diff --git a/waku/v2/node/connectedness.go b/waku/v2/node/connectedness.go index cf1a5ea3..79dbd957 100644 --- a/waku/v2/node/connectedness.go +++ b/waku/v2/node/connectedness.go @@ -96,7 +96,7 @@ func (w *WakuNode) connectednessListener() { for { select { - case <-w.quit: + case <-w.ctx.Done(): return case <-w.protocolEventSub.Out(): case <-w.identificationEventSub.Out(): diff --git a/waku/v2/node/keepalive.go b/waku/v2/node/keepalive.go index 415944ae..fd6cbff5 100644 --- a/waku/v2/node/keepalive.go +++ b/waku/v2/node/keepalive.go @@ -54,7 +54,8 @@ func (w *WakuNode) startKeepAlive(t time.Duration) { } lastTimeExecuted = w.timesource.Now() - case <-w.quit: + case <-w.ctx.Done(): + w.log.Info("stopping ping protocol") return } } diff --git a/waku/v2/node/localnode.go b/waku/v2/node/localnode.go index 4ae3683c..8b0be2e7 100644 --- a/waku/v2/node/localnode.go +++ b/waku/v2/node/localnode.go @@ -241,7 +241,7 @@ func (w *WakuNode) setupENR(addrs []ma.Multiaddr) error { if w.discoveryV5 != nil && w.discoveryV5.IsStarted() { w.log.Info("restarting discv5") w.discoveryV5.Stop() - err = w.discoveryV5.Start() + err = w.discoveryV5.Start(w.ctx) if err != nil { w.log.Error("could not restart discv5", zap.Error(err)) return err diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index 92bd4661..9313be43 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -95,9 +95,8 @@ type WakuNode struct { keepAliveMutex sync.Mutex keepAliveFails map[peer.ID]int - ctx context.Context + ctx context.Context // TODO: remove this cancel context.CancelFunc - quit chan struct{} wg *sync.WaitGroup // Channel passed to WakuNode constructor @@ -171,7 +170,6 @@ func New(ctx context.Context, opts ...WakuNodeOption) (*WakuNode, error) { w.ctx = ctx w.opts = params w.log = params.logger.Named("node2") - w.quit = make(chan struct{}) w.wg = &sync.WaitGroup{} w.addrChan = make(chan ma.Multiaddr, 1024) w.keepAliveFails = make(map[peer.ID]int) @@ -236,7 +234,7 @@ func (w *WakuNode) checkForAddressChanges() { first <- struct{}{} for { select { - case <-w.quit: + case <-w.ctx.Done(): close(w.addrChan) return case <-first: @@ -269,7 +267,7 @@ func (w *WakuNode) checkForAddressChanges() { // Start initializes all the protocols that were setup in the WakuNode func (w *WakuNode) Start() error { if w.opts.enableNTP { - err := w.timesource.Start() + err := w.timesource.Start(w.ctx) if err != nil { return err } @@ -358,9 +356,7 @@ func (w *WakuNode) Start() error { // Stop stops the WakuNode and closess all connections to the host func (w *WakuNode) Stop() { - defer w.cancel() - - close(w.quit) + w.cancel() w.bcaster.Close() @@ -524,14 +520,14 @@ func (w *WakuNode) mountDiscV5() error { } var err error - w.discoveryV5, err = discv5.NewDiscoveryV5(w.ctx, w.Host(), w.opts.privKey, w.localNode, w.log, discV5Options...) + w.discoveryV5, err = discv5.NewDiscoveryV5(w.Host(), w.opts.privKey, w.localNode, w.log, discV5Options...) return err } func (w *WakuNode) mountPeerExchange() error { - w.peerExchange = peer_exchange.NewWakuPeerExchange(w.ctx, w.host, w.discoveryV5, w.log) - return w.peerExchange.Start() + w.peerExchange = peer_exchange.NewWakuPeerExchange(w.host, w.discoveryV5, w.log) + return w.peerExchange.Start(w.ctx) } func (w *WakuNode) startStore() error { diff --git a/waku/v2/node/wakunode2_rln.go b/waku/v2/node/wakunode2_rln.go index a1a35af4..86f792f5 100644 --- a/waku/v2/node/wakunode2_rln.go +++ b/waku/v2/node/wakunode2_rln.go @@ -4,7 +4,6 @@ package node import ( - "context" "encoding/hex" "errors" @@ -81,7 +80,7 @@ func (w *WakuNode) mountRlnRelay() error { // mount the rln relay protocol in the on-chain/dynamic mode var err error - w.rlnRelay, err = rln.RlnRelayDynamic(context.Background(), w.relay, w.opts.rlnETHClientAddress, w.opts.rlnETHPrivateKey, w.opts.rlnMembershipContractAddress, memKeyPair, w.opts.rlnRelayMemIndex, w.opts.rlnRelayPubsubTopic, w.opts.rlnRelayContentTopic, w.opts.rlnSpamHandler, w.opts.rlnRegistrationHandler, w.timesource, w.log) + w.rlnRelay, err = rln.RlnRelayDynamic(w.ctx, w.relay, w.opts.rlnETHClientAddress, w.opts.rlnETHPrivateKey, w.opts.rlnMembershipContractAddress, memKeyPair, w.opts.rlnRelayMemIndex, w.opts.rlnRelayPubsubTopic, w.opts.rlnRelayContentTopic, w.opts.rlnSpamHandler, w.opts.rlnRegistrationHandler, w.timesource, w.log) if err != nil { return err } diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange.go index 06c84546..e5c3732d 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange.go @@ -46,24 +46,21 @@ type peerRecord struct { type WakuPeerExchange struct { h host.Host - ctx context.Context disc *discv5.DiscoveryV5 - log *zap.Logger - quit chan struct{} - wg sync.WaitGroup + log *zap.Logger + + cancel context.CancelFunc + wg sync.WaitGroup enrCache map[enode.ID]peerRecord // todo: next step: ring buffer; future: implement cache satisfying https://rfc.vac.dev/spec/34/ enrCacheMutex sync.RWMutex rng *rand.Rand - - started bool } // NewWakuPeerExchange returns a new instance of WakuPeerExchange struct -func NewWakuPeerExchange(ctx context.Context, h host.Host, disc *discv5.DiscoveryV5, log *zap.Logger) *WakuPeerExchange { +func NewWakuPeerExchange(h host.Host, disc *discv5.DiscoveryV5, log *zap.Logger) *WakuPeerExchange { wakuPX := new(WakuPeerExchange) - wakuPX.ctx = ctx wakuPX.h = h wakuPX.disc = disc wakuPX.log = log.Named("wakupx") @@ -73,19 +70,21 @@ func NewWakuPeerExchange(ctx context.Context, h host.Host, disc *discv5.Discover } // Start inits the peer exchange protocol -func (wakuPX *WakuPeerExchange) Start() error { - wakuPX.h.SetStreamHandlerMatch(PeerExchangeID_v20alpha1, protocol.PrefixTextMatch(string(PeerExchangeID_v20alpha1)), wakuPX.onRequest) +func (wakuPX *WakuPeerExchange) Start(ctx context.Context) error { + wakuPX.wg.Wait() // Waiting for any go routines to stop + ctx, cancel := context.WithCancel(ctx) + wakuPX.cancel = cancel + + wakuPX.h.SetStreamHandlerMatch(PeerExchangeID_v20alpha1, protocol.PrefixTextMatch(string(PeerExchangeID_v20alpha1)), wakuPX.onRequest(ctx)) wakuPX.log.Info("Peer exchange protocol started") - wakuPX.started = true - wakuPX.quit = make(chan struct{}, 1) wakuPX.wg.Add(1) - go wakuPX.runPeerExchangeDiscv5Loop() + go wakuPX.runPeerExchangeDiscv5Loop(ctx) return nil } -func (wakuPX *WakuPeerExchange) handleResponse(response *pb.PeerExchangeResponse) error { +func (wakuPX *WakuPeerExchange) handleResponse(ctx context.Context, response *pb.PeerExchangeResponse) error { var peers []peer.AddrInfo for _, p := range response.PeerInfos { enrRecord := &enr.Record{} @@ -118,7 +117,7 @@ func (wakuPX *WakuPeerExchange) handleResponse(response *pb.PeerExchangeResponse log.Info("connecting to newly discovered peers", zap.Int("count", len(peers))) for _, p := range peers { func(p peer.AddrInfo) { - ctx, cancel := context.WithTimeout(wakuPX.ctx, dialTimeout) + ctx, cancel := context.WithTimeout(ctx, dialTimeout) defer cancel() err := wakuPX.h.Connect(ctx, p) if err != nil { @@ -131,35 +130,37 @@ func (wakuPX *WakuPeerExchange) handleResponse(response *pb.PeerExchangeResponse return nil } -func (wakuPX *WakuPeerExchange) onRequest(s network.Stream) { - defer s.Close() - logger := wakuPX.log.With(logging.HostID("peer", s.Conn().RemotePeer())) - requestRPC := &pb.PeerExchangeRPC{} - reader := protoio.NewDelimitedReader(s, math.MaxInt32) - err := reader.ReadMsg(requestRPC) - if err != nil { - logger.Error("reading request", zap.Error(err)) - metrics.RecordPeerExchangeError(wakuPX.ctx, "decodeRpcFailure") - return - } - - if requestRPC.Query != nil { - logger.Info("request received") - err := wakuPX.respond(requestRPC.Query.NumPeers, s.Conn().RemotePeer()) +func (wakuPX *WakuPeerExchange) onRequest(ctx context.Context) func(s network.Stream) { + return func(s network.Stream) { + defer s.Close() + logger := wakuPX.log.With(logging.HostID("peer", s.Conn().RemotePeer())) + requestRPC := &pb.PeerExchangeRPC{} + reader := protoio.NewDelimitedReader(s, math.MaxInt32) + err := reader.ReadMsg(requestRPC) if err != nil { - logger.Error("responding", zap.Error(err)) - metrics.RecordPeerExchangeError(wakuPX.ctx, "pxFailure") + logger.Error("reading request", zap.Error(err)) + metrics.RecordPeerExchangeError(ctx, "decodeRpcFailure") return } - } - if requestRPC.Response != nil { - logger.Info("response received") - err := wakuPX.handleResponse(requestRPC.Response) - if err != nil { - logger.Error("handling response", zap.Error(err)) - metrics.RecordPeerExchangeError(wakuPX.ctx, "pxFailure") - return + if requestRPC.Query != nil { + logger.Info("request received") + err := wakuPX.respond(ctx, requestRPC.Query.NumPeers, s.Conn().RemotePeer()) + if err != nil { + logger.Error("responding", zap.Error(err)) + metrics.RecordPeerExchangeError(ctx, "pxFailure") + return + } + } + + if requestRPC.Response != nil { + logger.Info("response received") + err := wakuPX.handleResponse(ctx, requestRPC.Response) + if err != nil { + logger.Error("handling response", zap.Error(err)) + metrics.RecordPeerExchangeError(ctx, "pxFailure") + return + } } } } @@ -176,7 +177,7 @@ func (wakuPX *WakuPeerExchange) Request(ctx context.Context, numPeers int, opts } if params.selectedPeer == "" { - metrics.RecordPeerExchangeError(wakuPX.ctx, "dialError") + metrics.RecordPeerExchangeError(ctx, "dialError") return ErrNoPeersAvailable } @@ -186,35 +187,27 @@ func (wakuPX *WakuPeerExchange) Request(ctx context.Context, numPeers int, opts }, } - return wakuPX.sendPeerExchangeRPCToPeer(requestRPC, params.selectedPeer) -} - -// IsStarted returns if the peer exchange protocol has been mounted or not -func (wakuPX *WakuPeerExchange) IsStarted() bool { - return wakuPX.started + return wakuPX.sendPeerExchangeRPCToPeer(ctx, requestRPC, params.selectedPeer) } // Stop unmounts the peer exchange protocol func (wakuPX *WakuPeerExchange) Stop() { - if wakuPX.started { - wakuPX.h.RemoveStreamHandler(PeerExchangeID_v20alpha1) - wakuPX.started = false - close(wakuPX.quit) - wakuPX.wg.Wait() - } + wakuPX.cancel() + wakuPX.h.RemoveStreamHandler(PeerExchangeID_v20alpha1) + wakuPX.wg.Wait() } -func (wakuPX *WakuPeerExchange) sendPeerExchangeRPCToPeer(rpc *pb.PeerExchangeRPC, peerID peer.ID) error { +func (wakuPX *WakuPeerExchange) sendPeerExchangeRPCToPeer(ctx context.Context, rpc *pb.PeerExchangeRPC, peerID peer.ID) error { logger := wakuPX.log.With(logging.HostID("peer", peerID)) // We connect first so dns4 addresses are resolved (NewStream does not do it) - err := wakuPX.h.Connect(wakuPX.ctx, wakuPX.h.Peerstore().PeerInfo(peerID)) + err := wakuPX.h.Connect(ctx, wakuPX.h.Peerstore().PeerInfo(peerID)) if err != nil { logger.Error("connecting peer", zap.Error(err)) return err } - connOpt, err := wakuPX.h.NewStream(wakuPX.ctx, peerID, PeerExchangeID_v20alpha1) + connOpt, err := wakuPX.h.NewStream(ctx, peerID, PeerExchangeID_v20alpha1) if err != nil { logger.Error("creating stream to peer", zap.Error(err)) return err @@ -231,7 +224,7 @@ func (wakuPX *WakuPeerExchange) sendPeerExchangeRPCToPeer(rpc *pb.PeerExchangeRP return nil } -func (wakuPX *WakuPeerExchange) respond(numPeers uint64, peerID peer.ID) error { +func (wakuPX *WakuPeerExchange) respond(ctx context.Context, numPeers uint64, peerID peer.ID) error { records, err := wakuPX.getENRsFromCache(numPeers) if err != nil { return err @@ -241,7 +234,7 @@ func (wakuPX *WakuPeerExchange) respond(numPeers uint64, peerID peer.ID) error { responseRPC.Response = new(pb.PeerExchangeResponse) responseRPC.Response.PeerInfos = records - return wakuPX.sendPeerExchangeRPCToPeer(responseRPC, peerID) + return wakuPX.sendPeerExchangeRPCToPeer(ctx, responseRPC, peerID) } func (wakuPX *WakuPeerExchange) getENRsFromCache(numPeers uint64) ([]*pb.PeerInfo, error) { @@ -304,12 +297,8 @@ func (wakuPX *WakuPeerExchange) cleanCache() { wakuPX.enrCache = r } -func (wakuPX *WakuPeerExchange) findPeers() { - if !wakuPX.disc.IsStarted() { - return - } - - ctx, cancel := context.WithTimeout(wakuPX.ctx, 2*time.Second) +func (wakuPX *WakuPeerExchange) findPeers(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() peerRecords, err := wakuPX.disc.FindNodes(ctx, "") if err != nil { @@ -332,7 +321,7 @@ func (wakuPX *WakuPeerExchange) findPeers() { wakuPX.cleanCache() } -func (wakuPX *WakuPeerExchange) runPeerExchangeDiscv5Loop() { +func (wakuPX *WakuPeerExchange) runPeerExchangeDiscv5Loop(ctx context.Context) { defer wakuPX.wg.Done() // Runs a discv5 loop adding new peers to the px peer cache @@ -349,15 +338,15 @@ func (wakuPX *WakuPeerExchange) runPeerExchangeDiscv5Loop() { // This loop "competes" with the loop in wakunode2 // For the purpose of collecting px peers, 30 sec intervals should be enough - wakuPX.findPeers() + wakuPX.findPeers(ctx) for { select { - case <-wakuPX.quit: + case <-ctx.Done(): return case <-ticker.C: - wakuPX.findPeers() + wakuPX.findPeers(ctx) } } diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go index e5a5be26..e473cc24 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go @@ -105,7 +105,7 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { ip1, _ := extractIP(host1.Addrs()[0]) l1, err := newLocalnode(prvKey1, ip1, udpPort1, utils.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) require.NoError(t, err) - d1, err := discv5.NewDiscoveryV5(context.Background(), host1, prvKey1, l1, utils.Logger(), discv5.WithUDPPort(udpPort1)) + d1, err := discv5.NewDiscoveryV5(host1, prvKey1, l1, utils.Logger(), discv5.WithUDPPort(udpPort1)) require.NoError(t, err) // H2 @@ -115,7 +115,7 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { require.NoError(t, err) l2, err := newLocalnode(prvKey2, ip2, udpPort2, utils.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) require.NoError(t, err) - d2, err := discv5.NewDiscoveryV5(context.Background(), host2, prvKey2, l2, utils.Logger(), discv5.WithUDPPort(udpPort2), discv5.WithBootnodes([]*enode.Node{d1.Node()})) + d2, err := discv5.NewDiscoveryV5(host2, prvKey2, l2, utils.Logger(), discv5.WithUDPPort(udpPort2), discv5.WithBootnodes([]*enode.Node{d1.Node()})) require.NoError(t, err) // H3 @@ -127,22 +127,22 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { defer host2.Close() defer host3.Close() - err = d1.Start() + err = d1.Start(context.Background()) require.NoError(t, err) - err = d2.Start() + err = d2.Start(context.Background()) require.NoError(t, err) time.Sleep(3 * time.Second) // Wait some time for peers to be discovered // mount peer exchange - px1 := NewWakuPeerExchange(context.Background(), host1, d1, utils.Logger()) - px3 := NewWakuPeerExchange(context.Background(), host3, nil, utils.Logger()) + px1 := NewWakuPeerExchange(host1, d1, utils.Logger()) + px3 := NewWakuPeerExchange(host3, nil, utils.Logger()) - err = px1.Start() + err = px1.Start(context.Background()) require.NoError(t, err) - err = px3.Start() + err = px3.Start(context.Background()) require.NoError(t, err) host3.Peerstore().AddAddrs(host1.ID(), host1.Addrs(), peerstore.PermanentAddrTTL) diff --git a/waku/v2/protocol/rln/web3.go b/waku/v2/protocol/rln/web3.go index d6d85450..055ec611 100644 --- a/waku/v2/protocol/rln/web3.go +++ b/waku/v2/protocol/rln/web3.go @@ -33,7 +33,7 @@ func register(ctx context.Context, idComm r.IDCommitment, ethAccountPrivateKey * } defer backend.Close() - chainID, err := backend.ChainID(context.Background()) + chainID, err := backend.ChainID(ctx) if err != nil { return nil, err } diff --git a/waku/v2/timesource/ntp.go b/waku/v2/timesource/ntp.go index b019d8cf..8f7fbe08 100644 --- a/waku/v2/timesource/ntp.go +++ b/waku/v2/timesource/ntp.go @@ -2,6 +2,7 @@ package timesource import ( "bytes" + "context" "errors" "sort" "sync" @@ -133,8 +134,8 @@ type NTPTimeSource struct { timeQuery ntpQuery // for ease of testing log *zap.Logger - quit chan struct{} - wg sync.WaitGroup + cancel context.CancelFunc + wg sync.WaitGroup mu sync.RWMutex latestOffset time.Duration @@ -162,9 +163,11 @@ func (s *NTPTimeSource) updateOffset() error { // runPeriodically runs periodically the given function based on NTPTimeSource // synchronization limits (fastNTPSyncPeriod / slowNTPSyncPeriod) -func (s *NTPTimeSource) runPeriodically(fn func() error) error { +func (s *NTPTimeSource) runPeriodically(ctx context.Context, fn func() error) error { var period time.Duration - s.quit = make(chan struct{}) + + s.log.Info("starting service") + // we try to do it synchronously so that user can have reliable messages right away s.wg.Add(1) go func() { @@ -177,7 +180,8 @@ func (s *NTPTimeSource) runPeriodically(fn func() error) error { period = s.fastNTPSyncPeriod } - case <-s.quit: + case <-ctx.Done(): + s.log.Info("stopping service") s.wg.Done() return } @@ -188,16 +192,16 @@ func (s *NTPTimeSource) runPeriodically(fn func() error) error { } // Start runs a goroutine that updates local offset every updatePeriod. -func (s *NTPTimeSource) Start() error { - return s.runPeriodically(s.updateOffset) +func (s *NTPTimeSource) Start(ctx context.Context) error { + s.wg.Wait() // Waiting for other go routines to stop + ctx, cancel := context.WithCancel(ctx) + s.cancel = cancel + return s.runPeriodically(ctx, s.updateOffset) } // Stop goroutine that updates time source. func (s *NTPTimeSource) Stop() error { - if s.quit == nil { - return nil - } - close(s.quit) + s.cancel() s.wg.Wait() return nil } diff --git a/waku/v2/timesource/ntp_test.go b/waku/v2/timesource/ntp_test.go index 68a990f6..e5987179 100644 --- a/waku/v2/timesource/ntp_test.go +++ b/waku/v2/timesource/ntp_test.go @@ -1,6 +1,7 @@ package timesource import ( + "context" "errors" "sync" "testing" @@ -174,12 +175,15 @@ func TestComputeOffset(t *testing.T) { func TestNTPTimeSource(t *testing.T) { for _, tc := range newTestCases() { t.Run(tc.description, func(t *testing.T) { + _, cancel := context.WithCancel(context.Background()) source := &NTPTimeSource{ servers: tc.servers, allowedFailures: tc.allowedFailures, timeQuery: tc.query, log: utils.Logger(), + cancel: cancel, } + assert.WithinDuration(t, time.Now(), source.Now(), clockCompareDelta) err := source.updateOffset() if tc.expectError { @@ -202,6 +206,7 @@ func TestRunningPeriodically(t *testing.T) { slowHits := 1 t.Run(tc.description, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) source := &NTPTimeSource{ servers: tc.servers, allowedFailures: tc.allowedFailures, @@ -209,11 +214,12 @@ func TestRunningPeriodically(t *testing.T) { fastNTPSyncPeriod: time.Duration(fastHits*10) * time.Millisecond, slowNTPSyncPeriod: time.Duration(slowHits*10) * time.Millisecond, log: utils.Logger(), + cancel: cancel, } lastCall := time.Now() // we're simulating a calls to updateOffset, testing ntp calls happens // on NTPTimeSource specified periods (fastNTPSyncPeriod & slowNTPSyncPeriod) - err := source.runPeriodically(func() error { + err := source.runPeriodically(ctx, func() error { mu.Lock() periods = append(periods, time.Since(lastCall)) mu.Unlock() diff --git a/waku/v2/timesource/timesource.go b/waku/v2/timesource/timesource.go index 25ce9cbc..667a227a 100644 --- a/waku/v2/timesource/timesource.go +++ b/waku/v2/timesource/timesource.go @@ -1,9 +1,12 @@ package timesource -import "time" +import ( + "context" + "time" +) type Timesource interface { Now() time.Time - Start() error + Start(ctx context.Context) error Stop() error } diff --git a/waku/v2/timesource/wall.go b/waku/v2/timesource/wall.go index 939a21ff..67b8b343 100644 --- a/waku/v2/timesource/wall.go +++ b/waku/v2/timesource/wall.go @@ -1,6 +1,9 @@ package timesource -import "time" +import ( + "context" + "time" +) type WallClockTimeSource struct { } @@ -13,7 +16,7 @@ func (t *WallClockTimeSource) Now() time.Time { return time.Now() } -func (t *WallClockTimeSource) Start() error { +func (t *WallClockTimeSource) Start(ctx context.Context) error { // Do nothing return nil }