diff --git a/key_protocol/Cargo.toml b/key_protocol/Cargo.toml index a562515..24b92c0 100644 --- a/key_protocol/Cargo.toml +++ b/key_protocol/Cargo.toml @@ -16,9 +16,11 @@ bip39.workspace = true hmac-sha512.workspace = true thiserror.workspace = true nssa-core = { path = "../nssa/core", features = ["host"] } +itertools.workspace = true [dependencies.common] path = "../common" [dependencies.nssa] path = "../nssa" +features = ["no_docker"] diff --git a/key_protocol/src/key_management/key_tree/chain_index.rs b/key_protocol/src/key_management/key_tree/chain_index.rs index 8b28327..d2c9c3b 100644 --- a/key_protocol/src/key_management/key_tree/chain_index.rs +++ b/key_protocol/src/key_management/key_tree/chain_index.rs @@ -1,8 +1,9 @@ use std::{fmt::Display, str::FromStr}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, Hash)] pub struct ChainIndex(Vec); #[derive(thiserror::Error, Debug)] @@ -104,6 +105,39 @@ impl ChainIndex { pub fn depth(&self) -> u32 { self.0.iter().map(|cci| cci + 1).sum() } + + fn collapse_back(&self) -> Option { + let mut res = self.parent()?; + + let last_mut = res.0.last_mut()?; + *last_mut += *(self.0.last()?) + 1; + + Some(res) + } + + fn shuffle_iter(&self) -> impl Iterator { + self.0 + .iter() + .permutations(self.0.len()) + .unique() + .map(|item| ChainIndex(item.into_iter().cloned().collect())) + } + + pub fn chain_ids_at_depth(depth: usize) -> impl Iterator { + let mut stack = vec![ChainIndex(vec![0; depth])]; + let mut cumulative_stack = vec![ChainIndex(vec![0; depth])]; + + while let Some(id) = stack.pop() { + if let Some(collapsed_id) = id.collapse_back() { + for id in collapsed_id.shuffle_iter() { + stack.push(id.clone()); + cumulative_stack.push(id); + } + } + } + + cumulative_stack.into_iter().unique() + } } #[cfg(test)] @@ -211,4 +245,38 @@ mod tests { assert_eq!(parent_chain_id, ChainIndex::root()) } + + #[test] + fn test_collapse_back() { + let chain_id = ChainIndex(vec![1, 1]); + + let collapsed = chain_id.collapse_back().unwrap(); + + assert_eq!(collapsed, ChainIndex(vec![3])) + } + + #[test] + fn test_collapse_back_one() { + let chain_id = ChainIndex(vec![1]); + + let collapsed = chain_id.collapse_back(); + + assert_eq!(collapsed, None) + } + + #[test] + fn test_collapse_back_root() { + let chain_id = ChainIndex(vec![]); + + let collapsed = chain_id.collapse_back(); + + assert_eq!(collapsed, None) + } + + #[test] + fn test_shuffle() { + for id in ChainIndex::chain_ids_at_depth(5) { + println!("{id}"); + } + } } diff --git a/key_protocol/src/key_management/key_tree/mod.rs b/key_protocol/src/key_management/key_tree/mod.rs index b75671e..d964af4 100644 --- a/key_protocol/src/key_management/key_tree/mod.rs +++ b/key_protocol/src/key_management/key_tree/mod.rs @@ -111,7 +111,7 @@ impl KeyTree { } } - fn generate_new_node_unconstrained( + pub fn generate_new_node( &mut self, parent_cci: &ChainIndex, ) -> Option<(nssa::AccountId, ChainIndex)> { @@ -165,7 +165,7 @@ impl KeyTree { let mut next_id = curr_id.nth_child(0); while (next_id.depth()) < depth { - self.generate_new_node_unconstrained(&curr_id); + self.generate_new_node(&curr_id); id_stack.push(next_id.clone()); next_id = next_id.next_in_line(); } @@ -174,45 +174,6 @@ impl KeyTree { } impl KeyTree { - #[allow(clippy::result_large_err)] - pub fn generate_new_node( - &mut self, - parent_cci: &ChainIndex, - ) -> Result<(nssa::AccountId, ChainIndex), KeyTreeGenerationError> { - let parent_keys = - self.key_map - .get(parent_cci) - .ok_or(KeyTreeGenerationError::ParentChainIdNotFound( - parent_cci.clone(), - ))?; - let next_child_id = self - .find_next_last_child_of_id(parent_cci) - .expect("Can be None only if parent is not present"); - let next_cci = parent_cci.nth_child(next_child_id); - - if let Some(prev_cci) = next_cci.previous_in_line() { - let prev_keys = self.key_map.get(&prev_cci).unwrap_or_else(|| { - panic!("Constraint violated, previous child with id {prev_cci} is missing") - }); - - if prev_keys.value.1 == nssa::Account::default() { - return Err(KeyTreeGenerationError::PredecesorsNotInitialized(next_cci)); - } - } else if *parent_cci != ChainIndex::root() - && parent_keys.value.1 == nssa::Account::default() - { - return Err(KeyTreeGenerationError::PredecesorsNotInitialized(next_cci)); - } - - let child_keys = parent_keys.nth_child(next_child_id); - let account_id = child_keys.account_id(); - - self.key_map.insert(next_cci.clone(), child_keys); - self.account_id_map.insert(account_id, next_cci.clone()); - - Ok((account_id, next_cci)) - } - /// Cleanup of all non-initialized accounts in a private tree /// /// For given `depth` checks children to a tree such that their `ChainIndex::depth(&self) < @@ -221,7 +182,9 @@ impl KeyTree { /// If account is default, removes them. /// /// Chain must be parsed for accounts beforehand - pub fn cleanup_tree_remove_ininit_for_depth(&mut self, depth: u32) { + /// + /// Fast, leaves gaps between accounts + pub fn cleanup_tree_remove_uninit_for_depth(&mut self, depth: u32) { let mut id_stack = vec![ChainIndex::root()]; while let Some(curr_id) = id_stack.pop() { @@ -241,64 +204,42 @@ impl KeyTree { } } } + + /// Cleanup of non-initialized accounts in a private tree + /// + /// If account is default, removes them, stops at first non-default account. + /// + /// Walks through tree in lairs of same depth using `ChainIndex::chain_ids_at_depth()` + /// + /// Chain must be parsed for accounts beforehand + /// + /// Slow, maintains tree consistency. + pub fn cleanup_tree_remove_uninit_layered(&mut self, depth: u32) { + 'outer: for i in (1..(depth as usize)).rev() { + println!("Cleanup of tree at depth {i}"); + for id in ChainIndex::chain_ids_at_depth(i) { + if let Some(node) = self.key_map.get(&id) { + if node.value.1 == nssa::Account::default() { + let addr = node.account_id(); + self.remove(addr); + } else { + break 'outer; + } + } + } + } + } } impl KeyTree { - #[allow(clippy::result_large_err)] - pub async fn generate_new_node( - &mut self, - parent_cci: &ChainIndex, - client: Arc, - ) -> Result<(nssa::AccountId, ChainIndex), KeyTreeGenerationError> { - let parent_keys = - self.key_map - .get(parent_cci) - .ok_or(KeyTreeGenerationError::ParentChainIdNotFound( - parent_cci.clone(), - ))?; - let next_child_id = self - .find_next_last_child_of_id(parent_cci) - .expect("Can be None only if parent is not present"); - let next_cci = parent_cci.nth_child(next_child_id); - - if let Some(prev_cci) = next_cci.previous_in_line() { - let prev_keys = self.key_map.get(&prev_cci).unwrap_or_else(|| { - panic!("Constraint violated, previous child with id {prev_cci} is missing") - }); - let prev_acc = client - .get_account(prev_keys.account_id().to_string()) - .await? - .account; - - if prev_acc == nssa::Account::default() { - return Err(KeyTreeGenerationError::PredecesorsNotInitialized(next_cci)); - } - } else if *parent_cci != ChainIndex::root() { - let parent_acc = client - .get_account(parent_keys.account_id().to_string()) - .await? - .account; - - if parent_acc == nssa::Account::default() { - return Err(KeyTreeGenerationError::PredecesorsNotInitialized(next_cci)); - } - } - - let child_keys = parent_keys.nth_child(next_child_id); - let account_id = child_keys.account_id(); - - self.key_map.insert(next_cci.clone(), child_keys); - self.account_id_map.insert(account_id, next_cci.clone()); - - Ok((account_id, next_cci)) - } - /// Cleanup of all non-initialized accounts in a public tree /// /// For given `depth` checks children to a tree such that their `ChainIndex::depth(&self) < /// depth`. /// /// If account is default, removes them. + /// + /// Fast, leaves gaps between accounts pub async fn cleanup_tree_remove_ininit_for_depth( &mut self, depth: u32, @@ -326,6 +267,38 @@ impl KeyTree { Ok(()) } + + /// Cleanup of non-initialized accounts in a public tree + /// + /// If account is default, removes them, stops at first non-default account. + /// + /// Walks through tree in lairs of same depth using `ChainIndex::chain_ids_at_depth()` + /// + /// Slow, maintains tree consistency. + pub async fn cleanup_tree_remove_uninit_layered( + &mut self, + depth: u32, + client: Arc, + ) -> Result<()> { + 'outer: for i in (1..(depth as usize)).rev() { + println!("Cleanup of tree at depth {i}"); + for id in ChainIndex::chain_ids_at_depth(i) { + if let Some(node) = self.key_map.get(&id) { + let address = node.account_id(); + let node_acc = client.get_account(address.to_string()).await?.account; + + if node_acc == nssa::Account::default() { + let addr = node.account_id(); + self.remove(addr); + } else { + break 'outer; + } + } + } + } + + Ok(()) + } } #[cfg(test)] @@ -367,8 +340,7 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 0); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); assert!( tree.key_map @@ -381,18 +353,12 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 1); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); let next_last_child_for_parent_id = tree .find_next_last_child_of_id(&ChainIndex::root()) @@ -413,8 +379,7 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 0); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); assert!( tree.key_map @@ -427,7 +392,7 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 1); - let key_opt = tree.generate_new_node_unconstrained(&ChainIndex::from_str("/3").unwrap()); + let key_opt = tree.generate_new_node(&ChainIndex::from_str("/3").unwrap()); assert_eq!(key_opt, None); } @@ -444,8 +409,7 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 0); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); assert!( tree.key_map @@ -458,8 +422,7 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 1); - tree.generate_new_node_unconstrained(&ChainIndex::root()) - .unwrap(); + tree.generate_new_node(&ChainIndex::root()).unwrap(); assert!( tree.key_map @@ -472,7 +435,7 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 2); - tree.generate_new_node_unconstrained(&ChainIndex::from_str("/0").unwrap()) + tree.generate_new_node(&ChainIndex::from_str("/0").unwrap()) .unwrap(); let next_last_child_for_parent_id = tree @@ -486,7 +449,7 @@ mod tests { .contains_key(&ChainIndex::from_str("/0/0").unwrap()) ); - tree.generate_new_node_unconstrained(&ChainIndex::from_str("/0").unwrap()) + tree.generate_new_node(&ChainIndex::from_str("/0").unwrap()) .unwrap(); let next_last_child_for_parent_id = tree @@ -500,7 +463,7 @@ mod tests { .contains_key(&ChainIndex::from_str("/0/1").unwrap()) ); - tree.generate_new_node_unconstrained(&ChainIndex::from_str("/0").unwrap()) + tree.generate_new_node(&ChainIndex::from_str("/0").unwrap()) .unwrap(); let next_last_child_for_parent_id = tree @@ -514,7 +477,7 @@ mod tests { .contains_key(&ChainIndex::from_str("/0/2").unwrap()) ); - tree.generate_new_node_unconstrained(&ChainIndex::from_str("/0/1").unwrap()) + tree.generate_new_node(&ChainIndex::from_str("/0/1").unwrap()) .unwrap(); assert!( @@ -529,49 +492,6 @@ mod tests { assert_eq!(next_last_child_for_parent_id, 1); } - #[test] - fn test_key_generation_constraint() { - let seed_holder = seed_holder_for_tests(); - - let mut tree = KeyTreePrivate::new(&seed_holder); - - let (_, chain_id) = tree.generate_new_node(&ChainIndex::root()).unwrap(); - - assert_eq!(chain_id, ChainIndex::from_str("/0").unwrap()); - - let res = tree.generate_new_node(&ChainIndex::from_str("/").unwrap()); - - assert!(matches!( - res, - Err(KeyTreeGenerationError::PredecesorsNotInitialized(_)) - )); - - let res = tree.generate_new_node(&ChainIndex::from_str("/0").unwrap()); - - assert!(matches!( - res, - Err(KeyTreeGenerationError::PredecesorsNotInitialized(_)) - )); - - let acc = tree - .key_map - .get_mut(&ChainIndex::from_str("/0").unwrap()) - .unwrap(); - acc.value.1.balance = 1; - - let (_, chain_id) = tree - .generate_new_node(&ChainIndex::from_str("/").unwrap()) - .unwrap(); - - assert_eq!(chain_id, ChainIndex::from_str("/1").unwrap()); - - let (_, chain_id) = tree - .generate_new_node(&ChainIndex::from_str("/0").unwrap()) - .unwrap(); - - assert_eq!(chain_id, ChainIndex::from_str("/0/0").unwrap()); - } - #[test] fn test_cleanup() { let seed_holder = seed_holder_for_tests(); @@ -579,12 +499,6 @@ mod tests { let mut tree = KeyTreePrivate::new(&seed_holder); tree.generate_tree_for_depth(10); - let acc = tree - .key_map - .get_mut(&ChainIndex::from_str("/0").unwrap()) - .unwrap(); - acc.value.1.balance = 1; - let acc = tree .key_map .get_mut(&ChainIndex::from_str("/1").unwrap()) @@ -597,12 +511,6 @@ mod tests { .unwrap(); acc.value.1.balance = 3; - let acc = tree - .key_map - .get_mut(&ChainIndex::from_str("/0/0").unwrap()) - .unwrap(); - acc.value.1.balance = 4; - let acc = tree .key_map .get_mut(&ChainIndex::from_str("/0/1").unwrap()) @@ -615,7 +523,7 @@ mod tests { .unwrap(); acc.value.1.balance = 6; - tree.cleanup_tree_remove_ininit_for_depth(10); + tree.cleanup_tree_remove_uninit_layered(10); let mut key_set_res = HashSet::new(); key_set_res.insert("/0".to_string()); @@ -634,12 +542,6 @@ mod tests { assert_eq!(key_set, key_set_res); - let acc = tree - .key_map - .get(&ChainIndex::from_str("/0").unwrap()) - .unwrap(); - assert_eq!(acc.value.1.balance, 1); - let acc = tree .key_map .get(&ChainIndex::from_str("/1").unwrap()) @@ -652,12 +554,6 @@ mod tests { .unwrap(); assert_eq!(acc.value.1.balance, 3); - let acc = tree - .key_map - .get(&ChainIndex::from_str("/0/0").unwrap()) - .unwrap(); - assert_eq!(acc.value.1.balance, 4); - let acc = tree .key_map .get(&ChainIndex::from_str("/0/1").unwrap()) diff --git a/key_protocol/src/key_protocol_core/mod.rs b/key_protocol/src/key_protocol_core/mod.rs index c474df0..fc0a393 100644 --- a/key_protocol/src/key_protocol_core/mod.rs +++ b/key_protocol/src/key_protocol_core/mod.rs @@ -1,7 +1,6 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use anyhow::Result; -use common::sequencer_client::SequencerClient; use k256::AffinePoint; use serde::{Deserialize, Serialize}; @@ -88,14 +87,12 @@ impl NSSAUserData { /// Generated new private key for public transaction signatures /// /// Returns the account_id of new account - pub async fn generate_new_public_transaction_private_key( + pub fn generate_new_public_transaction_private_key( &mut self, parent_cci: ChainIndex, - sequencer_client: Arc, ) -> nssa::AccountId { self.public_key_tree - .generate_new_node(&parent_cci, sequencer_client) - .await + .generate_new_node(&parent_cci) .unwrap() .0 } diff --git a/wallet/src/cli/mod.rs b/wallet/src/cli/mod.rs index cda599b..e53849e 100644 --- a/wallet/src/cli/mod.rs +++ b/wallet/src/cli/mod.rs @@ -219,7 +219,7 @@ pub async fn execute_keys_restoration(password: String, depth: u32) -> Result<() .storage .user_data .public_key_tree - .cleanup_tree_remove_ininit_for_depth(depth, wallet_core.sequencer_client.clone()) + .cleanup_tree_remove_uninit_layered(depth, wallet_core.sequencer_client.clone()) .await?; println!("Public tree cleaned up"); @@ -240,7 +240,7 @@ pub async fn execute_keys_restoration(password: String, depth: u32) -> Result<() .storage .user_data .private_key_tree - .cleanup_tree_remove_ininit_for_depth(depth); + .cleanup_tree_remove_uninit_layered(depth); println!("Private tree cleaned up"); diff --git a/wallet/src/lib.rs b/wallet/src/lib.rs index 56872db..d0a9014 100644 --- a/wallet/src/lib.rs +++ b/wallet/src/lib.rs @@ -115,8 +115,7 @@ impl WalletCore { pub async fn create_new_account_public(&mut self, chain_index: ChainIndex) -> AccountId { self.storage .user_data - .generate_new_public_transaction_private_key(chain_index, self.sequencer_client.clone()) - .await + .generate_new_public_transaction_private_key(chain_index) } pub fn create_new_account_private(&mut self, chain_index: ChainIndex) -> AccountId {