2023-08-22 17:23:54 -05:00
|
|
|
// Copyright (c) HashiCorp, Inc.
|
|
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
|
2023-08-11 15:52:51 -04:00
|
|
|
package protohcl
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
|
|
"google.golang.org/protobuf/reflect/protoregistry"
|
|
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
|
|
)
|
|
|
|
|
|
|
|
const wellKnownTypeAny = "google.protobuf.Any"
|
|
|
|
|
|
|
|
type AnyTypeProvider interface {
|
|
|
|
AnyType(*UnmarshalContext, MessageDecoder) (protoreflect.FullName, MessageDecoder, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
type AnyTypeURLProvider struct {
|
|
|
|
TypeURLFieldName string
|
|
|
|
}
|
|
|
|
|
|
|
|
func (p *AnyTypeURLProvider) AnyType(ctx *UnmarshalContext, decoder MessageDecoder) (protoreflect.FullName, MessageDecoder, error) {
|
|
|
|
typeURLFieldName := "type_url"
|
|
|
|
if p != nil {
|
|
|
|
typeURLFieldName = p.TypeURLFieldName
|
|
|
|
}
|
|
|
|
|
|
|
|
var typeURL *IterField
|
|
|
|
err := decoder.EachField(FieldIterator{
|
|
|
|
Desc: (&anypb.Any{}).ProtoReflect().Descriptor(),
|
|
|
|
Func: func(field *IterField) error {
|
|
|
|
if field.Name == typeURLFieldName {
|
|
|
|
typeURL = field
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
},
|
|
|
|
IgnoreUnknown: true,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return "", nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if typeURL == nil || typeURL.Val == nil {
|
|
|
|
return "", nil, fmt.Errorf("%s field is required to decode Any", typeURLFieldName)
|
|
|
|
}
|
|
|
|
|
|
|
|
url, err := stringFromCty(*typeURL.Val)
|
|
|
|
if err != nil {
|
|
|
|
return "", nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
slashIdx := strings.LastIndex(url, "/")
|
|
|
|
typeName := url
|
|
|
|
// strip all "hostname" parts of the URL path
|
|
|
|
if slashIdx > 1 && slashIdx+1 < len(url) {
|
|
|
|
typeName = url[slashIdx+1:]
|
|
|
|
}
|
|
|
|
|
|
|
|
return protoreflect.FullName(typeName), decoder.SkipFields(typeURLFieldName), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (u UnmarshalOptions) decodeAny(ctx *UnmarshalContext, decoder MessageDecoder, msg protoreflect.Message) error {
|
|
|
|
var typeProvider AnyTypeProvider = &AnyTypeURLProvider{TypeURLFieldName: "type_url"}
|
|
|
|
if u.AnyTypeProvider != nil {
|
|
|
|
typeProvider = u.AnyTypeProvider
|
|
|
|
}
|
|
|
|
|
|
|
|
var (
|
|
|
|
typeName protoreflect.FullName
|
|
|
|
err error
|
|
|
|
)
|
|
|
|
typeName, decoder, err = typeProvider.AnyType(ctx, decoder)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error getting type for Any field: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// the type.googleapis.come/ should be optional
|
|
|
|
mt, err := protoregistry.GlobalTypes.FindMessageByName(typeName)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error looking up type information for %s: %w", typeName, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
newMsg := mt.New()
|
|
|
|
|
|
|
|
err = u.decodeMessage(&UnmarshalContext{
|
|
|
|
Parent: ctx.Parent,
|
|
|
|
Name: ctx.Name,
|
|
|
|
Message: newMsg,
|
|
|
|
}, decoder, newMsg)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
enc, err := proto.Marshal(newMsg.Interface())
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error marshalling Any data as protobuf value: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
anyValue := msg.Interface().(*anypb.Any)
|
|
|
|
|
|
|
|
// This will look like <proto package>.<proto Message name> and not quite like a full URL with a path
|
|
|
|
anyValue.TypeUrl = string(newMsg.Descriptor().FullName())
|
|
|
|
anyValue.Value = enc
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func isAnyField(desc protoreflect.FieldDescriptor) bool {
|
|
|
|
if desc.Kind() != protoreflect.MessageKind {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
return desc.Message().FullName() == wellKnownTypeAny
|
|
|
|
}
|