major refactor

This commit is contained in:
M Alghazwi 2024-11-05 12:57:49 +01:00
parent 4c3f2043ee
commit 645b30fa96
22 changed files with 1184 additions and 2149 deletions

View File

@ -5,15 +5,9 @@ This crate is an implementation of the [codex storage proofs circuits](https://g
## Code organization
- [`capped_tree`](./src/merkle_tree/capped_tree.rs) is an adapted implementation of Merkle tree based on the original plonky2 merkle tree implementation.
- [`capped_tree_circuit`](./src/circuits/capped_tree_circuit.rs) is the circuit implementation for regular merkle tree implementation (non-safe version) based on the above merkle tree implementation.
- [`merkle_safe`](./src/merkle_tree/merkle_safe.rs) is the implementation of "safe" merkle tree used in codex, consistent with the one [here](https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim).
- [`safe_tree_circuit`](./src/circuits/safe_tree_circuit.rs) is the Plonky2 Circuit implementation of "safe" merkle tree above.
- [`prove_single_cell`](./src/circuits/prove_single_cell.rs) is the Plonky2 Circuit implementation for proving a single cell in slot merkle tree.
- [`merkle_circuit`](./src/circuits/merkle_circuit) is the Plonky2 Circuit implementation of "safe" merkle tree above.
- [`sample_cells`](./src/circuits/sample_cells.rs) is the Plonky2 Circuit implementation for sampling cells in dataset merkle tree.

View File

@ -4,7 +4,7 @@ use std::time::{Duration, Instant};
use codex_plonky2_circuits::{
merkle_tree::merkle_safe::MerkleProof,
circuits::safe_tree_circuit::MerkleTreeCircuit,
circuits::merkle_circuit::MerkleTreeCircuit,
};
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig};

View File

@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, Criterion};
use anyhow::Result;
use codex_plonky2_circuits::{merkle_tree::merkle_safe::MerkleTree, circuits::safe_tree_circuit::MerkleTreeCircuit};
use codex_plonky2_circuits::{merkle_tree::merkle_safe::MerkleTree, circuits::merkle_circuit::MerkleTreeCircuit};
use plonky2::field::types::Field;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig};

View File

@ -1,446 +0,0 @@
// circuit for regular merkle tree implementation (non-safe version)
// the circuit uses caps in similar way as in Plonky2 Merkle tree implementation
// NOTE: this might be deleted at later time, since we don't use it for codex
use anyhow::Result;
use plonky2::field::extension::Extendable;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut};
use plonky2::plonk::proof::ProofWithPublicInputs;
use std::marker::PhantomData;
use itertools::Itertools;
use crate::merkle_tree::capped_tree::MerkleTree;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::hash::hash_types::{HashOutTarget, MerkleCapTarget, NUM_HASH_OUT_ELTS};
use crate::merkle_tree::capped_tree::MerkleProofTarget;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::plonk::proof::Proof;
use plonky2::hash::hashing::PlonkyPermutation;
use plonky2::plonk::circuit_data::VerifierCircuitTarget;
// size of leaf data (in number of field elements)
pub const LEAF_LEN: usize = 4;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleTreeTargets<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> {
pub proof_target: MerkleProofTarget,
pub cap_target: MerkleCapTarget,
pub leaf: Vec<Target>,
pub leaf_index_target: Target,
_phantom: PhantomData<(C,H)>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> {
pub tree: MerkleTree<F, H>,
pub _phantom: PhantomData<C>,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
C: GenericConfig<D, F = F>,
const D: usize,
H: Hasher<F> + AlgebraicHasher<F>,
> MerkleTreeCircuit<F, C, D, H>{
pub fn tree_height(&self) -> usize {
self.tree.leaves.len().trailing_zeros() as usize
}
// build the circuit and returns the circuit data
pub fn build_circuit(&mut self, builder: &mut CircuitBuilder::<F, D>) -> MerkleTreeTargets<F, C, D, H>{
let proof_t = MerkleProofTarget {
siblings: builder.add_virtual_hashes(self.tree_height()-self.tree.cap.height()),
};
let cap_t = builder.add_virtual_cap(self.tree.cap.height());
let leaf_index_t = builder.add_virtual_target();
let leaf_index_bits = builder.split_le(leaf_index_t, self.tree_height());
// NOTE: takes the length from const LEAF_LEN and assume all lengths are the same
let leaf_t: [Target; LEAF_LEN] = builder.add_virtual_targets(LEAF_LEN).try_into().unwrap();
let zero = builder.zero();
// let mut mt = MT(self.tree.clone());
self.verify_merkle_proof_to_cap_circuit(
builder, leaf_t.to_vec(), &leaf_index_bits, &cap_t, &proof_t,
);
MerkleTreeTargets{
proof_target: proof_t,
cap_target: cap_t,
leaf: leaf_t.to_vec(),
leaf_index_target: leaf_index_t,
_phantom: Default::default(),
}
}
pub fn fill_targets(
&self,
pw: &mut PartialWitness<F>,
// leaf_data: Vec<F>,
leaf_index: usize,
targets: MerkleTreeTargets<F, C, D, H>,
) {
let proof = self.tree.prove(leaf_index);
for i in 0..proof.siblings.len() {
pw.set_hash_target(targets.proof_target.siblings[i], proof.siblings[i]);
}
// set cap target manually
// pw.set_cap_target(&cap_t, &tree.cap);
for (ht, h) in targets.cap_target.0.iter().zip(&self.tree.cap.0) {
pw.set_hash_target(*ht, *h);
}
pw.set_target(
targets.leaf_index_target,
F::from_canonical_usize(leaf_index),
);
for j in 0..targets.leaf.len() {
pw.set_target(targets.leaf[j], self.tree.leaves[leaf_index][j]);
}
}
pub fn prove(
&self,
data: CircuitData<F, C, D>,
pw: PartialWitness<F>
) -> Result<ProofWithPublicInputs<F, C, D>> {
let proof = data.prove(pw);
return proof
}
// function to automate build and prove, useful for quick testing
pub fn build_and_prove(
&mut self,
// builder: &mut CircuitBuilder::<F, D>,
config: CircuitConfig,
// pw: &mut PartialWitness<F>,
leaf_index: usize,
// data: CircuitData<F, C, D>,
) -> Result<(CircuitData<F, C, D>,ProofWithPublicInputs<F, C, D>)> {
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::new();
// merkle proof
let merkle_proof = self.tree.prove(leaf_index);
let proof_t = MerkleProofTarget {
siblings: builder.add_virtual_hashes(merkle_proof.siblings.len()),
};
for i in 0..merkle_proof.siblings.len() {
pw.set_hash_target(proof_t.siblings[i], merkle_proof.siblings[i]);
}
// merkle cap target
let cap_t = builder.add_virtual_cap(self.tree.cap.height());
// set cap target manually
// pw.set_cap_target(&cap_t, &tree.cap);
for (ht, h) in cap_t.0.iter().zip(&self.tree.cap.0) {
pw.set_hash_target(*ht, *h);
}
// leaf index target
let leaf_index_t = builder.constant(F::from_canonical_usize(leaf_index));
let leaf_index_bits = builder.split_le(leaf_index_t, self.tree_height());
// leaf targets
// NOTE: takes the length from const LEAF_LEN and assume all lengths are the same
// let leaf_t = builder.add_virtual_targets(LEAF_LEN);
let leaf_t = builder.add_virtual_targets(self.tree.leaves[leaf_index].len());
for j in 0..leaf_t.len() {
pw.set_target(leaf_t[j], self.tree.leaves[leaf_index][j]);
}
// let mut mt = MT(self.tree.clone());
self.verify_merkle_proof_to_cap_circuit(
&mut builder, leaf_t.to_vec(), &leaf_index_bits, &cap_t, &proof_t,
);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
Ok((data, proof))
}
pub fn verify(
&self,
verifier_data: &VerifierCircuitData<F, C, D>,
public_inputs: Vec<F>,
proof: Proof<F, C, D>
) -> Result<()> {
verifier_data.verify(ProofWithPublicInputs {
proof,
public_inputs,
})
}
}
impl<F: RichField + Extendable<D> + Poseidon2, const D: usize, C: GenericConfig<D, F = F>, H: Hasher<F> + AlgebraicHasher<F>,> MerkleTreeCircuit<F, C, D, H> {
pub fn verify_merkle_proof_circuit(
&mut self,
builder: &mut CircuitBuilder<F, D>,
leaf_data: Vec<Target>,
leaf_index_bits: &[BoolTarget],
merkle_root: HashOutTarget,
proof: &MerkleProofTarget,
) {
let merkle_cap = MerkleCapTarget(vec![merkle_root]);
self.verify_merkle_proof_to_cap_circuit(builder, leaf_data, leaf_index_bits, &merkle_cap, proof);
}
pub fn verify_merkle_proof_to_cap_circuit(
&mut self,
builder: &mut CircuitBuilder<F, D>,
leaf_data: Vec<Target>,
leaf_index_bits: &[BoolTarget],
merkle_cap: &MerkleCapTarget,
proof: &MerkleProofTarget,
) {
let cap_index = builder.le_sum(leaf_index_bits[proof.siblings.len()..].iter().copied());
self.verify_merkle_proof_to_cap_with_cap_index_circuit(
builder,
leaf_data,
leaf_index_bits,
cap_index,
merkle_cap,
proof,
);
}
pub fn verify_merkle_proof_to_cap_with_cap_index_circuit(
&mut self,
builder: &mut CircuitBuilder<F, D>,
leaf_data: Vec<Target>,
leaf_index_bits: &[BoolTarget],
cap_index: Target,
merkle_cap: &MerkleCapTarget,
proof: &MerkleProofTarget,
) {
debug_assert!(H::AlgebraicPermutation::RATE >= NUM_HASH_OUT_ELTS);
let zero = builder.zero();
let mut state: HashOutTarget = builder.hash_or_noop::<H>(leaf_data);
debug_assert_eq!(state.elements.len(), NUM_HASH_OUT_ELTS);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let mut perm_inputs = H::AlgebraicPermutation::default();
perm_inputs.set_from_slice(&state.elements, 0);
perm_inputs.set_from_slice(&sibling.elements, NUM_HASH_OUT_ELTS);
// Ensure the rest of the state, if any, is zero:
perm_inputs.set_from_iter(core::iter::repeat(zero), 2 * NUM_HASH_OUT_ELTS);
// let perm_outs = builder.permute_swapped::<H>(perm_inputs, bit);
let perm_outs = H::permute_swapped(perm_inputs, bit, builder);
let hash_outs = perm_outs.squeeze()[0..NUM_HASH_OUT_ELTS]
.try_into()
.unwrap();
state = HashOutTarget {
elements: hash_outs,
};
}
for i in 0..NUM_HASH_OUT_ELTS {
let result = builder.random_access(
cap_index,
merkle_cap.0.iter().map(|h| h.elements[i]).collect(),
);
builder.connect(result, state.elements[i]);
}
}
pub fn verify_batch_merkle_proof_to_cap_with_cap_index_circuit(
&mut self,
builder: &mut CircuitBuilder<F, D>,
leaf_data: &[Vec<Target>],
leaf_heights: &[usize],
leaf_index_bits: &[BoolTarget],
cap_index: Target,
merkle_cap: &MerkleCapTarget,
proof: &MerkleProofTarget,
) {
debug_assert!(H::AlgebraicPermutation::RATE >= NUM_HASH_OUT_ELTS);
let zero = builder.zero();
let mut state: HashOutTarget = builder.hash_or_noop::<H>(leaf_data[0].clone());
debug_assert_eq!(state.elements.len(), NUM_HASH_OUT_ELTS);
let mut current_height = leaf_heights[0];
let mut leaf_data_index = 1;
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let mut perm_inputs = H::AlgebraicPermutation::default();
perm_inputs.set_from_slice(&state.elements, 0);
perm_inputs.set_from_slice(&sibling.elements, NUM_HASH_OUT_ELTS);
// Ensure the rest of the state, if any, is zero:
perm_inputs.set_from_iter(core::iter::repeat(zero), 2 * NUM_HASH_OUT_ELTS);
// let perm_outs = builder.permute_swapped::<H>(perm_inputs, bit);
let perm_outs = H::permute_swapped(perm_inputs, bit, builder);
let hash_outs = perm_outs.squeeze()[0..NUM_HASH_OUT_ELTS]
.try_into()
.unwrap();
state = HashOutTarget {
elements: hash_outs,
};
current_height -= 1;
if leaf_data_index < leaf_heights.len()
&& current_height == leaf_heights[leaf_data_index]
{
let mut new_leaves = state.elements.to_vec();
new_leaves.extend_from_slice(&leaf_data[leaf_data_index]);
state = builder.hash_or_noop::<H>(new_leaves);
leaf_data_index += 1;
}
}
for i in 0..NUM_HASH_OUT_ELTS {
let result = builder.random_access(
cap_index,
merkle_cap.0.iter().map(|h| h.elements[i]).collect(),
);
builder.connect(result, state.elements[i]);
}
}
pub fn connect_hashes(&mut self, builder: &mut CircuitBuilder<F, D>, x: HashOutTarget, y: HashOutTarget) {
for i in 0..NUM_HASH_OUT_ELTS {
builder.connect(x.elements[i], y.elements[i]);
}
}
pub fn connect_merkle_caps(&mut self, builder: &mut CircuitBuilder<F, D>, x: &MerkleCapTarget, y: &MerkleCapTarget) {
for (h0, h1) in x.0.iter().zip_eq(&y.0) {
self.connect_hashes(builder, *h0, *h1);
}
}
pub fn connect_verifier_data(&mut self, builder: &mut CircuitBuilder<F, D>, x: &VerifierCircuitTarget, y: &VerifierCircuitTarget) {
self.connect_merkle_caps(builder, &x.constants_sigmas_cap, &y.constants_sigmas_cap);
self.connect_hashes(builder, x.circuit_digest, y.circuit_digest);
}
}
#[cfg(test)]
pub mod tests {
use std::time::Instant;
use rand::rngs::OsRng;
use rand::Rng;
use super::*;
use plonky2::field::types::Field;
use crate::merkle_tree::capped_tree::MerkleTree;
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
pub fn random_data<F: Field>(n: usize, k: usize) -> Vec<Vec<F>> {
(0..n).map(|_| F::rand_vec(k)).collect()
}
#[test]
fn test_merkle_circuit() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
// create Merkle tree
let log_n = 8;
let n = 1 << log_n;
let cap_height = 1;
let leaves = random_data::<F>(n, LEAF_LEN);
let tree = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
// ---- prover zone ----
// Build and prove
let start_build = Instant::now();
let mut mt_circuit = MerkleTreeCircuit::<F,C,D,H>{ tree: tree.clone(), _phantom: Default::default() };
let leaf_index: usize = OsRng.gen_range(0..n);
let config = CircuitConfig::standard_recursion_config();
let (data, proof_with_pub_input) = mt_circuit.build_and_prove(config,leaf_index).unwrap();
println!("build and prove time is: {:?}", start_build.elapsed());
let vd = data.verifier_data();
let pub_input = proof_with_pub_input.public_inputs;
let proof = proof_with_pub_input.proof;
// ---- verifier zone ----
let start_verifier = Instant::now();
assert!(mt_circuit.verify(&vd,pub_input,proof).is_ok());
println!("verify time is: {:?}", start_verifier.elapsed());
Ok(())
}
#[test]
fn mod_test_merkle_circuit() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
// create Merkle tree
let log_n = 8;
let n = 1 << log_n;
let cap_height = 0;
let leaves = random_data::<F>(n, LEAF_LEN);
let tree = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
// Build circuit
let start_build = Instant::now();
let mut mt_circuit = MerkleTreeCircuit{ tree: tree.clone(), _phantom: Default::default() };
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let targets = mt_circuit.build_circuit(&mut builder);
let data = builder.build::<C>();
let vd = data.verifier_data();
println!("build time is: {:?}", start_build.elapsed());
// Prover Zone
let start_prover = Instant::now();
let mut pw = PartialWitness::new();
let leaf_index: usize = OsRng.gen_range(0..n);
let proof = tree.prove(leaf_index);
mt_circuit.fill_targets(&mut pw, leaf_index, targets);
let proof_with_pub_input = mt_circuit.prove(data,pw).unwrap();
let pub_input = proof_with_pub_input.public_inputs;
let proof = proof_with_pub_input.proof;
println!("prove time is: {:?}", start_prover.elapsed());
// Verifier zone
let start_verifier = Instant::now();
assert!(mt_circuit.verify(&vd,pub_input,proof).is_ok());
println!("verify time is: {:?}", start_verifier.elapsed());
Ok(())
}
}

