status-go/multiaccounts/keypairs/database.go

580 lines
12 KiB
Go

package keypairs
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/protocol/protobuf"
)
var errDbTransactionIsNil = errors.New("keycard: database transaction is nil")
type KeyPair struct {
KeycardUID string `json:"keycard-uid"`
KeycardName string `json:"keycard-name"`
KeycardLocked bool `json:"keycard-locked"`
AccountsAddresses []types.Address `json:"accounts-addresses"`
KeyUID string `json:"key-uid"`
LastUpdateClock uint64
}
type KeycardAction struct {
Action string `json:"action"`
OldKeycardUID string `json:"old-keycard-uid,omitempty"`
Keycard *KeyPair `json:"keycard"`
}
func (kp *KeyPair) ToSyncKeycard() *protobuf.SyncKeycard {
kc := &protobuf.SyncKeycard{
Uid: kp.KeycardUID,
Name: kp.KeycardName,
Locked: kp.KeycardLocked,
KeyUid: kp.KeyUID,
Clock: kp.LastUpdateClock,
}
for _, addr := range kp.AccountsAddresses {
kc.Addresses = append(kc.Addresses, addr.Bytes())
}
return kc
}
func (kp *KeyPair) FromSyncKeycard(kc *protobuf.SyncKeycard) {
kp.KeycardUID = kc.Uid
kp.KeycardName = kc.Name
kp.KeycardLocked = kc.Locked
kp.KeyUID = kc.KeyUid
kp.LastUpdateClock = kc.Clock
for _, addr := range kc.Addresses {
kp.AccountsAddresses = append(kp.AccountsAddresses, types.BytesToAddress(addr))
}
}
func removeElementAtIndex[T any](s []T, index int) []T {
if index < 0 || index >= len(s) {
panic("keycard: index out of the range")
}
return append(s[:index], s[index+1:]...)
}
type KeyPairs struct {
db *sql.DB
}
func NewKeyPairs(db *sql.DB) *KeyPairs {
return &KeyPairs{
db: db,
}
}
func containsAddress(addresses []types.Address, address types.Address) bool {
for _, addr := range addresses {
if addr == address {
return true
}
}
return false
}
func (kp *KeyPairs) processResult(rows *sql.Rows, groupByKeycard bool) ([]*KeyPair, error) {
keyPairs := []*KeyPair{}
for rows.Next() {
keyPair := &KeyPair{}
addr := types.Address{}
err := rows.Scan(&keyPair.KeycardUID, &keyPair.KeycardName, &keyPair.KeycardLocked, &addr, &keyPair.KeyUID,
&keyPair.LastUpdateClock)
if err != nil {
return nil, err
}
foundAtIndex := -1
for i := range keyPairs {
if groupByKeycard {
if keyPairs[i].KeycardUID == keyPair.KeycardUID {
foundAtIndex = i
break
}
} else {
if keyPairs[i].KeyUID == keyPair.KeyUID {
foundAtIndex = i
break
}
}
}
if foundAtIndex == -1 {
keyPair.AccountsAddresses = append(keyPair.AccountsAddresses, addr)
keyPairs = append(keyPairs, keyPair)
} else {
if containsAddress(keyPairs[foundAtIndex].AccountsAddresses, addr) {
continue
}
keyPairs[foundAtIndex].AccountsAddresses = append(keyPairs[foundAtIndex].AccountsAddresses, addr)
}
}
return keyPairs, nil
}
func (kp *KeyPairs) getAllRows(groupByKeycard bool) ([]*KeyPair, error) {
rows, err := kp.db.Query(`
SELECT
k.keycard_uid,
k.keycard_name,
k.keycard_locked,
ka.account_address,
k.key_uid,
k.last_update_clock
FROM
keycards AS k
LEFT JOIN
keycards_accounts AS ka
ON
k.keycard_uid = ka.keycard_uid
ORDER BY
key_uid
`)
if err != nil {
return nil, err
}
defer rows.Close()
return kp.processResult(rows, groupByKeycard)
}
func (kp *KeyPairs) GetAllKnownKeycards() ([]*KeyPair, error) {
return kp.getAllRows(true)
}
func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) {
return kp.getAllRows(false)
}
func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error) {
rows, err := kp.db.Query(`
SELECT
k.keycard_uid,
k.keycard_name,
k.keycard_locked,
ka.account_address,
k.key_uid,
k.last_update_clock
FROM
keycards AS k
LEFT JOIN
keycards_accounts AS ka
ON
k.keycard_uid = ka.keycard_uid
WHERE
k.key_uid = ?
ORDER BY
k.keycard_uid
`, keyUID)
if err != nil {
return nil, err
}
defer rows.Close()
return kp.processResult(rows, false)
}
func (kp *KeyPairs) startTransactionAndCheckIfNeedToProceed(kcUID string, clock uint64) (tx *sql.Tx, proceed bool, err error) {
tx, err = kp.db.Begin()
if err != nil {
return nil, false, err
}
var dbLastUpdateClock uint64
err = tx.QueryRow(`SELECT last_update_clock FROM keycards WHERE keycard_uid = ?`, kcUID).Scan(&dbLastUpdateClock)
if err != nil {
return tx, err == sql.ErrNoRows, err
}
return tx, dbLastUpdateClock <= clock, nil
}
func (kp *KeyPairs) setLastUpdateClock(tx *sql.Tx, kcUID string, clock uint64) (err error) {
if tx == nil {
return errDbTransactionIsNil
}
_, err = tx.Exec(`
UPDATE
keycards
SET
last_update_clock = ?
WHERE
keycard_uid = ?`,
clock, kcUID)
return err
}
func (kp *KeyPairs) getAccountsForKeycard(tx *sql.Tx, kcUID string) ([]types.Address, error) {
var accountAddresses []types.Address
if tx == nil {
return accountAddresses, errDbTransactionIsNil
}
rows, err := tx.Query(`SELECT account_address FROM keycards_accounts WHERE keycard_uid = ?`, kcUID)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var accAddress types.Address
err = rows.Scan(&accAddress)
if err != nil {
return nil, err
}
accountAddresses = append(accountAddresses, accAddress)
}
return accountAddresses, nil
}
func (kp *KeyPairs) addAccounts(tx *sql.Tx, kcUID string, accountsAddresses []types.Address) (err error) {
if tx == nil {
return errDbTransactionIsNil
}
insertKcAcc, err := tx.Prepare(`
INSERT INTO
keycards_accounts
(
keycard_uid,
account_address
)
VALUES
(?, ?);
`)
if err != nil {
return err
}
defer insertKcAcc.Close()
for i := range accountsAddresses {
addr := accountsAddresses[i]
_, err = insertKcAcc.Exec(kcUID, addr)
if err != nil {
return err
}
}
return nil
}
func (kp *KeyPairs) deleteKeycard(tx *sql.Tx, kcUID string) (err error) {
if tx == nil {
return errDbTransactionIsNil
}
delete, err := tx.Prepare(`
DELETE
FROM
keycards
WHERE
keycard_uid = ?
`)
if err != nil {
return err
}
defer delete.Close()
_, err = delete.Exec(kcUID)
return err
}
func (kp *KeyPairs) AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(keyPair KeyPair) (addedKc bool, addedAccs bool, err error) {
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(keyPair.KeycardUID, keyPair.LastUpdateClock)
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
if proceed {
// insert only if there is no such keycard, otherwise just add accounts
if err != nil && err == sql.ErrNoRows {
_, err = tx.Exec(`
INSERT INTO
keycards
(
keycard_uid,
keycard_name,
keycard_locked,
key_uid,
last_update_clock
)
VALUES
(?, ?, ?, ?, ?);`,
keyPair.KeycardUID, keyPair.KeycardName, keyPair.KeycardLocked, keyPair.KeyUID, keyPair.LastUpdateClock)
if err != nil {
return false, false, err
}
err = kp.addAccounts(tx, keyPair.KeycardUID, keyPair.AccountsAddresses)
return err == nil, false, err
}
err = kp.setLastUpdateClock(tx, keyPair.KeycardUID, keyPair.LastUpdateClock)
if err != nil {
return false, false, err
}
err = kp.addAccounts(tx, keyPair.KeycardUID, keyPair.AccountsAddresses)
return false, err == nil, err
}
return false, false, err
}
func (kp *KeyPairs) SyncKeycards(syncingClock uint64, keypairsToSync []*KeyPair) (err error) {
tx, err := kp.db.Begin()
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
rows, err := tx.Query(`SELECT * FROM keycards`)
if err != nil && err != sql.ErrNoRows {
return err
}
defer rows.Close()
var dbKeypairs []*KeyPair
for rows.Next() {
keyPair := &KeyPair{}
err := rows.Scan(&keyPair.KeycardUID, &keyPair.KeycardName, &keyPair.KeycardLocked, &keyPair.KeyUID,
&keyPair.LastUpdateClock)
if err != nil {
return err
}
dbKeypairs = append(dbKeypairs, keyPair)
}
// apply those from `keypairsToSync` which are newer
for _, syncKp := range keypairsToSync {
foundAtIndex := -1
for i := range dbKeypairs {
if dbKeypairs[i].KeycardUID == syncKp.KeycardUID {
foundAtIndex = i
break
}
}
doInsertOrReplace := true
if foundAtIndex > -1 {
if dbKeypairs[foundAtIndex].LastUpdateClock > syncKp.LastUpdateClock {
doInsertOrReplace = false
}
dbKeypairs = removeElementAtIndex(dbKeypairs, foundAtIndex)
}
if doInsertOrReplace {
_, err = tx.Exec(`
INSERT OR REPLACE INTO
keycards
(
keycard_uid,
keycard_name,
keycard_locked,
key_uid,
last_update_clock
)
VALUES
(?, ?, ?, ?, ?);`,
syncKp.KeycardUID, syncKp.KeycardName, syncKp.KeycardLocked, syncKp.KeyUID, syncKp.LastUpdateClock)
if err != nil {
return err
}
err = kp.addAccounts(tx, syncKp.KeycardUID, syncKp.AccountsAddresses)
if err != nil {
return err
}
}
}
// remove those from the db if they are not in `keypairsToSync` and if they are older than the moment `keypairsToSync` was created at
for _, dbKp := range dbKeypairs {
if dbKp.LastUpdateClock > syncingClock {
continue
}
err = kp.deleteKeycard(tx, dbKp.KeycardUID)
if err != nil {
return err
}
}
return nil
}
func (kp *KeyPairs) RemoveMigratedAccountsForKeycard(kcUID string, accountAddresses []types.Address,
clock uint64) (err error) {
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(kcUID, clock)
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
if err != nil {
return err
}
if proceed {
err = kp.setLastUpdateClock(tx, kcUID, clock)
if err != nil {
return err
}
dbAccountAddresses, err := kp.getAccountsForKeycard(tx, kcUID)
if err != nil {
return err
}
deleteKeycard := true
for _, dbAddr := range dbAccountAddresses {
found := false
for _, addr := range accountAddresses {
if dbAddr == addr {
found = true
}
}
if !found {
deleteKeycard = false
}
}
if deleteKeycard {
return kp.deleteKeycard(tx, kcUID)
}
inVector := strings.Repeat(",?", len(accountAddresses)-1)
query := `
DELETE
FROM
keycards_accounts
WHERE
keycard_uid = ?
AND
account_address IN (?` + inVector + `)
`
delete, err := tx.Prepare(query)
if err != nil {
return err
}
args := make([]interface{}, len(accountAddresses)+1)
args[0] = kcUID
for i, addr := range accountAddresses {
args[i+1] = addr
}
defer delete.Close()
_, err = delete.Exec(args...)
return err
}
return err
}
func (kp *KeyPairs) execUpdateQuery(kcUID string, clock uint64, field string, value interface{}) (err error) {
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(kcUID, clock)
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
if err != nil {
return err
}
if proceed {
sql := fmt.Sprintf(`UPDATE keycards SET %s = ?, last_update_clock = ? WHERE keycard_uid = ?`, field) // nolint: gosec
_, err = tx.Exec(sql, value, clock, kcUID)
return err
}
return nil
}
func (kp *KeyPairs) KeycardLocked(kcUID string, clock uint64) (err error) {
return kp.execUpdateQuery(kcUID, clock, "keycard_locked", true)
}
func (kp *KeyPairs) KeycardUnlocked(kcUID string, clock uint64) (err error) {
return kp.execUpdateQuery(kcUID, clock, "keycard_locked", false)
}
func (kp *KeyPairs) UpdateKeycardUID(oldKcUID string, newKcUID string, clock uint64) (err error) {
return kp.execUpdateQuery(oldKcUID, clock, "keycard_uid", newKcUID)
}
func (kp *KeyPairs) SetKeycardName(kcUID string, kpName string, clock uint64) (err error) {
return kp.execUpdateQuery(kcUID, clock, "keycard_name", kpName)
}
func (kp *KeyPairs) DeleteKeycard(kcUID string, clock uint64) (err error) {
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(kcUID, clock)
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
if err != nil {
return err
}
if proceed {
return kp.deleteKeycard(tx, kcUID)
}
return err
}
func (kp *KeyPairs) DeleteKeypair(keyUID string) (err error) {
delete, err := kp.db.Prepare(`
DELETE
FROM
keycards
WHERE
key_uid = ?
`)
if err != nil {
return err
}
defer delete.Close()
_, err = delete.Exec(keyUID)
return err
}