From 90613359964a64c956aa26fa6ff9aba5b55f3de3 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 12 Aug 2021 15:46:18 +0200 Subject: [PATCH 1/5] Some more arithmetic optimizations --- src/plonk/vanishing_poly.rs | 62 +++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 5743d0fa..6a768aed 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -1,6 +1,9 @@ +use num::Integer; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::gate::PrefixedGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -274,6 +277,7 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( gammas: &[Target], alphas: &[Target], ) -> Vec> { + let one = builder.one_extension(); let max_degree = common_data.quotient_degree_factor; let (num_prods, final_num_prod) = common_data.num_partial_products; @@ -297,6 +301,22 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); + let mut s_ids = Vec::new(); + for j in 0..common_data.config.num_routed_wires / 2 { + let k_0 = builder.constant(common_data.k_is[2 * j]); + let k_0_ext = builder.convert_to_ext(k_0); + let k_1 = builder.constant(common_data.k_is[2 * j + 1]); + let k_1_ext = builder.convert_to_ext(k_1); + let tmp = builder.mul_two_extension(k_0_ext, x, k_1_ext, x); + s_ids.push(tmp.0); + s_ids.push(tmp.1); + } + if common_data.config.num_routed_wires.is_odd() { + let k = builder.constant(common_data.k_is[common_data.k_is.len() - 1]); + let k_ext = builder.convert_to_ext(k); + s_ids.push(builder.mul_extension(k_ext, x)); + } + for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; let z_gz = next_zs[i]; @@ -305,20 +325,50 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( let numerator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; - let k_i = builder.constant(common_data.k_is[j]); - let s_id = builder.scalar_mul_ext(k_i, x); + let beta_ext = builder.convert_to_ext(betas[i]); let gamma_ext = builder.convert_to_ext(gammas[i]); - let tmp = builder.scalar_mul_add_extension(betas[i], s_id, wire_value); - builder.add_extension(tmp, gamma_ext) + let gate = builder.num_gates(); + let first_out = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_output(), + ); + builder + .double_arithmetic_extension( + F::ONE, + F::ONE, + beta_ext, + s_ids[j], + wire_value, + one, + first_out, + gamma_ext, + ) + .1 }) .collect::>(); let denominator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; let s_sigma = s_sigmas[j]; + let beta_ext = builder.convert_to_ext(betas[i]); let gamma_ext = builder.convert_to_ext(gammas[i]); - let tmp = builder.scalar_mul_add_extension(betas[i], s_sigma, wire_value); - builder.add_extension(tmp, gamma_ext) + let gate = builder.num_gates(); + let first_out = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_output(), + ); + builder + .double_arithmetic_extension( + F::ONE, + F::ONE, + beta_ext, + s_sigma, + wire_value, + one, + first_out, + gamma_ext, + ) + .1 }) .collect::>(); let quotient_values = (0..common_data.config.num_routed_wires) From 08e457458d5eb4ca8cad0c09db0504d002eed1df Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 12 Aug 2021 15:48:45 +0200 Subject: [PATCH 2/5] Comments --- src/plonk/vanishing_poly.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 6a768aed..30101d3f 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -301,6 +301,7 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); + // Holds `k[i] * x`. let mut s_ids = Vec::new(); for j in 0..common_data.config.num_routed_wires / 2 { let k_0 = builder.constant(common_data.k_is[2 * j]); @@ -332,6 +333,7 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( gate, ArithmeticExtensionGate::::wires_first_output(), ); + // `beta * s_ids[j] + wire_value + gamma` builder .double_arithmetic_extension( F::ONE, @@ -357,6 +359,7 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( gate, ArithmeticExtensionGate::::wires_first_output(), ); + // `beta * s_sigma + wire_value + gamma` builder .double_arithmetic_extension( F::ONE, From 702eab158345bfde6931cc641fc59401b88d2b98 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 12 Aug 2021 16:03:13 +0200 Subject: [PATCH 3/5] Add `wide_arithmetic` --- src/gadgets/arithmetic_extension.rs | 18 +++++++++---- src/plonk/vanishing_poly.rs | 39 +++-------------------------- 2 files changed, 16 insertions(+), 41 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 03395211..8b7c11e3 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -152,6 +152,18 @@ impl, const D: usize> CircuitBuilder { None } + /// Returns `a*b + c*d + e`. + pub fn wide_arithmetic_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + d: ExtensionTarget, + e: ExtensionTarget, + ) -> ExtensionTarget { + self.inner_product_extension(F::ONE, e, vec![(a, b), (c, d)]) + } + /// Returns `sum_{(a,b) in vecs} constant * a * b`. pub fn inner_product_extension( &mut self, @@ -230,11 +242,7 @@ impl, const D: usize> CircuitBuilder { c: ExtensionTarget, ) -> ExtensionTarget { let one = self.one_extension(); - let gate = self.num_gates(); - let first_out = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out) - .1 + self.wide_arithmetic_extension(one, a, one, b, c) } /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 30101d3f..a0e00395 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -328,50 +328,17 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( let wire_value = vars.local_wires[j]; let beta_ext = builder.convert_to_ext(betas[i]); let gamma_ext = builder.convert_to_ext(gammas[i]); - let gate = builder.num_gates(); - let first_out = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_first_output(), - ); - // `beta * s_ids[j] + wire_value + gamma` - builder - .double_arithmetic_extension( - F::ONE, - F::ONE, - beta_ext, - s_ids[j], - wire_value, - one, - first_out, - gamma_ext, - ) - .1 + // `beta * s_id + wire_value + gamma` + builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext) }) .collect::>(); let denominator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; - let s_sigma = s_sigmas[j]; let beta_ext = builder.convert_to_ext(betas[i]); let gamma_ext = builder.convert_to_ext(gammas[i]); - let gate = builder.num_gates(); - let first_out = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_first_output(), - ); // `beta * s_sigma + wire_value + gamma` - builder - .double_arithmetic_extension( - F::ONE, - F::ONE, - beta_ext, - s_sigma, - wire_value, - one, - first_out, - gamma_ext, - ) - .1 + builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext) }) .collect::>(); let quotient_values = (0..common_data.config.num_routed_wires) From 7271af823b16edb6d70e64a4a55864b298d2eaa5 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 12 Aug 2021 16:48:13 +0200 Subject: [PATCH 4/5] Optimize `evaluate_gate_constraints_recurively` --- src/plonk/vanishing_poly.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index a0e00395..73eaadc7 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -3,7 +3,6 @@ use num::Integer; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; -use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::gate::PrefixedGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -242,6 +241,7 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( vars: EvaluationTargets, ) -> Vec> { let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; + let mut all_gate_constraints: Vec> = Vec::new(); for gate in gates { let gate_constraints = with_context!( builder, @@ -251,9 +251,16 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( .eval_filtered_recursively(builder, vars, &gate.prefix) ); for (i, c) in gate_constraints.into_iter().enumerate() { - constraints[i] = builder.add_extension(constraints[i], c); + if i < all_gate_constraints.len() { + all_gate_constraints[i].push(c); + } else { + all_gate_constraints.push(vec![c]); + } } } + for (i, v) in all_gate_constraints.into_iter().enumerate() { + constraints[i] = builder.add_many_extension(&v); + } constraints } From 2bfa45447628aa88c70ee6d9fcb9c3f967342c1c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 12 Aug 2021 18:21:21 +0200 Subject: [PATCH 5/5] PR feedback --- src/plonk/vanishing_poly.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 73eaadc7..e2180552 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -240,8 +240,7 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( num_gate_constraints: usize, vars: EvaluationTargets, ) -> Vec> { - let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; - let mut all_gate_constraints: Vec> = Vec::new(); + let mut all_gate_constraints = vec![vec![]; num_gate_constraints]; for gate in gates { let gate_constraints = with_context!( builder, @@ -251,13 +250,10 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( .eval_filtered_recursively(builder, vars, &gate.prefix) ); for (i, c) in gate_constraints.into_iter().enumerate() { - if i < all_gate_constraints.len() { - all_gate_constraints[i].push(c); - } else { - all_gate_constraints.push(vec![c]); - } + all_gate_constraints[i].push(c); } } + let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; for (i, v) in all_gate_constraints.into_iter().enumerate() { constraints[i] = builder.add_many_extension(&v); }