View File

@ -8,7 +8,10 @@ use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
/// Compression function which takes two 256 bit inputs (HashOut) and u64 key (which is converted to field element in the function)
/// and returns a 256 bit output (HashOut).
pub fn key_compress<F: RichField, H:Hasher<F> >(x: HashOut<F>, y: HashOut<F>, key: u64) -> HashOut<F> {
pub fn key_compress<
F: RichField,
H:Hasher<F>
>(x: HashOut<F>, y: HashOut<F>, key: u64) -> HashOut<F> {
debug_assert_eq!(x.elements.len(), NUM_HASH_OUT_ELTS);
debug_assert_eq!(y.elements.len(), NUM_HASH_OUT_ELTS);

View File

@ -0,0 +1,176 @@
// Plonky2 Circuit implementation of "safe" merkle tree
// consistent with the one in codex:
// https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/merkle.circom
use anyhow::Result;
use plonky2::field::extension::Extendable;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS};
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use std::marker::PhantomData;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use crate::circuits::keyed_compress::key_compress_circuit;
use crate::circuits::params::HF;
use crate::circuits::utils::{add_assign_hash_out_target, assign_bool_targets, assign_hash_out_targets, mul_hash_out_target, usize_to_bits_le_padded};
use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER};
/// Merkle tree targets representing the input to the circuit
#[derive(Clone)]
pub struct MerkleTreeTargets{
pub leaf: HashOutTarget,
pub path_bits: Vec<BoolTarget>,
pub last_bits: Vec<BoolTarget>,
pub mask_bits: Vec<BoolTarget>,
pub merkle_path: MerkleProofTarget,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleProofTarget {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub path: Vec<HashOutTarget>,
}
/// Merkle tree circuit contains the functions for
/// building, proving and verifying the circuit.
#[derive(Clone)]
pub struct MerkleTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub phantom_data: PhantomData<F>,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> MerkleTreeCircuit<F, D> {
pub fn new() -> Self{
Self{
phantom_data: Default::default(),
}
}
/// takes the params from the targets struct
/// outputs the reconstructed merkle root
pub fn reconstruct_merkle_root_circuit(
builder: &mut CircuitBuilder<F, D>,
targets: &mut MerkleTreeTargets,
max_depth: usize,
) -> HashOutTarget {
let mut state: HashOutTarget = targets.leaf;
let zero = builder.zero();
let one = builder.one();
let two = builder.two();
debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len());
// compute is_last
let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1];
is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true)
for i in (0..max_depth).rev() {
let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target);
is_last[i] = builder.and( is_last[i + 1] , eq_out);
}
let mut i: usize = 0;
for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let bottom = if i == 0 {
builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER))
} else {
builder.constant(F::from_canonical_u64(KEY_NONE))
};
// compute: odd = isLast[i] * (1-pathBits[i]);
// compute: key = bottom + 2*odd
let mut odd = builder.sub(one, targets.path_bits[i].target);
odd = builder.mul(is_last[i].target, odd);
odd = builder.mul(two, odd);
let key = builder.add(bottom,odd);
// select left and right based on path_bit
let mut left = vec![];
let mut right = vec![];
for i in 0..NUM_HASH_OUT_ELTS {
left.push( builder.select(bit, sibling.elements[i], state.elements[i]));
right.push( builder.select(bit, state.elements[i], sibling.elements[i]));
}
state = key_compress_circuit::<F,D,HF>(builder,left,right,key);
i += 1;
}
return state;
}
/// takes the params from the targets struct
/// outputs the reconstructed merkle root
/// this one uses the mask bits
pub fn reconstruct_merkle_root_circuit_with_mask(
builder: &mut CircuitBuilder<F, D>,
targets: &mut MerkleTreeTargets,
max_depth: usize,
) -> HashOutTarget {
let mut state: Vec<HashOutTarget> = Vec::with_capacity(max_depth+1);
state.push(targets.leaf);
let zero = builder.zero();
let one = builder.one();
let two = builder.two();
debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len());
// compute is_last
let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1];
is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true)
for i in (0..max_depth).rev() {
let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target);
is_last[i] = builder.and( is_last[i + 1] , eq_out);
}
let mut i: usize = 0;
for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let bottom = if i == 0 {
builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER))
} else {
builder.constant(F::from_canonical_u64(KEY_NONE))
};
// compute: odd = isLast[i] * (1-pathBits[i]);
// compute: key = bottom + 2*odd
let mut odd = builder.sub(one, targets.path_bits[i].target);
odd = builder.mul(is_last[i].target, odd);
odd = builder.mul(two, odd);
let key = builder.add(bottom,odd);
// select left and right based on path_bit
let mut left = vec![];
let mut right = vec![];
for j in 0..NUM_HASH_OUT_ELTS {
left.push( builder.select(bit, sibling.elements[j], state[i].elements[j]));
right.push( builder.select(bit, state[i].elements[j], sibling.elements[j]));
}
state.push(key_compress_circuit::<F,D,HF>(builder,left,right,key));
i += 1;
}
// another way to do this is to use builder.select
// but that might be less efficient & more constraints
let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec());
for k in 0..max_depth {
let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target);
let mul_result = mul_hash_out_target(builder,&diff,&mut state[k+1]);
add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result);
}
// reconstructed_root
state[max_depth]
}
}

View File

@ -1,6 +1,6 @@
pub mod capped_tree_circuit;
pub mod safe_tree_circuit;
pub mod prove_single_cell;
// pub mod capped_tree_circuit;
pub mod merkle_circuit;
// pub mod prove_single_cell;
pub mod sample_cells;
pub mod utils;
pub mod params;

View File

@ -1,25 +1,10 @@
// global params for the circuits
// only change params here
use plonky2::hash::poseidon::PoseidonHash;
// constants and types used throughout the circuit
pub const N_FIELD_ELEMS_PER_CELL: usize = 256;
pub const BOT_DEPTH: usize = 5; // block depth - depth of the block merkle tree
pub const MAX_DEPTH: usize = 16; // depth of big tree (slot tree depth + block tree depth)
pub const N_CELLS_IN_BLOCKS: usize = 1<<BOT_DEPTH; //2^BOT_DEPTH
pub const N_BLOCKS: usize = 1<<(MAX_DEPTH - BOT_DEPTH); // 2^(MAX_DEPTH - BOT_DEPTH)
pub const N_CELLS: usize = N_CELLS_IN_BLOCKS * N_BLOCKS;
//the index of the slot to be sampled
// this is fixed to speed up creating fake dataset
// otherwise it would take lots of time
pub const TESTING_SLOT_INDEX: usize = 2;
pub const DATASET_DEPTH: usize = 5;
pub const N_SAMPLES: usize = 10;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2Hash;
// hash function used. this is hackish way of doing it because
// H::Hash is not consistent with HashOut<F> and causing a lot of headache
// will look into this later.
pub type HF = PoseidonHash;
pub type HF = PoseidonHash;

View File

