status-go/services/typeddata/types.go

94 lines
2.5 KiB
Go

package typeddata
import (
"encoding/json"
"errors"
"fmt"
"math/big"
"strconv"
)
const (
eip712Domain = "EIP712Domain"
ChainIDKey = "chainId"
)
// Types define fields for each composite type.
type Types map[string][]Field
// Field stores name and solidity type of the field.
type Field struct {
Name string `json:"name"`
Type string `json:"type"`
}
// Validate checks that both name and type are not empty.
func (f Field) Validate() error {
if len(f.Name) == 0 {
return errors.New("`name` is required")
}
if len(f.Type) == 0 {
return errors.New("`type` is required")
}
return nil
}
// TypedData defines typed data according to eip-712.
type TypedData struct {
Types Types `json:"types"`
PrimaryType string `json:"primaryType"`
Domain map[string]json.RawMessage `json:"domain"`
Message map[string]json.RawMessage `json:"message"`
}
// Validate that required fields are defined.
// This method doesn't check if dependencies of the main type are defined, it will be validated
// when type string is computed.
func (t TypedData) Validate() error {
if _, exist := t.Types[eip712Domain]; !exist {
return fmt.Errorf("`%s` must be in `types`", eip712Domain)
}
if t.PrimaryType == "" {
return errors.New("`primaryType` is required")
}
if _, exist := t.Types[t.PrimaryType]; !exist {
return fmt.Errorf("primary type `%s` not defined in types", t.PrimaryType)
}
if t.Domain == nil {
return errors.New("`domain` is required")
}
if t.Message == nil {
return errors.New("`message` is required")
}
for typ := range t.Types {
fields := t.Types[typ]
for i := range fields {
if err := fields[i].Validate(); err != nil {
return fmt.Errorf("field %d from type `%s` is invalid: %v", i, typ, err)
}
}
}
return nil
}
// ValidateChainID accept chain as big integer and verifies if typed data belongs to the same chain.
func (t TypedData) ValidateChainID(chain *big.Int) error {
if _, exist := t.Domain[ChainIDKey]; !exist {
return fmt.Errorf("domain misses chain key %s", ChainIDKey)
}
var chainID int64
if err := json.Unmarshal(t.Domain[ChainIDKey], &chainID); err != nil {
var chainIDString string
if err = json.Unmarshal(t.Domain[ChainIDKey], &chainIDString); err != nil {
return err
}
if chainID, err = strconv.ParseInt(chainIDString, 0, 64); err != nil {
return err
}
}
if chainID != chain.Int64() {
return fmt.Errorf("chainId %d doesn't match selected chain %s", chainID, chain)
}
return nil
}