diff --git a/nssa/core/Cargo.toml b/nssa/core/Cargo.toml index e74eb90..1cdc78f 100644 --- a/nssa/core/Cargo.toml +++ b/nssa/core/Cargo.toml @@ -7,7 +7,8 @@ edition = "2024" risc0-zkvm = "2.3.1" serde = { version = "1.0", default-features = false } thiserror = { version = "2.0.12", optional = true } +bytemuck = { version = "1.13", optional = true } [features] default = [] -host = ["thiserror"] +host = ["thiserror", "bytemuck"] diff --git a/nssa/core/src/lib.rs b/nssa/core/src/lib.rs index 7f717f0..a616237 100644 --- a/nssa/core/src/lib.rs +++ b/nssa/core/src/lib.rs @@ -107,12 +107,7 @@ pub struct PrivacyPreservingCircuitOutput { #[cfg(feature = "host")] impl PrivacyPreservingCircuitOutput { pub fn to_bytes(&self) -> Vec { - let words = to_vec(&self).unwrap(); - let mut result = Vec::with_capacity(4 * words.len()); - for word in &words { - result.extend_from_slice(&word.to_le_bytes()); - } - result + bytemuck::cast_slice(&to_vec(&self).unwrap()).to_vec() } } diff --git a/nssa/src/lib.rs b/nssa/src/lib.rs index 2327514..2156b51 100644 --- a/nssa/src/lib.rs +++ b/nssa/src/lib.rs @@ -5,6 +5,7 @@ pub mod program; pub mod public_transaction; mod signature; mod state; +mod merkle_tree; pub use address::Address; pub use public_transaction::PublicTransaction; diff --git a/nssa/src/merkle_tree.rs b/nssa/src/merkle_tree.rs new file mode 100644 index 0000000..5c2821b --- /dev/null +++ b/nssa/src/merkle_tree.rs @@ -0,0 +1,135 @@ +use std::collections::HashMap; + +use sha2::{Digest, Sha256}; + +type Value = [u8; 32]; +type Node = [u8; 32]; + +/// Compute parent as the hash of two child nodes +fn hash_two(left: &Node, right: &Node) -> Node { + let mut hasher = Sha256::new(); + hasher.update(left); + hasher.update(right); + hasher.finalize().into() +} + +fn hash_value(value: &Value) -> Node { + let mut hasher = Sha256::new(); + hasher.update(value); + hasher.finalize().into() +} + +#[derive(Debug)] +pub struct MerkleTree { + index_map: HashMap, + node_map: HashMap, + capacity: usize, + length: usize, +} + +impl MerkleTree { + pub fn root(&self) -> Node { + *self.node_map.get(&0).unwrap() + } + + pub fn new(mut values: Vec) -> Self { + Self::deduplicate_values(&mut values); + + let capacity = values.len().next_power_of_two(); + let length = values.len(); + + let base_length = capacity; + + let mut node_map: HashMap = values + .iter() + .enumerate() + .map(|(index, value)| (index + base_length - 1, hash_value(value))) + .collect(); + node_map.extend( + (values.len()..base_length) + .map(|index| (index + base_length - 1, [0; 32])) + .collect::>(), + ); + + let mut current_layer_length = base_length; + let mut current_layer_first_index = base_length - 1; + + while current_layer_length > 1 { + let next_layer_length = current_layer_length >> 1; + let next_layer_first_index = current_layer_first_index >> 1; + + let next_layer = (next_layer_first_index..(next_layer_first_index + next_layer_length)) + .map(|index| { + let left_child = node_map.get(&((index << 1) + 1)).unwrap(); + let right_child = node_map.get(&((index << 1) + 2)).unwrap(); + (index, hash_two(&left_child, &right_child)) + }) + .collect::>(); + + node_map.extend(&next_layer); + + current_layer_length = next_layer_length; + current_layer_first_index = next_layer_first_index; + } + + let index_map = values + .into_iter() + .enumerate() + .map(|(index, value)| (value, index)) + .collect(); + + Self { + index_map, + node_map, + capacity, + length, + } + } + + fn deduplicate_values(values: &mut [Value]) { + // TODO: implement + } +} + +#[cfg(test)] +mod tests { + use nssa_core::account::{Account, NullifierPublicKey}; + + use super::*; + + #[test] + fn test_merkle_tree_1() { + let values = vec![[1; 32], [2; 32], [3; 32], [4; 32]]; + let tree = MerkleTree::new(values); + let expected_root = [ + 72, 199, 63, 120, 33, 165, 138, 141, 42, 112, 62, 91, 57, 197, 113, 192, 170, 32, 207, + 20, 171, 205, 10, 248, 242, 185, 85, 188, 32, 41, 152, 222, + ]; + + assert_eq!(tree.root(), expected_root); + assert_eq!(*tree.index_map.get(&[1; 32]).unwrap(), 0); + assert_eq!(*tree.index_map.get(&[2; 32]).unwrap(), 1); + assert_eq!(*tree.index_map.get(&[3; 32]).unwrap(), 2); + assert_eq!(*tree.index_map.get(&[4; 32]).unwrap(), 3); + assert_eq!(tree.capacity, 4); + assert_eq!(tree.length, 4); + } + + #[test] + fn test_merkle_tree_2() { + let values = vec![[1; 32], [2; 32], [3; 32], [0; 32]]; + let tree = MerkleTree::new(values); + let expected_root = [ + 201, 187, 184, 48, 150, 223, 133, 21, 122, 20, 110, 125, 119, 4, 85, 169, 132, 18, 222, + 224, 99, 49, 135, 238, 134, 254, 230, 200, 164, 91, 131, 26, + ]; + + assert_eq!(tree.root(), expected_root); + assert_eq!(*tree.index_map.get(&[1; 32]).unwrap(), 0); + assert_eq!(*tree.index_map.get(&[2; 32]).unwrap(), 1); + assert_eq!(*tree.index_map.get(&[3; 32]).unwrap(), 2); + assert_eq!(*tree.index_map.get(&[0; 32]).unwrap(), 3); + assert_eq!(tree.capacity, 4); + assert_eq!(tree.length, 4); + } +}