From 8a45365f68147dd13c1232bc574bd54e4a12e478 Mon Sep 17 00:00:00 2001 From: Frank Schroeder Date: Thu, 19 Oct 2017 16:39:16 +0200 Subject: [PATCH] ae: refactor StateSyncer to state machine for better testing --- agent/ae/ae.go | 235 ++++++++++++++++++++++++++++++-------------- agent/ae/ae_test.go | 200 ++++++++++++++++++++++++++++++++----- 2 files changed, 336 insertions(+), 99 deletions(-) diff --git a/agent/ae/ae.go b/agent/ae/ae.go index 83895f4d2b..a6851e0198 100644 --- a/agent/ae/ae.go +++ b/agent/ae/ae.go @@ -2,7 +2,7 @@ package ae import ( - "errors" + "fmt" "log" "math" "sync" @@ -37,7 +37,7 @@ func scaleFactor(nodes int) int { return int(math.Ceil(math.Log2(float64(nodes))-math.Log2(float64(scaleThreshold))) + 1.0) } -type State interface { +type SyncState interface { SyncChanges() error SyncFull() error } @@ -51,7 +51,7 @@ type State interface { // for the cluster which is also called anti-entropy. type StateSyncer struct { // State contains the data that needs to be synchronized. - State State + State SyncState // Interval is the time between two full sync runs. Interval time.Duration @@ -79,15 +79,23 @@ type StateSyncer struct { pauseLock sync.Mutex paused int - // stagger randomly picks a duration between 0s and the given duration. - stagger func(time.Duration) time.Duration - // serverUpInterval is the max time after which a full sync is // performed when a server has been added to the cluster. serverUpInterval time.Duration // retryFailInterval is the time after which a failed full sync is retried. retryFailInterval time.Duration + + // stagger randomly picks a duration between 0s and the given duration. + stagger func(time.Duration) time.Duration + + // retrySyncFullEvent generates an event based on multiple conditions + // when the state machine is trying to retry a full state sync. + retrySyncFullEvent func() event + + // syncChangesEvent generates an event based on multiple conditions + // when the state machine is performing partial state syncs. + syncChangesEvent func() event } const ( @@ -99,7 +107,7 @@ const ( retryFailIntv = 15 * time.Second ) -func NewStateSyncer(state State, intv time.Duration, shutdownCh chan struct{}, logger *log.Logger) *StateSyncer { +func NewStateSyncer(state SyncState, intv time.Duration, shutdownCh chan struct{}, logger *log.Logger) *StateSyncer { s := &StateSyncer{ State: state, Interval: intv, @@ -110,14 +118,25 @@ func NewStateSyncer(state State, intv time.Duration, shutdownCh chan struct{}, l serverUpInterval: serverUpIntv, retryFailInterval: retryFailIntv, } - s.stagger = func(d time.Duration) time.Duration { - f := scaleFactor(s.ClusterSize()) - return lib.RandomStagger(time.Duration(f) * d) - } + + // retain these methods as member variables so that + // we can mock them for testing. + s.retrySyncFullEvent = s.retrySyncFullEventFn + s.syncChangesEvent = s.syncChangesEventFn + s.stagger = s.staggerFn + return s } -var errPaused = errors.New("paused") +// fsmState defines states for the state machine. +type fsmState string + +const ( + doneState fsmState = "done" + fullSyncState fsmState = "fullSync" + partialSyncState fsmState = "partialSync" + retryFullSyncState fsmState = "retryFullSync" +) // Run is the long running method to perform state synchronization // between local and remote servers. @@ -125,77 +144,141 @@ func (s *StateSyncer) Run() { if s.ClusterSize == nil { panic("ClusterSize not set") } + s.runFSM(fullSyncState, s.nextFSMState) +} -FullSync: +// runFSM runs the state machine. +func (s *StateSyncer) runFSM(fs fsmState, next func(fsmState) fsmState) { for { - // attempt a full sync - err := s.ifNotPausedRun(s.State.SyncFull) - if err != nil { - if err != errPaused { - s.Logger.Printf("[ERR] agent: failed to sync remote state: %v", err) - } - - select { - // trigger a full sync immediately. - // this is usually called when a consul server was added to the cluster. - // stagger the delay to avoid a thundering herd. - case <-s.SyncFull.Notif(): - select { - case <-time.After(s.stagger(s.serverUpInterval)): - continue FullSync - case <-s.ShutdownCh: - return - } - - // retry full sync after some time - // todo(fs): why don't we use s.Interval here? - case <-time.After(s.retryFailInterval + s.stagger(s.retryFailInterval)): - continue FullSync - - case <-s.ShutdownCh: - return - } - } - - // do partial syncs until it is time for a full sync again - for { - select { - // trigger a full sync immediately - // this is usually called when a consul server was added to the cluster. - // stagger the delay to avoid a thundering herd. - case <-s.SyncFull.Notif(): - select { - case <-time.After(s.stagger(s.serverUpInterval)): - continue FullSync - case <-s.ShutdownCh: - return - } - - // time for a full sync again - case <-time.After(s.Interval + s.stagger(s.Interval)): - continue FullSync - - // do partial syncs on demand - case <-s.SyncChanges.Notif(): - err := s.ifNotPausedRun(s.State.SyncChanges) - if err != nil && err != errPaused { - s.Logger.Printf("[ERR] agent: failed to sync changes: %v", err) - } - - case <-s.ShutdownCh: - return - } + if fs = next(fs); fs == doneState { + return } } } -func (s *StateSyncer) ifNotPausedRun(f func() error) error { - s.pauseLock.Lock() - defer s.pauseLock.Unlock() - if s.paused != 0 { - return errPaused +// nextFSMState determines the next state based on the current state. +func (s *StateSyncer) nextFSMState(fs fsmState) fsmState { + switch fs { + case fullSyncState: + if s.Paused() { + return retryFullSyncState + } + + err := s.State.SyncFull() + if err != nil { + s.Logger.Printf("[ERR] agent: failed to sync remote state: %v", err) + return retryFullSyncState + } + + return partialSyncState + + case retryFullSyncState: + e := s.retrySyncFullEvent() + switch e { + case syncFullNotifEvent, syncFullTimerEvent: + return fullSyncState + case shutdownEvent: + return doneState + default: + panic(fmt.Sprintf("invalid event: %s", e)) + } + + case partialSyncState: + e := s.syncChangesEvent() + switch e { + case syncFullNotifEvent, syncFullTimerEvent: + return fullSyncState + + case syncChangesNotifEvent: + if s.Paused() { + return partialSyncState + } + + err := s.State.SyncChanges() + if err != nil { + s.Logger.Printf("[ERR] agent: failed to sync changes: %v", err) + } + return partialSyncState + + case shutdownEvent: + return doneState + + default: + panic(fmt.Sprintf("invalid event: %s", e)) + } + + default: + panic(fmt.Sprintf("invalid state: %s", fs)) } - return f() +} + +// event defines a timing or notification event from a multiple +// timers and channels. +type event string + +const ( + shutdownEvent event = "shutdown" + syncFullNotifEvent event = "syncFullNotif" + syncFullTimerEvent event = "syncFullTimer" + syncChangesNotifEvent event = "syncChangesNotif" +) + +// retrySyncFullEventFn waits for an event which triggers a retry +// of a full sync or a termination signal. +func (s *StateSyncer) retrySyncFullEventFn() event { + select { + // trigger a full sync immediately. + // this is usually called when a consul server was added to the cluster. + // stagger the delay to avoid a thundering herd. + case <-s.SyncFull.Notif(): + select { + case <-time.After(s.stagger(s.serverUpInterval)): + return syncFullNotifEvent + case <-s.ShutdownCh: + return shutdownEvent + } + + // retry full sync after some time + // todo(fs): why don't we use s.Interval here? + case <-time.After(s.retryFailInterval + s.stagger(s.retryFailInterval)): + return syncFullTimerEvent + + case <-s.ShutdownCh: + return shutdownEvent + } +} + +// syncChangesEventFn waits for a event which either triggers a full +// or a partial sync or a termination signal. +func (s *StateSyncer) syncChangesEventFn() event { + select { + // trigger a full sync immediately + // this is usually called when a consul server was added to the cluster. + // stagger the delay to avoid a thundering herd. + case <-s.SyncFull.Notif(): + select { + case <-time.After(s.stagger(s.serverUpInterval)): + return syncFullNotifEvent + case <-s.ShutdownCh: + return shutdownEvent + } + + // time for a full sync again + case <-time.After(s.Interval + s.stagger(s.Interval)): + return syncFullTimerEvent + + // do partial syncs on demand + case <-s.SyncChanges.Notif(): + return syncChangesNotifEvent + + case <-s.ShutdownCh: + return shutdownEvent + } +} + +func (s *StateSyncer) staggerFn(d time.Duration) time.Duration { + f := scaleFactor(s.ClusterSize()) + return lib.RandomStagger(time.Duration(f) * d) } // Pause temporarily disables sync runs. diff --git a/agent/ae/ae_test.go b/agent/ae/ae_test.go index 8d0afb5cb1..eb33bcb0b7 100644 --- a/agent/ae/ae_test.go +++ b/agent/ae/ae_test.go @@ -57,7 +57,7 @@ func TestAE_Pause_nestedPauseResume(t *testing.T) { defer func() { err := recover() if err == nil { - t.Fatal("unbalanced Resume() should cause a panic()") + t.Fatal("unbalanced Resume() should panic") } }() l.Resume() @@ -77,25 +77,6 @@ func TestAE_Pause_ResumeTriggersSyncChanges(t *testing.T) { } } -func TestAE_Pause_ifNotPausedRun(t *testing.T) { - l := NewStateSyncer(nil, 0, nil, nil) - - errCalled := errors.New("f called") - f := func() error { return errCalled } - - l.Pause() - err := l.ifNotPausedRun(f) - if got, want := err, errPaused; !reflect.DeepEqual(got, want) { - t.Fatalf("got error %q want %q", got, want) - } - l.Resume() - - err = l.ifNotPausedRun(f) - if got, want := err, errCalled; got != want { - t.Fatalf("got error %q want %q", got, want) - } -} - func TestAE_Run_SyncFullBeforeChanges(t *testing.T) { shutdownCh := make(chan struct{}) state := &mock{ @@ -106,7 +87,9 @@ func TestAE_Run_SyncFullBeforeChanges(t *testing.T) { } // indicate that we have partial changes before starting Run - l := testSyncer(state, shutdownCh) + l := testSyncer() + l.State = state + l.ShutdownCh = shutdownCh l.SyncChanges.Trigger() var wg sync.WaitGroup @@ -122,6 +105,177 @@ func TestAE_Run_SyncFullBeforeChanges(t *testing.T) { } } +func TestAE_Run_Quit(t *testing.T) { + // start timer which explodes if runFSM does not quit + tm := time.AfterFunc(time.Second, func() { panic("timeout") }) + + l := testSyncer() + l.runFSM(fullSyncState, func(fsmState) fsmState { return doneState }) + // should just quit + tm.Stop() +} + +func TestAE_FSM(t *testing.T) { + + t.Run("fullSyncState", func(t *testing.T) { + t.Run("Paused -> retryFullSyncState", func(t *testing.T) { + l := testSyncer() + l.Pause() + fs := l.nextFSMState(fullSyncState) + if got, want := fs, retryFullSyncState; got != want { + t.Fatalf("got state %v want %v", got, want) + } + }) + t.Run("SyncFull() error -> retryFullSyncState", func(t *testing.T) { + l := testSyncer() + l.State = &mock{syncFull: func() error { return errors.New("boom") }} + fs := l.nextFSMState(fullSyncState) + if got, want := fs, retryFullSyncState; got != want { + t.Fatalf("got state %v want %v", got, want) + } + }) + t.Run("SyncFull() OK -> partialSyncState", func(t *testing.T) { + l := testSyncer() + l.State = &mock{} + fs := l.nextFSMState(fullSyncState) + if got, want := fs, partialSyncState; got != want { + t.Fatalf("got state %v want %v", got, want) + } + }) + }) + + t.Run("retryFullSyncState", func(t *testing.T) { + // helper for testing state transitions from retrySyncFullState + test := func(ev event, to fsmState) { + l := testSyncer() + l.retrySyncFullEvent = func() event { return ev } + fs := l.nextFSMState(retryFullSyncState) + if got, want := fs, to; got != want { + t.Fatalf("got state %v want %v", got, want) + } + } + t.Run("shutdownEvent -> doneState", func(t *testing.T) { + test(shutdownEvent, doneState) + }) + t.Run("syncFullNotifEvent -> fullSyncState", func(t *testing.T) { + test(syncFullNotifEvent, fullSyncState) + }) + t.Run("syncFullTimerEvent -> fullSyncState", func(t *testing.T) { + test(syncFullTimerEvent, fullSyncState) + }) + t.Run("invalid event -> panic ", func(t *testing.T) { + defer func() { + err := recover() + if err == nil { + t.Fatal("invalid event should panic") + } + }() + test(event("invalid"), fsmState("")) + }) + }) + + t.Run("partialSyncState", func(t *testing.T) { + // helper for testing state transitions from partialSyncState + test := func(ev event, to fsmState) { + l := testSyncer() + l.syncChangesEvent = func() event { return ev } + fs := l.nextFSMState(partialSyncState) + if got, want := fs, to; got != want { + t.Fatalf("got state %v want %v", got, want) + } + } + t.Run("shutdownEvent -> doneState", func(t *testing.T) { + test(shutdownEvent, doneState) + }) + t.Run("syncFullNotifEvent -> fullSyncState", func(t *testing.T) { + test(syncFullNotifEvent, fullSyncState) + }) + t.Run("syncFullTimerEvent -> fullSyncState", func(t *testing.T) { + test(syncFullTimerEvent, fullSyncState) + }) + t.Run("syncChangesEvent+Paused -> partialSyncState", func(t *testing.T) { + l := testSyncer() + l.Pause() + l.syncChangesEvent = func() event { return syncChangesNotifEvent } + fs := l.nextFSMState(partialSyncState) + if got, want := fs, partialSyncState; got != want { + t.Fatalf("got state %v want %v", got, want) + } + }) + t.Run("syncChangesEvent+SyncChanges() error -> partialSyncState", func(t *testing.T) { + l := testSyncer() + l.State = &mock{syncChanges: func() error { return errors.New("boom") }} + l.syncChangesEvent = func() event { return syncChangesNotifEvent } + fs := l.nextFSMState(partialSyncState) + if got, want := fs, partialSyncState; got != want { + t.Fatalf("got state %v want %v", got, want) + } + }) + t.Run("syncChangesEvent+SyncChanges() OK -> partialSyncState", func(t *testing.T) { + l := testSyncer() + l.State = &mock{} + l.syncChangesEvent = func() event { return syncChangesNotifEvent } + fs := l.nextFSMState(partialSyncState) + if got, want := fs, partialSyncState; got != want { + t.Fatalf("got state %v want %v", got, want) + } + }) + }) +} + +func TestAE_SyncChangesEvent(t *testing.T) { + t.Run("trigger shutdownEvent", func(t *testing.T) { + l := testSyncer() + l.ShutdownCh = make(chan struct{}) + evch := make(chan event) + go func() { evch <- l.syncChangesEvent() }() + close(l.ShutdownCh) + if got, want := <-evch, shutdownEvent; got != want { + t.Fatalf("got event %q want %q", got, want) + } + }) + t.Run("trigger shutdownEvent during FullNotif", func(t *testing.T) { + l := testSyncer() + l.ShutdownCh = make(chan struct{}) + evch := make(chan event) + go func() { evch <- l.syncChangesEvent() }() + l.SyncFull.Trigger() + time.Sleep(100 * time.Millisecond) + close(l.ShutdownCh) + if got, want := <-evch, shutdownEvent; got != want { + t.Fatalf("got event %q want %q", got, want) + } + }) + t.Run("trigger syncFullNotifEvent", func(t *testing.T) { + l := testSyncer() + l.serverUpInterval = 10 * time.Millisecond + evch := make(chan event) + go func() { evch <- l.syncChangesEvent() }() + l.SyncFull.Trigger() + if got, want := <-evch, syncFullNotifEvent; got != want { + t.Fatalf("got event %q want %q", got, want) + } + }) + t.Run("trigger syncFullTimerEvent", func(t *testing.T) { + l := testSyncer() + l.Interval = 10 * time.Millisecond + evch := make(chan event) + go func() { evch <- l.syncChangesEvent() }() + if got, want := <-evch, syncFullTimerEvent; got != want { + t.Fatalf("got event %q want %q", got, want) + } + }) + t.Run("trigger syncChangesNotifEvent", func(t *testing.T) { + l := testSyncer() + evch := make(chan event) + go func() { evch <- l.syncChangesEvent() }() + l.SyncChanges.Trigger() + if got, want := <-evch, syncChangesNotifEvent; got != want { + t.Fatalf("got event %q want %q", got, want) + } + }) +} + type mock struct { seq []string syncFull, syncChanges func() error @@ -143,9 +297,9 @@ func (m *mock) SyncChanges() error { return nil } -func testSyncer(state State, shutdownCh chan struct{}) *StateSyncer { +func testSyncer() *StateSyncer { logger := log.New(os.Stderr, "", 0) - l := NewStateSyncer(state, 0, shutdownCh, logger) + l := NewStateSyncer(nil, time.Second, nil, logger) l.stagger = func(d time.Duration) time.Duration { return d } l.ClusterSize = func() int { return 1 } return l