test: new batching mechanism

This commit is contained in:
rymnc 2023-08-01 00:23:28 +05:30
parent d8796e33b0
commit 396c2ec342
No known key found for this signature in database
GPG Key ID: AAA088D5C68ECD34
7 changed files with 135 additions and 165 deletions

View File

@ -3,11 +3,11 @@ use crate::hashers::{poseidon_hash, PoseidonHash};
use crate::utils::{bytes_le_to_fr, fr_to_bytes_le};
use color_eyre::{Report, Result};
use serde_json::Value;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::path::PathBuf;
use std::str::FromStr;
use utils::pmtree::{Database, Hasher};
use utils::pmtree::{DBKey, Database, Hasher};
use utils::*;
const METADATA_KEY: [u8; 8] = *b"metadata";
@ -192,57 +192,42 @@ impl ZerokitMerkleTree for PmTree {
indices: J,
) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
let indices = indices.into_iter().collect::<Vec<_>>();
let end = start + leaves.len() + indices.len();
let indices = indices.into_iter().collect::<HashSet<_>>();
let mut subtree = HashMap::<usize, Fr>::new();
let leaves_len = leaves.len();
let leaves_set = self.tree.leaves_set();
// handle each case appropriately -
// case 1: both leaves and indices to be removed are passed in
// case 2: only leaves are passed in
// case 3: only indices are passed in
// case 4: neither leaves nor indices are passed in
match (leaves.len(), indices.len()) {
(0, 0) => Err(Report::msg("no leaves or indices to be removed")),
(0, _) => {
// case 3
// remove indices
let mut new_leaves = Vec::new();
let start = start + indices[0];
let end = start + indices.len();
for _ in start..end {
// Insert 0
new_leaves.push(Self::Hasher::default_leaf());
}
self.tree
.set_range(start, new_leaves)
.map_err(|e| Report::msg(e.to_string()))
}
(_, 0) => {
// case 2
// insert leaves
self.tree
.set_range(start, leaves)
.map_err(|e| Report::msg(e.to_string()))
}
(_, _) => {
// case 1
// remove indices
let mut new_leaves = Vec::new();
let indices = indices.into_iter().collect::<HashSet<_>>();
let new_start = start + leaves.len();
for i in new_start..=end {
if indices.contains(&i) {
// Insert 0
new_leaves.push(Self::Hasher::default_leaf());
} else if let Some(leaf) = leaves.get(i - new_start) {
// Insert leaf
new_leaves.push(*leaf);
}
}
self.tree
.set_range(start, new_leaves)
.map_err(|e| Report::msg(e.to_string()))
}
dbg!(self.tree.root());
// insert the old leaves
for i in 0..leaves_set {
let leaf = self.tree.get(i)?;
subtree.insert(i, leaf);
}
// zero out the leaves to be removed
for index in indices {
if index >= leaves_set {
return Err(Report::msg(format!(
"Index {} is out of bounds, leaves_set: {}",
index, leaves_set
)));
}
subtree.insert(index, Self::Hasher::default_leaf());
}
// insert the new leaves from start
for i in start..(start + leaves_len) {
let leaf = leaves[i - start];
subtree.insert(i, leaf);
}
// Use set_range with the new_leaves buffer to update the tree.
let res = self
.tree
.set_range(0, subtree.into_iter().map(|(_, v)| v))
.map_err(|e| Report::msg(e.to_string()));
dbg!(self.tree.root());
return res;
}
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> {

View File

@ -1438,6 +1438,12 @@ mod test {
rln.get_root(&mut buffer).unwrap();
let (root_after_noop, _) = bytes_le_to_fr(&buffer.into_inner());
let mut output_buffer = Cursor::new(Vec::<u8>::new());
rln.get_leaf(last_leaf_index, &mut output_buffer).unwrap();
let (received_leaf, _) = bytes_le_to_fr(output_buffer.into_inner().as_ref());
assert_eq!(received_leaf, last_leaf[0]);
assert_eq!(root_after_insertion, root_after_noop);
}

View File

@ -3,7 +3,7 @@ use hex_literal::hex;
use tiny_keccak::{Hasher as _, Keccak};
use zerokit_utils::{
FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree,
ZerokitMerkleTree,
ZerokitMerkleTree, BatchOf,
};
#[derive(Clone, Copy, Eq, PartialEq)]
@ -50,9 +50,16 @@ pub fn optimal_merkle_tree_benchmark(c: &mut Criterion) {
})
});
c.bench_function("OptimalMerkleTree::override_range", |b| {
c.bench_function("OptimalMerkleTree::set_range", |b| {
b.iter(|| {
tree.override_range(0, leaves, [0, 1, 2, 3]).unwrap();
let mut batch = BatchOf::<OptimalMerkleTree<Keccak256>>::new();
for i in 0..leaves.len() {
batch.insert(i, leaves[i]);
}
for i in [0, 1, 2, 3] {
batch.remove(&i);
}
tree.set_range(&batch).unwrap();
})
});
@ -94,7 +101,14 @@ pub fn full_merkle_tree_benchmark(c: &mut Criterion) {
c.bench_function("FullMerkleTree::override_range", |b| {
b.iter(|| {
tree.override_range(0, leaves, [0, 1, 2, 3]).unwrap();
let mut batch = BatchOf::<FullMerkleTree<Keccak256>>::new();
for i in 0..leaves.len() {
batch.insert(i, leaves[i]);
}
for i in [0, 1, 2, 3] {
batch.remove(&i);
}
tree.set_range(&batch).unwrap();
})
});

View File

@ -1,10 +1,10 @@
use crate::merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree};
use crate::{merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree}, merkle_tree::Batch};
use color_eyre::{Report, Result};
use std::{
cmp::max,
fmt::Debug,
iter::{once, repeat, successors},
str::FromStr,
iter::{repeat, successors},
str::FromStr, collections::HashMap,
};
////////////////////////////////////////////////////////////
@ -59,6 +59,30 @@ impl FromStr for FullMerkleConfig {
}
}
impl<H> Batch<H> for HashMap<usize, FrOf<H>>
where
H: Hasher,
{
type Key = usize;
fn insert(&mut self, key: usize, value: FrOf<H>) {
self.insert(key, value);
}
fn remove(&mut self, key: usize) {
self.remove(&key);
}
fn max_index(&self) -> usize {
*self.keys().max().unwrap_or(&0)
}
fn min_index(&self) -> usize {
*self.keys().min().unwrap_or(&0)
}
}
/// Implementations
impl<H: Hasher> ZerokitMerkleTree for FullMerkleTree<H>
where
@ -67,6 +91,7 @@ where
type Proof = FullMerkleProof<H>;
type Hasher = H;
type Config = FullMerkleConfig;
type Batch = HashMap<usize, FrOf<Self::Hasher>>;
fn default(depth: usize) -> Result<Self> {
FullMerkleTree::<H>::new(depth, Self::Hasher::default_leaf(), Self::Config::default())
@ -128,7 +153,12 @@ where
// Sets a leaf at the specified tree index
fn set(&mut self, leaf: usize, hash: FrOf<Self::Hasher>) -> Result<()> {
self.set_range(leaf, once(hash))?;
if leaf >= self.capacity() {
return Err(Report::msg("leaf index out of bounds"));
}
let capacity = self.capacity();
self.nodes[capacity + leaf - 1] = hash;
self.update_nodes(capacity + leaf - 1, capacity + leaf - 1)?;
self.next_index = max(self.next_index, leaf + 1);
Ok(())
}
@ -143,59 +173,18 @@ where
// Sets tree nodes, starting from start index
// Function proper of FullMerkleTree implementation
fn set_range<I: IntoIterator<Item = FrOf<Self::Hasher>>>(
fn set_range(
&mut self,
start: usize,
hashes: I,
batch: &Self::Batch,
) -> Result<()> {
let index = self.capacity() + start - 1;
let mut count = 0;
// first count number of hashes, and check that they fit in the tree
// then insert into the tree
let hashes = hashes.into_iter().collect::<Vec<_>>();
if hashes.len() + start > self.capacity() {
return Err(Report::msg("provided hashes do not fit in the tree"));
}
hashes.into_iter().for_each(|hash| {
self.nodes[index + count] = hash;
count += 1;
});
if count != 0 {
self.update_nodes(index, index + (count - 1))?;
self.next_index = max(self.next_index, start + count);
}
Ok(())
}
fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>,
{
let index = self.capacity() + start - 1;
let mut count = 0;
let leaves = leaves.into_iter().collect::<Vec<_>>();
let to_remove_indices = to_remove_indices.into_iter().collect::<Vec<_>>();
// first count number of hashes, and check that they fit in the tree
// then insert into the tree
if leaves.len() + start - to_remove_indices.len() > self.capacity() {
if batch.len() > self.capacity() {
return Err(Report::msg("provided hashes do not fit in the tree"));
}
// remove leaves
for i in &to_remove_indices {
self.delete(*i)?;
}
// insert new leaves
for hash in leaves {
self.nodes[index + count] = hash;
count += 1;
}
if count != 0 {
self.update_nodes(index, index + (count - 1))?;
self.next_index = max(self.next_index, start + count - to_remove_indices.len());
for (key, value) in batch {
self.set(*key, *value)?;
}
Ok(())
}

View File

@ -30,7 +30,17 @@ pub trait Hasher {
fn hash(input: &[Self::Fr]) -> Self::Fr;
}
pub trait Batch<H: Hasher> where H:Hasher {
type Key;
fn insert(&mut self, key: usize, value: H::Fr);
fn remove(&mut self, key: usize);
fn max_index(&self) -> usize;
fn min_index(&self) -> usize;
}
pub type FrOf<H> = <H as Hasher>::Fr;
pub type BatchOf<Tree> = <Tree as ZerokitMerkleTree>::Batch;
/// In the ZerokitMerkleTree trait we define the methods that are required to be implemented by a Merkle tree
/// Including, OptimalMerkleTree, FullMerkleTree
@ -38,6 +48,7 @@ pub trait ZerokitMerkleTree {
type Proof: ZerokitMerkleProof;
type Hasher: Hasher;
type Config: Default + FromStr;
type Batch: Batch<Self::Hasher>;
fn default(depth: usize) -> Result<Self>
where
@ -51,14 +62,8 @@ pub trait ZerokitMerkleTree {
fn root(&self) -> FrOf<Self::Hasher>;
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>>;
fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>) -> Result<()>;
fn set_range<I>(&mut self, start: usize, leaves: I) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>;
fn set_range(&mut self, batch: &Self::Batch) -> Result<()>;
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>>;
fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>;
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()>;
fn delete(&mut self, index: usize) -> Result<()>;
fn proof(&self, index: usize) -> Result<Self::Proof>;

View File

@ -60,6 +60,7 @@ where
type Proof = OptimalMerkleProof<H>;
type Hasher = H;
type Config = OptimalMerkleConfig;
type Batch = HashMap<usize, FrOf<H>>;
fn default(depth: usize) -> Result<Self> {
OptimalMerkleTree::<H>::new(depth, H::default_leaf(), Self::Config::default())
@ -128,47 +129,15 @@ where
}
// Sets multiple leaves from the specified tree index
fn set_range<I: IntoIterator<Item = H::Fr>>(&mut self, start: usize, leaves: I) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
fn set_range(&mut self, batch: &Self::Batch) -> Result<()> {
// check if the range is valid
if start + leaves.len() > self.capacity() {
return Err(Report::msg("provided range exceeds set size"));
}
for (i, leaf) in leaves.iter().enumerate() {
self.nodes.insert((self.depth, start + i), *leaf);
self.recalculate_from(start + i)?;
}
self.next_index = max(self.next_index, start + leaves.len());
Ok(())
}
fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>,
{
let leaves = leaves.into_iter().collect::<Vec<_>>();
let to_remove_indices = to_remove_indices.into_iter().collect::<Vec<_>>();
// check if the range is valid
if leaves.len() + start - to_remove_indices.len() > self.capacity() {
if batch.len() > self.capacity() {
return Err(Report::msg("provided range exceeds set size"));
}
// remove leaves
for i in &to_remove_indices {
self.delete(*i)?;
for (key, value) in batch {
self.set(*key, *value)?;
}
// add leaves
for (i, leaf) in leaves.iter().enumerate() {
self.nodes.insert((self.depth, start + i), *leaf);
self.recalculate_from(start + i)?;
}
self.next_index = max(
self.next_index,
start + leaves.len() - to_remove_indices.len(),
);
Ok(())
}

View File

@ -1,11 +1,13 @@
// Tests adapted from https://github.com/worldcoin/semaphore-rs/blob/d462a4372f1fd9c27610f2acfe4841fab1d396aa/src/merkle_tree.rs
#[cfg(test)]
mod test {
use std::collections::HashMap;
use hex_literal::hex;
use tiny_keccak::{Hasher as _, Keccak};
use zerokit_utils::{
FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree,
ZerokitMerkleProof, ZerokitMerkleTree,
ZerokitMerkleProof, ZerokitMerkleTree, BatchOf,
};
#[derive(Clone, Copy, Eq, PartialEq)]
struct Keccak256;
@ -139,27 +141,27 @@ mod test {
OptimalMerkleTree::<Keccak256>::new(2, [0; 32], OptimalMerkleConfig::default())
.unwrap();
// We set the leaves
tree.set_range(0, initial_leaves.iter().cloned()).unwrap();
// We set the leaves in a batch
// Batch = Hashmap<index, leaf>
let batch = initial_leaves
.iter()
.enumerate()
.map(|(i, leaf)| (i, *leaf))
.collect::<HashMap<_, _>>();
tree.set_range(&batch).unwrap();
let new_leaves = [
hex!("0000000000000000000000000000000000000000000000000000000000000005"),
hex!("0000000000000000000000000000000000000000000000000000000000000006"),
];
let to_delete_indices: [usize; 2] = [0, 1];
let mut new_batch = BatchOf::<OptimalMerkleTree<Keccak256>>::new();
new_batch.remove(&0);
new_batch.remove(&1);
new_batch.insert(tree.leaves_set() - 2, hex!("0000000000000000000000000000000000000000000000000000000000000005"));
new_batch.insert(tree.leaves_set() - 1, hex!("0000000000000000000000000000000000000000000000000000000000000006"));
// We override the leaves
tree.override_range(
0, // start from the end of the initial leaves
new_leaves.iter().cloned(),
to_delete_indices.iter().cloned(),
)
.unwrap();
tree.set_range(&new_batch).unwrap();
// ensure that the leaves are set correctly
for i in 0..new_leaves.len() {
assert_eq!(tree.get_leaf(i), new_leaves[i]);
for (i, leaf) in new_batch {
assert_eq!(tree.get_leaf(i), leaf);
}
}
}