From e76813eb13bc55cfbe1d7487c6fc65b3cbe64d46 Mon Sep 17 00:00:00 2001 From: 6xiaowu9 <736518585@qq.com> Date: Thu, 24 Nov 2022 18:45:20 +0800 Subject: [PATCH] signer/core/apitypes: deep convert types in slice (#26203) --- .../apitypes/signed_data_internal_test.go | 36 +++++++++++++++++++ signer/core/apitypes/types.go | 19 ++++++++-- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/signer/core/apitypes/signed_data_internal_test.go b/signer/core/apitypes/signed_data_internal_test.go index 8379c0a7f..af7fc93ed 100644 --- a/signer/core/apitypes/signed_data_internal_test.go +++ b/signer/core/apitypes/signed_data_internal_test.go @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/math" ) func TestBytesPadding(t *testing.T) { @@ -197,3 +198,38 @@ func TestParseInteger(t *testing.T) { } } } + +func TestConvertStringDataToSlice(t *testing.T) { + slice := []string{"a", "b", "c"} + var it interface{} = slice + _, err := convertDataToSlice(it) + if err != nil { + t.Fatal(err) + } +} + +func TestConvertUint256DataToSlice(t *testing.T) { + slice := []*math.HexOrDecimal256{ + math.NewHexOrDecimal256(1), + math.NewHexOrDecimal256(2), + math.NewHexOrDecimal256(3), + } + var it interface{} = slice + _, err := convertDataToSlice(it) + if err != nil { + t.Fatal(err) + } +} + +func TestConvertAddressDataToSlice(t *testing.T) { + slice := []common.Address{ + common.HexToAddress("0x0000000000000000000000000000000000000001"), + common.HexToAddress("0x0000000000000000000000000000000000000002"), + common.HexToAddress("0x0000000000000000000000000000000000000003"), + } + var it interface{} = slice + _, err := convertDataToSlice(it) + if err != nil { + t.Fatal(err) + } +} diff --git a/signer/core/apitypes/types.go b/signer/core/apitypes/types.go index 6e883b27c..3e099feaa 100644 --- a/signer/core/apitypes/types.go +++ b/signer/core/apitypes/types.go @@ -367,8 +367,8 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter encType := field.Type encValue := data[field.Name] if encType[len(encType)-1:] == "]" { - arrayValue, ok := encValue.([]interface{}) - if !ok { + arrayValue, err := convertDataToSlice(encValue) + if err != nil { return nil, dataMismatchError(encType, encValue) } @@ -573,6 +573,19 @@ func dataMismatchError(encType string, encValue interface{}) error { return fmt.Errorf("provided data '%v' doesn't match type '%s'", encValue, encType) } +func convertDataToSlice(encValue interface{}) ([]interface{}, error) { + var outEncValue []interface{} + rv := reflect.ValueOf(encValue) + if rv.Kind() == reflect.Slice { + for i := 0; i < rv.Len(); i++ { + outEncValue = append(outEncValue, rv.Index(i).Interface()) + } + } else { + return outEncValue, fmt.Errorf("provided data '%v' is not slice", encValue) + } + return outEncValue, nil +} + // validate makes sure the types are sound func (typedData *TypedData) validate() error { if err := typedData.Types.validate(); err != nil { @@ -632,7 +645,7 @@ func (typedData *TypedData) formatData(primaryType string, data map[string]inter Typ: field.Type, } if field.isArray() { - arrayValue, _ := encValue.([]interface{}) + arrayValue, _ := convertDataToSlice(encValue) parsedType := field.typeName() for _, v := range arrayValue { if typedData.Types[parsedType] != nil {