status-go/multiaccounts/accounts/keycard_database.go

478 lines
9.9 KiB
Go
Raw Normal View History

package accounts
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/protocol/protobuf"
)
var (
errKeycardDbTransactionIsNil = errors.New("keycard: database transaction is nil")
errCannotAddKeycardForUnknownKeypair = errors.New("keycard: cannot add keycard for an unknown keyapir")
ErrNoKeycardForPassedKeycardUID = errors.New("keycard: no keycard for the passed keycard uid")
)
type Keycard struct {
KeycardUID string `json:"keycard-uid"`
KeycardName string `json:"keycard-name"`
KeycardLocked bool `json:"keycard-locked"`
AccountsAddresses []types.Address `json:"accounts-addresses"`
KeyUID string `json:"key-uid"`
Position uint64
}
func (kp *Keycard) ToSyncKeycard() *protobuf.SyncKeycard {
kc := &protobuf.SyncKeycard{
Uid: kp.KeycardUID,
Name: kp.KeycardName,
Locked: kp.KeycardLocked,
KeyUid: kp.KeyUID,
Position: kp.Position,
}
for _, addr := range kp.AccountsAddresses {
kc.Addresses = append(kc.Addresses, addr.Bytes())
}
return kc
}
func (kp *Keycard) FromSyncKeycard(kc *protobuf.SyncKeycard) {
kp.KeycardUID = kc.Uid
kp.KeycardName = kc.Name
kp.KeycardLocked = kc.Locked
kp.KeyUID = kc.KeyUid
kp.Position = kc.Position
for _, addr := range kc.Addresses {
kp.AccountsAddresses = append(kp.AccountsAddresses, types.BytesToAddress(addr))
}
}
func containsAddress(addresses []types.Address, address types.Address) bool {
for _, addr := range addresses {
if addr == address {
return true
}
}
return false
}
func (db *Database) processResult(rows *sql.Rows) ([]*Keycard, error) {
keycards := []*Keycard{}
for rows.Next() {
keycard := &Keycard{}
var accAddress sql.NullString
err := rows.Scan(&keycard.KeycardUID, &keycard.KeycardName, &keycard.KeycardLocked, &accAddress, &keycard.KeyUID,
&keycard.Position)
if err != nil {
return nil, err
}
addr := types.Address{}
if accAddress.Valid {
addr = types.BytesToAddress([]byte(accAddress.String))
}
foundAtIndex := -1
for i := range keycards {
if keycards[i].KeycardUID == keycard.KeycardUID {
foundAtIndex = i
break
}
}
if foundAtIndex == -1 {
keycard.AccountsAddresses = append(keycard.AccountsAddresses, addr)
keycards = append(keycards, keycard)
} else {
if containsAddress(keycards[foundAtIndex].AccountsAddresses, addr) {
continue
}
keycards[foundAtIndex].AccountsAddresses = append(keycards[foundAtIndex].AccountsAddresses, addr)
}
}
return keycards, nil
}
func (db *Database) getKeycards(tx *sql.Tx, keyUID string, keycardUID string) ([]*Keycard, error) {
query := `
2023-01-13 17:12:46 +00:00
SELECT
kc.keycard_uid,
kc.keycard_name,
kc.keycard_locked,
ka.account_address,
kc.key_uid,
kc.position
2023-01-13 17:12:46 +00:00
FROM
keycards AS kc
LEFT JOIN
keycards_accounts AS ka
ON
kc.keycard_uid = ka.keycard_uid
LEFT JOIN
keypairs_accounts AS kpa
ON
ka.account_address = kpa.address
%s
2023-01-13 17:12:46 +00:00
ORDER BY
kc.position, kpa.position`
var where string
var args []interface{}
if keyUID != "" {
where = "WHERE kc.key_uid = ?"
args = append(args, keyUID)
if keycardUID != "" {
where += " AND kc.keycard_uid = ?"
args = append(args, keycardUID)
}
} else if keycardUID != "" {
where = "WHERE kc.keycard_uid = ?"
args = append(args, keycardUID)
}
query = fmt.Sprintf(query, where)
var (
stmt *sql.Stmt
err error
)
if tx == nil {
stmt, err = db.db.Prepare(query)
} else {
stmt, err = tx.Prepare(query)
}
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
defer rows.Close()
return db.processResult(rows)
}
func (db *Database) getKeycardByKeycardUID(tx *sql.Tx, keycardUID string) (*Keycard, error) {
keycards, err := db.getKeycards(tx, "", keycardUID)
if err != nil {
return nil, err
}
if len(keycards) == 0 {
return nil, ErrNoKeycardForPassedKeycardUID
}
return keycards[0], nil
}
func (db *Database) GetAllKnownKeycards() ([]*Keycard, error) {
return db.getKeycards(nil, "", "")
}
func (db *Database) GetKeycardsWithSameKeyUID(keyUID string) ([]*Keycard, error) {
return db.getKeycards(nil, keyUID, "")
}
func (db *Database) GetKeycardByKeycardUID(keycardUID string) (*Keycard, error) {
return db.getKeycardByKeycardUID(nil, keycardUID)
}
func (db *Database) saveOrUpdateKeycardAccounts(tx *sql.Tx, kcUID string, accountsAddresses []types.Address) (err error) {
if tx == nil {
return errKeycardDbTransactionIsNil
}
for i := range accountsAddresses {
addr := accountsAddresses[i]
_, err = tx.Exec(`
INSERT OR IGNORE INTO
keycards_accounts
(
keycard_uid,
account_address
)
VALUES
(?, ?);
`, kcUID, addr)
if err != nil {
return err
}
}
return nil
}
func (db *Database) deleteKeycard(tx *sql.Tx, kcUID string) (err error) {
if tx == nil {
return errKeycardDbTransactionIsNil
}
delete, err := tx.Prepare(`
DELETE
FROM
keycards
WHERE
keycard_uid = ?
`)
if err != nil {
return err
}
defer delete.Close()
_, err = delete.Exec(kcUID)
return err
}
func (db *Database) deleteAllKeycardsWithKeyUID(tx *sql.Tx, keyUID string) (err error) {
if tx == nil {
return errKeycardDbTransactionIsNil
}
delete, err := tx.Prepare(`
DELETE
FROM
keycards
WHERE
key_uid = ?
`)
if err != nil {
return err
}
defer delete.Close()
_, err = delete.Exec(keyUID)
return err
}
func (db *Database) deleteKeycardAccounts(tx *sql.Tx, kcUID string, accountAddresses []types.Address) (err error) {
if tx == nil {
return errKeycardDbTransactionIsNil
}
inVector := strings.Repeat(",?", len(accountAddresses)-1)
query := `
DELETE
FROM
keycards_accounts
WHERE
keycard_uid = ?
AND
account_address IN (?` + inVector + `)`
delete, err := tx.Prepare(query)
if err != nil {
return err
}
defer delete.Close()
args := make([]interface{}, len(accountAddresses)+1)
args[0] = kcUID
for i, addr := range accountAddresses {
args[i+1] = addr
}
_, err = delete.Exec(args...)
return err
}
func (db *Database) SaveOrUpdateKeycard(keycard Keycard, clock uint64, updateKeypairClock bool) error {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
relatedKeypairExists, err := db.keypairExists(tx, keycard.KeyUID)
if err != nil {
return err
}
if !relatedKeypairExists {
return errCannotAddKeycardForUnknownKeypair
}
_, err = tx.Exec(`
INSERT OR IGNORE INTO
keycards
(
keycard_uid,
keycard_name,
key_uid
)
VALUES
(?, ?, ?);
UPDATE
keycards
SET
keycard_name = ?,
keycard_locked = ?,
position = ?
WHERE
keycard_uid = ?;
`, keycard.KeycardUID, keycard.KeycardName, keycard.KeyUID,
keycard.KeycardName, keycard.KeycardLocked, keycard.Position, keycard.KeycardUID)
if err != nil {
return err
}
err = db.saveOrUpdateKeycardAccounts(tx, keycard.KeycardUID, keycard.AccountsAddresses)
if err != nil {
return err
}
if updateKeypairClock {
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
}
return nil
}
func (db *Database) execKeycardUpdateQuery(kcUID string, clock uint64, field string, value interface{}) (err error) {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE keycards SET %s = ? WHERE keycard_uid = ?`, field) // nolint: gosec
_, err = tx.Exec(sql, value, kcUID)
if err != nil {
return err
}
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
}
func (db *Database) KeycardLocked(kcUID string, clock uint64) (err error) {
return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", true)
}
func (db *Database) KeycardUnlocked(kcUID string, clock uint64) (err error) {
return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", false)
}
func (db *Database) UpdateKeycardUID(oldKcUID string, newKcUID string, clock uint64) (err error) {
return db.execKeycardUpdateQuery(oldKcUID, clock, "keycard_uid", newKcUID)
}
func (db *Database) SetKeycardName(kcUID string, kpName string, clock uint64) (err error) {
return db.execKeycardUpdateQuery(kcUID, clock, "keycard_name", kpName)
}
func (db *Database) DeleteKeycardAccounts(kcUID string, accountAddresses []types.Address, clock uint64) (err error) {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
if err != nil {
return err
}
err = db.deleteKeycardAccounts(tx, kcUID, accountAddresses)
if err != nil {
return err
}
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
}
func (db *Database) DeleteKeycard(kcUID string, clock uint64) (err error) {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
if err != nil {
return err
}
err = db.deleteKeycard(tx, kcUID)
if err != nil {
return err
}
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
}
func (db *Database) DeleteAllKeycardsWithKeyUID(keyUID string, clock uint64) (err error) {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
err = db.deleteAllKeycardsWithKeyUID(tx, keyUID)
if err != nil {
return err
}
return db.updateKeypairClock(tx, keyUID, clock)
}
func (db *Database) GetPositionForNextNewKeycard() (uint64, error) {
var pos sql.NullInt64
err := db.db.QueryRow("SELECT MAX(position) FROM keycards").Scan(&pos)
if err != nil {
return 0, err
}
if pos.Valid {
return uint64(pos.Int64) + 1, nil
}
return 0, nil
}