Fix race condition in subscriptions (#1646)
This commit is contained in:
parent
0571f561f0
commit
203f29b13e
|
@ -38,7 +38,7 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return SubscriptionID(""), fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
|
||||
return "", fmt.Errorf("[SubscribeSignal] could not subscribe, failed to call %s: %v", method, err)
|
||||
}
|
||||
|
||||
return api.activeSubscriptions.Create(namespace, filter)
|
||||
|
@ -47,7 +47,3 @@ func (api *API) SubscribeSignal(method string, args []interface{}) (Subscription
|
|||
func (api *API) UnsubscribeSignal(id string) error {
|
||||
return api.activeSubscriptions.Remove(SubscriptionID(id))
|
||||
}
|
||||
|
||||
func (api *API) shutdown() error {
|
||||
return api.activeSubscriptions.removeAll()
|
||||
}
|
||||
|
|
|
@ -47,5 +47,5 @@ func (s *Service) Start(server *p2p.Server) error {
|
|||
|
||||
// Stop is run when a service is stopped.
|
||||
func (s *Service) Stop() error {
|
||||
return s.api.shutdown()
|
||||
return s.api.activeSubscriptions.removeAll()
|
||||
}
|
||||
|
|
|
@ -3,36 +3,41 @@ package subscriptions
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SubscriptionID string
|
||||
|
||||
type Subscription struct {
|
||||
mu sync.RWMutex
|
||||
id SubscriptionID
|
||||
signal *filterSignal
|
||||
quit chan struct{}
|
||||
filter filter
|
||||
stopped bool
|
||||
started bool
|
||||
}
|
||||
|
||||
func NewSubscription(namespace string, filter filter) *Subscription {
|
||||
subscriptionID := NewSubscriptionID(namespace, filter.getID())
|
||||
|
||||
quit := make(chan struct{})
|
||||
|
||||
return &Subscription{
|
||||
id: subscriptionID,
|
||||
quit: quit,
|
||||
signal: newFilterSignal(string(subscriptionID)),
|
||||
filter: filter,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) Start(checkPeriod time.Duration) error {
|
||||
if s.stopped {
|
||||
return errors.New("it is impossible to start an already stopped subscription")
|
||||
s.mu.Lock()
|
||||
if s.started {
|
||||
s.mu.Unlock()
|
||||
return errors.New("subscription already started or used")
|
||||
}
|
||||
s.started = true
|
||||
s.quit = make(chan struct{})
|
||||
quit := s.quit
|
||||
s.mu.Unlock()
|
||||
|
||||
ticker := time.NewTicker(checkPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
@ -45,22 +50,31 @@ func (s *Subscription) Start(checkPeriod time.Duration) error {
|
|||
} else if len(filterData) > 0 {
|
||||
s.signal.SendData(filterData)
|
||||
}
|
||||
case <-s.quit:
|
||||
case <-quit:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) Stop(uninstall bool) error {
|
||||
if s.stopped {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.started {
|
||||
return nil
|
||||
}
|
||||
close(s.quit)
|
||||
if uninstall {
|
||||
return s.filter.uninstall()
|
||||
select {
|
||||
case _, ok := <-s.quit:
|
||||
// handle a case of a closed channel
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
close(s.quit)
|
||||
}
|
||||
s.stopped = true
|
||||
return nil
|
||||
if !uninstall {
|
||||
return nil
|
||||
}
|
||||
return s.filter.uninstall()
|
||||
}
|
||||
|
||||
func NewSubscriptionID(namespace, filterID string) SubscriptionID {
|
||||
|
|
|
@ -46,7 +46,6 @@ func (s *Subscriptions) Remove(id SubscriptionID) error {
|
|||
defer s.mu.Unlock()
|
||||
|
||||
found, err := s.stopSubscription(id, true)
|
||||
|
||||
if found {
|
||||
delete(s.subs, id)
|
||||
}
|
||||
|
@ -76,13 +75,10 @@ func (s *Subscriptions) removeAll() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// stopSubscription isn't thread safe!
|
||||
func (s *Subscriptions) stopSubscription(id SubscriptionID, uninstall bool) (bool, error) {
|
||||
sub, found := s.subs[id]
|
||||
if !found {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, sub.Stop(uninstall)
|
||||
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -17,6 +18,7 @@ const (
|
|||
)
|
||||
|
||||
type mockFilter struct {
|
||||
sync.Mutex
|
||||
filterID string
|
||||
data []interface{}
|
||||
filterError error
|
||||
|
@ -31,9 +33,14 @@ func newMockFilter(filterID string) *mockFilter {
|
|||
}
|
||||
|
||||
func (mf *mockFilter) getID() string {
|
||||
mf.Lock()
|
||||
defer mf.Unlock()
|
||||
return mf.filterID
|
||||
}
|
||||
func (mf *mockFilter) getChanges() ([]interface{}, error) {
|
||||
mf.Lock()
|
||||
defer mf.Unlock()
|
||||
|
||||
if mf.filterError != nil {
|
||||
err := mf.filterError
|
||||
mf.filterError = nil
|
||||
|
@ -46,15 +53,21 @@ func (mf *mockFilter) getChanges() ([]interface{}, error) {
|
|||
}
|
||||
|
||||
func (mf *mockFilter) uninstall() error {
|
||||
mf.Lock()
|
||||
defer mf.Unlock()
|
||||
mf.uninstalled = true
|
||||
return mf.uninstallError
|
||||
}
|
||||
|
||||
func (mf *mockFilter) setData(data ...interface{}) {
|
||||
mf.Lock()
|
||||
defer mf.Unlock()
|
||||
mf.data = data
|
||||
}
|
||||
|
||||
func (mf *mockFilter) setError(err error) {
|
||||
mf.Lock()
|
||||
defer mf.Unlock()
|
||||
mf.data = nil
|
||||
mf.filterError = err
|
||||
}
|
||||
|
@ -121,13 +134,13 @@ func TestSubscriptionGetError(t *testing.T) {
|
|||
|
||||
func TestSubscriptionRemove(t *testing.T) {
|
||||
filter := newMockFilter(filterID)
|
||||
|
||||
subs := NewSubscriptions(time.Microsecond)
|
||||
|
||||
subID, _ := subs.Create(filterNS, filter)
|
||||
subID, err := subs.Create(filterNS, filter)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
|
||||
|
||||
require.NoError(t, subs.Remove(subID))
|
||||
|
||||
require.True(t, filter.uninstalled)
|
||||
require.Empty(t, subs.subs)
|
||||
}
|
||||
|
@ -137,11 +150,11 @@ func TestSubscriptionRemoveError(t *testing.T) {
|
|||
filter.uninstallError = errors.New("uninstall-error-1")
|
||||
|
||||
subs := NewSubscriptions(time.Microsecond)
|
||||
|
||||
subID, _ := subs.Create(filterNS, filter)
|
||||
subID, err := subs.Create(filterNS, filter)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
|
||||
|
||||
require.Equal(t, subs.Remove(subID), filter.uninstallError)
|
||||
|
||||
require.True(t, filter.uninstalled)
|
||||
require.Equal(t, len(subs.subs), 0)
|
||||
}
|
||||
|
@ -155,14 +168,13 @@ func TestSubscriptionRemoveAll(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = subs.Create(filterNS, filter1)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond * 100) // create starts in a goroutine
|
||||
|
||||
require.Equal(t, len(subs.subs), 2)
|
||||
|
||||
require.NoError(t, subs.removeAll())
|
||||
|
||||
err = subs.removeAll()
|
||||
require.NoError(t, err)
|
||||
require.False(t, filter0.uninstalled)
|
||||
require.False(t, filter1.uninstalled)
|
||||
|
||||
require.Equal(t, len(subs.subs), 0)
|
||||
}
|
||||
|
||||
|
@ -171,9 +183,7 @@ func validateFilterError(t *testing.T, jsonEvent string, expectedSubID string, e
|
|||
Event signal.SubscriptionErrorEvent `json:"event"`
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
|
||||
|
||||
require.Equal(t, signal.EventSubscriptionsError, result.Type)
|
||||
require.Equal(t, expectedErrorMessage, result.Event.ErrorMessage)
|
||||
}
|
||||
|
@ -183,11 +193,8 @@ func validateFilterData(t *testing.T, jsonEvent string, expectedSubID string, ex
|
|||
Event signal.SubscriptionDataEvent `json:"event"`
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonEvent), &result))
|
||||
|
||||
require.Equal(t, signal.EventSubscriptionsData, result.Type)
|
||||
require.Equal(t, expectedData, result.Event.Data)
|
||||
require.Equal(t, expectedSubID, result.Event.FilterID)
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue