add safe MT

This commit is contained in:
M Alghazwi 2024-10-08 14:21:12 +02:00
parent 8c1489b273
commit e326630e7b
4 changed files with 898 additions and 382 deletions

View File

@ -11,11 +11,11 @@ use plonky2::plonk::proof::ProofWithPublicInputs;
use std::marker::PhantomData;
use itertools::Itertools;
use crate::merkle_tree::MerkleTree;
use crate::merkle_tree::capped_tree::MerkleTree;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::hash::hash_types::{HashOutTarget, MerkleCapTarget, NUM_HASH_OUT_ELTS};
use crate::merkle_tree::{MerkleProof, MerkleProofTarget};
use crate::merkle_tree::capped_tree::{MerkleProof, MerkleProofTarget};
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use plonky2::field::goldilocks_field::GoldilocksField;
@ -24,7 +24,7 @@ use plonky2::plonk::proof::Proof;
use plonky2::hash::hashing::PlonkyPermutation;
use plonky2::plonk::circuit_data::VerifierCircuitTarget;
use crate::merkle_tree::MerkleCap;
use crate::merkle_tree::capped_tree::MerkleCap;
// size of leaf data (in number of field elements)
pub const LEAF_LEN: usize = 4;
@ -359,7 +359,7 @@ pub mod tests {
use super::*;
use plonky2::field::types::Field;
use crate::merkle_tree::MerkleTree;
use crate::merkle_tree::capped_tree::MerkleTree;
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};

View File

