diff --git a/key_protocol/Cargo.toml b/key_protocol/Cargo.toml index a562515..24b92c0 100644 --- a/key_protocol/Cargo.toml +++ b/key_protocol/Cargo.toml @@ -16,9 +16,11 @@ bip39.workspace = true hmac-sha512.workspace = true thiserror.workspace = true nssa-core = { path = "../nssa/core", features = ["host"] } +itertools.workspace = true [dependencies.common] path = "../common" [dependencies.nssa] path = "../nssa" +features = ["no_docker"] diff --git a/key_protocol/src/key_management/key_tree/chain_index.rs b/key_protocol/src/key_management/key_tree/chain_index.rs index 31a6a07..d2c9c3b 100644 --- a/key_protocol/src/key_management/key_tree/chain_index.rs +++ b/key_protocol/src/key_management/key_tree/chain_index.rs @@ -1,8 +1,9 @@ use std::{fmt::Display, str::FromStr}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, Hash)] pub struct ChainIndex(Vec); #[derive(thiserror::Error, Debug)] @@ -77,6 +78,23 @@ impl ChainIndex { ChainIndex(chain) } + pub fn previous_in_line(&self) -> Option { + let mut chain = self.0.clone(); + if let Some(last_p) = chain.last_mut() { + *last_p = last_p.checked_sub(1)?; + } + + Some(ChainIndex(chain)) + } + + pub fn parent(&self) -> Option { + if self.0.is_empty() { + None + } else { + Some(ChainIndex(self.0[..(self.0.len() - 1)].to_vec())) + } + } + pub fn nth_child(&self, child_id: u32) -> ChainIndex { let mut chain = self.0.clone(); chain.push(child_id); @@ -85,13 +103,40 @@ impl ChainIndex { } pub fn depth(&self) -> u32 { - let mut res = 0; + self.0.iter().map(|cci| cci + 1).sum() + } - for cci in &self.0 { - res += cci + 1; + fn collapse_back(&self) -> Option { + let mut res = self.parent()?; + + let last_mut = res.0.last_mut()?; + *last_mut += *(self.0.last()?) + 1; + + Some(res) + } + + fn shuffle_iter(&self) -> impl Iterator { + self.0 + .iter() + .permutations(self.0.len()) + .unique() + .map(|item| ChainIndex(item.into_iter().cloned().collect())) + } + + pub fn chain_ids_at_depth(depth: usize) -> impl Iterator { + let mut stack = vec![ChainIndex(vec![0; depth])]; + let mut cumulative_stack = vec![ChainIndex(vec![0; depth])]; + + while let Some(id) = stack.pop() { + if let Some(collapsed_id) = id.collapse_back() { + for id in collapsed_id.shuffle_iter() { + stack.push(id.clone()); + cumulative_stack.push(id); + } + } } - res + cumulative_stack.into_iter().unique() } } @@ -155,4 +200,83 @@ mod tests { assert_eq!(string_index, "/5/7/8".to_string()); } + + #[test] + fn test_prev_in_line() { + let chain_id = ChainIndex(vec![1, 7, 3]); + + let prev_chain_id = chain_id.previous_in_line().unwrap(); + + assert_eq!(prev_chain_id, ChainIndex(vec![1, 7, 2])) + } + + #[test] + fn test_prev_in_line_no_prev() { + let chain_id = ChainIndex(vec![1, 7, 0]); + + let prev_chain_id = chain_id.previous_in_line(); + + assert_eq!(prev_chain_id, None) + } + + #[test] + fn test_parent() { + let chain_id = ChainIndex(vec![1, 7, 3]); + + let parent_chain_id = chain_id.parent().unwrap(); + + assert_eq!(parent_chain_id, ChainIndex(vec![1, 7])) + } + + #[test] + fn test_parent_no_parent() { + let chain_id = ChainIndex(vec![]); + + let parent_chain_id = chain_id.parent(); + + assert_eq!(parent_chain_id, None) + } + + #[test] + fn test_parent_root() { + let chain_id = ChainIndex(vec![1]); + + let parent_chain_id = chain_id.parent().unwrap(); + + assert_eq!(parent_chain_id, ChainIndex::root()) + } + + #[test] + fn test_collapse_back() { + let chain_id = ChainIndex(vec![1, 1]); + + let collapsed = chain_id.collapse_back().unwrap(); + + assert_eq!(collapsed, ChainIndex(vec![3])) + } + + #[test] + fn test_collapse_back_one() { + let chain_id = ChainIndex(vec![1]); + + let collapsed = chain_id.collapse_back(); + + assert_eq!(collapsed, None) + } + + #[test] + fn test_collapse_back_root() { + let chain_id = ChainIndex(vec![]); + + let collapsed = chain_id.collapse_back(); + + assert_eq!(collapsed, None) + } + + #[test] + fn test_shuffle() { + for id in ChainIndex::chain_ids_at_depth(5) { + println!("{id}"); + } + } } diff --git a/key_protocol/src/key_management/key_tree/mod.rs b/key_protocol/src/key_management/key_tree/mod.rs index beb0f95..9f72ecf 100644 --- a/key_protocol/src/key_management/key_tree/mod.rs +++ b/key_protocol/src/key_management/key_tree/mod.rs @@ -107,14 +107,13 @@ impl KeyTree { &mut self, parent_cci: &ChainIndex, ) -> Option<(nssa::AccountId, ChainIndex)> { - let father_keys = self.key_map.get(parent_cci)?; + let parent_keys = self.key_map.get(parent_cci)?; let next_child_id = self .find_next_last_child_of_id(parent_cci) .expect("Can be None only if parent is not present"); let next_cci = parent_cci.nth_child(next_child_id); - let child_keys = father_keys.nth_child(next_child_id); - + let child_keys = parent_keys.nth_child(next_child_id); let account_id = child_keys.account_id(); self.key_map.insert(next_cci.clone(), child_keys); @@ -189,11 +188,10 @@ impl KeyTree { let mut id_stack = vec![ChainIndex::root()]; while let Some(curr_id) = id_stack.pop() { - self.generate_new_node(&curr_id); - let mut next_id = curr_id.nth_child(0); - while (next_id.depth()) < depth - 1 { + while (next_id.depth()) < depth { + self.generate_new_node(&curr_id); id_stack.push(next_id.clone()); next_id = next_id.next_in_line(); } @@ -210,7 +208,9 @@ impl KeyTree { /// If account is default, removes them. /// /// Chain must be parsed for accounts beforehand - pub fn cleanup_tree_for_depth(&mut self, depth: u32) { + /// + /// Fast, leaves gaps between accounts + pub fn cleanup_tree_remove_uninit_for_depth(&mut self, depth: u32) { let mut id_stack = vec![ChainIndex::root()]; while let Some(curr_id) = id_stack.pop() { @@ -224,12 +224,37 @@ impl KeyTree { let mut next_id = curr_id.nth_child(0); - while (next_id.depth()) < depth - 1 { + while (next_id.depth()) < depth { id_stack.push(next_id.clone()); next_id = next_id.next_in_line(); } } } + + /// Cleanup of non-initialized accounts in a private tree + /// + /// If account is default, removes them, stops at first non-default account. + /// + /// Walks through tree in lairs of same depth using `ChainIndex::chain_ids_at_depth()` + /// + /// Chain must be parsed for accounts beforehand + /// + /// Slow, maintains tree consistency. + pub fn cleanup_tree_remove_uninit_layered(&mut self, depth: u32) { + 'outer: for i in (1..(depth as usize)).rev() { + println!("Cleanup of tree at depth {i}"); + for id in ChainIndex::chain_ids_at_depth(i) { + if let Some(node) = self.key_map.get(&id) { + if node.value.1 == nssa::Account::default() { + let addr = node.account_id(); + self.remove(addr); + } else { + break 'outer; + } + } + } + } + } } impl KeyTree { @@ -239,7 +264,9 @@ impl KeyTree { /// depth`. /// /// If account is default, removes them. - pub async fn cleanup_tree_for_depth( + /// + /// Fast, leaves gaps between accounts + pub async fn cleanup_tree_remove_ininit_for_depth( &mut self, depth: u32, client: Arc, @@ -258,7 +285,7 @@ impl KeyTree { let mut next_id = curr_id.nth_child(0); - while (next_id.depth()) < depth - 1 { + while (next_id.depth()) < depth { id_stack.push(next_id.clone()); next_id = next_id.next_in_line(); } @@ -266,11 +293,43 @@ impl KeyTree { Ok(()) } + + /// Cleanup of non-initialized accounts in a public tree + /// + /// If account is default, removes them, stops at first non-default account. + /// + /// Walks through tree in lairs of same depth using `ChainIndex::chain_ids_at_depth()` + /// + /// Slow, maintains tree consistency. + pub async fn cleanup_tree_remove_uninit_layered( + &mut self, + depth: u32, + client: Arc, + ) -> Result<()> { + 'outer: for i in (1..(depth as usize)).rev() { + println!("Cleanup of tree at depth {i}"); + for id in ChainIndex::chain_ids_at_depth(i) { + if let Some(node) = self.key_map.get(&id) { + let address = node.account_id(); + let node_acc = client.get_account(address.to_string()).await?.account; + + if node_acc == nssa::Account::default() { + let addr = node.account_id(); + self.remove(addr); + } else { + break 'outer; + } + } + } + } + + Ok(()) + } } #[cfg(test)] mod tests { - use std::str::FromStr; + use std::{collections::HashSet, str::FromStr}; use nssa::AccountId; @@ -299,7 +358,7 @@ mod tests { fn test_small_key_tree() { let seed_holder = seed_holder_for_tests(); - let mut tree = KeyTreePublic::new(&seed_holder); + let mut tree = KeyTreePrivate::new(&seed_holder); let next_last_child_for_parent_id = tree .find_next_last_child_of_id(&ChainIndex::root()) @@ -338,7 +397,7 @@ mod tests { fn test_key_tree_can_not_make_child_keys() { let seed_holder = seed_holder_for_tests(); - let mut tree = KeyTreePublic::new(&seed_holder); + let mut tree = KeyTreePrivate::new(&seed_holder); let next_last_child_for_parent_id = tree .find_next_last_child_of_id(&ChainIndex::root()) @@ -489,4 +548,79 @@ mod tests { assert_eq!(next_suitable_parent, ChainIndex::from_str("/2").unwrap()); } + + #[test] + fn test_cleanup() { + let seed_holder = seed_holder_for_tests(); + + let mut tree = KeyTreePrivate::new(&seed_holder); + tree.generate_tree_for_depth(10); + + let acc = tree + .key_map + .get_mut(&ChainIndex::from_str("/1").unwrap()) + .unwrap(); + acc.value.1.balance = 2; + + let acc = tree + .key_map + .get_mut(&ChainIndex::from_str("/2").unwrap()) + .unwrap(); + acc.value.1.balance = 3; + + let acc = tree + .key_map + .get_mut(&ChainIndex::from_str("/0/1").unwrap()) + .unwrap(); + acc.value.1.balance = 5; + + let acc = tree + .key_map + .get_mut(&ChainIndex::from_str("/1/0").unwrap()) + .unwrap(); + acc.value.1.balance = 6; + + tree.cleanup_tree_remove_uninit_layered(10); + + let mut key_set_res = HashSet::new(); + key_set_res.insert("/0".to_string()); + key_set_res.insert("/1".to_string()); + key_set_res.insert("/2".to_string()); + key_set_res.insert("/".to_string()); + key_set_res.insert("/0/0".to_string()); + key_set_res.insert("/0/1".to_string()); + key_set_res.insert("/1/0".to_string()); + + let mut key_set = HashSet::new(); + + for key in tree.key_map.keys() { + key_set.insert(key.to_string()); + } + + assert_eq!(key_set, key_set_res); + + let acc = tree + .key_map + .get(&ChainIndex::from_str("/1").unwrap()) + .unwrap(); + assert_eq!(acc.value.1.balance, 2); + + let acc = tree + .key_map + .get(&ChainIndex::from_str("/2").unwrap()) + .unwrap(); + assert_eq!(acc.value.1.balance, 3); + + let acc = tree + .key_map + .get(&ChainIndex::from_str("/0/1").unwrap()) + .unwrap(); + assert_eq!(acc.value.1.balance, 5); + + let acc = tree + .key_map + .get(&ChainIndex::from_str("/1/0").unwrap()) + .unwrap(); + assert_eq!(acc.value.1.balance, 6); + } } diff --git a/key_protocol/src/key_protocol_core/mod.rs b/key_protocol/src/key_protocol_core/mod.rs index 41a686b..ce41d38 100644 --- a/key_protocol/src/key_protocol_core/mod.rs +++ b/key_protocol/src/key_protocol_core/mod.rs @@ -187,15 +187,8 @@ mod tests { fn test_new_account() { let mut user_data = NSSAUserData::default(); - let (account_id_pub, _) = user_data.generate_new_public_transaction_private_key(None); - let (account_id_private, _) = - user_data.generate_new_privacy_preserving_transaction_key_chain(None); - - let is_private_key_generated = user_data - .get_pub_account_signing_key(&account_id_pub) - .is_some(); - - assert!(is_private_key_generated); + let (account_id_private, _) = user_data + .generate_new_privacy_preserving_transaction_key_chain(Some(ChainIndex::root())); let is_key_chain_generated = user_data.get_private_account(&account_id_private).is_some(); diff --git a/wallet/src/cli/mod.rs b/wallet/src/cli/mod.rs index d119faa..3ea6a84 100644 --- a/wallet/src/cli/mod.rs +++ b/wallet/src/cli/mod.rs @@ -231,7 +231,7 @@ pub async fn execute_keys_restoration(password: String, depth: u32) -> Result<() .storage .user_data .public_key_tree - .cleanup_tree_for_depth(depth, wallet_core.sequencer_client.clone()) + .cleanup_tree_remove_uninit_layered(depth, wallet_core.sequencer_client.clone()) .await?; println!("Public tree cleaned up"); @@ -252,7 +252,7 @@ pub async fn execute_keys_restoration(password: String, depth: u32) -> Result<() .storage .user_data .private_key_tree - .cleanup_tree_for_depth(depth); + .cleanup_tree_remove_uninit_layered(depth); println!("Private tree cleaned up");