Fix race condition in subscriptions (#1646)

This commit is contained in:
Adam Babik 2019-12-11 09:44:57 +01:00 committed by GitHub
parent 0571f561f0
commit 203f29b13e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 39 deletions

View File

@ -38,7 +38,7 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
}
if err != nil {
return SubscriptionID(""), fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
return "", fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
}
return api.activeSubscriptions.Create(namespace, filter)
@ -47,7 +47,3 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
func (api *API) UnsubscribeSignal(id string) error {
return api.activeSubscriptions.Remove(SubscriptionID(id))
}
func (api *API) shutdown() error {
return api.activeSubscriptions.removeAll()
}

View File

@ -47,5 +47,5 @@ func (s *Service) Start(server *p2p.Server) error {
// Stop is run when a service is stopped.
func (s *Service) Stop() error {
return s.api.shutdown()
return s.api.activeSubscriptions.removeAll()
}

View File

@ -3,36 +3,41 @@ package subscriptions
import (
"errors"
"fmt"
"sync"
"time"
)
type SubscriptionID string
type Subscription struct {
mu sync.RWMutex
id SubscriptionID
signal *filterSignal
quit chan struct{}
filter filter
stopped bool
started bool
}
func NewSubscription(namespace string, filter filter) *Subscription {
subscriptionID := NewSubscriptionID(namespace, filter.getID())
quit := make(chan struct{})
return &Subscription{
id: subscriptionID,
quit: quit,
signal: newFilterSignal(string(subscriptionID)),
filter: filter,
}
}
func (s *Subscription) Start(checkPeriod time.Duration) error {
if s.stopped {
return errors.New("it is impossible to start an already stopped subscription")
s.mu.Lock()
if s.started {
s.mu.Unlock()
return errors.New("subscription already started or used")
}
s.started = true
s.quit = make(chan struct{})
quit := s.quit
s.mu.Unlock()
ticker := time.NewTicker(checkPeriod)
defer ticker.Stop()
@ -45,22 +50,31 @@ func (s *Subscription) Start(checkPeriod time.Duration) error {
} else if len(filterData) > 0 {
s.signal.SendData(filterData)
}
case <-s.quit:
case <-quit:
return nil
}
}
}
func (s *Subscription) Stop(uninstall bool) error {
if s.stopped {
s.mu.Lock()
defer s.mu.Unlock()
if !s.started {
return nil
}
close(s.quit)
if uninstall {
return s.filter.uninstall()
select {
case _, ok := <-s.quit:
// handle a case of a closed channel
if !ok {
return nil
}
default:
close(s.quit)
}
s.stopped = true
return nil
if !uninstall {
return nil
}
return s.filter.uninstall()
}
func NewSubscriptionID(namespace, filterID string) SubscriptionID {

View File

@ -46,7 +46,6 @@ func (s *Subscriptions) Remove(id SubscriptionID) error {
defer s.mu.Unlock()
found, err := s.stopSubscription(id, true)
if found {
delete(s.subs, id)
}
@ -76,13 +75,10 @@ func (s *Subscriptions) removeAll() error {
return nil
}
// stopSubscription isn't thread safe!
func (s *Subscriptions) stopSubscription(id SubscriptionID, uninstall bool) (bool, error) {
sub, found := s.subs[id]
if !found {
return false, nil
}
return true, sub.Stop(uninstall)
}

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"testing"
"time"
@ -17,6 +18,7 @@ const (
)
type mockFilter struct {
sync.Mutex
filterID string
data []interface{}
filterError error
@ -31,9 +33,14 @@ func newMockFilter(filterID string) *mockFilter {
}
func (mf *mockFilter) getID() string {
mf.Lock()
defer mf.Unlock()
return mf.filterID
}
func (mf *mockFilter) getChanges() ([]interface{}, error) {
mf.Lock()
defer mf.Unlock()
if mf.filterError != nil {
err := mf.filterError
mf.filterError = nil
@ -46,15 +53,21 @@ func (mf *mockFilter) getChanges() ([]interface{}, error) {
}
func (mf *mockFilter) uninstall() error {
mf.Lock()
defer mf.Unlock()
mf.uninstalled = true
return mf.uninstallError
}
func (mf *mockFilter) setData(data ...interface{}) {
mf.Lock()
defer mf.Unlock()
mf.data = data
}
func (mf *mockFilter) setError(err error) {
mf.Lock()
defer mf.Unlock()
mf.data = nil
mf.filterError = err
}
@ -121,13 +134,13 @@ func TestSubscriptionGetError(t *testing.T) {
func TestSubscriptionRemove(t *testing.T) {
filter := newMockFilter(filterID)
subs := NewSubscriptions(time.Microsecond)
subID, _ := subs.Create(filterNS, filter)
subID, err := subs.Create(filterNS, filter)
require.NoError(t, err)
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
require.NoError(t, subs.Remove(subID))
require.True(t, filter.uninstalled)
require.Empty(t, subs.subs)
}
@ -137,11 +150,11 @@ func TestSubscriptionRemoveError(t *testing.T) {
filter.uninstallError = errors.New("uninstall-error-1")
subs := NewSubscriptions(time.Microsecond)
subID, _ := subs.Create(filterNS, filter)
subID, err := subs.Create(filterNS, filter)
require.NoError(t, err)
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
require.Equal(t, subs.Remove(subID), filter.uninstallError)
require.True(t, filter.uninstalled)
require.Equal(t, len(subs.subs), 0)
}
@ -155,14 +168,13 @@ func TestSubscriptionRemoveAll(t *testing.T) {
require.NoError(t, err)
_, err = subs.Create(filterNS, filter1)
require.NoError(t, err)
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
require.Equal(t, len(subs.subs), 2)
require.NoError(t, subs.removeAll())
err = subs.removeAll()
require.NoError(t, err)
require.False(t, filter0.uninstalled)
require.False(t, filter1.uninstalled)
require.Equal(t, len(subs.subs), 0)
}
@ -171,9 +183,7 @@ func validateFilterError(t *testing.T, jsonEvent string, expectedSubID string, e
Event signal.SubscriptionErrorEvent `json:"event"`
Type string `json:"type"`
}{}
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
require.Equal(t, signal.EventSubscriptionsError, result.Type)
require.Equal(t, expectedErrorMessage, result.Event.ErrorMessage)
}
@ -183,11 +193,8 @@ func validateFilterData(t *testing.T, jsonEvent string, expectedSubID string, ex
Event signal.SubscriptionDataEvent `json:"event"`
Type string `json:"type"`
}{}
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
require.Equal(t, signal.EventSubscriptionsData, result.Type)
require.Equal(t, expectedData, result.Event.Data)
require.Equal(t, expectedSubID, result.Event.FilterID)
}