@ -0,0 +1,378 @@
// An adapted implementation of Merkle tree
// based on the original plonky2 merkle tree implementation
use core::mem::MaybeUninit;
use core::slice;
use anyhow::{ensure, Result};
use plonky2_maybe_rayon::*;
use serde::{Deserialize, Serialize};
use plonky2::hash::hash_types::{HashOutTarget, RichField};
use plonky2::plonk::config::{GenericHashOut, Hasher};
use plonky2::util::log2_strict;
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleCap<F: RichField, H: Hasher<F>>(pub Vec<H::Hash>);
impl<F: RichField, H: Hasher<F>> Default for MerkleCap<F, H> {
fn default() -> Self {
Self(Vec::new())
}
}
impl<F: RichField, H: Hasher<F>> MerkleCap<F, H> {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn height(&self) -> usize {
log2_strict(self.len())
}
pub fn flatten(&self) -> Vec<F> {
self.0.iter().flat_map(|&h| h.to_vec()).collect()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleTree<F: RichField, H: Hasher<F>> {
pub leaves: Vec<Vec<F>>,
pub digests: Vec<H::Hash>,
pub cap: MerkleCap<F, H>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleProof<F: RichField, H: Hasher<F>> {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub siblings: Vec<H::Hash>,
}
impl<F: RichField, H: Hasher<F>> MerkleProof<F, H> {
pub fn len(&self) -> usize {
self.siblings.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleProofTarget {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub siblings: Vec<HashOutTarget>,
}
impl<F: RichField, H: Hasher<F>> Default for MerkleTree<F, H> {
fn default() -> Self {
Self {
leaves: Vec::new(),
digests: Vec::new(),
cap: MerkleCap::default(),
}
}
}
pub(crate) fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
assert!(v.capacity() >= len);
let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
unsafe {
slice::from_raw_parts_mut(v_ptr, len)
}
}
pub(crate) fn fill_subtree<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
) -> H::Hash {
assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
if digests_buf.is_empty() {
H::hash_or_noop(&leaves[0])
} else {
let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
let (left_digest, right_digest) = plonky2_maybe_rayon::join(
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
);
left_digest_mem.write(left_digest);
right_digest_mem.write(right_digest);
H::two_to_one(left_digest, right_digest)
}
}
pub(crate) fn fill_digests_buf<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
cap_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
cap_height: usize,
) {
if digests_buf.is_empty() {
debug_assert_eq!(cap_buf.len(), leaves.len());
cap_buf
.par_iter_mut()
.zip(leaves)
.for_each(|(cap_buf, leaf)| {
cap_buf.write(H::hash_or_noop(leaf));
});
return;
}
let subtree_digests_len = digests_buf.len() >> cap_height;
let subtree_leaves_len = leaves.len() >> cap_height;
let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
assert_eq!(digests_chunks.len(), cap_buf.len());
assert_eq!(digests_chunks.len(), leaves_chunks.len());
digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
|((subtree_digests, subtree_cap), subtree_leaves)| {
subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
},
);
}
pub(crate) fn merkle_tree_prove<F: RichField, H: Hasher<F>>(
leaf_index: usize,
leaves_len: usize,
cap_height: usize,
digests: &[H::Hash],
) -> Vec<H::Hash> {
let num_layers = log2_strict(leaves_len) - cap_height;
debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
let digest_len = 2 * (leaves_len - (1 << cap_height));
assert_eq!(digest_len, digests.len());
let digest_tree: &[H::Hash] = {
let tree_index = leaf_index >> num_layers;
let tree_len = digest_len >> cap_height;
&digests[tree_len * tree_index..tree_len * (tree_index + 1)]
};
// Mask out high bits to get the index within the sub-tree.
let mut pair_index = leaf_index & ((1 << num_layers) - 1);
(0..num_layers)
.map(|i| {
let parity = pair_index & 1;
pair_index >>= 1;
// The layers' data is interleaved as follows:
// [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...].
// Each of the above is a pair of siblings.
// `pair_index` is the index of the pair within layer `i`.
// The index of that the pair within `digests` is
// `pair_index * 2 ** (i + 1) + (2 ** i - 1)`.
let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
// We have an index for the _pair_, but we want the index of the _sibling_.
// Double the pair index to get the index of the left sibling. Conditionally add `1`
// if we are to retrieve the right sibling.
let sibling_index = 2 * siblings_index + (1 - parity);
digest_tree[sibling_index]
})
.collect()
}
impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
let log2_leaves_len = log2_strict(leaves.len());
assert!(
cap_height <= log2_leaves_len,
"cap_height={} should be at most log2(leaves.len())={}",
cap_height,
log2_leaves_len
);
let num_digests = 2 * (leaves.len() - (1 << cap_height));
let mut digests = Vec::with_capacity(num_digests);
let len_cap = 1 << cap_height;
let mut cap = Vec::with_capacity(len_cap);
let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
unsafe {
// SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to
// `num_digests` and `len_cap`, resp.
digests.set_len(num_digests);
cap.set_len(len_cap);
}
Self {
leaves,
digests,
cap: MerkleCap(cap),
}
}
pub fn get(&self, i: usize) -> &[F] {
&self.leaves[i]
}
// Create a Merkle proof from a leaf index.
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
let cap_height = log2_strict(self.cap.len());
let siblings =
merkle_tree_prove::<F, H>(leaf_index, self.leaves.len(), cap_height, &self.digests);
MerkleProof { siblings }
}
}
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given root.
pub fn verify_merkle_proof<F: RichField, H: Hasher<F>>(
leaf_data: Vec<F>,
leaf_index: usize,
merkle_root: H::Hash,
proof: &MerkleProof<F, H>,
) -> Result<()> {
let merkle_cap = MerkleCap(vec![merkle_root]);
verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof)
}
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given cap.
pub fn verify_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
leaf_data: Vec<F>,
leaf_index: usize,
merkle_cap: &MerkleCap<F, H>,
proof: &MerkleProof<F, H>,
) -> Result<()> {
verify_batch_merkle_proof_to_cap(
&[leaf_data.clone()],
&[proof.siblings.len()],
leaf_index,
merkle_cap,
proof,
)
}
/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the
/// given cap.
pub fn verify_batch_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
leaf_data: &[Vec<F>],
leaf_heights: &[usize],
mut leaf_index: usize,
merkle_cap: &MerkleCap<F, H>,
proof: &MerkleProof<F, H>,
) -> Result<()> {
assert_eq!(leaf_data.len(), leaf_heights.len());
let mut current_digest = H::hash_or_noop(&leaf_data[0]);
let mut current_height = leaf_heights[0];
let mut leaf_data_index = 1;
for &sibling_digest in &proof.siblings {
let bit = leaf_index & 1;
leaf_index >>= 1;
current_digest = if bit == 1 {
H::two_to_one(sibling_digest, current_digest)
} else {
H::two_to_one(current_digest, sibling_digest)
};
current_height -= 1;
if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] {
let mut new_leaves = current_digest.to_vec();
new_leaves.extend_from_slice(&leaf_data[leaf_data_index]);
current_digest = H::hash_or_noop(&new_leaves);
leaf_data_index += 1;
}
}
assert_eq!(leaf_data_index, leaf_data.len());
ensure!(
current_digest == merkle_cap.0[leaf_index],
"Invalid Merkle proof."
);
Ok(())
}
#[cfg(test)]
pub(crate) mod tests {
use anyhow::Result;
use super::*;
use plonky2::field::extension::Extendable;
use crate::merkle_tree::capped_tree::verify_merkle_proof_to_cap;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
pub(crate) fn random_data<F: RichField>(n: usize, k: usize) -> Vec<Vec<F>> {
(0..n).map(|_| F::rand_vec(k)).collect()
}
fn verify_all_leaves<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
leaves: Vec<Vec<F>>,
cap_height: usize,
) -> Result<()> {
let tree = MerkleTree::<F, C::Hasher>::new(leaves.clone(), cap_height);
for (i, leaf) in leaves.into_iter().enumerate() {
let proof = tree.prove(i);
verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?;
}
Ok(())
}
#[test]
#[should_panic]
fn test_cap_height_too_big() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let cap_height = log_n + 1; // Should panic if `cap_height > len_n`.
let leaves = random_data::<F>(1 << log_n, 7);
let _ = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
}
#[test]
fn test_cap_height_eq_log2_len() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, log_n)?;
Ok(())
}
#[test]
fn test_merkle_trees() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, 1)?;
Ok(())
}
}

