diff --git a/api/backend_test.go b/api/backend_test.go index c3cfb7007..aa0e03cc6 100644 --- a/api/backend_test.go +++ b/api/backend_test.go @@ -1579,12 +1579,12 @@ func TestWalletConfigOnLoginAccount(t *testing.T) { } require.Equal(t, b.config.WalletConfig.InfuraAPIKey, infuraToken) - require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[mainnetChainID], alchemyEthereumMainnetToken) - require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[sepoliaChainID], alchemyEthereumSepoliaToken) - require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[arbitrumChainID], alchemyArbitrumMainnetToken) - require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[arbitrumSepoliaChainID], alchemyArbitrumSepoliaToken) - require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[optimismChainID], alchemyOptimismMainnetToken) - require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[optimismSepoliaChainID], alchemyOptimismSepoliaToken) + require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[MainnetChainID], alchemyEthereumMainnetToken) + require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[SepoliaChainID], alchemyEthereumSepoliaToken) + require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[ArbitrumChainID], alchemyArbitrumMainnetToken) + require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[ArbitrumSepoliaChainID], alchemyArbitrumSepoliaToken) + require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[OptimismChainID], alchemyOptimismMainnetToken) + require.Equal(t, b.config.WalletConfig.AlchemyAPIKeys[OptimismSepoliaChainID], alchemyOptimismSepoliaToken) require.Equal(t, b.config.WalletConfig.RaribleMainnetAPIKey, raribleMainnetAPIKey) require.Equal(t, b.config.WalletConfig.RaribleTestnetAPIKey, raribleTestnetAPIKey) diff --git a/api/default_networks.go b/api/default_networks.go index 32cd654f1..e807cc6c9 100644 --- a/api/default_networks.go +++ b/api/default_networks.go @@ -10,12 +10,12 @@ import ( ) const ( - mainnetChainID uint64 = 1 - sepoliaChainID uint64 = 11155111 - optimismChainID uint64 = 10 - optimismSepoliaChainID uint64 = 11155420 - arbitrumChainID uint64 = 42161 - arbitrumSepoliaChainID uint64 = 421614 + MainnetChainID uint64 = 1 + SepoliaChainID uint64 = 11155111 + OptimismChainID uint64 = 10 + OptimismSepoliaChainID uint64 = 11155420 + ArbitrumChainID uint64 = 42161 + ArbitrumSepoliaChainID uint64 = 421614 sntSymbol = "SNT" sttSymbol = "STT" ) @@ -24,7 +24,7 @@ var ganacheTokenAddress = common.HexToAddress("0x8571Ddc46b10d31EF963aF49b6C7799 func mainnet(stageName string) params.Network { return params.Network{ - ChainID: mainnetChainID, + ChainID: MainnetChainID, ChainName: "Mainnet", DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/ethereum/mainnet/", stageName), DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/ethereum/mainnet/", stageName), @@ -41,13 +41,13 @@ func mainnet(stageName string) params.Network { IsTest: false, Layer: 1, Enabled: true, - RelatedChainID: sepoliaChainID, + RelatedChainID: SepoliaChainID, } } func sepolia(stageName string) params.Network { return params.Network{ - ChainID: sepoliaChainID, + ChainID: SepoliaChainID, ChainName: "Mainnet", DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/ethereum/sepolia/", stageName), DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/ethereum/sepolia/", stageName), @@ -64,13 +64,13 @@ func sepolia(stageName string) params.Network { IsTest: true, Layer: 1, Enabled: true, - RelatedChainID: mainnetChainID, + RelatedChainID: MainnetChainID, } } func optimism(stageName string) params.Network { return params.Network{ - ChainID: optimismChainID, + ChainID: OptimismChainID, ChainName: "Optimism", DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/optimism/mainnet/", stageName), DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/optimism/mainnet/", stageName), @@ -87,13 +87,13 @@ func optimism(stageName string) params.Network { IsTest: false, Layer: 2, Enabled: true, - RelatedChainID: optimismSepoliaChainID, + RelatedChainID: OptimismSepoliaChainID, } } func optimismSepolia(stageName string) params.Network { return params.Network{ - ChainID: optimismSepoliaChainID, + ChainID: OptimismSepoliaChainID, ChainName: "Optimism", DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/optimism/sepolia/", stageName), DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/optimism/sepolia/", stageName), @@ -110,13 +110,13 @@ func optimismSepolia(stageName string) params.Network { IsTest: true, Layer: 2, Enabled: false, - RelatedChainID: optimismChainID, + RelatedChainID: OptimismChainID, } } func arbitrum(stageName string) params.Network { return params.Network{ - ChainID: arbitrumChainID, + ChainID: ArbitrumChainID, ChainName: "Arbitrum", DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/arbitrum/mainnet/", stageName), DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/arbitrum/mainnet/", stageName), @@ -133,13 +133,13 @@ func arbitrum(stageName string) params.Network { IsTest: false, Layer: 2, Enabled: true, - RelatedChainID: arbitrumSepoliaChainID, + RelatedChainID: ArbitrumSepoliaChainID, } } func arbitrumSepolia(stageName string) params.Network { return params.Network{ - ChainID: arbitrumSepoliaChainID, + ChainID: ArbitrumSepoliaChainID, ChainName: "Arbitrum", DefaultRPCURL: fmt.Sprintf("https://%s.api.status.im/nodefleet/arbitrum/sepolia/", stageName), DefaultFallbackURL: fmt.Sprintf("https://%s.api.status.im/infura/arbitrum/sepolia/", stageName), @@ -156,7 +156,7 @@ func arbitrumSepolia(stageName string) params.Network { IsTest: true, Layer: 2, Enabled: false, - RelatedChainID: arbitrumChainID, + RelatedChainID: ArbitrumChainID, } } @@ -204,7 +204,7 @@ func setRPCs(networks []params.Network, request *requests.WalletSecretsConfig) [ if request.GanacheURL != "" { n.RPCURL = request.GanacheURL n.FallbackURL = request.GanacheURL - if n.ChainID == mainnetChainID { + if n.ChainID == MainnetChainID { n.TokenOverrides = []params.TokenOverride{ mainnetGanacheTokenOverrides, } diff --git a/api/default_networks_test.go b/api/default_networks_test.go index 556a677ab..4d3ecba80 100644 --- a/api/default_networks_test.go +++ b/api/default_networks_test.go @@ -28,12 +28,12 @@ func TestBuildDefaultNetworks(t *testing.T) { for _, n := range actualNetworks { var err error switch n.ChainID { - case mainnetChainID: - case sepoliaChainID: - case optimismChainID: - case optimismSepoliaChainID: - case arbitrumChainID: - case arbitrumSepoliaChainID: + case MainnetChainID: + case SepoliaChainID: + case OptimismChainID: + case OptimismSepoliaChainID: + case ArbitrumChainID: + case ArbitrumSepoliaChainID: default: err = errors.Errorf("unexpected chain id: %d", n.ChainID) } @@ -70,7 +70,7 @@ func TestBuildDefaultNetworksGanache(t *testing.T) { require.True(t, strings.Contains(n.FallbackURL, ganacheURL)) } - require.Equal(t, mainnetChainID, actualNetworks[0].ChainID) + require.Equal(t, MainnetChainID, actualNetworks[0].ChainID) require.NotNil(t, actualNetworks[0].TokenOverrides) require.Len(t, actualNetworks[0].TokenOverrides, 1) diff --git a/api/defaults.go b/api/defaults.go index 4982bda98..653189988 100644 --- a/api/defaults.go +++ b/api/defaults.go @@ -198,22 +198,22 @@ func buildWalletConfig(request *requests.WalletSecretsConfig, statusProxyEnabled } if request.AlchemyEthereumMainnetToken != "" { - walletConfig.AlchemyAPIKeys[mainnetChainID] = request.AlchemyEthereumMainnetToken + walletConfig.AlchemyAPIKeys[MainnetChainID] = request.AlchemyEthereumMainnetToken } if request.AlchemyEthereumSepoliaToken != "" { - walletConfig.AlchemyAPIKeys[sepoliaChainID] = request.AlchemyEthereumSepoliaToken + walletConfig.AlchemyAPIKeys[SepoliaChainID] = request.AlchemyEthereumSepoliaToken } if request.AlchemyArbitrumMainnetToken != "" { - walletConfig.AlchemyAPIKeys[arbitrumChainID] = request.AlchemyArbitrumMainnetToken + walletConfig.AlchemyAPIKeys[ArbitrumChainID] = request.AlchemyArbitrumMainnetToken } if request.AlchemyArbitrumSepoliaToken != "" { - walletConfig.AlchemyAPIKeys[arbitrumSepoliaChainID] = request.AlchemyArbitrumSepoliaToken + walletConfig.AlchemyAPIKeys[ArbitrumSepoliaChainID] = request.AlchemyArbitrumSepoliaToken } if request.AlchemyOptimismMainnetToken != "" { - walletConfig.AlchemyAPIKeys[optimismChainID] = request.AlchemyOptimismMainnetToken + walletConfig.AlchemyAPIKeys[OptimismChainID] = request.AlchemyOptimismMainnetToken } if request.AlchemyOptimismSepoliaToken != "" { - walletConfig.AlchemyAPIKeys[optimismSepoliaChainID] = request.AlchemyOptimismSepoliaToken + walletConfig.AlchemyAPIKeys[OptimismSepoliaChainID] = request.AlchemyOptimismSepoliaToken } if request.StatusProxyMarketUser != "" { walletConfig.StatusProxyMarketUser = request.StatusProxyMarketUser diff --git a/appdatabase/migrations/sql/1733400346_add_rpc_providers.up.sql b/appdatabase/migrations/sql/1733400346_add_rpc_providers.up.sql new file mode 100644 index 000000000..26e19d65c --- /dev/null +++ b/appdatabase/migrations/sql/1733400346_add_rpc_providers.up.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS rpc_providers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- Unique provider ID (sorting) + chain_id INTEGER NOT NULL CHECK (chain_id > 0), -- Chain ID for the network + name TEXT NOT NULL CHECK (LENGTH(name) > 0), -- Provider name + url TEXT NOT NULL CHECK (LENGTH(url) > 0), -- Provider URL + enable_rps_limiter BOOLEAN NOT NULL DEFAULT FALSE, -- Enable RPS limiter + type TEXT NOT NULL DEFAULT 'user', -- Provider type: embedded-proxy, embedded-direct, user + enabled BOOLEAN NOT NULL DEFAULT TRUE, -- Whether the provider is active or not + auth_type TEXT NOT NULL DEFAULT 'no-auth', -- Authentication type: no-auth, basic-auth, token-auth + auth_login TEXT, -- BasicAuth login (nullable) + auth_password TEXT, -- Password for BasicAuth (nullable) + auth_token TEXT -- Token for TokenAuth (nullable) +); diff --git a/params/config.go b/params/config.go index ca3ee5bf9..043988345 100644 --- a/params/config.go +++ b/params/config.go @@ -14,7 +14,6 @@ import ( "go.uber.org/zap" validator "gopkg.in/go-playground/validator.v9" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/params" @@ -482,7 +481,7 @@ type NodeConfig struct { // (persistent storage of user's mailserver records). MailserversConfig MailserversConfig - // Web3ProviderConfig extra configuration for provider.Service + // Web3ProviderConfig extra configuration for provider.Service. // (desktop provider API) Web3ProviderConfig Web3ProviderConfig @@ -514,35 +513,6 @@ type NodeConfig struct { ProcessBackedupMessages bool } -type TokenOverride struct { - Symbol string `json:"symbol"` - Address common.Address `json:"address"` -} - -type Network struct { - ChainID uint64 `json:"chainId"` - ChainName string `json:"chainName"` - DefaultRPCURL string `json:"defaultRpcUrl"` // proxy rpc url - DefaultFallbackURL string `json:"defaultFallbackURL"` // proxy fallback url - DefaultFallbackURL2 string `json:"defaultFallbackURL2"` // second proxy fallback url - RPCURL string `json:"rpcUrl"` - OriginalRPCURL string `json:"originalRpcUrl"` - FallbackURL string `json:"fallbackURL"` - OriginalFallbackURL string `json:"originalFallbackURL"` - BlockExplorerURL string `json:"blockExplorerUrl,omitempty"` - IconURL string `json:"iconUrl,omitempty"` - NativeCurrencyName string `json:"nativeCurrencyName,omitempty"` - NativeCurrencySymbol string `json:"nativeCurrencySymbol,omitempty"` - NativeCurrencyDecimals uint64 `json:"nativeCurrencyDecimals"` - IsTest bool `json:"isTest"` - Layer uint64 `json:"layer"` - Enabled bool `json:"enabled"` - ChainColor string `json:"chainColor"` - ShortName string `json:"shortName"` - TokenOverrides []TokenOverride `json:"tokenOverrides"` - RelatedChainID uint64 `json:"relatedChainId"` -} - // WalletConfig extra configuration for wallet.Service. type WalletConfig struct { Enabled bool @@ -598,7 +568,7 @@ type MailserversConfig struct { Enabled bool } -// ProviderConfig extra configuration for provider.Service +// ProviderAuthConfig extra configuration for provider.Service type Web3ProviderConfig struct { Enabled bool } diff --git a/params/network_config.go b/params/network_config.go new file mode 100644 index 000000000..d9c9ad600 --- /dev/null +++ b/params/network_config.go @@ -0,0 +1,95 @@ +package params + +import "github.com/ethereum/go-ethereum/common" + +// RpcProviderAuthType defines the different types of authentication for RPC providers +type RpcProviderAuthType string + +const ( + NoAuth RpcProviderAuthType = "no-auth" // No authentication + BasicAuth RpcProviderAuthType = "basic-auth" // HTTP Header "Authorization: Basic base64(username:password)" + TokenAuth RpcProviderAuthType = "token-auth" // URL Token-based authentication "https://api.example.com/YOUR_TOKEN" +) + +// RpcProviderType defines the type of RPC provider +type RpcProviderType string + +const ( + EmbeddedProxyProviderType RpcProviderType = "embedded-proxy" // Proxy-based RPC provider + EmbeddedDirectProviderType RpcProviderType = "embedded-direct" // Direct RPC provider + UserProviderType RpcProviderType = "user" // User-defined RPC provider +) + +// RpcProvider represents an RPC provider configuration with various options +type RpcProvider struct { + ID int64 `json:"id" validate:"omitempty"` // Auto-increment ID (for sorting order) + ChainID uint64 `json:"chainId" validate:"required,gt=0"` // Chain ID of the network + Name string `json:"name" validate:"required,min=1"` // Provider name for identification + URL string `json:"url" validate:"required,url"` // Current Provider URL + EnableRPSLimiter bool `json:"enableRpsLimiter"` // Enable RPC rate limiting for this provider + Type RpcProviderType `json:"type" validate:"required,oneof=embedded-proxy embedded-direct user"` + Enabled bool `json:"enabled"` // Whether the provider is enabled + // Authentication + AuthType RpcProviderAuthType `json:"authType" validate:"required,oneof=no-auth basic-auth token-auth"` // Type of authentication + AuthLogin string `json:"authLogin" validate:"omitempty,min=1"` // Login for BasicAuth (empty string if not used) + AuthPassword string `json:"authPassword" validate:"omitempty,min=1"` // Password for BasicAuth (empty string if not used) + AuthToken string `json:"authToken" validate:"omitempty,min=1"` // Token for TokenAuth (empty string if not used) +} + +type TokenOverride struct { + Symbol string `json:"symbol"` + Address common.Address `json:"address"` +} + +type Network struct { + ChainID uint64 `json:"chainId" validate:"required,gt=0"` + ChainName string `json:"chainName" validate:"required,min=1"` + RpcProviders []RpcProvider `json:"rpcProviders" validate:"dive,required"` // List of RPC providers, in the order in which they are accessed + + // Deprecated fields (kept for backward compatibility) + // FIXME: Removal of deprecated fields in integration PR https://github.com/status-im/status-go/pull/6178 + DefaultRPCURL string `json:"defaultRpcUrl" validate:"omitempty,url"` // Deprecated: proxy rpc url + DefaultFallbackURL string `json:"defaultFallbackURL" validate:"omitempty,url"` // Deprecated: proxy fallback url + DefaultFallbackURL2 string `json:"defaultFallbackURL2" validate:"omitempty,url"` // Deprecated: second proxy fallback url + RPCURL string `json:"rpcUrl" validate:"omitempty,url"` // Deprecated: direct rpc url + OriginalRPCURL string `json:"originalRpcUrl" validate:"omitempty,url"` // Deprecated: direct rpc url if user overrides RPCURL + FallbackURL string `json:"fallbackURL" validate:"omitempty,url"` // Deprecated + OriginalFallbackURL string `json:"originalFallbackURL" validate:"omitempty,url"` // Deprecated + + BlockExplorerURL string `json:"blockExplorerUrl,omitempty" validate:"omitempty,url"` + IconURL string `json:"iconUrl,omitempty" validate:"omitempty"` + NativeCurrencyName string `json:"nativeCurrencyName,omitempty" validate:"omitempty,min=1"` + NativeCurrencySymbol string `json:"nativeCurrencySymbol,omitempty" validate:"omitempty,min=1"` + NativeCurrencyDecimals uint64 `json:"nativeCurrencyDecimals" validate:"omitempty"` + IsTest bool `json:"isTest"` + Layer uint64 `json:"layer" validate:"omitempty"` + Enabled bool `json:"enabled"` + ChainColor string `json:"chainColor" validate:"omitempty"` + ShortName string `json:"shortName" validate:"omitempty,min=1"` + TokenOverrides []TokenOverride `json:"tokenOverrides" validate:"omitempty,dive"` + RelatedChainID uint64 `json:"relatedChainId" validate:"omitempty"` +} + +func newRpcProvider(chainID uint64, name, url string, enableRpsLimiter bool, providerType RpcProviderType) *RpcProvider { + return &RpcProvider{ + ChainID: chainID, + Name: name, + URL: url, + EnableRPSLimiter: enableRpsLimiter, + Type: providerType, + Enabled: true, + AuthType: NoAuth, + } +} + +func NewUserProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider { + return newRpcProvider(chainID, name, url, enableRpsLimiter, UserProviderType) +} + +func NewProxyProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider { + return newRpcProvider(chainID, name, url, enableRpsLimiter, EmbeddedProxyProviderType) +} + +func NewDirectProvider(chainID uint64, name, url string, enableRpsLimiter bool) *RpcProvider { + return newRpcProvider(chainID, name, url, enableRpsLimiter, EmbeddedDirectProviderType) +} diff --git a/params/networkhelper/provider_utils.go b/params/networkhelper/provider_utils.go new file mode 100644 index 000000000..02cac0355 --- /dev/null +++ b/params/networkhelper/provider_utils.go @@ -0,0 +1,186 @@ +package networkhelper + +import ( + "net/url" + "strings" + + "github.com/status-im/status-go/params" +) + +// MergeProvidersPreservingUsersAndEnabledState merges new embedded providers with the current ones, +// preserving user-defined providers and maintaining the Enabled state. +func MergeProvidersPreservingUsersAndEnabledState(currentProviders, newProviders []params.RpcProvider) []params.RpcProvider { + // Create a map for quick lookup of the Enabled state by Name + enabledState := make(map[string]bool, len(currentProviders)) + for _, provider := range currentProviders { + enabledState[provider.Name] = provider.Enabled + } + + // Update the Enabled field in newProviders if the Name matches + for i := range newProviders { + if enabled, exists := enabledState[newProviders[i].Name]; exists { + newProviders[i].Enabled = enabled + } + } + + // Retain current providers of type UserProviderType and add them to the beginning of the list + mergedProviders := make([]params.RpcProvider, 0, len(currentProviders)+len(newProviders)) + for _, provider := range currentProviders { + if provider.Type == params.UserProviderType { + mergedProviders = append(mergedProviders, provider) + } + } + + // Add the updated newProviders + mergedProviders = append(mergedProviders, newProviders...) + + return mergedProviders +} + +// ToggleUserProviders enables or disables all user-defined providers and disables other types. +func ToggleUserProviders(providers []params.RpcProvider, enabled bool) []params.RpcProvider { + for i := range providers { + if providers[i].Type == params.UserProviderType { + providers[i].Enabled = enabled + } else { + providers[i].Enabled = !enabled + } + } + return providers +} + +// GetEmbeddedProviders returns the embedded providers from the list. +func GetEmbeddedProviders(providers []params.RpcProvider) []params.RpcProvider { + var embeddedProviders []params.RpcProvider + for _, provider := range providers { + if provider.Type != params.UserProviderType { + embeddedProviders = append(embeddedProviders, provider) + } + } + return embeddedProviders +} + +// GetUserProviders returns the user-defined providers from the list. +func GetUserProviders(providers []params.RpcProvider) []params.RpcProvider { + var userProviders []params.RpcProvider + for _, provider := range providers { + if provider.Type == params.UserProviderType { + userProviders = append(userProviders, provider) + } + } + return userProviders +} + +// ReplaceUserProviders replaces user-defined providers with new ones, retaining the rest of the providers. +func ReplaceUserProviders(currentProviders, newUserProviders []params.RpcProvider) []params.RpcProvider { + // Extract embedded providers from the current list + embeddedProviders := GetEmbeddedProviders(currentProviders) + userProviders := GetUserProviders(newUserProviders) + + // Combine new user providers with the existing embedded providers + return append(userProviders, embeddedProviders...) +} + +// ReplaceEmbeddedProviders replaces embedded providers with new ones, retaining user-defined providers. +func ReplaceEmbeddedProviders(currentProviders, newEmbeddedProviders []params.RpcProvider) []params.RpcProvider { + // Extract user-defined providers from the current list + userProviders := GetUserProviders(currentProviders) + embeddedProviders := GetEmbeddedProviders(newEmbeddedProviders) + + // Combine existing user-defined providers with the new embedded providers + return append(userProviders, embeddedProviders...) +} + +// OverrideEmbeddedProxyProviders updates all embedded-proxy providers in the given networks. +// It sets the `Enabled` flag and configures the `AuthLogin` and `AuthPassword` for each provider. +func OverrideEmbeddedProxyProviders(networks []params.Network, enabled bool, user, password string) []params.Network { + updatedNetworks := make([]params.Network, len(networks)) + for i, network := range networks { + // Deep copy the network to avoid mutating the input slice + updatedNetwork := network + updatedProviders := make([]params.RpcProvider, len(network.RpcProviders)) + + // Update the embedded-proxy providers + for j, provider := range network.RpcProviders { + if provider.Type == params.EmbeddedProxyProviderType { + provider.Enabled = enabled + provider.AuthLogin = user + provider.AuthPassword = password + } + updatedProviders[j] = provider + } + + updatedNetwork.RpcProviders = updatedProviders + updatedNetworks[i] = updatedNetwork + } + + return updatedNetworks +} + +func deepCopyNetworks(networks []params.Network) []params.Network { + updatedNetworks := make([]params.Network, len(networks)) + for i, network := range networks { + updatedNetwork := network + updatedNetwork.RpcProviders = make([]params.RpcProvider, len(network.RpcProviders)) + copy(updatedNetwork.RpcProviders, network.RpcProviders) + updatedNetworks[i] = updatedNetwork + } + return updatedNetworks +} + +func OverrideDirectProvidersAuth(networks []params.Network, authTokens map[string]string) []params.Network { + updatedNetworks := deepCopyNetworks(networks) + + for i := range updatedNetworks { + network := &updatedNetworks[i] + + for j := range network.RpcProviders { + provider := &network.RpcProviders[j] + + if provider.Type != params.EmbeddedDirectProviderType { + continue + } + + host, err := extractHost(provider.URL) + if err != nil { + continue + } + + for suffix, token := range authTokens { + if strings.HasSuffix(host, suffix) && token != "" { + provider.AuthType = params.TokenAuth + provider.AuthToken = token + break + } + } + } + } + return updatedNetworks +} + +func OverrideGanacheToken(networks []params.Network, ganacheURL string, chainID uint64, tokenOverride params.TokenOverride) []params.Network { + updatedNetworks := deepCopyNetworks(networks) + + for i := range updatedNetworks { + network := &updatedNetworks[i] + + if network.ChainID != chainID { + continue + } + for j := range network.RpcProviders { + if ganacheURL != "" { + network.RpcProviders[j].URL = ganacheURL + } + } + network.TokenOverrides = []params.TokenOverride{tokenOverride} + } + return updatedNetworks +} + +func extractHost(providerURL string) (string, error) { + parsedURL, err := url.Parse(providerURL) + if err != nil { + return "", err + } + return parsedURL.Host, nil +} diff --git a/params/networkhelper/provider_utils_test.go b/params/networkhelper/provider_utils_test.go new file mode 100644 index 000000000..c6f9caa35 --- /dev/null +++ b/params/networkhelper/provider_utils_test.go @@ -0,0 +1,171 @@ +package networkhelper_test + +import ( + "reflect" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/status-im/status-go/api" + "github.com/status-im/status-go/params" + "github.com/status-im/status-go/params/networkhelper" + "github.com/status-im/status-go/rpc/network/testutil" +) + +func TestMergeProvidersPreserveEnabledAndOrder(t *testing.T) { + chainID := uint64(1) + + // Current providers with mixed types + currentProviders := []params.RpcProvider{ + *params.NewUserProvider(chainID, "UserProvider1", "https://userprovider1.example.com", true), + *params.NewUserProvider(chainID, "UserProvider2", "https://userprovider2.example.com", true), + *params.NewDirectProvider(chainID, "EmbeddedProvider1", "https://embeddedprovider1.example.com", true), + *params.NewProxyProvider(chainID, "EmbeddedProvider2", "https://embeddedprovider2.example.com", true), + } + currentProviders[1].Enabled = false // UserProvider2 is disabled + currentProviders[2].Enabled = false // EmbeddedProvider1 is disabled + + // New providers to merge + newProviders := []params.RpcProvider{ + *params.NewDirectProvider(chainID, "EmbeddedProvider1", "https://embeddedprovider1-new.example.com", true), // Should retain Enabled: false + *params.NewProxyProvider(chainID, "EmbeddedProvider3", "https://embeddedprovider3.example.com", true), // New embedded provider + *params.NewDirectProvider(chainID, "EmbeddedProvider4", "https://embeddedprovider4.example.com", true), // New embedded provider + } + + // Call MergeProviders + mergedProviders := networkhelper.MergeProvidersPreservingUsersAndEnabledState(currentProviders, newProviders) + + expectedEmbeddedProvider1 := newProviders[0] + expectedEmbeddedProvider1.Enabled = false // Should retain Enabled: false + // Expected providers after merging + expectedProviders := []params.RpcProvider{ + currentProviders[0], // UserProvider1 + currentProviders[1], // UserProvider2 + expectedEmbeddedProvider1, // EmbeddedProvider1 (should retain Enabled: false) + newProviders[1], // EmbeddedProvider3 + newProviders[2], // EmbeddedProvider4 + } + + // Assertions + require.True(t, reflect.DeepEqual(mergedProviders, expectedProviders), "Merged providers should match the expected providers") +} +func TestUpdateEmbeddedProxyProviders(t *testing.T) { + // Arrange: Create a sample list of networks with various provider types + networks := []params.Network{ + *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ + *params.NewUserProvider(api.MainnetChainID, "Provider1", "https://userprovider.example.com", true), + *params.NewProxyProvider(api.MainnetChainID, "Provider2", "https://proxyprovider.example.com", true), + }), + *testutil.CreateNetwork(api.OptimismChainID, "Optimism", []params.RpcProvider{ + *params.NewDirectProvider(api.OptimismChainID, "Provider3", "https://directprovider.example.com", true), + *params.NewProxyProvider(api.OptimismChainID, "Provider4", "https://proxyprovider2.example.com", true), + }), + } + networks[0].RpcProviders[1].Enabled = false + networks[1].RpcProviders[1].Enabled = false + + user := gofakeit.Username() + password := gofakeit.LetterN(5) + + // Call the function to update embedded-proxy providers + updatedNetworks := networkhelper.OverrideEmbeddedProxyProviders(networks, true, user, password) + + // Verify the networks + for i, network := range updatedNetworks { + networkCopy := network + expectedNetwork := &networks[i] + testutil.CompareNetworks(t, expectedNetwork, &networkCopy) + + for j, provider := range networkCopy.RpcProviders { + expectedProvider := expectedNetwork.RpcProviders[j] + if provider.Type == params.EmbeddedProxyProviderType { + assert.True(t, provider.Enabled, "Provider Enabled state should be overridden") + assert.Equal(t, user, provider.AuthLogin, "Provider AuthLogin should be overridden") + assert.Equal(t, password, provider.AuthPassword, "Provider AuthPassword should be overridden") + } else { + assert.Equal(t, expectedProvider.Enabled, provider.Enabled, "Provider Enabled state should remain unchanged") + assert.Equal(t, expectedProvider.AuthLogin, provider.AuthLogin, "Provider AuthLogin should remain unchanged") + assert.Equal(t, expectedProvider.AuthPassword, provider.AuthPassword, "Provider AuthPassword should remain unchanged") + } + } + } + +} + +func TestOverrideDirectProvidersAuth(t *testing.T) { + // Create a sample list of networks with various provider types + networks := []params.Network{ + *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", []params.RpcProvider{ + *params.NewUserProvider(api.MainnetChainID, "Provider1", "https://user.example.com/", true), + *params.NewDirectProvider(api.MainnetChainID, "Provider2", "https://mainnet.infura.io/v3/", true), + *params.NewDirectProvider(api.MainnetChainID, "Provider3", "https://eth-archival.rpc.grove.city/v1/", true), + }), + *testutil.CreateNetwork(api.OptimismChainID, "Optimism", []params.RpcProvider{ + *params.NewDirectProvider(api.OptimismChainID, "Provider4", "https://optimism.infura.io/v3/", true), + *params.NewDirectProvider(api.OptimismChainID, "Provider5", "https://op.grove.city/v1/", true), + }), + } + + authTokens := map[string]string{ + "infura.io": gofakeit.UUID(), + "grove.city": gofakeit.UUID(), + "example.com": gofakeit.UUID(), + } + + // Call OverrideDirectProvidersAuth + updatedNetworks := networkhelper.OverrideDirectProvidersAuth(networks, authTokens) + + // Verify the networks have updated auth tokens correctly + for i, network := range updatedNetworks { + for j, provider := range network.RpcProviders { + expectedProvider := networks[i].RpcProviders[j] + switch { + case strings.Contains(provider.URL, "infura.io"): + assert.Equal(t, params.TokenAuth, provider.AuthType) + assert.Equal(t, authTokens["infura.io"], provider.AuthToken) + assert.NotEqual(t, expectedProvider.AuthToken, provider.AuthToken) + case strings.Contains(provider.URL, "grove.city"): + assert.Equal(t, params.TokenAuth, provider.AuthType) + assert.Equal(t, authTokens["grove.city"], provider.AuthToken) + assert.NotEqual(t, expectedProvider.AuthToken, provider.AuthToken) + case strings.Contains(provider.URL, "example.com"): + assert.Equal(t, params.NoAuth, provider.AuthType) // should not update user providers + default: + assert.Equal(t, expectedProvider.AuthType, provider.AuthType) + assert.Equal(t, expectedProvider.AuthToken, provider.AuthToken) + } + } + } +} + +func TestOverrideGanacheTokenOverrides(t *testing.T) { + // Create a sample list of networks with various ChainIDs + networks := []params.Network{ + *testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", nil), + *testutil.CreateNetwork(api.OptimismChainID, "Optimism", nil), + *testutil.CreateNetwork(api.MainnetChainID, "Mainnet Duplicate", nil), + } + + ganacheTokenOverride := params.TokenOverride{ + Symbol: "SNT", + Address: common.HexToAddress("0x8571Ddc46b10d31EF963aF49b6C7799Ea7eff818"), + } + + // Call OverrideGanacheTokenOverrides + updatedNetworks := networkhelper.OverrideGanacheToken(networks, "url", api.MainnetChainID, ganacheTokenOverride) + + // Verify that only networks with the specified ChainID have the token override applied + for _, network := range updatedNetworks { + if network.ChainID == api.MainnetChainID { + require.NotNil(t, network.TokenOverrides, "TokenOverrides should not be nil for ChainID %d", network.ChainID) + assert.Contains(t, network.TokenOverrides, ganacheTokenOverride, "TokenOverrides should contain the ganache token") + } else { + assert.Nil(t, network.TokenOverrides, "TokenOverrides should be nil for ChainID %d", network.ChainID) + } + } +} diff --git a/params/networkhelper/validate.go b/params/networkhelper/validate.go new file mode 100644 index 000000000..2b925773c --- /dev/null +++ b/params/networkhelper/validate.go @@ -0,0 +1,49 @@ +package networkhelper + +import ( + "gopkg.in/go-playground/validator.v9" + + "github.com/status-im/status-go/params" +) + +func GetValidator() *validator.Validate { + validate := validator.New() + + // Register struct-level validation for RpcProvider + validate.RegisterStructValidation(rpcProviderStructLevelValidation, params.RpcProvider{}) + + return validate +} + +func rpcProviderStructLevelValidation(sl validator.StructLevel) { + provider := sl.Current().Interface().(params.RpcProvider) + + switch provider.AuthType { + case params.NoAuth: + if provider.AuthLogin != "" || provider.AuthPassword != "" || provider.AuthToken != "" { + sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "noauth_fields_empty", "") + sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "noauth_fields_empty", "") + sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "noauth_fields_empty", "") + } + case params.BasicAuth: + if provider.AuthLogin == "" { + sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "required", "") + } + if provider.AuthPassword == "" { + sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "required", "") + } + if provider.AuthToken != "" { + sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "basic_auth_token_empty", "") + } + case params.TokenAuth: + if provider.AuthToken == "" { + sl.ReportError(provider.AuthToken, "AuthToken", "authToken", "required", "") + } + if provider.AuthLogin != "" || provider.AuthPassword != "" { + sl.ReportError(provider.AuthLogin, "AuthLogin", "authLogin", "tokenauth_fields_empty", "") + sl.ReportError(provider.AuthPassword, "AuthPassword", "authPassword", "tokenauth_fields_empty", "") + } + default: + sl.ReportError(provider.AuthType, "AuthType", "authType", "invalid_auth_type", "") + } +} diff --git a/params/networkhelper/validate_test.go b/params/networkhelper/validate_test.go new file mode 100644 index 000000000..95d0f1c03 --- /dev/null +++ b/params/networkhelper/validate_test.go @@ -0,0 +1,184 @@ +package networkhelper + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "gopkg.in/go-playground/validator.v9" + + "github.com/status-im/status-go/params" +) + +func TestValidation(t *testing.T) { + validate := GetValidator() + + // Test cases for RpcProvider + providerTests := []struct { + name string + provider params.RpcProvider + expectErr bool + }{ + { + name: "Valid Provider", + provider: params.RpcProvider{ + ChainID: 1, + Name: "Mainnet Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + Enabled: true, + AuthType: params.NoAuth, + EnableRPSLimiter: false, + }, + expectErr: false, + }, + { + name: "Missing Provider Name", + provider: params.RpcProvider{ + ChainID: 1, + URL: "https://provider.example.com", + Type: params.UserProviderType, + }, + expectErr: true, + }, + { + name: "Invalid AuthType", + provider: params.RpcProvider{ + ChainID: 1, + Name: "Invalid Auth Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + AuthType: "invalid-auth-type", + }, + expectErr: true, + }, + { + name: "BasicAuth without Login", + provider: params.RpcProvider{ + ChainID: 1, + Name: "BasicAuth Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + AuthType: params.BasicAuth, + }, + expectErr: true, + }, + { + name: "TokenAuth without Token", + provider: params.RpcProvider{ + ChainID: 1, + Name: "TokenAuth Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + AuthType: params.TokenAuth, + }, + expectErr: true, + }, + { + name: "NoAuth with Login", + provider: params.RpcProvider{ + ChainID: 1, + Name: "NoAuth Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + AuthType: params.NoAuth, + AuthLogin: "user", + }, + expectErr: true, + }, + } + + for _, test := range providerTests { + t.Run("RpcProvider: "+test.name, func(t *testing.T) { + err := validate.Struct(test.provider) + if test.expectErr { + require.Error(t, err, "Expected error but got nil for test case '%s'", test.name) + } else { + require.NoError(t, err, "Did not expect error but got '%v' for test case '%s'", err, test.name) + } + }) + } +} + +// Test cases for Network +func TestNetworkValidation(t *testing.T) { + validate := validator.New() + + networkTests := []struct { + name string + network params.Network + expectErr bool + }{ + { + name: "Valid Network", + network: params.Network{ + ChainID: 1, + ChainName: "Ethereum Mainnet", + BlockExplorerURL: "https://etherscan.io", + NativeCurrencyName: "Ether", + NativeCurrencySymbol: "ETH", + NativeCurrencyDecimals: 18, + IsTest: false, + Layer: 1, + Enabled: true, + ChainColor: "#E90101", + ShortName: "eth", + RpcProviders: []params.RpcProvider{ + { + ChainID: 1, + Name: "Mainnet Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + Enabled: true, + AuthType: params.NoAuth, + }, + }, + }, + expectErr: false, + }, + { + name: "Missing Chain Name", + network: params.Network{ + ChainID: 1, + RpcProviders: []params.RpcProvider{ + { + ChainID: 1, + Name: "Mainnet Provider", + URL: "https://provider.example.com", + Type: params.UserProviderType, + Enabled: true, + }, + }, + }, + expectErr: true, + }, + { + name: "Invalid Provider in Network", + network: params.Network{ + ChainID: 1, + ChainName: "Ethereum Mainnet", + RpcProviders: []params.RpcProvider{ + { + ChainID: 1, + Name: "", + URL: "https://provider.example.com", + Type: params.UserProviderType, + Enabled: true, + }, + }, + }, + expectErr: true, + }, + } + + for _, test := range networkTests { + t.Run("Network: "+test.name, func(t *testing.T) { + err := validate.Struct(test.network) + if test.expectErr { + require.Error(t, err, "Expected error but got nil for test case '%s'", test.name) + } else { + require.NoError(t, err, "Did not expect error but got '%v' for test case '%s'", err, test.name) + } + }) + } +} diff --git a/rpc/network/db/network_db.go b/rpc/network/db/network_db.go new file mode 100644 index 000000000..885a72243 --- /dev/null +++ b/rpc/network/db/network_db.go @@ -0,0 +1,257 @@ +package db + +import ( + "fmt" + + sq "github.com/Masterminds/squirrel" + "gopkg.in/go-playground/validator.v9" + + "github.com/status-im/status-go/params" + "github.com/status-im/status-go/sqlite" +) + +// NetworksPersistenceInterface describes the interface for managing networks and providers. +type NetworksPersistenceInterface interface { + // GetNetworks returns networks based on filters. + GetNetworks(onlyEnabled bool, chainID *uint64) ([]*params.Network, error) + // GetAllNetworks returns all networks. + GetAllNetworks() ([]*params.Network, error) + // GetNetworkByChainID returns a network by ChainID. + GetNetworkByChainID(chainID uint64) ([]*params.Network, error) + // GetEnabledNetworks returns enabled networks. + GetEnabledNetworks() ([]*params.Network, error) + // SetNetworks replaces all networks with new ones and their providers. + SetNetworks(networks []params.Network) error + // UpsertNetwork adds or updates a network and its providers. + UpsertNetwork(network *params.Network) error + // DeleteNetwork deletes a network by ChainID and its providers. + DeleteNetwork(chainID uint64) error + // DeleteAllNetworks deletes all networks and their providers. + DeleteAllNetworks() error + + GetRpcPersistence() RpcProvidersPersistenceInterface +} + +// NetworksPersistence manages networks and their providers. +type NetworksPersistence struct { + db sqlite.StatementExecutor + rpcPersistence RpcProvidersPersistenceInterface + validator *validator.Validate +} + +// NewNetworksPersistence creates a new instance of NetworksPersistence. +func NewNetworksPersistence(db sqlite.StatementExecutor) *NetworksPersistence { + return &NetworksPersistence{ + db: db, + rpcPersistence: NewRpcProvidersPersistence(db), + validator: validator.New(), + } +} + +// GetRpcPersistence returns an instance of RpcProvidersPersistenceInterface. +func (n *NetworksPersistence) GetRpcPersistence() RpcProvidersPersistenceInterface { + return n.rpcPersistence +} + +// GetNetworks returns networks based on filters. +func (n *NetworksPersistence) GetNetworks(onlyEnabled bool, chainID *uint64) ([]*params.Network, error) { + q := sq.Select( + "chain_id", "chain_name", "rpc_url", "fallback_url", + "block_explorer_url", "icon_url", "native_currency_name", "native_currency_symbol", "native_currency_decimals", + "is_test", "layer", "enabled", "chain_color", "short_name", "related_chain_id", + ). + From("networks"). + OrderBy("chain_id ASC") + + if onlyEnabled { + q = q.Where(sq.Eq{"enabled": true}) + } + if chainID != nil { + q = q.Where(sq.Eq{"chain_id": *chainID}) + } + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + rows, err := n.db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make([]*params.Network, 0, 10) + for rows.Next() { + network := ¶ms.Network{} + err := rows.Scan( + &network.ChainID, &network.ChainName, &network.RPCURL, &network.FallbackURL, + &network.BlockExplorerURL, &network.IconURL, &network.NativeCurrencyName, &network.NativeCurrencySymbol, + &network.NativeCurrencyDecimals, &network.IsTest, &network.Layer, &network.Enabled, &network.ChainColor, + &network.ShortName, &network.RelatedChainID, + ) + if err != nil { + return nil, err + } + + // Fetch RPC providers for the network + providers, err := n.rpcPersistence.GetRpcProviders(network.ChainID) + if err != nil { + return nil, fmt.Errorf("failed to fetch RPC providers for chain_id %d: %w", network.ChainID, err) + } + network.RpcProviders = providers + + // Fill deprecated URLs if necessary (assuming this is a function you have) + FillDeprecatedURLs(network, providers) + + result = append(result, network) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +// GetAllNetworks returns all networks. +func (n *NetworksPersistence) GetAllNetworks() ([]*params.Network, error) { + return n.GetNetworks(false, nil) +} + +// GetNetworkByChainID returns a network by ChainID. +func (n *NetworksPersistence) GetNetworkByChainID(chainID uint64) ([]*params.Network, error) { + return n.GetNetworks(false, &chainID) +} + +// GetEnabledNetworks returns enabled networks. +func (n *NetworksPersistence) GetEnabledNetworks() ([]*params.Network, error) { + return n.GetNetworks(true, nil) +} + +// SetNetworks replaces all networks with new ones and their providers. +// Note: Transaction management should be handled by the caller. +func (n *NetworksPersistence) SetNetworks(networks []params.Network) error { + // Delete all networks and their providers + err := n.DeleteAllNetworks() + if err != nil { + return err + } + + // Upsert networks and their providers + for i := range networks { + err := n.UpsertNetwork(&networks[i]) + if err != nil { + return fmt.Errorf("failed to upsert network with chain_id %d: %w", networks[i].ChainID, err) + } + } + + return nil +} + +// DeleteAllNetworks deletes all networks and their RPC providers. +// Note: Transaction management should be handled by the caller. +func (n *NetworksPersistence) DeleteAllNetworks() error { + // Delete all RPC providers + err := n.rpcPersistence.DeleteAllRpcProviders() + if err != nil { + return fmt.Errorf("failed to delete all RPC providers: %w", err) + } + + // Delete all networks + q := sq.Delete("networks") + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + _, err = n.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute delete query: %w", err) + } + + return nil +} + +// UpsertNetwork adds or updates a network and its providers. +// Note: Transaction management should be handled by the caller. +func (n *NetworksPersistence) UpsertNetwork(network *params.Network) error { + // Validate the network + if err := n.validator.Struct(network); err != nil { + return fmt.Errorf("network validation failed: %w", err) + } + + // Upsert the network + err := n.upsertNetwork(network) + if err != nil { + return fmt.Errorf("failed to upsert network for chain_id %d: %w", network.ChainID, err) + } + + // Set the RPC providers + err = n.rpcPersistence.SetRpcProviders(network.ChainID, network.RpcProviders) + if err != nil { + return fmt.Errorf("failed to set RPC providers for chain_id %d: %w", network.ChainID, err) + } + + return nil +} + +// upsertNetwork handles the logic for inserting or updating a network record. +func (n *NetworksPersistence) upsertNetwork(network *params.Network) error { + q := sq.Insert("networks"). + Columns( + "chain_id", "chain_name", "rpc_url", "original_rpc_url", "fallback_url", "original_fallback_url", + "block_explorer_url", "icon_url", "native_currency_name", "native_currency_symbol", "native_currency_decimals", + "is_test", "layer", "enabled", "chain_color", "short_name", "related_chain_id", + ). + Values( + network.ChainID, network.ChainName, network.RPCURL, network.OriginalRPCURL, network.FallbackURL, network.OriginalFallbackURL, + network.BlockExplorerURL, network.IconURL, network.NativeCurrencyName, network.NativeCurrencySymbol, network.NativeCurrencyDecimals, + network.IsTest, network.Layer, network.Enabled, network.ChainColor, network.ShortName, network.RelatedChainID, + ). + Suffix("ON CONFLICT(chain_id) DO UPDATE SET " + + "chain_name = excluded.chain_name, rpc_url = excluded.rpc_url, original_rpc_url = excluded.original_rpc_url, " + + "fallback_url = excluded.fallback_url, original_fallback_url = excluded.original_fallback_url, " + + "block_explorer_url = excluded.block_explorer_url, icon_url = excluded.icon_url, " + + "native_currency_name = excluded.native_currency_name, native_currency_symbol = excluded.native_currency_symbol, " + + "native_currency_decimals = excluded.native_currency_decimals, is_test = excluded.is_test, " + + "layer = excluded.layer, enabled = excluded.enabled, chain_color = excluded.chain_color, " + + "short_name = excluded.short_name, related_chain_id = excluded.related_chain_id") + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build upsert query: %w", err) + } + + _, err = n.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute upsert query: %w", err) + } + + return nil +} + +// DeleteNetwork deletes a network and its associated RPC providers. +// Note: Transaction management should be handled by the caller. +func (n *NetworksPersistence) DeleteNetwork(chainID uint64) error { + // Delete RPC providers + err := n.rpcPersistence.DeleteRpcProviders(chainID) + if err != nil { + return fmt.Errorf("failed to delete RPC providers for chain_id %d: %w", chainID, err) + } + + // Delete the network record + q := sq.Delete("networks").Where(sq.Eq{"chain_id": chainID}) + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + _, err = n.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute delete query for chain_id %d: %w", chainID, err) + } + + return nil +} diff --git a/rpc/network/db/network_db_test.go b/rpc/network/db/network_db_test.go new file mode 100644 index 000000000..4587c529b --- /dev/null +++ b/rpc/network/db/network_db_test.go @@ -0,0 +1,197 @@ +package db_test + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/api" + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/params" + "github.com/status-im/status-go/rpc/network/db" + "github.com/status-im/status-go/rpc/network/testutil" + "github.com/status-im/status-go/t/helpers" +) + +type NetworksPersistenceTestSuite struct { + suite.Suite + db *sql.DB + cleanup func() error + networksPersistence db.NetworksPersistenceInterface +} + +func (s *NetworksPersistenceTestSuite) SetupTest() { + memDb, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "networks-tests") + s.Require().NoError(err) + s.db = memDb + s.cleanup = cleanup + s.networksPersistence = db.NewNetworksPersistence(memDb) +} + +func (s *NetworksPersistenceTestSuite) TearDownTest() { + if s.cleanup != nil { + err := s.cleanup() + require.NoError(s.T(), err) + } +} + +func TestNetworksPersistenceTestSuite(t *testing.T) { + suite.Run(t, new(NetworksPersistenceTestSuite)) +} + +// Helper function to create default providers for a given chainID +func DefaultProviders(chainID uint64) []params.RpcProvider { + return []params.RpcProvider{ + { + Name: "Provider1", + ChainID: chainID, + URL: "https://rpc.provider1.io", + Type: params.UserProviderType, + Enabled: true, + AuthType: params.NoAuth, + }, + { + Name: "Provider2", + ChainID: chainID, + URL: "https://rpc.provider2.io", + Type: params.EmbeddedProxyProviderType, + Enabled: true, + AuthType: params.BasicAuth, + AuthLogin: "user1", + AuthPassword: "password1", + }, + } +} + +// Helper function to add and verify networks +func (s *NetworksPersistenceTestSuite) addAndVerifyNetworks(networks []*params.Network) { + networkValues := make([]params.Network, 0, len(networks)) + for _, network := range networks { + networkValues = append(networkValues, *network) + } + err := s.networksPersistence.SetNetworks(networkValues) + s.Require().NoError(err) + + s.verifyNetworks(networks) +} + +// Helper function to verify networks against the database +func (s *NetworksPersistenceTestSuite) verifyNetworks(networks []*params.Network) { + allNetworks, err := s.networksPersistence.GetAllNetworks() + s.Require().NoError(err) + testutil.CompareNetworksList(s.T(), networks, allNetworks) +} + +// Helper function to verify network deletion +func (s *NetworksPersistenceTestSuite) verifyNetworkDeletion(chainID uint64) { + nets, err := s.networksPersistence.GetNetworkByChainID(chainID) + s.Require().NoError(err) + s.Require().Len(nets, 0) + + providers, err := s.networksPersistence.GetRpcPersistence().GetRpcProviders(chainID) + s.Require().NoError(err) + s.Require().Len(providers, 0) +} + +// Tests + +func (s *NetworksPersistenceTestSuite) TestAddAndGetNetworkWithProviders() { + network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", []params.RpcProvider{ + testutil.CreateProvider(api.OptimismChainID, "Provider1", params.UserProviderType, true, "https://rpc.optimism.io"), + testutil.CreateProvider(api.OptimismChainID, "Provider2", params.EmbeddedProxyProviderType, false, "https://backup.optimism.io"), + }) + s.addAndVerifyNetworks([]*params.Network{network}) +} + +func (s *NetworksPersistenceTestSuite) TestDeleteNetworkWithProviders() { + network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID)) + s.addAndVerifyNetworks([]*params.Network{network}) + + err := s.networksPersistence.DeleteNetwork(network.ChainID) + s.Require().NoError(err) + + s.verifyNetworkDeletion(network.ChainID) +} + +func (s *NetworksPersistenceTestSuite) TestUpdateNetworkAndProviders() { + network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID)) + s.addAndVerifyNetworks([]*params.Network{network}) + + // Update fields + network.ChainName = "Updated Optimism Mainnet" + network.RpcProviders = []params.RpcProvider{ + testutil.CreateProvider(api.OptimismChainID, "UpdatedProvider", params.UserProviderType, true, "https://rpc.optimism.updated.io"), + } + + s.addAndVerifyNetworks([]*params.Network{network}) +} + +func (s *NetworksPersistenceTestSuite) TestDeleteAllNetworks() { + networks := []*params.Network{ + testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", DefaultProviders(api.MainnetChainID)), + testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", DefaultProviders(api.SepoliaChainID)), + } + s.addAndVerifyNetworks(networks) + + err := s.networksPersistence.DeleteAllNetworks() + s.Require().NoError(err) + + allNetworks, err := s.networksPersistence.GetAllNetworks() + s.Require().NoError(err) + s.Require().Len(allNetworks, 0) +} + +func (s *NetworksPersistenceTestSuite) TestSetNetworks() { + initialNetworks := []*params.Network{ + testutil.CreateNetwork(api.MainnetChainID, "Ethereum Mainnet", DefaultProviders(api.MainnetChainID)), + testutil.CreateNetwork(api.SepoliaChainID, "Sepolia Testnet", DefaultProviders(api.SepoliaChainID)), + } + newNetworks := []*params.Network{ + testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID)), + } + + // Add initial networks + s.addAndVerifyNetworks(initialNetworks) + + // Replace with new networks + s.addAndVerifyNetworks(newNetworks) + + // Verify old networks are removed + s.verifyNetworkDeletion(api.MainnetChainID) + s.verifyNetworkDeletion(api.SepoliaChainID) +} + +func (s *NetworksPersistenceTestSuite) TestValidationForNetworksAndProviders() { + // Invalid Network: Missing required ChainName + invalidNetwork := testutil.CreateNetwork(api.MainnetChainID, "", DefaultProviders(api.MainnetChainID)) + + // Invalid Provider: Missing URL + invalidProvider := params.RpcProvider{ + Name: "InvalidProvider", + ChainID: api.MainnetChainID, + URL: "", // Invalid + Type: params.UserProviderType, + Enabled: true, + } + + // Add invalid provider to a valid network + validNetworkWithInvalidProvider := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", []params.RpcProvider{invalidProvider}) + + // Invalid networks and providers should fail validation + networksToValidate := []*params.Network{ + invalidNetwork, + validNetworkWithInvalidProvider, + } + + for _, network := range networksToValidate { + err := s.networksPersistence.UpsertNetwork(network) + s.Require().Error(err, "Expected validation to fail for invalid network or provider") + } + + // Ensure no invalid data is saved in the database + allNetworks, err := s.networksPersistence.GetAllNetworks() + s.Require().NoError(err) + s.Require().Len(allNetworks, 0, "No invalid networks should be saved") +} diff --git a/rpc/network/db/rpc_provider_db.go b/rpc/network/db/rpc_provider_db.go new file mode 100644 index 000000000..b9eb775f3 --- /dev/null +++ b/rpc/network/db/rpc_provider_db.go @@ -0,0 +1,243 @@ +package db + +import ( + "fmt" + + sq "github.com/Masterminds/squirrel" + "gopkg.in/go-playground/validator.v9" + + "github.com/status-im/status-go/params" + "github.com/status-im/status-go/sqlite" +) + +// Interface for managing RPC providers +type RpcProvidersPersistenceInterface interface { + GetRpcProviders(chainID uint64) ([]params.RpcProvider, error) + GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error) + AddRpcProvider(provider params.RpcProvider) error + DeleteRpcProviders(chainID uint64) error + DeleteAllRpcProviders() error + UpdateRpcProvider(provider params.RpcProvider) error + SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error +} + +// Struct for managing RPC providers +type RpcProvidersPersistence struct { + db sqlite.StatementExecutor + validator *validator.Validate +} + +// Constructor for RpcProvidersPersistence with validator +func NewRpcProvidersPersistence(db sqlite.StatementExecutor) *RpcProvidersPersistence { + return &RpcProvidersPersistence{ + db: db, + validator: validator.New(), + } +} + +// Retrieve all providers for a specific ChainID +func (p *RpcProvidersPersistence) GetRpcProviders(chainID uint64) ([]params.RpcProvider, error) { + q := sq.Select( + "id", + "chain_id", + "name", + "url", + "enable_rps_limiter", + "type", + "enabled", + "auth_type", + "auth_login", + "auth_password", + "auth_token", + ). + From("rpc_providers"). + Where(sq.Eq{"chain_id": chainID}). + OrderBy("id ASC") + + query, args, err := q.ToSql() + if err != nil { + return nil, fmt.Errorf("failed to build query: %w", err) + } + + rows, err := p.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + defer rows.Close() + + var providers []params.RpcProvider + for rows.Next() { + var provider params.RpcProvider + err := rows.Scan( + &provider.ID, + &provider.ChainID, + &provider.Name, + &provider.URL, + &provider.EnableRPSLimiter, + &provider.Type, + &provider.Enabled, + &provider.AuthType, + &provider.AuthLogin, + &provider.AuthPassword, + &provider.AuthToken, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + providers = append(providers, provider) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("row iteration error: %w", err) + } + + return providers, nil +} + +// Retrieve providers of a specific type +func (p *RpcProvidersPersistence) GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error) { + allProviders, err := p.GetRpcProviders(chainID) + if err != nil { + return nil, err + } + + result := make([]params.RpcProvider, 0, len(allProviders)) + for _, provider := range allProviders { + if provider.Type == providerType { + result = append(result, provider) + } + } + + return result, nil +} + +// Add a new provider +func (p *RpcProvidersPersistence) AddRpcProvider(provider params.RpcProvider) error { + // Validate the provider + if err := p.validator.Struct(provider); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Proceed with adding the provider to the database + q := sq.Insert("rpc_providers"). + Columns( + "chain_id", + "name", + "url", + "enable_rps_limiter", + "type", + "enabled", + "auth_type", + "auth_login", + "auth_password", + "auth_token", + ). + Values( + provider.ChainID, + provider.Name, + provider.URL, + provider.EnableRPSLimiter, + provider.Type, + provider.Enabled, + provider.AuthType, + provider.AuthLogin, + provider.AuthPassword, + provider.AuthToken, + ) + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build insert query: %w", err) + } + + _, err = p.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute insert query: %w", err) + } + + return nil +} + +// Delete providers for a specific ChainID +func (p *RpcProvidersPersistence) DeleteRpcProviders(chainID uint64) error { + q := sq.Delete("rpc_providers"). + Where(sq.Eq{"chain_id": chainID}) + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + _, err = p.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute delete query: %w", err) + } + + return nil +} + +// Delete all providers +func (p *RpcProvidersPersistence) DeleteAllRpcProviders() error { + q := sq.Delete("rpc_providers") + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + _, err = p.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute delete query: %w", err) + } + + return nil +} + +// Update an existing provider +func (p *RpcProvidersPersistence) UpdateRpcProvider(provider params.RpcProvider) error { + // Validate the provider + if err := p.validator.Struct(provider); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Proceed with updating the provider in the database + q := sq.Update("rpc_providers"). + SetMap(sq.Eq{ + "url": provider.URL, + "enable_rps_limiter": provider.EnableRPSLimiter, + "type": provider.Type, + "enabled": provider.Enabled, + "auth_type": provider.AuthType, + "auth_login": provider.AuthLogin, + "auth_password": provider.AuthPassword, + "auth_token": provider.AuthToken, + }). + Where(sq.Eq{"id": provider.ID}) + + query, args, err := q.ToSql() + if err != nil { + return fmt.Errorf("failed to build update query: %w", err) + } + + _, err = p.db.Exec(query, args...) + if err != nil { + return fmt.Errorf("failed to execute update query: %w", err) + } + + return nil +} + +// Set the list of providers for a ChainID, replacing any existing providers +func (p *RpcProvidersPersistence) SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error { + if err := p.DeleteRpcProviders(chainID); err != nil { + return fmt.Errorf("failed to delete existing providers: %w", err) + } + + for _, provider := range newProviders { + if err := p.AddRpcProvider(provider); err != nil { + return fmt.Errorf("failed to add new provider: %w", err) + } + } + + return nil +} diff --git a/rpc/network/db/rpc_provider_db_test.go b/rpc/network/db/rpc_provider_db_test.go new file mode 100644 index 000000000..624e8f59b --- /dev/null +++ b/rpc/network/db/rpc_provider_db_test.go @@ -0,0 +1,157 @@ +package db_test + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/api" + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/params" + "github.com/status-im/status-go/rpc/network/db" + "github.com/status-im/status-go/rpc/network/testutil" + "github.com/status-im/status-go/t/helpers" +) + +type RpcProviderPersistenceTestSuite struct { + suite.Suite + db *sql.DB + rpcPersistence *db.RpcProvidersPersistence +} + +func (s *RpcProviderPersistenceTestSuite) SetupTest() { + testDb := setupTestNetworkDB(s.T()) + s.db = testDb + s.rpcPersistence = db.NewRpcProvidersPersistence(testDb) +} + +func setupTestNetworkDB(t *testing.T) *sql.DB { + testDb, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "rpc-providers-tests") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, cleanup()) }) + return testDb +} + +func TestRpcProviderPersistenceTestSuite(t *testing.T) { + suite.Run(t, new(RpcProviderPersistenceTestSuite)) +} + +// Test cases + +func (s *RpcProviderPersistenceTestSuite) TestAddAndGetRpcProvider() { + provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com") + + err := s.rpcPersistence.AddRpcProvider(provider) + s.Require().NoError(err) + + // Verify the added provider + providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID) + s.Require().NoError(err) + testutil.CompareProvidersList(s.T(), []params.RpcProvider{provider}, providers) +} + +func (s *RpcProviderPersistenceTestSuite) TestGetRpcProvidersByType() { + providers := []params.RpcProvider{ + testutil.CreateProvider(api.MainnetChainID, "UserProvider1", params.UserProviderType, true, "https://provider1.example.com"), + testutil.CreateProvider(api.MainnetChainID, "EmbeddedDirect1", params.EmbeddedDirectProviderType, false, "https://provider2.example.com"), + testutil.CreateProvider(api.MainnetChainID, "UserProvider2", params.UserProviderType, false, "https://provider3.example.com"), + testutil.CreateProvider(api.MainnetChainID, "EmbeddedProxy1", params.EmbeddedProxyProviderType, true, "https://provider4.example.com"), + } + + for _, provider := range providers { + err := s.rpcPersistence.AddRpcProvider(provider) + s.Require().NoError(err) + } + + // Verify by type + userProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.UserProviderType) + s.Require().NoError(err) + testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[0], providers[2]}, userProviders) + + embeddedDirectProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.EmbeddedDirectProviderType) + s.Require().NoError(err) + testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[1]}, embeddedDirectProviders) + + embeddedProxyProviders, err := s.rpcPersistence.GetRpcProvidersByType(api.MainnetChainID, params.EmbeddedProxyProviderType) + s.Require().NoError(err) + testutil.CompareProvidersList(s.T(), []params.RpcProvider{providers[3]}, embeddedProxyProviders) +} + +func (s *RpcProviderPersistenceTestSuite) TestDeleteRpcProviders() { + provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com") + + err := s.rpcPersistence.AddRpcProvider(provider) + s.Require().NoError(err) + + err = s.rpcPersistence.DeleteRpcProviders(api.MainnetChainID) + s.Require().NoError(err) + + // Verify deletion + providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID) + s.Require().NoError(err) + s.Require().Empty(providers) +} + +func (s *RpcProviderPersistenceTestSuite) TestUpdateRpcProvider() { + provider := testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com") + + err := s.rpcPersistence.AddRpcProvider(provider) + s.Require().NoError(err) + + // Retrieve provider to get the ID + providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID) + s.Require().NoError(err) + s.Require().Len(providers, 1) + + provider.ID = providers[0].ID + provider.URL = "https://provider1-updated.example.com" + provider.EnableRPSLimiter = false + + err = s.rpcPersistence.UpdateRpcProvider(provider) + s.Require().NoError(err) + + // Verify update + updatedProviders, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID) + s.Require().NoError(err) + testutil.CompareProvidersList(s.T(), []params.RpcProvider{provider}, updatedProviders) +} + +func (s *RpcProviderPersistenceTestSuite) TestSetRpcProviders() { + initialProviders := []params.RpcProvider{ + testutil.CreateProvider(api.MainnetChainID, "Provider1", params.UserProviderType, true, "https://provider1.example.com"), + testutil.CreateProvider(api.MainnetChainID, "Provider2", params.EmbeddedDirectProviderType, false, "https://provider2.example.com"), + } + + for _, provider := range initialProviders { + err := s.rpcPersistence.AddRpcProvider(provider) + s.Require().NoError(err) + } + + newProviders := []params.RpcProvider{ + testutil.CreateProvider(api.MainnetChainID, "NewProvider1", params.UserProviderType, true, "https://newprovider1.example.com"), + testutil.CreateProvider(api.MainnetChainID, "NewProvider2", params.EmbeddedProxyProviderType, true, "https://newprovider2.example.com"), + } + + err := s.rpcPersistence.SetRpcProviders(api.MainnetChainID, newProviders) + s.Require().NoError(err) + + // Verify replacement + providers, err := s.rpcPersistence.GetRpcProviders(api.MainnetChainID) + s.Require().NoError(err) + testutil.CompareProvidersList(s.T(), newProviders, providers) +} + +func (s *RpcProviderPersistenceTestSuite) TestAddRpcProviderValidation() { + invalidProvider := params.RpcProvider{ + ChainID: 0, // Invalid: must be greater than 0 + Name: "", // Invalid: cannot be empty + URL: "invalid-url", // Invalid: not a valid URL + Type: "invalid-type", // Invalid: not in allowed values + } + + err := s.rpcPersistence.AddRpcProvider(invalidProvider) + s.Require().Error(err) + s.Contains(err.Error(), "validation failed") +} diff --git a/rpc/network/db/utils.go b/rpc/network/db/utils.go new file mode 100644 index 000000000..38745f331 --- /dev/null +++ b/rpc/network/db/utils.go @@ -0,0 +1,57 @@ +package db + +import ( + "github.com/status-im/status-go/params" +) + +// Deprecated: fillDeprecatedURLs populates the `original_rpc_url`, `original_fallback_url`, `rpc_url`, +// `fallback_url`, `defaultRpcUrl`, `defaultFallbackURL`, and `defaultFallbackURL2` fields. +// Keep for backwrad compatibility until it's fully integrated +func FillDeprecatedURLs(network *params.Network, providers []params.RpcProvider) { + var embeddedDirect []params.RpcProvider + var embeddedProxy []params.RpcProvider + var userProviders []params.RpcProvider + + // Categorize providers + for _, provider := range providers { + switch provider.Type { + case params.EmbeddedDirectProviderType: + embeddedDirect = append(embeddedDirect, provider) + case params.EmbeddedProxyProviderType: + embeddedProxy = append(embeddedProxy, provider) + case params.UserProviderType: + userProviders = append(userProviders, provider) + } + } + + // Set original_*_url fields based on EmbeddedDirectProviderType providers + if len(embeddedDirect) > 0 { + network.OriginalRPCURL = embeddedDirect[0].URL + if len(embeddedDirect) > 1 { + network.OriginalFallbackURL = embeddedDirect[1].URL + } + } + + // Set rpc_url and fallback_url based on User providers or EmbeddedDirectProviderType if no User providers exist + if len(userProviders) > 0 { + network.RPCURL = userProviders[0].URL + if len(userProviders) > 1 { + network.FallbackURL = userProviders[1].URL + } + } else { + // Default to EmbeddedDirectProviderType providers if no User providers exist + network.RPCURL = network.OriginalRPCURL + network.FallbackURL = network.OriginalFallbackURL + } + + // Set default_*_url fields based on EmbeddedProxyProviderType providers + if len(embeddedProxy) > 0 { + network.DefaultRPCURL = embeddedProxy[0].URL + if len(embeddedProxy) > 1 { + network.DefaultFallbackURL = embeddedProxy[1].URL + } + if len(embeddedProxy) > 2 { + network.DefaultFallbackURL2 = embeddedProxy[2].URL + } + } +} diff --git a/rpc/network/testutil/testutil.go b/rpc/network/testutil/testutil.go new file mode 100644 index 000000000..c3730027f --- /dev/null +++ b/rpc/network/testutil/testutil.go @@ -0,0 +1,104 @@ +package testutil + +import ( + "github.com/stretchr/testify/require" + + "github.com/status-im/status-go/api" + "github.com/status-im/status-go/params" +) + +// Helper function to create a provider +func CreateProvider(chainID uint64, name string, providerType params.RpcProviderType, enabled bool, url string) params.RpcProvider { + return params.RpcProvider{ + ChainID: chainID, + Name: name, + URL: url, + EnableRPSLimiter: true, + Type: providerType, + Enabled: enabled, + AuthType: params.BasicAuth, + AuthLogin: "user1", + AuthPassword: "password1", + AuthToken: "", + } +} + +// Helper function to create a network +func CreateNetwork(chainID uint64, chainName string, providers []params.RpcProvider) *params.Network { + return ¶ms.Network{ + ChainID: chainID, + ChainName: chainName, + BlockExplorerURL: "https://explorer.example.com", + IconURL: "network/Network=" + chainName, + NativeCurrencyName: "Ether", + NativeCurrencySymbol: "ETH", + NativeCurrencyDecimals: 18, + IsTest: false, + Layer: 2, + Enabled: true, + ChainColor: "#E90101", + ShortName: "eth", + RelatedChainID: api.OptimismSepoliaChainID, + RpcProviders: providers, + } +} + +// Helper function to compare two providers +func CompareProviders(t require.TestingT, expected, actual params.RpcProvider) { + require.Equal(t, expected.ChainID, actual.ChainID) + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.URL, actual.URL) + require.Equal(t, expected.EnableRPSLimiter, actual.EnableRPSLimiter) + require.Equal(t, expected.Type, actual.Type) + require.Equal(t, expected.Enabled, actual.Enabled) + require.Equal(t, expected.AuthType, actual.AuthType) + require.Equal(t, expected.AuthLogin, actual.AuthLogin) + require.Equal(t, expected.AuthPassword, actual.AuthPassword) + require.Equal(t, expected.AuthToken, actual.AuthToken) +} + +// Helper function to compare two networks +func CompareNetworks(t require.TestingT, expected, actual *params.Network) { + require.Equal(t, expected.ChainID, actual.ChainID, "ChainID does not match") + require.Equal(t, expected.ChainName, actual.ChainName, "ChainName does not match for ChainID %d", actual.ChainID) + require.Equal(t, expected.BlockExplorerURL, actual.BlockExplorerURL) + require.Equal(t, expected.NativeCurrencyName, actual.NativeCurrencyName) + require.Equal(t, expected.NativeCurrencySymbol, actual.NativeCurrencySymbol) + require.Equal(t, expected.NativeCurrencyDecimals, actual.NativeCurrencyDecimals) + require.Equal(t, expected.IsTest, actual.IsTest) + require.Equal(t, expected.Layer, actual.Layer) + require.Equal(t, expected.Enabled, actual.Enabled) + require.Equal(t, expected.ChainColor, actual.ChainColor) + require.Equal(t, expected.ShortName, actual.ShortName) + require.Equal(t, expected.RelatedChainID, actual.RelatedChainID) +} + +// Helper function to compare lists of providers +func CompareProvidersList(t require.TestingT, expectedProviders, actualProviders []params.RpcProvider) { + require.Len(t, actualProviders, len(expectedProviders)) + expectedMap := make(map[string]params.RpcProvider, len(expectedProviders)) + for _, provider := range expectedProviders { + expectedMap[provider.Name] = provider + } + + for _, provider := range actualProviders { + expectedProvider, exists := expectedMap[provider.Name] + require.True(t, exists, "Unexpected provider '%s'", provider.Name) + CompareProviders(t, expectedProvider, provider) + } +} + +// Helper function to compare lists of networks +func CompareNetworksList(t require.TestingT, expectedNetworks, actualNetworks []*params.Network) { + require.Len(t, actualNetworks, len(expectedNetworks)) + expectedMap := make(map[uint64]*params.Network, len(expectedNetworks)) + for _, network := range expectedNetworks { + expectedMap[network.ChainID] = network + } + + for _, network := range actualNetworks { + expectedNetwork, exists := expectedMap[network.ChainID] + require.True(t, exists, "Unexpected network with ChainID %d", network.ChainID) + CompareNetworks(t, expectedNetwork, network) + } +} diff --git a/sqlite/driver.go b/sqlite/driver.go index 8f178cfce..e1cfd1318 100644 --- a/sqlite/driver.go +++ b/sqlite/driver.go @@ -6,3 +6,9 @@ import "database/sql" type StatementCreator interface { Prepare(query string) (*sql.Stmt, error) } + +type StatementExecutor interface { + StatementCreator + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) +}