From d1414a44c5e088bbf23523f032a001122e35f0c2 Mon Sep 17 00:00:00 2001 From: Aaryamann Challani <43716372+rymnc@users.noreply.github.com> Date: Tue, 1 Aug 2023 18:06:52 +0530 Subject: [PATCH] fix(rln): atomic operation edge case (#195) * fix(rln): atomic operation edge case * fmt * fix: bug * test: new batching mechanism * Revert "test: new batching mechanism" This reverts commit 396c2ec342fbd0a776e494ed46de1e4f27d8f93e. * fix: end should be max index + 1 * fix: optimization * fix: apply cleanup * fix: idiomatic leaf setting * fix: abstract out logic * fix: type aliasing for verbose types * fix: remove_indices_and_set_leaves fn --- Cargo.toml | 1 + rln/src/pm_tree_adapter.rs | 148 ++++++++++++++++++++++++------------- rln/src/public.rs | 105 +++++++++++++++++++++++++- rln/tests/ffi.rs | 2 +- 4 files changed, 200 insertions(+), 56 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fbba3a9..af31d21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "rln-wasm", "utils", ] +resolver = "2" # Compilation profile for any non-workspace member. # Dependencies are optimized, even in a dev build. This improves dev performance diff --git a/rln/src/pm_tree_adapter.rs b/rln/src/pm_tree_adapter.rs index 8cd660b..6db5f4d 100644 --- a/rln/src/pm_tree_adapter.rs +++ b/rln/src/pm_tree_adapter.rs @@ -1,15 +1,17 @@ -use crate::circuit::Fr; -use crate::hashers::{poseidon_hash, PoseidonHash}; -use crate::utils::{bytes_le_to_fr, fr_to_bytes_le}; -use color_eyre::{Report, Result}; -use serde_json::Value; -use std::collections::HashSet; use std::fmt::Debug; use std::path::PathBuf; use std::str::FromStr; + +use color_eyre::{Report, Result}; +use serde_json::Value; + use utils::pmtree::{Database, Hasher}; use utils::*; +use crate::circuit::Fr; +use crate::hashers::{poseidon_hash, PoseidonHash}; +use crate::utils::{bytes_le_to_fr, fr_to_bytes_le}; + const METADATA_KEY: [u8; 8] = *b"metadata"; pub struct PmTree { @@ -25,13 +27,9 @@ pub struct PmTreeProof { pub type FrOf = ::Fr; // The pmtree Hasher trait used by pmtree Merkle tree -impl pmtree::Hasher for PoseidonHash { +impl Hasher for PoseidonHash { type Fr = Fr; - fn default_leaf() -> Self::Fr { - Fr::from(0) - } - fn serialize(value: Self::Fr) -> pmtree::Value { fr_to_bytes_le(&value) } @@ -41,12 +39,16 @@ impl pmtree::Hasher for PoseidonHash { fr } + fn default_leaf() -> Self::Fr { + Fr::from(0) + } + fn hash(inputs: &[Self::Fr]) -> Self::Fr { poseidon_hash(inputs) } } -fn get_tmp_path() -> std::path::PathBuf { +fn get_tmp_path() -> PathBuf { std::env::temp_dir().join(format!("pmtree-{}", rand::random::())) } @@ -54,7 +56,7 @@ fn get_tmp() -> bool { true } -pub struct PmtreeConfig(pm_tree::Config); +pub struct PmtreeConfig(Config); impl FromStr for PmtreeConfig { type Err = Report; @@ -85,7 +87,7 @@ impl FromStr for PmtreeConfig { ))); } - let config = pm_tree::Config::new() + let config = Config::new() .temporary(temporary.unwrap_or(get_tmp())) .path(path.unwrap_or(get_tmp_path())) .cache_capacity(cache_capacity.unwrap_or(1024 * 1024 * 1024)) @@ -100,7 +102,7 @@ impl Default for PmtreeConfig { fn default() -> Self { let tmp_path = get_tmp_path(); PmtreeConfig( - pm_tree::Config::new() + Config::new() .temporary(true) .path(tmp_path) .cache_capacity(150_000) @@ -145,10 +147,6 @@ impl ZerokitMerkleTree for PmTree { }) } - fn close_db_connection(&mut self) -> Result<()> { - self.tree.db.close().map_err(|e| Report::msg(e.to_string())) - } - fn depth(&self) -> usize { self.tree.depth() } @@ -165,16 +163,16 @@ impl ZerokitMerkleTree for PmTree { self.tree.root() } + fn compute_root(&mut self) -> Result> { + Ok(self.tree.root()) + } + fn set(&mut self, index: usize, leaf: FrOf) -> Result<()> { self.tree .set(index, leaf) .map_err(|e| Report::msg(e.to_string())) } - fn get(&self, index: usize) -> Result> { - self.tree.get(index).map_err(|e| Report::msg(e.to_string())) - } - fn set_range>>( &mut self, start: usize, @@ -185,6 +183,10 @@ impl ZerokitMerkleTree for PmTree { .map_err(|e| Report::msg(e.to_string())) } + fn get(&self, index: usize) -> Result> { + self.tree.get(index).map_err(|e| Report::msg(e.to_string())) + } + fn override_range>, J: IntoIterator>( &mut self, start: usize, @@ -192,33 +194,15 @@ impl ZerokitMerkleTree for PmTree { indices: J, ) -> Result<()> { let leaves = leaves.into_iter().collect::>(); - let indices = indices.into_iter().collect::>(); - let end = start + leaves.len(); + let mut indices = indices.into_iter().collect::>(); + indices.sort(); - if leaves.len() + start - indices.len() > self.capacity() { - return Err(Report::msg("index out of bounds")); + match (leaves.is_empty(), indices.is_empty()) { + (true, true) => Err(Report::msg("no leaves or indices to be removed")), + (false, true) => self.set_range_with_leaves(start, leaves), + (true, false) => self.remove_indices(indices), + (false, false) => self.remove_indices_and_set_leaves(start, leaves, indices), } - - // extend the range to include indices to be removed - let min_index = indices.iter().min().unwrap_or(&start); - let max_index = indices.iter().max().unwrap_or(&end); - - let mut new_leaves = Vec::new(); - - // insert leaves into new_leaves - for i in *min_index..*max_index { - if indices.contains(&i) { - // insert 0 - new_leaves.push(Self::Hasher::default_leaf()); - } else { - // insert leaf - new_leaves.push(leaves[i - start]); - } - } - - self.tree - .set_range(start, new_leaves) - .map_err(|e| Report::msg(e.to_string())) } fn update_next(&mut self, leaf: FrOf) -> Result<()> { @@ -246,10 +230,6 @@ impl ZerokitMerkleTree for PmTree { } } - fn compute_root(&mut self) -> Result> { - Ok(self.tree.root()) - } - fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> { self.tree.db.put(METADATA_KEY, metadata.to_vec())?; self.metadata = metadata.to_vec(); @@ -268,6 +248,70 @@ impl ZerokitMerkleTree for PmTree { } Ok(data.unwrap()) } + + fn close_db_connection(&mut self) -> Result<()> { + self.tree.db.close().map_err(|e| Report::msg(e.to_string())) + } +} + +type PmTreeHasher = ::Hasher; +type FrOfPmTreeHasher = FrOf; + +impl PmTree { + fn set_range_with_leaves(&mut self, start: usize, leaves: Vec) -> Result<()> { + self.tree + .set_range(start, leaves) + .map_err(|e| Report::msg(e.to_string())) + } + + fn remove_indices(&mut self, indices: Vec) -> Result<()> { + let start = indices[0]; + let end = indices.last().unwrap() + 1; + + let mut new_leaves: Vec<_> = (start..end) + .map(|i| self.tree.get(i)) + .collect::>()?; + + new_leaves + .iter_mut() + .take(indices.len()) + .for_each(|leaf| *leaf = PmTreeHasher::default_leaf()); + + self.tree + .set_range(start, new_leaves) + .map_err(|e| Report::msg(e.to_string())) + } + + fn remove_indices_and_set_leaves( + &mut self, + start: usize, + leaves: Vec, + indices: Vec, + ) -> Result<()> { + let min_index = *indices.first().unwrap(); + let max_index = start + leaves.len(); + + // Generated a placeholder with the exact size needed, + // Initiated with default values to be overridden throughout the method + let mut set_values = vec![PmTreeHasher::default_leaf(); max_index - min_index]; + + // If the index is not in indices list, keep the original value + for i in min_index..start { + if !indices.contains(&i) { + let value = self.tree.get(i)?; + set_values[i - min_index] = value; + } + } + + // Insert new leaves after 'start' position + for (i, &leaf) in leaves.iter().enumerate() { + set_values[start - min_index + i] = leaf; + } + + self.tree + .set_range(min_index, set_values) + .map_err(|e| Report::msg(e.to_string())) + } } impl ZerokitMerkleProof for PmTreeProof { diff --git a/rln/src/public.rs b/rln/src/public.rs index e54d157..499711c 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -362,7 +362,7 @@ impl RLN<'_> { // We set the leaves self.tree .override_range(index, leaves, indices) - .map_err(|_| Report::msg("Could not perform the batch operation"))?; + .map_err(|e| Report::msg(format!("Could not perform the batch operation: {e}")))?; Ok(()) } @@ -1387,7 +1387,7 @@ mod test { assert_eq!(root_batch_with_init, root_single_additions); - rln.flush(); + rln.flush().unwrap(); } #[test] @@ -1430,7 +1430,7 @@ mod test { let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap()); let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&last_leaf).unwrap()); - rln.atomic_operation(no_of_leaves, leaves_buffer, indices_buffer) + rln.atomic_operation(last_leaf_index, leaves_buffer, indices_buffer) .unwrap(); // We get the root of the tree obtained after a no-op @@ -1441,6 +1441,105 @@ mod test { assert_eq!(root_after_insertion, root_after_noop); } + #[test] + fn test_atomic_operation_zero_indexed() { + // Test duplicated from https://github.com/waku-org/go-zerokit-rln/pull/12/files + let tree_height = TEST_TREE_HEIGHT; + let no_of_leaves = 256; + + // We generate a vector of random leaves + let mut leaves: Vec = Vec::new(); + let mut rng = thread_rng(); + for _ in 0..no_of_leaves { + leaves.push(Fr::rand(&mut rng)); + } + + // We create a new tree + let input_buffer = + Cursor::new(json!({ "resources_folder": TEST_RESOURCES_FOLDER }).to_string()); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); + + // We add leaves in a batch into the tree + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); + rln.init_tree_with_leaves(&mut buffer).unwrap(); + + // We check if number of leaves set is consistent + assert_eq!(rln.tree.leaves_set(), no_of_leaves); + + // We get the root of the tree obtained adding leaves in batch + let mut buffer = Cursor::new(Vec::::new()); + rln.get_root(&mut buffer).unwrap(); + let (root_after_insertion, _) = bytes_le_to_fr(&buffer.into_inner()); + + let zero_index = 0; + let indices = vec![zero_index as u8]; + let zero_leaf: Vec = vec![]; + let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap()); + let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf).unwrap()); + rln.atomic_operation(0, leaves_buffer, indices_buffer) + .unwrap(); + + // We get the root of the tree obtained after a deletion + let mut buffer = Cursor::new(Vec::::new()); + rln.get_root(&mut buffer).unwrap(); + let (root_after_deletion, _) = bytes_le_to_fr(&buffer.into_inner()); + + assert_ne!(root_after_insertion, root_after_deletion); + } + + #[test] + fn test_atomic_operation_consistency() { + // Test duplicated from https://github.com/waku-org/go-zerokit-rln/pull/12/files + let tree_height = TEST_TREE_HEIGHT; + let no_of_leaves = 256; + + // We generate a vector of random leaves + let mut leaves: Vec = Vec::new(); + let mut rng = thread_rng(); + for _ in 0..no_of_leaves { + leaves.push(Fr::rand(&mut rng)); + } + + // We create a new tree + let input_buffer = + Cursor::new(json!({ "resources_folder": TEST_RESOURCES_FOLDER }).to_string()); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); + + // We add leaves in a batch into the tree + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); + rln.init_tree_with_leaves(&mut buffer).unwrap(); + + // We check if number of leaves set is consistent + assert_eq!(rln.tree.leaves_set(), no_of_leaves); + + // We get the root of the tree obtained adding leaves in batch + let mut buffer = Cursor::new(Vec::::new()); + rln.get_root(&mut buffer).unwrap(); + let (root_after_insertion, _) = bytes_le_to_fr(&buffer.into_inner()); + + let set_index = rng.gen_range(0..no_of_leaves) as usize; + let indices = vec![set_index as u8]; + let zero_leaf: Vec = vec![]; + let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap()); + let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf).unwrap()); + rln.atomic_operation(0, leaves_buffer, indices_buffer) + .unwrap(); + + // We get the root of the tree obtained after a deletion + let mut buffer = Cursor::new(Vec::::new()); + rln.get_root(&mut buffer).unwrap(); + let (root_after_deletion, _) = bytes_le_to_fr(&buffer.into_inner()); + + assert_ne!(root_after_insertion, root_after_deletion); + + // We get the leaf + let mut output_buffer = Cursor::new(Vec::::new()); + rln.get_leaf(set_index, &mut output_buffer).unwrap(); + let (received_leaf, _) = bytes_le_to_fr(output_buffer.into_inner().as_ref()); + + assert_eq!(received_leaf, Fr::from(0)); + } + #[allow(unused_must_use)] #[test] // This test checks if `set_leaves_from` throws an error when the index is out of bounds diff --git a/rln/tests/ffi.rs b/rln/tests/ffi.rs index ec5e0b6..0e3dd82 100644 --- a/rln/tests/ffi.rs +++ b/rln/tests/ffi.rs @@ -265,7 +265,7 @@ mod test { let success = atomic_operation( rln_pointer, - no_of_leaves as usize, + last_leaf_index as usize, leaves_buffer, indices_buffer, );