add recursion tests and refactor

This commit is contained in:
M Alghazwi 2024-12-13 16:38:05 +03:00
parent b280a5252c
commit 9a8e0000ee
No known key found for this signature in database
GPG Key ID: 646E567CAD7DB607
10 changed files with 382 additions and 22726 deletions

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,118 @@
// some tests for cyclic recursion
#[cfg(test)]
mod tests {
use std::time::Instant;
use anyhow::Result;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField};
use plonky2::hash::hashing::hash_n_to_hash_no_pad;
use plonky2::hash::poseidon::{PoseidonHash, PoseidonPermutation};
use plonky2::iop::witness::{PartialWitness, PartitionWitness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data;
use plonky2::recursion::dummy_circuit::cyclic_base_proof;
use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit;
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use crate::gen_input::gen_testing_circuit_input;
use crate::params::TestParams;
use codex_plonky2_circuits::recursion::cyclic_recursion::{common_data_for_recursion, CyclicCircuit};
/// Uses cyclic recursion to sample the dataset
#[test]
fn test_cyclic_recursion() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut params = TestParams::default();
params.n_samples = 10;
let circ_input = gen_testing_circuit_input::<F,D>(&params);
let mut cyclic_circ = CyclicCircuit::new(inner_sampling_circuit);
let s = Instant::now();
cyclic_circ.build_circuit()?;
println!("build = {:?}", s.elapsed());
let s = Instant::now();
let proof = cyclic_circ.prove_one_layer(&circ_input)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
let s = Instant::now();
assert!(
cyclic_circ.verify_latest_proof().is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
let mut hash_input = vec![];
hash_input.push(circ_input.slot_index);
hash_input.extend_from_slice(&circ_input.dataset_root.elements);
hash_input.extend_from_slice(&circ_input.entropy.elements);
// let hash_res = PoseidonHash::hash_no_pad(&hash_input);
let hash_res = hash_n_to_hash_no_pad::<F, PoseidonPermutation<F>>(&hash_input);
let zero_hash = HashOut::<F>::ZERO;
let mut hash_input2 = vec![];
hash_input2.extend_from_slice(&hash_res.elements);
hash_input2.extend_from_slice(&zero_hash.elements);
let hash_res = hash_n_to_hash_no_pad::<F, PoseidonPermutation<F>>(&hash_input2);
println!("hash input = {:?}", hash_res.elements);
Ok(())
}
/// Uses cyclic recursion to sample the dataset n times
#[test]
fn test_cyclic_recursion_n_layers() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
const N : usize = 2;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut params = TestParams::default();
params.n_samples = 10;
let mut circ_inputs = vec![];
for i in 0..N {
circ_inputs.push(gen_testing_circuit_input::<F, D>(&params));
}
let mut cyclic_circ = CyclicCircuit::new(inner_sampling_circuit);
let s = Instant::now();
cyclic_circ.build_circuit()?;
println!("build = {:?}", s.elapsed());
let s = Instant::now();
let proof = cyclic_circ.prove_n_layers(N,circ_inputs)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
let s = Instant::now();
assert!(
cyclic_circ.verify_latest_proof().is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
}

View File

@ -414,13 +414,13 @@ pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow
// build the circuit
let circ = SampleCircuit::new(circuit_params.clone());
let mut targets = circ.sample_slot_circuit(&mut builder);
let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder);
// Create a PartialWitness and assign
let mut pw = PartialWitness::new();
// assign a witness
circ.sample_slot_assign_witness(&mut pw, &mut targets, circ_input);
circ.sample_slot_assign_witness(&mut pw, &targets, &circ_input);
// Build the circuit
let data = builder.build::<C>();
@ -479,13 +479,13 @@ mod tests {
// build the circuit
let circ = SampleCircuit::new(circuit_params.clone());
let mut targets = circ.sample_slot_circuit(&mut builder);
let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder);
// Create a PartialWitness and assign
let mut pw = PartialWitness::new();
// assign a witness
circ.sample_slot_assign_witness(&mut pw, &mut targets, circ_input);
circ.sample_slot_assign_witness(&mut pw, &targets, &circ_input);
// Build the circuit
let data = builder.build::<C>();
@ -506,170 +506,4 @@ mod tests {
Ok(())
}
// Test recursion
#[test]
fn test_recursion() -> anyhow::Result<()> {
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
let n_inner = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..n_inner{
// build the circuit
let (data_i, pw) = build_circuit(n_samples, i)?;
// prove
proofs_with_pi.push(prove_circuit(&data_i, &pw)?);
data = Some(data_i);
}
println!("num of public inputs inner proof = {}", proofs_with_pi[0].public_inputs.len());
// Create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Create a PartialWitness
let mut pw_agg = PartialWitness::new();
// aggregate proofs
aggregate_sampling_proofs(&proofs_with_pi, &data.unwrap().verifier_data(), &mut builder, &mut pw_agg)?;
let data_agg = builder.build::<C>();
// Prove the circuit with the assigned witness
let start_time = Instant::now();
let proof_with_pis_agg = data_agg.prove(pw_agg)?;
println!("prove_time = {:?}", start_time.elapsed());
println!("num of public inputs = {}", proof_with_pis_agg.public_inputs.len());
// Verify the proof
let verifier_data = data_agg.verifier_data();
assert!(
verifier_data.verify(proof_with_pis_agg).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// Test tree recursion
#[test]
fn test_tree_recursion() -> anyhow::Result<()> {
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
let n_inner = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..n_inner{
// build the circuit
let (data_i, pw) = build_circuit(n_samples, i)?;
proofs_with_pi.push(prove_circuit(&data_i, &pw)?);
data = Some(data_i);
}
let data = data.unwrap();
println!("inner circuit size = {:?}", data.common.degree_bits());
let gate_serializer = DefaultGateSerializer;
let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
let data_bytes = data.to_bytes(&gate_serializer, &generator_serializer).unwrap();
println!("inner proof circuit data size = {} bytes", data_bytes.len());
let file_path = "inner_circ_data.bin";
// Write data to the file
write_bytes_to_file(data_bytes, file_path).unwrap();
println!("Data written to {}", file_path);
let start_time = Instant::now();
let (proof, vd_agg) = aggregate_sampling_proofs_tree(&proofs_with_pi, data)?;
println!("prove_time = {:?}", start_time.elapsed());
println!("num of public inputs = {}", proof.public_inputs.len());
println!("agg pub input = {:?}", proof.public_inputs);
println!("outer circuit size = {:?}", vd_agg.common.degree_bits());
// let gate_serializer = DefaultGateSerializer;
// let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
let outer_data_bytes = vd_agg.to_bytes(&gate_serializer, &generator_serializer).unwrap();
println!("outer proof circuit data size = {} bytes", outer_data_bytes.len());
let file_path = "outer_circ_data.bin";
// Write data to the file
write_bytes_to_file(outer_data_bytes, file_path).unwrap();
println!("Data written to {}", file_path);
// Verify the proof
let verifier_data = vd_agg.verifier_data();
assert!(
verifier_data.verify(proof).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// test the tree recursion
#[test]
pub fn test_tree_recursion2()-> anyhow::Result<()>{
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
let n_inner = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..n_inner{
// build the circuit
let (data_i, pw) = build_circuit(n_samples, i)?;
proofs_with_pi.push(prove_circuit(&data_i, &pw)?);
data = Some(data_i);
}
let data = data.unwrap();
let rt_params = RecursionTreeParams::new(n_inner);
let rec_circuit = SimpleRecursionCircuit::new(rt_params, data.verifier_data());
// Create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Create a PartialWitness
let mut pw = PartialWitness::new();
let targets = rec_circuit.build_circuit(&mut builder);
let start = Instant::now();
let agg_data = builder.build::<C>();
println!("build time = {:?}", start.elapsed());
println!("circuit size = {:?}", data.common.degree_bits());
let mut default_entropy = HashOut::ZERO;
default_entropy.elements[0] = F::from_canonical_u64(1234567);
let w = SimpleRecursionInput{
proofs: proofs_with_pi,
verifier_data: data.verifier_data(),
entropy: default_entropy,
};
rec_circuit.assign_witness(&mut pw,&targets,w)?;
let start = Instant::now();
let proof = agg_data.prove(pw)?;
println!("prove time = {:?}", start.elapsed());
// Verify the proof
let verifier_data = agg_data.verifier_data();
assert!(
verifier_data.verify(proof).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
}

View File

@ -384,7 +384,7 @@ mod tests {
let circuit_params = CircuitParams::default();
let circ = SampleCircuit::new(circuit_params.clone());
let mut targets = circ.sample_slot_circuit(&mut builder);
let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder);
// Create a PartialWitness and assign
let mut pw = PartialWitness::new();
@ -393,7 +393,7 @@ mod tests {
let imported_circ_input: SampleCircuitInput<F, D> = import_circ_input_from_json("input.json")?;
println!("circuit input imported from input.json");
circ.sample_slot_assign_witness(&mut pw, &mut targets, imported_circ_input);
circ.sample_slot_assign_witness(&mut pw, &targets, &imported_circ_input);
// Build the circuit
let data = builder.build::<C>();
@ -445,14 +445,14 @@ mod tests {
let circuit_params = CircuitParams::default();
let circ = SampleCircuit::new(circuit_params.clone());
let mut targets = circ.sample_slot_circuit(&mut builder);
let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder);
// Create a PartialWitness and assign
let mut pw = PartialWitness::new();
// gen circ input
let imported_circ_input: SampleCircuitInput<F, D> = gen_testing_circuit_input::<F,D>(&params);
circ.sample_slot_assign_witness(&mut pw, &mut targets, imported_circ_input);
circ.sample_slot_assign_witness(&mut pw, &targets, &imported_circ_input);
// Build the circuit
let data = builder.build::<C>();
@ -497,14 +497,14 @@ mod tests {
let circuit_params = CircuitParams::default();
let circ = SampleCircuit::new(circuit_params.clone());
let mut targets = circ.sample_slot_circuit(&mut builder);
let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder);
// Create a PartialWitness and assign
let mut pw = PartialWitness::new();
// gen circ input
let imported_circ_input: SampleCircuitInput<F, D> = gen_testing_circuit_input::<F,D>(&params);
circ.sample_slot_assign_witness(&mut pw, &mut targets, imported_circ_input);
circ.sample_slot_assign_witness(&mut pw, &targets, &imported_circ_input);
// Build the circuit
let data = builder.build::<C>();

View File

@ -4,4 +4,7 @@ pub mod params;
pub mod utils;
pub mod json;
pub mod tests;
mod sponge;
mod sponge;
pub mod cyclic_recursion;
pub mod tree_recursion;
pub mod simple_recursion;

View File

@ -10,7 +10,7 @@ use plonky2_poseidon2::config::Poseidon2GoldilocksConfig;
// test types
pub const D: usize = 2;
pub type C = Poseidon2GoldilocksConfig;
pub type C = PoseidonGoldilocksConfig;
pub type F = <C as GenericConfig<D>>::F; // this is the goldilocks field
// pub type H = PoseidonHash;
// pub type HP = <PoseidonHash as plonky2::plonk::config::Hasher<F>>::Permutation;

View File

@ -0,0 +1,181 @@
// tests for simple recursion
use std::time::Instant;
use plonky2::hash::hash_types::HashOut;
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData};
use plonky2_field::types::Field;
use codex_plonky2_circuits::recursion::params::RecursionTreeParams;
use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs, aggregate_sampling_proofs_tree};
use codex_plonky2_circuits::recursion::simple_recursion2::{SimpleRecursionCircuit, SimpleRecursionInput};
use plonky2_poseidon2::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer};
use crate::gen_input::{build_circuit, prove_circuit};
use crate::json::write_bytes_to_file;
use crate::params::{C, F, D};
// Test recursion
#[test]
fn test_recursion() -> anyhow::Result<()> {
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
let n_inner = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..n_inner{
// build the circuit
let (data_i, pw) = build_circuit(n_samples, i)?;
// prove
proofs_with_pi.push(prove_circuit(&data_i, &pw)?);
data = Some(data_i);
}
println!("num of public inputs inner proof = {}", proofs_with_pi[0].public_inputs.len());
// Create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Create a PartialWitness
let mut pw_agg = PartialWitness::new();
// aggregate proofs
aggregate_sampling_proofs(&proofs_with_pi, &data.unwrap().verifier_data(), &mut builder, &mut pw_agg)?;
let data_agg = builder.build::<C>();
// Prove the circuit with the assigned witness
let start_time = Instant::now();
let proof_with_pis_agg = data_agg.prove(pw_agg)?;
println!("prove_time = {:?}", start_time.elapsed());
println!("num of public inputs = {}", proof_with_pis_agg.public_inputs.len());
// Verify the proof
let verifier_data = data_agg.verifier_data();
assert!(
verifier_data.verify(proof_with_pis_agg).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// Test simple tree recursion
#[test]
fn test_tree_recursion() -> anyhow::Result<()> {
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
let n_inner = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..n_inner{
// build the circuit
let (data_i, pw) = build_circuit(n_samples, i)?;
proofs_with_pi.push(prove_circuit(&data_i, &pw)?);
data = Some(data_i);
}
let data = data.unwrap();
println!("inner circuit size = {:?}", data.common.degree_bits());
let gate_serializer = DefaultGateSerializer;
let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
let data_bytes = data.to_bytes(&gate_serializer, &generator_serializer).unwrap();
println!("inner proof circuit data size = {} bytes", data_bytes.len());
let file_path = "inner_circ_data.bin";
// Write data to the file
write_bytes_to_file(data_bytes, file_path).unwrap();
println!("Data written to {}", file_path);
let start_time = Instant::now();
let (proof, vd_agg) = aggregate_sampling_proofs_tree(&proofs_with_pi, data)?;
println!("prove_time = {:?}", start_time.elapsed());
println!("num of public inputs = {}", proof.public_inputs.len());
println!("agg pub input = {:?}", proof.public_inputs);
println!("outer circuit size = {:?}", vd_agg.common.degree_bits());
// let gate_serializer = DefaultGateSerializer;
// let generator_serializer =DefaultGeneratorSerializer::<C, D>::default();
let outer_data_bytes = vd_agg.to_bytes(&gate_serializer, &generator_serializer).unwrap();
println!("outer proof circuit data size = {} bytes", outer_data_bytes.len());
let file_path = "outer_circ_data.bin";
// Write data to the file
write_bytes_to_file(outer_data_bytes, file_path).unwrap();
println!("Data written to {}", file_path);
// Verify the proof
let verifier_data = vd_agg.verifier_data();
assert!(
verifier_data.verify(proof).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}
// test another approach of the tree recursion
#[test]
pub fn test_tree_recursion2()-> anyhow::Result<()>{
// number of samples in each proof
let n_samples = 10;
// number of inner proofs:
let n_inner = 4;
let mut data: Option<CircuitData<F, C, D>> = None;
// get proofs
let mut proofs_with_pi = vec![];
for i in 0..n_inner{
// build the circuit
let (data_i, pw) = build_circuit(n_samples, i)?;
proofs_with_pi.push(prove_circuit(&data_i, &pw)?);
data = Some(data_i);
}
let data = data.unwrap();
let rt_params = RecursionTreeParams::new(n_inner);
let rec_circuit = SimpleRecursionCircuit::new(rt_params, data.verifier_data());
// Create the circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Create a PartialWitness
let mut pw = PartialWitness::new();
let targets = rec_circuit.build_circuit(&mut builder);
let start = Instant::now();
let agg_data = builder.build::<C>();
println!("build time = {:?}", start.elapsed());
println!("circuit size = {:?}", data.common.degree_bits());
let mut default_entropy = HashOut::ZERO;
default_entropy.elements[0] = F::from_canonical_u64(1234567);
let w = SimpleRecursionInput{
proofs: proofs_with_pi,
verifier_data: data.verifier_data(),
entropy: default_entropy,
};
rec_circuit.assign_witness(&mut pw,&targets,w)?;
let start = Instant::now();
let proof = agg_data.prove(pw)?;
println!("prove time = {:?}", start.elapsed());
// Verify the proof
let verifier_data = agg_data.verifier_data();
assert!(
verifier_data.verify(proof).is_ok(),
"Merkle proof verification failed"
);
Ok(())
}

View File

@ -0,0 +1,68 @@
// some tests for cyclic recursion
#[cfg(test)]
mod tests {
use std::time::Instant;
use anyhow::{anyhow, Result};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig};
use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput;
use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof};
use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion;
use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit;
use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash};
use crate::gen_input::gen_testing_circuit_input;
use crate::params::TestParams;
use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion};
fn get_m_default_circ_input<const M: usize>() -> [SampleCircuitInput<F,D>; M]{
let mut params = TestParams::default();
params.n_samples = 10;
let one_circ_input = gen_testing_circuit_input::<F,D>(&params);
let circ_input: [SampleCircuitInput<F,D>; M] = (0..M)
.map(|_| one_circ_input.clone())
.collect::<Vec<_>>()
.try_into().unwrap();
circ_input
}
/// Uses node recursion to sample the dataset
#[test]
fn test_node_recursion() -> Result<()> {
// const D: usize = 2;
// type C = PoseidonGoldilocksConfig;
// type F = <C as GenericConfig<D>>::F;
const M: usize = 1;
const N: usize = 2;
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that does the sampling
let inner_sampling_circuit = SamplingRecursion::default();
let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit);
let mut tree_circ = TreeRecursion::new(cyclic_circ);
let circ_input = get_m_default_circ_input::<M>();
let s = Instant::now();
tree_circ.build()?;
println!("build = {:?}", s.elapsed());
let s = Instant::now();
let proof = tree_circ.prove(&circ_input,None, true)?;
println!("prove = {:?}", s.elapsed());
println!("num of pi = {}", proof.public_inputs.len());
println!("pub input: {:?}", proof.public_inputs);
let s = Instant::now();
assert!(
tree_circ.verify_proof(proof).is_ok(),
"proof verification failed"
);
println!("verify = {:?}", s.elapsed());
Ok(())
}
}