status-go/services/wallet/collectibles/collection_data_db.go

230 lines
4.8 KiB
Go
Raw Normal View History

package collectibles
import (
"database/sql"
"fmt"
"github.com/status-im/status-go/services/wallet/thirdparty"
"github.com/status-im/status-go/sqlite"
)
type CollectionDataDB struct {
db *sql.DB
}
func NewCollectionDataDB(sqlDb *sql.DB) *CollectionDataDB {
return &CollectionDataDB{
db: sqlDb,
}
}
const collectionDataColumns = "chain_id, contract_address, provider, name, slug, image_url"
const collectionTraitsColumns = "chain_id, contract_address, trait_type, min, max"
const selectCollectionTraitsColumns = "trait_type, min, max"
func rowsToCollectionTraits(rows *sql.Rows) (map[string]thirdparty.CollectionTrait, error) {
traits := make(map[string]thirdparty.CollectionTrait)
for rows.Next() {
var traitType string
var trait thirdparty.CollectionTrait
err := rows.Scan(
&traitType,
&trait.Min,
&trait.Max,
)
if err != nil {
return nil, err
}
traits[traitType] = trait
}
return traits, nil
}
func getCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID) (map[string]thirdparty.CollectionTrait, error) {
// Get traits list
selectTraits, err := creator.Prepare(fmt.Sprintf(`SELECT %s
FROM collection_traits_cache
WHERE chain_id = ? AND contract_address = ?`, selectCollectionTraitsColumns))
if err != nil {
return nil, err
}
rows, err := selectTraits.Query(
id.ChainID,
id.Address,
)
if err != nil {
return nil, err
}
return rowsToCollectionTraits(rows)
}
func upsertCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID, traits map[string]thirdparty.CollectionTrait) error {
// Rremove old traits list
deleteTraits, err := creator.Prepare(`DELETE FROM collection_traits_cache WHERE chain_id = ? AND contract_address = ?`)
if err != nil {
return err
}
_, err = deleteTraits.Exec(
id.ChainID,
id.Address,
)
if err != nil {
return err
}
// Insert new traits list
insertTrait, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_traits_cache (%s)
VALUES (?, ?, ?, ?, ?)`, collectionTraitsColumns))
if err != nil {
return err
}
for traitType, trait := range traits {
_, err = insertTrait.Exec(
id.ChainID,
id.Address,
traitType,
trait.Min,
trait.Max,
)
if err != nil {
return err
}
}
return nil
}
func upsertCollectionsData(creator sqlite.StatementCreator, collections []thirdparty.CollectionData) error {
insertCollection, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_data_cache (%s)
VALUES (?, ?, ?, ?, ?, ?)`, collectionDataColumns))
if err != nil {
return err
}
for _, c := range collections {
_, err = insertCollection.Exec(
c.ID.ChainID,
c.ID.Address,
c.Provider,
c.Name,
c.Slug,
c.ImageURL,
)
if err != nil {
return err
}
err = upsertCollectionTraits(creator, c.ID, c.Traits)
if err != nil {
return err
}
}
return nil
}
func (o *CollectionDataDB) SetData(collections []thirdparty.CollectionData) (err error) {
tx, err := o.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
// Insert new collections data
err = upsertCollectionsData(tx, collections)
if err != nil {
return err
}
return
}
func scanCollectionsDataRow(row *sql.Row) (*thirdparty.CollectionData, error) {
c := thirdparty.CollectionData{
Traits: make(map[string]thirdparty.CollectionTrait),
}
err := row.Scan(
&c.ID.ChainID,
&c.ID.Address,
&c.Provider,
&c.Name,
&c.Slug,
&c.ImageURL,
)
if err != nil {
return nil, err
}
return &c, nil
}
func (o *CollectionDataDB) GetIDsNotInDB(ids []thirdparty.ContractID) ([]thirdparty.ContractID, error) {
ret := make([]thirdparty.ContractID, 0, len(ids))
exists, err := o.db.Prepare(`SELECT EXISTS (
SELECT 1 FROM collection_data_cache
WHERE chain_id=? AND contract_address=?
)`)
if err != nil {
return nil, err
}
for _, id := range ids {
row := exists.QueryRow(
id.ChainID,
id.Address,
)
var exists bool
err = row.Scan(&exists)
if err != nil {
return nil, err
}
if !exists {
ret = append(ret, id)
}
}
return ret, nil
}
func (o *CollectionDataDB) GetData(ids []thirdparty.ContractID) (map[string]thirdparty.CollectionData, error) {
ret := make(map[string]thirdparty.CollectionData)
getData, err := o.db.Prepare(fmt.Sprintf(`SELECT %s
FROM collection_data_cache
WHERE chain_id=? AND contract_address=?`, collectionDataColumns))
if err != nil {
return nil, err
}
for _, id := range ids {
row := getData.QueryRow(
id.ChainID,
id.Address,
)
c, err := scanCollectionsDataRow(row)
if err == sql.ErrNoRows {
continue
} else if err != nil {
return nil, err
} else {
// Get traits from different table
c.Traits, err = getCollectionTraits(o.db, c.ID)
if err != nil {
return nil, err
}
ret[c.ID.HashKey()] = *c
}
}
return ret, nil
}