View File

@ -0,0 +1,514 @@
// Implementation of "safe" merkle tree
// consistent with the one in codex:
// https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim
use anyhow::{ensure, Result};
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField};
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
use std::ops::Shr;
use plonky2_field::types::Field;
// Constants for the keys used in compression
const KEY_NONE: u64 = 0x0;
const KEY_BOTTOM_LAYER: u64 = 0x1;
const KEY_ODD: u64 = 0x2;
const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3;
/// Trait for a hash function that supports keyed compression.
pub trait KeyedHasher<F: RichField>: Hasher<F> {
fn compress(x: Self::Hash, y: Self::Hash, key: u64) -> Self::Hash;
}
impl KeyedHasher<GoldilocksField> for PoseidonHash {
fn compress(x: Self::Hash, y: Self::Hash, key: u64) -> Self::Hash {
let key_field = GoldilocksField::from_canonical_u64(key);
let mut inputs = Vec::new();
inputs.extend_from_slice(&x.elements);
inputs.extend_from_slice(&y.elements);
inputs.push(key_field);
PoseidonHash::hash_no_pad(&inputs) // TODO: double-check this function
}
}
/// Merkle tree struct, containing the layers, compression function, and zero hash.
#[derive(Clone)]
pub struct MerkleTree<F: RichField, H: KeyedHasher<F>> {
layers: Vec<Vec<H::Hash>>,
compress: fn(H::Hash, H::Hash, u64) -> H::Hash,
zero: H::Hash,
}
impl<F: RichField, H: KeyedHasher<F>> MerkleTree<F, H> {
/// Constructs a new Merkle tree from the given leaves.
pub fn new(
leaves: &[H::Hash],
zero: H::Hash,
compress: fn(H::Hash, H::Hash, u64) -> H::Hash,
) -> Result<Self> {
let layers = merkle_tree_worker::<F,H>(leaves, zero, compress, true)?;
Ok(Self {
layers,
compress,
zero,
})
}
/// Returns the depth of the Merkle tree.
pub fn depth(&self) -> usize {
self.layers.len() - 1
}
/// Returns the number of leaves in the Merkle tree.
pub fn leaves_count(&self) -> usize {
self.layers[0].len()
}
/// Returns the root hash of the Merkle tree.
pub fn root(&self) -> Result<H::Hash> {
let last_layer = self.layers.last().ok_or_else(|| anyhow::anyhow!("Empty tree"))?;
ensure!(last_layer.len() == 1, "Invalid Merkle tree");
Ok(last_layer[0])
}
/// Generates a Merkle proof for a given leaf index.
pub fn get_proof(&self, index: usize) -> Result<MerkleProof<F, H>> {
let depth = self.depth();
let nleaves = self.leaves_count();
ensure!(index < nleaves, "Index out of bounds");
let mut path = Vec::with_capacity(depth);
let mut k = index;
let mut m = nleaves;
for i in 0..depth {
let j = k ^ 1;
let sibling = if j < m {
self.layers[i][j]
} else {
self.zero
};
path.push(sibling);
k = k >> 1;
m = (m + 1) >> 1;
}
Ok(MerkleProof {
index,
path,
nleaves,
compress: self.compress,
zero: self.zero,
})
}
}
/// Build the Merkle tree layers.
fn merkle_tree_worker<F: RichField, H: KeyedHasher<F>>(
xs: &[H::Hash],
zero: H::Hash,
compress: fn(H::Hash, H::Hash, u64) -> H::Hash,
is_bottom_layer: bool,
) -> Result<Vec<Vec<H::Hash>>> {
let m = xs.len();
if !is_bottom_layer && m == 1 {
return Ok(vec![xs.to_vec()]);
}
let halfn = m / 2;
let n = 2 * halfn;
let is_odd = n != m;
let mut ys = Vec::with_capacity(halfn + if is_odd { 1 } else { 0 });
for i in 0..halfn {
let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE };
let h = compress(xs[2 * i], xs[2 * i + 1], key);
ys.push(h);
}
if is_odd {
let key = if is_bottom_layer {
KEY_ODD_AND_BOTTOM_LAYER
} else {
KEY_ODD
};
let h = compress(xs[n], zero, key);
ys.push(h);
}
let mut layers = vec![xs.to_vec()];
let mut upper_layers = merkle_tree_worker::<F,H>(&ys, zero, compress, false)?;
layers.append(&mut upper_layers);
Ok(layers)
}
/// Merkle proof struct, containing the index, path, and other necessary data.
#[derive(Clone)]
pub struct MerkleProof<F: RichField, H: KeyedHasher<F>> {
pub index: usize, // Index of the leaf
pub path: Vec<H::Hash>, // Sibling hashes from the leaf to the root
pub nleaves: usize, // Total number of leaves
pub compress: fn(H::Hash, H::Hash, u64) -> H::Hash, // compression function - TODO: make it generic instead
pub zero: H::Hash,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleProofTarget {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub path: Vec<HashOutTarget>,
}
impl<F: RichField, H: KeyedHasher<F>> MerkleProof<F, H> {
/// Reconstructs the root hash from the proof and the given leaf.
pub fn reconstruct_root(&self, leaf: H::Hash) -> Result<H::Hash> {
let mut m = self.nleaves;
let mut j = self.index;
let mut h = leaf;
let mut bottom_flag = KEY_BOTTOM_LAYER;
for p in &self.path {
let odd_index = (j & 1) != 0;
if odd_index {
// The index of the child is odd
h = (self.compress)(*p, h, bottom_flag);
} else {
if j == m - 1 {
// Single child -> so odd node
h = (self.compress)(h, *p, bottom_flag + 2);
} else {
// Even node
h = (self.compress)(h, *p, bottom_flag);
}
}
bottom_flag = KEY_NONE;
j = j.shr(1);
m = (m + 1).shr(1);
}
Ok(h)
}
/// Verifies the proof against a given root and leaf.
pub fn verify(&self, leaf: H::Hash, root: H::Hash) -> Result<bool> {
let reconstructed_root = self.reconstruct_root(leaf)?;
Ok(reconstructed_root == root)
}
}
#[cfg(test)]
mod tests {
use super::*;
use plonky2::field::types::Field;
// Constants for the keys used in compression
// const KEY_NONE: u64 = 0x0;
// const KEY_BOTTOM_LAYER: u64 = 0x1;
// const KEY_ODD: u64 = 0x2;
// const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3;
fn compress(
x: HashOut<GoldilocksField>,
y: HashOut<GoldilocksField>,
key: u64,
) -> HashOut<GoldilocksField> {
let key_field = GoldilocksField::from_canonical_u64(key);
let mut inputs = Vec::new();
inputs.extend_from_slice(&x.elements);
inputs.extend_from_slice(&y.elements);
inputs.push(key_field);
PoseidonHash::hash_no_pad(&inputs)
}
fn make_tree(
data: &[GoldilocksField],
zero: HashOut<GoldilocksField>,
) -> Result<MerkleTree<GoldilocksField, PoseidonHash>> {
let compress_fn = PoseidonHash::compress;
// Hash the data to obtain leaf hashes
let leaves: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| {
// Hash each field element to get the leaf hash
PoseidonHash::hash_no_pad(&[element])
})
.collect();
MerkleTree::<GoldilocksField, PoseidonHash>::new(&leaves, zero, compress_fn)
}
#[test]
fn single_proof_test() -> Result<()> {
let data = (1u64..=8)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to obtain leaf hashes
let leaves: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| {
// Hash each field element to get the leaf hash
PoseidonHash::hash_no_pad(&[element])
})
.collect();
let zero = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let compress_fn = PoseidonHash::compress;
// Build the Merkle tree
let tree = MerkleTree::<GoldilocksField, PoseidonHash>::new(&leaves, zero, compress_fn)?;
// Get the root
let root = tree.root()?;
// Get a proof for the first leaf
let proof = tree.get_proof(0)?;
// Verify the proof
let is_valid = proof.verify(leaves[0], root)?;
assert!(is_valid, "Merkle proof verification failed");
Ok(())
}
#[test]
fn test_correctness_even_bottom_layer() -> Result<()> {
// Data for the test (field elements)
let data = (1u64..=8)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to get leaf hashes
let leaf_hashes: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| PoseidonHash::hash_no_pad(&[element]))
.collect();
// zero hash
let zero = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let expected_root =
compress(
compress(
compress(
leaf_hashes[0],
leaf_hashes[1],
KEY_BOTTOM_LAYER,
),
compress(
leaf_hashes[2],
leaf_hashes[3],
KEY_BOTTOM_LAYER,
),
KEY_NONE,
),
compress(
compress(
leaf_hashes[4],
leaf_hashes[5],
KEY_BOTTOM_LAYER,
),
compress(
leaf_hashes[6],
leaf_hashes[7],
KEY_BOTTOM_LAYER,
),
KEY_NONE,
),
KEY_NONE,
);
// Build the tree
let tree = make_tree(&data, zero)?;
// Get the computed root
let computed_root = tree.root()?;
// Check that the computed root matches the expected root
assert_eq!(computed_root, expected_root);
Ok(())
}
#[test]
fn test_correctness_odd_bottom_layer() -> Result<()> {
// Data for the test (field elements)
let data = (1u64..=7)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to get leaf hashes
let leaf_hashes: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| PoseidonHash::hash_no_pad(&[element]))
.collect();
// zero hash
let zero = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let expected_root =
compress(
compress(
compress(
leaf_hashes[0],
leaf_hashes[1],
KEY_BOTTOM_LAYER,
),
compress(
leaf_hashes[2],
leaf_hashes[3],
KEY_BOTTOM_LAYER,
),
KEY_NONE,
),
compress(
compress(
leaf_hashes[4],
leaf_hashes[5],
KEY_BOTTOM_LAYER,
),
compress(
leaf_hashes[6],
zero,
KEY_ODD_AND_BOTTOM_LAYER,
),
KEY_NONE,
),
KEY_NONE,
);
// Build the tree
let tree = make_tree(&data, zero)?;
// Get the computed root
let computed_root = tree.root()?;
// Check that the computed root matches the expected root
assert_eq!(computed_root, expected_root);
Ok(())
}
#[test]
fn test_correctness_even_bottom_odd_upper_layers() -> Result<()> {
// Data for the test (field elements)
let data = (1u64..=10)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to get leaf hashes
let leaf_hashes: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| PoseidonHash::hash_no_pad(&[element]))
.collect();
// zero hash
let zero = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let expected_root = compress(
compress(
compress(
compress(
leaf_hashes[0],
leaf_hashes[1],
KEY_BOTTOM_LAYER,
),
compress(
leaf_hashes[2],
leaf_hashes[3],
KEY_BOTTOM_LAYER,
),
KEY_NONE,
),
compress(
compress(
leaf_hashes[4],
leaf_hashes[5],
KEY_BOTTOM_LAYER,
),
compress(
leaf_hashes[6],
leaf_hashes[7],
KEY_BOTTOM_LAYER,
),
KEY_NONE,
),
KEY_NONE,
),
compress(
compress(
compress(
leaf_hashes[8],
leaf_hashes[9],
KEY_BOTTOM_LAYER,
),
zero,
KEY_ODD,
),
zero,
KEY_ODD,
),
KEY_NONE,
);
// Build the tree
let tree = make_tree(&data, zero)?;
// Get the computed root
let computed_root = tree.root()?;
// Check that the computed root matches the expected root
assert_eq!(computed_root, expected_root);
Ok(())
}
#[test]
fn test_proofs() -> Result<()> {
// Data for the test (field elements)
let data = (1u64..=10)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to get leaf hashes
let leaf_hashes: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| PoseidonHash::hash_no_pad(&[element]))
.collect();
// zero hash
let zero = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let compress_fn = PoseidonHash::compress;
// Build the tree
let tree = MerkleTree::<GoldilocksField, PoseidonHash>::new(&leaf_hashes, zero, compress_fn)?;
// Get the root
let expected_root = tree.root()?;
// Verify proofs for all leaves
for (i, &leaf_hash) in leaf_hashes.iter().enumerate() {
let proof = tree.get_proof(i)?;
let is_valid = proof.verify(leaf_hash, expected_root)?;
assert!(is_valid, "Proof verification failed for leaf {}", i);
}
Ok(())
}
}

