// Copyright (c) 2019-2021 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "fmt" "reflect" "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/dot" ) // The result interface represents a result produced by a constructor. // // The following implementations exist: // resultList All values returned by the constructor. // resultSingle A single value produced by a constructor. // resultObject dig.Out struct where each field in the struct can be // another result. // resultGrouped A value produced by a constructor that is part of a value // group. type result interface { // Extracts the values for this result from the provided value and // stores them into the provided containerWriter. // // This MAY panic if the result does not consume a single value. Extract(containerWriter, bool, reflect.Value) // DotResult returns a slice of dot.Result(s). DotResult() []*dot.Result } var ( _ result = resultSingle{} _ result = resultObject{} _ result = resultList{} _ result = resultGrouped{} ) type resultOptions struct { // If set, this is the name of the associated result value. // // For Result Objects, name:".." tags on fields override this. Name string Group string As []interface{} } // newResult builds a result from the given type. func newResult(t reflect.Type, opts resultOptions) (result, error) { switch { case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType): return nil, newErrInvalidInput(fmt.Sprintf( "cannot provide parameter objects: %v embeds a dig.In", t), nil) case isError(t): return nil, newErrInvalidInput("cannot return an error here, return it from the constructor instead", nil) case IsOut(t): return newResultObject(t, opts) case embedsType(t, _outPtrType): return nil, newErrInvalidInput(fmt.Sprintf( "cannot build a result object by embedding *dig.Out, embed dig.Out instead: %v embeds *dig.Out", t), nil) case t.Kind() == reflect.Ptr && IsOut(t.Elem()): return nil, newErrInvalidInput(fmt.Sprintf( "cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil) case len(opts.Group) > 0: g, err := parseGroupString(opts.Group) if err != nil { return nil, newErrInvalidInput( fmt.Sprintf("cannot parse group %q", opts.Group), err) } rg := resultGrouped{Type: t, Group: g.Name, Flatten: g.Flatten} if len(opts.As) > 0 { var asTypes []reflect.Type for _, as := range opts.As { ifaceType := reflect.TypeOf(as).Elem() if ifaceType == t { continue } if !t.Implements(ifaceType) { return nil, newErrInvalidInput( fmt.Sprintf("invalid dig.As: %v does not implement %v", t, ifaceType), nil) } asTypes = append(asTypes, ifaceType) } if len(asTypes) > 0 { rg.Type = asTypes[0] rg.As = asTypes[1:] } } if g.Soft { return nil, newErrInvalidInput(fmt.Sprintf( "cannot use soft with result value groups: soft was used with group:%q", g.Name), nil) } if g.Flatten { if t.Kind() != reflect.Slice { return nil, newErrInvalidInput(fmt.Sprintf( "flatten can be applied to slices only: %v is not a slice", t), nil) } rg.Type = rg.Type.Elem() } return rg, nil default: return newResultSingle(t, opts) } } // resultVisitor visits every result in a result tree, allowing tracking state // at each level. type resultVisitor interface { // Visit is called on the result being visited. // // If Visit returns a non-nil resultVisitor, that resultVisitor visits all // the child results of this result. Visit(result) resultVisitor // AnnotateWithField is called on each field of a resultObject after // visiting it but before walking its descendants. // // The same resultVisitor is used for all fields: the one returned upon // visiting the resultObject. // // For each visited field, if AnnotateWithField returns a non-nil // resultVisitor, it will be used to walk the result of that field. AnnotateWithField(resultObjectField) resultVisitor // AnnotateWithPosition is called with the index of each result of a // resultList after vising it but before walking its descendants. // // The same resultVisitor is used for all results: the one returned upon // visiting the resultList. // // For each position, if AnnotateWithPosition returns a non-nil // resultVisitor, it will be used to walk the result at that index. AnnotateWithPosition(idx int) resultVisitor } // walkResult walks the result tree for the given result with the provided // visitor. // // resultVisitor.Visit will be called on the provided result and if a non-nil // resultVisitor is received, it will be used to walk its descendants. If a // resultObject or resultList was visited, AnnotateWithField and // AnnotateWithPosition respectively will be called before visiting the // descendants of that resultObject/resultList. // // This is very similar to how go/ast.Walk works. func walkResult(r result, v resultVisitor) { v = v.Visit(r) if v == nil { return } switch res := r.(type) { case resultSingle, resultGrouped: // No sub-results case resultObject: w := v for _, f := range res.Fields { if v := w.AnnotateWithField(f); v != nil { walkResult(f.Result, v) } } case resultList: w := v for i, r := range res.Results { if v := w.AnnotateWithPosition(i); v != nil { walkResult(r, v) } } default: digerror.BugPanicf("received unknown result type %T", res) } } // resultList holds all values returned by the constructor as results. type resultList struct { ctype reflect.Type Results []result // For each item at index i returned by the constructor, resultIndexes[i] // is the index in .Results for the corresponding result object. // resultIndexes[i] is -1 for errors returned by constructors. resultIndexes []int } func (rl resultList) DotResult() []*dot.Result { var types []*dot.Result for _, result := range rl.Results { types = append(types, result.DotResult()...) } return types } func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { numOut := ctype.NumOut() rl := resultList{ ctype: ctype, Results: make([]result, 0, numOut), resultIndexes: make([]int, numOut), } resultIdx := 0 for i := 0; i < numOut; i++ { t := ctype.Out(i) if isError(t) { rl.resultIndexes[i] = -1 continue } r, err := newResult(t, opts) if err != nil { return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) } rl.Results = append(rl.Results, r) rl.resultIndexes[i] = resultIdx resultIdx++ } return rl, nil } func (resultList) Extract(containerWriter, bool, reflect.Value) { digerror.BugPanicf("resultList.Extract() must never be called") } func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error { for i, v := range values { if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { rl.Results[resultIdx].Extract(cw, decorated, v) continue } if err, _ := v.Interface().(error); err != nil { return err } } return nil } // resultSingle is an explicit value produced by a constructor, optionally // with a name. // // This object will be added to the graph as-is. type resultSingle struct { Name string Type reflect.Type // If specified, this is a list of types which the value will be made // available as, in addition to its own type. As []reflect.Type } func newResultSingle(t reflect.Type, opts resultOptions) (resultSingle, error) { r := resultSingle{ Type: t, Name: opts.Name, } var asTypes []reflect.Type for _, as := range opts.As { ifaceType := reflect.TypeOf(as).Elem() if ifaceType == t { // Special case: // c.Provide(func() io.Reader, As(new(io.Reader))) // Ignore instead of erroring out. continue } if !t.Implements(ifaceType) { return r, newErrInvalidInput( fmt.Sprintf("invalid dig.As: %v does not implement %v", t, ifaceType), nil) } asTypes = append(asTypes, ifaceType) } if len(asTypes) == 0 { return r, nil } return resultSingle{ Type: asTypes[0], Name: opts.Name, As: asTypes[1:], }, nil } func (rs resultSingle) DotResult() []*dot.Result { dotResults := make([]*dot.Result, 0, len(rs.As)+1) dotResults = append(dotResults, &dot.Result{ Node: &dot.Node{ Type: rs.Type, Name: rs.Name, }, }) for _, asType := range rs.As { dotResults = append(dotResults, &dot.Result{ Node: &dot.Node{Type: asType, Name: rs.Name}, }) } return dotResults } func (rs resultSingle) Extract(cw containerWriter, decorated bool, v reflect.Value) { if decorated { cw.setDecoratedValue(rs.Name, rs.Type, v) return } cw.setValue(rs.Name, rs.Type, v) for _, asType := range rs.As { cw.setValue(rs.Name, asType, v) } } // resultObject is a dig.Out struct where each field is another result. // // This object is not added to the graph. Its fields are interpreted as // results and added to the graph if needed. type resultObject struct { Type reflect.Type Fields []resultObjectField } func (ro resultObject) DotResult() []*dot.Result { var types []*dot.Result for _, field := range ro.Fields { types = append(types, field.DotResult()...) } return types } func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { ro := resultObject{Type: t} if len(opts.Name) > 0 { return ro, newErrInvalidInput(fmt.Sprintf( "cannot specify a name for result objects: %v embeds dig.Out", t), nil) } if len(opts.Group) > 0 { return ro, newErrInvalidInput(fmt.Sprintf( "cannot specify a group for result objects: %v embeds dig.Out", t), nil) } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Type == _outType { // Skip over the dig.Out embed. continue } rof, err := newResultObjectField(i, f, opts) if err != nil { return ro, newErrInvalidInput(fmt.Sprintf("bad field %q of %v", f.Name, t), err) } ro.Fields = append(ro.Fields, rof) } return ro, nil } func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) } } // resultObjectField is a single field inside a dig.Out struct. type resultObjectField struct { // Name of the field in the struct. FieldName string // Index of the field in the struct. // // We need to track this separately because not all fields of the struct // map to results. FieldIndex int // Result produced by this field. Result result } func (rof resultObjectField) DotResult() []*dot.Result { return rof.Result.DotResult() } // newResultObjectField(i, f, opts) builds a resultObjectField from the field // f at index i. func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (resultObjectField, error) { rof := resultObjectField{ FieldName: f.Name, FieldIndex: idx, } var r result switch { case f.PkgPath != "": return rof, newErrInvalidInput( fmt.Sprintf("unexported fields not allowed in dig.Out, did you mean to export %q (%v)?", f.Name, f.Type), nil) case f.Tag.Get(_groupTag) != "": var err error r, err = newResultGrouped(f) if err != nil { return rof, err } default: var err error if name := f.Tag.Get(_nameTag); len(name) > 0 { // can modify in-place because options are passed-by-value. opts.Name = name } r, err = newResult(f.Type, opts) if err != nil { return rof, err } } rof.Result = r return rof, nil } // resultGrouped is a value produced by a constructor that is part of a result // group. // // These will be produced as fields of a dig.Out struct. type resultGrouped struct { // Name of the group as specified in the `group:".."` tag. Group string // Type of value produced. Type reflect.Type // Indicates elements of a value are to be injected individually, instead of // as a group. Requires the value's slice to be a group. If set, Type will be // the type of individual elements rather than the group. Flatten bool // If specified, this is a list of types which the value will be made // available as, in addition to its own type. As []reflect.Type } func (rt resultGrouped) DotResult() []*dot.Result { dotResults := make([]*dot.Result, 0, len(rt.As)+1) dotResults = append(dotResults, &dot.Result{ Node: &dot.Node{ Type: rt.Type, Group: rt.Group, }, }) for _, asType := range rt.As { dotResults = append(dotResults, &dot.Result{ Node: &dot.Node{Type: asType, Group: rt.Group}, }) } return dotResults } // newResultGrouped(f) builds a new resultGrouped from the provided field. func newResultGrouped(f reflect.StructField) (resultGrouped, error) { g, err := parseGroupString(f.Tag.Get(_groupTag)) if err != nil { return resultGrouped{}, err } rg := resultGrouped{ Group: g.Name, Flatten: g.Flatten, Type: f.Type, } name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { case g.Flatten && f.Type.Kind() != reflect.Slice: return rg, newErrInvalidInput(fmt.Sprintf( "flatten can be applied to slices only: field %q (%v) is not a slice", f.Name, f.Type), nil) case g.Soft: return rg, newErrInvalidInput(fmt.Sprintf( "cannot use soft with result value groups: soft was used with group %q", rg.Group), nil) case name != "": return rg, newErrInvalidInput(fmt.Sprintf( "cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil) case optional: return rg, newErrInvalidInput("value groups cannot be optional", nil) } if g.Flatten { rg.Type = f.Type.Elem() } return rg, nil } func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Value) { // Decorated values are always flattened. if !decorated && !rt.Flatten { cw.submitGroupedValue(rt.Group, rt.Type, v) for _, asType := range rt.As { cw.submitGroupedValue(rt.Group, asType, v) } return } if decorated { cw.submitDecoratedGroupedValue(rt.Group, rt.Type, v) return } for i := 0; i < v.Len(); i++ { cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i)) } }