From fba905f45dd407f1bbdc8a3d1d144e7825b6ba1b Mon Sep 17 00:00:00 2001 From: Aaryamann Challani <43716372+rymnc@users.noreply.github.com> Date: Thu, 10 Nov 2022 22:41:44 +0530 Subject: [PATCH] fix(rln): throw if attempting to insert out of bounds (#67) * fix(rln): throw if attempting to insert out of bounds * chore(rln): better error msg * fix(merkle-tree): make it mimic OptimalMerkleTree impl * chore(rln): return result as is --- rln/src/ffi.rs | 50 ++++++++++++++++++++++++++++ rln/src/public.rs | 45 ++++++++++++++++++++++--- utils/src/merkle_tree/merkle_tree.rs | 38 ++++++++++++++++++--- 3 files changed, 124 insertions(+), 9 deletions(-) diff --git a/rln/src/ffi.rs b/rln/src/ffi.rs index 4e12f87..5c0f4db 100644 --- a/rln/src/ffi.rs +++ b/rln/src/ffi.rs @@ -534,6 +534,56 @@ mod test { assert_eq!(root_batch_with_init, root_single_additions); } + + #[test] + // This test is similar to the one in public.rs but it uses the RLN object as a pointer + fn test_set_leaves_bad_index_ffi() { + let tree_height = TEST_TREE_HEIGHT; + let no_of_leaves = 256; + + // We generate a vector of random leaves + let mut leaves: Vec = Vec::new(); + let mut rng = thread_rng(); + for _ in 0..no_of_leaves { + leaves.push(Fr::rand(&mut rng)); + } + + let bad_index = (1 << tree_height) - rng.gen_range(0..no_of_leaves) as usize; + + // We create a RLN instance + let mut rln_pointer = MaybeUninit::<*mut RLN>::uninit(); + let input_buffer = &Buffer::from(TEST_RESOURCES_FOLDER.as_bytes()); + let success = new(tree_height, input_buffer, rln_pointer.as_mut_ptr()); + assert!(success, "RLN object creation failed"); + let rln_pointer = unsafe { &mut *rln_pointer.assume_init() }; + + // Get root of empty tree + let mut output_buffer = MaybeUninit::::uninit(); + let success = get_root(rln_pointer, output_buffer.as_mut_ptr()); + assert!(success, "get root call failed"); + + let output_buffer = unsafe { output_buffer.assume_init() }; + let result_data = <&[u8]>::from(&output_buffer).to_vec(); + let (root_empty, _) = bytes_le_to_fr(&result_data); + + // We add leaves in a batch into the tree + let leaves = vec_fr_to_bytes_le(&leaves); + let buffer = &Buffer::from(leaves.as_ref()); + let success = set_leaves_from(rln_pointer, bad_index, buffer); + assert!(!success, "set leaves from call succeeded"); + + // Get root of tree after attempted set + let mut output_buffer = MaybeUninit::::uninit(); + let success = get_root(rln_pointer, output_buffer.as_mut_ptr()); + assert!(success, "get root call failed"); + + let output_buffer = unsafe { output_buffer.assume_init() }; + let result_data = <&[u8]>::from(&output_buffer).to_vec(); + let (root_after_bad_set, _) = bytes_le_to_fr(&result_data); + + assert_eq!(root_empty, root_after_bad_set); + } + #[test] // This test is similar to the one in lib, but uses only public C API fn test_merkle_proof_ffi() { diff --git a/rln/src/public.rs b/rln/src/public.rs index ad35f4e..00e91f9 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -126,11 +126,7 @@ impl RLN<'_> { let (leaves, _) = bytes_le_to_vec_fr(&leaves_byte); // We set the leaves - for (i, leaf) in leaves.iter().enumerate() { - self.tree.set(i + index, *leaf)?; - } - - Ok(()) + return self.tree.set_range(index, leaves); } pub fn init_tree_with_leaves(&mut self, input_data: R) -> io::Result<()> { @@ -631,6 +627,45 @@ mod test { assert_eq!(root_batch_with_init, root_single_additions); } + #[test] + // This test checks if `set_leaves_from` throws an error when the index is out of bounds + fn test_set_leaves_bad_index() { + let tree_height = TEST_TREE_HEIGHT; + let no_of_leaves = 256; + + // We generate a vector of random leaves + let mut leaves: Vec = Vec::new(); + let mut rng = thread_rng(); + for _ in 0..no_of_leaves { + leaves.push(Fr::rand(&mut rng)); + } + let bad_index = (1 << tree_height) - rng.gen_range(0..no_of_leaves) as usize; + + // We create a new tree + let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); + let mut rln = RLN::new(tree_height, input_buffer); + + // Get root of empty tree + let mut buffer = Cursor::new(Vec::::new()); + rln.get_root(&mut buffer).unwrap(); + let (root_empty, _) = bytes_le_to_fr(&buffer.into_inner()); + + // We add leaves in a batch into the tree + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + rln.set_leaves_from(bad_index, &mut buffer) + .expect_err("Should throw an error"); + + // We check if number of leaves set is consistent + assert_eq!(rln.tree.leaves_set(), 0); + + // Get the root of the tree + let mut buffer = Cursor::new(Vec::::new()); + rln.get_root(&mut buffer).unwrap(); + let (root_after_bad_set, _) = bytes_le_to_fr(&buffer.into_inner()); + + assert_eq!(root_empty, root_after_bad_set); + } + #[test] // This test is similar to the one in lib, but uses only public API fn test_merkle_proof() { diff --git a/utils/src/merkle_tree/merkle_tree.rs b/utils/src/merkle_tree/merkle_tree.rs index f1e18e5..4053920 100644 --- a/utils/src/merkle_tree/merkle_tree.rs +++ b/utils/src/merkle_tree/merkle_tree.rs @@ -127,6 +127,28 @@ impl OptimalMerkleTree { Ok(()) } + // Sets multiple leaves from the specified tree index + pub fn set_range>( + &mut self, + start: usize, + leaves: I, + ) -> io::Result<()> { + let leaves = leaves.into_iter().collect::>(); + // check if the range is valid + if start + leaves.len() > self.capacity() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "provided range exceeds set size", + )); + } + for (i, leaf) in leaves.iter().enumerate() { + self.nodes.insert((self.depth, start + i), *leaf); + self.recalculate_from(start + i); + } + self.next_index = max(self.next_index, start + leaves.len()); + Ok(()) + } + // Sets a leaf at the next available index pub fn update_next(&mut self, leaf: H::Fr) -> io::Result<()> { self.set(self.next_index, leaf)?; @@ -380,11 +402,19 @@ impl FullMerkleTree { ) -> io::Result<()> { let index = self.capacity() + start - 1; let mut count = 0; - // TODO: Error/panic when hashes is longer than available leafs - for (leaf, hash) in self.nodes[index..].iter_mut().zip(hashes) { - *leaf = hash; - count += 1; + // first count number of hashes, and check that they fit in the tree + // then insert into the tree + let hashes = hashes.into_iter().collect::>(); + if hashes.len() + start > self.capacity() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "provided hashes do not fit in the tree", + )); } + hashes.into_iter().for_each(|hash| { + self.nodes[index + count] = hash; + count += 1; + }); if count != 0 { self.update_nodes(index, index + (count - 1)); self.next_index = max(self.next_index, start + count);