fix(dapps)_: don't require chain ID for signing typed data v4
The reused implementation from signing typed data V1 was used in case of signing typed data V4. This implementation required chain ID to be present in the typed data. This change fixes the issue by making chainID optional for signing typed data V4.
This commit is contained in:
parent
4c6ca00520
commit
dc62171219
|
@ -17,11 +17,11 @@ func TestChainIDValidation(t *testing.T) {
|
|||
for _, tc := range []testCase{
|
||||
{
|
||||
"ChainIDMismatch",
|
||||
map[string]json.RawMessage{chainIDKey: json.RawMessage("1")},
|
||||
map[string]json.RawMessage{ChainIDKey: json.RawMessage("1")},
|
||||
},
|
||||
{
|
||||
"ChainIDNotAnInt",
|
||||
map[string]json.RawMessage{chainIDKey: json.RawMessage(`"aa"`)},
|
||||
map[string]json.RawMessage{ChainIDKey: json.RawMessage(`"aa"`)},
|
||||
},
|
||||
{
|
||||
"NoChainIDKey",
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
const (
|
||||
eip712Domain = "EIP712Domain"
|
||||
chainIDKey = "chainId"
|
||||
ChainIDKey = "chainId"
|
||||
)
|
||||
|
||||
// Types define fields for each composite type.
|
||||
|
@ -73,13 +73,13 @@ func (t TypedData) Validate() error {
|
|||
|
||||
// ValidateChainID accept chain as big integer and verifies if typed data belongs to the same chain.
|
||||
func (t TypedData) ValidateChainID(chain *big.Int) error {
|
||||
if _, exist := t.Domain[chainIDKey]; !exist {
|
||||
return fmt.Errorf("domain misses chain key %s", chainIDKey)
|
||||
if _, exist := t.Domain[ChainIDKey]; !exist {
|
||||
return fmt.Errorf("domain misses chain key %s", ChainIDKey)
|
||||
}
|
||||
var chainID int64
|
||||
if err := json.Unmarshal(t.Domain[chainIDKey], &chainID); err != nil {
|
||||
if err := json.Unmarshal(t.Domain[ChainIDKey], &chainID); err != nil {
|
||||
var chainIDString string
|
||||
if err = json.Unmarshal(t.Domain[chainIDKey], &chainIDString); err != nil {
|
||||
if err = json.Unmarshal(t.Domain[ChainIDKey], &chainIDString); err != nil {
|
||||
return err
|
||||
}
|
||||
if chainID, err = strconv.ParseInt(chainIDString, 0, 64); err != nil {
|
||||
|
|
|
@ -263,14 +263,18 @@ func SafeSignTypedDataForDApps(typedJson string, privateKey *ecdsa.PrivateKey, c
|
|||
}
|
||||
|
||||
chain := new(big.Int).SetUint64(chainID)
|
||||
if err := typed.ValidateChainID(chain); err != nil {
|
||||
return types.HexBytes{}, err
|
||||
}
|
||||
|
||||
var sig hexutil.Bytes
|
||||
if legacy {
|
||||
sig, err = typeddata.Sign(typed, privateKey, chain)
|
||||
} else {
|
||||
// Validate chainID if part of the typed data
|
||||
if _, exist := typed.Domain[typeddata.ChainIDKey]; exist {
|
||||
if err := typed.ValidateChainID(chain); err != nil {
|
||||
return types.HexBytes{}, err
|
||||
}
|
||||
}
|
||||
|
||||
var typedV4 signercore.TypedData
|
||||
err = json.Unmarshal([]byte(typedJson), &typedV4)
|
||||
if err != nil {
|
||||
|
|
|
@ -428,10 +428,28 @@ func Test_AddSession(t *testing.T) {
|
|||
assert.Equal(t, sessions[0].IconURL, dapps[0].IconURL)
|
||||
}
|
||||
|
||||
func generateTypedDataJson(chainID int, skipField bool) string {
|
||||
type typedDataParams struct {
|
||||
chainID int
|
||||
skipField bool
|
||||
excludeChainID bool
|
||||
wrongContractType bool
|
||||
}
|
||||
|
||||
func generateTypedDataJson(p typedDataParams) string {
|
||||
optionalKeyValueField := ""
|
||||
if !skipField {
|
||||
optionalKeyValueField = `,"verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC"`
|
||||
if !p.skipField {
|
||||
if p.wrongContractType {
|
||||
optionalKeyValueField = `,"verifyingContract": true`
|
||||
} else {
|
||||
optionalKeyValueField = `,"verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC"`
|
||||
}
|
||||
}
|
||||
|
||||
chainIDSchemeEntry := ""
|
||||
chainIDDataEntry := ""
|
||||
if !p.excludeChainID {
|
||||
chainIDSchemeEntry = `{"name": "chainId", "type": "uint256"},`
|
||||
chainIDDataEntry = `,"chainId": ` + strconv.Itoa(p.chainID)
|
||||
}
|
||||
|
||||
typedData := `{
|
||||
|
@ -439,7 +457,7 @@ func generateTypedDataJson(chainID int, skipField bool) string {
|
|||
"EIP712Domain": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "version", "type": "string"},
|
||||
{"name": "chainId", "type": "uint256"},
|
||||
` + chainIDSchemeEntry + `
|
||||
{"name": "verifyingContract", "type": "address"}
|
||||
],
|
||||
"Person": [
|
||||
|
@ -455,8 +473,8 @@ func generateTypedDataJson(chainID int, skipField bool) string {
|
|||
"primaryType": "Mail",
|
||||
"domain": {
|
||||
"name": "Ether Mail",
|
||||
"version": "1",
|
||||
"chainId": ` + strconv.Itoa(chainID) + `
|
||||
"version": "1"
|
||||
` + chainIDDataEntry + `
|
||||
` + optionalKeyValueField + `
|
||||
},
|
||||
"message": {
|
||||
|
@ -492,7 +510,9 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
|
|||
{
|
||||
name: "sign_typed_data",
|
||||
args: args{
|
||||
typedJson: generateTypedDataJson(1, false),
|
||||
typedJson: generateTypedDataJson(typedDataParams{
|
||||
chainID: 1,
|
||||
}),
|
||||
privateKey: privateKey,
|
||||
chainID: 1,
|
||||
legacy: false,
|
||||
|
@ -502,7 +522,9 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
|
|||
{
|
||||
name: "sign_typed_data_legacy",
|
||||
args: args{
|
||||
typedJson: generateTypedDataJson(1, false),
|
||||
typedJson: generateTypedDataJson(typedDataParams{
|
||||
chainID: 1,
|
||||
}),
|
||||
privateKey: privateKey,
|
||||
chainID: 1,
|
||||
legacy: true,
|
||||
|
@ -512,17 +534,32 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
|
|||
{
|
||||
name: "sign_typed_data_invalid_json",
|
||||
args: args{
|
||||
typedJson: `{"invalid": "json"`,
|
||||
typedJson: generateTypedDataJson(typedDataParams{
|
||||
chainID: 1,
|
||||
wrongContractType: true,
|
||||
}),
|
||||
privateKey: privateKey,
|
||||
chainID: 1,
|
||||
legacy: false,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sign_typed_data_invalid_json_legacy",
|
||||
args: args{
|
||||
typedJson: `{"invalid": "json"`,
|
||||
privateKey: privateKey,
|
||||
chainID: 1,
|
||||
legacy: true,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sign_typed_data_invalid_chain_id",
|
||||
args: args{
|
||||
typedJson: generateTypedDataJson(1, false),
|
||||
typedJson: generateTypedDataJson(typedDataParams{
|
||||
chainID: 1,
|
||||
}),
|
||||
privateKey: privateKey,
|
||||
chainID: 2,
|
||||
legacy: false,
|
||||
|
@ -532,13 +569,29 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
|
|||
{
|
||||
name: "sign_typed_data_missing_field",
|
||||
args: args{
|
||||
typedJson: generateTypedDataJson(1, true),
|
||||
typedJson: generateTypedDataJson(typedDataParams{
|
||||
chainID: 1,
|
||||
skipField: true,
|
||||
}),
|
||||
privateKey: privateKey,
|
||||
chainID: 1,
|
||||
legacy: false,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sign_typed_data_exclude_chain_id",
|
||||
args: args{
|
||||
typedJson: generateTypedDataJson(typedDataParams{
|
||||
chainID: 1,
|
||||
excludeChainID: true,
|
||||
}),
|
||||
privateKey: privateKey,
|
||||
chainID: 1,
|
||||
legacy: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue