Migrate the Endpoints controller to use the controller cache (#20241)

* Add cache resource decoding helpers

* Implement a common package for workload selection facilities. This includes:

   * Controller cache Index
   * ACL hooks
   * Dependency Mapper to go from workload to list of resources which select it
   * Dependency Mapper to go from a resource which selects workloads to all the workloads it selects.

* Update the endpoints controller to use the cache instead of custom mappers.

Co-authored-by: R.B. Boyer <4903+rboyer@users.noreply.github.com>
This commit is contained in:
Matt Keeler 2024-01-18 17:52:52 -05:00 committed by GitHub
parent d641998641
commit 59cb12c798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1778 additions and 757 deletions

View File

@ -225,7 +225,7 @@ func (m *MockAuthorizer) ServiceReadAll(ctx *AuthorizerContext) EnforcementDecis
} }
func (m *MockAuthorizer) ServiceReadPrefix(prefix string, ctx *AuthorizerContext) EnforcementDecision { func (m *MockAuthorizer) ServiceReadPrefix(prefix string, ctx *AuthorizerContext) EnforcementDecision {
ret := m.Called(ctx) ret := m.Called(prefix, ctx)
return ret.Get(0).(EnforcementDecision) return ret.Get(0).(EnforcementDecision)
} }

View File

