diff --git a/agent/consul/controller/controller.go b/agent/consul/controller/controller.go index 03a95e2122..22c9ba8272 100644 --- a/agent/consul/controller/controller.go +++ b/agent/consul/controller/controller.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" + "sync" "sync/atomic" "time" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-hclog" "golang.org/x/sync/errgroup" ) @@ -39,6 +41,9 @@ type Controller interface { // Request retry rate limiter. This should only ever be called prior to // running Run. WithBackoff(base, max time.Duration) Controller + // WithLogger sets the logger for the controller, it should be called prior to Start + // being invoked. + WithLogger(logger hclog.Logger) Controller // WithWorkers sets the number of worker goroutines used to process the queue // this defaults to 1 goroutine. WithWorkers(i int) Controller @@ -46,6 +51,12 @@ type Controller interface { // implementation. This is most useful for testing. This should only ever be called // prior to running Run. WithQueueFactory(fn func(ctx context.Context, baseBackoff time.Duration, maxBackoff time.Duration) WorkQueue) Controller + // AddTrigger allows for triggering a reconciliation request when a + // triggering function returns, when the passed in context is canceled + // the trigger must return + AddTrigger(request Request, trigger func(ctx context.Context) error) + // RemoveTrigger removes the triggering function associated with the Request object + RemoveTrigger(request Request) } var _ Controller = &controller{} @@ -78,8 +89,27 @@ type controller struct { // publisher is the event publisher that should be subscribed to for any updates publisher state.EventPublisher + // waitOnce ensures we wait until the controller has started + waitOnce sync.Once + // started signals when the controller has started + started chan struct{} + + // group is the error group used in our main start up worker routines + group *errgroup.Group + // groupCtx is the context of the error group to use in spinning up our + // worker routines + groupCtx context.Context + + // triggers is a map of cancel functions for out-of-band Request triggers + triggers map[Request]func() + // triggerMutex is used for accessing the above map + triggerMutex sync.Mutex + // running ensures that we are only calling Run a single time running int32 + + // logger is the logger for the controller + logger hclog.Logger } // New returns a new Controller associated with the given state store and reconciler. @@ -91,6 +121,9 @@ func New(publisher state.EventPublisher, reconciler Reconciler) Controller { baseBackoff: 5 * time.Millisecond, maxBackoff: 1000 * time.Second, makeQueue: RunWorkQueue, + started: make(chan struct{}), + triggers: make(map[Request]func()), + logger: hclog.NewNullLogger(), } } @@ -130,6 +163,14 @@ func (c *controller) WithWorkers(i int) Controller { return c } +// WithLogger sets the internal logger for the controller. +func (c *controller) WithLogger(logger hclog.Logger) Controller { + c.ensureNotRunning() + + c.logger = logger + return c +} + // WithQueueFactory changes the initialization method for the Controller's work // queue, this is predominantly just used for testing. This should only ever be called // prior to running Start. @@ -157,15 +198,18 @@ func (c *controller) Run(ctx context.Context) error { panic("Run cannot be called more than once") } - group, groupCtx := errgroup.WithContext(ctx) + c.group, c.groupCtx = errgroup.WithContext(ctx) // set up our queue - c.work = c.makeQueue(groupCtx, c.baseBackoff, c.maxBackoff) + c.work = c.makeQueue(c.groupCtx, c.baseBackoff, c.maxBackoff) + + // we can now add stuff to the queue from other contexts + close(c.started) for _, sub := range c.subscriptions { // store a reference for the closure sub := sub - group.Go(func() error { + c.group.Go(func() error { var index uint64 subscription, err := c.publisher.Subscribe(sub.request) @@ -201,14 +245,14 @@ func (c *controller) Run(ctx context.Context) error { } for i := 0; i < c.workers; i++ { - group.Go(func() error { + c.group.Go(func() error { for { request, shutdown := c.work.Get() if shutdown { // Stop working return nil } - c.reconcileHandler(groupCtx, request) + c.reconcileHandler(c.groupCtx, request) // Done is called here because it is required to be called // when we've finished processing each request c.work.Done(request) @@ -216,10 +260,57 @@ func (c *controller) Run(ctx context.Context) error { }) } - <-groupCtx.Done() + <-c.groupCtx.Done() return nil } +// AddTrigger allows for triggering a reconciliation request every time that the +// triggering function returns, when the passed in context is canceled +// the trigger must return +func (c *controller) AddTrigger(request Request, trigger func(ctx context.Context) error) { + c.wait() + + ctx, cancel := context.WithCancel(c.groupCtx) + + c.triggerMutex.Lock() + oldCancel, ok := c.triggers[request] + if ok { + oldCancel() + } + c.triggers[request] = cancel + c.triggerMutex.Unlock() + + c.group.Go(func() error { + if err := trigger(ctx); err != nil { + c.logger.Error("error while running trigger, adding re-reconcilation anyway", "error", err) + } + select { + case <-ctx.Done(): + return nil + default: + c.work.Add(request) + return nil + } + }) +} + +// RemoveTrigger removes the triggering function associated with the Request object +func (c *controller) RemoveTrigger(request Request) { + c.triggerMutex.Lock() + cancel, ok := c.triggers[request] + if ok { + cancel() + delete(c.triggers, request) + } + c.triggerMutex.Unlock() +} + +func (c *controller) wait() { + c.waitOnce.Do(func() { + <-c.started + }) +} + func (c *controller) processEvent(sub subscription, event stream.Event) error { switch payload := event.Payload.(type) { case state.EventPayloadConfigEntry: diff --git a/agent/consul/controller/controller_test.go b/agent/consul/controller/controller_test.go index fc270a6564..ea42ed87cd 100644 --- a/agent/consul/controller/controller_test.go +++ b/agent/consul/controller/controller_test.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-memdb" "github.com/stretchr/testify/require" ) @@ -415,3 +416,148 @@ func TestConfigEntrySubscriptions(t *testing.T) { }) } } + +func TestBasicController_Triggers(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + reconciler := newTestReconciler(true) + + publisher := stream.NewEventPublisher(0) + go publisher.Run(ctx) + + controller := New(publisher, reconciler) + + go func() { + require.NoError(t, controller.Run(ctx)) + }() + + ensureCalled := func(request chan Request, name string) bool { + select { + case req := <-request: + require.Equal(t, structs.IngressGateway, req.Kind) + require.Equal(t, name, req.Name) + return true + case <-time.After(10 * time.Millisecond): + return false + } + } + + request := Request{ + Kind: structs.IngressGateway, + Name: "foo-1", + } + + triggerOneChan := make(chan struct{}, 3) + triggerOne := func(ctx context.Context) error { + select { + case <-triggerOneChan: + return nil + case <-ctx.Done(): + return nil + } + } + controller.AddTrigger(request, triggerOne) + require.False(t, ensureCalled(reconciler.received, "foo-1")) + triggerOneChan <- struct{}{} + reconciler.stepFor(10 * time.Millisecond) + require.True(t, ensureCalled(reconciler.received, "foo-1")) + + // do it again + require.False(t, ensureCalled(reconciler.received, "foo-1")) + controller.AddTrigger(request, triggerOne) + triggerOneChan <- struct{}{} + reconciler.stepFor(10 * time.Millisecond) + require.True(t, ensureCalled(reconciler.received, "foo-1")) + + // check with the overwritten trigger + controller.AddTrigger(request, triggerOne) + triggerTwoChan := make(chan struct{}, 2) + triggerTwo := func(ctx context.Context) error { + select { + case <-triggerTwoChan: + return nil + case <-ctx.Done(): + return nil + } + } + controller.AddTrigger(request, triggerTwo) + triggerOneChan <- struct{}{} + reconciler.stepFor(10 * time.Millisecond) + require.False(t, ensureCalled(reconciler.received, "foo-1")) + triggerTwoChan <- struct{}{} + reconciler.stepFor(10 * time.Millisecond) + require.True(t, ensureCalled(reconciler.received, "foo-1")) + + // remove the trigger and make sure we're not called again + controller.RemoveTrigger(request) + triggerTwoChan <- struct{}{} + reconciler.stepFor(10 * time.Millisecond) + require.False(t, ensureCalled(reconciler.received, "foo-1")) +} + +func TestDiscoveryChainController(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + reconciler := newTestReconciler(false) + + publisher := stream.NewEventPublisher(1 * time.Millisecond) + go publisher.Run(ctx) + + // get the store through the FSM since the publisher handlers get registered through it + store := fsm.NewFromDeps(fsm.Deps{ + Logger: hclog.New(nil), + NewStateStore: func() *state.Store { + return state.NewStateStoreWithEventPublisher(nil, publisher) + }, + Publisher: publisher, + }).State() + + controller := New(publisher, reconciler) + go controller.Subscribe(&stream.SubscribeRequest{ + Topic: state.EventTopicIngressGateway, + Subject: stream.SubjectWildcard, + }).WithWorkers(10).Run(ctx) + + request := Request{ + Kind: structs.IngressGateway, + Name: "foo-1", + } + + ensureCalled := func(request chan Request, name string) bool { + select { + case req := <-request: + require.Equal(t, structs.IngressGateway, req.Kind) + require.Equal(t, name, req.Name) + return true + case <-time.After(10 * time.Millisecond): + return false + } + } + + require.NoError(t, store.EnsureConfigEntry(1, &structs.IngressGatewayConfigEntry{ + Kind: structs.IngressGateway, + Name: "foo-1", + })) + require.True(t, ensureCalled(reconciler.received, "foo-1")) + + // create the trigger and something that changes in its upstream discovery chain and ensure that we've + // fired the reconciler + ws := memdb.NewWatchSet() + ws.Add(store.AbandonCh()) + _, _, err := store.ReadDiscoveryChainConfigEntries(ws, "foo-2", nil) + require.NoError(t, err) + controller.AddTrigger(request, ws.WatchCtx) + + require.False(t, ensureCalled(reconciler.received, "foo-1")) + require.NoError(t, store.EnsureConfigEntry(1, &structs.ServiceResolverConfigEntry{ + Kind: structs.ServiceResolver, + Name: "foo-2", + })) + require.True(t, ensureCalled(reconciler.received, "foo-1")) +} diff --git a/agent/consul/controller/reconciler_test.go b/agent/consul/controller/reconciler_test.go index d2f0567533..5b08494ae6 100644 --- a/agent/consul/controller/reconciler_test.go +++ b/agent/consul/controller/reconciler_test.go @@ -3,6 +3,7 @@ package controller import ( "context" "sync" + "time" ) type testReconciler struct { @@ -43,6 +44,12 @@ func (r *testReconciler) setResponse(err error) { func (r *testReconciler) step() { r.stepChan <- struct{}{} } +func (r *testReconciler) stepFor(duration time.Duration) { + select { + case r.stepChan <- struct{}{}: + case <-time.After(duration): + } +} func (r *testReconciler) stop() { close(r.stopChan)