[NET-7265] Panic when passing an incorrect type to the data fetcher for gatewayproxy (#20238)

* panic when passing an incorrect type to the data fetcher

* Add assertions for sidecarproxy datafetcher as well

* rename assertion function

* Add in comments to ensure devs know about potential panics for using
invalid types

* fix method call
This commit is contained in:
John Maguire 2024-01-24 14:16:56 -05:00 committed by GitHub
parent 1eca44aef9
commit cfe4d59938
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 100 additions and 21 deletions

View File

@ -120,7 +120,7 @@ func (r *reconciler) Reconcile(ctx context.Context, rt controller.Runtime, req c
}
var exportedServices []*pbmulticluster.ComputedExportedService
dec, err := dataFetcher.FetchExportedServices(ctx, exportedServicesID)
dec, err := dataFetcher.FetchComputedExportedServices(ctx, exportedServicesID)
if err != nil {
rt.Logger.Error("error reading the associated exported services", "error", err)
} else if dec == nil {

View File

@ -5,6 +5,7 @@ package fetcher
import (
"context"
"fmt"
"github.com/hashicorp/consul/internal/mesh/internal/controllers/sidecarproxy/cache"
"github.com/hashicorp/consul/internal/mesh/internal/types"
@ -13,6 +14,7 @@ import (
pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1"
pbmulticluster "github.com/hashicorp/consul/proto-public/pbmulticluster/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
"google.golang.org/protobuf/proto"
)
type Fetcher struct {
@ -27,7 +29,11 @@ func New(client pbresource.ResourceServiceClient, cache *cache.Cache) *Fetcher {
}
}
// FetchMeshGateway fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a MeshGateway type.
func (f *Fetcher) FetchMeshGateway(ctx context.Context, id *pbresource.ID) (*types.DecodedMeshGateway, error) {
assertResourceType(pbmesh.MeshGatewayType, id.Type)
dec, err := resource.GetDecodedResource[*pbmesh.MeshGateway](ctx, f.client, id)
if err != nil {
return nil, err
@ -38,7 +44,11 @@ func (f *Fetcher) FetchMeshGateway(ctx context.Context, id *pbresource.ID) (*typ
return dec, nil
}
// FetchProxyStateTemplate fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a ProxyStateTemplate type.
func (f *Fetcher) FetchProxyStateTemplate(ctx context.Context, id *pbresource.ID) (*types.DecodedProxyStateTemplate, error) {
assertResourceType(pbmesh.ProxyStateTemplateType, id.Type)
dec, err := resource.GetDecodedResource[*pbmesh.ProxyStateTemplate](ctx, f.client, id)
if err != nil {
return nil, err
@ -49,7 +59,11 @@ func (f *Fetcher) FetchProxyStateTemplate(ctx context.Context, id *pbresource.ID
return dec, nil
}
// FetchWorkload fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a Workload type.
func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types.DecodedWorkload, error) {
assertResourceType(pbcatalog.WorkloadType, id.Type)
dec, err := resource.GetDecodedResource[*pbcatalog.Workload](ctx, f.client, id)
if err != nil {
return nil, err
@ -60,7 +74,11 @@ func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types.
return dec, nil
}
func (f *Fetcher) FetchExportedServices(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedExportedServices, error) {
// FetchComputedExportedServices fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a ComputedExportedServices type.
func (f *Fetcher) FetchComputedExportedServices(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedExportedServices, error) {
assertResourceType(pbmulticluster.ComputedExportedServicesType, id.Type)
dec, err := resource.GetDecodedResource[*pbmulticluster.ComputedExportedServices](ctx, f.client, id)
if err != nil {
return nil, err
@ -71,7 +89,11 @@ func (f *Fetcher) FetchExportedServices(ctx context.Context, id *pbresource.ID)
return dec, nil
}
// FetchService fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a Service type.
func (f *Fetcher) FetchService(ctx context.Context, id *pbresource.ID) (*types.DecodedService, error) {
assertResourceType(pbcatalog.ServiceType, id.Type)
dec, err := resource.GetDecodedResource[*pbcatalog.Service](ctx, f.client, id)
if err != nil {
return nil, err
@ -81,3 +103,11 @@ func (f *Fetcher) FetchService(ctx context.Context, id *pbresource.ID) (*types.D
return dec, nil
}
// this is a helper function to ensure that the resource type we are querying for is the type we expect
func assertResourceType(expected, actual *pbresource.Type) {
if !proto.Equal(expected, actual) {
// this is always a programmer error so safe to panic
panic(fmt.Sprintf("expected a query for a type of %q, you provided a type of %q", expected, actual))
}
}

View File

@ -124,6 +124,15 @@ func (suite *dataFetcherSuite) TestFetcher_FetchMeshGateway() {
require.NoError(t, err)
require.NotNil(t, gtw)
})
testutil.RunStep(suite.T(), "incorrect type is passed", func(t *testing.T) {
incorrectID := resourcetest.Resource(pbcatalog.ServiceType, "api-1").ID()
defer func() {
err := recover()
require.NotNil(t, err)
}()
f.FetchMeshGateway(suite.ctx, incorrectID)
})
})
}
@ -148,6 +157,15 @@ func (suite *dataFetcherSuite) TestFetcher_FetchProxyStateTemplate() {
require.NoError(t, err)
require.NotNil(t, tmpl)
})
testutil.RunStep(suite.T(), "incorrect type is passed", func(t *testing.T) {
incorrectID := resourcetest.Resource(pbcatalog.ServiceType, "api-1").ID()
defer func() {
err := recover()
require.NotNil(t, err)
}()
f.FetchProxyStateTemplate(suite.ctx, incorrectID)
})
})
}
@ -186,13 +204,13 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExportedServices() {
testutil.RunStep(suite.T(), "exported services do not exist", func(t *testing.T) {
nonExistantID := resourcetest.Resource(pbmulticluster.ComputedExportedServicesType, "not-found").ID()
svcs, err := f.FetchExportedServices(suite.ctx, nonExistantID)
svcs, err := f.FetchComputedExportedServices(suite.ctx, nonExistantID)
require.NoError(t, err)
require.Nil(t, svcs)
})
testutil.RunStep(suite.T(), "workload exists", func(t *testing.T) {
svcs, err := f.FetchExportedServices(suite.ctx, suite.exportedServices.Id)
svcs, err := f.FetchComputedExportedServices(suite.ctx, suite.exportedServices.Id)
require.NoError(t, err)
require.NotNil(t, svcs)
})
@ -220,6 +238,15 @@ func (suite *dataFetcherSuite) TestFetcher_FetchService() {
require.NoError(t, err)
require.NotNil(t, svc)
})
testutil.RunStep(suite.T(), "incorrect type is passed", func(t *testing.T) {
incorrectID := resourcetest.Resource(pbmesh.ProxyStateTemplateType, "api-1").ID()
defer func() {
err := recover()
require.NotNil(t, err)
}()
f.FetchService(suite.ctx, incorrectID)
})
})
}