@ -6,14 +6,14 @@ package helpers
import ( import (
"testing" "testing"
"github.com/hashicorp/consul/internal/catalog"
"github.com/hashicorp/consul/internal/catalog/internal/testhelpers" "github.com/hashicorp/consul/internal/catalog/internal/testhelpers"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
func RunWorkloadSelectingTypeACLsTests[T catalog.WorkloadSelecting](t *testing.T, typ *pbresource.Type, func RunWorkloadSelectingTypeACLsTests[T workloadselector.WorkloadSelecting](t *testing.T, typ *pbresource.Type,
getData func(selector *pbcatalog.WorkloadSelector) T, getData func(selector *pbcatalog.WorkloadSelector) T,
registerFunc func(registry resource.Registry), registerFunc func(registry resource.Registry),
) { ) {

View File

@ -115,10 +115,10 @@ func VerifyCatalogV2Beta1IntegrationTestResults(t *testing.T, client pbresource.
}) })
testutil.RunStep(t, "service-reconciliation", func(t *testing.T) { testutil.RunStep(t, "service-reconciliation", func(t *testing.T) {
c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "foo").ID(), endpoints.StatusKey, endpoints.ConditionUnmanaged) c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "foo").ID(), endpoints.ControllerID, endpoints.ConditionUnmanaged)
c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "api").ID(), endpoints.StatusKey, endpoints.ConditionManaged) c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "api").ID(), endpoints.ControllerID, endpoints.ConditionManaged)
c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "http-api").ID(), endpoints.StatusKey, endpoints.ConditionManaged) c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "http-api").ID(), endpoints.ControllerID, endpoints.ConditionManaged)
c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "grpc-api").ID(), endpoints.StatusKey, endpoints.ConditionManaged) c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "grpc-api").ID(), endpoints.ControllerID, endpoints.ConditionManaged)
}) })
testutil.RunStep(t, "service-endpoints-generation", func(t *testing.T) { testutil.RunStep(t, "service-endpoints-generation", func(t *testing.T) {

View File

@ -13,7 +13,6 @@ import (
"github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/catalog/internal/types"
"github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/mappers/selectiontracker"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
@ -29,7 +28,7 @@ var (
WorkloadHealthConditions = workloadhealth.WorkloadConditions WorkloadHealthConditions = workloadhealth.WorkloadConditions
WorkloadAndNodeHealthConditions = workloadhealth.NodeAndWorkloadConditions WorkloadAndNodeHealthConditions = workloadhealth.NodeAndWorkloadConditions
EndpointsStatusKey = endpoints.StatusKey EndpointsStatusKey = endpoints.ControllerID
EndpointsStatusConditionEndpointsManaged = endpoints.StatusConditionEndpointsManaged EndpointsStatusConditionEndpointsManaged = endpoints.StatusConditionEndpointsManaged
EndpointsStatusConditionManaged = endpoints.ConditionManaged EndpointsStatusConditionManaged = endpoints.ConditionManaged
EndpointsStatusConditionUnmanaged = endpoints.ConditionUnmanaged EndpointsStatusConditionUnmanaged = endpoints.ConditionUnmanaged
@ -47,12 +46,6 @@ var (
FailoverStatusConditionAcceptedUsingMeshDestinationPortReason = failover.UsingMeshDestinationPortReason FailoverStatusConditionAcceptedUsingMeshDestinationPortReason = failover.UsingMeshDestinationPortReason
) )
type WorkloadSelecting = types.WorkloadSelecting
func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks {
return types.ACLHooksForWorkloadSelectingType[T]()
}
// RegisterTypes adds all resource types within the "catalog" API group // RegisterTypes adds all resource types within the "catalog" API group
// to the given type registry // to the given type registry
func RegisterTypes(r resource.Registry) { func RegisterTypes(r resource.Registry) {
@ -63,8 +56,7 @@ type ControllerDependencies = controllers.Dependencies
func DefaultControllerDependencies() ControllerDependencies { func DefaultControllerDependencies() ControllerDependencies {
return ControllerDependencies{ return ControllerDependencies{
EndpointsWorkloadMapper: selectiontracker.New(), FailoverMapper: failovermapper.New(),
FailoverMapper: failovermapper.New(),
} }
} }

View File

@ -11,7 +11,9 @@ import (
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/controller/cache"
"github.com/hashicorp/consul/internal/controller/dependency" "github.com/hashicorp/consul/internal/controller/dependency"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
@ -20,51 +22,52 @@ import (
const ( const (
endpointsMetaManagedBy = "managed-by-controller" endpointsMetaManagedBy = "managed-by-controller"
selectedWorkloadsIndexName = "selected-workloads"
) )
// The WorkloadMapper interface is used to provide an implementation around being able type (
// to map a watch even for a Workload resource and translate it to reconciliation requests DecodedWorkload = resource.DecodedResource[*pbcatalog.Workload]
type WorkloadMapper interface { DecodedService = resource.DecodedResource[*pbcatalog.Service]
// MapWorkload conforms to the controller.DependencyMapper signature. Given a Workload DecodedServiceEndpoints = resource.DecodedResource[*pbcatalog.ServiceEndpoints]
// resource it should report the resource IDs that have selected the workload. )
MapWorkload(context.Context, controller.Runtime, *pbresource.Resource) ([]controller.Request, error)
// TrackIDForSelector should be used to associate the specified WorkloadSelector with
// the given resource ID. Future calls to MapWorkload
TrackIDForSelector(*pbresource.ID, *pbcatalog.WorkloadSelector)
// UntrackID should be used to inform the tracker to forget about the specified ID
UntrackID(*pbresource.ID)
}
// ServiceEndpointsController creates a controller to perform automatic endpoint management for // ServiceEndpointsController creates a controller to perform automatic endpoint management for
// services. // services.
func ServiceEndpointsController(workloadMap WorkloadMapper) *controller.Controller { func ServiceEndpointsController() *controller.Controller {
if workloadMap == nil { return controller.NewController(ControllerID, pbcatalog.ServiceEndpointsType).
panic("No WorkloadMapper was provided to the ServiceEndpointsController constructor") WithWatch(pbcatalog.ServiceType,
} // ServiceEndpoints are name-aligned with the Service type
dependency.ReplaceType(pbcatalog.ServiceEndpointsType),
return controller.NewController(StatusKey, pbcatalog.ServiceEndpointsType). // This cache index keeps track of the relationship between WorkloadSelectors (and the workload names and prefixes
WithWatch(pbcatalog.ServiceType, dependency.ReplaceType(pbcatalog.ServiceEndpointsType)). // they include) and Services. This allows us to efficiently find all services and service endpoints that are
WithWatch(pbcatalog.WorkloadType, workloadMap.MapWorkload). // are affected by the change to a workload.
WithReconciler(newServiceEndpointsReconciler(workloadMap)) workloadselector.Index[*pbcatalog.Service](selectedWorkloadsIndexName)).
WithWatch(pbcatalog.WorkloadType,
// The cache index is kept on the Service type but we need to translate events for ServiceEndpoints.
// Therefore we need to wrap the mapper from the workloadselector package with one which will
// replace the request types of Service with ServiceEndpoints.
dependency.WrapAndReplaceType(
pbcatalog.ServiceEndpointsType,
// This mapper will use the selected-workloads index to find all Services which select this
// workload by exact name or by prefix.
workloadselector.MapWorkloadsToSelectors(pbcatalog.ServiceType, selectedWorkloadsIndexName),
),
).
WithReconciler(newServiceEndpointsReconciler())
} }
type serviceEndpointsReconciler struct { type serviceEndpointsReconciler struct{}
workloadMap WorkloadMapper
}
func newServiceEndpointsReconciler(workloadMap WorkloadMapper) *serviceEndpointsReconciler { func newServiceEndpointsReconciler() *serviceEndpointsReconciler {
return &serviceEndpointsReconciler{ return &serviceEndpointsReconciler{}
workloadMap: workloadMap,
}
} }
// Reconcile will reconcile one ServiceEndpoints resource in response to some event. // Reconcile will reconcile one ServiceEndpoints resource in response to some event.
func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error { func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error {
// The runtime is passed by value so replacing it here for the remainder of this // The runtime is passed by value so replacing it here for the remainder of this
// reconciliation request processing will not affect future invocations. // reconciliation request processing will not affect future invocations.
rt.Logger = rt.Logger.With("resource-id", req.ID, "controller", StatusKey) rt.Logger = rt.Logger.With("resource-id", req.ID)
rt.Logger.Trace("reconciling service endpoints") rt.Logger.Trace("reconciling service endpoints")
@ -76,21 +79,16 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle
} }
// First we read and unmarshal the service // First we read and unmarshal the service
service, err := cache.GetDecoded[*pbcatalog.Service](rt.Cache, pbcatalog.ServiceType, "id", serviceID)
serviceData, err := getServiceData(ctx, rt, serviceID)
if err != nil { if err != nil {
rt.Logger.Error("error retrieving corresponding Service", "error", err) rt.Logger.Error("error retrieving corresponding Service", "error", err)
return err return err
} }
// Check if the service exists. If it doesn't we can avoid a bunch of other work. // Check if the service exists. If it doesn't we can avoid a bunch of other work.
if serviceData == nil { if service == nil {
rt.Logger.Trace("service has been deleted") rt.Logger.Trace("service has been deleted")
// The service was deleted so we need to update the WorkloadMapper to tell it to
// stop tracking this service
r.workloadMap.UntrackID(req.ID)
// Note that because we configured ServiceEndpoints to be owned by the service, // Note that because we configured ServiceEndpoints to be owned by the service,
// the service endpoints object should eventually be automatically deleted. // the service endpoints object should eventually be automatically deleted.
// There is no reason to attempt deletion here. // There is no reason to attempt deletion here.
@ -100,7 +98,7 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle
// Now read and unmarshal the endpoints. We don't need this data just yet but all // Now read and unmarshal the endpoints. We don't need this data just yet but all
// code paths from this point on will need this regardless of branching so we pull // code paths from this point on will need this regardless of branching so we pull
// it now. // it now.
endpointsData, err := getEndpointsData(ctx, rt, endpointsID) endpoints, err := cache.GetDecoded[*pbcatalog.ServiceEndpoints](rt.Cache, pbcatalog.ServiceEndpointsType, "id", endpointsID)
if err != nil { if err != nil {
rt.Logger.Error("error retrieving existing endpoints", "error", err) rt.Logger.Error("error retrieving existing endpoints", "error", err)
return err return err
@ -108,40 +106,29 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle
var statusConditions []*pbresource.Condition var statusConditions []*pbresource.Condition
if serviceUnderManagement(serviceData.service) { if serviceUnderManagement(service.Data) {
rt.Logger.Trace("service is enabled for automatic endpoint management") rt.Logger.Trace("service is enabled for automatic endpoint management")
// This service should have its endpoints automatically managed // This service should have its endpoints automatically managed
statusConditions = append(statusConditions, ConditionManaged) statusConditions = append(statusConditions, ConditionManaged)
// Inform the WorkloadMapper to track this service and its selectors. So // Now read and unmarshal all workloads selected by the service.
// future workload updates that would be matched by the services selectors workloads, err := workloadselector.GetWorkloadsWithSelector(rt.Cache, service)
// cause this service to be rereconciled.
r.workloadMap.TrackIDForSelector(req.ID, serviceData.service.GetWorkloads())
// Now read and unmarshal all workloads selected by the service. It is imperative
// that this happens after we notify the selection tracker to be tracking that
// selection criteria. If the order were reversed we could potentially miss
// workload creations that should be selected if they happen after gathering
// the workloads but before tracking the selector. Tracking first ensures that
// any event that happens after that would get mapped to an event for these
// endpoints.
workloadData, err := getWorkloadData(ctx, rt, serviceData)
if err != nil { if err != nil {
rt.Logger.Trace("error retrieving selected workloads", "error", err) rt.Logger.Trace("error retrieving selected workloads", "error", err)
return err return err
} }
// Calculate the latest endpoints from the already gathered workloads // Calculate the latest endpoints from the already gathered workloads
latestEndpoints := workloadsToEndpoints(serviceData.service, workloadData) latestEndpoints := workloadsToEndpoints(service.Data, workloads)
// Add status // Add status
if endpointsData != nil { if endpoints != nil {
statusConditions = append(statusConditions, statusConditions = append(statusConditions,
workloadIdentityStatusFromEndpoints(latestEndpoints)) workloadIdentityStatusFromEndpoints(latestEndpoints))
} }
// Before writing the endpoints actually check to see if they are changed // Before writing the endpoints actually check to see if they are changed
if endpointsData == nil || !proto.Equal(endpointsData.endpoints, latestEndpoints) { if endpoints == nil || !proto.Equal(endpoints.Data, latestEndpoints) {
rt.Logger.Trace("endpoints have changed") rt.Logger.Trace("endpoints have changed")
// First encode the endpoints data as an Any type. // First encode the endpoints data as an Any type.
@ -158,9 +145,9 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle
_, err = rt.Client.Write(ctx, &pbresource.WriteRequest{ _, err = rt.Client.Write(ctx, &pbresource.WriteRequest{
Resource: &pbresource.Resource{ Resource: &pbresource.Resource{
Id: req.ID, Id: req.ID,
Owner: serviceData.resource.Id, Owner: service.Id,
Metadata: map[string]string{ Metadata: map[string]string{
endpointsMetaManagedBy: StatusKey, endpointsMetaManagedBy: ControllerID,
}, },
Data: endpointData, Data: endpointData,
}, },
@ -177,20 +164,16 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle
// This service is not having its endpoints automatically managed // This service is not having its endpoints automatically managed
statusConditions = append(statusConditions, ConditionUnmanaged) statusConditions = append(statusConditions, ConditionUnmanaged)
// Inform the WorkloadMapper that it no longer needs to track this service
// as it is no longer under endpoint management
r.workloadMap.UntrackID(req.ID)
// Delete the managed ServiceEndpoints if necessary if the metadata would // Delete the managed ServiceEndpoints if necessary if the metadata would
// indicate that they were previously managed by this controller // indicate that they were previously managed by this controller
if endpointsData != nil && endpointsData.resource.Metadata[endpointsMetaManagedBy] == StatusKey { if endpoints != nil && endpoints.Metadata[endpointsMetaManagedBy] == ControllerID {
rt.Logger.Trace("removing previous managed endpoints") rt.Logger.Trace("removing previous managed endpoints")
// This performs a CAS deletion to protect against the case where the user // This performs a CAS deletion to protect against the case where the user
// has overwritten the endpoints since we fetched them. // has overwritten the endpoints since we fetched them.
_, err := rt.Client.Delete(ctx, &pbresource.DeleteRequest{ _, err := rt.Client.Delete(ctx, &pbresource.DeleteRequest{
Id: endpointsData.resource.Id, Id: endpoints.Id,
Version: endpointsData.resource.Version, Version: endpoints.Version,
}) })
// Potentially we could look for CAS failures by checking if the gRPC // Potentially we could look for CAS failures by checking if the gRPC
@ -209,17 +192,17 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle
// whether we are automatically managing the endpoints to set expectations // whether we are automatically managing the endpoints to set expectations
// for that object existing or not. // for that object existing or not.
newStatus := &pbresource.Status{ newStatus := &pbresource.Status{
ObservedGeneration: serviceData.resource.Generation, ObservedGeneration: service.Generation,
Conditions: statusConditions, Conditions: statusConditions,
} }
// If the status is unchanged then we should return and avoid the unnecessary write // If the status is unchanged then we should return and avoid the unnecessary write
if resource.EqualStatus(serviceData.resource.Status[StatusKey], newStatus, false) { if resource.EqualStatus(service.Status[ControllerID], newStatus, false) {
return nil return nil
} }
_, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{ _, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{
Id: serviceData.resource.Id, Id: service.Id,
Key: StatusKey, Key: ControllerID,
Status: newStatus, Status: newStatus,
}) })
@ -275,7 +258,7 @@ func serviceUnderManagement(svc *pbcatalog.Service) bool {
} }
// workloadsToEndpoints will translate the Workload resources into a ServiceEndpoints resource // workloadsToEndpoints will translate the Workload resources into a ServiceEndpoints resource
func workloadsToEndpoints(svc *pbcatalog.Service, workloads []*workloadData) *pbcatalog.ServiceEndpoints { func workloadsToEndpoints(svc *pbcatalog.Service, workloads []*DecodedWorkload) *pbcatalog.ServiceEndpoints {
var endpoints []*pbcatalog.Endpoint var endpoints []*pbcatalog.Endpoint
for _, workload := range workloads { for _, workload := range workloads {
@ -300,8 +283,8 @@ func workloadsToEndpoints(svc *pbcatalog.Service, workloads []*workloadData) *pb
// have reconciled the workloads health and stored it within the resources Status field. // have reconciled the workloads health and stored it within the resources Status field.
// Any unreconciled workload health will be represented in the ServiceEndpoints with // Any unreconciled workload health will be represented in the ServiceEndpoints with
// the ANY health status. // the ANY health status.
func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.Endpoint { func workloadToEndpoint(svc *pbcatalog.Service, workload *DecodedWorkload) *pbcatalog.Endpoint {
health := determineWorkloadHealth(data.resource) health := determineWorkloadHealth(workload.Resource)
endpointPorts := make(map[string]*pbcatalog.WorkloadPort) endpointPorts := make(map[string]*pbcatalog.WorkloadPort)
@ -309,7 +292,7 @@ func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.E
// one of the services ports are included. Ports with a protocol mismatch // one of the services ports are included. Ports with a protocol mismatch
// between the service and workload will be excluded as well. // between the service and workload will be excluded as well.
for _, svcPort := range svc.Ports { for _, svcPort := range svc.Ports {
workloadPort, found := data.workload.Ports[svcPort.TargetPort] workloadPort, found := workload.Data.Ports[svcPort.TargetPort]
if !found { if !found {
// this workload doesn't have this port so ignore it // this workload doesn't have this port so ignore it
continue continue
@ -336,7 +319,7 @@ func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.E
// address list. If some but not all of its ports are served, then the list // address list. If some but not all of its ports are served, then the list
// of ports will be reduced to just the intersection of the service ports // of ports will be reduced to just the intersection of the service ports
// and the workload addresses ports // and the workload addresses ports
for _, addr := range data.workload.Addresses { for _, addr := range workload.Data.Addresses {
var ports []string var ports []string
if len(addr.Ports) > 0 { if len(addr.Ports) > 0 {
@ -386,12 +369,12 @@ func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.E
} }
return &pbcatalog.Endpoint{ return &pbcatalog.Endpoint{
TargetRef: data.resource.Id, TargetRef: workload.Id,
HealthStatus: health, HealthStatus: health,
Addresses: workloadAddrs, Addresses: workloadAddrs,
Ports: endpointPorts, Ports: endpointPorts,
Identity: data.workload.Identity, Identity: workload.Data.Identity,
Dns: data.workload.Dns, Dns: workload.Data.Dns,
} }
} }

View File

@ -15,7 +15,6 @@ import (
"github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth"
"github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/catalog/internal/types"
"github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/resource/mappers/selectiontracker"
"github.com/hashicorp/consul/internal/resource/resourcetest" "github.com/hashicorp/consul/internal/resource/resourcetest"
rtest "github.com/hashicorp/consul/internal/resource/resourcetest" rtest "github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
@ -63,26 +62,23 @@ func TestWorkloadsToEndpoints(t *testing.T) {
} }
// Build out the workloads. // Build out the workloads.
workloads := []*workloadData{ workloads := []*DecodedWorkload{
{ rtest.MustDecode[*pbcatalog.Workload](
// this workload should result in an endpoints t,
resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). rtest.Resource(pbcatalog.WorkloadType, "foo").
WithData(t, workloadData1). WithData(t, workloadData1).
Build(), Build()),
workload: workloadData1,
}, rtest.MustDecode[*pbcatalog.Workload](
{ t,
// this workload should be filtered out rtest.Resource(pbcatalog.WorkloadType, "bar").
resource: rtest.Resource(pbcatalog.WorkloadType, "bar").
WithData(t, workloadData2). WithData(t, workloadData2).
Build(), Build()),
workload: workloadData2,
},
} }
endpoints := workloadsToEndpoints(service, workloads) endpoints := workloadsToEndpoints(service, workloads)
require.Len(t, endpoints.Endpoints, 1) require.Len(t, endpoints.Endpoints, 1)
prototest.AssertDeepEqual(t, workloads[0].resource.Id, endpoints.Endpoints[0].TargetRef) prototest.AssertDeepEqual(t, workloads[0].Id, endpoints.Endpoints[0].TargetRef)
} }
func TestWorkloadToEndpoint(t *testing.T) { func TestWorkloadToEndpoint(t *testing.T) {
@ -135,15 +131,12 @@ func TestWorkloadToEndpoint(t *testing.T) {
}, },
} }
data := &workloadData{ data := rtest.MustDecode[*pbcatalog.Workload](t, rtest.Resource(pbcatalog.WorkloadType, "foo").
resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). WithData(t, workload).
WithData(t, workload). Build())
Build(),
workload: workload,
}
expected := &pbcatalog.Endpoint{ expected := &pbcatalog.Endpoint{
TargetRef: data.resource.Id, TargetRef: data.Id,
Addresses: []*pbcatalog.WorkloadAddress{ Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1", Ports: []string{"http"}}, {Host: "127.0.0.1", Ports: []string{"http"}},
{Host: "198.18.1.1", Ports: []string{"http"}}, {Host: "198.18.1.1", Ports: []string{"http"}},
@ -189,12 +182,11 @@ func TestWorkloadToEndpoint_AllAddressesFiltered(t *testing.T) {
}, },
} }
data := &workloadData{ data := rtest.MustDecode[*pbcatalog.Workload](
resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). t,
rtest.Resource(pbcatalog.WorkloadType, "foo").
WithData(t, workload). WithData(t, workload).
Build(), Build())
workload: workload,
}
require.Nil(t, workloadToEndpoint(service, data)) require.Nil(t, workloadToEndpoint(service, data))
} }
@ -218,15 +210,14 @@ func TestWorkloadToEndpoint_MissingWorkloadProtocol(t *testing.T) {
}, },
} }
data := &workloadData{ data := rtest.MustDecode[*pbcatalog.Workload](
resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). t,
rtest.Resource(pbcatalog.WorkloadType, "foo").
WithData(t, workload). WithData(t, workload).
Build(), Build())
workload: workload,
}
expected := &pbcatalog.Endpoint{ expected := &pbcatalog.Endpoint{
TargetRef: data.resource.Id, TargetRef: data.Id,
Addresses: []*pbcatalog.WorkloadAddress{ Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1", Ports: []string{"test-port"}}, {Host: "127.0.0.1", Ports: []string{"test-port"}},
}, },
@ -453,9 +444,8 @@ type controllerSuite struct {
client *rtest.Client client *rtest.Client
rt controller.Runtime rt controller.Runtime
tracker *selectiontracker.WorkloadSelectionTracker ctl *controller.TestController
reconciler *serviceEndpointsReconciler tenancies []*pbresource.Tenancy
tenancies []*pbresource.Tenancy
} }
func (suite *controllerSuite) SetupTest() { func (suite *controllerSuite) SetupTest() {
@ -465,22 +455,10 @@ func (suite *controllerSuite) SetupTest() {
WithRegisterFns(types.Register). WithRegisterFns(types.Register).
WithTenancies(suite.tenancies...). WithTenancies(suite.tenancies...).
Run(suite.T()) Run(suite.T())
suite.rt = controller.Runtime{ suite.ctl = controller.NewTestController(ServiceEndpointsController(), client).
Client: client, WithLogger(testutil.Logger(suite.T()))
Logger: testutil.Logger(suite.T()), suite.rt = suite.ctl.Runtime()
} suite.client = rtest.NewClient(suite.rt.Client)
suite.client = rtest.NewClient(client)
suite.tracker = selectiontracker.New()
suite.reconciler = newServiceEndpointsReconciler(suite.tracker)
}
func (suite *controllerSuite) requireTracking(workload *pbresource.Resource, ids ...*pbresource.ID) {
reqs, err := suite.tracker.MapWorkload(suite.ctx, suite.rt, workload)
require.NoError(suite.T(), err)
require.Len(suite.T(), reqs, len(ids))
for _, id := range ids {
prototest.AssertContainsElement(suite.T(), reqs, controller.Request{ID: id})
}
} }
func (suite *controllerSuite) requireEndpoints(resource *pbresource.Resource, expected ...*pbcatalog.Endpoint) { func (suite *controllerSuite) requireEndpoints(resource *pbresource.Resource, expected ...*pbcatalog.Endpoint) {
@ -491,33 +469,14 @@ func (suite *controllerSuite) requireEndpoints(resource *pbresource.Resource, ex
} }
func (suite *controllerSuite) TestReconcile_ServiceNotFound() { func (suite *controllerSuite) TestReconcile_ServiceNotFound() {
// This test's purpose is to ensure that when we are reconciling // This test really only checks that the Reconcile call will not panic or otherwise error
// endpoints for a service that no longer exists, we stop // when the request is for an endpoints object whose corresponding service does not exist.
// tracking the endpoints resource ID in the selection tracker.
// generate a workload resource to use for checking if it maps
// to a service endpoints object
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
workload := rtest.Resource(pbcatalog.WorkloadType, "foo").WithTenancy(tenancy).Build()
// ensure that the tracker knows about the service prior to
// calling reconcile so that we can ensure it removes tracking
id := rtest.Resource(pbcatalog.ServiceEndpointsType, "not-found").WithTenancy(tenancy).ID() id := rtest.Resource(pbcatalog.ServiceEndpointsType, "not-found").WithTenancy(tenancy).ID()
suite.tracker.TrackIDForSelector(id, &pbcatalog.WorkloadSelector{Prefixes: []string{""}})
// verify that mapping the workload to service endpoints returns a // Because the endpoints don't exist, this reconcile call not error but also shouldn't do anything useful.
// non-empty list prior to reconciliation which should remove the err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: id})
// tracking.
suite.requireTracking(workload, id)
// Because the endpoints don't exist, this reconcile call should
// cause tracking of the endpoints to be removed
err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: id})
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
// Now ensure that the tracking was removed
suite.requireTracking(workload)
}) })
} }
@ -539,10 +498,10 @@ func (suite *controllerSuite) TestReconcile_NoSelector_NoEndpoints() {
endpointsID := rtest.Resource(pbcatalog.ServiceEndpointsType, "test").WithTenancy(tenancy).ID() endpointsID := rtest.Resource(pbcatalog.ServiceEndpointsType, "test").WithTenancy(tenancy).ID()
err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpointsID}) err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpointsID})
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionUnmanaged) suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionUnmanaged)
}) })
} }
@ -565,13 +524,13 @@ func (suite *controllerSuite) TestReconcile_NoSelector_ManagedEndpoints() {
WithTenancy(tenancy). WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.ServiceEndpoints{}). WithData(suite.T(), &pbcatalog.ServiceEndpoints{}).
// this marks these endpoints as under management // this marks these endpoints as under management
WithMeta(endpointsMetaManagedBy, StatusKey). WithMeta(endpointsMetaManagedBy, ControllerID).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpoints.Id}) err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpoints.Id})
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
// the status should indicate the services endpoints are not being managed // the status should indicate the services endpoints are not being managed
suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionUnmanaged) suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionUnmanaged)
// endpoints under management should be deleted // endpoints under management should be deleted
suite.client.RequireResourceNotFound(suite.T(), endpoints.Id) suite.client.RequireResourceNotFound(suite.T(), endpoints.Id)
}) })
@ -597,10 +556,10 @@ func (suite *controllerSuite) TestReconcile_NoSelector_UnmanagedEndpoints() {
WithData(suite.T(), &pbcatalog.ServiceEndpoints{}). WithData(suite.T(), &pbcatalog.ServiceEndpoints{}).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpoints.Id}) err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpoints.Id})
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
// the status should indicate the services endpoints are not being managed // the status should indicate the services endpoints are not being managed
suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionUnmanaged) suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionUnmanaged)
// unmanaged endpoints should not be deleted when the service is unmanaged // unmanaged endpoints should not be deleted when the service is unmanaged
suite.client.RequireResourceExists(suite.T(), endpoints.Id) suite.client.RequireResourceExists(suite.T(), endpoints.Id)
}) })
@ -635,14 +594,14 @@ func (suite *controllerSuite) TestReconcile_Managed_NoPreviousEndpoints() {
}). }).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpointsID}) err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpointsID})
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
// Verify that the services status has been set to indicate endpoints are automatically managed. // Verify that the services status has been set to indicate endpoints are automatically managed.
suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionManaged) suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionManaged)
// The service endpoints metadata should include our tag to indcate it was generated by this controller // The service endpoints metadata should include our tag to indcate it was generated by this controller
res := suite.client.RequireResourceMeta(suite.T(), endpointsID, endpointsMetaManagedBy, StatusKey) res := suite.client.RequireResourceMeta(suite.T(), endpointsID, endpointsMetaManagedBy, ControllerID)
var endpoints pbcatalog.ServiceEndpoints var endpoints pbcatalog.ServiceEndpoints
err = res.Data.UnmarshalTo(&endpoints) err = res.Data.UnmarshalTo(&endpoints)
@ -688,11 +647,11 @@ func (suite *controllerSuite) TestReconcile_Managed_ExistingEndpoints() {
}). }).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpoints.Id}) err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpoints.Id})
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionManaged) suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionManaged)
res := suite.client.RequireResourceMeta(suite.T(), endpoints.Id, endpointsMetaManagedBy, StatusKey) res := suite.client.RequireResourceMeta(suite.T(), endpoints.Id, endpointsMetaManagedBy, ControllerID)
var newEndpoints pbcatalog.ServiceEndpoints var newEndpoints pbcatalog.ServiceEndpoints
err = res.Data.UnmarshalTo(&newEndpoints) err = res.Data.UnmarshalTo(&newEndpoints)
@ -711,7 +670,7 @@ func (suite *controllerSuite) TestController() {
// Run the controller manager // Run the controller manager
mgr := controller.NewManager(suite.client, suite.rt.Logger) mgr := controller.NewManager(suite.client, suite.rt.Logger)
mgr.Register(ServiceEndpointsController(suite.tracker)) mgr.Register(ServiceEndpointsController())
mgr.SetRaftLeader(true) mgr.SetRaftLeader(true)
go mgr.Run(suite.ctx) go mgr.Run(suite.ctx)
@ -731,10 +690,10 @@ func (suite *controllerSuite) TestController() {
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
// Wait for the controller to record that the endpoints are being managed // Wait for the controller to record that the endpoints are being managed
res := suite.client.WaitForReconciliation(suite.T(), service.Id, StatusKey) res := suite.client.WaitForReconciliation(suite.T(), service.Id, ControllerID)
// Check that the services status was updated accordingly // Check that the services status was updated accordingly
rtest.RequireStatusCondition(suite.T(), res, StatusKey, ConditionManaged) rtest.RequireStatusCondition(suite.T(), res, ControllerID, ConditionManaged)
rtest.RequireStatusCondition(suite.T(), res, StatusKey, ConditionIdentitiesNotFound) rtest.RequireStatusCondition(suite.T(), res, ControllerID, ConditionIdentitiesNotFound)
// Check that the endpoints resource exists and contains 0 endpoints // Check that the endpoints resource exists and contains 0 endpoints
endpointsID := rtest.Resource(pbcatalog.ServiceEndpointsType, "api").WithTenancy(tenancy).ID() endpointsID := rtest.Resource(pbcatalog.ServiceEndpointsType, "api").WithTenancy(tenancy).ID()
@ -755,7 +714,7 @@ func (suite *controllerSuite) TestController() {
}). }).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
suite.client.WaitForStatusCondition(suite.T(), service.Id, StatusKey, suite.client.WaitForStatusCondition(suite.T(), service.Id, ControllerID,
ConditionIdentitiesFound([]string{"api"})) ConditionIdentitiesFound([]string{"api"}))
// Wait for the endpoints to be regenerated // Wait for the endpoints to be regenerated
@ -818,7 +777,7 @@ func (suite *controllerSuite) TestController() {
}). }).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
suite.client.WaitForStatusCondition(suite.T(), service.Id, StatusKey, ConditionIdentitiesFound([]string{"endpoints-api-identity"})) suite.client.WaitForStatusCondition(suite.T(), service.Id, ControllerID, ConditionIdentitiesFound([]string{"endpoints-api-identity"}))
// Verify that the generated endpoints now contain the workload // Verify that the generated endpoints now contain the workload
endpoints = suite.client.WaitForNewVersion(suite.T(), endpointsID, endpoints.Version) endpoints = suite.client.WaitForNewVersion(suite.T(), endpointsID, endpoints.Version)
@ -850,7 +809,7 @@ func (suite *controllerSuite) TestController() {
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
// Wait for the service status' observed generation to get bumped // Wait for the service status' observed generation to get bumped
service = suite.client.WaitForReconciliation(suite.T(), service.Id, StatusKey) service = suite.client.WaitForReconciliation(suite.T(), service.Id, ControllerID)
// Verify that the endpoints were not regenerated // Verify that the endpoints were not regenerated
suite.client.RequireVersionUnchanged(suite.T(), endpointsID, endpoints.Version) suite.client.RequireVersionUnchanged(suite.T(), endpointsID, endpoints.Version)
@ -891,8 +850,8 @@ func (suite *controllerSuite) TestController() {
}). }).
Write(suite.T(), suite.client) Write(suite.T(), suite.client)
res = suite.client.WaitForReconciliation(suite.T(), service.Id, StatusKey) res = suite.client.WaitForReconciliation(suite.T(), service.Id, ControllerID)
rtest.RequireStatusCondition(suite.T(), res, StatusKey, ConditionUnmanaged) rtest.RequireStatusCondition(suite.T(), res, ControllerID, ConditionUnmanaged)
// Verify that the endpoints were deleted // Verify that the endpoints were deleted
suite.client.RequireResourceNotFound(suite.T(), endpointsID) suite.client.RequireResourceNotFound(suite.T(), endpointsID)

