diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 52b1d660..fd61804b 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -19,8 +19,7 @@ impl, const D: usize> CircuitBuilder { /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { - let xe = self.convert_to_ext(x); - self.mul_three_extension(xe, xe, xe).to_target_array()[0] + self.mul_many(&[x, x, x]) } /// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`. @@ -63,8 +62,8 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, one, F::ONE, y) } - /// Add `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. - // TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`. + /// Add `n` `Target`s. + // TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { let terms_ext = terms .iter() @@ -86,7 +85,7 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, y, F::ZERO, x) } - /// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. + /// Multiply `n` `Target`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { let terms_ext = terms .iter() diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index b11447df..1471b031 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,7 +1,5 @@ use std::convert::TryInto; -use num::Integer; - use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::Field; @@ -179,35 +177,13 @@ impl, const D: usize> CircuitBuilder { a } - /// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. - pub fn add_three_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let one = self.one_extension(); - self.wide_arithmetic_extension(one, a, one, b, c) - } - - /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. + /// Add `n` `ExtensionTarget`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let zero = self.zero_extension(); - let mut terms = terms.to_vec(); - if terms.is_empty() { - return zero; - } else if terms.len() < 3 { - terms.resize(3, zero); - } else if terms.len().is_even() { - terms.push(zero); + let mut sum = self.zero_extension(); + for &term in terms { + sum = self.add_extension(sum, term); } - - let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]); - terms.drain(0..3); - for chunk in terms.chunks_exact(2) { - acc = self.add_three_extension(acc, chunk[0], chunk[1]); - } - acc + sum } pub fn sub_extension( @@ -255,7 +231,7 @@ impl, const D: usize> CircuitBuilder { /// Computes `x^3`. pub fn cube_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { - self.mul_three_extension(x, x, x) + self.mul_many_extension(&[x, x, x]) } /// Returns `a * b + c`. @@ -298,34 +274,13 @@ impl, const D: usize> CircuitBuilder { self.mul_add_ext_algebra(a, b, zero) } - /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. - pub fn mul_three_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let tmp = self.mul_extension(a, b); - self.mul_extension(tmp, c) - } - - /// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. + /// Multiply `n` `ExtensionTarget`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let one = self.one_extension(); - let mut terms = terms.to_vec(); - if terms.is_empty() { - return one; - } else if terms.len() < 3 { - terms.resize(3, one); - } else if terms.len().is_even() { - terms.push(one); + let mut product = self.one_extension(); + for &term in terms { + product = self.mul_extension(product, term); } - let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]); - terms.drain(0..3); - for chunk in terms.chunks_exact(2) { - acc = self.mul_three_extension(acc, chunk[0], chunk[1]); - } - acc + product } /// Like `mul_add`, but for `ExtensionTarget`s. @@ -576,12 +531,10 @@ mod tests { } acc }; - let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]); - let mul3 = builder.constant_extension(vs.into_iter().product()); + let mul2 = builder.constant_extension(vs.into_iter().product()); builder.assert_equal_extension(mul0, mul1); builder.assert_equal_extension(mul1, mul2); - builder.assert_equal_extension(mul2, mul3); let data = builder.build(); let proof = data.prove(pw)?; diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 6db70b6e..8901b9f8 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -32,8 +32,6 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { - use std::convert::TryInto; - use anyhow::Result; use crate::field::crandall_field::CrandallField; diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 5e8540ae..1830b9c7 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -191,7 +191,7 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< let constant = builder.constant_extension(self.constants[r].into()); let cubing_input = - builder.add_three_extension(state[active], addition_buffer, constant); + builder.add_many_extension(&[state[active], addition_buffer, constant]); let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; constraints.push(builder.sub_extension(cubing_input, cubing_input_wire)); let f = builder.cube_extension(cubing_input_wire);