View File

@ -1,378 +1,2 @@
// An adapted implementation of Merkle tree
// based on the original plonky2 merkle tree implementation
use core::mem::MaybeUninit;
use core::slice;
use anyhow::{ensure, Result};
use plonky2_maybe_rayon::*;
use serde::{Deserialize, Serialize};
use plonky2::hash::hash_types::{HashOutTarget, RichField};
use plonky2::plonk::config::{GenericHashOut, Hasher};
use plonky2::util::log2_strict;
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleCap<F: RichField, H: Hasher<F>>(pub Vec<H::Hash>);
impl<F: RichField, H: Hasher<F>> Default for MerkleCap<F, H> {
fn default() -> Self {
Self(Vec::new())
}
}
impl<F: RichField, H: Hasher<F>> MerkleCap<F, H> {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn height(&self) -> usize {
log2_strict(self.len())
}
pub fn flatten(&self) -> Vec<F> {
self.0.iter().flat_map(|&h| h.to_vec()).collect()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleTree<F: RichField, H: Hasher<F>> {
pub leaves: Vec<Vec<F>>,
pub digests: Vec<H::Hash>,
pub cap: MerkleCap<F, H>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleProof<F: RichField, H: Hasher<F>> {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub siblings: Vec<H::Hash>,
}
impl<F: RichField, H: Hasher<F>> MerkleProof<F, H> {
pub fn len(&self) -> usize {
self.siblings.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleProofTarget {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub siblings: Vec<HashOutTarget>,
}
impl<F: RichField, H: Hasher<F>> Default for MerkleTree<F, H> {
fn default() -> Self {
Self {
leaves: Vec::new(),
digests: Vec::new(),
cap: MerkleCap::default(),
}
}
}
pub(crate) fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
assert!(v.capacity() >= len);
let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
unsafe {
slice::from_raw_parts_mut(v_ptr, len)
}
}
pub(crate) fn fill_subtree<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
) -> H::Hash {
assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
if digests_buf.is_empty() {
H::hash_or_noop(&leaves[0])
} else {
let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
let (left_digest, right_digest) = plonky2_maybe_rayon::join(
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
);
left_digest_mem.write(left_digest);
right_digest_mem.write(right_digest);
H::two_to_one(left_digest, right_digest)
}
}
pub(crate) fn fill_digests_buf<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
cap_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
cap_height: usize,
) {
if digests_buf.is_empty() {
debug_assert_eq!(cap_buf.len(), leaves.len());
cap_buf
.par_iter_mut()
.zip(leaves)
.for_each(|(cap_buf, leaf)| {
cap_buf.write(H::hash_or_noop(leaf));
});
return;
}
let subtree_digests_len = digests_buf.len() >> cap_height;
let subtree_leaves_len = leaves.len() >> cap_height;
let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
assert_eq!(digests_chunks.len(), cap_buf.len());
assert_eq!(digests_chunks.len(), leaves_chunks.len());
digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
|((subtree_digests, subtree_cap), subtree_leaves)| {
subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
},
);
}
pub(crate) fn merkle_tree_prove<F: RichField, H: Hasher<F>>(
leaf_index: usize,
leaves_len: usize,
cap_height: usize,
digests: &[H::Hash],
) -> Vec<H::Hash> {
let num_layers = log2_strict(leaves_len) - cap_height;
debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
let digest_len = 2 * (leaves_len - (1 << cap_height));
assert_eq!(digest_len, digests.len());
let digest_tree: &[H::Hash] = {
let tree_index = leaf_index >> num_layers;
let tree_len = digest_len >> cap_height;
&digests[tree_len * tree_index..tree_len * (tree_index + 1)]
};
// Mask out high bits to get the index within the sub-tree.
let mut pair_index = leaf_index & ((1 << num_layers) - 1);
(0..num_layers)
.map(|i| {
let parity = pair_index & 1;
pair_index >>= 1;
// The layers' data is interleaved as follows:
// [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...].
// Each of the above is a pair of siblings.
// `pair_index` is the index of the pair within layer `i`.
// The index of that the pair within `digests` is
// `pair_index * 2 ** (i + 1) + (2 ** i - 1)`.
let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
// We have an index for the _pair_, but we want the index of the _sibling_.
// Double the pair index to get the index of the left sibling. Conditionally add `1`
// if we are to retrieve the right sibling.
let sibling_index = 2 * siblings_index + (1 - parity);
digest_tree[sibling_index]
})
.collect()
}
impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
let log2_leaves_len = log2_strict(leaves.len());
assert!(
cap_height <= log2_leaves_len,
"cap_height={} should be at most log2(leaves.len())={}",
cap_height,
log2_leaves_len
);
let num_digests = 2 * (leaves.len() - (1 << cap_height));
let mut digests = Vec::with_capacity(num_digests);
let len_cap = 1 << cap_height;
let mut cap = Vec::with_capacity(len_cap);
let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
unsafe {
// SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to
// `num_digests` and `len_cap`, resp.
digests.set_len(num_digests);
cap.set_len(len_cap);
}
Self {
leaves,
digests,
cap: MerkleCap(cap),
}
}
pub fn get(&self, i: usize) -> &[F] {
&self.leaves[i]
}
// Create a Merkle proof from a leaf index.
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
let cap_height = log2_strict(self.cap.len());
let siblings =
merkle_tree_prove::<F, H>(leaf_index, self.leaves.len(), cap_height, &self.digests);
MerkleProof { siblings }
}
}
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given root.
pub fn verify_merkle_proof<F: RichField, H: Hasher<F>>(
leaf_data: Vec<F>,
leaf_index: usize,
merkle_root: H::Hash,
proof: &MerkleProof<F, H>,
) -> Result<()> {
let merkle_cap = MerkleCap(vec![merkle_root]);
verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof)
}
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given cap.
pub fn verify_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
leaf_data: Vec<F>,
leaf_index: usize,
merkle_cap: &MerkleCap<F, H>,
proof: &MerkleProof<F, H>,
) -> Result<()> {
verify_batch_merkle_proof_to_cap(
&[leaf_data.clone()],
&[proof.siblings.len()],
leaf_index,
merkle_cap,
proof,
)
}
/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the
/// given cap.
pub fn verify_batch_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
leaf_data: &[Vec<F>],
leaf_heights: &[usize],
mut leaf_index: usize,
merkle_cap: &MerkleCap<F, H>,
proof: &MerkleProof<F, H>,
) -> Result<()> {
assert_eq!(leaf_data.len(), leaf_heights.len());
let mut current_digest = H::hash_or_noop(&leaf_data[0]);
let mut current_height = leaf_heights[0];
let mut leaf_data_index = 1;
for &sibling_digest in &proof.siblings {
let bit = leaf_index & 1;
leaf_index >>= 1;
current_digest = if bit == 1 {
H::two_to_one(sibling_digest, current_digest)
} else {
H::two_to_one(current_digest, sibling_digest)
};
current_height -= 1;
if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] {
let mut new_leaves = current_digest.to_vec();
new_leaves.extend_from_slice(&leaf_data[leaf_data_index]);
current_digest = H::hash_or_noop(&new_leaves);
leaf_data_index += 1;
}
}
assert_eq!(leaf_data_index, leaf_data.len());
ensure!(
current_digest == merkle_cap.0[leaf_index],
"Invalid Merkle proof."
);
Ok(())
}
#[cfg(test)]
pub(crate) mod tests {
use anyhow::Result;
use super::*;
use plonky2::field::extension::Extendable;
use crate::merkle_tree::verify_merkle_proof_to_cap;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
pub(crate) fn random_data<F: RichField>(n: usize, k: usize) -> Vec<Vec<F>> {
(0..n).map(|_| F::rand_vec(k)).collect()
}
fn verify_all_leaves<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
leaves: Vec<Vec<F>>,
cap_height: usize,
) -> Result<()> {
let tree = MerkleTree::<F, C::Hasher>::new(leaves.clone(), cap_height);
for (i, leaf) in leaves.into_iter().enumerate() {
let proof = tree.prove(i);
verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?;
}
Ok(())
}
#[test]
#[should_panic]
fn test_cap_height_too_big() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let cap_height = log_n + 1; // Should panic if `cap_height > len_n`.
let leaves = random_data::<F>(1 << log_n, 7);
let _ = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
}
#[test]
fn test_cap_height_eq_log2_len() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, log_n)?;
Ok(())
}
#[test]
fn test_merkle_trees() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, 1)?;
Ok(())
}
}
pub mod capped_tree;
pub mod merkle_safe;