diff --git a/internal/mesh/internal/controllers/gatewayproxy/controller.go b/internal/mesh/internal/controllers/gatewayproxy/controller.go index 3e9a4b45c5..5868947e35 100644 --- a/internal/mesh/internal/controllers/gatewayproxy/controller.go +++ b/internal/mesh/internal/controllers/gatewayproxy/controller.go @@ -120,7 +120,7 @@ func (r *reconciler) Reconcile(ctx context.Context, rt controller.Runtime, req c } var exportedServices []*pbmulticluster.ComputedExportedService - dec, err := dataFetcher.FetchExportedServices(ctx, exportedServicesID) + dec, err := dataFetcher.FetchComputedExportedServices(ctx, exportedServicesID) if err != nil { rt.Logger.Error("error reading the associated exported services", "error", err) } else if dec == nil { diff --git a/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher.go b/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher.go index 0586435e50..6d9f274c2d 100644 --- a/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher.go +++ b/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher.go @@ -5,6 +5,7 @@ package fetcher import ( "context" + "fmt" "github.com/hashicorp/consul/internal/mesh/internal/controllers/sidecarproxy/cache" "github.com/hashicorp/consul/internal/mesh/internal/types" @@ -13,6 +14,7 @@ import ( pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" pbmulticluster "github.com/hashicorp/consul/proto-public/pbmulticluster/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" + "google.golang.org/protobuf/proto" ) type Fetcher struct { @@ -27,7 +29,11 @@ func New(client pbresource.ResourceServiceClient, cache *cache.Cache) *Fetcher { } } +// FetchMeshGateway fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a MeshGateway type. func (f *Fetcher) FetchMeshGateway(ctx context.Context, id *pbresource.ID) (*types.DecodedMeshGateway, error) { + assertResourceType(pbmesh.MeshGatewayType, id.Type) + dec, err := resource.GetDecodedResource[*pbmesh.MeshGateway](ctx, f.client, id) if err != nil { return nil, err @@ -38,7 +44,11 @@ func (f *Fetcher) FetchMeshGateway(ctx context.Context, id *pbresource.ID) (*typ return dec, nil } +// FetchProxyStateTemplate fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a ProxyStateTemplate type. func (f *Fetcher) FetchProxyStateTemplate(ctx context.Context, id *pbresource.ID) (*types.DecodedProxyStateTemplate, error) { + assertResourceType(pbmesh.ProxyStateTemplateType, id.Type) + dec, err := resource.GetDecodedResource[*pbmesh.ProxyStateTemplate](ctx, f.client, id) if err != nil { return nil, err @@ -49,7 +59,11 @@ func (f *Fetcher) FetchProxyStateTemplate(ctx context.Context, id *pbresource.ID return dec, nil } +// FetchWorkload fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a Workload type. func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types.DecodedWorkload, error) { + assertResourceType(pbcatalog.WorkloadType, id.Type) + dec, err := resource.GetDecodedResource[*pbcatalog.Workload](ctx, f.client, id) if err != nil { return nil, err @@ -60,7 +74,11 @@ func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types. return dec, nil } -func (f *Fetcher) FetchExportedServices(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedExportedServices, error) { +// FetchComputedExportedServices fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a ComputedExportedServices type. +func (f *Fetcher) FetchComputedExportedServices(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedExportedServices, error) { + assertResourceType(pbmulticluster.ComputedExportedServicesType, id.Type) + dec, err := resource.GetDecodedResource[*pbmulticluster.ComputedExportedServices](ctx, f.client, id) if err != nil { return nil, err @@ -71,7 +89,11 @@ func (f *Fetcher) FetchExportedServices(ctx context.Context, id *pbresource.ID) return dec, nil } +// FetchService fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a Service type. func (f *Fetcher) FetchService(ctx context.Context, id *pbresource.ID) (*types.DecodedService, error) { + assertResourceType(pbcatalog.ServiceType, id.Type) + dec, err := resource.GetDecodedResource[*pbcatalog.Service](ctx, f.client, id) if err != nil { return nil, err @@ -81,3 +103,11 @@ func (f *Fetcher) FetchService(ctx context.Context, id *pbresource.ID) (*types.D return dec, nil } + +// this is a helper function to ensure that the resource type we are querying for is the type we expect +func assertResourceType(expected, actual *pbresource.Type) { + if !proto.Equal(expected, actual) { + // this is always a programmer error so safe to panic + panic(fmt.Sprintf("expected a query for a type of %q, you provided a type of %q", expected, actual)) + } +} diff --git a/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher_test.go b/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher_test.go index 895e7470df..e3ec69b6bd 100644 --- a/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher_test.go +++ b/internal/mesh/internal/controllers/gatewayproxy/fetcher/data_fetcher_test.go @@ -124,6 +124,15 @@ func (suite *dataFetcherSuite) TestFetcher_FetchMeshGateway() { require.NoError(t, err) require.NotNil(t, gtw) }) + + testutil.RunStep(suite.T(), "incorrect type is passed", func(t *testing.T) { + incorrectID := resourcetest.Resource(pbcatalog.ServiceType, "api-1").ID() + defer func() { + err := recover() + require.NotNil(t, err) + }() + f.FetchMeshGateway(suite.ctx, incorrectID) + }) }) } @@ -148,6 +157,15 @@ func (suite *dataFetcherSuite) TestFetcher_FetchProxyStateTemplate() { require.NoError(t, err) require.NotNil(t, tmpl) }) + + testutil.RunStep(suite.T(), "incorrect type is passed", func(t *testing.T) { + incorrectID := resourcetest.Resource(pbcatalog.ServiceType, "api-1").ID() + defer func() { + err := recover() + require.NotNil(t, err) + }() + f.FetchProxyStateTemplate(suite.ctx, incorrectID) + }) }) } @@ -186,13 +204,13 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExportedServices() { testutil.RunStep(suite.T(), "exported services do not exist", func(t *testing.T) { nonExistantID := resourcetest.Resource(pbmulticluster.ComputedExportedServicesType, "not-found").ID() - svcs, err := f.FetchExportedServices(suite.ctx, nonExistantID) + svcs, err := f.FetchComputedExportedServices(suite.ctx, nonExistantID) require.NoError(t, err) require.Nil(t, svcs) }) testutil.RunStep(suite.T(), "workload exists", func(t *testing.T) { - svcs, err := f.FetchExportedServices(suite.ctx, suite.exportedServices.Id) + svcs, err := f.FetchComputedExportedServices(suite.ctx, suite.exportedServices.Id) require.NoError(t, err) require.NotNil(t, svcs) }) @@ -220,6 +238,15 @@ func (suite *dataFetcherSuite) TestFetcher_FetchService() { require.NoError(t, err) require.NotNil(t, svc) }) + + testutil.RunStep(suite.T(), "incorrect type is passed", func(t *testing.T) { + incorrectID := resourcetest.Resource(pbmesh.ProxyStateTemplateType, "api-1").ID() + defer func() { + err := recover() + require.NotNil(t, err) + }() + f.FetchService(suite.ctx, incorrectID) + }) }) } diff --git a/internal/mesh/internal/controllers/sidecarproxy/controller.go b/internal/mesh/internal/controllers/sidecarproxy/controller.go index bf93f79d16..c73d84a823 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/controller.go +++ b/internal/mesh/internal/controllers/sidecarproxy/controller.go @@ -200,7 +200,7 @@ func (r *reconciler) Reconcile(ctx context.Context, rt controller.Runtime, req c BuildLocalApp(workloadDataWithInheritedPorts, ctp) // Get all destinationsData. - destinationsData, err := dataFetcher.FetchExplicitDestinationsData(ctx, req.ID) + destinationsData, err := dataFetcher.FetchComputedExplicitDestinationsData(ctx, req.ID) if err != nil { rt.Logger.Error("error fetching explicit destinations for this proxy", "error", err) return err @@ -253,7 +253,6 @@ func (r *reconciler) workloadPortProtocolsFromService( workload *types.DecodedWorkload, logger hclog.Logger, ) (map[string]*pbcatalog.WorkloadPort, error) { - // Fetch all services for this workload. serviceIDs := r.cache.ServicesForWorkload(workload.GetResource().GetId()) diff --git a/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher.go b/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher.go index 7a31798a24..f71198a9f2 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher.go +++ b/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher.go @@ -33,7 +33,10 @@ func New(client pbresource.ResourceServiceClient, cache *cache.Cache) *Fetcher { } } +// FetchWorkload fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a Workload type. func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types.DecodedWorkload, error) { + assertResourceType(pbcatalog.WorkloadType, id.Type) dec, err := resource.GetDecodedResource[*pbcatalog.Workload](ctx, f.client, id) if err != nil { return nil, err @@ -48,27 +51,45 @@ func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types. return dec, err } +// FetchProxyStateTemplate fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a ProxyStateTemplate type. func (f *Fetcher) FetchProxyStateTemplate(ctx context.Context, id *pbresource.ID) (*types.DecodedProxyStateTemplate, error) { + assertResourceType(pbmesh.ProxyStateTemplateType, id.Type) return resource.GetDecodedResource[*pbmesh.ProxyStateTemplate](ctx, f.client, id) } +// FetchComputedTrafficPermissions fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a ComputedTrafficPermissons type. func (f *Fetcher) FetchComputedTrafficPermissions(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedTrafficPermissions, error) { + assertResourceType(pbauth.ComputedTrafficPermissionsType, id.Type) return resource.GetDecodedResource[*pbauth.ComputedTrafficPermissions](ctx, f.client, id) } +// FetchServiceEndpoints fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a ServiceEndpoints type. func (f *Fetcher) FetchServiceEndpoints(ctx context.Context, id *pbresource.ID) (*types.DecodedServiceEndpoints, error) { + assertResourceType(pbcatalog.ServiceEndpointsType, id.Type) return resource.GetDecodedResource[*pbcatalog.ServiceEndpoints](ctx, f.client, id) } +// FetchService fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a Service type. func (f *Fetcher) FetchService(ctx context.Context, id *pbresource.ID) (*types.DecodedService, error) { + assertResourceType(pbcatalog.ServiceType, id.Type) return resource.GetDecodedResource[*pbcatalog.Service](ctx, f.client, id) } +// FetchDestinations fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a Destinations type. func (f *Fetcher) FetchDestinations(ctx context.Context, id *pbresource.ID) (*types.DecodedDestinations, error) { + assertResourceType(pbmesh.DestinationsType, id.Type) return resource.GetDecodedResource[*pbmesh.Destinations](ctx, f.client, id) } +// FetchComputedRoutes fetches a service resource from the resource service. +// This will panic if the type field in the ID argument is not a ComputedRoutes type. func (f *Fetcher) FetchComputedRoutes(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedRoutes, error) { + assertResourceType(pbmesh.ComputedRoutesType, id.Type) if !types.IsComputedRoutesType(id.Type) { return nil, fmt.Errorf("id must be a ComputedRoutes type") } @@ -83,11 +104,10 @@ func (f *Fetcher) FetchComputedRoutes(ctx context.Context, id *pbresource.ID) (* return dec, err } -func (f *Fetcher) FetchExplicitDestinationsData( +func (f *Fetcher) FetchComputedExplicitDestinationsData( ctx context.Context, proxyID *pbresource.ID, ) ([]*intermediateTypes.Destination, error) { - var destinations []*intermediateTypes.Destination // Fetch computed explicit destinations first. @@ -107,9 +127,7 @@ func (f *Fetcher) FetchExplicitDestinationsData( for _, dest := range cd.GetData().GetDestinations() { d := &intermediateTypes.Destination{} - var ( - serviceID = resource.IDFromReference(dest.DestinationRef) - ) + serviceID := resource.IDFromReference(dest.DestinationRef) // Fetch Service svc, err := f.FetchService(ctx, serviceID) @@ -372,3 +390,10 @@ func isPartOfService(workloadID *pbresource.ID, svc *types.DecodedService) bool } return false } + +func assertResourceType(expected, actual *pbresource.Type) { + if !proto.Equal(expected, actual) { + // this is always a programmer error so safe to panic + panic(fmt.Sprintf("expected a query for a type of %q, you provided a type of %q", expected, actual)) + } +} diff --git a/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher_test.go b/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher_test.go index c9cdd54751..d16600e4bc 100644 --- a/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher_test.go +++ b/internal/mesh/internal/controllers/sidecarproxy/fetcher/data_fetcher_test.go @@ -227,9 +227,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { c := cache.New() - var ( - api1ServiceRef = resource.Reference(suite.api1Service.Id, "") - ) + api1ServiceRef := resource.Reference(suite.api1Service.Id, "") f := Fetcher{ cache: c, @@ -252,7 +250,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { c.TrackComputedDestinations(resourcetest.MustDecode[*pbmesh.ComputedExplicitDestinations](t, compDest)) // We will try to fetch explicit destinations for a proxy that doesn't have one. - destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) require.Nil(t, destinations) @@ -277,7 +275,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { WithTenancy(tenancy). Write(t, suite.client) - destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) require.Nil(t, destinations) cachedCompDestIDs := c.ComputedDestinationsByService(resource.IDFromReference(notFoundServiceRef)) @@ -307,7 +305,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { WithTenancy(tenancy). Write(t, suite.client) - destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) require.Nil(t, destinations) cachedCompDestIDs := c.ComputedDestinationsByService(resource.IDFromReference(api1ServiceRef)) @@ -337,7 +335,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { WithTenancy(tenancy). Write(t, suite.client) - destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) require.Nil(t, destinations) cachedCompDestIDs := c.ComputedDestinationsByService(resource.IDFromReference(api1ServiceRef)) @@ -369,7 +367,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { WithTenancy(tenancy). Write(t, suite.client) - destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) require.Empty(t, destinations) @@ -402,7 +400,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { require.NotNil(suite.T(), api1ComputedRoutes) // This destination points to TCP, but the computed routes is stale and only knows about HTTP. - destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) // Check that we didn't return any destinations. @@ -483,7 +481,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() { }, } - actualDestinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id) + actualDestinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id) require.NoError(t, err) // Check that we've computed expanded destinations correctly.