fix_: return default chainID instead of throwing error for unregistered dApp (#5584) (#5598)

fix #5583
This commit is contained in:
Godfrain Jacques 2024-07-26 13:19:03 -07:00 committed by GitHub
parent 2013b65c58
commit 13b78d2679
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 83 additions and 24 deletions

View File

@ -28,7 +28,10 @@ func NewAPI(s *Service) *API {
}) })
// Active chain per dapp management // Active chain per dapp management
r.Register("eth_chainId", &commands.ChainIDCommand{Db: s.db}) r.Register("eth_chainId", &commands.ChainIDCommand{
Db: s.db,
NetworkManager: s.nm,
})
r.Register("wallet_switchEthereumChain", &commands.SwitchEthereumChainCommand{ r.Register("wallet_switchEthereumChain", &commands.SwitchEthereumChainCommand{
Db: s.db, Db: s.db,
NetworkManager: s.nm, NetworkManager: s.nm,

View File

@ -0,0 +1,41 @@
package chainutils
import (
"errors"
"github.com/status-im/status-go/params"
)
type NetworkManagerInterface interface {
GetActiveNetworks() ([]*params.Network, error)
}
var ErrNoActiveNetworks = errors.New("no active networks available")
// GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager.
func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) {
activeNetworks, err := networkManager.GetActiveNetworks()
if err != nil {
return nil, err
}
if len(activeNetworks) < 1 {
return nil, ErrNoActiveNetworks
}
chainIDs := make([]uint64, len(activeNetworks))
for i, network := range activeNetworks {
chainIDs[i] = network.ChainID
}
return chainIDs, nil
}
func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) {
chainIDs, err := GetSupportedChainIDs(networkManager)
if err != nil {
return 0, err
}
return chainIDs[0], nil
}

View File

@ -2,13 +2,16 @@ package commands
import ( import (
"database/sql" "database/sql"
"strconv"
"github.com/status-im/status-go/services/connector/chainutils"
persistence "github.com/status-im/status-go/services/connector/database" persistence "github.com/status-im/status-go/services/connector/database"
walletCommon "github.com/status-im/status-go/services/wallet/common" walletCommon "github.com/status-im/status-go/services/wallet/common"
) )
type ChainIDCommand struct { type ChainIDCommand struct {
Db *sql.DB NetworkManager NetworkManagerInterface
Db *sql.DB
} }
func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) { func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
@ -23,7 +26,11 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (string, error) {
} }
if dApp == nil { if dApp == nil {
return "", ErrDAppIsNotPermittedByUser defaultChainID, err := chainutils.GetDefaultChainID(c.NetworkManager)
if err != nil {
return "", err
}
return strconv.FormatUint(defaultChainID, 16), nil
} }
return walletCommon.ChainID(dApp.ChainID).String(), nil return walletCommon.ChainID(dApp.ChainID).String(), nil

View File

@ -1,19 +1,40 @@
package commands package commands
import ( import (
"database/sql"
"strconv"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params"
walletCommon "github.com/status-im/status-go/services/wallet/common" walletCommon "github.com/status-im/status-go/services/wallet/common"
) )
func setupNetworks(db *sql.DB) *ChainIDCommand {
nm := NetworkManagerMock{}
nm.SetNetworks([]*params.Network{
{
ChainID: walletCommon.EthereumMainnet,
},
{
ChainID: walletCommon.EthereumGoerli,
},
})
cmd := &ChainIDCommand{
Db: db,
NetworkManager: &nm,
}
return cmd
}
func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) { func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) {
db, close := SetupTestDB(t) db, close := SetupTestDB(t)
defer close() defer close()
cmd := &ChainIDCommand{Db: db} cmd := setupNetworks(db)
// Missing DApp fields // Missing DApp fields
request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, nil) request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, nil)
@ -24,25 +45,25 @@ func TestFailToGetChainIdWithMissingDAppFields(t *testing.T) {
assert.Empty(t, result) assert.Empty(t, result)
} }
func TestFailToGetChainIdForUnpermittedDApp(t *testing.T) { func TestGetDefaultChainIdForUnpermittedDApp(t *testing.T) {
db, close := SetupTestDB(t) db, close := SetupTestDB(t)
defer close() defer close()
cmd := &ChainIDCommand{Db: db} cmd := setupNetworks(db)
request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, &testDAppData) request, err := ConstructRPCRequest("eth_chainId", []interface{}{}, &testDAppData)
assert.NoError(t, err) assert.NoError(t, err)
result, err := cmd.Execute(request) result, err := cmd.Execute(request)
assert.Equal(t, ErrDAppIsNotPermittedByUser, err) assert.NoError(t, err)
assert.Empty(t, result) assert.Equal(t, result, strconv.FormatUint(walletCommon.EthereumMainnet, 16))
} }
func TestGetChainIdForPermittedDApp(t *testing.T) { func TestGetChainIdForPermittedDApp(t *testing.T) {
db, close := SetupTestDB(t) db, close := SetupTestDB(t)
defer close() defer close()
cmd := &ChainIDCommand{Db: db} cmd := setupNetworks(db)
sharedAccount := types.HexToAddress("0x6d0aa2a774b74bb1d36f97700315adf962c69fcg") sharedAccount := types.HexToAddress("0x6d0aa2a774b74bb1d36f97700315adf962c69fcg")
chainID := uint64(0x123) chainID := uint64(0x123)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"slices" "slices"
"github.com/status-im/status-go/services/connector/chainutils"
persistence "github.com/status-im/status-go/services/connector/database" persistence "github.com/status-im/status-go/services/connector/database"
walletCommon "github.com/status-im/status-go/services/wallet/common" walletCommon "github.com/status-im/status-go/services/wallet/common"
) )
@ -41,21 +42,7 @@ func (r *RPCRequest) getChainID() (uint64, error) {
} }
func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) { func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
activeNetworks, err := c.NetworkManager.GetActiveNetworks() return chainutils.GetSupportedChainIDs(c.NetworkManager)
if err != nil {
return nil, err
}
if len(activeNetworks) < 1 {
return nil, ErrNoActiveNetworks
}
chainIDs := make([]uint64, len(activeNetworks))
for i, network := range activeNetworks {
chainIDs[i] = network.ChainID
}
return chainIDs, nil
} }
func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error) { func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (string, error) {