diff --git a/codex-plonky2-circuits/src/circuits/utils.rs b/codex-plonky2-circuits/src/circuits/utils.rs index 0fc910d..cd46e54 100644 --- a/codex-plonky2-circuits/src/circuits/utils.rs +++ b/codex-plonky2-circuits/src/circuits/utils.rs @@ -1,5 +1,6 @@ use std::{fs, io}; use std::path::Path; +use itertools::Itertools; use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichField}; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2_field::extension::Extendable; @@ -135,3 +136,14 @@ pub fn vec_to_array(vec: Vec) -> Result<[T; N]> { v.len() ))) } + +/// Computes `if b { v0 } else { v1 }`. +pub fn select_vec< + F: RichField + Extendable + Poseidon2, + const D: usize, +>(builder: &mut CircuitBuilder, b: BoolTarget, v0: &[Target], v1: &[Target]) -> Vec { + v0.iter() + .zip_eq(v1) + .map(|(t0, t1)| builder.select(b, *t0, *t1)) + .collect() +} diff --git a/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs b/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs index 36668cb..3148476 100644 --- a/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs @@ -10,7 +10,7 @@ use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::recursion::circuits::inner_circuit::InnerCircuit; use plonky2_field::extension::Extendable; -use crate::circuits::utils::{select_hash, vec_to_array}; +use crate::circuits::utils::{select_hash, select_vec, vec_to_array}; use crate::{error::CircuitError, Result}; use crate::recursion::tree2::leaf_circuit::LeafCircuit; @@ -78,6 +78,7 @@ impl< // circuit data for leaf let leaf_circ_data = leaf_circuit.get_circuit_data::()?; + // common data for leaf let leaf_common = leaf_circ_data.common.clone(); @@ -94,8 +95,6 @@ impl< let inner_cyclic_pis = &leaf_proofs[i].public_inputs; leaf_pub_input_hashes.extend_from_slice(&inner_cyclic_pis[0..4]); } - // hash the public input so H(H_0, ..., H_N) - let leaf_pub_input_hash = builder.hash_n_to_hash_no_pad::(leaf_pub_input_hashes); // leaf verifier data // TODO: double check that it is ok for this verifier data to be private/witness @@ -140,11 +139,11 @@ impl< let inner_cyclic_pis = &inner_cyclic_proof_with_pis[i].public_inputs; inner_pub_input_hashes.extend_from_slice(&inner_cyclic_pis[0..4]); } - // hash all the node public input h = H(h_1 | h_2 | ... | h_N) - // TODO: optimize by removing the need for 2 hashes and instead select then hash - let inner_pub_input_hash = builder.hash_n_to_hash_no_pad::(inner_pub_input_hashes); - let node_hash_or_leaf_hash = select_hash(&mut builder, condition, leaf_pub_input_hash, inner_pub_input_hash); + // select the public input - either leaf or node + let pub_input_to_be_hashed = select_vec(&mut builder, condition, &leaf_pub_input_hashes ,&inner_pub_input_hashes); + // hash all the node public input h = H(h_1 | h_2 | ... | h_N) + let node_hash_or_leaf_hash= builder.hash_n_to_hash_no_pad::(pub_input_to_be_hashed); builder.connect_hashes(pub_input_hash,node_hash_or_leaf_hash);