diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 650bee11..5c082d86 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -4,6 +4,7 @@ use log::info; use plonky2::field::crandall_field::CrandallField; use plonky2::field::extension_field::Extendable; use plonky2::field::field_types::RichField; +use plonky2::fri::reduction_strategies::FriReductionStrategy; use plonky2::fri::FriConfig; use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::witness::PartialWitness; @@ -31,7 +32,7 @@ fn bench_prove, const D: usize>() -> Result<()> { cap_height: 1, fri_config: FriConfig { proof_of_work_bits: 15, - reduction_arity_bits: vec![2, 2, 2, 2, 2, 2], + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), num_query_rounds: 35, }, }; diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index c704c33c..8233a293 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -216,7 +216,7 @@ impl PolynomialBatchCommitment { lde_final_poly, lde_final_values, challenger, - &config, + &common_data, timing, ); diff --git a/src/fri/mod.rs b/src/fri/mod.rs index 5e8936fa..4aebbeec 100644 --- a/src/fri/mod.rs +++ b/src/fri/mod.rs @@ -1,24 +1,34 @@ +use crate::fri::reduction_strategies::FriReductionStrategy; + pub mod commitment; pub mod proof; pub mod prover; pub mod recursive_verifier; +pub mod reduction_strategies; pub mod verifier; #[derive(Debug, Clone, Eq, PartialEq)] pub struct FriConfig { pub proof_of_work_bits: u32, - /// The arity of each FRI reduction step, expressed (i.e. the log2 of the actual arity). - /// For example, `[3, 2, 1]` would describe a FRI reduction tree with 8-to-1 reduction, then - /// a 4-to-1 reduction, then a 2-to-1 reduction. After these reductions, the reduced polynomial - /// is sent directly. - pub reduction_arity_bits: Vec, + pub reduction_strategy: FriReductionStrategy, /// Number of query rounds to perform. pub num_query_rounds: usize, } -impl FriConfig { +/// Parameters which are generated during preprocessing, in contrast to `FriConfig` which is +/// user-specified. +#[derive(Debug)] +pub(crate) struct FriParams { + /// The arity of each FRI reduction step, expressed as the log2 of the actual arity. + /// For example, `[3, 2, 1]` would describe a FRI reduction tree with 8-to-1 reduction, then + /// a 4-to-1 reduction, then a 2-to-1 reduction. After these reductions, the reduced polynomial + /// is sent directly. + pub reduction_arity_bits: Vec, +} + +impl FriParams { pub(crate) fn total_arities(&self) -> usize { self.reduction_arity_bits.iter().sum() } diff --git a/src/fri/proof.rs b/src/fri/proof.rs index 62e836af..793d0798 100644 --- a/src/fri/proof.rs +++ b/src/fri/proof.rs @@ -139,7 +139,7 @@ impl, const D: usize> FriProof { .. } = self; let cap_height = common_data.config.cap_height; - let reduction_arity_bits = &common_data.config.fri_config.reduction_arity_bits; + let reduction_arity_bits = &common_data.fri_params.reduction_arity_bits; let num_reductions = reduction_arity_bits.len(); let num_initial_trees = query_round_proofs[0].initial_trees_proof.evals_proofs.len(); @@ -241,7 +241,7 @@ impl, const D: usize> CompressedFriProof { .. } = self; let cap_height = common_data.config.cap_height; - let reduction_arity_bits = &common_data.config.fri_config.reduction_arity_bits; + let reduction_arity_bits = &common_data.fri_params.reduction_arity_bits; let num_reductions = reduction_arity_bits.len(); let num_initial_trees = query_round_proofs .initial_trees_proofs diff --git a/src/fri/prover.rs b/src/fri/prover.rs index a0d71d98..d0fc1f0c 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -8,7 +8,7 @@ use crate::hash::hash_types::HashOut; use crate::hash::hashing::hash_n_to_1; use crate::hash::merkle_tree::MerkleTree; use crate::iop::challenger::Challenger; -use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::plonk_common::reduce_with_powers; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; @@ -23,7 +23,7 @@ pub fn fri_proof, const D: usize>( // Evaluation of the polynomial on the large domain. lde_polynomial_values: PolynomialValues, challenger: &mut Challenger, - config: &CircuitConfig, + common_data: &CommonCircuitData, timing: &mut TimingTree, ) -> FriProof { let n = lde_polynomial_values.values.len(); @@ -37,7 +37,7 @@ pub fn fri_proof, const D: usize>( lde_polynomial_coeffs, lde_polynomial_values, challenger, - config, + common_data, ) ); @@ -46,17 +46,12 @@ pub fn fri_proof, const D: usize>( let pow_witness = timed!( timing, "find for proof-of-work witness", - fri_proof_of_work(current_hash, &config.fri_config) + fri_proof_of_work(current_hash, &common_data.config.fri_config) ); // Query phase - let query_round_proofs = fri_prover_query_rounds( - initial_merkle_trees, - &trees, - challenger, - n, - &config.fri_config, - ); + let query_round_proofs = + fri_prover_query_rounds(initial_merkle_trees, &trees, challenger, n, common_data); FriProof { commit_phase_merkle_caps: trees.iter().map(|t| t.cap.clone()).collect(), @@ -70,14 +65,15 @@ fn fri_committed_trees, const D: usize>( mut coeffs: PolynomialCoeffs, mut values: PolynomialValues, challenger: &mut Challenger, - config: &CircuitConfig, + common_data: &CommonCircuitData, ) -> (Vec>, PolynomialCoeffs) { + let config = &common_data.config; let mut trees = Vec::new(); let mut shift = F::MULTIPLICATIVE_GROUP_GENERATOR; - let num_reductions = config.fri_config.reduction_arity_bits.len(); + let num_reductions = common_data.fri_params.reduction_arity_bits.len(); for i in 0..num_reductions { - let arity = 1 << config.fri_config.reduction_arity_bits[i]; + let arity = 1 << common_data.fri_params.reduction_arity_bits[i]; reverse_index_bits_in_place(&mut values.values); let chunked_values = values @@ -136,10 +132,10 @@ fn fri_prover_query_rounds, const D: usize>( trees: &[MerkleTree], challenger: &mut Challenger, n: usize, - config: &FriConfig, + common_data: &CommonCircuitData, ) -> Vec> { - (0..config.num_query_rounds) - .map(|_| fri_prover_query_round(initial_merkle_trees, trees, challenger, n, config)) + (0..common_data.config.fri_config.num_query_rounds) + .map(|_| fri_prover_query_round(initial_merkle_trees, trees, challenger, n, common_data)) .collect() } @@ -148,7 +144,7 @@ fn fri_prover_query_round, const D: usize>( trees: &[MerkleTree], challenger: &mut Challenger, n: usize, - config: &FriConfig, + common_data: &CommonCircuitData, ) -> FriQueryRound { let mut query_steps = Vec::new(); let x = challenger.get_challenge(); @@ -158,7 +154,7 @@ fn fri_prover_query_round, const D: usize>( .map(|t| (t.get(x_index).to_vec(), t.prove(x_index))) .collect::>(); for (i, tree) in trees.iter().enumerate() { - let arity_bits = config.reduction_arity_bits[i]; + let arity_bits = common_data.fri_params.reduction_arity_bits[i]; let evals = unflatten(tree.get(x_index >> arity_bits)); let merkle_proof = tree.prove(x_index >> arity_bits); diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 78d5f007..51135351 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -122,7 +122,7 @@ impl, const D: usize> CircuitBuilder { "Number of query rounds does not match config." ); debug_assert!( - !config.fri_config.reduction_arity_bits.is_empty(), + !common_data.fri_params.reduction_arity_bits.is_empty(), "Number of reductions should be non-zero." ); @@ -325,7 +325,12 @@ impl, const D: usize> CircuitBuilder { ) ); - for (i, &arity_bits) in config.fri_config.reduction_arity_bits.iter().enumerate() { + for (i, &arity_bits) in common_data + .fri_params + .reduction_arity_bits + .iter() + .enumerate() + { let evals = &round_proof.steps[i].evals; // Split x_index into the index of the coset x is in, and the index of x within that coset. @@ -376,7 +381,10 @@ impl, const D: usize> CircuitBuilder { // to the one sent by the prover. let eval = with_context!( self, - "evaluate final polynomial", + &format!( + "evaluate final polynomial of length {}", + proof.final_poly.len() + ), proof.final_poly.eval_scalar(self, subgroup_x) ); self.connect_extension(eval, old_eval); diff --git a/src/fri/reduction_strategies.rs b/src/fri/reduction_strategies.rs new file mode 100644 index 00000000..8bc708c6 --- /dev/null +++ b/src/fri/reduction_strategies.rs @@ -0,0 +1,136 @@ +use std::time::Instant; + +use log::debug; + +/// A method for deciding what arity to use at each reduction layer. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum FriReductionStrategy { + /// `ConstantArityBits(arity_bits, final_poly_bits)` applies reductions of arity `2^arity_bits` + /// until the polynomial degree is `2^final_poly_bits` or less. This tends to work well in the + /// recursive setting, as it avoids needing multiple configurations of gates used in FRI + /// verification, such as `InterpolationGate`. + ConstantArityBits(usize, usize), + + /// Optimize for size. + MinSize, +} + +impl FriReductionStrategy { + /// The arity of each FRI reduction step, expressed as the log2 of the actual arity. + pub(crate) fn reduction_arity_bits( + &self, + mut degree_bits: usize, + rate_bits: usize, + num_queries: usize, + ) -> Vec { + match self { + &FriReductionStrategy::ConstantArityBits(arity_bits, final_poly_bits) => { + let mut result = Vec::new(); + while result.is_empty() || degree_bits > final_poly_bits { + result.push(arity_bits); + assert!(degree_bits >= arity_bits); + degree_bits -= arity_bits; + } + result.shrink_to_fit(); + result + } + + &FriReductionStrategy::MinSize => { + min_size_arity_bits(degree_bits, rate_bits, num_queries) + } + } + } +} + +fn min_size_arity_bits(degree_bits: usize, rate_bits: usize, num_queries: usize) -> Vec { + let start = Instant::now(); + let (mut arity_bits, fri_proof_size) = + min_size_arity_bits_helper(degree_bits, rate_bits, num_queries, vec![]); + arity_bits.shrink_to_fit(); + + debug!( + "min_size_arity_bits took {:.3}s", + start.elapsed().as_secs_f32() + ); + debug!( + "Smallest arity_bits {:?} results in estimated FRI proof size of {} elements", + arity_bits, fri_proof_size + ); + + arity_bits +} + +/// Return `(arity_bits, fri_proof_size)`. +fn min_size_arity_bits_helper( + degree_bits: usize, + rate_bits: usize, + num_queries: usize, + prefix: Vec, +) -> (Vec, usize) { + // 2^4 is the largest arity we see in optimal reduction sequences in practice. For 2^5 to occur + // in an optimal sequence, we would need a really massive polynomial. + const MAX_ARITY_BITS: usize = 4; + + let sum_of_arities: usize = prefix.iter().sum(); + let current_layer_bits = degree_bits + rate_bits - sum_of_arities; + assert!(current_layer_bits >= rate_bits); + + let mut best_arity_bits = prefix.clone(); + let mut best_size = relative_proof_size(degree_bits, rate_bits, num_queries, &prefix); + + // The largest next_arity_bits to search. Note that any optimal arity sequence will be + // monotonically non-increasing, as a larger arity will shrink more Merkle proofs if it occurs + // earlier in the sequence. + let max_arity_bits = prefix + .last() + .copied() + .unwrap_or(MAX_ARITY_BITS) + .min(current_layer_bits - rate_bits); + + for next_arity_bits in 1..=max_arity_bits { + let mut extended_prefix = prefix.clone(); + extended_prefix.push(next_arity_bits); + + let (arity_bits, size) = + min_size_arity_bits_helper(degree_bits, rate_bits, num_queries, extended_prefix); + if size < best_size { + best_arity_bits = arity_bits; + best_size = size; + } + } + + (best_arity_bits, best_size) +} + +/// Compute the approximate size of a FRI proof with the given reduction arities. Note that this +/// ignores initial evaluations, which aren't affected by arities, and some other minor +/// contributions. The result is measured in field elements. +fn relative_proof_size( + degree_bits: usize, + rate_bits: usize, + num_queries: usize, + arity_bits: &[usize], +) -> usize { + const D: usize = 4; + + let mut current_layer_bits = degree_bits + rate_bits; + + let mut total_elems = 0; + for arity_bits in arity_bits { + let arity = 1 << arity_bits; + + // Add neighboring evaluations, which are extension field elements. + total_elems += (arity - 1) * D * num_queries; + // Add siblings in the Merkle path. + total_elems += current_layer_bits * 4 * num_queries; + + current_layer_bits -= arity_bits; + } + + // Add the final polynomial's coefficients. + assert!(current_layer_bits >= rate_bits); + let final_poly_len = 1 << (current_layer_bits - rate_bits); + total_elems += D * final_poly_len; + + total_elems +} diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index f9c1d998..93d10723 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -81,7 +81,7 @@ pub(crate) fn verify_fri_proof, const D: usize>( "Number of query rounds does not match config." ); ensure!( - !config.fri_config.reduction_arity_bits.is_empty(), + !common_data.fri_params.reduction_arity_bits.is_empty(), "Number of reductions should be non-zero." ); @@ -225,7 +225,6 @@ fn fri_verifier_query_round, const D: usize>( round_proof: &FriQueryRound, common_data: &CommonCircuitData, ) -> Result<()> { - let config = &common_data.config.fri_config; fri_verify_initial_proof( x_index, &round_proof.initial_trees_proof, @@ -247,7 +246,12 @@ fn fri_verifier_query_round, const D: usize>( common_data, ); - for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { + for (i, &arity_bits) in common_data + .fri_params + .reduction_arity_bits + .iter() + .enumerate() + { let arity = 1 << arity_bits; let evals = &round_proof.steps[i].evals; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 18f84681..eeef44da 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -11,6 +11,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::fft::fft_root_table; use crate::field::field_types::RichField; use crate::fri::commitment::PolynomialBatchCommitment; +use crate::fri::FriParams; use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; @@ -320,15 +321,27 @@ impl, const D: usize> CircuitBuilder { }) } + fn fri_params(&self, degree_bits_estimate: usize) -> FriParams { + let fri_config = &self.config.fri_config; + let reduction_arity_bits = fri_config.reduction_strategy.reduction_arity_bits( + degree_bits_estimate, + self.config.rate_bits, + fri_config.num_query_rounds, + ); + FriParams { + reduction_arity_bits, + } + } + /// The number of polynomial values that will be revealed per opening, both for the "regular" /// polynomials and for the Z polynomials. Because calculating these values involves a recursive /// dependence (the amount of blinding depends on the degree, which depends on the blinding), /// this function takes in an estimate of the degree. fn num_blinding_gates(&self, degree_estimate: usize) -> (usize, usize) { + let degree_bits_estimate = log2_strict(degree_estimate); let fri_queries = self.config.fri_config.num_query_rounds; let arities: Vec = self - .config - .fri_config + .fri_params(degree_bits_estimate) .reduction_arity_bits .iter() .map(|x| 1 << x) @@ -578,9 +591,10 @@ impl, const D: usize> CircuitBuilder { let degree = self.gate_instances.len(); info!("Degree after blinding & padding: {}", degree); let degree_bits = log2_strict(degree); + let fri_params = self.fri_params(degree_bits); assert!( - self.config.fri_config.total_arities() <= degree_bits, - "FRI total reduction arity is too large." + fri_params.total_arities() <= degree_bits, + "FRI total reduction arity is too large.", ); let gates = self.gates.iter().cloned().collect(); @@ -673,6 +687,7 @@ impl, const D: usize> CircuitBuilder { let common = CommonCircuitData { config: self.config, + fri_params, degree_bits, gates: prefixed_gates, quotient_degree_factor, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index f8bb0b6a..9cb7ba61 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -7,7 +7,8 @@ use crate::field::extension_field::Extendable; use crate::field::fft::FftRootTable; use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; -use crate::fri::FriConfig; +use crate::fri::reduction_strategies::FriReductionStrategy; +use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::PrefixedGate; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; @@ -47,7 +48,7 @@ impl Default for CircuitConfig { cap_height: 1, fri_config: FriConfig { proof_of_work_bits: 1, - reduction_arity_bits: vec![1, 1, 1, 1], + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), num_query_rounds: 1, }, } @@ -71,7 +72,7 @@ impl CircuitConfig { cap_height: 1, fri_config: FriConfig { proof_of_work_bits: 1, - reduction_arity_bits: vec![1], + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), num_query_rounds: 1, }, } @@ -84,7 +85,7 @@ impl CircuitConfig { cap_height: 1, fri_config: FriConfig { proof_of_work_bits: 1, - reduction_arity_bits: vec![1, 1, 1, 1], + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), num_query_rounds: 1, }, ..Self::large_config() @@ -175,6 +176,8 @@ pub(crate) struct VerifierOnlyCircuitData { pub struct CommonCircuitData, const D: usize> { pub(crate) config: CircuitConfig, + pub(crate) fri_params: FriParams, + pub(crate) degree_bits: usize, /// The types of gates used in this circuit, along with their prefixes. @@ -254,7 +257,7 @@ impl, const D: usize> CommonCircuitData { } pub fn final_poly_len(&self) -> usize { - 1 << (self.degree_bits - self.config.fri_config.total_arities()) + 1 << (self.degree_bits - self.fri_params.total_arities()) } } diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 40f22708..7e95bc6d 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -127,10 +127,12 @@ mod tests { use log::info; use super::*; + use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::fri::proof::{ FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, }; + use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::FriConfig; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::hash::merkle_proofs::MerkleProofTarget; @@ -372,9 +374,9 @@ mod tests { zero_knowledge: false, cap_height: 2, fri_config: FriConfig { - proof_of_work_bits: 1, - reduction_arity_bits: vec![2, 2, 2, 2, 2, 2], - num_query_rounds: 40, + proof_of_work_bits: 15, + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), + num_query_rounds: 27, }, }; let (proof_with_pis, vd, cd) = { @@ -428,17 +430,15 @@ mod tests { cap_height: 3, fri_config: FriConfig { proof_of_work_bits: 15, - reduction_arity_bits: vec![3, 3, 3], + reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), num_query_rounds: 27, }, }; let (proof_with_pis, vd, cd) = { let (proof_with_pis, vd, cd) = { let mut builder = CircuitBuilder::::new(config.clone()); - let _two = builder.two(); - let mut _two = builder.hash_n_to_hash(vec![_two], true).elements[0]; - for _ in 0..10000 { - _two = builder.mul(_two, _two); + for i in 0..8_000 { + builder.constant(F::from_canonical_u64(i)); } let data = builder.build(); ( diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index 2dbf7989..0ce9e3d0 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -10,7 +10,7 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly; use crate::plonk::vars::EvaluationVars; pub(crate) fn verify, const D: usize>( - mut proof_with_pis: ProofWithPublicInputs, + proof_with_pis: ProofWithPublicInputs, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, ) -> Result<()> { diff --git a/src/util/serialization.rs b/src/util/serialization.rs index 21b12c11..d4a4ea1e 100644 --- a/src/util/serialization.rs +++ b/src/util/serialization.rs @@ -277,8 +277,8 @@ impl Buffer { let mut fqrs = Vec::with_capacity(config.fri_config.num_query_rounds); for _ in 0..config.fri_config.num_query_rounds { let initial_trees_proof = self.read_fri_initial_proof(common_data)?; - let steps = config - .fri_config + let steps = common_data + .fri_params .reduction_arity_bits .iter() .map(|&ar| self.read_fri_query_step(1 << ar)) @@ -307,7 +307,7 @@ impl Buffer { common_data: &CommonCircuitData, ) -> Result> { let config = &common_data.config; - let commit_phase_merkle_caps = (0..config.fri_config.reduction_arity_bits.len()) + let commit_phase_merkle_caps = (0..common_data.fri_params.reduction_arity_bits.len()) .map(|_| self.read_merkle_cap(config.cap_height)) .collect::>>()?; let query_round_proofs = self.read_fri_query_rounds(common_data)?; @@ -417,8 +417,8 @@ impl Buffer { } let initial_trees_proofs = HashMap::from_iter(pairs); - let mut steps = Vec::with_capacity(config.fri_config.reduction_arity_bits.len()); - for &a in &config.fri_config.reduction_arity_bits { + let mut steps = Vec::with_capacity(common_data.fri_params.reduction_arity_bits.len()); + for &a in &common_data.fri_params.reduction_arity_bits { indices.iter_mut().for_each(|x| { *x >>= a; }); @@ -458,7 +458,7 @@ impl Buffer { common_data: &CommonCircuitData, ) -> Result> { let config = &common_data.config; - let commit_phase_merkle_caps = (0..config.fri_config.reduction_arity_bits.len()) + let commit_phase_merkle_caps = (0..common_data.fri_params.reduction_arity_bits.len()) .map(|_| self.read_merkle_cap(config.cap_height)) .collect::>>()?; let query_round_proofs = self.read_compressed_fri_query_rounds(common_data)?;