Merge pull request #166 from mir-protocol/optimize_arithmetic_ops

Optimize some arithmetic operations
This commit is contained in:
wborgeaud 2021-08-10 09:25:59 +02:00 committed by GitHub
commit 2cf82636f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 243 additions and 67 deletions

View File

@ -7,7 +7,7 @@ use crate::field::extension_field::OEF;
/// Let `F_D` be the optimal extension field `F[X]/(X^D-W)`. Then `ExtensionAlgebra<F_D>` is the quotient `F_D[X]/(X^D-W)`.
/// It's a `D`-dimensional algebra over `F_D` useful to lift the multiplication over `F_D` to a multiplication over `(F_D)^D`.
#[derive(Copy, Clone)]
pub struct ExtensionAlgebra<F: OEF<D>, const D: usize>([F; D]);
pub struct ExtensionAlgebra<F: OEF<D>, const D: usize>(pub [F; D]);
impl<F: OEF<D>, const D: usize> ExtensionAlgebra<F, D> {
pub const ZERO: Self = Self([F::ZERO; D]);

View File

@ -105,6 +105,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.constant_extension(F::Extension::TWO)
}
pub fn neg_one_extension(&mut self) -> ExtensionTarget<D> {
self.constant_extension(F::Extension::NEG_ONE)
}
pub fn zero_ext_algebra(&mut self) -> ExtensionAlgebraTarget<D> {
self.constant_ext_algebra(ExtensionAlgebra::ZERO)
}

View File

@ -40,14 +40,23 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let coset_start = self.mul(start, x);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
let points = g
let g_powers = g
.powers()
.map(|y| {
let yt = self.constant(y);
self.mul(coset_start, yt)
})
.zip(evals)
.take(arity)
.map(|y| self.constant(y))
.collect::<Vec<_>>();
let mut coset = Vec::new();
for i in 0..arity / 2 {
let res = self.mul_two(
coset_start,
g_powers[2 * i],
coset_start,
g_powers[2 * i + 1],
);
coset.push(res.0);
coset.push(res.1);
}
let points = coset.into_iter().zip(evals).collect::<Vec<_>>();
self.interpolate(&points, beta)
}
@ -195,6 +204,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
assert!(D > 1, "Not implemented for D=1.");
let config = self.config.clone();
let degree_log = proof.evals_proofs[0].1.siblings.len() - config.rate_bits;
let one = self.one_extension();
let subgroup_x = self.convert_to_ext(subgroup_x);
let vanish_zeta = self.sub_extension(subgroup_x, zeta);
let mut alpha = ReducingFactorTarget::new(alpha);
@ -223,8 +233,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.sub_extension(single_composition_eval, precomputed_reduced_evals.single);
// This division is safe because the denominator will be nonzero unless zeta is in the
// codeword domain, which occurs with negligible probability given a large extension field.
let quotient = self.div_unsafe_extension(single_numerator, vanish_zeta);
sum = self.add_extension(sum, quotient);
sum = self.div_add_extension(single_numerator, vanish_zeta, sum);
alpha.reset();
// Polynomials opened at `x` and `g x`, i.e., the Zs polynomials.
@ -245,14 +254,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
],
subgroup_x,
);
let zs_numerator = self.sub_extension(zs_composition_eval, interpol_val);
let vanish_zeta_right = self.sub_extension(subgroup_x, zeta_right);
let (zs_numerator, vanish_zeta_right) =
self.sub_two_extension(zs_composition_eval, interpol_val, subgroup_x, zeta_right);
let zs_denominator = self.mul_extension(vanish_zeta, vanish_zeta_right);
sum = alpha.shift(sum, self);
// This division is safe because the denominator will be nonzero unless zeta is in the
// codeword domain, which occurs with negligible probability given a large extension field.
let zs_quotient = self.div_unsafe_extension(zs_numerator, zs_denominator);
sum = alpha.shift(sum, self);
sum = self.add_extension(sum, zs_quotient);
sum = self.div_add_extension(zs_numerator, zs_denominator, sum);
sum
}

View File

