From 4c050f35dd1452c6d09709445efca40bb9fa717d Mon Sep 17 00:00:00 2001 From: Sergio Chouhy Date: Thu, 16 Apr 2026 03:03:13 -0300 Subject: [PATCH] refactor private key tree to store a vec<(Identifier, AccountId)> --- .../tests/auth_transfer/private.rs | 2 - integration_tests/tests/keys_restoration.rs | 9 +- integration_tests/tests/token.rs | 1 - .../key_management/key_tree/keys_private.rs | 30 ++-- .../src/key_management/key_tree/mod.rs | 136 ++++++++++++------ key_protocol/src/key_management/mod.rs | 5 +- key_protocol/src/key_protocol_core/mod.rs | 77 +++++----- wallet-ffi/src/account.rs | 11 +- wallet/configs/debug/wallet_config.json | 6 +- wallet/src/chain_storage.rs | 107 +++++++++++--- wallet/src/cli/account.rs | 20 +-- wallet/src/helperfunctions.rs | 22 +-- wallet/src/lib.rs | 20 +-- wallet/src/privacy_preserving_tx.rs | 1 - 14 files changed, 289 insertions(+), 158 deletions(-) diff --git a/integration_tests/tests/auth_transfer/private.rs b/integration_tests/tests/auth_transfer/private.rs index eac73322..f3e70c09 100644 --- a/integration_tests/tests/auth_transfer/private.rs +++ b/integration_tests/tests/auth_transfer/private.rs @@ -178,7 +178,6 @@ async fn private_transfer_to_owned_account_using_claiming_path() -> Result<()> { .storage() .user_data .get_private_account(to_account_id) - .cloned() .context("Failed to get private account")?; // Send to this account using claiming path (using npk and vpk instead of account ID) @@ -347,7 +346,6 @@ async fn private_transfer_to_owned_account_continuous_run_path() -> Result<()> { .storage() .user_data .get_private_account(to_account_id) - .cloned() .context("Failed to get private account")?; // Send transfer using nullifier and viewing public keys diff --git a/integration_tests/tests/keys_restoration.rs b/integration_tests/tests/keys_restoration.rs index 704a133d..2588e7d9 100644 --- a/integration_tests/tests/keys_restoration.rs +++ b/integration_tests/tests/keys_restoration.rs @@ -64,7 +64,6 @@ async fn sync_private_account_with_non_zero_chain_index() -> Result<()> { .storage() .user_data .get_private_account(to_account_id) - .cloned() .context("Failed to get private account")?; // Send to this account using claiming path (using npk and vpk instead of account ID) @@ -264,16 +263,16 @@ async fn restore_keys_from_seed() -> Result<()> { .expect("Acc 4 should be restored"); assert_eq!( - acc1.value.1.program_owner, + acc1.value.1[0].1.program_owner, Program::authenticated_transfer_program().id() ); assert_eq!( - acc2.value.1.program_owner, + acc2.value.1[0].1.program_owner, Program::authenticated_transfer_program().id() ); - assert_eq!(acc1.value.1.balance, 100); - assert_eq!(acc2.value.1.balance, 101); + assert_eq!(acc1.value.1[0].1.balance, 100); + assert_eq!(acc2.value.1[0].1.balance, 101); info!("Tree checks passed, testing restored accounts can transact"); diff --git a/integration_tests/tests/token.rs b/integration_tests/tests/token.rs index 66e5cb84..5ee25a75 100644 --- a/integration_tests/tests/token.rs +++ b/integration_tests/tests/token.rs @@ -1144,7 +1144,6 @@ async fn token_claiming_path_with_private_accounts() -> Result<()> { .storage() .user_data .get_private_account(recipient_account_id) - .cloned() .context("Failed to get private account keys")?; // Mint using claiming path (foreign account) diff --git a/key_protocol/src/key_management/key_tree/keys_private.rs b/key_protocol/src/key_management/key_tree/keys_private.rs index 298adc08..acd25b6b 100644 --- a/key_protocol/src/key_management/key_tree/keys_private.rs +++ b/key_protocol/src/key_management/key_tree/keys_private.rs @@ -10,7 +10,7 @@ use crate::key_management::{ #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChildKeysPrivate { - pub value: (KeyChain, nssa::Account, Identifier), + pub value: (KeyChain, Vec<(Identifier, nssa::Account)>), pub ccc: [u8; 32], /// Can be [`None`] if root. pub cci: Option, @@ -46,15 +46,14 @@ impl ChildKeysPrivate { viewing_secret_key: vsk, }, }, - nssa::Account::default(), - 0, + vec![], ), ccc, cci: None, } } - pub fn nth_child(&self, cci: u32, identifier: Identifier) -> Self { + pub fn nth_child(&self, cci: u32) -> Self { #[expect(clippy::arithmetic_side_effects, reason = "TODO: fix later")] let parent_pt = Scalar::from_repr(self.value.0.private_key_holder.nullifier_secret_key.into()) @@ -96,8 +95,7 @@ impl ChildKeysPrivate { viewing_secret_key: vsk, }, }, - nssa::Account::default(), - identifier, + vec![], ), ccc, cci: Some(cci), @@ -111,17 +109,13 @@ impl ChildKeysPrivate { pub fn child_index(&self) -> Option { self.cci } - - pub fn account_id(&self) -> nssa::AccountId { - nssa::AccountId::from((&self.value.0.nullifier_public_key, self.value.2)) - } } #[expect( clippy::single_char_lifetime_names, reason = "TODO add meaningful name" )] -impl<'a> From<&'a ChildKeysPrivate> for &'a (KeyChain, nssa::Account, Identifier) { +impl<'a> From<&'a ChildKeysPrivate> for &'a (KeyChain, Vec<(Identifier, nssa::Account)>) { fn from(value: &'a ChildKeysPrivate) -> Self { &value.value } @@ -131,7 +125,7 @@ impl<'a> From<&'a ChildKeysPrivate> for &'a (KeyChain, nssa::Account, Identifier clippy::single_char_lifetime_names, reason = "TODO add meaningful name" )] -impl<'a> From<&'a mut ChildKeysPrivate> for &'a mut (KeyChain, nssa::Account, Identifier) { +impl<'a> From<&'a mut ChildKeysPrivate> for &'a mut (KeyChain, Vec<(Identifier, nssa::Account)>) { fn from(value: &'a mut ChildKeysPrivate) -> Self { &mut value.value } @@ -143,11 +137,17 @@ impl KeyTreeNode for ChildKeysPrivate { } fn derive_child(&self, cci: u32) -> Self { - self.nth_child(cci, 0) + self.nth_child(cci) } fn account_ids(&self) -> Vec { - vec![self.account_id()] + self.value + .1 + .iter() + .map(|(identifier, _)| { + nssa::AccountId::from((&self.value.0.nullifier_public_key, *identifier)) + }) + .collect() } } @@ -217,7 +217,7 @@ mod tests { ]; let root_node = ChildKeysPrivate::root(seed); - let child_node = ChildKeysPrivate::nth_child(&root_node, 42_u32, 0); + let child_node = ChildKeysPrivate::nth_child(&root_node, 42_u32); let expected_ccc: [u8; 32] = [ 27, 73, 133, 213, 214, 63, 217, 184, 164, 17, 172, 140, 223, 95, 255, 157, 11, 0, 58, diff --git a/key_protocol/src/key_management/key_tree/mod.rs b/key_protocol/src/key_management/key_tree/mod.rs index c0812f9c..3f847ebd 100644 --- a/key_protocol/src/key_management/key_tree/mod.rs +++ b/key_protocol/src/key_management/key_tree/mod.rs @@ -63,10 +63,7 @@ impl KeyTree { } } - pub fn generate_new_node( - &mut self, - parent_cci: &ChainIndex, - ) -> Option<(nssa::AccountId, ChainIndex)> { + pub fn generate_new_node(&mut self, parent_cci: &ChainIndex) -> Option { let parent_keys = self.key_map.get(parent_cci)?; let next_child_id = self .find_next_last_child_of_id(parent_cci) @@ -75,30 +72,28 @@ impl KeyTree { let child_keys = parent_keys.derive_child(next_child_id); let account_ids = child_keys.account_ids(); - let primary_account_id = *account_ids.first().expect("account_ids() must be non-empty"); for account_id in account_ids { self.account_id_map.insert(account_id, next_cci.clone()); } self.key_map.insert(next_cci.clone(), child_keys); - Some((primary_account_id, next_cci)) + Some(next_cci) } - pub fn fill_node(&mut self, chain_index: &ChainIndex) -> Option<(nssa::AccountId, ChainIndex)> { + pub fn fill_node(&mut self, chain_index: &ChainIndex) -> Option { let parent_keys = self.key_map.get(&chain_index.parent()?)?; let child_id = *chain_index.chain().last()?; let child_keys = parent_keys.derive_child(child_id); let account_ids = child_keys.account_ids(); - let primary_account_id = *account_ids.first().expect("account_ids() must be non-empty"); for account_id in account_ids { self.account_id_map.insert(account_id, chain_index.clone()); } self.key_map.insert(chain_index.clone(), child_keys); - Some((primary_account_id, chain_index.clone())) + Some(chain_index.clone()) } #[must_use] @@ -155,7 +150,7 @@ impl KeyTree { } } - pub fn generate_new_node_layered(&mut self) -> Option<(nssa::AccountId, ChainIndex)> { + pub fn generate_new_node_layered(&mut self) -> Option { self.fill_node(&self.find_next_slot_layered()) } @@ -204,35 +199,27 @@ impl KeyTree { } } -impl KeyTree { - /// Cleanup of non-initialized accounts in a private tree. - /// - /// If account is default, removes them, stops at first non-default account. - /// - /// Walks through tree in lairs of same depth using `ChainIndex::chain_ids_at_depth()`. - /// - /// Chain must be parsed for accounts beforehand. - /// - /// Slow, maintains tree consistency. - pub fn cleanup_tree_remove_uninit_layered(&mut self, depth: u32) { - let depth = usize::try_from(depth).expect("Depth is expected to fit in usize"); - 'outer: for i in (1..depth).rev() { - println!("Cleanup of tree at depth {i}"); - for id in ChainIndex::chain_ids_at_depth(i) { - if let Some(node) = self.key_map.get(&id) { - if node.value.1 == nssa::Account::default() { - let addr = node.account_id(); - self.remove(addr); - } else { - break 'outer; - } - } - } - } - } -} - impl KeyTree { + /// Generate a new public key node, returning the account ID and chain index. + pub fn generate_new_public_node( + &mut self, + parent_cci: &ChainIndex, + ) -> Option<(nssa::AccountId, ChainIndex)> { + let cci = self.generate_new_node(parent_cci)?; + let node = self.key_map.get(&cci)?; + let account_id = *node.account_ids().first()?; + Some((account_id, cci)) + } + + /// Generate a new public key node using layered placement, returning the account ID and chain + /// index. + pub fn generate_new_public_node_layered(&mut self) -> Option<(nssa::AccountId, ChainIndex)> { + let cci = self.generate_new_node_layered()?; + let node = self.key_map.get(&cci)?; + let account_id = *node.account_ids().first()?; + Some((account_id, cci)) + } + /// Cleanup of non-initialized accounts in a public tree. /// /// If account is default, removes them, stops at first non-default account. @@ -267,6 +254,37 @@ impl KeyTree { } } +impl KeyTree { + /// Cleanup of non-initialized accounts in a private tree. + /// + /// If account has no synced entries, removes it, stops at first initialized account. + /// + /// Walks through tree in layers of same depth using `ChainIndex::chain_ids_at_depth()`. + /// + /// Chain must be parsed for accounts beforehand. + /// + /// Slow, maintains tree consistency. + pub fn cleanup_tree_remove_uninit_layered(&mut self, depth: u32) { + let depth = usize::try_from(depth).expect("Depth is expected to fit in usize"); + 'outer: for i in (1..depth).rev() { + println!("Cleanup of tree at depth {i}"); + for id in ChainIndex::chain_ids_at_depth(i) { + if let Some(node) = self.key_map.get(&id) { + if node.value.1.is_empty() { + let account_ids: Vec<_> = node.account_ids(); + self.key_map.remove(&id); + for addr in account_ids { + self.account_id_map.remove(&addr); + } + } else { + break 'outer; + } + } + } + } + } +} + #[cfg(test)] mod tests { #![expect(clippy::shadow_unrelated, reason = "We don't care about this in tests")] @@ -486,25 +504,51 @@ mod tests { .key_map .get_mut(&ChainIndex::from_str("/1").unwrap()) .unwrap(); - acc.value.1.balance = 2; + acc.value.1.push((0, { + let mut a = nssa::Account::default(); + a.balance = 2; + a + })); let acc = tree .key_map .get_mut(&ChainIndex::from_str("/2").unwrap()) .unwrap(); - acc.value.1.balance = 3; + acc.value.1.push((0, { + let mut a = nssa::Account::default(); + a.balance = 3; + a + })); let acc = tree .key_map .get_mut(&ChainIndex::from_str("/0/1").unwrap()) .unwrap(); - acc.value.1.balance = 5; + acc.value.1.push((0, { + let mut a = nssa::Account::default(); + a.balance = 5; + a + })); let acc = tree .key_map .get_mut(&ChainIndex::from_str("/1/0").unwrap()) .unwrap(); - acc.value.1.balance = 6; + acc.value.1.push((0, { + let mut a = nssa::Account::default(); + a.balance = 6; + a + })); + + // Update account_id_map for nodes that now have entries + for chain_index_str in ["/1", "/2", "/0/1", "/1/0"] { + let id = ChainIndex::from_str(chain_index_str).unwrap(); + if let Some(node) = tree.key_map.get(&id) { + for account_id in node.account_ids() { + tree.account_id_map.insert(account_id, id.clone()); + } + } + } tree.cleanup_tree_remove_uninit_layered(10); @@ -526,15 +570,15 @@ mod tests { assert_eq!(key_set, key_set_res); let acc = &tree.key_map[&ChainIndex::from_str("/1").unwrap()]; - assert_eq!(acc.value.1.balance, 2); + assert_eq!(acc.value.1[0].1.balance, 2); let acc = &tree.key_map[&ChainIndex::from_str("/2").unwrap()]; - assert_eq!(acc.value.1.balance, 3); + assert_eq!(acc.value.1[0].1.balance, 3); let acc = &tree.key_map[&ChainIndex::from_str("/0/1").unwrap()]; - assert_eq!(acc.value.1.balance, 5); + assert_eq!(acc.value.1[0].1.balance, 5); let acc = &tree.key_map[&ChainIndex::from_str("/1/0").unwrap()]; - assert_eq!(acc.value.1.balance, 6); + assert_eq!(acc.value.1[0].1.balance, 6); } } diff --git a/key_protocol/src/key_management/mod.rs b/key_protocol/src/key_management/mod.rs index c038c415..b9c26071 100644 --- a/key_protocol/src/key_management/mod.rs +++ b/key_protocol/src/key_management/mod.rs @@ -172,10 +172,11 @@ mod tests { // /0/0 key_tree_private.generate_new_node_layered().unwrap(); // /2 - let (second_child_id, _) = key_tree_private.generate_new_node_layered().unwrap(); + let second_chain_index = key_tree_private.generate_new_node_layered().unwrap(); key_tree_private - .get_node(second_child_id) + .key_map + .get(&second_chain_index) .unwrap() .value .0 diff --git a/key_protocol/src/key_protocol_core/mod.rs b/key_protocol/src/key_protocol_core/mod.rs index a26f4f1a..449196ea 100644 --- a/key_protocol/src/key_protocol_core/mod.rs +++ b/key_protocol/src/key_protocol_core/mod.rs @@ -20,7 +20,7 @@ pub struct NSSAUserData { pub default_pub_account_signing_keys: BTreeMap, /// Default private accounts. pub default_user_private_accounts: - BTreeMap, + BTreeMap)>, /// Tree of public keys. pub public_key_tree: KeyTreePublic, /// Tree of private keys. @@ -46,15 +46,16 @@ impl NSSAUserData { fn valid_private_key_transaction_pairing_check( accounts_keys_map: &BTreeMap< nssa::AccountId, - (KeyChain, nssa_core::account::Account, Identifier), + (KeyChain, Vec<(Identifier, nssa_core::account::Account)>), >, ) -> bool { let mut check_res = true; - for (account_id, (key, _, identifier)) in accounts_keys_map { - let expected_account_id = - nssa::AccountId::from((&key.nullifier_public_key, *identifier)); - if expected_account_id != *account_id { - println!("{expected_account_id}, {account_id}"); + for (account_id, (key, entries)) in accounts_keys_map { + let any_match = entries.iter().any(|(identifier, _)| { + nssa::AccountId::from((&key.nullifier_public_key, *identifier)) == *account_id + }); + if !any_match { + println!("No matching entry found for account_id {account_id}"); check_res = false; } } @@ -65,7 +66,7 @@ impl NSSAUserData { default_accounts_keys: BTreeMap, default_accounts_key_chains: BTreeMap< nssa::AccountId, - (KeyChain, nssa_core::account::Account, Identifier), + (KeyChain, Vec<(Identifier, nssa_core::account::Account)>), >, public_key_tree: KeyTreePublic, private_key_tree: KeyTreePrivate, @@ -100,11 +101,11 @@ impl NSSAUserData { match parent_cci { Some(parent_cci) => self .public_key_tree - .generate_new_node(&parent_cci) + .generate_new_public_node(&parent_cci) .expect("Parent must be present in a tree"), None => self .public_key_tree - .generate_new_node_layered() + .generate_new_public_node_layered() .expect("Search for new node slot failed"), } } @@ -122,11 +123,11 @@ impl NSSAUserData { /// Generated new private key for privacy preserving transactions. /// - /// Returns the `account_id` of new account. + /// Returns the `ChainIndex` of the new node. pub fn generate_new_privacy_preserving_transaction_key_chain( &mut self, parent_cci: Option, - ) -> (nssa::AccountId, ChainIndex) { + ) -> ChainIndex { match parent_cci { Some(parent_cci) => self .private_key_tree @@ -139,31 +140,35 @@ impl NSSAUserData { } } - /// Returns the signing key for public transaction signatures. + /// Returns the key chain and account data for the given private account ID. #[must_use] pub fn get_private_account( &self, account_id: nssa::AccountId, - ) -> Option<&(KeyChain, nssa_core::account::Account, Identifier)> { - self.default_user_private_accounts - .get(&account_id) - .or_else(|| self.private_key_tree.get_node(account_id).map(Into::into)) - } - - /// Returns the signing key for public transaction signatures. - pub fn get_private_account_mut( - &mut self, - account_id: &nssa::AccountId, - ) -> Option<&mut (KeyChain, nssa_core::account::Account, Identifier)> { - // First seek in defaults - if let Some(key) = self.default_user_private_accounts.get_mut(account_id) { - Some(key) - // Then seek in tree - } else { - self.private_key_tree - .get_node_mut(*account_id) - .map(Into::into) + ) -> Option<(KeyChain, nssa_core::account::Account, Identifier)> { + // Check default accounts + if let Some((key_chain, entries)) = self.default_user_private_accounts.get(&account_id) { + for &(identifier, ref account) in entries { + let expected_id = + nssa::AccountId::from((&key_chain.nullifier_public_key, identifier)); + if expected_id == account_id { + return Some((key_chain.clone(), account.clone(), identifier)); + } + } + return None; } + // Check tree + if let Some(node) = self.private_key_tree.get_node(account_id) { + let key_chain = &node.value.0; + for &(identifier, ref account) in &node.value.1 { + let expected_id = + nssa::AccountId::from((&key_chain.nullifier_public_key, identifier)); + if expected_id == account_id { + return Some((key_chain.clone(), account.clone(), identifier)); + } + } + } + None } pub fn account_ids(&self) -> impl Iterator { @@ -206,16 +211,14 @@ mod tests { fn new_account() { let mut user_data = NSSAUserData::default(); - let (account_id_private, _) = user_data + let chain_index = user_data .generate_new_privacy_preserving_transaction_key_chain(Some(ChainIndex::root())); - let is_key_chain_generated = user_data.get_private_account(account_id_private).is_some(); + let is_key_chain_generated = user_data.private_key_tree.key_map.contains_key(&chain_index); assert!(is_key_chain_generated); - let account_id_private_str = account_id_private.to_string(); - println!("{account_id_private_str:#?}"); - let key_chain = &user_data.get_private_account(account_id_private).unwrap().0; + let key_chain = &user_data.private_key_tree.key_map[&chain_index].value.0; println!("{key_chain:#?}"); } } diff --git a/wallet-ffi/src/account.rs b/wallet-ffi/src/account.rs index 49f6a8de..5bbf8c53 100644 --- a/wallet-ffi/src/account.rs +++ b/wallet-ffi/src/account.rs @@ -98,7 +98,16 @@ pub unsafe extern "C" fn wallet_ffi_create_account_private( } }; - let (account_id, _chain_index) = wallet.create_new_account_private(None); + let chain_index = wallet.create_new_account_private(None); + + let node = wallet + .storage() + .user_data + .private_key_tree + .key_map + .get(&chain_index) + .expect("Node was just inserted"); + let account_id = AccountId::from((&node.value.0.nullifier_public_key, 0_u128)); unsafe { (*out_account_id).data = *account_id.value(); diff --git a/wallet/configs/debug/wallet_config.json b/wallet/configs/debug/wallet_config.json index 6604f65b..94e13ebd 100644 --- a/wallet/configs/debug/wallet_config.json +++ b/wallet/configs/debug/wallet_config.json @@ -19,7 +19,8 @@ }, { "Private": { - "account_id": "9DGDXnrNo4QhUUb2F8WDuDrPESja3eYDkZG5HkzvAvMC", + "account_id": "GoKB6RuE6pT2KxCqDXQqiCuuuYZaGdJNfctzyqRdGBCy", + "identifier": 0, "account": { "program_owner": [ 0, @@ -214,7 +215,8 @@ }, { "Private": { - "account_id": "A6AT9UvsgitUi8w4BH43n6DyX1bK37DtSCfjEWXQQUrQ", + "account_id": "BCdMnPkdH2DrVhe7cGdawkPU9iapsSboRvJpWX8pWnLq", + "identifier": 0, "account": { "program_owner": [ 0, diff --git a/wallet/src/chain_storage.rs b/wallet/src/chain_storage.rs index 36222894..30f312fa 100644 --- a/wallet/src/chain_storage.rs +++ b/wallet/src/chain_storage.rs @@ -77,8 +77,10 @@ impl WalletChainStore { public_init_acc_map.insert(data.account_id, data.pub_sign_key); } InitialAccountData::Private(data) => { - private_init_acc_map - .insert(data.account_id, (data.key_chain, data.account, data.identifier)); + private_init_acc_map.insert( + data.account_id, + (data.key_chain, vec![(data.identifier, data.account)]), + ); } }, } @@ -117,8 +119,10 @@ impl WalletChainStore { // startup. Fix this when program id can be fetched // from the node and queried from the wallet. account.program_owner = Program::authenticated_transfer_program().id(); - private_init_acc_map - .insert(data.account_id, (data.key_chain, account, data.identifier)); + private_init_acc_map.insert( + data.account_id, + (data.key_chain, vec![(data.identifier, account)]), + ); } } } @@ -175,24 +179,86 @@ impl WalletChainStore { ) { debug!("inserting at address {account_id}, this account {account:?}"); - let entry = self - .user_data - .default_user_private_accounts - .entry(account_id) - .and_modify(|data| data.1 = account.clone()); + // Update default accounts if present + if let Entry::Occupied(mut entry) = + self.user_data.default_user_private_accounts.entry(account_id) + { + let (key_chain, entries) = entry.get_mut(); + let identifier = entries + .iter() + .find_map(|(id, _)| { + if nssa::AccountId::from((&key_chain.nullifier_public_key, *id)) == account_id { + Some(*id) + } else { + None + } + }) + .unwrap_or(0); + // Update existing entry or insert new one + if let Some((_, acc)) = entries.iter_mut().find(|(id, _)| *id == identifier) { + *acc = account; + } else { + entries.push((identifier, account)); + } + return; + } - if matches!(entry, Entry::Vacant(_)) { - self.user_data + // Otherwise update the private key tree + // Identifier is hardcoded to 0 until ciphertexts carry the identifier + let identifier: nssa_core::Identifier = 0; + + // Find the node by iterating all tree nodes for this account_id + let chain_index = self + .user_data + .private_key_tree + .account_id_map + .get(&account_id) + .cloned(); + + if let Some(chain_index) = chain_index { + // Node already in account_id_map — update its entry + if let Some(node) = self + .user_data .private_key_tree - .account_id_map - .get(&account_id) - .map(|chain_index| { + .key_map + .get_mut(&chain_index) + { + if let Some((_, acc)) = + node.value.1.iter_mut().find(|(id, _)| *id == identifier) + { + *acc = account; + } else { + node.value.1.push((identifier, account)); + } + } + } else { + // Node not yet in account_id_map — find it by checking all nodes + for (ci, node) in self + .user_data + .private_key_tree + .key_map + .iter_mut() + { + let expected_id = nssa::AccountId::from(( + &node.value.0.nullifier_public_key, + identifier, + )); + if expected_id == account_id { + if let Some((_, acc)) = + node.value.1.iter_mut().find(|(id, _)| *id == identifier) + { + *acc = account; + } else { + node.value.1.push((identifier, account)); + } + // Register in account_id_map self.user_data .private_key_tree - .key_map - .entry(chain_index.clone()) - .and_modify(|data| data.value.1 = account) - }); + .account_id_map + .insert(account_id, ci.clone()); + break; + } + } } } } @@ -229,7 +295,10 @@ mod tests { data: public_data, }), PersistentAccountData::Private(Box::new(PersistentAccountDataPrivate { - account_id: private_data.account_id(), + account_id: nssa::AccountId::from(( + &private_data.value.0.nullifier_public_key, + 0_u128, + )), chain_index: ChainIndex::root(), data: private_data, })), diff --git a/wallet/src/cli/account.rs b/wallet/src/cli/account.rs index c5400878..577f8042 100644 --- a/wallet/src/cli/account.rs +++ b/wallet/src/cli/account.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, Result}; use clap::Subcommand; use itertools::Itertools as _; use key_protocol::key_management::key_tree::chain_index::ChainIndex; -use nssa::{Account, PublicKey, program::Program}; +use nssa::{Account, AccountId, PublicKey, program::Program}; use sequencer_service_rpc::RpcClient as _; use token_core::{TokenDefinition, TokenHolding}; @@ -147,7 +147,17 @@ impl WalletSubcommand for NewSubcommand { anyhow::bail!("Label '{label}' is already in use by another account"); } - let (account_id, chain_index) = wallet_core.create_new_account_private(cci); + let chain_index = wallet_core.create_new_account_private(cci); + + let node = wallet_core + .storage + .user_data + .private_key_tree + .key_map + .get(&chain_index) + .expect("Node was just inserted"); + let key = &node.value.0; + let account_id = AccountId::from((&key.nullifier_public_key, 0_u128)); if let Some(label) = label { wallet_core @@ -156,12 +166,6 @@ impl WalletSubcommand for NewSubcommand { .insert(account_id.to_string(), Label::new(label)); } - let (key, _, _) = wallet_core - .storage - .user_data - .get_private_account(account_id) - .unwrap(); - println!( "Generated new account with account_id Private/{account_id} at path {chain_index}", ); diff --git a/wallet/src/helperfunctions.rs b/wallet/src/helperfunctions.rs index 6e5d2e7d..e1c5b753 100644 --- a/wallet/src/helperfunctions.rs +++ b/wallet/src/helperfunctions.rs @@ -188,16 +188,18 @@ pub fn produce_data_for_storage( ); } - for (account_id, (key_chain, account, identifier)) in &user_data.default_user_private_accounts { - vec_for_storage.push( - InitialAccountData::Private(Box::new(PrivateAccountPrivateInitialData { - account_id: *account_id, - account: account.clone(), - key_chain: key_chain.clone(), - identifier: *identifier, - })) - .into(), - ); + for (account_id, (key_chain, entries)) in &user_data.default_user_private_accounts { + for (identifier, account) in entries { + vec_for_storage.push( + InitialAccountData::Private(Box::new(PrivateAccountPrivateInitialData { + account_id: *account_id, + account: account.clone(), + key_chain: key_chain.clone(), + identifier: *identifier, + })) + .into(), + ); + } } PersistentStorage { diff --git a/wallet/src/lib.rs b/wallet/src/lib.rs index 5557491a..7cd481a9 100644 --- a/wallet/src/lib.rs +++ b/wallet/src/lib.rs @@ -259,7 +259,7 @@ impl WalletCore { pub fn create_new_account_private( &mut self, chain_index: Option, - ) -> (AccountId, ChainIndex) { + ) -> ChainIndex { self.storage .user_data .generate_new_privacy_preserving_transaction_key_chain(chain_index) @@ -295,14 +295,14 @@ impl WalletCore { self.storage .user_data .get_private_account(account_id) - .map(|value| value.1.clone()) + .map(|(_keys, account, _identifier)| account) } #[must_use] pub fn get_private_account_commitment(&self, account_id: AccountId) -> Option { let (_keys, account, _identifier) = self.storage.user_data.get_private_account(account_id)?; - Some(Commitment::new(&account_id, account)) + Some(Commitment::new(&account_id, &account)) } /// Poll transactions. @@ -485,14 +485,16 @@ impl WalletCore { .user_data .default_user_private_accounts .iter() - .map(|(acc_account_id, (key_chain, _, _))| (*acc_account_id, key_chain, None)) + .map(|(acc_account_id, (key_chain, _))| (*acc_account_id, key_chain, None)) .chain(self.storage.user_data.private_key_tree.key_map.iter().map( |(chain_index, keys_node)| { - ( - keys_node.account_id(), - &keys_node.value.0, - chain_index.index(), - ) + // Use identifier=0 as the expected first account for this node. + // The actual identifier will be confirmed once the account is synced. + let account_id = nssa::AccountId::from(( + &keys_node.value.0.nullifier_public_key, + 0_u128, + )); + (account_id, &keys_node.value.0, chain_index.index()) }, )); diff --git a/wallet/src/privacy_preserving_tx.rs b/wallet/src/privacy_preserving_tx.rs index 4a0d3a03..f06b5bf7 100644 --- a/wallet/src/privacy_preserving_tx.rs +++ b/wallet/src/privacy_preserving_tx.rs @@ -211,7 +211,6 @@ async fn private_acc_preparation( .storage .user_data .get_private_account(account_id) - .cloned() else { return Err(ExecutionFailureKind::KeyNotFoundError); };