Reduce required type arguments for DecodedResource (#18540)

This commit is contained in:
Matt Keeler 2023-08-21 20:20:19 -04:00 committed by GitHub
parent 6d22179625
commit 547f4f8395
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 36 additions and 46 deletions

View File

@ -118,7 +118,7 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover
// FailoverPolicyMapper maintains the bidirectional tracking relationship of a
// FailoverPolicy to the Services related to it.
type FailoverPolicyMapper interface {
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
UntrackFailover(failoverID *pbresource.ID)
FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID
}

View File

@ -20,7 +20,7 @@ type FailoverMapper interface {
// TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service
// events into FailoverPolicy events properly.
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
// UntrackFailover forgets the links inserted by TrackFailover for the
// provided FailoverPolicyID.
@ -86,7 +86,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
rt.Logger.Error("error retrieving corresponding service", "error", err)
return err
}
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service])
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service])
if service != nil {
destServices[resource.NewReferenceKey(serviceID)] = service
}
@ -148,18 +148,18 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
return nil
}
func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id)
func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id)
}
func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) {
return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id)
func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) {
return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id)
}
func computeNewStatus(
failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy],
service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy],
service *resource.DecodedResource[*pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Status {
if service == nil {
return &pbresource.Status{
@ -238,7 +238,7 @@ func computeNewStatus(
func serviceHasPort(
dest *pbcatalog.FailoverDestination,
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Condition {
key := resource.NewReferenceKey(dest.Ref)
destService, ok := destServices[key]

View File

@ -31,7 +31,7 @@ func New() *Mapper {
// TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service events
// into FailoverPolicy events properly.
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) {
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) {
destRefs := failover.Data.GetUnderlyingDestinationRefs()
destRefs = append(destRefs, &pbresource.Reference{
Type: types.ServiceType,

View File

@ -59,7 +59,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail1)
failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1)
failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1)
fail2 := rtest.Resource(types.FailoverPolicyType, "www").
WithData(t, &pbcatalog.FailoverPolicy{
@ -72,7 +72,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail2)
failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2)
failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2)
fail1_updated := rtest.Resource(types.FailoverPolicyType, "api").
WithData(t, &pbcatalog.FailoverPolicy{
@ -84,7 +84,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail1_updated)
failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated)
failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated)
m := New()

View File

@ -31,7 +31,7 @@ func TestMutateFailoverPolicy(t *testing.T) {
err := MutateFailoverPolicy(res)
got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
if tc.expectErr == "" {
require.NoError(t, err)
@ -162,13 +162,13 @@ func TestValidateFailoverPolicy(t *testing.T) {
require.NoError(t, MutateFailoverPolicy(res))
// Verify that mutate didn't actually change the object.
got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data)
err := ValidateFailoverPolicy(res)
// Verify that validate didn't actually change the object.
got = resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got = resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data)
if tc.expectErr == "" {
@ -359,9 +359,9 @@ func TestSimplifyFailoverPolicy(t *testing.T) {
resourcetest.ValidateAndNormalize(t, registry, tc.failover)
resourcetest.ValidateAndNormalize(t, registry, tc.expect)
svc := resourcetest.MustDecode[pbcatalog.Service, *pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.expect)
svc := resourcetest.MustDecode[*pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.expect)
inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy)

View File

@ -15,27 +15,23 @@ import (
// DecodedResource is a generic holder to contain an original Resource and its
// decoded contents.
type DecodedResource[V any, PV interface {
proto.Message
*V
}] struct {
type DecodedResource[T proto.Message] struct {
Resource *pbresource.Resource
Data PV
Data T
}
// Decode will generically decode the provided resource into a 2-field
// structure that holds onto the original Resource and the decoded contents.
//
// Returns an ErrDataParse on unmarshalling errors.
func Decode[V any, PV interface {
proto.Message
*V
}](res *pbresource.Resource) (*DecodedResource[V, PV], error) {
data := PV(new(V))
func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], error) {
var zero T
data := zero.ProtoReflect().New().Interface().(T)
if err := res.Data.UnmarshalTo(data); err != nil {
return nil, NewErrDataParse(data, err)
}
return &DecodedResource[V, PV]{
return &DecodedResource[T]{
Resource: res,
Data: data,
}, nil
@ -43,10 +39,7 @@ func Decode[V any, PV interface {
// GetDecodedResource will generically read the requested resource using the
// client and either return nil on a NotFound or decode the response value.
func GetDecodedResource[V any, PV interface {
proto.Message
*V
}](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[V, PV], error) {
func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) {
rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id})
switch {
case status.Code(err) == codes.NotFound:
@ -55,5 +48,5 @@ func GetDecodedResource[V any, PV interface {
return nil, err
}
return Decode[V, PV](rsp.Resource)
return Decode[T](rsp.Resource)
}

View File

@ -34,7 +34,7 @@ func TestGetDecodedResource(t *testing.T) {
}
testutil.RunStep(t, "not found", func(t *testing.T) {
got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID)
got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
require.NoError(t, err)
require.Nil(t, got)
})
@ -47,7 +47,7 @@ func TestGetDecodedResource(t *testing.T) {
WithData(t, data).
Write(t, client)
got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID)
got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
require.NoError(t, err)
require.NotNil(t, got)
@ -84,7 +84,7 @@ func TestDecode(t *testing.T) {
},
}
dec, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo)
dec, err := resource.Decode[*pbdemo.Artist](foo)
require.NoError(t, err)
prototest.AssertDeepEqual(t, foo, dec.Resource)
@ -107,7 +107,7 @@ func TestDecode(t *testing.T) {
},
}
_, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo)
_, err := resource.Decode[*pbdemo.Artist](foo)
require.Error(t, err)
})
}

View File

@ -13,11 +13,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)
func MustDecode[V any, PV interface {
proto.Message
*V
}](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[V, PV] {
dec, err := resource.Decode[V, PV](res)
func MustDecode[T proto.Message](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[T] {
dec, err := resource.Decode[T](res)
require.NoError(t, err)
return dec
}