diff --git a/rpc/network/db/network_db.go b/rpc/network/db/network_db.go index 885a72243..8c3015036 100644 --- a/rpc/network/db/network_db.go +++ b/rpc/network/db/network_db.go @@ -30,6 +30,7 @@ type NetworksPersistenceInterface interface { DeleteAllNetworks() error GetRpcPersistence() RpcProvidersPersistenceInterface + SetEnabled(chainID uint64, enabled bool) error } // NetworksPersistence manages networks and their providers. @@ -255,3 +256,22 @@ func (n *NetworksPersistence) DeleteNetwork(chainID uint64) error { return nil } + +// SetEnabled updates the enabled status of a network. +func (n *NetworksPersistence) SetEnabled(chainID uint64, enabled bool) error { + q := sq.Update("networks"). + Set("enabled", enabled). + Where(sq.Eq{"chain_id": chainID}) + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build update query: %w", err) + } + + _, err = n.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute update query for chain_id %d: %w", chainID, err) + } + + return nil +} diff --git a/rpc/network/db/network_db_test.go b/rpc/network/db/network_db_test.go index 4587c529b..dfe6a51fd 100644 --- a/rpc/network/db/network_db_test.go +++ b/rpc/network/db/network_db_test.go @@ -195,3 +195,28 @@ func (s *NetworksPersistenceTestSuite) TestValidationForNetworksAndProviders() { s.Require().NoError(err) s.Require().Len(allNetworks, 0, "No invalid networks should be saved") } + +func (s *NetworksPersistenceTestSuite) TestSetEnabled() { + network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID)) + s.addAndVerifyNetworks([]*params.Network{network}) + + // Disable the network + err := s.networksPersistence.SetEnabled(network.ChainID, false) + s.Require().NoError(err) + + // Verify the network is disabled + updatedNetwork, err := s.networksPersistence.GetNetworkByChainID(network.ChainID) + s.Require().NoError(err) + s.Require().Len(updatedNetwork, 1) + s.Require().False(updatedNetwork[0].Enabled) + + // Enable the network + err = s.networksPersistence.SetEnabled(network.ChainID, true) + s.Require().NoError(err) + + // Verify the network is enabled + updatedNetwork, err = s.networksPersistence.GetNetworkByChainID(network.ChainID) + s.Require().NoError(err) + s.Require().Len(updatedNetwork, 1) + s.Require().True(updatedNetwork[0].Enabled) +} diff --git a/rpc/network/network.go b/rpc/network/network.go index 4fffa7ec1..eda5689b7 100644 --- a/rpc/network/network.go +++ b/rpc/network/network.go @@ -36,6 +36,7 @@ type ManagerInterface interface { GetTestNetworksEnabled() (bool, error) SetUserRpcProviders(chainID uint64, providers []params.RpcProvider) error + SetEnabled(chainID uint64, enabled bool) error } type Manager struct { @@ -174,6 +175,15 @@ func (nm *Manager) SetUserRpcProviders(chainID uint64, userProviders []params.Rp return rpcPersistence.SetRpcProviders(chainID, networkhelper.GetUserProviders(userProviders)) } +// SetEnabled updates the enabled status of a network +func (nm *Manager) SetEnabled(chainID uint64, enabled bool) error { + err := nm.networkPersistence.SetEnabled(chainID, enabled) + if err != nil { + return fmt.Errorf("failed to set enabled status: %w", err) + } + return nil +} + // Find locates a network by ChainID. func (nm *Manager) Find(chainID uint64) *params.Network { networks, err := nm.networkPersistence.GetNetworkByChainID(chainID) diff --git a/services/wallet/api.go b/services/wallet/api.go index 8fa0189eb..da603e5bf 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -400,6 +400,16 @@ func (api *API) AddEthereumChain(ctx context.Context, network params.Network) er return api.s.rpcClient.NetworkManager.Upsert(&network) } +func (api *API) SetChainUserRpcProviders(ctx context.Context, chainID uint64, rpcProviders []params.RpcProvider) error { + logutils.ZapLogger().Debug("call to SetChainUserRpcProviders") + return api.s.rpcClient.NetworkManager.SetUserRpcProviders(chainID, rpcProviders) +} + +func (api *API) SetChainEnabled(ctx context.Context, chainID uint64, enabled bool) error { + logutils.ZapLogger().Debug("call to SetChainEnabled") + return api.s.rpcClient.NetworkManager.SetEnabled(chainID, enabled) +} + func (api *API) DeleteEthereumChain(ctx context.Context, chainID uint64) error { logutils.ZapLogger().Debug("call to DeleteEthereumChain") return api.s.rpcClient.NetworkManager.Delete(chainID)