From 9eefa78c24bbea12951a45a5b82146e1727c943a Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Thu, 17 Oct 2024 21:38:14 +0200 Subject: [PATCH] refactor circuits --- codex-plonky2-circuits/benches/prove_cells.rs | 8 +- .../src/circuits/prove_single_cell.rs | 104 ++++++++++++------ .../src/circuits/safe_tree_circuit.rs | 8 +- .../src/circuits/sample_cells.rs | 94 +++++++--------- codex-plonky2-circuits/src/circuits/utils.rs | 35 +++++- 5 files changed, 149 insertions(+), 100 deletions(-) diff --git a/codex-plonky2-circuits/benches/prove_cells.rs b/codex-plonky2-circuits/benches/prove_cells.rs index e52f6b3..aaa60b3 100644 --- a/codex-plonky2-circuits/benches/prove_cells.rs +++ b/codex-plonky2-circuits/benches/prove_cells.rs @@ -15,7 +15,7 @@ use plonky2::hash::hash_types::RichField; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use std::marker::PhantomData; use plonky2::plonk::circuit_builder::CircuitBuilder; -use codex_plonky2_circuits::circuits::prove_single_cell::SlotTree; +use codex_plonky2_circuits::circuits::prove_single_cell::SlotTreeCircuit; macro_rules! pretty_print { ($($arg:tt)*) => { @@ -28,7 +28,7 @@ macro_rules! pretty_print { type HF = PoseidonHash; fn prepare_data(N: usize) -> Result<( - SlotTree, + SlotTreeCircuit, Vec, Vec>, )> @@ -37,7 +37,7 @@ where H: Hasher + AlgebraicHasher + Hasher, { // Initialize the slot tree with default data - let slot_tree = SlotTree::::default(); + let slot_tree = SlotTreeCircuit::::default(); // Select N leaf indices to prove let leaf_indices: Vec = (0..N).collect(); @@ -52,7 +52,7 @@ where } fn build_circuit( - slot_tree: &SlotTree, + slot_tree: &SlotTreeCircuit, leaf_indices: &[usize], proofs: &[MerkleProof], ) -> Result<(CircuitData, PartialWitness)> diff --git a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs index 13ef578..4b6ebc8 100644 --- a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs +++ b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs @@ -35,14 +35,24 @@ use crate::circuits::params::{MAX_DEPTH, BOT_DEPTH, N_FIELD_ELEMS_PER_CELL, N_CE // ------ Slot Tree -------- #[derive(Clone)] -pub struct SlotTree> { - pub tree: MerkleTree, // slot tree - pub block_trees: Vec>, // vec of block trees +pub struct SlotTreeCircuit< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> { + pub tree: MerkleTreeCircuit, // slot tree + pub block_trees: Vec>, // vec of block trees pub cell_data: Vec>, // cell data as field elements pub cell_hash: Vec>, // hash of above } -impl> Default for SlotTree{ +impl< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> Default for SlotTreeCircuit{ /// slot tree with fake data, for testing only fn default() -> Self { // generate fake cell data @@ -69,19 +79,38 @@ impl> Default for SlotTree{ .map(|i| { let start = i * N_CELLS_IN_BLOCKS; let end = (i + 1) * N_CELLS_IN_BLOCKS; - Self::get_block_tree(&leaves[start..end].to_vec()) // use helper function + let b_tree = Self::get_block_tree(&leaves[start..end].to_vec()); // use helper function + MerkleTreeCircuit::{ tree:b_tree, _phantom:Default::default()} }) .collect::>(); // get the roots or block trees let block_roots = block_trees.iter() .map(|t| { - t.root().unwrap() + t.tree.root().unwrap() }) .collect::>(); // create slot tree let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); + // let mt = + // MerkleTree::{ + // tree: slot_tree, + // block_trees, + // cell_data, + // cell_hash: leaves, + // } + + // create block circuits + // let block_circuits = block_trees.iter() + // .map(|b_tree| { + // // let start = i * N_CELLS_IN_BLOCKS; + // // let end = (i + 1) * N_CELLS_IN_BLOCKS; + // // Self::get_block_tree(&leaves[start..end].to_vec()) // use helper function + // MerkleTreeCircuit::{ tree:b_tree.clone(), _phantom:Default::default()}, + // }) + // .collect::>(); + Self{ - tree: slot_tree, + tree: MerkleTreeCircuit::{ tree:slot_tree, _phantom:Default::default()}, block_trees, cell_data, cell_hash: leaves, @@ -89,7 +118,12 @@ impl> Default for SlotTree{ } } -impl> SlotTree { +impl< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> SlotTreeCircuit { /// same as default but with supplied cell data pub fn new(cell_data: Vec>) -> Self{ @@ -106,17 +140,18 @@ impl> SlotTree { .map(|i| { let start = i * N_CELLS_IN_BLOCKS; let end = (i + 1) * N_CELLS_IN_BLOCKS; - Self::get_block_tree(&leaves[start..end].to_vec()) + let b_tree = Self::get_block_tree(&leaves[start..end].to_vec()); + MerkleTreeCircuit::{ tree:b_tree, _phantom:Default::default()} }) .collect::>(); let block_roots = block_trees.iter() .map(|t| { - t.root().unwrap() + t.tree.root().unwrap() }) .collect::>(); let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); Self{ - tree: slot_tree, + tree: MerkleTreeCircuit::{ tree:slot_tree, _phantom:Default::default()}, block_trees, cell_data, cell_hash: leaves, @@ -128,8 +163,8 @@ impl> SlotTree { pub fn get_proof(&self, index: usize) -> MerkleProof { let block_index = index/ N_CELLS_IN_BLOCKS; let leaf_index = index % N_CELLS_IN_BLOCKS; - let block_proof = self.block_trees[block_index].get_proof(leaf_index).unwrap(); - let slot_proof = self.tree.get_proof(block_index).unwrap(); + let block_proof = self.block_trees[block_index].tree.get_proof(leaf_index).unwrap(); + let slot_proof = self.tree.tree.get_proof(block_index).unwrap(); // Combine the paths from the block and slot proofs let mut combined_path = block_proof.path.clone(); @@ -213,10 +248,10 @@ impl< C: GenericConfig, const D: usize, H: Hasher + AlgebraicHasher + Hasher, -> MerkleTreeCircuit { +> SlotTreeCircuit { - pub fn prove_single_cell2( - &mut self, + pub fn prove_single_cell( + // &mut self, builder: &mut CircuitBuilder:: ) -> SingleCellTargets { @@ -258,7 +293,7 @@ impl< }; // reconstruct block root - let block_root = self.reconstruct_merkle_root_circuit(builder, &mut block_targets); + let block_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut block_targets); // create MerkleTreeTargets struct let mut slot_targets = MerkleTreeTargets { @@ -270,7 +305,7 @@ impl< }; // reconstruct slot root with block root as leaf - let slot_root = self.reconstruct_merkle_root_circuit(builder, &mut slot_targets); + let slot_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut slot_targets); // check equality with expected root for i in 0..NUM_HASH_OUT_ELTS { @@ -305,7 +340,7 @@ impl< /// this takes leaf_index, leaf, and proof (generated from slot_tree) /// and fills all required circuit targets(circuit inputs) pub fn single_cell_assign_witness( - &mut self, + &self, pw: &mut PartialWitness, targets: &mut SingleCellTargets, leaf_index: usize, @@ -343,7 +378,7 @@ impl< } // assign the expected Merkle root to the target - let expected_root = self.tree.root()?; + let expected_root = self.tree.tree.root()?; // TODO: fix this HashOutTarget later same issue as above let expected_root_hash_out = expected_root.to_vec(); for j in 0..expected_root_hash_out.len() { @@ -367,49 +402,46 @@ mod tests { use plonky2::iop::witness::PartialWitness; //types for tests - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; type H = PoseidonHash; #[test] fn test_prove_single_cell(){ - let slot_t = SlotTree::::default(); + let slot_t = SlotTreeCircuit::::default(); let index = 8; let proof = slot_t.get_proof(index); - let res = slot_t.verify_cell_proof(proof,slot_t.tree.root().unwrap()).unwrap(); + let res = slot_t.verify_cell_proof(proof,slot_t.tree.tree.root().unwrap()).unwrap(); assert_eq!(res, true); } #[test] fn test_cell_build_circuit() -> Result<()> { - // circuit params - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type H = PoseidonHash; - let slot_t = SlotTree::::default(); + let slot_t = SlotTreeCircuit::::default(); // select leaf index to prove let leaf_index: usize = 8; let proof = slot_t.get_proof(leaf_index); // get the expected Merkle root - let expected_root = slot_t.tree.root().unwrap(); + let expected_root = slot_t.tree.tree.root().unwrap(); let res = slot_t.verify_cell_proof(proof.clone(),expected_root).unwrap(); assert_eq!(res, true); // create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let mut circuit_instance = MerkleTreeCircuit:: { - tree: slot_t.tree.clone(), - _phantom: PhantomData, - }; - let mut targets = circuit_instance.prove_single_cell2(&mut builder); + // let mut circuit_instance = MerkleTreeCircuit:: { + // tree: slot_t.tree.clone(), + // _phantom: PhantomData, + // }; + let mut targets = SlotTreeCircuit::::prove_single_cell(&mut builder); // create a PartialWitness and assign let mut pw = PartialWitness::new(); - circuit_instance.single_cell_assign_witness(&mut pw, &mut targets, leaf_index, &slot_t.cell_data[leaf_index], proof)?; + slot_t.single_cell_assign_witness(&mut pw, &mut targets, leaf_index, &slot_t.cell_data[leaf_index], proof)?; // build the circuit let data = builder.build::(); diff --git a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs index 96e4bbf..5b7eb98 100644 --- a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs +++ b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs @@ -28,7 +28,7 @@ use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER}; // note: this omits the mask bits since in plonky2 we can // uses the Plonk's permutation argument to check that two elements are equal. // TODO: double check the need for mask -// #[derive(Clone)] +#[derive(Clone)] pub struct MerkleTreeTargets< F: RichField + Extendable + Poseidon2, C: GenericConfig, @@ -44,7 +44,7 @@ pub struct MerkleTreeTargets< /// Merkle tree circuit contains the tree and functions for /// building, proving and verifying the circuit. -// #[derive(Clone)] +#[derive(Clone)] pub struct MerkleTreeCircuit< F: RichField + Extendable + Poseidon2, C: GenericConfig, @@ -94,7 +94,7 @@ impl< }; // Add Merkle proof verification constraints to the circuit - self.reconstruct_merkle_root_circuit(builder, &mut targets); + Self::reconstruct_merkle_root_circuit(builder, &mut targets); // Return MerkleTreeTargets targets @@ -174,7 +174,7 @@ impl< /// takes the params from the targets struct /// outputs the reconstructed merkle root pub fn reconstruct_merkle_root_circuit( - &self, + // &self, builder: &mut CircuitBuilder, targets: &mut MerkleTreeTargets, ) -> HashOutTarget { diff --git a/codex-plonky2-circuits/src/circuits/sample_cells.rs b/codex-plonky2-circuits/src/circuits/sample_cells.rs index b6d0dd3..c968f32 100644 --- a/codex-plonky2-circuits/src/circuits/sample_cells.rs +++ b/codex-plonky2-circuits/src/circuits/sample_cells.rs @@ -27,17 +27,23 @@ use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::hash::hashing::PlonkyPermutation; -use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTree}; +use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTreeCircuit}; use crate::circuits::params::{MAX_DEPTH, BOT_DEPTH, N_FIELD_ELEMS_PER_CELL, N_CELLS_IN_BLOCKS, N_BLOCKS, N_CELLS, HF, DATASET_DEPTH, N_SAMPLES}; use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets}; +use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits}; // ------ Dataset Tree -------- ///dataset tree containing all slot trees #[derive(Clone)] -pub struct DatasetTree> { - pub tree: MerkleTree, // dataset tree - pub slot_trees: Vec>, // vec of slot trees +pub struct DatasetTreeCircuit< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> { + pub tree: MerkleTreeCircuit, // dataset tree + pub slot_trees: Vec>, // vec of slot trees } /// Dataset Merkle proof struct, containing the dataset proof and N_SAMPLES proofs. @@ -49,18 +55,23 @@ pub struct DatasetMerkleProof> { pub slot_proofs: Vec>, // proofs for sampled slot, contains N_SAMPLES proofs } -impl> Default for DatasetTree { +impl< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> Default for DatasetTreeCircuit { /// dataset tree with fake data, for testing only fn default() -> Self { let mut slot_trees = vec![]; let n_slots = 1<::default()); + slot_trees.push(SlotTreeCircuit::::default()); } // get the roots or slot trees let slot_roots = slot_trees.iter() .map(|t| { - t.tree.root().unwrap() + t.tree.tree.root().unwrap() }) .collect::>(); // zero hash @@ -69,20 +80,25 @@ impl> Default for DatasetTree { }; let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); Self{ - tree: dataset_tree, + tree: MerkleTreeCircuit::{ tree:dataset_tree, _phantom:Default::default()}, slot_trees, } } } -impl> DatasetTree { +impl< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> DatasetTreeCircuit { /// same as default but with supplied slot trees - pub fn new(slot_trees: Vec>) -> Self{ + pub fn new(slot_trees: Vec>) -> Self{ // get the roots or slot trees let slot_roots = slot_trees.iter() .map(|t| { - t.tree.root().unwrap() + t.tree.tree.root().unwrap() }) .collect::>(); // zero hash @@ -91,7 +107,7 @@ impl> DatasetTree { }; let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); Self{ - tree: dataset_tree, + tree: MerkleTreeCircuit::{ tree:dataset_tree, _phantom:Default::default()}, slot_trees, } } @@ -99,16 +115,16 @@ impl> DatasetTree { /// generates a dataset level proof for given slot index /// just a regular merkle tree proof pub fn get_proof(&self, index: usize) -> MerkleProof { - let dataset_proof = self.tree.get_proof(index).unwrap(); + let dataset_proof = self.tree.tree.get_proof(index).unwrap(); dataset_proof } /// generates a proof for given slot index /// also takes entropy so it can use it sample the slot pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetMerkleProof { - let dataset_proof = self.get_proof(index); + let dataset_proof = self.tree.tree.get_proof(index).unwrap(); let slot = &self.slot_trees[index]; - let slot_root = slot.tree.root().unwrap(); + let slot_root = slot.tree.tree.root().unwrap(); let mut slot_proofs = vec![]; // get the index for cell from H(slot_root|counter|entropy) for i in 0..N_SAMPLES { @@ -128,9 +144,9 @@ impl> DatasetTree { // verify the sampling - non-circuit version pub fn verify_sampling(&self, proof: DatasetMerkleProof) -> Result{ let slot = &self.slot_trees[proof.slot_index]; - let slot_root = slot.tree.root().unwrap(); + let slot_root = slot.tree.tree.root().unwrap(); // check dataset level proof - let d_res = proof.dataset_proof.verify(slot_root,self.tree.root().unwrap()); + let d_res = proof.dataset_proof.verify(slot_root,self.tree.tree.root().unwrap()); if(d_res.unwrap() == false){ return Ok(false); } @@ -180,7 +196,7 @@ impl< C: GenericConfig, const D: usize, H: Hasher + AlgebraicHasher + Hasher, -> MerkleTreeCircuit { +> DatasetTreeCircuit { // the in-circuit sampling of a slot in a dataset pub fn sample_slot_circuit( @@ -192,7 +208,7 @@ impl< // let slot_root = builder.add_virtual_hash(); let mut slot_proofs =vec![]; for i in 0..N_SAMPLES{ - let proof_i = self.prove_single_cell2(builder); + let proof_i = SlotTreeCircuit::::prove_single_cell(builder); slot_proofs.push(proof_i); } @@ -212,7 +228,7 @@ impl< &mut self, pw: &mut PartialWitness, targets: DatasetTargets, - dataset_tree: DatasetTree, + dataset_tree: DatasetTreeCircuit, slot_index:usize, entropy:usize, ){ @@ -222,38 +238,6 @@ impl< } - -// --------- helper functions -------------- -fn calculate_cell_index_bits(p0: usize, p1: HashOut, p2: usize) -> Vec { - let p0_field = F::from_canonical_u64(p0 as u64); - let p2_field = F::from_canonical_u64(p2 as u64); - let mut inputs = Vec::new(); - inputs.extend_from_slice(&p1.elements); - inputs.push(p0_field); - inputs.push(p2_field); - let p_hash = HF::hash_no_pad(&inputs); - let p_bytes = p_hash.to_bytes(); - - let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH); - p_bits -} -fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec { - bytes.iter() - .flat_map(|byte| (0..8u8).map(move |i| (byte >> i) & 1 == 1)) - .take(n) - .collect() -} -/// Converts a vector of bits (LSB first) into an index (usize). -fn bits_le_padded_to_usize(bits: &[bool]) -> usize { - bits.iter().enumerate().fold(0usize, |acc, (i, &bit)| { - if bit { - acc | (1 << i) - } else { - acc - } - }) -} - #[cfg(test)] mod tests { use std::time::Instant; @@ -263,12 +247,14 @@ mod tests { use plonky2::iop::witness::PartialWitness; //types for tests - type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; type H = PoseidonHash; #[test] fn test_sample_cells() { - let dataset_t = DatasetTree::::default(); + let dataset_t = DatasetTreeCircuit::::default(); let slot_index = 2; let entropy = 123; let proof = dataset_t.sample_slot(slot_index,entropy); diff --git a/codex-plonky2-circuits/src/circuits/utils.rs b/codex-plonky2-circuits/src/circuits/utils.rs index 77f62e6..e493b4a 100644 --- a/codex-plonky2-circuits/src/circuits/utils.rs +++ b/codex-plonky2-circuits/src/circuits/utils.rs @@ -1,5 +1,6 @@ - - +use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use crate::circuits::params::{HF, MAX_DEPTH}; // --------- helper functions --------- @@ -14,4 +15,34 @@ pub(crate) fn usize_to_bits_le_padded(index: usize, bit_length: usize) -> Vec(p0: usize, p1: HashOut, p2: usize) -> Vec { + let p0_field = F::from_canonical_u64(p0 as u64); + let p2_field = F::from_canonical_u64(p2 as u64); + let mut inputs = Vec::new(); + inputs.extend_from_slice(&p1.elements); + inputs.push(p0_field); + inputs.push(p2_field); + let p_hash = HF::hash_no_pad(&inputs); + let p_bytes = p_hash.to_bytes(); + + let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH); + p_bits +} +pub(crate) fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec { + bytes.iter() + .flat_map(|byte| (0..8u8).map(move |i| (byte >> i) & 1 == 1)) + .take(n) + .collect() +} +/// Converts a vector of bits (LSB first) into an index (usize). +pub(crate) fn bits_le_padded_to_usize(bits: &[bool]) -> usize { + bits.iter().enumerate().fold(0usize, |acc, (i, &bit)| { + if bit { + acc | (1 << i) + } else { + acc + } + }) } \ No newline at end of file