status-go/services/wallet/collectibles/collectible_data_db.go

358 lines
8.9 KiB
Go
Raw Normal View History

package collectibles
//go:generate mockgen -package=mock_collectibles -source=collectible_data_db.go -destination=mock/collectible_data_db.go
import (
"database/sql"
"fmt"
"math/big"
"github.com/status-im/status-go/protocol/communities/token"
"github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/thirdparty"
"github.com/status-im/status-go/sqlite"
)
type CollectibleDataStorage interface {
SetData(collectibles []thirdparty.CollectibleData, allowUpdate bool) error
GetIDsNotInDB(ids []thirdparty.CollectibleUniqueID) ([]thirdparty.CollectibleUniqueID, error)
GetData(ids []thirdparty.CollectibleUniqueID) (map[string]thirdparty.CollectibleData, error)
SetCommunityInfo(id thirdparty.CollectibleUniqueID, communityInfo thirdparty.CollectibleCommunityInfo) error
GetCommunityInfo(id thirdparty.CollectibleUniqueID) (*thirdparty.CollectibleCommunityInfo, error)
}
type CollectibleDataDB struct {
db *sql.DB
}
func NewCollectibleDataDB(sqlDb *sql.DB) *CollectibleDataDB {
return &CollectibleDataDB{
db: sqlDb,
}
}
const collectibleDataColumns = "chain_id, contract_address, token_id, provider, name, description, permalink, image_url, image_payload, animation_url, animation_media_type, background_color, token_uri, community_id, soulbound"
const collectibleCommunityDataColumns = "community_privileges_level"
const collectibleTraitsColumns = "chain_id, contract_address, token_id, trait_type, trait_value, display_type, max_value"
const selectCollectibleTraitsColumns = "trait_type, trait_value, display_type, max_value"
func rowsToCollectibleTraits(rows *sql.Rows) ([]thirdparty.CollectibleTrait, error) {
var traits []thirdparty.CollectibleTrait = make([]thirdparty.CollectibleTrait, 0)
for rows.Next() {
var trait thirdparty.CollectibleTrait
err := rows.Scan(
&trait.TraitType,
&trait.Value,
&trait.DisplayType,
&trait.MaxValue,
)
if err != nil {
return nil, err
}
traits = append(traits, trait)
}
return traits, nil
}
func getCollectibleTraits(creator sqlite.StatementCreator, id thirdparty.CollectibleUniqueID) ([]thirdparty.CollectibleTrait, error) {
// Get traits list
selectTraits, err := creator.Prepare(fmt.Sprintf(`SELECT %s
FROM collectible_traits_cache
WHERE chain_id = ? AND contract_address = ? AND token_id = ?`, selectCollectibleTraitsColumns))
if err != nil {
return nil, err
}
rows, err := selectTraits.Query(
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
)
if err != nil {
return nil, err
}
return rowsToCollectibleTraits(rows)
}
func upsertCollectibleTraits(creator sqlite.StatementCreator, id thirdparty.CollectibleUniqueID, traits []thirdparty.CollectibleTrait) error {
// Remove old traits list
deleteTraits, err := creator.Prepare(`DELETE FROM collectible_traits_cache WHERE chain_id = ? AND contract_address = ? AND token_id = ?`)
if err != nil {
return err
}
_, err = deleteTraits.Exec(
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
)
if err != nil {
return err
}
// Insert new traits list
insertTrait, err := creator.Prepare(fmt.Sprintf(`INSERT INTO collectible_traits_cache (%s)
VALUES (?, ?, ?, ?, ?, ?, ?)`, collectibleTraitsColumns))
if err != nil {
return err
}
for _, t := range traits {
_, err = insertTrait.Exec(
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
t.TraitType,
t.Value,
t.DisplayType,
t.MaxValue,
)
if err != nil {
return err
}
}
return nil
}
func setCollectiblesData(creator sqlite.StatementCreator, collectibles []thirdparty.CollectibleData, allowUpdate bool) error {
insertCollectible, err := creator.Prepare(fmt.Sprintf(`%s INTO collectible_data_cache (%s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, insertStatement(allowUpdate), collectibleDataColumns))
if err != nil {
return err
}
for _, c := range collectibles {
_, err = insertCollectible.Exec(
c.ID.ContractID.ChainID,
c.ID.ContractID.Address,
(*bigint.SQLBigIntBytes)(c.ID.TokenID.Int),
c.Provider,
c.Name,
c.Description,
c.Permalink,
c.ImageURL,
c.ImagePayload,
c.AnimationURL,
c.AnimationMediaType,
c.BackgroundColor,
c.TokenURI,
c.CommunityID,
c.Soulbound,
)
if err != nil {
return err
}
err = upsertContractType(creator, c.ID.ContractID, c.ContractType)
if err != nil {
return err
}
if allowUpdate {
err = upsertCollectibleTraits(creator, c.ID, c.Traits)
if err != nil {
return err
}
}
}
return nil
}
func (o *CollectibleDataDB) SetData(collectibles []thirdparty.CollectibleData, allowUpdate bool) (err error) {
tx, err := o.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
// Insert new collectibles data
err = setCollectiblesData(tx, collectibles, allowUpdate)
if err != nil {
return err
}
return
}
func scanCollectiblesDataRow(row *sql.Row) (*thirdparty.CollectibleData, error) {
c := thirdparty.CollectibleData{
ID: thirdparty.CollectibleUniqueID{
TokenID: &bigint.BigInt{Int: big.NewInt(0)},
},
Traits: make([]thirdparty.CollectibleTrait, 0),
}
err := row.Scan(
&c.ID.ContractID.ChainID,
&c.ID.ContractID.Address,
(*bigint.SQLBigIntBytes)(c.ID.TokenID.Int),
&c.Provider,
&c.Name,
&c.Description,
&c.Permalink,
&c.ImageURL,
&c.ImagePayload,
&c.AnimationURL,
&c.AnimationMediaType,
&c.BackgroundColor,
&c.TokenURI,
&c.CommunityID,
&c.Soulbound,
)
if err != nil {
return nil, err
}
return &c, nil
}
func (o *CollectibleDataDB) GetIDsNotInDB(ids []thirdparty.CollectibleUniqueID) ([]thirdparty.CollectibleUniqueID, error) {
ret := make([]thirdparty.CollectibleUniqueID, 0, len(ids))
idMap := make(map[string]thirdparty.CollectibleUniqueID, len(ids))
// Ensure we don't have duplicates
for _, id := range ids {
idMap[id.HashKey()] = id
}
exists, err := o.db.Prepare(`SELECT EXISTS (
SELECT 1 FROM collectible_data_cache
WHERE chain_id=? AND contract_address=? AND token_id=?
)`)
if err != nil {
return nil, err
}
for _, id := range idMap {
row := exists.QueryRow(
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
)
var exists bool
err = row.Scan(&exists)
if err != nil {
return nil, err
}
if !exists {
ret = append(ret, id)
}
}
return ret, nil
}
func (o *CollectibleDataDB) GetData(ids []thirdparty.CollectibleUniqueID) (map[string]thirdparty.CollectibleData, error) {
ret := make(map[string]thirdparty.CollectibleData)
getData, err := o.db.Prepare(fmt.Sprintf(`SELECT %s
FROM collectible_data_cache
WHERE chain_id=? AND contract_address=? AND token_id=?`, collectibleDataColumns))
if err != nil {
return nil, err
}
for _, id := range ids {
row := getData.QueryRow(
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
)
c, err := scanCollectiblesDataRow(row)
if err == sql.ErrNoRows {
continue
} else if err != nil {
return nil, err
} else {
// Get traits from different table
c.Traits, err = getCollectibleTraits(o.db, c.ID)
if err != nil {
return nil, err
}
// Get contract type from different table
c.ContractType, err = readContractType(o.db, c.ID.ContractID)
if err != nil {
return nil, err
}
ret[c.ID.HashKey()] = *c
}
}
return ret, nil
}
func (o *CollectibleDataDB) SetCommunityInfo(id thirdparty.CollectibleUniqueID, communityInfo thirdparty.CollectibleCommunityInfo) (err error) {
tx, err := o.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
update, err := tx.Prepare(`UPDATE collectible_data_cache
SET community_privileges_level=?
WHERE chain_id=? AND contract_address=? AND token_id=?`)
if err != nil {
return err
}
_, err = update.Exec(
communityInfo.PrivilegesLevel,
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
)
return err
}
func (o *CollectibleDataDB) GetCommunityInfo(id thirdparty.CollectibleUniqueID) (*thirdparty.CollectibleCommunityInfo, error) {
ret := thirdparty.CollectibleCommunityInfo{
PrivilegesLevel: token.CommunityLevel,
}
getData, err := o.db.Prepare(fmt.Sprintf(`SELECT %s
FROM collectible_data_cache
WHERE chain_id=? AND contract_address=? AND token_id=?`, collectibleCommunityDataColumns))
if err != nil {
return nil, err
}
row := getData.QueryRow(
id.ContractID.ChainID,
id.ContractID.Address,
(*bigint.SQLBigIntBytes)(id.TokenID.Int),
)
var dbPrivilegesLevel sql.NullByte
err = row.Scan(
&dbPrivilegesLevel,
)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
if dbPrivilegesLevel.Valid {
ret.PrivilegesLevel = token.PrivilegesLevel(dbPrivilegesLevel.Byte)
}
return &ret, nil
}