package collectibles import ( "database/sql" "fmt" "github.com/status-im/status-go/services/wallet/thirdparty" "github.com/status-im/status-go/sqlite" ) type CollectionDataDB struct { db *sql.DB } func NewCollectionDataDB(sqlDb *sql.DB) *CollectionDataDB { return &CollectionDataDB{ db: sqlDb, } } const collectionDataColumns = "chain_id, contract_address, provider, name, slug, image_url, community_id" const collectionTraitsColumns = "chain_id, contract_address, trait_type, min, max" const selectCollectionTraitsColumns = "trait_type, min, max" func rowsToCollectionTraits(rows *sql.Rows) (map[string]thirdparty.CollectionTrait, error) { traits := make(map[string]thirdparty.CollectionTrait) for rows.Next() { var traitType string var trait thirdparty.CollectionTrait err := rows.Scan( &traitType, &trait.Min, &trait.Max, ) if err != nil { return nil, err } traits[traitType] = trait } return traits, nil } func getCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID) (map[string]thirdparty.CollectionTrait, error) { // Get traits list selectTraits, err := creator.Prepare(fmt.Sprintf(`SELECT %s FROM collection_traits_cache WHERE chain_id = ? AND contract_address = ?`, selectCollectionTraitsColumns)) if err != nil { return nil, err } rows, err := selectTraits.Query( id.ChainID, id.Address, ) if err != nil { return nil, err } return rowsToCollectionTraits(rows) } func upsertCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID, traits map[string]thirdparty.CollectionTrait) error { // Rremove old traits list deleteTraits, err := creator.Prepare(`DELETE FROM collection_traits_cache WHERE chain_id = ? AND contract_address = ?`) if err != nil { return err } _, err = deleteTraits.Exec( id.ChainID, id.Address, ) if err != nil { return err } // Insert new traits list insertTrait, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_traits_cache (%s) VALUES (?, ?, ?, ?, ?)`, collectionTraitsColumns)) if err != nil { return err } for traitType, trait := range traits { _, err = insertTrait.Exec( id.ChainID, id.Address, traitType, trait.Min, trait.Max, ) if err != nil { return err } } return nil } func upsertCollectionsData(creator sqlite.StatementCreator, collections []thirdparty.CollectionData) error { insertCollection, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_data_cache (%s) VALUES (?, ?, ?, ?, ?, ?, ?)`, collectionDataColumns)) if err != nil { return err } for _, c := range collections { _, err = insertCollection.Exec( c.ID.ChainID, c.ID.Address, c.Provider, c.Name, c.Slug, c.ImageURL, c.CommunityID, ) if err != nil { return err } err = upsertCollectionTraits(creator, c.ID, c.Traits) if err != nil { return err } } return nil } func (o *CollectionDataDB) SetData(collections []thirdparty.CollectionData) (err error) { tx, err := o.db.Begin() if err != nil { return err } defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() // Insert new collections data err = upsertCollectionsData(tx, collections) if err != nil { return err } return } func scanCollectionsDataRow(row *sql.Row) (*thirdparty.CollectionData, error) { c := thirdparty.CollectionData{ Traits: make(map[string]thirdparty.CollectionTrait), } err := row.Scan( &c.ID.ChainID, &c.ID.Address, &c.Provider, &c.Name, &c.Slug, &c.ImageURL, &c.CommunityID, ) if err != nil { return nil, err } return &c, nil } func (o *CollectionDataDB) GetIDsNotInDB(ids []thirdparty.ContractID) ([]thirdparty.ContractID, error) { ret := make([]thirdparty.ContractID, 0, len(ids)) exists, err := o.db.Prepare(`SELECT EXISTS ( SELECT 1 FROM collection_data_cache WHERE chain_id=? AND contract_address=? )`) if err != nil { return nil, err } for _, id := range ids { row := exists.QueryRow( id.ChainID, id.Address, ) var exists bool err = row.Scan(&exists) if err != nil { return nil, err } if !exists { ret = append(ret, id) } } return ret, nil } func (o *CollectionDataDB) GetData(ids []thirdparty.ContractID) (map[string]thirdparty.CollectionData, error) { ret := make(map[string]thirdparty.CollectionData) getData, err := o.db.Prepare(fmt.Sprintf(`SELECT %s FROM collection_data_cache WHERE chain_id=? AND contract_address=?`, collectionDataColumns)) if err != nil { return nil, err } for _, id := range ids { row := getData.QueryRow( id.ChainID, id.Address, ) c, err := scanCollectionsDataRow(row) if err == sql.ErrNoRows { continue } else if err != nil { return nil, err } else { // Get traits from different table c.Traits, err = getCollectionTraits(o.db, c.ID) if err != nil { return nil, err } ret[c.ID.HashKey()] = *c } } return ret, nil }