From 21669be2463a77bcaf5e23bd40cafa1cca5250fc Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 13 Aug 2021 10:40:31 +0200 Subject: [PATCH] Some arithm optims --- src/fri/recursive_verifier.rs | 49 +++++++++++++++++++++++---------- src/gadgets/interpolation.rs | 5 ++-- src/gates/arithmetic.rs | 8 +++--- src/plonk/copy_constraint.rs | 1 + src/plonk/prover.rs | 20 ++++++++++++++ src/plonk/recursive_verifier.rs | 8 +++--- src/util/reducing.rs | 18 ++++++++++-- 7 files changed, 83 insertions(+), 26 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 1e028255..cf07dfad 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -139,7 +139,13 @@ impl, const D: usize> CircuitBuilder { let precomputed_reduced_evals = with_context!( self, "precompute reduced evaluations", - PrecomputedReducedEvalsTarget::from_os_and_alpha(os, alpha, self) + PrecomputedReducedEvalsTarget::from_os_and_alpha( + os, + alpha, + common_data.degree_bits, + zeta, + self + ) ); for (i, round_proof) in proof.query_round_proofs.iter().enumerate() { @@ -209,6 +215,7 @@ impl, const D: usize> CircuitBuilder { precomputed_reduced_evals: PrecomputedReducedEvalsTarget, common_data: &CommonCircuitData, ) -> ExtensionTarget { + println!("combine initial: {}", self.num_gates()); assert!(D > 1, "Not implemented for D=1."); let config = self.config.clone(); let degree_log = common_data.degree_bits; @@ -246,6 +253,7 @@ impl, const D: usize> CircuitBuilder { sum = self.div_add_extension(single_numerator, vanish_zeta, sum); alpha.reset(); + println!("done single: {}", self.num_gates()); // Polynomials opened at `x` and `g x`, i.e., the Zs polynomials. let zs_evals = proof .unsalted_evals(PlonkPolynomials::ZS_PARTIAL_PRODUCTS, config.zero_knowledge) @@ -255,20 +263,21 @@ impl, const D: usize> CircuitBuilder { .collect::>(); let zs_composition_eval = alpha.reduce_base(&zs_evals, self); - let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); - let zeta_right = self.mul_extension(g, zeta); - let interpol_val = self.interpolate2( - [ - (zeta, precomputed_reduced_evals.zs), - (zeta_right, precomputed_reduced_evals.zs_right), - ], - subgroup_x, + let interpol_val = self.mul_add_extension( + vanish_zeta, + precomputed_reduced_evals.slope, + precomputed_reduced_evals.zs, ); - let (zs_numerator, vanish_zeta_right) = - self.sub_two_extension(zs_composition_eval, interpol_val, subgroup_x, zeta_right); - let zs_denominator = self.mul_extension(vanish_zeta, vanish_zeta_right); - sum = alpha.shift(sum, self); + let (zs_numerator, vanish_zeta_right) = self.sub_two_extension( + zs_composition_eval, + interpol_val, + subgroup_x, + precomputed_reduced_evals.zeta_right, + ); + let (mut sum, zs_denominator) = + alpha.shift_and_mul(sum, vanish_zeta, vanish_zeta_right, self); sum = self.div_add_extension(zs_numerator, zs_denominator, sum); + println!("done doubles: {}", self.num_gates()); sum } @@ -286,6 +295,7 @@ impl, const D: usize> CircuitBuilder { round_proof: &FriQueryRoundTarget, common_data: &CommonCircuitData, ) { + println!("query round: {}", self.num_gates()); let config = &common_data.config.fri_config; let n_log = log2_strict(n); // TODO: Do we need to range check `x_index` to a target smaller than `p`? @@ -308,7 +318,7 @@ impl, const D: usize> CircuitBuilder { // `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain. let mut subgroup_x = with_context!(self, "compute x from its index", { - let g = self.constant(F::MULTIPLICATIVE_GROUP_GENERATOR); + let g = self.constant(F::coset_shift()); let phi = self.constant(F::primitive_root_of_unity(n_log)); let phi = self.exp_from_bits(phi, x_index_bits.iter().rev()); @@ -331,6 +341,7 @@ impl, const D: usize> CircuitBuilder { ); for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() { + println!("query round, {}-th arity: {}", i, self.num_gates()); let evals = &round_proof.steps[i].evals; // Split x_index into the index of the coset x is in, and the index of x within that coset. @@ -393,12 +404,16 @@ struct PrecomputedReducedEvalsTarget { pub single: ExtensionTarget, pub zs: ExtensionTarget, pub zs_right: ExtensionTarget, + pub slope: ExtensionTarget, + pub zeta_right: ExtensionTarget, } impl PrecomputedReducedEvalsTarget { fn from_os_and_alpha>( os: &OpeningSetTarget, alpha: ExtensionTarget, + degree_log: usize, + zeta: ExtensionTarget, builder: &mut CircuitBuilder, ) -> Self { let mut alpha = ReducingFactorTarget::new(alpha); @@ -416,10 +431,16 @@ impl PrecomputedReducedEvalsTarget { let zs = alpha.reduce(&os.plonk_zs, builder); let zs_right = alpha.reduce(&os.plonk_zs_right, builder); + let g = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); + let zeta_right = builder.mul_extension(g, zeta); + let (numerator, denominator) = builder.sub_two_extension(zs_right, zs, zeta_right, zeta); + Self { single, zs, zs_right, + slope: builder.div_extension(numerator, denominator), + zeta_right, } } } diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 92090d90..dd8a1334 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -12,9 +12,10 @@ impl, const D: usize> CircuitBuilder { interpolation_points: [(ExtensionTarget, ExtensionTarget); 2], evaluation_point: ExtensionTarget, ) -> ExtensionTarget { - // a0 -> a1 - // b0 -> b1 + // a0 -> a1 : zeta -> precomp0 + // b0 -> b1 : g*zeta -> precomp1 // x -> a1 + (x-a0)*(b1-a1)/(b0-a0) + // x -> precomp0 + (x-zeta)*(precomp1-precomp0)/(g*zeta - zeta) let (x_m_a0, b1_m_a1) = self.sub_two_extension( evaluation_point, diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 499bd4a7..81c9bc9a 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -24,16 +24,16 @@ impl ArithmeticExtensionGate { pub fn wires_first_addend() -> Range { 2 * D..3 * D } - pub fn wires_second_multiplicand_0() -> Range { + pub fn wires_first_output() -> Range { 3 * D..4 * D } - pub fn wires_second_multiplicand_1() -> Range { + pub fn wires_second_multiplicand_0() -> Range { 4 * D..5 * D } - pub fn wires_second_addend() -> Range { + pub fn wires_second_multiplicand_1() -> Range { 5 * D..6 * D } - pub fn wires_first_output() -> Range { + pub fn wires_second_addend() -> Range { 6 * D..7 * D } pub fn wires_second_output() -> Range { diff --git a/src/plonk/copy_constraint.rs b/src/plonk/copy_constraint.rs index a838ed37..a04fc75b 100644 --- a/src/plonk/copy_constraint.rs +++ b/src/plonk/copy_constraint.rs @@ -1,6 +1,7 @@ use crate::iop::target::Target; /// A named copy constraint. +#[derive(Debug)] pub struct CopyConstraint { pub pair: (Target, Target), pub name: String, diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 0d66daa8..026fc4b7 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -68,6 +68,26 @@ pub(crate) fn prove, const D: usize>( partial_witness.full_witness(degree, num_wires) ); + { + let mut count = 0; + let mut count_bad = 0; + for i in 0..degree { + if prover_data.gate_instances[i].gate_ref.0.id() + != "ArithmeticExtensionGate".to_string() + { + continue; + } + count += 1; + let row = witness.wire_values.iter().map(|c| c[i]).collect::>(); + // println!("{} {:?}", i, &row); + if row[16..].iter().all(|x| x.is_zero()) { + println!("{} {:?}", i, row); + count_bad += 1; + } + } + println!("{} {}", count, count_bad); + } + let wires_values: Vec> = timed!( timing, "compute wire polynomials", diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 3ab1cdec..369fbcb1 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -426,7 +426,7 @@ mod tests { const D: usize = 4; let config = CircuitConfig { num_wires: 126, - num_routed_wires: 37, + num_routed_wires: 64, security_bits: 128, rate_bits: 3, num_challenges: 3, @@ -434,7 +434,7 @@ mod tests { cap_height: 3, fri_config: FriConfig { proof_of_work_bits: 1, - reduction_arity_bits: vec![2, 2, 2, 2, 2, 2], + reduction_arity_bits: vec![3, 3, 3], num_query_rounds: 40, cap_height: 3, }, @@ -443,9 +443,9 @@ mod tests { let (proof_with_pis, vd, cd) = { let mut builder = CircuitBuilder::::new(config.clone()); let _two = builder.two(); - let _two = builder.hash_n_to_hash(vec![_two], true).elements[0]; + let mut _two = builder.hash_n_to_hash(vec![_two], true).elements[0]; for _ in 0..10000 { - let _two = builder.mul(_two, _two); + _two = builder.mul(_two, _two); } let data = builder.build(); ( diff --git a/src/util/reducing.rs b/src/util/reducing.rs index ce6827ba..8e8fa661 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -211,9 +211,23 @@ impl ReducingFactorTarget { F: Extendable, { let exp = builder.exp_u64_extension(self.base, self.count); - let tmp = builder.mul_extension(exp, x); self.count = 0; - tmp + builder.mul_extension(exp, x) + } + + pub fn shift_and_mul( + &mut self, + x: ExtensionTarget, + a: ExtensionTarget, + b: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> (ExtensionTarget, ExtensionTarget) + where + F: Extendable, + { + let exp = builder.exp_u64_extension(self.base, self.count); + self.count = 0; + builder.mul_two_extension(exp, x, a, b) } pub fn reset(&mut self) {