diff --git a/acl/MockAuthorizer.go b/acl/MockAuthorizer.go index 9941f81e3f..e3a97ceec9 100644 --- a/acl/MockAuthorizer.go +++ b/acl/MockAuthorizer.go @@ -225,7 +225,7 @@ func (m *MockAuthorizer) ServiceReadAll(ctx *AuthorizerContext) EnforcementDecis } func (m *MockAuthorizer) ServiceReadPrefix(prefix string, ctx *AuthorizerContext) EnforcementDecision { - ret := m.Called(ctx) + ret := m.Called(prefix, ctx) return ret.Get(0).(EnforcementDecision) } diff --git a/internal/catalog/catalogtest/helpers/acl_hooks_test_helpers.go b/internal/catalog/catalogtest/helpers/acl_hooks_test_helpers.go index 097647ed08..9679ef0436 100644 --- a/internal/catalog/catalogtest/helpers/acl_hooks_test_helpers.go +++ b/internal/catalog/catalogtest/helpers/acl_hooks_test_helpers.go @@ -6,14 +6,14 @@ package helpers import ( "testing" - "github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/catalog/internal/testhelpers" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" ) -func RunWorkloadSelectingTypeACLsTests[T catalog.WorkloadSelecting](t *testing.T, typ *pbresource.Type, +func RunWorkloadSelectingTypeACLsTests[T workloadselector.WorkloadSelecting](t *testing.T, typ *pbresource.Type, getData func(selector *pbcatalog.WorkloadSelector) T, registerFunc func(registry resource.Registry), ) { diff --git a/internal/catalog/catalogtest/test_integration_v2beta1.go b/internal/catalog/catalogtest/test_integration_v2beta1.go index 5a7ec6a82a..0ffca6972d 100644 --- a/internal/catalog/catalogtest/test_integration_v2beta1.go +++ b/internal/catalog/catalogtest/test_integration_v2beta1.go @@ -115,10 +115,10 @@ func VerifyCatalogV2Beta1IntegrationTestResults(t *testing.T, client pbresource. }) testutil.RunStep(t, "service-reconciliation", func(t *testing.T) { - c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "foo").ID(), endpoints.StatusKey, endpoints.ConditionUnmanaged) - c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "api").ID(), endpoints.StatusKey, endpoints.ConditionManaged) - c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "http-api").ID(), endpoints.StatusKey, endpoints.ConditionManaged) - c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "grpc-api").ID(), endpoints.StatusKey, endpoints.ConditionManaged) + c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "foo").ID(), endpoints.ControllerID, endpoints.ConditionUnmanaged) + c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "api").ID(), endpoints.ControllerID, endpoints.ConditionManaged) + c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "http-api").ID(), endpoints.ControllerID, endpoints.ConditionManaged) + c.WaitForStatusCondition(t, rtest.Resource(pbcatalog.ServiceType, "grpc-api").ID(), endpoints.ControllerID, endpoints.ConditionManaged) }) testutil.RunStep(t, "service-endpoints-generation", func(t *testing.T) { diff --git a/internal/catalog/exports.go b/internal/catalog/exports.go index 85ce182cbf..546fe30e2d 100644 --- a/internal/catalog/exports.go +++ b/internal/catalog/exports.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/resource" - "github.com/hashicorp/consul/internal/resource/mappers/selectiontracker" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" ) @@ -29,7 +28,7 @@ var ( WorkloadHealthConditions = workloadhealth.WorkloadConditions WorkloadAndNodeHealthConditions = workloadhealth.NodeAndWorkloadConditions - EndpointsStatusKey = endpoints.StatusKey + EndpointsStatusKey = endpoints.ControllerID EndpointsStatusConditionEndpointsManaged = endpoints.StatusConditionEndpointsManaged EndpointsStatusConditionManaged = endpoints.ConditionManaged EndpointsStatusConditionUnmanaged = endpoints.ConditionUnmanaged @@ -47,12 +46,6 @@ var ( FailoverStatusConditionAcceptedUsingMeshDestinationPortReason = failover.UsingMeshDestinationPortReason ) -type WorkloadSelecting = types.WorkloadSelecting - -func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks { - return types.ACLHooksForWorkloadSelectingType[T]() -} - // RegisterTypes adds all resource types within the "catalog" API group // to the given type registry func RegisterTypes(r resource.Registry) { @@ -63,8 +56,7 @@ type ControllerDependencies = controllers.Dependencies func DefaultControllerDependencies() ControllerDependencies { return ControllerDependencies{ - EndpointsWorkloadMapper: selectiontracker.New(), - FailoverMapper: failovermapper.New(), + FailoverMapper: failovermapper.New(), } } diff --git a/internal/catalog/internal/controllers/endpoints/controller.go b/internal/catalog/internal/controllers/endpoints/controller.go index 1dacc30292..391155a6a0 100644 --- a/internal/catalog/internal/controllers/endpoints/controller.go +++ b/internal/catalog/internal/controllers/endpoints/controller.go @@ -11,7 +11,9 @@ import ( "google.golang.org/protobuf/types/known/anypb" "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache" "github.com/hashicorp/consul/internal/controller/dependency" "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" @@ -20,51 +22,52 @@ import ( const ( endpointsMetaManagedBy = "managed-by-controller" + + selectedWorkloadsIndexName = "selected-workloads" ) -// The WorkloadMapper interface is used to provide an implementation around being able -// to map a watch even for a Workload resource and translate it to reconciliation requests -type WorkloadMapper interface { - // MapWorkload conforms to the controller.DependencyMapper signature. Given a Workload - // resource it should report the resource IDs that have selected the workload. - MapWorkload(context.Context, controller.Runtime, *pbresource.Resource) ([]controller.Request, error) - - // TrackIDForSelector should be used to associate the specified WorkloadSelector with - // the given resource ID. Future calls to MapWorkload - TrackIDForSelector(*pbresource.ID, *pbcatalog.WorkloadSelector) - - // UntrackID should be used to inform the tracker to forget about the specified ID - UntrackID(*pbresource.ID) -} +type ( + DecodedWorkload = resource.DecodedResource[*pbcatalog.Workload] + DecodedService = resource.DecodedResource[*pbcatalog.Service] + DecodedServiceEndpoints = resource.DecodedResource[*pbcatalog.ServiceEndpoints] +) // ServiceEndpointsController creates a controller to perform automatic endpoint management for // services. -func ServiceEndpointsController(workloadMap WorkloadMapper) *controller.Controller { - if workloadMap == nil { - panic("No WorkloadMapper was provided to the ServiceEndpointsController constructor") - } - - return controller.NewController(StatusKey, pbcatalog.ServiceEndpointsType). - WithWatch(pbcatalog.ServiceType, dependency.ReplaceType(pbcatalog.ServiceEndpointsType)). - WithWatch(pbcatalog.WorkloadType, workloadMap.MapWorkload). - WithReconciler(newServiceEndpointsReconciler(workloadMap)) +func ServiceEndpointsController() *controller.Controller { + return controller.NewController(ControllerID, pbcatalog.ServiceEndpointsType). + WithWatch(pbcatalog.ServiceType, + // ServiceEndpoints are name-aligned with the Service type + dependency.ReplaceType(pbcatalog.ServiceEndpointsType), + // This cache index keeps track of the relationship between WorkloadSelectors (and the workload names and prefixes + // they include) and Services. This allows us to efficiently find all services and service endpoints that are + // are affected by the change to a workload. + workloadselector.Index[*pbcatalog.Service](selectedWorkloadsIndexName)). + WithWatch(pbcatalog.WorkloadType, + // The cache index is kept on the Service type but we need to translate events for ServiceEndpoints. + // Therefore we need to wrap the mapper from the workloadselector package with one which will + // replace the request types of Service with ServiceEndpoints. + dependency.WrapAndReplaceType( + pbcatalog.ServiceEndpointsType, + // This mapper will use the selected-workloads index to find all Services which select this + // workload by exact name or by prefix. + workloadselector.MapWorkloadsToSelectors(pbcatalog.ServiceType, selectedWorkloadsIndexName), + ), + ). + WithReconciler(newServiceEndpointsReconciler()) } -type serviceEndpointsReconciler struct { - workloadMap WorkloadMapper -} +type serviceEndpointsReconciler struct{} -func newServiceEndpointsReconciler(workloadMap WorkloadMapper) *serviceEndpointsReconciler { - return &serviceEndpointsReconciler{ - workloadMap: workloadMap, - } +func newServiceEndpointsReconciler() *serviceEndpointsReconciler { + return &serviceEndpointsReconciler{} } // Reconcile will reconcile one ServiceEndpoints resource in response to some event. func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error { // The runtime is passed by value so replacing it here for the remainder of this // reconciliation request processing will not affect future invocations. - rt.Logger = rt.Logger.With("resource-id", req.ID, "controller", StatusKey) + rt.Logger = rt.Logger.With("resource-id", req.ID) rt.Logger.Trace("reconciling service endpoints") @@ -76,21 +79,16 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle } // First we read and unmarshal the service - - serviceData, err := getServiceData(ctx, rt, serviceID) + service, err := cache.GetDecoded[*pbcatalog.Service](rt.Cache, pbcatalog.ServiceType, "id", serviceID) if err != nil { rt.Logger.Error("error retrieving corresponding Service", "error", err) return err } // Check if the service exists. If it doesn't we can avoid a bunch of other work. - if serviceData == nil { + if service == nil { rt.Logger.Trace("service has been deleted") - // The service was deleted so we need to update the WorkloadMapper to tell it to - // stop tracking this service - r.workloadMap.UntrackID(req.ID) - // Note that because we configured ServiceEndpoints to be owned by the service, // the service endpoints object should eventually be automatically deleted. // There is no reason to attempt deletion here. @@ -100,7 +98,7 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle // Now read and unmarshal the endpoints. We don't need this data just yet but all // code paths from this point on will need this regardless of branching so we pull // it now. - endpointsData, err := getEndpointsData(ctx, rt, endpointsID) + endpoints, err := cache.GetDecoded[*pbcatalog.ServiceEndpoints](rt.Cache, pbcatalog.ServiceEndpointsType, "id", endpointsID) if err != nil { rt.Logger.Error("error retrieving existing endpoints", "error", err) return err @@ -108,40 +106,29 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle var statusConditions []*pbresource.Condition - if serviceUnderManagement(serviceData.service) { + if serviceUnderManagement(service.Data) { rt.Logger.Trace("service is enabled for automatic endpoint management") // This service should have its endpoints automatically managed statusConditions = append(statusConditions, ConditionManaged) - // Inform the WorkloadMapper to track this service and its selectors. So - // future workload updates that would be matched by the services selectors - // cause this service to be rereconciled. - r.workloadMap.TrackIDForSelector(req.ID, serviceData.service.GetWorkloads()) - - // Now read and unmarshal all workloads selected by the service. It is imperative - // that this happens after we notify the selection tracker to be tracking that - // selection criteria. If the order were reversed we could potentially miss - // workload creations that should be selected if they happen after gathering - // the workloads but before tracking the selector. Tracking first ensures that - // any event that happens after that would get mapped to an event for these - // endpoints. - workloadData, err := getWorkloadData(ctx, rt, serviceData) + // Now read and unmarshal all workloads selected by the service. + workloads, err := workloadselector.GetWorkloadsWithSelector(rt.Cache, service) if err != nil { rt.Logger.Trace("error retrieving selected workloads", "error", err) return err } // Calculate the latest endpoints from the already gathered workloads - latestEndpoints := workloadsToEndpoints(serviceData.service, workloadData) + latestEndpoints := workloadsToEndpoints(service.Data, workloads) // Add status - if endpointsData != nil { + if endpoints != nil { statusConditions = append(statusConditions, workloadIdentityStatusFromEndpoints(latestEndpoints)) } // Before writing the endpoints actually check to see if they are changed - if endpointsData == nil || !proto.Equal(endpointsData.endpoints, latestEndpoints) { + if endpoints == nil || !proto.Equal(endpoints.Data, latestEndpoints) { rt.Logger.Trace("endpoints have changed") // First encode the endpoints data as an Any type. @@ -158,9 +145,9 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle _, err = rt.Client.Write(ctx, &pbresource.WriteRequest{ Resource: &pbresource.Resource{ Id: req.ID, - Owner: serviceData.resource.Id, + Owner: service.Id, Metadata: map[string]string{ - endpointsMetaManagedBy: StatusKey, + endpointsMetaManagedBy: ControllerID, }, Data: endpointData, }, @@ -177,20 +164,16 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle // This service is not having its endpoints automatically managed statusConditions = append(statusConditions, ConditionUnmanaged) - // Inform the WorkloadMapper that it no longer needs to track this service - // as it is no longer under endpoint management - r.workloadMap.UntrackID(req.ID) - // Delete the managed ServiceEndpoints if necessary if the metadata would // indicate that they were previously managed by this controller - if endpointsData != nil && endpointsData.resource.Metadata[endpointsMetaManagedBy] == StatusKey { + if endpoints != nil && endpoints.Metadata[endpointsMetaManagedBy] == ControllerID { rt.Logger.Trace("removing previous managed endpoints") // This performs a CAS deletion to protect against the case where the user // has overwritten the endpoints since we fetched them. _, err := rt.Client.Delete(ctx, &pbresource.DeleteRequest{ - Id: endpointsData.resource.Id, - Version: endpointsData.resource.Version, + Id: endpoints.Id, + Version: endpoints.Version, }) // Potentially we could look for CAS failures by checking if the gRPC @@ -209,17 +192,17 @@ func (r *serviceEndpointsReconciler) Reconcile(ctx context.Context, rt controlle // whether we are automatically managing the endpoints to set expectations // for that object existing or not. newStatus := &pbresource.Status{ - ObservedGeneration: serviceData.resource.Generation, + ObservedGeneration: service.Generation, Conditions: statusConditions, } // If the status is unchanged then we should return and avoid the unnecessary write - if resource.EqualStatus(serviceData.resource.Status[StatusKey], newStatus, false) { + if resource.EqualStatus(service.Status[ControllerID], newStatus, false) { return nil } _, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{ - Id: serviceData.resource.Id, - Key: StatusKey, + Id: service.Id, + Key: ControllerID, Status: newStatus, }) @@ -275,7 +258,7 @@ func serviceUnderManagement(svc *pbcatalog.Service) bool { } // workloadsToEndpoints will translate the Workload resources into a ServiceEndpoints resource -func workloadsToEndpoints(svc *pbcatalog.Service, workloads []*workloadData) *pbcatalog.ServiceEndpoints { +func workloadsToEndpoints(svc *pbcatalog.Service, workloads []*DecodedWorkload) *pbcatalog.ServiceEndpoints { var endpoints []*pbcatalog.Endpoint for _, workload := range workloads { @@ -300,8 +283,8 @@ func workloadsToEndpoints(svc *pbcatalog.Service, workloads []*workloadData) *pb // have reconciled the workloads health and stored it within the resources Status field. // Any unreconciled workload health will be represented in the ServiceEndpoints with // the ANY health status. -func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.Endpoint { - health := determineWorkloadHealth(data.resource) +func workloadToEndpoint(svc *pbcatalog.Service, workload *DecodedWorkload) *pbcatalog.Endpoint { + health := determineWorkloadHealth(workload.Resource) endpointPorts := make(map[string]*pbcatalog.WorkloadPort) @@ -309,7 +292,7 @@ func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.E // one of the services ports are included. Ports with a protocol mismatch // between the service and workload will be excluded as well. for _, svcPort := range svc.Ports { - workloadPort, found := data.workload.Ports[svcPort.TargetPort] + workloadPort, found := workload.Data.Ports[svcPort.TargetPort] if !found { // this workload doesn't have this port so ignore it continue @@ -336,7 +319,7 @@ func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.E // address list. If some but not all of its ports are served, then the list // of ports will be reduced to just the intersection of the service ports // and the workload addresses ports - for _, addr := range data.workload.Addresses { + for _, addr := range workload.Data.Addresses { var ports []string if len(addr.Ports) > 0 { @@ -386,12 +369,12 @@ func workloadToEndpoint(svc *pbcatalog.Service, data *workloadData) *pbcatalog.E } return &pbcatalog.Endpoint{ - TargetRef: data.resource.Id, + TargetRef: workload.Id, HealthStatus: health, Addresses: workloadAddrs, Ports: endpointPorts, - Identity: data.workload.Identity, - Dns: data.workload.Dns, + Identity: workload.Data.Identity, + Dns: workload.Data.Dns, } } diff --git a/internal/catalog/internal/controllers/endpoints/controller_test.go b/internal/catalog/internal/controllers/endpoints/controller_test.go index 743b9755fe..c079179ffc 100644 --- a/internal/catalog/internal/controllers/endpoints/controller_test.go +++ b/internal/catalog/internal/controllers/endpoints/controller_test.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" "github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/controller" - "github.com/hashicorp/consul/internal/resource/mappers/selectiontracker" "github.com/hashicorp/consul/internal/resource/resourcetest" rtest "github.com/hashicorp/consul/internal/resource/resourcetest" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" @@ -63,26 +62,23 @@ func TestWorkloadsToEndpoints(t *testing.T) { } // Build out the workloads. - workloads := []*workloadData{ - { - // this workload should result in an endpoints - resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). + workloads := []*DecodedWorkload{ + rtest.MustDecode[*pbcatalog.Workload]( + t, + rtest.Resource(pbcatalog.WorkloadType, "foo"). WithData(t, workloadData1). - Build(), - workload: workloadData1, - }, - { - // this workload should be filtered out - resource: rtest.Resource(pbcatalog.WorkloadType, "bar"). + Build()), + + rtest.MustDecode[*pbcatalog.Workload]( + t, + rtest.Resource(pbcatalog.WorkloadType, "bar"). WithData(t, workloadData2). - Build(), - workload: workloadData2, - }, + Build()), } endpoints := workloadsToEndpoints(service, workloads) require.Len(t, endpoints.Endpoints, 1) - prototest.AssertDeepEqual(t, workloads[0].resource.Id, endpoints.Endpoints[0].TargetRef) + prototest.AssertDeepEqual(t, workloads[0].Id, endpoints.Endpoints[0].TargetRef) } func TestWorkloadToEndpoint(t *testing.T) { @@ -135,15 +131,12 @@ func TestWorkloadToEndpoint(t *testing.T) { }, } - data := &workloadData{ - resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). - WithData(t, workload). - Build(), - workload: workload, - } + data := rtest.MustDecode[*pbcatalog.Workload](t, rtest.Resource(pbcatalog.WorkloadType, "foo"). + WithData(t, workload). + Build()) expected := &pbcatalog.Endpoint{ - TargetRef: data.resource.Id, + TargetRef: data.Id, Addresses: []*pbcatalog.WorkloadAddress{ {Host: "127.0.0.1", Ports: []string{"http"}}, {Host: "198.18.1.1", Ports: []string{"http"}}, @@ -189,12 +182,11 @@ func TestWorkloadToEndpoint_AllAddressesFiltered(t *testing.T) { }, } - data := &workloadData{ - resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). + data := rtest.MustDecode[*pbcatalog.Workload]( + t, + rtest.Resource(pbcatalog.WorkloadType, "foo"). WithData(t, workload). - Build(), - workload: workload, - } + Build()) require.Nil(t, workloadToEndpoint(service, data)) } @@ -218,15 +210,14 @@ func TestWorkloadToEndpoint_MissingWorkloadProtocol(t *testing.T) { }, } - data := &workloadData{ - resource: rtest.Resource(pbcatalog.WorkloadType, "foo"). + data := rtest.MustDecode[*pbcatalog.Workload]( + t, + rtest.Resource(pbcatalog.WorkloadType, "foo"). WithData(t, workload). - Build(), - workload: workload, - } + Build()) expected := &pbcatalog.Endpoint{ - TargetRef: data.resource.Id, + TargetRef: data.Id, Addresses: []*pbcatalog.WorkloadAddress{ {Host: "127.0.0.1", Ports: []string{"test-port"}}, }, @@ -453,9 +444,8 @@ type controllerSuite struct { client *rtest.Client rt controller.Runtime - tracker *selectiontracker.WorkloadSelectionTracker - reconciler *serviceEndpointsReconciler - tenancies []*pbresource.Tenancy + ctl *controller.TestController + tenancies []*pbresource.Tenancy } func (suite *controllerSuite) SetupTest() { @@ -465,22 +455,10 @@ func (suite *controllerSuite) SetupTest() { WithRegisterFns(types.Register). WithTenancies(suite.tenancies...). Run(suite.T()) - suite.rt = controller.Runtime{ - Client: client, - Logger: testutil.Logger(suite.T()), - } - suite.client = rtest.NewClient(client) - suite.tracker = selectiontracker.New() - suite.reconciler = newServiceEndpointsReconciler(suite.tracker) -} - -func (suite *controllerSuite) requireTracking(workload *pbresource.Resource, ids ...*pbresource.ID) { - reqs, err := suite.tracker.MapWorkload(suite.ctx, suite.rt, workload) - require.NoError(suite.T(), err) - require.Len(suite.T(), reqs, len(ids)) - for _, id := range ids { - prototest.AssertContainsElement(suite.T(), reqs, controller.Request{ID: id}) - } + suite.ctl = controller.NewTestController(ServiceEndpointsController(), client). + WithLogger(testutil.Logger(suite.T())) + suite.rt = suite.ctl.Runtime() + suite.client = rtest.NewClient(suite.rt.Client) } func (suite *controllerSuite) requireEndpoints(resource *pbresource.Resource, expected ...*pbcatalog.Endpoint) { @@ -491,33 +469,14 @@ func (suite *controllerSuite) requireEndpoints(resource *pbresource.Resource, ex } func (suite *controllerSuite) TestReconcile_ServiceNotFound() { - // This test's purpose is to ensure that when we are reconciling - // endpoints for a service that no longer exists, we stop - // tracking the endpoints resource ID in the selection tracker. - - // generate a workload resource to use for checking if it maps - // to a service endpoints object - + // This test really only checks that the Reconcile call will not panic or otherwise error + // when the request is for an endpoints object whose corresponding service does not exist. suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - workload := rtest.Resource(pbcatalog.WorkloadType, "foo").WithTenancy(tenancy).Build() - - // ensure that the tracker knows about the service prior to - // calling reconcile so that we can ensure it removes tracking id := rtest.Resource(pbcatalog.ServiceEndpointsType, "not-found").WithTenancy(tenancy).ID() - suite.tracker.TrackIDForSelector(id, &pbcatalog.WorkloadSelector{Prefixes: []string{""}}) - // verify that mapping the workload to service endpoints returns a - // non-empty list prior to reconciliation which should remove the - // tracking. - suite.requireTracking(workload, id) - - // Because the endpoints don't exist, this reconcile call should - // cause tracking of the endpoints to be removed - err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: id}) + // Because the endpoints don't exist, this reconcile call not error but also shouldn't do anything useful. + err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: id}) require.NoError(suite.T(), err) - - // Now ensure that the tracking was removed - suite.requireTracking(workload) }) } @@ -539,10 +498,10 @@ func (suite *controllerSuite) TestReconcile_NoSelector_NoEndpoints() { endpointsID := rtest.Resource(pbcatalog.ServiceEndpointsType, "test").WithTenancy(tenancy).ID() - err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpointsID}) + err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpointsID}) require.NoError(suite.T(), err) - suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionUnmanaged) + suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionUnmanaged) }) } @@ -565,13 +524,13 @@ func (suite *controllerSuite) TestReconcile_NoSelector_ManagedEndpoints() { WithTenancy(tenancy). WithData(suite.T(), &pbcatalog.ServiceEndpoints{}). // this marks these endpoints as under management - WithMeta(endpointsMetaManagedBy, StatusKey). + WithMeta(endpointsMetaManagedBy, ControllerID). Write(suite.T(), suite.client) - err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpoints.Id}) + err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpoints.Id}) require.NoError(suite.T(), err) // the status should indicate the services endpoints are not being managed - suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionUnmanaged) + suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionUnmanaged) // endpoints under management should be deleted suite.client.RequireResourceNotFound(suite.T(), endpoints.Id) }) @@ -597,10 +556,10 @@ func (suite *controllerSuite) TestReconcile_NoSelector_UnmanagedEndpoints() { WithData(suite.T(), &pbcatalog.ServiceEndpoints{}). Write(suite.T(), suite.client) - err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpoints.Id}) + err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpoints.Id}) require.NoError(suite.T(), err) // the status should indicate the services endpoints are not being managed - suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionUnmanaged) + suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionUnmanaged) // unmanaged endpoints should not be deleted when the service is unmanaged suite.client.RequireResourceExists(suite.T(), endpoints.Id) }) @@ -635,14 +594,14 @@ func (suite *controllerSuite) TestReconcile_Managed_NoPreviousEndpoints() { }). Write(suite.T(), suite.client) - err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpointsID}) + err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpointsID}) require.NoError(suite.T(), err) // Verify that the services status has been set to indicate endpoints are automatically managed. - suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionManaged) + suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionManaged) // The service endpoints metadata should include our tag to indcate it was generated by this controller - res := suite.client.RequireResourceMeta(suite.T(), endpointsID, endpointsMetaManagedBy, StatusKey) + res := suite.client.RequireResourceMeta(suite.T(), endpointsID, endpointsMetaManagedBy, ControllerID) var endpoints pbcatalog.ServiceEndpoints err = res.Data.UnmarshalTo(&endpoints) @@ -688,11 +647,11 @@ func (suite *controllerSuite) TestReconcile_Managed_ExistingEndpoints() { }). Write(suite.T(), suite.client) - err := suite.reconciler.Reconcile(suite.ctx, suite.rt, controller.Request{ID: endpoints.Id}) + err := suite.ctl.Reconcile(suite.ctx, controller.Request{ID: endpoints.Id}) require.NoError(suite.T(), err) - suite.client.RequireStatusCondition(suite.T(), service.Id, StatusKey, ConditionManaged) - res := suite.client.RequireResourceMeta(suite.T(), endpoints.Id, endpointsMetaManagedBy, StatusKey) + suite.client.RequireStatusCondition(suite.T(), service.Id, ControllerID, ConditionManaged) + res := suite.client.RequireResourceMeta(suite.T(), endpoints.Id, endpointsMetaManagedBy, ControllerID) var newEndpoints pbcatalog.ServiceEndpoints err = res.Data.UnmarshalTo(&newEndpoints) @@ -711,7 +670,7 @@ func (suite *controllerSuite) TestController() { // Run the controller manager mgr := controller.NewManager(suite.client, suite.rt.Logger) - mgr.Register(ServiceEndpointsController(suite.tracker)) + mgr.Register(ServiceEndpointsController()) mgr.SetRaftLeader(true) go mgr.Run(suite.ctx) @@ -731,10 +690,10 @@ func (suite *controllerSuite) TestController() { Write(suite.T(), suite.client) // Wait for the controller to record that the endpoints are being managed - res := suite.client.WaitForReconciliation(suite.T(), service.Id, StatusKey) + res := suite.client.WaitForReconciliation(suite.T(), service.Id, ControllerID) // Check that the services status was updated accordingly - rtest.RequireStatusCondition(suite.T(), res, StatusKey, ConditionManaged) - rtest.RequireStatusCondition(suite.T(), res, StatusKey, ConditionIdentitiesNotFound) + rtest.RequireStatusCondition(suite.T(), res, ControllerID, ConditionManaged) + rtest.RequireStatusCondition(suite.T(), res, ControllerID, ConditionIdentitiesNotFound) // Check that the endpoints resource exists and contains 0 endpoints endpointsID := rtest.Resource(pbcatalog.ServiceEndpointsType, "api").WithTenancy(tenancy).ID() @@ -755,7 +714,7 @@ func (suite *controllerSuite) TestController() { }). Write(suite.T(), suite.client) - suite.client.WaitForStatusCondition(suite.T(), service.Id, StatusKey, + suite.client.WaitForStatusCondition(suite.T(), service.Id, ControllerID, ConditionIdentitiesFound([]string{"api"})) // Wait for the endpoints to be regenerated @@ -818,7 +777,7 @@ func (suite *controllerSuite) TestController() { }). Write(suite.T(), suite.client) - suite.client.WaitForStatusCondition(suite.T(), service.Id, StatusKey, ConditionIdentitiesFound([]string{"endpoints-api-identity"})) + suite.client.WaitForStatusCondition(suite.T(), service.Id, ControllerID, ConditionIdentitiesFound([]string{"endpoints-api-identity"})) // Verify that the generated endpoints now contain the workload endpoints = suite.client.WaitForNewVersion(suite.T(), endpointsID, endpoints.Version) @@ -850,7 +809,7 @@ func (suite *controllerSuite) TestController() { Write(suite.T(), suite.client) // Wait for the service status' observed generation to get bumped - service = suite.client.WaitForReconciliation(suite.T(), service.Id, StatusKey) + service = suite.client.WaitForReconciliation(suite.T(), service.Id, ControllerID) // Verify that the endpoints were not regenerated suite.client.RequireVersionUnchanged(suite.T(), endpointsID, endpoints.Version) @@ -891,8 +850,8 @@ func (suite *controllerSuite) TestController() { }). Write(suite.T(), suite.client) - res = suite.client.WaitForReconciliation(suite.T(), service.Id, StatusKey) - rtest.RequireStatusCondition(suite.T(), res, StatusKey, ConditionUnmanaged) + res = suite.client.WaitForReconciliation(suite.T(), service.Id, ControllerID) + rtest.RequireStatusCondition(suite.T(), res, ControllerID, ConditionUnmanaged) // Verify that the endpoints were deleted suite.client.RequireResourceNotFound(suite.T(), endpointsID) diff --git a/internal/catalog/internal/controllers/endpoints/reconciliation_data.go b/internal/catalog/internal/controllers/endpoints/reconciliation_data.go deleted file mode 100644 index 186354eda9..0000000000 --- a/internal/catalog/internal/controllers/endpoints/reconciliation_data.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package endpoints - -import ( - "context" - "fmt" - "sort" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/hashicorp/consul/internal/controller" - "github.com/hashicorp/consul/internal/resource" - pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" -) - -type serviceData struct { - resource *pbresource.Resource - service *pbcatalog.Service -} - -type endpointsData struct { - resource *pbresource.Resource - endpoints *pbcatalog.ServiceEndpoints -} - -type workloadData struct { - resource *pbresource.Resource - workload *pbcatalog.Workload -} - -// getServiceData will read the service with the given ID and unmarshal the -// Data field. The return value is a struct that contains the retrieved -// resource as well as the unmarshalled form. If the resource doesn't -// exist, nil will be returned. Any other error either with retrieving -// the resource or unmarshalling it will cause the error to be returned -// to the caller -func getServiceData(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*serviceData, error) { - rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: id}) - switch { - case status.Code(err) == codes.NotFound: - return nil, nil - case err != nil: - return nil, err - } - - var service pbcatalog.Service - err = rsp.Resource.Data.UnmarshalTo(&service) - if err != nil { - return nil, resource.NewErrDataParse(&service, err) - } - - return &serviceData{resource: rsp.Resource, service: &service}, nil -} - -// getEndpointsData will read the endpoints with the given ID and unmarshal the -// Data field. The return value is a struct that contains the retrieved -// resource as well as the unmsashalled form. If the resource doesn't -// exist, nil will be returned. Any other error either with retrieving -// the resource or unmarshalling it will cause the error to be returned -// to the caller -func getEndpointsData(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*endpointsData, error) { - rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: id}) - switch { - case status.Code(err) == codes.NotFound: - return nil, nil - case err != nil: - return nil, err - } - - var endpoints pbcatalog.ServiceEndpoints - err = rsp.Resource.Data.UnmarshalTo(&endpoints) - if err != nil { - return nil, resource.NewErrDataParse(&endpoints, err) - } - - return &endpointsData{resource: rsp.Resource, endpoints: &endpoints}, nil -} - -// getWorkloadData will retrieve all workloads for the given services selector -// and unmarhshal them, returning a slic of objects hold both the resource and -// unmarshaled forms. Unmarshalling errors, or other resource service errors -// will be returned to the caller. -func getWorkloadData(ctx context.Context, rt controller.Runtime, svc *serviceData) ([]*workloadData, error) { - workloadResources, err := gatherWorkloadsForService(ctx, rt, svc) - if err != nil { - return nil, err - } - - var results []*workloadData - for _, res := range workloadResources { - var workload pbcatalog.Workload - err = res.Data.UnmarshalTo(&workload) - if err != nil { - return nil, resource.NewErrDataParse(&workload, err) - } - - results = append(results, &workloadData{resource: res, workload: &workload}) - } - - return results, nil -} - -// gatherWorkloadsForService will retrieve all the unique workloads for a given selector. -// NotFound errors for workloads selected by Name will be ignored. Any other -// resource service errors will be returned to the caller. Prior to returning -// the slice of resources, they will be sorted by name. The consistent ordering -// will allow callers to diff two versions of the data to determine if anything -// has changed but it also will make testing a little easier. -func gatherWorkloadsForService(ctx context.Context, rt controller.Runtime, svc *serviceData) ([]*pbresource.Resource, error) { - var workloads []*pbresource.Resource - - sel := svc.service.GetWorkloads() - - // this map will track all the gathered workloads by name, this is mainly to deduplicate workloads if they - // are specified multiple times throughout the list of selection criteria - workloadNames := make(map[string]struct{}) - - // First gather all the prefix matched workloads. We could do this second but by doing - // it first its possible we can avoid some resource service calls to read individual - // workloads selected by name if they are also matched by a prefix. - for _, prefix := range sel.GetPrefixes() { - rsp, err := rt.Client.List(ctx, &pbresource.ListRequest{ - Type: pbcatalog.WorkloadType, - Tenancy: svc.resource.Id.Tenancy, - NamePrefix: prefix, - }) - if err != nil { - return nil, err - } - - // append all workloads in the list response to our list of all selected workloads - for _, workload := range rsp.Resources { - // ignore duplicate workloads - if _, found := workloadNames[workload.Id.Name]; !found { - workloads = append(workloads, workload) - workloadNames[workload.Id.Name] = struct{}{} - } - } - } - - // Now gather the exact match selections - for _, name := range sel.GetNames() { - // ignore names we have already fetched - if _, found := workloadNames[name]; found { - continue - } - - workloadID := &pbresource.ID{ - Type: pbcatalog.WorkloadType, - Tenancy: svc.resource.Id.Tenancy, - Name: name, - } - - rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: workloadID}) - switch { - case status.Code(err) == codes.NotFound: - // Ignore not found errors as services may select workloads that do not - // yet exist. This is not considered an error state or mis-configuration - // as the user could be getting ready to add the workloads. - continue - case err != nil: - return nil, err - } - - workloads = append(workloads, rsp.Resource) - workloadNames[rsp.Resource.Id.Name] = struct{}{} - } - - if sel.GetFilter() != "" && len(workloads) > 0 { - var err error - workloads, err = resource.FilterResourcesByMetadata(workloads, sel.GetFilter()) - if err != nil { - return nil, fmt.Errorf("error filtering results by metadata: %w", err) - } - } - - // Sorting ensures deterministic output. This will help for testing but - // the real reason to do this is so we will be able to diff the set of - // workloads endpoints to determine if we need to update them. - sort.Slice(workloads, func(i, j int) bool { - return workloads[i].Id.Name < workloads[j].Id.Name - }) - - return workloads, nil -} diff --git a/internal/catalog/internal/controllers/endpoints/reconciliation_data_test.go b/internal/catalog/internal/controllers/endpoints/reconciliation_data_test.go deleted file mode 100644 index 14c729e2cf..0000000000 --- a/internal/catalog/internal/controllers/endpoints/reconciliation_data_test.go +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package endpoints - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - - svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" - "github.com/hashicorp/consul/internal/catalog/internal/types" - "github.com/hashicorp/consul/internal/controller" - "github.com/hashicorp/consul/internal/resource" - "github.com/hashicorp/consul/internal/resource/resourcetest" - rtest "github.com/hashicorp/consul/internal/resource/resourcetest" - pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" - "github.com/hashicorp/consul/proto/private/prototest" - "github.com/hashicorp/consul/sdk/testutil" -) - -type reconciliationDataSuite struct { - suite.Suite - - ctx context.Context - client *resourcetest.Client - rt controller.Runtime - - apiServiceData *pbcatalog.Service - apiService *pbresource.Resource - apiServiceSubsetData *pbcatalog.Service - apiServiceSubset *pbresource.Resource - apiEndpoints *pbresource.Resource - api1Workload *pbresource.Resource - api2Workload *pbresource.Resource - api123Workload *pbresource.Resource - web1Workload *pbresource.Resource - web2Workload *pbresource.Resource - - tenancies []*pbresource.Tenancy -} - -func (suite *reconciliationDataSuite) SetupTest() { - suite.ctx = testutil.TestContext(suite.T()) - suite.tenancies = rtest.TestTenancies() - resourceClient := svctest.NewResourceServiceBuilder(). - WithRegisterFns(types.Register). - WithTenancies(suite.tenancies...). - Run(suite.T()) - suite.client = resourcetest.NewClient(resourceClient) - suite.rt = controller.Runtime{ - Client: suite.client, - Logger: testutil.Logger(suite.T()), - } - - suite.apiServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{ - // This services selectors are specially crafted to exercise both the - // dedeuplication and sorting behaviors of gatherWorkloadsForService - Prefixes: []string{"api-"}, - Names: []string{"api-1", "web-2", "web-1", "api-1", "not-found"}, - }, - Ports: []*pbcatalog.ServicePort{ - { - TargetPort: "http", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - }, - } - suite.apiServiceSubsetData = proto.Clone(suite.apiServiceData).(*pbcatalog.Service) - suite.apiServiceSubsetData.Workloads.Filter = "(zim in metadata) and (metadata.zim matches `^g.`)" -} - -func (suite *reconciliationDataSuite) TestGetServiceData_NotFound() { - // This test's purposes is to ensure that NotFound errors when retrieving - // the service data are ignored properly. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getServiceData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceType, "not-found").WithTenancy(tenancy).ID()) - require.NoError(suite.T(), err) - require.Nil(suite.T(), data) - }) -} - -func (suite *reconciliationDataSuite) TestGetServiceData_ReadError() { - // This test's purpose is to ensure that Read errors other than NotFound - // are propagated back to the caller. Specifying a resource ID with an - // unregistered type is the easiest way to force a resource service error. - badType := &pbresource.Type{ - Group: "not", - Kind: "found", - GroupVersion: "vfake", - } - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getServiceData(suite.ctx, suite.rt, rtest.Resource(badType, "foo").WithTenancy(tenancy).ID()) - require.Error(suite.T(), err) - require.Equal(suite.T(), codes.InvalidArgument, status.Code(err)) - require.Nil(suite.T(), data) - }) -} - -func (suite *reconciliationDataSuite) TestGetServiceData_UnmarshalError() { - // This test's purpose is to ensure that unmarshlling errors are returned - // to the caller. We are using a resource id that points to an endpoints - // object instead of a service to ensure that the data will be unmarshallable. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getServiceData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceEndpointsType, "api").WithTenancy(tenancy).ID()) - require.Error(suite.T(), err) - var parseErr resource.ErrDataParse - require.ErrorAs(suite.T(), err, &parseErr) - require.Nil(suite.T(), data) - }) -} - -func (suite *reconciliationDataSuite) TestGetServiceData_Ok() { - // This test's purpose is to ensure that the happy path for - // retrieving a service works as expected. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getServiceData(suite.ctx, suite.rt, suite.apiService.Id) - require.NoError(suite.T(), err) - require.NotNil(suite.T(), data) - require.NotNil(suite.T(), data.resource) - prototest.AssertDeepEqual(suite.T(), suite.apiService.Id, data.resource.Id) - require.Len(suite.T(), data.service.Ports, 1) - }) -} - -func (suite *reconciliationDataSuite) TestGetEndpointsData_NotFound() { - // This test's purposes is to ensure that NotFound errors when retrieving - // the endpoint data are ignored properly. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getEndpointsData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceEndpointsType, "not-found").WithTenancy(tenancy).ID()) - require.NoError(suite.T(), err) - require.Nil(suite.T(), data) - }) -} - -func (suite *reconciliationDataSuite) TestGetEndpointsData_ReadError() { - // This test's purpose is to ensure that Read errors other than NotFound - // are propagated back to the caller. Specifying a resource ID with an - // unregistered type is the easiest way to force a resource service error. - badType := &pbresource.Type{ - Group: "not", - Kind: "found", - GroupVersion: "vfake", - } - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getEndpointsData(suite.ctx, suite.rt, rtest.Resource(badType, "foo").WithTenancy(tenancy).ID()) - require.Error(suite.T(), err) - require.Equal(suite.T(), codes.InvalidArgument, status.Code(err)) - require.Nil(suite.T(), data) - }) -} - -func (suite *reconciliationDataSuite) TestGetEndpointsData_UnmarshalError() { - // This test's purpose is to ensure that unmarshlling errors are returned - // to the caller. We are using a resource id that points to a service object - // instead of an endpoints object to ensure that the data will be unmarshallable. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getEndpointsData(suite.ctx, suite.rt, rtest.Resource(pbcatalog.ServiceType, "api").WithTenancy(tenancy).ID()) - require.Error(suite.T(), err) - var parseErr resource.ErrDataParse - require.ErrorAs(suite.T(), err, &parseErr) - require.Nil(suite.T(), data) - }) -} - -func (suite *reconciliationDataSuite) TestGetEndpointsData_Ok() { - // This test's purpose is to ensure that the happy path for - // retrieving an endpoints object works as expected. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - data, err := getEndpointsData(suite.ctx, suite.rt, suite.apiEndpoints.Id) - require.NoError(suite.T(), err) - require.NotNil(suite.T(), data) - require.NotNil(suite.T(), data.resource) - prototest.AssertDeepEqual(suite.T(), suite.apiEndpoints.Id, data.resource.Id) - require.Len(suite.T(), data.endpoints.Endpoints, 1) - }) -} - -func (suite *reconciliationDataSuite) TestGetWorkloadData() { - // This test's purpose is to ensure that gather workloads for - // a service work as expected. The services selector was crafted - // to exercise the deduplication behavior as well as the sorting - // behavior. The assertions in this test will verify that only - // unique workloads are returned and that they are ordered. - - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - require.NotNil(suite.T(), suite.apiService) - - data, err := getWorkloadData(suite.ctx, suite.rt, &serviceData{ - resource: suite.apiService, - service: suite.apiServiceData, - }) - - require.NoError(suite.T(), err) - require.Len(suite.T(), data, 5) - prototest.AssertDeepEqual(suite.T(), suite.api1Workload, data[0].resource) - prototest.AssertDeepEqual(suite.T(), suite.api123Workload, data[1].resource) - prototest.AssertDeepEqual(suite.T(), suite.api2Workload, data[2].resource) - prototest.AssertDeepEqual(suite.T(), suite.web1Workload, data[3].resource) - prototest.AssertDeepEqual(suite.T(), suite.web2Workload, data[4].resource) - }) -} - -func (suite *reconciliationDataSuite) TestGetWorkloadDataWithFilter() { - // This is like TestGetWorkloadData except it exercises the post-read - // filter on the selector. - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - require.NotNil(suite.T(), suite.apiServiceSubset) - - data, err := getWorkloadData(suite.ctx, suite.rt, &serviceData{ - resource: suite.apiServiceSubset, - service: suite.apiServiceSubsetData, - }) - - require.NoError(suite.T(), err) - require.Len(suite.T(), data, 2) - prototest.AssertDeepEqual(suite.T(), suite.api123Workload, data[0].resource) - prototest.AssertDeepEqual(suite.T(), suite.web1Workload, data[1].resource) - }) -} - -func TestReconciliationData(t *testing.T) { - suite.Run(t, new(reconciliationDataSuite)) -} - -func (suite *reconciliationDataSuite) setupResourcesWithTenancy(tenancy *pbresource.Tenancy) { - suite.apiService = rtest.Resource(pbcatalog.ServiceType, "api"). - WithTenancy(tenancy). - WithData(suite.T(), suite.apiServiceData). - Write(suite.T(), suite.client) - - suite.apiServiceSubset = rtest.Resource(pbcatalog.ServiceType, "api-subset"). - WithTenancy(tenancy). - WithData(suite.T(), suite.apiServiceSubsetData). - Write(suite.T(), suite.client) - - suite.api1Workload = rtest.Resource(pbcatalog.WorkloadType, "api-1"). - WithTenancy(tenancy). - WithMeta("zim", "dib"). - WithData(suite.T(), &pbcatalog.Workload{ - Addresses: []*pbcatalog.WorkloadAddress{ - {Host: "127.0.0.1"}, - }, - Ports: map[string]*pbcatalog.WorkloadPort{ - "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, - }, - Identity: "api", - }). - Write(suite.T(), suite.client) - - suite.api2Workload = rtest.Resource(pbcatalog.WorkloadType, "api-2"). - WithTenancy(tenancy). - WithData(suite.T(), &pbcatalog.Workload{ - Addresses: []*pbcatalog.WorkloadAddress{ - {Host: "127.0.0.1"}, - }, - Ports: map[string]*pbcatalog.WorkloadPort{ - "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, - }, - Identity: "api", - }). - Write(suite.T(), suite.client) - - suite.api123Workload = rtest.Resource(pbcatalog.WorkloadType, "api-123"). - WithTenancy(tenancy). - WithMeta("zim", "gir"). - WithData(suite.T(), &pbcatalog.Workload{ - Addresses: []*pbcatalog.WorkloadAddress{ - {Host: "127.0.0.1"}, - }, - Ports: map[string]*pbcatalog.WorkloadPort{ - "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, - }, - Identity: "api", - }). - Write(suite.T(), suite.client) - - suite.web1Workload = rtest.Resource(pbcatalog.WorkloadType, "web-1"). - WithTenancy(tenancy). - WithMeta("zim", "gaz"). - WithData(suite.T(), &pbcatalog.Workload{ - Addresses: []*pbcatalog.WorkloadAddress{ - {Host: "127.0.0.1"}, - }, - Ports: map[string]*pbcatalog.WorkloadPort{ - "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, - }, - Identity: "web", - }). - Write(suite.T(), suite.client) - - suite.web2Workload = rtest.Resource(pbcatalog.WorkloadType, "web-2"). - WithTenancy(tenancy). - WithData(suite.T(), &pbcatalog.Workload{ - Addresses: []*pbcatalog.WorkloadAddress{ - {Host: "127.0.0.1"}, - }, - Ports: map[string]*pbcatalog.WorkloadPort{ - "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, - }, - Identity: "web", - }). - Write(suite.T(), suite.client) - - suite.apiEndpoints = rtest.Resource(pbcatalog.ServiceEndpointsType, "api"). - WithTenancy(tenancy). - WithData(suite.T(), &pbcatalog.ServiceEndpoints{ - Endpoints: []*pbcatalog.Endpoint{ - { - TargetRef: rtest.Resource(pbcatalog.WorkloadType, "api-1").WithTenancy(tenancy).ID(), - Addresses: []*pbcatalog.WorkloadAddress{ - { - Host: "127.0.0.1", - Ports: []string{"http"}, - }, - }, - Ports: map[string]*pbcatalog.WorkloadPort{ - "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, - }, - HealthStatus: pbcatalog.Health_HEALTH_PASSING, - }, - }, - }). - Write(suite.T(), suite.client) -} - -func (suite *reconciliationDataSuite) cleanupResources() { - suite.client.MustDelete(suite.T(), suite.apiService.Id) - suite.client.MustDelete(suite.T(), suite.apiServiceSubset.Id) - suite.client.MustDelete(suite.T(), suite.api1Workload.Id) - suite.client.MustDelete(suite.T(), suite.api2Workload.Id) - suite.client.MustDelete(suite.T(), suite.api123Workload.Id) - suite.client.MustDelete(suite.T(), suite.web1Workload.Id) - suite.client.MustDelete(suite.T(), suite.web2Workload.Id) - suite.client.MustDelete(suite.T(), suite.apiEndpoints.Id) -} - -func (suite *reconciliationDataSuite) runTestCaseWithTenancies(testFunc func(*pbresource.Tenancy)) { - for _, tenancy := range suite.tenancies { - suite.Run(suite.appendTenancyInfo(tenancy), func() { - suite.setupResourcesWithTenancy(tenancy) - testFunc(tenancy) - suite.T().Cleanup(suite.cleanupResources) - }) - } -} - -func (suite *reconciliationDataSuite) appendTenancyInfo(tenancy *pbresource.Tenancy) string { - return fmt.Sprintf("%s_Namespace_%s_Partition", tenancy.Namespace, tenancy.Partition) -} diff --git a/internal/catalog/internal/controllers/endpoints/status.go b/internal/catalog/internal/controllers/endpoints/status.go index 078d5e9a5f..daf1428b51 100644 --- a/internal/catalog/internal/controllers/endpoints/status.go +++ b/internal/catalog/internal/controllers/endpoints/status.go @@ -11,7 +11,7 @@ import ( ) const ( - StatusKey = "consul.io/endpoint-manager" + ControllerID = "consul.io/endpoint-manager" StatusConditionEndpointsManaged = "EndpointsManaged" StatusReasonSelectorNotFound = "SelectorNotFound" diff --git a/internal/catalog/internal/controllers/register.go b/internal/catalog/internal/controllers/register.go index 94e7740260..f3352e6de6 100644 --- a/internal/catalog/internal/controllers/register.go +++ b/internal/catalog/internal/controllers/register.go @@ -12,13 +12,12 @@ import ( ) type Dependencies struct { - EndpointsWorkloadMapper endpoints.WorkloadMapper - FailoverMapper failover.FailoverMapper + FailoverMapper failover.FailoverMapper } func Register(mgr *controller.Manager, deps Dependencies) { mgr.Register(nodehealth.NodeHealthController()) mgr.Register(workloadhealth.WorkloadHealthController()) - mgr.Register(endpoints.ServiceEndpointsController(deps.EndpointsWorkloadMapper)) + mgr.Register(endpoints.ServiceEndpointsController()) mgr.Register(failover.FailoverPolicyController(deps.FailoverMapper)) } diff --git a/internal/catalog/internal/types/health_checks.go b/internal/catalog/internal/types/health_checks.go index 3d819e1288..06afab81f7 100644 --- a/internal/catalog/internal/types/health_checks.go +++ b/internal/catalog/internal/types/health_checks.go @@ -6,6 +6,7 @@ package types import ( "github.com/hashicorp/go-multierror" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" ) @@ -18,7 +19,7 @@ func RegisterHealthChecks(r resource.Registry) { Proto: &pbcatalog.HealthChecks{}, Scope: resource.ScopeNamespace, Validate: ValidateHealthChecks, - ACLs: ACLHooksForWorkloadSelectingType[*pbcatalog.HealthChecks](), + ACLs: workloadselector.ACLHooks[*pbcatalog.HealthChecks](), }) } diff --git a/internal/catalog/internal/types/service.go b/internal/catalog/internal/types/service.go index bb56fe10a5..4b243bf152 100644 --- a/internal/catalog/internal/types/service.go +++ b/internal/catalog/internal/types/service.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/go-multierror" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" ) @@ -21,7 +22,7 @@ func RegisterService(r resource.Registry) { Scope: resource.ScopeNamespace, Validate: ValidateService, Mutate: MutateService, - ACLs: ACLHooksForWorkloadSelectingType[*pbcatalog.Service](), + ACLs: workloadselector.ACLHooks[*pbcatalog.Service](), }) } diff --git a/internal/catalog/internal/types/acl_hooks.go b/internal/catalog/workloadselector/acls.go similarity index 93% rename from internal/catalog/internal/types/acl_hooks.go rename to internal/catalog/workloadselector/acls.go index d9ddcb8e93..ebad0d47be 100644 --- a/internal/catalog/internal/types/acl_hooks.go +++ b/internal/catalog/workloadselector/acls.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package types +package workloadselector import ( "github.com/hashicorp/consul/acl" @@ -38,7 +38,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac return nil } -func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks { +func ACLHooks[T WorkloadSelecting]() *resource.ACLHooks { return &resource.ACLHooks{ Read: aclReadHookResourceWithWorkloadSelector, Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]), diff --git a/internal/catalog/workloadselector/acls_test.go b/internal/catalog/workloadselector/acls_test.go new file mode 100644 index 0000000000..9303460053 --- /dev/null +++ b/internal/catalog/workloadselector/acls_test.go @@ -0,0 +1,123 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +func TestACLHooks(t *testing.T) { + suite.Run(t, new(aclHookSuite)) +} + +type aclHookSuite struct { + suite.Suite + + hooks *resource.ACLHooks + authz *acl.MockAuthorizer + ctx *acl.AuthorizerContext + res *pbresource.Resource +} + +func (suite *aclHookSuite) SetupTest() { + suite.authz = new(acl.MockAuthorizer) + + suite.authz.On("ToAllowAuthorizer").Return(acl.AllowAuthorizer{Authorizer: suite.authz, AccessorID: "862270e5-7d7b-4583-98bc-4d14810cc158"}) + + suite.ctx = &acl.AuthorizerContext{} + acl.DefaultEnterpriseMeta().FillAuthzContext(suite.ctx) + + suite.hooks = ACLHooks[*pbcatalog.Service]() + + suite.res = resourcetest.Resource(pbcatalog.ServiceType, "foo"). + WithData(suite.T(), &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{"api-"}, + Names: []string{"bar"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() +} + +func (suite *aclHookSuite) TeardownTest() { + suite.authz.AssertExpectations(suite.T()) +} + +func (suite *aclHookSuite) TestReadHook_Allowed() { + suite.authz.On("ServiceRead", "foo", suite.ctx). + Return(acl.Allow). + Once() + + require.NoError(suite.T(), suite.hooks.Read(suite.authz, suite.ctx, suite.res.Id, nil)) +} + +func (suite *aclHookSuite) TestReadHook_Denied() { + suite.authz.On("ServiceRead", "foo", suite.ctx). + Return(acl.Deny). + Once() + + require.Error(suite.T(), suite.hooks.Read(suite.authz, suite.ctx, suite.res.Id, nil)) +} + +func (suite *aclHookSuite) TestWriteHook_ServiceWriteDenied() { + suite.authz.On("ServiceWrite", "foo", suite.ctx). + Return(acl.Deny). + Once() + + require.Error(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res)) +} + +func (suite *aclHookSuite) TestWriteHook_ServiceReadNameDenied() { + suite.authz.On("ServiceWrite", "foo", suite.ctx). + Return(acl.Allow). + Once() + + suite.authz.On("ServiceRead", "bar", suite.ctx). + Return(acl.Deny). + Once() + + require.Error(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res)) +} + +func (suite *aclHookSuite) TestWriteHook_ServiceReadPrefixDenied() { + suite.authz.On("ServiceWrite", "foo", suite.ctx). + Return(acl.Allow). + Once() + + suite.authz.On("ServiceRead", "bar", suite.ctx). + Return(acl.Allow). + Once() + + suite.authz.On("ServiceReadPrefix", "api-", suite.ctx). + Return(acl.Deny). + Once() + + require.Error(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res)) +} + +func (suite *aclHookSuite) TestWriteHook_Allowed() { + suite.authz.On("ServiceWrite", "foo", suite.ctx). + Return(acl.Allow). + Once() + + suite.authz.On("ServiceRead", "bar", suite.ctx). + Return(acl.Allow). + Once() + + suite.authz.On("ServiceReadPrefix", "api-", suite.ctx). + Return(acl.Allow). + Once() + + require.NoError(suite.T(), suite.hooks.Write(suite.authz, suite.ctx, suite.res)) +} diff --git a/internal/catalog/workloadselector/gather.go b/internal/catalog/workloadselector/gather.go new file mode 100644 index 0000000000..b4b4492698 --- /dev/null +++ b/internal/catalog/workloadselector/gather.go @@ -0,0 +1,114 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "fmt" + "sort" + + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/controller/cache/index" + "github.com/hashicorp/consul/internal/resource" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +// GetWorkloadsWithSelector will retrieve all workloads for the given resources selector +// and unmarhshal them, returning a slice of objects hold both the resource and +// unmarshaled forms. Unmarshalling errors, or other cache errors +// will be returned to the caller. +func GetWorkloadsWithSelector[T WorkloadSelecting](c cache.ReadOnlyCache, res *resource.DecodedResource[T]) ([]*resource.DecodedResource[*pbcatalog.Workload], error) { + if res == nil { + return nil, nil + } + + sel := res.Data.GetWorkloads() + + if sel == nil || (len(sel.GetNames()) < 1 && len(sel.GetPrefixes()) < 1) { + return nil, nil + } + + // this map will track all workloads by name which is needed to deduplicate workloads if they + // are specified multiple times throughout the list of selection criteria + workloadNames := make(map[string]struct{}) + + var workloads []*resource.DecodedResource[*pbcatalog.Workload] + + // First gather all the prefix matched workloads. We could do this second but by doing + // it first its possible we can avoid some operations to get individual + // workloads selected by name if they are also matched by a prefix. + for _, prefix := range sel.GetPrefixes() { + iter, err := cache.ListIteratorDecoded[*pbcatalog.Workload]( + c, + pbcatalog.WorkloadType, + "id", + &pbresource.ID{ + Type: pbcatalog.WorkloadType, + Tenancy: res.Id.Tenancy, + Name: prefix, + }, + index.IndexQueryOptions{Prefix: true}) + if err != nil { + return nil, err + } + + // append all workloads in the list response to our list of all selected workloads + for workload, err := iter.Next(); workload != nil || err != nil; workload, err = iter.Next() { + if err != nil { + return nil, err + } + + // ignore duplicate workloads + if _, found := workloadNames[workload.Id.Name]; !found { + workloads = append(workloads, workload) + workloadNames[workload.Id.Name] = struct{}{} + } + } + } + + // Now gather the exact match selections + for _, name := range sel.GetNames() { + // ignore names we have already fetched + if _, found := workloadNames[name]; found { + continue + } + + workloadID := &pbresource.ID{ + Type: pbcatalog.WorkloadType, + Tenancy: res.Id.Tenancy, + Name: name, + } + + res, err := cache.GetDecoded[*pbcatalog.Workload](c, pbcatalog.WorkloadType, "id", workloadID) + if err != nil { + return nil, err + } + + // ignore workloads that don't exist as it is fine for a Service to select them. If they exist in the + // future then the ServiceEndpoints will be regenerated to include them. + if res == nil { + continue + } + + workloads = append(workloads, res) + workloadNames[res.Id.Name] = struct{}{} + } + + if sel.GetFilter() != "" && len(workloads) > 0 { + var err error + workloads, err = resource.FilterResourcesByMetadata(workloads, sel.GetFilter()) + if err != nil { + return nil, fmt.Errorf("error filtering results by metadata: %w", err) + } + } + + // Sorting ensures deterministic output. This will help for testing but + // the real reason to do this is so we will be able to diff the set of + // workloads endpoints to determine if we need to update them. + sort.Slice(workloads, func(i, j int) bool { + return workloads[i].Id.Name < workloads[j].Id.Name + }) + + return workloads, nil +} diff --git a/internal/catalog/workloadselector/gather_test.go b/internal/catalog/workloadselector/gather_test.go new file mode 100644 index 0000000000..4671698a8a --- /dev/null +++ b/internal/catalog/workloadselector/gather_test.go @@ -0,0 +1,258 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" + + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/resource" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/proto/private/prototest" + "github.com/hashicorp/consul/sdk/testutil" +) + +type gatherWorkloadsDataSuite struct { + suite.Suite + + ctx context.Context + cache cache.Cache + + apiServiceData *pbcatalog.Service + apiService *resource.DecodedResource[*pbcatalog.Service] + apiServiceSubsetData *pbcatalog.Service + apiServiceSubset *resource.DecodedResource[*pbcatalog.Service] + apiEndpoints *resource.DecodedResource[*pbcatalog.ServiceEndpoints] + api1Workload *resource.DecodedResource[*pbcatalog.Workload] + api2Workload *resource.DecodedResource[*pbcatalog.Workload] + api123Workload *resource.DecodedResource[*pbcatalog.Workload] + web1Workload *resource.DecodedResource[*pbcatalog.Workload] + web2Workload *resource.DecodedResource[*pbcatalog.Workload] + + tenancies []*pbresource.Tenancy +} + +func (suite *gatherWorkloadsDataSuite) SetupTest() { + suite.ctx = testutil.TestContext(suite.T()) + suite.tenancies = rtest.TestTenancies() + + suite.cache = cache.New() + suite.cache.AddType(pbcatalog.WorkloadType) + + suite.apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + // This services selectors are specially crafted to exercise both the + // dedeuplication and sorting behaviors of gatherWorkloadsForService + Prefixes: []string{"api-"}, + Names: []string{"api-1", "web-2", "web-1", "api-1", "not-found"}, + }, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + suite.apiServiceSubsetData = proto.Clone(suite.apiServiceData).(*pbcatalog.Service) + suite.apiServiceSubsetData.Workloads.Filter = "(zim in metadata) and (metadata.zim matches `^g.`)" +} + +func (suite *gatherWorkloadsDataSuite) TestGetWorkloadData() { + // This test's purpose is to ensure that gather workloads for + // a service work as expected. The services selector was crafted + // to exercise the deduplication behavior as well as the sorting + // behavior. The assertions in this test will verify that only + // unique workloads are returned and that they are ordered. + + suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { + require.NotNil(suite.T(), suite.apiService) + + data, err := GetWorkloadsWithSelector(suite.cache, suite.apiService) + + require.NoError(suite.T(), err) + require.Len(suite.T(), data, 5) + requireDecodedWorkloadEquals(suite.T(), suite.api1Workload, data[0]) + requireDecodedWorkloadEquals(suite.T(), suite.api1Workload, data[0]) + requireDecodedWorkloadEquals(suite.T(), suite.api123Workload, data[1]) + requireDecodedWorkloadEquals(suite.T(), suite.api2Workload, data[2]) + requireDecodedWorkloadEquals(suite.T(), suite.web1Workload, data[3]) + requireDecodedWorkloadEquals(suite.T(), suite.web2Workload, data[4]) + }) +} + +func (suite *gatherWorkloadsDataSuite) TestGetWorkloadDataWithFilter() { + // This is like TestGetWorkloadData except it exercises the post-read + // filter on the selector. + suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { + require.NotNil(suite.T(), suite.apiServiceSubset) + + data, err := GetWorkloadsWithSelector(suite.cache, suite.apiServiceSubset) + + require.NoError(suite.T(), err) + require.Len(suite.T(), data, 2) + requireDecodedWorkloadEquals(suite.T(), suite.api123Workload, data[0]) + requireDecodedWorkloadEquals(suite.T(), suite.web1Workload, data[1]) + }) +} + +func TestReconciliationData(t *testing.T) { + suite.Run(t, new(gatherWorkloadsDataSuite)) +} + +func (suite *gatherWorkloadsDataSuite) setupResourcesWithTenancy(tenancy *pbresource.Tenancy) { + suite.apiService = rtest.MustDecode[*pbcatalog.Service]( + suite.T(), + rtest.Resource(pbcatalog.ServiceType, "api"). + WithTenancy(tenancy). + WithData(suite.T(), suite.apiServiceData). + Build()) + + suite.apiServiceSubset = rtest.MustDecode[*pbcatalog.Service]( + suite.T(), + rtest.Resource(pbcatalog.ServiceType, "api-subset"). + WithTenancy(tenancy). + WithData(suite.T(), suite.apiServiceSubsetData). + Build()) + + suite.api1Workload = rtest.MustDecode[*pbcatalog.Workload]( + suite.T(), + rtest.Resource(pbcatalog.WorkloadType, "api-1"). + WithTenancy(tenancy). + WithMeta("zim", "dib"). + WithData(suite.T(), &pbcatalog.Workload{ + Addresses: []*pbcatalog.WorkloadAddress{ + {Host: "127.0.0.1"}, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + }, + Identity: "api", + }). + Build()) + suite.cache.Insert(suite.api1Workload.Resource) + + suite.api2Workload = rtest.MustDecode[*pbcatalog.Workload]( + suite.T(), + rtest.Resource(pbcatalog.WorkloadType, "api-2"). + WithTenancy(tenancy). + WithData(suite.T(), &pbcatalog.Workload{ + Addresses: []*pbcatalog.WorkloadAddress{ + {Host: "127.0.0.1"}, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + }, + Identity: "api", + }). + Build()) + suite.cache.Insert(suite.api2Workload.Resource) + + suite.api123Workload = rtest.MustDecode[*pbcatalog.Workload]( + suite.T(), + rtest.Resource(pbcatalog.WorkloadType, "api-123"). + WithTenancy(tenancy). + WithMeta("zim", "gir"). + WithData(suite.T(), &pbcatalog.Workload{ + Addresses: []*pbcatalog.WorkloadAddress{ + {Host: "127.0.0.1"}, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + }, + Identity: "api", + }). + Build()) + suite.cache.Insert(suite.api123Workload.Resource) + + suite.web1Workload = rtest.MustDecode[*pbcatalog.Workload]( + suite.T(), + rtest.Resource(pbcatalog.WorkloadType, "web-1"). + WithTenancy(tenancy). + WithMeta("zim", "gaz"). + WithData(suite.T(), &pbcatalog.Workload{ + Addresses: []*pbcatalog.WorkloadAddress{ + {Host: "127.0.0.1"}, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + }, + Identity: "web", + }). + Build()) + suite.cache.Insert(suite.web1Workload.Resource) + + suite.web2Workload = rtest.MustDecode[*pbcatalog.Workload]( + suite.T(), + rtest.Resource(pbcatalog.WorkloadType, "web-2"). + WithTenancy(tenancy). + WithData(suite.T(), &pbcatalog.Workload{ + Addresses: []*pbcatalog.WorkloadAddress{ + {Host: "127.0.0.1"}, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + }, + Identity: "web", + }). + Build()) + suite.cache.Insert(suite.web2Workload.Resource) + + suite.apiEndpoints = rtest.MustDecode[*pbcatalog.ServiceEndpoints]( + suite.T(), + rtest.Resource(pbcatalog.ServiceEndpointsType, "api"). + WithTenancy(tenancy). + WithData(suite.T(), &pbcatalog.ServiceEndpoints{ + Endpoints: []*pbcatalog.Endpoint{ + { + TargetRef: rtest.Resource(pbcatalog.WorkloadType, "api-1").WithTenancy(tenancy).ID(), + Addresses: []*pbcatalog.WorkloadAddress{ + { + Host: "127.0.0.1", + Ports: []string{"http"}, + }, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "http": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + }, + HealthStatus: pbcatalog.Health_HEALTH_PASSING, + }, + }, + }). + Build()) +} + +func (suite *gatherWorkloadsDataSuite) cleanupResources() { + require.NoError(suite.T(), suite.cache.Delete(suite.api1Workload.Resource)) + require.NoError(suite.T(), suite.cache.Delete(suite.api2Workload.Resource)) + require.NoError(suite.T(), suite.cache.Delete(suite.api123Workload.Resource)) + require.NoError(suite.T(), suite.cache.Delete(suite.web1Workload.Resource)) + require.NoError(suite.T(), suite.cache.Delete(suite.web2Workload.Resource)) +} + +func (suite *gatherWorkloadsDataSuite) runTestCaseWithTenancies(testFunc func(*pbresource.Tenancy)) { + for _, tenancy := range suite.tenancies { + suite.Run(suite.appendTenancyInfo(tenancy), func() { + suite.setupResourcesWithTenancy(tenancy) + testFunc(tenancy) + suite.T().Cleanup(suite.cleanupResources) + }) + } +} + +func (suite *gatherWorkloadsDataSuite) appendTenancyInfo(tenancy *pbresource.Tenancy) string { + return fmt.Sprintf("%s_Namespace_%s_Partition", tenancy.Namespace, tenancy.Partition) +} + +func requireDecodedWorkloadEquals(t testutil.TestingTB, expected, actual *resource.DecodedResource[*pbcatalog.Workload]) { + prototest.AssertDeepEqual(t, expected.Resource, actual.Resource) + require.Equal(t, expected.Data, actual.Data) +} diff --git a/internal/catalog/workloadselector/index.go b/internal/catalog/workloadselector/index.go new file mode 100644 index 0000000000..1928318583 --- /dev/null +++ b/internal/catalog/workloadselector/index.go @@ -0,0 +1,72 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "github.com/hashicorp/consul/internal/controller/cache/index" + "github.com/hashicorp/consul/internal/controller/cache/indexers" + "github.com/hashicorp/consul/internal/resource" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +const ( + IndexName = "selected-workloads" +) + +func Index[T WorkloadSelecting](name string) *index.Index { + return indexers.DecodedMultiIndexer[T]( + name, + index.SingleValueFromOneOrTwoArgs[resource.ReferenceOrID, index.IndexQueryOptions](fromArgs), + fromResource[T], + ) +} + +func fromArgs(r resource.ReferenceOrID, opts index.IndexQueryOptions) ([]byte, error) { + workloadRef := &pbresource.Reference{ + Type: pbcatalog.WorkloadType, + Tenancy: r.GetTenancy(), + Name: r.GetName(), + } + + if opts.Prefix { + return index.PrefixIndexFromRefOrID(workloadRef), nil + } else { + return index.IndexFromRefOrID(workloadRef), nil + } +} + +func fromResource[T WorkloadSelecting](res *resource.DecodedResource[T]) (bool, [][]byte, error) { + sel := res.Data.GetWorkloads() + if sel == nil || (len(sel.Prefixes) == 0 && len(sel.Names) == 0) { + return false, nil, nil + } + + var indexes [][]byte + + for _, name := range sel.Names { + ref := &pbresource.Reference{ + Type: pbcatalog.WorkloadType, + Tenancy: res.Id.Tenancy, + Name: name, + } + + indexes = append(indexes, index.IndexFromRefOrID(ref)) + } + + for _, name := range sel.Prefixes { + ref := &pbresource.Reference{ + Type: pbcatalog.WorkloadType, + Tenancy: res.Id.Tenancy, + Name: name, + } + + b := index.IndexFromRefOrID(ref) + + // need to remove the path separator to be compatible with prefix matching + indexes = append(indexes, b[:len(b)-1]) + } + + return true, indexes, nil +} diff --git a/internal/catalog/workloadselector/index_test.go b/internal/catalog/workloadselector/index_test.go new file mode 100644 index 0000000000..ec61c69eb7 --- /dev/null +++ b/internal/catalog/workloadselector/index_test.go @@ -0,0 +1,135 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "testing" + + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/controller/cache/index" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/resourcetest" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/proto/private/prototest" + "github.com/stretchr/testify/require" +) + +func TestServiceWorkloadIndexer(t *testing.T) { + c := cache.New() + i := Index[*pbcatalog.Service]("selected-workloads") + require.NoError(t, c.AddIndex(pbcatalog.ServiceType, i)) + + foo := rtest.Resource(pbcatalog.ServiceType, "foo"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{ + "api-2", + }, + Prefixes: []string{ + "api-1", + }, + }, + }). + WithTenancy(&pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }). + Build() + + require.NoError(t, c.Insert(foo)) + + bar := rtest.Resource(pbcatalog.ServiceType, "bar"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{ + "api-3", + }, + Prefixes: []string{ + "api-2", + }, + }, + }). + WithTenancy(&pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }). + Build() + + require.NoError(t, c.Insert(bar)) + + api123 := rtest.Resource(pbcatalog.WorkloadType, "api-123"). + WithTenancy(&pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }). + Reference("") + + api2 := rtest.Resource(pbcatalog.WorkloadType, "api-2"). + WithTenancy(&pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }). + Reference("") + + resources, err := c.Parents(pbcatalog.ServiceType, i.Name(), api123) + require.NoError(t, err) + require.Len(t, resources, 1) + prototest.AssertDeepEqual(t, foo, resources[0]) + + resources, err = c.Parents(pbcatalog.ServiceType, i.Name(), api2) + require.NoError(t, err) + require.Len(t, resources, 2) + prototest.AssertElementsMatch(t, []*pbresource.Resource{foo, bar}, resources) + + refPrefix := &pbresource.Reference{ + Type: pbcatalog.WorkloadType, + Tenancy: &pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + }, + } + resources, err = c.List(pbcatalog.ServiceType, i.Name(), refPrefix, index.IndexQueryOptions{Prefix: true}) + require.NoError(t, err) + // because foo and bar both have 2 index values they will appear in the output twice + require.Len(t, resources, 4) + prototest.AssertElementsMatch(t, []*pbresource.Resource{foo, bar, foo, bar}, resources) +} + +func TestServiceWorkloadIndexer_FromResource_Errors(t *testing.T) { + t.Run("nil-selector", func(t *testing.T) { + res := resourcetest.MustDecode[*pbcatalog.Service]( + t, + resourcetest.Resource(pbcatalog.ServiceType, "foo"). + WithData(t, &pbcatalog.Service{}). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build()) + + indexed, vals, err := fromResource(res) + require.False(t, indexed) + require.Nil(t, vals) + require.NoError(t, err) + }) + + t.Run("no-selections", func(t *testing.T) { + res := resourcetest.MustDecode[*pbcatalog.Service]( + t, + resourcetest.Resource(pbcatalog.ServiceType, "foo"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{}, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build()) + + indexed, vals, err := fromResource(res) + require.False(t, indexed) + require.Nil(t, vals) + require.NoError(t, err) + }) +} diff --git a/internal/catalog/workloadselector/integ_test.go b/internal/catalog/workloadselector/integ_test.go new file mode 100644 index 0000000000..866f83def6 --- /dev/null +++ b/internal/catalog/workloadselector/integ_test.go @@ -0,0 +1,151 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector_test + +import ( + "context" + "testing" + + "github.com/hashicorp/consul/internal/catalog/workloadselector" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto/private/prototest" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/stretchr/testify/require" +) + +func TestWorkloadSelectorCacheIntegration(t *testing.T) { + c := cache.New() + i := workloadselector.Index[*pbcatalog.Service]("selected-workloads") + c.AddType(pbcatalog.WorkloadType) + c.AddIndex(pbcatalog.ServiceType, i) + + rt := controller.Runtime{ + Cache: c, + Logger: testutil.Logger(t), + } + + svcFoo := resourcetest.Resource(pbcatalog.ServiceType, "foo"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{"foo"}, + Prefixes: []string{"api-", "other-"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + svcBar := resourcetest.Resource(pbcatalog.ServiceType, "bar"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{"bar"}, + Prefixes: []string{"api-1", "something-else-"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + workloadBar := resourcetest.Resource(pbcatalog.WorkloadType, "bar"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + workloadAPIFoo := resourcetest.Resource(pbcatalog.WorkloadType, "api-foo"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + workloadAPI12 := resourcetest.Resource(pbcatalog.WorkloadType, "api-1"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + workloadFoo := resourcetest.Resource(pbcatalog.WorkloadType, "foo"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + workloadSomethingElse12 := resourcetest.Resource(pbcatalog.WorkloadType, "something-else-12"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + // prime the cache with all of our services and workloads + require.NoError(t, c.Insert(svcFoo)) + require.NoError(t, c.Insert(svcBar)) + require.NoError(t, c.Insert(workloadAPIFoo)) + require.NoError(t, c.Insert(workloadAPI12)) + require.NoError(t, c.Insert(workloadFoo)) + require.NoError(t, c.Insert(workloadSomethingElse12)) + + // check that mapping a selecting resource to the list of currently selected workloads works as expected + reqs, err := workloadselector.MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svcFoo) + require.NoError(t, err) + // in particular workloadSomethingElse12 should not show up here + expected := []controller.Request{ + {ID: workloadFoo.Id}, + {ID: workloadAPI12.Id}, + {ID: workloadAPIFoo.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + + reqs, err = workloadselector.MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svcBar) + require.NoError(t, err) + // workloadFoo and workloadAPIFoo should not show up here as they don't meet the selection critiera + // workloadBar should not show up here because it hasn't been inserted into the cache yet. + expected = []controller.Request{ + {ID: workloadSomethingElse12.Id}, + {ID: workloadAPI12.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + + // insert workloadBar into the cache so that future calls to MapSelectorToWorkloads for svcBar show + // the workload in the output + require.NoError(t, c.Insert(workloadBar)) + + // now validate that workloadBar shows up in the svcBar mapping + reqs, err = workloadselector.MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svcBar) + require.NoError(t, err) + expected = []controller.Request{ + {ID: workloadSomethingElse12.Id}, + {ID: workloadAPI12.Id}, + {ID: workloadBar.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + + // create the mapper to verify that finding services that select workloads functions correctly + mapper := workloadselector.MapWorkloadsToSelectors(pbcatalog.ServiceType, i.Name()) + + // check that workloadAPIFoo only returns a request for serviceFoo + reqs, err = mapper(context.Background(), rt, workloadAPIFoo) + require.NoError(t, err) + expected = []controller.Request{ + {ID: svcFoo.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + + // check that workloadAPI12 returns both services + reqs, err = mapper(context.Background(), rt, workloadAPI12) + require.NoError(t, err) + expected = []controller.Request{ + {ID: svcFoo.Id}, + {ID: svcBar.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + + // check that workloadSomethingElse12 returns only svcBar + reqs, err = mapper(context.Background(), rt, workloadSomethingElse12) + require.NoError(t, err) + expected = []controller.Request{ + {ID: svcBar.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + + // check that workloadFoo returns only svcFoo + reqs, err = mapper(context.Background(), rt, workloadFoo) + require.NoError(t, err) + expected = []controller.Request{ + {ID: svcFoo.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + +} diff --git a/internal/catalog/workloadselector/mapper.go b/internal/catalog/workloadselector/mapper.go new file mode 100644 index 0000000000..fac38ae55a --- /dev/null +++ b/internal/catalog/workloadselector/mapper.go @@ -0,0 +1,45 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "context" + + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/dependency" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +// MapSelectorToWorkloads will use the "id" index on watched Workload type to find all current +// workloads selected by the resource. +func MapSelectorToWorkloads[T WorkloadSelecting](_ context.Context, rt controller.Runtime, r *pbresource.Resource) ([]controller.Request, error) { + res, err := resource.Decode[T](r) + if err != nil { + return nil, err + } + + workloads, err := GetWorkloadsWithSelector[T](rt.Cache, res) + if err != nil { + return nil, err + } + + reqs := make([]controller.Request, len(workloads)) + for i, workload := range workloads { + reqs[i] = controller.Request{ + ID: workload.Id, + } + } + + return reqs, nil +} + +// MapWorkloadsToSelectors returns a DependencyMapper that will use the specified index to map a workload +// to resources that select it. +// +// This mapper can only be used on watches for the Workload type and works in conjunction with the Index +// created by this package. +func MapWorkloadsToSelectors(indexType *pbresource.Type, indexName string) controller.DependencyMapper { + return dependency.CacheParentsMapper(indexType, indexName) +} diff --git a/internal/catalog/workloadselector/mapper_test.go b/internal/catalog/workloadselector/mapper_test.go new file mode 100644 index 0000000000..c0bd853190 --- /dev/null +++ b/internal/catalog/workloadselector/mapper_test.go @@ -0,0 +1,180 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workloadselector + +import ( + "context" + "errors" + "testing" + + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache/cachemock" + "github.com/hashicorp/consul/internal/controller/cache/index" + "github.com/hashicorp/consul/internal/controller/cache/index/indexmock" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/proto/private/prototest" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" +) + +var injectedError = errors.New("injected error") + +func TestMapSelectorToWorkloads(t *testing.T) { + cache := cachemock.NewReadOnlyCache(t) + + rt := controller.Runtime{ + Cache: cache, + } + + mres := indexmock.NewResourceIterator(t) + + svc := resourcetest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{"api-"}, + Names: []string{"foo"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + api1 := resourcetest.Resource(pbcatalog.WorkloadType, "api-1"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + api2 := resourcetest.Resource(pbcatalog.WorkloadType, "api-2"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + fooRes := resourcetest.Resource(pbcatalog.WorkloadType, "foo"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + cache.EXPECT(). + ListIterator(pbcatalog.WorkloadType, "id", &pbresource.ID{ + Type: pbcatalog.WorkloadType, + Name: "api-", + Tenancy: resource.DefaultNamespacedTenancy(), + }, index.IndexQueryOptions{Prefix: true}). + Return(mres, nil). + Once() + cache.EXPECT(). + Get(pbcatalog.WorkloadType, "id", &pbresource.ID{ + Type: pbcatalog.WorkloadType, + Name: "foo", + Tenancy: resource.DefaultNamespacedTenancy(), + }). + Return(fooRes, nil). + Once() + + mres.EXPECT().Next().Return(api1).Once() + mres.EXPECT().Next().Return(api2).Once() + mres.EXPECT().Next().Return(nil).Once() + + expected := []controller.Request{ + {ID: fooRes.Id}, + {ID: api1.Id}, + {ID: api2.Id}, + } + + reqs, err := MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svc) + require.NoError(t, err) + prototest.AssertElementsMatch(t, expected, reqs) +} + +func TestMapSelectorToWorkloads_DecodeError(t *testing.T) { + res := resourcetest.Resource(pbcatalog.ServiceType, "foo"). + WithData(t, &pbcatalog.DNSPolicy{}). + Build() + + reqs, err := MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), controller.Runtime{}, res) + require.Nil(t, reqs) + require.Error(t, err) + require.ErrorAs(t, err, &resource.ErrDataParse{}) +} + +func TestMapSelectorToWorkloads_CacheError(t *testing.T) { + cache := cachemock.NewReadOnlyCache(t) + + rt := controller.Runtime{ + Cache: cache, + } + + svc := resourcetest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{"api-"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + cache.EXPECT(). + ListIterator(pbcatalog.WorkloadType, "id", &pbresource.ID{ + Type: pbcatalog.WorkloadType, + Name: "api-", + Tenancy: resource.DefaultNamespacedTenancy(), + }, index.IndexQueryOptions{Prefix: true}). + Return(nil, injectedError). + Once() + + reqs, err := MapSelectorToWorkloads[*pbcatalog.Service](context.Background(), rt, svc) + require.ErrorIs(t, err, injectedError) + require.Nil(t, reqs) +} + +func TestMapWorkloadsToSelectors(t *testing.T) { + cache := cachemock.NewReadOnlyCache(t) + rt := controller.Runtime{ + Cache: cache, + Logger: hclog.NewNullLogger(), + } + + dm := MapWorkloadsToSelectors(pbcatalog.ServiceType, "selected-workloads") + + workload := resourcetest.Resource(pbcatalog.WorkloadType, "api-123"). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + svc1 := resourcetest.Resource(pbcatalog.ServiceType, "foo"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{"api-"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + svc2 := resourcetest.Resource(pbcatalog.ServiceType, "bar"). + WithData(t, &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{"api-"}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + mres := indexmock.NewResourceIterator(t) + + cache.EXPECT(). + ParentsIterator(pbcatalog.ServiceType, "selected-workloads", workload.Id). + Return(mres, nil). + Once() + + mres.EXPECT().Next().Return(svc1).Once() + mres.EXPECT().Next().Return(svc2).Once() + mres.EXPECT().Next().Return(nil).Once() + + reqs, err := dm(context.Background(), rt, workload) + require.NoError(t, err) + expected := []controller.Request{ + {ID: svc1.Id}, + {ID: svc2.Id}, + } + prototest.AssertElementsMatch(t, expected, reqs) + +} diff --git a/internal/catalog/internal/types/workload_selecting.go b/internal/catalog/workloadselector/selecting.go similarity index 93% rename from internal/catalog/internal/types/workload_selecting.go rename to internal/catalog/workloadselector/selecting.go index 6d129bfaa6..d243b4c6ff 100644 --- a/internal/catalog/internal/types/workload_selecting.go +++ b/internal/catalog/workloadselector/selecting.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package types +package workloadselector import ( "google.golang.org/protobuf/proto" diff --git a/internal/controller/cache/decoded.go b/internal/controller/cache/decoded.go new file mode 100644 index 0000000000..76c6b2d7c9 --- /dev/null +++ b/internal/controller/cache/decoded.go @@ -0,0 +1,107 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "google.golang.org/protobuf/proto" + + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +// Get retrieves a single resource from the specified index that matches the provided args. +// If more than one match is found the first is returned. +func GetDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) (*resource.DecodedResource[T], error) { + res, err := c.Get(it, indexName, args...) + if err != nil { + return nil, err + } + + if res == nil { + return nil, nil + } + + return resource.Decode[T](res) +} + +// List retrieves all the resources from the specified index matching the provided args. +func ListDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) ([]*resource.DecodedResource[T], error) { + resources, err := c.List(it, indexName, args...) + if err != nil { + return nil, err + } + + return resource.DecodeList[T](resources) +} + +// ListIterator retrieves an iterator over all resources from the specified index matching the provided args. +func ListIteratorDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) (DecodedResourceIterator[T], error) { + iter, err := c.ListIterator(it, indexName, args...) + if err != nil { + return nil, err + } + + if iter == nil { + return nil, nil + } + + return decodedResourceIterator[T]{iter}, nil +} + +// Parents retrieves all resources whos index value is a parent (or prefix) of the value calculated +// from the provided args. +func ParentsDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) ([]*resource.DecodedResource[T], error) { + resources, err := c.Parents(it, indexName, args...) + if err != nil { + return nil, err + } + + return resource.DecodeList[T](resources) +} + +// ParentsIterator retrieves an iterator over all resources whos index value is a parent (or prefix) +// of the value calculated from the provided args. +func ParentsIteratorDecoded[T proto.Message](c ReadOnlyCache, it *pbresource.Type, indexName string, args ...any) (DecodedResourceIterator[T], error) { + iter, err := c.ParentsIterator(it, indexName, args...) + if err != nil { + return nil, err + } + + if iter == nil { + return nil, nil + } + + return decodedResourceIterator[T]{iter}, nil +} + +// Query will execute a named query against the cache and return an interator over its results +func QueryDecoded[T proto.Message](c ReadOnlyCache, name string, args ...any) (DecodedResourceIterator[T], error) { + iter, err := c.Query(name, args...) + if err != nil { + return nil, err + } + + if iter == nil { + return nil, nil + } + + return decodedResourceIterator[T]{iter}, nil +} + +type DecodedResourceIterator[T proto.Message] interface { + Next() (*resource.DecodedResource[T], error) +} + +type decodedResourceIterator[T proto.Message] struct { + ResourceIterator +} + +func (iter decodedResourceIterator[T]) Next() (*resource.DecodedResource[T], error) { + res := iter.ResourceIterator.Next() + if res == nil { + return nil, nil + } + + return resource.Decode[T](res) +} diff --git a/internal/controller/cache/decoded_test.go b/internal/controller/cache/decoded_test.go new file mode 100644 index 0000000000..81ab62d80f --- /dev/null +++ b/internal/controller/cache/decoded_test.go @@ -0,0 +1,360 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/controller/cache/cachemock" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/demo" + "github.com/hashicorp/consul/proto-public/pbresource" + pbdemo "github.com/hashicorp/consul/proto/private/pbdemo/v2" + "github.com/hashicorp/consul/proto/private/prototest" +) + +type decodedSuite struct { + suite.Suite + rc *cachemock.ReadOnlyCache + iter *cachemock.ResourceIterator + + artistGood *resource.DecodedResource[*pbdemo.Artist] + artistGood2 *resource.DecodedResource[*pbdemo.Artist] + artistBad *pbresource.Resource +} + +func (suite *decodedSuite) SetupTest() { + suite.rc = cachemock.NewReadOnlyCache(suite.T()) + suite.iter = cachemock.NewResourceIterator(suite.T()) + artist, err := demo.GenerateV2Artist() + require.NoError(suite.T(), err) + suite.artistGood, err = resource.Decode[*pbdemo.Artist](artist) + require.NoError(suite.T(), err) + + artist2, err := demo.GenerateV2Artist() + require.NoError(suite.T(), err) + suite.artistGood2, err = resource.Decode[*pbdemo.Artist](artist2) + require.NoError(suite.T(), err) + + suite.artistBad, err = demo.GenerateV2Album(artist.Id) + require.NoError(suite.T(), err) +} + +func (suite *decodedSuite) TestGetDecoded_Ok() { + suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(suite.artistGood.Resource, nil) + + dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) +} + +func (suite *decodedSuite) TestGetDecoded_DecodeError() { + suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(suite.artistBad, nil) + + dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.Error(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestGetDecoded_CacheError() { + suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError) + + dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.ErrorIs(suite.T(), err, injectedError) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestGetDecoded_Nil() { + suite.rc.EXPECT().Get(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil) + + dec, err := cache.GetDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestListDecoded_Ok() { + suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistGood2.Resource}, nil) + + dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Len(suite.T(), dec, 2) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec[0].Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec[0].Data) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec[1].Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec[1].Data) +} + +func (suite *decodedSuite) TestListDecoded_DecodeError() { + suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistBad}, nil) + + dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.Error(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestListDecoded_CacheError() { + suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError) + + dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.ErrorIs(suite.T(), err, injectedError) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestListDecoded_Nil() { + suite.rc.EXPECT().List(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil) + + dec, err := cache.ListDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestListIteratorDecoded_Ok() { + suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once() + suite.iter.EXPECT().Next().Return(suite.artistGood2.Resource).Once() + suite.iter.EXPECT().Next().Return(nil).Times(0) + suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return(suite.iter, nil) + + iter, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.NotNil(suite.T(), iter) + + dec, err := iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec.Data) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestListIteratorDecoded_DecodeError() { + suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once() + suite.iter.EXPECT().Next().Return(suite.artistBad).Once() + suite.iter.EXPECT().Next().Return(nil).Times(0) + suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return(suite.iter, nil) + + iter, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.NotNil(suite.T(), iter) + + dec, err := iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) + + dec, err = iter.Next() + require.Error(suite.T(), err) + require.Nil(suite.T(), dec) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestListIteratorDecoded_CacheError() { + suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError) + + iter, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.ErrorIs(suite.T(), err, injectedError) + require.Nil(suite.T(), iter) +} + +func (suite *decodedSuite) TestListIteratorDecoded_Nil() { + suite.rc.EXPECT().ListIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil) + + dec, err := cache.ListIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestParentsDecoded_Ok() { + suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistGood2.Resource}, nil) + + dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Len(suite.T(), dec, 2) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec[0].Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec[0].Data) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec[1].Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec[1].Data) +} + +func (suite *decodedSuite) TestParentsDecoded_DecodeError() { + suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return([]*pbresource.Resource{suite.artistGood.Resource, suite.artistBad}, nil) + + dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.Error(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestParentsDecoded_CacheError() { + suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError) + + dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.ErrorIs(suite.T(), err, injectedError) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestParentsDecoded_Nil() { + suite.rc.EXPECT().Parents(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil) + + dec, err := cache.ParentsDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestParentsIteratorDecoded_Ok() { + suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once() + suite.iter.EXPECT().Next().Return(suite.artistGood2.Resource).Once() + suite.iter.EXPECT().Next().Return(nil).Times(0) + suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return(suite.iter, nil) + + iter, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.NotNil(suite.T(), iter) + + dec, err := iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec.Data) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestParentsIteratorDecoded_DecodeError() { + suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once() + suite.iter.EXPECT().Next().Return(suite.artistBad).Once() + suite.iter.EXPECT().Next().Return(nil).Times(0) + suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id). + Return(suite.iter, nil) + + iter, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.NotNil(suite.T(), iter) + + dec, err := iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) + + dec, err = iter.Next() + require.Error(suite.T(), err) + require.Nil(suite.T(), dec) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestParentsIteratorDecoded_CacheError() { + suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, injectedError) + + iter, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.ErrorIs(suite.T(), err, injectedError) + require.Nil(suite.T(), iter) +} + +func (suite *decodedSuite) TestParentsIteratorDecoded_Nil() { + suite.rc.EXPECT().ParentsIterator(pbdemo.ArtistType, "id", suite.artistGood.Id).Return(nil, nil) + + dec, err := cache.ParentsIteratorDecoded[*pbdemo.Artist](suite.rc, pbdemo.ArtistType, "id", suite.artistGood.Id) + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestQueryDecoded_Ok() { + suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once() + suite.iter.EXPECT().Next().Return(suite.artistGood2.Resource).Once() + suite.iter.EXPECT().Next().Return(nil).Times(0) + suite.rc.EXPECT().Query("query", "blah"). + Return(suite.iter, nil) + + iter, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah") + require.NoError(suite.T(), err) + require.NotNil(suite.T(), iter) + + dec, err := iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood2.Data, dec.Data) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestQueryDecoded_DecodeError() { + suite.iter.EXPECT().Next().Return(suite.artistGood.Resource).Once() + suite.iter.EXPECT().Next().Return(suite.artistBad).Once() + suite.iter.EXPECT().Next().Return(nil).Times(0) + suite.rc.EXPECT().Query("query", "blah"). + Return(suite.iter, nil) + + iter, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah") + require.NoError(suite.T(), err) + require.NotNil(suite.T(), iter) + + dec, err := iter.Next() + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Resource, dec.Resource) + prototest.AssertDeepEqual(suite.T(), suite.artistGood.Data, dec.Data) + + dec, err = iter.Next() + require.Error(suite.T(), err) + require.Nil(suite.T(), dec) + + dec, err = iter.Next() + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestQueryDecoded_CacheError() { + suite.rc.EXPECT().Query("query", "blah").Return(nil, injectedError) + + dec, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah") + require.ErrorIs(suite.T(), err, injectedError) + require.Nil(suite.T(), dec) +} + +func (suite *decodedSuite) TestQueryDecoded_Nil() { + suite.rc.EXPECT().Query("query", "blah").Return(nil, nil) + + dec, err := cache.QueryDecoded[*pbdemo.Artist](suite.rc, "query", "blah") + require.NoError(suite.T(), err) + require.Nil(suite.T(), dec) +} + +func TestDecodedCache(t *testing.T) { + suite.Run(t, new(decodedSuite)) +} diff --git a/internal/controller/cache/index/convenience.go b/internal/controller/cache/index/convenience.go index fded6930a0..b3c9dbf86b 100644 --- a/internal/controller/cache/index/convenience.go +++ b/internal/controller/cache/index/convenience.go @@ -11,6 +11,10 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type IndexQueryOptions struct { + Prefix bool +} + func IndexFromID(id *pbresource.ID, includeUid bool) []byte { var b Builder b.Raw(IndexFromType(id.Type)) @@ -87,6 +91,14 @@ var PrefixReferenceOrIDFromArgs = SingleValueFromArgs[resource.ReferenceOrID](fu return PrefixIndexFromRefOrID(r), nil }) +var MaybePrefixReferenceOrIDFromArgs = SingleValueFromOneOrTwoArgs[resource.ReferenceOrID, IndexQueryOptions](func(r resource.ReferenceOrID, opts IndexQueryOptions) ([]byte, error) { + if opts.Prefix { + return PrefixIndexFromRefOrID(r), nil + } else { + return IndexFromRefOrID(r), nil + } +}) + func SingleValueFromArgs[T any](indexer func(value T) ([]byte, error)) func(args ...any) ([]byte, error) { return func(args ...any) ([]byte, error) { var zero T diff --git a/internal/controller/cache/indexers/id_indexer.go b/internal/controller/cache/indexers/id_indexer.go index f45358f40c..f50d609446 100644 --- a/internal/controller/cache/indexers/id_indexer.go +++ b/internal/controller/cache/indexers/id_indexer.go @@ -38,7 +38,7 @@ type idOrRefIndexer struct { // FromArgs constructs a radix tree key from an ID for lookup. func (i idOrRefIndexer) FromArgs(args ...any) ([]byte, error) { - return index.ReferenceOrIDFromArgs(args...) + return index.MaybePrefixReferenceOrIDFromArgs(args...) } // FromObject constructs a radix tree key from a Resource at write-time, or an diff --git a/internal/mesh/internal/mappers/workloadselectionmapper/workload_selection_mapper.go b/internal/mesh/internal/mappers/workloadselectionmapper/workload_selection_mapper.go index 7b06424841..7cbc35e57f 100644 --- a/internal/mesh/internal/mappers/workloadselectionmapper/workload_selection_mapper.go +++ b/internal/mesh/internal/mappers/workloadselectionmapper/workload_selection_mapper.go @@ -6,7 +6,7 @@ package workloadselectionmapper import ( "context" - "github.com/hashicorp/consul/internal/catalog" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/mesh/internal/mappers/common" "github.com/hashicorp/consul/internal/resource" @@ -15,12 +15,12 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) -type Mapper[T catalog.WorkloadSelecting] struct { +type Mapper[T workloadselector.WorkloadSelecting] struct { workloadSelectionTracker *selectiontracker.WorkloadSelectionTracker computedType *pbresource.Type } -func New[T catalog.WorkloadSelecting](computedType *pbresource.Type) *Mapper[T] { +func New[T workloadselector.WorkloadSelecting](computedType *pbresource.Type) *Mapper[T] { if computedType == nil { panic("computed type is required") } diff --git a/internal/mesh/internal/types/destinations.go b/internal/mesh/internal/types/destinations.go index 7de3011e3e..a128631195 100644 --- a/internal/mesh/internal/types/destinations.go +++ b/internal/mesh/internal/types/destinations.go @@ -10,6 +10,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/hashicorp/consul/internal/catalog" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" @@ -22,7 +23,7 @@ func RegisterDestinations(r resource.Registry) { Scope: resource.ScopeNamespace, Mutate: MutateDestinations, Validate: ValidateDestinations, - ACLs: catalog.ACLHooksForWorkloadSelectingType[*pbmesh.Destinations](), + ACLs: workloadselector.ACLHooks[*pbmesh.Destinations](), }) } diff --git a/internal/mesh/internal/types/destinations_configuration.go b/internal/mesh/internal/types/destinations_configuration.go index 7d46d93ed9..2b6e1f2c75 100644 --- a/internal/mesh/internal/types/destinations_configuration.go +++ b/internal/mesh/internal/types/destinations_configuration.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/consul/internal/catalog" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" @@ -18,7 +19,7 @@ func RegisterDestinationsConfiguration(r resource.Registry) { Proto: &pbmesh.DestinationsConfiguration{}, Scope: resource.ScopeNamespace, Validate: ValidateDestinationsConfiguration, - ACLs: catalog.ACLHooksForWorkloadSelectingType[*pbmesh.DestinationsConfiguration](), + ACLs: workloadselector.ACLHooks[*pbmesh.DestinationsConfiguration](), }) } diff --git a/internal/mesh/internal/types/proxy_configuration.go b/internal/mesh/internal/types/proxy_configuration.go index 9a4388a40f..4ab849a6f9 100644 --- a/internal/mesh/internal/types/proxy_configuration.go +++ b/internal/mesh/internal/types/proxy_configuration.go @@ -7,6 +7,7 @@ import ( "math" "github.com/hashicorp/consul/internal/catalog" + "github.com/hashicorp/consul/internal/catalog/workloadselector" "github.com/hashicorp/go-multierror" @@ -22,7 +23,7 @@ func RegisterProxyConfiguration(r resource.Registry) { Scope: resource.ScopeNamespace, Mutate: MutateProxyConfiguration, Validate: ValidateProxyConfiguration, - ACLs: catalog.ACLHooksForWorkloadSelectingType[*pbmesh.ProxyConfiguration](), + ACLs: workloadselector.ACLHooks[*pbmesh.ProxyConfiguration](), }) } diff --git a/internal/resource/decode.go b/internal/resource/decode.go index 2b96853f1e..461fd9376e 100644 --- a/internal/resource/decode.go +++ b/internal/resource/decode.go @@ -58,6 +58,22 @@ func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], err }, nil } +// DecodeList will generically decode the provided resource list into a list of 2-field +// structures that holds onto the original Resource and the decoded contents. +// +// Returns an ErrDataParse on unmarshalling errors. +func DecodeList[T proto.Message](resources []*pbresource.Resource) ([]*DecodedResource[T], error) { + var decoded []*DecodedResource[T] + for _, res := range resources { + d, err := Decode[T](res) + if err != nil { + return nil, err + } + decoded = append(decoded, d) + } + return decoded, nil +} + // GetDecodedResource will generically read the requested resource using the // client and either return nil on a NotFound or decode the response value. func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) { diff --git a/internal/resource/decode_test.go b/internal/resource/decode_test.go index 31db247c70..9df601c901 100644 --- a/internal/resource/decode_test.go +++ b/internal/resource/decode_test.go @@ -112,3 +112,50 @@ func TestDecode(t *testing.T) { require.Error(t, err) }) } + +func TestDecodeList(t *testing.T) { + t.Run("good", func(t *testing.T) { + artist1, err := demo.GenerateV2Artist() + require.NoError(t, err) + artist2, err := demo.GenerateV2Artist() + require.NoError(t, err) + dec1, err := resource.Decode[*pbdemo.Artist](artist1) + require.NoError(t, err) + dec2, err := resource.Decode[*pbdemo.Artist](artist2) + require.NoError(t, err) + + resources := []*pbresource.Resource{artist1, artist2} + + decList, err := resource.DecodeList[*pbdemo.Artist](resources) + require.NoError(t, err) + require.Len(t, decList, 2) + + prototest.AssertDeepEqual(t, dec1.Resource, decList[0].Resource) + prototest.AssertDeepEqual(t, dec1.Data, decList[0].Data) + prototest.AssertDeepEqual(t, dec2.Resource, decList[1].Resource) + prototest.AssertDeepEqual(t, dec2.Data, decList[1].Data) + }) + + t.Run("bad", func(t *testing.T) { + artist1, err := demo.GenerateV2Artist() + require.NoError(t, err) + + foo := &pbresource.Resource{ + Id: &pbresource.ID{ + Type: demo.TypeV2Artist, + Tenancy: resource.DefaultNamespacedTenancy(), + Name: "babypants", + }, + Data: &anypb.Any{ + TypeUrl: "garbage", + Value: []byte("more garbage"), + }, + Metadata: map[string]string{ + "generated_at": time.Now().Format(time.RFC3339), + }, + } + + _, err = resource.DecodeList[*pbdemo.Artist]([]*pbresource.Resource{artist1, foo}) + require.Error(t, err) + }) +} diff --git a/internal/resource/filter.go b/internal/resource/filter.go index 44a3689293..09e251e218 100644 --- a/internal/resource/filter.go +++ b/internal/resource/filter.go @@ -11,6 +11,10 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type MetadataFilterableResources interface { + GetMetadata() map[string]string +} + // FilterResourcesByMetadata will use the provided go-bexpr based filter to // retain matching items from the provided slice. // @@ -18,7 +22,7 @@ import ( // by "metadata." // // If no filter is provided, then this does nothing and returns the input. -func FilterResourcesByMetadata(resources []*pbresource.Resource, filter string) ([]*pbresource.Resource, error) { +func FilterResourcesByMetadata[T MetadataFilterableResources](resources []T, filter string) ([]T, error) { if filter == "" || len(resources) == 0 { return resources, nil } @@ -28,10 +32,10 @@ func FilterResourcesByMetadata(resources []*pbresource.Resource, filter string) return nil, err } - filtered := make([]*pbresource.Resource, 0, len(resources)) + filtered := make([]T, 0, len(resources)) for _, res := range resources { vars := &metadataFilterFieldDetails{ - Meta: res.Metadata, + Meta: res.GetMetadata(), } match, err := eval.Evaluate(vars) if err != nil { diff --git a/internal/resource/resourcetest/decode.go b/internal/resource/resourcetest/decode.go index 109ad39ceb..c84f56df4e 100644 --- a/internal/resource/resourcetest/decode.go +++ b/internal/resource/resourcetest/decode.go @@ -16,3 +16,9 @@ func MustDecode[Tp proto.Message](t T, res *pbresource.Resource) *resource.Decod require.NoError(t, err) return dec } + +func MustDecodeList[Tp proto.Message](t T, resources []*pbresource.Resource) []*resource.DecodedResource[Tp] { + dec, err := resource.DecodeList[Tp](resources) + require.NoError(t, err) + return dec +}