View File

@ -200,7 +200,7 @@ func (r *reconciler) Reconcile(ctx context.Context, rt controller.Runtime, req c
BuildLocalApp(workloadDataWithInheritedPorts, ctp)
// Get all destinationsData.
destinationsData, err := dataFetcher.FetchExplicitDestinationsData(ctx, req.ID)
destinationsData, err := dataFetcher.FetchComputedExplicitDestinationsData(ctx, req.ID)
if err != nil {
rt.Logger.Error("error fetching explicit destinations for this proxy", "error", err)
return err
@ -253,7 +253,6 @@ func (r *reconciler) workloadPortProtocolsFromService(
workload *types.DecodedWorkload,
logger hclog.Logger,
) (map[string]*pbcatalog.WorkloadPort, error) {
// Fetch all services for this workload.
serviceIDs := r.cache.ServicesForWorkload(workload.GetResource().GetId())

View File

@ -33,7 +33,10 @@ func New(client pbresource.ResourceServiceClient, cache *cache.Cache) *Fetcher {
}
}
// FetchWorkload fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a Workload type.
func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types.DecodedWorkload, error) {
assertResourceType(pbcatalog.WorkloadType, id.Type)
dec, err := resource.GetDecodedResource[*pbcatalog.Workload](ctx, f.client, id)
if err != nil {
return nil, err
@ -48,27 +51,45 @@ func (f *Fetcher) FetchWorkload(ctx context.Context, id *pbresource.ID) (*types.
return dec, err
}
// FetchProxyStateTemplate fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a ProxyStateTemplate type.
func (f *Fetcher) FetchProxyStateTemplate(ctx context.Context, id *pbresource.ID) (*types.DecodedProxyStateTemplate, error) {
assertResourceType(pbmesh.ProxyStateTemplateType, id.Type)
return resource.GetDecodedResource[*pbmesh.ProxyStateTemplate](ctx, f.client, id)
}
// FetchComputedTrafficPermissions fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a ComputedTrafficPermissons type.
func (f *Fetcher) FetchComputedTrafficPermissions(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedTrafficPermissions, error) {
assertResourceType(pbauth.ComputedTrafficPermissionsType, id.Type)
return resource.GetDecodedResource[*pbauth.ComputedTrafficPermissions](ctx, f.client, id)
}
// FetchServiceEndpoints fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a ServiceEndpoints type.
func (f *Fetcher) FetchServiceEndpoints(ctx context.Context, id *pbresource.ID) (*types.DecodedServiceEndpoints, error) {
assertResourceType(pbcatalog.ServiceEndpointsType, id.Type)
return resource.GetDecodedResource[*pbcatalog.ServiceEndpoints](ctx, f.client, id)
}
// FetchService fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a Service type.
func (f *Fetcher) FetchService(ctx context.Context, id *pbresource.ID) (*types.DecodedService, error) {
assertResourceType(pbcatalog.ServiceType, id.Type)
return resource.GetDecodedResource[*pbcatalog.Service](ctx, f.client, id)
}
// FetchDestinations fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a Destinations type.
func (f *Fetcher) FetchDestinations(ctx context.Context, id *pbresource.ID) (*types.DecodedDestinations, error) {
assertResourceType(pbmesh.DestinationsType, id.Type)
return resource.GetDecodedResource[*pbmesh.Destinations](ctx, f.client, id)
}
// FetchComputedRoutes fetches a service resource from the resource service.
// This will panic if the type field in the ID argument is not a ComputedRoutes type.
func (f *Fetcher) FetchComputedRoutes(ctx context.Context, id *pbresource.ID) (*types.DecodedComputedRoutes, error) {
assertResourceType(pbmesh.ComputedRoutesType, id.Type)
if !types.IsComputedRoutesType(id.Type) {
return nil, fmt.Errorf("id must be a ComputedRoutes type")
}
@ -83,11 +104,10 @@ func (f *Fetcher) FetchComputedRoutes(ctx context.Context, id *pbresource.ID) (*
return dec, err
}
func (f *Fetcher) FetchExplicitDestinationsData(
func (f *Fetcher) FetchComputedExplicitDestinationsData(
ctx context.Context,
proxyID *pbresource.ID,
) ([]*intermediateTypes.Destination, error) {
var destinations []*intermediateTypes.Destination
// Fetch computed explicit destinations first.
@ -107,9 +127,7 @@ func (f *Fetcher) FetchExplicitDestinationsData(
for _, dest := range cd.GetData().GetDestinations() {
d := &intermediateTypes.Destination{}
var (
serviceID = resource.IDFromReference(dest.DestinationRef)
)
serviceID := resource.IDFromReference(dest.DestinationRef)
// Fetch Service
svc, err := f.FetchService(ctx, serviceID)
@ -372,3 +390,10 @@ func isPartOfService(workloadID *pbresource.ID, svc *types.DecodedService) bool
}
return false
}
func assertResourceType(expected, actual *pbresource.Type) {
if !proto.Equal(expected, actual) {
// this is always a programmer error so safe to panic
panic(fmt.Sprintf("expected a query for a type of %q, you provided a type of %q", expected, actual))
}
}

View File

@ -227,9 +227,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) {
c := cache.New()
var (
api1ServiceRef = resource.Reference(suite.api1Service.Id, "")
)
api1ServiceRef := resource.Reference(suite.api1Service.Id, "")
f := Fetcher{
cache: c,
@ -252,7 +250,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
c.TrackComputedDestinations(resourcetest.MustDecode[*pbmesh.ComputedExplicitDestinations](t, compDest))
// We will try to fetch explicit destinations for a proxy that doesn't have one.
destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
require.Nil(t, destinations)
@ -277,7 +275,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
WithTenancy(tenancy).
Write(t, suite.client)
destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
require.Nil(t, destinations)
cachedCompDestIDs := c.ComputedDestinationsByService(resource.IDFromReference(notFoundServiceRef))
@ -307,7 +305,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
WithTenancy(tenancy).
Write(t, suite.client)
destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
require.Nil(t, destinations)
cachedCompDestIDs := c.ComputedDestinationsByService(resource.IDFromReference(api1ServiceRef))
@ -337,7 +335,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
WithTenancy(tenancy).
Write(t, suite.client)
destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
require.Nil(t, destinations)
cachedCompDestIDs := c.ComputedDestinationsByService(resource.IDFromReference(api1ServiceRef))
@ -369,7 +367,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
WithTenancy(tenancy).
Write(t, suite.client)
destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
require.Empty(t, destinations)
@ -402,7 +400,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
require.NotNil(suite.T(), api1ComputedRoutes)
// This destination points to TCP, but the computed routes is stale and only knows about HTTP.
destinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
destinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
// Check that we didn't return any destinations.
@ -483,7 +481,7 @@ func (suite *dataFetcherSuite) TestFetcher_FetchExplicitDestinationsData() {
},
}
actualDestinations, err := f.FetchExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
actualDestinations, err := f.FetchComputedExplicitDestinationsData(suite.ctx, suite.webProxy.Id)
require.NoError(t, err)
// Check that we've computed expanded destinations correctly.