PR feedback

This commit is contained in:
wborgeaud 2021-07-21 19:23:26 +02:00
parent 59494ff8d1
commit 6cc871d30c

View File

@ -140,21 +140,24 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.1
}
/// Add `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
/// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let zero = self.zero_extension();
let mut terms = terms.to_vec();
if terms.len().is_odd() {
if terms.is_empty() {
return zero;
} else if terms.len() < 3 {
terms.resize(3, zero);
} else if terms.len().is_even() {
terms.push(zero);
}
// We maintain two accumulators, one for the sum of even elements, and one for odd elements.
let mut acc0 = zero;
let mut acc1 = zero;
let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]);
terms.drain(0..3);
for chunk in terms.chunks_exact(2) {
(acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]);
acc = self.add_three_extension(acc, chunk[0], chunk[1]);
}
// We sum both accumulators to get the final result.
self.add_extension(acc0, acc1)
acc
}
pub fn sub_extension(
@ -277,21 +280,23 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.1
}
/// Multiply `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
/// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let one = self.one_extension();
let mut terms = terms.to_vec();
if terms.len().is_odd() {
if terms.is_empty() {
return one;
} else if terms.len() < 3 {
terms.resize(3, one);
} else if terms.len().is_even() {
terms.push(one);
}
// We maintain two accumulators, one for the product of even elements, and one for odd elements.
let mut acc0 = one;
let mut acc1 = one;
let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]);
terms.drain(0..3);
for chunk in terms.chunks_exact(2) {
(acc0, acc1) = self.mul_two_extension(acc0, chunk[0], acc1, chunk[1]);
acc = self.mul_three_extension(acc, chunk[0], chunk[1]);
}
// We multiply both accumulators to get the final result.
self.mul_extension(acc0, acc1)
acc
}
/// Like `mul_add`, but for `ExtensionTarget`s.