fix(dapps)_: don't require chain ID for signing typed data v4

The reused implementation from signing typed data V1 was used
in case of signing typed data V4. This implementation required
chain ID to be present in the typed data. This change fixes
the issue by making chainID optional for signing typed data V4.
This commit is contained in:
Stefan 2024-07-15 17:43:57 +03:00 committed by Stefan Dunca
parent 4c6ca00520
commit dc62171219
4 changed files with 78 additions and 21 deletions

View File

@ -17,11 +17,11 @@ func TestChainIDValidation(t *testing.T) {
for _, tc := range []testCase{
{
"ChainIDMismatch",
map[string]json.RawMessage{chainIDKey: json.RawMessage("1")},
map[string]json.RawMessage{ChainIDKey: json.RawMessage("1")},
},
{
"ChainIDNotAnInt",
map[string]json.RawMessage{chainIDKey: json.RawMessage(`"aa"`)},
map[string]json.RawMessage{ChainIDKey: json.RawMessage(`"aa"`)},
},
{
"NoChainIDKey",

View File

@ -10,7 +10,7 @@ import (
const (
eip712Domain = "EIP712Domain"
chainIDKey = "chainId"
ChainIDKey = "chainId"
)
// Types define fields for each composite type.
@ -73,13 +73,13 @@ func (t TypedData) Validate() error {
// ValidateChainID accept chain as big integer and verifies if typed data belongs to the same chain.
func (t TypedData) ValidateChainID(chain *big.Int) error {
if _, exist := t.Domain[chainIDKey]; !exist {
return fmt.Errorf("domain misses chain key %s", chainIDKey)
if _, exist := t.Domain[ChainIDKey]; !exist {
return fmt.Errorf("domain misses chain key %s", ChainIDKey)
}
var chainID int64
if err := json.Unmarshal(t.Domain[chainIDKey], &chainID); err != nil {
if err := json.Unmarshal(t.Domain[ChainIDKey], &chainID); err != nil {
var chainIDString string
if err = json.Unmarshal(t.Domain[chainIDKey], &chainIDString); err != nil {
if err = json.Unmarshal(t.Domain[ChainIDKey], &chainIDString); err != nil {
return err
}
if chainID, err = strconv.ParseInt(chainIDString, 0, 64); err != nil {

View File

@ -263,14 +263,18 @@ func SafeSignTypedDataForDApps(typedJson string, privateKey *ecdsa.PrivateKey, c
}
chain := new(big.Int).SetUint64(chainID)
if err := typed.ValidateChainID(chain); err != nil {
return types.HexBytes{}, err
}
var sig hexutil.Bytes
if legacy {
sig, err = typeddata.Sign(typed, privateKey, chain)
} else {
// Validate chainID if part of the typed data
if _, exist := typed.Domain[typeddata.ChainIDKey]; exist {
if err := typed.ValidateChainID(chain); err != nil {
return types.HexBytes{}, err
}
}
var typedV4 signercore.TypedData
err = json.Unmarshal([]byte(typedJson), &typedV4)
if err != nil {

View File

@ -428,10 +428,28 @@ func Test_AddSession(t *testing.T) {
assert.Equal(t, sessions[0].IconURL, dapps[0].IconURL)
}
func generateTypedDataJson(chainID int, skipField bool) string {
type typedDataParams struct {
chainID int
skipField bool
excludeChainID bool
wrongContractType bool
}
func generateTypedDataJson(p typedDataParams) string {
optionalKeyValueField := ""
if !skipField {
optionalKeyValueField = `,"verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC"`
if !p.skipField {
if p.wrongContractType {
optionalKeyValueField = `,"verifyingContract": true`
} else {
optionalKeyValueField = `,"verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC"`
}
}
chainIDSchemeEntry := ""
chainIDDataEntry := ""
if !p.excludeChainID {
chainIDSchemeEntry = `{"name": "chainId", "type": "uint256"},`
chainIDDataEntry = `,"chainId": ` + strconv.Itoa(p.chainID)
}
typedData := `{
@ -439,7 +457,7 @@ func generateTypedDataJson(chainID int, skipField bool) string {
"EIP712Domain": [
{"name": "name", "type": "string"},
{"name": "version", "type": "string"},
{"name": "chainId", "type": "uint256"},
` + chainIDSchemeEntry + `
{"name": "verifyingContract", "type": "address"}
],
"Person": [
@ -455,8 +473,8 @@ func generateTypedDataJson(chainID int, skipField bool) string {
"primaryType": "Mail",
"domain": {
"name": "Ether Mail",
"version": "1",
"chainId": ` + strconv.Itoa(chainID) + `
"version": "1"
` + chainIDDataEntry + `
` + optionalKeyValueField + `
},
"message": {
@ -492,7 +510,9 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
{
name: "sign_typed_data",
args: args{
typedJson: generateTypedDataJson(1, false),
typedJson: generateTypedDataJson(typedDataParams{
chainID: 1,
}),
privateKey: privateKey,
chainID: 1,
legacy: false,
@ -502,7 +522,9 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
{
name: "sign_typed_data_legacy",
args: args{
typedJson: generateTypedDataJson(1, false),
typedJson: generateTypedDataJson(typedDataParams{
chainID: 1,
}),
privateKey: privateKey,
chainID: 1,
legacy: true,
@ -512,17 +534,32 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
{
name: "sign_typed_data_invalid_json",
args: args{
typedJson: `{"invalid": "json"`,
typedJson: generateTypedDataJson(typedDataParams{
chainID: 1,
wrongContractType: true,
}),
privateKey: privateKey,
chainID: 1,
legacy: false,
},
wantErr: true,
},
{
name: "sign_typed_data_invalid_json_legacy",
args: args{
typedJson: `{"invalid": "json"`,
privateKey: privateKey,
chainID: 1,
legacy: true,
},
wantErr: true,
},
{
name: "sign_typed_data_invalid_chain_id",
args: args{
typedJson: generateTypedDataJson(1, false),
typedJson: generateTypedDataJson(typedDataParams{
chainID: 1,
}),
privateKey: privateKey,
chainID: 2,
legacy: false,
@ -532,13 +569,29 @@ func TestSafeSignTypedDataForDApps(t *testing.T) {
{
name: "sign_typed_data_missing_field",
args: args{
typedJson: generateTypedDataJson(1, true),
typedJson: generateTypedDataJson(typedDataParams{
chainID: 1,
skipField: true,
}),
privateKey: privateKey,
chainID: 1,
legacy: false,
},
wantErr: true,
},
{
name: "sign_typed_data_exclude_chain_id",
args: args{
typedJson: generateTypedDataJson(typedDataParams{
chainID: 1,
excludeChainID: true,
}),
privateKey: privateKey,
chainID: 1,
legacy: false,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {