diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 03395211..8b7c11e3 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -152,6 +152,18 @@ impl, const D: usize> CircuitBuilder { None } + /// Returns `a*b + c*d + e`. + pub fn wide_arithmetic_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + d: ExtensionTarget, + e: ExtensionTarget, + ) -> ExtensionTarget { + self.inner_product_extension(F::ONE, e, vec![(a, b), (c, d)]) + } + /// Returns `sum_{(a,b) in vecs} constant * a * b`. pub fn inner_product_extension( &mut self, @@ -230,11 +242,7 @@ impl, const D: usize> CircuitBuilder { c: ExtensionTarget, ) -> ExtensionTarget { let one = self.one_extension(); - let gate = self.num_gates(); - let first_out = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out) - .1 + self.wide_arithmetic_extension(one, a, one, b, c) } /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 5743d0fa..e2180552 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -1,3 +1,5 @@ +use num::Integer; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; @@ -238,7 +240,7 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( num_gate_constraints: usize, vars: EvaluationTargets, ) -> Vec> { - let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; + let mut all_gate_constraints = vec![vec![]; num_gate_constraints]; for gate in gates { let gate_constraints = with_context!( builder, @@ -248,9 +250,13 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( .eval_filtered_recursively(builder, vars, &gate.prefix) ); for (i, c) in gate_constraints.into_iter().enumerate() { - constraints[i] = builder.add_extension(constraints[i], c); + all_gate_constraints[i].push(c); } } + let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; + for (i, v) in all_gate_constraints.into_iter().enumerate() { + constraints[i] = builder.add_many_extension(&v); + } constraints } @@ -274,6 +280,7 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( gammas: &[Target], alphas: &[Target], ) -> Vec> { + let one = builder.one_extension(); let max_degree = common_data.quotient_degree_factor; let (num_prods, final_num_prod) = common_data.num_partial_products; @@ -297,6 +304,23 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); + // Holds `k[i] * x`. + let mut s_ids = Vec::new(); + for j in 0..common_data.config.num_routed_wires / 2 { + let k_0 = builder.constant(common_data.k_is[2 * j]); + let k_0_ext = builder.convert_to_ext(k_0); + let k_1 = builder.constant(common_data.k_is[2 * j + 1]); + let k_1_ext = builder.convert_to_ext(k_1); + let tmp = builder.mul_two_extension(k_0_ext, x, k_1_ext, x); + s_ids.push(tmp.0); + s_ids.push(tmp.1); + } + if common_data.config.num_routed_wires.is_odd() { + let k = builder.constant(common_data.k_is[common_data.k_is.len() - 1]); + let k_ext = builder.convert_to_ext(k); + s_ids.push(builder.mul_extension(k_ext, x)); + } + for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; let z_gz = next_zs[i]; @@ -305,20 +329,19 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( let numerator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; - let k_i = builder.constant(common_data.k_is[j]); - let s_id = builder.scalar_mul_ext(k_i, x); + let beta_ext = builder.convert_to_ext(betas[i]); let gamma_ext = builder.convert_to_ext(gammas[i]); - let tmp = builder.scalar_mul_add_extension(betas[i], s_id, wire_value); - builder.add_extension(tmp, gamma_ext) + // `beta * s_id + wire_value + gamma` + builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext) }) .collect::>(); let denominator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; - let s_sigma = s_sigmas[j]; + let beta_ext = builder.convert_to_ext(betas[i]); let gamma_ext = builder.convert_to_ext(gammas[i]); - let tmp = builder.scalar_mul_add_extension(betas[i], s_sigma, wire_value); - builder.add_extension(tmp, gamma_ext) + // `beta * s_sigma + wire_value + gamma` + builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext) }) .collect::>(); let quotient_values = (0..common_data.config.num_routed_wires)