diff --git a/src/field/cosets.rs b/src/field/cosets.rs index 58c8b838..0f43a12b 100644 --- a/src/field/cosets.rs +++ b/src/field/cosets.rs @@ -39,7 +39,7 @@ mod tests { let generator = F::primitive_root_of_unity(SUBGROUP_BITS); let subgroup_size = 1 << SUBGROUP_BITS; - let shifts = get_unique_coset_shifts::(SUBGROUP_BITS, NUM_SHIFTS); + let shifts = get_unique_coset_shifts::(subgroup_size, NUM_SHIFTS); let mut union = HashSet::new(); for shift in shifts { diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index eda618cd..e6f8474f 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -118,14 +118,16 @@ impl, const D: usize> CircuitBuilder { "Number of reductions should be non-zero." ); + let precomputed_reduced_evals = + PrecomputedReducedEvalsTarget::from_os_and_alpha(os, alpha, self); for (i, round_proof) in proof.query_round_proofs.iter().enumerate() { context!( self, &format!("verify {}'th FRI query", i), self.fri_verifier_query_round( - os, zeta, alpha, + precomputed_reduced_evals, initial_merkle_roots, proof, challenger, @@ -162,9 +164,9 @@ impl, const D: usize> CircuitBuilder { &mut self, proof: &FriInitialTreeProofTarget, alpha: ExtensionTarget, - os: &OpeningSetTarget, zeta: ExtensionTarget, subgroup_x: Target, + precomputed_reduced_evals: PrecomputedReducedEvalsTarget, common_data: &CommonCircuitData, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); @@ -192,19 +194,9 @@ impl, const D: usize> CircuitBuilder { ) .map(|&e| self.convert_to_ext(e)) .collect::>(); - let single_openings = os - .constants - .iter() - .chain(&os.plonk_sigmas) - .chain(&os.wires) - .chain(&os.quotient_polys) - .chain(&os.partial_products) - .copied() - .collect::>(); - let mut single_numerator = alpha.reduce(&single_evals, self); - // TODO: Precompute the rhs as it is the same in all FRI rounds. - let rhs = alpha.reduce(&single_openings, self); - single_numerator = self.sub_extension(single_numerator, rhs); + let single_composition_eval = alpha.reduce(&single_evals, self); + let single_numerator = + self.sub_extension(single_composition_eval, precomputed_reduced_evals.single); let single_denominator = self.sub_extension(subgroup_x, zeta); let quotient = self.div_unsafe_extension(single_numerator, single_denominator); sum = self.add_extension(sum, quotient); @@ -217,14 +209,15 @@ impl, const D: usize> CircuitBuilder { .take(common_data.zs_range().end) .map(|&e| self.convert_to_ext(e)) .collect::>(); - let zs_composition_eval = alpha.clone().reduce(&zs_evals, self); + let zs_composition_eval = alpha.reduce(&zs_evals, self); let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); let zeta_right = self.mul_extension(g, zeta); - let zs_ev_zeta = alpha.clone().reduce(&os.plonk_zs, self); - let zs_ev_zeta_right = alpha.reduce(&os.plonk_zs_right, self); let interpol_val = self.interpolate2( - [(zeta, zs_ev_zeta), (zeta_right, zs_ev_zeta_right)], + [ + (zeta, precomputed_reduced_evals.zs), + (zeta_right, precomputed_reduced_evals.zs_right), + ], subgroup_x, ); let zs_numerator = self.sub_extension(zs_composition_eval, interpol_val); @@ -240,9 +233,9 @@ impl, const D: usize> CircuitBuilder { fn fri_verifier_query_round( &mut self, - os: &OpeningSetTarget, zeta: ExtensionTarget, alpha: ExtensionTarget, + precomputed_reduced_evals: PrecomputedReducedEvalsTarget, initial_merkle_roots: &[HashTarget], proof: &FriProofTarget, challenger: &mut RecursiveChallenger, @@ -253,7 +246,6 @@ impl, const D: usize> CircuitBuilder { ) { let config = &common_data.config.fri_config; let n_log = log2_strict(n); - let mut evaluations: Vec>> = Vec::new(); // TODO: Do we need to range check `x_index` to a target smaller than `p`? let mut x_index = challenger.get_challenge(self); x_index = self.split_low_high(x_index, n_log, 64).0; @@ -280,6 +272,7 @@ impl, const D: usize> CircuitBuilder { self.mul(g, phi) }); + let mut evaluations: Vec>> = Vec::new(); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { let next_domain_size = domain_size >> arity_bits; let e_x = if i == 0 { @@ -289,9 +282,9 @@ impl, const D: usize> CircuitBuilder { self.fri_combine_initial( &round_proof.initial_trees_proof, alpha, - os, zeta, subgroup_x, + precomputed_reduced_evals, common_data, ) ) @@ -315,23 +308,21 @@ impl, const D: usize> CircuitBuilder { let (low_x_index, high_x_index) = self.split_low_high(x_index, arity_bits, x_index_num_bits); evals = self.insert(low_x_index, e_x, evals); - evaluations.push(evals); context!( self, "verify FRI round Merkle proof.", self.verify_merkle_proof( - flatten_target(&evaluations[i]), + flatten_target(&evals), high_x_index, proof.commit_phase_merkle_roots[i], &round_proof.steps[i].merkle_proof, ) ); + evaluations.push(evals); if i > 0 { // Update the point x to x^arity. - for _ in 0..config.reduction_arity_bits[i - 1] { - subgroup_x = self.square(subgroup_x); - } + subgroup_x = self.exp_power_of_2(subgroup_x, config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; old_x_index = low_x_index; @@ -352,9 +343,7 @@ impl, const D: usize> CircuitBuilder { *betas.last().unwrap(), ) ); - for _ in 0..final_arity_bits { - subgroup_x = self.square(subgroup_x); - } + subgroup_x = self.exp_power_of_2(subgroup_x, final_arity_bits); // Final check of FRI. After all the reductions, we check that the final polynomial is equal // to the one sent by the prover. @@ -366,3 +355,39 @@ impl, const D: usize> CircuitBuilder { self.assert_equal_extension(eval, purported_eval); } } + +#[derive(Copy, Clone)] +struct PrecomputedReducedEvalsTarget { + pub single: ExtensionTarget, + pub zs: ExtensionTarget, + pub zs_right: ExtensionTarget, +} + +impl PrecomputedReducedEvalsTarget { + fn from_os_and_alpha>( + os: &OpeningSetTarget, + alpha: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> Self { + let mut alpha = ReducingFactorTarget::new(alpha); + let single = alpha.reduce( + &os.constants + .iter() + .chain(&os.plonk_sigmas) + .chain(&os.wires) + .chain(&os.quotient_polys) + .chain(&os.partial_products) + .copied() + .collect::>(), + builder, + ); + let zs = alpha.reduce(&os.plonk_zs, builder); + let zs_right = alpha.reduce(&os.plonk_zs_right, builder); + + Self { + single, + zs, + zs_right, + } + } +} diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index 796c2694..712288fd 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -112,11 +112,12 @@ pub fn verify_fri_proof, const D: usize>( "Number of reductions should be non-zero." ); + let precomputed_reduced_evals = PrecomputedReducedEvals::from_os_and_alpha(os, alpha); for round_proof in &proof.query_round_proofs { fri_verifier_query_round( - os, zeta, alpha, + precomputed_reduced_evals, initial_merkle_roots, &proof, challenger, @@ -142,12 +143,43 @@ fn fri_verify_initial_proof( Ok(()) } +/// Holds the reduced (by `alpha`) evaluations at `zeta` for the polynomial opened just at +/// zeta, for `Z` at zeta and for `Z` at `g*zeta`. +#[derive(Copy, Clone)] +struct PrecomputedReducedEvals, const D: usize> { + pub single: F::Extension, + pub zs: F::Extension, + pub zs_right: F::Extension, +} + +impl, const D: usize> PrecomputedReducedEvals { + fn from_os_and_alpha(os: &OpeningSet, alpha: F::Extension) -> Self { + let mut alpha = ReducingFactor::new(alpha); + let single = alpha.reduce( + os.constants + .iter() + .chain(&os.plonk_sigmas) + .chain(&os.wires) + .chain(&os.quotient_polys) + .chain(&os.partial_products), + ); + let zs = alpha.reduce(os.plonk_zs.iter()); + let zs_right = alpha.reduce(os.plonk_zs_right.iter()); + + Self { + single, + zs, + zs_right, + } + } +} + fn fri_combine_initial, const D: usize>( proof: &FriInitialTreeProof, alpha: F::Extension, - os: &OpeningSet, zeta: F::Extension, subgroup_x: F, + precomputed_reduced_evals: PrecomputedReducedEvals, common_data: &CommonCircuitData, ) -> F::Extension { let config = &common_data.config; @@ -174,19 +206,8 @@ fn fri_combine_initial, const D: usize>( [common_data.partial_products_range()], ) .map(|&e| F::Extension::from_basefield(e)); - let single_openings = os - .constants - .iter() - .chain(&os.plonk_sigmas) - .chain(&os.wires) - .chain(&os.quotient_polys) - .chain(&os.partial_products); - let single_diffs = single_evals - .into_iter() - .zip(single_openings) - .map(|(e, &o)| e - o) - .collect::>(); - let single_numerator = alpha.reduce(single_diffs.iter()); + let single_composition_eval = alpha.reduce(single_evals); + let single_numerator = single_composition_eval - precomputed_reduced_evals.single; let single_denominator = subgroup_x - zeta; sum += single_numerator / single_denominator; alpha.reset(); @@ -197,12 +218,12 @@ fn fri_combine_initial, const D: usize>( .iter() .map(|&e| F::Extension::from_basefield(e)) .take(common_data.zs_range().end); - let zs_composition_eval = alpha.clone().reduce(zs_evals); + let zs_composition_eval = alpha.reduce(zs_evals); let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta; let zs_interpol = interpolate2( [ - (zeta, alpha.clone().reduce(os.plonk_zs.iter())), - (zeta_right, alpha.reduce(os.plonk_zs_right.iter())), + (zeta, precomputed_reduced_evals.zs), + (zeta_right, precomputed_reduced_evals.zs_right), ], subgroup_x, ); @@ -215,9 +236,9 @@ fn fri_combine_initial, const D: usize>( } fn fri_verifier_query_round, const D: usize>( - os: &OpeningSet, zeta: F::Extension, alpha: F::Extension, + precomputed_reduced_evals: PrecomputedReducedEvals, initial_merkle_roots: &[Hash], proof: &FriProof, challenger: &mut Challenger, @@ -227,7 +248,6 @@ fn fri_verifier_query_round, const D: usize>( common_data: &CommonCircuitData, ) -> Result<()> { let config = &common_data.config.fri_config; - let mut evaluations: Vec> = Vec::new(); let x = challenger.get_challenge(); let mut domain_size = n; let mut x_index = x.to_canonical_u64() as usize % n; @@ -241,6 +261,8 @@ fn fri_verifier_query_round, const D: usize>( let log_n = log2_strict(n); let mut subgroup_x = F::MULTIPLICATIVE_GROUP_GENERATOR * F::primitive_root_of_unity(log_n).exp(reverse_bits(x_index, log_n) as u64); + + let mut evaluations: Vec> = Vec::new(); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { let arity = 1 << arity_bits; let next_domain_size = domain_size >> arity_bits; @@ -248,9 +270,9 @@ fn fri_verifier_query_round, const D: usize>( fri_combine_initial( &round_proof.initial_trees_proof, alpha, - os, zeta, subgroup_x, + precomputed_reduced_evals, common_data, ) } else { @@ -267,20 +289,18 @@ fn fri_verifier_query_round, const D: usize>( let mut evals = round_proof.steps[i].evals.clone(); // Insert P(y) into the evaluation vector, since it wasn't included by the prover. evals.insert(x_index & (arity - 1), e_x); - evaluations.push(evals); verify_merkle_proof( - flatten(&evaluations[i]), + flatten(&evals), x_index >> arity_bits, proof.commit_phase_merkle_roots[i], &round_proof.steps[i].merkle_proof, false, )?; + evaluations.push(evals); if i > 0 { // Update the point x to x^arity. - for _ in 0..config.reduction_arity_bits[i - 1] { - subgroup_x = subgroup_x.square(); - } + subgroup_x = subgroup_x.exp_power_of_2(config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; old_x_index = x_index & (arity - 1); @@ -296,9 +316,7 @@ fn fri_verifier_query_round, const D: usize>( last_evals, *betas.last().unwrap(), ); - for _ in 0..final_arity_bits { - subgroup_x = subgroup_x.square(); - } + subgroup_x = subgroup_x.exp_power_of_2(final_arity_bits); // Final check of FRI. After all the reductions, we check that the final polynomial is equal // to the one sent by the prover. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 6f85cdcf..3e2b8c4b 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -153,6 +153,15 @@ impl, const D: usize> CircuitBuilder { product } + /// Exponentiate `base` to the power of `2^power_log`. + // TODO: Test + pub fn exp_power_of_2(&mut self, mut base: Target, power_log: usize) -> Target { + for _ in 0..power_log { + base = self.square(base); + } + base + } + // TODO: Optimize this, maybe with a new gate. // TODO: Test /// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 10b60dcd..c7951d78 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -292,7 +292,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. // TODO: Test - pub fn exp_power_of_2( + pub fn exp_power_of_2_extension( &mut self, mut base: ExtensionTarget, power_log: usize, diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index 51bfe6ba..3dd135af 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -59,7 +59,7 @@ impl, const D: usize> CircuitBuilder { let s_sigmas = &proof.openings.plonk_sigmas; let partial_products = &proof.openings.partial_products; - let zeta_pow_deg = self.exp_power_of_2(zeta, inner_common_data.degree_bits); + let zeta_pow_deg = self.exp_power_of_2_extension(zeta, inner_common_data.degree_bits); let vanishing_polys_zeta = context!( self, "evaluate the vanishing polynomial at our challenge point, zeta.", @@ -89,7 +89,7 @@ impl, const D: usize> CircuitBuilder { { let recombined_quotient = scale.reduce(chunk, self); let computed_vanishing_poly = self.mul_extension(z_h_zeta, recombined_quotient); - self.named_route_extension( + self.named_assert_equal_extension( vanishing_polys_zeta[i], computed_vanishing_poly, format!("Vanishing polynomial == Z_H * quotient, challenge {}", i), diff --git a/src/util/scaling.rs b/src/util/scaling.rs index be339784..9409d49c 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -16,7 +16,7 @@ use crate::polynomial::polynomial::PolynomialCoeffs; /// This struct abstract away these operations by implementing Horner's method and keeping track /// of the number of multiplications by `a` to compute the scaling factor. /// See https://github.com/mir-protocol/plonky2/pull/69 for more details and discussions. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ReducingFactor { base: F, count: u64, @@ -79,7 +79,7 @@ impl ReducingFactor { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ReducingFactorTarget { base: ExtensionTarget, count: u64, diff --git a/src/vanishing_poly.rs b/src/vanishing_poly.rs index b82a23c1..540ab86b 100644 --- a/src/vanishing_poly.rs +++ b/src/vanishing_poly.rs @@ -236,10 +236,13 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( ) -> Vec> { let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; for gate in gates { - let gate_constraints = gate - .gate - .0 - .eval_filtered_recursively(builder, vars, &gate.prefix); + let gate_constraints = context!( + builder, + &format!("evaluate {} constraints", gate.gate.0.id()), + gate.gate + .0 + .eval_filtered_recursively(builder, vars, &gate.prefix) + ); for (i, c) in gate_constraints.into_iter().enumerate() { constraints[i] = builder.add_extension(constraints[i], c); }