Fix race condition in subscriptions (#1646)

This commit is contained in:
Adam Babik 2019-12-11 09:44:57 +01:00 committed by GitHub
parent 0571f561f0
commit 203f29b13e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 39 deletions

View File

@ -38,7 +38,7 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
} }
if err != nil { if err != nil {
return SubscriptionID(""), fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err) return "", fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
} }
return api.activeSubscriptions.Create(namespace, filter) return api.activeSubscriptions.Create(namespace, filter)
@ -47,7 +47,3 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
func (api *API) UnsubscribeSignal(id string) error { func (api *API) UnsubscribeSignal(id string) error {
return api.activeSubscriptions.Remove(SubscriptionID(id)) return api.activeSubscriptions.Remove(SubscriptionID(id))
} }
func (api *API) shutdown() error {
return api.activeSubscriptions.removeAll()
}

View File

@ -47,5 +47,5 @@ func (s *Service) Start(server *p2p.Server) error {
// Stop is run when a service is stopped. // Stop is run when a service is stopped.
func (s *Service) Stop() error { func (s *Service) Stop() error {
return s.api.shutdown() return s.api.activeSubscriptions.removeAll()
} }

View File

@ -3,36 +3,41 @@ package subscriptions
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"time" "time"
) )
type SubscriptionID string type SubscriptionID string
type Subscription struct { type Subscription struct {
mu sync.RWMutex
id SubscriptionID id SubscriptionID
signal *filterSignal signal *filterSignal
quit chan struct{} quit chan struct{}
filter filter filter filter
stopped bool started bool
} }
func NewSubscription(namespace string, filter filter) *Subscription { func NewSubscription(namespace string, filter filter) *Subscription {
subscriptionID := NewSubscriptionID(namespace, filter.getID()) subscriptionID := NewSubscriptionID(namespace, filter.getID())
quit := make(chan struct{})
return &Subscription{ return &Subscription{
id: subscriptionID, id: subscriptionID,
quit: quit,
signal: newFilterSignal(string(subscriptionID)), signal: newFilterSignal(string(subscriptionID)),
filter: filter, filter: filter,
} }
} }
func (s *Subscription) Start(checkPeriod time.Duration) error { func (s *Subscription) Start(checkPeriod time.Duration) error {
if s.stopped { s.mu.Lock()
return errors.New("it is impossible to start an already stopped subscription") if s.started {
s.mu.Unlock()
return errors.New("subscription already started or used")
} }
s.started = true
s.quit = make(chan struct{})
quit := s.quit
s.mu.Unlock()
ticker := time.NewTicker(checkPeriod) ticker := time.NewTicker(checkPeriod)
defer ticker.Stop() defer ticker.Stop()
@ -45,22 +50,31 @@ func (s *Subscription) Start(checkPeriod time.Duration) error {
} else if len(filterData) > 0 { } else if len(filterData) > 0 {
s.signal.SendData(filterData) s.signal.SendData(filterData)
} }
case <-s.quit: case <-quit:
return nil return nil
} }
} }
} }
func (s *Subscription) Stop(uninstall bool) error { func (s *Subscription) Stop(uninstall bool) error {
if s.stopped { s.mu.Lock()
defer s.mu.Unlock()
if !s.started {
return nil return nil
} }
close(s.quit) select {
if uninstall { case _, ok := <-s.quit:
return s.filter.uninstall() // handle a case of a closed channel
if !ok {
return nil
}
default:
close(s.quit)
} }
s.stopped = true if !uninstall {
return nil return nil
}
return s.filter.uninstall()
} }
func NewSubscriptionID(namespace, filterID string) SubscriptionID { func NewSubscriptionID(namespace, filterID string) SubscriptionID {

View File

@ -46,7 +46,6 @@ func (s *Subscriptions) Remove(id SubscriptionID) error {
defer s.mu.Unlock() defer s.mu.Unlock()
found, err := s.stopSubscription(id, true) found, err := s.stopSubscription(id, true)
if found { if found {
delete(s.subs, id) delete(s.subs, id)
} }
@ -76,13 +75,10 @@ func (s *Subscriptions) removeAll() error {
return nil return nil
} }
// stopSubscription isn't thread safe!
func (s *Subscriptions) stopSubscription(id SubscriptionID, uninstall bool) (bool, error) { func (s *Subscriptions) stopSubscription(id SubscriptionID, uninstall bool) (bool, error) {
sub, found := s.subs[id] sub, found := s.subs[id]
if !found { if !found {
return false, nil return false, nil
} }
return true, sub.Stop(uninstall) return true, sub.Stop(uninstall)
} }

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"sync"
"testing" "testing"
"time" "time"
@ -17,6 +18,7 @@ const (
) )
type mockFilter struct { type mockFilter struct {
sync.Mutex
filterID string filterID string
data []interface{} data []interface{}
filterError error filterError error
@ -31,9 +33,14 @@ func newMockFilter(filterID string) *mockFilter {
} }
func (mf *mockFilter) getID() string { func (mf *mockFilter) getID() string {
mf.Lock()
defer mf.Unlock()
return mf.filterID return mf.filterID
} }
func (mf *mockFilter) getChanges() ([]interface{}, error) { func (mf *mockFilter) getChanges() ([]interface{}, error) {
mf.Lock()
defer mf.Unlock()
if mf.filterError != nil { if mf.filterError != nil {
err := mf.filterError err := mf.filterError
mf.filterError = nil mf.filterError = nil
@ -46,15 +53,21 @@ func (mf *mockFilter) getChanges() ([]interface{}, error) {
} }
func (mf *mockFilter) uninstall() error { func (mf *mockFilter) uninstall() error {
mf.Lock()
defer mf.Unlock()
mf.uninstalled = true mf.uninstalled = true
return mf.uninstallError return mf.uninstallError
} }
func (mf *mockFilter) setData(data ...interface{}) { func (mf *mockFilter) setData(data ...interface{}) {
mf.Lock()
defer mf.Unlock()
mf.data = data mf.data = data
} }
func (mf *mockFilter) setError(err error) { func (mf *mockFilter) setError(err error) {
mf.Lock()
defer mf.Unlock()
mf.data = nil mf.data = nil
mf.filterError = err mf.filterError = err
} }
@ -121,13 +134,13 @@ func TestSubscriptionGetError(t *testing.T) {
func TestSubscriptionRemove(t *testing.T) { func TestSubscriptionRemove(t *testing.T) {
filter := newMockFilter(filterID) filter := newMockFilter(filterID)
subs := NewSubscriptions(time.Microsecond) subs := NewSubscriptions(time.Microsecond)
subID, _ := subs.Create(filterNS, filter) subID, err := subs.Create(filterNS, filter)
require.NoError(t, err)
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
require.NoError(t, subs.Remove(subID)) require.NoError(t, subs.Remove(subID))
require.True(t, filter.uninstalled) require.True(t, filter.uninstalled)
require.Empty(t, subs.subs) require.Empty(t, subs.subs)
} }
@ -137,11 +150,11 @@ func TestSubscriptionRemoveError(t *testing.T) {
filter.uninstallError = errors.New("uninstall-error-1") filter.uninstallError = errors.New("uninstall-error-1")
subs := NewSubscriptions(time.Microsecond) subs := NewSubscriptions(time.Microsecond)
subID, err := subs.Create(filterNS, filter)
subID, _ := subs.Create(filterNS, filter) require.NoError(t, err)
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
require.Equal(t, subs.Remove(subID), filter.uninstallError) require.Equal(t, subs.Remove(subID), filter.uninstallError)
require.True(t, filter.uninstalled) require.True(t, filter.uninstalled)
require.Equal(t, len(subs.subs), 0) require.Equal(t, len(subs.subs), 0)
} }
@ -155,14 +168,13 @@ func TestSubscriptionRemoveAll(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = subs.Create(filterNS, filter1) _, err = subs.Create(filterNS, filter1)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
require.Equal(t, len(subs.subs), 2) require.Equal(t, len(subs.subs), 2)
err = subs.removeAll()
require.NoError(t, subs.removeAll()) require.NoError(t, err)
require.False(t, filter0.uninstalled) require.False(t, filter0.uninstalled)
require.False(t, filter1.uninstalled) require.False(t, filter1.uninstalled)
require.Equal(t, len(subs.subs), 0) require.Equal(t, len(subs.subs), 0)
} }
@ -171,9 +183,7 @@ func validateFilterError(t *testing.T, jsonEvent string, expectedSubID string, e
Event signal.SubscriptionErrorEvent `json:"event"` Event signal.SubscriptionErrorEvent `json:"event"`
Type string `json:"type"` Type string `json:"type"`
}{} }{}
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result)) require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
require.Equal(t, signal.EventSubscriptionsError, result.Type) require.Equal(t, signal.EventSubscriptionsError, result.Type)
require.Equal(t, expectedErrorMessage, result.Event.ErrorMessage) require.Equal(t, expectedErrorMessage, result.Event.ErrorMessage)
} }
@ -183,11 +193,8 @@ func validateFilterData(t *testing.T, jsonEvent string, expectedSubID string, ex
Event signal.SubscriptionDataEvent `json:"event"` Event signal.SubscriptionDataEvent `json:"event"`
Type string `json:"type"` Type string `json:"type"`
}{} }{}
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result)) require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
require.Equal(t, signal.EventSubscriptionsData, result.Type) require.Equal(t, signal.EventSubscriptionsData, result.Type)
require.Equal(t, expectedData, result.Event.Data) require.Equal(t, expectedData, result.Event.Data)
require.Equal(t, expectedSubID, result.Event.FilterID) require.Equal(t, expectedSubID, result.Event.FilterID)
} }