diff --git a/internal/catalog/exports.go b/internal/catalog/exports.go index 4019fbeb51..566c2e2b6e 100644 --- a/internal/catalog/exports.go +++ b/internal/catalog/exports.go @@ -118,7 +118,7 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover // FailoverPolicyMapper maintains the bidirectional tracking relationship of a // FailoverPolicy to the Services related to it. type FailoverPolicyMapper interface { - TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) + TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) UntrackFailover(failoverID *pbresource.ID) FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID } diff --git a/internal/catalog/internal/controllers/failover/controller.go b/internal/catalog/internal/controllers/failover/controller.go index ea6efa992d..9accb62aa4 100644 --- a/internal/catalog/internal/controllers/failover/controller.go +++ b/internal/catalog/internal/controllers/failover/controller.go @@ -20,7 +20,7 @@ type FailoverMapper interface { // TrackFailover extracts all Service references from the provided // FailoverPolicy and indexes them so that MapService can turn Service // events into FailoverPolicy events properly. - TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) + TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) // UntrackFailover forgets the links inserted by TrackFailover for the // provided FailoverPolicyID. @@ -86,7 +86,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. rt.Logger.Error("error retrieving corresponding service", "error", err) return err } - destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service]) + destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service]) if service != nil { destServices[resource.NewReferenceKey(serviceID)] = service } @@ -148,18 +148,18 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. return nil } -func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) { - return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id) +func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) { + return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id) } -func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) { - return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id) +func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) { + return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id) } func computeNewStatus( - failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], - service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], - destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], + failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy], + service *resource.DecodedResource[*pbcatalog.Service], + destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service], ) *pbresource.Status { if service == nil { return &pbresource.Status{ @@ -238,7 +238,7 @@ func computeNewStatus( func serviceHasPort( dest *pbcatalog.FailoverDestination, - destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], + destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service], ) *pbresource.Condition { key := resource.NewReferenceKey(dest.Ref) destService, ok := destServices[key] diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go index 5c23a1bfe3..4ae6776cb6 100644 --- a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go +++ b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go @@ -31,7 +31,7 @@ func New() *Mapper { // TrackFailover extracts all Service references from the provided // FailoverPolicy and indexes them so that MapService can turn Service events // into FailoverPolicy events properly. -func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) { +func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) { destRefs := failover.Data.GetUnderlyingDestinationRefs() destRefs = append(destRefs, &pbresource.Reference{ Type: types.ServiceType, diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go index 8a4ac2d722..048f444eca 100644 --- a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go +++ b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go @@ -59,7 +59,7 @@ func TestMapper_Tracking(t *testing.T) { }). Build() rtest.ValidateAndNormalize(t, registry, fail1) - failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1) + failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1) fail2 := rtest.Resource(types.FailoverPolicyType, "www"). WithData(t, &pbcatalog.FailoverPolicy{ @@ -72,7 +72,7 @@ func TestMapper_Tracking(t *testing.T) { }). Build() rtest.ValidateAndNormalize(t, registry, fail2) - failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2) + failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2) fail1_updated := rtest.Resource(types.FailoverPolicyType, "api"). WithData(t, &pbcatalog.FailoverPolicy{ @@ -84,7 +84,7 @@ func TestMapper_Tracking(t *testing.T) { }). Build() rtest.ValidateAndNormalize(t, registry, fail1_updated) - failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated) + failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated) m := New() diff --git a/internal/catalog/internal/types/failover_policy_test.go b/internal/catalog/internal/types/failover_policy_test.go index 8f2ad97172..41bfd3d827 100644 --- a/internal/catalog/internal/types/failover_policy_test.go +++ b/internal/catalog/internal/types/failover_policy_test.go @@ -31,7 +31,7 @@ func TestMutateFailoverPolicy(t *testing.T) { err := MutateFailoverPolicy(res) - got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res) + got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res) if tc.expectErr == "" { require.NoError(t, err) @@ -162,13 +162,13 @@ func TestValidateFailoverPolicy(t *testing.T) { require.NoError(t, MutateFailoverPolicy(res)) // Verify that mutate didn't actually change the object. - got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res) + got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res) prototest.AssertDeepEqual(t, tc.failover, got.Data) err := ValidateFailoverPolicy(res) // Verify that validate didn't actually change the object. - got = resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res) + got = resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res) prototest.AssertDeepEqual(t, tc.failover, got.Data) if tc.expectErr == "" { @@ -359,9 +359,9 @@ func TestSimplifyFailoverPolicy(t *testing.T) { resourcetest.ValidateAndNormalize(t, registry, tc.failover) resourcetest.ValidateAndNormalize(t, registry, tc.expect) - svc := resourcetest.MustDecode[pbcatalog.Service, *pbcatalog.Service](t, tc.svc) - failover := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.failover) - expect := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.expect) + svc := resourcetest.MustDecode[*pbcatalog.Service](t, tc.svc) + failover := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.failover) + expect := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.expect) inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy) diff --git a/internal/resource/decode.go b/internal/resource/decode.go index 7b1fb7b364..c610898ca3 100644 --- a/internal/resource/decode.go +++ b/internal/resource/decode.go @@ -15,27 +15,23 @@ import ( // DecodedResource is a generic holder to contain an original Resource and its // decoded contents. -type DecodedResource[V any, PV interface { - proto.Message - *V -}] struct { +type DecodedResource[T proto.Message] struct { Resource *pbresource.Resource - Data PV + Data T } // Decode will generically decode the provided resource into a 2-field // structure that holds onto the original Resource and the decoded contents. // // Returns an ErrDataParse on unmarshalling errors. -func Decode[V any, PV interface { - proto.Message - *V -}](res *pbresource.Resource) (*DecodedResource[V, PV], error) { - data := PV(new(V)) +func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], error) { + var zero T + data := zero.ProtoReflect().New().Interface().(T) + if err := res.Data.UnmarshalTo(data); err != nil { return nil, NewErrDataParse(data, err) } - return &DecodedResource[V, PV]{ + return &DecodedResource[T]{ Resource: res, Data: data, }, nil @@ -43,10 +39,7 @@ func Decode[V any, PV interface { // GetDecodedResource will generically read the requested resource using the // client and either return nil on a NotFound or decode the response value. -func GetDecodedResource[V any, PV interface { - proto.Message - *V -}](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[V, PV], error) { +func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) { rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id}) switch { case status.Code(err) == codes.NotFound: @@ -55,5 +48,5 @@ func GetDecodedResource[V any, PV interface { return nil, err } - return Decode[V, PV](rsp.Resource) + return Decode[T](rsp.Resource) } diff --git a/internal/resource/decode_test.go b/internal/resource/decode_test.go index 31ebe47c64..17c1bd7f1b 100644 --- a/internal/resource/decode_test.go +++ b/internal/resource/decode_test.go @@ -34,7 +34,7 @@ func TestGetDecodedResource(t *testing.T) { } testutil.RunStep(t, "not found", func(t *testing.T) { - got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID) + got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID) require.NoError(t, err) require.Nil(t, got) }) @@ -47,7 +47,7 @@ func TestGetDecodedResource(t *testing.T) { WithData(t, data). Write(t, client) - got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID) + got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID) require.NoError(t, err) require.NotNil(t, got) @@ -84,7 +84,7 @@ func TestDecode(t *testing.T) { }, } - dec, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo) + dec, err := resource.Decode[*pbdemo.Artist](foo) require.NoError(t, err) prototest.AssertDeepEqual(t, foo, dec.Resource) @@ -107,7 +107,7 @@ func TestDecode(t *testing.T) { }, } - _, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo) + _, err := resource.Decode[*pbdemo.Artist](foo) require.Error(t, err) }) } diff --git a/internal/resource/resourcetest/decode.go b/internal/resource/resourcetest/decode.go index 077bbc0dd5..d68fff8655 100644 --- a/internal/resource/resourcetest/decode.go +++ b/internal/resource/resourcetest/decode.go @@ -13,11 +13,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) -func MustDecode[V any, PV interface { - proto.Message - *V -}](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[V, PV] { - dec, err := resource.Decode[V, PV](res) +func MustDecode[T proto.Message](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[T] { + dec, err := resource.Decode[T](res) require.NoError(t, err) return dec }