Fix race condition in subscriptions (#1646)
This commit is contained in:
parent
0571f561f0
commit
203f29b13e
|
@ -38,7 +38,7 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return SubscriptionID(""), fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
|
return "", fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.activeSubscriptions.Create(namespace, filter)
|
return api.activeSubscriptions.Create(namespace, filter)
|
||||||
|
@ -47,7 +47,3 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
|
||||||
func (api *API) UnsubscribeSignal(id string) error {
|
func (api *API) UnsubscribeSignal(id string) error {
|
||||||
return api.activeSubscriptions.Remove(SubscriptionID(id))
|
return api.activeSubscriptions.Remove(SubscriptionID(id))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) shutdown() error {
|
|
||||||
return api.activeSubscriptions.removeAll()
|
|
||||||
}
|
|
||||||
|
|
|
@ -47,5 +47,5 @@ func (s *Service) Start(server *p2p.Server) error {
|
||||||
|
|
||||||
// Stop is run when a service is stopped.
|
// Stop is run when a service is stopped.
|
||||||
func (s *Service) Stop() error {
|
func (s *Service) Stop() error {
|
||||||
return s.api.shutdown()
|
return s.api.activeSubscriptions.removeAll()
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,36 +3,41 @@ package subscriptions
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SubscriptionID string
|
type SubscriptionID string
|
||||||
|
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
|
mu sync.RWMutex
|
||||||
id SubscriptionID
|
id SubscriptionID
|
||||||
signal *filterSignal
|
signal *filterSignal
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
filter filter
|
filter filter
|
||||||
stopped bool
|
started bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSubscription(namespace string, filter filter) *Subscription {
|
func NewSubscription(namespace string, filter filter) *Subscription {
|
||||||
subscriptionID := NewSubscriptionID(namespace, filter.getID())
|
subscriptionID := NewSubscriptionID(namespace, filter.getID())
|
||||||
|
|
||||||
quit := make(chan struct{})
|
|
||||||
|
|
||||||
return &Subscription{
|
return &Subscription{
|
||||||
id: subscriptionID,
|
id: subscriptionID,
|
||||||
quit: quit,
|
|
||||||
signal: newFilterSignal(string(subscriptionID)),
|
signal: newFilterSignal(string(subscriptionID)),
|
||||||
filter: filter,
|
filter: filter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Subscription) Start(checkPeriod time.Duration) error {
|
func (s *Subscription) Start(checkPeriod time.Duration) error {
|
||||||
if s.stopped {
|
s.mu.Lock()
|
||||||
return errors.New("it is impossible to start an already stopped subscription")
|
if s.started {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return errors.New("subscription already started or used")
|
||||||
}
|
}
|
||||||
|
s.started = true
|
||||||
|
s.quit = make(chan struct{})
|
||||||
|
quit := s.quit
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
ticker := time.NewTicker(checkPeriod)
|
ticker := time.NewTicker(checkPeriod)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
@ -45,22 +50,31 @@ func (s *Subscription) Start(checkPeriod time.Duration) error {
|
||||||
} else if len(filterData) > 0 {
|
} else if len(filterData) > 0 {
|
||||||
s.signal.SendData(filterData)
|
s.signal.SendData(filterData)
|
||||||
}
|
}
|
||||||
case <-s.quit:
|
case <-quit:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Subscription) Stop(uninstall bool) error {
|
func (s *Subscription) Stop(uninstall bool) error {
|
||||||
if s.stopped {
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if !s.started {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
select {
|
||||||
|
case _, ok := <-s.quit:
|
||||||
|
// handle a case of a closed channel
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
default:
|
||||||
close(s.quit)
|
close(s.quit)
|
||||||
if uninstall {
|
|
||||||
return s.filter.uninstall()
|
|
||||||
}
|
}
|
||||||
s.stopped = true
|
if !uninstall {
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
return s.filter.uninstall()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSubscriptionID(namespace, filterID string) SubscriptionID {
|
func NewSubscriptionID(namespace, filterID string) SubscriptionID {
|
||||||
|
|
|
@ -46,7 +46,6 @@ func (s *Subscriptions) Remove(id SubscriptionID) error {
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
found, err := s.stopSubscription(id, true)
|
found, err := s.stopSubscription(id, true)
|
||||||
|
|
||||||
if found {
|
if found {
|
||||||
delete(s.subs, id)
|
delete(s.subs, id)
|
||||||
}
|
}
|
||||||
|
@ -76,13 +75,10 @@ func (s *Subscriptions) removeAll() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopSubscription isn't thread safe!
|
|
||||||
func (s *Subscriptions) stopSubscription(id SubscriptionID, uninstall bool) (bool, error) {
|
func (s *Subscriptions) stopSubscription(id SubscriptionID, uninstall bool) (bool, error) {
|
||||||
sub, found := s.subs[id]
|
sub, found := s.subs[id]
|
||||||
if !found {
|
if !found {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, sub.Stop(uninstall)
|
return true, sub.Stop(uninstall)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockFilter struct {
|
type mockFilter struct {
|
||||||
|
sync.Mutex
|
||||||
filterID string
|
filterID string
|
||||||
data []interface{}
|
data []interface{}
|
||||||
filterError error
|
filterError error
|
||||||
|
@ -31,9 +33,14 @@ func newMockFilter(filterID string) *mockFilter {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mf *mockFilter) getID() string {
|
func (mf *mockFilter) getID() string {
|
||||||
|
mf.Lock()
|
||||||
|
defer mf.Unlock()
|
||||||
return mf.filterID
|
return mf.filterID
|
||||||
}
|
}
|
||||||
func (mf *mockFilter) getChanges() ([]interface{}, error) {
|
func (mf *mockFilter) getChanges() ([]interface{}, error) {
|
||||||
|
mf.Lock()
|
||||||
|
defer mf.Unlock()
|
||||||
|
|
||||||
if mf.filterError != nil {
|
if mf.filterError != nil {
|
||||||
err := mf.filterError
|
err := mf.filterError
|
||||||
mf.filterError = nil
|
mf.filterError = nil
|
||||||
|
@ -46,15 +53,21 @@ func (mf *mockFilter) getChanges() ([]interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mf *mockFilter) uninstall() error {
|
func (mf *mockFilter) uninstall() error {
|
||||||
|
mf.Lock()
|
||||||
|
defer mf.Unlock()
|
||||||
mf.uninstalled = true
|
mf.uninstalled = true
|
||||||
return mf.uninstallError
|
return mf.uninstallError
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mf *mockFilter) setData(data ...interface{}) {
|
func (mf *mockFilter) setData(data ...interface{}) {
|
||||||
|
mf.Lock()
|
||||||
|
defer mf.Unlock()
|
||||||
mf.data = data
|
mf.data = data
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mf *mockFilter) setError(err error) {
|
func (mf *mockFilter) setError(err error) {
|
||||||
|
mf.Lock()
|
||||||
|
defer mf.Unlock()
|
||||||
mf.data = nil
|
mf.data = nil
|
||||||
mf.filterError = err
|
mf.filterError = err
|
||||||
}
|
}
|
||||||
|
@ -121,13 +134,13 @@ func TestSubscriptionGetError(t *testing.T) {
|
||||||
|
|
||||||
func TestSubscriptionRemove(t *testing.T) {
|
func TestSubscriptionRemove(t *testing.T) {
|
||||||
filter := newMockFilter(filterID)
|
filter := newMockFilter(filterID)
|
||||||
|
|
||||||
subs := NewSubscriptions(time.Microsecond)
|
subs := NewSubscriptions(time.Microsecond)
|
||||||
|
|
||||||
subID, _ := subs.Create(filterNS, filter)
|
subID, err := subs.Create(filterNS, filter)
|
||||||
|
require.NoError(t, err)
|
||||||
|
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
|
||||||
|
|
||||||
require.NoError(t, subs.Remove(subID))
|
require.NoError(t, subs.Remove(subID))
|
||||||
|
|
||||||
require.True(t, filter.uninstalled)
|
require.True(t, filter.uninstalled)
|
||||||
require.Empty(t, subs.subs)
|
require.Empty(t, subs.subs)
|
||||||
}
|
}
|
||||||
|
@ -137,11 +150,11 @@ func TestSubscriptionRemoveError(t *testing.T) {
|
||||||
filter.uninstallError = errors.New("uninstall-error-1")
|
filter.uninstallError = errors.New("uninstall-error-1")
|
||||||
|
|
||||||
subs := NewSubscriptions(time.Microsecond)
|
subs := NewSubscriptions(time.Microsecond)
|
||||||
|
subID, err := subs.Create(filterNS, filter)
|
||||||
subID, _ := subs.Create(filterNS, filter)
|
require.NoError(t, err)
|
||||||
|
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
|
||||||
|
|
||||||
require.Equal(t, subs.Remove(subID), filter.uninstallError)
|
require.Equal(t, subs.Remove(subID), filter.uninstallError)
|
||||||
|
|
||||||
require.True(t, filter.uninstalled)
|
require.True(t, filter.uninstalled)
|
||||||
require.Equal(t, len(subs.subs), 0)
|
require.Equal(t, len(subs.subs), 0)
|
||||||
}
|
}
|
||||||
|
@ -155,14 +168,13 @@ func TestSubscriptionRemoveAll(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = subs.Create(filterNS, filter1)
|
_, err = subs.Create(filterNS, filter1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
|
||||||
|
|
||||||
require.Equal(t, len(subs.subs), 2)
|
require.Equal(t, len(subs.subs), 2)
|
||||||
|
err = subs.removeAll()
|
||||||
require.NoError(t, subs.removeAll())
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.False(t, filter0.uninstalled)
|
require.False(t, filter0.uninstalled)
|
||||||
require.False(t, filter1.uninstalled)
|
require.False(t, filter1.uninstalled)
|
||||||
|
|
||||||
require.Equal(t, len(subs.subs), 0)
|
require.Equal(t, len(subs.subs), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,9 +183,7 @@ func validateFilterError(t *testing.T, jsonEvent string, expectedSubID string, e
|
||||||
Event signal.SubscriptionErrorEvent `json:"event"`
|
Event signal.SubscriptionErrorEvent `json:"event"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
|
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
|
||||||
|
|
||||||
require.Equal(t, signal.EventSubscriptionsError, result.Type)
|
require.Equal(t, signal.EventSubscriptionsError, result.Type)
|
||||||
require.Equal(t, expectedErrorMessage, result.Event.ErrorMessage)
|
require.Equal(t, expectedErrorMessage, result.Event.ErrorMessage)
|
||||||
}
|
}
|
||||||
|
@ -183,11 +193,8 @@ func validateFilterData(t *testing.T, jsonEvent string, expectedSubID string, ex
|
||||||
Event signal.SubscriptionDataEvent `json:"event"`
|
Event signal.SubscriptionDataEvent `json:"event"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
|
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
|
||||||
|
|
||||||
require.Equal(t, signal.EventSubscriptionsData, result.Type)
|
require.Equal(t, signal.EventSubscriptionsData, result.Type)
|
||||||
require.Equal(t, expectedData, result.Event.Data)
|
require.Equal(t, expectedData, result.Event.Data)
|
||||||
require.Equal(t, expectedSubID, result.Event.FilterID)
|
require.Equal(t, expectedSubID, result.Event.FilterID)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue