chore(config)_: rpc providers configuration (#6151)

* chore(config)_: extract rpc_provider_persistence + tests

* Add rpc_providers table, migration
* add RpcProvider type
* deprecate old rpc fields in networks, add RpcProviders list
* add persistence packages for rpc_providers, networks
* Tests
This commit is contained in:
Andrey Bocharnikov 2025-01-11 05:02:09 +07:00 committed by GitHub
parent 90ce72a2d5
commit e9abf1662d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1759 additions and 70 deletions

View File

@ -1579,12 +1579,12 @@ func TestWalletConfigOnLoginAccount(t *testing.T) {
}
require.Equal(t, b.config.WalletConfig.InfuraAPIKey, infuraToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[mainnetChainID], alchemyEthereumMainnetToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[sepoliaChainID], alchemyEthereumSepoliaToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[arbitrumChainID], alchemyArbitrumMainnetToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[arbitrumSepoliaChainID], alchemyArbitrumSepoliaToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[optimismChainID], alchemyOptimismMainnetToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[optimismSepoliaChainID], alchemyOptimismSepoliaToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[MainnetChainID], alchemyEthereumMainnetToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[SepoliaChainID], alchemyEthereumSepoliaToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[ArbitrumChainID], alchemyArbitrumMainnetToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[ArbitrumSepoliaChainID], alchemyArbitrumSepoliaToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[OptimismChainID], alchemyOptimismMainnetToken)
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[OptimismSepoliaChainID], alchemyOptimismSepoliaToken)
require.Equal(t, b.config.WalletConfig.RaribleMainnetAPIKey, raribleMainnetAPIKey)
require.Equal(t, b.config.WalletConfig.RaribleTestnetAPIKey, raribleTestnetAPIKey)

View File

