536 lines
15 KiB
Go
Raw Normal View History

// Copyright (c) 2019-2021 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package dig
import (
"fmt"
"reflect"
"go.uber.org/dig/internal/digerror"
"go.uber.org/dig/internal/dot"
)
// The result interface represents a result produced by a constructor.
//
// The following implementations exist:
// resultList All values returned by the constructor.
// resultSingle A single value produced by a constructor.
// resultObject dig.Out struct where each field in the struct can be
// another result.
// resultGrouped A value produced by a constructor that is part of a value
// group.
type result interface {
// Extracts the values for this result from the provided value and
// stores them into the provided containerWriter.
//
// This MAY panic if the result does not consume a single value.
Extract(containerWriter, bool, reflect.Value)
// DotResult returns a slice of dot.Result(s).
DotResult() []*dot.Result
}
var (
_ result = resultSingle{}
_ result = resultObject{}
_ result = resultList{}
_ result = resultGrouped{}
)
type resultOptions struct {
// If set, this is the name of the associated result value.
//
// For Result Objects, name:".." tags on fields override this.
Name string
Group string
As []interface{}
}
// newResult builds a result from the given type.
func newResult(t reflect.Type, opts resultOptions) (result, error) {
switch {
case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType):
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput(fmt.Sprintf(
"cannot provide parameter objects: %v embeds a dig.In", t), nil)
case isError(t):
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput("cannot return an error here, return it from the constructor instead", nil)
case IsOut(t):
return newResultObject(t, opts)
case embedsType(t, _outPtrType):
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput(fmt.Sprintf(
"cannot build a result object by embedding *dig.Out, embed dig.Out instead: %v embeds *dig.Out", t), nil)
case t.Kind() == reflect.Ptr && IsOut(t.Elem()):
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput(fmt.Sprintf(
"cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil)
case len(opts.Group) > 0:
g, err := parseGroupString(opts.Group)
if err != nil {
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput(
fmt.Sprintf("cannot parse group %q", opts.Group), err)
}
rg := resultGrouped{Type: t, Group: g.Name, Flatten: g.Flatten}
if len(opts.As) > 0 {
var asTypes []reflect.Type
for _, as := range opts.As {
ifaceType := reflect.TypeOf(as).Elem()
if ifaceType == t {
continue
}
if !t.Implements(ifaceType) {
return nil, newErrInvalidInput(
fmt.Sprintf("invalid dig.As: %v does not implement %v", t, ifaceType), nil)
}
asTypes = append(asTypes, ifaceType)
}
if len(asTypes) > 0 {
rg.Type = asTypes[0]
rg.As = asTypes[1:]
}
}
if g.Soft {
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput(fmt.Sprintf(
"cannot use soft with result value groups: soft was used with group:%q", g.Name), nil)
}
if g.Flatten {
if t.Kind() != reflect.Slice {
2023-05-19 16:23:55 -04:00
return nil, newErrInvalidInput(fmt.Sprintf(
"flatten can be applied to slices only: %v is not a slice", t), nil)
}
rg.Type = rg.Type.Elem()
}
return rg, nil
default:
return newResultSingle(t, opts)
}
}
// resultVisitor visits every result in a result tree, allowing tracking state
// at each level.
type resultVisitor interface {
// Visit is called on the result being visited.
//
// If Visit returns a non-nil resultVisitor, that resultVisitor visits all
// the child results of this result.
Visit(result) resultVisitor
// AnnotateWithField is called on each field of a resultObject after
// visiting it but before walking its descendants.
//
// The same resultVisitor is used for all fields: the one returned upon
// visiting the resultObject.
//
// For each visited field, if AnnotateWithField returns a non-nil
// resultVisitor, it will be used to walk the result of that field.
AnnotateWithField(resultObjectField) resultVisitor
// AnnotateWithPosition is called with the index of each result of a
// resultList after vising it but before walking its descendants.
//
// The same resultVisitor is used for all results: the one returned upon
// visiting the resultList.
//
// For each position, if AnnotateWithPosition returns a non-nil
// resultVisitor, it will be used to walk the result at that index.
AnnotateWithPosition(idx int) resultVisitor
}
// walkResult walks the result tree for the given result with the provided
// visitor.
//
// resultVisitor.Visit will be called on the provided result and if a non-nil
// resultVisitor is received, it will be used to walk its descendants. If a
// resultObject or resultList was visited, AnnotateWithField and
// AnnotateWithPosition respectively will be called before visiting the
// descendants of that resultObject/resultList.
//
// This is very similar to how go/ast.Walk works.
func walkResult(r result, v resultVisitor) {
v = v.Visit(r)
if v == nil {
return
}
switch res := r.(type) {
case resultSingle, resultGrouped:
// No sub-results
case resultObject:
w := v
for _, f := range res.Fields {
if v := w.AnnotateWithField(f); v != nil {
walkResult(f.Result, v)
}
}
case resultList:
w := v
for i, r := range res.Results {
if v := w.AnnotateWithPosition(i); v != nil {
walkResult(r, v)
}
}
default:
digerror.BugPanicf("received unknown result type %T", res)
}
}
// resultList holds all values returned by the constructor as results.
type resultList struct {
ctype reflect.Type
Results []result
// For each item at index i returned by the constructor, resultIndexes[i]
// is the index in .Results for the corresponding result object.
// resultIndexes[i] is -1 for errors returned by constructors.
resultIndexes []int
}
func (rl resultList) DotResult() []*dot.Result {
var types []*dot.Result
for _, result := range rl.Results {
types = append(types, result.DotResult()...)
}
return types
}
func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) {
numOut := ctype.NumOut()
rl := resultList{
ctype: ctype,
Results: make([]result, 0, numOut),
resultIndexes: make([]int, numOut),
}
resultIdx := 0
for i := 0; i < numOut; i++ {
t := ctype.Out(i)
if isError(t) {
rl.resultIndexes[i] = -1
continue
}
r, err := newResult(t, opts)
if err != nil {
2023-05-19 16:23:55 -04:00
return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err)
}
rl.Results = append(rl.Results, r)
rl.resultIndexes[i] = resultIdx
resultIdx++
}
return rl, nil
}
func (resultList) Extract(containerWriter, bool, reflect.Value) {
digerror.BugPanicf("resultList.Extract() must never be called")
}
func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error {
for i, v := range values {
if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 {
rl.Results[resultIdx].Extract(cw, decorated, v)
continue
}
if err, _ := v.Interface().(error); err != nil {
return err
}
}
return nil
}
// resultSingle is an explicit value produced by a constructor, optionally
// with a name.
//
// This object will be added to the graph as-is.
type resultSingle struct {
Name string
Type reflect.Type
// If specified, this is a list of types which the value will be made
// available as, in addition to its own type.
As []reflect.Type
}
func newResultSingle(t reflect.Type, opts resultOptions) (resultSingle, error) {
r := resultSingle{
Type: t,
Name: opts.Name,
}
var asTypes []reflect.Type
for _, as := range opts.As {
ifaceType := reflect.TypeOf(as).Elem()
if ifaceType == t {
// Special case:
// c.Provide(func() io.Reader, As(new(io.Reader)))
// Ignore instead of erroring out.
continue
}
if !t.Implements(ifaceType) {
2023-05-19 16:23:55 -04:00
return r, newErrInvalidInput(
fmt.Sprintf("invalid dig.As: %v does not implement %v", t, ifaceType), nil)
}
asTypes = append(asTypes, ifaceType)
}
if len(asTypes) == 0 {
return r, nil
}
return resultSingle{
Type: asTypes[0],
Name: opts.Name,
As: asTypes[1:],
}, nil
}
func (rs resultSingle) DotResult() []*dot.Result {
dotResults := make([]*dot.Result, 0, len(rs.As)+1)
dotResults = append(dotResults, &dot.Result{
Node: &dot.Node{
Type: rs.Type,
Name: rs.Name,
},
})
for _, asType := range rs.As {
dotResults = append(dotResults, &dot.Result{
Node: &dot.Node{Type: asType, Name: rs.Name},
})
}
return dotResults
}
func (rs resultSingle) Extract(cw containerWriter, decorated bool, v reflect.Value) {
if decorated {
cw.setDecoratedValue(rs.Name, rs.Type, v)
return
}
cw.setValue(rs.Name, rs.Type, v)
for _, asType := range rs.As {
cw.setValue(rs.Name, asType, v)
}
}
// resultObject is a dig.Out struct where each field is another result.
//
// This object is not added to the graph. Its fields are interpreted as
// results and added to the graph if needed.
type resultObject struct {
Type reflect.Type
Fields []resultObjectField
}
func (ro resultObject) DotResult() []*dot.Result {
var types []*dot.Result
for _, field := range ro.Fields {
types = append(types, field.DotResult()...)
}
return types
}
func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) {
ro := resultObject{Type: t}
if len(opts.Name) > 0 {
2023-05-19 16:23:55 -04:00
return ro, newErrInvalidInput(fmt.Sprintf(
"cannot specify a name for result objects: %v embeds dig.Out", t), nil)
}
if len(opts.Group) > 0 {
2023-05-19 16:23:55 -04:00
return ro, newErrInvalidInput(fmt.Sprintf(
"cannot specify a group for result objects: %v embeds dig.Out", t), nil)
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type == _outType {
// Skip over the dig.Out embed.
continue
}
rof, err := newResultObjectField(i, f, opts)
if err != nil {
2023-05-19 16:23:55 -04:00
return ro, newErrInvalidInput(fmt.Sprintf("bad field %q of %v", f.Name, t), err)
}
ro.Fields = append(ro.Fields, rof)
}
return ro, nil
}
func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) {
for _, f := range ro.Fields {
f.Result.Extract(cw, decorated, v.Field(f.FieldIndex))
}
}
// resultObjectField is a single field inside a dig.Out struct.
type resultObjectField struct {
// Name of the field in the struct.
FieldName string
// Index of the field in the struct.
//
// We need to track this separately because not all fields of the struct
// map to results.
FieldIndex int
// Result produced by this field.
Result result
}
func (rof resultObjectField) DotResult() []*dot.Result {
return rof.Result.DotResult()
}
// newResultObjectField(i, f, opts) builds a resultObjectField from the field
// f at index i.
func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (resultObjectField, error) {
rof := resultObjectField{
FieldName: f.Name,
FieldIndex: idx,
}
var r result
switch {
case f.PkgPath != "":
2023-05-19 16:23:55 -04:00
return rof, newErrInvalidInput(
fmt.Sprintf("unexported fields not allowed in dig.Out, did you mean to export %q (%v)?", f.Name, f.Type), nil)
case f.Tag.Get(_groupTag) != "":
var err error
r, err = newResultGrouped(f)
if err != nil {
return rof, err
}
default:
var err error
if name := f.Tag.Get(_nameTag); len(name) > 0 {
// can modify in-place because options are passed-by-value.
opts.Name = name
}
r, err = newResult(f.Type, opts)
if err != nil {
return rof, err
}
}
rof.Result = r
return rof, nil
}
// resultGrouped is a value produced by a constructor that is part of a result
// group.
//
// These will be produced as fields of a dig.Out struct.
type resultGrouped struct {
// Name of the group as specified in the `group:".."` tag.
Group string
// Type of value produced.
Type reflect.Type
// Indicates elements of a value are to be injected individually, instead of
// as a group. Requires the value's slice to be a group. If set, Type will be
// the type of individual elements rather than the group.
Flatten bool
// If specified, this is a list of types which the value will be made
// available as, in addition to its own type.
As []reflect.Type
}
func (rt resultGrouped) DotResult() []*dot.Result {
dotResults := make([]*dot.Result, 0, len(rt.As)+1)
dotResults = append(dotResults, &dot.Result{
Node: &dot.Node{
Type: rt.Type,
Group: rt.Group,
},
})
for _, asType := range rt.As {
dotResults = append(dotResults, &dot.Result{
Node: &dot.Node{Type: asType, Group: rt.Group},
})
}
return dotResults
}
// newResultGrouped(f) builds a new resultGrouped from the provided field.
func newResultGrouped(f reflect.StructField) (resultGrouped, error) {
g, err := parseGroupString(f.Tag.Get(_groupTag))
if err != nil {
return resultGrouped{}, err
}
rg := resultGrouped{
Group: g.Name,
Flatten: g.Flatten,
Type: f.Type,
}
name := f.Tag.Get(_nameTag)
optional, _ := isFieldOptional(f)
switch {
case g.Flatten && f.Type.Kind() != reflect.Slice:
2023-05-19 16:23:55 -04:00
return rg, newErrInvalidInput(fmt.Sprintf(
"flatten can be applied to slices only: field %q (%v) is not a slice", f.Name, f.Type), nil)
case g.Soft:
2023-05-19 16:23:55 -04:00
return rg, newErrInvalidInput(fmt.Sprintf(
"cannot use soft with result value groups: soft was used with group %q", rg.Group), nil)
case name != "":
2023-05-19 16:23:55 -04:00
return rg, newErrInvalidInput(fmt.Sprintf(
"cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil)
case optional:
2023-05-19 16:23:55 -04:00
return rg, newErrInvalidInput("value groups cannot be optional", nil)
}
if g.Flatten {
rg.Type = f.Type.Elem()
}
return rg, nil
}
func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Value) {
// Decorated values are always flattened.
if !decorated && !rt.Flatten {
cw.submitGroupedValue(rt.Group, rt.Type, v)
for _, asType := range rt.As {
cw.submitGroupedValue(rt.Group, asType, v)
}
return
}
if decorated {
cw.submitDecoratedGroupedValue(rt.Group, rt.Type, v)
return
}
for i := 0; i < v.Len(); i++ {
cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i))
}
}