From cb01702cd227692db37f24a207ab3ca0785cbc37 Mon Sep 17 00:00:00 2001 From: Matt Keeler Date: Tue, 21 Jun 2022 13:36:49 -0400 Subject: [PATCH] Add server local blocking queries and watches (#13438) Co-authored-by: Dan Upton --- agent/consul/watch/mock_StateStore_test.go | 40 ++ agent/consul/watch/server_local.go | 332 ++++++++++++++++ agent/consul/watch/server_local_test.go | 424 +++++++++++++++++++++ 3 files changed, 796 insertions(+) create mode 100644 agent/consul/watch/mock_StateStore_test.go create mode 100644 agent/consul/watch/server_local.go create mode 100644 agent/consul/watch/server_local_test.go diff --git a/agent/consul/watch/mock_StateStore_test.go b/agent/consul/watch/mock_StateStore_test.go new file mode 100644 index 0000000000..08d58e2f04 --- /dev/null +++ b/agent/consul/watch/mock_StateStore_test.go @@ -0,0 +1,40 @@ +// Code generated by mockery v2.12.2. DO NOT EDIT. + +package watch + +import ( + testing "testing" + + mock "github.com/stretchr/testify/mock" +) + +// MockStateStore is an autogenerated mock type for the StateStore type +type MockStateStore struct { + mock.Mock +} + +// AbandonCh provides a mock function with given fields: +func (_m *MockStateStore) AbandonCh() <-chan struct{} { + ret := _m.Called() + + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } + } + + return r0 +} + +// NewMockStateStore creates a new instance of MockStateStore. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockStateStore(t testing.TB) *MockStateStore { + mock := &MockStateStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/consul/watch/server_local.go b/agent/consul/watch/server_local.go new file mode 100644 index 0000000000..8085396fe2 --- /dev/null +++ b/agent/consul/watch/server_local.go @@ -0,0 +1,332 @@ +package watch + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/hashicorp/consul/lib/retry" + "github.com/hashicorp/go-memdb" + hashstructure_v2 "github.com/mitchellh/hashstructure/v2" +) + +var ( + ErrorNotFound = errors.New("no data found for query") + ErrorNotChanged = errors.New("data did not change for query") + + errNilContext = errors.New("cannot call ServerLocalNotify with a nil context") + errNilGetStore = errors.New("cannot call ServerLocalNotify without a callback to get a StateStore") + errNilQuery = errors.New("cannot call ServerLocalNotify without a callback to perform the query") + errNilNotify = errors.New("cannot call ServerLocalNotify without a callback to send notifications") +) + +//go:generate mockery --name StateStore --inpackage --testonly +type StateStore interface { + AbandonCh() <-chan struct{} +} + +const ( + defaultWaiterMinFailures uint = 1 + defaultWaiterMinWait = time.Second + defaultWaiterMaxWait = 60 * time.Second + defaultWaiterFactor = 2 * time.Second +) + +var ( + defaultWaiterJitter = retry.NewJitter(100) +) + +func defaultWaiter() *retry.Waiter { + return &retry.Waiter{ + MinFailures: defaultWaiterMinFailures, + MinWait: defaultWaiterMinWait, + MaxWait: defaultWaiterMaxWait, + Jitter: defaultWaiterJitter, + Factor: defaultWaiterFactor, + } +} + +// noopDone can be passed to serverLocalNotifyWithWaiter +func noopDone() {} + +// ServerLocalBlockingQuery performs a blocking query similar to the pre-existing blockingQuery +// method on the agent/consul.Server type. There are a few key differences. +// +// 1. This function makes use of Go 1.18 generics. The function is parameterized with two +// types. The first is the ResultType which can be anything. Having this be parameterized +// instead of using interface{} allows us to simplify the call sites so that no type +// coercion from interface{} to the real type is necessary. The second parameterized type +// is something that VERY loosely resembles a agent/consul/state.Store type. The StateStore +// interface in this package has a single method to get the stores abandon channel so we +// know when a snapshot restore is occurring and can act accordingly. We could have not +// parameterized this type and used a real *state.Store instead but then we would have +// concrete dependencies on the state package and it would make it a little harder to +// test this function. +// +// We could have also avoided the need to use a ResultType parameter by taking the route +// the original blockingQuery method did and to just assume all callers close around +// a pointer to their results and can modify it as necessary. That way of doing things +// feels a little gross so I have taken this one a different direction. The old way +// also gets especially gross with how we have to push concerns of spurious wakeup +// suppression down into every call site. +// +// 2. This method has no internal timeout and can potentially run forever until a state +// change is observed. If there is a desire to have a timeout, that should be built into +// the context.Context passed as the first argument. +// +// 3. This method bakes in some newer functionality around hashing of results to prevent sending +// back data when nothing has actually changed. With the old blockingQuery method this has to +// be done within the closure passed to the method which means the same bit of code is duplicated +// in many places. As this functionality isn't necessary in many scenarios whether to opt-in to +// that behavior is a argument to this function. +// +// Similar to the older method: +// +// 1. Errors returned from the query will be propagated back to the caller. +// +// The query function must follow these rules: +// +// 1. To access data it must use the passed in StoreType (which will be a state.Store when +// everything gets stiched together outside of unit tests). +// 2. It must return an index greater than the minIndex if the results returned by the query +// have changed. +// 3. Any channels added to the memdb.WatchSet must unblock when the results +// returned by the query have changed. +// +// To ensure optimal performance of the query, the query function should make a +// best-effort attempt to follow these guidelines: +// +// 1. Only return an index greater than the minIndex. +// 2. Any channels added to the memdb.WatchSet should only unblock when the +// results returned by the query have changed. This might be difficult +// to do when blocking on non-existent data. +// +func ServerLocalBlockingQuery[ResultType any, StoreType StateStore]( + ctx context.Context, + getStore func() StoreType, + minIndex uint64, + suppressSpuriousWakeup bool, + query func(memdb.WatchSet, StoreType) (uint64, ResultType, error), +) (uint64, ResultType, error) { + var ( + notFound bool + ranOnce bool + priorHash uint64 + ) + + var zeroResult ResultType + if getStore == nil { + return 0, zeroResult, fmt.Errorf("no getStore function was provided to ServerLocalBlockingQuery") + } + if query == nil { + return 0, zeroResult, fmt.Errorf("no query function was provided to ServerLocalBlockingQuery") + } + + for { + state := getStore() + + ws := memdb.NewWatchSet() + + // Adding the AbandonCh to the WatchSet allows us to detect when + // a snapshot restore happens that would otherwise not modify anything + // within the individual state store. If we didn't do this then we + // could end up blocking indefinitely. + ws.Add(state.AbandonCh()) + + index, result, err := query(ws, state) + + switch { + case errors.Is(err, ErrorNotFound): + // if minIndex is 0 then we should never block but we + // also should not propagate the error + if minIndex == 0 { + return index, result, nil + } + + // update the min index if the previous result was not found. This + // is an attempt to not return data unnecessarily when we end up + // watching the root of a memdb Radix tree because the data being + // watched doesn't exist yet. + if notFound { + minIndex = index + } + + notFound = true + case err != nil: + return index, result, err + } + + // when enabled we can prevent sending back data that hasn't changed. + if suppressSpuriousWakeup { + newHash, err := hashstructure_v2.Hash(result, hashstructure_v2.FormatV2, nil) + if err != nil { + return index, result, fmt.Errorf("error hashing data for spurious wakeup suppression: %w", err) + } + + // set minIndex to the returned index to prevent sending back identical data + if ranOnce && priorHash == newHash { + minIndex = index + } + ranOnce = true + priorHash = newHash + } + + // one final check if we should be considered unblocked and + // return the value. Some conditions in the switch above + // alter the minIndex and prevent this return if it would + // be desirable. One such case is when the actual data has + // not changed since the last round through the query and + // we would rather not do any further processing for unchanged + // data. This mostly protects against watches for data that + // doesn't exist from return the non-existant value constantly. + if index > minIndex { + return index, result, nil + } + + // Block until something changes. Because we have added the state + // stores AbandonCh to this watch set, a snapshot restore will + // cause things to unblock in addition to changes to the actual + // queried data. + if err := ws.WatchCtx(ctx); err != nil { + // exit if the context was cancelled + return index, result, nil + } + + select { + case <-state.AbandonCh(): + return index, result, nil + default: + } + } +} + +// ServerLocalNotify will watch for changes in the State Store using the provided +// query function and invoke the notify callback whenever the results of that query +// function have changed. This function will return an error if parameter validations +// fail but otherwise the background go routine to process the notifications will +// be spawned and nil will be returned. Just like ServerLocalBlockingQuery this makes +// use of Go Generics and for the same reasons as outlined in the documentation for +// that function. +func ServerLocalNotify[ResultType any, StoreType StateStore]( + ctx context.Context, + correlationID string, + getStore func() StoreType, + query func(memdb.WatchSet, StoreType) (uint64, ResultType, error), + notify func(ctx context.Context, correlationID string, result ResultType, err error), +) error { + return serverLocalNotify( + ctx, + correlationID, + getStore, + query, + notify, + // Public callers should not need to know when the internal go routines are finished. + // Being able to provide a done function to the internal version of this function is + // to allow our tests to be more determinstic and to eliminate arbitrary sleeps. + noopDone, + // Public callers do not get to override the error backoff configuration. Internally + // we want to allow for this to enable our unit tests to run much more quickly. + defaultWaiter(), + ) +} + +// serverLocalNotify is the internal version of ServerLocalNotify. It takes +// two additional arguments of the waiter to use and a function to call +// when the notification go routine has finished +func serverLocalNotify[ResultType any, StoreType StateStore]( + ctx context.Context, + correlationID string, + getStore func() StoreType, + query func(memdb.WatchSet, StoreType) (uint64, ResultType, error), + notify func(ctx context.Context, correlationID string, result ResultType, err error), + done func(), + waiter *retry.Waiter, +) error { + if ctx == nil { + return errNilContext + } + + if getStore == nil { + return errNilGetStore + } + + if query == nil { + return errNilQuery + } + + if notify == nil { + return errNilNotify + } + + go serverLocalNotifyRoutine( + ctx, + correlationID, + getStore, + query, + notify, + done, + waiter, + ) + return nil +} + +// serverLocalNotifyRoutine is the function intended to be run within a new +// go routine to process the updates. It will not check to ensure callbacks +// are non-nil nor perform other parameter validation. It is assumed that +// the in-package caller of this method will have already done that. It also +// takes the backoff waiter in as an argument so that unit tests within this +// package can override the default values that the exported ServerLocalNotify +// function would have set up. +func serverLocalNotifyRoutine[ResultType any, StoreType StateStore]( + ctx context.Context, + correlationID string, + getStore func() StoreType, + query func(memdb.WatchSet, StoreType) (uint64, ResultType, error), + notify func(ctx context.Context, correlationID string, result ResultType, err error), + done func(), + waiter *retry.Waiter, +) { + defer done() + + var minIndex uint64 + + for { + // Check if the context has been cancelled. Do not issue + // more queries if it has been cancelled. + if ctx.Err() != nil { + return + } + + // Perform the blocking query + index, result, err := ServerLocalBlockingQuery(ctx, getStore, minIndex, true, query) + + // Check if the context has been cancelled. If it has we should not send more + // notifications. + if ctx.Err() != nil { + return + } + + // Check the index to see if we should call notify + if minIndex == 0 || minIndex < index { + notify(ctx, correlationID, result, err) + minIndex = index + } + + // Handle errors with backoff. Badly behaved blocking calls that returned + // a zero index are considered as failures since we need to not get stuck + // in a busy loop. + if err == nil && index > 0 { + waiter.Reset() + } else { + if waiter.Wait(ctx) != nil { + return + } + } + + // ensure we don't use zero indexes + if err == nil && minIndex < 1 { + minIndex = 1 + } + } +} diff --git a/agent/consul/watch/server_local_test.go b/agent/consul/watch/server_local_test.go new file mode 100644 index 0000000000..6fa440979b --- /dev/null +++ b/agent/consul/watch/server_local_test.go @@ -0,0 +1,424 @@ +package watch + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/hashicorp/consul/lib/retry" + "github.com/hashicorp/go-memdb" + mock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockStoreProvider struct { + mock.Mock +} + +func newMockStoreProvider(t *testing.T) *mockStoreProvider { + t.Helper() + provider := &mockStoreProvider{} + t.Cleanup(func() { + provider.AssertExpectations(t) + }) + return provider +} + +func (m *mockStoreProvider) getStore() *MockStateStore { + return m.Called().Get(0).(*MockStateStore) +} + +type testResult struct { + value string +} + +func (m *mockStoreProvider) query(ws memdb.WatchSet, store *MockStateStore) (uint64, *testResult, error) { + ret := m.Called(ws, store) + + index := ret.Get(0).(uint64) + result := ret.Get(1).(*testResult) + err := ret.Error(2) + + return index, result, err +} + +func (m *mockStoreProvider) notify(ctx context.Context, correlationID string, result *testResult, err error) { + m.Called(ctx, correlationID, result, err) +} + +func TestServerLocalBlockingQuery_getStoreNotProvided(t *testing.T) { + _, _, err := ServerLocalBlockingQuery( + context.Background(), + nil, + 0, + true, + func(memdb.WatchSet, *MockStateStore) (uint64, struct{}, error) { + return 0, struct{}{}, nil + }, + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "no getStore function was provided") +} + +func TestServerLocalBlockingQuery_queryNotProvided(t *testing.T) { + var query func(memdb.WatchSet, *MockStateStore) (uint64, struct{}, error) + _, _, err := ServerLocalBlockingQuery( + context.Background(), + func() *MockStateStore { return nil }, + 0, + true, + query, + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "no query function was provided") +} + +func TestServerLocalBlockingQuery_NonBlocking(t *testing.T) { + abandonCh := make(chan struct{}) + t.Cleanup(func() { close(abandonCh) }) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Once() + + provider := newMockStoreProvider(t) + provider.On("getStore").Return(store).Once() + provider.On("query", mock.Anything, store). + Return(uint64(1), &testResult{value: "foo"}, nil). + Once() + + idx, result, err := ServerLocalBlockingQuery( + context.Background(), + provider.getStore, + 0, + true, + provider.query, + ) + require.NoError(t, err) + require.EqualValues(t, 1, idx) + require.Equal(t, &testResult{value: "foo"}, result) +} + +func TestServerLocalBlockingQuery_NotFound(t *testing.T) { + abandonCh := make(chan struct{}) + t.Cleanup(func() { close(abandonCh) }) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Once() + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Once() + + var nilResult *testResult + provider.On("query", mock.Anything, store). + Return(uint64(1), nilResult, ErrorNotFound). + Once() + + idx, result, err := ServerLocalBlockingQuery( + context.Background(), + provider.getStore, + 0, + true, + provider.query, + ) + require.NoError(t, err) + require.EqualValues(t, 1, idx) + require.Nil(t, result) +} + +func TestServerLocalBlockingQuery_NotFoundBlocks(t *testing.T) { + abandonCh := make(chan struct{}) + t.Cleanup(func() { close(abandonCh) }) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Times(5) + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Times(3) + + var nilResult *testResult + // Initial data returned is not found and has an index less than the original + // blocking index. This should not return data to the caller. + provider.On("query", mock.Anything, store). + Return(uint64(4), nilResult, ErrorNotFound). + Run(addReadyWatchSet). + Once() + // There is an update to the data but the value still doesn't exist. Therefore + // we should not return data to the caller. + provider.On("query", mock.Anything, store). + Return(uint64(6), nilResult, ErrorNotFound). + Run(addReadyWatchSet). + Once() + // Finally we have some real data and can return it to the caller. + provider.On("query", mock.Anything, store). + Return(uint64(7), &testResult{value: "foo"}, nil). + Once() + + idx, result, err := ServerLocalBlockingQuery( + context.Background(), + provider.getStore, + 5, + true, + provider.query, + ) + require.NoError(t, err) + require.EqualValues(t, 7, idx) + require.Equal(t, &testResult{value: "foo"}, result) +} + +func TestServerLocalBlockingQuery_Error(t *testing.T) { + abandonCh := make(chan struct{}) + t.Cleanup(func() { close(abandonCh) }) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Once() + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Once() + + var nilResult *testResult + provider.On("query", mock.Anything, store). + Return(uint64(10), nilResult, fmt.Errorf("synthetic error")). + Once() + + idx, result, err := ServerLocalBlockingQuery( + context.Background(), + provider.getStore, + 4, + true, + provider.query, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "synthetic error") + require.EqualValues(t, 10, idx) + require.Nil(t, result) +} + +func TestServerLocalBlockingQuery_ContextCancellation(t *testing.T) { + abandonCh := make(chan struct{}) + t.Cleanup(func() { close(abandonCh) }) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Once() + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Once() + provider.On("query", mock.Anything, store). + // Return an index that should not cause the blocking query to return. + Return(uint64(4), &testResult{value: "foo"}, nil). + Once(). + Run(func(_ mock.Arguments) { + // Cancel the context so that the memdb WatchCtx call will error. + cancel() + }) + + idx, result, err := ServerLocalBlockingQuery( + ctx, + provider.getStore, + 8, + true, + provider.query, + ) + // The internal cancellation error should not be propagated. + require.NoError(t, err) + require.EqualValues(t, 4, idx) + require.Equal(t, &testResult{value: "foo"}, result) +} + +func TestServerLocalBlockingQuery_StateAbandoned(t *testing.T) { + abandonCh := make(chan struct{}) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Twice() + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Once() + provider.On("query", mock.Anything, store). + // Return an index that should not cause the blocking query to return. + Return(uint64(4), &testResult{value: "foo"}, nil). + Once(). + Run(func(_ mock.Arguments) { + // Cancel the context so that the memdb WatchCtx call will error. + close(abandonCh) + }) + + idx, result, err := ServerLocalBlockingQuery( + context.Background(), + provider.getStore, + 8, + true, + provider.query, + ) + // The internal cancellation error should not be propagated. + require.NoError(t, err) + require.EqualValues(t, 4, idx) + require.Equal(t, &testResult{value: "foo"}, result) +} + +func TestServerLocalNotify_Validations(t *testing.T) { + provider := newMockStoreProvider(t) + + type testCase struct { + ctx context.Context + getStore func() *MockStateStore + query func(memdb.WatchSet, *MockStateStore) (uint64, *testResult, error) + notify func(context.Context, string, *testResult, error) + err error + } + + cases := map[string]testCase{ + "nil-context": { + getStore: provider.getStore, + query: provider.query, + notify: provider.notify, + err: errNilContext, + }, + "nil-getStore": { + ctx: context.Background(), + query: provider.query, + notify: provider.notify, + err: errNilGetStore, + }, + "nil-query": { + ctx: context.Background(), + getStore: provider.getStore, + notify: provider.notify, + err: errNilQuery, + }, + "nil-notify": { + ctx: context.Background(), + getStore: provider.getStore, + query: provider.query, + err: errNilNotify, + }, + } + + for name, tcase := range cases { + t.Run(name, func(t *testing.T) { + err := ServerLocalNotify(tcase.ctx, "test", tcase.getStore, tcase.query, tcase.notify) + require.ErrorIs(t, err, tcase.err) + }) + } +} + +func TestServerLocalNotify(t *testing.T) { + notifyCtx, notifyCancel := context.WithCancel(context.Background()) + t.Cleanup(notifyCancel) + + abandonCh := make(chan struct{}) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Times(3) + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Times(3) + provider.On("query", mock.Anything, store). + Return(uint64(4), &testResult{value: "foo"}, nil). + Once() + provider.On("notify", notifyCtx, t.Name(), &testResult{value: "foo"}, nil).Once() + provider.On("query", mock.Anything, store). + Return(uint64(6), &testResult{value: "bar"}, nil). + Once() + provider.On("notify", notifyCtx, t.Name(), &testResult{value: "bar"}, nil).Once() + provider.On("query", mock.Anything, store). + Return(uint64(7), &testResult{value: "baz"}, context.Canceled). + Run(func(mock.Arguments) { + notifyCancel() + }) + + doneCtx, routineDone := context.WithCancel(context.Background()) + err := serverLocalNotify(notifyCtx, t.Name(), provider.getStore, provider.query, provider.notify, routineDone, defaultWaiter()) + require.NoError(t, err) + + // Wait for the context cancellation which will happen when the "query" func is run the third time. The doneCtx gets "cancelled" + // by the backgrounded go routine when it is actually finished. We need to wait for this to ensure that all mocked calls have been + // made and that no extra calls get made. + <-doneCtx.Done() +} + +func TestServerLocalNotify_internal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + abandonCh := make(chan struct{}) + + store := NewMockStateStore(t) + store.On("AbandonCh"). + Return(closeChan(abandonCh)). + Times(4) + + var nilResult *testResult + + provider := newMockStoreProvider(t) + provider.On("getStore"). + Return(store). + Times(4) + provider.On("query", mock.Anything, store). + Return(uint64(0), nilResult, fmt.Errorf("injected error")). + Times(3) + provider.On("notify", ctx, "test", nilResult, fmt.Errorf("injected error")). + Times(3) + provider.On("query", mock.Anything, store). + Return(uint64(7), &testResult{value: "foo"}, nil). + Once() + provider.On("notify", ctx, "test", &testResult{value: "foo"}, nil). + Once(). + Run(func(mock.Arguments) { + cancel() + }) + waiter := retry.Waiter{ + MinFailures: 1, + MinWait: time.Millisecond, + MaxWait: 50 * time.Millisecond, + Jitter: retry.NewJitter(100), + Factor: 2 * time.Millisecond, + } + + // all the mock expectations should ensure things are working properly + serverLocalNotifyRoutine(ctx, "test", provider.getStore, provider.query, provider.notify, noopDone, &waiter) +} + +func addReadyWatchSet(args mock.Arguments) { + ws := args.Get(0).(memdb.WatchSet) + ch := make(chan struct{}) + ws.Add(ch) + close(ch) +} + +// small convenience to make this more readable. The alternative in a few +// cases would be to do something like (<-chan struct{})(ch). I find that +// syntax very difficult to read. +func closeChan(ch chan struct{}) <-chan struct{} { + return ch +}