@ -10,12 +10,12 @@ import (
)
const (
mainnetChainID uint64 = 1
sepoliaChainID uint64 = 11155111
optimismChainID uint64 = 10
optimismSepoliaChainID uint64 = 11155420
arbitrumChainID uint64 = 42161
arbitrumSepoliaChainID uint64 = 421614
MainnetChainID uint64 = 1
SepoliaChainID uint64 = 11155111
OptimismChainID uint64 = 10
OptimismSepoliaChainID uint64 = 11155420
ArbitrumChainID uint64 = 42161
ArbitrumSepoliaChainID uint64 = 421614
sntSymbol = "SNT"
sttSymbol = "STT"
)
@ -24,7 +24,7 @@ var ganacheTokenAddress = common.HexToAddress("0x8571Ddc46b10d31EF963aF49b6C7799
func mainnet(stageName string) params.Network {
return params.Network{
ChainID: mainnetChainID,
ChainID: MainnetChainID,
ChainName: "Mainnet",
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/ethereum/mainnet/", stageName),
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/ethereum/mainnet/", stageName),
@ -41,13 +41,13 @@ func mainnet(stageName string) params.Network {
IsTest: false,
Layer: 1,
Enabled: true,
RelatedChainID: sepoliaChainID,
RelatedChainID: SepoliaChainID,
}
}
func sepolia(stageName string) params.Network {
return params.Network{
ChainID: sepoliaChainID,
ChainID: SepoliaChainID,
ChainName: "Mainnet",
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/ethereum/sepolia/", stageName),
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/ethereum/sepolia/", stageName),
@ -64,13 +64,13 @@ func sepolia(stageName string) params.Network {
IsTest: true,
Layer: 1,
Enabled: true,
RelatedChainID: mainnetChainID,
RelatedChainID: MainnetChainID,
}
}
func optimism(stageName string) params.Network {
return params.Network{
ChainID: optimismChainID,
ChainID: OptimismChainID,
ChainName: "Optimism",
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/optimism/mainnet/", stageName),
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/optimism/mainnet/", stageName),
@ -87,13 +87,13 @@ func optimism(stageName string) params.Network {
IsTest: false,
Layer: 2,
Enabled: true,
RelatedChainID: optimismSepoliaChainID,
RelatedChainID: OptimismSepoliaChainID,
}
}
func optimismSepolia(stageName string) params.Network {
return params.Network{
ChainID: optimismSepoliaChainID,
ChainID: OptimismSepoliaChainID,
ChainName: "Optimism",
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/optimism/sepolia/", stageName),
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/optimism/sepolia/", stageName),
@ -110,13 +110,13 @@ func optimismSepolia(stageName string) params.Network {
IsTest: true,
Layer: 2,
Enabled: false,
RelatedChainID: optimismChainID,
RelatedChainID: OptimismChainID,
}
}
func arbitrum(stageName string) params.Network {
return params.Network{
ChainID: arbitrumChainID,
ChainID: ArbitrumChainID,
ChainName: "Arbitrum",
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/arbitrum/mainnet/", stageName),
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/arbitrum/mainnet/", stageName),
@ -133,13 +133,13 @@ func arbitrum(stageName string) params.Network {
IsTest: false,
Layer: 2,
Enabled: true,
RelatedChainID: arbitrumSepoliaChainID,
RelatedChainID: ArbitrumSepoliaChainID,
}
}
func arbitrumSepolia(stageName string) params.Network {
return params.Network{
ChainID: arbitrumSepoliaChainID,
ChainID: ArbitrumSepoliaChainID,
ChainName: "Arbitrum",
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/arbitrum/sepolia/", stageName),
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/arbitrum/sepolia/", stageName),
@ -156,7 +156,7 @@ func arbitrumSepolia(stageName string) params.Network {
IsTest: true,
Layer: 2,
Enabled: false,
RelatedChainID: arbitrumChainID,
RelatedChainID: ArbitrumChainID,
}
}
@ -204,7 +204,7 @@ func setRPCs(networks []params.Network, request *requests.WalletSecretsConfig) [
if request.GanacheURL != "" {
n.RPCURL = request.GanacheURL
n.FallbackURL = request.GanacheURL
if n.ChainID == mainnetChainID {
if n.ChainID == MainnetChainID {
n.TokenOverrides = []params.TokenOverride{
mainnetGanacheTokenOverrides,
}

View File

@ -28,12 +28,12 @@ func TestBuildDefaultNetworks(t *testing.T) {
for _, n := range actualNetworks {
var err error
switch n.ChainID {
case mainnetChainID:
case sepoliaChainID:
case optimismChainID:
case optimismSepoliaChainID:
case arbitrumChainID:
case arbitrumSepoliaChainID:
case MainnetChainID:
case SepoliaChainID:
case OptimismChainID:
case OptimismSepoliaChainID:
case ArbitrumChainID:
case ArbitrumSepoliaChainID:
default:
err = errors.Errorf("unexpected chain id: %d", n.ChainID)
}
@ -70,7 +70,7 @@ func TestBuildDefaultNetworksGanache(t *testing.T) {
require.True(t, strings.Contains(n.FallbackURL, ganacheURL))
}
require.Equal(t, mainnetChainID, actualNetworks[0].ChainID)
require.Equal(t, MainnetChainID, actualNetworks[0].ChainID)
require.NotNil(t, actualNetworks[0].TokenOverrides)
require.Len(t, actualNetworks[0].TokenOverrides, 1)

View File

@ -198,22 +198,22 @@ func buildWalletConfig(request *requests.WalletSecretsConfig, statusProxyEnabled
}
if request.AlchemyEthereumMainnetToken != "" {
walletConfig.AlchemyAPIKeys[mainnetChainID] = request.AlchemyEthereumMainnetToken
walletConfig.AlchemyAPIKeys[MainnetChainID] = request.AlchemyEthereumMainnetToken
}
if request.AlchemyEthereumSepoliaToken != "" {
walletConfig.AlchemyAPIKeys[sepoliaChainID] = request.AlchemyEthereumSepoliaToken
walletConfig.AlchemyAPIKeys[SepoliaChainID] = request.AlchemyEthereumSepoliaToken
}
if request.AlchemyArbitrumMainnetToken != "" {
walletConfig.AlchemyAPIKeys[arbitrumChainID] = request.AlchemyArbitrumMainnetToken
walletConfig.AlchemyAPIKeys[ArbitrumChainID] = request.AlchemyArbitrumMainnetToken
}
if request.AlchemyArbitrumSepoliaToken != "" {
walletConfig.AlchemyAPIKeys[arbitrumSepoliaChainID] = request.AlchemyArbitrumSepoliaToken
walletConfig.AlchemyAPIKeys[ArbitrumSepoliaChainID] = request.AlchemyArbitrumSepoliaToken
}
if request.AlchemyOptimismMainnetToken != "" {
walletConfig.AlchemyAPIKeys[optimismChainID] = request.AlchemyOptimismMainnetToken
walletConfig.AlchemyAPIKeys[OptimismChainID] = request.AlchemyOptimismMainnetToken
}
if request.AlchemyOptimismSepoliaToken != "" {
walletConfig.AlchemyAPIKeys[optimismSepoliaChainID] = request.AlchemyOptimismSepoliaToken
walletConfig.AlchemyAPIKeys[OptimismSepoliaChainID] = request.AlchemyOptimismSepoliaToken
}
if request.StatusProxyMarketUser != "" {
walletConfig.StatusProxyMarketUser = request.StatusProxyMarketUser

View File

@ -0,0 +1,13 @@
CREATE TABLE IF NOT EXISTS rpc_providers (
id INTEGER PRIMARY KEY AUTOINCREMENT, -- Unique provider ID (sorting)
chain_id INTEGER NOT NULL CHECK (chain_id > 0), -- Chain ID for the network
name TEXT NOT NULL CHECK (LENGTH(name) > 0), -- Provider name
url TEXT NOT NULL CHECK (LENGTH(url) > 0), -- Provider URL
enable_rps_limiter BOOLEAN NOT NULL DEFAULT FALSE, -- Enable RPS limiter
type TEXT NOT NULL DEFAULT 'user', -- Provider type: embedded-proxy, embedded-direct, user
enabled BOOLEAN NOT NULL DEFAULT TRUE, -- Whether the provider is active or not
auth_type TEXT NOT NULL DEFAULT 'no-auth', -- Authentication type: no-auth, basic-auth, token-auth
auth_login TEXT, -- BasicAuth login (nullable)
auth_password TEXT, -- Password for BasicAuth (nullable)
auth_token TEXT -- Token for TokenAuth (nullable)
);

View File

@ -14,7 +14,6 @@ import (
"go.uber.org/zap"
validator "gopkg.in/go-playground/validator.v9"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/params"
@ -482,7 +481,7 @@ type NodeConfig struct {
// (persistent storage of user's mailserver records).
MailserversConfig MailserversConfig
// Web3ProviderConfig extra configuration for provider.Service
// Web3ProviderConfig extra configuration for provider.Service.
// (desktop provider API)
Web3ProviderConfig Web3ProviderConfig
@ -514,35 +513,6 @@ type NodeConfig struct {
ProcessBackedupMessages bool
}
type TokenOverride struct {
Symbol string `json:"symbol"`
Address common.Address `json:"address"`
}
type Network struct {
ChainID uint64 `json:"chainId"`
ChainName string `json:"chainName"`
DefaultRPCURL string `json:"defaultRpcUrl"` // proxy rpc url
DefaultFallbackURL string `json:"defaultFallbackURL"` // proxy fallback url
DefaultFallbackURL2 string `json:"defaultFallbackURL2"` // second proxy fallback url
RPCURL string `json:"rpcUrl"`
OriginalRPCURL string `json:"originalRpcUrl"`
FallbackURL string `json:"fallbackURL"`
OriginalFallbackURL string `json:"originalFallbackURL"`
BlockExplorerURL string `json:"blockExplorerUrl,omitempty"`
IconURL string `json:"iconUrl,omitempty"`
NativeCurrencyName string `json:"nativeCurrencyName,omitempty"`
NativeCurrencySymbol string `json:"nativeCurrencySymbol,omitempty"`
NativeCurrencyDecimals uint64 `json:"nativeCurrencyDecimals"`
IsTest bool `json:"isTest"`
Layer uint64 `json:"layer"`
Enabled bool `json:"enabled"`
ChainColor string `json:"chainColor"`
ShortName string `json:"shortName"`
TokenOverrides []TokenOverride `json:"tokenOverrides"`
RelatedChainID uint64 `json:"relatedChainId"`
}
// WalletConfig extra configuration for wallet.Service.
type WalletConfig struct {
Enabled bool
@ -598,7 +568,7 @@ type MailserversConfig struct {
Enabled bool
}
// ProviderConfig extra configuration for provider.Service
// ProviderAuthConfig extra configuration for provider.Service
type Web3ProviderConfig struct {
Enabled bool
}

95
params/network_config.go Normal file
View File

@ -0,0 +1,95 @@
package params
import "github.com/ethereum/go-ethereum/common"
// RpcProviderAuthType defines the different types of authentication for RPC providers
type RpcProviderAuthType string
const (
NoAuth RpcProviderAuthType = "no-auth" // No authentication
BasicAuth RpcProviderAuthType = "basic-auth" // HTTP Header "Authorization: Basic base64(username:password)"
TokenAuth RpcProviderAuthType = "token-auth" // URL Token-based authentication "https://api.example.com/YOUR_TOKEN"
)
// RpcProviderType defines the type of RPC provider
type RpcProviderType string
const (
EmbeddedProxyProviderType RpcProviderType = "embedded-proxy" // Proxy-based RPC provider
EmbeddedDirectProviderType RpcProviderType = "embedded-direct" // Direct RPC provider
UserProviderType RpcProviderType = "user" // User-defined RPC provider
)
// RpcProvider represents an RPC provider configuration with various options
type RpcProvider struct {
ID int64 `json:"id" validate:"omitempty"` // Auto-increment ID (for sorting order)
ChainID uint64 `json:"chainId" validate:"required,gt=0"` // Chain ID of the network
Name string `json:"name" validate:"required,min=1"` // Provider name for identification
URL string `json:"url" validate:"required,url"` // Current Provider URL
EnableRPSLimiter bool `json:"enableRpsLimiter"` // Enable RPC rate limiting for this provider
Type RpcProviderType `json:"type" validate:"required,oneof=embedded-proxy embedded-direct user"`
Enabled bool `json:"enabled"` // Whether the provider is enabled
// Authentication
AuthType RpcProviderAuthType `json:"authType" validate:"required,oneof=no-auth basic-auth token-auth"` // Type of authentication
AuthLogin string `json:"authLogin" validate:"omitempty,min=1"` // Login for BasicAuth (empty string if not used)
AuthPassword string `json:"authPassword" validate:"omitempty,min=1"` // Password for BasicAuth (empty string if not used)
AuthToken string `json:"authToken" validate:"omitempty,min=1"` // Token for TokenAuth (empty string if not used)
}
type TokenOverride struct {
Symbol string `json:"symbol"`
Address common.Address `json:"address"`
}
type Network struct {
ChainID uint64 `json:"chainId" validate:"required,gt=0"`
ChainName string `json:"chainName" validate:"required,min=1"`
RpcProviders []RpcProvider `json:"rpcProviders" validate:"dive,required"` // List of RPC providers, in the order in which they are accessed
// Deprecated fields (kept for backward compatibility)
// FIXME: Removal of deprecated fields in integration PR https://github.com/status-im/status-go/pull/6178
DefaultRPCURL string `json:"defaultRpcUrl" validate:"omitempty,url"` // Deprecated: proxy rpc url
DefaultFallbackURL string `json:"defaultFallbackURL" validate:"omitempty,url"` // Deprecated: proxy fallback url
DefaultFallbackURL2 string `json:"defaultFallbackURL2" validate:"omitempty,url"` // Deprecated: second proxy fallback url
RPCURL string `json:"rpcUrl" validate:"omitempty,url"` // Deprecated: direct rpc url
OriginalRPCURL string `json:"originalRpcUrl" validate:"omitempty,url"` // Deprecated: direct rpc url if user overrides RPCURL
FallbackURL string `json:"fallbackURL" validate:"omitempty,url"` // Deprecated
OriginalFallbackURL string `json:"originalFallbackURL" validate:"omitempty,url"` // Deprecated
BlockExplorerURL string `json:"blockExplorerUrl,omitempty" validate:"omitempty,url"`
IconURL string `json:"iconUrl,omitempty" validate:"omitempty"`
NativeCurrencyName string `json:"nativeCurrencyName,omitempty" validate:"omitempty,min=1"`
NativeCurrencySymbol string `json:"nativeCurrencySymbol,omitempty" validate:"omitempty,min=1"`
NativeCurrencyDecimals uint64 `json:"nativeCurrencyDecimals" validate:"omitempty"`
IsTest bool `json:"isTest"`
Layer uint64 `json:"layer" validate:"omitempty"`
Enabled bool `json:"enabled"`
ChainColor string `json:"chainColor" validate:"omitempty"`
ShortName string `json:"shortName" validate:"omitempty,min=1"`
TokenOverrides []TokenOverride `json:"tokenOverrides" validate:"omitempty,dive"`
RelatedChainID uint64 `json:"relatedChainId" validate:"omitempty"`
}
func newRpcProvider(chainID uint64, name, url string, enableRpsLimiter bool, providerType RpcProviderType) *RpcProvider {
return &RpcProvider{
ChainID: chainID,
Name: name,
URL: url,
EnableRPSLimiter: enableRpsLimiter,
Type: providerType,
Enabled: true,
AuthType: NoAuth,
}
}
func NewUserProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider {
return newRpcProvider(chainID, name, url, enableRpsLimiter, UserProviderType)
}
func NewProxyProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider {
return newRpcProvider(chainID, name, url, enableRpsLimiter, EmbeddedProxyProviderType)
}
func NewDirectProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider {
return newRpcProvider(chainID, name, url, enableRpsLimiter, EmbeddedDirectProviderType)
}

View File

@ -0,0 +1,186 @@
package networkhelper
import (
"net/url"
"strings"
"github.com/status-im/status-go/params"
)
// MergeProvidersPreservingUsersAndEnabledState merges new embedded providers with the current ones,
// preserving user-defined providers and maintaining the Enabled state.
func MergeProvidersPreservingUsersAndEnabledState(currentProviders, newProviders []params.RpcProvider) []params.RpcProvider {
// Create a map for quick lookup of the Enabled state by Name
enabledState := make(map[string]bool, len(currentProviders))
for _, provider := range currentProviders {
enabledState[provider.Name] = provider.Enabled
}
// Update the Enabled field in newProviders if the Name matches
for i := range newProviders {
if enabled, exists := enabledState[newProviders[i].Name]; exists {
newProviders[i].Enabled = enabled
}
}
// Retain current providers of type UserProviderType and add them to the beginning of the list
mergedProviders := make([]params.RpcProvider, 0, len(currentProviders)+len(newProviders))
for _, provider := range currentProviders {
if provider.Type == params.UserProviderType {
mergedProviders = append(mergedProviders, provider)
}
}
// Add the updated newProviders
mergedProviders = append(mergedProviders, newProviders...)
return mergedProviders
}
// ToggleUserProviders enables or disables all user-defined providers and disables other types.
func ToggleUserProviders(providers []params.RpcProvider, enabled bool) []params.RpcProvider {
for i := range providers {
if providers[i].Type == params.UserProviderType {
providers[i].Enabled = enabled
} else {
providers[i].Enabled = !enabled
}
}
return providers
}
// GetEmbeddedProviders returns the embedded providers from the list.
func GetEmbeddedProviders(providers []params.RpcProvider) []params.RpcProvider {
var embeddedProviders []params.RpcProvider
for _, provider := range providers {
if provider.Type != params.UserProviderType {
embeddedProviders = append(embeddedProviders, provider)
}
}
return embeddedProviders
}
// GetUserProviders returns the user-defined providers from the list.
func GetUserProviders(providers []params.RpcProvider) []params.RpcProvider {
var userProviders []params.RpcProvider
for _, provider := range providers {
if provider.Type == params.UserProviderType {
userProviders = append(userProviders, provider)
}
}
return userProviders
}
// ReplaceUserProviders replaces user-defined providers with new ones, retaining the rest of the providers.
func ReplaceUserProviders(currentProviders, newUserProviders []params.RpcProvider) []params.RpcProvider {
// Extract embedded providers from the current list
embeddedProviders := GetEmbeddedProviders(currentProviders)
userProviders := GetUserProviders(newUserProviders)
// Combine new user providers with the existing embedded providers
return append(userProviders, embeddedProviders...)
}
// ReplaceEmbeddedProviders replaces embedded providers with new ones, retaining user-defined providers.
func ReplaceEmbeddedProviders(currentProviders, newEmbeddedProviders []params.RpcProvider) []params.RpcProvider {
// Extract user-defined providers from the current list
userProviders := GetUserProviders(currentProviders)
embeddedProviders := GetEmbeddedProviders(newEmbeddedProviders)
// Combine existing user-defined providers with the new embedded providers
return append(userProviders, embeddedProviders...)
}
// OverrideEmbeddedProxyProviders updates all embedded-proxy providers in the given networks.
// It sets the `Enabled` flag and configures the `AuthLogin` and `AuthPassword` for each provider.
func OverrideEmbeddedProxyProviders(networks []params.Network, enabled bool, user, password string) []params.Network {
updatedNetworks := make([]params.Network, len(networks))
for i, network := range networks {
// Deep copy the network to avoid mutating the input slice
updatedNetwork := network
updatedProviders := make([]params.RpcProvider, len(network.RpcProviders))
// Update the embedded-proxy providers
for j, provider := range network.RpcProviders {
if provider.Type == params.EmbeddedProxyProviderType {
provider.Enabled = enabled
provider.AuthLogin = user
provider.AuthPassword = password
}
updatedProviders[j] = provider
}
updatedNetwork.RpcProviders = updatedProviders
updatedNetworks[i] = updatedNetwork
}
return updatedNetworks
}
func deepCopyNetworks(networks []params.Network) []params.Network {
updatedNetworks := make([]params.Network, len(networks))
for i, network := range networks {
updatedNetwork := network
updatedNetwork.RpcProviders = make([]params.RpcProvider, len(network.RpcProviders))
copy(updatedNetwork.RpcProviders, network.RpcProviders)
updatedNetworks[i] = updatedNetwork
}
return updatedNetworks
}
func OverrideDirectProvidersAuth(networks []params.Network, authTokens map[string]string) []params.Network {
updatedNetworks := deepCopyNetworks(networks)
for i := range updatedNetworks {
network := &updatedNetworks[i]
for j := range network.RpcProviders {
provider := &network.RpcProviders[j]
if provider.Type != params.EmbeddedDirectProviderType {
continue
}
host, err := extractHost(provider.URL)
if err != nil {
continue
}
for suffix, token := range authTokens {
if strings.HasSuffix(host, suffix) && token != "" {
provider.AuthType = params.TokenAuth
provider.AuthToken = token
break
}
}
}
}
return updatedNetworks
}
func OverrideGanacheToken(networks []params.Network, ganacheURL string, chainID uint64, tokenOverride params.TokenOverride) []params.Network {
updatedNetworks := deepCopyNetworks(networks)
for i := range updatedNetworks {
network := &updatedNetworks[i]
if network.ChainID != chainID {
continue
}
for j := range network.RpcProviders {
if ganacheURL != "" {
network.RpcProviders[j].URL = ganacheURL
}
}
network.TokenOverrides = []params.TokenOverride{tokenOverride}
}
return updatedNetworks
}
func extractHost(providerURL string) (string, error) {
parsedURL, err := url.Parse(providerURL)
if err != nil {
return "", err
}
return parsedURL.Host, nil
}

View File

@ -0,0 +1,171 @@
package networkhelper_test
import (
"reflect"
"strings"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/status-im/status-go/api"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/params/networkhelper"
"github.com/status-im/status-go/rpc/network/testutil"
)
func TestMergeProvidersPreserveEnabledAndOrder(t *testing.T) {
chainID := uint64(1)
// Current providers with mixed types
currentProviders := []params.RpcProvider{
*params.NewUserProvider(chainID, "UserProvider1", "https://userprovider1.example.com", true),
*params.NewUserProvider(chainID, "UserProvider2", "https://userprovider2.example.com", true),
*params.NewDirectProvider(chainID, "EmbeddedProvider1", "https://embeddedprovider1.example.com", true),
*params.NewProxyProvider(chainID, "EmbeddedProvider2", "https://embeddedprovider2.example.com", true),
}
currentProviders[1].Enabled = false // UserProvider2 is disabled
currentProviders[2].Enabled = false // EmbeddedProvider1 is disabled
// New providers to merge
newProviders := []params.RpcProvider{
*params.NewDirectProvider(chainID, "EmbeddedProvider1", "https://embeddedprovider1-new.example.com", true), // Should retain Enabled: false
*params.NewProxyProvider(chainID, "EmbeddedProvider3", "https://embeddedprovider3.example.com", true), // New embedded provider
*params.NewDirectProvider(chainID, "EmbeddedProvider4", "https://embeddedprovider4.example.com", true), // New embedded provider
}
// Call MergeProviders
mergedProviders := networkhelper.MergeProvidersPreservingUsersAndEnabledState(currentProviders, newProviders)
expectedEmbeddedProvider1 := newProviders[0]
expectedEmbeddedProvider1.Enabled = false // Should retain Enabled: false
// Expected providers after merging
expectedProviders := []params.RpcProvider{
currentProviders[0], // UserProvider1
currentProviders[1], // UserProvider2
expectedEmbeddedProvider1, // EmbeddedProvider1 (should retain Enabled: false)
newProviders[1], // EmbeddedProvider3
newProviders[2], // EmbeddedProvider4
}
// Assertions
require.True(t, reflect.DeepEqual(mergedProviders, expectedProviders), "Merged providers should match the expected providers")
}
func TestUpdateEmbeddedProxyProviders(t *testing.T) {
// Arrange: Create a sample list of networks with various provider types
networks := []params.Network{
*testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{
*params.NewUserProvider(api.MainnetChainID, "Provider1", "https://userprovider.example.com", true),
*params.NewProxyProvider(api.MainnetChainID, "Provider2", "https://proxyprovider.example.com", true),
}),
*testutil.CreateNetwork(api.OptimismChainID, "Optimism", []params.RpcProvider{
*params.NewDirectProvider(api.OptimismChainID, "Provider3", "https://directprovider.example.com", true),
*params.NewProxyProvider(api.OptimismChainID, "Provider4", "https://proxyprovider2.example.com", true),
}),
}
networks[0].RpcProviders[1].Enabled = false
networks[1].RpcProviders[1].Enabled = false
user := gofakeit.Username()
password := gofakeit.LetterN(5)
// Call the function to update embedded-proxy providers
updatedNetworks := networkhelper.OverrideEmbeddedProxyProviders(networks, true, user, password)
// Verify the networks
for i, network := range updatedNetworks {
networkCopy := network
expectedNetwork := &networks[i]
testutil.CompareNetworks(t, expectedNetwork, &networkCopy)
for j, provider := range networkCopy.RpcProviders {
expectedProvider := expectedNetwork.RpcProviders[j]
if provider.Type == params.EmbeddedProxyProviderType {
assert.True(t, provider.Enabled, "Provider Enabled state should be overridden")
assert.Equal(t, user, provider.AuthLogin, "Provider AuthLogin should be overridden")
assert.Equal(t, password, provider.AuthPassword, "Provider AuthPassword should be overridden")
} else {
assert.Equal(t, expectedProvider.Enabled, provider.Enabled, "Provider Enabled state should remain unchanged")
assert.Equal(t, expectedProvider.AuthLogin, provider.AuthLogin, "Provider AuthLogin should remain unchanged")
assert.Equal(t, expectedProvider.AuthPassword, provider.AuthPassword, "Provider AuthPassword should remain unchanged")
}
}
}
}
func TestOverrideDirectProvidersAuth(t *testing.T) {
// Create a sample list of networks with various provider types
networks := []params.Network{
*testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{
*params.NewUserProvider(api.MainnetChainID, "Provider1", "https://user.example.com/", true),
*params.NewDirectProvider(api.MainnetChainID, "Provider2", "https://mainnet.infura.io/v3/", true),
*params.NewDirectProvider(api.MainnetChainID, "Provider3", "https://eth-archival.rpc.grove.city/v1/", true),
}),
*testutil.CreateNetwork(api.OptimismChainID, "Optimism", []params.RpcProvider{
*params.NewDirectProvider(api.OptimismChainID, "Provider4", "https://optimism.infura.io/v3/", true),
*params.NewDirectProvider(api.OptimismChainID, "Provider5", "https://op.grove.city/v1/", true),
}),
}
authTokens := map[string]string{
"infura.io": gofakeit.UUID(),
"grove.city": gofakeit.UUID(),
"example.com": gofakeit.UUID(),
}
// Call OverrideDirectProvidersAuth
updatedNetworks := networkhelper.OverrideDirectProvidersAuth(networks, authTokens)
// Verify the networks have updated auth tokens correctly
for i, network := range updatedNetworks {
for j, provider := range network.RpcProviders {
expectedProvider := networks[i].RpcProviders[j]
switch {
case strings.Contains(provider.URL, "infura.io"):
assert.Equal(t, params.TokenAuth, provider.AuthType)
assert.Equal(t, authTokens["infura.io"], provider.AuthToken)
assert.NotEqual(t, expectedProvider.AuthToken, provider.AuthToken)
case strings.Contains(provider.URL, "grove.city"):
assert.Equal(t, params.TokenAuth, provider.AuthType)
assert.Equal(t, authTokens["grove.city"], provider.AuthToken)
assert.NotEqual(t, expectedProvider.AuthToken, provider.AuthToken)
case strings.Contains(provider.URL, "example.com"):
assert.Equal(t, params.NoAuth, provider.AuthType) // should not update user providers
default:
assert.Equal(t, expectedProvider.AuthType, provider.AuthType)
assert.Equal(t, expectedProvider.AuthToken, provider.AuthToken)
}
}
}
}
func TestOverrideGanacheTokenOverrides(t *testing.T) {
// Create a sample list of networks with various ChainIDs
networks := []params.Network{
*testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", nil),
*testutil.CreateNetwork(api.OptimismChainID, "Optimism", nil),
*testutil.CreateNetwork(api.MainnetChainID, "Mainnet Duplicate", nil),
}
ganacheTokenOverride := params.TokenOverride{
Symbol: "SNT",
Address: common.HexToAddress("0x8571Ddc46b10d31EF963aF49b6C7799Ea7eff818"),
}
// Call OverrideGanacheTokenOverrides
updatedNetworks := networkhelper.OverrideGanacheToken(networks, "url", api.MainnetChainID, ganacheTokenOverride)
// Verify that only networks with the specified ChainID have the token override applied
for _, network := range updatedNetworks {
if network.ChainID == api.MainnetChainID {
require.NotNil(t, network.TokenOverrides, "TokenOverrides should not be nil for ChainID %d", network.ChainID)
assert.Contains(t, network.TokenOverrides, ganacheTokenOverride, "TokenOverrides should contain the ganache token")
} else {
assert.Nil(t, network.TokenOverrides, "TokenOverrides should be nil for ChainID %d", network.ChainID)
}
}
}

View File

@ -0,0 +1,49 @@
package networkhelper
import (
"gopkg.in/go-playground/validator.v9"
"github.com/status-im/status-go/params"
)
func GetValidator() *validator.Validate {
validate := validator.New()
// Register struct-level validation for RpcProvider
validate.RegisterStructValidation(rpcProviderStructLevelValidation, params.RpcProvider{})
return validate
}
func rpcProviderStructLevelValidation(sl validator.StructLevel) {
provider := sl.Current().Interface().(params.RpcProvider)
switch provider.AuthType {
case params.NoAuth:
if provider.AuthLogin != "" || provider.AuthPassword != "" || provider.AuthToken != "" {
sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "noauth_fields_empty", "")
sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "noauth_fields_empty", "")
sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "noauth_fields_empty", "")
}
case params.BasicAuth:
if provider.AuthLogin == "" {
sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "required", "")
}
if provider.AuthPassword == "" {
sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "required", "")
}
if provider.AuthToken != "" {
sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "basic_auth_token_empty", "")
}
case params.TokenAuth:
if provider.AuthToken == "" {
sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "required", "")
}
if provider.AuthLogin != "" || provider.AuthPassword != "" {
sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "tokenauth_fields_empty", "")
sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "tokenauth_fields_empty", "")
}
default:
sl.ReportError(provider.AuthType, "AuthType", "authType", "invalid_auth_type", "")
}
}

View File

@ -0,0 +1,184 @@
package networkhelper
import (
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/go-playground/validator.v9"
"github.com/status-im/status-go/params"
)
func TestValidation(t *testing.T) {
validate := GetValidator()
// Test cases for RpcProvider
providerTests := []struct {
name string
provider params.RpcProvider
expectErr bool
}{
{
name: "Valid Provider",
provider: params.RpcProvider{
ChainID: 1,
Name: "Mainnet Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
Enabled: true,
AuthType: params.NoAuth,
EnableRPSLimiter: false,
},
expectErr: false,
},
{
name: "Missing Provider Name",
provider: params.RpcProvider{
ChainID: 1,
URL: "https://provider.example.com",
Type: params.UserProviderType,
},
expectErr: true,
},
{
name: "Invalid AuthType",
provider: params.RpcProvider{
ChainID: 1,
Name: "Invalid Auth Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
AuthType: "invalid-auth-type",
},
expectErr: true,
},
{
name: "BasicAuth without Login",
provider: params.RpcProvider{
ChainID: 1,
Name: "BasicAuth Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
AuthType: params.BasicAuth,
},
expectErr: true,
},
{
name: "TokenAuth without Token",
provider: params.RpcProvider{
ChainID: 1,
Name: "TokenAuth Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
AuthType: params.TokenAuth,
},
expectErr: true,
},
{
name: "NoAuth with Login",
provider: params.RpcProvider{
ChainID: 1,
Name: "NoAuth Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
AuthType: params.NoAuth,
AuthLogin: "user",
},
expectErr: true,
},
}
for _, test := range providerTests {
t.Run("RpcProvider: "+test.name, func(t *testing.T) {
err := validate.Struct(test.provider)
if test.expectErr {
require.Error(t, err, "Expected error but got nil for test case '%s'", test.name)
} else {
require.NoError(t, err, "Did not expect error but got '%v' for test case '%s'", err, test.name)
}
})
}
}
// Test cases for Network
func TestNetworkValidation(t *testing.T) {
validate := validator.New()
networkTests := []struct {
name string
network params.Network
expectErr bool
}{
{
name: "Valid Network",
network: params.Network{
ChainID: 1,
ChainName: "Ethereum Mainnet",
BlockExplorerURL: "https://etherscan.io",
NativeCurrencyName: "Ether",
NativeCurrencySymbol: "ETH",
NativeCurrencyDecimals: 18,
IsTest: false,
Layer: 1,
Enabled: true,
ChainColor: "#E90101",
ShortName: "eth",
RpcProviders: []params.RpcProvider{
{
ChainID: 1,
Name: "Mainnet Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
Enabled: true,
AuthType: params.NoAuth,
},
},
},
expectErr: false,
},
{
name: "Missing Chain Name",
network: params.Network{
ChainID: 1,
RpcProviders: []params.RpcProvider{
{
ChainID: 1,
Name: "Mainnet Provider",
URL: "https://provider.example.com",
Type: params.UserProviderType,
Enabled: true,
},
},
},
expectErr: true,
},
{
name: "Invalid Provider in Network",
network: params.Network{
ChainID: 1,
ChainName: "Ethereum Mainnet",
RpcProviders: []params.RpcProvider{
{
ChainID: 1,
Name: "",
URL: "https://provider.example.com",
Type: params.UserProviderType,
Enabled: true,
},
},
},
expectErr: true,
},
}
for _, test := range networkTests {
t.Run("Network: "+test.name, func(t *testing.T) {
err := validate.Struct(test.network)
if test.expectErr {
require.Error(t, err, "Expected error but got nil for test case '%s'", test.name)
} else {
require.NoError(t, err, "Did not expect error but got '%v' for test case '%s'", err, test.name)
}
})
}
}

View File

@ -0,0 +1,257 @@
package db
import (
"fmt"
sq "github.com/Masterminds/squirrel"
"gopkg.in/go-playground/validator.v9"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/sqlite"
)
// NetworksPersistenceInterface describes the interface for managing networks and providers.
type NetworksPersistenceInterface interface {
// GetNetworks returns networks based on filters.
GetNetworks(onlyEnabled bool, chainID *uint64) ([]*params.Network, error)
// GetAllNetworks returns all networks.
GetAllNetworks() ([]*params.Network, error)
// GetNetworkByChainID returns a network by ChainID.
GetNetworkByChainID(chainID uint64) ([]*params.Network, error)
// GetEnabledNetworks returns enabled networks.
GetEnabledNetworks() ([]*params.Network, error)
// SetNetworks replaces all networks with new ones and their providers.
SetNetworks(networks []params.Network) error
// UpsertNetwork adds or updates a network and its providers.
UpsertNetwork(network *params.Network) error
// DeleteNetwork deletes a network by ChainID and its providers.
DeleteNetwork(chainID uint64) error
// DeleteAllNetworks deletes all networks and their providers.
DeleteAllNetworks() error
GetRpcPersistence() RpcProvidersPersistenceInterface
}
// NetworksPersistence manages networks and their providers.
type NetworksPersistence struct {
db sqlite.StatementExecutor
rpcPersistence RpcProvidersPersistenceInterface
validator *validator.Validate
}
// NewNetworksPersistence creates a new instance of NetworksPersistence.
func NewNetworksPersistence(db sqlite.StatementExecutor) *NetworksPersistence {
return &NetworksPersistence{
db: db,
rpcPersistence: NewRpcProvidersPersistence(db),
validator: validator.New(),
}
}
// GetRpcPersistence returns an instance of RpcProvidersPersistenceInterface.
func (n *NetworksPersistence) GetRpcPersistence() RpcProvidersPersistenceInterface {
return n.rpcPersistence
}
// GetNetworks returns networks based on filters.
func (n *NetworksPersistence) GetNetworks(onlyEnabled bool, chainID *uint64) ([]*params.Network, error) {
q := sq.Select(
"chain_id", "chain_name", "rpc_url", "fallback_url",
"block_explorer_url", "icon_url", "native_currency_name", "native_currency_symbol", "native_currency_decimals",
"is_test", "layer", "enabled", "chain_color", "short_name", "related_chain_id",
).
From("networks").
OrderBy("chain_id ASC")
if onlyEnabled {
q = q.Where(sq.Eq{"enabled": true})
}
if chainID != nil {
q = q.Where(sq.Eq{"chain_id": *chainID})
}
query, args, err := q.ToSql()
if err != nil {
return nil, err
}
rows, err := n.db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result := make([]*params.Network, 0, 10)
for rows.Next() {
network := &params.Network{}
err := rows.Scan(
&network.ChainID, &network.ChainName, &network.RPCURL, &network.FallbackURL,
&network.BlockExplorerURL, &network.IconURL, &network.NativeCurrencyName, &network.NativeCurrencySymbol,
&network.NativeCurrencyDecimals, &network.IsTest, &network.Layer, &network.Enabled, &network.ChainColor,
&network.ShortName, &network.RelatedChainID,
)
if err != nil {
return nil, err
}
// Fetch RPC providers for the network
providers, err := n.rpcPersistence.GetRpcProviders(network.ChainID)
if err != nil {
return nil, fmt.Errorf("failed to fetch RPC providers for chain_id %d: %w", network.ChainID, err)
}
network.RpcProviders = providers
// Fill deprecated URLs if necessary (assuming this is a function you have)
FillDeprecatedURLs(network, providers)
result = append(result, network)
}
if err = rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetAllNetworks returns all networks.
func (n *NetworksPersistence) GetAllNetworks() ([]*params.Network, error) {
return n.GetNetworks(false, nil)
}
// GetNetworkByChainID returns a network by ChainID.
func (n *NetworksPersistence) GetNetworkByChainID(chainID uint64) ([]*params.Network, error) {
return n.GetNetworks(false, &chainID)
}
// GetEnabledNetworks returns enabled networks.
func (n *NetworksPersistence) GetEnabledNetworks() ([]*params.Network, error) {
return n.GetNetworks(true, nil)
}
// SetNetworks replaces all networks with new ones and their providers.
// Note: Transaction management should be handled by the caller.
func (n *NetworksPersistence) SetNetworks(networks []params.Network) error {
// Delete all networks and their providers
err := n.DeleteAllNetworks()
if err != nil {
return err
}
// Upsert networks and their providers
for i := range networks {
err := n.UpsertNetwork(&networks[i])
if err != nil {
return fmt.Errorf("failed to upsert network with chain_id %d: %w", networks[i].ChainID, err)
}
}
return nil
}
// DeleteAllNetworks deletes all networks and their RPC providers.
// Note: Transaction management should be handled by the caller.
func (n *NetworksPersistence) DeleteAllNetworks() error {
// Delete all RPC providers
err := n.rpcPersistence.DeleteAllRpcProviders()
if err != nil {
return fmt.Errorf("failed to delete all RPC providers: %w", err)
}
// Delete all networks
q := sq.Delete("networks")
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build delete query: %w", err)
}
_, err = n.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute delete query: %w", err)
}
return nil
}
// UpsertNetwork adds or updates a network and its providers.
// Note: Transaction management should be handled by the caller.
func (n *NetworksPersistence) UpsertNetwork(network *params.Network) error {
// Validate the network
if err := n.validator.Struct(network); err != nil {
return fmt.Errorf("network validation failed: %w", err)
}
// Upsert the network
err := n.upsertNetwork(network)
if err != nil {
return fmt.Errorf("failed to upsert network for chain_id %d: %w", network.ChainID, err)
}
// Set the RPC providers
err = n.rpcPersistence.SetRpcProviders(network.ChainID, network.RpcProviders)
if err != nil {
return fmt.Errorf("failed to set RPC providers for chain_id %d: %w", network.ChainID, err)
}
return nil
}
// upsertNetwork handles the logic for inserting or updating a network record.
func (n *NetworksPersistence) upsertNetwork(network *params.Network) error {
q := sq.Insert("networks").
Columns(
"chain_id", "chain_name", "rpc_url", "original_rpc_url", "fallback_url", "original_fallback_url",
"block_explorer_url", "icon_url", "native_currency_name", "native_currency_symbol", "native_currency_decimals",
"is_test", "layer", "enabled", "chain_color", "short_name", "related_chain_id",
).
Values(
network.ChainID, network.ChainName, network.RPCURL, network.OriginalRPCURL, network.FallbackURL, network.OriginalFallbackURL,
network.BlockExplorerURL, network.IconURL, network.NativeCurrencyName, network.NativeCurrencySymbol, network.NativeCurrencyDecimals,
network.IsTest, network.Layer, network.Enabled, network.ChainColor, network.ShortName, network.RelatedChainID,
).
Suffix("ON CONFLICT(chain_id) DO UPDATE SET " +
"chain_name = excluded.chain_name, rpc_url = excluded.rpc_url, original_rpc_url = excluded.original_rpc_url, " +
"fallback_url = excluded.fallback_url, original_fallback_url = excluded.original_fallback_url, " +
"block_explorer_url = excluded.block_explorer_url, icon_url = excluded.icon_url, " +
"native_currency_name = excluded.native_currency_name, native_currency_symbol = excluded.native_currency_symbol, " +
"native_currency_decimals = excluded.native_currency_decimals, is_test = excluded.is_test, " +
"layer = excluded.layer, enabled = excluded.enabled, chain_color = excluded.chain_color, " +
"short_name = excluded.short_name, related_chain_id = excluded.related_chain_id")
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build upsert query: %w", err)
}
_, err = n.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute upsert query: %w", err)
}
return nil
}
// DeleteNetwork deletes a network and its associated RPC providers.
// Note: Transaction management should be handled by the caller.
func (n *NetworksPersistence) DeleteNetwork(chainID uint64) error {
// Delete RPC providers
err := n.rpcPersistence.DeleteRpcProviders(chainID)
if err != nil {
return fmt.Errorf("failed to delete RPC providers for chain_id %d: %w", chainID, err)
}
// Delete the network record
q := sq.Delete("networks").Where(sq.Eq{"chain_id": chainID})
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build delete query: %w", err)
}
_, err = n.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute delete query for chain_id %d: %w", chainID, err)
}
return nil
}

View File

@ -0,0 +1,197 @@
package db_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/status-im/status-go/api"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc/network/db"
"github.com/status-im/status-go/rpc/network/testutil"
"github.com/status-im/status-go/t/helpers"
)
type NetworksPersistenceTestSuite struct {
suite.Suite
db *sql.DB
cleanup func() error
networksPersistence db.NetworksPersistenceInterface
}
func (s *NetworksPersistenceTestSuite) SetupTest() {
memDb, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "networks-tests")
s.Require().NoError(err)
s.db = memDb
s.cleanup = cleanup
s.networksPersistence = db.NewNetworksPersistence(memDb)
}
func (s *NetworksPersistenceTestSuite) TearDownTest() {
if s.cleanup != nil {
err := s.cleanup()
require.NoError(s.T(), err)
}
}
func TestNetworksPersistenceTestSuite(t *testing.T) {
suite.Run(t, new(NetworksPersistenceTestSuite))
}
// Helper function to create default providers for a given chainID
func DefaultProviders(chainID uint64) []params.RpcProvider {
return []params.RpcProvider{
{
Name: "Provider1",
ChainID: chainID,
URL: "https://rpc.provider1.io",
Type: params.UserProviderType,
Enabled: true,
AuthType: params.NoAuth,
},
{
Name: "Provider2",
ChainID: chainID,
URL: "https://rpc.provider2.io",
Type: params.EmbeddedProxyProviderType,
Enabled: true,
AuthType: params.BasicAuth,
AuthLogin: "user1",
AuthPassword: "password1",
},
}
}
// Helper function to add and verify networks
func (s *NetworksPersistenceTestSuite) addAndVerifyNetworks(networks []*params.Network) {
networkValues := make([]params.Network, 0, len(networks))
for _, network := range networks {
networkValues = append(networkValues, *network)
}
err := s.networksPersistence.SetNetworks(networkValues)
s.Require().NoError(err)
s.verifyNetworks(networks)
}
// Helper function to verify networks against the database
func (s *NetworksPersistenceTestSuite) verifyNetworks(networks []*params.Network) {
allNetworks, err := s.networksPersistence.GetAllNetworks()
s.Require().NoError(err)
testutil.CompareNetworksList(s.T(), networks, allNetworks)
}
// Helper function to verify network deletion
func (s *NetworksPersistenceTestSuite) verifyNetworkDeletion(chainID uint64) {
nets, err := s.networksPersistence.GetNetworkByChainID(chainID)
s.Require().NoError(err)
s.Require().Len(nets, 0)
providers, err := s.networksPersistence.GetRpcPersistence().GetRpcProviders(chainID)
s.Require().NoError(err)
s.Require().Len(providers, 0)
}
// Tests
func (s *NetworksPersistenceTestSuite) TestAddAndGetNetworkWithProviders() {
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", []params.RpcProvider{
testutil.CreateProvider(api.OptimismChainID, "Provider1", params.UserProviderType, true, "https://rpc.optimism.io"),
testutil.CreateProvider(api.OptimismChainID, "Provider2", params.EmbeddedProxyProviderType, false, "https://backup.optimism.io"),
})
s.addAndVerifyNetworks([]*params.Network{network})
}
func (s *NetworksPersistenceTestSuite) TestDeleteNetworkWithProviders() {
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID))
s.addAndVerifyNetworks([]*params.Network{network})
err := s.networksPersistence.DeleteNetwork(network.ChainID)
s.Require().NoError(err)
s.verifyNetworkDeletion(network.ChainID)
}
func (s *NetworksPersistenceTestSuite) TestUpdateNetworkAndProviders() {
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID))
s.addAndVerifyNetworks([]*params.Network{network})
// Update fields
network.ChainName = "Updated Optimism Mainnet"
network.RpcProviders = []params.RpcProvider{
testutil.CreateProvider(api.OptimismChainID, "UpdatedProvider", params.UserProviderType, true, "https://rpc.optimism.updated.io"),
}
s.addAndVerifyNetworks([]*params.Network{network})
}
func (s *NetworksPersistenceTestSuite) TestDeleteAllNetworks() {
networks := []*params.Network{
testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", DefaultProviders(api.MainnetChainID)),
testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", DefaultProviders(api.SepoliaChainID)),
}
s.addAndVerifyNetworks(networks)
err := s.networksPersistence.DeleteAllNetworks()
s.Require().NoError(err)
allNetworks, err := s.networksPersistence.GetAllNetworks()
s.Require().NoError(err)
s.Require().Len(allNetworks, 0)
}
func (s *NetworksPersistenceTestSuite) TestSetNetworks() {
initialNetworks := []*params.Network{
testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", DefaultProviders(api.MainnetChainID)),
testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", DefaultProviders(api.SepoliaChainID)),
}
newNetworks := []*params.Network{
testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID)),
}
// Add initial networks
s.addAndVerifyNetworks(initialNetworks)
// Replace with new networks
s.addAndVerifyNetworks(newNetworks)
// Verify old networks are removed
s.verifyNetworkDeletion(api.MainnetChainID)
s.verifyNetworkDeletion(api.SepoliaChainID)
}
func (s *NetworksPersistenceTestSuite) TestValidationForNetworksAndProviders() {
// Invalid Network: Missing required ChainName
invalidNetwork := testutil.CreateNetwork(api.MainnetChainID, "", DefaultProviders(api.MainnetChainID))
// Invalid Provider: Missing URL
invalidProvider := params.RpcProvider{
Name: "InvalidProvider",
ChainID: api.MainnetChainID,
URL: "", // Invalid
Type: params.UserProviderType,
Enabled: true,
}
// Add invalid provider to a valid network
validNetworkWithInvalidProvider := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", []params.RpcProvider{invalidProvider})
// Invalid networks and providers should fail validation
networksToValidate := []*params.Network{
invalidNetwork,
validNetworkWithInvalidProvider,
}
for _, network := range networksToValidate {
err := s.networksPersistence.UpsertNetwork(network)
s.Require().Error(err, "Expected validation to fail for invalid network or provider")
}
// Ensure no invalid data is saved in the database
allNetworks, err := s.networksPersistence.GetAllNetworks()
s.Require().NoError(err)
s.Require().Len(allNetworks, 0, "No invalid networks should be saved")
}

View File

@ -0,0 +1,243 @@
package db
import (
"fmt"
sq "github.com/Masterminds/squirrel"
"gopkg.in/go-playground/validator.v9"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/sqlite"
)
// Interface for managing RPC providers
type RpcProvidersPersistenceInterface interface {
GetRpcProviders(chainID uint64) ([]params.RpcProvider, error)
GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error)
AddRpcProvider(provider params.RpcProvider) error
DeleteRpcProviders(chainID uint64) error
DeleteAllRpcProviders() error
UpdateRpcProvider(provider params.RpcProvider) error
SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error
}
// Struct for managing RPC providers
type RpcProvidersPersistence struct {
db sqlite.StatementExecutor
validator *validator.Validate
}
// Constructor for RpcProvidersPersistence with validator
func NewRpcProvidersPersistence(db sqlite.StatementExecutor) *RpcProvidersPersistence {
return &RpcProvidersPersistence{
db: db,
validator: validator.New(),
}
}
// Retrieve all providers for a specific ChainID
func (p *RpcProvidersPersistence) GetRpcProviders(chainID uint64) ([]params.RpcProvider, error) {
q := sq.Select(
"id",
"chain_id",
"name",
"url",
"enable_rps_limiter",
"type",
"enabled",
"auth_type",
"auth_login",
"auth_password",
"auth_token",
).
From("rpc_providers").
Where(sq.Eq{"chain_id": chainID}).
OrderBy("id ASC")
query, args, err := q.ToSql()
if err != nil {
return nil, fmt.Errorf("failed to build query: %w", err)
}
rows, err := p.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
defer rows.Close()
var providers []params.RpcProvider
for rows.Next() {
var provider params.RpcProvider
err := rows.Scan(
&provider.ID,
&provider.ChainID,
&provider.Name,
&provider.URL,
&provider.EnableRPSLimiter,
&provider.Type,
&provider.Enabled,
&provider.AuthType,
&provider.AuthLogin,
&provider.AuthPassword,
&provider.AuthToken,
)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
providers = append(providers, provider)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration error: %w", err)
}
return providers, nil
}
// Retrieve providers of a specific type
func (p *RpcProvidersPersistence) GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error) {
allProviders, err := p.GetRpcProviders(chainID)
if err != nil {
return nil, err
}
result := make([]params.RpcProvider, 0, len(allProviders))
for _, provider := range allProviders {
if provider.Type == providerType {
result = append(result, provider)
}
}
return result, nil
}
// Add a new provider
func (p *RpcProvidersPersistence) AddRpcProvider(provider params.RpcProvider) error {
// Validate the provider
if err := p.validator.Struct(provider); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Proceed with adding the provider to the database
q := sq.Insert("rpc_providers").
Columns(
"chain_id",
"name",
"url",
"enable_rps_limiter",
"type",
"enabled",
"auth_type",
"auth_login",
"auth_password",
"auth_token",
).
Values(
provider.ChainID,
provider.Name,
provider.URL,
provider.EnableRPSLimiter,
provider.Type,
provider.Enabled,
provider.AuthType,
provider.AuthLogin,
provider.AuthPassword,
provider.AuthToken,
)
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build insert query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute insert query: %w", err)
}
return nil
}
// Delete providers for a specific ChainID
func (p *RpcProvidersPersistence) DeleteRpcProviders(chainID uint64) error {
q := sq.Delete("rpc_providers").
Where(sq.Eq{"chain_id": chainID})
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build delete query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute delete query: %w", err)
}
return nil
}
// Delete all providers
func (p *RpcProvidersPersistence) DeleteAllRpcProviders() error {
q := sq.Delete("rpc_providers")
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build delete query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute delete query: %w", err)
}
return nil
}
// Update an existing provider
func (p *RpcProvidersPersistence) UpdateRpcProvider(provider params.RpcProvider) error {
// Validate the provider
if err := p.validator.Struct(provider); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Proceed with updating the provider in the database
q := sq.Update("rpc_providers").
SetMap(sq.Eq{
"url": provider.URL,
"enable_rps_limiter": provider.EnableRPSLimiter,
"type": provider.Type,
"enabled": provider.Enabled,
"auth_type": provider.AuthType,
"auth_login": provider.AuthLogin,
"auth_password": provider.AuthPassword,
"auth_token": provider.AuthToken,
}).
Where(sq.Eq{"id": provider.ID})
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build update query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute update query: %w", err)
}
return nil
}
// Set the list of providers for a ChainID, replacing any existing providers
func (p *RpcProvidersPersistence) SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error {
if err := p.DeleteRpcProviders(chainID); err != nil {
return fmt.Errorf("failed to delete existing providers: %w", err)
}
for _, provider := range newProviders {
if err := p.AddRpcProvider(provider); err != nil {
return fmt.Errorf("failed to add new provider: %w", err)
}
}
return nil
}

