diff --git a/notify.go b/notify.go index 7965263..34feb7c 100644 --- a/notify.go +++ b/notify.go @@ -22,8 +22,15 @@ func (p *PubSubNotif) Connected(n network.Network, c network.Conn) { return } - p.newPeersMx.Lock() - defer p.newPeersMx.Unlock() + select { + case <-p.newPeersSema: + defer func() { + p.newPeersSema <- struct{}{} + }() + + case <-p.ctx.Done(): + return + } p.newPeersPend[c.RemotePeer()] = struct{}{} select { @@ -52,8 +59,15 @@ func (p *PubSubNotif) Initialize() { return true } - p.newPeersMx.Lock() - defer p.newPeersMx.Unlock() + select { + case <-p.newPeersSema: + defer func() { + p.newPeersSema <- struct{}{} + }() + + case <-p.ctx.Done(): + return + } for _, pid := range p.host.Network().Peers() { if isTransient(pid) { diff --git a/pubsub.go b/pubsub.go index cf0fb3b..3b9c011 100644 --- a/pubsub.go +++ b/pubsub.go @@ -91,7 +91,7 @@ type PubSub struct { // a notification channel for new peer connections accumulated newPeers chan struct{} - newPeersMx sync.Mutex + newPeersSema chan struct{} newPeersPend map[peer.ID]struct{} // a notification channel for new outoging peer streams @@ -234,6 +234,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option signPolicy: StrictSign, incoming: make(chan *RPC, 32), newPeers: make(chan struct{}, 1), + newPeersSema: make(chan struct{}, 1), newPeersPend: make(map[peer.ID]struct{}), newPeerStream: make(chan network.Stream), newPeerError: make(chan peer.ID), @@ -263,6 +264,8 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option counter: uint64(time.Now().UnixNano()), } + ps.newPeersSema <- struct{}{} + for _, opt := range opts { err := opt(ps) if err != nil { @@ -612,8 +615,16 @@ func (p *PubSub) processLoop(ctx context.Context) { } func (p *PubSub) handlePendingPeers() { - p.newPeersMx.Lock() - defer p.newPeersMx.Unlock() + select { + case <-p.newPeersSema: + defer func() { + p.newPeersSema <- struct{}{} + }() + + default: + // contention, return and wait for the next notification without blocking the event loop + return + } if len(p.newPeersPend) == 0 { return