add tree recursion approach2

This commit is contained in:
M Alghazwi 2024-12-30 11:40:13 +03:00
parent aac4bfc39e
commit c717ab2770
No known key found for this signature in database
GPG Key ID: 646E567CAD7DB607
11 changed files with 675 additions and 156 deletions

View File

@ -12,7 +12,7 @@ use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2::plonk::proof::ProofWithPublicInputs;
use crate::sponge::{hash_bytes_no_padding, hash_n_with_padding};
use crate::sponge::hash_bytes_no_padding;
use crate::params::{C, D, F};
/// generates circuit input (SampleCircuitInput) from fake data for testing
@ -388,16 +388,16 @@ impl<
}
}
// build the sampling circuit
// returns the proof and verifier ci
/// build the sampling circuit
/// returns the proof and circuit data
pub fn build_circuit(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData<F, C, D>, PartialWitness<F>)>{
let (data, pw, _) = build_circuit_with_targets(n_samples, slot_index).unwrap();
Ok((data, pw))
}
// build the sampling circuit ,
// returns the proof and verifier ci and targets
/// build the sampling circuit ,
/// returns the proof, circuit data, and targets
pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData<F, C, D>, PartialWitness<F>, SampleTargets)>{
// get input
let mut params = TestParams::default();
@ -428,7 +428,7 @@ pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow
Ok((data, pw, targets))
}
// prove the circuit
/// prove the circuit
pub fn prove_circuit(data: &CircuitData<F, C, D>, pw: &PartialWitness<F>) -> anyhow::Result<ProofWithPublicInputs<F, C, D>>{
// Prove the circuit with the assigned witness
let proof_with_pis = data.prove(pw.clone())?;
@ -436,22 +436,26 @@ pub fn prove_circuit(data: &CircuitData<F, C, D>, pw: &PartialWitness<F>) -> any
Ok(proof_with_pis)
}
/// returns exactly M default circuit input
pub fn get_m_default_circ_input<const M: usize>() -> [SampleCircuitInput<codex_plonky2_circuits::recursion::params::F,D>; M]{
let params = TestParams::default();
let one_circ_input = gen_testing_circuit_input::<codex_plonky2_circuits::recursion::params::F,D>(&params);
let circ_input: [SampleCircuitInput<codex_plonky2_circuits::recursion::params::F,D>; M] = (0..M)
.map(|_| one_circ_input.clone())
.collect::<Vec<_>>()
.try_into().unwrap();
circ_input
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2::plonk::config::GenericConfig;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use codex_plonky2_circuits::circuits::params::CircuitParams;
use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs, aggregate_sampling_proofs_tree, aggregate_sampling_proofs_tree2};
use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleTargets};
use codex_plonky2_circuits::recursion::params::RecursionTreeParams;
use plonky2::plonk::proof::ProofWithPublicInputs;
use plonky2_poseidon2::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer};
use crate::json::write_bytes_to_file;
use codex_plonky2_circuits::recursion::simple_recursion2::{SimpleRecursionCircuit, SimpleRecursionInput};
use codex_plonky2_circuits::circuits::sample_cells::SampleCircuit;
// use crate::params::{C, D, F};
// Test sample cells (non-circuit)

View File

@ -5,6 +5,3 @@ pub mod utils;
pub mod json;
pub mod tests;
mod sponge;
pub mod cyclic_recursion;
pub mod tree_recursion;
pub mod simple_recursion;

View File

