diff --git a/Cargo.toml b/Cargo.toml index 8e142ec..05c9ef2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ plonky2_maybe_rayon = { version = "1.0.0", default-features = false } itertools = { version = "0.12.1", default-features = false } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +thiserror = "2.0.10" [workspace.package] diff --git a/codex-plonky2-circuits/Cargo.toml b/codex-plonky2-circuits/Cargo.toml index c7c6e73..e42a5ff 100644 --- a/codex-plonky2-circuits/Cargo.toml +++ b/codex-plonky2-circuits/Cargo.toml @@ -13,6 +13,7 @@ serde = { workplace = true } serde_json = { workplace = true } plonky2 = { workplace = true } plonky2_field = { workplace = true } +thiserror = { workplace = true } plonky2_poseidon2 = { path = "../plonky2_poseidon2" } itertools = { workplace = true } plonky2_maybe_rayon = { workplace = true } diff --git a/codex-plonky2-circuits/src/circuits/keyed_compress.rs b/codex-plonky2-circuits/src/circuits/keyed_compress.rs index b609b49..59acf8d 100644 --- a/codex-plonky2-circuits/src/circuits/keyed_compress.rs +++ b/codex-plonky2-circuits/src/circuits/keyed_compress.rs @@ -1,52 +1,28 @@ -use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; +use plonky2::hash::hash_types::{ HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; use plonky2::hash::hashing::PlonkyPermutation; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::{AlgebraicHasher, Hasher}; +use plonky2::plonk::config::AlgebraicHasher; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -/// Compression function which takes two 256 bit inputs (HashOut) and u64 key (which is converted to field element in the function) -/// and returns a 256 bit output (HashOut / 4 Goldilocks field elems). -pub fn key_compress< - F: RichField + Extendable + Poseidon2, - const D: usize, - H:Hasher ->(x: HashOut, y: HashOut, key: u64) -> HashOut { - - debug_assert_eq!(x.elements.len(), NUM_HASH_OUT_ELTS); - debug_assert_eq!(y.elements.len(), NUM_HASH_OUT_ELTS); - - let key_field = F::from_canonical_u64(key); - - let mut perm = H::Permutation::new(core::iter::repeat(F::ZERO)); - perm.set_from_slice(&x.elements, 0); - perm.set_from_slice(&y.elements, NUM_HASH_OUT_ELTS); - perm.set_elt(key_field,NUM_HASH_OUT_ELTS*2); - - perm.permute(); - - HashOut { - elements: perm.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(), - } -} - -/// same as above but in-circuit +/// Compression function which takes two 256 bit inputs (HashOutTarget) and key Target +/// and returns a 256 bit output (HashOutTarget / 4 Targets). pub fn key_compress_circuit< F: RichField + Extendable + Poseidon2, const D: usize, H: AlgebraicHasher, >( builder: &mut CircuitBuilder, - x: Vec, - y: Vec, + x: HashOutTarget, + y: HashOutTarget, key: Target, ) -> HashOutTarget { let zero = builder.zero(); let mut state = H::AlgebraicPermutation::new(core::iter::repeat(zero)); - state.set_from_slice(&x, 0); - state.set_from_slice(&y, NUM_HASH_OUT_ELTS); + state.set_from_slice(&x.elements, 0); + state.set_from_slice(&y.elements, NUM_HASH_OUT_ELTS); state.set_elt(key, NUM_HASH_OUT_ELTS*2); state = builder.permute::(state); @@ -56,86 +32,3 @@ pub fn key_compress_circuit< } } -#[cfg(test)] -mod tests { - // use plonky2::hash::poseidon::PoseidonHash; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2_field::types::Field; - use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2Hash; - use super::*; - // types - pub const D: usize = 2; - pub type C = PoseidonGoldilocksConfig; - pub type F = >::F; - pub type H = Poseidon2Hash; - - /// tests the non-circuit key_compress with concrete cases - #[test] - pub fn test_key_compress(){ - let ref_inp_1: [F; 4] = [ - F::from_canonical_u64(0x0000000000000001), - F::from_canonical_u64(0x0000000000000002), - F::from_canonical_u64(0x0000000000000003), - F::from_canonical_u64(0x0000000000000004), - ]; - - let ref_inp_2: [F; 4] = [ - F::from_canonical_u64(0x0000000000000005), - F::from_canonical_u64(0x0000000000000006), - F::from_canonical_u64(0x0000000000000007), - F::from_canonical_u64(0x0000000000000008), - ]; - - let ref_out_key_0: [F; 4] = [ - F::from_canonical_u64(0xc4a4082f411ba790), - F::from_canonical_u64(0x98c2ed7546c44cce), - F::from_canonical_u64(0xc9404f373b78c979), - F::from_canonical_u64(0x65d6b3c998920f59), - ]; - - let ref_out_key_1: [F; 4] = [ - F::from_canonical_u64(0xca47449a05283778), - F::from_canonical_u64(0x08d3ced2020391ac), - F::from_canonical_u64(0xda461ea45670fb12), - F::from_canonical_u64(0x57f2c0b6c98a05c5), - ]; - - let ref_out_key_2: [F; 4] = [ - F::from_canonical_u64(0xe6fcec96a7a7f4b0), - F::from_canonical_u64(0x3002a22356daa551), - F::from_canonical_u64(0x899e2c1075a45f3f), - F::from_canonical_u64(0xf07e38ccb3ade312), - ]; - - let ref_out_key_3: [F; 4] = [ - F::from_canonical_u64(0x9930cff752b046fb), - F::from_canonical_u64(0x41570687cadcea0b), - F::from_canonical_u64(0x3ac093a5a92066c7), - F::from_canonical_u64(0xc45c75a3911cde87), - ]; - - // `HashOut` for inputs - let inp1 = HashOut { elements: ref_inp_1 }; - let inp2 = HashOut { elements: ref_inp_2 }; - - // Expected outputs - let expected_outputs = [ - ref_out_key_0, - ref_out_key_1, - ref_out_key_2, - ref_out_key_3, - ]; - - // Iterate over each key and test key_compress output - for (key, &expected) in expected_outputs.iter().enumerate() { - let output = key_compress::(inp1, inp2, key as u64); - - // Assert that output matches the expected result - assert_eq!(output.elements, expected, "Output mismatch for key: {}", key); - - println!("Test passed for key {}", key); - } - - } -} - diff --git a/codex-plonky2-circuits/src/circuits/merkle_circuit.rs b/codex-plonky2-circuits/src/circuits/merkle_circuit.rs index cd98e4e..ac24c12 100644 --- a/codex-plonky2-circuits/src/circuits/merkle_circuit.rs +++ b/codex-plonky2-circuits/src/circuits/merkle_circuit.rs @@ -2,7 +2,7 @@ // consistent with the one in codex: // https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/merkle.circom -use anyhow::Result; +// use anyhow::Result; use plonky2::{ field::{extension::Extendable, types::Field}, hash::hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, @@ -16,7 +16,14 @@ use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::circuits::keyed_compress::key_compress_circuit; use crate::circuits::params::HF; use crate::circuits::utils::{add_assign_hash_out_target, mul_hash_out_target}; -use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER}; +use crate::Result; +use crate::error::CircuitError; + +// Constants for the keys used in compression +pub const KEY_NONE: u64 = 0x0; +pub const KEY_BOTTOM_LAYER: u64 = 0x1; +pub const KEY_ODD: u64 = 0x2; +pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; /// Merkle tree targets representing the input to the circuit #[derive(Clone)] @@ -72,13 +79,34 @@ impl< builder: &mut CircuitBuilder, targets: &mut MerkleTreeTargets, max_depth: usize, - ) -> HashOutTarget { + ) -> Result { let mut state: Vec = Vec::with_capacity(max_depth+1); state.push(targets.leaf); let zero = builder.zero(); let one = builder.one(); let two = builder.two(); - debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len()); + + // --- Basic checks on input sizes. + let path_len = targets.path_bits.len(); + let proof_len = targets.merkle_path.path.len(); + let mask_len = targets.mask_bits.len(); + let last_len = targets.last_bits.len(); + + if path_len != proof_len { + return Err(CircuitError::PathBitsLengthMismatch(path_len, proof_len)); + } + + if mask_len != path_len + 1 { + return Err(CircuitError::MaskBitsLengthMismatch(mask_len, path_len+1)); + } + + if last_len != path_len { + return Err(CircuitError::LastBitsLengthMismatch(last_len, path_len)); + } + + if path_len != max_depth { + return Err(CircuitError::PathBitsMaxDepthMismatch(path_len, max_depth)); + } // compute is_last let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; @@ -115,7 +143,11 @@ impl< } // Compress them with a keyed-hash function - let combined_hash = key_compress_circuit::(builder, left, right, key); + let combined_hash = key_compress_circuit:: + (builder, + HashOutTarget::from_vec(left), + HashOutTarget::from_vec(right), + key); state.push(combined_hash); i += 1; @@ -129,7 +161,7 @@ impl< add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result); } - reconstructed_root + Ok(reconstructed_root) } } diff --git a/codex-plonky2-circuits/src/circuits/sample_cells.rs b/codex-plonky2-circuits/src/circuits/sample_cells.rs index afdd64d..d44abc1 100644 --- a/codex-plonky2-circuits/src/circuits/sample_cells.rs +++ b/codex-plonky2-circuits/src/circuits/sample_cells.rs @@ -5,20 +5,32 @@ // - reconstruct the dataset merkle root using the slot root as leaf // - samples multiple cells by calling the sample_cells -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichField}; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartialWitness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::GenericConfig; use std::marker::PhantomData; -use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use plonky2::hash::hashing::PlonkyPermutation; -use crate::circuits::params::{CircuitParams, HF}; -use crate::circuits::merkle_circuit::{MerkleProofTarget, MerkleTreeCircuit, MerkleTreeTargets}; -use crate::circuits::sponge::{hash_n_no_padding, hash_n_with_padding}; -use crate::circuits::utils::{assign_hash_out_targets, ceiling_log2}; +use plonky2::{ + field::extension::Extendable, + hash::{ + hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichField}, + hashing::PlonkyPermutation, + }, + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::circuit_builder::CircuitBuilder, +}; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; + +use crate::{ + circuits::{ + merkle_circuit::{MerkleProofTarget, MerkleTreeCircuit, MerkleTreeTargets}, + params::{CircuitParams, HF}, + sponge::{hash_n_no_padding, hash_n_with_padding}, + utils::{assign_hash_out_targets, ceiling_log2}, + }, + Result, + error::CircuitError, +}; /// circuit for sampling a slot in a dataset merkle tree #[derive(Clone, Debug)] @@ -116,14 +128,14 @@ impl< pub fn sample_slot_circuit_with_public_input( &self, builder: &mut CircuitBuilder::, - )-> SampleTargets { - let targets = self.sample_slot_circuit(builder); + ) -> Result { + let targets = self.sample_slot_circuit(builder)?; let mut pub_targets = vec![]; pub_targets.push(targets.slot_index); pub_targets.extend_from_slice(&targets.dataset_root.elements); pub_targets.extend_from_slice(&targets.entropy.elements); builder.register_public_inputs(&pub_targets); - targets + Ok(targets) } /// in-circuit sampling @@ -131,7 +143,7 @@ impl< pub fn sample_slot_circuit( &self, builder: &mut CircuitBuilder::, - )-> SampleTargets { + ) -> Result { // circuit params let CircuitParams { max_depth, @@ -144,7 +156,6 @@ impl< // constants let zero = builder.zero(); let one = builder.one(); - let two = builder.two(); // ***** prove slot root is in dataset tree ********* @@ -179,7 +190,7 @@ impl< // dataset reconstructed root let d_reconstructed_root = - MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut d_targets, max_log2_n_slots); + MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut d_targets, max_log2_n_slots)?; // expected Merkle root let d_expected_root = builder.add_virtual_hash(); // public input @@ -226,7 +237,7 @@ impl< let mut hash_inputs:Vec= Vec::new(); hash_inputs.extend_from_slice(&data_i); // let data_i_hash = builder.hash_n_to_hash_no_pad::(hash_inputs); - let data_i_hash = hash_n_no_padding::(builder, hash_inputs); + let data_i_hash = hash_n_no_padding::(builder, hash_inputs)?; // make the counter into hash digest let ctr_target = builder.constant(F::from_canonical_u64((i+1) as u64)); let mut ctr = builder.add_virtual_hash(); @@ -238,7 +249,7 @@ impl< } } // paths for block and slot - let mut b_path_bits = self.calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr, mask_bits.clone()); + let mut b_path_bits = self.calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr, mask_bits.clone())?; let mut s_path_bits = b_path_bits.split_off(block_tree_depth); let mut b_merkle_path = MerkleProofTarget { @@ -258,7 +269,7 @@ impl< }; // reconstruct block root - let b_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets, block_tree_depth); + let b_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets, block_tree_depth)?; let mut slot_targets = MerkleTreeTargets { leaf: b_root, @@ -269,7 +280,7 @@ impl< }; // reconstruct slot root with block root as leaf - let slot_reconstructed_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets, max_depth-block_tree_depth); + let slot_reconstructed_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets, max_depth-block_tree_depth)?; // check equality with expected root for i in 0..NUM_HASH_OUT_ELTS { @@ -290,7 +301,7 @@ impl< } - SampleTargets { + let st = SampleTargets { entropy: entropy_target, dataset_root: d_expected_root, slot_index, @@ -300,17 +311,19 @@ impl< slot_proof: d_targets.merkle_path, cell_data: data_targets, merkle_paths: slot_sample_proofs, - } + }; + + Ok(st) } /// calculate the cell index = H( entropy | slotRoot | counter ) `mod` nCells - pub fn calculate_cell_index_bits(&self, builder: &mut CircuitBuilder::, entropy: &HashOutTarget, slot_root: &HashOutTarget, ctr: &HashOutTarget, mask_bits: Vec) -> Vec { + pub fn calculate_cell_index_bits(&self, builder: &mut CircuitBuilder::, entropy: &HashOutTarget, slot_root: &HashOutTarget, ctr: &HashOutTarget, mask_bits: Vec) -> Result> { let mut hash_inputs:Vec= Vec::new(); hash_inputs.extend_from_slice(&entropy.elements); hash_inputs.extend_from_slice(&slot_root.elements); hash_inputs.extend_from_slice(&ctr.elements); - // let hash_out = builder.hash_n_to_hash_no_pad::(hash_inputs); - let hash_out = hash_n_with_padding::(builder, hash_inputs); + + let hash_out = hash_n_with_padding::(builder, hash_inputs)?; let cell_index_bits = builder.low_bits(hash_out.elements[0], self.params.max_depth, 64); let mut masked_cell_index_bits = vec![]; @@ -320,7 +333,7 @@ impl< masked_cell_index_bits.push(BoolTarget::new_unsafe(builder.mul(mask_bits[i].target, cell_index_bits[i].target))); } - masked_cell_index_bits + Ok(masked_cell_index_bits) } /// helper method to assign the targets in the circuit to actual field elems @@ -329,7 +342,7 @@ impl< pw: &mut PartialWitness, targets: &SampleTargets, witnesses: &SampleCircuitInput, - ){ + ) -> Result<()>{ // circuit params let CircuitParams { max_depth, @@ -340,41 +353,66 @@ impl< } = self.params; // assign n_cells_per_slot - pw.set_target(targets.n_cells_per_slot, witnesses.n_cells_per_slot); + pw.set_target(targets.n_cells_per_slot, witnesses.n_cells_per_slot) + .map_err(|e| { + CircuitError::TargetAssignmentError("n_cells_per_slot".to_string(), e.to_string()) + })?; // assign n_slots_per_dataset - pw.set_target(targets.n_slots_per_dataset, witnesses.n_slots_per_dataset); + pw.set_target(targets.n_slots_per_dataset, witnesses.n_slots_per_dataset) + .map_err(|e| { + CircuitError::TargetAssignmentError("n_slots_per_dataset".to_string(), e.to_string()) + })?; // assign dataset proof for (i, sibling_hash) in witnesses.slot_proof.iter().enumerate() { - pw.set_hash_target(targets.slot_proof.path[i], *sibling_hash); + pw.set_hash_target(targets.slot_proof.path[i], *sibling_hash) + .map_err(|e| { + CircuitError::HashTargetAssignmentError("slot_proof".to_string(), e.to_string()) + })?; } // assign slot index - pw.set_target(targets.slot_index, witnesses.slot_index); + pw.set_target(targets.slot_index, witnesses.slot_index) + .map_err(|e| { + CircuitError::TargetAssignmentError("slot_index".to_string(), e.to_string()) + })?; // assign the expected Merkle root of dataset to the target - pw.set_hash_target(targets.dataset_root, witnesses.dataset_root); + pw.set_hash_target(targets.dataset_root, witnesses.dataset_root) + .map_err(|e| { + CircuitError::HashTargetAssignmentError("dataset_root".to_string(), e.to_string()) + })?; // assign the sampled slot - pw.set_hash_target(targets.slot_root, witnesses.slot_root); + pw.set_hash_target(targets.slot_root, witnesses.slot_root) + .map_err(|e| { + CircuitError::HashTargetAssignmentError("slot_root".to_string(), e.to_string()) + })?; // assign entropy - assign_hash_out_targets(pw, &targets.entropy.elements, &witnesses.entropy.elements); + assign_hash_out_targets(pw, &targets.entropy, &witnesses.entropy)?; // do the sample N times for i in 0..n_samples { // assign cell data let leaf = witnesses.cell_data[i].data.clone(); for j in 0..n_field_elems_per_cell{ - pw.set_target(targets.cell_data[i].data[j], leaf[j]); + pw.set_target(targets.cell_data[i].data[j], leaf[j]) + .map_err(|e| { + CircuitError::TargetAssignmentError("cell_data".to_string(), e.to_string()) + })?; } // assign proof for that cell let cell_proof = witnesses.merkle_paths[i].path.clone(); for k in 0..max_depth { - pw.set_hash_target(targets.merkle_paths[i].path[k], cell_proof[k]); + pw.set_hash_target(targets.merkle_paths[i].path[k], cell_proof[k]) + .map_err(|e| { + CircuitError::HashTargetAssignmentError("merkle_paths".to_string(), e.to_string()) + })?; } } + Ok(()) } } diff --git a/codex-plonky2-circuits/src/circuits/sponge.rs b/codex-plonky2-circuits/src/circuits/sponge.rs index dee3d58..d87a175 100644 --- a/codex-plonky2-circuits/src/circuits/sponge.rs +++ b/codex-plonky2-circuits/src/circuits/sponge.rs @@ -5,6 +5,8 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::AlgebraicHasher; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; +use crate::error::CircuitError; +use crate::Result; /// hash n targets (field elements) into hash digest / HashOutTarget (4 Goldilocks field elements) /// this function uses the 10* padding @@ -15,8 +17,12 @@ pub fn hash_n_with_padding< >( builder: &mut CircuitBuilder, inputs: Vec, -) -> HashOutTarget { - HashOutTarget::from_vec( hash_n_to_m_with_padding::(builder, inputs, NUM_HASH_OUT_ELTS)) +) -> Result { + Ok( + HashOutTarget::from_vec( + hash_n_to_m_with_padding::(builder, inputs, NUM_HASH_OUT_ELTS)? + ) + ) } pub fn hash_n_to_m_with_padding< @@ -27,7 +33,7 @@ pub fn hash_n_to_m_with_padding< builder: &mut CircuitBuilder, inputs: Vec, num_outputs: usize, -) -> Vec { +) -> Result> { let rate = H::AlgebraicPermutation::RATE; let width = H::AlgebraicPermutation::WIDTH; // rate + capacity let zero = builder.zero(); @@ -51,7 +57,7 @@ pub fn hash_n_to_m_with_padding< chunk.push(input); } else { // should not happen here - panic!("Insufficient input elements for chunk; expected more elements."); + return Err(CircuitError::InsufficientInputs(rate,chunk.len())); } } // Add the chunk to the state @@ -96,7 +102,7 @@ pub fn hash_n_to_m_with_padding< for &s in state.squeeze() { outputs.push(s); if outputs.len() == num_outputs { - return outputs; + return Ok(outputs); } } state = builder.permute::(state); @@ -113,8 +119,12 @@ pub fn hash_n_no_padding< >( builder: &mut CircuitBuilder, inputs: Vec, -) -> HashOutTarget { - HashOutTarget::from_vec( hash_n_to_m_no_padding::(builder, inputs, NUM_HASH_OUT_ELTS)) +) -> Result { + Ok( + HashOutTarget::from_vec( + hash_n_to_m_no_padding::(builder, inputs, NUM_HASH_OUT_ELTS)? + ) + ) } pub fn hash_n_to_m_no_padding< @@ -125,11 +135,10 @@ pub fn hash_n_to_m_no_padding< builder: &mut CircuitBuilder, inputs: Vec, num_outputs: usize, -) -> Vec { +) -> Result> { let rate = H::AlgebraicPermutation::RATE; let width = H::AlgebraicPermutation::WIDTH; // rate + capacity let zero = builder.zero(); - let one = builder.one(); let mut state = H::AlgebraicPermutation::new(core::iter::repeat(zero).take(width)); // Set the domain separator at index 8 @@ -138,7 +147,9 @@ pub fn hash_n_to_m_no_padding< state.set_elt(dom_sep, 8); let n = inputs.len(); - assert_eq!(n % rate, 0, "Input length ({}) must be divisible by rate ({})", n, rate); + if n % rate != 0 { + return Err(CircuitError::SpongeInputLengthMismatch(n, rate)); + } let num_chunks = n / rate; // 10* padding let mut input_iter = inputs.iter(); @@ -150,7 +161,7 @@ pub fn hash_n_to_m_no_padding< chunk.push(input); } else { // should not happen here - panic!("Insufficient input elements for chunk; expected more elements."); + return Err(CircuitError::InsufficientInputs(rate,chunk.len())); } } // Add the chunk to the state @@ -166,9 +177,9 @@ pub fn hash_n_to_m_no_padding< for &s in state.squeeze() { outputs.push(s); if outputs.len() == num_outputs { - return outputs; + return Ok(outputs); } } state = builder.permute::(state); } -} +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/utils.rs b/codex-plonky2-circuits/src/circuits/utils.rs index a9ff178..49f95e9 100644 --- a/codex-plonky2-circuits/src/circuits/utils.rs +++ b/codex-plonky2-circuits/src/circuits/utils.rs @@ -1,11 +1,13 @@ use std::{fs, io}; use std::path::Path; -use plonky2::hash::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS, RichField}; +use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichField}; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::Result; +use crate::error::CircuitError; // --------- helper functions --------- @@ -44,24 +46,42 @@ pub fn assign_bool_targets< pw: &mut PartialWitness, bool_targets: &Vec, bools: Vec, -){ - for (i, bit) in bools.iter().enumerate() { - pw.set_bool_target(bool_targets[i], *bit); +) -> Result<()>{ + if bools.len() > bool_targets.len() { + return Err(CircuitError::AssignmentLengthMismatch ( + bool_targets.len(), + bools.len(), + ) + ); } + for (i, bit) in bools.iter().enumerate() { + pw.set_bool_target(bool_targets[i], *bit) + .map_err(|e| + CircuitError::ArrayBoolTargetAssignmentError(i, e.to_string()), + )?; + } + Ok(()) } /// assign a vec of field elems to hash out target elements +/// TODO: change to HashOut pub fn assign_hash_out_targets< F: RichField + Extendable + Poseidon2, const D: usize, >( pw: &mut PartialWitness, - hash_out_elements_targets: &[Target], - hash_out_elements: &[F], -){ - for j in 0..NUM_HASH_OUT_ELTS { - pw.set_target(hash_out_elements_targets[j], hash_out_elements[j]); + hash_out_elements_targets: &HashOutTarget, + hash_out_elements: &HashOut, +) -> Result<()>{ + + // Assign each field element to its corresponding target + for (j, (&target, &element)) in hash_out_elements_targets.elements.iter().zip(hash_out_elements.elements.iter()).enumerate() { + pw.set_target(target, element).map_err(|e| { + CircuitError::ArrayTargetAssignmentError(j, e.to_string()) + })?; } + + Ok(()) } /// helper fn to multiply a HashOutTarget by a Target diff --git a/codex-plonky2-circuits/src/error.rs b/codex-plonky2-circuits/src/error.rs new file mode 100644 index 0000000..bc62762 --- /dev/null +++ b/codex-plonky2-circuits/src/error.rs @@ -0,0 +1,47 @@ +use thiserror::Error; + +/// Custom error types for the Circuits. +#[derive(Error, Debug)] +pub enum CircuitError { + #[error("Path bits length mismatch: expected {0}, found {1}")] + PathBitsLengthMismatch(usize, usize), + + #[error("Mask bits length mismatch: expected {0}, found {1}")] + MaskBitsLengthMismatch(usize, usize), + + #[error("Last bits length mismatch: expected {0}, found {1}")] + LastBitsLengthMismatch(usize, usize), + + #[error("Path bits and max depth mismatch: path bits length {0}, max depth {1}")] + PathBitsMaxDepthMismatch(usize, usize), + + #[error("Sibling hash at depth {0} has invalid length: expected {1}, found {2}")] + SiblingHashInvalidLength(usize, usize, usize), + + #[error("Invalid path bits: expected {0}, found {1}")] + InvalidPathBits(usize, usize), + + #[error("Insufficient input elements for chunk; expected {0}, found {1}")] + InsufficientInputs (usize, usize), + + #[error("Sponge: Input length ({0}) must be divisible by rate ({1}) for no padding")] + SpongeInputLengthMismatch(usize, usize), + + #[error("Assignment length mismatch: expected at least {0}, found {1}")] + AssignmentLengthMismatch(usize, usize), + + #[error("Failed to assign Target at index {0}: {1}")] + ArrayTargetAssignmentError(usize, String), + + #[error("Failed to assign Target {0}: {1}")] + TargetAssignmentError(String, String), + + #[error("Failed to assign BoolTarget at index {0}: {1}")] + ArrayBoolTargetAssignmentError(usize, String), + + #[error("Failed to assign BoolTarget {0}: {1}")] + BoolTargetAssignmentError(String, String), + + #[error("Failed to assign HashTarget {0}: {1}")] + HashTargetAssignmentError(String, String), +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/lib.rs b/codex-plonky2-circuits/src/lib.rs index e10d8b9..ea4f9b8 100644 --- a/codex-plonky2-circuits/src/lib.rs +++ b/codex-plonky2-circuits/src/lib.rs @@ -1,3 +1,7 @@ pub mod circuits; -pub mod merkle_tree; -pub mod recursion; +// pub mod merkle_tree; +// pub mod recursion; +pub mod error; +pub mod params; + +pub type Result = core::result::Result; diff --git a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs deleted file mode 100644 index 112477b..0000000 --- a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs +++ /dev/null @@ -1,552 +0,0 @@ -// Implementation of "safe" merkle tree -// consistent with the one in codex: -// https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim - -use std::marker::PhantomData; -use anyhow::{ensure, Result}; -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::hash::hash_types::{HashOut, RichField}; -use plonky2::hash::poseidon::PoseidonHash; -use plonky2::plonk::config::Hasher; -use std::ops::Shr; -use plonky2_field::extension::Extendable; -use plonky2_field::types::Field; -use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use crate::circuits::keyed_compress::key_compress; -use crate::circuits::params::HF; - -// Constants for the keys used in compression -pub const KEY_NONE: u64 = 0x0; -pub const KEY_BOTTOM_LAYER: u64 = 0x1; -pub const KEY_ODD: u64 = 0x2; -pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; - -/// Merkle tree struct, containing the layers, compression function, and zero hash. -#[derive(Clone)] -pub struct MerkleTree< - F: RichField + Extendable + Poseidon2, - const D: usize, -> { - pub layers: Vec>>, - pub zero: HashOut, -} - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> MerkleTree { - /// Constructs a new Merkle tree from the given leaves. - pub fn new( - leaves: &[HashOut], - zero: HashOut, - ) -> Result { - let layers = merkle_tree_worker::(leaves, zero, true)?; - Ok(Self { - layers, - zero, - }) - } - - /// Returns the depth of the Merkle tree. - pub fn depth(&self) -> usize { - self.layers.len() - 1 - } - - /// Returns the number of leaves in the Merkle tree. - pub fn leaves_count(&self) -> usize { - self.layers[0].len() - } - - /// Returns the root hash of the Merkle tree. - pub fn root(&self) -> Result> { - let last_layer = self.layers.last().ok_or_else(|| anyhow::anyhow!("Empty tree"))?; - ensure!(last_layer.len() == 1, "Invalid Merkle tree"); - Ok(last_layer[0]) - } - - /// Generates a Merkle proof for a given leaf index. - pub fn get_proof(&self, index: usize) -> Result> { - let depth = self.depth(); - let nleaves = self.leaves_count(); - - ensure!(index < nleaves, "Index out of bounds"); - - let mut path = Vec::with_capacity(depth); - let mut k = index; - let mut m = nleaves; - - for i in 0..depth { - let j = k ^ 1; - let sibling = if j < m { - self.layers[i][j] - } else { - self.zero - }; - path.push(sibling); - k = k >> 1; - m = (m + 1) >> 1; - } - - Ok(MerkleProof { - index, - path, - nleaves, - zero: self.zero, - }) - } -} - -/// Build the Merkle tree layers. -fn merkle_tree_worker< - F: RichField + Extendable + Poseidon2, - const D: usize, ->( - xs: &[HashOut], - zero: HashOut, - is_bottom_layer: bool, -) -> Result>>> { - let m = xs.len(); - if !is_bottom_layer && m == 1 { - return Ok(vec![xs.to_vec()]); - } - - let halfn = m / 2; - let n = 2 * halfn; - let is_odd = n != m; - - let mut ys = Vec::with_capacity(halfn + if is_odd { 1 } else { 0 }); - - for i in 0..halfn { - let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE }; - let h = key_compress::(xs[2 * i], xs[2 * i + 1], key); - ys.push(h); - } - - if is_odd { - let key = if is_bottom_layer { - KEY_ODD_AND_BOTTOM_LAYER - } else { - KEY_ODD - }; - let h = key_compress::(xs[n], zero, key); - ys.push(h); - } - - let mut layers = vec![xs.to_vec()]; - let mut upper_layers = merkle_tree_worker::(&ys, zero, false)?; - layers.append(&mut upper_layers); - - Ok(layers) -} - -/// Merkle proof struct, containing the index, path, and other necessary data. -#[derive(Clone)] -pub struct MerkleProof< - F: RichField + Extendable + Poseidon2, - const D: usize, -> { - pub index: usize, // Index of the leaf - pub path: Vec>, // Sibling hashes from the leaf to the root - pub nleaves: usize, // Total number of leaves - pub zero: HashOut, -} - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> MerkleProof { - /// Reconstructs the root hash from the proof and the given leaf. - pub fn reconstruct_root(&self, leaf: HashOut) -> Result> { - let mut m = self.nleaves; - let mut j = self.index; - let mut h = leaf; - let mut bottom_flag = KEY_BOTTOM_LAYER; - - for p in &self.path { - let odd_index = (j & 1) != 0; - if odd_index { - // The index of the child is odd - h = key_compress::(*p, h, bottom_flag); - } else { - if j == m - 1 { - // Single child -> so odd node - h = key_compress::(h, *p, bottom_flag + 2); - } else { - // Even node - h = key_compress::(h, *p, bottom_flag); - } - } - bottom_flag = KEY_NONE; - j = j.shr(1); - m = (m + 1).shr(1); - } - - Ok(h) - } - - /// reconstruct the root using path_bits and last_bits in similar way as the circuit - /// this is used for testing - sanity check - pub fn reconstruct_root2(leaf: HashOut, path_bits: Vec, last_bits:Vec, path: Vec>, mask_bits:Vec, depth: usize) -> Result> { - let is_last = compute_is_last(path_bits.clone(),last_bits); - - let mut h = vec![]; - h.push(leaf); - let mut i = 0; - - for p in &path { - let bottom = if(i==0){ - KEY_BOTTOM_LAYER - }else{ - KEY_NONE - }; - - let odd = (is_last[i] as usize) * (1-(path_bits[i] as usize)); - - let key = bottom + (2 * (odd as u64)); - let odd_index = path_bits[i]; - if odd_index { - h.push(key_compress::(*p, h[i], key)); - } else { - h.push(key_compress::(h[i], *p, key)); - } - i += 1; - } - - let mut reconstructed_root = HashOut::::ZERO; - for k in 0..depth{ - let diff = (mask_bits[k] as u64) - (mask_bits[k+1] as u64); - let mul_res: Vec = h[k+1].elements.iter().map(|e| e.mul(F::from_canonical_u64(diff))).collect(); - reconstructed_root = HashOut::::from_vec( - mul_res.iter().zip(reconstructed_root.elements).map(|(e1,e2)| e1.add(e2)).collect() - ); - } - - Ok(reconstructed_root) - } - - /// Verifies the proof against a given root and leaf. - pub fn verify(&self, leaf: HashOut, root: HashOut) -> Result { - let reconstructed_root = self.reconstruct_root(leaf)?; - Ok(reconstructed_root == root) - } -} - -///helper function to compute is_last -fn compute_is_last(path_bits: Vec, last_bits: Vec) -> Vec { - let max_depth = path_bits.len(); - - // Initialize isLast vector - let mut is_last = vec![false; max_depth + 1]; - is_last[max_depth] = true; // Set isLast[max_depth] to 1 (true) - - // Iterate over eq and isLast in reverse order - for i in (0..max_depth).rev() { - let eq_out = path_bits[i] == last_bits[i]; // eq[i].out - is_last[i] = is_last[i + 1] && eq_out; // isLast[i] = isLast[i+1] * eq[i].out - } - - is_last -} - -#[cfg(test)] -mod tests { - use super::*; - use plonky2::field::types::Field; - use crate::circuits::keyed_compress::key_compress; - - // types used in all tests - type F = GoldilocksField; - const D: usize = 2; - type H = PoseidonHash; - - fn compress( - x: HashOut, - y: HashOut, - key: u64, - ) -> HashOut { - key_compress::(x,y,key) - } - - fn make_tree( - data: &[F], - zero: HashOut, - ) -> Result> { - // Hash the data to obtain leaf hashes - let leaves: Vec> = data - .iter() - .map(|&element| { - // Hash each field element to get the leaf hash - H::hash_no_pad(&[element]) - }) - .collect(); - - MerkleTree::::new(&leaves, zero) - } - - #[test] - fn single_proof_test() -> Result<()> { - let data = (1u64..=8) - .map(|i| F::from_canonical_u64(i)) - .collect::>(); - - // Hash the data to obtain leaf hashes - let leaves: Vec> = data - .iter() - .map(|&element| { - // Hash each field element to get the leaf hash - H::hash_no_pad(&[element]) - }) - .collect(); - - let zero = HashOut { - elements: [F::ZERO; 4], - }; - - // Build the Merkle tree - let tree = MerkleTree::::new(&leaves, zero)?; - - // Get the root - let root = tree.root()?; - - // Get a proof for the first leaf - let proof = tree.get_proof(0)?; - - // Verify the proof - let is_valid = proof.verify(leaves[0], root)?; - assert!(is_valid, "Merkle proof verification failed"); - - Ok(()) - } - - #[test] - fn test_correctness_even_bottom_layer() -> Result<()> { - // Data for the test (field elements) - let data = (1u64..=8) - .map(|i| F::from_canonical_u64(i)) - .collect::>(); - - // Hash the data to get leaf hashes - let leaf_hashes: Vec> = data - .iter() - .map(|&element| H::hash_no_pad(&[element])) - .collect(); - - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - - let expected_root = - compress( - compress( - compress( - leaf_hashes[0], - leaf_hashes[1], - KEY_BOTTOM_LAYER, - ), - compress( - leaf_hashes[2], - leaf_hashes[3], - KEY_BOTTOM_LAYER, - ), - KEY_NONE, - ), - compress( - compress( - leaf_hashes[4], - leaf_hashes[5], - KEY_BOTTOM_LAYER, - ), - compress( - leaf_hashes[6], - leaf_hashes[7], - KEY_BOTTOM_LAYER, - ), - KEY_NONE, - ), - KEY_NONE, - ); - - // Build the tree - let tree = make_tree(&data, zero)?; - - // Get the computed root - let computed_root = tree.root()?; - - // Check that the computed root matches the expected root - assert_eq!(computed_root, expected_root); - - Ok(()) - } - - #[test] - fn test_correctness_odd_bottom_layer() -> Result<()> { - // Data for the test (field elements) - let data = (1u64..=7) - .map(|i| F::from_canonical_u64(i)) - .collect::>(); - - // Hash the data to get leaf hashes - let leaf_hashes: Vec> = data - .iter() - .map(|&element| H::hash_no_pad(&[element])) - .collect(); - - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - - let expected_root = - compress( - compress( - compress( - leaf_hashes[0], - leaf_hashes[1], - KEY_BOTTOM_LAYER, - ), - compress( - leaf_hashes[2], - leaf_hashes[3], - KEY_BOTTOM_LAYER, - ), - KEY_NONE, - ), - compress( - compress( - leaf_hashes[4], - leaf_hashes[5], - KEY_BOTTOM_LAYER, - ), - compress( - leaf_hashes[6], - zero, - KEY_ODD_AND_BOTTOM_LAYER, - ), - KEY_NONE, - ), - KEY_NONE, - ); - - // Build the tree - let tree = make_tree(&data, zero)?; - - // Get the computed root - let computed_root = tree.root()?; - - // Check that the computed root matches the expected root - assert_eq!(computed_root, expected_root); - - Ok(()) - } - - #[test] - fn test_correctness_even_bottom_odd_upper_layers() -> Result<()> { - // Data for the test (field elements) - let data = (1u64..=10) - .map(|i| F::from_canonical_u64(i)) - .collect::>(); - - // Hash the data to get leaf hashes - let leaf_hashes: Vec> = data - .iter() - .map(|&element| H::hash_no_pad(&[element])) - .collect(); - - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - - let expected_root = compress( - compress( - compress( - compress( - leaf_hashes[0], - leaf_hashes[1], - KEY_BOTTOM_LAYER, - ), - compress( - leaf_hashes[2], - leaf_hashes[3], - KEY_BOTTOM_LAYER, - ), - KEY_NONE, - ), - compress( - compress( - leaf_hashes[4], - leaf_hashes[5], - KEY_BOTTOM_LAYER, - ), - compress( - leaf_hashes[6], - leaf_hashes[7], - KEY_BOTTOM_LAYER, - ), - KEY_NONE, - ), - KEY_NONE, - ), - compress( - compress( - compress( - leaf_hashes[8], - leaf_hashes[9], - KEY_BOTTOM_LAYER, - ), - zero, - KEY_ODD, - ), - zero, - KEY_ODD, - ), - KEY_NONE, - ); - - // Build the tree - let tree = make_tree(&data, zero)?; - - // Get the computed root - let computed_root = tree.root()?; - - // Check that the computed root matches the expected root - assert_eq!(computed_root, expected_root); - - Ok(()) - } - - #[test] - fn test_proofs() -> Result<()> { - // Data for the test (field elements) - let data = (1u64..=10) - .map(|i| F::from_canonical_u64(i)) - .collect::>(); - - // Hash the data to get leaf hashes - let leaf_hashes: Vec> = data - .iter() - .map(|&element| H::hash_no_pad(&[element])) - .collect(); - - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - - // Build the tree - let tree = MerkleTree::::new(&leaf_hashes, zero)?; - - // Get the root - let expected_root = tree.root()?; - - // Verify proofs for all leaves - for (i, &leaf_hash) in leaf_hashes.iter().enumerate() { - let proof = tree.get_proof(i)?; - let is_valid = proof.verify(leaf_hash, expected_root)?; - assert!(is_valid, "Proof verification failed for leaf {}", i); - } - - Ok(()) - } -} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/merkle_tree/mod.rs b/codex-plonky2-circuits/src/merkle_tree/mod.rs deleted file mode 100644 index 53eca8e..0000000 --- a/codex-plonky2-circuits/src/merkle_tree/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod merkle_safe; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/recursion/params.rs b/codex-plonky2-circuits/src/params.rs similarity index 100% rename from codex-plonky2-circuits/src/recursion/params.rs rename to codex-plonky2-circuits/src/params.rs diff --git a/codex-plonky2-circuits/src/recursion/inner_circuit.rs b/codex-plonky2-circuits/src/recursion/circuits/inner_circuit.rs similarity index 82% rename from codex-plonky2-circuits/src/recursion/inner_circuit.rs rename to codex-plonky2-circuits/src/recursion/circuits/inner_circuit.rs index 801b06f..2d37720 100644 --- a/codex-plonky2-circuits/src/recursion/inner_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/circuits/inner_circuit.rs @@ -1,8 +1,9 @@ -use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::target::Target; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CommonCircuitData; -use crate::recursion::params::{F,C,D}; +use crate::Result; +use crate::params::{F, D}; /// InnerCircuit is the trait used to define the logic of the circuit and assign witnesses /// to that circuit instance. @@ -16,7 +17,7 @@ pub trait InnerCircuit< fn build( &self, builder: &mut CircuitBuilder, - ) -> anyhow::Result; + ) -> Result; /// assign the actual witness values for the current instance of the circuit. fn assign_targets( @@ -24,17 +25,17 @@ pub trait InnerCircuit< pw: &mut PartialWitness, targets: &Self::Targets, input: &Self::Input, - ) -> anyhow::Result<()>; + ) -> Result<()>; /// from the set of the targets, return only the targets which are public /// TODO: this can probably be replaced with enum for Public/Private targets fn get_pub_input_targets( targets: &Self::Targets, - ) -> anyhow::Result<(Vec)>; + ) -> Vec; /// from the set of the targets, return only the targets which are public /// TODO: this can probably be replaced with enum for Public/Private targets fn get_common_data( &self - ) -> anyhow::Result<(CommonCircuitData)>; + ) -> Result<(CommonCircuitData)>; } diff --git a/codex-plonky2-circuits/src/recursion/circuits/mod.rs b/codex-plonky2-circuits/src/recursion/circuits/mod.rs new file mode 100644 index 0000000..d803ff2 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/circuits/mod.rs @@ -0,0 +1,2 @@ +pub mod inner_circuit; +pub mod sampling_inner_circuit; diff --git a/codex-plonky2-circuits/src/recursion/sampling_inner_circuit.rs b/codex-plonky2-circuits/src/recursion/circuits/sampling_inner_circuit.rs similarity index 75% rename from codex-plonky2-circuits/src/recursion/sampling_inner_circuit.rs rename to codex-plonky2-circuits/src/recursion/circuits/sampling_inner_circuit.rs index 0497779..fd4cd7b 100644 --- a/codex-plonky2-circuits/src/recursion/sampling_inner_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/circuits/sampling_inner_circuit.rs @@ -4,8 +4,9 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; use crate::circuits::params::CircuitParams; use crate::circuits::sample_cells::{SampleCircuit, SampleCircuitInput, SampleTargets}; -use crate::recursion::params::{D, F, C}; -use crate::recursion::inner_circuit::InnerCircuit; +use crate::params::{D, F, C}; +use crate::recursion::circuits::inner_circuit::InnerCircuit; +use crate::Result; /// recursion Inner circuit for the sampling circuit #[derive(Clone, Debug)] @@ -36,33 +37,33 @@ impl InnerCircuit for SamplingRecursion{ type Input = SampleCircuitInput; /// build the circuit - fn build(&self, builder: &mut CircuitBuilder) -> anyhow::Result { - Ok(self.sampling_circ.sample_slot_circuit(builder)) + fn build(&self, builder: &mut CircuitBuilder) -> Result { + self.sampling_circ.sample_slot_circuit(builder) } - fn assign_targets(&self, pw: &mut PartialWitness, targets: &Self::Targets, input: &Self::Input) -> anyhow::Result<()> { - Ok(self.sampling_circ.sample_slot_assign_witness(pw, targets, input)) + fn assign_targets(&self, pw: &mut PartialWitness, targets: &Self::Targets, input: &Self::Input) -> Result<()> { + self.sampling_circ.sample_slot_assign_witness(pw, targets, input) } /// returns the public input specific for this circuit which are: /// `[slot_index, dataset_root, entropy]` - fn get_pub_input_targets(targets: &Self::Targets) -> anyhow::Result<(Vec)> { + fn get_pub_input_targets(targets: &Self::Targets) -> Vec { let mut pub_targets = vec![]; pub_targets.push(targets.slot_index.clone()); pub_targets.extend_from_slice(&targets.dataset_root.elements); pub_targets.extend_from_slice(&targets.entropy.elements); - Ok(pub_targets) + pub_targets } /// return the common circuit data for the sampling circuit /// uses the `standard_recursion_config` - fn get_common_data(&self) -> anyhow::Result<(CommonCircuitData)> { + fn get_common_data(&self) -> Result<(CommonCircuitData)> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); // build the inner circuit - self.sampling_circ.sample_slot_circuit_with_public_input(&mut builder); + self.sampling_circ.sample_slot_circuit_with_public_input(&mut builder)?; let circ_data = builder.build::(); diff --git a/codex-plonky2-circuits/src/recursion/cyclic_recursion.rs b/codex-plonky2-circuits/src/recursion/cyclic/mod.rs similarity index 95% rename from codex-plonky2-circuits/src/recursion/cyclic_recursion.rs rename to codex-plonky2-circuits/src/recursion/cyclic/mod.rs index a19236f..cab57e9 100644 --- a/codex-plonky2-circuits/src/recursion/cyclic_recursion.rs +++ b/codex-plonky2-circuits/src/recursion/cyclic/mod.rs @@ -7,18 +7,17 @@ use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField}; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget}; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitTarget}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::recursion::dummy_circuit::cyclic_base_proof; -use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use crate::recursion::params::{F,D,C,Plonky2Proof,H}; -use crate::recursion::inner_circuit::InnerCircuit; -use anyhow::Result; +use crate::params::{F,D,C,Plonky2Proof,H}; +use crate::recursion::circuits::inner_circuit::InnerCircuit; use plonky2::gates::noop::NoopGate; use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; use crate::circuits::utils::select_hash; +use crate::Result; /// cyclic circuit struct /// contains necessary data @@ -79,7 +78,7 @@ impl< let inner_t = self.circ.build(& mut builder)?; // common data for recursion - let mut common_data = common_data_for_recursion(); + let mut common_data = common_data_for_cyclic_recursion(); // the hash of the public input let pub_input_hash = builder.add_virtual_hash_public_input(); // verifier data for inner proofs @@ -198,7 +197,7 @@ impl< ) -> Result>{ // asserts that n equals the number of input - assert_eq!(n, circ_input.len()); + assert_eq!(n, circ_input.len()); // TODO: replace with err for i in 0..n { self.prove_one_layer(&circ_input[i])?; @@ -212,7 +211,7 @@ impl< &mut self, ) -> Result<()>{ if(self.cyclic_circuit_data.is_none() || self.latest_proof.is_none()){ - panic!("no circuit data or proof found"); + panic!("no circuit data or proof found"); // TODO: replace with err } let circ_data = self.cyclic_circuit_data.as_ref().unwrap(); let proof = self.latest_proof.clone().unwrap(); @@ -224,7 +223,7 @@ impl< } /// Generates `CommonCircuitData` usable for recursion. -pub fn common_data_for_recursion() -> CommonCircuitData +pub fn common_data_for_cyclic_recursion() -> CommonCircuitData { // layer 1 let config = CircuitConfig::standard_recursion_config(); diff --git a/codex-plonky2-circuits/src/recursion/mod.rs b/codex-plonky2-circuits/src/recursion/mod.rs index 59badb9..51b767b 100644 --- a/codex-plonky2-circuits/src/recursion/mod.rs +++ b/codex-plonky2-circuits/src/recursion/mod.rs @@ -1,12 +1,5 @@ -pub mod inner_circuit; -pub mod simple_recursion; -// pub mod simple_recursion2; -pub mod tree_recursion; -pub mod params; -pub mod sampling_inner_circuit; -pub mod cyclic_recursion; -pub mod leaf_circuit; - -pub mod tree_recursion2; -pub mod utils; -pub mod simple_tree_recursion; \ No newline at end of file +pub mod cyclic; +pub mod circuits; +pub mod simple; +pub mod tree1; +pub mod tree2; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/recursion/simple/mod.rs b/codex-plonky2-circuits/src/recursion/simple/mod.rs new file mode 100644 index 0000000..98e64d5 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/simple/mod.rs @@ -0,0 +1,2 @@ +pub mod simple_recursion; +pub mod simple_tree_recursion; diff --git a/codex-plonky2-circuits/src/recursion/simple_recursion.rs b/codex-plonky2-circuits/src/recursion/simple/simple_recursion.rs similarity index 98% rename from codex-plonky2-circuits/src/recursion/simple_recursion.rs rename to codex-plonky2-circuits/src/recursion/simple/simple_recursion.rs index 5e069e3..50737b1 100644 --- a/codex-plonky2-circuits/src/recursion/simple_recursion.rs +++ b/codex-plonky2-circuits/src/recursion/simple/simple_recursion.rs @@ -8,8 +8,9 @@ use plonky2::plonk::circuit_data::{VerifierCircuitData, VerifierCircuitTarget}; use plonky2::plonk::config::GenericConfig; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use crate::recursion::inner_circuit::InnerCircuit; -use crate::recursion::params::{C, D, F, Plonky2Proof}; +use crate::recursion::circuits::inner_circuit::InnerCircuit; +use crate::params::{C, D, F, Plonky2Proof}; +use crate::Result; /// aggregate sampling proofs /// This function takes: @@ -25,7 +26,7 @@ pub fn aggregate_sampling_proofs< verifier_data: &VerifierCircuitData, builder: &mut CircuitBuilder::, pw: &mut PartialWitness, -)-> anyhow::Result<()>{ +)-> Result<()>{ // the proof virtual targets let mut proof_targets = vec![]; let mut inner_entropy_targets = vec![]; diff --git a/codex-plonky2-circuits/src/recursion/simple_tree_recursion.rs b/codex-plonky2-circuits/src/recursion/simple/simple_tree_recursion.rs similarity index 97% rename from codex-plonky2-circuits/src/recursion/simple_tree_recursion.rs rename to codex-plonky2-circuits/src/recursion/simple/simple_tree_recursion.rs index 0c27af6..1bf71ed 100644 --- a/codex-plonky2-circuits/src/recursion/simple_tree_recursion.rs +++ b/codex-plonky2-circuits/src/recursion/simple/simple_tree_recursion.rs @@ -2,8 +2,8 @@ use plonky2::plonk::proof::ProofWithPublicInputs; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::iop::witness::PartialWitness; -use crate::recursion::params::{C, D, F}; -use crate::recursion::simple_recursion; +use crate::params::{C, D, F}; +use crate::recursion::simple::simple_recursion; // recursion tree width or the number of proofs in each node in the tree const RECURSION_TREE_WIDTH: usize = 2; diff --git a/codex-plonky2-circuits/src/recursion/tree1/mod.rs b/codex-plonky2-circuits/src/recursion/tree1/mod.rs new file mode 100644 index 0000000..ca40b68 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/tree1/mod.rs @@ -0,0 +1 @@ +pub mod tree_recursion; diff --git a/codex-plonky2-circuits/src/recursion/tree_recursion.rs b/codex-plonky2-circuits/src/recursion/tree1/tree_recursion.rs similarity index 99% rename from codex-plonky2-circuits/src/recursion/tree_recursion.rs rename to codex-plonky2-circuits/src/recursion/tree1/tree_recursion.rs index 33fcde5..6ae3469 100644 --- a/codex-plonky2-circuits/src/recursion/tree_recursion.rs +++ b/codex-plonky2-circuits/src/recursion/tree1/tree_recursion.rs @@ -11,8 +11,8 @@ use plonky2::recursion::dummy_circuit::cyclic_base_proof; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; // use crate::recursion::params::RecursionTreeParams; -use crate::recursion::params::{F,D,C,Plonky2Proof,H}; -use crate::recursion::inner_circuit::InnerCircuit; +use crate::params::{F, D, C, Plonky2Proof, H}; +use crate::recursion::circuits::inner_circuit::InnerCircuit; use anyhow::{anyhow, Result}; use plonky2::gates::noop::NoopGate; use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; diff --git a/codex-plonky2-circuits/src/recursion/leaf_circuit.rs b/codex-plonky2-circuits/src/recursion/tree2/leaf_circuit.rs similarity index 95% rename from codex-plonky2-circuits/src/recursion/leaf_circuit.rs rename to codex-plonky2-circuits/src/recursion/tree2/leaf_circuit.rs index b464d47..4650b45 100644 --- a/codex-plonky2-circuits/src/recursion/leaf_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/leaf_circuit.rs @@ -4,9 +4,9 @@ use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use crate::circuits::params::CircuitParams; use crate::circuits::sample_cells::SampleCircuit; -use crate::recursion::params::{C, D, F, H}; -use crate::recursion::inner_circuit::InnerCircuit; -use crate::recursion::sampling_inner_circuit::SamplingRecursion; +use crate::params::{C, D, F, H}; +use crate::recursion::circuits::inner_circuit::InnerCircuit; +use crate::recursion::circuits::sampling_inner_circuit::SamplingRecursion; /// recursion Inner circuit for the sampling circuit #[derive(Clone, Debug)] diff --git a/codex-plonky2-circuits/src/recursion/tree2/mod.rs b/codex-plonky2-circuits/src/recursion/tree2/mod.rs new file mode 100644 index 0000000..5d322f9 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/tree2/mod.rs @@ -0,0 +1,3 @@ +pub mod leaf_circuit; +pub mod tree_recursion2; +pub mod utils; diff --git a/codex-plonky2-circuits/src/recursion/tree_recursion2.rs b/codex-plonky2-circuits/src/recursion/tree2/tree_recursion2.rs similarity index 98% rename from codex-plonky2-circuits/src/recursion/tree_recursion2.rs rename to codex-plonky2-circuits/src/recursion/tree2/tree_recursion2.rs index 9de97a0..561e4c0 100644 --- a/codex-plonky2-circuits/src/recursion/tree_recursion2.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/tree_recursion2.rs @@ -5,14 +5,15 @@ use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData use plonky2::plonk::config::GenericConfig; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use crate::recursion::params::{C, D, F, H}; -use crate::recursion::inner_circuit::InnerCircuit; +use crate::params::{C, D, F, H}; +use crate::recursion::circuits::inner_circuit::InnerCircuit; use anyhow::{anyhow, Result}; use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; // use serde::de::Unexpected::Option; use crate::circuits::utils::select_hash; -use crate::recursion::{leaf_circuit, utils}; -use crate::recursion::utils::{get_dummy_leaf_proof, get_dummy_node_proof}; +use crate::recursion::tree2::leaf_circuit; +use crate::recursion::tree2::utils; +use crate::recursion::tree2::utils::{get_dummy_leaf_proof, get_dummy_node_proof}; /// the tree recursion struct simplifies the process /// of building, proving and verifying diff --git a/codex-plonky2-circuits/src/recursion/utils.rs b/codex-plonky2-circuits/src/recursion/tree2/utils.rs similarity index 98% rename from codex-plonky2-circuits/src/recursion/utils.rs rename to codex-plonky2-circuits/src/recursion/tree2/utils.rs index 62a18ee..e8a4949 100644 --- a/codex-plonky2-circuits/src/recursion/utils.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/utils.rs @@ -4,7 +4,7 @@ use plonky2::gates::noop::NoopGate; use plonky2::plonk::proof::ProofWithPublicInputs; use plonky2::recursion::dummy_circuit::{cyclic_base_proof, dummy_circuit, dummy_proof}; use hashbrown::HashMap; -use crate::recursion::params::{C, D, F}; +use crate::params::{C, D, F}; /// Generates `CommonCircuitData` usable for node recursion. /// the circuit being built here depends on M and N so must be re-generated diff --git a/proof-input/src/tests/merkle_circuit.rs b/proof-input/src/merkle_tree/merkle_circuit.rs similarity index 90% rename from proof-input/src/tests/merkle_circuit.rs rename to proof-input/src/merkle_tree/merkle_circuit.rs index eee19fe..67d415d 100644 --- a/proof-input/src/tests/merkle_circuit.rs +++ b/proof-input/src/merkle_tree/merkle_circuit.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use codex_plonky2_circuits::Result; use plonky2::field::extension::Extendable; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::Field; @@ -13,9 +13,10 @@ use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use serde::Serialize; use codex_plonky2_circuits::circuits::merkle_circuit::{MerkleProofTarget, MerkleTreeCircuit, MerkleTreeTargets}; use codex_plonky2_circuits::circuits::utils::{assign_bool_targets, assign_hash_out_targets}; +use codex_plonky2_circuits::error::CircuitError; use crate::utils::usize_to_bits_le; -use codex_plonky2_circuits::merkle_tree::merkle_safe::MerkleTree; +use crate::merkle_tree::merkle_safe::MerkleTree; /// the input to the merkle tree circuit #[derive(Clone)] @@ -38,7 +39,7 @@ pub fn build_circuit< >( builder: &mut CircuitBuilder::, depth: usize, -) -> (MerkleTreeTargets, HashOutTarget) { +) -> Result<(MerkleTreeTargets, HashOutTarget)> { // Create virtual targets let leaf = builder.add_virtual_hash(); @@ -67,10 +68,10 @@ pub fn build_circuit< }; // Add Merkle proof verification constraints to the circuit - let reconstructed_root_target = MerkleTreeCircuit::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets, depth); + let reconstructed_root_target = MerkleTreeCircuit::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets, depth)?; // Return MerkleTreeTargets - (targets, reconstructed_root_target) + Ok((targets, reconstructed_root_target)) } /// assign the witness values in the circuit targets @@ -84,24 +85,36 @@ pub fn assign_witness< witnesses: MerkleTreeCircuitInput )-> Result<()> { // Assign the leaf hash to the leaf target - pw.set_hash_target(targets.leaf, witnesses.leaf); + pw.set_hash_target(targets.leaf, witnesses.leaf) + .map_err(|e| { + CircuitError::HashTargetAssignmentError("leaf".to_string(), e.to_string()) + })?; // Assign path bits - assign_bool_targets(pw, &targets.path_bits, witnesses.path_bits); + assign_bool_targets(pw, &targets.path_bits, witnesses.path_bits) + .map_err(|e| { + CircuitError::BoolTargetAssignmentError("path_bits".to_string(), e.to_string()) + })?; // Assign last bits - assign_bool_targets(pw, &targets.last_bits, witnesses.last_bits); + assign_bool_targets(pw, &targets.last_bits, witnesses.last_bits) + .map_err(|e| { + CircuitError::BoolTargetAssignmentError("last_bits".to_string(), e.to_string()) + })?; // Assign mask bits - assign_bool_targets(pw, &targets.mask_bits, witnesses.mask_bits); + assign_bool_targets(pw, &targets.mask_bits, witnesses.mask_bits) + .map_err(|e| { + CircuitError::BoolTargetAssignmentError("mask_bits".to_string(), e.to_string()) + })?; // assign the Merkle path (sibling hashes) to the targets for i in 0..targets.merkle_path.path.len() { if i>=witnesses.merkle_path.len() { // pad with zeros - assign_hash_out_targets(pw, &targets.merkle_path.path[i].elements, &[F::ZERO; NUM_HASH_OUT_ELTS]); + assign_hash_out_targets(pw, &targets.merkle_path.path[i], &HashOut::from_vec([F::ZERO; NUM_HASH_OUT_ELTS].to_vec()))?; continue } - assign_hash_out_targets(pw, &targets.merkle_path.path[i].elements, &witnesses.merkle_path[i].elements) + assign_hash_out_targets(pw, &targets.merkle_path.path[i], &witnesses.merkle_path[i])?; } Ok(()) } @@ -117,10 +130,7 @@ mod tests { use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::goldilocks_field::GoldilocksField; - // use crate::merkle_tree::merkle_safe::MerkleTree; - // NOTE: for now these tests don't check the reconstructed root is equal to expected_root - // will be fixed later, but for that test check the other tests in this crate #[test] fn test_build_circuit() -> anyhow::Result<()> { // circuit params @@ -165,7 +175,7 @@ mod tests { // create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth); + let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth)?; // expected Merkle root let expected_root = builder.add_virtual_hash(); @@ -241,7 +251,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth); + let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth)?; // expected Merkle root let expected_root_target = builder.add_virtual_hash(); diff --git a/proof-input/src/tests/cyclic_recursion.rs b/proof-input/src/recursion/cyclic_recursion.rs similarity index 94% rename from proof-input/src/tests/cyclic_recursion.rs rename to proof-input/src/recursion/cyclic_recursion.rs index b535738..1f6bf6e 100644 --- a/proof-input/src/tests/cyclic_recursion.rs +++ b/proof-input/src/recursion/cyclic_recursion.rs @@ -10,12 +10,12 @@ mod tests { use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::GenericConfig; - use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof}; - use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; + use codex_plonky2_circuits::params::{F, D, C, Plonky2Proof}; + use codex_plonky2_circuits::recursion::circuits::sampling_inner_circuit::SamplingRecursion; use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; use crate::gen_input::gen_testing_circuit_input; use crate::params::TestParams; - use codex_plonky2_circuits::recursion::cyclic_recursion::CyclicCircuit; + use codex_plonky2_circuits::recursion::cyclic::CyclicCircuit; /// Uses cyclic recursion to sample the dataset diff --git a/proof-input/src/tests/simple_recursion.rs b/proof-input/src/recursion/simple_recursion.rs similarity index 94% rename from proof-input/src/tests/simple_recursion.rs rename to proof-input/src/recursion/simple_recursion.rs index e0fc346..6c40935 100644 --- a/proof-input/src/tests/simple_recursion.rs +++ b/proof-input/src/recursion/simple_recursion.rs @@ -6,9 +6,9 @@ use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; use plonky2_field::types::Field; -use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; -use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs,SimpleRecursionCircuit, SimpleRecursionInput}; -use codex_plonky2_circuits::recursion::simple_tree_recursion::aggregate_sampling_proofs_tree; +use codex_plonky2_circuits::recursion::circuits::sampling_inner_circuit::SamplingRecursion; +use codex_plonky2_circuits::recursion::simple::simple_recursion::{aggregate_sampling_proofs, SimpleRecursionCircuit, SimpleRecursionInput}; +use codex_plonky2_circuits::recursion::simple::simple_tree_recursion::aggregate_sampling_proofs_tree; use plonky2_poseidon2::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer}; use crate::gen_input::{build_circuit, prove_circuit}; use crate::json::write_bytes_to_file; diff --git a/workflow/benches/simple_recursion.rs b/workflow/benches/simple_recursion.rs index 1c6990f..070d30f 100644 --- a/workflow/benches/simple_recursion.rs +++ b/workflow/benches/simple_recursion.rs @@ -4,7 +4,7 @@ use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; use plonky2::plonk::config::GenericConfig; -use codex_plonky2_circuits::recursion::simple_recursion::aggregate_sampling_proofs; +use codex_plonky2_circuits::recursion::simple::simple_recursion::aggregate_sampling_proofs; use proof_input::params::{D, C, F, Params, TestParams}; use proof_input::gen_input::{build_circuit, prove_circuit}; diff --git a/workflow/benches/simple_tree_recursion.rs b/workflow/benches/simple_tree_recursion.rs index e76bd24..504a1e4 100644 --- a/workflow/benches/simple_tree_recursion.rs +++ b/workflow/benches/simple_tree_recursion.rs @@ -2,7 +2,7 @@ use criterion::{Criterion, criterion_group, criterion_main}; use plonky2::plonk::circuit_data::VerifierCircuitData; use plonky2::plonk::config::GenericConfig; use plonky2::plonk::proof::ProofWithPublicInputs; -use codex_plonky2_circuits::recursion::simple_tree_recursion::aggregate_sampling_proofs_tree2; +use codex_plonky2_circuits::recursion::simple::simple_tree_recursion::aggregate_sampling_proofs_tree2; use proof_input::params::{C, D, F}; use proof_input::gen_input::{build_circuit, prove_circuit};