fix_: fix switchEthereum api when processing chain input string (#5603)
fix #5587
This commit is contained in:
parent
3ca29b87c3
commit
07614f6640
|
@ -2,6 +2,8 @@ package chainutils
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/status-im/status-go/params"
|
||||
)
|
||||
|
@ -10,7 +12,10 @@ type NetworkManagerInterface interface {
|
|||
GetActiveNetworks() ([]*params.Network, error)
|
||||
}
|
||||
|
||||
var ErrNoActiveNetworks = errors.New("no active networks available")
|
||||
var (
|
||||
ErrNoActiveNetworks = errors.New("no active networks available")
|
||||
ErrUnsupportedNetwork = errors.New("unsupported network")
|
||||
)
|
||||
|
||||
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
|
||||
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
|
||||
|
@ -39,3 +44,14 @@ func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) {
|
|||
|
||||
return chainIDs[0], nil
|
||||
}
|
||||
|
||||
func GetHexChainID(decimalStr string) (string, error) {
|
||||
decimalValue, err := strconv.ParseInt(decimalStr, 10, 64)
|
||||
if err != nil {
|
||||
return "", ErrUnsupportedNetwork
|
||||
}
|
||||
|
||||
hexStr := fmt.Sprintf(`0x%s`, strconv.FormatInt(decimalValue, 16))
|
||||
|
||||
return hexStr, nil
|
||||
}
|
||||
|
|
|
@ -30,8 +30,19 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatUint(defaultChainID, 16), nil
|
||||
|
||||
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(defaultChainID, 16))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return walletCommon.ChainID(dApp.ChainID).String(), nil
|
||||
return chainId, nil
|
||||
}
|
||||
|
||||
chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return chainId, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package commands
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
|
@ -9,6 +10,7 @@ import (
|
|||
|
||||
"github.com/status-im/status-go/eth-node/types"
|
||||
"github.com/status-im/status-go/params"
|
||||
"github.com/status-im/status-go/services/connector/chainutils"
|
||||
walletCommon "github.com/status-im/status-go/services/wallet/common"
|
||||
)
|
||||
|
||||
|
@ -56,7 +58,9 @@ func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) {
|
|||
|
||||
result, err := cmd.Execute(request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16))
|
||||
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(walletCommon.EthereumMainnet, 16))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result, chainId)
|
||||
}
|
||||
|
||||
func TestGetChainIdForPermittedDApp(t *testing.T) {
|
||||
|
@ -76,5 +80,6 @@ func TestGetChainIdForPermittedDApp(t *testing.T) {
|
|||
|
||||
response, err := cmd.Execute(request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, walletCommon.ChainID(chainID).String(), response)
|
||||
chainId := fmt.Sprintf(`0x%s`, strconv.FormatUint(chainID, 16))
|
||||
assert.Equal(t, chainId, response)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/status-im/status-go/services/connector/chainutils"
|
||||
persistence "github.com/status-im/status-go/services/connector/database"
|
||||
|
@ -22,23 +23,29 @@ type SwitchEthereumChainCommand struct {
|
|||
Db *sql.DB
|
||||
}
|
||||
|
||||
func hexStringToUint64(s string) (uint64, error) {
|
||||
if len(s) > 2 && s[:2] == "0x" {
|
||||
value, err := strconv.ParseUint(s[2:], 16, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
return 0, ErrUnsupportedNetwork
|
||||
}
|
||||
|
||||
func (r *RPCRequest) getChainID() (uint64, error) {
|
||||
if r.Params == nil || len(r.Params) == 0 {
|
||||
return 0, ErrEmptyRPCParams
|
||||
}
|
||||
|
||||
switch v := r.Params[0].(type) {
|
||||
case float64:
|
||||
return uint64(v), nil
|
||||
case int:
|
||||
return uint64(v), nil
|
||||
case int64:
|
||||
return uint64(v), nil
|
||||
case uint64:
|
||||
return v, nil
|
||||
default:
|
||||
return 0, ErrNoChainIDParamsFound
|
||||
chainIds := r.Params[0].(map[string]interface{})
|
||||
|
||||
for _, chainId := range chainIds {
|
||||
return hexStringToUint64(chainId.(string))
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
|
||||
|
@ -70,6 +77,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
|
|||
return "", err
|
||||
}
|
||||
|
||||
if dApp == nil {
|
||||
return "", ErrDAppIsNotPermittedByUser
|
||||
}
|
||||
|
||||
dApp.ChainID = requestedChainID
|
||||
|
||||
err = persistence.UpsertDApp(c.Db, dApp)
|
||||
|
@ -77,5 +88,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
|
|||
return "", err
|
||||
}
|
||||
|
||||
return walletCommon.ChainID(dApp.ChainID).String(), nil
|
||||
chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return chainId, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -55,7 +56,9 @@ func TestFailToSwitchEthereumChainWithUnsupportedChainId(t *testing.T) {
|
|||
}
|
||||
|
||||
params := make([]interface{}, 1)
|
||||
params[0] = walletCommon.BinanceTestChainID // some unrecoginzed chain id
|
||||
params[0] = map[string]interface{}{
|
||||
"chainId": "0x1a343",
|
||||
} // some unrecoginzed chain id
|
||||
|
||||
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
|
||||
assert.NoError(t, err)
|
||||
|
@ -84,7 +87,9 @@ func TestSwitchEthereumChain(t *testing.T) {
|
|||
}
|
||||
|
||||
params := make([]interface{}, 1)
|
||||
params[0] = walletCommon.EthereumMainnet
|
||||
params[0] = map[string]interface{}{
|
||||
"chainId": "0x1",
|
||||
}
|
||||
|
||||
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
|
||||
assert.NoError(t, err)
|
||||
|
@ -94,5 +99,6 @@ func TestSwitchEthereumChain(t *testing.T) {
|
|||
|
||||
response, err := cmd.Execute(request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, walletCommon.ChainID(walletCommon.EthereumMainnet).String(), response)
|
||||
chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String())
|
||||
assert.Equal(t, chainId, response)
|
||||
}
|
||||
|
|
|
@ -83,9 +83,9 @@ func TestRequestAccountsSwitchChainAndSendTransactionFlow(t *testing.T) {
|
|||
assert.Equal(t, expectedResponse, response)
|
||||
|
||||
// Request to switch ethereum chain
|
||||
expectedChainId := 0x5
|
||||
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [%d], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
|
||||
expectedResponse = fmt.Sprintf(`%d`, expectedChainId)
|
||||
expectedChainId := "0x5"
|
||||
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [{\"chainId\": \"%s\"}], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
|
||||
expectedResponse = expectedChainId
|
||||
response, err = api.CallRPC(request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedResponse, response)
|
||||
|
|
Loading…
Reference in New Issue