View File

@ -1,189 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package endpoints
import (
"context"
"fmt"
"sort"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)
type serviceData struct {
resource *pbresource.Resource
service *pbcatalog.Service
}
type endpointsData struct {
resource *pbresource.Resource
endpoints *pbcatalog.ServiceEndpoints
}
type workloadData struct {
resource *pbresource.Resource
workload *pbcatalog.Workload
}
// getServiceData will read the service with the given ID and unmarshal the
// Data field. The return value is a struct that contains the retrieved
// resource as well as the unmarshalled form. If the resource doesn't
// exist, nil will be returned. Any other error either with retrieving
// the resource or unmarshalling it will cause the error to be returned
// to the caller
func getServiceData(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*serviceData, error) {
rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: id})
switch {
case status.Code(err) == codes.NotFound:
return nil, nil
case err != nil:
return nil, err
}
var service pbcatalog.Service
err = rsp.Resource.Data.UnmarshalTo(&service)
if err != nil {
return nil, resource.NewErrDataParse(&service, err)
}
return &serviceData{resource: rsp.Resource, service: &service}, nil
}
// getEndpointsData will read the endpoints with the given ID and unmarshal the
// Data field. The return value is a struct that contains the retrieved
// resource as well as the unmsashalled form. If the resource doesn't
// exist, nil will be returned. Any other error either with retrieving
// the resource or unmarshalling it will cause the error to be returned
// to the caller
func getEndpointsData(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*endpointsData, error) {
rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: id})
switch {
case status.Code(err) == codes.NotFound:
return nil, nil
case err != nil:
return nil, err
}
var endpoints pbcatalog.ServiceEndpoints
err = rsp.Resource.Data.UnmarshalTo(&endpoints)
if err != nil {
return nil, resource.NewErrDataParse(&endpoints, err)
}
return &endpointsData{resource: rsp.Resource, endpoints: &endpoints}, nil
}
// getWorkloadData will retrieve all workloads for the given services selector
// and unmarhshal them, returning a slic of objects hold both the resource and
// unmarshaled forms. Unmarshalling errors, or other resource service errors
// will be returned to the caller.
func getWorkloadData(ctx context.Context, rt controller.Runtime, svc *serviceData) ([]*workloadData, error) {
workloadResources, err := gatherWorkloadsForService(ctx, rt, svc)
if err != nil {
return nil, err
}
var results []*workloadData
for _, res := range workloadResources {
var workload pbcatalog.Workload
err = res.Data.UnmarshalTo(&workload)
if err != nil {
return nil, resource.NewErrDataParse(&workload, err)
}
results = append(results, &workloadData{resource: res, workload: &workload})
}
return results, nil
}
// gatherWorkloadsForService will retrieve all the unique workloads for a given selector.
// NotFound errors for workloads selected by Name will be ignored. Any other
// resource service errors will be returned to the caller. Prior to returning
// the slice of resources, they will be sorted by name. The consistent ordering
// will allow callers to diff two versions of the data to determine if anything
// has changed but it also will make testing a little easier.
func gatherWorkloadsForService(ctx context.Context, rt controller.Runtime, svc *serviceData) ([]*pbresource.Resource, error) {
var workloads []*pbresource.Resource
sel := svc.service.GetWorkloads()
// this map will track all the gathered workloads by name, this is mainly to deduplicate workloads if they
// are specified multiple times throughout the list of selection criteria
workloadNames := make(map[string]struct{})
// First gather all the prefix matched workloads. We could do this second but by doing
// it first its possible we can avoid some resource service calls to read individual
// workloads selected by name if they are also matched by a prefix.
for _, prefix := range sel.GetPrefixes() {
rsp, err := rt.Client.List(ctx, &pbresource.ListRequest{
Type: pbcatalog.WorkloadType,
Tenancy: svc.resource.Id.Tenancy,
NamePrefix: prefix,
})
if err != nil {
return nil, err
}
// append all workloads in the list response to our list of all selected workloads
for _, workload := range rsp.Resources {
// ignore duplicate workloads
if _, found := workloadNames[workload.Id.Name]; !found {
workloads = append(workloads, workload)
workloadNames[workload.Id.Name] = struct{}{}
}
}
}
// Now gather the exact match selections
for _, name := range sel.GetNames() {
// ignore names we have already fetched
if _, found := workloadNames[name]; found {
continue
}
workloadID := &pbresource.ID{
Type: pbcatalog.WorkloadType,
Tenancy: svc.resource.Id.Tenancy,
Name: name,
}
rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: workloadID})
switch {
case status.Code(err) == codes.NotFound:
// Ignore not found errors as services may select workloads that do not
// yet exist. This is not considered an error state or mis-configuration
// as the user could be getting ready to add the workloads.
continue
case err != nil:
return nil, err
}
workloads = append(workloads, rsp.Resource)
workloadNames[rsp.Resource.Id.Name] = struct{}{}
}
if sel.GetFilter() != "" && len(workloads) > 0 {
var err error
workloads, err = resource.FilterResourcesByMetadata(workloads, sel.GetFilter())
if err != nil {
return nil, fmt.Errorf("error filtering results by metadata: %w", err)
}
}
// Sorting ensures deterministic output. This will help for testing but
// the real reason to do this is so we will be able to diff the set of
// workloads endpoints to determine if we need to update them.
sort.Slice(workloads, func(i, j int) bool {
return workloads[i].Id.Name < workloads[j].Id.Name
})
return workloads, nil
}

