2022-04-01 00:23:19 +02:00

454 lines
11 KiB
Go

package resolvable
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/graph-gophers/graphql-go/decode"
"github.com/graph-gophers/graphql-go/internal/exec/packer"
"github.com/graph-gophers/graphql-go/types"
)
type Schema struct {
*Meta
types.Schema
Query Resolvable
Mutation Resolvable
Subscription Resolvable
Resolver reflect.Value
}
type Resolvable interface {
isResolvable()
}
type Object struct {
Name string
Fields map[string]*Field
TypeAssertions map[string]*TypeAssertion
}
type Field struct {
types.FieldDefinition
TypeName string
MethodIndex int
FieldIndex []int
HasContext bool
HasError bool
ArgsPacker *packer.StructPacker
ValueExec Resolvable
TraceLabel string
}
func (f *Field) UseMethodResolver() bool {
return len(f.FieldIndex) == 0
}
type TypeAssertion struct {
MethodIndex int
TypeExec Resolvable
}
type List struct {
Elem Resolvable
}
type Scalar struct{}
func (*Object) isResolvable() {}
func (*List) isResolvable() {}
func (*Scalar) isResolvable() {}
func ApplyResolver(s *types.Schema, resolver interface{}) (*Schema, error) {
if resolver == nil {
return &Schema{Meta: newMeta(s), Schema: *s}, nil
}
b := newBuilder(s)
var query, mutation, subscription Resolvable
if t, ok := s.EntryPoints["query"]; ok {
if err := b.assignExec(&query, t, reflect.TypeOf(resolver)); err != nil {
return nil, err
}
}
if t, ok := s.EntryPoints["mutation"]; ok {
if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver)); err != nil {
return nil, err
}
}
if t, ok := s.EntryPoints["subscription"]; ok {
if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver)); err != nil {
return nil, err
}
}
if err := b.finish(); err != nil {
return nil, err
}
return &Schema{
Meta: newMeta(s),
Schema: *s,
Resolver: reflect.ValueOf(resolver),
Query: query,
Mutation: mutation,
Subscription: subscription,
}, nil
}
type execBuilder struct {
schema *types.Schema
resMap map[typePair]*resMapEntry
packerBuilder *packer.Builder
}
type typePair struct {
graphQLType types.Type
resolverType reflect.Type
}
type resMapEntry struct {
exec Resolvable
targets []*Resolvable
}
func newBuilder(s *types.Schema) *execBuilder {
return &execBuilder{
schema: s,
resMap: make(map[typePair]*resMapEntry),
packerBuilder: packer.NewBuilder(),
}
}
func (b *execBuilder) finish() error {
for _, entry := range b.resMap {
for _, target := range entry.targets {
*target = entry.exec
}
}
return b.packerBuilder.Finish()
}
func (b *execBuilder) assignExec(target *Resolvable, t types.Type, resolverType reflect.Type) error {
k := typePair{t, resolverType}
ref, ok := b.resMap[k]
if !ok {
ref = &resMapEntry{}
b.resMap[k] = ref
var err error
ref.exec, err = b.makeExec(t, resolverType)
if err != nil {
return err
}
}
ref.targets = append(ref.targets, target)
return nil
}
func (b *execBuilder) makeExec(t types.Type, resolverType reflect.Type) (Resolvable, error) {
var nonNull bool
t, nonNull = unwrapNonNull(t)
switch t := t.(type) {
case *types.ObjectTypeDefinition:
return b.makeObjectExec(t.Name, t.Fields, nil, nonNull, resolverType)
case *types.InterfaceTypeDefinition:
return b.makeObjectExec(t.Name, t.Fields, t.PossibleTypes, nonNull, resolverType)
case *types.Union:
return b.makeObjectExec(t.Name, nil, t.UnionMemberTypes, nonNull, resolverType)
}
if !nonNull {
if resolverType.Kind() != reflect.Ptr {
return nil, fmt.Errorf("%s is not a pointer", resolverType)
}
resolverType = resolverType.Elem()
}
switch t := t.(type) {
case *types.ScalarTypeDefinition:
return makeScalarExec(t, resolverType)
case *types.EnumTypeDefinition:
return &Scalar{}, nil
case *types.List:
if resolverType.Kind() != reflect.Slice {
return nil, fmt.Errorf("%s is not a slice", resolverType)
}
e := &List{}
if err := b.assignExec(&e.Elem, t.OfType, resolverType.Elem()); err != nil {
return nil, err
}
return e, nil
default:
panic("invalid type: " + t.String())
}
}
func makeScalarExec(t *types.ScalarTypeDefinition, resolverType reflect.Type) (Resolvable, error) {
implementsType := false
switch r := reflect.New(resolverType).Interface().(type) {
case *int32:
implementsType = t.Name == "Int"
case *float64:
implementsType = t.Name == "Float"
case *string:
implementsType = t.Name == "String"
case *bool:
implementsType = t.Name == "Boolean"
case decode.Unmarshaler:
implementsType = r.ImplementsGraphQLType(t.Name)
}
if !implementsType {
return nil, fmt.Errorf("can not use %s as %s", resolverType, t.Name)
}
return &Scalar{}, nil
}
func (b *execBuilder) makeObjectExec(typeName string, fields types.FieldsDefinition, possibleTypes []*types.ObjectTypeDefinition,
nonNull bool, resolverType reflect.Type) (*Object, error) {
if !nonNull {
if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface {
return nil, fmt.Errorf("%s is not a pointer or interface", resolverType)
}
}
methodHasReceiver := resolverType.Kind() != reflect.Interface
Fields := make(map[string]*Field)
rt := unwrapPtr(resolverType)
fieldsCount := fieldCount(rt, map[string]int{})
for _, f := range fields {
var fieldIndex []int
methodIndex := findMethod(resolverType, f.Name)
if b.schema.UseFieldResolvers && methodIndex == -1 {
if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 {
return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name)
}
fieldIndex = findField(rt, f.Name, []int{})
}
if methodIndex == -1 && len(fieldIndex) == 0 {
hint := ""
if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 {
hint = " (hint: the method exists on the pointer type)"
}
return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint)
}
var m reflect.Method
var sf reflect.StructField
if methodIndex != -1 {
m = resolverType.Method(methodIndex)
} else {
sf = rt.FieldByIndex(fieldIndex)
}
fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver)
if err != nil {
var resolverName string
if methodIndex != -1 {
resolverName = m.Name
} else {
resolverName = sf.Name
}
return nil, fmt.Errorf("%s\n\tused by (%s).%s", err, resolverType, resolverName)
}
Fields[f.Name] = fe
}
// Check type assertions when
// 1) using method resolvers
// 2) Or resolver is not an interface type
typeAssertions := make(map[string]*TypeAssertion)
if !b.schema.UseFieldResolvers || resolverType.Kind() != reflect.Interface {
for _, impl := range possibleTypes {
methodIndex := findMethod(resolverType, "To"+impl.Name)
if methodIndex == -1 {
return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name)
}
if resolverType.Method(methodIndex).Type.NumOut() != 2 {
return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name)
}
a := &TypeAssertion{
MethodIndex: methodIndex,
}
if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil {
return nil, err
}
typeAssertions[impl.Name] = a
}
}
return &Object{
Name: typeName,
Fields: Fields,
TypeAssertions: typeAssertions,
}, nil
}
var contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
var errorType = reflect.TypeOf((*error)(nil)).Elem()
func (b *execBuilder) makeFieldExec(typeName string, f *types.FieldDefinition, m reflect.Method, sf reflect.StructField,
methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) {
var argsPacker *packer.StructPacker
var hasError bool
var hasContext bool
// Validate resolver method only when there is one
if methodIndex != -1 {
in := make([]reflect.Type, m.Type.NumIn())
for i := range in {
in[i] = m.Type.In(i)
}
if methodHasReceiver {
in = in[1:] // first parameter is receiver
}
hasContext = len(in) > 0 && in[0] == contextType
if hasContext {
in = in[1:]
}
if len(f.Arguments) > 0 {
if len(in) == 0 {
return nil, fmt.Errorf("must have parameter for field arguments")
}
var err error
argsPacker, err = b.packerBuilder.MakeStructPacker(f.Arguments, in[0])
if err != nil {
return nil, err
}
in = in[1:]
}
if len(in) > 0 {
return nil, fmt.Errorf("too many parameters")
}
maxNumOfReturns := 2
if m.Type.NumOut() < maxNumOfReturns-1 {
return nil, fmt.Errorf("too few return values")
}
if m.Type.NumOut() > maxNumOfReturns {
return nil, fmt.Errorf("too many return values")
}
hasError = m.Type.NumOut() == maxNumOfReturns
if hasError {
if m.Type.Out(maxNumOfReturns-1) != errorType {
return nil, fmt.Errorf(`must have "error" as its last return value`)
}
}
}
fe := &Field{
FieldDefinition: *f,
TypeName: typeName,
MethodIndex: methodIndex,
FieldIndex: fieldIndex,
HasContext: hasContext,
ArgsPacker: argsPacker,
HasError: hasError,
TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name),
}
var out reflect.Type
if methodIndex != -1 {
out = m.Type.Out(0)
sub, ok := b.schema.EntryPoints["subscription"]
if ok && typeName == sub.TypeName() && out.Kind() == reflect.Chan {
out = m.Type.Out(0).Elem()
}
} else {
out = sf.Type
}
if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil {
return nil, err
}
return fe, nil
}
func findMethod(t reflect.Type, name string) int {
for i := 0; i < t.NumMethod(); i++ {
if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Method(i).Name)) {
return i
}
}
return -1
}
func findField(t reflect.Type, name string, index []int) []int {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Type.Kind() == reflect.Struct && field.Anonymous {
newIndex := findField(field.Type, name, []int{i})
if len(newIndex) > 1 {
return append(index, newIndex...)
}
}
if strings.EqualFold(stripUnderscore(name), stripUnderscore(field.Name)) {
return append(index, i)
}
}
return index
}
// fieldCount helps resolve ambiguity when more than one embedded struct contains fields with the same name.
func fieldCount(t reflect.Type, count map[string]int) map[string]int {
if t.Kind() != reflect.Struct {
return nil
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldName := strings.ToLower(stripUnderscore(field.Name))
if field.Type.Kind() == reflect.Struct && field.Anonymous {
count = fieldCount(field.Type, count)
} else {
if _, ok := count[fieldName]; !ok {
count[fieldName] = 0
}
count[fieldName]++
}
}
return count
}
func unwrapNonNull(t types.Type) (types.Type, bool) {
if nn, ok := t.(*types.NonNull); ok {
return nn.OfType, true
}
return t, false
}
func stripUnderscore(s string) string {
return strings.Replace(s, "_", "", -1)
}
func unwrapPtr(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return t.Elem()
}
return t
}