View File

@ -0,0 +1,157 @@
package db_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/status-im/status-go/api"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc/network/db"
"github.com/status-im/status-go/rpc/network/testutil"
"github.com/status-im/status-go/t/helpers"
)
type RpcProviderPersistenceTestSuite struct {
suite.Suite
db *sql.DB
rpcPersistence *db.RpcProvidersPersistence
}
func (s *RpcProviderPersistenceTestSuite) SetupTest() {
testDb := setupTestNetworkDB(s.T())
s.db = testDb
s.rpcPersistence = db.NewRpcProvidersPersistence(testDb)
}
func setupTestNetworkDB(t *testing.T) *sql.DB {
testDb, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "rpc-providers-tests")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, cleanup()) })
return testDb
}
func TestRpcProviderPersistenceTestSuite(t *testing.T) {
suite.Run(t, new(RpcProviderPersistenceTestSuite))
}
// Test cases
func (s *RpcProviderPersistenceTestSuite) TestAddAndGetRpcProvider() {
provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com")
err := s.rpcPersistence.AddRpcProvider(provider)
s.Require().NoError(err)
// Verify the added provider
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
s.Require().NoError(err)
testutil.CompareProvidersList(s.T(), []params.RpcProvider{provider}, providers)
}
func (s *RpcProviderPersistenceTestSuite) TestGetRpcProvidersByType() {
providers := []params.RpcProvider{
testutil.CreateProvider(api.MainnetChainID, "UserProvider1", params.UserProviderType, true, "https://provider1.example.com"),
testutil.CreateProvider(api.MainnetChainID, "EmbeddedDirect1", params.EmbeddedDirectProviderType, false, "https://provider2.example.com"),
testutil.CreateProvider(api.MainnetChainID, "UserProvider2", params.UserProviderType, false, "https://provider3.example.com"),
testutil.CreateProvider(api.MainnetChainID, "EmbeddedProxy1", params.EmbeddedProxyProviderType, true, "https://provider4.example.com"),
}
for _, provider := range providers {
err := s.rpcPersistence.AddRpcProvider(provider)
s.Require().NoError(err)
}
// Verify by type
userProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.UserProviderType)
s.Require().NoError(err)
testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[0], providers[2]}, userProviders)
embeddedDirectProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.EmbeddedDirectProviderType)
s.Require().NoError(err)
testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[1]}, embeddedDirectProviders)
embeddedProxyProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.EmbeddedProxyProviderType)
s.Require().NoError(err)
testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[3]}, embeddedProxyProviders)
}
func (s *RpcProviderPersistenceTestSuite) TestDeleteRpcProviders() {
provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com")
err := s.rpcPersistence.AddRpcProvider(provider)
s.Require().NoError(err)
err = s.rpcPersistence.DeleteRpcProviders(api.MainnetChainID)
s.Require().NoError(err)
// Verify deletion
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
s.Require().NoError(err)
s.Require().Empty(providers)
}
func (s *RpcProviderPersistenceTestSuite) TestUpdateRpcProvider() {
provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com")
err := s.rpcPersistence.AddRpcProvider(provider)
s.Require().NoError(err)
// Retrieve provider to get the ID
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
s.Require().NoError(err)
s.Require().Len(providers, 1)
provider.ID = providers[0].ID
provider.URL = "https://provider1-updated.example.com"
provider.EnableRPSLimiter = false
err = s.rpcPersistence.UpdateRpcProvider(provider)
s.Require().NoError(err)
// Verify update
updatedProviders, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
s.Require().NoError(err)
testutil.CompareProvidersList(s.T(), []params.RpcProvider{provider}, updatedProviders)
}
func (s *RpcProviderPersistenceTestSuite) TestSetRpcProviders() {
initialProviders := []params.RpcProvider{
testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com"),
testutil.CreateProvider(api.MainnetChainID, "Provider2", params.EmbeddedDirectProviderType, false, "https://provider2.example.com"),
}
for _, provider := range initialProviders {
err := s.rpcPersistence.AddRpcProvider(provider)
s.Require().NoError(err)
}
newProviders := []params.RpcProvider{
testutil.CreateProvider(api.MainnetChainID, "NewProvider1", params.UserProviderType, true, "https://newprovider1.example.com"),
testutil.CreateProvider(api.MainnetChainID, "NewProvider2", params.EmbeddedProxyProviderType, true, "https://newprovider2.example.com"),
}
err := s.rpcPersistence.SetRpcProviders(api.MainnetChainID, newProviders)
s.Require().NoError(err)
// Verify replacement
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
s.Require().NoError(err)
testutil.CompareProvidersList(s.T(), newProviders, providers)
}
func (s *RpcProviderPersistenceTestSuite) TestAddRpcProviderValidation() {
invalidProvider := params.RpcProvider{
ChainID: 0, // Invalid: must be greater than 0
Name: "", // Invalid: cannot be empty
URL: "invalid-url", // Invalid: not a valid URL
Type: "invalid-type", // Invalid: not in allowed values
}
err := s.rpcPersistence.AddRpcProvider(invalidProvider)
s.Require().Error(err)
s.Contains(err.Error(), "validation failed")
}

