fix_: fix switchEthereum api when processing chain input string (#5589)
fix #5587
This commit is contained in:
parent
3792d37df4
commit
4ddf9f2727
|
@ -2,6 +2,8 @@ package chainutils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/status-im/status-go/params"
|
"github.com/status-im/status-go/params"
|
||||||
)
|
)
|
||||||
|
@ -10,7 +12,10 @@ type NetworkManagerInterface interface {
|
||||||
GetActiveNetworks() ([]*params.Network, error)
|
GetActiveNetworks() ([]*params.Network, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrNoActiveNetworks = errors.New("no active networks available")
|
var (
|
||||||
|
ErrNoActiveNetworks = errors.New("no active networks available")
|
||||||
|
ErrUnsupportedNetwork = errors.New("unsupported network")
|
||||||
|
)
|
||||||
|
|
||||||
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
|
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
|
||||||
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
|
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
|
||||||
|
@ -39,3 +44,14 @@ func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) {
|
||||||
|
|
||||||
return chainIDs[0], nil
|
return chainIDs[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetHexChainID(decimalStr string) (string, error) {
|
||||||
|
decimalValue, err := strconv.ParseInt(decimalStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return "", ErrUnsupportedNetwork
|
||||||
|
}
|
||||||
|
|
||||||
|
hexStr := fmt.Sprintf(`0x%s`, strconv.FormatInt(decimalValue, 16))
|
||||||
|
|
||||||
|
return hexStr, nil
|
||||||
|
}
|
||||||
|
|
|
@ -30,8 +30,19 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return strconv.FormatUint(defaultChainID, 16), nil
|
|
||||||
|
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(defaultChainID, 16))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return chainId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return walletCommon.ChainID(dApp.ChainID).String(), nil
|
chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return chainId, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package commands
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/status-im/status-go/eth-node/types"
|
"github.com/status-im/status-go/eth-node/types"
|
||||||
"github.com/status-im/status-go/params"
|
"github.com/status-im/status-go/params"
|
||||||
|
"github.com/status-im/status-go/services/connector/chainutils"
|
||||||
walletCommon "github.com/status-im/status-go/services/wallet/common"
|
walletCommon "github.com/status-im/status-go/services/wallet/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,7 +58,9 @@ func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) {
|
||||||
|
|
||||||
result, err := cmd.Execute(request)
|
result, err := cmd.Execute(request)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16))
|
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(walletCommon.EthereumMainnet, 16))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, result, chainId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetChainIdForPermittedDApp(t *testing.T) {
|
func TestGetChainIdForPermittedDApp(t *testing.T) {
|
||||||
|
@ -76,5 +80,6 @@ func TestGetChainIdForPermittedDApp(t *testing.T) {
|
||||||
|
|
||||||
response, err := cmd.Execute(request)
|
response, err := cmd.Execute(request)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, walletCommon.ChainID(chainID).String(), response)
|
chainId := fmt.Sprintf(`0x%s`, strconv.FormatUint(chainID, 16))
|
||||||
|
assert.Equal(t, chainId, response)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/status-im/status-go/services/connector/chainutils"
|
"github.com/status-im/status-go/services/connector/chainutils"
|
||||||
persistence "github.com/status-im/status-go/services/connector/database"
|
persistence "github.com/status-im/status-go/services/connector/database"
|
||||||
|
@ -22,23 +23,29 @@ type SwitchEthereumChainCommand struct {
|
||||||
Db *sql.DB
|
Db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hexStringToUint64(s string) (uint64, error) {
|
||||||
|
if len(s) > 2 && s[:2] == "0x" {
|
||||||
|
value, err := strconv.ParseUint(s[2:], 16, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
return 0, ErrUnsupportedNetwork
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RPCRequest) getChainID() (uint64, error) {
|
func (r *RPCRequest) getChainID() (uint64, error) {
|
||||||
if r.Params == nil || len(r.Params) == 0 {
|
if r.Params == nil || len(r.Params) == 0 {
|
||||||
return 0, ErrEmptyRPCParams
|
return 0, ErrEmptyRPCParams
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := r.Params[0].(type) {
|
chainIds := r.Params[0].(map[string]interface{})
|
||||||
case float64:
|
|
||||||
return uint64(v), nil
|
for _, chainId := range chainIds {
|
||||||
case int:
|
return hexStringToUint64(chainId.(string))
|
||||||
return uint64(v), nil
|
|
||||||
case int64:
|
|
||||||
return uint64(v), nil
|
|
||||||
case uint64:
|
|
||||||
return v, nil
|
|
||||||
default:
|
|
||||||
return 0, ErrNoChainIDParamsFound
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
|
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
|
||||||
|
@ -70,6 +77,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dApp == nil {
|
||||||
|
return "", ErrDAppIsNotPermittedByUser
|
||||||
|
}
|
||||||
|
|
||||||
dApp.ChainID = requestedChainID
|
dApp.ChainID = requestedChainID
|
||||||
|
|
||||||
err = persistence.UpsertDApp(c.Db, dApp)
|
err = persistence.UpsertDApp(c.Db, dApp)
|
||||||
|
@ -77,5 +88,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return walletCommon.ChainID(dApp.ChainID).String(), nil
|
chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return chainId, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package commands
|
package commands
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -55,7 +56,9 @@ func TestFailToSwitchEthereumChainWithUnsupportedChainId(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
params := make([]interface{}, 1)
|
params := make([]interface{}, 1)
|
||||||
params[0] = walletCommon.BinanceTestChainID // some unrecoginzed chain id
|
params[0] = map[string]interface{}{
|
||||||
|
"chainId": "0x1a343",
|
||||||
|
} // some unrecoginzed chain id
|
||||||
|
|
||||||
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
|
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -84,7 +87,9 @@ func TestSwitchEthereumChain(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
params := make([]interface{}, 1)
|
params := make([]interface{}, 1)
|
||||||
params[0] = walletCommon.EthereumMainnet
|
params[0] = map[string]interface{}{
|
||||||
|
"chainId": "0x1",
|
||||||
|
}
|
||||||
|
|
||||||
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
|
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -94,5 +99,6 @@ func TestSwitchEthereumChain(t *testing.T) {
|
||||||
|
|
||||||
response, err := cmd.Execute(request)
|
response, err := cmd.Execute(request)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, walletCommon.ChainID(walletCommon.EthereumMainnet).String(), response)
|
chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String())
|
||||||
|
assert.Equal(t, chainId, response)
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,9 +83,9 @@ func TestRequestAccountsSwitchChainAndSendTransactionFlow(t *testing.T) {
|
||||||
assert.Equal(t, expectedResponse, response)
|
assert.Equal(t, expectedResponse, response)
|
||||||
|
|
||||||
// Request to switch ethereum chain
|
// Request to switch ethereum chain
|
||||||
expectedChainId := 0x5
|
expectedChainId := "0x5"
|
||||||
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [%d], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
|
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [{\"chainId\": \"%s\"}], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
|
||||||
expectedResponse = fmt.Sprintf(`%d`, expectedChainId)
|
expectedResponse = expectedChainId
|
||||||
response, err = api.CallRPC(request)
|
response, err = api.CallRPC(request)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedResponse, response)
|
assert.Equal(t, expectedResponse, response)
|
||||||
|
|
Loading…
Reference in New Issue