feat(config)_: Do not store embedded RPC provider credentials in the DB

* add method to make a deepcopy of a network + tests
* improved logging
* improved memory allocation
This commit is contained in:
Andrey Bocharnikov 2025-01-10 19:06:46 +04:00
parent 593cdc0314
commit 27a5515df7
3 changed files with 43 additions and 10 deletions

View File

@ -51,7 +51,7 @@ func ToggleUserProviders(providers []params.RpcProvider, enabled bool) []params.
// GetEmbeddedProviders returns the embedded providers from the list.
func GetEmbeddedProviders(providers []params.RpcProvider) []params.RpcProvider {
var embeddedProviders []params.RpcProvider
embeddedProviders := make([]params.RpcProvider, 0, len(providers))
for _, provider := range providers {
if provider.Type != params.UserProviderType {
embeddedProviders = append(embeddedProviders, provider)
@ -62,7 +62,7 @@ func GetEmbeddedProviders(providers []params.RpcProvider) []params.RpcProvider {
// GetUserProviders returns the user-defined providers from the list.
func GetUserProviders(providers []params.RpcProvider) []params.RpcProvider {
var userProviders []params.RpcProvider
userProviders := make([]params.RpcProvider, 0, len(providers))
for _, provider := range providers {
if provider.Type == params.UserProviderType {
userProviders = append(userProviders, provider)
@ -118,19 +118,25 @@ func OverrideEmbeddedProxyProviders(networks []params.Network, enabled bool, use
return updatedNetworks
}
func deepCopyNetworks(networks []params.Network) []params.Network {
updatedNetworks := make([]params.Network, len(networks))
for i, network := range networks {
func DeepCopyNetwork(network params.Network) params.Network {
updatedNetwork := network
updatedNetwork.RpcProviders = make([]params.RpcProvider, len(network.RpcProviders))
copy(updatedNetwork.RpcProviders, network.RpcProviders)
updatedNetworks[i] = updatedNetwork
updatedNetwork.TokenOverrides = make([]params.TokenOverride, len(network.TokenOverrides))
copy(updatedNetwork.TokenOverrides, network.TokenOverrides)
return updatedNetwork
}
func DeepCopyNetworks(networks []params.Network) []params.Network {
updatedNetworks := make([]params.Network, len(networks))
for i, network := range networks {
updatedNetworks[i] = DeepCopyNetwork(network)
}
return updatedNetworks
}
func OverrideDirectProvidersAuth(networks []params.Network, authTokens map[string]string) []params.Network {
updatedNetworks := deepCopyNetworks(networks)
updatedNetworks := DeepCopyNetworks(networks)
for i := range updatedNetworks {
network := &updatedNetworks[i]

View File

@ -143,3 +143,24 @@ func TestOverrideDirectProvidersAuth(t *testing.T) {
}
}
}
func TestDeepCopyNetwork(t *testing.T) {
originalNetwork := testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{
*params.NewUserProvider(api.MainnetChainID, "Provider1", "https://userprovider.example.com", true),
*params.NewDirectProvider(api.MainnetChainID, "Provider2", "https://mainnet.infura.io/v3/", true),
})
originalNetwork.TokenOverrides = []params.TokenOverride{
{Symbol: "token1", Address: common.HexToAddress("0x123")},
}
copiedNetwork := networkhelper.DeepCopyNetwork(*originalNetwork)
assert.True(t, reflect.DeepEqual(originalNetwork, &copiedNetwork), "Copied network should be deeply equal to the original")
// Modify the copied network and verify that the original network remains unchanged
copiedNetwork.RpcProviders[0].Enabled = false
copiedNetwork.TokenOverrides[0].Symbol = "modifiedSymbol"
assert.NotEqual(t, originalNetwork.RpcProviders[0].Enabled, copiedNetwork.RpcProviders[0].Enabled, "Original network should remain unchanged")
assert.NotEqual(t, originalNetwork.TokenOverrides[0].Symbol, copiedNetwork.TokenOverrides[0].Symbol, "Original network should remain unchanged")
}

View File

@ -2,7 +2,9 @@ package db
import (
"database/sql"
"errors"
"fmt"
"github.com/status-im/status-go/params"
)
@ -65,13 +67,17 @@ func ExecuteWithinTransaction(db *sql.DB, fn func(tx *sql.Tx) error) (err error)
}
defer func() {
if p := recover(); p != nil {
err = fmt.Errorf("panic: %v", p)
_ = tx.Rollback()
panic(p)
} else if err != nil {
_ = tx.Rollback()
rollbackErr := tx.Rollback()
if rollbackErr != nil {
err = errors.Join(err, fmt.Errorf("transaction rollback failed: %w", rollbackErr))
}
} else {
if commitErr := tx.Commit(); commitErr != nil {
err = fmt.Errorf("transaction commit failed: %w", commitErr)
err = errors.Join(err, fmt.Errorf("transaction commit failed: %w", commitErr))
}
}
}()