diff --git a/key_protocol/src/key_management/key_tree/chain_index.rs b/key_protocol/src/key_management/key_tree/chain_index.rs index d2c9c3b..6dbaf9a 100644 --- a/key_protocol/src/key_management/key_tree/chain_index.rs +++ b/key_protocol/src/key_management/key_tree/chain_index.rs @@ -138,6 +138,22 @@ impl ChainIndex { cumulative_stack.into_iter().unique() } + + pub fn chain_ids_at_depth_rev(depth: usize) -> impl Iterator { + let mut stack = vec![ChainIndex(vec![0; depth])]; + let mut cumulative_stack = vec![ChainIndex(vec![0; depth])]; + + while let Some(id) = stack.pop() { + if let Some(collapsed_id) = id.collapse_back() { + for id in collapsed_id.shuffle_iter() { + stack.push(id.clone()); + cumulative_stack.push(id); + } + } + } + + cumulative_stack.into_iter().rev().unique() + } } #[cfg(test)] diff --git a/key_protocol/src/key_management/key_tree/mod.rs b/key_protocol/src/key_management/key_tree/mod.rs index 324d6fe..389580b 100644 --- a/key_protocol/src/key_management/key_tree/mod.rs +++ b/key_protocol/src/key_management/key_tree/mod.rs @@ -126,8 +126,8 @@ impl KeyTree { let mut depth = 1; 'outer: loop { - for chain_id in ChainIndex::chain_ids_at_depth(depth) { - if self.key_map.get(&chain_id).is_none() { + for chain_id in ChainIndex::chain_ids_at_depth_rev(depth) { + if !self.key_map.contains_key(&chain_id) { break 'outer chain_id; } } @@ -520,13 +520,13 @@ mod tests { let mut tree = KeyTreePublic::new(&seed_holder); + for _ in 0..100 { + tree.generate_new_node_layered().unwrap(); + } + let next_slot = tree.find_next_slot_layered(); - println!("NEXT SLOT {next_slot}"); - - let (acc_id, chain_id) = tree.generate_new_node_layered().unwrap(); - - println!("NEXT ACC {acc_id} at {chain_id}"); + assert_eq!(next_slot, ChainIndex::from_str("/0/0/2/1").unwrap()); } #[test]