fix_: return default chainID instead of throwing error for unregistered dApp (#5584)

fix #5583
This commit is contained in:
Godfrain Jacques 2024-07-26 09:18:01 -07:00 committed by GitHub
parent 5d113071db
commit f6d7d1429c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 83 additions and 24 deletions

View File

@ -28,7 +28,10 @@ func NewAPI(s *Service) *API {
})
// Active chain per dapp management
r.Register("eth_chainId", &commands.ChainIDCommand{Db: s.db})
r.Register("eth_chainId", &commands.ChainIDCommand{
Db: s.db,
NetworkManager: s.nm,
})
r.Register("wallet_switchEthereumChain", &commands.SwitchEthereumChainCommand{
Db: s.db,
NetworkManager: s.nm,

View File

@ -0,0 +1,41 @@
package chainutils
import (
"errors"
"github.com/status-im/status-go/params"
)
type NetworkManagerInterface interface {
GetActiveNetworks() ([]*params.Network, error)
}
var ErrNoActiveNetworks = errors.New("no active networks available")
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
activeNetworks, err := networkManager.GetActiveNetworks()
if err != nil {
return nil, err
}
if len(activeNetworks) < 1 {
return nil, ErrNoActiveNetworks
}
chainIDs := make([]uint64, len(activeNetworks))
for i, network := range activeNetworks {
chainIDs[i] = network.ChainID
}
return chainIDs, nil
}
func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) {
chainIDs, err := GetSupportedChainIDs(networkManager)
if err != nil {
return 0, err
}
return chainIDs[0], nil
}

View File

@ -2,13 +2,16 @@ package commands
import (
"database/sql"
"strconv"
"github.com/status-im/status-go/services/connector/chainutils"
persistence "github.com/status-im/status-go/services/connector/database"
walletCommon "github.com/status-im/status-go/services/wallet/common"
)
type ChainIDCommand struct {
Db *sql.DB
NetworkManager NetworkManagerInterface
Db *sql.DB
}
func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
@ -23,7 +26,11 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
}
if dApp == nil {
return "", ErrDAppIsNotPermittedByUser
defaultChainID, err := chainutils.GetDefaultChainID(c.NetworkManager)
if err != nil {
return "", err
}
return strconv.FormatUint(defaultChainID, 16), nil
}
return walletCommon.ChainID(dApp.ChainID).String(), nil

View File

@ -1,19 +1,40 @@
package commands
import (
"database/sql"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params"
walletCommon "github.com/status-im/status-go/services/wallet/common"
)
func setupNetworks(db *sql.DB) *ChainIDCommand {
nm := NetworkManagerMock{}
nm.SetNetworks([]*params.Network{
{
ChainID: walletCommon.EthereumMainnet,
},
{
ChainID: walletCommon.EthereumGoerli,
},
})
cmd := &ChainIDCommand{
Db: db,
NetworkManager: &nm,
}
return cmd
}
func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) {
db, close := SetupTestDB(t)
defer close()
cmd := &ChainIDCommand{Db: db}
cmd := setupNetworks(db)
// Missing DApp fields
request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, nil)
@ -24,25 +45,25 @@ func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) {
assert.Empty(t, result)
}
func TestFailToGetChainIdForUnpermittedDApp(t *testing.T) {
func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) {
db, close := SetupTestDB(t)
defer close()
cmd := &ChainIDCommand{Db: db}
cmd := setupNetworks(db)
request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, &testDAppData)
assert.NoError(t, err)
result, err := cmd.Execute(request)
assert.Equal(t, ErrDAppIsNotPermittedByUser, err)
assert.Empty(t, result)
assert.NoError(t, err)
assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16))
}
func TestGetChainIdForPermittedDApp(t *testing.T) {
db, close := SetupTestDB(t)
defer close()
cmd := &ChainIDCommand{Db: db}
cmd := setupNetworks(db)
sharedAccount := types.HexToAddress("0x6d0aa2a774b74bb1d36f97700315adf962c69fcg")
chainID := uint64(0x123)

View File

@ -5,6 +5,7 @@ import (
"errors"
"slices"
"github.com/status-im/status-go/services/connector/chainutils"
persistence "github.com/status-im/status-go/services/connector/database"
walletCommon "github.com/status-im/status-go/services/wallet/common"
)
@ -41,21 +42,7 @@ func (r *RPCRequest) getChainID() (uint64, error) {
}
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
activeNetworks, err := c.NetworkManager.GetActiveNetworks()
if err != nil {
return nil, err
}
if len(activeNetworks) < 1 {
return nil, ErrNoActiveNetworks
}
chainIDs := make([]uint64, len(activeNetworks))
for i, network := range activeNetworks {
chainIDs[i] = network.ChainID
}
return chainIDs, nil
return chainutils.GetSupportedChainIDs(c.NetworkManager)
}
func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error) {