View File

@ -1,358 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package endpoints
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing"
"github.com/hashicorp/consul/internal/catalog/internal/types"
"github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/resourcetest"
rtest "github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/consul/sdk/testutil"
)
type reconciliationDataSuite struct {
suite.Suite
ctx context.Context
client *resourcetest.Client
rt controller.Runtime
apiServiceData *pbcatalog.Service
apiService *pbresource.Resource
apiServiceSubsetData *pbcatalog.Service
apiServiceSubset *pbresource.Resource
apiEndpoints *pbresource.Resource
api1Workload *pbresource.Resource
api2Workload *pbresource.Resource
api123Workload *pbresource.Resource
web1Workload *pbresource.Resource
web2Workload *pbresource.Resource
tenancies []*pbresource.Tenancy
}
func (suite *reconciliationDataSuite) SetupTest() {
suite.ctx = testutil.TestContext(suite.T())
suite.tenancies = rtest.TestTenancies()
resourceClient := svctest.NewResourceServiceBuilder().
WithRegisterFns(types.Register).
WithTenancies(suite.tenancies...).
Run(suite.T())
suite.client = resourcetest.NewClient(resourceClient)
suite.rt = controller.Runtime{
Client: suite.client,
Logger: testutil.Logger(suite.T()),
}
suite.apiServiceData = &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
// This services selectors are specially crafted to exercise both the
// dedeuplication and sorting behaviors of gatherWorkloadsForService
Prefixes: []string{"api-"},
Names: []string{"api-1", "web-2", "web-1", "api-1", "not-found"},
},
Ports: []*pbcatalog.ServicePort{
{
TargetPort: "http",
Protocol: pbcatalog.Protocol_PROTOCOL_HTTP,
},
},
}
suite.apiServiceSubsetData = proto.Clone(suite.apiServiceData).(*pbcatalog.Service)
suite.apiServiceSubsetData.Workloads.Filter = "(zim in metadata) and (metadata.zim matches `^g.`)"
}
func (suite *reconciliationDataSuite) TestGetServiceData_NotFound() {
// This test's purposes is to ensure that NotFound errors when retrieving
// the service data are ignored properly.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getServiceData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceType, "not-found").WithTenancy(tenancy).ID())
require.NoError(suite.T(), err)
require.Nil(suite.T(), data)
})
}
func (suite *reconciliationDataSuite) TestGetServiceData_ReadError() {
// This test's purpose is to ensure that Read errors other than NotFound
// are propagated back to the caller. Specifying a resource ID with an
// unregistered type is the easiest way to force a resource service error.
badType := &pbresource.Type{
Group: "not",
Kind: "found",
GroupVersion: "vfake",
}
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getServiceData(suite.ctx, suite.rt, rtest.Resource(badType, "foo").WithTenancy(tenancy).ID())
require.Error(suite.T(), err)
require.Equal(suite.T(), codes.InvalidArgument, status.Code(err))
require.Nil(suite.T(), data)
})
}
func (suite *reconciliationDataSuite) TestGetServiceData_UnmarshalError() {
// This test's purpose is to ensure that unmarshlling errors are returned
// to the caller. We are using a resource id that points to an endpoints
// object instead of a service to ensure that the data will be unmarshallable.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getServiceData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceEndpointsType, "api").WithTenancy(tenancy).ID())
require.Error(suite.T(), err)
var parseErr resource.ErrDataParse
require.ErrorAs(suite.T(), err, &parseErr)
require.Nil(suite.T(), data)
})
}
func (suite *reconciliationDataSuite) TestGetServiceData_Ok() {
// This test's purpose is to ensure that the happy path for
// retrieving a service works as expected.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getServiceData(suite.ctx, suite.rt, suite.apiService.Id)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), data)
require.NotNil(suite.T(), data.resource)
prototest.AssertDeepEqual(suite.T(), suite.apiService.Id, data.resource.Id)
require.Len(suite.T(), data.service.Ports, 1)
})
}
func (suite *reconciliationDataSuite) TestGetEndpointsData_NotFound() {
// This test's purposes is to ensure that NotFound errors when retrieving
// the endpoint data are ignored properly.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getEndpointsData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceEndpointsType, "not-found").WithTenancy(tenancy).ID())
require.NoError(suite.T(), err)
require.Nil(suite.T(), data)
})
}
func (suite *reconciliationDataSuite) TestGetEndpointsData_ReadError() {
// This test's purpose is to ensure that Read errors other than NotFound
// are propagated back to the caller. Specifying a resource ID with an
// unregistered type is the easiest way to force a resource service error.
badType := &pbresource.Type{
Group: "not",
Kind: "found",
GroupVersion: "vfake",
}
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getEndpointsData(suite.ctx, suite.rt, rtest.Resource(badType, "foo").WithTenancy(tenancy).ID())
require.Error(suite.T(), err)
require.Equal(suite.T(), codes.InvalidArgument, status.Code(err))
require.Nil(suite.T(), data)
})
}
func (suite *reconciliationDataSuite) TestGetEndpointsData_UnmarshalError() {
// This test's purpose is to ensure that unmarshlling errors are returned
// to the caller. We are using a resource id that points to a service object
// instead of an endpoints object to ensure that the data will be unmarshallable.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getEndpointsData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceType, "api").WithTenancy(tenancy).ID())
require.Error(suite.T(), err)
var parseErr resource.ErrDataParse
require.ErrorAs(suite.T(), err, &parseErr)
require.Nil(suite.T(), data)
})
}
func (suite *reconciliationDataSuite) TestGetEndpointsData_Ok() {
// This test's purpose is to ensure that the happy path for
// retrieving an endpoints object works as expected.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
data, err := getEndpointsData(suite.ctx, suite.rt, suite.apiEndpoints.Id)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), data)
require.NotNil(suite.T(), data.resource)
prototest.AssertDeepEqual(suite.T(), suite.apiEndpoints.Id, data.resource.Id)
require.Len(suite.T(), data.endpoints.Endpoints, 1)
})
}
func (suite *reconciliationDataSuite) TestGetWorkloadData() {
// This test's purpose is to ensure that gather workloads for
// a service work as expected. The services selector was crafted
// to exercise the deduplication behavior as well as the sorting
// behavior. The assertions in this test will verify that only
// unique workloads are returned and that they are ordered.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
require.NotNil(suite.T(), suite.apiService)
data, err := getWorkloadData(suite.ctx, suite.rt, &serviceData{
resource: suite.apiService,
service: suite.apiServiceData,
})
require.NoError(suite.T(), err)
require.Len(suite.T(), data, 5)
prototest.AssertDeepEqual(suite.T(), suite.api1Workload, data[0].resource)
prototest.AssertDeepEqual(suite.T(), suite.api123Workload, data[1].resource)
prototest.AssertDeepEqual(suite.T(), suite.api2Workload, data[2].resource)
prototest.AssertDeepEqual(suite.T(), suite.web1Workload, data[3].resource)
prototest.AssertDeepEqual(suite.T(), suite.web2Workload, data[4].resource)
})
}
func (suite *reconciliationDataSuite) TestGetWorkloadDataWithFilter() {
// This is like TestGetWorkloadData except it exercises the post-read
// filter on the selector.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
require.NotNil(suite.T(), suite.apiServiceSubset)
data, err := getWorkloadData(suite.ctx, suite.rt, &serviceData{
resource: suite.apiServiceSubset,
service: suite.apiServiceSubsetData,
})
require.NoError(suite.T(), err)
require.Len(suite.T(), data, 2)
prototest.AssertDeepEqual(suite.T(), suite.api123Workload, data[0].resource)
prototest.AssertDeepEqual(suite.T(), suite.web1Workload, data[1].resource)
})
}
func TestReconciliationData(t *testing.T) {
suite.Run(t, new(reconciliationDataSuite))
}
func (suite *reconciliationDataSuite) setupResourcesWithTenancy(tenancy *pbresource.Tenancy) {
suite.apiService = rtest.Resource(pbcatalog.ServiceType, "api").
WithTenancy(tenancy).
WithData(suite.T(), suite.apiServiceData).
Write(suite.T(), suite.client)
suite.apiServiceSubset = rtest.Resource(pbcatalog.ServiceType, "api-subset").
WithTenancy(tenancy).
WithData(suite.T(), suite.apiServiceSubsetData).
Write(suite.T(), suite.client)
suite.api1Workload = rtest.Resource(pbcatalog.WorkloadType, "api-1").
WithTenancy(tenancy).
WithMeta("zim", "dib").
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "api",
}).
Write(suite.T(), suite.client)
suite.api2Workload = rtest.Resource(pbcatalog.WorkloadType, "api-2").
WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "api",
}).
Write(suite.T(), suite.client)
suite.api123Workload = rtest.Resource(pbcatalog.WorkloadType, "api-123").
WithTenancy(tenancy).
WithMeta("zim", "gir").
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "api",
}).
Write(suite.T(), suite.client)
suite.web1Workload = rtest.Resource(pbcatalog.WorkloadType, "web-1").
WithTenancy(tenancy).
WithMeta("zim", "gaz").
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "web",
}).
Write(suite.T(), suite.client)
suite.web2Workload = rtest.Resource(pbcatalog.WorkloadType, "web-2").
WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "web",
}).
Write(suite.T(), suite.client)
suite.apiEndpoints = rtest.Resource(pbcatalog.ServiceEndpointsType, "api").
WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.ServiceEndpoints{
Endpoints: []*pbcatalog.Endpoint{
{
TargetRef: rtest.Resource(pbcatalog.WorkloadType, "api-1").WithTenancy(tenancy).ID(),
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: "127.0.0.1",
Ports: []string{"http"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
HealthStatus: pbcatalog.Health_HEALTH_PASSING,
},
},
}).
Write(suite.T(), suite.client)
}
func (suite *reconciliationDataSuite) cleanupResources() {
suite.client.MustDelete(suite.T(), suite.apiService.Id)
suite.client.MustDelete(suite.T(), suite.apiServiceSubset.Id)
suite.client.MustDelete(suite.T(), suite.api1Workload.Id)
suite.client.MustDelete(suite.T(), suite.api2Workload.Id)
suite.client.MustDelete(suite.T(), suite.api123Workload.Id)
suite.client.MustDelete(suite.T(), suite.web1Workload.Id)
suite.client.MustDelete(suite.T(), suite.web2Workload.Id)
suite.client.MustDelete(suite.T(), suite.apiEndpoints.Id)
}
func (suite *reconciliationDataSuite) runTestCaseWithTenancies(testFunc func(*pbresource.Tenancy)) {
for _, tenancy := range suite.tenancies {
suite.Run(suite.appendTenancyInfo(tenancy), func() {
suite.setupResourcesWithTenancy(tenancy)
testFunc(tenancy)
suite.T().Cleanup(suite.cleanupResources)
})
}
}
func (suite *reconciliationDataSuite) appendTenancyInfo(tenancy *pbresource.Tenancy) string {
return fmt.Sprintf("%s_Namespace_%s_Partition", tenancy.Namespace, tenancy.Partition)
}

