fix_: fix switchEthereum api when processing chain input string (#5603)

fix #5587
This commit is contained in:
Godfrain Jacques 2024-07-26 15:12:33 -07:00 committed by Siddarth Kumar
parent 3ca29b87c3
commit 07614f6640
No known key found for this signature in database
GPG Key ID: 599D10112BF518DB
6 changed files with 77 additions and 23 deletions

View File

@ -2,6 +2,8 @@ package chainutils
import (
"errors"
"fmt"
"strconv"
"github.com/status-im/status-go/params"
)
@ -10,7 +12,10 @@ type NetworkManagerInterface interface {
GetActiveNetworks() ([]*params.Network, error)
}
var ErrNoActiveNetworks = errors.New("no active networks available")
var (
ErrNoActiveNetworks = errors.New("no active networks available")
ErrUnsupportedNetwork = errors.New("unsupported network")
)
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
@ -39,3 +44,14 @@ func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) {
return chainIDs[0], nil
}
func GetHexChainID(decimalStr string) (string, error) {
decimalValue, err := strconv.ParseInt(decimalStr, 10, 64)
if err != nil {
return "", ErrUnsupportedNetwork
}
hexStr := fmt.Sprintf(`0x%s`, strconv.FormatInt(decimalValue, 16))
return hexStr, nil
}

View File

@ -30,8 +30,19 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
if err != nil {
return "", err
}
return strconv.FormatUint(defaultChainID, 16), nil
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(defaultChainID, 16))
if err != nil {
return "", err
}
return chainId, nil
}
return walletCommon.ChainID(dApp.ChainID).String(), nil
chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
if err != nil {
return "", err
}
return chainId, nil
}

View File

@ -2,6 +2,7 @@ package commands
import (
"database/sql"
"fmt"
"strconv"
"testing"
@ -9,6 +10,7 @@ import (
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/services/connector/chainutils"
walletCommon "github.com/status-im/status-go/services/wallet/common"
)
@ -56,7 +58,9 @@ func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) {
result, err := cmd.Execute(request)
assert.NoError(t, err)
assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16))
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(walletCommon.EthereumMainnet, 16))
assert.NoError(t, err)
assert.Equal(t, result, chainId)
}
func TestGetChainIdForPermittedDApp(t *testing.T) {
@ -76,5 +80,6 @@ func TestGetChainIdForPermittedDApp(t *testing.T) {
response, err := cmd.Execute(request)
assert.NoError(t, err)
assert.Equal(t, walletCommon.ChainID(chainID).String(), response)
chainId := fmt.Sprintf(`0x%s`, strconv.FormatUint(chainID, 16))
assert.Equal(t, chainId, response)
}

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"slices"
"strconv"
"github.com/status-im/status-go/services/connector/chainutils"
persistence "github.com/status-im/status-go/services/connector/database"
@ -22,23 +23,29 @@ type SwitchEthereumChainCommand struct {
Db *sql.DB
}
func hexStringToUint64(s string) (uint64, error) {
if len(s) > 2 && s[:2] == "0x" {
value, err := strconv.ParseUint(s[2:], 16, 64)
if err != nil {
return 0, err
}
return value, nil
}
return 0, ErrUnsupportedNetwork
}
func (r *RPCRequest) getChainID() (uint64, error) {
if r.Params == nil || len(r.Params) == 0 {
return 0, ErrEmptyRPCParams
}
switch v := r.Params[0].(type) {
case float64:
return uint64(v), nil
case int:
return uint64(v), nil
case int64:
return uint64(v), nil
case uint64:
return v, nil
default:
return 0, ErrNoChainIDParamsFound
chainIds := r.Params[0].(map[string]interface{})
for _, chainId := range chainIds {
return hexStringToUint64(chainId.(string))
}
return 0, nil
}
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
@ -70,6 +77,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
return "", err
}
if dApp == nil {
return "", ErrDAppIsNotPermittedByUser
}
dApp.ChainID = requestedChainID
err = persistence.UpsertDApp(c.Db, dApp)
@ -77,5 +88,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
return "", err
}
return walletCommon.ChainID(dApp.ChainID).String(), nil
chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
if err != nil {
return "", err
}
return chainId, nil
}

View File

@ -1,6 +1,7 @@
package commands
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@ -55,7 +56,9 @@ func TestFailToSwitchEthereumChainWithUnsupportedChainId(t *testing.T) {
}
params := make([]interface{}, 1)
params[0] = walletCommon.BinanceTestChainID // some unrecoginzed chain id
params[0] = map[string]interface{}{
"chainId": "0x1a343",
} // some unrecoginzed chain id
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
assert.NoError(t, err)
@ -84,7 +87,9 @@ func TestSwitchEthereumChain(t *testing.T) {
}
params := make([]interface{}, 1)
params[0] = walletCommon.EthereumMainnet
params[0] = map[string]interface{}{
"chainId": "0x1",
}
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
assert.NoError(t, err)
@ -94,5 +99,6 @@ func TestSwitchEthereumChain(t *testing.T) {
response, err := cmd.Execute(request)
assert.NoError(t, err)
assert.Equal(t, walletCommon.ChainID(walletCommon.EthereumMainnet).String(), response)
chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String())
assert.Equal(t, chainId, response)
}

View File

@ -83,9 +83,9 @@ func TestRequestAccountsSwitchChainAndSendTransactionFlow(t *testing.T) {
assert.Equal(t, expectedResponse, response)
// Request to switch ethereum chain
expectedChainId := 0x5
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [%d], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
expectedResponse = fmt.Sprintf(`%d`, expectedChainId)
expectedChainId := "0x5"
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [{\"chainId\": \"%s\"}], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
expectedResponse = expectedChainId
response, err = api.CallRPC(request)
assert.NoError(t, err)
assert.Equal(t, expectedResponse, response)