From 563ba77a8c2029635da0e0ffe226d0e775180fcc Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Thu, 9 Jan 2025 10:36:38 +0100 Subject: [PATCH] add prove_tree fn --- .../src/recursion/tree_recursion.rs | 56 ++++++++++++++++--- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/codex-plonky2-circuits/src/recursion/tree_recursion.rs b/codex-plonky2-circuits/src/recursion/tree_recursion.rs index 56f0cfc..33fcde5 100644 --- a/codex-plonky2-circuits/src/recursion/tree_recursion.rs +++ b/codex-plonky2-circuits/src/recursion/tree_recursion.rs @@ -1,4 +1,4 @@ - +use std::array::from_fn; use hashbrown::HashMap; use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField}; use plonky2::iop::target::{BoolTarget, Target}; @@ -10,7 +10,7 @@ use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::recursion::dummy_circuit::cyclic_base_proof; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use crate::recursion::params::RecursionTreeParams; +// use crate::recursion::params::RecursionTreeParams; use crate::recursion::params::{F,D,C,Plonky2Proof,H}; use crate::recursion::inner_circuit::InnerCircuit; use anyhow::{anyhow, Result}; @@ -28,7 +28,7 @@ pub struct TreeRecursion< const M: usize, const N: usize, >{ - node_circ: NodeCircuit + pub node_circ: NodeCircuit } impl< @@ -87,11 +87,51 @@ impl< /// prove n in a tree structure recursively /// the function takes /// - circ_input: vector of circuit inputs - pub fn prove_n_nodes( + pub fn prove_tree( &mut self, circ_input: Vec, + depth: usize, ) -> Result>{ - todo!() + // Total input size check + let total_input = (N.pow(depth as u32) - 1) / (N - 1); + assert_eq!(circ_input.len(), total_input, "Invalid input size for tree depth"); + + let mut cur_proofs: Vec> = vec![]; + + // Iterate from leaf layer to root + for layer in (0..depth).rev() { + let layer_num_nodes = N.pow(layer as u32); // Number of nodes at this layer + let mut next_proofs = Vec::new(); + + for node_idx in 0..layer_num_nodes { + // Get the inputs for the current node + let node_inputs: [I::Input; M] = from_fn(|i| { + circ_input + .get(node_idx * M + i) + .cloned() + .unwrap_or_else(|| panic!("Index out of bounds at node {node_idx}, input {i}")) + }); + + let proof = if layer == depth - 1 { + // Leaf layer: no child proofs + self.prove(&node_inputs, None, true)? + } else { + // Non-leaf layer: collect child proofs + let proofs_array: [ProofWithPublicInputs; N] = cur_proofs + .drain(..N) + .collect::>() + .try_into() + .map_err(|_| anyhow!("Incorrect number of proofs for node"))?; + self.prove(&node_inputs, Some(proofs_array), false)? + }; + next_proofs.push(proof); + } + cur_proofs = next_proofs; + } + + // Final root proof + assert_eq!(cur_proofs.len(), 1, "Final proof count incorrect"); + Ok(cur_proofs.remove(0)) } /// verifies the proof generated @@ -161,9 +201,9 @@ impl< &mut self, ) -> Result<()>{ // if the circuit data is already build then no need to rebuild - if self.cyclic_circuit_data.is_some(){ - return Ok(()); - } + // if self.cyclic_circuit_data.is_some(){ + // return Ok(()); + // } // builder with standard recursion config let config = CircuitConfig::standard_recursion_config();