fix_: fix switchEthereum api when processing chain input string (#5589)

fix #5587
This commit is contained in:
Godfrain Jacques 2024-07-26 13:00:12 -07:00 committed by GitHub
parent 3792d37df4
commit 4ddf9f2727
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 77 additions and 23 deletions

View File

@ -2,6 +2,8 @@ package chainutils
import ( import (
"errors" "errors"
"fmt"
"strconv"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
) )
@ -10,7 +12,10 @@ type NetworkManagerInterface interface {
GetActiveNetworks() ([]*params.Network, error) GetActiveNetworks() ([]*params.Network, error)
} }
var ErrNoActiveNetworks = errors.New("no active networks available") var (
ErrNoActiveNetworks = errors.New("no active networks available")
ErrUnsupportedNetwork = errors.New("unsupported network")
)
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager. // GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) { func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
@ -39,3 +44,14 @@ func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) {
return chainIDs[0], nil return chainIDs[0], nil
} }
func GetHexChainID(decimalStr string) (string, error) {
decimalValue, err := strconv.ParseInt(decimalStr, 10, 64)
if err != nil {
return "", ErrUnsupportedNetwork
}
hexStr := fmt.Sprintf(`0x%s`, strconv.FormatInt(decimalValue, 16))
return hexStr, nil
}

View File

@ -30,8 +30,19 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return strconv.FormatUint(defaultChainID, 16), nil
chainId, err := chainutils.GetHexChainID(strconv.FormatUint(defaultChainID, 16))
if err != nil {
return "", err
}
return chainId, nil
} }
return walletCommon.ChainID(dApp.ChainID).String(), nil chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
if err != nil {
return "", err
}
return chainId, nil
} }

View File

@ -2,6 +2,7 @@ package commands
import ( import (
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"testing" "testing"
@ -9,6 +10,7 @@ import (
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/services/connector/chainutils"
walletCommon "github.com/status-im/status-go/services/wallet/common" walletCommon "github.com/status-im/status-go/services/wallet/common"
) )
@ -56,7 +58,9 @@ func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) {
result, err := cmd.Execute(request) result, err := cmd.Execute(request)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16)) chainId, err := chainutils.GetHexChainID(strconv.FormatUint(walletCommon.EthereumMainnet, 16))
assert.NoError(t, err)
assert.Equal(t, result, chainId)
} }
func TestGetChainIdForPermittedDApp(t *testing.T) { func TestGetChainIdForPermittedDApp(t *testing.T) {
@ -76,5 +80,6 @@ func TestGetChainIdForPermittedDApp(t *testing.T) {
response, err := cmd.Execute(request) response, err := cmd.Execute(request)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, walletCommon.ChainID(chainID).String(), response) chainId := fmt.Sprintf(`0x%s`, strconv.FormatUint(chainID, 16))
assert.Equal(t, chainId, response)
} }

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"slices" "slices"
"strconv"
"github.com/status-im/status-go/services/connector/chainutils" "github.com/status-im/status-go/services/connector/chainutils"
persistence "github.com/status-im/status-go/services/connector/database" persistence "github.com/status-im/status-go/services/connector/database"
@ -22,23 +23,29 @@ type SwitchEthereumChainCommand struct {
Db *sql.DB Db *sql.DB
} }
func hexStringToUint64(s string) (uint64, error) {
if len(s) > 2 && s[:2] == "0x" {
value, err := strconv.ParseUint(s[2:], 16, 64)
if err != nil {
return 0, err
}
return value, nil
}
return 0, ErrUnsupportedNetwork
}
func (r *RPCRequest) getChainID() (uint64, error) { func (r *RPCRequest) getChainID() (uint64, error) {
if r.Params == nil || len(r.Params) == 0 { if r.Params == nil || len(r.Params) == 0 {
return 0, ErrEmptyRPCParams return 0, ErrEmptyRPCParams
} }
switch v := r.Params[0].(type) { chainIds := r.Params[0].(map[string]interface{})
case float64:
return uint64(v), nil for _, chainId := range chainIds {
case int: return hexStringToUint64(chainId.(string))
return uint64(v), nil
case int64:
return uint64(v), nil
case uint64:
return v, nil
default:
return 0, ErrNoChainIDParamsFound
} }
return 0, nil
} }
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) { func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
@ -70,6 +77,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
return "", err return "", err
} }
if dApp == nil {
return "", ErrDAppIsNotPermittedByUser
}
dApp.ChainID = requestedChainID dApp.ChainID = requestedChainID
err = persistence.UpsertDApp(c.Db, dApp) err = persistence.UpsertDApp(c.Db, dApp)
@ -77,5 +88,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error)
return "", err return "", err
} }
return walletCommon.ChainID(dApp.ChainID).String(), nil chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
if err != nil {
return "", err
}
return chainId, nil
} }

View File

@ -1,6 +1,7 @@
package commands package commands
import ( import (
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -55,7 +56,9 @@ func TestFailToSwitchEthereumChainWithUnsupportedChainId(t *testing.T) {
} }
params := make([]interface{}, 1) params := make([]interface{}, 1)
params[0] = walletCommon.BinanceTestChainID // some unrecoginzed chain id params[0] = map[string]interface{}{
"chainId": "0x1a343",
} // some unrecoginzed chain id
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData) request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
assert.NoError(t, err) assert.NoError(t, err)
@ -84,7 +87,9 @@ func TestSwitchEthereumChain(t *testing.T) {
} }
params := make([]interface{}, 1) params := make([]interface{}, 1)
params[0] = walletCommon.EthereumMainnet params[0] = map[string]interface{}{
"chainId": "0x1",
}
request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData) request, err := ConstructRPCRequest("wallet_switchEthereumChain", params, &testDAppData)
assert.NoError(t, err) assert.NoError(t, err)
@ -94,5 +99,6 @@ func TestSwitchEthereumChain(t *testing.T) {
response, err := cmd.Execute(request) response, err := cmd.Execute(request)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, walletCommon.ChainID(walletCommon.EthereumMainnet).String(), response) chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String())
assert.Equal(t, chainId, response)
} }

View File

@ -83,9 +83,9 @@ func TestRequestAccountsSwitchChainAndSendTransactionFlow(t *testing.T) {
assert.Equal(t, expectedResponse, response) assert.Equal(t, expectedResponse, response)
// Request to switch ethereum chain // Request to switch ethereum chain
expectedChainId := 0x5 expectedChainId := "0x5"
request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [%d], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId) request = fmt.Sprintf("{\"method\": \"wallet_switchEthereumChain\", \"params\": [{\"chainId\": \"%s\"}], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }", expectedChainId)
expectedResponse = fmt.Sprintf(`%d`, expectedChainId) expectedResponse = expectedChainId
response, err = api.CallRPC(request) response, err = api.CallRPC(request)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedResponse, response) assert.Equal(t, expectedResponse, response)