View File

@ -11,7 +11,7 @@ import (
) )
const ( const (
StatusKey = "consul.io/endpoint-manager" ControllerID = "consul.io/endpoint-manager"
StatusConditionEndpointsManaged = "EndpointsManaged" StatusConditionEndpointsManaged = "EndpointsManaged"
StatusReasonSelectorNotFound = "SelectorNotFound" StatusReasonSelectorNotFound = "SelectorNotFound"

View File

@ -12,13 +12,12 @@ import (
) )
type Dependencies struct { type Dependencies struct {
EndpointsWorkloadMapper endpoints.WorkloadMapper FailoverMapper failover.FailoverMapper
FailoverMapper failover.FailoverMapper
} }
func Register(mgr *controller.Manager, deps Dependencies) { func Register(mgr *controller.Manager, deps Dependencies) {
mgr.Register(nodehealth.NodeHealthController()) mgr.Register(nodehealth.NodeHealthController())
mgr.Register(workloadhealth.WorkloadHealthController()) mgr.Register(workloadhealth.WorkloadHealthController())
mgr.Register(endpoints.ServiceEndpointsController(deps.EndpointsWorkloadMapper)) mgr.Register(endpoints.ServiceEndpointsController())
mgr.Register(failover.FailoverPolicyController(deps.FailoverMapper)) mgr.Register(failover.FailoverPolicyController(deps.FailoverMapper))
} }

View File

@ -6,6 +6,7 @@ package types
import ( import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
) )
@ -18,7 +19,7 @@ func RegisterHealthChecks(r resource.Registry) {
Proto: &pbcatalog.HealthChecks{}, Proto: &pbcatalog.HealthChecks{},
Scope: resource.ScopeNamespace, Scope: resource.ScopeNamespace,
Validate: ValidateHealthChecks, Validate: ValidateHealthChecks,
ACLs: ACLHooksForWorkloadSelectingType[*pbcatalog.HealthChecks](), ACLs: workloadselector.ACLHooks[*pbcatalog.HealthChecks](),
}) })
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
) )
@ -21,7 +22,7 @@ func RegisterService(r resource.Registry) {
Scope: resource.ScopeNamespace, Scope: resource.ScopeNamespace,
Validate: ValidateService, Validate: ValidateService,
Mutate: MutateService, Mutate: MutateService,
ACLs: ACLHooksForWorkloadSelectingType[*pbcatalog.Service](), ACLs: workloadselector.ACLHooks[*pbcatalog.Service](),
}) })
} }

View File

@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc. // Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1 // SPDX-License-Identifier: BUSL-1.1
package types package workloadselector
import ( import (
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
@ -38,7 +38,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac
return nil return nil
} }
func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks { func ACLHooks[T WorkloadSelecting]() *resource.ACLHooks {
return &resource.ACLHooks{ return &resource.ACLHooks{
Read: aclReadHookResourceWithWorkloadSelector, Read: aclReadHookResourceWithWorkloadSelector,
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]), Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]),

View File

@ -0,0 +1,123 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)
func TestACLHooks(t *testing.T) {
suite.Run(t, new(aclHookSuite))
}
type aclHookSuite struct {
suite.Suite
hooks *resource.ACLHooks
authz *acl.MockAuthorizer
ctx *acl.AuthorizerContext
res *pbresource.Resource
}
func (suite *aclHookSuite) SetupTest() {
suite.authz = new(acl.MockAuthorizer)
suite.authz.On("ToAllowAuthorizer").Return(acl.AllowAuthorizer{Authorizer: suite.authz, AccessorID: "862270e5-7d7b-4583-98bc-4d14810cc158"})
suite.ctx = &acl.AuthorizerContext{}
acl.DefaultEnterpriseMeta().FillAuthzContext(suite.ctx)
suite.hooks = ACLHooks[*pbcatalog.Service]()
suite.res = resourcetest.Resource(pbcatalog.ServiceType, "foo").
WithData(suite.T(), &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"api-"},
Names: []string{"bar"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
}
func (suite *aclHookSuite) TeardownTest() {
suite.authz.AssertExpectations(suite.T())
}
func (suite *aclHookSuite) TestReadHook_Allowed() {
suite.authz.On("ServiceRead", "foo", suite.ctx).
Return(acl.Allow).
Once()
require.NoError(suite.T(), suite.hooks.Read(suite.authz, suite.ctx, suite.res.Id, nil))
}
func (suite *aclHookSuite) TestReadHook_Denied() {
suite.authz.On("ServiceRead", "foo", suite.ctx).
Return(acl.Deny).
Once()
require.Error(suite.T(), suite.hooks.Read(suite.authz, suite.ctx, suite.res.Id, nil))
}
func (suite *aclHookSuite) TestWriteHook_ServiceWriteDenied() {
suite.authz.On("ServiceWrite", "foo", suite.ctx).
Return(acl.Deny).
Once()
require.Error(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res))
}
func (suite *aclHookSuite) TestWriteHook_ServiceReadNameDenied() {
suite.authz.On("ServiceWrite", "foo", suite.ctx).
Return(acl.Allow).
Once()
suite.authz.On("ServiceRead", "bar", suite.ctx).
Return(acl.Deny).
Once()
require.Error(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res))
}
func (suite *aclHookSuite) TestWriteHook_ServiceReadPrefixDenied() {
suite.authz.On("ServiceWrite", "foo", suite.ctx).
Return(acl.Allow).
Once()
suite.authz.On("ServiceRead", "bar", suite.ctx).
Return(acl.Allow).
Once()
suite.authz.On("ServiceReadPrefix", "api-", suite.ctx).
Return(acl.Deny).
Once()
require.Error(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res))
}
func (suite *aclHookSuite) TestWriteHook_Allowed() {
suite.authz.On("ServiceWrite", "foo", suite.ctx).
Return(acl.Allow).
Once()
suite.authz.On("ServiceRead", "bar", suite.ctx).
Return(acl.Allow).
Once()
suite.authz.On("ServiceReadPrefix", "api-", suite.ctx).
Return(acl.Allow).
Once()
require.NoError(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res))
}

View File

@ -0,0 +1,114 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"fmt"
"sort"
"github.com/hashicorp/consul/internal/controller/cache"
"github.com/hashicorp/consul/internal/controller/cache/index"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)
// GetWorkloadsWithSelector will retrieve all workloads for the given resources selector
// and unmarhshal them, returning a slice of objects hold both the resource and
// unmarshaled forms. Unmarshalling errors, or other cache errors
// will be returned to the caller.
func GetWorkloadsWithSelector[T WorkloadSelecting](c cache.ReadOnlyCache, res *resource.DecodedResource[T]) ([]*resource.DecodedResource[*pbcatalog.Workload], error) {
if res == nil {
return nil, nil
}
sel := res.Data.GetWorkloads()
if sel == nil || (len(sel.GetNames()) < 1 && len(sel.GetPrefixes()) < 1) {
return nil, nil
}
// this map will track all workloads by name which is needed to deduplicate workloads if they
// are specified multiple times throughout the list of selection criteria
workloadNames := make(map[string]struct{})
var workloads []*resource.DecodedResource[*pbcatalog.Workload]
// First gather all the prefix matched workloads. We could do this second but by doing
// it first its possible we can avoid some operations to get individual
// workloads selected by name if they are also matched by a prefix.
for _, prefix := range sel.GetPrefixes() {
iter, err := cache.ListIteratorDecoded[*pbcatalog.Workload](
c,
pbcatalog.WorkloadType,
"id",
&pbresource.ID{
Type: pbcatalog.WorkloadType,
Tenancy: res.Id.Tenancy,
Name: prefix,
},
index.IndexQueryOptions{Prefix: true})
if err != nil {
return nil, err
}
// append all workloads in the list response to our list of all selected workloads
for workload, err := iter.Next(); workload != nil || err != nil; workload, err = iter.Next() {
if err != nil {
return nil, err
}
// ignore duplicate workloads
if _, found := workloadNames[workload.Id.Name]; !found {
workloads = append(workloads, workload)
workloadNames[workload.Id.Name] = struct{}{}
}
}
}
// Now gather the exact match selections
for _, name := range sel.GetNames() {
// ignore names we have already fetched
if _, found := workloadNames[name]; found {
continue
}
workloadID := &pbresource.ID{
Type: pbcatalog.WorkloadType,
Tenancy: res.Id.Tenancy,
Name: name,
}
res, err := cache.GetDecoded[*pbcatalog.Workload](c, pbcatalog.WorkloadType, "id", workloadID)
if err != nil {
return nil, err
}
// ignore workloads that don't exist as it is fine for a Service to select them. If they exist in the
// future then the ServiceEndpoints will be regenerated to include them.
if res == nil {
continue
}
workloads = append(workloads, res)
workloadNames[res.Id.Name] = struct{}{}
}
if sel.GetFilter() != "" && len(workloads) > 0 {
var err error
workloads, err = resource.FilterResourcesByMetadata(workloads, sel.GetFilter())
if err != nil {
return nil, fmt.Errorf("error filtering results by metadata: %w", err)
}
}
// Sorting ensures deterministic output. This will help for testing but
// the real reason to do this is so we will be able to diff the set of
// workloads endpoints to determine if we need to update them.
sort.Slice(workloads, func(i, j int) bool {
return workloads[i].Id.Name < workloads[j].Id.Name
})
return workloads, nil
}

View File

