From 20b0b382be08366e16356bc006a8118894496f28 Mon Sep 17 00:00:00 2001 From: Andrey Bocharnikov Date: Fri, 10 Jan 2025 19:07:23 +0400 Subject: [PATCH] feat(config)_: Do not store embedded RPC provider credentials in the DB * +tests --- rpc/network/network.go | 126 ++++++++++++++++++++++----------- rpc/network/network_test.go | 124 +++++++++++++++++++++++++------- services/wallet/token/token.go | 2 +- 3 files changed, 181 insertions(+), 71 deletions(-) diff --git a/rpc/network/network.go b/rpc/network/network.go index 97b756c43..4fffa7ec1 100644 --- a/rpc/network/network.go +++ b/rpc/network/network.go @@ -6,8 +6,13 @@ import ( "github.com/status-im/status-go/multiaccounts/accounts" + "go.uber.org/zap" + + "github.com/status-im/status-go/logutils" + "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/params" "github.com/status-im/status-go/params/networkhelper" + persistence "github.com/status-im/status-go/rpc/network/db" ) @@ -27,7 +32,7 @@ type ManagerInterface interface { GetAll() ([]*params.Network, error) GetActiveNetworks() ([]*params.Network, error) GetCombinedNetworks() ([]*CombinedNetwork, error) - GetConfiguredNetworks() []params.Network + GetEmbeddedNetworks() []params.Network // Networks that are embedded in the app binary code GetTestNetworksEnabled() (bool, error) SetUserRpcProviders(chainID uint64, providers []params.RpcProvider) error @@ -37,7 +42,9 @@ type Manager struct { db *sql.DB accountsDB *accounts.Database networkPersistence persistence.NetworksPersistenceInterface - configuredNetworks []params.Network + embeddedNetworks []params.Network + + logger *zap.Logger } // NewManager creates a new instance of Manager. @@ -46,19 +53,30 @@ func NewManager(db *sql.DB) *Manager { if err != nil { return nil } + + logger := logutils.ZapLogger().Named("NetworkManager") + return &Manager{ db: db, accountsDB: accountsDB, networkPersistence: persistence.NewNetworksPersistence(db), + logger: logger, } } -// Init initializes the networks, merging with existing ones and wrapping the operation in a transaction. -func (nm *Manager) InitEmbeddedNetworks(networks []params.Network) error { - if networks == nil { +// Init initializes the nets, merges them with existing ones, and wraps the operation in a transaction. +// We should store the following information in the DB: +// - User's RPC providers +// - Enabled state of the network +// Embedded RPC providers should only be stored in memory +func (nm *Manager) InitEmbeddedNetworks(embeddedNetworks []params.Network) error { + if embeddedNetworks == nil { return nil } + // Update embedded networks + nm.embeddedNetworks = embeddedNetworks + // Begin a transaction return persistence.ExecuteWithinTransaction(nm.db, func(tx *sql.Tx) error { // Create temporary persistence instances with the transaction @@ -70,39 +88,67 @@ func (nm *Manager) InitEmbeddedNetworks(networks []params.Network) error { } // Create a map for quick access to current networks - currentNetworkMap := make(map[uint64]params.Network) + currentNetworskMap := make(map[uint64]params.Network) for _, currentNetwork := range currentNetworks { - currentNetworkMap[currentNetwork.ChainID] = *currentNetwork + currentNetworskMap[currentNetwork.ChainID] = *currentNetwork } - // Process new networks + // Keep user's rpc providers and enabled state var updatedNetworks []params.Network - for _, newNetwork := range networks { - if existingNetwork, exists := currentNetworkMap[newNetwork.ChainID]; exists { - // If network already exists, merge providers - newNetwork.RpcProviders = networkhelper.ReplaceEmbeddedProviders(existingNetwork.RpcProviders, newNetwork.RpcProviders) + for _, newNetwork := range embeddedNetworks { + if existingNetwork, exists := currentNetworskMap[newNetwork.ChainID]; exists { + newNetwork.RpcProviders = networkhelper.GetUserProviders(existingNetwork.RpcProviders) + newNetwork.Enabled = existingNetwork.Enabled + } else { + newNetwork.RpcProviders = networkhelper.GetUserProviders(newNetwork.RpcProviders) } updatedNetworks = append(updatedNetworks, newNetwork) } - // Use SetNetworks to replace all networks in the database + // Use SetNetworks to replace all networks in the database without embedded RPC providers err = txNetworksPersistence.SetNetworks(updatedNetworks) if err != nil { return fmt.Errorf("error setting networks: %w", err) } - // Update configured networks - nm.configuredNetworks = networks - return nil }) } +// GetEmbeddedProviders returns embedded providers for a given chainID. +func (nm *Manager) getEmbeddedProviders(chainID uint64) []params.RpcProvider { + for _, network := range nm.embeddedNetworks { + if network.ChainID == chainID { + return networkhelper.GetEmbeddedProviders(network.RpcProviders) + } + } + return nil +} + +// setEmbeddedProviders adds embedded providers to a network. +func (nm *Manager) setNetworkEmbeddedProviders(network *params.Network) { + network.RpcProviders = networkhelper.ReplaceEmbeddedProviders( + network.RpcProviders, nm.getEmbeddedProviders(network.ChainID)) +} + +func (nm *Manager) setEmbeddedProviders(networks []*params.Network) { + for _, network := range networks { + nm.setNetworkEmbeddedProviders(network) + } +} + +// networkWithoutEmbeddedProviders returns a copy of the given network without embedded RPC providers. +func (nm *Manager) networkWithoutEmbeddedProviders(network *params.Network) *params.Network { + networkCopy := networkhelper.DeepCopyNetwork(*network) + networkCopy.RpcProviders = networkhelper.GetUserProviders(network.RpcProviders) + return &networkCopy +} + // Upsert adds or updates a network, synchronizing RPC providers, wrapped in a transaction. func (nm *Manager) Upsert(network *params.Network) error { return persistence.ExecuteWithinTransaction(nm.db, func(tx *sql.Tx) error { txNetworksPersistence := persistence.NewNetworksPersistence(tx) - err := txNetworksPersistence.UpsertNetwork(network) + err := txNetworksPersistence.UpsertNetwork(nm.networkWithoutEmbeddedProviders(network)) if err != nil { return fmt.Errorf("failed to upsert network: %w", err) } @@ -124,51 +170,45 @@ func (nm *Manager) Delete(chainID uint64) error { // SetUserRpcProviders updates user RPC providers, wrapped in a transaction. func (nm *Manager) SetUserRpcProviders(chainID uint64, userProviders []params.RpcProvider) error { - return persistence.ExecuteWithinTransaction(nm.db, func(tx *sql.Tx) error { - // Create temporary persistence instances with the transaction - txRpcPersistence := persistence.NewRpcProvidersPersistence(tx) - - // Get all providers using the transactional RPC persistence - allProviders, err := txRpcPersistence.GetRpcProviders(chainID) - if err != nil { - return fmt.Errorf("failed to get all providers: %w", err) - } - - // Replace user providers - providers := networkhelper.ReplaceUserProviders(allProviders, userProviders) - - // Set RPC providers using the transactional RPC persistence - err = txRpcPersistence.SetRpcProviders(chainID, providers) - if err != nil { - return fmt.Errorf("failed to set RPC providers: %w", err) - } - - return nil - }) + rpcPersistence := nm.networkPersistence.GetRpcPersistence() + return rpcPersistence.SetRpcProviders(chainID, networkhelper.GetUserProviders(userProviders)) } // Find locates a network by ChainID. func (nm *Manager) Find(chainID uint64) *params.Network { networks, err := nm.networkPersistence.GetNetworkByChainID(chainID) if len(networks) != 1 || err != nil { + nm.logger.Warn("Failed to find network", zap.Uint64("chainID", chainID), zap.Error(err)) return nil } - return networks[0] + result := networks[0] + nm.setNetworkEmbeddedProviders(result) + return result } // GetAll returns all networks. func (nm *Manager) GetAll() ([]*params.Network, error) { - return nm.networkPersistence.GetAllNetworks() + networks, err := nm.networkPersistence.GetAllNetworks() + if err != nil { + return nil, err + } + nm.setEmbeddedProviders(networks) + return networks, nil } // Get returns networks filtered by the enabled status. func (nm *Manager) Get(onlyEnabled bool) ([]*params.Network, error) { - return nm.networkPersistence.GetNetworks(onlyEnabled, nil) + networks, err := nm.networkPersistence.GetNetworks(onlyEnabled, nil) + if err != nil { + return nil, err + } + nm.setEmbeddedProviders(networks) + return networks, nil } // GetConfiguredNetworks returns the configured networks. -func (nm *Manager) GetConfiguredNetworks() []params.Network { - return nm.configuredNetworks +func (nm *Manager) GetEmbeddedNetworks() []params.Network { + return nm.embeddedNetworks } // GetTestNetworksEnabled checks if test networks are enabled. diff --git a/rpc/network/network_test.go b/rpc/network/network_test.go index 96eb8a2f3..951c8f048 100644 --- a/rpc/network/network_test.go +++ b/rpc/network/network_test.go @@ -2,9 +2,10 @@ package network_test import ( "database/sql" - "github.com/status-im/status-go/api" "testing" + "github.com/status-im/status-go/api" + "github.com/status-im/status-go/appdatabase" "github.com/status-im/status-go/params" "github.com/status-im/status-go/params/networkhelper" @@ -36,14 +37,12 @@ func (s *NetworkManagerTestSuite) SetupTest() { initNetworks := []params.Network{ *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.UserProviderType, true, "https://mainnet.infura.io"), - testutil.CreateProvider(api.MainnetChainID, "Backup Mainnet", params.EmbeddedProxyProviderType, false, "https://backup.mainnet.provider.com"), }), *testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", []params.RpcProvider{ testutil.CreateProvider(api.SepoliaChainID, "Infura Sepolia", params.UserProviderType, true, "https://sepolia.infura.io"), }), *testutil.CreateNetwork(api.OptimismChainID, "Optimistic Ethereum", []params.RpcProvider{ testutil.CreateProvider(api.OptimismChainID, "Infura Optimism", params.UserProviderType, true, "https://optimism.infura.io"), - testutil.CreateProvider(api.OptimismChainID, "Backup Optimism", params.EmbeddedDirectProviderType, false, "https://backup.optimism.provider.com"), }), } err = persistence.SetNetworks(initNetworks) @@ -75,29 +74,6 @@ func (s *NetworkManagerTestSuite) assertDbNetworks(expectedNetworks []params.Net testutil.CompareNetworksList(s.T(), expectedNetworksPtr, savedNetworks) } -func (s *NetworkManagerTestSuite) TestInitNetworksWithChangedAuth() { - // Change auth token for a provider - updatedNetworks := []params.Network{ - *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ - testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.UserProviderType, true, "https://mainnet.infura.io"), - { - Name: "Backup Mainnet", - ChainID: api.MainnetChainID, - Type: params.EmbeddedProxyProviderType, - Enabled: false, - URL: "https://backup.mainnet.provider.com", - AuthType: params.TokenAuth, - AuthToken: "new-token", - }, - }), - } - - // Re-initialize and assert - err := s.manager.InitEmbeddedNetworks(updatedNetworks) - s.Require().NoError(err) - s.assertDbNetworks(updatedNetworks) -} - func (s *NetworkManagerTestSuite) TestUserAddsCustomProviders() { // Adding custom providers customProviders := []params.RpcProvider{ @@ -126,7 +102,7 @@ func (s *NetworkManagerTestSuite) TestInitNetworksKeepsUserProviders() { // Re-initialize networks initNetworks := []params.Network{ *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ - testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.UserProviderType, true, "https://mainnet.infura.io"), + testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.EmbeddedProxyProviderType, true, "https://mainnet.infura.io"), }), } err = s.manager.InitEmbeddedNetworks(initNetworks) @@ -139,6 +115,81 @@ func (s *NetworkManagerTestSuite) TestInitNetworksKeepsUserProviders() { testutil.CompareProvidersList(s.T(), expectedProviders, foundNetwork.RpcProviders) } +func (s *NetworkManagerTestSuite) TestInitNetworksDoesNotSaveEmbeddedProviders() { + persistence := db.NewNetworksPersistence(s.db) + s.Require().NoError(persistence.DeleteAllNetworks()) + + // Re-initialize networks + initNetworks := []params.Network{ + *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ + testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.EmbeddedProxyProviderType, true, "https://mainnet.infura.io"), + }), + } + err := s.manager.InitEmbeddedNetworks(initNetworks) + s.Require().NoError(err) + + // Check that embedded providers are not saved using persistence + networks, err := persistence.GetNetworks(false, nil) + s.Require().NoError(err) + s.Require().Len(networks, 1) + s.Require().Len(networks[0].RpcProviders, 0) +} + +func (s *NetworkManagerTestSuite) TestInitEmbeddedNetworks() { + // Re-initialize networks + initNetworks := []params.Network{ + *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ + testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.EmbeddedProxyProviderType, true, "https://mainnet.infura.io"), + }), + } + expectedProviders := networkhelper.GetEmbeddedProviders(initNetworks[0].RpcProviders) + err := s.manager.InitEmbeddedNetworks(initNetworks) + s.Require().NoError(err) + + // functor tests if embedded providers are present in the networks + expectEmbeddedProviders := func(networks []*params.Network) { + for _, network := range networks { + if network.ChainID == api.MainnetChainID { + storedEmbeddedProviders := networkhelper.GetEmbeddedProviders(network.RpcProviders) + testutil.CompareProvidersList(s.T(), expectedProviders, storedEmbeddedProviders) + } + } + } + + // GetAll + networks, err := s.manager.GetAll() + s.Require().NoError(err) + expectEmbeddedProviders(networks) + + // Get + networks, err = s.manager.Get(false) + s.Require().NoError(err) + expectEmbeddedProviders(networks) + + // GetActiveNetworks + networks, err = s.manager.GetActiveNetworks() + s.Require().NoError(err) + expectEmbeddedProviders(networks) + + // GetCombinedNetworks + combinedNetworks, err := s.manager.GetCombinedNetworks() + s.Require().NoError(err) + for _, combinedNetwork := range combinedNetworks { + if combinedNetwork.Test != nil && combinedNetwork.Test.ChainID == api.MainnetChainID { + storedEmbeddedProviders := networkhelper.GetEmbeddedProviders(combinedNetwork.Test.RpcProviders) + testutil.CompareProvidersList(s.T(), expectedProviders, storedEmbeddedProviders) + } + if combinedNetwork.Prod != nil && combinedNetwork.Prod.ChainID == api.MainnetChainID { + storedEmbeddedProviders := networkhelper.GetEmbeddedProviders(combinedNetwork.Prod.RpcProviders) + testutil.CompareProvidersList(s.T(), expectedProviders, storedEmbeddedProviders) + } + } + + // GetEmbeddedNetworks + embeddedNetworks := s.manager.GetEmbeddedNetworks() + expectEmbeddedProviders(testutil.ConvertNetworksToPointers(embeddedNetworks)) +} + func (s *NetworkManagerTestSuite) TestLegacyFieldPopulation() { // Create initial test networks with various providers initNetworks := []params.Network{ @@ -206,3 +257,22 @@ func (s *NetworkManagerTestSuite) TestLegacyFieldPopulationWithoutUserProviders( s.Equal("https://proxy2.sepolia.io", network.DefaultFallbackURL) s.Empty(network.DefaultFallbackURL2) // No third Proxy provider } + +func (s *NetworkManagerTestSuite) TestUpsertNetwork() { + // Create a new network + newNetwork := testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ + testutil.CreateProvider(api.MainnetChainID, "Infura Mainnet", params.EmbeddedProxyProviderType, true, "https://mainnet.infura.io"), + }) + + // Upsert the network + err := s.manager.Upsert(newNetwork) + s.Require().NoError(err) + + // Verify the network was upserted without embedded providers + persistence := db.NewNetworksPersistence(s.db) + chainID := api.MainnetChainID + networks, err := persistence.GetNetworks(false, &chainID) + s.Require().NoError(err) + s.Require().Len(networks, 1) + s.Require().Len(networkhelper.GetEmbeddedProviders(networks[0].RpcProviders), 0) +} diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index fa47b7b56..e1f3be613 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -501,7 +501,7 @@ func (tm *Manager) GetAllTokens() ([]*Token, error) { allTokens = append(tm.getTokens(), allTokens...) - overrideTokensInPlace(tm.networkManager.GetConfiguredNetworks(), allTokens) + overrideTokensInPlace(tm.networkManager.GetEmbeddedNetworks(), allTokens) native, err := tm.getNativeTokens() if err != nil {