57
rpc/network/db/utils.go Normal file
View File

@ -0,0 +1,57 @@
package db
import (
"github.com/status-im/status-go/params"
)
// Deprecated: fillDeprecatedURLs populates the `original_rpc_url`, `original_fallback_url`, `rpc_url`,
// `fallback_url`, `defaultRpcUrl`, `defaultFallbackURL`, and `defaultFallbackURL2` fields.
// Keep for backwrad compatibility until it's fully integrated
func FillDeprecatedURLs(network *params.Network, providers []params.RpcProvider) {
var embeddedDirect []params.RpcProvider
var embeddedProxy []params.RpcProvider
var userProviders []params.RpcProvider
// Categorize providers
for _, provider := range providers {
switch provider.Type {
case params.EmbeddedDirectProviderType:
embeddedDirect = append(embeddedDirect, provider)
case params.EmbeddedProxyProviderType:
embeddedProxy = append(embeddedProxy, provider)
case params.UserProviderType:
userProviders = append(userProviders, provider)
}
}
// Set original_*_url fields based on EmbeddedDirectProviderType providers
if len(embeddedDirect) > 0 {
network.OriginalRPCURL = embeddedDirect[0].URL
if len(embeddedDirect) > 1 {
network.OriginalFallbackURL = embeddedDirect[1].URL
}
}
// Set rpc_url and fallback_url based on User providers or EmbeddedDirectProviderType if no User providers exist
if len(userProviders) > 0 {
network.RPCURL = userProviders[0].URL
if len(userProviders) > 1 {
network.FallbackURL = userProviders[1].URL
}
} else {
// Default to EmbeddedDirectProviderType providers if no User providers exist
network.RPCURL = network.OriginalRPCURL
network.FallbackURL = network.OriginalFallbackURL
}
// Set default_*_url fields based on EmbeddedProxyProviderType providers
if len(embeddedProxy) > 0 {
network.DefaultRPCURL = embeddedProxy[0].URL
if len(embeddedProxy) > 1 {
network.DefaultFallbackURL = embeddedProxy[1].URL
}
if len(embeddedProxy) > 2 {
network.DefaultFallbackURL2 = embeddedProxy[2].URL
}
}
}