@ -0,0 +1,258 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/consul/internal/controller/cache"
"github.com/hashicorp/consul/internal/resource"
rtest "github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/consul/sdk/testutil"
)
type gatherWorkloadsDataSuite struct {
suite.Suite
ctx context.Context
cache cache.Cache
apiServiceData *pbcatalog.Service
apiService *resource.DecodedResource[*pbcatalog.Service]
apiServiceSubsetData *pbcatalog.Service
apiServiceSubset *resource.DecodedResource[*pbcatalog.Service]
apiEndpoints *resource.DecodedResource[*pbcatalog.ServiceEndpoints]
api1Workload *resource.DecodedResource[*pbcatalog.Workload]
api2Workload *resource.DecodedResource[*pbcatalog.Workload]
api123Workload *resource.DecodedResource[*pbcatalog.Workload]
web1Workload *resource.DecodedResource[*pbcatalog.Workload]
web2Workload *resource.DecodedResource[*pbcatalog.Workload]
tenancies []*pbresource.Tenancy
}
func (suite *gatherWorkloadsDataSuite) SetupTest() {
suite.ctx = testutil.TestContext(suite.T())
suite.tenancies = rtest.TestTenancies()
suite.cache = cache.New()
suite.cache.AddType(pbcatalog.WorkloadType)
suite.apiServiceData = &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
// This services selectors are specially crafted to exercise both the
// dedeuplication and sorting behaviors of gatherWorkloadsForService
Prefixes: []string{"api-"},
Names: []string{"api-1", "web-2", "web-1", "api-1", "not-found"},
},
Ports: []*pbcatalog.ServicePort{
{
TargetPort: "http",
Protocol: pbcatalog.Protocol_PROTOCOL_HTTP,
},
},
}
suite.apiServiceSubsetData = proto.Clone(suite.apiServiceData).(*pbcatalog.Service)
suite.apiServiceSubsetData.Workloads.Filter = "(zim in metadata) and (metadata.zim matches `^g.`)"
}
func (suite *gatherWorkloadsDataSuite) TestGetWorkloadData() {
// This test's purpose is to ensure that gather workloads for
// a service work as expected. The services selector was crafted
// to exercise the deduplication behavior as well as the sorting
// behavior. The assertions in this test will verify that only
// unique workloads are returned and that they are ordered.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
require.NotNil(suite.T(), suite.apiService)
data, err := GetWorkloadsWithSelector(suite.cache, suite.apiService)
require.NoError(suite.T(), err)
require.Len(suite.T(), data, 5)
requireDecodedWorkloadEquals(suite.T(), suite.api1Workload, data[0])
requireDecodedWorkloadEquals(suite.T(), suite.api1Workload, data[0])
requireDecodedWorkloadEquals(suite.T(), suite.api123Workload, data[1])
requireDecodedWorkloadEquals(suite.T(), suite.api2Workload, data[2])
requireDecodedWorkloadEquals(suite.T(), suite.web1Workload, data[3])
requireDecodedWorkloadEquals(suite.T(), suite.web2Workload, data[4])
})
}
func (suite *gatherWorkloadsDataSuite) TestGetWorkloadDataWithFilter() {
// This is like TestGetWorkloadData except it exercises the post-read
// filter on the selector.
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
require.NotNil(suite.T(), suite.apiServiceSubset)
data, err := GetWorkloadsWithSelector(suite.cache, suite.apiServiceSubset)
require.NoError(suite.T(), err)
require.Len(suite.T(), data, 2)
requireDecodedWorkloadEquals(suite.T(), suite.api123Workload, data[0])
requireDecodedWorkloadEquals(suite.T(), suite.web1Workload, data[1])
})
}
func TestReconciliationData(t *testing.T) {
suite.Run(t, new(gatherWorkloadsDataSuite))
}
func (suite *gatherWorkloadsDataSuite) setupResourcesWithTenancy(tenancy *pbresource.Tenancy) {
suite.apiService = rtest.MustDecode[*pbcatalog.Service](
suite.T(),
rtest.Resource(pbcatalog.ServiceType, "api").
WithTenancy(tenancy).
WithData(suite.T(), suite.apiServiceData).
Build())
suite.apiServiceSubset = rtest.MustDecode[*pbcatalog.Service](
suite.T(),
rtest.Resource(pbcatalog.ServiceType, "api-subset").
WithTenancy(tenancy).
WithData(suite.T(), suite.apiServiceSubsetData).
Build())
suite.api1Workload = rtest.MustDecode[*pbcatalog.Workload](
suite.T(),
rtest.Resource(pbcatalog.WorkloadType, "api-1").
WithTenancy(tenancy).
WithMeta("zim", "dib").
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "api",
}).
Build())
suite.cache.Insert(suite.api1Workload.Resource)
suite.api2Workload = rtest.MustDecode[*pbcatalog.Workload](
suite.T(),
rtest.Resource(pbcatalog.WorkloadType, "api-2").
WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "api",
}).
Build())
suite.cache.Insert(suite.api2Workload.Resource)
suite.api123Workload = rtest.MustDecode[*pbcatalog.Workload](
suite.T(),
rtest.Resource(pbcatalog.WorkloadType, "api-123").
WithTenancy(tenancy).
WithMeta("zim", "gir").
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "api",
}).
Build())
suite.cache.Insert(suite.api123Workload.Resource)
suite.web1Workload = rtest.MustDecode[*pbcatalog.Workload](
suite.T(),
rtest.Resource(pbcatalog.WorkloadType, "web-1").
WithTenancy(tenancy).
WithMeta("zim", "gaz").
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "web",
}).
Build())
suite.cache.Insert(suite.web1Workload.Resource)
suite.web2Workload = rtest.MustDecode[*pbcatalog.Workload](
suite.T(),
rtest.Resource(pbcatalog.WorkloadType, "web-2").
WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.Workload{
Addresses: []*pbcatalog.WorkloadAddress{
{Host: "127.0.0.1"},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
Identity: "web",
}).
Build())
suite.cache.Insert(suite.web2Workload.Resource)
suite.apiEndpoints = rtest.MustDecode[*pbcatalog.ServiceEndpoints](
suite.T(),
rtest.Resource(pbcatalog.ServiceEndpointsType, "api").
WithTenancy(tenancy).
WithData(suite.T(), &pbcatalog.ServiceEndpoints{
Endpoints: []*pbcatalog.Endpoint{
{
TargetRef: rtest.Resource(pbcatalog.WorkloadType, "api-1").WithTenancy(tenancy).ID(),
Addresses: []*pbcatalog.WorkloadAddress{
{
Host: "127.0.0.1",
Ports: []string{"http"},
},
},
Ports: map[string]*pbcatalog.WorkloadPort{
"http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP},
},
HealthStatus: pbcatalog.Health_HEALTH_PASSING,
},
},
}).
Build())
}
func (suite *gatherWorkloadsDataSuite) cleanupResources() {
require.NoError(suite.T(), suite.cache.Delete(suite.api1Workload.Resource))
require.NoError(suite.T(), suite.cache.Delete(suite.api2Workload.Resource))
require.NoError(suite.T(), suite.cache.Delete(suite.api123Workload.Resource))
require.NoError(suite.T(), suite.cache.Delete(suite.web1Workload.Resource))
require.NoError(suite.T(), suite.cache.Delete(suite.web2Workload.Resource))
}
func (suite *gatherWorkloadsDataSuite) runTestCaseWithTenancies(testFunc func(*pbresource.Tenancy)) {
for _, tenancy := range suite.tenancies {
suite.Run(suite.appendTenancyInfo(tenancy), func() {
suite.setupResourcesWithTenancy(tenancy)
testFunc(tenancy)
suite.T().Cleanup(suite.cleanupResources)
})
}
}
func (suite *gatherWorkloadsDataSuite) appendTenancyInfo(tenancy *pbresource.Tenancy) string {
return fmt.Sprintf("%s_Namespace_%s_Partition", tenancy.Namespace, tenancy.Partition)
}
func requireDecodedWorkloadEquals(t testutil.TestingTB, expected, actual *resource.DecodedResource[*pbcatalog.Workload]) {
prototest.AssertDeepEqual(t, expected.Resource, actual.Resource)
require.Equal(t, expected.Data, actual.Data)
}

View File

@ -0,0 +1,72 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"github.com/hashicorp/consul/internal/controller/cache/index"
"github.com/hashicorp/consul/internal/controller/cache/indexers"
"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)
const (
IndexName = "selected-workloads"
)
func Index[T WorkloadSelecting](name string) *index.Index {
return indexers.DecodedMultiIndexer[T](
name,
index.SingleValueFromOneOrTwoArgs[resource.ReferenceOrID, index.IndexQueryOptions](fromArgs),
fromResource[T],
)
}
func fromArgs(r resource.ReferenceOrID, opts index.IndexQueryOptions) ([]byte, error) {
workloadRef := &pbresource.Reference{
Type: pbcatalog.WorkloadType,
Tenancy: r.GetTenancy(),
Name: r.GetName(),
}
if opts.Prefix {
return index.PrefixIndexFromRefOrID(workloadRef), nil
} else {
return index.IndexFromRefOrID(workloadRef), nil
}
}
func fromResource[T WorkloadSelecting](res *resource.DecodedResource[T]) (bool, [][]byte, error) {
sel := res.Data.GetWorkloads()
if sel == nil || (len(sel.Prefixes) == 0 && len(sel.Names) == 0) {
return false, nil, nil
}
var indexes [][]byte
for _, name := range sel.Names {
ref := &pbresource.Reference{
Type: pbcatalog.WorkloadType,
Tenancy: res.Id.Tenancy,
Name: name,
}
indexes = append(indexes, index.IndexFromRefOrID(ref))
}
for _, name := range sel.Prefixes {
ref := &pbresource.Reference{
Type: pbcatalog.WorkloadType,
Tenancy: res.Id.Tenancy,
Name: name,
}
b := index.IndexFromRefOrID(ref)
// need to remove the path separator to be compatible with prefix matching
indexes = append(indexes, b[:len(b)-1])
}
return true, indexes, nil
}

View File