@ -1,475 +0,0 @@
// prove single cell
// consistent with:
// https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/single_cell.circom
// circuit consists of:
// - reconstruct the block merkle root
// - use merkle root as leaf and reconstruct slot root
// - check equality with given slot root
use anyhow::Result;
use plonky2::field::extension::Extendable;
use plonky2::hash::hash_types::{HashOut, RichField};
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut};
use std::marker::PhantomData;
use itertools::Itertools;
use crate::merkle_tree::merkle_safe::MerkleTree;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::hash::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS};
use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleProofTarget};
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::hash::hashing::PlonkyPermutation;
use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets};
use crate::circuits::utils::usize_to_bits_le_padded;
use crate::circuits::params::{MAX_DEPTH, BOT_DEPTH, N_FIELD_ELEMS_PER_CELL, N_CELLS_IN_BLOCKS, N_BLOCKS, N_CELLS, HF};
// ------ Slot Tree --------
#[derive(Clone)]
pub struct SlotTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub tree: MerkleTreeCircuit<F,D>, // slot tree
pub block_trees: Vec<MerkleTreeCircuit<F,D>>, // vec of block trees
pub cell_data: Vec<Vec<F>>, // cell data as field elements
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> Default for SlotTreeCircuit<F,D>{
/// slot tree with fake data, for testing only
fn default() -> Self {
// generate fake cell data
let mut cell_data = (0..N_CELLS)
.map(|i|{
(0..N_FIELD_ELEMS_PER_CELL)
.map(|j| F::from_canonical_u64((j+i) as u64))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// hash it
let leaves: Vec<HashOut<F>> = cell_data
.iter()
.map(|element| {
HF::hash_no_pad(&element)
})
.collect();
// zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
// create block tree
let block_trees = (0..N_BLOCKS)
.map(|i| {
let start = i * N_CELLS_IN_BLOCKS;
let end = (i + 1) * N_CELLS_IN_BLOCKS;
let b_tree = Self::get_block_tree(&leaves[start..end].to_vec()); // use helper function
MerkleTreeCircuit::<F,D>{ tree:b_tree}
})
.collect::<Vec<_>>();
// get the roots or block trees
let block_roots = block_trees.iter()
.map(|t| {
t.tree.root().unwrap()
})
.collect::<Vec<_>>();
// create slot tree
let slot_tree = MerkleTree::<F>::new(&block_roots, zero).unwrap();
Self{
tree: MerkleTreeCircuit::<F,D>{ tree:slot_tree},
block_trees,
cell_data,
}
}
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> SlotTreeCircuit<F,D> {
/// Slot tree with fake data, for testing only
pub fn new_for_testing() -> Self {
// Generate fake cell data for one block
let cell_data_block = (0..N_CELLS_IN_BLOCKS)
.map(|i| {
(0..N_FIELD_ELEMS_PER_CELL)
.map(|j| F::from_canonical_u64((j + i) as u64))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// Hash the cell data block to create leaves for one block
let leaves_block: Vec<HashOut<F>> = cell_data_block
.iter()
.map(|element| {
HF::hash_no_pad(&element)
})
.collect();
// Zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
// Create a block tree from the leaves of one block
let b_tree = Self::get_block_tree(&leaves_block);
// Create a block tree circuit
let block_tree_circuit = MerkleTreeCircuit::<F, D> {
tree: b_tree,
// _phantom: Default::default(),
};
// Now replicate this block tree for all N_BLOCKS blocks
let block_trees = vec![block_tree_circuit.clone(); N_BLOCKS];
// Get the roots of block trees
let block_roots = block_trees
.iter()
.map(|t| t.tree.root().unwrap())
.collect::<Vec<_>>();
// Create the slot tree from block roots
let slot_tree = MerkleTree::<F>::new(&block_roots, zero).unwrap();
// Create the full cell data and cell hash by repeating the block data
let cell_data = vec![cell_data_block.clone(); N_BLOCKS].concat();
// Return the constructed Self
Self {
tree: MerkleTreeCircuit::<F, D> {
tree: slot_tree,
},
block_trees,
cell_data,
}
}
/// same as default but with supplied cell data
pub fn new(cell_data: Vec<Vec<F>>) -> Self{
let leaves: Vec<HashOut<F>> = cell_data
.iter()
.map(|element| {
HF::hash_no_pad(element)
})
.collect();
let zero = HashOut {
elements: [F::ZERO; 4],
};
let block_trees = (0..N_BLOCKS as usize)
.map(|i| {
let start = i * N_CELLS_IN_BLOCKS;
let end = (i + 1) * N_CELLS_IN_BLOCKS;
let b_tree = Self::get_block_tree(&leaves[start..end].to_vec());
MerkleTreeCircuit::<F,D>{ tree:b_tree}
})
.collect::<Vec<_>>();
let block_roots = block_trees.iter()
.map(|t| {
t.tree.root().unwrap()
})
.collect::<Vec<_>>();
let slot_tree = MerkleTree::<F>::new(&block_roots, zero).unwrap();
Self{
tree: MerkleTreeCircuit::<F,D>{ tree:slot_tree},
block_trees,
cell_data,
}
}
/// generates a proof for given leaf index
/// the path in the proof is a combined block and slot path to make up the full path
pub fn get_proof(&self, index: usize) -> MerkleProof<F> {
let block_index = index/ N_CELLS_IN_BLOCKS;
let leaf_index = index % N_CELLS_IN_BLOCKS;
let block_proof = self.block_trees[block_index].tree.get_proof(leaf_index).unwrap();
let slot_proof = self.tree.tree.get_proof(block_index).unwrap();
// Combine the paths from the block and slot proofs
let mut combined_path = block_proof.path.clone();
combined_path.extend(slot_proof.path.clone());
MerkleProof::<F> {
index: index,
path: combined_path,
nleaves: self.cell_data.len(),
zero: block_proof.zero.clone(),
}
}
/// verify the given proof for slot tree, checks equality with given root
pub fn verify_cell_proof(&self, proof: MerkleProof<F>, root: HashOut<F>) -> Result<bool>{
let mut block_path_bits = usize_to_bits_le_padded(proof.index, MAX_DEPTH);
let last_index = N_CELLS - 1;
let mut block_last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH);
let split_point = BOT_DEPTH;
let slot_last_bits = block_last_bits.split_off(split_point);
let slot_path_bits = block_path_bits.split_off(split_point);
let leaf_hash = HF::hash_no_pad(&self.cell_data[proof.index]);
let mut block_path = proof.path;
let slot_path = block_path.split_off(split_point);
let block_res = MerkleProof::<F>::reconstruct_root2(leaf_hash,block_path_bits.clone(),block_last_bits.clone(),block_path);
let reconstructed_root = MerkleProof::<F>::reconstruct_root2(block_res.unwrap(),slot_path_bits,slot_last_bits,slot_path);
Ok(reconstructed_root.unwrap() == root)
}
fn get_block_tree(leaves: &Vec<HashOut<F>>) -> MerkleTree<F> {
let zero = HashOut {
elements: [F::ZERO; 4],
};
// Build the Merkle tree
let block_tree = MerkleTree::<F>::new(leaves, zero).unwrap();
block_tree
}
}
//------- single cell struct ------
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SingleCellTargets{
pub expected_slot_root_target: HashOutTarget,
pub proof_target: MerkleProofTarget,
pub leaf_target: Vec<Target>,
pub path_bits: Vec<BoolTarget>,
pub last_bits: Vec<BoolTarget>,
}
//------- circuit impl --------
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> SlotTreeCircuit<F, D> {
pub fn prove_single_cell(
builder: &mut CircuitBuilder::<F, D>
) -> SingleCellTargets {
// Retrieve tree depth
let depth = MAX_DEPTH;
// Create virtual targets
let mut leaf = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
let mut hash_inputs:Vec<Target>= Vec::new();
hash_inputs.extend_from_slice(&leaf);
let leaf_hash = builder.hash_n_to_hash_no_pad::<HF>(hash_inputs);
// path bits (binary decomposition of leaf_index)
let mut block_path_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
let mut slot_path_bits = (0..(depth - BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1)
let block_last_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
let slot_last_bits = (0..(depth-BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// mask bits (binary decomposition of last_index = nleaves - 1)
let block_mask_bits = (0..BOT_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
let slot_mask_bits = (0..(depth-BOT_DEPTH)+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// Merkle path (sibling hashes from leaf to root)
let mut block_merkle_path = MerkleProofTarget {
path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(),
};
let mut slot_merkle_path = MerkleProofTarget {
path: (0..(depth - BOT_DEPTH)).map(|_| builder.add_virtual_hash()).collect(),
};
// expected Merkle root
let slot_expected_root = builder.add_virtual_hash();
let mut block_targets = MerkleTreeTargets {
leaf: leaf_hash,
path_bits:block_path_bits,
last_bits: block_last_bits,
mask_bits: block_mask_bits,
merkle_path: block_merkle_path,
};
// reconstruct block root
let block_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit(builder, &mut block_targets);
// create MerkleTreeTargets struct
let mut slot_targets = MerkleTreeTargets {
leaf: block_root,
path_bits:slot_path_bits,
last_bits:slot_last_bits,
mask_bits:slot_mask_bits,
merkle_path:slot_merkle_path,
};
// reconstruct slot root with block root as leaf
let slot_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit(builder, &mut slot_targets);
// check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS {
builder.connect(slot_expected_root.elements[i], slot_root.elements[i]);
}
let mut proof_target = MerkleProofTarget{
path: block_targets.merkle_path.path,
};
proof_target.path.extend_from_slice(&slot_targets.merkle_path.path);
let mut path_bits = block_targets.path_bits;
path_bits.extend_from_slice(&slot_targets.path_bits);
let mut last_bits = block_targets.last_bits;
last_bits.extend_from_slice(&slot_targets.last_bits);
let mut cell_targets = SingleCellTargets {
expected_slot_root_target: slot_expected_root,
proof_target,
leaf_target: leaf,
path_bits,
last_bits,
};
// Return MerkleTreeTargets
cell_targets
}
/// assign the witness values in the circuit targets
/// this takes leaf_index, leaf, and proof (generated from slot_tree)
/// and fills all required circuit targets(circuit inputs)
pub fn single_cell_assign_witness(
&self,
pw: &mut PartialWitness<F>,
targets: &mut SingleCellTargets,
leaf_index: usize,
leaf: &Vec<F>,
proof: MerkleProof<F>,
)-> Result<()> {
// Assign the leaf to the leaf target
for i in 0..targets.leaf_target.len(){
pw.set_target(targets.leaf_target[i], leaf[i]);
}
// Convert `leaf_index` to binary bits and assign as path_bits
let path_bits = usize_to_bits_le_padded(leaf_index, MAX_DEPTH);
for (i, bit) in path_bits.iter().enumerate() {
pw.set_bool_target(targets.path_bits[i], *bit);
}
// get `last_index` (nleaves - 1) in binary bits and assign
let last_index = N_CELLS - 1;
let last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH);
for (i, bit) in last_bits.iter().enumerate() {
pw.set_bool_target(targets.last_bits[i], *bit);
}
// assign the Merkle path (sibling hashes) to the targets
for (i, sibling_hash) in proof.path.iter().enumerate() {
// This is a bit hacky because it should be HashOutTarget, but it is H:Hash
// pw.set_hash_target(targets.merkle_path.path[i],sibling_hash);
// TODO: fix this HashOutTarget later
let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() {
pw.set_target(targets.proof_target.path[i].elements[j], sibling_hash_out[j]);
}
}
// assign the expected Merkle root to the target
let expected_root = self.tree.tree.root()?;
// TODO: fix this HashOutTarget later same issue as above
let expected_root_hash_out = expected_root.to_vec();
for j in 0..expected_root_hash_out.len() {
pw.set_target(targets.expected_slot_root_target.elements[j], expected_root_hash_out[j]);
}
Ok(())
}
fn hash_leaf(builder: &mut CircuitBuilder<F, D >, leaf: &mut Vec<Target>){
builder.hash_n_to_hash_no_pad::<HF>(leaf.to_owned());
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::iop::witness::PartialWitness;
//types for tests
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
#[test]
fn test_prove_single_cell(){
let slot_t = SlotTreeCircuit::<F,D>::default();
let index = 8;
let proof = slot_t.get_proof(index);
let res = slot_t.verify_cell_proof(proof,slot_t.tree.tree.root().unwrap()).unwrap();
assert_eq!(res, true);
}
#[test]
fn test_cell_build_circuit() -> Result<()> {
let slot_t = SlotTreeCircuit::<F,D>::default();
// select leaf index to prove
let leaf_index: usize = 8;
let proof = slot_t.get_proof(leaf_index);
// get the expected Merkle root
let expected_root = slot_t.tree.tree.root().unwrap();
let res = slot_t.verify_cell_proof(proof.clone(),expected_root).unwrap();
assert_eq!(res, true);
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut targets = SlotTreeCircuit::<F,D>::prove_single_cell(&mut builder);
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
slot_t.single_cell_assign_witness(&mut pw, &mut targets, leaf_index, &slot_t.cell_data[leaf_index], proof)?;
// build the circuit
let data = builder.build::<C>();
println!("circuit size = {:?}", data.common.degree_bits());
// Prove the circuit with the assigned witness
let start_time = Instant::now();
let proof_with_pis = data.prove(pw)?;
println!("prove_time = {:?}", start_time.elapsed());
// verify the proof
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
}

View File

@ -1,439 +0,0 @@
// Plonky2 Circuit implementation of "safe" merkle tree
// consistent with the one in codex:
// https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/merkle.circom
use anyhow::Result;
use plonky2::field::extension::Extendable;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS};
use plonky2::hash::hashing::PlonkyPermutation;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2::plonk::proof::{Proof, ProofWithPublicInputs};
use std::marker::PhantomData;
use std::os::macos::raw::stat;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use serde::Serialize;
use crate::circuits::keyed_compress::key_compress_circuit;
use crate::circuits::params::{HF, MAX_DEPTH};
use crate::circuits::utils::usize_to_bits_le_padded;
use crate::merkle_tree::merkle_safe::{MerkleTree, MerkleProofTarget};
use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER};
/// Merkle tree targets representing the input to the circuit
// note: this omits the mask bits since in plonky2 we can
// uses the Plonk's permutation argument to check that two elements are equal.
// TODO: double check the need for mask
#[derive(Clone)]
pub struct MerkleTreeTargets{
pub leaf: HashOutTarget,
pub path_bits: Vec<BoolTarget>,
pub last_bits: Vec<BoolTarget>,
pub mask_bits: Vec<BoolTarget>,
pub merkle_path: MerkleProofTarget,
}
/// Merkle tree circuit contains the tree and functions for
/// building, proving and verifying the circuit.
#[derive(Clone)]
pub struct MerkleTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub tree: MerkleTree<F>,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> MerkleTreeCircuit<F, D> {
/// defines the computations inside the circuit and returns the targets used
pub fn build_circuit(
&mut self,
builder: &mut CircuitBuilder::<F, D>
) -> (MerkleTreeTargets, HashOutTarget) {
// Retrieve tree depth
let depth = self.tree.depth();
// Create virtual targets
let leaf = builder.add_virtual_hash();
// path bits (binary decomposition of leaf_index)
let path_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1)
let last_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1)
let mask_bits = (0..MAX_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// Merkle path (sibling hashes from leaf to root)
let merkle_path = MerkleProofTarget {
path: (0..MAX_DEPTH).map(|_| builder.add_virtual_hash()).collect(),
};
// create MerkleTreeTargets struct
let mut targets = MerkleTreeTargets{
leaf,
path_bits,
last_bits,
mask_bits,
merkle_path,
};
// Add Merkle proof verification constraints to the circuit
let expected_root_target = Self::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets);
// Return MerkleTreeTargets
(targets, expected_root_target)
}
/// assign the witness values in the circuit targets
/// this takes leaf_index and fills all required circuit targets(inputs)
pub fn assign_witness(
&mut self,
pw: &mut PartialWitness<F>,
targets: &mut MerkleTreeTargets,
leaf_index: usize,
)-> Result<()> {
// Get the total number of leaves and tree depth
let nleaves = self.tree.leaves_count();
let depth = self.tree.depth();
// get the Merkle proof for the specified leaf index
let proof = self.tree.get_proof(leaf_index)?;
// get the leaf hash from the Merkle tree
let leaf_hash = self.tree.layers[0][leaf_index].clone();
// Assign the leaf hash to the leaf target
pw.set_hash_target(targets.leaf, leaf_hash);
// Convert `leaf_index` to binary bits and assign as path_bits
let path_bits = usize_to_bits_le_padded(leaf_index, MAX_DEPTH);
for (i, bit) in path_bits.iter().enumerate() {
pw.set_bool_target(targets.path_bits[i], *bit);
}
// get `last_index` (nleaves - 1) in binary bits and assign
let last_index = nleaves - 1;
let last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH);
for (i, bit) in last_bits.iter().enumerate() {
pw.set_bool_target(targets.last_bits[i], *bit);
}
// get mask bits
let mask_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH+1);
for (i, bit) in mask_bits.iter().enumerate() {
pw.set_bool_target(targets.mask_bits[i], *bit);
}
// assign the Merkle path (sibling hashes) to the targets
for i in 0..MAX_DEPTH {
if i>=proof.path.len() {
for j in 0..NUM_HASH_OUT_ELTS {
pw.set_target(targets.merkle_path.path[i].elements[j], F::ZERO);
}
continue
}
// This is a bit hacky because it should be HashOutTarget, but it is H:Hash
// pw.set_hash_target(targets.merkle_path.path[i],sibling_hash);
// TODO: fix this HashOutTarget later
let sibling_hash = proof.path[i];
let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() {
pw.set_target(targets.merkle_path.path[i].elements[j], sibling_hash_out[j]);
}
}
Ok(())
}
/// takes the params from the targets struct
/// outputs the reconstructed merkle root
pub fn reconstruct_merkle_root_circuit(
builder: &mut CircuitBuilder<F, D>,
targets: &mut MerkleTreeTargets,
) -> HashOutTarget {
let max_depth = targets.path_bits.len();
let mut state: HashOutTarget = targets.leaf;
let zero = builder.zero();
let one = builder.one();
let two = builder.two();
debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len());
// compute is_last
let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1];
is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true)
for i in (0..max_depth).rev() {
let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target);
is_last[i] = builder.and( is_last[i + 1] , eq_out);
}
let mut i: usize = 0;
for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let bottom = if i == 0 {
builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER))
} else {
builder.constant(F::from_canonical_u64(KEY_NONE))
};
// compute: odd = isLast[i] * (1-pathBits[i]);
// compute: key = bottom + 2*odd
let mut odd = builder.sub(one, targets.path_bits[i].target);
odd = builder.mul(is_last[i].target, odd);
odd = builder.mul(two, odd);
let key = builder.add(bottom,odd);
// select left and right based on path_bit
let mut left = vec![];
let mut right = vec![];
for i in 0..NUM_HASH_OUT_ELTS {
left.push( builder.select(bit, sibling.elements[i], state.elements[i]));
right.push( builder.select(bit, state.elements[i], sibling.elements[i]));
}
state = key_compress_circuit::<F,D,HF>(builder,left,right,key);
i += 1;
}
return state;
}
/// takes the params from the targets struct
/// outputs the reconstructed merkle root
/// this one uses the mask bits
pub fn reconstruct_merkle_root_circuit_with_mask(
builder: &mut CircuitBuilder<F, D>,
targets: &mut MerkleTreeTargets,
) -> HashOutTarget {
let max_depth = targets.path_bits.len();
let mut state: Vec<HashOutTarget> = Vec::with_capacity(max_depth+1);
state.push(targets.leaf);
let zero = builder.zero();
let one = builder.one();
let two = builder.two();
debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len());
// compute is_last
let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1];
is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true)
for i in (0..max_depth).rev() {
let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target);
is_last[i] = builder.and( is_last[i + 1] , eq_out);
}
let mut i: usize = 0;
for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) {
debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS);
let bottom = if i == 0 {
builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER))
} else {
builder.constant(F::from_canonical_u64(KEY_NONE))
};
// compute: odd = isLast[i] * (1-pathBits[i]);
// compute: key = bottom + 2*odd
let mut odd = builder.sub(one, targets.path_bits[i].target);
odd = builder.mul(is_last[i].target, odd);
odd = builder.mul(two, odd);
let key = builder.add(bottom,odd);
// select left and right based on path_bit
let mut left = vec![];
let mut right = vec![];
for j in 0..NUM_HASH_OUT_ELTS {
left.push( builder.select(bit, sibling.elements[j], state[i].elements[j]));
right.push( builder.select(bit, state[i].elements[j], sibling.elements[j]));
}
state.push(key_compress_circuit::<F,D,HF>(builder,left,right,key));
i += 1;
}
// println!("mask = {}, last = {}", targets.mask_bits.len(), targets.last_bits.len(), );
// another way to do this is to use builder.select
// but that might be less efficient & more constraints
let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec());
for k in 0..MAX_DEPTH {
let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target);
let mul_result = Self::mul_hash_out_target(builder,&diff,&mut state[k+1]);
Self::add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result);
}
reconstructed_root
}
/// helper fn to multiply a HashOutTarget by a Target
pub(crate) fn mul_hash_out_target(builder: &mut CircuitBuilder<F, D>, t: &Target, hash_target: &mut HashOutTarget) -> HashOutTarget {
let mut mul_elements = vec![];
for i in 0..NUM_HASH_OUT_ELTS {
mul_elements.push(builder.mul(hash_target.elements[i], *t));
}
HashOutTarget::from_vec(mul_elements)
}
/// helper fn to add AND assign a HashOutTarget (hot) to a mutable HashOutTarget (mut_hot)
pub(crate) fn add_assign_hash_out_target(builder: &mut CircuitBuilder<F, D>, mut_hot: &mut HashOutTarget, hot: &HashOutTarget) {
for i in 0..NUM_HASH_OUT_ELTS {
mut_hot.elements[i] = (builder.add(mut_hot.elements[i], hot.elements[i]));
}
}
}
// NOTE: for now these tests don't check the reconstructed root is equal to expected_root
// will be fixed later, but for that test check the prove_single_cell tests
#[cfg(test)]
mod tests {
use super::*;
use plonky2::field::types::Field;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::iop::witness::PartialWitness;
#[test]
fn test_build_circuit() -> Result<()> {
// circuit params
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
// Generate random leaf data
let nleaves = 10; // Number of leaves
let data = (0..nleaves)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to obtain leaf hashes
let leaves: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| {
// Hash each field element to get the leaf hash
PoseidonHash::hash_no_pad(&[element])
})
.collect();
//initialize the Merkle tree
let zero_hash = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let tree = MerkleTree::<F>::new(&leaves, zero_hash)?;
// select leaf index to prove
let leaf_index: usize = 8;
// get the Merkle proof for the selected leaf
let proof = tree.get_proof(leaf_index)?;
// sanity check:
let check = proof.verify(tree.layers[0][leaf_index],tree.root().unwrap()).unwrap();
assert_eq!(check, true);
// get the expected Merkle root
let expected_root = tree.root()?;
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut circuit_instance = MerkleTreeCircuit::<F, D> {
tree: tree.clone(),
// _phantom: PhantomData,
};
let (mut targets, expected_root_target) = circuit_instance.build_circuit(&mut builder);
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
circuit_instance.assign_witness(&mut pw, &mut targets, leaf_index)?;
// build the circuit
let data = builder.build::<C>();
// Prove the circuit with the assigned witness
let proof_with_pis = data.prove(pw)?;
// verify the proof
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// same as test above but for all leaves
#[test]
fn test_verify_all_leaves() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
let nleaves = 10; // Number of leaves
let data = (0..nleaves)
.map(|i| GoldilocksField::from_canonical_u64(i as u64))
.collect::<Vec<_>>();
// Hash the data to obtain leaf hashes
let leaves: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| {
// Hash each field element to get the leaf hash
PoseidonHash::hash_no_pad(&[element])
})
.collect();
let zero_hash = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let tree = MerkleTree::<F>::new(&leaves, zero_hash)?;
let expected_root = tree.root()?;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut circuit_instance = MerkleTreeCircuit::<F, D> {
tree: tree.clone(),
};
let (mut targets, expected_root_target) = circuit_instance.build_circuit(&mut builder);
let data = builder.build::<C>();
for leaf_index in 0..nleaves {
let proof = tree.get_proof(leaf_index)?;
let check = proof.verify(tree.layers[0][leaf_index], expected_root)?;
assert!(
check,
"Merkle proof verification failed for leaf index {}",
leaf_index
);
let mut pw = PartialWitness::new();
circuit_instance.assign_witness(&mut pw, &mut targets, leaf_index)?;
let proof_with_pis = data.prove(pw)?;
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed in circuit for leaf index {}",
leaf_index
);
}
Ok(())
}
}

