From d63a309e02984d73dc0fb364e5f6ac3d7f12af67 Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Fri, 10 Jan 2025 12:32:45 +0100 Subject: [PATCH] remove the hardcoded circuit params and refactor --- codex-plonky2-circuits/src/circuits/params.rs | 28 ++------ codex-plonky2-circuits/src/lib.rs | 1 - codex-plonky2-circuits/src/recursion/mod.rs | 3 +- .../src/{ => recursion}/params.rs | 0 proof-input/src/data_structs.rs | 18 ++--- proof-input/src/gen_input.rs | 38 ++++++----- proof-input/src/params.rs | 67 +++++++------------ .../src/serialization/circuit_input.rs | 4 +- proof-input/src/serialization/json.rs | 28 ++++---- workflow/src/bin/gen_input.rs | 2 +- 10 files changed, 77 insertions(+), 112 deletions(-) rename codex-plonky2-circuits/src/{ => recursion}/params.rs (100%) diff --git a/codex-plonky2-circuits/src/circuits/params.rs b/codex-plonky2-circuits/src/circuits/params.rs index a613a71..3dc9b60 100644 --- a/codex-plonky2-circuits/src/circuits/params.rs +++ b/codex-plonky2-circuits/src/circuits/params.rs @@ -14,26 +14,6 @@ pub struct CircuitParams{ pub n_samples: usize, } -// hardcoded default constants -const DEFAULT_MAX_DEPTH:usize = 32; -const DEFAULT_MAX_LOG2_N_SLOTS:usize = 8; -const DEFAULT_BLOCK_TREE_DEPTH:usize = 5; -const DEFAULT_N_FIELD_ELEMS_PER_CELL:usize = 272; -const DEFAULT_N_SAMPLES:usize = 5; - -/// Implement the Default trait for Params using the hardcoded constants -impl Default for CircuitParams { - fn default() -> Self { - Self{ - max_depth: DEFAULT_MAX_DEPTH, - max_log2_n_slots: DEFAULT_MAX_LOG2_N_SLOTS, - block_tree_depth: DEFAULT_BLOCK_TREE_DEPTH, - n_field_elems_per_cell: DEFAULT_N_FIELD_ELEMS_PER_CELL, - n_samples: DEFAULT_N_SAMPLES, - } - } -} - impl CircuitParams { /// Creates a new `CircuitParams` struct from environment. /// @@ -45,12 +25,12 @@ impl CircuitParams { /// /// Returns an error if any environment variable is missing or fails to parse. pub fn from_env() -> Result { - let MAX_DEPTH = env::var("MAX_DEPTH") + let max_depth = env::var("MAX_DEPTH") .context("MAX_DEPTH is not set")? .parse::() .context("MAX_DEPTH must be a valid usize")?; - let MAX_LOG2_N_SLOTS = env::var("MAX_LOG2_N_SLOTS") + let max_log2_n_slots = env::var("MAX_LOG2_N_SLOTS") .context("MAX_LOG2_N_SLOTS is not set")? .parse::() .context("MAX_LOG2_N_SLOTS must be a valid usize")?; @@ -71,8 +51,8 @@ impl CircuitParams { .context("N_SAMPLES must be a valid usize")?; Ok(CircuitParams { - max_depth: MAX_DEPTH, - max_log2_n_slots: MAX_LOG2_N_SLOTS, + max_depth, + max_log2_n_slots, block_tree_depth, n_field_elems_per_cell, n_samples, diff --git a/codex-plonky2-circuits/src/lib.rs b/codex-plonky2-circuits/src/lib.rs index ea4f9b8..27a298c 100644 --- a/codex-plonky2-circuits/src/lib.rs +++ b/codex-plonky2-circuits/src/lib.rs @@ -2,6 +2,5 @@ pub mod circuits; // pub mod merkle_tree; // pub mod recursion; pub mod error; -pub mod params; pub type Result = core::result::Result; diff --git a/codex-plonky2-circuits/src/recursion/mod.rs b/codex-plonky2-circuits/src/recursion/mod.rs index 51b767b..cf65683 100644 --- a/codex-plonky2-circuits/src/recursion/mod.rs +++ b/codex-plonky2-circuits/src/recursion/mod.rs @@ -2,4 +2,5 @@ pub mod cyclic; pub mod circuits; pub mod simple; pub mod tree1; -pub mod tree2; \ No newline at end of file +pub mod tree2; +pub mod params; diff --git a/codex-plonky2-circuits/src/params.rs b/codex-plonky2-circuits/src/recursion/params.rs similarity index 100% rename from codex-plonky2-circuits/src/params.rs rename to codex-plonky2-circuits/src/recursion/params.rs diff --git a/proof-input/src/data_structs.rs b/proof-input/src/data_structs.rs index 342add9..55e8134 100644 --- a/proof-input/src/data_structs.rs +++ b/proof-input/src/data_structs.rs @@ -6,7 +6,7 @@ use codex_plonky2_circuits::circuits::sample_cells::Cell; use plonky2_field::types::Sample; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleTree}; -use crate::params::{TestParams, HF}; +use crate::params::{InputParams,Params, HF}; use crate::sponge::hash_bytes_no_padding; use crate::utils::{bits_le_padded_to_usize, calculate_cell_index_bits, usize_to_bits_le}; @@ -19,7 +19,7 @@ pub struct SlotTree< pub tree: MerkleTree, // slot tree pub block_trees: Vec>, // vec of block trees pub cell_data: Vec>, // cell data as field elements - pub params: TestParams, // parameters + pub params: InputParams, // parameters } impl< @@ -27,7 +27,7 @@ impl< const D: usize, > SlotTree { /// Create a slot tree with fake data, for testing only - pub fn new_default(params: &TestParams) -> Self { + pub fn new_default(params: &InputParams) -> Self { // generate fake cell data let cell_data = (0..params.n_cells) .map(|_| new_random_cell(params)) @@ -36,7 +36,7 @@ impl< } /// Create a new slot tree with the supplied cell data and parameters - pub fn new(cells: Vec>, params: TestParams) -> Self { + pub fn new(cells: Vec>, params: InputParams) -> Self { let leaves: Vec> = cells .iter() .map(|element| hash_bytes_no_padding::(&element.data)) @@ -106,7 +106,7 @@ pub struct DatasetTree< > { pub tree: MerkleTree, // dataset tree pub slot_trees: Vec>, // vec of slot trees - pub params: TestParams, // parameters + pub params: InputParams, // parameters } /// Dataset Merkle proof struct, containing the dataset proof and sampled proofs. @@ -127,7 +127,7 @@ impl< const D: usize, > DatasetTree { /// Dataset tree with fake data, for testing only - pub fn new_default(params: &TestParams) -> Self { + pub fn new_default(params: &InputParams) -> Self { let mut slot_trees = vec![]; let n_slots = 1 << params.dataset_depth_test(); for _ in 0..n_slots { @@ -137,7 +137,7 @@ impl< } /// Create data for only the specified slot index in params - pub fn new_for_testing(params: &TestParams) -> Self { + pub fn new_for_testing(params: &InputParams) -> Self { let mut slot_trees = vec![]; // let n_slots = 1 << params.dataset_depth(); let n_slots = params.n_slots; @@ -172,7 +172,7 @@ impl< } /// Same as default but with supplied slot trees - pub fn new(slot_trees: Vec>, params: TestParams) -> Self { + pub fn new(slot_trees: Vec>, params: InputParams) -> Self { // get the roots of slot trees let slot_roots = slot_trees .iter() @@ -248,7 +248,7 @@ impl< pub fn new_random_cell< F: RichField + Extendable + Poseidon2, const D: usize, ->(params: &TestParams) -> Cell { +>(params: &InputParams) -> Cell { let data = (0..params.n_field_elems_per_cell()) .map(|_| F::rand()) .collect::>(); diff --git a/proof-input/src/gen_input.rs b/proof-input/src/gen_input.rs index 2f94c57..041ab8b 100644 --- a/proof-input/src/gen_input.rs +++ b/proof-input/src/gen_input.rs @@ -4,7 +4,7 @@ use plonky2_field::extension::Extendable; use plonky2_field::types::Field; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use codex_plonky2_circuits::circuits::params::CircuitParams; -use crate::params::TestParams; +use crate::params::{Params,InputParams}; use crate::utils::{bits_le_padded_to_usize, calculate_cell_index_bits, ceiling_log2, usize_to_bits_le}; use crate::merkle_tree::merkle_safe::MerkleProof; use codex_plonky2_circuits::circuits::sample_cells::{MerklePath, SampleCircuit, SampleCircuitInput, SampleTargets}; @@ -21,7 +21,7 @@ use crate::params::{C, D, F, HF}; pub fn gen_testing_circuit_input< F: RichField + Extendable + Poseidon2, const D: usize, ->(params: &TestParams) -> SampleCircuitInput{ +>(params: &InputParams) -> SampleCircuitInput{ let dataset_t = DatasetTree::::new_for_testing(¶ms); let slot_index = params.testing_slot_index; // samples the specified slot @@ -57,7 +57,7 @@ pub fn gen_testing_circuit_input< pub fn verify_circuit_input< F: RichField + Extendable + Poseidon2, const D: usize, ->(circ_input: SampleCircuitInput, params: &TestParams) -> bool{ +>(circ_input: SampleCircuitInput, params: &InputParams) -> bool{ let slot_index = circ_input.slot_index.to_canonical_u64(); let slot_root = circ_input.slot_root.clone(); // check dataset level proof @@ -102,7 +102,7 @@ pub fn verify_circuit_input< pub fn verify_cell_proof< F: RichField + Extendable + Poseidon2, const D: usize, ->(circ_input: &SampleCircuitInput, params: &TestParams, cell_index: usize, ctr: usize) -> anyhow::Result { +>(circ_input: &SampleCircuitInput, params: &InputParams, cell_index: usize, ctr: usize) -> anyhow::Result { let mut block_path_bits = usize_to_bits_le(cell_index, params.max_depth); let last_index = params.n_cells - 1; let mut block_last_bits = usize_to_bits_le(last_index, params.max_depth); @@ -156,16 +156,17 @@ pub fn build_circuit(n_samples: usize, slot_index: usize) -> anyhow::Result<(Cir /// returns the proof, circuit data, and targets pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData, PartialWitness, SampleTargets)>{ // get input - let mut params = TestParams::default(); - params.n_samples = n_samples; - params.testing_slot_index = slot_index; - let circ_input = gen_testing_circuit_input::(¶ms); + let mut params = Params::default(); + let mut input_params = params.input_params; + input_params.n_samples = n_samples; + input_params.testing_slot_index = slot_index; + let circ_input = gen_testing_circuit_input::(&input_params); // Create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let mut circuit_params = CircuitParams::default(); + let mut circuit_params = params.circuit_params; circuit_params.n_samples = n_samples; // build the circuit @@ -193,10 +194,10 @@ pub fn prove_circuit(data: &CircuitData, pw: &PartialWitness) -> any } /// returns exactly M default circuit input -pub fn get_m_default_circ_input() -> [SampleCircuitInput; M]{ - let params = TestParams::default(); - let one_circ_input = gen_testing_circuit_input::(¶ms); - let circ_input: [SampleCircuitInput; M] = (0..M) +pub fn get_m_default_circ_input() -> [SampleCircuitInput; M]{ + let params = Params::default().input_params; + let one_circ_input = gen_testing_circuit_input::(¶ms); + let circ_input: [SampleCircuitInput; M] = (0..M) .map(|_| one_circ_input.clone()) .collect::>() .try_into().unwrap(); @@ -217,7 +218,7 @@ mod tests { // Test sample cells (non-circuit) #[test] fn test_gen_verify_proof(){ - let params = TestParams::default(); + let params = Params::default().input_params; let w = gen_testing_circuit_input::(¶ms); assert!(verify_circuit_input::(w, ¶ms)); } @@ -226,15 +227,16 @@ mod tests { #[test] fn test_proof_in_circuit() -> anyhow::Result<()> { // get input - let mut params = TestParams::default(); - params.n_samples = 10; - let circ_input = gen_testing_circuit_input::(¶ms); + let mut params = Params::default(); + let mut input_params = params.input_params; + input_params.n_samples = 10; + let circ_input = gen_testing_circuit_input::(&input_params); // Create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let mut circuit_params = CircuitParams::default(); + let mut circuit_params = params.circuit_params; circuit_params.n_samples = 10; // build the circuit diff --git a/proof-input/src/params.rs b/proof-input/src/params.rs index 1244352..47a0796 100644 --- a/proof-input/src/params.rs +++ b/proof-input/src/params.rs @@ -13,7 +13,6 @@ pub const D: usize = 2; pub type C = PoseidonGoldilocksConfig; pub type F = >::F; // this is the goldilocks field pub type HF = PoseidonHash; -// pub type HP = >::Permutation; // hardcoded default params for generating proof input const DEFAULT_MAX_DEPTH: usize = 32; // depth of big tree (slot tree depth, includes block tree depth) @@ -33,12 +32,12 @@ const DEFAULT_N_CELLS: usize = 512; // number of cells in each slot #[derive(Clone)] pub struct Params { pub circuit_params: CircuitParams, - pub test: TestParams, + pub input_params: InputParams, } /// test params #[derive(Clone)] -pub struct TestParams{ +pub struct InputParams{ pub max_depth: usize, pub max_slots: usize, pub cell_size: usize, @@ -52,9 +51,9 @@ pub struct TestParams{ } /// Implement the Default trait for Params using the hardcoded constants -impl Default for TestParams { +impl Default for Params { fn default() -> Self { - TestParams { + let input_params = InputParams { max_depth: DEFAULT_MAX_DEPTH, max_slots: DEFAULT_MAX_SLOTS, cell_size: DEFAULT_CELL_SIZE, @@ -65,37 +64,18 @@ impl Default for TestParams { n_slots: DEFAULT_N_SLOTS, testing_slot_index: DEFAULT_SLOT_INDEX, n_cells: DEFAULT_N_CELLS, + }; + let circuit_params = input_params.get_circuit_params(); + + Params{ + circuit_params, + input_params, } } } /// Implement a new function to create Params with custom values -impl TestParams { - pub fn new( - max_depth: usize, - max_slots: usize, - cell_size: usize, - block_size: usize, - n_samples: usize, - entropy: usize, - seed: usize, - n_slots: usize, - testing_slot_index: usize, - n_cells: usize, - ) -> Self { - TestParams { - max_depth, - max_slots, - cell_size, - block_size, - n_samples, - entropy, - seed, - n_slots, - testing_slot_index, - n_cells, - } - } +impl InputParams { // GOLDILOCKS_F_SIZE pub fn goldilocks_f_size(&self) -> usize { 64 @@ -141,6 +121,15 @@ impl TestParams { ceiling_log2(self.n_slots) } + pub fn get_circuit_params(&self) -> CircuitParams{ + CircuitParams{ + max_depth: self.max_depth, + max_log2_n_slots: self.dataset_max_depth(), + block_tree_depth: self.bot_depth(), + n_field_elems_per_cell: self.n_field_elems_per_cell(), + n_samples:self.n_samples, + } + } } pub fn log2(x: usize) -> usize { @@ -156,7 +145,7 @@ pub fn ceiling_log2(x: usize) -> usize { } /// load test params from env -impl TestParams { +impl InputParams { pub fn from_env() -> Result { let max_depth = env::var("MAXDEPTH") .context("MAXDEPTH not set")? @@ -208,7 +197,7 @@ impl TestParams { .parse::() .context("Invalid NCELLS")?; - Ok(TestParams { + Ok(InputParams { max_depth, max_slots, cell_size, @@ -226,18 +215,12 @@ impl TestParams { /// load params from env impl Params { pub fn from_env() -> Result { - let test_params = TestParams::from_env()?; - let circuit_params = CircuitParams{ - max_depth: test_params.max_depth, - max_log2_n_slots: test_params.dataset_max_depth(), - block_tree_depth: test_params.bot_depth(), - n_field_elems_per_cell: test_params.n_field_elems_per_cell(), - n_samples:test_params.n_samples, - }; + let input_params = InputParams::from_env()?; + let circuit_params = input_params.get_circuit_params(); Ok(Params{ circuit_params, - test: test_params, + input_params, }) } } \ No newline at end of file diff --git a/proof-input/src/serialization/circuit_input.rs b/proof-input/src/serialization/circuit_input.rs index cd8b639..e30c3a1 100644 --- a/proof-input/src/serialization/circuit_input.rs +++ b/proof-input/src/serialization/circuit_input.rs @@ -8,7 +8,7 @@ use std::fs::File; use std::io::{BufReader, Write}; use plonky2_field::types::{Field, PrimeField64}; use crate::gen_input::gen_testing_circuit_input; -use crate::params::TestParams; +use crate::params::InputParams; /// export circuit input to json file pub fn export_circ_input_to_json< @@ -32,7 +32,7 @@ pub fn export_circ_input_to_json< pub fn generate_and_export_circ_input_to_json< F: RichField + Extendable + Poseidon2 + Serialize, const D: usize, ->(params: &TestParams, filename: &str) -> anyhow::Result<()> { +>(params: &InputParams, filename: &str) -> anyhow::Result<()> { let circ_input = gen_testing_circuit_input::(params); diff --git a/proof-input/src/serialization/json.rs b/proof-input/src/serialization/json.rs index 25f2022..be58739 100644 --- a/proof-input/src/serialization/json.rs +++ b/proof-input/src/serialization/json.rs @@ -12,7 +12,7 @@ use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput; use plonky2::plonk::proof::CompressedProofWithPublicInputs; use serde_json::to_writer_pretty; -use crate::params::TestParams; +use crate::params::InputParams; // Function to export proof with public input to json file fn export_proof_with_pi_to_json( @@ -46,7 +46,7 @@ pub fn read_bytes_from_file>(path: P) -> io::Result> { #[cfg(test)] mod tests { use super::*; - use crate::params::{C, D, F, HF}; + use crate::params::{C, D, F, HF, InputParams, Params}; use std::time::Instant; use codex_plonky2_circuits::circuits::params::CircuitParams; use codex_plonky2_circuits::circuits::sample_cells::SampleCircuit; @@ -61,7 +61,7 @@ mod tests { #[test] fn test_export_circ_input_to_json() -> anyhow::Result<()> { // Create Params - let params = TestParams::default(); + let params = Params::default().input_params; // Export the circuit input to JSON generate_and_export_circ_input_to_json::(¶ms, "input.json")?; @@ -84,7 +84,7 @@ mod tests { #[test] fn test_export_import_circ_input() -> anyhow::Result<()> { // Create Params instance - let params = TestParams::default(); + let params = Params::default().input_params; // Export the circuit input to JSON let original_circ_input = gen_testing_circuit_input(¶ms); @@ -109,13 +109,11 @@ mod tests { // reads the json input from file and runs the circuit #[test] fn test_read_json_and_run_circuit() -> anyhow::Result<()> { - let params = TestParams::default(); - // Create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let circuit_params = CircuitParams::default(); + let circuit_params = Params::default().circuit_params; let circ = SampleCircuit::::new(circuit_params.clone()); let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder)?; @@ -152,7 +150,7 @@ mod tests { // NOTE: expects that the json input proof uses the default params #[test] fn test_read_json_and_verify() -> anyhow::Result<()> { - let params = TestParams::default(); + let params = Params::default().input_params; // Import the circuit input from JSON let imported_circ_input: SampleCircuitInput = import_circ_input_from_json("input.json")?; @@ -171,13 +169,14 @@ mod tests { // test out custom default gate and generator serializers to export/import circuit data #[test] fn test_circuit_data_serializer() -> anyhow::Result<()> { - let params = TestParams::default(); + let params = Params::default(); + let input_params = params.input_params; // Create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let circuit_params = CircuitParams::default(); + let circuit_params = params.circuit_params; let circ = SampleCircuit::::new(circuit_params.clone()); let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder)?; @@ -185,7 +184,7 @@ mod tests { let mut pw = PartialWitness::new(); // gen circ input - let imported_circ_input: SampleCircuitInput = gen_testing_circuit_input::(¶ms); + let imported_circ_input: SampleCircuitInput = gen_testing_circuit_input::(&input_params); circ.sample_slot_assign_witness(&mut pw, &targets, &imported_circ_input)?; // Build the circuit @@ -223,13 +222,14 @@ mod tests { // test proof with public input serialization #[test] fn test_proof_with_pi_serializer() -> anyhow::Result<()> { - let params = TestParams::default(); + let params = Params::default(); + let input_params = params.input_params; // Create the circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let circuit_params = CircuitParams::default(); + let circuit_params = params.circuit_params; let circ = SampleCircuit::::new(circuit_params.clone()); let mut targets = circ.sample_slot_circuit_with_public_input(&mut builder)?; @@ -237,7 +237,7 @@ mod tests { let mut pw = PartialWitness::new(); // gen circ input - let imported_circ_input: SampleCircuitInput = gen_testing_circuit_input::(¶ms); + let imported_circ_input: SampleCircuitInput = gen_testing_circuit_input::(&input_params); circ.sample_slot_assign_witness(&mut pw, &targets, &imported_circ_input)?; // Build the circuit diff --git a/workflow/src/bin/gen_input.rs b/workflow/src/bin/gen_input.rs index 1a8bd84..ff5c26e 100644 --- a/workflow/src/bin/gen_input.rs +++ b/workflow/src/bin/gen_input.rs @@ -10,7 +10,7 @@ fn main() -> Result<()> { let params = Params::from_env()?; // generate circuit input with given parameters - let circ_input = gen_testing_circuit_input::(¶ms.test); + let circ_input = gen_testing_circuit_input::(¶ms.input_params); // export circuit parameters to json file let filename= "input.json";