View File

@ -0,0 +1,104 @@
package testutil
import (
"github.com/stretchr/testify/require"
"github.com/status-im/status-go/api"
"github.com/status-im/status-go/params"
)
// Helper function to create a provider
func CreateProvider(chainID uint64, name string, providerType params.RpcProviderType, enabled bool, url string) params.RpcProvider {
return params.RpcProvider{
ChainID: chainID,
Name: name,
URL: url,
EnableRPSLimiter: true,
Type: providerType,
Enabled: enabled,
AuthType: params.BasicAuth,
AuthLogin: "user1",
AuthPassword: "password1",
AuthToken: "",
}
}
// Helper function to create a network
func CreateNetwork(chainID uint64, chainName string, providers []params.RpcProvider) *params.Network {
return &params.Network{
ChainID: chainID,
ChainName: chainName,
BlockExplorerURL: "https://explorer.example.com",
IconURL: "network/Network=" + chainName,
NativeCurrencyName: "Ether",
NativeCurrencySymbol: "ETH",
NativeCurrencyDecimals: 18,
IsTest: false,
Layer: 2,
Enabled: true,
ChainColor: "#E90101",
ShortName: "eth",
RelatedChainID: api.OptimismSepoliaChainID,
RpcProviders: providers,
}
}
// Helper function to compare two providers
func CompareProviders(t require.TestingT, expected, actual params.RpcProvider) {
require.Equal(t, expected.ChainID, actual.ChainID)
require.Equal(t, expected.Name, actual.Name)
require.Equal(t, expected.URL, actual.URL)
require.Equal(t, expected.EnableRPSLimiter, actual.EnableRPSLimiter)
require.Equal(t, expected.Type, actual.Type)
require.Equal(t, expected.Enabled, actual.Enabled)
require.Equal(t, expected.AuthType, actual.AuthType)
require.Equal(t, expected.AuthLogin, actual.AuthLogin)
require.Equal(t, expected.AuthPassword, actual.AuthPassword)
require.Equal(t, expected.AuthToken, actual.AuthToken)
}
// Helper function to compare two networks
func CompareNetworks(t require.TestingT, expected, actual *params.Network) {
require.Equal(t, expected.ChainID, actual.ChainID, "ChainID does not match")
require.Equal(t, expected.ChainName, actual.ChainName, "ChainName does not match for ChainID %d", actual.ChainID)
require.Equal(t, expected.BlockExplorerURL, actual.BlockExplorerURL)
require.Equal(t, expected.NativeCurrencyName, actual.NativeCurrencyName)
require.Equal(t, expected.NativeCurrencySymbol, actual.NativeCurrencySymbol)
require.Equal(t, expected.NativeCurrencyDecimals, actual.NativeCurrencyDecimals)
require.Equal(t, expected.IsTest, actual.IsTest)
require.Equal(t, expected.Layer, actual.Layer)
require.Equal(t, expected.Enabled, actual.Enabled)
require.Equal(t, expected.ChainColor, actual.ChainColor)
require.Equal(t, expected.ShortName, actual.ShortName)
require.Equal(t, expected.RelatedChainID, actual.RelatedChainID)
}
// Helper function to compare lists of providers
func CompareProvidersList(t require.TestingT, expectedProviders, actualProviders []params.RpcProvider) {
require.Len(t, actualProviders, len(expectedProviders))
expectedMap := make(map[string]params.RpcProvider, len(expectedProviders))
for _, provider := range expectedProviders {
expectedMap[provider.Name] = provider
}
for _, provider := range actualProviders {
expectedProvider, exists := expectedMap[provider.Name]
require.True(t, exists, "Unexpected provider '%s'", provider.Name)
CompareProviders(t, expectedProvider, provider)
}
}
// Helper function to compare lists of networks
func CompareNetworksList(t require.TestingT, expectedNetworks, actualNetworks []*params.Network) {
require.Len(t, actualNetworks, len(expectedNetworks))
expectedMap := make(map[uint64]*params.Network, len(expectedNetworks))
for _, network := range expectedNetworks {
expectedMap[network.ChainID] = network
}
for _, network := range actualNetworks {
expectedNetwork, exists := expectedMap[network.ChainID]
require.True(t, exists, "Unexpected network with ChainID %d", network.ChainID)
CompareNetworks(t, expectedNetwork, network)
}
}

View File

@ -6,3 +6,9 @@ import "database/sql"
type StatementCreator interface {
Prepare(query string) (*sql.Stmt, error)
}
type StatementExecutor interface {
StatementCreator
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}