diff --git a/waku/v2/discv5/discover.go b/waku/v2/discv5/discover.go index 97d16531..184e84b0 100644 --- a/waku/v2/discv5/discover.go +++ b/waku/v2/discv5/discover.go @@ -256,7 +256,20 @@ func (d *DiscoveryV5) iterate(ctx context.Context) error { return fmt.Errorf("obtaining iterator: %w", err) } - defer iterator.Close() + closeCh := make(chan struct{}, 1) + defer close(closeCh) + + // Closing iterator when context is cancelled or function is returning + d.wg.Add(1) + go func() { + defer d.wg.Done() + select { + case <-ctx.Done(): + iterator.Close() + case <-closeCh: + iterator.Close() + } + }() for { if ctx.Err() != nil { diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange.go index 86df733f..ef00fe5f 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange.go @@ -317,7 +317,21 @@ func (wakuPX *WakuPeerExchange) iterate(ctx context.Context) error { if err != nil { return fmt.Errorf("obtaining iterator: %w", err) } - defer iterator.Close() + + closeCh := make(chan struct{}, 1) + defer close(closeCh) + + // Closing iterator when context is cancelled or function is returning + wakuPX.wg.Add(1) + go func() { + defer wakuPX.wg.Done() + select { + case <-ctx.Done(): + iterator.Close() + case <-closeCh: + iterator.Close() + } + }() for { if ctx.Err() != nil {