diff --git a/services/connector/api.go b/services/connector/api.go index aa6478e06..2fee8aaca 100644 --- a/services/connector/api.go +++ b/services/connector/api.go @@ -28,7 +28,10 @@ func NewAPI(s *Service) *API { }) // Active chain per dapp management - r.Register("eth_chainId", &commands.ChainIDCommand{Db: s.db}) + r.Register("eth_chainId", &commands.ChainIDCommand{ + Db: s.db, + NetworkManager: s.nm, + }) r.Register("wallet_switchEthereumChain", &commands.SwitchEthereumChainCommand{ Db: s.db, NetworkManager: s.nm, diff --git a/services/connector/chainutils/utils.go b/services/connector/chainutils/utils.go new file mode 100644 index 000000000..a1dfc090e --- /dev/null +++ b/services/connector/chainutils/utils.go @@ -0,0 +1,41 @@ +package chainutils + +import ( + "errors" + + "github.com/status-im/status-go/params" +) + +type NetworkManagerInterface interface { + GetActiveNetworks() ([]*params.Network, error) +} + +var ErrNoActiveNetworks = errors.New("no active networks available") + +// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager. +func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) { + activeNetworks, err := networkManager.GetActiveNetworks() + if err != nil { + return nil, err + } + + if len(activeNetworks) < 1 { + return nil, ErrNoActiveNetworks + } + + chainIDs := make([]uint64, len(activeNetworks)) + for i, network := range activeNetworks { + chainIDs[i] = network.ChainID + } + + return chainIDs, nil +} + +func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) { + chainIDs, err := GetSupportedChainIDs(networkManager) + if err != nil { + return 0, err + } + + return chainIDs[0], nil +} diff --git a/services/connector/commands/chain_id.go b/services/connector/commands/chain_id.go index ed5bdbec6..6dad4bcb7 100644 --- a/services/connector/commands/chain_id.go +++ b/services/connector/commands/chain_id.go @@ -2,13 +2,16 @@ package commands import ( "database/sql" + "strconv" + "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" walletCommon "github.com/status-im/status-go/services/wallet/common" ) type ChainIDCommand struct { - Db *sql.DB + NetworkManager NetworkManagerInterface + Db *sql.DB } func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) { @@ -23,7 +26,11 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) { } if dApp == nil { - return "", ErrDAppIsNotPermittedByUser + defaultChainID, err := chainutils.GetDefaultChainID(c.NetworkManager) + if err != nil { + return "", err + } + return strconv.FormatUint(defaultChainID, 16), nil } return walletCommon.ChainID(dApp.ChainID).String(), nil diff --git a/services/connector/commands/chain_id_test.go b/services/connector/commands/chain_id_test.go index 77061f39f..f10030608 100644 --- a/services/connector/commands/chain_id_test.go +++ b/services/connector/commands/chain_id_test.go @@ -1,19 +1,40 @@ package commands import ( + "database/sql" + "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/params" walletCommon "github.com/status-im/status-go/services/wallet/common" ) +func setupNetworks(db *sql.DB) *ChainIDCommand { + nm := NetworkManagerMock{} + nm.SetNetworks([]*params.Network{ + { + ChainID: walletCommon.EthereumMainnet, + }, + { + ChainID: walletCommon.EthereumGoerli, + }, + }) + cmd := &ChainIDCommand{ + Db: db, + NetworkManager: &nm, + } + + return cmd +} + func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) { db, close := SetupTestDB(t) defer close() - cmd := &ChainIDCommand{Db: db} + cmd := setupNetworks(db) // Missing DApp fields request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, nil) @@ -24,25 +45,25 @@ func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) { assert.Empty(t, result) } -func TestFailToGetChainIdForUnpermittedDApp(t *testing.T) { +func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) { db, close := SetupTestDB(t) defer close() - cmd := &ChainIDCommand{Db: db} + cmd := setupNetworks(db) request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, &testDAppData) assert.NoError(t, err) result, err := cmd.Execute(request) - assert.Equal(t, ErrDAppIsNotPermittedByUser, err) - assert.Empty(t, result) + assert.NoError(t, err) + assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16)) } func TestGetChainIdForPermittedDApp(t *testing.T) { db, close := SetupTestDB(t) defer close() - cmd := &ChainIDCommand{Db: db} + cmd := setupNetworks(db) sharedAccount := types.HexToAddress("0x6d0aa2a774b74bb1d36f97700315adf962c69fcg") chainID := uint64(0x123) diff --git a/services/connector/commands/switch_ethereum_chain.go b/services/connector/commands/switch_ethereum_chain.go index ce0650650..0bfb89cf8 100644 --- a/services/connector/commands/switch_ethereum_chain.go +++ b/services/connector/commands/switch_ethereum_chain.go @@ -5,6 +5,7 @@ import ( "errors" "slices" + "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" walletCommon "github.com/status-im/status-go/services/wallet/common" ) @@ -41,21 +42,7 @@ func (r *RPCRequest) getChainID() (uint64, error) { } func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) { - activeNetworks, err := c.NetworkManager.GetActiveNetworks() - if err != nil { - return nil, err - } - - if len(activeNetworks) < 1 { - return nil, ErrNoActiveNetworks - } - - chainIDs := make([]uint64, len(activeNetworks)) - for i, network := range activeNetworks { - chainIDs[i] = network.ChainID - } - - return chainIDs, nil + return chainutils.GetSupportedChainIDs(c.NetworkManager) } func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error) {