status-go/rpc/network/network.go

362 lines
10 KiB
Go
Raw Normal View History

package network
//go:generate mockgen -package=mock_network -source=network.go -destination=mock/network.go
import (
"bytes"
"database/sql"
"fmt"
2023-10-03 15:27:42 +02:00
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/params"
)
2024-01-29 11:15:28 +01:00
var SepoliaChainIDs = []uint64{11155111, 421614, 11155420}
2023-10-03 15:27:42 +02:00
type CombinedNetwork struct {
Prod *params.Network
Test *params.Network
}
2023-09-15 10:06:44 +02:00
const baseQuery = "SELECT chain_id, chain_name, rpc_url, original_rpc_url, fallback_url, original_fallback_url, block_explorer_url, icon_url, native_currency_name, native_currency_symbol, native_currency_decimals, is_test, layer, enabled, chain_color, short_name, related_chain_id FROM networks"
func newNetworksQuery() *networksQuery {
buf := bytes.NewBuffer(nil)
buf.WriteString(baseQuery)
return &networksQuery{buf: buf}
}
type networksQuery struct {
buf *bytes.Buffer
args []interface{}
added bool
}
func (nq *networksQuery) andOrWhere() {
if nq.added {
nq.buf.WriteString(" AND")
} else {
nq.buf.WriteString(" WHERE")
}
}
func (nq *networksQuery) filterEnabled(enabled bool) *networksQuery {
nq.andOrWhere()
nq.added = true
nq.buf.WriteString(" enabled = ?")
nq.args = append(nq.args, enabled)
return nq
}
func (nq *networksQuery) filterChainID(chainID uint64) *networksQuery {
nq.andOrWhere()
nq.added = true
nq.buf.WriteString(" chain_id = ?")
nq.args = append(nq.args, chainID)
return nq
}
2023-09-15 10:06:44 +02:00
func (nq *networksQuery) exec(db *sql.DB) ([]*params.Network, error) {
rows, err := db.Query(nq.buf.String(), nq.args...)
if err != nil {
return nil, err
}
var res []*params.Network
defer rows.Close()
for rows.Next() {
network := params.Network{}
err := rows.Scan(
2023-09-15 10:06:44 +02:00
&network.ChainID, &network.ChainName, &network.RPCURL, &network.OriginalRPCURL, &network.FallbackURL, &network.OriginalFallbackURL,
&network.BlockExplorerURL, &network.IconURL, &network.NativeCurrencyName, &network.NativeCurrencySymbol,
&network.NativeCurrencyDecimals, &network.IsTest, &network.Layer, &network.Enabled, &network.ChainColor, &network.ShortName,
&network.RelatedChainID,
)
if err != nil {
return nil, err
}
2023-08-08 11:05:17 +02:00
res = append(res, &network)
}
return res, err
}
type ManagerInterface interface {
Get(onlyEnabled bool) ([]*params.Network, error)
GetAll() ([]*params.Network, error)
Find(chainID uint64) *params.Network
GetConfiguredNetworks() []params.Network
GetTestNetworksEnabled() (bool, error)
}
type Manager struct {
2023-08-08 11:05:17 +02:00
db *sql.DB
configuredNetworks []params.Network
2023-10-03 15:27:42 +02:00
accountsDB *accounts.Database
}
func NewManager(db *sql.DB) *Manager {
2023-10-03 15:27:42 +02:00
accountsDB, err := accounts.NewDB(db)
if err != nil {
return nil
}
return &Manager{
2023-10-03 15:27:42 +02:00
db: db,
accountsDB: accountsDB,
}
}
func find(chainID uint64, networks []params.Network) int {
for i := range networks {
if networks[i].ChainID == chainID {
return i
}
}
return -1
}
func (nm *Manager) Init(networks []params.Network) error {
if networks == nil {
return nil
}
2023-08-08 11:05:17 +02:00
nm.configuredNetworks = networks
var errors string
currentNetworks, _ := nm.Get(false)
// Delete networks which are not supported any more
for i := range currentNetworks {
if find(currentNetworks[i].ChainID, networks) == -1 {
err := nm.Delete(currentNetworks[i].ChainID)
if err != nil {
errors += fmt.Sprintf("error deleting network with ChainID: %d, %s", currentNetworks[i].ChainID, err.Error())
}
}
}
2023-08-08 11:05:17 +02:00
// Add new networks and update related chain id for the old ones
for i := range networks {
found := false
networks[i].OriginalRPCURL = networks[i].RPCURL
networks[i].OriginalFallbackURL = networks[i].FallbackURL
for j := range currentNetworks {
if currentNetworks[j].ChainID == networks[i].ChainID {
found = true
if currentNetworks[j].RelatedChainID != networks[i].RelatedChainID {
// Update fallback_url if it's different
err := nm.UpdateRelatedChainID(currentNetworks[j].ChainID, networks[i].RelatedChainID)
if err != nil {
errors += fmt.Sprintf("error updating network fallback_url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
}
}
if networks[i].OriginalRPCURL != currentNetworks[j].OriginalRPCURL && currentNetworks[j].RPCURL == currentNetworks[j].OriginalRPCURL {
err := nm.updateRPCURL(networks[i].ChainID, networks[i].OriginalRPCURL)
if err != nil {
errors += fmt.Sprintf("error updating rpc url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
}
}
if networks[i].OriginalFallbackURL != currentNetworks[j].OriginalFallbackURL && currentNetworks[j].FallbackURL == currentNetworks[j].OriginalFallbackURL {
err := nm.updateFallbackURL(networks[i].ChainID, networks[i].OriginalFallbackURL)
if err != nil {
errors += fmt.Sprintf("error updating rpc url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
}
}
err := nm.updateOriginalURLs(networks[i].ChainID, networks[i].OriginalRPCURL, networks[i].OriginalFallbackURL)
2023-09-15 10:06:44 +02:00
if err != nil {
errors += fmt.Sprintf("error updating network original url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
}
break
}
}
if !found {
// Insert new network
err := nm.Upsert(&networks[i])
if err != nil {
errors += fmt.Sprintf("error inserting network with ChainID: %d, %s", networks[i].ChainID, err.Error())
}
}
}
if len(errors) > 0 {
return fmt.Errorf(errors)
}
return nil
}
func (nm *Manager) Upsert(network *params.Network) error {
_, err := nm.db.Exec(
2023-09-15 10:06:44 +02:00
"INSERT OR REPLACE INTO networks (chain_id, chain_name, rpc_url, original_rpc_url, fallback_url, original_fallback_url, block_explorer_url, icon_url, native_currency_name, native_currency_symbol, native_currency_decimals, is_test, layer, enabled, chain_color, short_name, related_chain_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
network.ChainID, network.ChainName, network.RPCURL, network.OriginalRPCURL, network.FallbackURL, network.OriginalFallbackURL, network.BlockExplorerURL, network.IconURL,
network.NativeCurrencyName, network.NativeCurrencySymbol, network.NativeCurrencyDecimals,
network.IsTest, network.Layer, network.Enabled, network.ChainColor, network.ShortName,
network.RelatedChainID,
)
return err
}
func (nm *Manager) Delete(chainID uint64) error {
_, err := nm.db.Exec("DELETE FROM networks WHERE chain_id = ?", chainID)
return err
}
func (nm *Manager) UpdateRelatedChainID(chainID uint64, relatedChainID uint64) error {
_, err := nm.db.Exec(`UPDATE networks SET related_chain_id = ? WHERE chain_id = ?`, relatedChainID, chainID)
return err
}
func (nm *Manager) updateRPCURL(chainID uint64, rpcURL string) error {
_, err := nm.db.Exec(`UPDATE networks SET rpc_url = ? WHERE chain_id = ?`, rpcURL, chainID)
return err
}
func (nm *Manager) updateFallbackURL(chainID uint64, fallbackURL string) error {
_, err := nm.db.Exec(`UPDATE networks SET fallback_url = ? WHERE chain_id = ?`, fallbackURL, chainID)
return err
}
func (nm *Manager) updateOriginalURLs(chainID uint64, originalRPCURL, OriginalFallbackURL string) error {
_, err := nm.db.Exec(`UPDATE networks SET original_rpc_url = ?, original_fallback_url = ? WHERE chain_id = ?`, originalRPCURL, OriginalFallbackURL, chainID)
2023-09-15 10:06:44 +02:00
return err
}
func (nm *Manager) Find(chainID uint64) *params.Network {
2023-09-15 10:06:44 +02:00
networks, err := newNetworksQuery().filterChainID(chainID).exec(nm.db)
if len(networks) != 1 || err != nil {
return nil
}
setDefaultRPCURL(networks, nm.configuredNetworks)
return networks[0]
}
2023-10-03 15:27:42 +02:00
func (nm *Manager) GetAll() ([]*params.Network, error) {
query := newNetworksQuery()
networks, err := query.exec(nm.db)
setDefaultRPCURL(networks, nm.configuredNetworks)
return networks, err
2023-10-03 15:27:42 +02:00
}
func (nm *Manager) Get(onlyEnabled bool) ([]*params.Network, error) {
query := newNetworksQuery()
if onlyEnabled {
query.filterEnabled(true)
}
2023-10-03 15:27:42 +02:00
networks, err := query.exec(nm.db)
if err != nil {
return nil, err
}
var results []*params.Network
for _, network := range networks {
configuredNetwork, err := findNetwork(nm.configuredNetworks, network.ChainID)
if err != nil {
addDefaultRPCURL(network, configuredNetwork)
}
2023-10-03 15:27:42 +02:00
results = append(results, network)
}
return results, nil
}
func (nm *Manager) GetCombinedNetworks() ([]*CombinedNetwork, error) {
2023-10-03 15:27:42 +02:00
networks, err := nm.Get(false)
if err != nil {
return nil, err
}
var combinedNetworks []*CombinedNetwork
for _, network := range networks {
found := false
for _, n := range combinedNetworks {
2023-10-03 15:27:42 +02:00
if (n.Test != nil && (network.ChainID == n.Test.RelatedChainID || n.Test.ChainID == network.RelatedChainID)) || (n.Prod != nil && (network.ChainID == n.Prod.RelatedChainID || n.Prod.ChainID == network.RelatedChainID)) {
found = true
if network.IsTest {
n.Test = network
break
} else {
n.Prod = network
break
}
}
}
if found {
continue
}
newCombined := &CombinedNetwork{}
if network.IsTest {
newCombined.Test = network
} else {
newCombined.Prod = network
}
combinedNetworks = append(combinedNetworks, newCombined)
}
return combinedNetworks, nil
}
func (nm *Manager) GetConfiguredNetworks() []params.Network {
2023-08-08 11:05:17 +02:00
return nm.configuredNetworks
}
func (nm *Manager) GetTestNetworksEnabled() (result bool, err error) {
return nm.accountsDB.GetTestNetworksEnabled()
}
// Returns all networks for active mode (test/prod)
func (nm *Manager) GetActiveNetworks() ([]*params.Network, error) {
areTestNetworksEnabled, err := nm.GetTestNetworksEnabled()
if err != nil {
return nil, err
}
networks, err := nm.Get(false)
if err != nil {
return nil, err
}
availableNetworks := make([]*params.Network, 0)
for _, network := range networks {
if network.IsTest != areTestNetworksEnabled {
continue
}
availableNetworks = append(availableNetworks, network)
}
return availableNetworks, nil
}
func findNetwork(networks []params.Network, chainID uint64) (params.Network, error) {
for _, network := range networks {
if network.ChainID == chainID {
return network, nil
}
}
return params.Network{}, fmt.Errorf("network not found")
}
func addDefaultRPCURL(target *params.Network, source params.Network) {
target.DefaultRPCURL = source.DefaultRPCURL
target.DefaultFallbackURL = source.DefaultFallbackURL
target.DefaultFallbackURL2 = source.DefaultFallbackURL2
}
func setDefaultRPCURL(target []*params.Network, source []params.Network) {
for i := range target {
for j := range source {
if target[i].ChainID == source[j].ChainID {
addDefaultRPCURL(target[i], source[j])
break
}
}
}
}