status-go/rpc/network/db/rpc_provider_db.go
Andrey Bocharnikov e9abf1662d
chore(config)_: rpc providers configuration (#6151)
* chore(config)_: extract rpc_provider_persistence + tests

* Add rpc_providers table, migration
* add RpcProvider type
* deprecate old rpc fields in networks, add RpcProviders list
* add persistence packages for rpc_providers, networks
* Tests
2025-01-11 02:02:09 +04:00

244 lines
6.0 KiB
Go

package db
import (
"fmt"
sq "github.com/Masterminds/squirrel"
"gopkg.in/go-playground/validator.v9"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/sqlite"
)
// Interface for managing RPC providers
type RpcProvidersPersistenceInterface interface {
GetRpcProviders(chainID uint64) ([]params.RpcProvider, error)
GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error)
AddRpcProvider(provider params.RpcProvider) error
DeleteRpcProviders(chainID uint64) error
DeleteAllRpcProviders() error
UpdateRpcProvider(provider params.RpcProvider) error
SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error
}
// Struct for managing RPC providers
type RpcProvidersPersistence struct {
db sqlite.StatementExecutor
validator *validator.Validate
}
// Constructor for RpcProvidersPersistence with validator
func NewRpcProvidersPersistence(db sqlite.StatementExecutor) *RpcProvidersPersistence {
return &RpcProvidersPersistence{
db: db,
validator: validator.New(),
}
}
// Retrieve all providers for a specific ChainID
func (p *RpcProvidersPersistence) GetRpcProviders(chainID uint64) ([]params.RpcProvider, error) {
q := sq.Select(
"id",
"chain_id",
"name",
"url",
"enable_rps_limiter",
"type",
"enabled",
"auth_type",
"auth_login",
"auth_password",
"auth_token",
).
From("rpc_providers").
Where(sq.Eq{"chain_id": chainID}).
OrderBy("id ASC")
query, args, err := q.ToSql()
if err != nil {
return nil, fmt.Errorf("failed to build query: %w", err)
}
rows, err := p.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
defer rows.Close()
var providers []params.RpcProvider
for rows.Next() {
var provider params.RpcProvider
err := rows.Scan(
&provider.ID,
&provider.ChainID,
&provider.Name,
&provider.URL,
&provider.EnableRPSLimiter,
&provider.Type,
&provider.Enabled,
&provider.AuthType,
&provider.AuthLogin,
&provider.AuthPassword,
&provider.AuthToken,
)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
providers = append(providers, provider)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration error: %w", err)
}
return providers, nil
}
// Retrieve providers of a specific type
func (p *RpcProvidersPersistence) GetRpcProvidersByType(chainID uint64, providerType params.RpcProviderType) ([]params.RpcProvider, error) {
allProviders, err := p.GetRpcProviders(chainID)
if err != nil {
return nil, err
}
result := make([]params.RpcProvider, 0, len(allProviders))
for _, provider := range allProviders {
if provider.Type == providerType {
result = append(result, provider)
}
}
return result, nil
}
// Add a new provider
func (p *RpcProvidersPersistence) AddRpcProvider(provider params.RpcProvider) error {
// Validate the provider
if err := p.validator.Struct(provider); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Proceed with adding the provider to the database
q := sq.Insert("rpc_providers").
Columns(
"chain_id",
"name",
"url",
"enable_rps_limiter",
"type",
"enabled",
"auth_type",
"auth_login",
"auth_password",
"auth_token",
).
Values(
provider.ChainID,
provider.Name,
provider.URL,
provider.EnableRPSLimiter,
provider.Type,
provider.Enabled,
provider.AuthType,
provider.AuthLogin,
provider.AuthPassword,
provider.AuthToken,
)
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build insert query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute insert query: %w", err)
}
return nil
}
// Delete providers for a specific ChainID
func (p *RpcProvidersPersistence) DeleteRpcProviders(chainID uint64) error {
q := sq.Delete("rpc_providers").
Where(sq.Eq{"chain_id": chainID})
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build delete query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute delete query: %w", err)
}
return nil
}
// Delete all providers
func (p *RpcProvidersPersistence) DeleteAllRpcProviders() error {
q := sq.Delete("rpc_providers")
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build delete query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute delete query: %w", err)
}
return nil
}
// Update an existing provider
func (p *RpcProvidersPersistence) UpdateRpcProvider(provider params.RpcProvider) error {
// Validate the provider
if err := p.validator.Struct(provider); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Proceed with updating the provider in the database
q := sq.Update("rpc_providers").
SetMap(sq.Eq{
"url": provider.URL,
"enable_rps_limiter": provider.EnableRPSLimiter,
"type": provider.Type,
"enabled": provider.Enabled,
"auth_type": provider.AuthType,
"auth_login": provider.AuthLogin,
"auth_password": provider.AuthPassword,
"auth_token": provider.AuthToken,
}).
Where(sq.Eq{"id": provider.ID})
query, args, err := q.ToSql()
if err != nil {
return fmt.Errorf("failed to build update query: %w", err)
}
_, err = p.db.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to execute update query: %w", err)
}
return nil
}
// Set the list of providers for a ChainID, replacing any existing providers
func (p *RpcProvidersPersistence) SetRpcProviders(chainID uint64, newProviders []params.RpcProvider) error {
if err := p.DeleteRpcProviders(chainID); err != nil {
return fmt.Errorf("failed to delete existing providers: %w", err)
}
for _, provider := range newProviders {
if err := p.AddRpcProvider(provider); err != nil {
return fmt.Errorf("failed to add new provider: %w", err)
}
}
return nil
}