diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 9aedc2fe..e6efb451 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -140,21 +140,24 @@ impl, const D: usize> CircuitBuilder { .1 } - /// Add `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. + /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); - if terms.len().is_odd() { + if terms.is_empty() { + return zero; + } else if terms.len() < 3 { + terms.resize(3, zero); + } else if terms.len().is_even() { terms.push(zero); } - // We maintain two accumulators, one for the sum of even elements, and one for odd elements. - let mut acc0 = zero; - let mut acc1 = zero; + + let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]); + terms.drain(0..3); for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); + acc = self.add_three_extension(acc, chunk[0], chunk[1]); } - // We sum both accumulators to get the final result. - self.add_extension(acc0, acc1) + acc } pub fn sub_extension( @@ -277,21 +280,23 @@ impl, const D: usize> CircuitBuilder { .1 } - /// Multiply `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. + /// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let one = self.one_extension(); let mut terms = terms.to_vec(); - if terms.len().is_odd() { + if terms.is_empty() { + return one; + } else if terms.len() < 3 { + terms.resize(3, one); + } else if terms.len().is_even() { terms.push(one); } - // We maintain two accumulators, one for the product of even elements, and one for odd elements. - let mut acc0 = one; - let mut acc1 = one; + let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]); + terms.drain(0..3); for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.mul_two_extension(acc0, chunk[0], acc1, chunk[1]); + acc = self.mul_three_extension(acc, chunk[0], chunk[1]); } - // We multiply both accumulators to get the final result. - self.mul_extension(acc0, acc1) + acc } /// Like `mul_add`, but for `ExtensionTarget`s.