use std::marker::PhantomData; use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierOnlyCircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; use plonky2::plonk::proof::ProofWithPublicInputs; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::recursion::circuits::inner_circuit::InnerCircuit; use plonky2_field::extension::Extendable; use crate::{error::CircuitError, Result}; use crate::recursion::uniform::{leaf::{LeafTargets,LeafCircuit},node::{NodeTargets,NodeCircuit}}; /// tree recursion pub struct TreeRecursion< F: RichField + Extendable + Poseidon2, const D: usize, C: GenericConfig, H: AlgebraicHasher, const N: usize, const M: usize, > where >::Hasher: AlgebraicHasher { leaf: LeafCircuit, node: NodeCircuit, leaf_circ_data: CircuitData, node_circ_data: CircuitData, leaf_targets: LeafTargets, node_targets: NodeTargets, phantom_data: PhantomData<(H)> } impl< F: RichField + Extendable + Poseidon2, const D: usize, C: GenericConfig, H: AlgebraicHasher, const N: usize, const M: usize, > TreeRecursion where >::Hasher: AlgebraicHasher { pub fn build( inner_common_data: CommonCircuitData ) -> Result { // build leaf with standard recursion config let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let leaf = LeafCircuit::<_,D,_,_,N>::new(inner_common_data.clone()); let leaf_targets = leaf.build(&mut builder)?; let leaf_circ_data = builder.build::(); // println!("leaf circuit size = {:?}", leaf_circ_data.common.degree_bits()); // build node with standard recursion config let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let node = NodeCircuit::<_,D,_,_,M>::new(leaf_circ_data.common.clone()); let node_targets = node.build(&mut builder)?; let node_circ_data = builder.build::(); // println!("node circuit size = {:?}", node_circ_data.common.degree_bits()); Ok(Self{ leaf, node, leaf_circ_data, node_circ_data, leaf_targets, node_targets, phantom_data: Default::default(), }) } pub fn get_leaf_verifier_data(&self) -> VerifierCircuitData{ self.leaf_circ_data.verifier_data() } pub fn get_node_verifier_data(&self) -> VerifierCircuitData{ self.node_circ_data.verifier_data() } pub fn prove_tree ( &mut self, proofs_with_pi: &[ProofWithPublicInputs], inner_verifier_only_data: &VerifierOnlyCircuitData, ) -> Result<(ProofWithPublicInputs)> { if proofs_with_pi.len() % 2 != 0 { return Err(CircuitError::RecursionTreeError(format!( "input proofs must be divisible by {}, got {}", 2, proofs_with_pi.len()) )) } // process leaves let leaf_proofs = self.get_leaf_proofs( proofs_with_pi, inner_verifier_only_data, )?; // process nodes let (root_proof, vd) = self.prove(&leaf_proofs,&self.leaf_circ_data.verifier_only)?; Ok(root_proof) } fn get_leaf_proofs ( &mut self, proofs_with_pi: &[ProofWithPublicInputs], inner_verifier_only_data: &VerifierOnlyCircuitData, ) -> Result<(Vec>)> { let mut leaf_proofs = vec![]; for proof in proofs_with_pi.chunks(N){ let mut pw = PartialWitness::::new(); self.leaf.assign_targets(&mut pw,&self.leaf_targets,proof,inner_verifier_only_data)?; let proof = self.leaf_circ_data.prove(pw).unwrap(); leaf_proofs.push(proof); } Ok(leaf_proofs) } /// generates a proof fn prove( &self, proofs_with_pi: &[ProofWithPublicInputs], verifier_only_data: &VerifierOnlyCircuitData, ) -> Result<(ProofWithPublicInputs, VerifierOnlyCircuitData)> where >::Hasher: AlgebraicHasher { if proofs_with_pi.len() == 1 { return Ok((proofs_with_pi[0].clone(), verifier_only_data.clone())); } let mut new_proofs = vec![]; for chunk in proofs_with_pi.chunks(M) { let mut inner_pw = PartialWitness::new(); self.node.assign_targets( &mut inner_pw, &self.node_targets, chunk, verifier_only_data, )?; let proof = self.node_circ_data.prove(inner_pw) .map_err(|e| CircuitError::ProofGenerationError(e.to_string()))?; new_proofs.push(proof); } self.prove(&new_proofs, &self.node_circ_data.verifier_only) } pub fn verify_proof_and_public_input( &self, proof: ProofWithPublicInputs, inner_public_input: Vec>, inner_verifier_data: &VerifierCircuitData) -> Result<()> { let public_input = proof.public_inputs.clone(); self.node_circ_data.verify(proof) .map_err(|e| CircuitError::InvalidProofError(e.to_string()))?; self.verify_public_input(public_input, inner_public_input, inner_verifier_data) } pub fn verify_public_input(&self, public_input: Vec, inner_public_input: Vec>, inner_verifier_data: &VerifierCircuitData) -> Result<()>{ assert_eq!(public_input.len(), 8); let given_input_hash = &public_input[0..4]; let given_vd_hash = &public_input[4..8]; let inner_hash = Self::get_hash_of_verifier_data(&inner_verifier_data); let leaf_hash = Self::get_hash_of_verifier_data(&self.leaf_circ_data.verifier_data()); let node_hash = Self::get_hash_of_verifier_data(&self.node_circ_data.verifier_data()); let mut pub_in_hashes = vec![]; let mut inner_vd_hashes = vec![]; for pub_in in inner_public_input.chunks(N){ let pub_in_flat: Vec = pub_in .iter() .flat_map(|v| v.iter().cloned()) .collect(); let hash = H::hash_no_pad(&pub_in_flat); pub_in_hashes.push(hash); inner_vd_hashes.push(inner_hash.clone()); } let mut level = 0; while pub_in_hashes.len() > 1 { let mut next_level_pi_hashes = Vec::new(); let mut next_level_vd_hashes = Vec::new(); for (pi_chunk, vd_chunk) in pub_in_hashes.chunks(M).zip(inner_vd_hashes.chunks(M)) { // collect field elements let pi_chunk_f: Vec = pi_chunk.iter() .flat_map(|h| h.elements.iter().cloned()) .collect(); let mut vd_chunk_f: Vec = vd_chunk.iter() .flat_map(|h| h.elements.iter().cloned()) .collect(); let hash_n = if level == 0 {leaf_hash} else{node_hash}; vd_chunk_f.extend_from_slice(&hash_n.elements); // Compute Poseidon2 hash of the concatenated chunk let pi_hash = H::hash_no_pad(&pi_chunk_f); let vd_hash = H::hash_no_pad(&vd_chunk_f); next_level_pi_hashes.push(pi_hash); next_level_vd_hashes.push(vd_hash); } pub_in_hashes = next_level_pi_hashes; inner_vd_hashes = next_level_vd_hashes; level +=1; } //check expected hash let expected_pi_hash = pub_in_hashes[0]; let expected_vd_hash = inner_vd_hashes[0]; assert_eq!(given_input_hash, expected_pi_hash.elements); assert_eq!(given_vd_hash, expected_vd_hash.elements); Ok(()) } /// helper fn to generate hash of verifier data fn get_hash_of_verifier_data(verifier_data: &VerifierCircuitData) -> HashOut{ let mut vd = vec![]; let digest: &HashOut = &verifier_data.verifier_only.circuit_digest; let caps = &verifier_data.verifier_only.constants_sigmas_cap; vd.extend_from_slice(&digest.elements); for i in 0..verifier_data.common.config.fri_config.num_cap_elements() { let cap_hash = caps.0[i] as HashOut; vd.extend_from_slice(&cap_hash.elements); } H::hash_no_pad(&vd) } }