From 0ecec7016edc01543a0cd8440eaf71a330ff1774 Mon Sep 17 00:00:00 2001 From: Sergio Chouhy Date: Wed, 15 Apr 2026 23:34:49 -0300 Subject: [PATCH] refactor key trees --- .../key_management/key_tree/keys_private.rs | 34 +++-- .../key_management/key_tree/keys_public.rs | 74 +++++---- .../src/key_management/key_tree/mod.rs | 140 +++++++++--------- .../src/key_management/key_tree/traits.rs | 18 +-- wallet/src/chain_storage.rs | 2 +- wallet/src/lib.rs | 2 +- 6 files changed, 147 insertions(+), 123 deletions(-) diff --git a/key_protocol/src/key_management/key_tree/keys_private.rs b/key_protocol/src/key_management/key_tree/keys_private.rs index 8b19231a..298adc08 100644 --- a/key_protocol/src/key_management/key_tree/keys_private.rs +++ b/key_protocol/src/key_management/key_tree/keys_private.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::key_management::{ KeyChain, - key_tree::traits::KeyNode, + key_tree::traits::KeyTreeNode, secret_holders::{PrivateKeyHolder, SecretSpendingKey}, }; @@ -16,8 +16,8 @@ pub struct ChildKeysPrivate { pub cci: Option, } -impl KeyNode for ChildKeysPrivate { - fn root(seed: [u8; 64]) -> Self { +impl ChildKeysPrivate { + pub fn root(seed: [u8; 64]) -> Self { let hash_value = hmac_sha512::HMAC::mac(seed, b"LEE_master_priv"); let ssk = SecretSpendingKey( @@ -54,7 +54,7 @@ impl KeyNode for ChildKeysPrivate { } } - fn nth_child(&self, cci: u32) -> Self { + pub fn nth_child(&self, cci: u32, identifier: Identifier) -> Self { #[expect(clippy::arithmetic_side_effects, reason = "TODO: fix later")] let parent_pt = Scalar::from_repr(self.value.0.private_key_holder.nullifier_secret_key.into()) @@ -97,23 +97,23 @@ impl KeyNode for ChildKeysPrivate { }, }, nssa::Account::default(), - 0, + identifier, ), ccc, cci: Some(cci), } } - fn chain_code(&self) -> &[u8; 32] { + pub fn chain_code(&self) -> &[u8; 32] { &self.ccc } - fn child_index(&self) -> Option { + pub fn child_index(&self) -> Option { self.cci } - fn account_id(&self) -> nssa::AccountId { - nssa::AccountId::from((&self.value.0.nullifier_public_key, 0)) + pub fn account_id(&self) -> nssa::AccountId { + nssa::AccountId::from((&self.value.0.nullifier_public_key, self.value.2)) } } @@ -137,6 +137,20 @@ impl<'a> From<&'a mut ChildKeysPrivate> for &'a mut (KeyChain, nssa::Account, Id } } +impl KeyTreeNode for ChildKeysPrivate { + fn from_seed(seed: [u8; 64]) -> Self { + Self::root(seed) + } + + fn derive_child(&self, cci: u32) -> Self { + self.nth_child(cci, 0) + } + + fn account_ids(&self) -> Vec { + vec![self.account_id()] + } +} + #[cfg(test)] mod tests { use nssa_core::{NullifierPublicKey, NullifierSecretKey}; @@ -203,7 +217,7 @@ mod tests { ]; let root_node = ChildKeysPrivate::root(seed); - let child_node = ChildKeysPrivate::nth_child(&root_node, 42_u32); + let child_node = ChildKeysPrivate::nth_child(&root_node, 42_u32, 0); let expected_ccc: [u8; 32] = [ 27, 73, 133, 213, 214, 63, 217, 184, 164, 17, 172, 140, 223, 95, 255, 157, 11, 0, 58, diff --git a/key_protocol/src/key_management/key_tree/keys_public.rs b/key_protocol/src/key_management/key_tree/keys_public.rs index d4c32b4a..bc656ed3 100644 --- a/key_protocol/src/key_management/key_tree/keys_public.rs +++ b/key_protocol/src/key_management/key_tree/keys_public.rs @@ -1,7 +1,7 @@ use k256::elliptic_curve::{PrimeField as _, sec1::ToEncodedPoint as _}; use serde::{Deserialize, Serialize}; -use crate::key_management::key_tree::traits::KeyNode; +use crate::key_management::key_tree::traits::KeyTreeNode; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChildKeysPublic { @@ -13,32 +13,7 @@ pub struct ChildKeysPublic { } impl ChildKeysPublic { - fn compute_hash_value(&self, cci: u32) -> [u8; 64] { - let mut hash_input = vec![]; - - if ((2_u32).pow(31)).cmp(&cci) == std::cmp::Ordering::Greater { - // Non-harden. - // BIP-032 compatibility requires 1-byte header from the public_key; - // Not stored in `self.cpk.value()`. - let sk = k256::SecretKey::from_bytes(self.csk.value().into()) - .expect("32 bytes, within curve order"); - let pk = sk.public_key(); - hash_input.extend_from_slice(pk.to_encoded_point(true).as_bytes()); - } else { - // Harden. - hash_input.extend_from_slice(&[0_u8]); - hash_input.extend_from_slice(self.csk.value()); - } - - #[expect(clippy::big_endian_bytes, reason = "BIP-032 uses big endian")] - hash_input.extend_from_slice(&cci.to_be_bytes()); - - hmac_sha512::HMAC::mac(hash_input, self.ccc) - } -} - -impl KeyNode for ChildKeysPublic { - fn root(seed: [u8; 64]) -> Self { + pub fn root(seed: [u8; 64]) -> Self { let hash_value = hmac_sha512::HMAC::mac(seed, "LEE_master_pub"); let csk = nssa::PrivateKey::try_new( @@ -58,7 +33,7 @@ impl KeyNode for ChildKeysPublic { } } - fn nth_child(&self, cci: u32) -> Self { + pub fn nth_child(&self, cci: u32) -> Self { let hash_value = self.compute_hash_value(cci); let csk = nssa::PrivateKey::try_new({ @@ -90,17 +65,40 @@ impl KeyNode for ChildKeysPublic { } } - fn chain_code(&self) -> &[u8; 32] { + pub fn chain_code(&self) -> &[u8; 32] { &self.ccc } - fn child_index(&self) -> Option { + pub fn child_index(&self) -> Option { self.cci } - fn account_id(&self) -> nssa::AccountId { + pub fn account_id(&self) -> nssa::AccountId { nssa::AccountId::from(&self.cpk) } + + fn compute_hash_value(&self, cci: u32) -> [u8; 64] { + let mut hash_input = vec![]; + + if ((2_u32).pow(31)).cmp(&cci) == std::cmp::Ordering::Greater { + // Non-harden. + // BIP-032 compatibility requires 1-byte header from the public_key; + // Not stored in `self.cpk.value()`. + let sk = k256::SecretKey::from_bytes(self.csk.value().into()) + .expect("32 bytes, within curve order"); + let pk = sk.public_key(); + hash_input.extend_from_slice(pk.to_encoded_point(true).as_bytes()); + } else { + // Harden. + hash_input.extend_from_slice(&[0_u8]); + hash_input.extend_from_slice(self.csk.value()); + } + + #[expect(clippy::big_endian_bytes, reason = "BIP-032 uses big endian")] + hash_input.extend_from_slice(&cci.to_be_bytes()); + + hmac_sha512::HMAC::mac(hash_input, self.ccc) + } } #[expect( @@ -113,6 +111,20 @@ impl<'a> From<&'a ChildKeysPublic> for &'a nssa::PrivateKey { } } +impl KeyTreeNode for ChildKeysPublic { + fn from_seed(seed: [u8; 64]) -> Self { + Self::root(seed) + } + + fn derive_child(&self, cci: u32) -> Self { + self.nth_child(cci) + } + + fn account_ids(&self) -> Vec { + vec![self.account_id()] + } +} + #[cfg(test)] mod tests { use nssa::{PrivateKey, PublicKey}; diff --git a/key_protocol/src/key_management/key_tree/mod.rs b/key_protocol/src/key_management/key_tree/mod.rs index 08a576e5..c0812f9c 100644 --- a/key_protocol/src/key_management/key_tree/mod.rs +++ b/key_protocol/src/key_management/key_tree/mod.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::key_management::{ key_tree::{ chain_index::ChainIndex, keys_private::ChildKeysPrivate, keys_public::ChildKeysPublic, - traits::KeyNode, + traits::KeyTreeNode, }, secret_holders::SeedHolder, }; @@ -20,7 +20,7 @@ pub mod traits; pub const DEPTH_SOFT_CAP: u32 = 20; #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct KeyTree { +pub struct KeyTree { pub key_map: BTreeMap, pub account_id_map: BTreeMap, } @@ -28,7 +28,7 @@ pub struct KeyTree { pub type KeyTreePublic = KeyTree; pub type KeyTreePrivate = KeyTree; -impl KeyTree { +impl KeyTree { #[must_use] pub fn new(seed: &SeedHolder) -> Self { let seed_fit: [u8; 64] = seed @@ -37,29 +37,69 @@ impl KeyTree { .try_into() .expect("SeedHolder seed is 64 bytes long"); - let root_keys = N::root(seed_fit); - let account_id = root_keys.account_id(); - - let key_map = BTreeMap::from_iter([(ChainIndex::root(), root_keys)]); - let account_id_map = BTreeMap::from_iter([(account_id, ChainIndex::root())]); + let root_keys = N::from_seed(seed_fit); + let account_id_map = root_keys + .account_ids() + .into_iter() + .map(|id| (id, ChainIndex::root())) + .collect(); Self { - key_map, + key_map: BTreeMap::from_iter([(ChainIndex::root(), root_keys)]), account_id_map, } } pub fn new_from_root(root: N) -> Self { - let account_id_map = BTreeMap::from_iter([(root.account_id(), ChainIndex::root())]); - let key_map = BTreeMap::from_iter([(ChainIndex::root(), root)]); + let account_id_map = root + .account_ids() + .into_iter() + .map(|id| (id, ChainIndex::root())) + .collect(); Self { - key_map, + key_map: BTreeMap::from_iter([(ChainIndex::root(), root)]), account_id_map, } } - // ToDo: Add function to create a tree from list of nodes with consistency check. + pub fn generate_new_node( + &mut self, + parent_cci: &ChainIndex, + ) -> Option<(nssa::AccountId, ChainIndex)> { + let parent_keys = self.key_map.get(parent_cci)?; + let next_child_id = self + .find_next_last_child_of_id(parent_cci) + .expect("Can be None only if parent is not present"); + let next_cci = parent_cci.nth_child(next_child_id); + + let child_keys = parent_keys.derive_child(next_child_id); + let account_ids = child_keys.account_ids(); + let primary_account_id = *account_ids.first().expect("account_ids() must be non-empty"); + + for account_id in account_ids { + self.account_id_map.insert(account_id, next_cci.clone()); + } + self.key_map.insert(next_cci.clone(), child_keys); + + Some((primary_account_id, next_cci)) + } + + pub fn fill_node(&mut self, chain_index: &ChainIndex) -> Option<(nssa::AccountId, ChainIndex)> { + let parent_keys = self.key_map.get(&chain_index.parent()?)?; + let child_id = *chain_index.chain().last()?; + + let child_keys = parent_keys.derive_child(child_id); + let account_ids = child_keys.account_ids(); + let primary_account_id = *account_ids.first().expect("account_ids() must be non-empty"); + + for account_id in account_ids { + self.account_id_map.insert(account_id, chain_index.clone()); + } + self.key_map.insert(chain_index.clone(), child_keys); + + Some((primary_account_id, chain_index.clone())) + } #[must_use] pub fn find_next_last_child_of_id(&self, parent_id: &ChainIndex) -> Option { @@ -102,25 +142,6 @@ impl KeyTree { } } - pub fn generate_new_node( - &mut self, - parent_cci: &ChainIndex, - ) -> Option<(nssa::AccountId, ChainIndex)> { - let parent_keys = self.key_map.get(parent_cci)?; - let next_child_id = self - .find_next_last_child_of_id(parent_cci) - .expect("Can be None only if parent is not present"); - let next_cci = parent_cci.nth_child(next_child_id); - - let child_keys = parent_keys.nth_child(next_child_id); - let account_id = child_keys.account_id(); - - self.key_map.insert(next_cci.clone(), child_keys); - self.account_id_map.insert(account_id, next_cci.clone()); - - Some((account_id, next_cci)) - } - fn find_next_slot_layered(&self) -> ChainIndex { let mut depth = 1; @@ -134,44 +155,10 @@ impl KeyTree { } } - pub fn fill_node(&mut self, chain_index: &ChainIndex) -> Option<(nssa::AccountId, ChainIndex)> { - let parent_keys = self.key_map.get(&chain_index.parent()?)?; - let child_id = *chain_index.chain().last()?; - - let child_keys = parent_keys.nth_child(child_id); - let account_id = child_keys.account_id(); - - self.key_map.insert(chain_index.clone(), child_keys); - self.account_id_map.insert(account_id, chain_index.clone()); - - Some((account_id, chain_index.clone())) - } - pub fn generate_new_node_layered(&mut self) -> Option<(nssa::AccountId, ChainIndex)> { self.fill_node(&self.find_next_slot_layered()) } - #[must_use] - pub fn get_node(&self, account_id: nssa::AccountId) -> Option<&N> { - let chain_id = self.account_id_map.get(&account_id)?; - self.key_map.get(chain_id) - } - - pub fn get_node_mut(&mut self, account_id: nssa::AccountId) -> Option<&mut N> { - let chain_id = self.account_id_map.get(&account_id)?; - self.key_map.get_mut(chain_id) - } - - pub fn insert(&mut self, account_id: nssa::AccountId, chain_index: ChainIndex, node: N) { - self.account_id_map.insert(account_id, chain_index.clone()); - self.key_map.insert(chain_index, node); - } - - pub fn remove(&mut self, addr: nssa::AccountId) -> Option { - let chain_index = self.account_id_map.remove(&addr)?; - self.key_map.remove(&chain_index) - } - /// Populates tree with children. /// /// For given `depth` adds children to a tree such that their `ChainIndex::depth(&self) < @@ -194,6 +181,27 @@ impl KeyTree { } } } + + #[must_use] + pub fn get_node(&self, account_id: nssa::AccountId) -> Option<&N> { + let chain_id = self.account_id_map.get(&account_id)?; + self.key_map.get(chain_id) + } + + pub fn get_node_mut(&mut self, account_id: nssa::AccountId) -> Option<&mut N> { + let chain_id = self.account_id_map.get(&account_id)?; + self.key_map.get_mut(chain_id) + } + + pub fn insert(&mut self, account_id: nssa::AccountId, chain_index: ChainIndex, node: N) { + self.account_id_map.insert(account_id, chain_index.clone()); + self.key_map.insert(chain_index, node); + } + + pub fn remove(&mut self, addr: nssa::AccountId) -> Option { + let chain_index = self.account_id_map.remove(&addr)?; + self.key_map.remove(&chain_index) + } } impl KeyTree { diff --git a/key_protocol/src/key_management/key_tree/traits.rs b/key_protocol/src/key_management/key_tree/traits.rs index 65e8fae0..fb1549bc 100644 --- a/key_protocol/src/key_management/key_tree/traits.rs +++ b/key_protocol/src/key_management/key_tree/traits.rs @@ -1,15 +1,5 @@ -/// Trait, that reperesents a Node in hierarchical key tree. -pub trait KeyNode { - /// Tree root node. - fn root(seed: [u8; 64]) -> Self; - - /// `cci`'s child of node. - #[must_use] - fn nth_child(&self, cci: u32) -> Self; - - fn chain_code(&self) -> &[u8; 32]; - - fn child_index(&self) -> Option; - - fn account_id(&self) -> nssa::AccountId; +pub trait KeyTreeNode: Sized { + fn from_seed(seed: [u8; 64]) -> Self; + fn derive_child(&self, cci: u32) -> Self; + fn account_ids(&self) -> Vec; } diff --git a/wallet/src/chain_storage.rs b/wallet/src/chain_storage.rs index c176f7d6..36222894 100644 --- a/wallet/src/chain_storage.rs +++ b/wallet/src/chain_storage.rs @@ -200,7 +200,7 @@ impl WalletChainStore { #[cfg(test)] mod tests { use key_protocol::key_management::key_tree::{ - keys_private::ChildKeysPrivate, keys_public::ChildKeysPublic, traits::KeyNode as _, + keys_private::ChildKeysPrivate, keys_public::ChildKeysPublic, }; use super::*; diff --git a/wallet/src/lib.rs b/wallet/src/lib.rs index 77a09156..5557491a 100644 --- a/wallet/src/lib.rs +++ b/wallet/src/lib.rs @@ -15,7 +15,7 @@ use bip39::Mnemonic; use chain_storage::WalletChainStore; use common::{HashType, transaction::NSSATransaction}; use config::WalletConfig; -use key_protocol::key_management::key_tree::{chain_index::ChainIndex, traits::KeyNode as _}; +use key_protocol::key_management::key_tree::chain_index::ChainIndex; use log::info; use nssa::{ Account, AccountId, PrivacyPreservingTransaction,