View File

@ -11,25 +11,14 @@ use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichF
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut};
use std::marker::PhantomData;
use itertools::Itertools;
use crate::merkle_tree::merkle_safe::MerkleTree;
use plonky2::hash::poseidon::PoseidonHash;
use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleProofTarget};
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::hash::hashing::PlonkyPermutation;
use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTreeCircuit};
use crate::circuits::params::{BOT_DEPTH, DATASET_DEPTH, HF, MAX_DEPTH, N_FIELD_ELEMS_PER_CELL, N_SAMPLES, TESTING_SLOT_INDEX};
use crate::circuits::params::HF;
use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets};
use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits};
use crate::circuits::merkle_circuit::{MerkleTreeCircuit, MerkleTreeTargets, MerkleProofTarget};
use crate::circuits::utils::{assign_hash_out_targets, bits_le_padded_to_usize, calculate_cell_index_bits};
// ------ Dataset Tree --------
///dataset tree containing all slot trees
@ -38,175 +27,81 @@ pub struct DatasetTreeCircuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub tree: MerkleTreeCircuit<F, D>, // dataset tree
pub slot_trees: Vec<SlotTreeCircuit<F,D>>, // vec of slot trees
}
/// Dataset Merkle proof struct, containing the dataset proof and N_SAMPLES proofs.
#[derive(Clone)]
pub struct DatasetMerkleProof<F: RichField> {
pub slot_index: usize,
pub entropy: usize,
pub dataset_proof: MerkleProof<F>, // proof for dataset level tree
pub slot_proofs: Vec<MerkleProof<F>>, // proofs for sampled slot, contains N_SAMPLES proofs
params: CircuitParams,
phantom_data: PhantomData<F>,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> Default for DatasetTreeCircuit<F,D> {
/// dataset tree with fake data, for testing only
fn default() -> Self {
let mut slot_trees = vec![];
let n_slots = 1<<DATASET_DEPTH;
for i in 0..n_slots {
slot_trees.push(SlotTreeCircuit::<F,D>::default());
}
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.tree.root().unwrap()
})
.collect::<Vec<_>>();
// zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
let dataset_tree = MerkleTree::<F>::new(&slot_roots, zero).unwrap();
> DatasetTreeCircuit<F, D> {
pub fn new(params: CircuitParams) -> Self{
Self{
tree: MerkleTreeCircuit::<F,D>{ tree:dataset_tree},
slot_trees,
params,
phantom_data: Default::default(),
}
}
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> DatasetTreeCircuit<F,D> {
/// dataset tree with fake data, for testing only
/// create data for only the TESTING_SLOT_INDEX in params file
pub fn new_for_testing() -> Self {
let mut slot_trees = vec![];
let n_slots = 1<<DATASET_DEPTH;
// zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
let zero_slot = SlotTreeCircuit::<F,D>{
tree: MerkleTreeCircuit {
tree: MerkleTree::<F>::new(&[zero.clone()], zero.clone()).unwrap(),
},
block_trees: vec![],
cell_data: vec![],
};
for i in 0..n_slots {
if(i == TESTING_SLOT_INDEX) {
slot_trees.push(SlotTreeCircuit::<F, D>::new_for_testing());
}else{
slot_trees.push(zero_slot.clone());
}
}
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.tree.root().unwrap()
})
.collect::<Vec<_>>();
let dataset_tree = MerkleTree::<F>::new(&slot_roots, zero).unwrap();
Self{
tree: MerkleTreeCircuit::<F,D>{ tree:dataset_tree},
slot_trees,
}
}
/// same as default but with supplied slot trees
pub fn new(slot_trees: Vec<SlotTreeCircuit<F,D>>) -> Self{
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.tree.root().unwrap()
})
.collect::<Vec<_>>();
// zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
let dataset_tree = MerkleTree::<F>::new(&slot_roots, zero).unwrap();
Self{
tree: MerkleTreeCircuit::<F,D>{ tree:dataset_tree},
slot_trees,
}
}
/// generates a dataset level proof for given slot index
/// just a regular merkle tree proof
pub fn get_proof(&self, index: usize) -> MerkleProof<F> {
let dataset_proof = self.tree.tree.get_proof(index).unwrap();
dataset_proof
}
/// generates a proof for given slot index
/// also takes entropy so it can use it sample the slot
pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetMerkleProof<F> {
let dataset_proof = self.tree.tree.get_proof(index).unwrap();
let slot = &self.slot_trees[index];
let slot_root = slot.tree.tree.root().unwrap();
let mut slot_proofs = vec![];
// get the index for cell from H(slot_root|counter|entropy)
for i in 0..N_SAMPLES {
let cell_index_bits = calculate_cell_index_bits(entropy, slot_root, i);
let cell_index = bits_le_padded_to_usize(&cell_index_bits);
slot_proofs.push(slot.get_proof(cell_index));
}
DatasetMerkleProof{
slot_index: index,
entropy,
dataset_proof,
slot_proofs,
}
}
// verify the sampling - non-circuit version
pub fn verify_sampling(&self, proof: DatasetMerkleProof<F>) -> Result<bool>{
let slot = &self.slot_trees[proof.slot_index];
let slot_root = slot.tree.tree.root().unwrap();
// check dataset level proof
let d_res = proof.dataset_proof.verify(slot_root,self.tree.tree.root().unwrap());
if(d_res.unwrap() == false){
return Ok(false);
}
// sanity check
assert_eq!(N_SAMPLES, proof.slot_proofs.len());
// get the index for cell from H(slot_root|counter|entropy)
for i in 0..N_SAMPLES {
let cell_index_bits = calculate_cell_index_bits(proof.entropy, slot_root, i);
let cell_index = bits_le_padded_to_usize(&cell_index_bits);
//check the cell_index is the same as one in the proof
assert_eq!(cell_index, proof.slot_proofs[i].index);
let s_res = slot.verify_cell_proof(proof.slot_proofs[i].clone(),slot_root);
if(s_res.unwrap() == false){
return Ok(false);
}
}
Ok(true)
}
// params used for the circuits
// should be defined prior to building the circuit
#[derive(Clone)]
pub struct CircuitParams{
pub max_depth: usize,
pub max_log2_n_slots: usize,
pub block_tree_depth: usize,
pub n_field_elems_per_cell: usize,
pub n_samples: usize,
}
#[derive(Clone)]
pub struct DatasetTargets{
pub dataset_proof: MerkleProofTarget, // proof that slot_root in dataset tree
pub struct SampleTargets {
pub entropy: HashOutTarget,
pub dataset_root: HashOutTarget,
pub slot_index: Target,
pub slot_root: HashOutTarget,
pub n_cells_per_slot: Target,
pub n_slots_per_dataset: Target,
pub slot_proof: MerkleProofTarget, // proof that slot_root in dataset tree
pub cell_data: Vec<Vec<Target>>,
pub entropy: HashOutTarget,
pub slot_index: Target,
pub slot_root: HashOutTarget,
pub slot_proofs: Vec<MerkleProofTarget>,
pub merkle_paths: Vec<MerkleProofTarget>,
}
#[derive(Clone)]
pub struct SampleCircuitInput<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>{
pub entropy: Vec<F>,
pub dataset_root: HashOut<F>,
pub slot_index: F,
pub slot_root: HashOut<F>,
pub n_cells_per_slot: F,
pub n_slots_per_dataset: F,
pub slot_proof: Vec<HashOut<F>>, // proof that slot_root in dataset tree
pub cell_data: Vec<Vec<F>>,
pub merkle_paths: Vec<Vec<HashOut<F>>>,
}
#[derive(Clone)]
pub struct MerklePath<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
path: Vec<HashOut<F>>
}
#[derive(Clone)]
pub struct CellTarget {
pub data: Vec<Target>
}
//------- circuit impl --------
@ -218,9 +113,17 @@ impl<
// in-circuit sampling
// TODO: make it more modular
pub fn sample_slot_circuit(
&mut self,
&self,
builder: &mut CircuitBuilder::<F, D>,
)-> DatasetTargets {
)-> SampleTargets {
// circuit params
let CircuitParams {
max_depth,
max_log2_n_slots,
block_tree_depth,
n_field_elems_per_cell,
n_samples,
} = self.params;
// constants
let zero = builder.zero();
@ -229,27 +132,24 @@ impl<
// ***** prove slot root is in dataset tree *********
// Retrieve dataset tree depth
let d_depth = DATASET_DEPTH;
// Create virtual target for slot root and index
let slot_root = builder.add_virtual_hash();
let slot_index = builder.add_virtual_target();
// dataset path bits (binary decomposition of leaf_index)
let d_path_bits = builder.split_le(slot_index,d_depth);
let d_path_bits = builder.split_le(slot_index,max_log2_n_slots);
// create virtual target for n_slots_per_dataset
let n_slots_per_dataset = builder.add_virtual_target();
// dataset last bits (binary decomposition of last_index = nleaves - 1)
let depth_target = builder.constant(F::from_canonical_u64(d_depth as u64));
let mut d_last_index = builder.exp(two,depth_target,d_depth);
d_last_index = builder.sub(d_last_index, one);
let d_last_bits = builder.split_le(d_last_index,d_depth);
let d_mask_bits = builder.split_le(d_last_index,d_depth+1);
let dataset_last_index = builder.sub(n_slots_per_dataset, one);
let d_last_bits = builder.split_le(dataset_last_index,max_log2_n_slots);
let d_mask_bits = builder.split_le(dataset_last_index,max_log2_n_slots+1);
// dataset Merkle path (sibling hashes from leaf to root)
let d_merkle_path = MerkleProofTarget {
path: (0..d_depth).map(|_| builder.add_virtual_hash()).collect(),
path: (0..max_log2_n_slots).map(|_| builder.add_virtual_hash()).collect(),
};
// create MerkleTreeTargets struct
@ -263,7 +163,7 @@ impl<
// dataset reconstructed root
let d_reconstructed_root =
MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit(builder, &mut d_targets);
MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut d_targets, max_log2_n_slots);
// expected Merkle root
let d_expected_root = builder.add_virtual_hash();
@ -279,24 +179,23 @@ impl<
let mut slot_sample_proofs = vec![];
let entropy_target = builder.add_virtual_hash();
//TODO: this can probably be optimized by supplying nCellsPerSlot as input to the circuit
let b_depth_target = builder.constant(F::from_canonical_u64(BOT_DEPTH as u64));
let mut b_last_index = builder.exp(two,b_depth_target,BOT_DEPTH);
b_last_index = builder.sub(b_last_index, one);
let b_last_bits = builder.split_le(b_last_index,BOT_DEPTH);
// virtual target for n_cells_per_slot
let n_cells_per_slot = builder.add_virtual_target();
let b_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1);
let slot_last_index = builder.sub(n_cells_per_slot, one);
let mut b_last_bits = builder.split_le(slot_last_index,max_depth);
let mut b_mask_bits = builder.split_le(slot_last_index,max_depth);
let s_depth_target = builder.constant(F::from_canonical_u64(MAX_DEPTH as u64));
let mut s_last_index = builder.exp(two,s_depth_target,MAX_DEPTH);
s_last_index = builder.sub(s_last_index, one);
let s_last_bits = builder.split_le(s_last_index,MAX_DEPTH);
let s_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1);
let mut s_last_bits = b_last_bits.split_off(block_tree_depth);
let mut s_mask_bits = b_mask_bits.split_off(block_tree_depth);
for i in 0..N_SAMPLES{
b_mask_bits.push(BoolTarget::new_unsafe(zero.clone()));
s_mask_bits.push(BoolTarget::new_unsafe(zero.clone()));
for i in 0..n_samples{
// cell data targets
let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
let mut data_i = (0..n_field_elems_per_cell).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
let mut hash_inputs:Vec<Target>= Vec::new();
hash_inputs.extend_from_slice(&data_i);
@ -312,15 +211,15 @@ impl<
}
}
// paths
let mut b_path_bits = Self::calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr);
let mut s_path_bits = b_path_bits.split_off(BOT_DEPTH);
let mut b_path_bits = self.calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr);
let mut s_path_bits = b_path_bits.split_off(block_tree_depth);
let mut b_merkle_path = MerkleProofTarget {
path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(),
path: (0..block_tree_depth).map(|_| builder.add_virtual_hash()).collect(),
};
let mut s_merkle_path = MerkleProofTarget {
path: (0..(MAX_DEPTH - BOT_DEPTH)).map(|_| builder.add_virtual_hash()).collect(),
path: (0..(max_depth - block_tree_depth)).map(|_| builder.add_virtual_hash()).collect(),
};
let mut block_targets = MerkleTreeTargets {
@ -332,7 +231,7 @@ impl<
};
// reconstruct block root
let b_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets);
let b_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets, block_tree_depth);
let mut slot_targets = MerkleTreeTargets {
leaf: b_root,
@ -343,7 +242,7 @@ impl<
};
// reconstruct slot root with block root as leaf
let slot_reconstructed_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets);
let slot_reconstructed_root = MerkleTreeCircuit::<F,D>::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets, max_depth-block_tree_depth);
// check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS {
@ -361,204 +260,84 @@ impl<
}
DatasetTargets{
dataset_proof: d_targets.merkle_path,
dataset_root: d_expected_root,
cell_data: data_targets,
SampleTargets {
entropy: entropy_target,
dataset_root: d_expected_root,
slot_index,
slot_root: d_targets.leaf,
slot_proofs: slot_sample_proofs,
n_cells_per_slot,
n_slots_per_dataset,
slot_proof: d_targets.merkle_path,
cell_data: data_targets,
merkle_paths: slot_sample_proofs,
}
}
pub fn calculate_cell_index_bits(builder: &mut CircuitBuilder::<F, D>, entropy: &HashOutTarget, slot_root: &HashOutTarget, ctr: &HashOutTarget) -> Vec<BoolTarget> {
pub fn calculate_cell_index_bits(&self, builder: &mut CircuitBuilder::<F, D>, entropy: &HashOutTarget, slot_root: &HashOutTarget, ctr: &HashOutTarget) -> Vec<BoolTarget> {
let mut hash_inputs:Vec<Target>= Vec::new();
hash_inputs.extend_from_slice(&entropy.elements);
hash_inputs.extend_from_slice(&slot_root.elements);
hash_inputs.extend_from_slice(&ctr.elements);
let hash_out = builder.hash_n_to_hash_no_pad::<HF>(hash_inputs);
let cell_index_bits = builder.low_bits(hash_out.elements[0], MAX_DEPTH, 64);
let cell_index_bits = builder.low_bits(hash_out.elements[0], self.params.max_depth, 64);
cell_index_bits
}
pub fn sample_slot_assign_witness(
&mut self,
&self,
pw: &mut PartialWitness<F>,
targets: &mut DatasetTargets,
slot_index:usize,
entropy:usize,
targets: &mut SampleTargets,
witnesses: SampleCircuitInput<F, D>,
){
// dataset proof
let d_proof = self.tree.tree.get_proof(slot_index).unwrap();
// circuit params
let CircuitParams {
max_depth,
max_log2_n_slots,
block_tree_depth,
n_field_elems_per_cell,
n_samples,
} = self.params;
// assign n_cells_per_slot
pw.set_target(targets.n_cells_per_slot, witnesses.n_cells_per_slot);
// assign n_slots_per_dataset
pw.set_target(targets.n_slots_per_dataset, witnesses.n_slots_per_dataset);
// assign dataset proof
for (i, sibling_hash) in d_proof.path.iter().enumerate() {
// TODO: fix this HashOutTarget later
let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() {
pw.set_target(targets.dataset_proof.path[i].elements[j], sibling_hash_out[j]);
}
for (i, sibling_hash) in witnesses.slot_proof.iter().enumerate() {
pw.set_hash_target(targets.slot_proof.path[i], *sibling_hash);
}
// assign slot index
pw.set_target(targets.slot_index, F::from_canonical_u64(slot_index as u64));
pw.set_target(targets.slot_index, witnesses.slot_index);
// assign the expected Merkle root of dataset to the target
let expected_root = self.tree.tree.root().unwrap();
let expected_root_hash_out = expected_root.to_vec();
for j in 0..expected_root_hash_out.len() {
pw.set_target(targets.dataset_root.elements[j], expected_root_hash_out[j]);
}
pw.set_hash_target(targets.dataset_root, witnesses.dataset_root);
// the sampled slot
let slot = &self.slot_trees[slot_index];
let slot_root = slot.tree.tree.root().unwrap();
pw.set_hash_target(targets.slot_root, slot_root);
// assign the sampled slot
pw.set_hash_target(targets.slot_root, witnesses.slot_root);
// assign entropy
for (i, element) in targets.entropy.elements.iter().enumerate() {
if(i==0) {
pw.set_target(*element, F::from_canonical_u64(entropy as u64));
}else {
pw.set_target(*element, F::from_canonical_u64(0));
}
}
// pw.set_target(targets.entropy, F::from_canonical_u64(entropy as u64));
assign_hash_out_targets(pw, &targets.entropy.elements, &witnesses.entropy);
// do the sample N times
for i in 0..N_SAMPLES {
let cell_index_bits = calculate_cell_index_bits(entropy,slot_root,i+1);
for i in 0..n_samples {
let cell_index_bits = calculate_cell_index_bits(&witnesses.entropy,witnesses.slot_root,i+1,max_depth);
let cell_index = bits_le_padded_to_usize(&cell_index_bits);
// assign cell data
let leaf = &slot.cell_data[cell_index];
for j in 0..leaf.len(){
let leaf = witnesses.cell_data[i].clone();
for j in 0..n_field_elems_per_cell{
pw.set_target(targets.cell_data[i][j], leaf[j]);
}
// assign proof for that cell
let cell_proof = slot.get_proof(cell_index);
for (k, sibling_hash) in cell_proof.path.iter().enumerate() {
let sibling_hash_out = sibling_hash.to_vec();
for j in 0..sibling_hash_out.len() {
pw.set_target(targets.slot_proofs[i].path[k].elements[j], sibling_hash_out[j]);
}
let cell_proof = witnesses.merkle_paths[i].clone();
for k in 0..max_depth {
pw.set_hash_target(targets.merkle_paths[i].path[k], cell_proof[k])
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::iop::witness::PartialWitness;
//types for tests
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
#[test]
fn test_sample_cells() {
let dataset_t = DatasetTreeCircuit::<F,D>::default();
let slot_index = 2;
let entropy = 123;
let proof = dataset_t.sample_slot(slot_index,entropy);
let res = dataset_t.verify_sampling(proof).unwrap();
assert_eq!(res, true);
}
// sample cells with full set of fake data
// this test takes too long, see next test
#[test]
fn test_sample_cells_circuit() -> Result<()> {
let mut dataset_t = DatasetTreeCircuit::<F,D>::default();
let slot_index = 2;
let entropy = 123;
// sanity check
let proof = dataset_t.sample_slot(slot_index,entropy);
let slot_root = dataset_t.slot_trees[slot_index].tree.tree.root().unwrap();
let res = dataset_t.verify_sampling(proof).unwrap();
assert_eq!(res, true);
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut targets = dataset_t.sample_slot_circuit(&mut builder);
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
dataset_t.sample_slot_assign_witness(&mut pw, &mut targets,slot_index,entropy);
// build the circuit
let data = builder.build::<C>();
println!("circuit size = {:?}", data.common.degree_bits());
// Prove the circuit with the assigned witness
let start_time = Instant::now();
let proof_with_pis = data.prove(pw)?;
println!("prove_time = {:?}", start_time.elapsed());
// verify the proof
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// same as above but with fake data for the specific slot to be sampled
#[test]
fn test_sample_cells_circuit_from_selected_slot() -> Result<()> {
let mut dataset_t = DatasetTreeCircuit::<F,D>::new_for_testing();
let slot_index = TESTING_SLOT_INDEX;
let entropy = 123;
// sanity check
let proof = dataset_t.sample_slot(slot_index,entropy);
let slot_root = dataset_t.slot_trees[slot_index].tree.tree.root().unwrap();
let res = dataset_t.verify_sampling(proof).unwrap();
assert_eq!(res, true);
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut targets = dataset_t.sample_slot_circuit(&mut builder);
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
dataset_t.sample_slot_assign_witness(&mut pw, &mut targets,slot_index,entropy);
// build the circuit
let data = builder.build::<C>();
println!("circuit size = {:?}", data.common.degree_bits());
// Prove the circuit with the assigned witness
let start_time = Instant::now();
let proof_with_pis = data.prove(pw)?;
println!("prove_time = {:?}", start_time.elapsed());
// verify the proof
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
}

View File

@ -1,12 +1,14 @@
use plonky2::hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS, RichField};
use plonky2::iop::witness::PartialWitness;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichField};
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_data::{CircuitData, VerifierCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher};
use plonky2::plonk::proof::{Proof, ProofWithPublicInputs};
use plonky2_field::extension::Extendable;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use crate::circuits::params::{HF, MAX_DEPTH};
use crate::circuits::params::HF;
use anyhow::Result;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::plonk::circuit_builder::CircuitBuilder;
// --------- helper functions ---------
@ -23,30 +25,28 @@ pub(crate) fn usize_to_bits_le_padded(index: usize, bit_length: usize) -> Vec<bo
bits
}
/// calculate the sampled cell index from entropy, slot root, and counter
pub(crate) fn calculate_cell_index_bits<F: RichField>(entropy: usize, slot_root: HashOut<F>, ctr: usize) -> Vec<bool> {
let entropy_field = F::from_canonical_u64(entropy as u64);
let mut entropy_as_digest = HashOut::<F>::ZERO;
entropy_as_digest.elements[0] = entropy_field;
pub(crate) fn calculate_cell_index_bits<F: RichField>(entropy: &Vec<F>, slot_root: HashOut<F>, ctr: usize, depth: usize) -> Vec<bool> {
let ctr_field = F::from_canonical_u64(ctr as u64);
let mut ctr_as_digest = HashOut::<F>::ZERO;
ctr_as_digest.elements[0] = ctr_field;
let mut hash_inputs = Vec::new();
hash_inputs.extend_from_slice(&entropy_as_digest.elements);
hash_inputs.extend_from_slice(&entropy);
hash_inputs.extend_from_slice(&slot_root.elements);
hash_inputs.extend_from_slice(&ctr_as_digest.elements);
let hash_output = HF::hash_no_pad(&hash_inputs);
let cell_index_bytes = hash_output.elements[0].to_canonical_u64();
// let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH);
let cell_index_bits = usize_to_bits_le_padded(cell_index_bytes as usize, MAX_DEPTH);
let cell_index_bits = usize_to_bits_le_padded(cell_index_bytes as usize, depth);
cell_index_bits
}
pub(crate) fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec<bool> {
bytes.iter()
.flat_map(|byte| (0..8u8).map(move |i| (byte >> i) & 1 == 1))
.take(n)
.collect()
}
/// Converts a vector of bits (LSB first) into an index (usize).
pub(crate) fn bits_le_padded_to_usize(bits: &[bool]) -> usize {
bits.iter().enumerate().fold(0usize, |acc, (i, &bit)| {
@ -87,4 +87,54 @@ pub fn verify<
proof,
public_inputs,
})
}
/// assign a vec of bool values to a vec of BoolTargets
pub(crate) fn assign_bool_targets<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>(
pw: &mut PartialWitness<F>,
bool_targets: &Vec<BoolTarget>,
bools: Vec<bool>,
){
for (i, bit) in bools.iter().enumerate() {
pw.set_bool_target(bool_targets[i], *bit);
}
}
/// assign a vec of field elems to hash out target elements
pub(crate) fn assign_hash_out_targets<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>(
pw: &mut PartialWitness<F>,
hash_out_elements_targets: &[Target],
hash_out_elements: &[F],
){
for j in 0..NUM_HASH_OUT_ELTS {
pw.set_target(hash_out_elements_targets[j], hash_out_elements[j]);
}
}
/// helper fn to multiply a HashOutTarget by a Target
pub(crate) fn mul_hash_out_target<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>(builder: &mut CircuitBuilder<F, D>, t: &Target, hash_target: &mut HashOutTarget) -> HashOutTarget {
let mut mul_elements = vec![];
for i in 0..NUM_HASH_OUT_ELTS {
mul_elements.push(builder.mul(hash_target.elements[i], *t));
}
HashOutTarget::from_vec(mul_elements)
}
/// helper fn to add AND assign a HashOutTarget (hot) to a mutable HashOutTarget (mut_hot)
pub(crate) fn add_assign_hash_out_target<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>(builder: &mut CircuitBuilder<F, D>, mut_hot: &mut HashOutTarget, hot: &HashOutTarget) {
for i in 0..NUM_HASH_OUT_ELTS {
mut_hot.elements[i] = (builder.add(mut_hot.elements[i], hot.elements[i]));
}
}

View File

@ -1,2 +1,4 @@
pub mod circuits;
pub mod merkle_tree;
pub mod merkle_tree;
pub mod proof_input;
pub mod tests;

View File

@ -1,378 +0,0 @@
// An adapted implementation of Merkle tree
// based on the original plonky2 merkle tree implementation
use core::mem::MaybeUninit;
use core::slice;
use anyhow::{ensure, Result};
use plonky2_maybe_rayon::*;
use serde::{Deserialize, Serialize};
use plonky2::hash::hash_types::{HashOutTarget, RichField};
use plonky2::plonk::config::{GenericHashOut, Hasher};
use plonky2::util::log2_strict;
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleCap<F: RichField, H: Hasher<F>>(pub Vec<H::Hash>);
impl<F: RichField, H: Hasher<F>> Default for MerkleCap<F, H> {
fn default() -> Self {
Self(Vec::new())
}
}
impl<F: RichField, H: Hasher<F>> MerkleCap<F, H> {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn height(&self) -> usize {
log2_strict(self.len())
}
pub fn flatten(&self) -> Vec<F> {
self.0.iter().flat_map(|&h| h.to_vec()).collect()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleTree<F: RichField, H: Hasher<F>> {
pub leaves: Vec<Vec<F>>,
pub digests: Vec<H::Hash>,
pub cap: MerkleCap<F, H>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleProof<F: RichField, H: Hasher<F>> {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub siblings: Vec<H::Hash>,
}
impl<F: RichField, H: Hasher<F>> MerkleProof<F, H> {
pub fn len(&self) -> usize {
self.siblings.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleProofTarget {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub siblings: Vec<HashOutTarget>,
}
impl<F: RichField, H: Hasher<F>> Default for MerkleTree<F, H> {
fn default() -> Self {
Self {
leaves: Vec::new(),
digests: Vec::new(),
cap: MerkleCap::default(),
}
}
}
pub(crate) fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
assert!(v.capacity() >= len);
let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
unsafe {
slice::from_raw_parts_mut(v_ptr, len)
}
}
pub(crate) fn fill_subtree<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
) -> H::Hash {
assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
if digests_buf.is_empty() {
H::hash_or_noop(&leaves[0])
} else {
let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
let (left_digest, right_digest) = plonky2_maybe_rayon::join(
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
);
left_digest_mem.write(left_digest);
right_digest_mem.write(right_digest);
H::two_to_one(left_digest, right_digest)
}
}
pub(crate) fn fill_digests_buf<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
cap_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
cap_height: usize,
) {
if digests_buf.is_empty() {
debug_assert_eq!(cap_buf.len(), leaves.len());
cap_buf
.par_iter_mut()
.zip(leaves)
.for_each(|(cap_buf, leaf)| {
cap_buf.write(H::hash_or_noop(leaf));
});
return;
}
let subtree_digests_len = digests_buf.len() >> cap_height;
let subtree_leaves_len = leaves.len() >> cap_height;
let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
assert_eq!(digests_chunks.len(), cap_buf.len());
assert_eq!(digests_chunks.len(), leaves_chunks.len());
digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
|((subtree_digests, subtree_cap), subtree_leaves)| {
subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
},
);
}
pub(crate) fn merkle_tree_prove<F: RichField, H: Hasher<F>>(
leaf_index: usize,
leaves_len: usize,
cap_height: usize,
digests: &[H::Hash],
) -> Vec<H::Hash> {
let num_layers = log2_strict(leaves_len) - cap_height;
debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
let digest_len = 2 * (leaves_len - (1 << cap_height));
assert_eq!(digest_len, digests.len());
let digest_tree: &[H::Hash] = {
let tree_index = leaf_index >> num_layers;
let tree_len = digest_len >> cap_height;
&digests[tree_len * tree_index..tree_len * (tree_index + 1)]
};
// Mask out high bits to get the index within the sub-tree.
let mut pair_index = leaf_index & ((1 << num_layers) - 1);
(0..num_layers)
.map(|i| {
let parity = pair_index & 1;
pair_index >>= 1;
// The layers' data is interleaved as follows:
// [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...].
// Each of the above is a pair of siblings.
// `pair_index` is the index of the pair within layer `i`.
// The index of that the pair within `digests` is
// `pair_index * 2 ** (i + 1) + (2 ** i - 1)`.
let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
// We have an index for the _pair_, but we want the index of the _sibling_.
// Double the pair index to get the index of the left sibling. Conditionally add `1`
// if we are to retrieve the right sibling.
let sibling_index = 2 * siblings_index + (1 - parity);
digest_tree[sibling_index]
})
.collect()
}
impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
let log2_leaves_len = log2_strict(leaves.len());
assert!(
cap_height <= log2_leaves_len,
"cap_height={} should be at most log2(leaves.len())={}",
cap_height,
log2_leaves_len
);
let num_digests = 2 * (leaves.len() - (1 << cap_height));
let mut digests = Vec::with_capacity(num_digests);
let len_cap = 1 << cap_height;
let mut cap = Vec::with_capacity(len_cap);
let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
unsafe {
// SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to
// `num_digests` and `len_cap`, resp.
digests.set_len(num_digests);
cap.set_len(len_cap);
}
Self {
leaves,
digests,
cap: MerkleCap(cap),
}
}
pub fn get(&self, i: usize) -> &[F] {
&self.leaves[i]
}
// Create a Merkle proof from a leaf index.
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
let cap_height = log2_strict(self.cap.len());
let siblings =
merkle_tree_prove::<F, H>(leaf_index, self.leaves.len(), cap_height, &self.digests);
MerkleProof { siblings }
}
}
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given root.
pub fn verify_merkle_proof<F: RichField, H: Hasher<F>>(
leaf_data: Vec<F>,
leaf_index: usize,
merkle_root: H::Hash,
proof: &MerkleProof<F, H>,
) -> Result<()> {
let merkle_cap = MerkleCap(vec![merkle_root]);
verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof)
}
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given cap.
pub fn verify_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
leaf_data: Vec<F>,
leaf_index: usize,
merkle_cap: &MerkleCap<F, H>,
proof: &MerkleProof<F, H>,
) -> Result<()> {
verify_batch_merkle_proof_to_cap(
&[leaf_data.clone()],
&[proof.siblings.len()],
leaf_index,
merkle_cap,
proof,
)
}
/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the
/// given cap.
pub fn verify_batch_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
leaf_data: &[Vec<F>],
leaf_heights: &[usize],
mut leaf_index: usize,
merkle_cap: &MerkleCap<F, H>,
proof: &MerkleProof<F, H>,
) -> Result<()> {
assert_eq!(leaf_data.len(), leaf_heights.len());
let mut current_digest = H::hash_or_noop(&leaf_data[0]);
let mut current_height = leaf_heights[0];
let mut leaf_data_index = 1;
for &sibling_digest in &proof.siblings {
let bit = leaf_index & 1;
leaf_index >>= 1;
current_digest = if bit == 1 {
H::two_to_one(sibling_digest, current_digest)
} else {
H::two_to_one(current_digest, sibling_digest)
};
current_height -= 1;
if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] {
let mut new_leaves = current_digest.to_vec();
new_leaves.extend_from_slice(&leaf_data[leaf_data_index]);
current_digest = H::hash_or_noop(&new_leaves);
leaf_data_index += 1;
}
}
assert_eq!(leaf_data_index, leaf_data.len());
ensure!(
current_digest == merkle_cap.0[leaf_index],
"Invalid Merkle proof."
);
Ok(())
}
#[cfg(test)]
pub(crate) mod tests {
use anyhow::Result;
use super::*;
use plonky2::field::extension::Extendable;
use crate::merkle_tree::capped_tree::verify_merkle_proof_to_cap;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
pub(crate) fn random_data<F: RichField>(n: usize, k: usize) -> Vec<Vec<F>> {
(0..n).map(|_| F::rand_vec(k)).collect()
}
fn verify_all_leaves<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
leaves: Vec<Vec<F>>,
cap_height: usize,
) -> Result<()> {
let tree = MerkleTree::<F, C::Hasher>::new(leaves.clone(), cap_height);
for (i, leaf) in leaves.into_iter().enumerate() {
let proof = tree.prove(i);
verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?;
}
Ok(())
}
#[test]
#[should_panic]
fn test_cap_height_too_big() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let cap_height = log_n + 1; // Should panic if `cap_height > len_n`.
let leaves = random_data::<F>(1 << log_n, 7);
let _ = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
}
#[test]
fn test_cap_height_eq_log2_len() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, log_n)?;
Ok(())
}
#[test]
fn test_merkle_trees() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, 1)?;
Ok(())
}
}

