478 lines
9.9 KiB
Go
478 lines
9.9 KiB
Go
package accounts
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/status-im/status-go/eth-node/types"
|
|
"github.com/status-im/status-go/protocol/protobuf"
|
|
)
|
|
|
|
var (
|
|
errKeycardDbTransactionIsNil = errors.New("keycard: database transaction is nil")
|
|
errCannotAddKeycardForUnknownKeypair = errors.New("keycard: cannot add keycard for an unknown keyapir")
|
|
ErrNoKeycardForPassedKeycardUID = errors.New("keycard: no keycard for the passed keycard uid")
|
|
)
|
|
|
|
type Keycard struct {
|
|
KeycardUID string `json:"keycard-uid"`
|
|
KeycardName string `json:"keycard-name"`
|
|
KeycardLocked bool `json:"keycard-locked"`
|
|
AccountsAddresses []types.Address `json:"accounts-addresses"`
|
|
KeyUID string `json:"key-uid"`
|
|
Position uint64
|
|
}
|
|
|
|
func (kp *Keycard) ToSyncKeycard() *protobuf.SyncKeycard {
|
|
kc := &protobuf.SyncKeycard{
|
|
Uid: kp.KeycardUID,
|
|
Name: kp.KeycardName,
|
|
Locked: kp.KeycardLocked,
|
|
KeyUid: kp.KeyUID,
|
|
Position: kp.Position,
|
|
}
|
|
|
|
for _, addr := range kp.AccountsAddresses {
|
|
kc.Addresses = append(kc.Addresses, addr.Bytes())
|
|
}
|
|
|
|
return kc
|
|
}
|
|
|
|
func (kp *Keycard) FromSyncKeycard(kc *protobuf.SyncKeycard) {
|
|
kp.KeycardUID = kc.Uid
|
|
kp.KeycardName = kc.Name
|
|
kp.KeycardLocked = kc.Locked
|
|
kp.KeyUID = kc.KeyUid
|
|
kp.Position = kc.Position
|
|
|
|
for _, addr := range kc.Addresses {
|
|
kp.AccountsAddresses = append(kp.AccountsAddresses, types.BytesToAddress(addr))
|
|
}
|
|
}
|
|
|
|
func containsAddress(addresses []types.Address, address types.Address) bool {
|
|
for _, addr := range addresses {
|
|
if addr == address {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (db *Database) processResult(rows *sql.Rows) ([]*Keycard, error) {
|
|
keycards := []*Keycard{}
|
|
for rows.Next() {
|
|
keycard := &Keycard{}
|
|
var accAddress sql.NullString
|
|
err := rows.Scan(&keycard.KeycardUID, &keycard.KeycardName, &keycard.KeycardLocked, &accAddress, &keycard.KeyUID,
|
|
&keycard.Position)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
addr := types.Address{}
|
|
if accAddress.Valid {
|
|
addr = types.BytesToAddress([]byte(accAddress.String))
|
|
}
|
|
|
|
foundAtIndex := -1
|
|
for i := range keycards {
|
|
if keycards[i].KeycardUID == keycard.KeycardUID {
|
|
foundAtIndex = i
|
|
break
|
|
}
|
|
}
|
|
if foundAtIndex == -1 {
|
|
keycard.AccountsAddresses = append(keycard.AccountsAddresses, addr)
|
|
keycards = append(keycards, keycard)
|
|
} else {
|
|
if containsAddress(keycards[foundAtIndex].AccountsAddresses, addr) {
|
|
continue
|
|
}
|
|
keycards[foundAtIndex].AccountsAddresses = append(keycards[foundAtIndex].AccountsAddresses, addr)
|
|
}
|
|
}
|
|
|
|
return keycards, nil
|
|
}
|
|
|
|
func (db *Database) getKeycards(tx *sql.Tx, keyUID string, keycardUID string) ([]*Keycard, error) {
|
|
query := `
|
|
SELECT
|
|
kc.keycard_uid,
|
|
kc.keycard_name,
|
|
kc.keycard_locked,
|
|
ka.account_address,
|
|
kc.key_uid,
|
|
kc.position
|
|
FROM
|
|
keycards AS kc
|
|
LEFT JOIN
|
|
keycards_accounts AS ka
|
|
ON
|
|
kc.keycard_uid = ka.keycard_uid
|
|
LEFT JOIN
|
|
keypairs_accounts AS kpa
|
|
ON
|
|
ka.account_address = kpa.address
|
|
%s
|
|
ORDER BY
|
|
kc.position, kpa.position`
|
|
|
|
var where string
|
|
var args []interface{}
|
|
|
|
if keyUID != "" {
|
|
where = "WHERE kc.key_uid = ?"
|
|
args = append(args, keyUID)
|
|
if keycardUID != "" {
|
|
where += " AND kc.keycard_uid = ?"
|
|
args = append(args, keycardUID)
|
|
}
|
|
} else if keycardUID != "" {
|
|
where = "WHERE kc.keycard_uid = ?"
|
|
args = append(args, keycardUID)
|
|
}
|
|
|
|
query = fmt.Sprintf(query, where)
|
|
|
|
var (
|
|
stmt *sql.Stmt
|
|
err error
|
|
)
|
|
if tx == nil {
|
|
stmt, err = db.db.Prepare(query)
|
|
} else {
|
|
stmt, err = tx.Prepare(query)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query(args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return db.processResult(rows)
|
|
}
|
|
|
|
func (db *Database) getKeycardByKeycardUID(tx *sql.Tx, keycardUID string) (*Keycard, error) {
|
|
keycards, err := db.getKeycards(tx, "", keycardUID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(keycards) == 0 {
|
|
return nil, ErrNoKeycardForPassedKeycardUID
|
|
}
|
|
|
|
return keycards[0], nil
|
|
}
|
|
|
|
func (db *Database) GetAllKnownKeycards() ([]*Keycard, error) {
|
|
return db.getKeycards(nil, "", "")
|
|
}
|
|
|
|
func (db *Database) GetKeycardsWithSameKeyUID(keyUID string) ([]*Keycard, error) {
|
|
return db.getKeycards(nil, keyUID, "")
|
|
}
|
|
|
|
func (db *Database) GetKeycardByKeycardUID(keycardUID string) (*Keycard, error) {
|
|
return db.getKeycardByKeycardUID(nil, keycardUID)
|
|
}
|
|
|
|
func (db *Database) saveOrUpdateKeycardAccounts(tx *sql.Tx, kcUID string, accountsAddresses []types.Address) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
for i := range accountsAddresses {
|
|
addr := accountsAddresses[i]
|
|
|
|
_, err = tx.Exec(`
|
|
INSERT OR IGNORE INTO
|
|
keycards_accounts
|
|
(
|
|
keycard_uid,
|
|
account_address
|
|
)
|
|
VALUES
|
|
(?, ?);
|
|
`, kcUID, addr)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Database) deleteKeycard(tx *sql.Tx, kcUID string) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
delete, err := tx.Prepare(`
|
|
DELETE
|
|
FROM
|
|
keycards
|
|
WHERE
|
|
keycard_uid = ?
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer delete.Close()
|
|
|
|
_, err = delete.Exec(kcUID)
|
|
|
|
return err
|
|
}
|
|
|
|
func (db *Database) deleteAllKeycardsWithKeyUID(tx *sql.Tx, keyUID string) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
delete, err := tx.Prepare(`
|
|
DELETE
|
|
FROM
|
|
keycards
|
|
WHERE
|
|
key_uid = ?
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer delete.Close()
|
|
|
|
_, err = delete.Exec(keyUID)
|
|
return err
|
|
}
|
|
|
|
func (db *Database) deleteKeycardAccounts(tx *sql.Tx, kcUID string, accountAddresses []types.Address) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
inVector := strings.Repeat(",?", len(accountAddresses)-1)
|
|
query := `
|
|
DELETE
|
|
FROM
|
|
keycards_accounts
|
|
WHERE
|
|
keycard_uid = ?
|
|
AND
|
|
account_address IN (?` + inVector + `)`
|
|
|
|
delete, err := tx.Prepare(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer delete.Close()
|
|
|
|
args := make([]interface{}, len(accountAddresses)+1)
|
|
args[0] = kcUID
|
|
for i, addr := range accountAddresses {
|
|
args[i+1] = addr
|
|
}
|
|
|
|
_, err = delete.Exec(args...)
|
|
|
|
return err
|
|
}
|
|
|
|
func (db *Database) SaveOrUpdateKeycard(keycard Keycard, clock uint64, updateKeypairClock bool) error {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
relatedKeypairExists, err := db.keypairExists(tx, keycard.KeyUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !relatedKeypairExists {
|
|
return errCannotAddKeycardForUnknownKeypair
|
|
}
|
|
|
|
_, err = tx.Exec(`
|
|
INSERT OR IGNORE INTO
|
|
keycards
|
|
(
|
|
keycard_uid,
|
|
keycard_name,
|
|
key_uid
|
|
)
|
|
VALUES
|
|
(?, ?, ?);
|
|
|
|
UPDATE
|
|
keycards
|
|
SET
|
|
keycard_name = ?,
|
|
keycard_locked = ?,
|
|
position = ?
|
|
WHERE
|
|
keycard_uid = ?;
|
|
`, keycard.KeycardUID, keycard.KeycardName, keycard.KeyUID,
|
|
keycard.KeycardName, keycard.KeycardLocked, keycard.Position, keycard.KeycardUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = db.saveOrUpdateKeycardAccounts(tx, keycard.KeycardUID, keycard.AccountsAddresses)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if updateKeypairClock {
|
|
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Database) execKeycardUpdateQuery(kcUID string, clock uint64, field string, value interface{}) (err error) {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sql := fmt.Sprintf(`UPDATE keycards SET %s = ? WHERE keycard_uid = ?`, field) // nolint: gosec
|
|
_, err = tx.Exec(sql, value, kcUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
|
|
}
|
|
|
|
func (db *Database) KeycardLocked(kcUID string, clock uint64) (err error) {
|
|
return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", true)
|
|
}
|
|
|
|
func (db *Database) KeycardUnlocked(kcUID string, clock uint64) (err error) {
|
|
return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", false)
|
|
}
|
|
|
|
func (db *Database) UpdateKeycardUID(oldKcUID string, newKcUID string, clock uint64) (err error) {
|
|
return db.execKeycardUpdateQuery(oldKcUID, clock, "keycard_uid", newKcUID)
|
|
}
|
|
|
|
func (db *Database) SetKeycardName(kcUID string, kpName string, clock uint64) (err error) {
|
|
return db.execKeycardUpdateQuery(kcUID, clock, "keycard_name", kpName)
|
|
}
|
|
|
|
func (db *Database) DeleteKeycardAccounts(kcUID string, accountAddresses []types.Address, clock uint64) (err error) {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = db.deleteKeycardAccounts(tx, kcUID, accountAddresses)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
|
|
}
|
|
|
|
func (db *Database) DeleteKeycard(kcUID string, clock uint64) (err error) {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = db.deleteKeycard(tx, kcUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.updateKeypairClock(tx, keycard.KeyUID, clock)
|
|
}
|
|
|
|
func (db *Database) DeleteAllKeycardsWithKeyUID(keyUID string, clock uint64) (err error) {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
err = db.deleteAllKeycardsWithKeyUID(tx, keyUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.updateKeypairClock(tx, keyUID, clock)
|
|
}
|
|
|
|
func (db *Database) GetPositionForNextNewKeycard() (uint64, error) {
|
|
var pos sql.NullInt64
|
|
err := db.db.QueryRow("SELECT MAX(position) FROM keycards").Scan(&pos)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if pos.Valid {
|
|
return uint64(pos.Int64) + 1, nil
|
|
}
|
|
return 0, nil
|
|
}
|