582 lines
12 KiB
Go
582 lines
12 KiB
Go
package accounts
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/status-im/status-go/eth-node/types"
|
|
"github.com/status-im/status-go/protocol/protobuf"
|
|
)
|
|
|
|
var errKeycardDbTransactionIsNil = errors.New("keycard: database transaction is nil")
|
|
|
|
type Keycard struct {
|
|
KeycardUID string `json:"keycard-uid"`
|
|
KeycardName string `json:"keycard-name"`
|
|
KeycardLocked bool `json:"keycard-locked"`
|
|
AccountsAddresses []types.Address `json:"accounts-addresses"`
|
|
KeyUID string `json:"key-uid"`
|
|
LastUpdateClock uint64
|
|
}
|
|
|
|
type KeycardAction struct {
|
|
Action string `json:"action"`
|
|
OldKeycardUID string `json:"old-keycard-uid,omitempty"`
|
|
Keycard *Keycard `json:"keycard"`
|
|
}
|
|
|
|
func (kp *Keycard) ToSyncKeycard() *protobuf.SyncKeycard {
|
|
kc := &protobuf.SyncKeycard{
|
|
Uid: kp.KeycardUID,
|
|
Name: kp.KeycardName,
|
|
Locked: kp.KeycardLocked,
|
|
KeyUid: kp.KeyUID,
|
|
Clock: kp.LastUpdateClock,
|
|
}
|
|
|
|
for _, addr := range kp.AccountsAddresses {
|
|
kc.Addresses = append(kc.Addresses, addr.Bytes())
|
|
}
|
|
|
|
return kc
|
|
}
|
|
|
|
func (kp *Keycard) FromSyncKeycard(kc *protobuf.SyncKeycard) {
|
|
kp.KeycardUID = kc.Uid
|
|
kp.KeycardName = kc.Name
|
|
kp.KeycardLocked = kc.Locked
|
|
kp.KeyUID = kc.KeyUid
|
|
kp.LastUpdateClock = kc.Clock
|
|
|
|
for _, addr := range kc.Addresses {
|
|
kp.AccountsAddresses = append(kp.AccountsAddresses, types.BytesToAddress(addr))
|
|
}
|
|
}
|
|
|
|
func removeElementAtIndex[T any](s []T, index int) []T {
|
|
if index < 0 || index >= len(s) {
|
|
panic("keycard: index out of the range")
|
|
}
|
|
return append(s[:index], s[index+1:]...)
|
|
}
|
|
|
|
type Keycards struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewKeycards(db *sql.DB) *Keycards {
|
|
return &Keycards{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func containsAddress(addresses []types.Address, address types.Address) bool {
|
|
for _, addr := range addresses {
|
|
if addr == address {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kp *Keycards) processResult(rows *sql.Rows, groupByKeycard bool) ([]*Keycard, error) {
|
|
keycards := []*Keycard{}
|
|
for rows.Next() {
|
|
keycard := &Keycard{}
|
|
addr := types.Address{}
|
|
err := rows.Scan(&keycard.KeycardUID, &keycard.KeycardName, &keycard.KeycardLocked, &addr, &keycard.KeyUID,
|
|
&keycard.LastUpdateClock)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
foundAtIndex := -1
|
|
for i := range keycards {
|
|
if groupByKeycard {
|
|
if keycards[i].KeycardUID == keycard.KeycardUID {
|
|
foundAtIndex = i
|
|
break
|
|
}
|
|
} else {
|
|
if keycards[i].KeyUID == keycard.KeyUID {
|
|
foundAtIndex = i
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if foundAtIndex == -1 {
|
|
keycard.AccountsAddresses = append(keycard.AccountsAddresses, addr)
|
|
keycards = append(keycards, keycard)
|
|
} else {
|
|
if containsAddress(keycards[foundAtIndex].AccountsAddresses, addr) {
|
|
continue
|
|
}
|
|
keycards[foundAtIndex].AccountsAddresses = append(keycards[foundAtIndex].AccountsAddresses, addr)
|
|
}
|
|
}
|
|
|
|
return keycards, nil
|
|
}
|
|
|
|
func (kp *Keycards) getAllRows(groupByKeycard bool) ([]*Keycard, error) {
|
|
rows, err := kp.db.Query(`
|
|
SELECT
|
|
k.keycard_uid,
|
|
k.keycard_name,
|
|
k.keycard_locked,
|
|
ka.account_address,
|
|
k.key_uid,
|
|
k.last_update_clock
|
|
FROM
|
|
keycards AS k
|
|
LEFT JOIN
|
|
keycards_accounts AS ka
|
|
ON
|
|
k.keycard_uid = ka.keycard_uid
|
|
ORDER BY
|
|
key_uid
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer rows.Close()
|
|
return kp.processResult(rows, groupByKeycard)
|
|
}
|
|
|
|
func (kp *Keycards) GetAllKnownKeycards() ([]*Keycard, error) {
|
|
return kp.getAllRows(true)
|
|
}
|
|
|
|
func (kp *Keycards) GetAllKnownKeycardsGroupedByKeyUID() ([]*Keycard, error) {
|
|
return kp.getAllRows(false)
|
|
}
|
|
|
|
func (kp *Keycards) GetKeycardByKeyUID(keyUID string) ([]*Keycard, error) {
|
|
rows, err := kp.db.Query(`
|
|
SELECT
|
|
k.keycard_uid,
|
|
k.keycard_name,
|
|
k.keycard_locked,
|
|
ka.account_address,
|
|
k.key_uid,
|
|
k.last_update_clock
|
|
FROM
|
|
keycards AS k
|
|
LEFT JOIN
|
|
keycards_accounts AS ka
|
|
ON
|
|
k.keycard_uid = ka.keycard_uid
|
|
WHERE
|
|
k.key_uid = ?
|
|
ORDER BY
|
|
k.keycard_uid
|
|
`, keyUID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return []*Keycard{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
defer rows.Close()
|
|
return kp.processResult(rows, false)
|
|
}
|
|
|
|
func (kp *Keycards) startTransactionAndCheckIfNeedToProceed(kcUID string, clock uint64) (tx *sql.Tx, proceed bool, err error) {
|
|
tx, err = kp.db.Begin()
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
var dbLastUpdateClock uint64
|
|
err = tx.QueryRow(`SELECT last_update_clock FROM keycards WHERE keycard_uid = ?`, kcUID).Scan(&dbLastUpdateClock)
|
|
if err != nil {
|
|
return tx, err == sql.ErrNoRows, err
|
|
}
|
|
|
|
return tx, dbLastUpdateClock <= clock, nil
|
|
}
|
|
|
|
func (kp *Keycards) setLastUpdateClock(tx *sql.Tx, kcUID string, clock uint64) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
_, err = tx.Exec(`
|
|
UPDATE
|
|
keycards
|
|
SET
|
|
last_update_clock = ?
|
|
WHERE
|
|
keycard_uid = ?`,
|
|
clock, kcUID)
|
|
|
|
return err
|
|
}
|
|
|
|
func (kp *Keycards) getAccountsForKeycard(tx *sql.Tx, kcUID string) ([]types.Address, error) {
|
|
var accountAddresses []types.Address
|
|
if tx == nil {
|
|
return accountAddresses, errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
rows, err := tx.Query(`SELECT account_address FROM keycards_accounts WHERE keycard_uid = ?`, kcUID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var accAddress types.Address
|
|
err = rows.Scan(&accAddress)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
accountAddresses = append(accountAddresses, accAddress)
|
|
}
|
|
|
|
return accountAddresses, nil
|
|
}
|
|
|
|
func (kp *Keycards) addAccounts(tx *sql.Tx, kcUID string, accountsAddresses []types.Address) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
insertKcAcc, err := tx.Prepare(`
|
|
INSERT INTO
|
|
keycards_accounts
|
|
(
|
|
keycard_uid,
|
|
account_address
|
|
)
|
|
VALUES
|
|
(?, ?);
|
|
`)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer insertKcAcc.Close()
|
|
|
|
for i := range accountsAddresses {
|
|
addr := accountsAddresses[i]
|
|
|
|
_, err = insertKcAcc.Exec(kcUID, addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (kp *Keycards) deleteKeycard(tx *sql.Tx, kcUID string) (err error) {
|
|
if tx == nil {
|
|
return errKeycardDbTransactionIsNil
|
|
}
|
|
|
|
delete, err := tx.Prepare(`
|
|
DELETE
|
|
FROM
|
|
keycards
|
|
WHERE
|
|
keycard_uid = ?
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer delete.Close()
|
|
|
|
_, err = delete.Exec(kcUID)
|
|
|
|
return err
|
|
}
|
|
|
|
func (kp *Keycards) AddKeycardOrAddAccountsIfKeycardIsAdded(keycard Keycard) (addedKc bool, addedAccs bool, err error) {
|
|
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(keycard.KeycardUID, keycard.LastUpdateClock)
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
if proceed {
|
|
// insert only if there is no such keycard, otherwise just add accounts
|
|
if err != nil && err == sql.ErrNoRows {
|
|
_, err = tx.Exec(`
|
|
INSERT INTO
|
|
keycards
|
|
(
|
|
keycard_uid,
|
|
keycard_name,
|
|
keycard_locked,
|
|
key_uid,
|
|
last_update_clock
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?, ?);`,
|
|
keycard.KeycardUID, keycard.KeycardName, keycard.KeycardLocked, keycard.KeyUID, keycard.LastUpdateClock)
|
|
|
|
if err != nil {
|
|
return false, false, err
|
|
}
|
|
|
|
err = kp.addAccounts(tx, keycard.KeycardUID, keycard.AccountsAddresses)
|
|
return err == nil, false, err
|
|
}
|
|
|
|
err = kp.setLastUpdateClock(tx, keycard.KeycardUID, keycard.LastUpdateClock)
|
|
if err != nil {
|
|
return false, false, err
|
|
}
|
|
|
|
err = kp.addAccounts(tx, keycard.KeycardUID, keycard.AccountsAddresses)
|
|
return false, err == nil, err
|
|
}
|
|
|
|
return false, false, err
|
|
}
|
|
|
|
func (kp *Keycards) ApplyKeycardsForKeypairWithKeyUID(keyUID string, keycardsToSync []*Keycard) (err error) {
|
|
tx, err := kp.db.Begin()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
rows, err := tx.Query(`SELECT * FROM keycards WHERE key_uid = ?`, keyUID)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var dbKeycards []*Keycard
|
|
for rows.Next() {
|
|
keycard := &Keycard{}
|
|
err := rows.Scan(&keycard.KeycardUID, &keycard.KeycardName, &keycard.KeycardLocked, &keycard.KeyUID,
|
|
&keycard.LastUpdateClock)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dbKeycards = append(dbKeycards, keycard)
|
|
}
|
|
|
|
// apply those from `keycardsToSync` which are newer
|
|
for _, syncKc := range keycardsToSync {
|
|
foundAtIndex := -1
|
|
for i := range dbKeycards {
|
|
if dbKeycards[i].KeycardUID == syncKc.KeycardUID {
|
|
foundAtIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if foundAtIndex > -1 {
|
|
dbClock := dbKeycards[foundAtIndex].LastUpdateClock
|
|
dbKeycards = removeElementAtIndex(dbKeycards, foundAtIndex)
|
|
|
|
if dbClock > syncKc.LastUpdateClock {
|
|
continue
|
|
}
|
|
err = kp.deleteKeycard(tx, syncKc.KeycardUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
_, err = tx.Exec(`
|
|
INSERT OR REPLACE INTO
|
|
keycards
|
|
(
|
|
keycard_uid,
|
|
keycard_name,
|
|
keycard_locked,
|
|
key_uid,
|
|
last_update_clock
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?, ?);`,
|
|
syncKc.KeycardUID, syncKc.KeycardName, syncKc.KeycardLocked, syncKc.KeyUID, syncKc.LastUpdateClock)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = kp.addAccounts(tx, syncKc.KeycardUID, syncKc.AccountsAddresses)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// remove those from the db if they are not in `keycardsToSync`
|
|
for _, dbKp := range dbKeycards {
|
|
err = kp.deleteKeycard(tx, dbKp.KeycardUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (kp *Keycards) RemoveMigratedAccountsForKeycard(kcUID string, accountAddresses []types.Address,
|
|
clock uint64) (err error) {
|
|
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(kcUID, clock)
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if proceed {
|
|
err = kp.setLastUpdateClock(tx, kcUID, clock)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dbAccountAddresses, err := kp.getAccountsForKeycard(tx, kcUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
deleteKeycard := true
|
|
for _, dbAddr := range dbAccountAddresses {
|
|
found := false
|
|
for _, addr := range accountAddresses {
|
|
if dbAddr == addr {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
deleteKeycard = false
|
|
}
|
|
}
|
|
|
|
if deleteKeycard {
|
|
return kp.deleteKeycard(tx, kcUID)
|
|
}
|
|
|
|
inVector := strings.Repeat(",?", len(accountAddresses)-1)
|
|
query := `
|
|
DELETE
|
|
FROM
|
|
keycards_accounts
|
|
WHERE
|
|
keycard_uid = ?
|
|
AND
|
|
account_address IN (?` + inVector + `)
|
|
`
|
|
delete, err := tx.Prepare(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
args := make([]interface{}, len(accountAddresses)+1)
|
|
args[0] = kcUID
|
|
for i, addr := range accountAddresses {
|
|
args[i+1] = addr
|
|
}
|
|
|
|
defer delete.Close()
|
|
|
|
_, err = delete.Exec(args...)
|
|
|
|
return err
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (kp *Keycards) execUpdateQuery(kcUID string, clock uint64, field string, value interface{}) (err error) {
|
|
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(kcUID, clock)
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if proceed {
|
|
sql := fmt.Sprintf(`UPDATE keycards SET %s = ?, last_update_clock = ? WHERE keycard_uid = ?`, field) // nolint: gosec
|
|
_, err = tx.Exec(sql, value, clock, kcUID)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (kp *Keycards) KeycardLocked(kcUID string, clock uint64) (err error) {
|
|
return kp.execUpdateQuery(kcUID, clock, "keycard_locked", true)
|
|
}
|
|
|
|
func (kp *Keycards) KeycardUnlocked(kcUID string, clock uint64) (err error) {
|
|
return kp.execUpdateQuery(kcUID, clock, "keycard_locked", false)
|
|
}
|
|
|
|
func (kp *Keycards) UpdateKeycardUID(oldKcUID string, newKcUID string, clock uint64) (err error) {
|
|
return kp.execUpdateQuery(oldKcUID, clock, "keycard_uid", newKcUID)
|
|
}
|
|
|
|
func (kp *Keycards) SetKeycardName(kcUID string, kpName string, clock uint64) (err error) {
|
|
return kp.execUpdateQuery(kcUID, clock, "keycard_name", kpName)
|
|
}
|
|
|
|
func (kp *Keycards) DeleteKeycard(kcUID string, clock uint64) (err error) {
|
|
tx, proceed, err := kp.startTransactionAndCheckIfNeedToProceed(kcUID, clock)
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if proceed {
|
|
return kp.deleteKeycard(tx, kcUID)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (kp *Keycards) DeleteAllKeycardsWithKeyUID(keyUID string) (err error) {
|
|
delete, err := kp.db.Prepare(`
|
|
DELETE
|
|
FROM
|
|
keycards
|
|
WHERE
|
|
key_uid = ?
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer delete.Close()
|
|
_, err = delete.Exec(keyUID)
|
|
return err
|
|
}
|