View File

@ -5,7 +5,7 @@
use std::marker::PhantomData;
use anyhow::{ensure, Result};
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField};
use plonky2::hash::hash_types::{HashOut, RichField};
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
use std::ops::Shr;
@ -137,12 +137,6 @@ pub struct MerkleProof<F: RichField> {
pub zero: HashOut<F>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleProofTarget {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.
pub path: Vec<HashOutTarget>,
}
impl<F: RichField> MerkleProof<F> {
/// Reconstructs the root hash from the proof and the given leaf.
pub fn reconstruct_root(&self, leaf: HashOut<F>) -> Result<HashOut<F>> {

View File

@ -1,2 +1 @@
pub mod capped_tree;
pub mod merkle_safe;

View File

@ -0,0 +1,440 @@
use anyhow::Result;
use plonky2::hash::hash_types::{HashOut, RichField};
use plonky2::plonk::config::{GenericConfig, Hasher};
use plonky2_field::extension::Extendable;
use plonky2_field::types::Field;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use crate::circuits::params::HF;
use crate::proof_input::test_params::{BOT_DEPTH, DATASET_DEPTH, MAX_DEPTH, N_BLOCKS, N_CELLS, N_CELLS_IN_BLOCKS, N_FIELD_ELEMS_PER_CELL, N_SAMPLES, TESTING_SLOT_INDEX};
use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits, usize_to_bits_le_padded};
use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleTree};
#[derive(Clone)]
pub struct SlotTree<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub tree: MerkleTree<F>, // slot tree
pub block_trees: Vec<MerkleTree<F>>, // vec of block trees
pub cell_data: Vec<Cell<F,D>>, // cell data as field elements
}
#[derive(Clone)]
pub struct Cell<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub data: Vec<F>, // cell data as field elements
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> Default for Cell<F, D> {
/// default cell with random data
fn default() -> Self {
let data = (0..N_FIELD_ELEMS_PER_CELL)
.map(|j| F::rand())
.collect::<Vec<_>>();
Self{
data,
}
}
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> Default for SlotTree<F, D> {
/// slot tree with fake data, for testing only
fn default() -> Self {
// generate fake cell data
let mut cell_data = (0..N_CELLS)
.map(|i|{
Cell::<F,D>::default()
})
.collect::<Vec<_>>();
Self::new(cell_data)
}
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> SlotTree<F, D> {
/// Slot tree with fake data, for testing only
pub fn new_for_testing(cells: Vec<Cell<F, D>>) -> Self {
// Hash the cell data block to create leaves for one block
let leaves_block: Vec<HashOut<F>> = cells
.iter()
.map(|element| {
HF::hash_no_pad(&element.data)
})
.collect();
// Zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
// Create a block tree from the leaves of one block
let b_tree = Self::get_block_tree(&leaves_block);
// Now replicate this block tree for all N_BLOCKS blocks
let block_trees = vec![b_tree; N_BLOCKS];
// Get the roots of block trees
let block_roots = block_trees
.iter()
.map(|t| t.root().unwrap())
.collect::<Vec<_>>();
// Create the slot tree from block roots
let slot_tree = MerkleTree::<F>::new(&block_roots, zero).unwrap();
// Create the full cell data and cell hash by repeating the block data
let cell_data = vec![cells.clone(); N_BLOCKS].concat();
// Return the constructed Self
Self {
tree: slot_tree,
block_trees,
cell_data,
}
}
/// same as default but with supplied cell data
pub fn new(cells: Vec<Cell<F, D>>) -> Self {
let leaves: Vec<HashOut<F>> = cells
.iter()
.map(|element| {
HF::hash_no_pad(&element.data)
})
.collect();
let zero = HashOut {
elements: [F::ZERO; 4],
};
let block_trees = (0..N_BLOCKS as usize)
.map(|i| {
let start = i * N_CELLS_IN_BLOCKS;
let end = (i + 1) * N_CELLS_IN_BLOCKS;
Self::get_block_tree(&leaves[start..end].to_vec())
// MerkleTree::<F> { tree: b_tree }
})
.collect::<Vec<_>>();
let block_roots = block_trees.iter()
.map(|t| {
t.root().unwrap()
})
.collect::<Vec<_>>();
let slot_tree = MerkleTree::<F>::new(&block_roots, zero).unwrap();
Self {
tree: slot_tree,
block_trees,
cell_data: cells,
}
}
/// generates a proof for given leaf index
/// the path in the proof is a combined block and slot path to make up the full path
pub fn get_proof(&self, index: usize) -> MerkleProof<F> {
let block_index = index / N_CELLS_IN_BLOCKS;
let leaf_index = index % N_CELLS_IN_BLOCKS;
let block_proof = self.block_trees[block_index].get_proof(leaf_index).unwrap();
let slot_proof = self.tree.get_proof(block_index).unwrap();
// Combine the paths from the block and slot proofs
let mut combined_path = block_proof.path.clone();
combined_path.extend(slot_proof.path.clone());
MerkleProof::<F> {
index: index,
path: combined_path,
nleaves: self.cell_data.len(),
zero: block_proof.zero.clone(),
}
}
/// verify the given proof for slot tree, checks equality with given root
pub fn verify_cell_proof(&self, proof: MerkleProof<F>, root: HashOut<F>) -> anyhow::Result<bool> {
let mut block_path_bits = usize_to_bits_le_padded(proof.index, MAX_DEPTH);
let last_index = N_CELLS - 1;
let mut block_last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH);
let split_point = BOT_DEPTH;
let slot_last_bits = block_last_bits.split_off(split_point);
let slot_path_bits = block_path_bits.split_off(split_point);
let leaf_hash = HF::hash_no_pad(&self.cell_data[proof.index].data);
let mut block_path = proof.path;
let slot_path = block_path.split_off(split_point);
let block_res = MerkleProof::<F>::reconstruct_root2(leaf_hash, block_path_bits.clone(), block_last_bits.clone(), block_path);
let reconstructed_root = MerkleProof::<F>::reconstruct_root2(block_res.unwrap(), slot_path_bits, slot_last_bits, slot_path);
Ok(reconstructed_root.unwrap() == root)
}
fn get_block_tree(leaves: &Vec<HashOut<F>>) -> MerkleTree<F> {
let zero = HashOut {
elements: [F::ZERO; 4],
};
// Build the Merkle tree
let block_tree = MerkleTree::<F>::new(leaves, zero).unwrap();
block_tree
}
}
// ------ Dataset Tree --------
///dataset tree containing all slot trees
#[derive(Clone)]
pub struct DatasetTree<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> {
pub tree: MerkleTree<F>, // dataset tree
pub slot_trees: Vec<SlotTree<F, D>>, // vec of slot trees
}
/// Dataset Merkle proof struct, containing the dataset proof and N_SAMPLES proofs.
#[derive(Clone)]
pub struct DatasetProof<F: RichField> {
pub slot_index: F,
pub entropy: HashOut<F>,
pub dataset_proof: MerkleProof<F>, // proof for dataset level tree
pub slot_proofs: Vec<MerkleProof<F>>, // proofs for sampled slot, contains N_SAMPLES proofs
pub cell_data: Vec<Vec<F>>,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> Default for DatasetTree<F, D> {
/// dataset tree with fake data, for testing only
fn default() -> Self {
let mut slot_trees = vec![];
let n_slots = 1 << DATASET_DEPTH;
for i in 0..n_slots {
slot_trees.push(SlotTree::<F, D>::default());
}
Self::new(slot_trees)
}
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
> DatasetTree<F, D> {
/// dataset tree with fake data, for testing only
/// create data for only the TESTING_SLOT_INDEX in params file
pub fn new_for_testing() -> Self {
let mut slot_trees = vec![];
let n_slots = 1 << DATASET_DEPTH;
// zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
let zero_slot = SlotTree::<F, D> {
tree: MerkleTree::<F>::new(&[zero.clone()], zero.clone()).unwrap(),
block_trees: vec![],
cell_data: vec![],
};
for i in 0..n_slots {
if (i == TESTING_SLOT_INDEX) {
slot_trees.push(SlotTree::<F, D>::default());
} else {
slot_trees.push(zero_slot.clone());
}
}
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.root().unwrap()
})
.collect::<Vec<_>>();
let dataset_tree = MerkleTree::<F>::new(&slot_roots, zero).unwrap();
Self {
tree: dataset_tree,
slot_trees,
}
}
/// same as default but with supplied slot trees
pub fn new(slot_trees: Vec<SlotTree<F, D>>) -> Self {
// get the roots or slot trees
let slot_roots = slot_trees.iter()
.map(|t| {
t.tree.root().unwrap()
})
.collect::<Vec<_>>();
// zero hash
let zero = HashOut {
elements: [F::ZERO; 4],
};
let dataset_tree = MerkleTree::<F>::new(&slot_roots, zero).unwrap();
Self {
tree: dataset_tree,
slot_trees,
}
}
/// generates a dataset level proof for given slot index
/// just a regular merkle tree proof
pub fn get_proof(&self, index: usize) -> MerkleProof<F> {
let dataset_proof = self.tree.get_proof(index).unwrap();
dataset_proof
}
/// generates a proof for given slot index
/// also takes entropy so it can use it sample the slot
pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetProof<F> {
let dataset_proof = self.tree.get_proof(index).unwrap();
let slot = &self.slot_trees[index];
let slot_root = slot.tree.root().unwrap();
let mut slot_proofs = vec![];
let mut cell_data = vec![];
let entropy_field = F::from_canonical_u64(entropy as u64);
let mut entropy_as_digest = HashOut::<F>::ZERO;
entropy_as_digest.elements[0] = entropy_field;
// get the index for cell from H(slot_root|counter|entropy)
for i in 0..N_SAMPLES {
let cell_index_bits = calculate_cell_index_bits(&entropy_as_digest.elements.to_vec(), slot_root, i+1, MAX_DEPTH);
let cell_index = bits_le_padded_to_usize(&cell_index_bits);
let s_proof = slot.get_proof(cell_index);
slot_proofs.push(s_proof);
cell_data.push(slot.cell_data[cell_index].data.clone());
}
DatasetProof {
slot_index: F::from_canonical_u64(index as u64),
entropy: entropy_as_digest,
dataset_proof,
slot_proofs,
cell_data,
}
}
// verify the sampling - non-circuit version
pub fn verify_sampling(&self, proof: DatasetProof<F>) -> bool {
let slot = &self.slot_trees[proof.slot_index.to_canonical_u64() as usize];
let slot_root = slot.tree.root().unwrap();
// check dataset level proof
let d_res = proof.dataset_proof.verify(slot_root, self.tree.root().unwrap());
if (d_res.unwrap() == false) {
return false;
}
// sanity check
assert_eq!(N_SAMPLES, proof.slot_proofs.len());
// get the index for cell from H(slot_root|counter|entropy)
for i in 0..N_SAMPLES {
// let entropy_field = F::from_canonical_u64(proof.entropy as u64);
// let mut entropy_as_digest = HashOut::<F>::ZERO;
// entropy_as_digest.elements[0] = entropy_field;
let cell_index_bits = calculate_cell_index_bits(&proof.entropy.elements.to_vec(), slot_root, i+1, MAX_DEPTH);
let cell_index = bits_le_padded_to_usize(&cell_index_bits);
//check the cell_index is the same as one in the proof
assert_eq!(cell_index, proof.slot_proofs[i].index);
let s_res = slot.verify_cell_proof(proof.slot_proofs[i].clone(), slot_root);
if (s_res.unwrap() == false) {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::GenericConfig;
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use crate::circuits::sample_cells::{CircuitParams, DatasetTreeCircuit, SampleCircuitInput};
use crate::proof_input::test_params::{D, C, F, H, N_SLOTS};
#[test]
fn test_sample_cells() {
let dataset_t = DatasetTree::<F, D>::new_for_testing();
let slot_index = 2;
let entropy = 2;
let proof = dataset_t.sample_slot(slot_index,entropy);
let res = dataset_t.verify_sampling(proof);
assert_eq!(res, true);
}
#[test]
fn test_sample_cells_circuit_from_selected_slot() -> anyhow::Result<()> {
let mut dataset_t = DatasetTree::<F, D>::new_for_testing();
let slot_index = TESTING_SLOT_INDEX;
let entropy = 123;
// sanity check
let proof = dataset_t.sample_slot(slot_index,entropy);
let slot_root = dataset_t.slot_trees[slot_index].tree.root().unwrap();
let res = dataset_t.verify_sampling(proof.clone());
assert_eq!(res, true);
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let circuit_params = CircuitParams{
max_depth: MAX_DEPTH,
max_log2_n_slots: DATASET_DEPTH,
block_tree_depth: BOT_DEPTH,
n_field_elems_per_cell: N_FIELD_ELEMS_PER_CELL,
n_samples: N_SAMPLES,
};
let circ = DatasetTreeCircuit::new(circuit_params);
let mut targets = circ.sample_slot_circuit(&mut builder);
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
let mut slot_paths = vec![];
for i in 0..N_SAMPLES{
let path = proof.slot_proofs[i].path.clone();
slot_paths.push(path);
//TODO: need to be padded
}
let witness = SampleCircuitInput::<F,D>{
entropy: proof.entropy.elements.clone().to_vec(),
dataset_root: dataset_t.tree.root().unwrap(),
slot_index: proof.slot_index.clone(),
slot_root,
n_cells_per_slot: F::from_canonical_u64((2_u32.pow(MAX_DEPTH as u32)) as u64),
n_slots_per_dataset: F::from_canonical_u64((2_u32.pow(DATASET_DEPTH as u32)) as u64),
slot_proof: proof.dataset_proof.path.clone(),
cell_data: proof.cell_data.clone(),
merkle_paths: slot_paths,
};
println!("dataset ={:?}",dataset_t.slot_trees[0].tree.layers);
circ.sample_slot_assign_witness(&mut pw, &mut targets,witness);
// build the circuit
let data = builder.build::<C>();
println!("circuit size = {:?}", data.common.degree_bits());
// Prove the circuit with the assigned witness
let start_time = Instant::now();
let proof_with_pis = data.prove(pw)?;
println!("prove_time = {:?}", start_time.elapsed());
// verify the proof
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
}

View File

@ -0,0 +1,3 @@
pub mod gen_input;
pub mod test_params;
pub mod utils;

View File

@ -0,0 +1,39 @@
// config for generating input for proof circuit
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
// fake input params
// types
pub const D: usize = 2;
pub type C = PoseidonGoldilocksConfig;
pub type F = <C as GenericConfig<D>>::F; // this is the goldilocks field
pub type H = PoseidonHash;
// hardcoded params for generating proof input
pub const MAX_DEPTH: usize = 8; // depth of big tree (slot tree depth, includes block tree depth)
pub const MAX_SLOTS: usize = 256; // maximum number of slots
pub const CELL_SIZE: usize = 2048; // cell size in bytes
pub const BLOCK_SIZE: usize = 65536; // block size in bytes
pub const N_SAMPLES: usize = 5; // number of samples to prove
pub const ENTROPY: usize = 1234567; // external randomness
pub const SEED: usize = 12345; // seed for creating fake data TODO: not used now
pub const N_SLOTS: usize = 8; // number of slots in the dataset
pub const TESTING_SLOT_INDEX: usize = 2; //the index of the slot to be sampled
pub const N_CELLS: usize = 512; // number of cells in each slot
// computed constants
pub const GOLDILOCKS_F_SIZE: usize = 64;
pub const N_FIELD_ELEMS_PER_CELL: usize = CELL_SIZE * 8 / GOLDILOCKS_F_SIZE;
pub const BOT_DEPTH: usize = (BLOCK_SIZE/CELL_SIZE).ilog2() as usize; // block tree depth
pub const N_CELLS_IN_BLOCKS: usize = 1<< BOT_DEPTH; //2^BOT_DEPTH
pub const N_BLOCKS: usize = 1<<(MAX_DEPTH - BOT_DEPTH); // 2^(MAX_DEPTH - BOT_DEPTH)
pub const DATASET_DEPTH: usize = MAX_SLOTS.ilog2() as usize;
// TODO: load params

View File

@ -0,0 +1,308 @@
use anyhow::Result;
use plonky2::field::extension::Extendable;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS};
use plonky2::hash::hashing::PlonkyPermutation;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2::plonk::proof::{Proof, ProofWithPublicInputs};
use std::marker::PhantomData;
use std::os::macos::raw::stat;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use serde::Serialize;
use crate::circuits::keyed_compress::key_compress_circuit;
use crate::circuits::params::HF;
use crate::circuits::merkle_circuit::{MerkleProofTarget, MerkleTreeCircuit, MerkleTreeTargets};
use crate::circuits::utils::{add_assign_hash_out_target, assign_bool_targets, assign_hash_out_targets, mul_hash_out_target, usize_to_bits_le_padded};
use crate::merkle_tree::merkle_safe::MerkleTree;
use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER};
/// the input to the merkle tree circuit
#[derive(Clone)]
pub struct MerkleTreeCircuitInput<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>{
pub leaf: HashOut<F>,
pub path_bits: Vec<bool>,
pub last_bits: Vec<bool>,
pub mask_bits: Vec<bool>,
pub merkle_path: Vec<HashOut<F>>,
}
/// defines the computations inside the circuit and returns the targets used
/// NOTE: this is not used in the sampling circuit, see reconstruct_merkle_root_circuit_with_mask
pub fn build_circuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>(
builder: &mut CircuitBuilder::<F, D>,
depth: usize,
) -> (MerkleTreeTargets, HashOutTarget) {
// Create virtual targets
let leaf = builder.add_virtual_hash();
// path bits (binary decomposition of leaf_index)
let path_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1)
let last_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// last bits (binary decomposition of last_index = nleaves - 1)
let mask_bits = (0..depth+1).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
// Merkle path (sibling hashes from leaf to root)
let merkle_path = MerkleProofTarget {
path: (0..depth).map(|_| builder.add_virtual_hash()).collect(),
};
// create MerkleTreeTargets struct
let mut targets = MerkleTreeTargets{
leaf,
path_bits,
last_bits,
mask_bits,
merkle_path,
};
// Add Merkle proof verification constraints to the circuit
let reconstructed_root_target = MerkleTreeCircuit::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets, depth);
// Return MerkleTreeTargets
(targets, reconstructed_root_target)
}
/// assign the witness values in the circuit targets
/// this takes MerkleTreeCircuitInput and fills all required circuit targets
pub fn assign_witness<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
>(
pw: &mut PartialWitness<F>,
targets: &mut MerkleTreeTargets,
witnesses: MerkleTreeCircuitInput<F, D>
)-> Result<()> {
// Assign the leaf hash to the leaf target
pw.set_hash_target(targets.leaf, witnesses.leaf);
// Assign path bits
assign_bool_targets(pw, &targets.path_bits, witnesses.path_bits);
// Assign last bits
assign_bool_targets(pw, &targets.last_bits, witnesses.last_bits);
// Assign mask bits
assign_bool_targets(pw, &targets.mask_bits, witnesses.mask_bits);
// assign the Merkle path (sibling hashes) to the targets
for i in 0..targets.merkle_path.path.len() {
if i>=witnesses.merkle_path.len() { // pad with zeros
assign_hash_out_targets(pw, &targets.merkle_path.path[i].elements, &[F::ZERO; NUM_HASH_OUT_ELTS]);
continue
}
assign_hash_out_targets(pw, &targets.merkle_path.path[i].elements, &witnesses.merkle_path[i].elements)
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use plonky2::hash::hash_types::HashOut;
use plonky2::hash::poseidon::PoseidonHash;
use super::*;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2_field::goldilocks_field::GoldilocksField;
use crate::circuits::merkle_circuit::{MerkleTreeCircuit, };
use crate::circuits::sample_cells::{CircuitParams, DatasetTreeCircuit, SampleCircuitInput};
use crate::circuits::utils::usize_to_bits_le_padded;
use crate::merkle_tree::merkle_safe::MerkleTree;
use crate::proof_input::test_params::{D, C, F, H, N_SLOTS, MAX_DEPTH};
// NOTE: for now these tests don't check the reconstructed root is equal to expected_root
// will be fixed later, but for that test check the prove_single_cell tests
#[test]
fn test_build_circuit() -> anyhow::Result<()> {
// circuit params
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
// Generate random leaf data
let nleaves = 16; // Number of leaves
let max_depth = 4;
let data = (0..nleaves)
.map(|i| GoldilocksField::from_canonical_u64(i))
.collect::<Vec<_>>();
// Hash the data to obtain leaf hashes
let leaves: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| {
// Hash each field element to get the leaf hash
PoseidonHash::hash_no_pad(&[element])
})
.collect();
//initialize the Merkle tree
let zero_hash = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let tree = MerkleTree::<F>::new(&leaves, zero_hash)?;
// select leaf index to prove
let leaf_index: usize = 8;
// get the Merkle proof for the selected leaf
let proof = tree.get_proof(leaf_index)?;
// sanity check:
let check = proof.verify(tree.layers[0][leaf_index],tree.root().unwrap()).unwrap();
assert_eq!(check, true);
// get the expected Merkle root
let expected_root = tree.root()?;
// create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth);
// expected Merkle root
let expected_root = builder.add_virtual_hash();
// check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS {
builder.connect(expected_root.elements[i], reconstructed_root_target.elements[i]);
}
let path_bits = usize_to_bits_le_padded(leaf_index, max_depth);
let last_index = (nleaves - 1) as usize;
let last_bits = usize_to_bits_le_padded(last_index, max_depth);
let mask_bits = usize_to_bits_le_padded(last_index, max_depth+1);
// circuit input
let circuit_input = MerkleTreeCircuitInput::<F, D>{
leaf: tree.layers[0][leaf_index],
path_bits,
last_bits,
mask_bits,
merkle_path: proof.path,
};
// create a PartialWitness and assign
let mut pw = PartialWitness::new();
assign_witness(&mut pw, &mut targets, circuit_input)?;
pw.set_hash_target(expected_root, tree.root().unwrap());
// build the circuit
let data = builder.build::<C>();
// Prove the circuit with the assigned witness
let proof_with_pis = data.prove(pw)?;
// verify the proof
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// same as test above but for all leaves
#[test]
fn test_verify_all_leaves() -> anyhow::Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type H = PoseidonHash;
let nleaves = 16; // Number of leaves
let max_depth = 4;
let data = (0..nleaves)
.map(|i| GoldilocksField::from_canonical_u64(i as u64))
.collect::<Vec<_>>();
// Hash the data to obtain leaf hashes
let leaves: Vec<HashOut<GoldilocksField>> = data
.iter()
.map(|&element| {
// Hash each field element to get the leaf hash
PoseidonHash::hash_no_pad(&[element])
})
.collect();
let zero_hash = HashOut {
elements: [GoldilocksField::ZERO; 4],
};
let tree = MerkleTree::<F>::new(&leaves, zero_hash)?;
let expected_root = tree.root()?;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth);
// expected Merkle root
let expected_root_target = builder.add_virtual_hash();
// check equality with expected root
for i in 0..NUM_HASH_OUT_ELTS {
builder.connect(expected_root_target.elements[i], reconstructed_root_target.elements[i]);
}
let data = builder.build::<C>();
for leaf_index in 0..nleaves {
let proof = tree.get_proof(leaf_index)?;
let check = proof.verify(tree.layers[0][leaf_index], expected_root)?;
assert!(
check,
"Merkle proof verification failed for leaf index {}",
leaf_index
);
let mut pw = PartialWitness::new();
let path_bits = usize_to_bits_le_padded(leaf_index, max_depth);
let last_index = (nleaves - 1) as usize;
let last_bits = usize_to_bits_le_padded(last_index, max_depth);
let mask_bits = usize_to_bits_le_padded(last_index, max_depth+1);
// circuit input
let circuit_input = MerkleTreeCircuitInput::<F, D>{
leaf: tree.layers[0][leaf_index],
path_bits,
last_bits,
mask_bits,
merkle_path: proof.path,
};
assign_witness(&mut pw, &mut targets, circuit_input)?;
pw.set_hash_target(expected_root_target, expected_root);
let proof_with_pis = data.prove(pw)?;
let verifier_data = data.verifier_data();
assert!(
verifier_data.verify(proof_with_pis).is_ok(),
"Merkle proof verification failed in circuit for leaf index {}",
leaf_index
);
}
Ok(())
}
}

View File

@ -0,0 +1 @@
pub mod merkle_circuit;