118 lines
2.9 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package protohcl
import (
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/anypb"
)
const wellKnownTypeAny = "google.protobuf.Any"
type AnyTypeProvider interface {
AnyType(*UnmarshalContext, MessageDecoder) (protoreflect.FullName, MessageDecoder, error)
}
type AnyTypeURLProvider struct {
TypeURLFieldName string
}
func (p *AnyTypeURLProvider) AnyType(ctx *UnmarshalContext, decoder MessageDecoder) (protoreflect.FullName, MessageDecoder, error) {
typeURLFieldName := "type_url"
if p != nil {
typeURLFieldName = p.TypeURLFieldName
}
var typeURL *IterField
err := decoder.EachField(FieldIterator{
Desc: (&anypb.Any{}).ProtoReflect().Descriptor(),
Func: func(field *IterField) error {
if field.Name == typeURLFieldName {
typeURL = field
}
return nil
},
IgnoreUnknown: true,
})
if err != nil {
return "", nil, err
}
if typeURL == nil || typeURL.Val == nil {
return "", nil, fmt.Errorf("%s field is required to decode Any", typeURLFieldName)
}
url, err := stringFromCty(*typeURL.Val)
if err != nil {
return "", nil, err
}
slashIdx := strings.LastIndex(url, "/")
typeName := url
// strip all "hostname" parts of the URL path
if slashIdx > 1 && slashIdx+1 < len(url) {
typeName = url[slashIdx+1:]
}
return protoreflect.FullName(typeName), decoder.SkipFields(typeURLFieldName), nil
}
func (u UnmarshalOptions) decodeAny(ctx *UnmarshalContext, decoder MessageDecoder, msg protoreflect.Message) error {
var typeProvider AnyTypeProvider = &AnyTypeURLProvider{TypeURLFieldName: "type_url"}
if u.AnyTypeProvider != nil {
typeProvider = u.AnyTypeProvider
}
var (
typeName protoreflect.FullName
err error
)
typeName, decoder, err = typeProvider.AnyType(ctx, decoder)
if err != nil {
return fmt.Errorf("error getting type for Any field: %w", err)
}
// the type.googleapis.come/ should be optional
mt, err := protoregistry.GlobalTypes.FindMessageByName(typeName)
if err != nil {
return fmt.Errorf("error looking up type information for %s: %w", typeName, err)
}
newMsg := mt.New()
err = u.decodeMessage(&UnmarshalContext{
Parent: ctx.Parent,
Name: ctx.Name,
Message: newMsg,
}, decoder, newMsg)
if err != nil {
return err
}
enc, err := proto.Marshal(newMsg.Interface())
if err != nil {
return fmt.Errorf("error marshalling Any data as protobuf value: %w", err)
}
anyValue := msg.Interface().(*anypb.Any)
// This will look like <proto package>.<proto Message name> and not quite like a full URL with a path
anyValue.TypeUrl = string(newMsg.Descriptor().FullName())
anyValue.Value = enc
return nil
}
func isAnyField(desc protoreflect.FieldDescriptor) bool {
if desc.Kind() != protoreflect.MessageKind {
return false
}
return desc.Message().FullName() == wellKnownTypeAny
}