@ -4,22 +4,18 @@
mod tests {
use std::time::Instant;
use anyhow::Result;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField};
use plonky2::hash::hash_types::HashOut;
use plonky2::hash::hashing::hash_n_to_hash_no_pad;
use plonky2::hash::poseidon::{PoseidonHash, PoseidonPermutation};
use plonky2::iop::witness::{PartialWitness, PartitionWitness, WitnessWrite};
use plonky2::hash::poseidon::PoseidonPermutation;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data;
use plonky2::recursion::dummy_circuit::cyclic_base_proof;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::GenericConfig;
use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit;
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use crate::gen_input::gen_testing_circuit_input;
use crate::params::TestParams;
use codex_plonky2_circuits::recursion::cyclic_recursion::{common_data_for_recursion, CyclicCircuit};
use codex_plonky2_circuits::recursion::cyclic_recursion::CyclicCircuit;
/// Uses cyclic recursion to sample the dataset
@ -33,7 +29,6 @@ mod tests {
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut params = TestParams::default();
params.n_samples = 10;
@ -87,7 +82,6 @@ mod tests {
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut params = TestParams::default();
params.n_samples = 10;

View File

@ -1,2 +1,6 @@
pub mod merkle_circuit;
pub mod merkle;
pub mod simple_recursion;
pub mod cyclic_recursion;
pub mod tree_recursion1;
pub mod tree_recursion2;

View File

@ -1,4 +1,4 @@
// tests for simple recursion
// tests for simple recursion approaches
use std::time::Instant;
use plonky2::hash::hash_types::HashOut;
@ -6,17 +6,17 @@ use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2_field::types::Field;
use codex_plonky2_circuits::recursion::params::RecursionTreeParams;
use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs, aggregate_sampling_proofs_tree};
use codex_plonky2_circuits::recursion::simple_recursion2::{SimpleRecursionCircuit, SimpleRecursionInput};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs,SimpleRecursionCircuit, SimpleRecursionInput};
use codex_plonky2_circuits::recursion::simple_tree_recursion::aggregate_sampling_proofs_tree;
use plonky2_poseidon2::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer};
use crate::gen_input::{build_circuit, prove_circuit};
use crate::json::write_bytes_to_file;
use crate::params::{C, F, D};
use crate::params::{C, D, F};
// Test recursion
// Test simple recursion
#[test]
fn test_recursion() -> anyhow::Result<()> {
fn test_simple_recursion() -> anyhow::Result<()> {
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
@ -66,7 +66,7 @@ fn test_recursion() -> anyhow::Result<()> {
// Test simple tree recursion
#[test]
fn test_tree_recursion() -> anyhow::Result<()> {
fn test_simple_tree_recursion() -> anyhow::Result<()> {
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
@ -84,30 +84,32 @@ fn test_tree_recursion() -> anyhow::Result<()> {
let data = data.unwrap();
println!("inner circuit size = {:?}", data.common.degree_bits());
let gate_serializer = DefaultGateSerializer;
let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
let data_bytes = data.to_bytes(&gate_serializer, &generator_serializer).unwrap();
println!("inner proof circuit data size = {} bytes", data_bytes.len());
let file_path = "inner_circ_data.bin";
// Write data to the file
write_bytes_to_file(data_bytes, file_path).unwrap();
println!("Data written to {}", file_path);
// serialization
// let gate_serializer = DefaultGateSerializer;
// let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
// let data_bytes = data.to_bytes(&gate_serializer, &generator_serializer).unwrap();
// println!("inner proof circuit data size = {} bytes", data_bytes.len());
// let file_path = "inner_circ_data.bin";
// // Write data to the file
// write_bytes_to_file(data_bytes, file_path).unwrap();
// println!("Data written to {}", file_path);
let start_time = Instant::now();
let (proof, vd_agg) = aggregate_sampling_proofs_tree(&proofs_with_pi, data)?;
println!("prove_time = {:?}", start_time.elapsed());
println!("num of public inputs = {}", proof.public_inputs.len());
println!("agg pub input = {:?}", proof.public_inputs);
println!("outer circuit size = {:?}", vd_agg.common.degree_bits());
// let gate_serializer = DefaultGateSerializer;
// let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
let outer_data_bytes = vd_agg.to_bytes(&gate_serializer, &generator_serializer).unwrap();
println!("outer proof circuit data size = {} bytes", outer_data_bytes.len());
let file_path = "outer_circ_data.bin";
// Write data to the file
write_bytes_to_file(outer_data_bytes, file_path).unwrap();
println!("Data written to {}", file_path);
// serialization
// // let gate_serializer = DefaultGateSerializer;
// // let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
// let outer_data_bytes = vd_agg.to_bytes(&gate_serializer, &generator_serializer).unwrap();
// println!("outer proof circuit data size = {} bytes", outer_data_bytes.len());
// let file_path = "outer_circ_data.bin";
// // Write data to the file
// write_bytes_to_file(outer_data_bytes, file_path).unwrap();
// println!("Data written to {}", file_path);
// Verify the proof
let verifier_data = vd_agg.verifier_data();
@ -119,13 +121,13 @@ fn test_tree_recursion() -> anyhow::Result<()> {
Ok(())
}
// test another approach of the tree recursion
// test another approach of the simple recursion
#[test]
pub fn test_tree_recursion2()-> anyhow::Result<()>{
pub fn test_simple_recursion_approach2()-> anyhow::Result<()>{
// number of samples in each proof
let n_samples = 10;
let n_samples = 5;
// number of inner proofs:
let n_inner = 4;
const n_inner: usize = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
@ -138,9 +140,10 @@ pub fn test_tree_recursion2()-> anyhow::Result<()>{
}
let data = data.unwrap();
let rt_params = RecursionTreeParams::new(n_inner);
let rec_circuit = SimpleRecursionCircuit::new(rt_params, data.verifier_data());
// careful here, the sampling recursion is the default so proofs should be for circuit
// with default params
let sampling_inner_circ = SamplingRecursion::default();
let rec_circuit = SimpleRecursionCircuit::<_,n_inner>::new(sampling_inner_circ);
// Create the circuit
let config = CircuitConfig::standard_recursion_config();
@ -148,7 +151,7 @@ pub fn test_tree_recursion2()-> anyhow::Result<()>{
// Create a PartialWitness
let mut pw = PartialWitness::new();
let targets = rec_circuit.build_circuit(&mut builder);
let targets = rec_circuit.build_circuit(&mut builder)?;
let start = Instant::now();
let agg_data = builder.build::<C>();

View File

@ -0,0 +1,185 @@
// some tests for approach 1 of the tree recursion
#[cfg(test)]
mod tests {
use std::time::Instant;
use anyhow::{anyhow, Result};
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2_field::types::Field;
use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleCircuitInput};
use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit;
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use crate::gen_input::get_m_default_circ_input;
use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion};
/// Uses node recursion to sample the dataset
#[test]
fn test_node_recursion() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
const M: usize = 1;
const N: usize = 2;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
let inner_sampling_circuit = SamplingRecursion::default();
let mut node = NodeCircuit::<_,M,N>::new(inner_sampling_circuit);
let mut tree_circ = TreeRecursion::new(node);
let circ_input = get_m_default_circ_input::<M>();
let s = Instant::now();
tree_circ.build()?;
println!("build = {:?}", s.elapsed());
let s = Instant::now();
let proof = tree_circ.prove(&circ_input,None, true)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
let s = Instant::now();
assert!(
tree_circ.verify_proof(proof).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
/// Uses node recursion to sample the dataset
#[test]
fn test_tree_recursion_approach1() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
const M: usize = 1;
const N: usize = 2;
const DEPTH: usize = 3;
const TOTAL_INPUT: usize = (N.pow(DEPTH as u32) - 1) / (N - 1);
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let inner_sampling_circuit = SamplingRecursion::default();
let mut node = NodeCircuit::<_,M,N>::new(inner_sampling_circuit);
let mut tree_circ = TreeRecursion::new(node);
let all_circ_input = get_m_default_circ_input::<TOTAL_INPUT>().to_vec();
let s = Instant::now();
tree_circ.build()?;
println!("build = {:?}", s.elapsed());
let s = Instant::now();
let proof = tree_circ.prove_tree(all_circ_input.clone(),DEPTH)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
// Extract the final public input hash from the proof
let final_proof_hash = &proof.public_inputs[0..4];
// Recompute the expected final public input hash (outside the circuit)
let expected_hash = compute_expected_pub_input_hash::<SamplingRecursion>(
&all_circ_input,
DEPTH,
M,
N
)?;
// Check that the final hash in the proof matches the expected hash
assert_eq!(final_proof_hash, expected_hash.as_slice(), "Public input hash mismatch");
let s = Instant::now();
assert!(
tree_circ.verify_proof(proof).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
/// Recursively compute the final public input hash for a single node in the recursion tree.
/// This is the same logic from `NodeCircuit::build_circuit`
/// TODO: optimize this
fn compute_node_hash<I: InnerCircuit<Input = SampleCircuitInput<F,D>>>(
all_circ_inputs: &[I::Input],
depth: usize,
current_depth: usize,
node_idx: usize,
M: usize,
N: usize,
) -> [F; 4] {
// Calculate the index in all_circ_inputs for this node's M inputs.
// Total inputs per layer: sum_{k=0}^{current_depth-1} M*N^k = M * ((N^current_depth - 1)/(N-1))
let offset_for_layer = ((N.pow(current_depth as u32) - 1) / (N - 1)) * M;
let node_start = offset_for_layer + node_idx * M;
let node_inputs = &all_circ_inputs[node_start..node_start + M];
// Compute the outer public input hash:
// public inputs are [slot_index, dataset_root.elements, entropy.elements].
let mut outer_pi_hashes = vec![];
for inp in node_inputs {
let mut pi_vec = vec![inp.slot_index];
pi_vec.extend_from_slice(&inp.dataset_root.elements);
pi_vec.extend_from_slice(&inp.entropy.elements);
let hash_res = PoseidonHash::hash_no_pad(&pi_vec);
outer_pi_hashes.extend_from_slice(&hash_res.elements);
}
// hash all these M hashes into one
let outer_pi_hash = PoseidonHash::hash_no_pad(&outer_pi_hashes);
let is_leaf = current_depth == depth - 1;
// Compute the inner proof hash (or zero hash if leaf)
let inner_pi_hash_or_zero = if is_leaf {
// condition = false at leaf, so inner proofs = zero hash
[F::ZERO; 4]
} else {
// condition = true at non-leaf node -> recursively compute child hashes
let next_depth = current_depth + 1;
let child_start = node_idx * N;
let mut inner_pub_input_hashes = vec![];
for i in child_start..child_start + N {
let child_hash = compute_node_hash::<I>(all_circ_inputs, depth, next_depth, i, M, N);
inner_pub_input_hashes.extend_from_slice(&child_hash);
}
let inner_pub_input_hash = PoseidonHash::hash_no_pad(&inner_pub_input_hashes);
inner_pub_input_hash.elements
};
// Combine outer_pi_hash and inner_pi_hash_or_zero
let mut final_input = vec![];
final_input.extend_from_slice(&outer_pi_hash.elements);
final_input.extend_from_slice(&inner_pi_hash_or_zero);
let final_hash = PoseidonHash::hash_no_pad(&final_input);
final_hash.elements
}
/// Compute the expected public input hash for the entire recursion tree.
/// This function calls `compute_node_hash` starting from the root (layer 0, node 0).
pub fn compute_expected_pub_input_hash<I: InnerCircuit<Input = SampleCircuitInput<F,D>>>(
all_circ_inputs: &[I::Input],
depth: usize,
M: usize,
N: usize,
) -> Result<Vec<F>> {
// The root node is at layer = 0 and node_idx = 0
let final_hash = compute_node_hash::<I>(all_circ_inputs, depth, 0, 0, M, N);
Ok(final_hash.to_vec())
}
}

View File

@ -0,0 +1,307 @@
// some tests for approach 2 of the tree recursion
#[cfg(test)]
mod tests {
use std::time::Instant;
use anyhow::{anyhow, Result};
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2::plonk::proof::ProofWithPublicInputs;
use codex_plonky2_circuits::circuits::params::CircuitParams;
use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleCircuitInput};
use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit;
use codex_plonky2_circuits::recursion::leaf_circuit::{LeafCircuit, LeafInput};
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use crate::gen_input::gen_testing_circuit_input;
use crate::params::TestParams;
use codex_plonky2_circuits::recursion::tree_recursion2::{NodeCircuit as nodeC, TreeRecursion as TR};
use codex_plonky2_circuits::recursion::utils::{get_dummy_leaf_proof, get_dummy_node_proof};
use crate::gen_input::get_m_default_circ_input;
/// Uses node recursion to sample the dataset
#[test]
fn test_leaf_circuit() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
const M: usize = 1;
const N: usize = 2;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let params = TestParams::default();
let one_circ_input = gen_testing_circuit_input::<F,D>(&params);
let samp_circ = SampleCircuit::<F,D>::new(CircuitParams::default());
let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut builder);
let mut pw = PartialWitness::<F>::new();
samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input);
let inner_d = builder.build::<C>();
let inner_prf = inner_d.prove(pw)?;
let leaf_in = LeafInput{
inner_proof:inner_prf,
verifier_data: inner_d.verifier_data(),
};
let config2 = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config2);
let inner_circ = SamplingRecursion::default();
let leaf_circuit = LeafCircuit::new(inner_circ);
let s = Instant::now();
let leaf_tar = leaf_circuit.build(&mut builder)?;
let circ_data = builder.build::<C>();
println!("build = {:?}", s.elapsed());
let s = Instant::now();
// let proof = tree_circ.prove(&[leaf_in],None, true)?;
let mut pw = PartialWitness::<F>::new();
leaf_circuit.assign_targets(&mut pw, &leaf_tar, &leaf_in)?;
let proof = circ_data.prove(pw)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
let s = Instant::now();
assert!(
circ_data.verify(proof).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
#[test]
fn test_node_circuit_approach2() -> Result<()> {
// use predefined: C, D, F c
const N: usize = 2; // binary tree
let config = CircuitConfig::standard_recursion_config();
let mut sampling_builder = CircuitBuilder::<F, D>::new(config);
//------------ sampling inner circuit ----------------------
// Circuit that does the sampling - default input
let mut params = TestParams::default();
// params.n_samples = 10;
let one_circ_input = gen_testing_circuit_input::<F,D>(&params);
let samp_circ = SampleCircuit::<F,D>::new(CircuitParams::default());
let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut sampling_builder);
// get generate a sampling proof
let mut pw = PartialWitness::<F>::new();
samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input);
let inner_data = sampling_builder.build::<C>();
let inner_proof = inner_data.prove(pw)?;
// ------------------- leaf --------------------
// leaf circuit that verifies the sampling proof
let inner_circ = SamplingRecursion::default();
let leaf_circuit = LeafCircuit::new(inner_circ);
let leaf_in = LeafInput{
inner_proof,
verifier_data: inner_data.verifier_data(),
};
let config = CircuitConfig::standard_recursion_config();
let mut leaf_builder = CircuitBuilder::<F, D>::new(config);
// build
let s = Instant::now();
let leaf_targets = leaf_circuit.build(&mut leaf_builder)?;
let leaf_circ_data = leaf_builder.build::<C>();
println!("build = {:?}", s.elapsed());
// prove
let s = Instant::now();
let mut pw = PartialWitness::<F>::new();
leaf_circuit.assign_targets(&mut pw, &leaf_targets, &leaf_in)?;
let leaf_proof = leaf_circ_data.prove(pw)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", leaf_proof.public_inputs.len());
println!("pub input: {:?}", leaf_proof.public_inputs);
// verify
let s = Instant::now();
assert!(
leaf_circ_data.verify(leaf_proof.clone()).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
// ------------- Node circuit ------------------
// node circuit that verifies leafs or itself
// build
let s = Instant::now();
let mut node = nodeC::build_circuit()?;
println!("build = {:?}", s.elapsed());
// prove leaf
let s = Instant::now();
let mut pw = PartialWitness::<F>::new();
let leaf_proofs: [ProofWithPublicInputs<F, C, D>; N] = (0..N)
.map(|_| {
leaf_proof.clone()
})
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("Expected exactly M inner circuits"))?;
let dummy_node_proof = get_dummy_node_proof(
&node.node_data.inner_node_common_data,
&node.node_data.node_circuit_data.verifier_only,
);
let dummy_node_proofs: [ProofWithPublicInputs<F, C, D>; N] = (0..N)
.map(|_| {
dummy_node_proof.clone()
})
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("Expected exactly M inner circuits"))?;
nodeC::<N>::assign_targets(
node.node_targets.clone(), //targets
Some(leaf_proofs), // leaf proofs
Some(dummy_node_proofs), // node proofs (dummy here)
&node.node_data.leaf_circuit_data.verifier_only, // leaf verifier data
&mut pw, // partial witness
true // is leaf
)?;
let node_proof = node.node_data.node_circuit_data.prove(pw)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", node_proof.public_inputs.len());
println!("pub input: {:?}", node_proof.public_inputs);
let s = Instant::now();
assert!(
node.node_data.node_circuit_data.verify(node_proof.clone()).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
// prove node
let s = Instant::now();
let mut pw = PartialWitness::<F>::new();
let node_proofs: [ProofWithPublicInputs<F, C, D>; N] = (0..N)
.map(|_| {
node_proof.clone()
})
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("Expected exactly M inner circuits"))?;
let dummy_leaf_proof = get_dummy_leaf_proof(
&node.node_data.leaf_circuit_data.common
);
let dummy_leaf_proofs: [ProofWithPublicInputs<F, C, D>; N] = (0..N)
.map(|_| {
dummy_leaf_proof.clone()
})
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("Expected exactly M inner circuits"))?;
nodeC::<N>::assign_targets(
node.node_targets.clone(), //targets
Some(dummy_leaf_proofs), // leaf proofs
Some(node_proofs), // node proofs (dummy here)
&node.node_data.leaf_circuit_data.verifier_only, // leaf verifier data
&mut pw, // partial witness
false // is leaf
)?;
let node_proof = node.node_data.node_circuit_data.prove(pw)?;
// let node_proof = node_d.prove(pw)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", node_proof.public_inputs.len());
println!("pub input: {:?}", node_proof.public_inputs);
let s = Instant::now();
assert!(
node.node_data.node_circuit_data.verify(node_proof.clone()).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
#[test]
fn test_tree_recursion_approach2() -> Result<()> {
// use predefined: C, D, F c
const N: usize = 2; // binary tree
const K: usize = 4; // number of leaves/slots sampled - should be power of 2
let config = CircuitConfig::standard_recursion_config();
let mut sampling_builder = CircuitBuilder::<F, D>::new(config);
//------------ sampling inner circuit ----------------------
// Circuit that does the sampling - default input
let mut params = TestParams::default();
params.n_samples = 10;
let one_circ_input = gen_testing_circuit_input::<F,D>(&params);
let samp_circ = SampleCircuit::<F,D>::new(CircuitParams::default());
let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut sampling_builder);
// get generate a sampling proof
let mut pw = PartialWitness::<F>::new();
samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input);
let inner_data = sampling_builder.build::<C>();
println!("sampling circuit degree bits = {:?}", inner_data.common.degree_bits());
let inner_proof = inner_data.prove(pw)?;
// ------------------- leaf --------------------
// leaf circuit that verifies the sampling proof
let inner_circ = SamplingRecursion::default();
let leaf_circuit = LeafCircuit::new(inner_circ);
let leaf_in = LeafInput{
inner_proof,
verifier_data: inner_data.verifier_data(),
};
let config = CircuitConfig::standard_recursion_config();
let mut leaf_builder = CircuitBuilder::<F, D>::new(config);
// build
let s = Instant::now();
let leaf_targets = leaf_circuit.build(&mut leaf_builder)?;
let leaf_circ_data = leaf_builder.build::<C>();
println!("build = {:?}", s.elapsed());
println!("leaf circuit degree bits = {:?}", leaf_circ_data.common.degree_bits());
// prove
let s = Instant::now();
let mut pw = PartialWitness::<F>::new();
leaf_circuit.assign_targets(&mut pw, &leaf_targets, &leaf_in)?;
let leaf_proof = leaf_circ_data.prove(pw)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", leaf_proof.public_inputs.len());
println!("pub input: {:?}", leaf_proof.public_inputs);
// verify
let s = Instant::now();
assert!(
leaf_circ_data.verify(leaf_proof.clone()).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
// ------------- tree circuit ------------------
// node circuit that verifies leafs or itself
// build
let s = Instant::now();
let mut tree = TR::<N>::build()?;
println!("build = {:?}", s.elapsed());
println!("tree circuit degree bits = {:?}", tree.node.node_data.node_circuit_data.common.degree_bits());
// prove leaf
let s = Instant::now();
// let mut pw = PartialWitness::<F>::new();
let leaf_proofs: Vec<ProofWithPublicInputs<F, C, D>> = (0..K)
.map(|_| {
leaf_proof.clone()
})
.collect::<Vec<_>>();
let tree_root_proof = tree.prove_tree(leaf_proofs)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", tree_root_proof.public_inputs.len());
println!("pub input: {:?}", tree_root_proof.public_inputs);
let s = Instant::now();
assert!(
tree.verify_proof(tree_root_proof.clone()).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
}

View File

@ -1,68 +0,0 @@
// some tests for cyclic recursion
#[cfg(test)]
mod tests {
use std::time::Instant;
use anyhow::{anyhow, Result};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput;
use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit;
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use crate::gen_input::gen_testing_circuit_input;
use crate::params::TestParams;
use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion};
fn get_m_default_circ_input<const M: usize>() -> [SampleCircuitInput<F,D>; M]{
let mut params = TestParams::default();
params.n_samples = 10;
let one_circ_input = gen_testing_circuit_input::<F,D>(&params);
let circ_input: [SampleCircuitInput<F,D>; M] = (0..M)
.map(|_| one_circ_input.clone())
.collect::<Vec<_>>()
.try_into().unwrap();
circ_input
}
/// Uses node recursion to sample the dataset
#[test]
fn test_node_recursion() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
const M: usize = 1;
const N: usize = 2;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit);
let mut tree_circ = TreeRecursion::new(cyclic_circ);
let circ_input = get_m_default_circ_input::<M>();
let s = Instant::now();
tree_circ.build()?;
println!("build = {:?}", s.elapsed());
let s = Instant::now();
let proof = tree_circ.prove(&circ_input,None, true)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
let s = Instant::now();
assert!(
tree_circ.verify_proof(proof).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
}

View File

@ -52,4 +52,8 @@ harness = false
[[bench]]
name = "tree_recursion"
harness = false
[[bench]]
name = "simple_tree_recursion"
harness = false

View File

@ -0,0 +1,58 @@
use criterion::{Criterion, criterion_group, criterion_main};
use plonky2::plonk::circuit_data::VerifierCircuitData;
use plonky2::plonk::config::GenericConfig;
use plonky2::plonk::proof::ProofWithPublicInputs;
use codex_plonky2_circuits::recursion::simple_tree_recursion::aggregate_sampling_proofs_tree2;
use proof_input::params::{C, D, F};
use proof_input::gen_input::{build_circuit, prove_circuit};
/// Benchmark for building, proving, and verifying the Plonky2 recursion circuit.
fn bench_tree_recursion(c: &mut Criterion) {
// num of inner proofs
let num_of_inner_proofs = 4;
// number of samples in each proof
let n_samples = 10;
let (data, pw) = build_circuit(n_samples, 3).unwrap();
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..num_of_inner_proofs{
proofs_with_pi.push(prove_circuit(&data, &pw).unwrap());
}
let vd = data.verifier_data();
let mut group = c.benchmark_group("bench simple tree recursion");
let mut agg_proof_with_pis: Option<ProofWithPublicInputs<F, C, D>> = None;
let mut agg_vd: Option<VerifierCircuitData<F, C, D>> = None;
// Benchmark the Circuit Building Phase
group.bench_function("build & prove Circuit", |b| {
b.iter(|| {
let (agg_p, agg_d) = aggregate_sampling_proofs_tree2(&proofs_with_pi, vd.clone()).unwrap();
agg_proof_with_pis = Some(agg_p);
agg_vd = Some(agg_d);
})
});
let proof = agg_proof_with_pis.unwrap();
println!("Proof size: {} bytes", proof.to_bytes().len());
// Benchmark the Verifying Phase
let loc_vd = agg_vd.unwrap();
group.bench_function("Verify Proof", |b| {
b.iter(|| {
loc_vd.clone().verify(proof.clone()).expect("Failed to verify proof");
})
});
group.finish();
}
/// Criterion benchmark group
criterion_group!{
name = recursion;
config = Criterion::default().sample_size(10);
targets = bench_tree_recursion
}
criterion_main!(recursion);

View File

@ -1,49 +1,80 @@
use anyhow::Result;
use criterion::{criterion_group, criterion_main, Criterion};
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData};
use criterion::{Criterion, criterion_group, criterion_main};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData};
use plonky2::plonk::config::GenericConfig;
use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs_tree, aggregate_sampling_proofs_tree2};
use plonky2::plonk::proof::ProofWithPublicInputs;
use proof_input::params::{D, C, F, Params, TestParams};
use proof_input::gen_input::{build_circuit, prove_circuit};
use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput;
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion};
use proof_input::params::{C, D, F, TestParams};
use proof_input::gen_input::gen_testing_circuit_input;
/// Benchmark for building, proving, and verifying the Plonky2 recursion circuit.
fn get_m_default_circ_input<const M: usize>() -> [SampleCircuitInput<codex_plonky2_circuits::recursion::params::F,D>; M]{
let mut params = TestParams::default();
params.n_samples = 10;
let one_circ_input = gen_testing_circuit_input::<codex_plonky2_circuits::recursion::params::F,D>(&params);
let circ_input: [SampleCircuitInput<codex_plonky2_circuits::recursion::params::F,D>; M] = (0..M)
.map(|_| one_circ_input.clone())
.collect::<Vec<_>>()
.try_into().unwrap();
circ_input
}
/// Benchmark for building, proving, and verifying the Plonky2 tree recursion circuit.
fn bench_tree_recursion(c: &mut Criterion) {
// num of inner proofs
let num_of_inner_proofs = 4;
// number of samples in each proof
let n_samples = 10;
let (data, pw) = build_circuit(n_samples, 3).unwrap();
const M: usize = 1;
const N: usize = 2;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..num_of_inner_proofs{
proofs_with_pi.push(prove_circuit(&data, &pw).unwrap());
}
let vd = data.verifier_data();
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<codex_plonky2_circuits::recursion::params::F, D>::new(config);
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit);
let mut tree_circ = TreeRecursion::new(cyclic_circ);
let circ_input = get_m_default_circ_input::<M>();
// // get proofs
// let mut proofs_with_pi = vec![];
// for i in 0..num_of_inner_proofs{
// proofs_with_pi.push(prove_circuit(&data, &pw).unwrap());
// }
// let vd = data.verifier_data();
let mut group = c.benchmark_group("bench simple tree recursion");
let mut agg_proof_with_pis: Option<ProofWithPublicInputs<F, C, D>> = None;
let mut agg_vd: Option<VerifierCircuitData<F, C, D>> = None;
// Benchmark the Circuit Building Phase
group.bench_function("build & prove Circuit", |b| {
group.bench_function("build", |b| {
b.iter(|| {
let (agg_p, agg_d) = aggregate_sampling_proofs_tree2(&proofs_with_pi, vd.clone()).unwrap();
agg_proof_with_pis = Some(agg_p);
agg_vd = Some(agg_d);
let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit.clone());
let mut tree_circ = TreeRecursion::new(cyclic_circ);
tree_circ.build();
})
});
let proof = agg_proof_with_pis.unwrap();
println!("Proof size: {} bytes", proof.to_bytes().len());
// let proof = agg_proof_with_pis.unwrap();
// println!("Proof size: {} bytes", proof.to_bytes().len());
let mut proof: Option<ProofWithPublicInputs<F, C, D>> = None;
// Benchmark the Circuit prove Phase
group.bench_function("prove", |b| {
b.iter(|| {
proof = Some(tree_circ.prove(&circ_input,None, true).unwrap());
})
});
// let proof = tree_circ.prove(&circ_input,None, true)?;
// Benchmark the Verifying Phase
let loc_vd = agg_vd.unwrap();
group.bench_function("Verify Proof", |b| {
b.iter(|| {
loc_vd.clone().verify(proof.clone()).expect("Failed to verify proof");
tree_circ.verify_proof(proof.unwrap()).expect("Failed to verify proof");
})
});
@ -53,7 +84,7 @@ fn bench_tree_recursion(c: &mut Criterion) {
/// Criterion benchmark group
criterion_group!{
name = recursion;
config = Criterion::default().sample_size(10);
config = Criterion::default().sample_size(3);
targets = bench_tree_recursion
}
criterion_main!(recursion);