mirror of
https://github.com/status-im/status-go.git
synced 2025-01-17 18:22:13 +00:00
chore(config)_: rpc providers configuration (#6151)
* chore(config)_: extract rpc_provider_persistence + tests * Add rpc_providers table, migration * add RpcProvider type * deprecate old rpc fields in networks, add RpcProviders list * add persistence packages for rpc_providers, networks * Tests
This commit is contained in:
parent
90ce72a2d5
commit
e9abf1662d
@ -1579,12 +1579,12 @@ func TestWalletConfigOnLoginAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Equal(t, b.config.WalletConfig.InfuraAPIKey, infuraToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[mainnetChainID], alchemyEthereumMainnetToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[sepoliaChainID], alchemyEthereumSepoliaToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[arbitrumChainID], alchemyArbitrumMainnetToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[arbitrumSepoliaChainID], alchemyArbitrumSepoliaToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[optimismChainID], alchemyOptimismMainnetToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[optimismSepoliaChainID], alchemyOptimismSepoliaToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[MainnetChainID], alchemyEthereumMainnetToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[SepoliaChainID], alchemyEthereumSepoliaToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[ArbitrumChainID], alchemyArbitrumMainnetToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[ArbitrumSepoliaChainID], alchemyArbitrumSepoliaToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[OptimismChainID], alchemyOptimismMainnetToken)
|
||||
require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[OptimismSepoliaChainID], alchemyOptimismSepoliaToken)
|
||||
require.Equal(t, b.config.WalletConfig.RaribleMainnetAPIKey, raribleMainnetAPIKey)
|
||||
require.Equal(t, b.config.WalletConfig.RaribleTestnetAPIKey, raribleTestnetAPIKey)
|
||||
|
||||
|
@ -10,12 +10,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
mainnetChainID uint64 = 1
|
||||
sepoliaChainID uint64 = 11155111
|
||||
optimismChainID uint64 = 10
|
||||
optimismSepoliaChainID uint64 = 11155420
|
||||
arbitrumChainID uint64 = 42161
|
||||
arbitrumSepoliaChainID uint64 = 421614
|
||||
MainnetChainID uint64 = 1
|
||||
SepoliaChainID uint64 = 11155111
|
||||
OptimismChainID uint64 = 10
|
||||
OptimismSepoliaChainID uint64 = 11155420
|
||||
ArbitrumChainID uint64 = 42161
|
||||
ArbitrumSepoliaChainID uint64 = 421614
|
||||
sntSymbol = "SNT"
|
||||
sttSymbol = "STT"
|
||||
)
|
||||
@ -24,7 +24,7 @@ var ganacheTokenAddress = common.HexToAddress("0x8571Ddc46b10d31EF963aF49b6C7799
|
||||
|
||||
func mainnet(stageName string) params.Network {
|
||||
return params.Network{
|
||||
ChainID: mainnetChainID,
|
||||
ChainID: MainnetChainID,
|
||||
ChainName: "Mainnet",
|
||||
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/ethereum/mainnet/", stageName),
|
||||
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/ethereum/mainnet/", stageName),
|
||||
@ -41,13 +41,13 @@ func mainnet(stageName string) params.Network {
|
||||
IsTest: false,
|
||||
Layer: 1,
|
||||
Enabled: true,
|
||||
RelatedChainID: sepoliaChainID,
|
||||
RelatedChainID: SepoliaChainID,
|
||||
}
|
||||
}
|
||||
|
||||
func sepolia(stageName string) params.Network {
|
||||
return params.Network{
|
||||
ChainID: sepoliaChainID,
|
||||
ChainID: SepoliaChainID,
|
||||
ChainName: "Mainnet",
|
||||
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/ethereum/sepolia/", stageName),
|
||||
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/ethereum/sepolia/", stageName),
|
||||
@ -64,13 +64,13 @@ func sepolia(stageName string) params.Network {
|
||||
IsTest: true,
|
||||
Layer: 1,
|
||||
Enabled: true,
|
||||
RelatedChainID: mainnetChainID,
|
||||
RelatedChainID: MainnetChainID,
|
||||
}
|
||||
}
|
||||
|
||||
func optimism(stageName string) params.Network {
|
||||
return params.Network{
|
||||
ChainID: optimismChainID,
|
||||
ChainID: OptimismChainID,
|
||||
ChainName: "Optimism",
|
||||
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/optimism/mainnet/", stageName),
|
||||
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/optimism/mainnet/", stageName),
|
||||
@ -87,13 +87,13 @@ func optimism(stageName string) params.Network {
|
||||
IsTest: false,
|
||||
Layer: 2,
|
||||
Enabled: true,
|
||||
RelatedChainID: optimismSepoliaChainID,
|
||||
RelatedChainID: OptimismSepoliaChainID,
|
||||
}
|
||||
}
|
||||
|
||||
func optimismSepolia(stageName string) params.Network {
|
||||
return params.Network{
|
||||
ChainID: optimismSepoliaChainID,
|
||||
ChainID: OptimismSepoliaChainID,
|
||||
ChainName: "Optimism",
|
||||
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/optimism/sepolia/", stageName),
|
||||
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/optimism/sepolia/", stageName),
|
||||
@ -110,13 +110,13 @@ func optimismSepolia(stageName string) params.Network {
|
||||
IsTest: true,
|
||||
Layer: 2,
|
||||
Enabled: false,
|
||||
RelatedChainID: optimismChainID,
|
||||
RelatedChainID: OptimismChainID,
|
||||
}
|
||||
}
|
||||
|
||||
func arbitrum(stageName string) params.Network {
|
||||
return params.Network{
|
||||
ChainID: arbitrumChainID,
|
||||
ChainID: ArbitrumChainID,
|
||||
ChainName: "Arbitrum",
|
||||
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/arbitrum/mainnet/", stageName),
|
||||
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/arbitrum/mainnet/", stageName),
|
||||
@ -133,13 +133,13 @@ func arbitrum(stageName string) params.Network {
|
||||
IsTest: false,
|
||||
Layer: 2,
|
||||
Enabled: true,
|
||||
RelatedChainID: arbitrumSepoliaChainID,
|
||||
RelatedChainID: ArbitrumSepoliaChainID,
|
||||
}
|
||||
}
|
||||
|
||||
func arbitrumSepolia(stageName string) params.Network {
|
||||
return params.Network{
|
||||
ChainID: arbitrumSepoliaChainID,
|
||||
ChainID: ArbitrumSepoliaChainID,
|
||||
ChainName: "Arbitrum",
|
||||
DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/arbitrum/sepolia/", stageName),
|
||||
DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/arbitrum/sepolia/", stageName),
|
||||
@ -156,7 +156,7 @@ func arbitrumSepolia(stageName string) params.Network {
|
||||
IsTest: true,
|
||||
Layer: 2,
|
||||
Enabled: false,
|
||||
RelatedChainID: arbitrumChainID,
|
||||
RelatedChainID: ArbitrumChainID,
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,7 +204,7 @@ func setRPCs(networks []params.Network, request *requests.WalletSecretsConfig) [
|
||||
if request.GanacheURL != "" {
|
||||
n.RPCURL = request.GanacheURL
|
||||
n.FallbackURL = request.GanacheURL
|
||||
if n.ChainID == mainnetChainID {
|
||||
if n.ChainID == MainnetChainID {
|
||||
n.TokenOverrides = []params.TokenOverride{
|
||||
mainnetGanacheTokenOverrides,
|
||||
}
|
||||
|
@ -28,12 +28,12 @@ func TestBuildDefaultNetworks(t *testing.T) {
|
||||
for _, n := range actualNetworks {
|
||||
var err error
|
||||
switch n.ChainID {
|
||||
case mainnetChainID:
|
||||
case sepoliaChainID:
|
||||
case optimismChainID:
|
||||
case optimismSepoliaChainID:
|
||||
case arbitrumChainID:
|
||||
case arbitrumSepoliaChainID:
|
||||
case MainnetChainID:
|
||||
case SepoliaChainID:
|
||||
case OptimismChainID:
|
||||
case OptimismSepoliaChainID:
|
||||
case ArbitrumChainID:
|
||||
case ArbitrumSepoliaChainID:
|
||||
default:
|
||||
err = errors.Errorf("unexpected chain id: %d", n.ChainID)
|
||||
}
|
||||
@ -70,7 +70,7 @@ func TestBuildDefaultNetworksGanache(t *testing.T) {
|
||||
require.True(t, strings.Contains(n.FallbackURL, ganacheURL))
|
||||
}
|
||||
|
||||
require.Equal(t, mainnetChainID, actualNetworks[0].ChainID)
|
||||
require.Equal(t, MainnetChainID, actualNetworks[0].ChainID)
|
||||
|
||||
require.NotNil(t, actualNetworks[0].TokenOverrides)
|
||||
require.Len(t, actualNetworks[0].TokenOverrides, 1)
|
||||
|
@ -198,22 +198,22 @@ func buildWalletConfig(request *requests.WalletSecretsConfig, statusProxyEnabled
|
||||
}
|
||||
|
||||
if request.AlchemyEthereumMainnetToken != "" {
|
||||
walletConfig.AlchemyAPIKeys[mainnetChainID] = request.AlchemyEthereumMainnetToken
|
||||
walletConfig.AlchemyAPIKeys[MainnetChainID] = request.AlchemyEthereumMainnetToken
|
||||
}
|
||||
if request.AlchemyEthereumSepoliaToken != "" {
|
||||
walletConfig.AlchemyAPIKeys[sepoliaChainID] = request.AlchemyEthereumSepoliaToken
|
||||
walletConfig.AlchemyAPIKeys[SepoliaChainID] = request.AlchemyEthereumSepoliaToken
|
||||
}
|
||||
if request.AlchemyArbitrumMainnetToken != "" {
|
||||
walletConfig.AlchemyAPIKeys[arbitrumChainID] = request.AlchemyArbitrumMainnetToken
|
||||
walletConfig.AlchemyAPIKeys[ArbitrumChainID] = request.AlchemyArbitrumMainnetToken
|
||||
}
|
||||
if request.AlchemyArbitrumSepoliaToken != "" {
|
||||
walletConfig.AlchemyAPIKeys[arbitrumSepoliaChainID] = request.AlchemyArbitrumSepoliaToken
|
||||
walletConfig.AlchemyAPIKeys[ArbitrumSepoliaChainID] = request.AlchemyArbitrumSepoliaToken
|
||||
}
|
||||
if request.AlchemyOptimismMainnetToken != "" {
|
||||
walletConfig.AlchemyAPIKeys[optimismChainID] = request.AlchemyOptimismMainnetToken
|
||||
walletConfig.AlchemyAPIKeys[OptimismChainID] = request.AlchemyOptimismMainnetToken
|
||||
}
|
||||
if request.AlchemyOptimismSepoliaToken != "" {
|
||||
walletConfig.AlchemyAPIKeys[optimismSepoliaChainID] = request.AlchemyOptimismSepoliaToken
|
||||
walletConfig.AlchemyAPIKeys[OptimismSepoliaChainID] = request.AlchemyOptimismSepoliaToken
|
||||
}
|
||||
if request.StatusProxyMarketUser != "" {
|
||||
walletConfig.StatusProxyMarketUser = request.StatusProxyMarketUser
|
||||
|
@ -0,0 +1,13 @@
|
||||
CREATE TABLE IF NOT EXISTS rpc_providers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, -- Unique provider ID (sorting)
|
||||
chain_id INTEGER NOT NULL CHECK (chain_id > 0), -- Chain ID for the network
|
||||
name TEXT NOT NULL CHECK (LENGTH(name) > 0), -- Provider name
|
||||
url TEXT NOT NULL CHECK (LENGTH(url) > 0), -- Provider URL
|
||||
enable_rps_limiter BOOLEAN NOT NULL DEFAULT FALSE, -- Enable RPS limiter
|
||||
type TEXT NOT NULL DEFAULT 'user', -- Provider type: embedded-proxy, embedded-direct, user
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE, -- Whether the provider is active or not
|
||||
auth_type TEXT NOT NULL DEFAULT 'no-auth', -- Authentication type: no-auth, basic-auth, token-auth
|
||||
auth_login TEXT, -- BasicAuth login (nullable)
|
||||
auth_password TEXT, -- Password for BasicAuth (nullable)
|
||||
auth_token TEXT -- Token for TokenAuth (nullable)
|
||||
);
|
@ -14,7 +14,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
validator "gopkg.in/go-playground/validator.v9"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/p2p/discv5"
|
||||
"github.com/ethereum/go-ethereum/params"
|
||||
|
||||
@ -482,7 +481,7 @@ type NodeConfig struct {
|
||||
// (persistent storage of user's mailserver records).
|
||||
MailserversConfig MailserversConfig
|
||||
|
||||
// Web3ProviderConfig extra configuration for provider.Service
|
||||
// Web3ProviderConfig extra configuration for provider.Service.
|
||||
// (desktop provider API)
|
||||
Web3ProviderConfig Web3ProviderConfig
|
||||
|
||||
@ -514,35 +513,6 @@ type NodeConfig struct {
|
||||
ProcessBackedupMessages bool
|
||||
}
|
||||
|
||||
type TokenOverride struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Address common.Address `json:"address"`
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
ChainID uint64 `json:"chainId"`
|
||||
ChainName string `json:"chainName"`
|
||||
DefaultRPCURL string `json:"defaultRpcUrl"` // proxy rpc url
|
||||
DefaultFallbackURL string `json:"defaultFallbackURL"` // proxy fallback url
|
||||
DefaultFallbackURL2 string `json:"defaultFallbackURL2"` // second proxy fallback url
|
||||
RPCURL string `json:"rpcUrl"`
|
||||
OriginalRPCURL string `json:"originalRpcUrl"`
|
||||
FallbackURL string `json:"fallbackURL"`
|
||||
OriginalFallbackURL string `json:"originalFallbackURL"`
|
||||
BlockExplorerURL string `json:"blockExplorerUrl,omitempty"`
|
||||
IconURL string `json:"iconUrl,omitempty"`
|
||||
NativeCurrencyName string `json:"nativeCurrencyName,omitempty"`
|
||||
NativeCurrencySymbol string `json:"nativeCurrencySymbol,omitempty"`
|
||||
NativeCurrencyDecimals uint64 `json:"nativeCurrencyDecimals"`
|
||||
IsTest bool `json:"isTest"`
|
||||
Layer uint64 `json:"layer"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ChainColor string `json:"chainColor"`
|
||||
ShortName string `json:"shortName"`
|
||||
TokenOverrides []TokenOverride `json:"tokenOverrides"`
|
||||
RelatedChainID uint64 `json:"relatedChainId"`
|
||||
}
|
||||
|
||||
// WalletConfig extra configuration for wallet.Service.
|
||||
type WalletConfig struct {
|
||||
Enabled bool
|
||||
@ -598,7 +568,7 @@ type MailserversConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// ProviderConfig extra configuration for provider.Service
|
||||
// ProviderAuthConfig extra configuration for provider.Service
|
||||
type Web3ProviderConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
95
params/network_config.go
Normal file
95
params/network_config.go
Normal file
@ -0,0 +1,95 @@
|
||||
package params
|
||||
|
||||
import "github.com/ethereum/go-ethereum/common"
|
||||
|
||||
// RpcProviderAuthType defines the different types of authentication for RPC providers
|
||||
type RpcProviderAuthType string
|
||||
|
||||
const (
|
||||
NoAuth RpcProviderAuthType = "no-auth" // No authentication
|
||||
BasicAuth RpcProviderAuthType = "basic-auth" // HTTP Header "Authorization: Basic base64(username:password)"
|
||||
TokenAuth RpcProviderAuthType = "token-auth" // URL Token-based authentication "https://api.example.com/YOUR_TOKEN"
|
||||
)
|
||||
|
||||
// RpcProviderType defines the type of RPC provider
|
||||
type RpcProviderType string
|
||||
|
||||
const (
|
||||
EmbeddedProxyProviderType RpcProviderType = "embedded-proxy" // Proxy-based RPC provider
|
||||
EmbeddedDirectProviderType RpcProviderType = "embedded-direct" // Direct RPC provider
|
||||
UserProviderType RpcProviderType = "user" // User-defined RPC provider
|
||||
)
|
||||
|
||||
// RpcProvider represents an RPC provider configuration with various options
|
||||
type RpcProvider struct {
|
||||
ID int64 `json:"id" validate:"omitempty"` // Auto-increment ID (for sorting order)
|
||||
ChainID uint64 `json:"chainId" validate:"required,gt=0"` // Chain ID of the network
|
||||
Name string `json:"name" validate:"required,min=1"` // Provider name for identification
|
||||
URL string `json:"url" validate:"required,url"` // Current Provider URL
|
||||
EnableRPSLimiter bool `json:"enableRpsLimiter"` // Enable RPC rate limiting for this provider
|
||||
Type RpcProviderType `json:"type" validate:"required,oneof=embedded-proxy embedded-direct user"`
|
||||
Enabled bool `json:"enabled"` // Whether the provider is enabled
|
||||
// Authentication
|
||||
AuthType RpcProviderAuthType `json:"authType" validate:"required,oneof=no-auth basic-auth token-auth"` // Type of authentication
|
||||
AuthLogin string `json:"authLogin" validate:"omitempty,min=1"` // Login for BasicAuth (empty string if not used)
|
||||
AuthPassword string `json:"authPassword" validate:"omitempty,min=1"` // Password for BasicAuth (empty string if not used)
|
||||
AuthToken string `json:"authToken" validate:"omitempty,min=1"` // Token for TokenAuth (empty string if not used)
|
||||
}
|
||||
|
||||
type TokenOverride struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Address common.Address `json:"address"`
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
ChainID uint64 `json:"chainId" validate:"required,gt=0"`
|
||||
ChainName string `json:"chainName" validate:"required,min=1"`
|
||||
RpcProviders []RpcProvider `json:"rpcProviders" validate:"dive,required"` // List of RPC providers, in the order in which they are accessed
|
||||
|
||||
// Deprecated fields (kept for backward compatibility)
|
||||
// FIXME: Removal of deprecated fields in integration PR https://github.com/status-im/status-go/pull/6178
|
||||
DefaultRPCURL string `json:"defaultRpcUrl" validate:"omitempty,url"` // Deprecated: proxy rpc url
|
||||
DefaultFallbackURL string `json:"defaultFallbackURL" validate:"omitempty,url"` // Deprecated: proxy fallback url
|
||||
DefaultFallbackURL2 string `json:"defaultFallbackURL2" validate:"omitempty,url"` // Deprecated: second proxy fallback url
|
||||
RPCURL string `json:"rpcUrl" validate:"omitempty,url"` // Deprecated: direct rpc url
|
||||
OriginalRPCURL string `json:"originalRpcUrl" validate:"omitempty,url"` // Deprecated: direct rpc url if user overrides RPCURL
|
||||
FallbackURL string `json:"fallbackURL" validate:"omitempty,url"` // Deprecated
|
||||
OriginalFallbackURL string `json:"originalFallbackURL" validate:"omitempty,url"` // Deprecated
|
||||
|
||||
BlockExplorerURL string `json:"blockExplorerUrl,omitempty" validate:"omitempty,url"`
|
||||
IconURL string `json:"iconUrl,omitempty" validate:"omitempty"`
|
||||
NativeCurrencyName string `json:"nativeCurrencyName,omitempty" validate:"omitempty,min=1"`
|
||||
NativeCurrencySymbol string `json:"nativeCurrencySymbol,omitempty" validate:"omitempty,min=1"`
|
||||
NativeCurrencyDecimals uint64 `json:"nativeCurrencyDecimals" validate:"omitempty"`
|
||||
IsTest bool `json:"isTest"`
|
||||
Layer uint64 `json:"layer" validate:"omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ChainColor string `json:"chainColor" validate:"omitempty"`
|
||||
ShortName string `json:"shortName" validate:"omitempty,min=1"`
|
||||
TokenOverrides []TokenOverride `json:"tokenOverrides" validate:"omitempty,dive"`
|
||||
RelatedChainID uint64 `json:"relatedChainId" validate:"omitempty"`
|
||||
}
|
||||
|
||||
func newRpcProvider(chainID uint64, name, url string, enableRpsLimiter bool, providerType RpcProviderType) *RpcProvider {
|
||||
return &RpcProvider{
|
||||
ChainID: chainID,
|
||||
Name: name,
|
||||
URL: url,
|
||||
EnableRPSLimiter: enableRpsLimiter,
|
||||
Type: providerType,
|
||||
Enabled: true,
|
||||
AuthType: NoAuth,
|
||||
}
|
||||
}
|
||||
|
||||
func NewUserProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider {
|
||||
return newRpcProvider(chainID, name, url, enableRpsLimiter, UserProviderType)
|
||||
}
|
||||
|
||||
func NewProxyProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider {
|
||||
return newRpcProvider(chainID, name, url, enableRpsLimiter, EmbeddedProxyProviderType)
|
||||
}
|
||||
|
||||
func NewDirectProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider {
|
||||
return newRpcProvider(chainID, name, url, enableRpsLimiter, EmbeddedDirectProviderType)
|
||||
}
|
186
params/networkhelper/provider_utils.go
Normal file
186
params/networkhelper/provider_utils.go
Normal file
@ -0,0 +1,186 @@
|
||||
package networkhelper
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/status-im/status-go/params"
|
||||
)
|
||||
|
||||
// MergeProvidersPreservingUsersAndEnabledState merges new embedded providers with the current ones,
|
||||
// preserving user-defined providers and maintaining the Enabled state.
|
||||
func MergeProvidersPreservingUsersAndEnabledState(currentProviders, newProviders []params.RpcProvider) []params.RpcProvider {
|
||||
// Create a map for quick lookup of the Enabled state by Name
|
||||
enabledState := make(map[string]bool, len(currentProviders))
|
||||
for _, provider := range currentProviders {
|
||||
enabledState[provider.Name] = provider.Enabled
|
||||
}
|
||||
|
||||
// Update the Enabled field in newProviders if the Name matches
|
||||
for i := range newProviders {
|
||||
if enabled, exists := enabledState[newProviders[i].Name]; exists {
|
||||
newProviders[i].Enabled = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// Retain current providers of type UserProviderType and add them to the beginning of the list
|
||||
mergedProviders := make([]params.RpcProvider, 0, len(currentProviders)+len(newProviders))
|
||||
for _, provider := range currentProviders {
|
||||
if provider.Type == params.UserProviderType {
|
||||
mergedProviders = append(mergedProviders, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the updated newProviders
|
||||
mergedProviders = append(mergedProviders, newProviders...)
|
||||
|
||||
return mergedProviders
|
||||
}
|
||||
|
||||
// ToggleUserProviders enables or disables all user-defined providers and disables other types.
|
||||
func ToggleUserProviders(providers []params.RpcProvider, enabled bool) []params.RpcProvider {
|
||||
for i := range providers {
|
||||
if providers[i].Type == params.UserProviderType {
|
||||
providers[i].Enabled = enabled
|
||||
} else {
|
||||
providers[i].Enabled = !enabled
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetEmbeddedProviders returns the embedded providers from the list.
|
||||
func GetEmbeddedProviders(providers []params.RpcProvider) []params.RpcProvider {
|
||||
var embeddedProviders []params.RpcProvider
|
||||
for _, provider := range providers {
|
||||
if provider.Type != params.UserProviderType {
|
||||
embeddedProviders = append(embeddedProviders, provider)
|
||||
}
|
||||
}
|
||||
return embeddedProviders
|
||||
}
|
||||
|
||||
// GetUserProviders returns the user-defined providers from the list.
|
||||
func GetUserProviders(providers []params.RpcProvider) []params.RpcProvider {
|
||||
var userProviders []params.RpcProvider
|
||||
for _, provider := range providers {
|
||||
if provider.Type == params.UserProviderType {
|
||||
userProviders = append(userProviders, provider)
|
||||
}
|
||||
}
|
||||
return userProviders
|
||||
}
|
||||
|
||||
// ReplaceUserProviders replaces user-defined providers with new ones, retaining the rest of the providers.
|
||||
func ReplaceUserProviders(currentProviders, newUserProviders []params.RpcProvider) []params.RpcProvider {
|
||||
// Extract embedded providers from the current list
|
||||
embeddedProviders := GetEmbeddedProviders(currentProviders)
|
||||
userProviders := GetUserProviders(newUserProviders)
|
||||
|
||||
// Combine new user providers with the existing embedded providers
|
||||
return append(userProviders, embeddedProviders...)
|
||||
}
|
||||
|
||||
// ReplaceEmbeddedProviders replaces embedded providers with new ones, retaining user-defined providers.
|
||||
func ReplaceEmbeddedProviders(currentProviders, newEmbeddedProviders []params.RpcProvider) []params.RpcProvider {
|
||||
// Extract user-defined providers from the current list
|
||||
userProviders := GetUserProviders(currentProviders)
|
||||
embeddedProviders := GetEmbeddedProviders(newEmbeddedProviders)
|
||||
|
||||
// Combine existing user-defined providers with the new embedded providers
|
||||
return append(userProviders, embeddedProviders...)
|
||||
}
|
||||
|
||||
// OverrideEmbeddedProxyProviders updates all embedded-proxy providers in the given networks.
|
||||
// It sets the `Enabled` flag and configures the `AuthLogin` and `AuthPassword` for each provider.
|
||||
func OverrideEmbeddedProxyProviders(networks []params.Network, enabled bool, user, password string) []params.Network {
|
||||
updatedNetworks := make([]params.Network, len(networks))
|
||||
for i, network := range networks {
|
||||
// Deep copy the network to avoid mutating the input slice
|
||||
updatedNetwork := network
|
||||
updatedProviders := make([]params.RpcProvider, len(network.RpcProviders))
|
||||
|
||||
// Update the embedded-proxy providers
|
||||
for j, provider := range network.RpcProviders {
|
||||
if provider.Type == params.EmbeddedProxyProviderType {
|
||||
provider.Enabled = enabled
|
||||
provider.AuthLogin = user
|
||||
provider.AuthPassword = password
|
||||
}
|
||||
updatedProviders[j] = provider
|
||||
}
|
||||
|
||||
updatedNetwork.RpcProviders = updatedProviders
|
||||
updatedNetworks[i] = updatedNetwork
|
||||
}
|
||||
|
||||
return updatedNetworks
|
||||
}
|
||||
|
||||
func deepCopyNetworks(networks []params.Network) []params.Network {
|
||||
updatedNetworks := make([]params.Network, len(networks))
|
||||
for i, network := range networks {
|
||||
updatedNetwork := network
|
||||
updatedNetwork.RpcProviders = make([]params.RpcProvider, len(network.RpcProviders))
|
||||
copy(updatedNetwork.RpcProviders, network.RpcProviders)
|
||||
updatedNetworks[i] = updatedNetwork
|
||||
}
|
||||
return updatedNetworks
|
||||
}
|
||||
|
||||
func OverrideDirectProvidersAuth(networks []params.Network, authTokens map[string]string) []params.Network {
|
||||
updatedNetworks := deepCopyNetworks(networks)
|
||||
|
||||
for i := range updatedNetworks {
|
||||
network := &updatedNetworks[i]
|
||||
|
||||
for j := range network.RpcProviders {
|
||||
provider := &network.RpcProviders[j]
|
||||
|
||||
if provider.Type != params.EmbeddedDirectProviderType {
|
||||
continue
|
||||
}
|
||||
|
||||
host, err := extractHost(provider.URL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for suffix, token := range authTokens {
|
||||
if strings.HasSuffix(host, suffix) && token != "" {
|
||||
provider.AuthType = params.TokenAuth
|
||||
provider.AuthToken = token
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return updatedNetworks
|
||||
}
|
||||
|
||||
func OverrideGanacheToken(networks []params.Network, ganacheURL string, chainID uint64, tokenOverride params.TokenOverride) []params.Network {
|
||||
updatedNetworks := deepCopyNetworks(networks)
|
||||
|
||||
for i := range updatedNetworks {
|
||||
network := &updatedNetworks[i]
|
||||
|
||||
if network.ChainID != chainID {
|
||||
continue
|
||||
}
|
||||
for j := range network.RpcProviders {
|
||||
if ganacheURL != "" {
|
||||
network.RpcProviders[j].URL = ganacheURL
|
||||
}
|
||||
}
|
||||
network.TokenOverrides = []params.TokenOverride{tokenOverride}
|
||||
}
|
||||
return updatedNetworks
|
||||
}
|
||||
|
||||
func extractHost(providerURL string) (string, error) {
|
||||
parsedURL, err := url.Parse(providerURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return parsedURL.Host, nil
|
||||
}
|
171
params/networkhelper/provider_utils_test.go
Normal file
171
params/networkhelper/provider_utils_test.go
Normal file
@ -0,0 +1,171 @@
|
||||
package networkhelper_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/status-im/status-go/api"
|
||||
"github.com/status-im/status-go/params"
|
||||
"github.com/status-im/status-go/params/networkhelper"
|
||||
"github.com/status-im/status-go/rpc/network/testutil"
|
||||
)
|
||||
|
||||
func TestMergeProvidersPreserveEnabledAndOrder(t *testing.T) {
|
||||
chainID := uint64(1)
|
||||
|
||||
// Current providers with mixed types
|
||||
currentProviders := []params.RpcProvider{
|
||||
*params.NewUserProvider(chainID, "UserProvider1", "https://userprovider1.example.com", true),
|
||||
*params.NewUserProvider(chainID, "UserProvider2", "https://userprovider2.example.com", true),
|
||||
*params.NewDirectProvider(chainID, "EmbeddedProvider1", "https://embeddedprovider1.example.com", true),
|
||||
*params.NewProxyProvider(chainID, "EmbeddedProvider2", "https://embeddedprovider2.example.com", true),
|
||||
}
|
||||
currentProviders[1].Enabled = false // UserProvider2 is disabled
|
||||
currentProviders[2].Enabled = false // EmbeddedProvider1 is disabled
|
||||
|
||||
// New providers to merge
|
||||
newProviders := []params.RpcProvider{
|
||||
*params.NewDirectProvider(chainID, "EmbeddedProvider1", "https://embeddedprovider1-new.example.com", true), // Should retain Enabled: false
|
||||
*params.NewProxyProvider(chainID, "EmbeddedProvider3", "https://embeddedprovider3.example.com", true), // New embedded provider
|
||||
*params.NewDirectProvider(chainID, "EmbeddedProvider4", "https://embeddedprovider4.example.com", true), // New embedded provider
|
||||
}
|
||||
|
||||
// Call MergeProviders
|
||||
mergedProviders := networkhelper.MergeProvidersPreservingUsersAndEnabledState(currentProviders, newProviders)
|
||||
|
||||
expectedEmbeddedProvider1 := newProviders[0]
|
||||
expectedEmbeddedProvider1.Enabled = false // Should retain Enabled: false
|
||||
// Expected providers after merging
|
||||
expectedProviders := []params.RpcProvider{
|
||||
currentProviders[0], // UserProvider1
|
||||
currentProviders[1], // UserProvider2
|
||||
expectedEmbeddedProvider1, // EmbeddedProvider1 (should retain Enabled: false)
|
||||
newProviders[1], // EmbeddedProvider3
|
||||
newProviders[2], // EmbeddedProvider4
|
||||
}
|
||||
|
||||
// Assertions
|
||||
require.True(t, reflect.DeepEqual(mergedProviders, expectedProviders), "Merged providers should match the expected providers")
|
||||
}
|
||||
func TestUpdateEmbeddedProxyProviders(t *testing.T) {
|
||||
// Arrange: Create a sample list of networks with various provider types
|
||||
networks := []params.Network{
|
||||
*testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{
|
||||
*params.NewUserProvider(api.MainnetChainID, "Provider1", "https://userprovider.example.com", true),
|
||||
*params.NewProxyProvider(api.MainnetChainID, "Provider2", "https://proxyprovider.example.com", true),
|
||||
}),
|
||||
*testutil.CreateNetwork(api.OptimismChainID, "Optimism", []params.RpcProvider{
|
||||
*params.NewDirectProvider(api.OptimismChainID, "Provider3", "https://directprovider.example.com", true),
|
||||
*params.NewProxyProvider(api.OptimismChainID, "Provider4", "https://proxyprovider2.example.com", true),
|
||||
}),
|
||||
}
|
||||
networks[0].RpcProviders[1].Enabled = false
|
||||
networks[1].RpcProviders[1].Enabled = false
|
||||
|
||||
user := gofakeit.Username()
|
||||
password := gofakeit.LetterN(5)
|
||||
|
||||
// Call the function to update embedded-proxy providers
|
||||
updatedNetworks := networkhelper.OverrideEmbeddedProxyProviders(networks, true, user, password)
|
||||
|
||||
// Verify the networks
|
||||
for i, network := range updatedNetworks {
|
||||
networkCopy := network
|
||||
expectedNetwork := &networks[i]
|
||||
testutil.CompareNetworks(t, expectedNetwork, &networkCopy)
|
||||
|
||||
for j, provider := range networkCopy.RpcProviders {
|
||||
expectedProvider := expectedNetwork.RpcProviders[j]
|
||||
if provider.Type == params.EmbeddedProxyProviderType {
|
||||
assert.True(t, provider.Enabled, "Provider Enabled state should be overridden")
|
||||
assert.Equal(t, user, provider.AuthLogin, "Provider AuthLogin should be overridden")
|
||||
assert.Equal(t, password, provider.AuthPassword, "Provider AuthPassword should be overridden")
|
||||
} else {
|
||||
assert.Equal(t, expectedProvider.Enabled, provider.Enabled, "Provider Enabled state should remain unchanged")
|
||||
assert.Equal(t, expectedProvider.AuthLogin, provider.AuthLogin, "Provider AuthLogin should remain unchanged")
|
||||
assert.Equal(t, expectedProvider.AuthPassword, provider.AuthPassword, "Provider AuthPassword should remain unchanged")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestOverrideDirectProvidersAuth(t *testing.T) {
|
||||
// Create a sample list of networks with various provider types
|
||||
networks := []params.Network{
|
||||
*testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{
|
||||
*params.NewUserProvider(api.MainnetChainID, "Provider1", "https://user.example.com/", true),
|
||||
*params.NewDirectProvider(api.MainnetChainID, "Provider2", "https://mainnet.infura.io/v3/", true),
|
||||
*params.NewDirectProvider(api.MainnetChainID, "Provider3", "https://eth-archival.rpc.grove.city/v1/", true),
|
||||
}),
|
||||
*testutil.CreateNetwork(api.OptimismChainID, "Optimism", []params.RpcProvider{
|
||||
*params.NewDirectProvider(api.OptimismChainID, "Provider4", "https://optimism.infura.io/v3/", true),
|
||||
*params.NewDirectProvider(api.OptimismChainID, "Provider5", "https://op.grove.city/v1/", true),
|
||||
}),
|
||||
}
|
||||
|
||||
authTokens := map[string]string{
|
||||
"infura.io": gofakeit.UUID(),
|
||||
"grove.city": gofakeit.UUID(),
|
||||
"example.com": gofakeit.UUID(),
|
||||
}
|
||||
|
||||
// Call OverrideDirectProvidersAuth
|
||||
updatedNetworks := networkhelper.OverrideDirectProvidersAuth(networks, authTokens)
|
||||
|
||||
// Verify the networks have updated auth tokens correctly
|
||||
for i, network := range updatedNetworks {
|
||||
for j, provider := range network.RpcProviders {
|
||||
expectedProvider := networks[i].RpcProviders[j]
|
||||
switch {
|
||||
case strings.Contains(provider.URL, "infura.io"):
|
||||
assert.Equal(t, params.TokenAuth, provider.AuthType)
|
||||
assert.Equal(t, authTokens["infura.io"], provider.AuthToken)
|
||||
assert.NotEqual(t, expectedProvider.AuthToken, provider.AuthToken)
|
||||
case strings.Contains(provider.URL, "grove.city"):
|
||||
assert.Equal(t, params.TokenAuth, provider.AuthType)
|
||||
assert.Equal(t, authTokens["grove.city"], provider.AuthToken)
|
||||
assert.NotEqual(t, expectedProvider.AuthToken, provider.AuthToken)
|
||||
case strings.Contains(provider.URL, "example.com"):
|
||||
assert.Equal(t, params.NoAuth, provider.AuthType) // should not update user providers
|
||||
default:
|
||||
assert.Equal(t, expectedProvider.AuthType, provider.AuthType)
|
||||
assert.Equal(t, expectedProvider.AuthToken, provider.AuthToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverrideGanacheTokenOverrides(t *testing.T) {
|
||||
// Create a sample list of networks with various ChainIDs
|
||||
networks := []params.Network{
|
||||
*testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", nil),
|
||||
*testutil.CreateNetwork(api.OptimismChainID, "Optimism", nil),
|
||||
*testutil.CreateNetwork(api.MainnetChainID, "Mainnet Duplicate", nil),
|
||||
}
|
||||
|
||||
ganacheTokenOverride := params.TokenOverride{
|
||||
Symbol: "SNT",
|
||||
Address: common.HexToAddress("0x8571Ddc46b10d31EF963aF49b6C7799Ea7eff818"),
|
||||
}
|
||||
|
||||
// Call OverrideGanacheTokenOverrides
|
||||
updatedNetworks := networkhelper.OverrideGanacheToken(networks, "url", api.MainnetChainID, ganacheTokenOverride)
|
||||
|
||||
// Verify that only networks with the specified ChainID have the token override applied
|
||||
for _, network := range updatedNetworks {
|
||||
if network.ChainID == api.MainnetChainID {
|
||||
require.NotNil(t, network.TokenOverrides, "TokenOverrides should not be nil for ChainID %d", network.ChainID)
|
||||
assert.Contains(t, network.TokenOverrides, ganacheTokenOverride, "TokenOverrides should contain the ganache token")
|
||||
} else {
|
||||
assert.Nil(t, network.TokenOverrides, "TokenOverrides should be nil for ChainID %d", network.ChainID)
|
||||
}
|
||||
}
|
||||
}
|
49
params/networkhelper/validate.go
Normal file
49
params/networkhelper/validate.go
Normal file
@ -0,0 +1,49 @@
|
||||
package networkhelper
|
||||
|
||||
import (
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
|
||||
"github.com/status-im/status-go/params"
|
||||
)
|
||||
|
||||
func GetValidator() *validator.Validate {
|
||||
validate := validator.New()
|
||||
|
||||
// Register struct-level validation for RpcProvider
|
||||
validate.RegisterStructValidation(rpcProviderStructLevelValidation, params.RpcProvider{})
|
||||
|
||||
return validate
|
||||
}
|
||||
|
||||
func rpcProviderStructLevelValidation(sl validator.StructLevel) {
|
||||
provider := sl.Current().Interface().(params.RpcProvider)
|
||||
|
||||
switch provider.AuthType {
|
||||
case params.NoAuth:
|
||||
if provider.AuthLogin != "" || provider.AuthPassword != "" || provider.AuthToken != "" {
|
||||
sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "noauth_fields_empty", "")
|
||||
sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "noauth_fields_empty", "")
|
||||
sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "noauth_fields_empty", "")
|
||||
}
|
||||
case params.BasicAuth:
|
||||
if provider.AuthLogin == "" {
|
||||
sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "required", "")
|
||||
}
|
||||
if provider.AuthPassword == "" {
|
||||
sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "required", "")
|
||||
}
|
||||
if provider.AuthToken != "" {
|
||||
sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "basic_auth_token_empty", "")
|
||||
}
|
||||
case params.TokenAuth:
|
||||
if provider.AuthToken == "" {
|
||||
sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "required", "")
|
||||
}
|
||||
if provider.AuthLogin != "" || provider.AuthPassword != "" {
|
||||
sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "tokenauth_fields_empty", "")
|
||||
sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "tokenauth_fields_empty", "")
|
||||
}
|
||||
default:
|
||||
sl.ReportError(provider.AuthType, "AuthType", "authType", "invalid_auth_type", "")
|
||||
}
|
||||
}
|
184
params/networkhelper/validate_test.go
Normal file
184
params/networkhelper/validate_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
package networkhelper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
|
||||
"github.com/status-im/status-go/params"
|
||||
)
|
||||
|
||||
func TestValidation(t *testing.T) {
|
||||
validate := GetValidator()
|
||||
|
||||
// Test cases for RpcProvider
|
||||
providerTests := []struct {
|
||||
name string
|
||||
provider params.RpcProvider
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Provider",
|
||||
provider: params.RpcProvider{
|
||||
ChainID: 1,
|
||||
Name: "Mainnet Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
Enabled: true,
|
||||
AuthType: params.NoAuth,
|
||||
EnableRPSLimiter: false,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Missing Provider Name",
|
||||
provider: params.RpcProvider{
|
||||
ChainID: 1,
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid AuthType",
|
||||
provider: params.RpcProvider{
|
||||
ChainID: 1,
|
||||
Name: "Invalid Auth Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
AuthType: "invalid-auth-type",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "BasicAuth without Login",
|
||||
provider: params.RpcProvider{
|
||||
ChainID: 1,
|
||||
Name: "BasicAuth Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
AuthType: params.BasicAuth,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "TokenAuth without Token",
|
||||
provider: params.RpcProvider{
|
||||
ChainID: 1,
|
||||
Name: "TokenAuth Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
AuthType: params.TokenAuth,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "NoAuth with Login",
|
||||
provider: params.RpcProvider{
|
||||
ChainID: 1,
|
||||
Name: "NoAuth Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
AuthType: params.NoAuth,
|
||||
AuthLogin: "user",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range providerTests {
|
||||
t.Run("RpcProvider: "+test.name, func(t *testing.T) {
|
||||
err := validate.Struct(test.provider)
|
||||
if test.expectErr {
|
||||
require.Error(t, err, "Expected error but got nil for test case '%s'", test.name)
|
||||
} else {
|
||||
require.NoError(t, err, "Did not expect error but got '%v' for test case '%s'", err, test.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for Network
|
||||
func TestNetworkValidation(t *testing.T) {
|
||||
validate := validator.New()
|
||||
|
||||
networkTests := []struct {
|
||||
name string
|
||||
network params.Network
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Network",
|
||||
network: params.Network{
|
||||
ChainID: 1,
|
||||
ChainName: "Ethereum Mainnet",
|
||||
BlockExplorerURL: "https://etherscan.io",
|
||||
NativeCurrencyName: "Ether",
|
||||
NativeCurrencySymbol: "ETH",
|
||||
NativeCurrencyDecimals: 18,
|
||||
IsTest: false,
|
||||
Layer: 1,
|
||||
Enabled: true,
|
||||
ChainColor: "#E90101",
|
||||
ShortName: "eth",
|
||||
RpcProviders: []params.RpcProvider{
|
||||
{
|
||||
ChainID: 1,
|
||||
Name: "Mainnet Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
Enabled: true,
|
||||
AuthType: params.NoAuth,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Missing Chain Name",
|
||||
network: params.Network{
|
||||
ChainID: 1,
|
||||
RpcProviders: []params.RpcProvider{
|
||||
{
|
||||
ChainID: 1,
|
||||
Name: "Mainnet Provider",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Provider in Network",
|
||||
network: params.Network{
|
||||
ChainID: 1,
|
||||
ChainName: "Ethereum Mainnet",
|
||||
RpcProviders: []params.RpcProvider{
|
||||
{
|
||||
ChainID: 1,
|
||||
Name: "",
|
||||
URL: "https://provider.example.com",
|
||||
Type: params.UserProviderType,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range networkTests {
|
||||
t.Run("Network: "+test.name, func(t *testing.T) {
|
||||
err := validate.Struct(test.network)
|
||||
if test.expectErr {
|
||||
require.Error(t, err, "Expected error but got nil for test case '%s'", test.name)
|
||||
} else {
|
||||
require.NoError(t, err, "Did not expect error but got '%v' for test case '%s'", err, test.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
257
rpc/network/db/network_db.go
Normal file
257
rpc/network/db/network_db.go
Normal file
@ -0,0 +1,257 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
|
||||
"github.com/status-im/status-go/params"
|
||||
"github.com/status-im/status-go/sqlite"
|
||||
)
|
||||
|
||||
// NetworksPersistenceInterface describes the interface for managing networks and providers.
|
||||
type NetworksPersistenceInterface interface {
|
||||
// GetNetworks returns networks based on filters.
|
||||
GetNetworks(onlyEnabled bool, chainID *uint64) ([]*params.Network, error)
|
||||
// GetAllNetworks returns all networks.
|
||||
GetAllNetworks() ([]*params.Network, error)
|
||||
// GetNetworkByChainID returns a network by ChainID.
|
||||
GetNetworkByChainID(chainID uint64) ([]*params.Network, error)
|
||||
// GetEnabledNetworks returns enabled networks.
|
||||
GetEnabledNetworks() ([]*params.Network, error)
|
||||
// SetNetworks replaces all networks with new ones and their providers.
|
||||
SetNetworks(networks []params.Network) error
|
||||
// UpsertNetwork adds or updates a network and its providers.
|
||||
UpsertNetwork(network *params.Network) error
|
||||
// DeleteNetwork deletes a network by ChainID and its providers.
|
||||
DeleteNetwork(chainID uint64) error
|
||||
// DeleteAllNetworks deletes all networks and their providers.
|
||||
DeleteAllNetworks() error
|
||||
|
||||
GetRpcPersistence() RpcProvidersPersistenceInterface
|
||||
}
|
||||
|
||||
// NetworksPersistence manages networks and their providers.
|
||||
type NetworksPersistence struct {
|
||||
db sqlite.StatementExecutor
|
||||
rpcPersistence RpcProvidersPersistenceInterface
|
||||
validator *validator.Validate
|
||||
}
|
||||
|
||||
// NewNetworksPersistence creates a new instance of NetworksPersistence.
|
||||
func NewNetworksPersistence(db sqlite.StatementExecutor) *NetworksPersistence {
|
||||
return &NetworksPersistence{
|
||||
db: db,
|
||||
rpcPersistence: NewRpcProvidersPersistence(db),
|
||||
validator: validator.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetRpcPersistence returns an instance of RpcProvidersPersistenceInterface.
|
||||
func (n *NetworksPersistence) GetRpcPersistence() RpcProvidersPersistenceInterface {
|
||||
return n.rpcPersistence
|
||||
}
|
||||
|
||||
// GetNetworks returns networks based on filters.
|
||||
func (n *NetworksPersistence) GetNetworks(onlyEnabled bool, chainID *uint64) ([]*params.Network, error) {
|
||||
q := sq.Select(
|
||||
"chain_id", "chain_name", "rpc_url", "fallback_url",
|
||||
"block_explorer_url", "icon_url", "native_currency_name", "native_currency_symbol", "native_currency_decimals",
|
||||
"is_test", "layer", "enabled", "chain_color", "short_name", "related_chain_id",
|
||||
).
|
||||
From("networks").
|
||||
OrderBy("chain_id ASC")
|
||||
|
||||
if onlyEnabled {
|
||||
q = q.Where(sq.Eq{"enabled": true})
|
||||
}
|
||||
if chainID != nil {
|
||||
q = q.Where(sq.Eq{"chain_id": *chainID})
|
||||
}
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := n.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make([]*params.Network, 0, 10)
|
||||
for rows.Next() {
|
||||
network := ¶ms.Network{}
|
||||
err := rows.Scan(
|
||||
&network.ChainID, &network.ChainName, &network.RPCURL, &network.FallbackURL,
|
||||
&network.BlockExplorerURL, &network.IconURL, &network.NativeCurrencyName, &network.NativeCurrencySymbol,
|
||||
&network.NativeCurrencyDecimals, &network.IsTest, &network.Layer, &network.Enabled, &network.ChainColor,
|
||||
&network.ShortName, &network.RelatedChainID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch RPC providers for the network
|
||||
providers, err := n.rpcPersistence.GetRpcProviders(network.ChainID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch RPC providers for chain_id %d: %w", network.ChainID, err)
|
||||
}
|
||||
network.RpcProviders = providers
|
||||
|
||||
// Fill deprecated URLs if necessary (assuming this is a function you have)
|
||||
FillDeprecatedURLs(network, providers)
|
||||
|
||||
result = append(result, network)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAllNetworks returns all networks.
|
||||
func (n *NetworksPersistence) GetAllNetworks() ([]*params.Network, error) {
|
||||
return n.GetNetworks(false, nil)
|
||||
}
|
||||
|
||||
// GetNetworkByChainID returns a network by ChainID.
|
||||
func (n *NetworksPersistence) GetNetworkByChainID(chainID uint64) ([]*params.Network, error) {
|
||||
return n.GetNetworks(false, &chainID)
|
||||
}
|
||||
|
||||
// GetEnabledNetworks returns enabled networks.
|
||||
func (n *NetworksPersistence) GetEnabledNetworks() ([]*params.Network, error) {
|
||||
return n.GetNetworks(true, nil)
|
||||
}
|
||||
|
||||
// SetNetworks replaces all networks with new ones and their providers.
|
||||
// Note: Transaction management should be handled by the caller.
|
||||
func (n *NetworksPersistence) SetNetworks(networks []params.Network) error {
|
||||
// Delete all networks and their providers
|
||||
err := n.DeleteAllNetworks()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Upsert networks and their providers
|
||||
for i := range networks {
|
||||
err := n.UpsertNetwork(&networks[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert network with chain_id %d: %w", networks[i].ChainID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAllNetworks deletes all networks and their RPC providers.
|
||||
// Note: Transaction management should be handled by the caller.
|
||||
func (n *NetworksPersistence) DeleteAllNetworks() error {
|
||||
// Delete all RPC providers
|
||||
err := n.rpcPersistence.DeleteAllRpcProviders()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete all RPC providers: %w", err)
|
||||
}
|
||||
|
||||
// Delete all networks
|
||||
q := sq.Delete("networks")
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build delete query: %w", err)
|
||||
}
|
||||
|
||||
_, err = n.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute delete query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertNetwork adds or updates a network and its providers.
|
||||
// Note: Transaction management should be handled by the caller.
|
||||
func (n *NetworksPersistence) UpsertNetwork(network *params.Network) error {
|
||||
// Validate the network
|
||||
if err := n.validator.Struct(network); err != nil {
|
||||
return fmt.Errorf("network validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Upsert the network
|
||||
err := n.upsertNetwork(network)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert network for chain_id %d: %w", network.ChainID, err)
|
||||
}
|
||||
|
||||
// Set the RPC providers
|
||||
err = n.rpcPersistence.SetRpcProviders(network.ChainID, network.RpcProviders)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set RPC providers for chain_id %d: %w", network.ChainID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upsertNetwork handles the logic for inserting or updating a network record.
|
||||
func (n *NetworksPersistence) upsertNetwork(network *params.Network) error {
|
||||
q := sq.Insert("networks").
|
||||
Columns(
|
||||
"chain_id", "chain_name", "rpc_url", "original_rpc_url", "fallback_url", "original_fallback_url",
|
||||
"block_explorer_url", "icon_url", "native_currency_name", "native_currency_symbol", "native_currency_decimals",
|
||||
"is_test", "layer", "enabled", "chain_color", "short_name", "related_chain_id",
|
||||
).
|
||||
Values(
|
||||
network.ChainID, network.ChainName, network.RPCURL, network.OriginalRPCURL, network.FallbackURL, network.OriginalFallbackURL,
|
||||
network.BlockExplorerURL, network.IconURL, network.NativeCurrencyName, network.NativeCurrencySymbol, network.NativeCurrencyDecimals,
|
||||
network.IsTest, network.Layer, network.Enabled, network.ChainColor, network.ShortName, network.RelatedChainID,
|
||||
).
|
||||
Suffix("ON CONFLICT(chain_id) DO UPDATE SET " +
|
||||
"chain_name = excluded.chain_name, rpc_url = excluded.rpc_url, original_rpc_url = excluded.original_rpc_url, " +
|
||||
"fallback_url = excluded.fallback_url, original_fallback_url = excluded.original_fallback_url, " +
|
||||
"block_explorer_url = excluded.block_explorer_url, icon_url = excluded.icon_url, " +
|
||||
"native_currency_name = excluded.native_currency_name, native_currency_symbol = excluded.native_currency_symbol, " +
|
||||
"native_currency_decimals = excluded.native_currency_decimals, is_test = excluded.is_test, " +
|
||||
"layer = excluded.layer, enabled = excluded.enabled, chain_color = excluded.chain_color, " +
|
||||
"short_name = excluded.short_name, related_chain_id = excluded.related_chain_id")
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build upsert query: %w", err)
|
||||
}
|
||||
|
||||
_, err = n.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upsert query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNetwork deletes a network and its associated RPC providers.
|
||||
// Note: Transaction management should be handled by the caller.
|
||||
func (n *NetworksPersistence) DeleteNetwork(chainID uint64) error {
|
||||
// Delete RPC providers
|
||||
err := n.rpcPersistence.DeleteRpcProviders(chainID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete RPC providers for chain_id %d: %w", chainID, err)
|
||||
}
|
||||
|
||||
// Delete the network record
|
||||
q := sq.Delete("networks").Where(sq.Eq{"chain_id": chainID})
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build delete query: %w", err)
|
||||
}
|
||||
|
||||
_, err = n.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute delete query for chain_id %d: %w", chainID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
197
rpc/network/db/network_db_test.go
Normal file
197
rpc/network/db/network_db_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/status-im/status-go/api"
|
||||
"github.com/status-im/status-go/appdatabase"
|
||||
"github.com/status-im/status-go/params"
|
||||
"github.com/status-im/status-go/rpc/network/db"
|
||||
"github.com/status-im/status-go/rpc/network/testutil"
|
||||
"github.com/status-im/status-go/t/helpers"
|
||||
)
|
||||
|
||||
type NetworksPersistenceTestSuite struct {
|
||||
suite.Suite
|
||||
db *sql.DB
|
||||
cleanup func() error
|
||||
networksPersistence db.NetworksPersistenceInterface
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) SetupTest() {
|
||||
memDb, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "networks-tests")
|
||||
s.Require().NoError(err)
|
||||
s.db = memDb
|
||||
s.cleanup = cleanup
|
||||
s.networksPersistence = db.NewNetworksPersistence(memDb)
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
err := s.cleanup()
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworksPersistenceTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(NetworksPersistenceTestSuite))
|
||||
}
|
||||
|
||||
// Helper function to create default providers for a given chainID
|
||||
func DefaultProviders(chainID uint64) []params.RpcProvider {
|
||||
return []params.RpcProvider{
|
||||
{
|
||||
Name: "Provider1",
|
||||
ChainID: chainID,
|
||||
URL: "https://rpc.provider1.io",
|
||||
Type: params.UserProviderType,
|
||||
Enabled: true,
|
||||
AuthType: params.NoAuth,
|
||||
},
|
||||
{
|
||||
Name: "Provider2",
|
||||
ChainID: chainID,
|
||||
URL: "https://rpc.provider2.io",
|
||||
Type: params.EmbeddedProxyProviderType,
|
||||
Enabled: true,
|
||||
AuthType: params.BasicAuth,
|
||||
AuthLogin: "user1",
|
||||
AuthPassword: "password1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to add and verify networks
|
||||
func (s *NetworksPersistenceTestSuite) addAndVerifyNetworks(networks []*params.Network) {
|
||||
networkValues := make([]params.Network, 0, len(networks))
|
||||
for _, network := range networks {
|
||||
networkValues = append(networkValues, *network)
|
||||
}
|
||||
err := s.networksPersistence.SetNetworks(networkValues)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.verifyNetworks(networks)
|
||||
}
|
||||
|
||||
// Helper function to verify networks against the database
|
||||
func (s *NetworksPersistenceTestSuite) verifyNetworks(networks []*params.Network) {
|
||||
allNetworks, err := s.networksPersistence.GetAllNetworks()
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareNetworksList(s.T(), networks, allNetworks)
|
||||
}
|
||||
|
||||
// Helper function to verify network deletion
|
||||
func (s *NetworksPersistenceTestSuite) verifyNetworkDeletion(chainID uint64) {
|
||||
nets, err := s.networksPersistence.GetNetworkByChainID(chainID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(nets, 0)
|
||||
|
||||
providers, err := s.networksPersistence.GetRpcPersistence().GetRpcProviders(chainID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(providers, 0)
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TestAddAndGetNetworkWithProviders() {
|
||||
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", []params.RpcProvider{
|
||||
testutil.CreateProvider(api.OptimismChainID, "Provider1", params.UserProviderType, true, "https://rpc.optimism.io"),
|
||||
testutil.CreateProvider(api.OptimismChainID, "Provider2", params.EmbeddedProxyProviderType, false, "https://backup.optimism.io"),
|
||||
})
|
||||
s.addAndVerifyNetworks([]*params.Network{network})
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TestDeleteNetworkWithProviders() {
|
||||
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID))
|
||||
s.addAndVerifyNetworks([]*params.Network{network})
|
||||
|
||||
err := s.networksPersistence.DeleteNetwork(network.ChainID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.verifyNetworkDeletion(network.ChainID)
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TestUpdateNetworkAndProviders() {
|
||||
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID))
|
||||
s.addAndVerifyNetworks([]*params.Network{network})
|
||||
|
||||
// Update fields
|
||||
network.ChainName = "Updated Optimism Mainnet"
|
||||
network.RpcProviders = []params.RpcProvider{
|
||||
testutil.CreateProvider(api.OptimismChainID, "UpdatedProvider", params.UserProviderType, true, "https://rpc.optimism.updated.io"),
|
||||
}
|
||||
|
||||
s.addAndVerifyNetworks([]*params.Network{network})
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TestDeleteAllNetworks() {
|
||||
networks := []*params.Network{
|
||||
testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", DefaultProviders(api.MainnetChainID)),
|
||||
testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", DefaultProviders(api.SepoliaChainID)),
|
||||
}
|
||||
s.addAndVerifyNetworks(networks)
|
||||
|
||||
err := s.networksPersistence.DeleteAllNetworks()
|
||||
s.Require().NoError(err)
|
||||
|
||||
allNetworks, err := s.networksPersistence.GetAllNetworks()
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(allNetworks, 0)
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TestSetNetworks() {
|
||||
initialNetworks := []*params.Network{
|
||||
testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", DefaultProviders(api.MainnetChainID)),
|
||||
testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", DefaultProviders(api.SepoliaChainID)),
|
||||
}
|
||||
newNetworks := []*params.Network{
|
||||
testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID)),
|
||||
}
|
||||
|
||||
// Add initial networks
|
||||
s.addAndVerifyNetworks(initialNetworks)
|
||||
|
||||
// Replace with new networks
|
||||
s.addAndVerifyNetworks(newNetworks)
|
||||
|
||||
// Verify old networks are removed
|
||||
s.verifyNetworkDeletion(api.MainnetChainID)
|
||||
s.verifyNetworkDeletion(api.SepoliaChainID)
|
||||
}
|
||||
|
||||
func (s *NetworksPersistenceTestSuite) TestValidationForNetworksAndProviders() {
|
||||
// Invalid Network: Missing required ChainName
|
||||
invalidNetwork := testutil.CreateNetwork(api.MainnetChainID, "", DefaultProviders(api.MainnetChainID))
|
||||
|
||||
// Invalid Provider: Missing URL
|
||||
invalidProvider := params.RpcProvider{
|
||||
Name: "InvalidProvider",
|
||||
ChainID: api.MainnetChainID,
|
||||
URL: "", // Invalid
|
||||
Type: params.UserProviderType,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
// Add invalid provider to a valid network
|
||||
validNetworkWithInvalidProvider := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", []params.RpcProvider{invalidProvider})
|
||||
|
||||
// Invalid networks and providers should fail validation
|
||||
networksToValidate := []*params.Network{
|
||||
invalidNetwork,
|
||||
validNetworkWithInvalidProvider,
|
||||
}
|
||||
|
||||
for _, network := range networksToValidate {
|
||||
err := s.networksPersistence.UpsertNetwork(network)
|
||||
s.Require().Error(err, "Expected validation to fail for invalid network or provider")
|
||||
}
|
||||
|
||||
// Ensure no invalid data is saved in the database
|
||||
allNetworks, err := s.networksPersistence.GetAllNetworks()
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(allNetworks, 0, "No invalid networks should be saved")
|
||||
}
|
243
rpc/network/db/rpc_provider_db.go
Normal file
243
rpc/network/db/rpc_provider_db.go
Normal file
@ -0,0 +1,243 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
|
||||
"github.com/status-im/status-go/params"
|
||||
"github.com/status-im/status-go/sqlite"
|
||||
)
|
||||
|
||||
// Interface for managing RPC providers
|
||||
type RpcProvidersPersistenceInterface interface {
|
||||
GetRpcProviders(chainID uint64) ([]params.RpcProvider, error)
|
||||
GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error)
|
||||
AddRpcProvider(provider params.RpcProvider) error
|
||||
DeleteRpcProviders(chainID uint64) error
|
||||
DeleteAllRpcProviders() error
|
||||
UpdateRpcProvider(provider params.RpcProvider) error
|
||||
SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error
|
||||
}
|
||||
|
||||
// Struct for managing RPC providers
|
||||
type RpcProvidersPersistence struct {
|
||||
db sqlite.StatementExecutor
|
||||
validator *validator.Validate
|
||||
}
|
||||
|
||||
// Constructor for RpcProvidersPersistence with validator
|
||||
func NewRpcProvidersPersistence(db sqlite.StatementExecutor) *RpcProvidersPersistence {
|
||||
return &RpcProvidersPersistence{
|
||||
db: db,
|
||||
validator: validator.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve all providers for a specific ChainID
|
||||
func (p *RpcProvidersPersistence) GetRpcProviders(chainID uint64) ([]params.RpcProvider, error) {
|
||||
q := sq.Select(
|
||||
"id",
|
||||
"chain_id",
|
||||
"name",
|
||||
"url",
|
||||
"enable_rps_limiter",
|
||||
"type",
|
||||
"enabled",
|
||||
"auth_type",
|
||||
"auth_login",
|
||||
"auth_password",
|
||||
"auth_token",
|
||||
).
|
||||
From("rpc_providers").
|
||||
Where(sq.Eq{"chain_id": chainID}).
|
||||
OrderBy("id ASC")
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build query: %w", err)
|
||||
}
|
||||
|
||||
rows, err := p.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var providers []params.RpcProvider
|
||||
for rows.Next() {
|
||||
var provider params.RpcProvider
|
||||
err := rows.Scan(
|
||||
&provider.ID,
|
||||
&provider.ChainID,
|
||||
&provider.Name,
|
||||
&provider.URL,
|
||||
&provider.EnableRPSLimiter,
|
||||
&provider.Type,
|
||||
&provider.Enabled,
|
||||
&provider.AuthType,
|
||||
&provider.AuthLogin,
|
||||
&provider.AuthPassword,
|
||||
&provider.AuthToken,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// Retrieve providers of a specific type
|
||||
func (p *RpcProvidersPersistence) GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error) {
|
||||
allProviders, err := p.GetRpcProviders(chainID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]params.RpcProvider, 0, len(allProviders))
|
||||
for _, provider := range allProviders {
|
||||
if provider.Type == providerType {
|
||||
result = append(result, provider)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Add a new provider
|
||||
func (p *RpcProvidersPersistence) AddRpcProvider(provider params.RpcProvider) error {
|
||||
// Validate the provider
|
||||
if err := p.validator.Struct(provider); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Proceed with adding the provider to the database
|
||||
q := sq.Insert("rpc_providers").
|
||||
Columns(
|
||||
"chain_id",
|
||||
"name",
|
||||
"url",
|
||||
"enable_rps_limiter",
|
||||
"type",
|
||||
"enabled",
|
||||
"auth_type",
|
||||
"auth_login",
|
||||
"auth_password",
|
||||
"auth_token",
|
||||
).
|
||||
Values(
|
||||
provider.ChainID,
|
||||
provider.Name,
|
||||
provider.URL,
|
||||
provider.EnableRPSLimiter,
|
||||
provider.Type,
|
||||
provider.Enabled,
|
||||
provider.AuthType,
|
||||
provider.AuthLogin,
|
||||
provider.AuthPassword,
|
||||
provider.AuthToken,
|
||||
)
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build insert query: %w", err)
|
||||
}
|
||||
|
||||
_, err = p.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute insert query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete providers for a specific ChainID
|
||||
func (p *RpcProvidersPersistence) DeleteRpcProviders(chainID uint64) error {
|
||||
q := sq.Delete("rpc_providers").
|
||||
Where(sq.Eq{"chain_id": chainID})
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build delete query: %w", err)
|
||||
}
|
||||
|
||||
_, err = p.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute delete query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete all providers
|
||||
func (p *RpcProvidersPersistence) DeleteAllRpcProviders() error {
|
||||
q := sq.Delete("rpc_providers")
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build delete query: %w", err)
|
||||
}
|
||||
|
||||
_, err = p.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute delete query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update an existing provider
|
||||
func (p *RpcProvidersPersistence) UpdateRpcProvider(provider params.RpcProvider) error {
|
||||
// Validate the provider
|
||||
if err := p.validator.Struct(provider); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Proceed with updating the provider in the database
|
||||
q := sq.Update("rpc_providers").
|
||||
SetMap(sq.Eq{
|
||||
"url": provider.URL,
|
||||
"enable_rps_limiter": provider.EnableRPSLimiter,
|
||||
"type": provider.Type,
|
||||
"enabled": provider.Enabled,
|
||||
"auth_type": provider.AuthType,
|
||||
"auth_login": provider.AuthLogin,
|
||||
"auth_password": provider.AuthPassword,
|
||||
"auth_token": provider.AuthToken,
|
||||
}).
|
||||
Where(sq.Eq{"id": provider.ID})
|
||||
|
||||
query, args, err := q.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build update query: %w", err)
|
||||
}
|
||||
|
||||
_, err = p.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute update query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set the list of providers for a ChainID, replacing any existing providers
|
||||
func (p *RpcProvidersPersistence) SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error {
|
||||
if err := p.DeleteRpcProviders(chainID); err != nil {
|
||||
return fmt.Errorf("failed to delete existing providers: %w", err)
|
||||
}
|
||||
|
||||
for _, provider := range newProviders {
|
||||
if err := p.AddRpcProvider(provider); err != nil {
|
||||
return fmt.Errorf("failed to add new provider: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
157
rpc/network/db/rpc_provider_db_test.go
Normal file
157
rpc/network/db/rpc_provider_db_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/status-im/status-go/api"
|
||||
"github.com/status-im/status-go/appdatabase"
|
||||
"github.com/status-im/status-go/params"
|
||||
"github.com/status-im/status-go/rpc/network/db"
|
||||
"github.com/status-im/status-go/rpc/network/testutil"
|
||||
"github.com/status-im/status-go/t/helpers"
|
||||
)
|
||||
|
||||
type RpcProviderPersistenceTestSuite struct {
|
||||
suite.Suite
|
||||
db *sql.DB
|
||||
rpcPersistence *db.RpcProvidersPersistence
|
||||
}
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) SetupTest() {
|
||||
testDb := setupTestNetworkDB(s.T())
|
||||
s.db = testDb
|
||||
s.rpcPersistence = db.NewRpcProvidersPersistence(testDb)
|
||||
}
|
||||
|
||||
func setupTestNetworkDB(t *testing.T) *sql.DB {
|
||||
testDb, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "rpc-providers-tests")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, cleanup()) })
|
||||
return testDb
|
||||
}
|
||||
|
||||
func TestRpcProviderPersistenceTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(RpcProviderPersistenceTestSuite))
|
||||
}
|
||||
|
||||
// Test cases
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) TestAddAndGetRpcProvider() {
|
||||
provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com")
|
||||
|
||||
err := s.rpcPersistence.AddRpcProvider(provider)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify the added provider
|
||||
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareProvidersList(s.T(), []params.RpcProvider{provider}, providers)
|
||||
}
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) TestGetRpcProvidersByType() {
|
||||
providers := []params.RpcProvider{
|
||||
testutil.CreateProvider(api.MainnetChainID, "UserProvider1", params.UserProviderType, true, "https://provider1.example.com"),
|
||||
testutil.CreateProvider(api.MainnetChainID, "EmbeddedDirect1", params.EmbeddedDirectProviderType, false, "https://provider2.example.com"),
|
||||
testutil.CreateProvider(api.MainnetChainID, "UserProvider2", params.UserProviderType, false, "https://provider3.example.com"),
|
||||
testutil.CreateProvider(api.MainnetChainID, "EmbeddedProxy1", params.EmbeddedProxyProviderType, true, "https://provider4.example.com"),
|
||||
}
|
||||
|
||||
for _, provider := range providers {
|
||||
err := s.rpcPersistence.AddRpcProvider(provider)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Verify by type
|
||||
userProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.UserProviderType)
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[0], providers[2]}, userProviders)
|
||||
|
||||
embeddedDirectProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.EmbeddedDirectProviderType)
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[1]}, embeddedDirectProviders)
|
||||
|
||||
embeddedProxyProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.EmbeddedProxyProviderType)
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[3]}, embeddedProxyProviders)
|
||||
}
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) TestDeleteRpcProviders() {
|
||||
provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com")
|
||||
|
||||
err := s.rpcPersistence.AddRpcProvider(provider)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.rpcPersistence.DeleteRpcProviders(api.MainnetChainID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify deletion
|
||||
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(providers)
|
||||
}
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) TestUpdateRpcProvider() {
|
||||
provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com")
|
||||
|
||||
err := s.rpcPersistence.AddRpcProvider(provider)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Retrieve provider to get the ID
|
||||
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(providers, 1)
|
||||
|
||||
provider.ID = providers[0].ID
|
||||
provider.URL = "https://provider1-updated.example.com"
|
||||
provider.EnableRPSLimiter = false
|
||||
|
||||
err = s.rpcPersistence.UpdateRpcProvider(provider)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify update
|
||||
updatedProviders, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareProvidersList(s.T(), []params.RpcProvider{provider}, updatedProviders)
|
||||
}
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) TestSetRpcProviders() {
|
||||
initialProviders := []params.RpcProvider{
|
||||
testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com"),
|
||||
testutil.CreateProvider(api.MainnetChainID, "Provider2", params.EmbeddedDirectProviderType, false, "https://provider2.example.com"),
|
||||
}
|
||||
|
||||
for _, provider := range initialProviders {
|
||||
err := s.rpcPersistence.AddRpcProvider(provider)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
newProviders := []params.RpcProvider{
|
||||
testutil.CreateProvider(api.MainnetChainID, "NewProvider1", params.UserProviderType, true, "https://newprovider1.example.com"),
|
||||
testutil.CreateProvider(api.MainnetChainID, "NewProvider2", params.EmbeddedProxyProviderType, true, "https://newprovider2.example.com"),
|
||||
}
|
||||
|
||||
err := s.rpcPersistence.SetRpcProviders(api.MainnetChainID, newProviders)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify replacement
|
||||
providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID)
|
||||
s.Require().NoError(err)
|
||||
testutil.CompareProvidersList(s.T(), newProviders, providers)
|
||||
}
|
||||
|
||||
func (s *RpcProviderPersistenceTestSuite) TestAddRpcProviderValidation() {
|
||||
invalidProvider := params.RpcProvider{
|
||||
ChainID: 0, // Invalid: must be greater than 0
|
||||
Name: "", // Invalid: cannot be empty
|
||||
URL: "invalid-url", // Invalid: not a valid URL
|
||||
Type: "invalid-type", // Invalid: not in allowed values
|
||||
}
|
||||
|
||||
err := s.rpcPersistence.AddRpcProvider(invalidProvider)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "validation failed")
|
||||
}
|
57
rpc/network/db/utils.go
Normal file
57
rpc/network/db/utils.go
Normal file
@ -0,0 +1,57 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/status-im/status-go/params"
|
||||
)
|
||||
|
||||
// Deprecated: fillDeprecatedURLs populates the `original_rpc_url`, `original_fallback_url`, `rpc_url`,
|
||||
// `fallback_url`, `defaultRpcUrl`, `defaultFallbackURL`, and `defaultFallbackURL2` fields.
|
||||
// Keep for backwrad compatibility until it's fully integrated
|
||||
func FillDeprecatedURLs(network *params.Network, providers []params.RpcProvider) {
|
||||
var embeddedDirect []params.RpcProvider
|
||||
var embeddedProxy []params.RpcProvider
|
||||
var userProviders []params.RpcProvider
|
||||
|
||||
// Categorize providers
|
||||
for _, provider := range providers {
|
||||
switch provider.Type {
|
||||
case params.EmbeddedDirectProviderType:
|
||||
embeddedDirect = append(embeddedDirect, provider)
|
||||
case params.EmbeddedProxyProviderType:
|
||||
embeddedProxy = append(embeddedProxy, provider)
|
||||
case params.UserProviderType:
|
||||
userProviders = append(userProviders, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Set original_*_url fields based on EmbeddedDirectProviderType providers
|
||||
if len(embeddedDirect) > 0 {
|
||||
network.OriginalRPCURL = embeddedDirect[0].URL
|
||||
if len(embeddedDirect) > 1 {
|
||||
network.OriginalFallbackURL = embeddedDirect[1].URL
|
||||
}
|
||||
}
|
||||
|
||||
// Set rpc_url and fallback_url based on User providers or EmbeddedDirectProviderType if no User providers exist
|
||||
if len(userProviders) > 0 {
|
||||
network.RPCURL = userProviders[0].URL
|
||||
if len(userProviders) > 1 {
|
||||
network.FallbackURL = userProviders[1].URL
|
||||
}
|
||||
} else {
|
||||
// Default to EmbeddedDirectProviderType providers if no User providers exist
|
||||
network.RPCURL = network.OriginalRPCURL
|
||||
network.FallbackURL = network.OriginalFallbackURL
|
||||
}
|
||||
|
||||
// Set default_*_url fields based on EmbeddedProxyProviderType providers
|
||||
if len(embeddedProxy) > 0 {
|
||||
network.DefaultRPCURL = embeddedProxy[0].URL
|
||||
if len(embeddedProxy) > 1 {
|
||||
network.DefaultFallbackURL = embeddedProxy[1].URL
|
||||
}
|
||||
if len(embeddedProxy) > 2 {
|
||||
network.DefaultFallbackURL2 = embeddedProxy[2].URL
|
||||
}
|
||||
}
|
||||
}
|
104
rpc/network/testutil/testutil.go
Normal file
104
rpc/network/testutil/testutil.go
Normal file
@ -0,0 +1,104 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/status-im/status-go/api"
|
||||
"github.com/status-im/status-go/params"
|
||||
)
|
||||
|
||||
// Helper function to create a provider
|
||||
func CreateProvider(chainID uint64, name string, providerType params.RpcProviderType, enabled bool, url string) params.RpcProvider {
|
||||
return params.RpcProvider{
|
||||
ChainID: chainID,
|
||||
Name: name,
|
||||
URL: url,
|
||||
EnableRPSLimiter: true,
|
||||
Type: providerType,
|
||||
Enabled: enabled,
|
||||
AuthType: params.BasicAuth,
|
||||
AuthLogin: "user1",
|
||||
AuthPassword: "password1",
|
||||
AuthToken: "",
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a network
|
||||
func CreateNetwork(chainID uint64, chainName string, providers []params.RpcProvider) *params.Network {
|
||||
return ¶ms.Network{
|
||||
ChainID: chainID,
|
||||
ChainName: chainName,
|
||||
BlockExplorerURL: "https://explorer.example.com",
|
||||
IconURL: "network/Network=" + chainName,
|
||||
NativeCurrencyName: "Ether",
|
||||
NativeCurrencySymbol: "ETH",
|
||||
NativeCurrencyDecimals: 18,
|
||||
IsTest: false,
|
||||
Layer: 2,
|
||||
Enabled: true,
|
||||
ChainColor: "#E90101",
|
||||
ShortName: "eth",
|
||||
RelatedChainID: api.OptimismSepoliaChainID,
|
||||
RpcProviders: providers,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare two providers
|
||||
func CompareProviders(t require.TestingT, expected, actual params.RpcProvider) {
|
||||
require.Equal(t, expected.ChainID, actual.ChainID)
|
||||
require.Equal(t, expected.Name, actual.Name)
|
||||
require.Equal(t, expected.URL, actual.URL)
|
||||
require.Equal(t, expected.EnableRPSLimiter, actual.EnableRPSLimiter)
|
||||
require.Equal(t, expected.Type, actual.Type)
|
||||
require.Equal(t, expected.Enabled, actual.Enabled)
|
||||
require.Equal(t, expected.AuthType, actual.AuthType)
|
||||
require.Equal(t, expected.AuthLogin, actual.AuthLogin)
|
||||
require.Equal(t, expected.AuthPassword, actual.AuthPassword)
|
||||
require.Equal(t, expected.AuthToken, actual.AuthToken)
|
||||
}
|
||||
|
||||
// Helper function to compare two networks
|
||||
func CompareNetworks(t require.TestingT, expected, actual *params.Network) {
|
||||
require.Equal(t, expected.ChainID, actual.ChainID, "ChainID does not match")
|
||||
require.Equal(t, expected.ChainName, actual.ChainName, "ChainName does not match for ChainID %d", actual.ChainID)
|
||||
require.Equal(t, expected.BlockExplorerURL, actual.BlockExplorerURL)
|
||||
require.Equal(t, expected.NativeCurrencyName, actual.NativeCurrencyName)
|
||||
require.Equal(t, expected.NativeCurrencySymbol, actual.NativeCurrencySymbol)
|
||||
require.Equal(t, expected.NativeCurrencyDecimals, actual.NativeCurrencyDecimals)
|
||||
require.Equal(t, expected.IsTest, actual.IsTest)
|
||||
require.Equal(t, expected.Layer, actual.Layer)
|
||||
require.Equal(t, expected.Enabled, actual.Enabled)
|
||||
require.Equal(t, expected.ChainColor, actual.ChainColor)
|
||||
require.Equal(t, expected.ShortName, actual.ShortName)
|
||||
require.Equal(t, expected.RelatedChainID, actual.RelatedChainID)
|
||||
}
|
||||
|
||||
// Helper function to compare lists of providers
|
||||
func CompareProvidersList(t require.TestingT, expectedProviders, actualProviders []params.RpcProvider) {
|
||||
require.Len(t, actualProviders, len(expectedProviders))
|
||||
expectedMap := make(map[string]params.RpcProvider, len(expectedProviders))
|
||||
for _, provider := range expectedProviders {
|
||||
expectedMap[provider.Name] = provider
|
||||
}
|
||||
|
||||
for _, provider := range actualProviders {
|
||||
expectedProvider, exists := expectedMap[provider.Name]
|
||||
require.True(t, exists, "Unexpected provider '%s'", provider.Name)
|
||||
CompareProviders(t, expectedProvider, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare lists of networks
|
||||
func CompareNetworksList(t require.TestingT, expectedNetworks, actualNetworks []*params.Network) {
|
||||
require.Len(t, actualNetworks, len(expectedNetworks))
|
||||
expectedMap := make(map[uint64]*params.Network, len(expectedNetworks))
|
||||
for _, network := range expectedNetworks {
|
||||
expectedMap[network.ChainID] = network
|
||||
}
|
||||
|
||||
for _, network := range actualNetworks {
|
||||
expectedNetwork, exists := expectedMap[network.ChainID]
|
||||
require.True(t, exists, "Unexpected network with ChainID %d", network.ChainID)
|
||||
CompareNetworks(t, expectedNetwork, network)
|
||||
}
|
||||
}
|
@ -6,3 +6,9 @@ import "database/sql"
|
||||
type StatementCreator interface {
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
type StatementExecutor interface {
|
||||
StatementCreator
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user