From 4ddf9f272745d0cd3a31bcd7e89d70c4ecee096b Mon Sep 17 00:00:00 2001 From: Godfrain Jacques Date: Fri, 26 Jul 2024 13:00:12 -0700 Subject: [PATCH] fix_: fix switchEthereum api when processing chain input string (#5589) fix #5587 --- services/connector/chainutils/utils.go | 18 ++++++++- services/connector/commands/chain_id.go | 15 ++++++- services/connector/commands/chain_id_test.go | 9 ++++- .../commands/switch_ethereum_chain.go | 40 +++++++++++++------ .../commands/switch_ethereum_chain_test.go | 12 ++++-- services/connector/connector_flows_test.go | 6 +-- 6 files changed, 77 insertions(+), 23 deletions(-) diff --git a/services/connector/chainutils/utils.go b/services/connector/chainutils/utils.go index a1dfc090e..4e454a9b8 100644 --- a/services/connector/chainutils/utils.go +++ b/services/connector/chainutils/utils.go @@ -2,6 +2,8 @@ package chainutils import ( "errors" + "fmt" + "strconv" "github.com/status-im/status-go/params" ) @@ -10,7 +12,10 @@ type NetworkManagerInterface interface { GetActiveNetworks() ([]*params.Network, error) } -var ErrNoActiveNetworks = errors.New("no active networks available") +var ( + ErrNoActiveNetworks = errors.New("no active networks available") + ErrUnsupportedNetwork = errors.New("unsupported network") +) // GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager. func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) { @@ -39,3 +44,14 @@ func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) { return chainIDs[0], nil } + +func GetHexChainID(decimalStr string) (string, error) { + decimalValue, err := strconv.ParseInt(decimalStr, 10, 64) + if err != nil { + return "", ErrUnsupportedNetwork + } + + hexStr := fmt.Sprintf(`0x%s`, strconv.FormatInt(decimalValue, 16)) + + return hexStr, nil +} diff --git a/services/connector/commands/chain_id.go b/services/connector/commands/chain_id.go index 6dad4bcb7..d1b9d496d 100644 --- a/services/connector/commands/chain_id.go +++ b/services/connector/commands/chain_id.go @@ -30,8 +30,19 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) { if err != nil { return "", err } - return strconv.FormatUint(defaultChainID, 16), nil + + chainId, err := chainutils.GetHexChainID(strconv.FormatUint(defaultChainID, 16)) + if err != nil { + return "", err + } + + return chainId, nil } - return walletCommon.ChainID(dApp.ChainID).String(), nil + chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String()) + if err != nil { + return "", err + } + + return chainId, nil } diff --git a/services/connector/commands/chain_id_test.go b/services/connector/commands/chain_id_test.go index f10030608..d07559e54 100644 --- a/services/connector/commands/chain_id_test.go +++ b/services/connector/commands/chain_id_test.go @@ -2,6 +2,7 @@ package commands import ( "database/sql" + "fmt" "strconv" "testing" @@ -9,6 +10,7 @@ import ( "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" + "github.com/status-im/status-go/services/connector/chainutils" walletCommon "github.com/status-im/status-go/services/wallet/common" ) @@ -56,7 +58,9 @@ func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) { result, err := cmd.Execute(request) assert.NoError(t, err) - assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16)) + chainId, err := chainutils.GetHexChainID(strconv.FormatUint(walletCommon.EthereumMainnet, 16)) + assert.NoError(t, err) + assert.Equal(t, result, chainId) } func TestGetChainIdForPermittedDApp(t *testing.T) { @@ -76,5 +80,6 @@ func TestGetChainIdForPermittedDApp(t *testing.T) { response, err := cmd.Execute(request) assert.NoError(t, err) - assert.Equal(t, walletCommon.ChainID(chainID).String(), response) + chainId := fmt.Sprintf(`0x%s`, strconv.FormatUint(chainID, 16)) + assert.Equal(t, chainId, response) } diff --git a/services/connector/commands/switch_ethereum_chain.go b/services/connector/commands/switch_ethereum_chain.go index 0bfb89cf8..dee8ef7cd 100644 --- a/services/connector/commands/switch_ethereum_chain.go +++ b/services/connector/commands/switch_ethereum_chain.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "slices" + "strconv" "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" @@ -22,23 +23,29 @@ type SwitchEthereumChainCommand struct { Db *sql.DB } +func hexStringToUint64(s string) (uint64, error) { + if len(s) > 2 && s[:2] == "0x" { + value, err := strconv.ParseUint(s[2:], 16, 64) + if err != nil { + return 0, err + } + return value, nil + } + return 0, ErrUnsupportedNetwork +} + func (r *RPCRequest) getChainID() (uint64, error) { if r.Params == nil || len(r.Params) == 0 { return 0, ErrEmptyRPCParams } - switch v := r.Params[0].(type) { - case float64: - return uint64(v), nil - case int: - return uint64(v), nil - case int64: - return uint64(v), nil - case uint64: - return v, nil - default: - return 0, ErrNoChainIDParamsFound + chainIds := r.Params[0].(map[string]interface{}) + + for _, chainId := range chainIds { + return hexStringToUint64(chainId.(string)) } + + return 0, nil } func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) { @@ -70,6 +77,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error) return "", err } + if dApp == nil { + return "", ErrDAppIsNotPermittedByUser + } + dApp.ChainID = requestedChainID err = persistence.UpsertDApp(c.Db, dApp) @@ -77,5 +88,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error) return "", err } - return walletCommon.ChainID(dApp.ChainID).String(), nil + chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String()) + if err != nil { + return "", err + } + + return chainId, nil } diff --git a/services/connector/commands/switch_ethereum_chain_test.go b/services/connector/commands/switch_ethereum_chain_test.go index 0a3e2478f..479ca8bed 100644 --- a/services/connector/commands/switch_ethereum_chain_test.go +++ b/services/connector/commands/switch_ethereum_chain_test.go @@ -1,6 +1,7 @@ package commands import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -55,7 +56,9 @@ func TestFailToSwitchEthereumChainWithUnsupportedChainId(t *testing.T) { } params := make([]interface{}, 1) - params[0] = walletCommon.BinanceTestChainID // some unrecoginzed chain id + params[0] = map[string]interface{}{ + "chainId": "0x1a343", + } // some unrecoginzed chain id request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData) assert.NoError(t, err) @@ -84,7 +87,9 @@ func TestSwitchEthereumChain(t *testing.T) { } params := make([]interface{}, 1) - params[0] = walletCommon.EthereumMainnet + params[0] = map[string]interface{}{ + "chainId": "0x1", + } request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData) assert.NoError(t, err) @@ -94,5 +99,6 @@ func TestSwitchEthereumChain(t *testing.T) { response, err := cmd.Execute(request) assert.NoError(t, err) - assert.Equal(t, walletCommon.ChainID(walletCommon.EthereumMainnet).String(), response) + chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String()) + assert.Equal(t, chainId, response) } diff --git a/services/connector/connector_flows_test.go b/services/connector/connector_flows_test.go index 0615cb8bb..228c6fc00 100644 --- a/services/connector/connector_flows_test.go +++ b/services/connector/connector_flows_test.go @@ -83,9 +83,9 @@ func TestRequestAccountsSwitchChainAndSendTransactionFlow(t *testing.T) { assert.Equal(t, expectedResponse, response) // Request to switch ethereum chain - expectedChainId := 0x5 - request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [%d], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId) - expectedResponse = fmt.Sprintf(`%d`, expectedChainId) + expectedChainId := "0x5" + request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [{\"chainId\": \"%s\"}], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId) + expectedResponse = expectedChainId response, err = api.CallRPC(request) assert.NoError(t, err) assert.Equal(t, expectedResponse, response)