Remove *_three methods (#182)

* Remove *_three methods

Since there's no longer a performance reason for them, and I think the *_many methods are about as short etc.

* PR feedback
This commit is contained in:
Daniel Lubarov 2021-08-17 00:38:41 -07:00 committed by GitHub
parent 81e0acfca4
commit 69193a8dc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 67 deletions

View File

@ -19,8 +19,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x^3`.
pub fn cube(&mut self, x: Target) -> Target {
let xe = self.convert_to_ext(x);
self.mul_three_extension(xe, xe, xe).to_target_array()[0]
self.mul_many(&[x, x, x])
}
/// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`.
@ -63,8 +62,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic(F::ONE, x, one, F::ONE, y)
}
/// Add `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
// TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
/// Add `n` `Target`s.
// TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
pub fn add_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms
.iter()
@ -86,7 +85,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic(F::ONE, x, y, F::ZERO, x)
}
/// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
/// Multiply `n` `Target`s.
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms
.iter()

View File

@ -1,7 +1,5 @@
use std::convert::TryInto;
use num::Integer;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::Field;
@ -179,35 +177,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a
}
/// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
pub fn add_three_extension(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let one = self.one_extension();
self.wide_arithmetic_extension(one, a, one, b, c)
}
/// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
/// Add `n` `ExtensionTarget`s.
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let zero = self.zero_extension();
let mut terms = terms.to_vec();
if terms.is_empty() {
return zero;
} else if terms.len() < 3 {
terms.resize(3, zero);
} else if terms.len().is_even() {
terms.push(zero);
let mut sum = self.zero_extension();
for &term in terms {
sum = self.add_extension(sum, term);
}
let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]);
terms.drain(0..3);
for chunk in terms.chunks_exact(2) {
acc = self.add_three_extension(acc, chunk[0], chunk[1]);
}
acc
sum
}
pub fn sub_extension(
@ -255,7 +231,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x^3`.
pub fn cube_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
self.mul_three_extension(x, x, x)
self.mul_many_extension(&[x, x, x])
}
/// Returns `a * b + c`.
@ -298,34 +274,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.mul_add_ext_algebra(a, b, zero)
}
/// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
pub fn mul_three_extension(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let tmp = self.mul_extension(a, b);
self.mul_extension(tmp, c)
}
/// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
/// Multiply `n` `ExtensionTarget`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let one = self.one_extension();
let mut terms = terms.to_vec();
if terms.is_empty() {
return one;
} else if terms.len() < 3 {
terms.resize(3, one);
} else if terms.len().is_even() {
terms.push(one);
let mut product = self.one_extension();
for &term in terms {
product = self.mul_extension(product, term);
}
let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]);
terms.drain(0..3);
for chunk in terms.chunks_exact(2) {
acc = self.mul_three_extension(acc, chunk[0], chunk[1]);
}
acc
product
}
/// Like `mul_add`, but for `ExtensionTarget`s.
@ -576,12 +531,10 @@ mod tests {
}
acc
};
let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]);
let mul3 = builder.constant_extension(vs.into_iter().product());
let mul2 = builder.constant_extension(vs.into_iter().product());
builder.assert_equal_extension(mul0, mul1);
builder.assert_equal_extension(mul1, mul2);
builder.assert_equal_extension(mul2, mul3);
let data = builder.build();
let proof = data.prove(pw)?;

View File

@ -32,8 +32,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::crandall_field::CrandallField;

View File

@ -191,7 +191,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let constant = builder.constant_extension(self.constants[r].into());
let cubing_input =
builder.add_three_extension(state[active], addition_buffer, constant);
builder.add_many_extension(&[state[active], addition_buffer, constant]);
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
constraints.push(builder.sub_extension(cubing_input, cubing_input_wire));
let f = builder.cube_extension(cubing_input_wire);