fix: race condition in unsubscribe (#197)

This commit is contained in:
Richard Ramos 2022-02-23 11:08:27 -04:00 committed by GitHub
parent 8d155fb51e
commit df66ef5bb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 34 deletions

View File

@ -7,10 +7,17 @@ import (
// Adapted from https://github.com/dustin/go-broadcast/commit/f664265f5a662fb4d1df7f3533b1e8d0e0277120 // Adapted from https://github.com/dustin/go-broadcast/commit/f664265f5a662fb4d1df7f3533b1e8d0e0277120
// by Dustin Sallings (c) 2013, which was released under MIT license // by Dustin Sallings (c) 2013, which was released under MIT license
type doneCh chan struct{}
type chOperation struct {
ch chan<- *protocol.Envelope
done doneCh
}
type broadcaster struct { type broadcaster struct {
input chan *protocol.Envelope input chan *protocol.Envelope
reg chan chan<- *protocol.Envelope reg chan chOperation
unreg chan chan<- *protocol.Envelope unreg chan chOperation
outputs map[chan<- *protocol.Envelope]bool outputs map[chan<- *protocol.Envelope]bool
} }
@ -20,8 +27,12 @@ type broadcaster struct {
type Broadcaster interface { type Broadcaster interface {
// Register a new channel to receive broadcasts // Register a new channel to receive broadcasts
Register(chan<- *protocol.Envelope) Register(chan<- *protocol.Envelope)
// Register a new channel to receive broadcasts and return a channel to wait until this operation is complete
WaitRegister(newch chan<- *protocol.Envelope) doneCh
// Unregister a channel so that it no longer receives broadcasts. // Unregister a channel so that it no longer receives broadcasts.
Unregister(chan<- *protocol.Envelope) Unregister(chan<- *protocol.Envelope)
// Unregister a subscriptor channel and return a channel to wait until this operation is done
WaitUnregister(newch chan<- *protocol.Envelope) doneCh
// Shut this broadcaster down. // Shut this broadcaster down.
Close() Close()
// Submit a new object to all subscribers // Submit a new object to all subscribers
@ -39,14 +50,23 @@ func (b *broadcaster) run() {
select { select {
case m := <-b.input: case m := <-b.input:
b.broadcast(m) b.broadcast(m)
case ch, ok := <-b.reg: case broadcastee, ok := <-b.reg:
if ok { if ok {
b.outputs[ch] = true b.outputs[broadcastee.ch] = true
if broadcastee.done != nil {
broadcastee.done <- struct{}{}
}
} else { } else {
if broadcastee.done != nil {
broadcastee.done <- struct{}{}
}
return return
} }
case ch := <-b.unreg: case broadcastee := <-b.unreg:
delete(b.outputs, ch) delete(b.outputs, broadcastee.ch)
if broadcastee.done != nil {
broadcastee.done <- struct{}{}
}
} }
} }
} }
@ -57,8 +77,8 @@ func (b *broadcaster) run() {
func NewBroadcaster(buflen int) Broadcaster { func NewBroadcaster(buflen int) Broadcaster {
b := &broadcaster{ b := &broadcaster{
input: make(chan *protocol.Envelope, buflen), input: make(chan *protocol.Envelope, buflen),
reg: make(chan chan<- *protocol.Envelope), reg: make(chan chOperation),
unreg: make(chan chan<- *protocol.Envelope), unreg: make(chan chOperation),
outputs: make(map[chan<- *protocol.Envelope]bool), outputs: make(map[chan<- *protocol.Envelope]bool),
} }
@ -67,14 +87,40 @@ func NewBroadcaster(buflen int) Broadcaster {
return b return b
} }
// Register a subscriptor channel and return a channel to wait until this operation is done
func (b *broadcaster) WaitRegister(newch chan<- *protocol.Envelope) doneCh {
d := make(doneCh)
b.reg <- chOperation{
ch: newch,
done: d,
}
return d
}
// Register a subscriptor channel // Register a subscriptor channel
func (b *broadcaster) Register(newch chan<- *protocol.Envelope) { func (b *broadcaster) Register(newch chan<- *protocol.Envelope) {
b.reg <- newch b.reg <- chOperation{
ch: newch,
done: nil,
}
}
// Unregister a subscriptor channel and return a channel to wait until this operation is done
func (b *broadcaster) WaitUnregister(newch chan<- *protocol.Envelope) doneCh {
d := make(doneCh)
b.unreg <- chOperation{
ch: newch,
done: d,
}
return d
} }
// Unregister a subscriptor channel // Unregister a subscriptor channel
func (b *broadcaster) Unregister(newch chan<- *protocol.Envelope) { func (b *broadcaster) Unregister(newch chan<- *protocol.Envelope) {
b.unreg <- newch b.unreg <- chOperation{
ch: newch,
done: nil,
}
} }
// Closes the broadcaster. Used to stop receiving new subscribers // Closes the broadcaster. Used to stop receiving new subscribers

View File

@ -37,6 +37,33 @@ func TestBroadcast(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestBroadcastWait(t *testing.T) {
wg := sync.WaitGroup{}
b := NewBroadcaster(100)
defer b.Close()
for i := 0; i < 5; i++ {
wg.Add(1)
cch := make(chan *protocol.Envelope)
<-b.WaitRegister(cch)
go func() {
defer wg.Done()
<-cch
<-b.WaitUnregister(cch)
}()
}
env := new(protocol.Envelope)
b.Submit(env)
wg.Wait()
}
func TestBroadcastCleanup(t *testing.T) { func TestBroadcastCleanup(t *testing.T) {
b := NewBroadcaster(100) b := NewBroadcaster(100)
b.Register(make(chan *protocol.Envelope)) b.Register(make(chan *protocol.Envelope))

View File

@ -8,6 +8,8 @@ import (
// Subscription handles the subscrition to a particular pubsub topic // Subscription handles the subscrition to a particular pubsub topic
type Subscription struct { type Subscription struct {
sync.RWMutex
// C is channel used for receiving envelopes // C is channel used for receiving envelopes
C chan *protocol.Envelope C chan *protocol.Envelope
@ -19,14 +21,14 @@ type Subscription struct {
// Unsubscribe will close a subscription from a pubsub topic. Will close the message channel // Unsubscribe will close a subscription from a pubsub topic. Will close the message channel
func (subs *Subscription) Unsubscribe() { func (subs *Subscription) Unsubscribe() {
subs.once.Do(func() { subs.once.Do(func() {
subs.closed = true
close(subs.quit) close(subs.quit)
close(subs.C)
}) })
} }
// IsClosed determine whether a Subscription is still open for receiving messages // IsClosed determine whether a Subscription is still open for receiving messages
func (subs *Subscription) IsClosed() bool { func (subs *Subscription) IsClosed() bool {
subs.RLock()
defer subs.RUnlock()
return subs.closed return subs.closed
} }

View File

@ -1,18 +0,0 @@
package relay
import (
"testing"
waku_proto "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/stretchr/testify/require"
)
func TestSubscription(t *testing.T) {
e := Subscription{
closed: false,
C: make(chan *waku_proto.Envelope, 10),
quit: make(chan struct{}),
}
e.Unsubscribe()
require.True(t, e.closed)
}

View File

@ -300,9 +300,20 @@ func (w *WakuRelay) subscribeToTopic(t string, subscription *Subscription, sub *
for { for {
select { select {
case <-subscription.quit: case <-subscription.quit:
if w.bcaster != nil { func() {
w.bcaster.Unregister(subscription.C) // Remove from broadcast list subscription.Lock()
} defer subscription.Unlock()
if subscription.closed {
return
}
subscription.closed = true
if w.bcaster != nil {
<-w.bcaster.WaitUnregister(subscription.C) // Remove from broadcast list
}
close(subscription.C)
}()
// TODO: if there are no more relay subscriptions, close the pubsub subscription // TODO: if there are no more relay subscriptions, close the pubsub subscription
case msg := <-subChannel: case msg := <-subChannel:
if msg == nil { if msg == nil {