From 26a0a6b6753c6ea7a6d2910a8e893f2e860dab49 Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Mon, 23 Jun 2025 15:43:26 +0200 Subject: [PATCH] improve circuit input generation and refactor --- proof-input/src/data_structs.rs | 69 ++-- proof-input/src/gen_input.rs | 345 +++++++++--------- proof-input/src/merkle_tree/key_compress.rs | 1 - proof-input/src/merkle_tree/merkle_circuit.rs | 12 +- proof-input/src/merkle_tree/merkle_safe.rs | 109 +++--- proof-input/src/merkle_tree/test.rs | 19 +- proof-input/src/recursion/mod.rs | 5 +- proof-input/src/recursion/node_test.rs | 8 +- proof-input/src/recursion/tree_test.rs | 3 +- proof-input/src/recursion/wrap_test.rs | 2 +- .../src/serialization/circuit_input.rs | 41 +-- proof-input/src/sponge.rs | 11 +- proof-input/src/utils.rs | 4 + workflow/src/gen_input.rs | 7 +- 14 files changed, 300 insertions(+), 336 deletions(-) diff --git a/proof-input/src/data_structs.rs b/proof-input/src/data_structs.rs index 23552ab..a0ea9de 100755 --- a/proof-input/src/data_structs.rs +++ b/proof-input/src/data_structs.rs @@ -1,6 +1,7 @@ // Data structure used to generate the proof input use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::plonk::config::AlgebraicHasher; use plonky2_field::extension::Extendable; use codex_plonky2_circuits::circuits::sample_cells::Cell; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; @@ -14,9 +15,10 @@ use crate::utils::{bits_le_padded_to_usize, calculate_cell_index_bits, usize_to_ pub struct SlotTree< F: RichField + Extendable + Poseidon2, const D: usize, + H: AlgebraicHasher, > { - pub tree: MerkleTree, // slot tree - pub block_trees: Vec>, // vec of block trees + pub tree: MerkleTree, // slot tree + pub block_trees: Vec>, // vec of block trees pub cell_data: Vec>, // cell data as field elements pub params: InputParams, // parameters } @@ -24,7 +26,8 @@ pub struct SlotTree< impl< F: RichField + Extendable + Poseidon2, const D: usize, -> SlotTree { + H: AlgebraicHasher, +> SlotTree { /// Create a slot tree with fake data, for testing only pub fn new_default(params: &InputParams) -> Self { // generate fake cell data @@ -40,9 +43,7 @@ impl< .iter() .map(|element| hash_bytes_no_padding::(&element.data)) .collect(); - let zero = HashOut { - elements: [F::ZERO; 4], - }; + let n_blocks = params.n_blocks_test(); let n_cells_in_blocks = params.n_cells_in_blocks(); @@ -57,7 +58,7 @@ impl< .iter() .map(|t| t.root().unwrap()) .collect::>(); - let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); + let slot_tree = MerkleTree::::new(&block_roots).unwrap(); Self { tree: slot_tree, block_trees, @@ -68,7 +69,7 @@ impl< /// Generates a proof for the given leaf index /// The path in the proof is a combined block and slot path to make up the full path - pub fn get_proof(&self, index: usize) -> MerkleProof { + pub fn get_proof(&self, index: usize) -> MerkleProof { let block_index = index / self.params.n_cells_in_blocks(); let leaf_index = index % self.params.n_cells_in_blocks(); let block_proof = self.block_trees[block_index].get_proof(leaf_index).unwrap(); @@ -78,20 +79,16 @@ impl< let mut combined_path = block_proof.path.clone(); combined_path.extend(slot_proof.path.clone()); - MerkleProof:: { + MerkleProof::::new( index, - path: combined_path, - nleaves: self.cell_data.len(), - zero: block_proof.zero.clone(), - } + combined_path, + self.cell_data.len(), + ) } - fn get_block_tree(leaves: &Vec>) -> MerkleTree { - let zero = HashOut { - elements: [F::ZERO; 4], - }; + fn get_block_tree(leaves: &Vec>) -> MerkleTree { // Build the Merkle tree - let block_tree = MerkleTree::::new(leaves, zero).unwrap(); + let block_tree = MerkleTree::::new(leaves).unwrap(); block_tree } } @@ -102,9 +99,10 @@ impl< pub struct DatasetTree< F: RichField + Extendable + Poseidon2, const D: usize, + H: AlgebraicHasher, > { - pub tree: MerkleTree, // dataset tree - pub slot_trees: Vec>, // vec of slot trees + pub tree: MerkleTree, // dataset tree + pub slot_trees: Vec>, // vec of slot trees pub params: InputParams, // parameters } @@ -113,24 +111,26 @@ pub struct DatasetTree< pub struct DatasetProof< F: RichField + Extendable + Poseidon2, const D: usize, + H: AlgebraicHasher, > { pub slot_index: F, pub entropy: HashOut, - pub dataset_proof: MerkleProof, // proof for dataset level tree - pub slot_proofs: Vec>, // proofs for sampled slot + pub dataset_proof: MerkleProof, // proof for dataset level tree + pub slot_proofs: Vec>, // proofs for sampled slot pub cell_data: Vec>, } impl< F: RichField + Extendable + Poseidon2, const D: usize, -> DatasetTree { + H: AlgebraicHasher, +> DatasetTree { /// Dataset tree with fake data, for testing only pub fn new_default(params: &InputParams) -> Self { let mut slot_trees = vec![]; let n_slots = 1 << params.dataset_depth_test(); for _ in 0..n_slots { - slot_trees.push(SlotTree::::new_default(params)); + slot_trees.push(SlotTree::::new_default(params)); } Self::new(slot_trees, params.clone()) } @@ -144,15 +144,15 @@ impl< let zero = HashOut { elements: [F::ZERO; 4], }; - let zero_slot = SlotTree:: { - tree: MerkleTree::::new(&[zero.clone()], zero.clone()).unwrap(), + let zero_slot = SlotTree:: { + tree: MerkleTree::::new(&[zero.clone()]).unwrap(), block_trees: vec![], cell_data: vec![], params: params.clone(), }; for i in 0..n_slots { if i == params.testing_slot_index { - slot_trees.push(SlotTree::::new_default(params)); + slot_trees.push(SlotTree::::new_default(params)); } else { slot_trees.push(zero_slot.clone()); } @@ -162,7 +162,7 @@ impl< .iter() .map(|t| t.tree.root().unwrap()) .collect::>(); - let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); + let dataset_tree = MerkleTree::::new(&slot_roots).unwrap(); Self { tree: dataset_tree, slot_trees, @@ -171,17 +171,14 @@ impl< } /// Same as default but with supplied slot trees - pub fn new(slot_trees: Vec>, params: InputParams) -> Self { + pub fn new(slot_trees: Vec>, params: InputParams) -> Self { // get the roots of slot trees let slot_roots = slot_trees .iter() .map(|t| t.tree.root().unwrap()) .collect::>(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); + + let dataset_tree = MerkleTree::::new(&slot_roots).unwrap(); Self { tree: dataset_tree, slot_trees, @@ -192,7 +189,7 @@ impl< /// Generates a proof for the given slot index /// Also takes entropy so it can use it to sample the slot /// note: proofs are padded based on the params in self - pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetProof { + pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetProof { let mut dataset_proof = self.tree.get_proof(index).unwrap(); Self::pad_proof(&mut dataset_proof, self.params.dataset_max_depth()); @@ -234,7 +231,7 @@ impl< } } /// pad the proof with 0s until max_depth - pub fn pad_proof(merkle_proof: &mut MerkleProof, max_depth: usize){ + pub fn pad_proof(merkle_proof: &mut MerkleProof, max_depth: usize){ for _i in merkle_proof.path.len()..max_depth{ merkle_proof.path.push(HashOut::::ZERO); } diff --git a/proof-input/src/gen_input.rs b/proof-input/src/gen_input.rs index 0a16248..706dd83 100755 --- a/proof-input/src/gen_input.rs +++ b/proof-input/src/gen_input.rs @@ -1,230 +1,224 @@ +use std::marker::PhantomData; +use std::path::Path; use plonky2::hash::hash_types::RichField; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::params::{Params,InputParams}; use crate::utils::{bits_le_padded_to_usize, calculate_cell_index_bits, ceiling_log2, usize_to_bits_le}; use crate::merkle_tree::merkle_safe::MerkleProof; -use codex_plonky2_circuits::circuits::sample_cells::{MerklePath, SampleCircuit, SampleCircuitInput, SampleTargets}; -use plonky2::iop::witness::PartialWitness; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; -use plonky2::plonk::proof::ProofWithPublicInputs; +use codex_plonky2_circuits::circuits::sample_cells::{MerklePath, SampleCircuitInput}; +use plonky2::plonk::config::AlgebraicHasher; use crate::data_structs::DatasetTree; +use crate::serialization::circuit_input::export_circ_input_to_json; use crate::sponge::hash_bytes_no_padding; -use crate::params::{C, D, F, HF}; -/// generates circuit input (SampleCircuitInput) from fake data for testing +/// Input Generator to generates circuit input (SampleCircuitInput) /// which can be later stored into json see json.rs -pub fn gen_testing_circuit_input< +/// For now it generates input from fake data for testing +pub struct InputGenerator< F: RichField + Extendable + Poseidon2, const D: usize, ->(params: &InputParams) -> SampleCircuitInput{ - let dataset_t = DatasetTree::::new_for_testing(¶ms); - - let slot_index = params.testing_slot_index; // samples the specified slot - let entropy = params.entropy; // Use the entropy from Params - - let proof = dataset_t.sample_slot(slot_index, entropy); - let slot_root = dataset_t.slot_trees[slot_index].tree.root().unwrap(); - - let mut slot_paths = vec![]; - for i in 0..params.n_samples { - let path = proof.slot_proofs[i].path.clone(); - let mp = MerklePath::{ - path, - }; - slot_paths.push(mp); - } - - SampleCircuitInput:: { - entropy: proof.entropy, - dataset_root: dataset_t.tree.root().unwrap(), - slot_index: proof.slot_index.clone(), - slot_root, - n_cells_per_slot: F::from_canonical_usize(params.n_cells), - n_slots_per_dataset: F::from_canonical_usize(params.n_slots), - slot_proof: proof.dataset_proof.path.clone(), - cell_data: proof.cell_data.clone(), - merkle_paths: slot_paths, - } + H: AlgebraicHasher, +>{ + pub input_params: InputParams, + phantom_data: PhantomData<(F,H)> } -/// verifies the given circuit input. -/// this is non circuit version for sanity check -pub fn verify_circuit_input< +impl< F: RichField + Extendable + Poseidon2, const D: usize, ->(circ_input: SampleCircuitInput, params: &InputParams) -> bool{ - let slot_index = circ_input.slot_index.to_canonical_u64(); - let slot_root = circ_input.slot_root.clone(); - // check dataset level proof - let slot_proof = circ_input.slot_proof.clone(); - let dataset_path_bits = usize_to_bits_le(slot_index as usize, params.dataset_max_depth()); - let (dataset_last_bits, dataset_mask_bits) = ceiling_log2(params.n_slots, params.dataset_max_depth()); - let reconstructed_slot_root = MerkleProof::::reconstruct_root2( - slot_root, - dataset_path_bits, - dataset_last_bits, - slot_proof, - dataset_mask_bits, - params.max_slots.trailing_zeros() as usize, - ).unwrap(); - // assert reconstructed equals dataset root - assert_eq!(reconstructed_slot_root, circ_input.dataset_root.clone()); + H: AlgebraicHasher, +> InputGenerator { - // check each sampled cell - // get the index for cell from H(slot_root|counter|entropy) - let mask_bits = usize_to_bits_le(params.n_cells -1, params.max_depth); - for i in 0..params.n_samples { - let cell_index_bits = calculate_cell_index_bits( - &circ_input.entropy.elements.to_vec(), - slot_root, - i + 1, - params.max_depth, - mask_bits.clone(), - ); - - let cell_index = bits_le_padded_to_usize(&cell_index_bits); - - let s_res = verify_cell_proof(&circ_input, ¶ms, cell_index, i); - if s_res.unwrap() == false { - println!("call {} is false", i); - return false; + pub fn new(input_params: InputParams) -> Self{ + Self{ + input_params, + phantom_data: PhantomData::default(), } } - true -} -/// Verify the given proof for slot tree, checks equality with the given root -pub fn verify_cell_proof< - F: RichField + Extendable + Poseidon2, - const D: usize, ->(circ_input: &SampleCircuitInput, params: &InputParams, cell_index: usize, ctr: usize) -> anyhow::Result { - let mut block_path_bits = usize_to_bits_le(cell_index, params.max_depth); - let last_index = params.n_cells - 1; - let mut block_last_bits = usize_to_bits_le(last_index, params.max_depth); + pub fn default() -> Self{ + Self{ + input_params: Params::default().input_params, + phantom_data: PhantomData::default(), + } + } - let split_point = params.bot_depth(); + /// Generate circuit input and export to JSON + pub fn generate_and_export_circ_input_to_json< + P: AsRef, + >( + &self, + base_path: P, + ) -> anyhow::Result<()> { + let circ_input = self.gen_testing_circuit_input(); + export_circ_input_to_json(circ_input, base_path)?; - let slot_last_bits = block_last_bits.split_off(split_point); - let slot_path_bits = block_path_bits.split_off(split_point); + Ok(()) + } - // pub type HP = >::Permutation; - let leaf_hash = hash_bytes_no_padding::(&circ_input.cell_data[ctr].data); + /// returns exactly M default circuit input of all same circuit input + pub fn get_m_testing_circ_input(&self) -> [SampleCircuitInput; M]{ + let one_circ_input = self.gen_testing_circuit_input(); + let circ_input: [SampleCircuitInput; M] = (0..M) + .map(|_| one_circ_input.clone()) + .collect::>() + .try_into().unwrap(); + circ_input + } - let mut block_path = circ_input.merkle_paths[ctr].path.clone(); - let slot_path = block_path.split_off(split_point); + /// returns exactly M default circuit input of different circuit input + pub fn get_m_unique_testing_circ_input(&self) -> [SampleCircuitInput; M]{ + todo!() + } - let mut block_mask_bits = usize_to_bits_le(last_index, params.max_depth+1); - let mut slot_mask_bits = block_mask_bits.split_off(split_point); + /// generates circuit input (SampleCircuitInput) from fake data for testing + pub fn gen_testing_circuit_input(&self) -> SampleCircuitInput{ + let params = &self.input_params; + let dataset_t = DatasetTree::::new_for_testing(params); - block_mask_bits.push(false); - slot_mask_bits.push(false); + let slot_index = params.testing_slot_index; // samples the specified slot + let entropy = params.entropy; // Use the entropy from Params - let block_res = MerkleProof::::reconstruct_root2( - leaf_hash, - block_path_bits.clone(), - block_last_bits.clone(), - block_path, - block_mask_bits, - params.bot_depth(), - ); - let reconstructed_root = MerkleProof::::reconstruct_root2( - block_res.unwrap(), - slot_path_bits, - slot_last_bits, - slot_path, - slot_mask_bits, - params.max_depth - params.bot_depth(), - ); + let proof = dataset_t.sample_slot(slot_index, entropy); + let slot_root = dataset_t.slot_trees[slot_index].tree.root().unwrap(); - Ok(reconstructed_root.unwrap() == circ_input.slot_root) -} + let mut slot_paths = vec![]; + for i in 0..params.n_samples { + let path = proof.slot_proofs[i].path.clone(); + let mp = MerklePath::{ + path, + }; + slot_paths.push(mp); + } -/// build the sampling circuit -/// returns the proof and circuit data -pub fn build_circuit(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData, PartialWitness)>{ - let (data, pw, _) = build_circuit_with_targets(n_samples, slot_index).unwrap(); + SampleCircuitInput:: { + entropy: proof.entropy, + dataset_root: dataset_t.tree.root().unwrap(), + slot_index: proof.slot_index.clone(), + slot_root, + n_cells_per_slot: F::from_canonical_usize(params.n_cells), + n_slots_per_dataset: F::from_canonical_usize(params.n_slots), + slot_proof: proof.dataset_proof.path.clone(), + cell_data: proof.cell_data.clone(), + merkle_paths: slot_paths, + } + } - Ok((data, pw)) -} + /// verifies the given circuit input. + /// this is non circuit version for sanity check + pub fn verify_circuit_input< + >(&self, circ_input: SampleCircuitInput) -> bool{ + let params = self.input_params.clone(); + let slot_index = circ_input.slot_index.to_canonical_u64(); + let slot_root = circ_input.slot_root.clone(); + // check dataset level proof + let slot_proof = circ_input.slot_proof.clone(); + let dataset_path_bits = usize_to_bits_le(slot_index as usize, params.dataset_max_depth()); + let (dataset_last_bits, dataset_mask_bits) = ceiling_log2(params.n_slots, params.dataset_max_depth()); + let reconstructed_slot_root = MerkleProof::::reconstruct_root2( + slot_root, + dataset_path_bits, + dataset_last_bits, + slot_proof, + dataset_mask_bits, + params.max_slots.trailing_zeros() as usize, + ).unwrap(); + // assert reconstructed equals dataset root + assert_eq!(reconstructed_slot_root, circ_input.dataset_root.clone()); -/// build the sampling circuit , -/// returns the proof, circuit data, and targets -pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData, PartialWitness, SampleTargets)>{ - // get input - let mut params = Params::default(); - params.set_n_samples(n_samples); - let mut input_params = params.input_params; - input_params.testing_slot_index = slot_index; - let circ_input = gen_testing_circuit_input::(&input_params); + // check each sampled cell + // get the index for cell from H(slot_root|counter|entropy) + let mask_bits = usize_to_bits_le(params.n_cells -1, params.max_depth); + for i in 0..params.n_samples { + let cell_index_bits = calculate_cell_index_bits( + &circ_input.entropy.elements.to_vec(), + slot_root, + i + 1, + params.max_depth, + mask_bits.clone(), + ); - // Create the circuit - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + let cell_index = bits_le_padded_to_usize(&cell_index_bits); - let circuit_params = params.circuit_params; + let s_res = self.verify_cell_proof(&circ_input, cell_index, i); + if s_res.unwrap() == false { + println!("call {} is false", i); + return false; + } + } + true + } - // build the circuit - let circ = SampleCircuit::::new(circuit_params.clone()); - let targets = circ.sample_slot_circuit_with_public_input(&mut builder)?; + /// Verify the given proof for slot tree, checks equality with the given root + fn verify_cell_proof< + >(&self, circ_input: &SampleCircuitInput, cell_index: usize, ctr: usize) -> anyhow::Result { + let params = self.input_params.clone(); + let mut block_path_bits = usize_to_bits_le(cell_index, params.max_depth); + let last_index = params.n_cells - 1; + let mut block_last_bits = usize_to_bits_le(last_index, params.max_depth); - // Create a PartialWitness and assign - let mut pw = PartialWitness::new(); + let split_point = params.bot_depth(); - // assign a witness - circ.sample_slot_assign_witness(&mut pw, &targets, &circ_input)?; + let slot_last_bits = block_last_bits.split_off(split_point); + let slot_path_bits = block_path_bits.split_off(split_point); - // Build the circuit - let data = builder.build::(); + // pub type HP = >::Permutation; + let leaf_hash = hash_bytes_no_padding::(&circ_input.cell_data[ctr].data); - Ok((data, pw, targets)) -} + let mut block_path = circ_input.merkle_paths[ctr].path.clone(); + let slot_path = block_path.split_off(split_point); -/// prove the circuit -pub fn prove_circuit(data: &CircuitData, pw: &PartialWitness) -> anyhow::Result>{ - // Prove the circuit with the assigned witness - let proof_with_pis = data.prove(pw.clone())?; + let mut block_mask_bits = usize_to_bits_le(last_index, params.max_depth+1); + let mut slot_mask_bits = block_mask_bits.split_off(split_point); - Ok(proof_with_pis) -} + block_mask_bits.push(false); + slot_mask_bits.push(false); -/// returns exactly M default circuit input -pub fn get_m_default_circ_input() -> [SampleCircuitInput; M]{ - let params = Params::default().input_params; - // let one_circ_input = gen_testing_circuit_input::(¶ms); - // let circ_input: [SampleCircuitInput; M] = (0..M) - // .map(|_| one_circ_input.clone()) - // .collect::>() - // .try_into().unwrap(); - // circ_input - get_m_circ_input::(params) -} + let block_res = MerkleProof::::reconstruct_root2( + leaf_hash, + block_path_bits.clone(), + block_last_bits.clone(), + block_path, + block_mask_bits, + params.bot_depth(), + ); + let reconstructed_root = MerkleProof::::reconstruct_root2( + block_res.unwrap(), + slot_path_bits, + slot_last_bits, + slot_path, + slot_mask_bits, + params.max_depth - params.bot_depth(), + ); -/// returns exactly M default circuit input -pub fn get_m_circ_input(params: InputParams) -> [SampleCircuitInput; M]{ - // let params = Params::default().input_params; - let one_circ_input = gen_testing_circuit_input::(¶ms); - let circ_input: [SampleCircuitInput; M] = (0..M) - .map(|_| one_circ_input.clone()) - .collect::>() - .try_into().unwrap(); - circ_input + Ok(reconstructed_root.unwrap() == circ_input.slot_root) + } } #[cfg(test)] mod tests { use std::time::Instant; + use plonky2::hash::poseidon::PoseidonHash; + use plonky2::plonk::config::PoseidonGoldilocksConfig; + use plonky2::plonk::proof::ProofWithPublicInputs; + use plonky2_field::goldilocks_field::GoldilocksField; use super::*; use codex_plonky2_circuits::circuit_helper::Plonky2Circuit; use codex_plonky2_circuits::circuits::sample_cells::SampleCircuit; + // types used in all tests + type F = GoldilocksField; + const D: usize = 2; + type H = PoseidonHash; + type C = PoseidonGoldilocksConfig; + // Test sample cells (non-circuit) #[test] fn test_gen_verify_proof(){ - let params = Params::default().input_params; - let w = gen_testing_circuit_input::(¶ms); - assert!(verify_circuit_input::(w, ¶ms)); + let input_gen = InputGenerator::::default(); + let w = input_gen.gen_testing_circuit_input(); + assert!(input_gen.verify_circuit_input(w)); } // Test sample cells in-circuit for a selected slot @@ -235,10 +229,11 @@ mod tests { params.set_n_samples(10); let input_params = params.input_params; let circuit_params = params.circuit_params; - let circ_input = gen_testing_circuit_input::(&input_params); + let input_gen = InputGenerator::::new(input_params); + let circ_input = input_gen.gen_testing_circuit_input(); // build the circuit - let circ = SampleCircuit::::new(circuit_params.clone()); + let circ = SampleCircuit::::new(circuit_params.clone()); let (targets, data) = circ.build_with_standard_config()?; println!("circuit size = {:?}", data.common.degree_bits()); diff --git a/proof-input/src/merkle_tree/key_compress.rs b/proof-input/src/merkle_tree/key_compress.rs index b831478..7bf3ad3 100755 --- a/proof-input/src/merkle_tree/key_compress.rs +++ b/proof-input/src/merkle_tree/key_compress.rs @@ -28,7 +28,6 @@ pub fn key_compress< #[cfg(test)] mod tests { - // use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2_field::types::Field; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2Hash; diff --git a/proof-input/src/merkle_tree/merkle_circuit.rs b/proof-input/src/merkle_tree/merkle_circuit.rs index 0bfecef..726d4c2 100755 --- a/proof-input/src/merkle_tree/merkle_circuit.rs +++ b/proof-input/src/merkle_tree/merkle_circuit.rs @@ -129,7 +129,7 @@ mod tests { use plonky2::field::types::Field; #[test] - fn test_build_circuit() -> anyhow::Result<()> { + fn test_mt_build_circuit() -> anyhow::Result<()> { // circuit params const D: usize = 2; type C = PoseidonGoldilocksConfig; @@ -152,10 +152,7 @@ mod tests { .collect(); //initialize the Merkle tree - let zero_hash = HashOut { - elements: [GoldilocksField::ZERO; 4], - }; - let tree = MerkleTree::::new(&leaves, zero_hash)?; + let tree = MerkleTree::::new(&leaves)?; // select leaf index to prove let leaf_index: usize = 8; @@ -236,10 +233,7 @@ mod tests { }) .collect(); - let zero_hash = HashOut { - elements: [GoldilocksField::ZERO; 4], - }; - let tree = MerkleTree::::new(&leaves, zero_hash)?; + let tree = MerkleTree::::new(&leaves)?; let expected_root = tree.root()?; diff --git a/proof-input/src/merkle_tree/merkle_safe.rs b/proof-input/src/merkle_tree/merkle_safe.rs index b7d8765..c4b39de 100755 --- a/proof-input/src/merkle_tree/merkle_safe.rs +++ b/proof-input/src/merkle_tree/merkle_safe.rs @@ -1,14 +1,16 @@ -// Implementation of "safe" merkle tree +// Implementation of Codex specific "safe" merkle tree // consistent with the one in codex: // https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim +use std::marker::PhantomData; use anyhow::{ensure, Result}; use plonky2::hash::hash_types::{HashOut, RichField}; use std::ops::Shr; +use plonky2::plonk::config::AlgebraicHasher; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::merkle_tree::key_compress::key_compress; -use crate::params::HF; +use crate::utils::zero; // Constants for the keys used in compression pub const KEY_NONE: u64 = 0x0; @@ -21,24 +23,25 @@ pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; pub struct MerkleTree< F: RichField + Extendable + Poseidon2, const D: usize, + H: AlgebraicHasher, > { pub layers: Vec>>, - pub zero: HashOut, + phantom_data: PhantomData } impl< F: RichField + Extendable + Poseidon2, const D: usize, -> MerkleTree { + H: AlgebraicHasher, +> MerkleTree { /// Constructs a new Merkle tree from the given leaves. pub fn new( leaves: &[HashOut], - zero: HashOut, ) -> Result { - let layers = merkle_tree_worker::(leaves, zero, true)?; + let layers = merkle_tree_worker::(leaves, true)?; Ok(Self { layers, - zero, + phantom_data: PhantomData::default(), }) } @@ -60,7 +63,7 @@ impl< } /// Generates a Merkle proof for a given leaf index. - pub fn get_proof(&self, index: usize) -> Result> { + pub fn get_proof(&self, index: usize) -> Result> { let depth = self.depth(); let nleaves = self.leaves_count(); @@ -75,19 +78,18 @@ impl< let sibling = if j < m { self.layers[i][j] } else { - self.zero + zero::() }; path.push(sibling); k = k >> 1; m = (m + 1) >> 1; } - Ok(MerkleProof { + Ok(MerkleProof::new( index, path, nleaves, - zero: self.zero, - }) + )) } } @@ -95,9 +97,9 @@ impl< fn merkle_tree_worker< F: RichField + Extendable + Poseidon2, const D: usize, + H: AlgebraicHasher, >( xs: &[HashOut], - zero: HashOut, is_bottom_layer: bool, ) -> Result>>> { let m = xs.len(); @@ -113,7 +115,7 @@ fn merkle_tree_worker< for i in 0..halfn { let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE }; - let h = key_compress::(xs[2 * i], xs[2 * i + 1], key); + let h = key_compress::(xs[2 * i], xs[2 * i + 1], key); ys.push(h); } @@ -123,12 +125,12 @@ fn merkle_tree_worker< } else { KEY_ODD }; - let h = key_compress::(xs[n], zero, key); + let h = key_compress::(xs[n], zero::(), key); ys.push(h); } let mut layers = vec![xs.to_vec()]; - let mut upper_layers = merkle_tree_worker::(&ys, zero, false)?; + let mut upper_layers = merkle_tree_worker::(&ys, false)?; layers.append(&mut upper_layers); Ok(layers) @@ -139,17 +141,31 @@ fn merkle_tree_worker< pub struct MerkleProof< F: RichField + Extendable + Poseidon2, const D: usize, + H: AlgebraicHasher, > { pub index: usize, // Index of the leaf pub path: Vec>, // Sibling hashes from the leaf to the root pub nleaves: usize, // Total number of leaves - pub zero: HashOut, + phantom_data: PhantomData } impl< F: RichField + Extendable + Poseidon2, const D: usize, -> MerkleProof { + H: AlgebraicHasher, +> MerkleProof { + pub fn new( + index: usize, + path: Vec>, + nleaves: usize, + ) -> Self{ + Self{ + index, + path, + nleaves, + phantom_data: PhantomData::default(), + } + } /// Reconstructs the root hash from the proof and the given leaf. pub fn reconstruct_root(&self, leaf: HashOut) -> Result> { let mut m = self.nleaves; @@ -161,14 +177,14 @@ impl< let odd_index = (j & 1) != 0; if odd_index { // The index of the child is odd - h = key_compress::(*p, h, bottom_flag); + h = key_compress::(*p, h, bottom_flag); } else { if j == m - 1 { // Single child -> so odd node - h = key_compress::(h, *p, bottom_flag + 2); + h = key_compress::(h, *p, bottom_flag + 2); } else { // Even node - h = key_compress::(h, *p, bottom_flag); + h = key_compress::(h, *p, bottom_flag); } } bottom_flag = KEY_NONE; @@ -200,9 +216,9 @@ impl< let key = bottom + (2 * (odd as u64)); let odd_index = path_bits[i]; if odd_index { - h.push(key_compress::(*p, h[i], key)); + h.push(key_compress::(*p, h[i], key)); } else { - h.push(key_compress::(h[i], *p, key)); + h.push(key_compress::(h[i], *p, key)); } i += 1; } @@ -262,13 +278,12 @@ mod tests { y: HashOut, key: u64, ) -> HashOut { - key_compress::(x,y,key) + key_compress::(x,y,key) } fn make_tree( data: &[F], - zero: HashOut, - ) -> Result> { + ) -> Result> { // Hash the data to obtain leaf hashes let leaves: Vec> = data .iter() @@ -278,7 +293,7 @@ mod tests { }) .collect(); - MerkleTree::::new(&leaves, zero) + MerkleTree::::new(&leaves) } #[test] @@ -296,12 +311,8 @@ mod tests { }) .collect(); - let zero = HashOut { - elements: [F::ZERO; 4], - }; - // Build the Merkle tree - let tree = MerkleTree::::new(&leaves, zero)?; + let tree = MerkleTree::::new(&leaves)?; // Get the root let root = tree.root()?; @@ -329,11 +340,6 @@ mod tests { .map(|&element| H::hash_no_pad(&[element])) .collect(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let expected_root = compress( compress( @@ -366,7 +372,7 @@ mod tests { ); // Build the tree - let tree = make_tree(&data, zero)?; + let tree = make_tree(&data)?; // Get the computed root let computed_root = tree.root()?; @@ -390,11 +396,6 @@ mod tests { .map(|&element| H::hash_no_pad(&[element])) .collect(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let expected_root = compress( compress( @@ -418,7 +419,7 @@ mod tests { ), compress( leaf_hashes[6], - zero, + zero::(), KEY_ODD_AND_BOTTOM_LAYER, ), KEY_NONE, @@ -427,7 +428,7 @@ mod tests { ); // Build the tree - let tree = make_tree(&data, zero)?; + let tree = make_tree(&data)?; // Get the computed root let computed_root = tree.root()?; @@ -451,11 +452,6 @@ mod tests { .map(|&element| H::hash_no_pad(&[element])) .collect(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let expected_root = compress( compress( compress( @@ -493,17 +489,17 @@ mod tests { leaf_hashes[9], KEY_BOTTOM_LAYER, ), - zero, + zero::(), KEY_ODD, ), - zero, + zero::(), KEY_ODD, ), KEY_NONE, ); // Build the tree - let tree = make_tree(&data, zero)?; + let tree = make_tree(&data)?; // Get the computed root let computed_root = tree.root()?; @@ -527,13 +523,8 @@ mod tests { .map(|&element| H::hash_no_pad(&[element])) .collect(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - // Build the tree - let tree = MerkleTree::::new(&leaf_hashes, zero)?; + let tree = MerkleTree::::new(&leaf_hashes)?; // Get the root let expected_root = tree.root()?; diff --git a/proof-input/src/merkle_tree/test.rs b/proof-input/src/merkle_tree/test.rs index ec01e32..cdb824f 100755 --- a/proof-input/src/merkle_tree/test.rs +++ b/proof-input/src/merkle_tree/test.rs @@ -3,13 +3,14 @@ mod tests { use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2_field::extension::Extendable; - use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; + use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; use anyhow::Result; use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleTree}; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::Field; type F = GoldilocksField; + type H = Poseidon2Hash; const D: usize = 2; struct TestCase { @@ -35,9 +36,6 @@ mod tests { #[test] fn test_merkle_roots() -> Result<()> { - let zero = HashOut { - elements: [F::ZERO; 4], - }; let test_cases: Vec = vec![ TestCase { n: 1, digest: [0x232f21acc9d346d8, 0x2eba96d3a73822c1, 0x4163308f6d0eff64, 0x5190c2b759734aff] }, @@ -53,7 +51,7 @@ mod tests { let inputs = digest_seq::(n); // Build the Merkle tree - let tree = MerkleTree::::new(&inputs, zero.clone())?; + let tree = MerkleTree::::new(&inputs)?; // Get the computed root let computed_root = tree.root()?; @@ -160,14 +158,11 @@ mod tests { let mut found = false; for index in 0..num_indices { - let proof = MerkleProof:: { + let proof = MerkleProof::::new( index, - path: path_hashes.clone(), - nleaves: num_indices, - zero: HashOut { - elements: [F::ZERO; 4], - }, - }; + path_hashes.clone(), + num_indices, + ); // Reconstruct the root let reconstructed_root = proof.reconstruct_root(leaf.clone())?; diff --git a/proof-input/src/recursion/mod.rs b/proof-input/src/recursion/mod.rs index 54b106a..5d859d7 100755 --- a/proof-input/src/recursion/mod.rs +++ b/proof-input/src/recursion/mod.rs @@ -2,7 +2,7 @@ use plonky2::plonk::circuit_data::{ProverCircuitData, VerifierCircuitData}; use plonky2::plonk::proof::ProofWithPublicInputs; use codex_plonky2_circuits::circuit_helper::Plonky2Circuit; use codex_plonky2_circuits::circuits::sample_cells::SampleCircuit; -use crate::gen_input::gen_testing_circuit_input; +use crate::gen_input::InputGenerator; use crate::params::{C, D, F, HF, Params}; pub mod tree_test; @@ -16,7 +16,8 @@ pub fn run_sampling_circ() -> anyhow::Result<(ProofWithPublicInputs, Pr // Circuit that does the sampling - 100 samples let mut params = Params::default(); params.set_n_samples(100); - let one_circ_input = gen_testing_circuit_input::(¶ms.input_params); + let input_gen = InputGenerator::::new(params.input_params.clone()); + let one_circ_input = input_gen.gen_testing_circuit_input(); let samp_circ = SampleCircuit::::new(params.circuit_params); let (inner_tar, inner_data) = samp_circ.build_with_standard_config()?; diff --git a/proof-input/src/recursion/node_test.rs b/proof-input/src/recursion/node_test.rs index 96cb3c8..46fd71a 100755 --- a/proof-input/src/recursion/node_test.rs +++ b/proof-input/src/recursion/node_test.rs @@ -10,7 +10,7 @@ mod tests { use crate::recursion::leaf_test::tests::run_leaf_circ; use crate::recursion::run_sampling_circ; - fn run_node_circ(leaf_proofs: Vec>, leaf_verifier_data: VerifierCircuitData, flag: bool, index: usize) -> anyhow::Result<()> { + fn run_node_circ(leaf_proofs: Vec>, leaf_verifier_data: VerifierCircuitData, _flag: bool, index: usize) -> anyhow::Result<()> { // ------------------- Node -------------------- // N leaf proofs @@ -53,10 +53,10 @@ mod tests { fn test_real_node_circ() -> anyhow::Result<()> { let (inner_proof, _, inner_verifier) = run_sampling_circ()?; // this is a bit wasteful to build leaf twice, TODO: fix this - let (leaf_proof_1, _, leaf_verifier) = run_leaf_circ::<128>(inner_proof.clone(), inner_verifier.clone(), true, 0)?; - let (leaf_proof_2, _, leaf_verifier) = run_leaf_circ::<128>(inner_proof, inner_verifier, true, 1)?; + let (leaf_proof_1, _, _leaf_verifier_1) = run_leaf_circ::<128>(inner_proof.clone(), inner_verifier.clone(), true, 0)?; + let (leaf_proof_2, _, leaf_verifier_2) = run_leaf_circ::<128>(inner_proof, inner_verifier, true, 1)?; let leaf_proofs = vec![leaf_proof_1,leaf_proof_2]; - run_node_circ::<2,128>(leaf_proofs, leaf_verifier, true, 0) + run_node_circ::<2,128>(leaf_proofs, leaf_verifier_2, true, 0) } } \ No newline at end of file diff --git a/proof-input/src/recursion/tree_test.rs b/proof-input/src/recursion/tree_test.rs index e8b4a81..002ebf6 100755 --- a/proof-input/src/recursion/tree_test.rs +++ b/proof-input/src/recursion/tree_test.rs @@ -3,7 +3,6 @@ #[cfg(test)] mod tests { use plonky2::plonk::proof::{ProofWithPublicInputs}; - use codex_plonky2_circuits::circuit_helper::Plonky2Circuit; use crate::params::{F, D, C, HF}; use codex_plonky2_circuits::recursion::{tree::TreeRecursion}; use crate::recursion::run_sampling_circ; @@ -12,7 +11,7 @@ mod tests { //------------ sampling inner circuit ---------------------- // Circuit that does the sampling - 100 samples - let (inner_proof, inner_prover_data, inner_verifier_data) = run_sampling_circ()?; + let (inner_proof, _inner_prover_data, inner_verifier_data) = run_sampling_circ()?; let proofs: Vec> = (0..T).map(|_i| inner_proof.clone()).collect(); diff --git a/proof-input/src/recursion/wrap_test.rs b/proof-input/src/recursion/wrap_test.rs index cd6b4d8..f8b2019 100755 --- a/proof-input/src/recursion/wrap_test.rs +++ b/proof-input/src/recursion/wrap_test.rs @@ -69,7 +69,7 @@ mod tests { //------------ sampling inner circuit ---------------------- // Circuit that does the sampling - 100 samples - let (inner_proof, inner_prover_data, inner_verifier_data) = run_sampling_circ()?; + let (inner_proof, _inner_prover_data, inner_verifier_data) = run_sampling_circ()?; let proofs: Vec> = (0..T).map(|_i| inner_proof.clone()).collect(); diff --git a/proof-input/src/serialization/circuit_input.rs b/proof-input/src/serialization/circuit_input.rs index 3f42787..a368c32 100755 --- a/proof-input/src/serialization/circuit_input.rs +++ b/proof-input/src/serialization/circuit_input.rs @@ -7,8 +7,6 @@ use codex_plonky2_circuits::circuits::sample_cells::{Cell, MerklePath, SampleCir use std::fs::File; use std::io::{BufReader, Write}; use std::path::Path; -use crate::gen_input::gen_testing_circuit_input; -use crate::params::InputParams; use codex_plonky2_circuits::serialization::ensure_parent_directory_exists; pub const CIRC_INPUT_JSON: &str = "prover_data/input.json"; @@ -261,24 +259,6 @@ pub fn export_circ_input_to_json< Ok(()) } - -/// Function to generate circuit input and export to JSON -pub fn generate_and_export_circ_input_to_json< - F: RichField + Extendable + Poseidon2 + Serialize, - const D: usize, - P: AsRef, ->( - params: &InputParams, - base_path: P, -) -> anyhow::Result<()> { - - let circ_input = gen_testing_circuit_input::(params); - - export_circ_input_to_json(circ_input, base_path)?; - - Ok(()) -} - /// reads the json file, converts it to circuit input (SampleCircuitInput) and returns it pub fn import_circ_input_from_json< F: RichField + Extendable + Poseidon2, @@ -303,16 +283,16 @@ mod tests { use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleCircuitInput}; use plonky2::plonk::circuit_data::{ProverCircuitData, VerifierCircuitData}; use codex_plonky2_circuits::circuit_helper::Plonky2Circuit; - use crate::gen_input::{gen_testing_circuit_input, verify_circuit_input}; - use crate::serialization::circuit_input::{export_circ_input_to_json, generate_and_export_circ_input_to_json, import_circ_input_from_json}; + use crate::gen_input::InputGenerator; + use crate::serialization::circuit_input::{export_circ_input_to_json, import_circ_input_from_json}; // Test to generate the JSON file #[test] fn test_export_circ_input_to_json() -> anyhow::Result<()> { - // Create Params - let params = Params::default().input_params; + // Create InputGenerator + let input_gen = InputGenerator::::default(); // Export the circuit input to JSON - generate_and_export_circ_input_to_json::(¶ms, "../output/test/")?; + input_gen.generate_and_export_circ_input_to_json( "../output/test/")?; println!("Circuit input exported to input.json"); @@ -332,11 +312,11 @@ mod tests { // export the circuit input and then import it and checks equality #[test] fn test_export_import_circ_input() -> anyhow::Result<()> { - // Create Params instance - let params = Params::default().input_params; + // Create InputGenerator + let input_gen = InputGenerator::::default(); // Export the circuit input to JSON - let original_circ_input = gen_testing_circuit_input(¶ms); + let original_circ_input = input_gen.gen_testing_circuit_input(); export_circ_input_to_json(original_circ_input.clone(), "../output/test/")?; println!("circuit input exported to input.json"); @@ -387,14 +367,15 @@ mod tests { // NOTE: expects that the json input proof uses the default params #[test] fn test_read_json_and_verify() -> anyhow::Result<()> { - let params = Params::default().input_params; + // Create InputGenerator + let input_gen = InputGenerator::::default(); // Import the circuit input from JSON let imported_circ_input: SampleCircuitInput = import_circ_input_from_json("../output/test/")?; println!("circuit input imported from input.json"); // Verify the proof - let ver = verify_circuit_input(imported_circ_input, ¶ms); + let ver = input_gen.verify_circuit_input(imported_circ_input); assert!( ver, "Merkle proof verification failed" diff --git a/proof-input/src/sponge.rs b/proof-input/src/sponge.rs index 54471f1..d85a477 100755 --- a/proof-input/src/sponge.rs +++ b/proof-input/src/sponge.rs @@ -168,8 +168,15 @@ pub fn hash_bytes_to_m_no_padding< #[cfg(test)] mod tests { use plonky2::field::types::Field; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2Hash; use crate::sponge::hash_n_with_padding; - use crate::params::{D, F, HF}; + + // test types + pub const D: usize = 2; + pub type C = PoseidonGoldilocksConfig; + pub type F = >::F; + pub type H = Poseidon2Hash; #[test] fn test_sponge_hash_rate_8() { @@ -273,7 +280,7 @@ mod tests { .collect(); // Call the sponge function - let output = hash_n_with_padding::(&inputs); + let output = hash_n_with_padding::(&inputs); // Compare the outputs for (i, &out_elem) in output.elements.iter().enumerate() { diff --git a/proof-input/src/utils.rs b/proof-input/src/utils.rs index 40f23a1..f0f2702 100755 --- a/proof-input/src/utils.rs +++ b/proof-input/src/utils.rs @@ -100,3 +100,7 @@ pub fn ceiling_log2( (last_bits, mask) } + +pub fn zero + Poseidon2, const D: usize>() -> HashOut{ + HashOut { elements: [F::ZERO; 4],} +} diff --git a/workflow/src/gen_input.rs b/workflow/src/gen_input.rs index ca92e49..0d1a3b5 100755 --- a/workflow/src/gen_input.rs +++ b/workflow/src/gen_input.rs @@ -1,9 +1,9 @@ use std::time::Instant; use anyhow::Result; use proof_input::serialization::circuit_input::export_circ_input_to_json; -use proof_input::gen_input::gen_testing_circuit_input; +use proof_input::gen_input::InputGenerator; use proof_input::params::Params; -use proof_input::params::{D, F}; +use proof_input::params::{D, F, HF}; use crate::file_paths::SAMPLING_CIRC_BASE_PATH; pub fn run() -> Result<()> { @@ -12,7 +12,8 @@ pub fn run() -> Result<()> { // generate circuit input with given parameters let start_time = Instant::now(); - let circ_input = gen_testing_circuit_input::(¶ms.input_params); + let input_gen = InputGenerator::::new(params.input_params); + let circ_input = input_gen.gen_testing_circuit_input(); println!("Generating input time: {:?}", start_time.elapsed()); // export circuit parameters to json file