package wallet import ( "database/sql" "database/sql/driver" "encoding/json" "errors" "math/big" "reflect" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) // DBHeader fields from header that are stored in database. type DBHeader struct { Number *big.Int Hash common.Hash Timestamp uint64 // Head is true if the block was a head at the time it was pulled from chain. Head bool } func toDBHeader(header *types.Header) *DBHeader { return &DBHeader{ Hash: header.Hash(), Number: header.Number, Timestamp: header.Time, } } func toHead(header *types.Header) *DBHeader { dbheader := toDBHeader(header) dbheader.Head = true return dbheader } // SyncOption is used to specify that application processed transfers for that block. type SyncOption uint const ( // sync options ethSync SyncOption = 1 erc20Sync SyncOption = 2 ) // SQLBigInt type for storing uint256 in the databse. // FIXME(dshulyak) SQL big int is max 64 bits. Maybe store as bytes in big endian and hope // that lexographical sorting will work. type SQLBigInt big.Int // Scan implements interface. func (i *SQLBigInt) Scan(value interface{}) error { val, ok := value.(int64) if !ok { return errors.New("not an integer") } (*big.Int)(i).SetInt64(val) return nil } // Value implements interface. func (i *SQLBigInt) Value() (driver.Value, error) { if !(*big.Int)(i).IsInt64() { return nil, errors.New("not an int64") } return (*big.Int)(i).Int64(), nil } // JSONBlob type for marshaling/unmarshaling inner type to json. type JSONBlob struct { data interface{} } // Scan implements interface. func (blob *JSONBlob) Scan(value interface{}) error { if value == nil || reflect.ValueOf(blob.data).IsNil() { return nil } bytes, ok := value.([]byte) if !ok { return errors.New("not a byte slice") } if len(bytes) == 0 { return nil } err := json.Unmarshal(bytes, blob.data) return err } // Value implements interface. func (blob *JSONBlob) Value() (driver.Value, error) { if blob.data == nil || reflect.ValueOf(blob.data).IsNil() { return nil, nil } return json.Marshal(blob.data) } func NewDB(db *sql.DB, network uint64) *Database { return &Database{db: db, network: network} } // Database sql wrapper for operations with wallet objects. type Database struct { db *sql.DB network uint64 } // Close closes database. func (db Database) Close() error { return db.db.Close() } // ProcessTranfers atomically adds/removes blocks and adds new tranfers. func (db Database) ProcessTranfers(transfers []Transfer, accounts []common.Address, added, removed []*DBHeader, option SyncOption) (err error) { var ( tx *sql.Tx ) tx, err = db.db.Begin() if err != nil { return err } defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() err = deleteHeaders(tx, removed) if err != nil { return } err = insertHeaders(tx, db.network, added) if err != nil { return } err = insertTransfers(tx, db.network, transfers) if err != nil { return } err = updateAccounts(tx, db.network, accounts, added, option) return } // GetTransfersByAddress loads transfers for a given address between two blocks. func (db *Database) GetTransfersByAddress(address common.Address, start, end *big.Int) (rst []Transfer, err error) { query := newTransfersQuery().FilterNetwork(db.network).FilterAddress(address).FilterStart(start).FilterEnd(end) rows, err := db.db.Query(query.String(), query.Args()...) if err != nil { return } defer rows.Close() return query.Scan(rows) } // GetTransfers load transfers transfer betweeen two blocks. func (db *Database) GetTransfers(start, end *big.Int) (rst []Transfer, err error) { query := newTransfersQuery().FilterNetwork(db.network).FilterStart(start).FilterEnd(end) rows, err := db.db.Query(query.String(), query.Args()...) if err != nil { return } defer rows.Close() return query.Scan(rows) } // SaveHeaders stores a list of headers atomically. func (db *Database) SaveHeaders(headers []*types.Header) (err error) { var ( tx *sql.Tx insert *sql.Stmt ) tx, err = db.db.Begin() if err != nil { return } insert, err = tx.Prepare("INSERT INTO blocks(network_id, number, hash, timestamp) VALUES (?, ?, ?, ?)") if err != nil { return } defer func() { if err == nil { err = tx.Commit() } else { _ = tx.Rollback() } }() for _, h := range headers { _, err = insert.Exec(db.network, (*SQLBigInt)(h.Number), h.Hash(), h.Time) if err != nil { return } } return } func (db *Database) SaveSyncedHeader(address common.Address, header *types.Header, option SyncOption) (err error) { var ( tx *sql.Tx insert *sql.Stmt ) tx, err = db.db.Begin() if err != nil { return } insert, err = tx.Prepare("INSERT INTO accounts_to_blocks(network_id, address, blk_number, sync) VALUES (?, ?,?,?)") if err != nil { return } defer func() { if err == nil { err = tx.Commit() } else { _ = tx.Rollback() } }() _, err = insert.Exec(db.network, address, (*SQLBigInt)(header.Number), option) if err != nil { return } return err } // HeaderExists checks if header with hash exists in db. func (db *Database) HeaderExists(hash common.Hash) (bool, error) { var val sql.NullBool err := db.db.QueryRow("SELECT EXISTS (SELECT hash FROM blocks WHERE hash = ? AND network_id = ?)", hash, db.network).Scan(&val) if err != nil { return false, err } return val.Bool, nil } // GetHeaderByNumber selects header using block number. func (db *Database) GetHeaderByNumber(number *big.Int) (header *DBHeader, err error) { header = &DBHeader{Hash: common.Hash{}, Number: new(big.Int)} err = db.db.QueryRow("SELECT hash,number FROM blocks WHERE number = ? AND network_id = ?", (*SQLBigInt)(number), db.network).Scan(&header.Hash, (*SQLBigInt)(header.Number)) if err == nil { return header, nil } if err == sql.ErrNoRows { return nil, nil } return nil, err } func (db *Database) GetLastHead() (header *DBHeader, err error) { header = &DBHeader{Hash: common.Hash{}, Number: new(big.Int)} err = db.db.QueryRow("SELECT hash,number FROM blocks WHERE network_id = $1 AND head = 1 AND number = (SELECT MAX(number) FROM blocks WHERE network_id = $1)", db.network).Scan(&header.Hash, (*SQLBigInt)(header.Number)) if err == nil { return header, nil } if err == sql.ErrNoRows { return nil, nil } return nil, err } // GetLatestSynced downloads last synced block with a given option. func (db *Database) GetLatestSynced(address common.Address, option SyncOption) (header *DBHeader, err error) { header = &DBHeader{Hash: common.Hash{}, Number: new(big.Int)} err = db.db.QueryRow(` SELECT blocks.hash, blk_number FROM accounts_to_blocks JOIN blocks ON blk_number = blocks.number WHERE blocks.network_id = $1 AND address = $2 AND blk_number = (SELECT MAX(blk_number) FROM accounts_to_blocks WHERE network_id = $1 AND address = $2 AND sync & $3 = $3)`, db.network, address, option).Scan(&header.Hash, (*SQLBigInt)(header.Number)) if err == nil { return header, nil } if err == sql.ErrNoRows { return nil, nil } return nil, err } type Token struct { Address common.Address `json:"address"` Name string `json:"name"` Symbol string `json:"symbol"` Color string `json:"color"` // Decimals defines how divisible the token is. For example, 0 would be // indivisible, whereas 18 would allow very small amounts of the token // to be traded. Decimals uint `json:"decimals"` } func (db *Database) GetCustomTokens() ([]*Token, error) { rows, err := db.db.Query(`SELECT address, name, symbol, decimals, color FROM tokens WHERE network_id = ?`, db.network) if err != nil { return nil, err } defer rows.Close() var rst []*Token for rows.Next() { token := &Token{} err := rows.Scan(&token.Address, &token.Name, &token.Symbol, &token.Decimals, &token.Color) if err != nil { return nil, err } rst = append(rst, token) } return rst, nil } func (db *Database) AddCustomToken(token Token) error { insert, err := db.db.Prepare("INSERT OR REPLACE INTO TOKENS (network_id, address, name, symbol, decimals, color) VALUES (?, ?, ?, ?, ?, ?)") if err != nil { return err } _, err = insert.Exec(db.network, token.Address, token.Name, token.Symbol, token.Decimals, token.Color) return err } func (db *Database) DeleteCustomToken(address common.Address) error { _, err := db.db.Exec(`DELETE FROM TOKENS WHERE address = ?`, address) return err } // statementCreator allows to pass transaction or database to use in consumer. type statementCreator interface { Prepare(query string) (*sql.Stmt, error) } func deleteHeaders(creator statementCreator, headers []*DBHeader) error { delete, err := creator.Prepare("DELETE FROM blocks WHERE hash = ?") if err != nil { return err } for _, h := range headers { _, err = delete.Exec(h.Hash) if err != nil { return err } } return nil } func insertHeaders(creator statementCreator, network uint64, headers []*DBHeader) error { insert, err := creator.Prepare("INSERT OR IGNORE INTO blocks(network_id, hash, number, timestamp, head) VALUES (?, ?, ?, ?, ?)") if err != nil { return err } for _, h := range headers { _, err = insert.Exec(network, h.Hash, (*SQLBigInt)(h.Number), h.Timestamp, h.Head) if err != nil { return err } } return nil } func insertTransfers(creator statementCreator, network uint64, transfers []Transfer) error { insert, err := creator.Prepare("INSERT OR IGNORE INTO transfers(network_id, hash, blk_hash, address, tx, sender, receipt, log, type) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { return err } for _, t := range transfers { _, err = insert.Exec(network, t.ID, t.BlockHash, t.Address, &JSONBlob{t.Transaction}, t.From, &JSONBlob{t.Receipt}, &JSONBlob{t.Log}, t.Type) if err != nil { return err } } return nil } func updateAccounts(creator statementCreator, network uint64, accounts []common.Address, headers []*DBHeader, option SyncOption) error { update, err := creator.Prepare("UPDATE accounts_to_blocks SET sync=sync|? WHERE address=? AND blk_number=? AND network_id=?") if err != nil { return err } insert, err := creator.Prepare("INSERT OR IGNORE INTO accounts_to_blocks(network_id,address,blk_number,sync) VALUES(?,?,?,?)") if err != nil { return err } for _, acc := range accounts { for _, h := range headers { rst, err := update.Exec(option, acc, (*SQLBigInt)(h.Number), network) if err != nil { return err } affected, err := rst.RowsAffected() if err != nil { return err } if affected > 0 { continue } _, err = insert.Exec(network, acc, (*SQLBigInt)(h.Number), option) if err != nil { return err } } } return nil }