status-go/services/wallet/history/balance_db.go

153 lines
4.5 KiB
Go

package history
import (
"database/sql"
"encoding/hex"
"fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/services/wallet/bigint"
)
type BalanceDB struct {
db *sql.DB
}
func NewBalanceDB(sqlDb *sql.DB) *BalanceDB {
return &BalanceDB{
db: sqlDb,
}
}
// entry represents a single row in the balance_history table
type entry struct {
chainID uint64
address common.Address
tokenSymbol string
tokenAddress common.Address
block *big.Int
timestamp int64
balance *big.Int
}
type assetIdentity struct {
ChainID uint64
Addresses []common.Address
TokenSymbol string
}
func (a *assetIdentity) addressesToString() string {
var addressesStr string
for i, address := range a.Addresses {
addressStr := hex.EncodeToString(address[:])
if i == 0 {
addressesStr = "X'" + addressStr + "'"
} else {
addressesStr += ", X'" + addressStr + "'"
}
}
return addressesStr
}
func (e *entry) String() string {
return fmt.Sprintf("chainID: %v, address: %v, tokenSymbol: %v, tokenAddress: %v, block: %v, timestamp: %v, balance: %v",
e.chainID, e.address, e.tokenSymbol, e.tokenAddress, e.block, e.timestamp, e.balance)
}
func (b *BalanceDB) add(entry *entry) error {
log.Debug("Adding entry to balance_history", "entry", entry)
_, err := b.db.Exec("INSERT OR IGNORE INTO balance_history (chain_id, address, currency, block, timestamp, balance) VALUES (?, ?, ?, ?, ?, ?)", entry.chainID, entry.address, entry.tokenSymbol, (*bigint.SQLBigInt)(entry.block), entry.timestamp, (*bigint.SQLBigIntBytes)(entry.balance))
return err
}
func (b *BalanceDB) getEntriesWithoutBalances(chainID uint64, address common.Address) (entries []*entry, err error) {
rows, err := b.db.Query("SELECT blk_number, tr.timestamp, token_address from transfers tr LEFT JOIN balance_history bh ON bh.block = tr.blk_number WHERE tr.network_id = ? AND tr.address = ? AND tr.type != 'erc721' AND bh.block IS NULL",
chainID, address)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
defer rows.Close()
entries = make([]*entry, 0)
for rows.Next() {
entry := &entry{
chainID: chainID,
address: address,
block: new(big.Int),
}
// tokenAddress can be NULL and can not unmarshal to common.Address
tokenHexAddress := make([]byte, common.AddressLength)
err := rows.Scan((*bigint.SQLBigInt)(entry.block), &entry.timestamp, &tokenHexAddress)
if err != nil {
return nil, err
}
tokenAddress := common.BytesToAddress(tokenHexAddress)
if tokenAddress != (common.Address{}) {
entry.tokenAddress = tokenAddress
}
entries = append(entries, entry)
}
return entries, nil
}
func (b *BalanceDB) getNewerThan(identity *assetIdentity, timestamp uint64) (entries []*entry, err error) {
// DISTINCT removes duplicates that can happen when a block has multiple transfers of same token
rawQueryStr := "SELECT DISTINCT block, timestamp, balance, address FROM balance_history WHERE chain_id = ? AND address IN (%s) AND currency = ? AND timestamp > ? ORDER BY timestamp"
queryString := fmt.Sprintf(rawQueryStr, identity.addressesToString())
rows, err := b.db.Query(queryString, identity.ChainID, identity.TokenSymbol, timestamp)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
defer rows.Close()
result := make([]*entry, 0)
for rows.Next() {
entry := &entry{
chainID: identity.ChainID,
tokenSymbol: identity.TokenSymbol,
block: new(big.Int),
balance: new(big.Int),
}
err := rows.Scan((*bigint.SQLBigInt)(entry.block), &entry.timestamp, (*bigint.SQLBigIntBytes)(entry.balance), &entry.address)
if err != nil {
return nil, err
}
result = append(result, entry)
}
return result, nil
}
func (b *BalanceDB) getEntryPreviousTo(item *entry) (res *entry, err error) {
res = &entry{
chainID: item.chainID,
address: item.address,
block: new(big.Int),
balance: new(big.Int),
tokenSymbol: item.tokenSymbol,
}
queryStr := "SELECT block, timestamp, balance FROM balance_history WHERE chain_id = ? AND address = ? AND currency = ? AND timestamp < ? ORDER BY timestamp DESC LIMIT 1"
row := b.db.QueryRow(queryStr, item.chainID, item.address, item.tokenSymbol, item.timestamp)
err = row.Scan((*bigint.SQLBigInt)(res.block), &res.timestamp, (*bigint.SQLBigIntBytes)(res.balance))
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return res, nil
}