refactor: use context instead of quit channel in wakuv2/waku.go

This commit is contained in:
Richard Ramos 2023-07-17 13:20:55 -04:00 committed by richΛrd
parent 25ff1dd758
commit b9b86712e7
2 changed files with 39 additions and 27 deletions

View File

@ -1 +1 @@
0.163.8 0.163.9

View File

@ -120,8 +120,10 @@ type Waku struct {
sendQueue chan *protocol.Envelope sendQueue chan *protocol.Envelope
msgQueue chan *common.ReceivedMessage // Message queue for waku messages that havent been decoded msgQueue chan *common.ReceivedMessage // Message queue for waku messages that havent been decoded
quit chan struct{} // Channel used for graceful exit
wg sync.WaitGroup ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
cfg *Config cfg *Config
settings settings // Holds configuration settings that can be dynamically changed settings settings // Holds configuration settings that can be dynamically changed
@ -219,9 +221,10 @@ func New(nodeKey string, fleet string, cfg *Config, logger *zap.Logger, appDB *s
onHistoricMessagesRequestFailed: onHistoricMessagesRequestFailed, onHistoricMessagesRequestFailed: onHistoricMessagesRequestFailed,
onPeerStats: onPeerStats, onPeerStats: onPeerStats,
} }
// This fn is being mocked in test // This fn is being mocked in test
waku.isFilterSubAlive = func(sub *filter.SubscriptionDetails) error { waku.isFilterSubAlive = func(sub *filter.SubscriptionDetails) error {
return waku.node.FilterLightnode().IsSubscriptionAlive(context.Background(), sub) return waku.node.FilterLightnode().IsSubscriptionAlive(waku.ctx, sub)
} }
waku.settings = settings{ waku.settings = settings{
@ -497,7 +500,7 @@ func (w *Waku) telemetryBandwidthStats(telemetryServerURL string) {
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
return return
case now := <-ticker.C: case now := <-ticker.C:
// Reset totals when day changes // Reset totals when day changes
@ -535,7 +538,7 @@ func (w *Waku) runPeerExchangeLoop() {
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
w.logger.Debug("Running peer exchange loop") w.logger.Debug("Running peer exchange loop")
@ -591,7 +594,7 @@ func (w *Waku) runPeerExchangeLoop() {
continue // No peers with peer exchange have been discovered via DNS Discovery so far, skip this iteration continue // No peers with peer exchange have been discovered via DNS Discovery so far, skip this iteration
} }
err := w.node.PeerExchange().Request(context.Background(), peersToDiscover, peer_exchange.WithAutomaticPeerSelection(withThesePeers...)) err := w.node.PeerExchange().Request(w.ctx, peersToDiscover, peer_exchange.WithAutomaticPeerSelection(withThesePeers...))
if err != nil { if err != nil {
w.logger.Error("couldnt request peers via peer exchange", zap.Error(err)) w.logger.Error("couldnt request peers via peer exchange", zap.Error(err))
} }
@ -606,7 +609,7 @@ func (w *Waku) runRelayMsgLoop() {
return return
} }
sub, err := w.node.Relay().Subscribe(context.Background()) sub, err := w.node.Relay().Subscribe(w.ctx)
if err != nil { if err != nil {
fmt.Println("Could not subscribe:", err) fmt.Println("Could not subscribe:", err)
return return
@ -614,7 +617,7 @@ func (w *Waku) runRelayMsgLoop() {
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
sub.Unsubscribe() sub.Unsubscribe()
return return
case env := <-sub.Ch: case env := <-sub.Ch:
@ -632,7 +635,7 @@ func (w *Waku) runRelayMsgLoop() {
func (w *Waku) runFilterSubscriptionLoop(sub *filter.SubscriptionDetails) { func (w *Waku) runFilterSubscriptionLoop(sub *filter.SubscriptionDetails) {
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
return return
case env, ok := <-sub.C: case env, ok := <-sub.C:
if ok { if ok {
@ -660,7 +663,7 @@ func (w *Waku) runFilterMsgLoop() {
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
for f, subMap := range w.filterSubscriptions { for f, subMap := range w.filterSubscriptions {
@ -679,7 +682,7 @@ func (w *Waku) runFilterMsgLoop() {
// Unsubscribe on light node // Unsubscribe on light node
contentFilter := w.buildContentFilter(f.Topics) contentFilter := w.buildContentFilter(f.Topics)
// TODO Better return value handling for WakuFilterPushResult // TODO Better return value handling for WakuFilterPushResult
_, err := w.node.FilterLightnode().Unsubscribe(context.Background(), contentFilter, filter.Peer(sub.PeerID)) _, err := w.node.FilterLightnode().Unsubscribe(w.ctx, contentFilter, filter.Peer(sub.PeerID))
if err != nil { if err != nil {
w.logger.Warn("could not unsubscribe wakuv2 filter for peer", zap.Any("peer", sub.PeerID)) w.logger.Warn("could not unsubscribe wakuv2 filter for peer", zap.Any("peer", sub.PeerID))
continue continue
@ -692,7 +695,7 @@ func (w *Waku) runFilterMsgLoop() {
// Re-subscribe // Re-subscribe
peers := w.findFilterPeers() peers := w.findFilterPeers()
if len(peers) > 0 && len(subMap) < w.settings.MinPeersForFilter { if len(peers) > 0 && len(subMap) < w.settings.MinPeersForFilter {
subDetails, err := w.node.FilterLightnode().Subscribe(context.Background(), contentFilter, filter.WithPeer(peers[0])) subDetails, err := w.node.FilterLightnode().Subscribe(w.ctx, contentFilter, filter.WithPeer(peers[0]))
if err != nil { if err != nil {
w.logger.Warn("could not add wakuv2 filter for peer", zap.Any("peer", peers[0])) w.logger.Warn("could not add wakuv2 filter for peer", zap.Any("peer", peers[0]))
break break
@ -1067,10 +1070,10 @@ func (w *Waku) broadcast() {
var err error var err error
if w.settings.LightClient { if w.settings.LightClient {
w.logger.Info("publishing message via lightpush", zap.String("envelopeHash", hexutil.Encode(envelope.Hash()))) w.logger.Info("publishing message via lightpush", zap.String("envelopeHash", hexutil.Encode(envelope.Hash())))
_, err = w.node.Lightpush().Publish(context.Background(), envelope.Message()) _, err = w.node.Lightpush().Publish(w.ctx, envelope.Message())
} else { } else {
w.logger.Info("publishing message via relay", zap.String("envelopeHash", hexutil.Encode(envelope.Hash()))) w.logger.Info("publishing message via relay", zap.String("envelopeHash", hexutil.Encode(envelope.Hash())))
_, err = w.node.Relay().Publish(context.Background(), envelope.Message()) _, err = w.node.Relay().Publish(w.ctx, envelope.Message())
} }
if err != nil { if err != nil {
@ -1090,7 +1093,7 @@ func (w *Waku) broadcast() {
w.SendEnvelopeEvent(event) w.SendEnvelopeEvent(event)
case <-w.quit: case <-w.ctx.Done():
return return
} }
} }
@ -1174,10 +1177,12 @@ func (w *Waku) Start() error {
return fmt.Errorf("failed to create a go-waku node: %v", err) return fmt.Errorf("failed to create a go-waku node: %v", err)
} }
w.quit = make(chan struct{})
w.connectionChanged = make(chan struct{}) w.connectionChanged = make(chan struct{})
ctx := context.Background() ctx, cancel := context.WithCancel(context.Background())
w.ctx = ctx
w.cancel = cancel
if err = w.node.Start(ctx); err != nil { if err = w.node.Start(ctx); err != nil {
return fmt.Errorf("failed to start go-waku node: %v", err) return fmt.Errorf("failed to start go-waku node: %v", err)
} }
@ -1208,7 +1213,7 @@ func (w *Waku) Start() error {
isConnected := false isConnected := false
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
return return
case c := <-w.connStatusChan: case c := <-w.connStatusChan:
w.connStatusMu.Lock() w.connStatusMu.Lock()
@ -1265,7 +1270,7 @@ func (w *Waku) Start() error {
// Stop implements node.Service, stopping the background data propagation thread // Stop implements node.Service, stopping the background data propagation thread
// of the Waku protocol. // of the Waku protocol.
func (w *Waku) Stop() error { func (w *Waku) Stop() error {
close(w.quit) w.cancel()
w.identifyService.Close() w.identifyService.Close()
w.node.Stop() w.node.Stop()
close(w.connectionChanged) close(w.connectionChanged)
@ -1348,7 +1353,7 @@ func (w *Waku) postEvent(envelope *common.ReceivedMessage) {
func (w *Waku) processQueue() { func (w *Waku) processQueue() {
for { for {
select { select {
case <-w.quit: case <-w.ctx.Done():
return return
case e := <-w.msgQueue: case e := <-w.msgQueue:
if e.MsgType == common.StoreMessageType { if e.MsgType == common.StoreMessageType {
@ -1429,7 +1434,7 @@ func (w *Waku) StartDiscV5() error {
return errors.New("discv5 is not setup") return errors.New("discv5 is not setup")
} }
return w.node.DiscV5().Start(context.Background()) return w.node.DiscV5().Start(w.ctx)
} }
func (w *Waku) StopDiscV5() error { func (w *Waku) StopDiscV5() error {
@ -1518,7 +1523,7 @@ func (w *Waku) seedBootnodesForDiscV5() {
retries = 0 retries = 0
lastTry = now() lastTry = now()
case <-w.quit: case <-w.ctx.Done():
return return
} }
} }
@ -1526,7 +1531,7 @@ func (w *Waku) seedBootnodesForDiscV5() {
// Restart discv5, re-retrieving bootstrap nodes // Restart discv5, re-retrieving bootstrap nodes
func (w *Waku) restartDiscV5() error { func (w *Waku) restartDiscV5() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(w.ctx, 30*time.Second)
defer cancel() defer cancel()
bootnodes, err := w.getDiscV5BootstrapNodes(ctx, w.discV5BootstrapNodes) bootnodes, err := w.getDiscV5BootstrapNodes(ctx, w.discV5BootstrapNodes)
if err != nil { if err != nil {
@ -1545,6 +1550,13 @@ func (w *Waku) restartDiscV5() error {
} else { } else {
w.node.DiscV5().Stop() w.node.DiscV5().Stop()
w.logger.Info("is started restarting") w.logger.Info("is started restarting")
select {
case <-w.ctx.Done(): // Don't start discv5 if we are stopping waku
return nil
default:
}
err := w.node.DiscV5().Start(ctx) err := w.node.DiscV5().Start(ctx)
if err != nil { if err != nil {
w.logger.Error("Could not start DiscV5", zap.Error(err)) w.logger.Error("Could not start DiscV5", zap.Error(err))
@ -1586,13 +1598,13 @@ func (w *Waku) AddRelayPeer(address string) (peer.ID, error) {
} }
func (w *Waku) DialPeer(address string) error { func (w *Waku) DialPeer(address string) error {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) ctx, cancel := context.WithTimeout(w.ctx, requestTimeout)
defer cancel() defer cancel()
return w.node.DialPeer(ctx, address) return w.node.DialPeer(ctx, address)
} }
func (w *Waku) DialPeerByID(peerID string) error { func (w *Waku) DialPeerByID(peerID string) error {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) ctx, cancel := context.WithTimeout(w.ctx, requestTimeout)
defer cancel() defer cancel()
pid, err := peer.Decode(peerID) pid, err := peer.Decode(peerID)
if err != nil { if err != nil {
@ -1717,7 +1729,7 @@ func (w *Waku) subscribeToFilter(f *common.Filter) error {
if len(peers) > 0 { if len(peers) > 0 {
contentFilter := w.buildContentFilter(f.Topics) contentFilter := w.buildContentFilter(f.Topics)
for i := 0; i < len(peers) && i < w.settings.MinPeersForFilter; i++ { for i := 0; i < len(peers) && i < w.settings.MinPeersForFilter; i++ {
subDetails, err := w.node.FilterLightnode().Subscribe(context.Background(), contentFilter, filter.WithPeer(peers[i])) subDetails, err := w.node.FilterLightnode().Subscribe(w.ctx, contentFilter, filter.WithPeer(peers[i]))
if err != nil { if err != nil {
w.logger.Warn("could not add wakuv2 filter for peer", zap.Any("peer", peers[i])) w.logger.Warn("could not add wakuv2 filter for peer", zap.Any("peer", peers[i]))
continue continue