@ -0,0 +1,135 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"testing"
"github.com/hashicorp/consul/internal/controller/cache"
"github.com/hashicorp/consul/internal/controller/cache/index"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/resourcetest"
rtest "github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/stretchr/testify/require"
)
func TestServiceWorkloadIndexer(t *testing.T) {
c := cache.New()
i := Index[*pbcatalog.Service]("selected-workloads")
require.NoError(t, c.AddIndex(pbcatalog.ServiceType, i))
foo := rtest.Resource(pbcatalog.ServiceType, "foo").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Names: []string{
"api-2",
},
Prefixes: []string{
"api-1",
},
},
}).
WithTenancy(&pbresource.Tenancy{
Partition: "default",
Namespace: "default",
PeerName: "local",
}).
Build()
require.NoError(t, c.Insert(foo))
bar := rtest.Resource(pbcatalog.ServiceType, "bar").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Names: []string{
"api-3",
},
Prefixes: []string{
"api-2",
},
},
}).
WithTenancy(&pbresource.Tenancy{
Partition: "default",
Namespace: "default",
PeerName: "local",
}).
Build()
require.NoError(t, c.Insert(bar))
api123 := rtest.Resource(pbcatalog.WorkloadType, "api-123").
WithTenancy(&pbresource.Tenancy{
Partition: "default",
Namespace: "default",
PeerName: "local",
}).
Reference("")
api2 := rtest.Resource(pbcatalog.WorkloadType, "api-2").
WithTenancy(&pbresource.Tenancy{
Partition: "default",
Namespace: "default",
PeerName: "local",
}).
Reference("")
resources, err := c.Parents(pbcatalog.ServiceType, i.Name(), api123)
require.NoError(t, err)
require.Len(t, resources, 1)
prototest.AssertDeepEqual(t, foo, resources[0])
resources, err = c.Parents(pbcatalog.ServiceType, i.Name(), api2)
require.NoError(t, err)
require.Len(t, resources, 2)
prototest.AssertElementsMatch(t, []*pbresource.Resource{foo, bar}, resources)
refPrefix := &pbresource.Reference{
Type: pbcatalog.WorkloadType,
Tenancy: &pbresource.Tenancy{
Partition: "default",
Namespace: "default",
},
}
resources, err = c.List(pbcatalog.ServiceType, i.Name(), refPrefix, index.IndexQueryOptions{Prefix: true})
require.NoError(t, err)
// because foo and bar both have 2 index values they will appear in the output twice
require.Len(t, resources, 4)
prototest.AssertElementsMatch(t, []*pbresource.Resource{foo, bar, foo, bar}, resources)
}
func TestServiceWorkloadIndexer_FromResource_Errors(t *testing.T) {
t.Run("nil-selector", func(t *testing.T) {
res := resourcetest.MustDecode[*pbcatalog.Service](
t,
resourcetest.Resource(pbcatalog.ServiceType, "foo").
WithData(t, &pbcatalog.Service{}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build())
indexed, vals, err := fromResource(res)
require.False(t, indexed)
require.Nil(t, vals)
require.NoError(t, err)
})
t.Run("no-selections", func(t *testing.T) {
res := resourcetest.MustDecode[*pbcatalog.Service](
t,
resourcetest.Resource(pbcatalog.ServiceType, "foo").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build())
indexed, vals, err := fromResource(res)
require.False(t, indexed)
require.Nil(t, vals)
require.NoError(t, err)
})
}

View File

@ -0,0 +1,151 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector_test
import (
"context"
"testing"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/controller/cache"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/require"
)
func TestWorkloadSelectorCacheIntegration(t *testing.T) {
c := cache.New()
i := workloadselector.Index[*pbcatalog.Service]("selected-workloads")
c.AddType(pbcatalog.WorkloadType)
c.AddIndex(pbcatalog.ServiceType, i)
rt := controller.Runtime{
Cache: c,
Logger: testutil.Logger(t),
}
svcFoo := resourcetest.Resource(pbcatalog.ServiceType, "foo").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Names: []string{"foo"},
Prefixes: []string{"api-", "other-"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
svcBar := resourcetest.Resource(pbcatalog.ServiceType, "bar").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Names: []string{"bar"},
Prefixes: []string{"api-1", "something-else-"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
workloadBar := resourcetest.Resource(pbcatalog.WorkloadType, "bar").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
workloadAPIFoo := resourcetest.Resource(pbcatalog.WorkloadType, "api-foo").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
workloadAPI12 := resourcetest.Resource(pbcatalog.WorkloadType, "api-1").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
workloadFoo := resourcetest.Resource(pbcatalog.WorkloadType, "foo").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
workloadSomethingElse12 := resourcetest.Resource(pbcatalog.WorkloadType, "something-else-12").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
// prime the cache with all of our services and workloads
require.NoError(t, c.Insert(svcFoo))
require.NoError(t, c.Insert(svcBar))
require.NoError(t, c.Insert(workloadAPIFoo))
require.NoError(t, c.Insert(workloadAPI12))
require.NoError(t, c.Insert(workloadFoo))
require.NoError(t, c.Insert(workloadSomethingElse12))
// check that mapping a selecting resource to the list of currently selected workloads works as expected
reqs, err := workloadselector.MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svcFoo)
require.NoError(t, err)
// in particular workloadSomethingElse12 should not show up here
expected := []controller.Request{
{ID: workloadFoo.Id},
{ID: workloadAPI12.Id},
{ID: workloadAPIFoo.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
reqs, err = workloadselector.MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svcBar)
require.NoError(t, err)
// workloadFoo and workloadAPIFoo should not show up here as they don't meet the selection critiera
// workloadBar should not show up here because it hasn't been inserted into the cache yet.
expected = []controller.Request{
{ID: workloadSomethingElse12.Id},
{ID: workloadAPI12.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
// insert workloadBar into the cache so that future calls to MapSelectorToWorkloads for svcBar show
// the workload in the output
require.NoError(t, c.Insert(workloadBar))
// now validate that workloadBar shows up in the svcBar mapping
reqs, err = workloadselector.MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svcBar)
require.NoError(t, err)
expected = []controller.Request{
{ID: workloadSomethingElse12.Id},
{ID: workloadAPI12.Id},
{ID: workloadBar.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
// create the mapper to verify that finding services that select workloads functions correctly
mapper := workloadselector.MapWorkloadsToSelectors(pbcatalog.ServiceType, i.Name())
// check that workloadAPIFoo only returns a request for serviceFoo
reqs, err = mapper(context.Background(), rt, workloadAPIFoo)
require.NoError(t, err)
expected = []controller.Request{
{ID: svcFoo.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
// check that workloadAPI12 returns both services
reqs, err = mapper(context.Background(), rt, workloadAPI12)
require.NoError(t, err)
expected = []controller.Request{
{ID: svcFoo.Id},
{ID: svcBar.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
// check that workloadSomethingElse12 returns only svcBar
reqs, err = mapper(context.Background(), rt, workloadSomethingElse12)
require.NoError(t, err)
expected = []controller.Request{
{ID: svcBar.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
// check that workloadFoo returns only svcFoo
reqs, err = mapper(context.Background(), rt, workloadFoo)
require.NoError(t, err)
expected = []controller.Request{
{ID: svcFoo.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
}

View File

@ -0,0 +1,45 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"context"
"github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/controller/dependency"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/proto-public/pbresource"
)
// MapSelectorToWorkloads will use the "id" index on watched Workload type to find all current
// workloads selected by the resource.
func MapSelectorToWorkloads[T WorkloadSelecting](_ context.Context, rt controller.Runtime, r *pbresource.Resource) ([]controller.Request, error) {
res, err := resource.Decode[T](r)
if err != nil {
return nil, err
}
workloads, err := GetWorkloadsWithSelector[T](rt.Cache, res)
if err != nil {
return nil, err
}
reqs := make([]controller.Request, len(workloads))
for i, workload := range workloads {
reqs[i] = controller.Request{
ID: workload.Id,
}
}
return reqs, nil
}
// MapWorkloadsToSelectors returns a DependencyMapper that will use the specified index to map a workload
// to resources that select it.
//
// This mapper can only be used on watches for the Workload type and works in conjunction with the Index
// created by this package.
func MapWorkloadsToSelectors(indexType *pbresource.Type, indexName string) controller.DependencyMapper {
return dependency.CacheParentsMapper(indexType, indexName)
}

View File

@ -0,0 +1,180 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package workloadselector
import (
"context"
"errors"
"testing"
"github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/controller/cache/cachemock"
"github.com/hashicorp/consul/internal/controller/cache/index"
"github.com/hashicorp/consul/internal/controller/cache/index/indexmock"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/resourcetest"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
var injectedError = errors.New("injected error")
func TestMapSelectorToWorkloads(t *testing.T) {
cache := cachemock.NewReadOnlyCache(t)
rt := controller.Runtime{
Cache: cache,
}
mres := indexmock.NewResourceIterator(t)
svc := resourcetest.Resource(pbcatalog.ServiceType, "api").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"api-"},
Names: []string{"foo"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
api1 := resourcetest.Resource(pbcatalog.WorkloadType, "api-1").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
api2 := resourcetest.Resource(pbcatalog.WorkloadType, "api-2").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
fooRes := resourcetest.Resource(pbcatalog.WorkloadType, "foo").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
cache.EXPECT().
ListIterator(pbcatalog.WorkloadType, "id", &pbresource.ID{
Type: pbcatalog.WorkloadType,
Name: "api-",
Tenancy: resource.DefaultNamespacedTenancy(),
}, index.IndexQueryOptions{Prefix: true}).
Return(mres, nil).
Once()
cache.EXPECT().
Get(pbcatalog.WorkloadType, "id", &pbresource.ID{
Type: pbcatalog.WorkloadType,
Name: "foo",
Tenancy: resource.DefaultNamespacedTenancy(),
}).
Return(fooRes, nil).
Once()
mres.EXPECT().Next().Return(api1).Once()
mres.EXPECT().Next().Return(api2).Once()
mres.EXPECT().Next().Return(nil).Once()
expected := []controller.Request{
{ID: fooRes.Id},
{ID: api1.Id},
{ID: api2.Id},
}
reqs, err := MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svc)
require.NoError(t, err)
prototest.AssertElementsMatch(t, expected, reqs)
}
func TestMapSelectorToWorkloads_DecodeError(t *testing.T) {
res := resourcetest.Resource(pbcatalog.ServiceType, "foo").
WithData(t, &pbcatalog.DNSPolicy{}).
Build()
reqs, err := MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), controller.Runtime{}, res)
require.Nil(t, reqs)
require.Error(t, err)
require.ErrorAs(t, err, &resource.ErrDataParse{})
}
func TestMapSelectorToWorkloads_CacheError(t *testing.T) {
cache := cachemock.NewReadOnlyCache(t)
rt := controller.Runtime{
Cache: cache,
}
svc := resourcetest.Resource(pbcatalog.ServiceType, "api").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"api-"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
cache.EXPECT().
ListIterator(pbcatalog.WorkloadType, "id", &pbresource.ID{
Type: pbcatalog.WorkloadType,
Name: "api-",
Tenancy: resource.DefaultNamespacedTenancy(),
}, index.IndexQueryOptions{Prefix: true}).
Return(nil, injectedError).
Once()
reqs, err := MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svc)
require.ErrorIs(t, err, injectedError)
require.Nil(t, reqs)
}
func TestMapWorkloadsToSelectors(t *testing.T) {
cache := cachemock.NewReadOnlyCache(t)
rt := controller.Runtime{
Cache: cache,
Logger: hclog.NewNullLogger(),
}
dm := MapWorkloadsToSelectors(pbcatalog.ServiceType, "selected-workloads")
workload := resourcetest.Resource(pbcatalog.WorkloadType, "api-123").
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
svc1 := resourcetest.Resource(pbcatalog.ServiceType, "foo").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"api-"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
svc2 := resourcetest.Resource(pbcatalog.ServiceType, "bar").
WithData(t, &pbcatalog.Service{
Workloads: &pbcatalog.WorkloadSelector{
Prefixes: []string{"api-"},
},
}).
WithTenancy(resource.DefaultNamespacedTenancy()).
Build()
mres := indexmock.NewResourceIterator(t)
cache.EXPECT().
ParentsIterator(pbcatalog.ServiceType, "selected-workloads", workload.Id).
Return(mres, nil).
Once()
mres.EXPECT().Next().Return(svc1).Once()
mres.EXPECT().Next().Return(svc2).Once()
mres.EXPECT().Next().Return(nil).Once()
reqs, err := dm(context.Background(), rt, workload)
require.NoError(t, err)
expected := []controller.Request{
{ID: svc1.Id},
{ID: svc2.Id},
}
prototest.AssertElementsMatch(t, expected, reqs)
}

View File

@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc. // Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1 // SPDX-License-Identifier: BUSL-1.1
package types package workloadselector
import ( import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"

107
internal/controller/cache/decoded.go vendored Normal file
View File

@ -0,0 +1,107 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache
import (
"google.golang.org/protobuf/proto"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/proto-public/pbresource"
)
// Get retrieves a single resource from the specified index that matches the provided args.
// If more than one match is found the first is returned.
func GetDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) (*resource.DecodedResource[T], error) {
res, err := c.Get(it, indexName, args...)
if err != nil {
return nil, err
}
if res == nil {
return nil, nil
}
return resource.Decode[T](res)
}
// List retrieves all the resources from the specified index matching the provided args.
func ListDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) ([]*resource.DecodedResource[T], error) {
resources, err := c.List(it, indexName, args...)
if err != nil {
return nil, err
}
return resource.DecodeList[T](resources)
}
// ListIterator retrieves an iterator over all resources from the specified index matching the provided args.
func ListIteratorDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) (DecodedResourceIterator[T], error) {
iter, err := c.ListIterator(it, indexName, args...)
if err != nil {
return nil, err
}
if iter == nil {
return nil, nil
}
return decodedResourceIterator[T]{iter}, nil
}
// Parents retrieves all resources whos index value is a parent (or prefix) of the value calculated
// from the provided args.
func ParentsDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) ([]*resource.DecodedResource[T], error) {
resources, err := c.Parents(it, indexName, args...)
if err != nil {
return nil, err
}
return resource.DecodeList[T](resources)
}
// ParentsIterator retrieves an iterator over all resources whos index value is a parent (or prefix)
// of the value calculated from the provided args.
func ParentsIteratorDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) (DecodedResourceIterator[T], error) {
iter, err := c.ParentsIterator(it, indexName, args...)
if err != nil {
return nil, err
}
if iter == nil {
return nil, nil
}
return decodedResourceIterator[T]{iter}, nil
}
// Query will execute a named query against the cache and return an interator over its results
func QueryDecoded[T proto.Message](c ReadOnlyCache, name string, args ...any) (DecodedResourceIterator[T], error) {
iter, err := c.Query(name, args...)
if err != nil {
return nil, err
}
if iter == nil {
return nil, nil
}
return decodedResourceIterator[T]{iter}, nil
}
type DecodedResourceIterator[T proto.Message] interface {
Next() (*resource.DecodedResource[T], error)
}
type decodedResourceIterator[T proto.Message] struct {
ResourceIterator
}
func (iter decodedResourceIterator[T]) Next() (*resource.DecodedResource[T], error) {
res := iter.ResourceIterator.Next()
if res == nil {
return nil, nil
}
return resource.Decode[T](res)
}

View File

@ -0,0 +1,360 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/hashicorp/consul/internal/controller/cache"
"github.com/hashicorp/consul/internal/controller/cache/cachemock"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/internal/resource/demo"
"github.com/hashicorp/consul/proto-public/pbresource"
pbdemo "github.com/hashicorp/consul/proto/private/pbdemo/v2"
"github.com/hashicorp/consul/proto/private/prototest"
)
type decodedSuite struct {
suite.Suite
rc *cachemock.ReadOnlyCache
iter *cachemock.ResourceIterator
artistGood *resource.DecodedResource[*pbdemo.Artist]
artistGood2 *resource.DecodedResource[*pbdemo.Artist]
artistBad *pbresource.Resource
}
func (suite *decodedSuite) SetupTest() {
suite.rc = cachemock.NewReadOnlyCache(suite.T())
suite.iter = cachemock.NewResourceIterator(suite.T())
artist, err := demo.GenerateV2Artist()
require.NoError(suite.T(), err)
suite.artistGood, err = resource.Decode[*pbdemo.Artist](artist)
require.NoError(suite.T(), err)
artist2, err := demo.GenerateV2Artist()
require.NoError(suite.T(), err)
suite.artistGood2, err = resource.Decode[*pbdemo.Artist](artist2)
require.NoError(suite.T(), err)
suite.artistBad, err = demo.GenerateV2Album(artist.Id)
require.NoError(suite.T(), err)
}
func (suite *decodedSuite) TestGetDecoded_Ok() {
suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(suite.artistGood.Resource, nil)
dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
}
func (suite *decodedSuite) TestGetDecoded_DecodeError() {
suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(suite.artistBad, nil)
dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.Error(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestGetDecoded_CacheError() {
suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError)
dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.ErrorIs(suite.T(), err, injectedError)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestGetDecoded_Nil() {
suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil)
dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestListDecoded_Ok() {
suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistGood2.Resource}, nil)
dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Len(suite.T(), dec, 2)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec[0].Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec[0].Data)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec[1].Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec[1].Data)
}
func (suite *decodedSuite) TestListDecoded_DecodeError() {
suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistBad}, nil)
dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.Error(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestListDecoded_CacheError() {
suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError)
dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.ErrorIs(suite.T(), err, injectedError)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestListDecoded_Nil() {
suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil)
dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestListIteratorDecoded_Ok() {
suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once()
suite.iter.EXPECT().Next().Return(suite.artistGood2.Resource).Once()
suite.iter.EXPECT().Next().Return(nil).Times(0)
suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return(suite.iter, nil)
iter, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), iter)
dec, err := iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
dec, err = iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec.Data)
dec, err = iter.Next()
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestListIteratorDecoded_DecodeError() {
suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once()
suite.iter.EXPECT().Next().Return(suite.artistBad).Once()
suite.iter.EXPECT().Next().Return(nil).Times(0)
suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return(suite.iter, nil)
iter, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), iter)
dec, err := iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
dec, err = iter.Next()
require.Error(suite.T(), err)
require.Nil(suite.T(), dec)
dec, err = iter.Next()
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestListIteratorDecoded_CacheError() {
suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError)
iter, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.ErrorIs(suite.T(), err, injectedError)
require.Nil(suite.T(), iter)
}
func (suite *decodedSuite) TestListIteratorDecoded_Nil() {
suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil)
dec, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestParentsDecoded_Ok() {
suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistGood2.Resource}, nil)
dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Len(suite.T(), dec, 2)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec[0].Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec[0].Data)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec[1].Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec[1].Data)
}
func (suite *decodedSuite) TestParentsDecoded_DecodeError() {
suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistBad}, nil)
dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.Error(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestParentsDecoded_CacheError() {
suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError)
dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.ErrorIs(suite.T(), err, injectedError)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestParentsDecoded_Nil() {
suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil)
dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestParentsIteratorDecoded_Ok() {
suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once()
suite.iter.EXPECT().Next().Return(suite.artistGood2.Resource).Once()
suite.iter.EXPECT().Next().Return(nil).Times(0)
suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return(suite.iter, nil)
iter, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), iter)
dec, err := iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
dec, err = iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec.Data)
dec, err = iter.Next()
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestParentsIteratorDecoded_DecodeError() {
suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once()
suite.iter.EXPECT().Next().Return(suite.artistBad).Once()
suite.iter.EXPECT().Next().Return(nil).Times(0)
suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).
Return(suite.iter, nil)
iter, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), iter)
dec, err := iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
dec, err = iter.Next()
require.Error(suite.T(), err)
require.Nil(suite.T(), dec)
dec, err = iter.Next()
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestParentsIteratorDecoded_CacheError() {
suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError)
iter, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.ErrorIs(suite.T(), err, injectedError)
require.Nil(suite.T(), iter)
}
func (suite *decodedSuite) TestParentsIteratorDecoded_Nil() {
suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil)
dec, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id)
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestQueryDecoded_Ok() {
suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once()
suite.iter.EXPECT().Next().Return(suite.artistGood2.Resource).Once()
suite.iter.EXPECT().Next().Return(nil).Times(0)
suite.rc.EXPECT().Query("query", "blah").
Return(suite.iter, nil)
iter, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah")
require.NoError(suite.T(), err)
require.NotNil(suite.T(), iter)
dec, err := iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
dec, err = iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec.Data)
dec, err = iter.Next()
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestQueryDecoded_DecodeError() {
suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once()
suite.iter.EXPECT().Next().Return(suite.artistBad).Once()
suite.iter.EXPECT().Next().Return(nil).Times(0)
suite.rc.EXPECT().Query("query", "blah").
Return(suite.iter, nil)
iter, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah")
require.NoError(suite.T(), err)
require.NotNil(suite.T(), iter)
dec, err := iter.Next()
require.NoError(suite.T(), err)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource)
prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data)
dec, err = iter.Next()
require.Error(suite.T(), err)
require.Nil(suite.T(), dec)
dec, err = iter.Next()
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestQueryDecoded_CacheError() {
suite.rc.EXPECT().Query("query", "blah").Return(nil, injectedError)
dec, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah")
require.ErrorIs(suite.T(), err, injectedError)
require.Nil(suite.T(), dec)
}
func (suite *decodedSuite) TestQueryDecoded_Nil() {
suite.rc.EXPECT().Query("query", "blah").Return(nil, nil)
dec, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah")
require.NoError(suite.T(), err)
require.Nil(suite.T(), dec)
}
func TestDecodedCache(t *testing.T) {
suite.Run(t, new(decodedSuite))
}

