diff --git a/internal/mesh/internal/controllers/sidecarproxy/builder/local_app.go b/internal/mesh/internal/controllers/sidecarproxy/builder/local_app.go index 1160321ac7..03e45f0428 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/builder/local_app.go +++ b/internal/mesh/internal/controllers/sidecarproxy/builder/local_app.go @@ -268,7 +268,7 @@ func (l *ListenerBuilder) addInboundRouter(clusterName string, port *pbcatalog.W return l } - if port.Protocol == pbcatalog.Protocol_PROTOCOL_TCP || port.Protocol == pbcatalog.Protocol_PROTOCOL_UNSPECIFIED { + if port.Protocol == pbcatalog.Protocol_PROTOCOL_TCP { r := &pbproxystate.Router{ Destination: &pbproxystate.Router_L4{ L4: &pbproxystate.L4Destination{ diff --git a/internal/mesh/internal/controllers/sidecarproxy/cache/cache.go b/internal/mesh/internal/controllers/sidecarproxy/cache/cache.go index e2bf82d974..b2007d5b0d 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/cache/cache.go +++ b/internal/mesh/internal/controllers/sidecarproxy/cache/cache.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/consul/internal/mesh/internal/types" "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource/mappers/bimapper" + "github.com/hashicorp/consul/internal/resource/mappers/selectiontracker" "github.com/hashicorp/consul/internal/storage" pbauth "github.com/hashicorp/consul/proto-public/pbauth/v2beta1" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" @@ -18,16 +19,26 @@ import ( ) type Cache struct { - computedRoutes *bimapper.Mapper - identities *bimapper.Mapper + // computedRoutes keeps track of computed routes IDs to service references it applies to. + computedRoutes *bimapper.Mapper + + // identities keeps track of which identity a workload is mapped to. + identities *bimapper.Mapper + + // computedDestinations keeps track of the computed explicit destinations IDs to service references that are + // referenced in that resource. computedDestinations *bimapper.Mapper + + // serviceSelectorTracker keeps track of which workload selectors a service is currently using. + serviceSelectorTracker *selectiontracker.WorkloadSelectionTracker } func New() *Cache { return &Cache{ - computedRoutes: bimapper.New(pbmesh.ComputedRoutesType, pbcatalog.ServiceType), - identities: bimapper.New(pbcatalog.WorkloadType, pbauth.WorkloadIdentityType), - computedDestinations: bimapper.New(pbmesh.ComputedExplicitDestinationsType, pbcatalog.ServiceType), + computedRoutes: bimapper.New(pbmesh.ComputedRoutesType, pbcatalog.ServiceType), + identities: bimapper.New(pbcatalog.WorkloadType, pbauth.WorkloadIdentityType), + computedDestinations: bimapper.New(pbmesh.ComputedExplicitDestinationsType, pbcatalog.ServiceType), + serviceSelectorTracker: selectiontracker.New(), } } @@ -87,6 +98,14 @@ func (c *Cache) WorkloadsByWorkloadIdentity(id *pbresource.ID) []*pbresource.ID return c.identities.ItemIDsForLink(id) } +func (c *Cache) ServicesForWorkload(id *pbresource.ID) []*pbresource.ID { + return c.serviceSelectorTracker.GetIDsForWorkload(id) +} + +func (c *Cache) UntrackService(id *pbresource.ID) { + c.serviceSelectorTracker.UntrackID(id) +} + func (c *Cache) MapComputedRoutes(ctx context.Context, rt controller.Runtime, res *pbresource.Resource) ([]controller.Request, error) { computedRoutes, err := resource.Decode[*pbmesh.ComputedRoutes](res) if err != nil { @@ -111,7 +130,18 @@ func (c *Cache) mapComputedRoutesToProxyStateTemplate(ctx context.Context, rt co return c.mapServiceThroughDestinations(ctx, rt, serviceRef) } +func (c *Cache) TrackService(svc *types.DecodedService) { + c.serviceSelectorTracker.TrackIDForSelector(svc.Resource.GetId(), svc.GetData().GetWorkloads()) +} + func (c *Cache) MapService(ctx context.Context, rt controller.Runtime, res *pbresource.Resource) ([]controller.Request, error) { + // Record workload selector in the cache every time we see an event for a service. + decodedService, err := resource.Decode[*pbcatalog.Service](res) + if err != nil { + return nil, err + } + c.TrackService(decodedService) + serviceRef := resource.Reference(res.Id, "") pstIDs, err := c.mapServiceThroughDestinations(ctx, rt, serviceRef) diff --git a/internal/mesh/internal/controllers/sidecarproxy/cache/cache_test.go b/internal/mesh/internal/controllers/sidecarproxy/cache/cache_test.go index e630ed22ec..2aa5484db4 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/cache/cache_test.go +++ b/internal/mesh/internal/controllers/sidecarproxy/cache/cache_test.go @@ -136,6 +136,9 @@ func TestUnified_AllMappingsToProxyStateTemplate(t *testing.T) { ) anyServiceData := &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{"src-workload"}, + }, Ports: []*pbcatalog.ServicePort{ { TargetPort: "tcp1", @@ -315,6 +318,26 @@ func TestUnified_AllMappingsToProxyStateTemplate(t *testing.T) { } prototest.AssertElementsMatch(t, expRequests, requests) + + // Check that service's workload selector is tracked. + prototest.AssertElementsMatch(t, + []*pbresource.ID{destService.Id}, + cache.serviceSelectorTracker.GetIDsForWorkload(resource.ReplaceType(pbcatalog.WorkloadType, sourceProxy1))) + prototest.AssertElementsMatch(t, + []*pbresource.ID{destService.Id}, + cache.serviceSelectorTracker.GetIDsForWorkload(resource.ReplaceType(pbcatalog.WorkloadType, sourceProxy2))) + prototest.AssertElementsMatch(t, + []*pbresource.ID{destService.Id}, + cache.serviceSelectorTracker.GetIDsForWorkload(resource.ReplaceType(pbcatalog.WorkloadType, sourceProxy3))) + prototest.AssertElementsMatch(t, + []*pbresource.ID{destService.Id}, + cache.serviceSelectorTracker.GetIDsForWorkload(resource.ReplaceType(pbcatalog.WorkloadType, sourceProxy4))) + prototest.AssertElementsMatch(t, + []*pbresource.ID{destService.Id}, + cache.serviceSelectorTracker.GetIDsForWorkload(resource.ReplaceType(pbcatalog.WorkloadType, sourceProxy5))) + prototest.AssertElementsMatch(t, + []*pbresource.ID{destService.Id}, + cache.serviceSelectorTracker.GetIDsForWorkload(resource.ReplaceType(pbcatalog.WorkloadType, sourceProxy6))) }) t.Run("map target endpoints (TCPRoute)", func(t *testing.T) { diff --git a/internal/mesh/internal/controllers/sidecarproxy/controller.go b/internal/mesh/internal/controllers/sidecarproxy/controller.go index 4ab6496f03..6b81b04b8c 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/controller.go +++ b/internal/mesh/internal/controllers/sidecarproxy/controller.go @@ -6,6 +6,7 @@ package sidecarproxy import ( "context" + "github.com/hashicorp/go-hclog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -179,8 +180,16 @@ func (r *reconciler) Reconcile(ctx context.Context, rt controller.Runtime, req c ctp = trafficPermissions.Data } + workloadPorts, err := r.workloadPortProtocolsFromService(ctx, dataFetcher, workload, rt.Logger) + if err != nil { + rt.Logger.Error("error determining workload ports", "error", err) + return err + } + workloadDataWithInheritedPorts := proto.Clone(workload.Data).(*pbcatalog.Workload) + workloadDataWithInheritedPorts.Ports = workloadPorts + b := builder.New(req.ID, identityRefFromWorkload(workload), trustDomain, r.dc, r.defaultAllow, proxyCfg.GetData()). - BuildLocalApp(workload.Data, ctp) + BuildLocalApp(workloadDataWithInheritedPorts, ctp) // Get all destinationsData. destinationsData, err := dataFetcher.FetchExplicitDestinationsData(ctx, req.ID) @@ -230,6 +239,100 @@ func (r *reconciler) Reconcile(ctx context.Context, rt controller.Runtime, req c return nil } +func (r *reconciler) workloadPortProtocolsFromService( + ctx context.Context, + fetcher *fetcher.Fetcher, + workload *types.DecodedWorkload, + logger hclog.Logger, +) (map[string]*pbcatalog.WorkloadPort, error) { + + // Fetch all services for this workload. + serviceIDs := r.cache.ServicesForWorkload(workload.GetResource().GetId()) + + var services []*types.DecodedService + + for _, serviceID := range serviceIDs { + svc, err := fetcher.FetchService(ctx, serviceID) + if err != nil { + return nil, err + } + + // If service is not found, we should untrack it. + if svc == nil { + r.cache.UntrackService(serviceID) + continue + } + + services = append(services, svc) + } + + // Now walk through all workload ports. + // For ports that don't have a protocol explicitly specified, inherit it from the service. + + result := make(map[string]*pbcatalog.WorkloadPort) + + for portName, port := range workload.GetData().GetPorts() { + if port.GetProtocol() != pbcatalog.Protocol_PROTOCOL_UNSPECIFIED { + // Add any specified protocols as is. + result[portName] = port + continue + } + + // Check if we have any service IDs or fetched services. + if len(serviceIDs) == 0 || len(services) == 0 { + logger.Trace("found no services for this workload's port; using default TCP protocol", "port", portName) + result[portName] = &pbcatalog.WorkloadPort{ + Port: port.GetPort(), + Protocol: pbcatalog.Protocol_PROTOCOL_TCP, + } + continue + } + + // Otherwise, look for port protocol in the service. + inheritedProtocol := pbcatalog.Protocol_PROTOCOL_UNSPECIFIED + for _, svc := range services { + // Find workload's port as the target port. + svcPort := svc.GetData().FindServicePort(portName) + + // If this service doesn't select this port, go to the next service. + if svcPort == nil { + continue + } + + // Check for conflicts. + // If protocols between services selecting this workload on this port do not match, + // we use the default protocol (tcp) instead. + if inheritedProtocol != pbcatalog.Protocol_PROTOCOL_UNSPECIFIED && + svcPort.GetProtocol() != inheritedProtocol { + + logger.Trace("found conflicting service protocols that select this workload port; using default TCP protocol", "port", portName) + inheritedProtocol = pbcatalog.Protocol_PROTOCOL_TCP + + // We won't check any remaining services as there's already a conflict. + break + } + + inheritedProtocol = svcPort.GetProtocol() + } + + // If after going through all services, we haven't found a protocol, use the default. + if inheritedProtocol == pbcatalog.Protocol_PROTOCOL_UNSPECIFIED { + logger.Trace("no services select this workload port; using default TCP protocol", "port", portName) + result[portName] = &pbcatalog.WorkloadPort{ + Port: port.GetPort(), + Protocol: pbcatalog.Protocol_PROTOCOL_TCP, + } + } else { + result[portName] = &pbcatalog.WorkloadPort{ + Port: port.GetPort(), + Protocol: inheritedProtocol, + } + } + } + + return result, nil +} + func identityRefFromWorkload(w *types.DecodedWorkload) *pbresource.Reference { return &pbresource.Reference{ Name: w.Data.Identity, diff --git a/internal/mesh/internal/controllers/sidecarproxy/controller_test.go b/internal/mesh/internal/controllers/sidecarproxy/controller_test.go index b081dc7a25..33db69e9f5 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/controller_test.go +++ b/internal/mesh/internal/controllers/sidecarproxy/controller_test.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/consul/internal/mesh/internal/controllers/routes/routestest" "github.com/hashicorp/consul/internal/mesh/internal/controllers/sidecarproxy/builder" "github.com/hashicorp/consul/internal/mesh/internal/controllers/sidecarproxy/cache" + "github.com/hashicorp/consul/internal/mesh/internal/controllers/sidecarproxy/fetcher" "github.com/hashicorp/consul/internal/mesh/internal/types" "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/internal/resource/resourcetest" @@ -32,7 +33,7 @@ import ( "github.com/hashicorp/consul/sdk/testutil/retry" ) -type meshControllerTestSuite struct { +type controllerTestSuite struct { suite.Suite client *resourcetest.Client @@ -60,7 +61,7 @@ type meshControllerTestSuite struct { proxyStateTemplate *pbmesh.ProxyStateTemplate } -func (suite *meshControllerTestSuite) SetupTest() { +func (suite *controllerTestSuite) SetupTest() { resourceClient := svctest.RunResourceService(suite.T(), types.Register, catalog.RegisterTypes, auth.RegisterTypes) suite.client = resourcetest.NewClient(resourceClient) suite.runtime = controller.Runtime{Client: resourceClient, Logger: testutil.Logger(suite.T())} @@ -234,7 +235,152 @@ func (suite *meshControllerTestSuite) SetupTest() { Build() } -func (suite *meshControllerTestSuite) TestReconcile_NoWorkload() { +func (suite *controllerTestSuite) TestWorkloadPortProtocolsFromService_NoServicesInCache() { + dataFetcher := fetcher.New(suite.client, suite.ctl.cache) + + workload := resourcetest.Resource(pbcatalog.WorkloadType, "api-workload"). + WithData(suite.T(), &pbcatalog.Workload{ + Ports: map[string]*pbcatalog.WorkloadPort{ + "tcp": {Port: 8080}, + }, + }). + Build() + + decWorkload := resourcetest.MustDecode[*pbcatalog.Workload](suite.T(), workload) + workloadPorts, err := suite.ctl.workloadPortProtocolsFromService(suite.ctx, dataFetcher, decWorkload, suite.runtime.Logger) + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), pbcatalog.Protocol_PROTOCOL_TCP, workloadPorts["tcp"].GetProtocol()) +} + +func (suite *controllerTestSuite) TestWorkloadPortProtocolsFromService_ServiceNotFound() { + c := cache.New() + dataFetcher := fetcher.New(suite.client, c) + ctrl := &reconciler{ + cache: c, + getTrustDomain: func() (string, error) { + return "test.consul", nil + }, + } + svc := resourcetest.Resource(pbcatalog.ServiceType, "not-found"). + WithData(suite.T(), &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{"api-workload"}, + }, + }). + Build() + + decSvc := resourcetest.MustDecode[*pbcatalog.Service](suite.T(), svc) + c.TrackService(decSvc) + + workload := resourcetest.Resource(pbcatalog.WorkloadType, "api-workload"). + WithData(suite.T(), &pbcatalog.Workload{ + Ports: map[string]*pbcatalog.WorkloadPort{ + "tcp": {Port: 8080}, + }, + }). + Build() + + decWorkload := resourcetest.MustDecode[*pbcatalog.Workload](suite.T(), workload) + + workloadPorts, err := ctrl.workloadPortProtocolsFromService(suite.ctx, dataFetcher, decWorkload, suite.runtime.Logger) + require.NoError(suite.T(), err) + prototest.AssertDeepEqual(suite.T(), pbcatalog.Protocol_PROTOCOL_TCP, workloadPorts["tcp"].GetProtocol()) + // Check that the service is no longer in cache. + require.Nil(suite.T(), c.ServicesForWorkload(workload.Id)) +} + +func (suite *controllerTestSuite) TestWorkloadPortProtocolsFromService() { + c := cache.New() + dataFetcher := fetcher.New(suite.client, c) + ctrl := &reconciler{ + cache: c, + getTrustDomain: func() (string, error) { + return "test.consul", nil + }, + } + svc1 := resourcetest.Resource(pbcatalog.ServiceType, "api-1"). + WithData(suite.T(), &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{"api-workload"}, + }, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http1", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "conflict", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + }). + Write(suite.T(), suite.client) + + decSvc := resourcetest.MustDecode[*pbcatalog.Service](suite.T(), svc1) + c.TrackService(decSvc) + + svc2 := resourcetest.Resource(pbcatalog.ServiceType, "api-2"). + WithData(suite.T(), &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{ + Names: []string{"api-workload"}, + }, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http2", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP2, + }, + { + TargetPort: "conflict", + Protocol: pbcatalog.Protocol_PROTOCOL_GRPC, + }, + }, + }). + Write(suite.T(), suite.client) + + decSvc = resourcetest.MustDecode[*pbcatalog.Service](suite.T(), svc2) + c.TrackService(decSvc) + + workload := resourcetest.Resource(pbcatalog.WorkloadType, "api-workload"). + WithData(suite.T(), &pbcatalog.Workload{ + Ports: map[string]*pbcatalog.WorkloadPort{ + "http1": {Port: 8080}, + "http2": {Port: 9090}, + "conflict": {Port: 9091}, + "not-selected": {Port: 8081}, + "specified-protocol": {Port: 8082, Protocol: pbcatalog.Protocol_PROTOCOL_GRPC}, + "mesh": {Port: 20000, Protocol: pbcatalog.Protocol_PROTOCOL_MESH}, + }, + }). + WithTenancy(resource.DefaultNamespacedTenancy()). + Build() + + decWorkload := resourcetest.MustDecode[*pbcatalog.Workload](suite.T(), workload) + + expWorkloadPorts := map[string]*pbcatalog.WorkloadPort{ + // This protocol should be inherited from service 1. + "http1": {Port: 8080, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP}, + + // this protocol should be inherited from service 2. + "http2": {Port: 9090, Protocol: pbcatalog.Protocol_PROTOCOL_HTTP2}, + + // This port is not selected by the service and should default to tcp. + "not-selected": {Port: 8081, Protocol: pbcatalog.Protocol_PROTOCOL_TCP}, + + // This port has conflicting protocols in each service and so it should default to tcp. + "conflict": {Port: 9091, Protocol: pbcatalog.Protocol_PROTOCOL_TCP}, + + // These port should keep its existing protocol. + "specified-protocol": {Port: 8082, Protocol: pbcatalog.Protocol_PROTOCOL_GRPC}, + "mesh": {Port: 20000, Protocol: pbcatalog.Protocol_PROTOCOL_MESH}, + } + + workloadPorts, err := ctrl.workloadPortProtocolsFromService(suite.ctx, dataFetcher, decWorkload, suite.runtime.Logger) + require.NoError(suite.T(), err) + + prototest.AssertDeepEqual(suite.T(), expWorkloadPorts, workloadPorts) +} + +func (suite *controllerTestSuite) TestReconcile_NoWorkload() { // This test ensures that removed workloads are ignored and don't result // in the creation of the proxy state template. err := suite.ctl.Reconcile(context.Background(), suite.runtime, controller.Request{ @@ -245,7 +391,7 @@ func (suite *meshControllerTestSuite) TestReconcile_NoWorkload() { suite.client.RequireResourceNotFound(suite.T(), resourceID(pbmesh.ProxyStateTemplateType, "not-found")) } -func (suite *meshControllerTestSuite) TestReconcile_NonMeshWorkload() { +func (suite *controllerTestSuite) TestReconcile_NonMeshWorkload() { // This test ensures that non-mesh workloads are ignored by the controller. nonMeshWorkload := &pbcatalog.Workload{ @@ -271,7 +417,7 @@ func (suite *meshControllerTestSuite) TestReconcile_NonMeshWorkload() { suite.client.RequireResourceNotFound(suite.T(), resourceID(pbmesh.ProxyStateTemplateType, "test-non-mesh-api-workload")) } -func (suite *meshControllerTestSuite) TestReconcile_NoExistingProxyStateTemplate() { +func (suite *controllerTestSuite) TestReconcile_NoExistingProxyStateTemplate() { err := suite.ctl.Reconcile(context.Background(), suite.runtime, controller.Request{ ID: resourceID(pbmesh.ProxyStateTemplateType, suite.apiWorkloadID.Name), }) @@ -283,7 +429,7 @@ func (suite *meshControllerTestSuite) TestReconcile_NoExistingProxyStateTemplate prototest.AssertDeepEqual(suite.T(), suite.apiWorkloadID, res.Owner) } -func (suite *meshControllerTestSuite) TestReconcile_ExistingProxyStateTemplate_WithUpdates() { +func (suite *controllerTestSuite) TestReconcile_ExistingProxyStateTemplate_WithUpdates() { // This test ensures that we write a new proxy state template when there are changes. // Write the original. @@ -292,8 +438,9 @@ func (suite *meshControllerTestSuite) TestReconcile_ExistingProxyStateTemplate_W WithOwner(suite.apiWorkloadID). Write(suite.T(), suite.client.ResourceServiceClient) - // Update the apiWorkload. - suite.apiWorkload.Ports["mesh"].Port = 21000 + // Update the apiWorkload and check that we default the port to tcp if it's unspecified. + suite.apiWorkload.Ports["tcp"].Protocol = pbcatalog.Protocol_PROTOCOL_UNSPECIFIED + updatedWorkloadID := resourcetest.Resource(pbcatalog.WorkloadType, "api-abc"). WithData(suite.T(), suite.apiWorkload). Write(suite.T(), suite.client.ResourceServiceClient).Id @@ -313,12 +460,15 @@ func (suite *meshControllerTestSuite) TestReconcile_ExistingProxyStateTemplate_W require.NoError(suite.T(), err) // Check that our value is updated in the proxy state template. - inboundListenerPort := updatedProxyStateTemplate.ProxyState.Listeners[0]. - BindAddress.(*pbproxystate.Listener_HostPort).HostPort.Port - require.Equal(suite.T(), uint32(21000), inboundListenerPort) + require.Len(suite.T(), updatedProxyStateTemplate.ProxyState.Listeners, 1) + require.Len(suite.T(), updatedProxyStateTemplate.ProxyState.Listeners[0].Routers, 1) + + l4InboundRouter := updatedProxyStateTemplate.ProxyState.Listeners[0]. + Routers[0].GetL4() + require.NotNil(suite.T(), l4InboundRouter) } -func (suite *meshControllerTestSuite) TestReconcile_ExistingProxyStateTemplate_NoUpdates() { +func (suite *controllerTestSuite) TestReconcile_ExistingProxyStateTemplate_NoUpdates() { // This test ensures that we skip writing of the proxy state template when there are no changes to it. // Write the original. @@ -342,7 +492,7 @@ func (suite *meshControllerTestSuite) TestReconcile_ExistingProxyStateTemplate_N resourcetest.RequireVersionUnchanged(suite.T(), updatedProxyState, originalProxyState.Version) } -func (suite *meshControllerTestSuite) TestController() { +func (suite *controllerTestSuite) TestController() { // This is a comprehensive test that checks the overall controller behavior as various resources change state. // This should test interactions between the reconciler, the mappers, and the destinationsCache to ensure they work // together and produce expected result. @@ -611,7 +761,7 @@ func (suite *meshControllerTestSuite) TestController() { }) } -func (suite *meshControllerTestSuite) TestControllerDefaultAllow() { +func (suite *controllerTestSuite) TestControllerDefaultAllow() { // Run the controller manager mgr := controller.NewManager(suite.client, suite.runtime.Logger) @@ -640,7 +790,7 @@ func (suite *meshControllerTestSuite) TestControllerDefaultAllow() { } func TestMeshController(t *testing.T) { - suite.Run(t, new(meshControllerTestSuite)) + suite.Run(t, new(controllerTestSuite)) } func requireExplicitDestinationsFound(t *testing.T, name string, tmplResource *pbresource.Resource) {