@ -86,6 +86,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic(F::ONE, x, y, F::ZERO, x)
}
/// Computes `x * y`.
pub fn mul_two(&mut self, a0: Target, b0: Target, a1: Target, b1: Target) -> (Target, Target) {
let a0_ext = self.convert_to_ext(a0);
let b0_ext = self.convert_to_ext(b0);
let a1_ext = self.convert_to_ext(a1);
let b1_ext = self.convert_to_ext(b1);
let res = self.mul_two_extension(a0_ext, b0_ext, a1_ext, b1_ext);
(res.0 .0[0], res.1 .0[0])
}
/// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms

View File

@ -152,6 +152,33 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
None
}
/// Returns `sum_{(a,b) in vecs} constant * a * b`.
pub fn inner_product_extension(
&mut self,
constant: F,
starting_acc: ExtensionTarget<D>,
pairs: Vec<(ExtensionTarget<D>, ExtensionTarget<D>)>,
) -> ExtensionTarget<D> {
let mut acc = starting_acc;
for chunk in pairs.chunks_exact(2) {
let (a0, b0) = chunk[0];
let (a1, b1) = chunk[1];
let gate = self.num_gates();
let first_out = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_first_output(),
);
acc = self
.double_arithmetic_extension(constant, F::ONE, a0, b0, acc, a1, b1, first_out)
.1;
}
if pairs.len().is_odd() {
let n = pairs.len() - 1;
acc = self.arithmetic_extension(constant, F::ONE, pairs[n].0, pairs[n].1, acc);
}
acc
}
pub fn add_extension(
&mut self,
a: ExtensionTarget<D>,
@ -320,24 +347,44 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.mul_three_extension(x, x, x)
}
/// Returns `a * b + c`.
pub fn mul_add_ext_algebra(
&mut self,
a: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
c: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
let mut inner = vec![vec![]; D];
let mut inner_w = vec![vec![]; D];
for i in 0..D {
for j in 0..D - i {
inner[(i + j) % D].push((a.0[i], b.0[j]));
}
for j in D - i..D {
inner_w[(i + j) % D].push((a.0[i], b.0[j]));
}
}
let res = inner_w
.into_iter()
.zip(inner)
.zip(c.0)
.map(|((pairs_w, pairs), ci)| {
let acc = self.inner_product_extension(F::Extension::W, ci, pairs_w);
self.inner_product_extension(F::ONE, acc, pairs)
})
.collect::<Vec<_>>();
ExtensionAlgebraTarget(res.try_into().unwrap())
}
/// Returns `a * b`.
pub fn mul_ext_algebra(
&mut self,
a: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
let mut res = [self.zero_extension(); D];
let w = self.constant(F::Extension::W);
for i in 0..D {
for j in 0..D {
res[(i + j) % D] = if i + j < D {
self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D])
} else {
let ai_bi = self.mul_extension(a.0[i], b.0[j]);
self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D])
}
}
}
ExtensionAlgebraTarget(res)
let zero = self.zero_ext_algebra();
self.mul_add_ext_algebra(a, b, zero)
}
/// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
@ -422,17 +469,41 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.mul_extension(a_ext, b)
}
/// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the
/// extension field.
/// Returns `a * b + c`, where `b, c` are in the extension algebra and `a` in the extension field.
pub fn scalar_mul_add_ext_algebra(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionAlgebraTarget<D>,
mut c: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
for i in 0..D / 2 {
let res = self.double_arithmetic_extension(
F::ONE,
F::ONE,
a,
b.0[2 * i],
c.0[2 * i],
a,
b.0[2 * i + 1],
c.0[2 * i + 1],
);
c.0[2 * i] = res.0;
c.0[2 * i + 1] = res.1;
}
if D.is_odd() {
c.0[D - 1] = self.arithmetic_extension(F::ONE, F::ONE, a, b.0[D - 1], c.0[D - 1]);
}
c
}
/// Returns `a * b`, where `b` is in the extension algebra and `a` in the extension field.
pub fn scalar_mul_ext_algebra(
&mut self,
a: ExtensionTarget<D>,
mut b: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
for i in 0..D {
b.0[i] = self.mul_extension(a, b.0[i]);
}
b
let zero = self.zero_ext_algebra();
self.scalar_mul_add_ext_algebra(a, b, zero)
}
/// Exponentiate `base` to the power of `2^power_log`.
@ -480,8 +551,43 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let y_inv = self.inverse_extension(y);
self.mul_extension(x, y_inv)
let inv = self.add_virtual_extension_target();
let one = self.one_extension();
self.add_generator(QuotientGeneratorExtension {
numerator: one,
denominator: y,
quotient: inv,
});
// Enforce that x times its purported inverse equals 1.
let (y_inv, res) = self.mul_two_extension(y, inv, x, inv);
self.assert_equal_extension(y_inv, one);
res
}
/// Computes ` x / y + z`.
pub fn div_add_extension(
&mut self,
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
z: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let inv = self.add_virtual_extension_target();
let zero = self.zero_extension();
let one = self.one_extension();
self.add_generator(QuotientGeneratorExtension {
numerator: one,
denominator: y,
quotient: inv,
});
// Enforce that x times its purported inverse equals 1.
let (y_inv, res) =
self.double_arithmetic_extension(F::ONE, F::ONE, y, inv, zero, x, inv, z);
self.assert_equal_extension(y_inv, one);
res
}
/// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in
@ -585,9 +691,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
@ -659,4 +768,35 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_mul_algebra() -> Result<()> {
type F = CrandallField;
type FF = QuarticCrandallField;
const D: usize = 4;
let config = CircuitConfig::large_config();
let pw = PartialWitness::new(config.num_wires);
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = FF::rand_vec(4);
let y = FF::rand_vec(4);
let xa = ExtensionAlgebra(x.try_into().unwrap());
let ya = ExtensionAlgebra(y.try_into().unwrap());
let za = xa * ya;
let xt = builder.constant_ext_algebra(xa);
let yt = builder.constant_ext_algebra(ya);
let zt = builder.constant_ext_algebra(za);
let comp_zt = builder.mul_ext_algebra(xt, yt);
for i in 0..D {
builder.assert_equal_extension(zt.0[i], comp_zt.0[i]);
}
let data = builder.build();
let proof = data.prove(pw)?;
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -16,8 +16,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// b0 -> b1
// x -> a1 + (x-a0)*(b1-a1)/(b0-a0)
let x_m_a0 = self.sub_extension(evaluation_point, interpolation_points[0].0);
let b1_m_a1 = self.sub_extension(interpolation_points[1].1, interpolation_points[0].1);
let (x_m_a0, b1_m_a1) = self.sub_two_extension(
evaluation_point,
interpolation_points[0].0,
interpolation_points[1].1,
interpolation_points[0].1,
);
let b0_m_a0 = self.sub_extension(interpolation_points[1].0, interpolation_points[0].0);
let quotient = self.div_extension(b1_m_a1, b0_m_a0);

View File

@ -2,6 +2,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar
use crate::field::extension_field::Extendable;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::reducing::ReducingFactorTarget;
pub struct PolynomialCoeffsExtTarget<const D: usize>(pub Vec<ExtensionTarget<D>>);
@ -15,12 +16,9 @@ impl<const D: usize> PolynomialCoeffsExtTarget<D> {
builder: &mut CircuitBuilder<F, D>,
point: Target,
) -> ExtensionTarget<D> {
let mut acc = builder.zero_extension();
for &c in self.0.iter().rev() {
let tmp = builder.scalar_mul_ext(point, acc);
acc = builder.add_extension(tmp, c);
}
acc
let point = builder.convert_to_ext(point);
let mut point = ReducingFactorTarget::new(point);
point.reduce(&self.0, builder)
}
pub fn eval<F: Extendable<D>>(
@ -28,12 +26,8 @@ impl<const D: usize> PolynomialCoeffsExtTarget<D> {
builder: &mut CircuitBuilder<F, D>,
point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let mut acc = builder.zero_extension();
for &c in self.0.iter().rev() {
let tmp = builder.mul_extension(point, acc);
acc = builder.add_extension(tmp, c);
}
acc
let mut point = ReducingFactorTarget::new(point);
point.reduce(&self.0, builder)
}
}
@ -50,8 +44,7 @@ impl<const D: usize> PolynomialCoeffsExtAlgebraTarget<D> {
{
let mut acc = builder.zero_ext_algebra();
for &c in self.0.iter().rev() {
let tmp = builder.scalar_mul_ext_algebra(point, acc);
acc = builder.add_ext_algebra(tmp, c);
acc = builder.scalar_mul_add_ext_algebra(point, acc, c);
}
acc
}
@ -66,8 +59,7 @@ impl<const D: usize> PolynomialCoeffsExtAlgebraTarget<D> {
{
let mut acc = builder.zero_ext_algebra();
for &c in self.0.iter().rev() {
let tmp = builder.mul_ext_algebra(point, acc);
acc = builder.add_ext_algebra(tmp, c);
acc = builder.mul_add_ext_algebra(point, acc, c);
}
acc
}

View File

@ -160,6 +160,8 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let one = builder.one_extension();
let neg_one = builder.neg_one_extension();
let mut constraints = Vec::with_capacity(self.num_constraints());
let swap = vars.local_wires[Self::WIRE_SWAP];
@ -195,8 +197,18 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
constraints.push(builder.sub_extension(cubing_input, cubing_input_wire));
let f = builder.cube_extension(cubing_input_wire);
addition_buffer = builder.add_extension(addition_buffer, f);
state[active] = builder.sub_extension(state[active], f);
// addition_buffer += f
// state[active] -= f
(addition_buffer, state[active]) = builder.double_arithmetic_extension(
F::ONE,
F::ONE,
one,
addition_buffer,
f,
neg_one,
f,
state[active],
);
}
for i in 0..W {

View File

@ -187,14 +187,20 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
let mut new_item = builder.scalar_mul_ext_algebra(insert_here, element_to_insert);
if r > 0 {
let to_add = builder.scalar_mul_ext_algebra(already_inserted, list_items[r - 1]);
new_item = builder.add_ext_algebra(new_item, to_add);
new_item = builder.scalar_mul_add_ext_algebra(
already_inserted,
list_items[r - 1],
new_item,
);
}
already_inserted = builder.add_extension(already_inserted, insert_here);
if r < self.vec_size {
let not_already_inserted = builder.sub_extension(one, already_inserted);
let to_add = builder.scalar_mul_ext_algebra(not_already_inserted, list_items[r]);
new_item = builder.add_ext_algebra(new_item, to_add);
new_item = builder.scalar_mul_add_ext_algebra(
not_already_inserted,
list_items[r],
new_item,
);
}
// Output constraint.

View File

@ -121,9 +121,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ReducingGate<D> {
let mut constraints = Vec::new();
let mut acc = old_acc;
for i in 0..self.num_coeffs {
let mut tmp = builder.mul_ext_algebra(acc, alpha);
let coeff = builder.convert_to_ext_algebra(coeffs[i]);
tmp = builder.add_ext_algebra(tmp, coeff);
let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff);
tmp = builder.sub_ext_algebra(tmp, accs[i]);
constraints.push(tmp);
acc = accs[i];

View File

@ -1,5 +1,7 @@
use std::borrow::Borrow;
use num::Integer;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
@ -7,6 +9,7 @@ use crate::fri::commitment::SALT_SIZE;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::util::reducing::ReducingFactorTarget;
/// Holds the Merkle tree index and blinding flag of a set of polynomials used in FRI.
#[derive(Debug, Copy, Clone)]
@ -181,11 +184,9 @@ pub(crate) fn reduce_with_powers_ext_recursive<F: Extendable<D>, const D: usize>
terms: &[ExtensionTarget<D>],
alpha: Target,
) -> ExtensionTarget<D> {
let mut sum = builder.zero_extension();
for &term in terms.iter().rev() {
sum = builder.scalar_mul_add_extension(alpha, sum, term);
}
sum
let alpha = builder.convert_to_ext(alpha);
let mut alpha = ReducingFactorTarget::new(alpha);
alpha.reduce(terms, builder)
}
/// Reduce a sequence of field elements by the given coefficients.