View File

@ -11,6 +11,10 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
type IndexQueryOptions struct {
Prefix bool
}
func IndexFromID(id *pbresource.ID, includeUid bool) []byte { func IndexFromID(id *pbresource.ID, includeUid bool) []byte {
var b Builder var b Builder
b.Raw(IndexFromType(id.Type)) b.Raw(IndexFromType(id.Type))
@ -87,6 +91,14 @@ var PrefixReferenceOrIDFromArgs = SingleValueFromArgs[resource.ReferenceOrID](fu
return PrefixIndexFromRefOrID(r), nil return PrefixIndexFromRefOrID(r), nil
}) })
var MaybePrefixReferenceOrIDFromArgs = SingleValueFromOneOrTwoArgs[resource.ReferenceOrID, IndexQueryOptions](func(r resource.ReferenceOrID, opts IndexQueryOptions) ([]byte, error) {
if opts.Prefix {
return PrefixIndexFromRefOrID(r), nil
} else {
return IndexFromRefOrID(r), nil
}
})
func SingleValueFromArgs[T any](indexer func(value T) ([]byte, error)) func(args ...any) ([]byte, error) { func SingleValueFromArgs[T any](indexer func(value T) ([]byte, error)) func(args ...any) ([]byte, error) {
return func(args ...any) ([]byte, error) { return func(args ...any) ([]byte, error) {
var zero T var zero T

View File

@ -38,7 +38,7 @@ type idOrRefIndexer struct {
// FromArgs constructs a radix tree key from an ID for lookup. // FromArgs constructs a radix tree key from an ID for lookup.
func (i idOrRefIndexer) FromArgs(args ...any) ([]byte, error) { func (i idOrRefIndexer) FromArgs(args ...any) ([]byte, error) {
return index.ReferenceOrIDFromArgs(args...) return index.MaybePrefixReferenceOrIDFromArgs(args...)
} }
// FromObject constructs a radix tree key from a Resource at write-time, or an // FromObject constructs a radix tree key from a Resource at write-time, or an

View File

@ -6,7 +6,7 @@ package workloadselectionmapper
import ( import (
"context" "context"
"github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/controller"
"github.com/hashicorp/consul/internal/mesh/internal/mappers/common" "github.com/hashicorp/consul/internal/mesh/internal/mappers/common"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
@ -15,12 +15,12 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
type Mapper[T catalog.WorkloadSelecting] struct { type Mapper[T workloadselector.WorkloadSelecting] struct {
workloadSelectionTracker *selectiontracker.WorkloadSelectionTracker workloadSelectionTracker *selectiontracker.WorkloadSelectionTracker
computedType *pbresource.Type computedType *pbresource.Type
} }
func New[T catalog.WorkloadSelecting](computedType *pbresource.Type) *Mapper[T] { func New[T workloadselector.WorkloadSelecting](computedType *pbresource.Type) *Mapper[T] {
if computedType == nil { if computedType == nil {
panic("computed type is required") panic("computed type is required")
} }

View File

@ -10,6 +10,7 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/catalog"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
@ -22,7 +23,7 @@ func RegisterDestinations(r resource.Registry) {
Scope: resource.ScopeNamespace, Scope: resource.ScopeNamespace,
Mutate: MutateDestinations, Mutate: MutateDestinations,
Validate: ValidateDestinations, Validate: ValidateDestinations,
ACLs: catalog.ACLHooksForWorkloadSelectingType[*pbmesh.Destinations](), ACLs: workloadselector.ACLHooks[*pbmesh.Destinations](),
}) })
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/catalog"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource"
pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1"
@ -18,7 +19,7 @@ func RegisterDestinationsConfiguration(r resource.Registry) {
Proto: &pbmesh.DestinationsConfiguration{}, Proto: &pbmesh.DestinationsConfiguration{},
Scope: resource.ScopeNamespace, Scope: resource.ScopeNamespace,
Validate: ValidateDestinationsConfiguration, Validate: ValidateDestinationsConfiguration,
ACLs: catalog.ACLHooksForWorkloadSelectingType[*pbmesh.DestinationsConfiguration](), ACLs: workloadselector.ACLHooks[*pbmesh.DestinationsConfiguration](),
}) })
} }

View File

@ -7,6 +7,7 @@ import (
"math" "math"
"github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/catalog"
"github.com/hashicorp/consul/internal/catalog/workloadselector"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@ -22,7 +23,7 @@ func RegisterProxyConfiguration(r resource.Registry) {
Scope: resource.ScopeNamespace, Scope: resource.ScopeNamespace,
Mutate: MutateProxyConfiguration, Mutate: MutateProxyConfiguration,
Validate: ValidateProxyConfiguration, Validate: ValidateProxyConfiguration,
ACLs: catalog.ACLHooksForWorkloadSelectingType[*pbmesh.ProxyConfiguration](), ACLs: workloadselector.ACLHooks[*pbmesh.ProxyConfiguration](),
}) })
} }

View File

@ -58,6 +58,22 @@ func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], err
}, nil }, nil
} }
// DecodeList will generically decode the provided resource list into a list of 2-field
// structures that holds onto the original Resource and the decoded contents.
//
// Returns an ErrDataParse on unmarshalling errors.
func DecodeList[T proto.Message](resources []*pbresource.Resource) ([]*DecodedResource[T], error) {
var decoded []*DecodedResource[T]
for _, res := range resources {
d, err := Decode[T](res)
if err != nil {
return nil, err
}
decoded = append(decoded, d)
}
return decoded, nil
}
// GetDecodedResource will generically read the requested resource using the // GetDecodedResource will generically read the requested resource using the
// client and either return nil on a NotFound or decode the response value. // client and either return nil on a NotFound or decode the response value.
func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) { func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) {

View File

@ -112,3 +112,50 @@ func TestDecode(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
} }
func TestDecodeList(t *testing.T) {
t.Run("good", func(t *testing.T) {
artist1, err := demo.GenerateV2Artist()
require.NoError(t, err)
artist2, err := demo.GenerateV2Artist()
require.NoError(t, err)
dec1, err := resource.Decode[*pbdemo.Artist](artist1)
require.NoError(t, err)
dec2, err := resource.Decode[*pbdemo.Artist](artist2)
require.NoError(t, err)
resources := []*pbresource.Resource{artist1, artist2}
decList, err := resource.DecodeList[*pbdemo.Artist](resources)
require.NoError(t, err)
require.Len(t, decList, 2)
prototest.AssertDeepEqual(t, dec1.Resource, decList[0].Resource)
prototest.AssertDeepEqual(t, dec1.Data, decList[0].Data)
prototest.AssertDeepEqual(t, dec2.Resource, decList[1].Resource)
prototest.AssertDeepEqual(t, dec2.Data, decList[1].Data)
})
t.Run("bad", func(t *testing.T) {
artist1, err := demo.GenerateV2Artist()
require.NoError(t, err)
foo := &pbresource.Resource{
Id: &pbresource.ID{
Type: demo.TypeV2Artist,
Tenancy: resource.DefaultNamespacedTenancy(),
Name: "babypants",
},
Data: &anypb.Any{
TypeUrl: "garbage",
Value: []byte("more garbage"),
},
Metadata: map[string]string{
"generated_at": time.Now().Format(time.RFC3339),
},
}
_, err = resource.DecodeList[*pbdemo.Artist]([]*pbresource.Resource{artist1, foo})
require.Error(t, err)
})
}

View File

@ -11,6 +11,10 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
type MetadataFilterableResources interface {
GetMetadata() map[string]string
}
// FilterResourcesByMetadata will use the provided go-bexpr based filter to // FilterResourcesByMetadata will use the provided go-bexpr based filter to
// retain matching items from the provided slice. // retain matching items from the provided slice.
// //
@ -18,7 +22,7 @@ import (
// by "metadata." // by "metadata."
// //
// If no filter is provided, then this does nothing and returns the input. // If no filter is provided, then this does nothing and returns the input.
func FilterResourcesByMetadata(resources []*pbresource.Resource, filter string) ([]*pbresource.Resource, error) { func FilterResourcesByMetadata[T MetadataFilterableResources](resources []T, filter string) ([]T, error) {
if filter == "" || len(resources) == 0 { if filter == "" || len(resources) == 0 {
return resources, nil return resources, nil
} }
@ -28,10 +32,10 @@ func FilterResourcesByMetadata(resources []*pbresource.Resource, filter string)
return nil, err return nil, err
} }
filtered := make([]*pbresource.Resource, 0, len(resources)) filtered := make([]T, 0, len(resources))
for _, res := range resources { for _, res := range resources {
vars := &metadataFilterFieldDetails{ vars := &metadataFilterFieldDetails{
Meta: res.Metadata, Meta: res.GetMetadata(),
} }
match, err := eval.Evaluate(vars) match, err := eval.Evaluate(vars)
if err != nil { if err != nil {

View File

@ -16,3 +16,9 @@ func MustDecode[Tp proto.Message](t T, res *pbresource.Resource) *resource.Decod
require.NoError(t, err) require.NoError(t, err)
return dec return dec
} }
func MustDecodeList[Tp proto.Message](t T, resources []*pbresource.Resource) []*resource.DecodedResource[Tp] {
dec, err := resource.DecodeList[Tp](resources)
require.NoError(t, err)
return dec
}