diff --git a/bencode/decode.go b/bencode/decode.go index 8b22fa73..51804614 100644 --- a/bencode/decode.go +++ b/bencode/decode.go @@ -205,50 +205,40 @@ func (d *Decoder) parseString(v reflect.Value) error { // Info for parsing a dict value. type dictField struct { - Value reflect.Value // Storage for the parsed value. - // True if field value should be parsed into Value. If false, the value - // should be parsed and discarded. - Ok bool - Set func() // Call this after parsing into Value. - IgnoreUnmarshalTypeError bool + Type reflect.Type + Get func(value reflect.Value) func(reflect.Value) + Tags tag } // Returns specifics for parsing a dict field value. -func getDictField(dict reflect.Value, key string) dictField { +func getDictField(dict reflect.Type, key string) dictField { // get valuev as a map value or as a struct field switch dict.Kind() { case reflect.Map: - value := reflect.New(dict.Type().Elem()).Elem() return dictField{ - Value: value, - Ok: true, - Set: func() { - if dict.IsNil() { - dict.Set(reflect.MakeMap(dict.Type())) + Type: dict.Elem(), + Get: func(mapValue reflect.Value) func(reflect.Value) { + return func(value reflect.Value) { + if mapValue.IsNil() { + mapValue.Set(reflect.MakeMap(dict)) + } + // Assigns the value into the map. + //log.Printf("map type: %v", mapValue.Type()) + mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value) } - // Assigns the value into the map. - dict.SetMapIndex(reflect.ValueOf(key).Convert(dict.Type().Key()), value) }, } case reflect.Struct: - sf, ok := getStructFieldForKey(dict.Type(), key) - if !ok { - return dictField{} - } - if sf.r.PkgPath != "" { - panic(&UnmarshalFieldError{ - Key: key, - Type: dict.Type(), - Field: sf.r, - }) - } - return dictField{ - Value: dict.FieldByIndex(sf.r.Index), - Ok: true, - Set: func() {}, - IgnoreUnmarshalTypeError: sf.tag.IgnoreUnmarshalTypeError(), - } + return getStructFieldForKey(dict, key) + //if sf.r.PkgPath != "" { + // panic(&UnmarshalFieldError{ + // Key: key, + // Type: dict.Type(), + // Field: sf.r, + // }) + //} default: + panic("unimplemented") return dictField{} } } @@ -260,11 +250,12 @@ type structField struct { var ( structFieldsMu sync.Mutex - structFields = map[reflect.Type]map[string]structField{} + structFields = map[reflect.Type]map[string]dictField{} ) -func parseStructFields(struct_ reflect.Type, each func(string, structField)) { - for i, n := 0, struct_.NumField(); i < n; i++ { +func parseStructFields(struct_ reflect.Type, each func(string, dictField)) { + for _i, n := 0, struct_.NumField(); _i < n; _i++ { + i := _i f := struct_.Field(i) if f.Anonymous { continue @@ -278,25 +269,35 @@ func parseStructFields(struct_ reflect.Type, each func(string, structField)) { if key == "" { key = f.Name } - each(key, structField{f, tag}) + each(key, dictField{f.Type, func(value reflect.Value) func(reflect.Value) { + return value.Field(i).Set + }, tag}) } } func saveStructFields(struct_ reflect.Type) { - m := make(map[string]structField) - parseStructFields(struct_, func(key string, sf structField) { + m := make(map[string]dictField) + parseStructFields(struct_, func(key string, sf dictField) { m[key] = sf }) structFields[struct_] = m } -func getStructFieldForKey(struct_ reflect.Type, key string) (f structField, ok bool) { +func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) { structFieldsMu.Lock() if _, ok := structFields[struct_]; !ok { saveStructFields(struct_) } - f, ok = structFields[struct_][key] + f, ok := structFields[struct_][key] structFieldsMu.Unlock() + if !ok { + var discard interface{} + return dictField{ + Type: reflect.TypeOf(discard), + Get: func(reflect.Value) func(reflect.Value) { return func(reflect.Value) {} }, + Tags: nil, + } + } return } @@ -314,31 +315,33 @@ func (d *Decoder) parseDict(v reflect.Value) error { return nil } - df := getDictField(v, keyStr) + df := getDictField(v.Type(), keyStr) // now we need to actually parse it - if df.Ok { - // log.Printf("parsing ok struct field for key %q", keyStr) - ok, err = d.parseValue(df.Value) - } else { + if df.Type == nil { // Discard the value, there's nowhere to put it. var if_ interface{} if_, ok = d.parseValueInterface() if if_ == nil { - err = fmt.Errorf("error parsing value for key %q", keyStr) + return fmt.Errorf("error parsing value for key %q", keyStr) } + if !ok { + return fmt.Errorf("missing value for key %q", keyStr) + } + continue } + setValue := reflect.New(df.Type).Elem() + //log.Printf("parsing into %v", setValue.Type()) + ok, err = d.parseValue(setValue) if err != nil { - if _, ok := err.(*UnmarshalTypeError); !ok || !df.IgnoreUnmarshalTypeError { + if _, ok := err.(*UnmarshalTypeError); !ok || !df.Tags.IgnoreUnmarshalTypeError() { return fmt.Errorf("parsing value for key %q: %s", keyStr, err) } } if !ok { return fmt.Errorf("missing value for key %q", keyStr) } - if df.Ok { - df.Set() - } + df.Get(v)(setValue) } } diff --git a/bencode/decode_test.go b/bencode/decode_test.go index 4b72edbb..056a399a 100644 --- a/bencode/decode_test.go +++ b/bencode/decode_test.go @@ -7,6 +7,7 @@ import ( "reflect" "testing" + qt "github.com/frankban/quicktest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -144,7 +145,7 @@ func TestIgnoreUnmarshalTypeError(t *testing.T) { }{} require.Error(t, Unmarshal([]byte("d6:Normal5:helloe"), &s)) assert.NoError(t, Unmarshal([]byte("d6:Ignore5:helloe"), &s)) - require.Nil(t, Unmarshal([]byte("d6:Ignorei42ee"), &s)) + qt.Assert(t, Unmarshal([]byte("d6:Ignorei42ee"), &s), qt.IsNil) assert.EqualValues(t, 42, s.Ignore) } diff --git a/bencode/tags.go b/bencode/tags.go index 50bdc72b..d4adeb24 100644 --- a/bencode/tags.go +++ b/bencode/tags.go @@ -24,6 +24,9 @@ func (me tag) Key() string { } func (me tag) HasOpt(opt string) bool { + if len(me) < 1 { + return false + } for _, s := range me[1:] { if s == opt { return true