package accounts import ( "database/sql" "encoding/json" "errors" "fmt" "strconv" "strings" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/multiaccounts/common" "github.com/status-im/status-go/multiaccounts/settings" notificationssettings "github.com/status-im/status-go/multiaccounts/settings_notifications" sociallinkssettings "github.com/status-im/status-go/multiaccounts/settings_social_links" "github.com/status-im/status-go/nodecfg" "github.com/status-im/status-go/params" ) const ( statusChatPath = "m/43'/60'/1581'/0'/0" statusWalletRootPath = "m/44'/60'/0'/0/" zeroAddress = "0x0000000000000000000000000000000000000000" SyncedFromBackup = "backup" // means a account is coming from backed up data SyncedFromLocalPairing = "local-pairing" // means a account is coming from another device when user is reocovering Status account ) var ( errDbPassedParameterIsNil = errors.New("accounts: passed parameter is nil") errDbTransactionIsNil = errors.New("accounts: database transaction is nil") ErrDbKeypairNotFound = errors.New("accounts: keypair is not found") ErrDbAccountNotFound = errors.New("accounts: account is not found") ErrAccountWrongPosition = errors.New("accounts: trying to set wrong position to account") ErrNotTheSameNumberOdAccountsToApplyReordering = errors.New("accounts: there is different number of accounts between received sync message and db accounts") ErrNotTheSameAccountsToApplyReordering = errors.New("accounts: there are differences between accounts in received sync message and db accounts") ErrMovingAccountToWrongPosition = errors.New("accounts: trying to move account to a wrong position") ErrKeypairDifferentAccountsKeyUID = errors.New("cannot store keypair with different accounts' key uid than keypair's key uid") ErrKeypairWithoutAccounts = errors.New("cannot store keypair without accounts") ) type Keypair struct { KeyUID string `json:"key-uid"` Name string `json:"name"` Type KeypairType `json:"type"` DerivedFrom string `json:"derived-from"` LastUsedDerivationIndex uint64 `json:"last-used-derivation-index,omitempty"` SyncedFrom string `json:"synced-from,omitempty"` // keeps an info which device this keypair is added from can be one of two values defined in constants or device name (custom) Clock uint64 `json:"clock,omitempty"` Accounts []*Account `json:"accounts,omitempty"` Keycards []*Keycard `json:"keycards,omitempty"` Removed bool `json:"removed,omitempty"` } type Account struct { Address types.Address `json:"address"` KeyUID string `json:"key-uid"` Wallet bool `json:"wallet"` Chat bool `json:"chat"` Type AccountType `json:"type,omitempty"` Path string `json:"path,omitempty"` PublicKey types.HexBytes `json:"public-key,omitempty"` Name string `json:"name"` Emoji string `json:"emoji"` ColorID common.CustomizationColor `json:"colorId,omitempty"` Hidden bool `json:"hidden"` Clock uint64 `json:"clock,omitempty"` Removed bool `json:"removed,omitempty"` Operable AccountOperable `json:"operable"` // describes an account's operability (read an explanation at the top of this file) CreatedAt int64 `json:"createdAt"` Position int64 `json:"position"` } type KeypairType string type AccountType string type AccountOperable string func (a KeypairType) String() string { return string(a) } func (a AccountType) String() string { return string(a) } func (a AccountOperable) String() string { return string(a) } const ( KeypairTypeProfile KeypairType = "profile" KeypairTypeKey KeypairType = "key" KeypairTypeSeed KeypairType = "seed" ) const ( AccountTypeGenerated AccountType = "generated" AccountTypeKey AccountType = "key" AccountTypeSeed AccountType = "seed" AccountTypeWatch AccountType = "watch" ) const ( AccountNonOperable AccountOperable = "no" // an account is non operable it is not a keycard account and there is no keystore file for it and no keystore file for the address it is derived from AccountPartiallyOperable AccountOperable = "partially" // an account is partially operable if it is not a keycard account and there is created keystore file for the address it is derived from AccountFullyOperable AccountOperable = "fully" // an account is fully operable if it is not a keycard account and there is a keystore file for it ) // IsOwnAccount returns true if this is an account we have the private key for // NOTE: Wallet flag can't be used as it actually indicates that it's the default // Wallet func (a *Account) IsOwnAccount() bool { return a.Wallet || a.Type == AccountTypeSeed || a.Type == AccountTypeGenerated || a.Type == AccountTypeKey } func (a *Account) MarshalJSON() ([]byte, error) { item := struct { Address types.Address `json:"address"` MixedcaseAddress string `json:"mixedcase-address"` KeyUID string `json:"key-uid"` Wallet bool `json:"wallet"` Chat bool `json:"chat"` Type AccountType `json:"type"` Path string `json:"path"` PublicKey types.HexBytes `json:"public-key"` Name string `json:"name"` Emoji string `json:"emoji"` ColorID common.CustomizationColor `json:"colorId"` Hidden bool `json:"hidden"` Clock uint64 `json:"clock"` Removed bool `json:"removed"` Operable AccountOperable `json:"operable"` CreatedAt int64 `json:"createdAt"` Position int64 `json:"position"` }{ Address: a.Address, MixedcaseAddress: a.Address.Hex(), KeyUID: a.KeyUID, Wallet: a.Wallet, Chat: a.Chat, Type: a.Type, Path: a.Path, PublicKey: a.PublicKey, Name: a.Name, Emoji: a.Emoji, ColorID: a.ColorID, Hidden: a.Hidden, Clock: a.Clock, Removed: a.Removed, Operable: a.Operable, CreatedAt: a.CreatedAt, Position: a.Position, } return json.Marshal(item) } func (a *Keypair) MarshalJSON() ([]byte, error) { item := struct { KeyUID string `json:"key-uid"` Name string `json:"name"` Type KeypairType `json:"type"` DerivedFrom string `json:"derived-from"` LastUsedDerivationIndex uint64 `json:"last-used-derivation-index"` SyncedFrom string `json:"synced-from"` Clock uint64 `json:"clock"` Accounts []*Account `json:"accounts"` Keycards []*Keycard `json:"keycards"` Removed bool `json:"removed"` }{ KeyUID: a.KeyUID, Name: a.Name, Type: a.Type, DerivedFrom: a.DerivedFrom, LastUsedDerivationIndex: a.LastUsedDerivationIndex, SyncedFrom: a.SyncedFrom, Clock: a.Clock, Accounts: a.Accounts, Keycards: a.Keycards, Removed: a.Removed, } return json.Marshal(item) } func (a *Keypair) CopyKeypair() *Keypair { kp := &Keypair{ Clock: a.Clock, KeyUID: a.KeyUID, Name: a.Name, Type: a.Type, DerivedFrom: a.DerivedFrom, LastUsedDerivationIndex: a.LastUsedDerivationIndex, SyncedFrom: a.SyncedFrom, Accounts: make([]*Account, len(a.Accounts)), Keycards: make([]*Keycard, len(a.Keycards)), Removed: a.Removed, } for i, acc := range a.Accounts { kp.Accounts[i] = &Account{ Address: acc.Address, KeyUID: acc.KeyUID, Wallet: acc.Wallet, Chat: acc.Chat, Type: acc.Type, Path: acc.Path, PublicKey: acc.PublicKey, Name: acc.Name, Emoji: acc.Emoji, ColorID: acc.ColorID, Hidden: acc.Hidden, Clock: acc.Clock, Removed: acc.Removed, Operable: acc.Operable, CreatedAt: acc.CreatedAt, Position: acc.Position, } } for i, kc := range a.Keycards { kp.Keycards[i] = &Keycard{ KeycardUID: kc.KeycardUID, KeycardName: kc.KeycardName, KeycardLocked: kc.KeycardLocked, AccountsAddresses: kc.AccountsAddresses, KeyUID: kc.KeyUID, } } return kp } func (a *Keypair) GetChatPublicKey() types.HexBytes { for _, acc := range a.Accounts { if acc.Chat { return acc.PublicKey } } return nil } // Database sql wrapper for operations with browser objects. type Database struct { *settings.Database *notificationssettings.NotificationsSettings *sociallinkssettings.SocialLinksSettings db *sql.DB } // NewDB returns a new instance of *Database func NewDB(db *sql.DB) (*Database, error) { sDB, err := settings.MakeNewDB(db) if err != nil { return nil, err } sn := notificationssettings.NewNotificationsSettings(db) ssl := sociallinkssettings.NewSocialLinksSettings(db) return &Database{sDB, sn, ssl, db}, nil } // DB Gets db sql.DB func (db *Database) DB() *sql.DB { return db.db } // Close closes database. func (db *Database) Close() error { return db.db.Close() } func GetAccountTypeForKeypairType(kpType KeypairType) AccountType { switch kpType { case KeypairTypeProfile: return AccountTypeGenerated case KeypairTypeKey: return AccountTypeKey case KeypairTypeSeed: return AccountTypeSeed default: return AccountTypeWatch } } func (db *Database) processRows(rows *sql.Rows) ([]*Keypair, []*Account, error) { keypairMap := make(map[string]*Keypair) allAccounts := []*Account{} var ( kpKeyUID sql.NullString kpName sql.NullString kpType sql.NullString kpDerivedFrom sql.NullString kpLastUsedDerivationIndex sql.NullInt64 kpSyncedFrom sql.NullString kpClock sql.NullInt64 ) var ( accAddress sql.NullString accKeyUID sql.NullString accPath sql.NullString accName sql.NullString accColorID sql.NullString accEmoji sql.NullString accWallet sql.NullBool accChat sql.NullBool accHidden sql.NullBool accOperable sql.NullString accClock sql.NullInt64 accCreatedAt sql.NullTime accPosition sql.NullInt64 ) for rows.Next() { kp := &Keypair{} acc := &Account{} pubkey := []byte{} err := rows.Scan( &kpKeyUID, &kpName, &kpType, &kpDerivedFrom, &kpLastUsedDerivationIndex, &kpSyncedFrom, &kpClock, &accAddress, &accKeyUID, &pubkey, &accPath, &accName, &accColorID, &accEmoji, &accWallet, &accChat, &accHidden, &accOperable, &accClock, &accCreatedAt, &accPosition) if err != nil { return nil, nil, err } // check keypair fields if kpKeyUID.Valid { kp.KeyUID = kpKeyUID.String } if kpName.Valid { kp.Name = kpName.String } if kpType.Valid { kp.Type = KeypairType(kpType.String) } if kpDerivedFrom.Valid { kp.DerivedFrom = kpDerivedFrom.String } if kpLastUsedDerivationIndex.Valid { kp.LastUsedDerivationIndex = uint64(kpLastUsedDerivationIndex.Int64) } if kpSyncedFrom.Valid { kp.SyncedFrom = kpSyncedFrom.String } if kpClock.Valid { kp.Clock = uint64(kpClock.Int64) } // check keypair accounts fields if accAddress.Valid { acc.Address = types.BytesToAddress([]byte(accAddress.String)) } if accKeyUID.Valid { acc.KeyUID = accKeyUID.String } if accPath.Valid { acc.Path = accPath.String } if accName.Valid { acc.Name = accName.String } if accColorID.Valid { acc.ColorID = common.CustomizationColor(accColorID.String) } if accEmoji.Valid { acc.Emoji = accEmoji.String } if accWallet.Valid { acc.Wallet = accWallet.Bool } if accChat.Valid { acc.Chat = accChat.Bool } if accHidden.Valid { acc.Hidden = accHidden.Bool } if accOperable.Valid { acc.Operable = AccountOperable(accOperable.String) } if accClock.Valid { acc.Clock = uint64(accClock.Int64) } if accCreatedAt.Valid { acc.CreatedAt = accCreatedAt.Time.UnixMilli() } if accPosition.Valid { acc.Position = accPosition.Int64 } if lth := len(pubkey); lth > 0 { acc.PublicKey = make(types.HexBytes, lth) copy(acc.PublicKey, pubkey) } acc.Type = GetAccountTypeForKeypairType(kp.Type) if kp.KeyUID != "" { if _, ok := keypairMap[kp.KeyUID]; !ok { keypairMap[kp.KeyUID] = kp } keypairMap[kp.KeyUID].Accounts = append(keypairMap[kp.KeyUID].Accounts, acc) } allAccounts = append(allAccounts, acc) } if err := rows.Err(); err != nil { return nil, nil, err } // Convert map to list keypairs := make([]*Keypair, 0, len(keypairMap)) for _, keypair := range keypairMap { keypairs = append(keypairs, keypair) } return keypairs, allAccounts, nil } // If `keyUID` is passed only keypairs which match the passed `keyUID` will be returned, if `keyUID` is empty, all keypairs will be returned. func (db *Database) getKeypairs(tx *sql.Tx, keyUID string) ([]*Keypair, error) { var ( rows *sql.Rows err error where string ) if tx == nil { tx, err = db.db.Begin() defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() if err != nil { return nil, err } } if keyUID != "" { where = "WHERE k.key_uid = ?" } query := fmt.Sprintf( // nolint: gosec ` SELECT k.*, ka.address, ka.key_uid, ka.pubkey, ka.path, ka.name, ka.color, ka.emoji, ka.wallet, ka.chat, ka.hidden, ka.operable, ka.clock, ka.created_at, ka.position FROM keypairs k LEFT JOIN keypairs_accounts ka ON k.key_uid = ka.key_uid %s ORDER BY ka.position`, where) stmt, err := tx.Prepare(query) if err != nil { return nil, err } defer stmt.Close() if where != "" { rows, err = stmt.Query(keyUID) } else { rows, err = stmt.Query() } if err != nil { return nil, err } defer rows.Close() keypairs, _, err := db.processRows(rows) if err != nil { return nil, err } for _, kp := range keypairs { keycards, err := db.getKeycards(tx, kp.KeyUID, "") if err != nil { return nil, err } kp.Keycards = keycards } return keypairs, nil } func (db *Database) getKeypairByKeyUID(tx *sql.Tx, keyUID string) (*Keypair, error) { keypairs, err := db.getKeypairs(tx, keyUID) if err != nil && err != sql.ErrNoRows { return nil, err } if len(keypairs) == 0 { return nil, ErrDbKeypairNotFound } return keypairs[0], nil } // If `address` is passed only accounts which match the passed `address` will be returned, if `address` is empty, all accounts will be returned. func (db *Database) getAccounts(tx *sql.Tx, address types.Address) ([]*Account, error) { var ( rows *sql.Rows err error where string ) if address.String() != zeroAddress { where = "WHERE ka.address = ?" } query := fmt.Sprintf( // nolint: gosec ` SELECT k.*, ka.address, ka.key_uid, ka.pubkey, ka.path, ka.name, ka.color, ka.emoji, ka.wallet, ka.chat, ka.hidden, ka.operable, ka.clock, ka.created_at, ka.position FROM keypairs_accounts ka LEFT JOIN keypairs k ON ka.key_uid = k.key_uid %s ORDER BY ka.position`, where) if tx == nil { if where != "" { rows, err = db.db.Query(query, address) } else { rows, err = db.db.Query(query) } if err != nil { return nil, err } } else { stmt, err := tx.Prepare(query) if err != nil { return nil, err } defer stmt.Close() if where != "" { rows, err = stmt.Query(address) } else { rows, err = stmt.Query() } if err != nil { return nil, err } } defer rows.Close() _, allAccounts, err := db.processRows(rows) if err != nil { return nil, err } return allAccounts, nil } func (db *Database) getAccountByAddress(tx *sql.Tx, address types.Address) (*Account, error) { accounts, err := db.getAccounts(tx, address) if err != nil && err != sql.ErrNoRows { return nil, err } if len(accounts) == 0 { return nil, ErrDbAccountNotFound } return accounts[0], nil } func (db *Database) deleteKeypair(tx *sql.Tx, keyUID string) error { keypairs, err := db.getKeypairs(tx, keyUID) if err != nil && err != sql.ErrNoRows { return err } if len(keypairs) == 0 { return ErrDbKeypairNotFound } query := ` DELETE FROM keypairs WHERE key_uid = ? ` if tx == nil { _, err := db.db.Exec(query, keyUID) return err } stmt, err := tx.Prepare(query) if err != nil { return err } defer stmt.Close() _, err = stmt.Exec(keyUID) return err } func (db *Database) GetKeypairs() ([]*Keypair, error) { return db.getKeypairs(nil, "") } func (db *Database) GetKeypairByKeyUID(keyUID string) (*Keypair, error) { return db.getKeypairByKeyUID(nil, keyUID) } func (db *Database) GetAccounts() ([]*Account, error) { return db.getAccounts(nil, types.Address{}) } func (db *Database) GetAccountByAddress(address types.Address) (*Account, error) { return db.getAccountByAddress(nil, address) } func (db *Database) GetWatchOnlyAccounts() (res []*Account, err error) { accounts, err := db.getAccounts(nil, types.Address{}) if err != nil { return nil, err } for _, acc := range accounts { if acc.Type == AccountTypeWatch { res = append(res, acc) } } return } func (db *Database) IsAnyAccountPartiallyOrFullyOperableForKeyUID(keyUID string) (bool, error) { kp, err := db.getKeypairByKeyUID(nil, keyUID) if err != nil { return false, err } for _, acc := range kp.Accounts { if acc.Operable != AccountNonOperable { return true, nil } } return false, nil } func (db *Database) DeleteKeypair(keyUID string) error { return db.deleteKeypair(nil, keyUID) } func (db *Database) DeleteAccount(address types.Address, clock uint64) error { tx, err := db.db.Begin() defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() if err != nil { return err } acc, err := db.getAccountByAddress(tx, address) if err != nil { return err } kp, err := db.getKeypairByKeyUID(tx, acc.KeyUID) if err != nil && err != ErrDbKeypairNotFound { return err } if kp != nil && len(kp.Accounts) == 1 && kp.Accounts[0].Address == address { return db.deleteKeypair(tx, acc.KeyUID) } delete, err := tx.Prepare(` DELETE FROM keypairs_accounts WHERE address = ? `) if err != nil { return err } defer delete.Close() _, err = delete.Exec(address) if err != nil { return err } // Update keypair clock if any but the watch only account was deleted. if kp != nil { err = db.updateKeypairClock(tx, acc.KeyUID, clock) return err } return nil } func updateKeypairLastUsedIndex(tx *sql.Tx, keyUID string, index uint64, clock uint64, updateKeypairClock bool) error { if tx == nil { return errDbTransactionIsNil } var ( err error setClock string ) if updateKeypairClock { setClock = ", clock = ?" } query := fmt.Sprintf( // nolint: gosec ` UPDATE keypairs SET last_used_derivation_index = ? %s WHERE key_uid = ?`, setClock) if setClock != "" { _, err = tx.Exec(query, index, clock, keyUID) } else { _, err = tx.Exec(query, index, keyUID) } return err } func (db *Database) updateKeypairClock(tx *sql.Tx, keyUID string, clock uint64) error { if tx == nil { return errDbTransactionIsNil } _, err := tx.Exec(` UPDATE keypairs SET clock = ? WHERE key_uid = ?`, clock, keyUID) return err } func (db *Database) saveOrUpdateAccounts(tx *sql.Tx, accounts []*Account, updateKeypairClock bool) (err error) { if tx == nil { return errDbTransactionIsNil } for _, acc := range accounts { var relatedKeypair *Keypair // only watch only accounts have an empty `KeyUID` field var keyUID *string if acc.KeyUID != "" { relatedKeypair, err = db.getKeypairByKeyUID(tx, acc.KeyUID) if err != nil { if err == sql.ErrNoRows { // all accounts, except watch only accounts, must have a row in `keypairs` table with the same key uid continue } return err } keyUID = &acc.KeyUID } _, err = tx.Exec(` INSERT OR IGNORE INTO keypairs_accounts (address, key_uid, pubkey, path, wallet, chat, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')); UPDATE keypairs_accounts SET name = ?, color = ?, emoji = ?, hidden = ?, operable = ?, clock = ?, position = ?, updated_at = datetime('now') WHERE address = ?; `, acc.Address, keyUID, acc.PublicKey, acc.Path, acc.Wallet, acc.Chat, acc.Name, acc.ColorID, acc.Emoji, acc.Hidden, acc.Operable, acc.Clock, acc.Position, acc.Address) if err != nil { return err } // Update positions change clock when adding new/updating account err = db.setClockOfLastAccountsPositionChange(tx, acc.Clock) if err != nil { return err } // Update keypair clock if any but the watch only account has changed. if relatedKeypair != nil && updateKeypairClock { err = db.updateKeypairClock(tx, acc.KeyUID, acc.Clock) if err != nil { return err } } if strings.HasPrefix(acc.Path, statusWalletRootPath) { accIndex, err := strconv.ParseUint(acc.Path[len(statusWalletRootPath):], 0, 64) if err != nil { return err } accountsContainPath := func(accounts []*Account, path string) bool { for _, acc := range accounts { if acc.Path == path { return true } } return false } expectedNewKeypairIndex := uint64(0) if relatedKeypair != nil { expectedNewKeypairIndex = relatedKeypair.LastUsedDerivationIndex for { expectedNewKeypairIndex++ if !accountsContainPath(relatedKeypair.Accounts, statusWalletRootPath+strconv.FormatUint(expectedNewKeypairIndex, 10)) { break } } } if accIndex == expectedNewKeypairIndex { err = updateKeypairLastUsedIndex(tx, acc.KeyUID, accIndex, acc.Clock, updateKeypairClock) if err != nil { return err } } } } return nil } // Saves accounts, if an account already exists, it will be updated. func (db *Database) SaveOrUpdateAccounts(accounts []*Account, updateKeypairClock bool) error { if len(accounts) == 0 { return errors.New("no provided accounts to save/update") } tx, err := db.db.Begin() if err != nil { return err } defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() err = db.saveOrUpdateAccounts(tx, accounts, updateKeypairClock) return err } // Saves a keypair and its accounts, if a keypair with `key_uid` already exists, it will be updated, // if any of its accounts exists it will be updated as well, otherwise it will be added. // Since keypair type contains `Keycards` as well, they are excluded from the saving/updating this way regardless they // are set or not. func (db *Database) SaveOrUpdateKeypair(keypair *Keypair) error { if keypair == nil { return errDbPassedParameterIsNil } tx, err := db.db.Begin() if err != nil { return err } defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() // If keypair is being saved, not updated, then it must be at least one account and all accounts must have the same key uid. dbKeypair, err := db.getKeypairByKeyUID(tx, keypair.KeyUID) if err != nil && err != ErrDbKeypairNotFound { return err } if dbKeypair == nil { if len(keypair.Accounts) == 0 { return ErrKeypairWithoutAccounts } for _, acc := range keypair.Accounts { if acc.KeyUID == "" || acc.KeyUID != keypair.KeyUID { return ErrKeypairDifferentAccountsKeyUID } } } _, err = tx.Exec(` INSERT OR IGNORE INTO keypairs (key_uid, type, derived_from) VALUES (?, ?, ?); UPDATE keypairs SET name = ?, last_used_derivation_index = ?, synced_from = ?, clock = ? WHERE key_uid = ?; `, keypair.KeyUID, keypair.Type, keypair.DerivedFrom, keypair.Name, keypair.LastUsedDerivationIndex, keypair.SyncedFrom, keypair.Clock, keypair.KeyUID) if err != nil { return err } return db.saveOrUpdateAccounts(tx, keypair.Accounts, false) } func (db *Database) UpdateKeypairName(keyUID string, name string, clock uint64, updateChatAccountName bool) error { tx, err := db.db.Begin() if err != nil { return err } defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() _, err = db.getKeypairByKeyUID(tx, keyUID) if err != nil { return err } _, err = tx.Exec(` UPDATE keypairs SET name = ?, clock = ? WHERE key_uid = ?; `, name, clock, keyUID) if err != nil { return err } if updateChatAccountName { _, err = tx.Exec(` UPDATE keypairs_accounts SET name = ?, clock = ? WHERE key_uid = ? AND path = ?; `, name, clock, keyUID, statusChatPath) return err } return nil } func (db *Database) GetWalletAddress() (rst types.Address, err error) { err = db.db.QueryRow("SELECT address FROM keypairs_accounts WHERE wallet = 1").Scan(&rst) return } func (db *Database) GetWalletAddresses() (rst []types.Address, err error) { rows, err := db.db.Query("SELECT address FROM keypairs_accounts WHERE chat = 0 ORDER BY created_at") if err != nil { return nil, err } defer rows.Close() for rows.Next() { addr := types.Address{} err = rows.Scan(&addr) if err != nil { return nil, err } rst = append(rst, addr) } if err := rows.Err(); err != nil { return nil, err } return rst, nil } func (db *Database) GetChatAddress() (rst types.Address, err error) { err = db.db.QueryRow("SELECT address FROM keypairs_accounts WHERE chat = 1").Scan(&rst) return } func (db *Database) GetAddresses() (rst []types.Address, err error) { rows, err := db.db.Query("SELECT address FROM keypairs_accounts ORDER BY created_at") if err != nil { return nil, err } defer rows.Close() for rows.Next() { addr := types.Address{} err = rows.Scan(&addr) if err != nil { return nil, err } rst = append(rst, addr) } if err := rows.Err(); err != nil { return nil, err } return rst, nil } func (db *Database) keypairExists(tx *sql.Tx, keyUID string) (exists bool, err error) { query := `SELECT EXISTS (SELECT 1 FROM keypairs WHERE key_uid = ?)` if tx == nil { err = db.db.QueryRow(query, keyUID).Scan(&exists) } else { err = tx.QueryRow(query, keyUID).Scan(&exists) } return exists, err } // KeypairExists returns true if given address is stored in database. func (db *Database) KeypairExists(keyUID string) (exists bool, err error) { return db.keypairExists(nil, keyUID) } // AddressExists returns true if given address is stored in database. func (db *Database) AddressExists(address types.Address) (exists bool, err error) { err = db.db.QueryRow("SELECT EXISTS (SELECT 1 FROM keypairs_accounts WHERE address = ?)", address).Scan(&exists) return exists, err } // GetPath returns true if account with given address was recently key and doesn't have a key yet func (db *Database) GetPath(address types.Address) (path string, err error) { err = db.db.QueryRow("SELECT path FROM keypairs_accounts WHERE address = ?", address).Scan(&path) return path, err } func (db *Database) GetNodeConfig() (*params.NodeConfig, error) { return nodecfg.GetNodeConfigFromDB(db.db) } // This function should not update the clock, cause it marks accounts locally. func (db *Database) SetAccountOperability(address types.Address, operable AccountOperable) (err error) { tx, err := db.db.Begin() defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() if err != nil { return err } _, err = db.getAccountByAddress(tx, address) if err != nil { return err } _, err = tx.Exec(`UPDATE keypairs_accounts SET operable = ? WHERE address = ?`, operable, address) return err } func (db *Database) GetPositionForNextNewAccount() (int64, error) { var pos sql.NullInt64 err := db.db.QueryRow("SELECT MAX(position) FROM keypairs_accounts").Scan(&pos) if err != nil { return 0, err } if pos.Valid { return pos.Int64 + 1, nil } return 0, nil } // This function should not be used directly, it is called from the functions which reorders accounts. func (db *Database) setClockOfLastAccountsPositionChange(tx *sql.Tx, clock uint64) error { if tx == nil { return nil } _, err := tx.Exec("UPDATE settings SET wallet_accounts_position_change_clock = ? WHERE synthetic_id = 'id'", clock) return err } func (db *Database) GetClockOfLastAccountsPositionChange() (result uint64, err error) { query := "SELECT wallet_accounts_position_change_clock FROM settings WHERE synthetic_id = 'id'" err = db.db.QueryRow(query).Scan(&result) if err != nil { return 0, err } return result, err } // Updates positions of accounts respecting current order. func (db *Database) ResolveAccountsPositions(clock uint64) (err error) { tx, err := db.db.Begin() defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() // returns all accounts ordered by position dbAccounts, err := db.getAccounts(tx, types.Address{}) if err != nil { return err } // starting from -1, cause `getAccounts` returns chat account as well for i := 0; i < len(dbAccounts); i++ { expectedPosition := int64(i - 1) if dbAccounts[i].Position != expectedPosition { _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE address = ?", expectedPosition, dbAccounts[i].Address) if err != nil { return err } } } return db.setClockOfLastAccountsPositionChange(tx, clock) } // Sets positions for passed accounts. func (db *Database) SetWalletAccountsPositions(accounts []*Account, clock uint64) (err error) { if len(accounts) == 0 { return nil } for _, acc := range accounts { if acc.Position < 0 { return ErrAccountWrongPosition } } tx, err := db.db.Begin() defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() dbAccounts, err := db.getAccounts(tx, types.Address{}) if err != nil { return err } // we need to subtract 1, because of the chat account if len(dbAccounts)-1 != len(accounts) { return ErrNotTheSameNumberOdAccountsToApplyReordering } for _, dbAcc := range dbAccounts { if dbAcc.Chat { continue } found := false for _, acc := range accounts { if dbAcc.Address == acc.Address { found = true break } } if !found { return ErrNotTheSameAccountsToApplyReordering } } for _, acc := range accounts { _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE address = ?", acc.Position, acc.Address) if err != nil { return err } } return db.setClockOfLastAccountsPositionChange(tx, clock) } // Moves wallet account fromPosition to toPosition. func (db *Database) MoveWalletAccount(fromPosition int64, toPosition int64, clock uint64) (err error) { if fromPosition < 0 || toPosition < 0 || fromPosition == toPosition { return ErrMovingAccountToWrongPosition } tx, err := db.db.Begin() defer func() { if err == nil { err = tx.Commit() return } _ = tx.Rollback() }() var ( newMaxPosition int64 newMinPosition int64 ) err = tx.QueryRow("SELECT MAX(position), MIN(position) FROM keypairs_accounts").Scan(&newMaxPosition, &newMinPosition) if err != nil { return err } newMaxPosition++ newMinPosition-- if toPosition > fromPosition { _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE position = ?", newMaxPosition, fromPosition) if err != nil { return err } for i := fromPosition + 1; i <= toPosition; i++ { _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE position = ?", i-1, i) if err != nil { return err } } _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE position = ?", toPosition, newMaxPosition) if err != nil { return err } } else { _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE position = ?", newMinPosition, fromPosition) if err != nil { return err } for i := fromPosition - 1; i >= toPosition; i-- { _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE position = ?", i+1, i) if err != nil { return err } } _, err = tx.Exec("UPDATE keypairs_accounts SET position = ? WHERE position = ?", toPosition, newMinPosition) if err != nil { return err } } return db.setClockOfLastAccountsPositionChange(tx, clock) }