Reduce required type arguments for DecodedResource (#18540)

This commit is contained in:
Matt Keeler 2023-08-21 20:20:19 -04:00 committed by GitHub
parent 6d22179625
commit 547f4f8395
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 36 additions and 46 deletions

View File

@ -118,7 +118,7 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover
// FailoverPolicyMapper maintains the bidirectional tracking relationship of a // FailoverPolicyMapper maintains the bidirectional tracking relationship of a
// FailoverPolicy to the Services related to it. // FailoverPolicy to the Services related to it.
type FailoverPolicyMapper interface { type FailoverPolicyMapper interface {
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
UntrackFailover(failoverID *pbresource.ID) UntrackFailover(failoverID *pbresource.ID)
FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID
} }

View File

@ -20,7 +20,7 @@ type FailoverMapper interface {
// TrackFailover extracts all Service references from the provided // TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service // FailoverPolicy and indexes them so that MapService can turn Service
// events into FailoverPolicy events properly. // events into FailoverPolicy events properly.
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
// UntrackFailover forgets the links inserted by TrackFailover for the // UntrackFailover forgets the links inserted by TrackFailover for the
// provided FailoverPolicyID. // provided FailoverPolicyID.
@ -86,7 +86,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
rt.Logger.Error("error retrieving corresponding service", "error", err) rt.Logger.Error("error retrieving corresponding service", "error", err)
return err return err
} }
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service]) destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service])
if service != nil { if service != nil {
destServices[resource.NewReferenceKey(serviceID)] = service destServices[resource.NewReferenceKey(serviceID)] = service
} }
@ -148,18 +148,18 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
return nil return nil
} }
func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) { func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id) return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id)
} }
func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) { func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) {
return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id) return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id)
} }
func computeNewStatus( func computeNewStatus(
failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy],
service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], service *resource.DecodedResource[*pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Status { ) *pbresource.Status {
if service == nil { if service == nil {
return &pbresource.Status{ return &pbresource.Status{
@ -238,7 +238,7 @@ func computeNewStatus(
func serviceHasPort( func serviceHasPort(
dest *pbcatalog.FailoverDestination, dest *pbcatalog.FailoverDestination,
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Condition { ) *pbresource.Condition {
key := resource.NewReferenceKey(dest.Ref) key := resource.NewReferenceKey(dest.Ref)
destService, ok := destServices[key] destService, ok := destServices[key]

View File

@ -31,7 +31,7 @@ func New() *Mapper {
// TrackFailover extracts all Service references from the provided // TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service events // FailoverPolicy and indexes them so that MapService can turn Service events
// into FailoverPolicy events properly. // into FailoverPolicy events properly.
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) { func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) {
destRefs := failover.Data.GetUnderlyingDestinationRefs() destRefs := failover.Data.GetUnderlyingDestinationRefs()
destRefs = append(destRefs, &pbresource.Reference{ destRefs = append(destRefs, &pbresource.Reference{
Type: types.ServiceType, Type: types.ServiceType,

View File

@ -59,7 +59,7 @@ func TestMapper_Tracking(t *testing.T) {
}). }).
Build() Build()
rtest.ValidateAndNormalize(t, registry, fail1) rtest.ValidateAndNormalize(t, registry, fail1)
failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1) failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1)
fail2 := rtest.Resource(types.FailoverPolicyType, "www"). fail2 := rtest.Resource(types.FailoverPolicyType, "www").
WithData(t, &pbcatalog.FailoverPolicy{ WithData(t, &pbcatalog.FailoverPolicy{
@ -72,7 +72,7 @@ func TestMapper_Tracking(t *testing.T) {
}). }).
Build() Build()
rtest.ValidateAndNormalize(t, registry, fail2) rtest.ValidateAndNormalize(t, registry, fail2)
failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2) failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2)
fail1_updated := rtest.Resource(types.FailoverPolicyType, "api"). fail1_updated := rtest.Resource(types.FailoverPolicyType, "api").
WithData(t, &pbcatalog.FailoverPolicy{ WithData(t, &pbcatalog.FailoverPolicy{
@ -84,7 +84,7 @@ func TestMapper_Tracking(t *testing.T) {
}). }).
Build() Build()
rtest.ValidateAndNormalize(t, registry, fail1_updated) rtest.ValidateAndNormalize(t, registry, fail1_updated)
failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated) failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated)
m := New() m := New()

View File

@ -31,7 +31,7 @@ func TestMutateFailoverPolicy(t *testing.T) {
err := MutateFailoverPolicy(res) err := MutateFailoverPolicy(res)
got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res) got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
if tc.expectErr == "" { if tc.expectErr == "" {
require.NoError(t, err) require.NoError(t, err)
@ -162,13 +162,13 @@ func TestValidateFailoverPolicy(t *testing.T) {
require.NoError(t, MutateFailoverPolicy(res)) require.NoError(t, MutateFailoverPolicy(res))
// Verify that mutate didn't actually change the object. // Verify that mutate didn't actually change the object.
got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res) got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data) prototest.AssertDeepEqual(t, tc.failover, got.Data)
err := ValidateFailoverPolicy(res) err := ValidateFailoverPolicy(res)
// Verify that validate didn't actually change the object. // Verify that validate didn't actually change the object.
got = resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res) got = resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data) prototest.AssertDeepEqual(t, tc.failover, got.Data)
if tc.expectErr == "" { if tc.expectErr == "" {
@ -359,9 +359,9 @@ func TestSimplifyFailoverPolicy(t *testing.T) {
resourcetest.ValidateAndNormalize(t, registry, tc.failover) resourcetest.ValidateAndNormalize(t, registry, tc.failover)
resourcetest.ValidateAndNormalize(t, registry, tc.expect) resourcetest.ValidateAndNormalize(t, registry, tc.expect)
svc := resourcetest.MustDecode[pbcatalog.Service, *pbcatalog.Service](t, tc.svc) svc := resourcetest.MustDecode[*pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.failover) failover := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.expect) expect := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.expect)
inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy) inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy)

View File

@ -15,27 +15,23 @@ import (
// DecodedResource is a generic holder to contain an original Resource and its // DecodedResource is a generic holder to contain an original Resource and its
// decoded contents. // decoded contents.
type DecodedResource[V any, PV interface { type DecodedResource[T proto.Message] struct {
proto.Message
*V
}] struct {
Resource *pbresource.Resource Resource *pbresource.Resource
Data PV Data T
} }
// Decode will generically decode the provided resource into a 2-field // Decode will generically decode the provided resource into a 2-field
// structure that holds onto the original Resource and the decoded contents. // structure that holds onto the original Resource and the decoded contents.
// //
// Returns an ErrDataParse on unmarshalling errors. // Returns an ErrDataParse on unmarshalling errors.
func Decode[V any, PV interface { func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], error) {
proto.Message var zero T
*V data := zero.ProtoReflect().New().Interface().(T)
}](res *pbresource.Resource) (*DecodedResource[V, PV], error) {
data := PV(new(V))
if err := res.Data.UnmarshalTo(data); err != nil { if err := res.Data.UnmarshalTo(data); err != nil {
return nil, NewErrDataParse(data, err) return nil, NewErrDataParse(data, err)
} }
return &DecodedResource[V, PV]{ return &DecodedResource[T]{
Resource: res, Resource: res,
Data: data, Data: data,
}, nil }, nil
@ -43,10 +39,7 @@ func Decode[V any, PV interface {
// GetDecodedResource will generically read the requested resource using the // GetDecodedResource will generically read the requested resource using the
// client and either return nil on a NotFound or decode the response value. // client and either return nil on a NotFound or decode the response value.
func GetDecodedResource[V any, PV interface { func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) {
proto.Message
*V
}](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[V, PV], error) {
rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id}) rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id})
switch { switch {
case status.Code(err) == codes.NotFound: case status.Code(err) == codes.NotFound:
@ -55,5 +48,5 @@ func GetDecodedResource[V any, PV interface {
return nil, err return nil, err
} }
return Decode[V, PV](rsp.Resource) return Decode[T](rsp.Resource)
} }

View File

@ -34,7 +34,7 @@ func TestGetDecodedResource(t *testing.T) {
} }
testutil.RunStep(t, "not found", func(t *testing.T) { testutil.RunStep(t, "not found", func(t *testing.T) {
got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID) got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, got) require.Nil(t, got)
}) })
@ -47,7 +47,7 @@ func TestGetDecodedResource(t *testing.T) {
WithData(t, data). WithData(t, data).
Write(t, client) Write(t, client)
got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID) got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, got) require.NotNil(t, got)
@ -84,7 +84,7 @@ func TestDecode(t *testing.T) {
}, },
} }
dec, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo) dec, err := resource.Decode[*pbdemo.Artist](foo)
require.NoError(t, err) require.NoError(t, err)
prototest.AssertDeepEqual(t, foo, dec.Resource) prototest.AssertDeepEqual(t, foo, dec.Resource)
@ -107,7 +107,7 @@ func TestDecode(t *testing.T) {
}, },
} }
_, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo) _, err := resource.Decode[*pbdemo.Artist](foo)
require.Error(t, err) require.Error(t, err)
}) })
} }

View File

@ -13,11 +13,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/proto-public/pbresource"
) )
func MustDecode[V any, PV interface { func MustDecode[T proto.Message](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[T] {
proto.Message dec, err := resource.Decode[T](res)
*V
}](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[V, PV] {
dec, err := resource.Decode[V, PV](res)
require.NoError(t, err) require.NoError(t, err)
return dec return dec
} }