mirror of
https://github.com/status-im/status-go.git
synced 2025-01-10 14:47:06 +00:00
192 lines
4.3 KiB
Go
192 lines
4.3 KiB
Go
|
package storenodes
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"github.com/status-im/status-go/eth-node/types"
|
||
|
)
|
||
|
|
||
|
type Database struct {
|
||
|
db *sql.DB
|
||
|
}
|
||
|
|
||
|
func NewDB(db *sql.DB) *Database {
|
||
|
return &Database{db: db}
|
||
|
}
|
||
|
|
||
|
// syncSave will sync the storenodes in the DB from the snode slice
|
||
|
// - if a storenode is not in the provided list, it will be soft-deleted
|
||
|
// - if a storenode is in the provided list, it will be inserted or updated
|
||
|
func (d *Database) syncSave(communityID types.HexBytes, snode []Storenode, clock uint64) (err error) {
|
||
|
var tx *sql.Tx
|
||
|
tx, err = d.db.Begin()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer func() {
|
||
|
if err == nil {
|
||
|
err = tx.Commit()
|
||
|
return
|
||
|
}
|
||
|
_ = tx.Rollback()
|
||
|
}()
|
||
|
|
||
|
now := time.Now().Unix()
|
||
|
dbNodes, err := d.getByCommunityID(communityID, tx)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("getting storenodes by community id: %w", err)
|
||
|
}
|
||
|
// Soft-delete db nodes that are not in the provided list
|
||
|
for _, dbN := range dbNodes {
|
||
|
if find(dbN, snode) != nil {
|
||
|
continue
|
||
|
}
|
||
|
if clock != 0 && dbN.Clock >= clock {
|
||
|
continue
|
||
|
}
|
||
|
if err := d.softDelete(communityID, dbN.StorenodeID, now, tx); err != nil {
|
||
|
return fmt.Errorf("soft deleting existing storenodes: %w", err)
|
||
|
}
|
||
|
|
||
|
}
|
||
|
// Insert or update the nodes in the provided list
|
||
|
for _, n := range snode {
|
||
|
// defensively validate the communityID
|
||
|
if len(n.CommunityID) == 0 || !bytes.Equal(communityID, n.CommunityID) {
|
||
|
err = fmt.Errorf("communityID mismatch %v != %v", communityID, n.CommunityID)
|
||
|
return err
|
||
|
}
|
||
|
dbN := find(n, dbNodes)
|
||
|
if dbN != nil && n.Clock != 0 && dbN.Clock >= n.Clock {
|
||
|
continue
|
||
|
}
|
||
|
if err := d.upsert(n, tx); err != nil {
|
||
|
return fmt.Errorf("upserting storenodes: %w", err)
|
||
|
}
|
||
|
}
|
||
|
// TODO for now only allow one storenode per community
|
||
|
count, err := d.countByCommunity(communityID, tx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if count > 1 {
|
||
|
err = fmt.Errorf("only one storenode per community is allowed")
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (d *Database) getAll() ([]Storenode, error) {
|
||
|
rows, err := d.db.Query(`
|
||
|
SELECT community_id, storenode_id, name, address, fleet, version, clock, removed, deleted_at
|
||
|
FROM community_storenodes
|
||
|
WHERE removed = 0
|
||
|
`)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
return toStorenodes(rows)
|
||
|
}
|
||
|
|
||
|
func (d *Database) getByCommunityID(communityID types.HexBytes, tx ...*sql.Tx) ([]Storenode, error) {
|
||
|
var rows *sql.Rows
|
||
|
var err error
|
||
|
q := `
|
||
|
SELECT community_id, storenode_id, name, address, fleet, version, clock, removed, deleted_at
|
||
|
FROM community_storenodes
|
||
|
WHERE community_id = ? AND removed = 0
|
||
|
`
|
||
|
if len(tx) > 0 {
|
||
|
rows, err = tx[0].Query(q, communityID)
|
||
|
} else {
|
||
|
rows, err = d.db.Query(q, communityID)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
return toStorenodes(rows)
|
||
|
}
|
||
|
|
||
|
func (d *Database) softDelete(communityID types.HexBytes, storenodeID string, deletedAt int64, tx *sql.Tx) error {
|
||
|
_, err := tx.Exec("UPDATE community_storenodes SET removed = 1, deleted_at = ? WHERE community_id = ? AND storenode_id = ?", deletedAt, communityID, storenodeID)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (d *Database) upsert(n Storenode, tx *sql.Tx) error {
|
||
|
_, err := tx.Exec(`INSERT OR REPLACE INTO community_storenodes(
|
||
|
community_id,
|
||
|
storenode_id,
|
||
|
name,
|
||
|
address,
|
||
|
fleet,
|
||
|
version,
|
||
|
clock,
|
||
|
removed,
|
||
|
deleted_at
|
||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||
|
n.CommunityID,
|
||
|
n.StorenodeID,
|
||
|
n.Name,
|
||
|
n.Address,
|
||
|
n.Fleet,
|
||
|
n.Version,
|
||
|
n.Clock,
|
||
|
n.Removed,
|
||
|
n.DeletedAt,
|
||
|
)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (d *Database) countByCommunity(communityID types.HexBytes, tx *sql.Tx) (int, error) {
|
||
|
var count int
|
||
|
err := tx.QueryRow(`SELECT COUNT(*) FROM community_storenodes WHERE community_id = ? AND removed = 0`, communityID).Scan(&count)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return count, nil
|
||
|
}
|
||
|
|
||
|
func toStorenodes(rows *sql.Rows) ([]Storenode, error) {
|
||
|
var result []Storenode
|
||
|
|
||
|
for rows.Next() {
|
||
|
var m Storenode
|
||
|
if err := rows.Scan(
|
||
|
&m.CommunityID,
|
||
|
&m.StorenodeID,
|
||
|
&m.Name,
|
||
|
&m.Address,
|
||
|
&m.Fleet,
|
||
|
&m.Version,
|
||
|
&m.Clock,
|
||
|
&m.Removed,
|
||
|
&m.DeletedAt,
|
||
|
); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
result = append(result, m)
|
||
|
}
|
||
|
|
||
|
return result, nil
|
||
|
}
|
||
|
|
||
|
func find(n Storenode, nodes []Storenode) *Storenode {
|
||
|
for i, node := range nodes {
|
||
|
if node.StorenodeID == n.StorenodeID && bytes.Equal(node.CommunityID, n.CommunityID) {
|
||
|
return &nodes[i]
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|