fix(rln): atomic operation edge case (#195)

* fix(rln): atomic operation edge case

* fmt

* fix: bug

* test: new batching mechanism

* Revert "test: new batching mechanism"

This reverts commit 396c2ec342.

* fix: end should be max index + 1

* fix: optimization

* fix: apply cleanup

* fix: idiomatic leaf setting

* fix: abstract out logic

* fix: type aliasing for verbose types

* fix: remove_indices_and_set_leaves fn
This commit is contained in:
Aaryamann Challani 2023-08-01 18:06:52 +05:30 committed by GitHub
parent 6d58320077
commit d1414a44c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 200 additions and 56 deletions

View File

@ -8,6 +8,7 @@ members = [
"rln-wasm", "rln-wasm",
"utils", "utils",
] ]
resolver = "2"
# Compilation profile for any non-workspace member. # Compilation profile for any non-workspace member.
# Dependencies are optimized, even in a dev build. This improves dev performance # Dependencies are optimized, even in a dev build. This improves dev performance

View File

@ -1,15 +1,17 @@
use crate::circuit::Fr;
use crate::hashers::{poseidon_hash, PoseidonHash};
use crate::utils::{bytes_le_to_fr, fr_to_bytes_le};
use color_eyre::{Report, Result};
use serde_json::Value;
use std::collections::HashSet;
use std::fmt::Debug; use std::fmt::Debug;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use color_eyre::{Report, Result};
use serde_json::Value;
use utils::pmtree::{Database, Hasher}; use utils::pmtree::{Database, Hasher};
use utils::*; use utils::*;
use crate::circuit::Fr;
use crate::hashers::{poseidon_hash, PoseidonHash};
use crate::utils::{bytes_le_to_fr, fr_to_bytes_le};
const METADATA_KEY: [u8; 8] = *b"metadata"; const METADATA_KEY: [u8; 8] = *b"metadata";
pub struct PmTree { pub struct PmTree {
@ -25,13 +27,9 @@ pub struct PmTreeProof {
pub type FrOf<H> = <H as Hasher>::Fr; pub type FrOf<H> = <H as Hasher>::Fr;
// The pmtree Hasher trait used by pmtree Merkle tree // The pmtree Hasher trait used by pmtree Merkle tree
impl pmtree::Hasher for PoseidonHash { impl Hasher for PoseidonHash {
type Fr = Fr; type Fr = Fr;
fn default_leaf() -> Self::Fr {
Fr::from(0)
}
fn serialize(value: Self::Fr) -> pmtree::Value { fn serialize(value: Self::Fr) -> pmtree::Value {
fr_to_bytes_le(&value) fr_to_bytes_le(&value)
} }
@ -41,12 +39,16 @@ impl pmtree::Hasher for PoseidonHash {
fr fr
} }
fn default_leaf() -> Self::Fr {
Fr::from(0)
}
fn hash(inputs: &[Self::Fr]) -> Self::Fr { fn hash(inputs: &[Self::Fr]) -> Self::Fr {
poseidon_hash(inputs) poseidon_hash(inputs)
} }
} }
fn get_tmp_path() -> std::path::PathBuf { fn get_tmp_path() -> PathBuf {
std::env::temp_dir().join(format!("pmtree-{}", rand::random::<u64>())) std::env::temp_dir().join(format!("pmtree-{}", rand::random::<u64>()))
} }
@ -54,7 +56,7 @@ fn get_tmp() -> bool {
true true
} }
pub struct PmtreeConfig(pm_tree::Config); pub struct PmtreeConfig(Config);
impl FromStr for PmtreeConfig { impl FromStr for PmtreeConfig {
type Err = Report; type Err = Report;
@ -85,7 +87,7 @@ impl FromStr for PmtreeConfig {
))); )));
} }
let config = pm_tree::Config::new() let config = Config::new()
.temporary(temporary.unwrap_or(get_tmp())) .temporary(temporary.unwrap_or(get_tmp()))
.path(path.unwrap_or(get_tmp_path())) .path(path.unwrap_or(get_tmp_path()))
.cache_capacity(cache_capacity.unwrap_or(1024 * 1024 * 1024)) .cache_capacity(cache_capacity.unwrap_or(1024 * 1024 * 1024))
@ -100,7 +102,7 @@ impl Default for PmtreeConfig {
fn default() -> Self { fn default() -> Self {
let tmp_path = get_tmp_path(); let tmp_path = get_tmp_path();
PmtreeConfig( PmtreeConfig(
pm_tree::Config::new() Config::new()
.temporary(true) .temporary(true)
.path(tmp_path) .path(tmp_path)
.cache_capacity(150_000) .cache_capacity(150_000)
@ -145,10 +147,6 @@ impl ZerokitMerkleTree for PmTree {
}) })
} }
fn close_db_connection(&mut self) -> Result<()> {
self.tree.db.close().map_err(|e| Report::msg(e.to_string()))
}
fn depth(&self) -> usize { fn depth(&self) -> usize {
self.tree.depth() self.tree.depth()
} }
@ -165,16 +163,16 @@ impl ZerokitMerkleTree for PmTree {
self.tree.root() self.tree.root()
} }
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
Ok(self.tree.root())
}
fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>) -> Result<()> { fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>) -> Result<()> {
self.tree self.tree
.set(index, leaf) .set(index, leaf)
.map_err(|e| Report::msg(e.to_string())) .map_err(|e| Report::msg(e.to_string()))
} }
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>> {
self.tree.get(index).map_err(|e| Report::msg(e.to_string()))
}
fn set_range<I: IntoIterator<Item = FrOf<Self::Hasher>>>( fn set_range<I: IntoIterator<Item = FrOf<Self::Hasher>>>(
&mut self, &mut self,
start: usize, start: usize,
@ -185,6 +183,10 @@ impl ZerokitMerkleTree for PmTree {
.map_err(|e| Report::msg(e.to_string())) .map_err(|e| Report::msg(e.to_string()))
} }
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>> {
self.tree.get(index).map_err(|e| Report::msg(e.to_string()))
}
fn override_range<I: IntoIterator<Item = FrOf<Self::Hasher>>, J: IntoIterator<Item = usize>>( fn override_range<I: IntoIterator<Item = FrOf<Self::Hasher>>, J: IntoIterator<Item = usize>>(
&mut self, &mut self,
start: usize, start: usize,
@ -192,33 +194,15 @@ impl ZerokitMerkleTree for PmTree {
indices: J, indices: J,
) -> Result<()> { ) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>(); let leaves = leaves.into_iter().collect::<Vec<_>>();
let indices = indices.into_iter().collect::<HashSet<_>>(); let mut indices = indices.into_iter().collect::<Vec<_>>();
let end = start + leaves.len(); indices.sort();
if leaves.len() + start - indices.len() > self.capacity() { match (leaves.is_empty(), indices.is_empty()) {
return Err(Report::msg("index out of bounds")); (true, true) => Err(Report::msg("no leaves or indices to be removed")),
(false, true) => self.set_range_with_leaves(start, leaves),
(true, false) => self.remove_indices(indices),
(false, false) => self.remove_indices_and_set_leaves(start, leaves, indices),
} }
// extend the range to include indices to be removed
let min_index = indices.iter().min().unwrap_or(&start);
let max_index = indices.iter().max().unwrap_or(&end);
let mut new_leaves = Vec::new();
// insert leaves into new_leaves
for i in *min_index..*max_index {
if indices.contains(&i) {
// insert 0
new_leaves.push(Self::Hasher::default_leaf());
} else {
// insert leaf
new_leaves.push(leaves[i - start]);
}
}
self.tree
.set_range(start, new_leaves)
.map_err(|e| Report::msg(e.to_string()))
} }
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> { fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> {
@ -246,10 +230,6 @@ impl ZerokitMerkleTree for PmTree {
} }
} }
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
Ok(self.tree.root())
}
fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> { fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> {
self.tree.db.put(METADATA_KEY, metadata.to_vec())?; self.tree.db.put(METADATA_KEY, metadata.to_vec())?;
self.metadata = metadata.to_vec(); self.metadata = metadata.to_vec();
@ -268,6 +248,70 @@ impl ZerokitMerkleTree for PmTree {
} }
Ok(data.unwrap()) Ok(data.unwrap())
} }
fn close_db_connection(&mut self) -> Result<()> {
self.tree.db.close().map_err(|e| Report::msg(e.to_string()))
}
}
type PmTreeHasher = <PmTree as ZerokitMerkleTree>::Hasher;
type FrOfPmTreeHasher = FrOf<PmTreeHasher>;
impl PmTree {
fn set_range_with_leaves(&mut self, start: usize, leaves: Vec<FrOfPmTreeHasher>) -> Result<()> {
self.tree
.set_range(start, leaves)
.map_err(|e| Report::msg(e.to_string()))
}
fn remove_indices(&mut self, indices: Vec<usize>) -> Result<()> {
let start = indices[0];
let end = indices.last().unwrap() + 1;
let mut new_leaves: Vec<_> = (start..end)
.map(|i| self.tree.get(i))
.collect::<Result<_, _>>()?;
new_leaves
.iter_mut()
.take(indices.len())
.for_each(|leaf| *leaf = PmTreeHasher::default_leaf());
self.tree
.set_range(start, new_leaves)
.map_err(|e| Report::msg(e.to_string()))
}
fn remove_indices_and_set_leaves(
&mut self,
start: usize,
leaves: Vec<FrOfPmTreeHasher>,
indices: Vec<usize>,
) -> Result<()> {
let min_index = *indices.first().unwrap();
let max_index = start + leaves.len();
// Generated a placeholder with the exact size needed,
// Initiated with default values to be overridden throughout the method
let mut set_values = vec![PmTreeHasher::default_leaf(); max_index - min_index];
// If the index is not in indices list, keep the original value
for i in min_index..start {
if !indices.contains(&i) {
let value = self.tree.get(i)?;
set_values[i - min_index] = value;
}
}
// Insert new leaves after 'start' position
for (i, &leaf) in leaves.iter().enumerate() {
set_values[start - min_index + i] = leaf;
}
self.tree
.set_range(min_index, set_values)
.map_err(|e| Report::msg(e.to_string()))
}
} }
impl ZerokitMerkleProof for PmTreeProof { impl ZerokitMerkleProof for PmTreeProof {

View File

@ -362,7 +362,7 @@ impl RLN<'_> {
// We set the leaves // We set the leaves
self.tree self.tree
.override_range(index, leaves, indices) .override_range(index, leaves, indices)
.map_err(|_| Report::msg("Could not perform the batch operation"))?; .map_err(|e| Report::msg(format!("Could not perform the batch operation: {e}")))?;
Ok(()) Ok(())
} }
@ -1387,7 +1387,7 @@ mod test {
assert_eq!(root_batch_with_init, root_single_additions); assert_eq!(root_batch_with_init, root_single_additions);
rln.flush(); rln.flush().unwrap();
} }
#[test] #[test]
@ -1430,7 +1430,7 @@ mod test {
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap()); let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap());
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&last_leaf).unwrap()); let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&last_leaf).unwrap());
rln.atomic_operation(no_of_leaves, leaves_buffer, indices_buffer) rln.atomic_operation(last_leaf_index, leaves_buffer, indices_buffer)
.unwrap(); .unwrap();
// We get the root of the tree obtained after a no-op // We get the root of the tree obtained after a no-op
@ -1441,6 +1441,105 @@ mod test {
assert_eq!(root_after_insertion, root_after_noop); assert_eq!(root_after_insertion, root_after_noop);
} }
#[test]
fn test_atomic_operation_zero_indexed() {
// Test duplicated from https://github.com/waku-org/go-zerokit-rln/pull/12/files
let tree_height = TEST_TREE_HEIGHT;
let no_of_leaves = 256;
// We generate a vector of random leaves
let mut leaves: Vec<Fr> = Vec::new();
let mut rng = thread_rng();
for _ in 0..no_of_leaves {
leaves.push(Fr::rand(&mut rng));
}
// We create a new tree
let input_buffer =
Cursor::new(json!({ "resources_folder": TEST_RESOURCES_FOLDER }).to_string());
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
assert_eq!(rln.tree.leaves_set(), no_of_leaves);
// We get the root of the tree obtained adding leaves in batch
let mut buffer = Cursor::new(Vec::<u8>::new());
rln.get_root(&mut buffer).unwrap();
let (root_after_insertion, _) = bytes_le_to_fr(&buffer.into_inner());
let zero_index = 0;
let indices = vec![zero_index as u8];
let zero_leaf: Vec<Fr> = vec![];
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap());
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf).unwrap());
rln.atomic_operation(0, leaves_buffer, indices_buffer)
.unwrap();
// We get the root of the tree obtained after a deletion
let mut buffer = Cursor::new(Vec::<u8>::new());
rln.get_root(&mut buffer).unwrap();
let (root_after_deletion, _) = bytes_le_to_fr(&buffer.into_inner());
assert_ne!(root_after_insertion, root_after_deletion);
}
#[test]
fn test_atomic_operation_consistency() {
// Test duplicated from https://github.com/waku-org/go-zerokit-rln/pull/12/files
let tree_height = TEST_TREE_HEIGHT;
let no_of_leaves = 256;
// We generate a vector of random leaves
let mut leaves: Vec<Fr> = Vec::new();
let mut rng = thread_rng();
for _ in 0..no_of_leaves {
leaves.push(Fr::rand(&mut rng));
}
// We create a new tree
let input_buffer =
Cursor::new(json!({ "resources_folder": TEST_RESOURCES_FOLDER }).to_string());
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
assert_eq!(rln.tree.leaves_set(), no_of_leaves);
// We get the root of the tree obtained adding leaves in batch
let mut buffer = Cursor::new(Vec::<u8>::new());
rln.get_root(&mut buffer).unwrap();
let (root_after_insertion, _) = bytes_le_to_fr(&buffer.into_inner());
let set_index = rng.gen_range(0..no_of_leaves) as usize;
let indices = vec![set_index as u8];
let zero_leaf: Vec<Fr> = vec![];
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap());
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf).unwrap());
rln.atomic_operation(0, leaves_buffer, indices_buffer)
.unwrap();
// We get the root of the tree obtained after a deletion
let mut buffer = Cursor::new(Vec::<u8>::new());
rln.get_root(&mut buffer).unwrap();
let (root_after_deletion, _) = bytes_le_to_fr(&buffer.into_inner());
assert_ne!(root_after_insertion, root_after_deletion);
// We get the leaf
let mut output_buffer = Cursor::new(Vec::<u8>::new());
rln.get_leaf(set_index, &mut output_buffer).unwrap();
let (received_leaf, _) = bytes_le_to_fr(output_buffer.into_inner().as_ref());
assert_eq!(received_leaf, Fr::from(0));
}
#[allow(unused_must_use)] #[allow(unused_must_use)]
#[test] #[test]
// This test checks if `set_leaves_from` throws an error when the index is out of bounds // This test checks if `set_leaves_from` throws an error when the index is out of bounds

View File

@ -265,7 +265,7 @@ mod test {
let success = atomic_operation( let success = atomic_operation(
rln_pointer, rln_pointer,
no_of_leaves as usize, last_leaf_index as usize,
leaves_buffer, leaves_buffer,
indices_buffer, indices_buffer,
); );