From 4b44578ffa3deda557e05938d43b52788dd36ca1 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 9 Aug 2021 12:39:37 +0200 Subject: [PATCH] More optimizations --- src/field/extension_field/target.rs | 4 ++ src/gadgets/arithmetic_extension.rs | 57 +++++++++++++++++++++++------ src/gadgets/polynomial.rs | 6 ++- src/gates/gmimc.rs | 16 +++++++- src/gates/reducing.rs | 3 +- 5 files changed, 69 insertions(+), 17 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index c7f9dc39..683afe0b 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -105,6 +105,10 @@ impl, const D: usize> CircuitBuilder { self.constant_extension(F::Extension::TWO) } + pub fn neg_one_extension(&mut self) -> ExtensionTarget { + self.constant_extension(F::Extension::NEG_ONE) + } + pub fn zero_ext_algebra(&mut self) -> ExtensionAlgebraTarget { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 9aaaf857..3a494cbc 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -347,12 +347,13 @@ impl, const D: usize> CircuitBuilder { self.mul_three_extension(x, x, x) } - pub fn mul_ext_algebra( + /// Returns `a*b + c`. + pub fn mul_add_ext_algebra( &mut self, a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, + c: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - let zero = self.zero_extension(); let mut inner = vec![vec![]; D]; let mut inner_w = vec![vec![]; D]; for i in 0..D { @@ -366,14 +367,23 @@ impl, const D: usize> CircuitBuilder { let res = inner_w .into_iter() .zip(inner) - .map(|(vecs_w, vecs)| { - let acc = self.inner_product_extension(F::Extension::W, zero, vecs_w); - self.inner_product_extension(F::ONE, acc, vecs) + .zip(c.0) + .map(|((pairs_w, pairs), ci)| { + let acc = self.inner_product_extension(F::Extension::W, ci, pairs_w); + self.inner_product_extension(F::ONE, acc, pairs) }) .collect::>(); ExtensionAlgebraTarget(res.try_into().unwrap()) } + pub fn mul_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let zero = self.zero_ext_algebra(); + self.mul_add_ext_algebra(a, b, zero) + } /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. pub fn mul_three_extension( @@ -457,17 +467,42 @@ impl, const D: usize> CircuitBuilder { self.mul_extension(a_ext, b) } - /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the + /// Returns `a * b + c`, where `b,c` are in the extension algebra and `a` in the extension field. /// extension field. + pub fn scalar_mul_add_ext_algebra( + &mut self, + a: ExtensionTarget, + b: ExtensionAlgebraTarget, + mut c: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + for i in 0..D / 2 { + let res = self.double_arithmetic_extension( + F::ONE, + F::ONE, + a, + b.0[2 * i], + c.0[2 * i], + a, + b.0[2 * i + 1], + c.0[2 * i + 1], + ); + c.0[2 * i] = res.0; + c.0[2 * i + 1] = res.1; + } + if D.is_odd() { + c.0[D - 1] = self.arithmetic_extension(F::ONE, F::ONE, a, b.0[D - 1], c.0[D - 1]); + } + c + } + + /// Returns `a * b`, where `b,c` are in the extension algebra and `a` in the extension field. pub fn scalar_mul_ext_algebra( &mut self, a: ExtensionTarget, - mut b: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D { - b.0[i] = self.mul_extension(a, b.0[i]); - } - b + let zero = self.zero_ext_algebra(); + self.scalar_mul_add_ext_algebra(a, b, zero) } /// Exponentiate `base` to the power of `2^power_log`. diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index a83cbcd4..089046e6 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -50,8 +50,10 @@ impl PolynomialCoeffsExtAlgebraTarget { { let mut acc = builder.zero_ext_algebra(); for &c in self.0.iter().rev() { - let tmp = builder.scalar_mul_ext_algebra(point, acc); - acc = builder.add_ext_algebra(tmp, c); + // let tmp = builder.scalar_mul_ext_algebra(point, acc); + // acc = builder.add_ext_algebra(tmp, c); + acc = builder.scalar_mul_add_ext_algebra(point, acc, c); + // acc = builder.add_ext_algebra(tmp, c); } acc } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 311ba841..ea03f8a7 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -160,6 +160,8 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { + let one = builder.one_extension(); + let neg_one = builder.neg_one_extension(); let mut constraints = Vec::with_capacity(self.num_constraints()); let swap = vars.local_wires[Self::WIRE_SWAP]; @@ -195,8 +197,18 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; constraints.push(builder.sub_extension(cubing_input, cubing_input_wire)); let f = builder.cube_extension(cubing_input_wire); - addition_buffer = builder.add_extension(addition_buffer, f); - state[active] = builder.sub_extension(state[active], f); + let tmp = builder.double_arithmetic_extension( + F::ONE, + F::ONE, + one, + addition_buffer, + f, + neg_one, + f, + state[active], + ); + addition_buffer = tmp.0; + state[active] = tmp.1; } for i in 0..W { diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index cfdcaf17..f3799b10 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -121,9 +121,8 @@ impl, const D: usize> Gate for ReducingGate { let mut constraints = Vec::new(); let mut acc = old_acc; for i in 0..self.num_coeffs { - let mut tmp = builder.mul_ext_algebra(acc, alpha); let coeff = builder.convert_to_ext_algebra(coeffs[i]); - tmp = builder.add_ext_algebra(tmp, coeff); + let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff); tmp = builder.sub_ext_algebra(tmp, accs[i]); constraints.push(tmp); acc = accs[i];