From 42efc11b4eaf33551b659a1144d506ee757649e7 Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" <4903+rboyer@users.noreply.github.com> Date: Wed, 9 Aug 2023 11:02:17 -0500 Subject: [PATCH] catalog: adding a controller to reconcile FailoverPolicy resources (#18399) Add most of the semantic cross-resource validation for FailoverPolicy resources using a new controller. --- internal/catalog/exports.go | 25 ++ .../controllers/failover/controller.go | 276 ++++++++++++++++++ .../controllers/failover/controller_test.go | 268 +++++++++++++++++ .../internal/controllers/failover/status.go | 84 ++++++ .../catalog/internal/controllers/register.go | 3 + .../mappers/failovermapper/failover_mapper.go | 60 ++++ .../failovermapper/failover_mapper_test.go | 190 ++++++++++++ internal/resource/decode.go | 21 ++ internal/resource/decode_test.go | 47 +++ proto/private/prototest/testing.go | 1 + 10 files changed, 975 insertions(+) create mode 100644 internal/catalog/internal/controllers/failover/controller.go create mode 100644 internal/catalog/internal/controllers/failover/controller_test.go create mode 100644 internal/catalog/internal/controllers/failover/status.go create mode 100644 internal/catalog/internal/mappers/failovermapper/failover_mapper.go create mode 100644 internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go diff --git a/internal/catalog/exports.go b/internal/catalog/exports.go index 463b11e16e..d5723ddb38 100644 --- a/internal/catalog/exports.go +++ b/internal/catalog/exports.go @@ -6,14 +6,17 @@ package catalog import ( "github.com/hashicorp/consul/internal/catalog/internal/controllers" "github.com/hashicorp/consul/internal/catalog/internal/controllers/endpoints" + "github.com/hashicorp/consul/internal/catalog/internal/controllers/failover" "github.com/hashicorp/consul/internal/catalog/internal/controllers/nodehealth" "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" + "github.com/hashicorp/consul/internal/catalog/internal/mappers/failovermapper" "github.com/hashicorp/consul/internal/catalog/internal/mappers/nodemapper" "github.com/hashicorp/consul/internal/catalog/internal/mappers/selectiontracker" "github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" ) var ( @@ -73,6 +76,15 @@ var ( EndpointsStatusConditionEndpointsManaged = endpoints.StatusConditionEndpointsManaged EndpointsStatusConditionManaged = endpoints.ConditionManaged EndpointsStatusConditionUnmanaged = endpoints.ConditionUnmanaged + + FailoverStatusKey = failover.StatusKey + FailoverStatusConditionAccepted = failover.StatusConditionAccepted + FailoverStatusConditionAcceptedOKReason = failover.OKReason + FailoverStatusConditionAcceptedMissingServiceReason = failover.MissingServiceReason + FailoverStatusConditionAcceptedUnknownPortReason = failover.UnknownPortReason + FailoverStatusConditionAcceptedMissingDestinationServiceReason = failover.MissingDestinationServiceReason + FailoverStatusConditionAcceptedUnknownDestinationPortReason = failover.UnknownDestinationPortReason + FailoverStatusConditionAcceptedUsingMeshDestinationPortReason = failover.UsingMeshDestinationPortReason ) // RegisterTypes adds all resource types within the "catalog" API group @@ -87,6 +99,7 @@ func DefaultControllerDependencies() ControllerDependencies { return ControllerDependencies{ WorkloadHealthNodeMapper: nodemapper.New(), EndpointsWorkloadMapper: selectiontracker.New(), + FailoverMapper: failovermapper.New(), } } @@ -101,3 +114,15 @@ func RegisterControllers(mgr *controller.Manager, deps ControllerDependencies) { func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.FailoverPolicy) *pbcatalog.FailoverPolicy { return types.SimplifyFailoverPolicy(svc, failover) } + +// FailoverPolicyMapper maintains the bidirectional tracking relationship of a +// FailoverPolicy to the Services related to it. +type FailoverPolicyMapper interface { + TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) + UntrackFailover(failoverID *pbresource.ID) + FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID +} + +func NewFailoverPolicyMapper() FailoverPolicyMapper { + return failovermapper.New() +} diff --git a/internal/catalog/internal/controllers/failover/controller.go b/internal/catalog/internal/controllers/failover/controller.go new file mode 100644 index 0000000000..ecb04b6d8f --- /dev/null +++ b/internal/catalog/internal/controllers/failover/controller.go @@ -0,0 +1,276 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package failover + +import ( + "context" + + "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +// FailoverMapper tracks the relationship between a FailoverPolicy an a Service +// it references whether due to name-alignment or from a reference in a +// FailoverDestination leg. +type FailoverMapper interface { + // TrackFailover extracts all Service references from the provided + // FailoverPolicy and indexes them so that MapService can turn Service + // events into FailoverPolicy events properly. + TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) + + // UntrackFailover forgets the links inserted by TrackFailover for the + // provided FailoverPolicyID. + UntrackFailover(failoverID *pbresource.ID) + + // MapService will take a Service resource and return controller requests + // for all FailoverPolicies associated with the Service. + MapService(ctx context.Context, rt controller.Runtime, res *pbresource.Resource) ([]controller.Request, error) +} + +func FailoverPolicyController(mapper FailoverMapper) controller.Controller { + if mapper == nil { + panic("No FailoverMapper was provided to the FailoverPolicyController constructor") + } + return controller.ForType(types.FailoverPolicyType). + WithWatch(types.ServiceType, mapper.MapService). + WithReconciler(newFailoverPolicyReconciler(mapper)) +} + +type failoverPolicyReconciler struct { + mapper FailoverMapper +} + +func newFailoverPolicyReconciler(mapper FailoverMapper) *failoverPolicyReconciler { + return &failoverPolicyReconciler{ + mapper: mapper, + } +} + +func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error { + // The runtime is passed by value so replacing it here for the remainder of this + // reconciliation request processing will not affect future invocations. + rt.Logger = rt.Logger.With("resource-id", req.ID, "controller", StatusKey) + + rt.Logger.Trace("reconciling failover policy") + + failoverPolicyID := req.ID + + failoverPolicy, err := getFailoverPolicy(ctx, rt, failoverPolicyID) + if err != nil { + rt.Logger.Error("error retrieving failover policy", "error", err) + return err + } + if failoverPolicy == nil { + r.mapper.UntrackFailover(failoverPolicyID) + + // Either the failover policy was deleted, or it doesn't exist but an + // update to a Service came through and we can ignore it. + return nil + } + + r.mapper.TrackFailover(failoverPolicy) + + // FailoverPolicy is name-aligned with the Service it controls. + serviceID := &pbresource.ID{ + Type: types.ServiceType, + Tenancy: failoverPolicyID.Tenancy, + Name: failoverPolicyID.Name, + } + + service, err := getService(ctx, rt, serviceID) + if err != nil { + rt.Logger.Error("error retrieving corresponding service", "error", err) + return err + } + destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service]) + if service != nil { + destServices[resource.NewReferenceKey(serviceID)] = service + } + + // Denorm the ports and stuff. After this we have no empty ports. + if service != nil { + failoverPolicy.Data = types.SimplifyFailoverPolicy( + service.Data, + failoverPolicy.Data, + ) + } + + // Fetch services. + for _, dest := range failoverPolicy.Data.GetUnderlyingDestinations() { + if dest.Ref == nil || !isServiceType(dest.Ref.Type) || dest.Ref.Section != "" { + continue // invalid, not possible due to validation hook + } + + key := resource.NewReferenceKey(dest.Ref) + + if _, ok := destServices[key]; ok { + continue + } + + destID := resource.IDFromReference(dest.Ref) + + destService, err := getService(ctx, rt, destID) + if err != nil { + rt.Logger.Error("error retrieving destination service", "service", key, "error", err) + return err + } + + if destService != nil { + destServices[key] = destService + } + } + + newStatus := computeNewStatus(failoverPolicy, service, destServices) + + if resource.EqualStatus(failoverPolicy.Resource.Status[StatusKey], newStatus, false) { + rt.Logger.Trace("resource's failover policy status is unchanged", + "conditions", newStatus.Conditions) + return nil + } + + _, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{ + Id: failoverPolicy.Resource.Id, + Key: StatusKey, + Status: newStatus, + }) + + if err != nil { + rt.Logger.Error("error encountered when attempting to update the resource's failover policy status", "error", err) + return err + } + + rt.Logger.Trace("resource's failover policy status was updated", + "conditions", newStatus.Conditions) + return nil +} + +func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) { + return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id) +} + +func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) { + return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id) +} + +func computeNewStatus( + failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], + service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], + destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], +) *pbresource.Status { + if service == nil { + return &pbresource.Status{ + ObservedGeneration: failoverPolicy.Resource.Generation, + Conditions: []*pbresource.Condition{ + ConditionMissingService, + }, + } + } + + allowedPortProtocols := make(map[string]pbcatalog.Protocol) + for _, port := range service.Data.Ports { + if port.Protocol == pbcatalog.Protocol_PROTOCOL_MESH { + continue // skip + } + allowedPortProtocols[port.TargetPort] = port.Protocol + } + + var conditions []*pbresource.Condition + + if failoverPolicy.Data.Config != nil { + for _, dest := range failoverPolicy.Data.Config.Destinations { + // We know from validation that a Ref must be set, and the type it + // points to is a Service. + // + // Rather than do additional validation, just do a quick + // belt-and-suspenders check-and-skip if something looks weird. + if dest.Ref == nil || !isServiceType(dest.Ref.Type) { + continue + } + + if cond := serviceHasPort(dest, destServices); cond != nil { + conditions = append(conditions, cond) + } + } + // TODO: validate that referenced sameness groups exist + } + + for port, pc := range failoverPolicy.Data.PortConfigs { + if _, ok := allowedPortProtocols[port]; !ok { + conditions = append(conditions, ConditionUnknownPort(port)) + } + + for _, dest := range pc.Destinations { + // We know from validation that a Ref must be set, and the type it + // points to is a Service. + // + // Rather than do additional validation, just do a quick + // belt-and-suspenders check-and-skip if something looks weird. + if dest.Ref == nil || !isServiceType(dest.Ref.Type) { + continue + } + + if cond := serviceHasPort(dest, destServices); cond != nil { + conditions = append(conditions, cond) + } + } + + // TODO: validate that referenced sameness groups exist + } + + if len(conditions) > 0 { + return &pbresource.Status{ + ObservedGeneration: failoverPolicy.Resource.Generation, + Conditions: conditions, + } + } + + return &pbresource.Status{ + ObservedGeneration: failoverPolicy.Resource.Generation, + Conditions: []*pbresource.Condition{ + ConditionOK, + }, + } +} + +func serviceHasPort( + dest *pbcatalog.FailoverDestination, + destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], +) *pbresource.Condition { + key := resource.NewReferenceKey(dest.Ref) + destService, ok := destServices[key] + if !ok { + return ConditionMissingDestinationService(dest.Ref) + } + + found := false + mesh := false + for _, port := range destService.Data.Ports { + if port.TargetPort == dest.Port { + found = true + if port.Protocol == pbcatalog.Protocol_PROTOCOL_MESH { + mesh = true + } + break + } + } + + if !found { + return ConditionUnknownDestinationPort(dest.Ref, dest.Port) + } else if mesh { + return ConditionUsingMeshDestinationPort(dest.Ref, dest.Port) + } + + return nil +} + +func isServiceType(typ *pbresource.Type) bool { + switch { + case resource.EqualType(typ, types.ServiceType): + return true + } + return false +} diff --git a/internal/catalog/internal/controllers/failover/controller_test.go b/internal/catalog/internal/controllers/failover/controller_test.go new file mode 100644 index 0000000000..a53a9f8af4 --- /dev/null +++ b/internal/catalog/internal/controllers/failover/controller_test.go @@ -0,0 +1,268 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package failover + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" + "github.com/hashicorp/consul/internal/catalog/internal/mappers/failovermapper" + "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/sdk/testutil" +) + +type controllerSuite struct { + suite.Suite + + ctx context.Context + client *rtest.Client + rt controller.Runtime + + failoverMapper FailoverMapper + + ctl failoverPolicyReconciler +} + +func (suite *controllerSuite) SetupTest() { + suite.ctx = testutil.TestContext(suite.T()) + client := svctest.RunResourceService(suite.T(), types.Register) + suite.rt = controller.Runtime{ + Client: client, + Logger: testutil.Logger(suite.T()), + } + suite.client = rtest.NewClient(client) + + suite.failoverMapper = failovermapper.New() +} + +func (suite *controllerSuite) TestController() { + // This test's purpose is to exercise the controller in a halfway realistic + // way, verifying the event triggers work in the live code. + + // Run the controller manager + mgr := controller.NewManager(suite.client, suite.rt.Logger) + mgr.Register(FailoverPolicyController(suite.failoverMapper)) + mgr.SetRaftLeader(true) + go mgr.Run(suite.ctx) + + // Create an advance pointer to some services. + apiServiceRef := resource.Reference(rtest.Resource(types.ServiceType, "api").ID(), "") + otherServiceRef := resource.Reference(rtest.Resource(types.ServiceType, "other").ID(), "") + + // create a failover without any services + failoverData := &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + }}, + }, + } + failover := rtest.Resource(types.FailoverPolicyType, "api"). + WithData(suite.T(), failoverData). + Write(suite.T(), suite.client) + + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionMissingService) + + // Provide the service. + apiServiceData := &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{{ + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }}, + } + _ = rtest.Resource(types.ServiceType, "api"). + WithData(suite.T(), apiServiceData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) + + // Update the failover to reference an unknown port + failoverData = &pbcatalog.FailoverPolicy{ + PortConfigs: map[string]*pbcatalog.FailoverConfig{ + "http": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + Port: "http", + }}, + }, + "admin": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + Port: "admin", + }}, + }, + }, + } + _ = rtest.Resource(types.FailoverPolicyType, "api"). + WithData(suite.T(), failoverData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUnknownPort("admin")) + + // update the service to fix the stray reference, but point to a mesh port + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "admin", + Protocol: pbcatalog.Protocol_PROTOCOL_MESH, + }, + }, + } + _ = rtest.Resource(types.ServiceType, "api"). + WithData(suite.T(), apiServiceData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUsingMeshDestinationPort(apiServiceRef, "admin")) + + // update the service to fix the stray reference to not be a mesh port + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "admin", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + _ = rtest.Resource(types.ServiceType, "api"). + WithData(suite.T(), apiServiceData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) + + // change failover leg to point to missing service + failoverData = &pbcatalog.FailoverPolicy{ + PortConfigs: map[string]*pbcatalog.FailoverConfig{ + "http": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + Port: "http", + }}, + }, + "admin": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: otherServiceRef, + Port: "admin", + }}, + }, + }, + } + _ = rtest.Resource(types.FailoverPolicyType, "api"). + WithData(suite.T(), failoverData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionMissingDestinationService(otherServiceRef)) + + // Create the missing service, but forget the port. + otherServiceData := &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, + Ports: []*pbcatalog.ServicePort{{ + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }}, + } + _ = rtest.Resource(types.ServiceType, "other"). + WithData(suite.T(), otherServiceData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUnknownDestinationPort(otherServiceRef, "admin")) + + // fix the destination leg's port + otherServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "admin", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + _ = rtest.Resource(types.ServiceType, "other"). + WithData(suite.T(), otherServiceData). + Write(suite.T(), suite.client) + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) + + // Update the two services to use differnet port names so the easy path doesn't work + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "foo", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "bar", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + _ = rtest.Resource(types.ServiceType, "api"). + WithData(suite.T(), apiServiceData). + Write(suite.T(), suite.client) + + otherServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "foo", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "baz", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + _ = rtest.Resource(types.ServiceType, "other"). + WithData(suite.T(), otherServiceData). + Write(suite.T(), suite.client) + + failoverData = &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: otherServiceRef, + }}, + }, + } + failover = rtest.Resource(types.FailoverPolicyType, "api"). + WithData(suite.T(), failoverData). + Write(suite.T(), suite.client) + + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUnknownDestinationPort(otherServiceRef, "bar")) + + // and fix it the silly way by removing it from api+failover + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "foo", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + _ = rtest.Resource(types.ServiceType, "api"). + WithData(suite.T(), apiServiceData). + Write(suite.T(), suite.client) + + suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) +} + +func TestFailoverController(t *testing.T) { + suite.Run(t, new(controllerSuite)) +} diff --git a/internal/catalog/internal/controllers/failover/status.go b/internal/catalog/internal/controllers/failover/status.go new file mode 100644 index 0000000000..10e5a472bd --- /dev/null +++ b/internal/catalog/internal/controllers/failover/status.go @@ -0,0 +1,84 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package failover + +import ( + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +const ( + StatusKey = "consul.io/failover-policy" + StatusConditionAccepted = "accepted" + + OKReason = "Ok" + OKMessage = "failover policy was accepted" + + MissingServiceReason = "MissingService" + MissingServiceMessage = "service for failover policy does not exist" + + UnknownPortReason = "UnknownPort" + UnknownPortMessagePrefix = "port is not defined on service: " + + MissingDestinationServiceReason = "MissingDestinationService" + MissingDestinationServiceMessagePrefix = "destination service for failover policy does not exist: " + + UnknownDestinationPortReason = "UnknownDestinationPort" + UnknownDestinationPortMessagePrefix = "port is not defined on destination service: " + + UsingMeshDestinationPortReason = "UsingMeshDestinationPort" + UsingMeshDestinationPortMessagePrefix = "port is a special unroutable mesh port on destination service: " +) + +var ( + ConditionOK = &pbresource.Condition{ + Type: StatusConditionAccepted, + State: pbresource.Condition_STATE_TRUE, + Reason: OKReason, + Message: OKMessage, + } + + ConditionMissingService = &pbresource.Condition{ + Type: StatusConditionAccepted, + State: pbresource.Condition_STATE_FALSE, + Reason: MissingServiceReason, + Message: MissingServiceMessage, + } +) + +func ConditionUnknownPort(port string) *pbresource.Condition { + return &pbresource.Condition{ + Type: StatusConditionAccepted, + State: pbresource.Condition_STATE_FALSE, + Reason: UnknownPortReason, + Message: UnknownPortMessagePrefix + port, + } +} + +func ConditionMissingDestinationService(ref *pbresource.Reference) *pbresource.Condition { + return &pbresource.Condition{ + Type: StatusConditionAccepted, + State: pbresource.Condition_STATE_FALSE, + Reason: MissingDestinationServiceReason, + Message: MissingDestinationServiceMessagePrefix + resource.ReferenceToString(ref), + } +} + +func ConditionUnknownDestinationPort(ref *pbresource.Reference, port string) *pbresource.Condition { + return &pbresource.Condition{ + Type: StatusConditionAccepted, + State: pbresource.Condition_STATE_FALSE, + Reason: UnknownDestinationPortReason, + Message: UnknownDestinationPortMessagePrefix + port + " on " + resource.ReferenceToString(ref), + } +} + +func ConditionUsingMeshDestinationPort(ref *pbresource.Reference, port string) *pbresource.Condition { + return &pbresource.Condition{ + Type: StatusConditionAccepted, + State: pbresource.Condition_STATE_FALSE, + Reason: UnknownDestinationPortReason, + Message: UnknownDestinationPortMessagePrefix + port + " on " + resource.ReferenceToString(ref), + } +} diff --git a/internal/catalog/internal/controllers/register.go b/internal/catalog/internal/controllers/register.go index 5f7fc631a5..78a0f1316a 100644 --- a/internal/catalog/internal/controllers/register.go +++ b/internal/catalog/internal/controllers/register.go @@ -5,6 +5,7 @@ package controllers import ( "github.com/hashicorp/consul/internal/catalog/internal/controllers/endpoints" + "github.com/hashicorp/consul/internal/catalog/internal/controllers/failover" "github.com/hashicorp/consul/internal/catalog/internal/controllers/nodehealth" "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" "github.com/hashicorp/consul/internal/controller" @@ -13,10 +14,12 @@ import ( type Dependencies struct { WorkloadHealthNodeMapper workloadhealth.NodeMapper EndpointsWorkloadMapper endpoints.WorkloadMapper + FailoverMapper failover.FailoverMapper } func Register(mgr *controller.Manager, deps Dependencies) { mgr.Register(nodehealth.NodeHealthController()) mgr.Register(workloadhealth.WorkloadHealthController(deps.WorkloadHealthNodeMapper)) mgr.Register(endpoints.ServiceEndpointsController(deps.EndpointsWorkloadMapper)) + mgr.Register(failover.FailoverPolicyController(deps.FailoverMapper)) } diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go new file mode 100644 index 0000000000..61da20a348 --- /dev/null +++ b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go @@ -0,0 +1,60 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package failovermapper + +import ( + "context" + + "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/mappers/bimapper" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +// Mapper tracks the relationship between a FailoverPolicy an a Service it +// references whether due to name-alignment or from a reference in a +// FailoverDestination leg. +type Mapper struct { + b *bimapper.Mapper +} + +// New creates a new Mapper. +func New() *Mapper { + return &Mapper{ + b: bimapper.New(types.FailoverPolicyType, types.ServiceType), + } +} + +// TrackFailover extracts all Service references from the provided +// FailoverPolicy and indexes them so that MapService can turn Service events +// into FailoverPolicy events properly. +func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) { + destRefs := failover.Data.GetUnderlyingDestinationRefs() + destRefs = append(destRefs, &pbresource.Reference{ + Type: types.ServiceType, + Tenancy: failover.Resource.Id.Tenancy, + Name: failover.Resource.Id.Name, + }) + m.trackFailover(failover.Resource.Id, destRefs) +} + +func (m *Mapper) trackFailover(failover *pbresource.ID, services []*pbresource.Reference) { + m.b.TrackItem(failover, services) +} + +// UntrackFailover forgets the links inserted by TrackFailover for the provided +// FailoverPolicyID. +func (m *Mapper) UntrackFailover(failoverID *pbresource.ID) { + m.b.UntrackItem(failoverID) +} + +func (m *Mapper) MapService(ctx context.Context, rt controller.Runtime, res *pbresource.Resource) ([]controller.Request, error) { + return m.b.MapLink(ctx, rt, res) +} + +func (m *Mapper) FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID { + return m.b.ItemsForLink(svcID) +} diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go new file mode 100644 index 0000000000..41621d8d21 --- /dev/null +++ b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go @@ -0,0 +1,190 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package failovermapper + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/proto/private/prototest" +) + +func TestMapper_Tracking(t *testing.T) { + registry := resource.NewRegistry() + types.Register(registry) + + // Create an advance pointer to some services. + randoSvc := rtest.Resource(types.ServiceType, "rando"). + WithData(t, &pbcatalog.Service{}). + Build() + rtest.ValidateAndNormalize(t, registry, randoSvc) + + apiSvc := rtest.Resource(types.ServiceType, "api"). + WithData(t, &pbcatalog.Service{}). + Build() + rtest.ValidateAndNormalize(t, registry, apiSvc) + + fooSvc := rtest.Resource(types.ServiceType, "foo"). + WithData(t, &pbcatalog.Service{}). + Build() + rtest.ValidateAndNormalize(t, registry, fooSvc) + + barSvc := rtest.Resource(types.ServiceType, "bar"). + WithData(t, &pbcatalog.Service{}). + Build() + rtest.ValidateAndNormalize(t, registry, barSvc) + + wwwSvc := rtest.Resource(types.ServiceType, "www"). + WithData(t, &pbcatalog.Service{}). + Build() + rtest.ValidateAndNormalize(t, registry, wwwSvc) + + fail1 := rtest.Resource(types.FailoverPolicyType, "api"). + WithData(t, &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{ + {Ref: newRef(types.ServiceType, "foo")}, + {Ref: newRef(types.ServiceType, "bar")}, + }, + }, + }). + Build() + rtest.ValidateAndNormalize(t, registry, fail1) + failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1) + + fail2 := rtest.Resource(types.FailoverPolicyType, "www"). + WithData(t, &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{ + {Ref: newRef(types.ServiceType, "www"), Datacenter: "dc2"}, + {Ref: newRef(types.ServiceType, "foo")}, + }, + }, + }). + Build() + rtest.ValidateAndNormalize(t, registry, fail2) + failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2) + + fail1_updated := rtest.Resource(types.FailoverPolicyType, "api"). + WithData(t, &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{ + {Ref: newRef(types.ServiceType, "bar")}, + }, + }, + }). + Build() + rtest.ValidateAndNormalize(t, registry, fail1_updated) + failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated) + + m := New() + + // Nothing tracked yet so we assume nothing. + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc) + requireServicesTracked(t, m, fooSvc) + requireServicesTracked(t, m, barSvc) + requireServicesTracked(t, m, wwwSvc) + + // no-ops + m.UntrackFailover(fail1.Id) + + // still nothing + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc) + requireServicesTracked(t, m, fooSvc) + requireServicesTracked(t, m, barSvc) + requireServicesTracked(t, m, wwwSvc) + + // Actually insert some data. + m.TrackFailover(failDec1) + + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc, fail1.Id) + requireServicesTracked(t, m, fooSvc, fail1.Id) + requireServicesTracked(t, m, barSvc, fail1.Id) + requireServicesTracked(t, m, wwwSvc) + + // track it again, no change + m.TrackFailover(failDec1) + + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc, fail1.Id) + requireServicesTracked(t, m, fooSvc, fail1.Id) + requireServicesTracked(t, m, barSvc, fail1.Id) + requireServicesTracked(t, m, wwwSvc) + + // track new one that overlaps slightly + m.TrackFailover(failDec2) + + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc, fail1.Id) + requireServicesTracked(t, m, fooSvc, fail1.Id, fail2.Id) + requireServicesTracked(t, m, barSvc, fail1.Id) + requireServicesTracked(t, m, wwwSvc, fail2.Id) + + // update the original to change it + m.TrackFailover(failDec1_updated) + + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc, fail1.Id) + requireServicesTracked(t, m, fooSvc, fail2.Id) + requireServicesTracked(t, m, barSvc, fail1.Id) + requireServicesTracked(t, m, wwwSvc, fail2.Id) + + // delete the original + m.UntrackFailover(fail1.Id) + + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc) + requireServicesTracked(t, m, fooSvc, fail2.Id) + requireServicesTracked(t, m, barSvc) + requireServicesTracked(t, m, wwwSvc, fail2.Id) + + // delete the other one + m.UntrackFailover(fail2.Id) + + requireServicesTracked(t, m, randoSvc) + requireServicesTracked(t, m, apiSvc) + requireServicesTracked(t, m, fooSvc) + requireServicesTracked(t, m, barSvc) + requireServicesTracked(t, m, wwwSvc) +} + +func requireServicesTracked(t *testing.T, mapper *Mapper, svc *pbresource.Resource, failovers ...*pbresource.ID) { + t.Helper() + + reqs, err := mapper.MapService( + context.Background(), + controller.Runtime{}, + svc, + ) + require.NoError(t, err) + + require.Len(t, reqs, len(failovers)) + + for _, failover := range failovers { + prototest.AssertContainsElement(t, reqs, controller.Request{ID: failover}) + } +} + +func newRef(typ *pbresource.Type, name string) *pbresource.Reference { + return rtest.Resource(typ, name).Reference("") +} + +func defaultTenancy() *pbresource.Tenancy { + return &pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + } +} diff --git a/internal/resource/decode.go b/internal/resource/decode.go index b93b799c52..7d9142ceb7 100644 --- a/internal/resource/decode.go +++ b/internal/resource/decode.go @@ -4,6 +4,10 @@ package resource import ( + "context" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "github.com/hashicorp/consul/proto-public/pbresource" @@ -36,3 +40,20 @@ func Decode[V any, PV interface { Data: data, }, nil } + +// GetDecodedResource will generically read the requested resource using the +// client and either return nil on a NotFound or decode the response value. +func GetDecodedResource[V any, PV interface { + proto.Message + *V +}](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[V, PV], error) { + rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id}) + switch { + case status.Code(err) == codes.NotFound: + return nil, nil + case err != nil: + return nil, err + } + + return Decode[V, PV](rsp.Resource) +} diff --git a/internal/resource/decode_test.go b/internal/resource/decode_test.go index 10b61cfa14..e5b079156e 100644 --- a/internal/resource/decode_test.go +++ b/internal/resource/decode_test.go @@ -10,13 +10,60 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/anypb" + svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource/demo" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" "github.com/hashicorp/consul/proto-public/pbresource" pbdemo "github.com/hashicorp/consul/proto/private/pbdemo/v2" "github.com/hashicorp/consul/proto/private/prototest" + "github.com/hashicorp/consul/sdk/testutil" ) +func TestGetDecodedResource(t *testing.T) { + var ( + baseClient = svctest.RunResourceService(t, demo.RegisterTypes) + client = rtest.NewClient(baseClient) + ctx = testutil.TestContext(t) + ) + + babypantsID := &pbresource.ID{ + Type: demo.TypeV2Artist, + Tenancy: demo.TenancyDefault, + Name: "babypants", + } + + testutil.RunStep(t, "not found", func(t *testing.T) { + got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID) + require.NoError(t, err) + require.Nil(t, got) + }) + + testutil.RunStep(t, "found", func(t *testing.T) { + data := &pbdemo.Artist{ + Name: "caspar babypants", + } + res := rtest.Resource(demo.TypeV2Artist, "babypants"). + WithData(t, data). + Write(t, client) + + got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID) + require.NoError(t, err) + require.NotNil(t, got) + + // Clone generated fields over. + res.Id.Uid = got.Resource.Id.Uid + res.Version = got.Resource.Version + res.Generation = got.Resource.Generation + + // Clone defaulted fields over + data.Genre = pbdemo.Genre_GENRE_DISCO + + prototest.AssertDeepEqual(t, res, got.Resource) + prototest.AssertDeepEqual(t, data, got.Data) + }) +} + func TestDecode(t *testing.T) { t.Run("good", func(t *testing.T) { fooData := &pbdemo.Artist{ diff --git a/proto/private/prototest/testing.go b/proto/private/prototest/testing.go index b423478155..cc1d1e0141 100644 --- a/proto/private/prototest/testing.go +++ b/proto/private/prototest/testing.go @@ -32,6 +32,7 @@ func AssertDeepEqual(t TestingT, x, y interface{}, opts ...cmp.Option) { func AssertElementsMatch[V any]( t TestingT, listX, listY []V, opts ...cmp.Option, ) { + t.Helper() diff := diffElements(listX, listY, opts...) if diff != "" { t.Fatalf("assertion failed: slices do not have matching elements\n--- expected\n+++ actual\n%v", diff)