Add trigger for doing reconciliation based on watch sets (#16052)

* Add trigger for doing reconciliation based on watch sets

* update doc string

* Fix my grammar fail
This commit is contained in:
Andrew Stucki 2023-01-26 15:20:37 -05:00 committed by GitHub
parent 44c608706b
commit 3febdbff39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 250 additions and 6 deletions

View File

@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-hclog"
"golang.org/x/sync/errgroup"
)
@ -39,6 +41,9 @@ type Controller interface {
// Request retry rate limiter. This should only ever be called prior to
// running Run.
WithBackoff(base, max time.Duration) Controller
// WithLogger sets the logger for the controller, it should be called prior to Start
// being invoked.
WithLogger(logger hclog.Logger) Controller
// WithWorkers sets the number of worker goroutines used to process the queue
// this defaults to 1 goroutine.
WithWorkers(i int) Controller
@ -46,6 +51,12 @@ type Controller interface {
// implementation. This is most useful for testing. This should only ever be called
// prior to running Run.
WithQueueFactory(fn func(ctx context.Context, baseBackoff time.Duration, maxBackoff time.Duration) WorkQueue) Controller
// AddTrigger allows for triggering a reconciliation request when a
// triggering function returns, when the passed in context is canceled
// the trigger must return
AddTrigger(request Request, trigger func(ctx context.Context) error)
// RemoveTrigger removes the triggering function associated with the Request object
RemoveTrigger(request Request)
}
var _ Controller = &controller{}
@ -78,8 +89,27 @@ type controller struct {
// publisher is the event publisher that should be subscribed to for any updates
publisher state.EventPublisher
// waitOnce ensures we wait until the controller has started
waitOnce sync.Once
// started signals when the controller has started
started chan struct{}
// group is the error group used in our main start up worker routines
group *errgroup.Group
// groupCtx is the context of the error group to use in spinning up our
// worker routines
groupCtx context.Context
// triggers is a map of cancel functions for out-of-band Request triggers
triggers map[Request]func()
// triggerMutex is used for accessing the above map
triggerMutex sync.Mutex
// running ensures that we are only calling Run a single time
running int32
// logger is the logger for the controller
logger hclog.Logger
}
// New returns a new Controller associated with the given state store and reconciler.
@ -91,6 +121,9 @@ func New(publisher state.EventPublisher, reconciler Reconciler) Controller {
baseBackoff: 5 * time.Millisecond,
maxBackoff: 1000 * time.Second,
makeQueue: RunWorkQueue,
started: make(chan struct{}),
triggers: make(map[Request]func()),
logger: hclog.NewNullLogger(),
}
}
@ -130,6 +163,14 @@ func (c *controller) WithWorkers(i int) Controller {
return c
}
// WithLogger sets the internal logger for the controller.
func (c *controller) WithLogger(logger hclog.Logger) Controller {
c.ensureNotRunning()
c.logger = logger
return c
}
// WithQueueFactory changes the initialization method for the Controller's work
// queue, this is predominantly just used for testing. This should only ever be called
// prior to running Start.
@ -157,15 +198,18 @@ func (c *controller) Run(ctx context.Context) error {
panic("Run cannot be called more than once")
}
group, groupCtx := errgroup.WithContext(ctx)
c.group, c.groupCtx = errgroup.WithContext(ctx)
// set up our queue
c.work = c.makeQueue(groupCtx, c.baseBackoff, c.maxBackoff)
c.work = c.makeQueue(c.groupCtx, c.baseBackoff, c.maxBackoff)
// we can now add stuff to the queue from other contexts
close(c.started)
for _, sub := range c.subscriptions {
// store a reference for the closure
sub := sub
group.Go(func() error {
c.group.Go(func() error {
var index uint64
subscription, err := c.publisher.Subscribe(sub.request)
@ -201,14 +245,14 @@ func (c *controller) Run(ctx context.Context) error {
}
for i := 0; i < c.workers; i++ {
group.Go(func() error {
c.group.Go(func() error {
for {
request, shutdown := c.work.Get()
if shutdown {
// Stop working
return nil
}
c.reconcileHandler(groupCtx, request)
c.reconcileHandler(c.groupCtx, request)
// Done is called here because it is required to be called
// when we've finished processing each request
c.work.Done(request)
@ -216,10 +260,57 @@ func (c *controller) Run(ctx context.Context) error {
})
}
<-groupCtx.Done()
<-c.groupCtx.Done()
return nil
}
// AddTrigger allows for triggering a reconciliation request every time that the
// triggering function returns, when the passed in context is canceled
// the trigger must return
func (c *controller) AddTrigger(request Request, trigger func(ctx context.Context) error) {
c.wait()
ctx, cancel := context.WithCancel(c.groupCtx)
c.triggerMutex.Lock()
oldCancel, ok := c.triggers[request]
if ok {
oldCancel()
}
c.triggers[request] = cancel
c.triggerMutex.Unlock()
c.group.Go(func() error {
if err := trigger(ctx); err != nil {
c.logger.Error("error while running trigger, adding re-reconcilation anyway", "error", err)
}
select {
case <-ctx.Done():
return nil
default:
c.work.Add(request)
return nil
}
})
}
// RemoveTrigger removes the triggering function associated with the Request object
func (c *controller) RemoveTrigger(request Request) {
c.triggerMutex.Lock()
cancel, ok := c.triggers[request]
if ok {
cancel()
delete(c.triggers, request)
}
c.triggerMutex.Unlock()
}
func (c *controller) wait() {
c.waitOnce.Do(func() {
<-c.started
})
}
func (c *controller) processEvent(sub subscription, event stream.Event) error {
switch payload := event.Payload.(type) {
case state.EventPayloadConfigEntry:

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/require"
)
@ -415,3 +416,148 @@ func TestConfigEntrySubscriptions(t *testing.T) {
})
}
}
func TestBasicController_Triggers(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
reconciler := newTestReconciler(true)
publisher := stream.NewEventPublisher(0)
go publisher.Run(ctx)
controller := New(publisher, reconciler)
go func() {
require.NoError(t, controller.Run(ctx))
}()
ensureCalled := func(request chan Request, name string) bool {
select {
case req := <-request:
require.Equal(t, structs.IngressGateway, req.Kind)
require.Equal(t, name, req.Name)
return true
case <-time.After(10 * time.Millisecond):
return false
}
}
request := Request{
Kind: structs.IngressGateway,
Name: "foo-1",
}
triggerOneChan := make(chan struct{}, 3)
triggerOne := func(ctx context.Context) error {
select {
case <-triggerOneChan:
return nil
case <-ctx.Done():
return nil
}
}
controller.AddTrigger(request, triggerOne)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
triggerOneChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.True(t, ensureCalled(reconciler.received, "foo-1"))
// do it again
require.False(t, ensureCalled(reconciler.received, "foo-1"))
controller.AddTrigger(request, triggerOne)
triggerOneChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.True(t, ensureCalled(reconciler.received, "foo-1"))
// check with the overwritten trigger
controller.AddTrigger(request, triggerOne)
triggerTwoChan := make(chan struct{}, 2)
triggerTwo := func(ctx context.Context) error {
select {
case <-triggerTwoChan:
return nil
case <-ctx.Done():
return nil
}
}
controller.AddTrigger(request, triggerTwo)
triggerOneChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
triggerTwoChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.True(t, ensureCalled(reconciler.received, "foo-1"))
// remove the trigger and make sure we're not called again
controller.RemoveTrigger(request)
triggerTwoChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
}
func TestDiscoveryChainController(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
reconciler := newTestReconciler(false)
publisher := stream.NewEventPublisher(1 * time.Millisecond)
go publisher.Run(ctx)
// get the store through the FSM since the publisher handlers get registered through it
store := fsm.NewFromDeps(fsm.Deps{
Logger: hclog.New(nil),
NewStateStore: func() *state.Store {
return state.NewStateStoreWithEventPublisher(nil, publisher)
},
Publisher: publisher,
}).State()
controller := New(publisher, reconciler)
go controller.Subscribe(&stream.SubscribeRequest{
Topic: state.EventTopicIngressGateway,
Subject: stream.SubjectWildcard,
}).WithWorkers(10).Run(ctx)
request := Request{
Kind: structs.IngressGateway,
Name: "foo-1",
}
ensureCalled := func(request chan Request, name string) bool {
select {
case req := <-request:
require.Equal(t, structs.IngressGateway, req.Kind)
require.Equal(t, name, req.Name)
return true
case <-time.After(10 * time.Millisecond):
return false
}
}
require.NoError(t, store.EnsureConfigEntry(1, &structs.IngressGatewayConfigEntry{
Kind: structs.IngressGateway,
Name: "foo-1",
}))
require.True(t, ensureCalled(reconciler.received, "foo-1"))
// create the trigger and something that changes in its upstream discovery chain and ensure that we've
// fired the reconciler
ws := memdb.NewWatchSet()
ws.Add(store.AbandonCh())
_, _, err := store.ReadDiscoveryChainConfigEntries(ws, "foo-2", nil)
require.NoError(t, err)
controller.AddTrigger(request, ws.WatchCtx)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
require.NoError(t, store.EnsureConfigEntry(1, &structs.ServiceResolverConfigEntry{
Kind: structs.ServiceResolver,
Name: "foo-2",
}))
require.True(t, ensureCalled(reconciler.received, "foo-1"))
}

View File

@ -3,6 +3,7 @@ package controller
import (
"context"
"sync"
"time"
)
type testReconciler struct {
@ -43,6 +44,12 @@ func (r *testReconciler) setResponse(err error) {
func (r *testReconciler) step() {
r.stepChan <- struct{}{}
}
func (r *testReconciler) stepFor(duration time.Duration) {
select {
case r.stepChan <- struct{}{}:
case <-time.After(duration):
}
}
func (r *testReconciler) stop() {
close(r.stopChan)