2023-06-30 14:23:08 -04:00

666 lines
19 KiB
Go

// Copyright (c) 2021 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package dig
import (
"bytes"
"fmt"
"reflect"
"strings"
"go.uber.org/dig/internal/digreflect"
"go.uber.org/dig/internal/dot"
"go.uber.org/dig/internal/graph"
)
// A ProvideOption modifies the default behavior of Provide.
type ProvideOption interface {
applyProvideOption(*provideOptions)
}
type provideOptions struct {
Name string
Group string
Info *ProvideInfo
As []interface{}
Location *digreflect.Func
Exported bool
Callback Callback
}
func (o *provideOptions) Validate() error {
if len(o.Group) > 0 {
if len(o.Name) > 0 {
return newErrInvalidInput(
fmt.Sprintf("cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group), nil)
}
}
// Names must be representable inside a backquoted string. The only
// limitation for raw string literals as per
// https://golang.org/ref/spec#raw_string_lit is that they cannot contain
// backquotes.
if strings.ContainsRune(o.Name, '`') {
return newErrInvalidInput(
fmt.Sprintf("invalid dig.Name(%q): names cannot contain backquotes", o.Name), nil)
}
if strings.ContainsRune(o.Group, '`') {
return newErrInvalidInput(
fmt.Sprintf("invalid dig.Group(%q): group names cannot contain backquotes", o.Group), nil)
}
for _, i := range o.As {
t := reflect.TypeOf(i)
if t == nil {
return newErrInvalidInput("invalid dig.As(nil): argument must be a pointer to an interface", nil)
}
if t.Kind() != reflect.Ptr {
return newErrInvalidInput(
fmt.Sprintf("invalid dig.As(%v): argument must be a pointer to an interface", t), nil)
}
pointingTo := t.Elem()
if pointingTo.Kind() != reflect.Interface {
return newErrInvalidInput(
fmt.Sprintf("invalid dig.As(*%v): argument must be a pointer to an interface", pointingTo), nil)
}
}
return nil
}
// Name is a ProvideOption that specifies that all values produced by a
// constructor should have the given name. See also the package documentation
// about Named Values.
//
// Given,
//
// func NewReadOnlyConnection(...) (*Connection, error)
// func NewReadWriteConnection(...) (*Connection, error)
//
// The following will provide two connections to the container: one under the
// name "ro" and the other under the name "rw".
//
// c.Provide(NewReadOnlyConnection, dig.Name("ro"))
// c.Provide(NewReadWriteConnection, dig.Name("rw"))
//
// This option cannot be provided for constructors which produce result
// objects.
func Name(name string) ProvideOption {
return provideNameOption(name)
}
type provideNameOption string
func (o provideNameOption) String() string {
return fmt.Sprintf("Name(%q)", string(o))
}
func (o provideNameOption) applyProvideOption(opt *provideOptions) {
opt.Name = string(o)
}
// Group is a ProvideOption that specifies that all values produced by a
// constructor should be added to the specified group. See also the package
// documentation about Value Groups.
//
// This option cannot be provided for constructors which produce result
// objects.
func Group(group string) ProvideOption {
return provideGroupOption(group)
}
type provideGroupOption string
func (o provideGroupOption) String() string {
return fmt.Sprintf("Group(%q)", string(o))
}
func (o provideGroupOption) applyProvideOption(opt *provideOptions) {
opt.Group = string(o)
}
// ID is a unique integer representing the constructor node in the dependency graph.
type ID int
// ProvideInfo provides information about the constructor's inputs and outputs
// types as strings, as well as the ID of the constructor supplied to the Container.
// It contains ID for the constructor, as well as slices of Input and Output types,
// which are Stringers that report the types of the parameters and results respectively.
type ProvideInfo struct {
ID ID
Inputs []*Input
Outputs []*Output
}
// Input contains information on an input parameter of a function.
type Input struct {
t reflect.Type
optional bool
name, group string
}
func (i *Input) String() string {
toks := make([]string, 0, 3)
t := i.t.String()
if i.optional {
toks = append(toks, "optional")
}
if i.name != "" {
toks = append(toks, fmt.Sprintf("name = %q", i.name))
}
if i.group != "" {
toks = append(toks, fmt.Sprintf("group = %q", i.group))
}
if len(toks) == 0 {
return t
}
return fmt.Sprintf("%v[%v]", t, strings.Join(toks, ", "))
}
// Output contains information on an output produced by a function.
type Output struct {
t reflect.Type
name, group string
}
func (o *Output) String() string {
toks := make([]string, 0, 2)
t := o.t.String()
if o.name != "" {
toks = append(toks, fmt.Sprintf("name = %q", o.name))
}
if o.group != "" {
toks = append(toks, fmt.Sprintf("group = %q", o.group))
}
if len(toks) == 0 {
return t
}
return fmt.Sprintf("%v[%v]", t, strings.Join(toks, ", "))
}
// FillProvideInfo is a ProvideOption that writes info on what Dig was able to get
// out of the provided constructor into the provided ProvideInfo.
func FillProvideInfo(info *ProvideInfo) ProvideOption {
return fillProvideInfoOption{info: info}
}
type fillProvideInfoOption struct{ info *ProvideInfo }
func (o fillProvideInfoOption) String() string {
return fmt.Sprintf("FillProvideInfo(%p)", o.info)
}
func (o fillProvideInfoOption) applyProvideOption(opts *provideOptions) {
opts.Info = o.info
}
// As is a ProvideOption that specifies that the value produced by the
// constructor implements one or more other interfaces and is provided
// to the container as those interfaces.
//
// As expects one or more pointers to the implemented interfaces. Values
// produced by constructors will be then available in the container as
// implementations of all of those interfaces, but not as the value itself.
//
// For example, the following will make io.Reader and io.Writer available
// in the container, but not buffer.
//
// c.Provide(newBuffer, dig.As(new(io.Reader), new(io.Writer)))
//
// That is, the above is equivalent to the following.
//
// c.Provide(func(...) (io.Reader, io.Writer) {
// b := newBuffer(...)
// return b, b
// })
//
// If used with dig.Name, the type produced by the constructor and the types
// specified with dig.As will all use the same name. For example,
//
// c.Provide(newFile, dig.As(new(io.Reader)), dig.Name("temp"))
//
// The above is equivalent to the following.
//
// type Result struct {
// dig.Out
//
// Reader io.Reader `name:"temp"`
// }
//
// c.Provide(func(...) Result {
// f := newFile(...)
// return Result{
// Reader: f,
// }
// })
//
// This option cannot be provided for constructors which produce result
// objects.
func As(i ...interface{}) ProvideOption {
return provideAsOption(i)
}
type provideAsOption []interface{}
func (o provideAsOption) String() string {
buf := bytes.NewBufferString("As(")
for i, iface := range o {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(reflect.TypeOf(iface).Elem().String())
}
buf.WriteString(")")
return buf.String()
}
func (o provideAsOption) applyProvideOption(opts *provideOptions) {
opts.As = append(opts.As, o...)
}
// LocationForPC is a ProvideOption which specifies an alternate function program
// counter address to be used for debug information. The package, name, file and
// line number of this alternate function address will be used in error messages
// and DOT graphs. This option is intended to be used with functions created
// with the reflect.MakeFunc method whose error messages are otherwise hard to
// understand
func LocationForPC(pc uintptr) ProvideOption {
return provideLocationOption{
loc: digreflect.InspectFuncPC(pc),
}
}
type provideLocationOption struct{ loc *digreflect.Func }
func (o provideLocationOption) String() string {
return fmt.Sprintf("LocationForPC(%v)", o.loc)
}
func (o provideLocationOption) applyProvideOption(opts *provideOptions) {
opts.Location = o.loc
}
// Export is a ProvideOption which specifies that the provided function should
// be made available to all Scopes available in the application, regardless
// of which Scope it was provided from. By default, it is false.
//
// For example,
//
// c := New()
// s1 := c.Scope("child 1")
// s2:= c.Scope("child 2")
// s1.Provide(func() *bytes.Buffer { ... })
//
// does not allow the constructor returning *bytes.Buffer to be made available to
// the root Container c or its sibling Scope s2.
//
// With Export, you can make this constructor available to all the Scopes:
//
// s1.Provide(func() *bytes.Buffer { ... }, Export(true))
func Export(export bool) ProvideOption {
return provideExportOption{exported: export}
}
type provideExportOption struct{ exported bool }
func (o provideExportOption) String() string {
return fmt.Sprintf("Export(%v)", o.exported)
}
func (o provideExportOption) applyProvideOption(opts *provideOptions) {
opts.Exported = o.exported
}
// provider encapsulates a user-provided constructor.
type provider interface {
// ID is a unique numerical identifier for this provider.
ID() dot.CtorID
// Order reports the order of this provider in the graphHolder.
// This value is usually returned by the graphHolder.NewNode method.
Order(*Scope) int
// Location returns where this constructor was defined.
Location() *digreflect.Func
// ParamList returns information about the direct dependencies of this
// constructor.
ParamList() paramList
// ResultList returns information about the values produced by this
// constructor.
ResultList() resultList
// Calls the underlying constructor, reading values from the
// containerStore as needed.
//
// The values produced by this provider should be submitted into the
// containerStore.
Call(containerStore) error
CType() reflect.Type
OrigScope() *Scope
}
// Provide teaches the container how to build values of one or more types and
// expresses their dependencies.
//
// The first argument of Provide is a function that accepts zero or more
// parameters and returns one or more results. The function may optionally
// return an error to indicate that it failed to build the value. This
// function will be treated as the constructor for all the types it returns.
// This function will be called AT MOST ONCE when a type produced by it, or a
// type that consumes this function's output, is requested via Invoke. If the
// same types are requested multiple times, the previously produced value will
// be reused.
//
// Provide accepts argument types or dig.In structs as dependencies, and
// separate return values or dig.Out structs for results.
func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) error {
return c.scope.Provide(constructor, opts...)
}
// Provide teaches the Scope how to build values of one or more types and
// expresses their dependencies.
//
// The first argument of Provide is a function that accepts zero or more
// parameters and returns one or more results. The function may optionally
// return an error to indicate that it failed to build the value. This
// function will be treated as the constructor for all the types it returns.
// This function will be called AT MOST ONCE when a type produced by it, or a
// type that consumes this function's output, is requested via Invoke. If the
// same types are requested multiple times, the previously produced value will
// be reused.
//
// Provide accepts argument types or dig.In structs as dependencies, and
// separate return values or dig.Out structs for results.
//
// When a constructor is Provided to a Scope, it will propagate this to any
// Scopes that are descendents, but not ancestors of this Scope.
// To provide a constructor to all the Scopes available, provide it to
// Container, which is the root Scope.
func (s *Scope) Provide(constructor interface{}, opts ...ProvideOption) error {
ctype := reflect.TypeOf(constructor)
if ctype == nil {
return newErrInvalidInput("can't provide an untyped nil", nil)
}
if ctype.Kind() != reflect.Func {
return newErrInvalidInput(
fmt.Sprintf("must provide constructor function, got %v (type %v)", constructor, ctype), nil)
}
var options provideOptions
for _, o := range opts {
o.applyProvideOption(&options)
}
if err := options.Validate(); err != nil {
return err
}
if err := s.provide(constructor, options); err != nil {
var errFunc *digreflect.Func
if options.Location == nil {
errFunc = digreflect.InspectFunc(constructor)
} else {
errFunc = options.Location
}
return errProvide{
Func: errFunc,
Reason: err,
}
}
return nil
}
func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) {
// If Export option is provided to the constructor, this should be injected to the
// root-level Scope (Container) to allow it to propagate to all other Scopes.
origScope := s
if opts.Exported {
s = s.rootScope()
}
// For all scopes affected by this change,
// take a snapshot of the current graph state before
// we start making changes to it as we may need to
// undo them upon encountering errors.
allScopes := s.appendSubscopes(nil)
for _, s := range allScopes {
s := s
s.gh.Snapshot()
defer func() {
if err != nil {
s.gh.Rollback()
}
}()
}
n, err := newConstructorNode(
ctor,
s,
origScope,
constructorOptions{
ResultName: opts.Name,
ResultGroup: opts.Group,
ResultAs: opts.As,
Location: opts.Location,
Callback: opts.Callback,
},
)
if err != nil {
return err
}
keys, err := s.findAndValidateResults(n.ResultList())
if err != nil {
return err
}
ctype := reflect.TypeOf(ctor)
if len(keys) == 0 {
return newErrInvalidInput(
fmt.Sprintf("%v must provide at least one non-error type", ctype), nil)
}
oldProviders := make(map[key][]*constructorNode)
for k := range keys {
// Cache old providers before running cycle detection.
oldProviders[k] = s.providers[k]
s.providers[k] = append(s.providers[k], n)
}
for _, s := range allScopes {
s.isVerifiedAcyclic = false
if s.deferAcyclicVerification {
continue
}
if ok, cycle := graph.IsAcyclic(s.gh); !ok {
// When a cycle is detected, recover the old providers to reset
// the providers map back to what it was before this node was
// introduced.
for k, ops := range oldProviders {
s.providers[k] = ops
}
return newErrInvalidInput("this function introduces a cycle", s.cycleDetectedError(cycle))
}
s.isVerifiedAcyclic = true
}
s.nodes = append(s.nodes, n)
// Record introspection info for caller if Info option is specified
if info := opts.Info; info != nil {
params := n.ParamList().DotParam()
results := n.ResultList().DotResult()
info.ID = (ID)(n.id)
info.Inputs = make([]*Input, len(params))
info.Outputs = make([]*Output, len(results))
for i, param := range params {
info.Inputs[i] = &Input{
t: param.Type,
optional: param.Optional,
name: param.Name,
group: param.Group,
}
}
for i, res := range results {
info.Outputs[i] = &Output{
t: res.Type,
name: res.Name,
group: res.Group,
}
}
}
return nil
}
// Builds a collection of all result types produced by this constructor.
func (s *Scope) findAndValidateResults(rl resultList) (map[key]struct{}, error) {
var err error
keyPaths := make(map[key]string)
walkResult(rl, connectionVisitor{
s: s,
err: &err,
keyPaths: keyPaths,
})
if err != nil {
return nil, err
}
keys := make(map[key]struct{}, len(keyPaths))
for k := range keyPaths {
keys[k] = struct{}{}
}
return keys, nil
}
// Visits the results of a node and compiles a collection of all the keys
// produced by that node.
type connectionVisitor struct {
s *Scope
// If this points to a non-nil value, we've already encountered an error
// and should stop traversing.
err *error
// Map of keys provided to path that provided this. The path is a string
// documenting which positional return value or dig.Out attribute is
// providing this particular key.
//
// For example, "[0].Foo" indicates that the value was provided by the Foo
// attribute of the dig.Out returned as the first result of the
// constructor.
keyPaths map[key]string
// We track the path to the current result here. For example, this will
// be, ["[1]", "Foo", "Bar"] when we're visiting Bar in,
//
// func() (io.Writer, struct {
// dig.Out
//
// Foo struct {
// dig.Out
//
// Bar io.Reader
// }
// })
currentResultPath []string
}
func (cv connectionVisitor) AnnotateWithField(f resultObjectField) resultVisitor {
cv.currentResultPath = append(cv.currentResultPath, f.FieldName)
return cv
}
func (cv connectionVisitor) AnnotateWithPosition(i int) resultVisitor {
cv.currentResultPath = append(cv.currentResultPath, fmt.Sprintf("[%d]", i))
return cv
}
func (cv connectionVisitor) Visit(res result) resultVisitor {
// Already failed. Stop looking.
if *cv.err != nil {
return nil
}
path := strings.Join(cv.currentResultPath, ".")
switch r := res.(type) {
case resultSingle:
k := key{name: r.Name, t: r.Type}
if err := cv.checkKey(k, path); err != nil {
*cv.err = err
return nil
}
for _, asType := range r.As {
k := key{name: r.Name, t: asType}
if err := cv.checkKey(k, path); err != nil {
*cv.err = err
return nil
}
}
case resultGrouped:
// we don't really care about the path for this since conflicts are
// okay for group results. We'll track it for the sake of having a
// value there.
k := key{group: r.Group, t: r.Type}
cv.keyPaths[k] = path
for _, asType := range r.As {
k := key{group: r.Group, t: asType}
cv.keyPaths[k] = path
}
}
return cv
}
func (cv connectionVisitor) checkKey(k key, path string) error {
defer func() { cv.keyPaths[k] = path }()
if conflict, ok := cv.keyPaths[k]; ok {
return newErrInvalidInput(fmt.Sprintf("cannot provide %v from %v", k, path),
newErrInvalidInput(fmt.Sprintf("already provided by %v", conflict), nil))
}
if ps := cv.s.providers[k]; len(ps) > 0 {
cons := make([]string, len(ps))
for i, p := range ps {
cons[i] = fmt.Sprint(p.Location())
}
return newErrInvalidInput(fmt.Sprintf("cannot provide %v from %v", k, path),
newErrInvalidInput(fmt.Sprintf("already provided by %v", strings.Join(cons, "; ")), nil))
}
return nil
}