Merge pull request #185 from mir-protocol/more_scalar_muls

More scalar muls
This commit is contained in:
wborgeaud 2021-08-17 18:16:08 +02:00 committed by GitHub
commit e98bca6c84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 56 additions and 24 deletions

View File

@ -23,6 +23,14 @@ impl<F: OEF<D>, const D: usize> ExtensionAlgebra<F, D> {
pub fn to_basefield_array(self) -> [F; D] {
self.0
}
pub fn scalar_mul(&self, scalar: F) -> Self {
let mut res = self.0;
res.iter_mut().for_each(|x| {
*x *= scalar;
});
Self(res)
}
}
impl<F: OEF<D>, const D: usize> From<F> for ExtensionAlgebra<F, D> {
@ -151,6 +159,13 @@ impl<F: OEF<D>, const D: usize> PolynomialCoeffsAlgebra<F, D> {
.rev()
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c)
}
pub fn eval_base(&self, x: F) -> ExtensionAlgebra<F, D> {
self.coeffs
.iter()
.rev()
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
}
#[cfg(test)]
@ -205,7 +220,7 @@ mod tests {
let c = a * b;
let res = selector(xs, &ts);
for i in 0..D {
ans[i] += res * c.to_basefield_array()[i].into();
ans[i] += res.scalar_mul(c.to_basefield_array()[i]);
}
}
ans

View File

@ -77,7 +77,10 @@ impl Field for QuadraticCrandallField {
let a_pow_r = a_pow_r_minus_1 * *self;
debug_assert!(FieldExtension::<2>::is_in_basefield(&a_pow_r));
Some(a_pow_r_minus_1 * a_pow_r.0[0].inverse().into())
Some(FieldExtension::<2>::scalar_mul(
&a_pow_r_minus_1,
a_pow_r.0[0].inverse(),
))
}
fn to_canonical_u64(&self) -> u64 {

View File

@ -110,7 +110,10 @@ impl Field for QuarticCrandallField {
let a_pow_r = a_pow_r_minus_1 * *self;
debug_assert!(FieldExtension::<4>::is_in_basefield(&a_pow_r));
Some(a_pow_r_minus_1 * a_pow_r.0[0].inverse().into())
Some(FieldExtension::<4>::scalar_mul(
&a_pow_r_minus_1,
a_pow_r.0[0].inverse(),
))
}
fn to_canonical_u64(&self) -> u64 {

View File

@ -1,6 +1,7 @@
use std::convert::TryInto;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::FieldExtension;
use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::Field;
use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS};
@ -98,14 +99,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let first_term_const = if first_term_zero {
Some(F::Extension::ZERO)
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
Some(x * y * const_0.into())
Some((x * y).scalar_mul(const_0))
} else {
None
};
let second_term_const = if second_term_zero {
Some(F::Extension::ZERO)
} else {
addend_const.map(|x| x * const_1.into())
addend_const.map(|x| x.scalar_mul(const_1))
};
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
return Some(self.constant_extension(x + y));
@ -117,12 +118,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
if second_term_zero {
if let Some(x) = mul_0_const {
if (x * const_0.into()).is_one() {
if x.scalar_mul(const_0).is_one() {
return Some(multiplicand_1);
}
}
if let Some(x) = mul_1_const {
if (x * const_0.into()).is_one() {
if x.scalar_mul(const_0).is_one() {
return Some(multiplicand_0);
}
}

View File

@ -48,7 +48,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output =
multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into();
(multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
constraints.extend((output - computed_output).to_basefield_array());
}
@ -176,8 +176,8 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
ArithmeticExtensionGate::<D>::wires_ith_output(self.i),
);
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0)
+ addend.scalar_mul(self.const_1);
out_buffer.set_extension_target(output_target, computed_output)
}

View File

@ -97,13 +97,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
constraints.push(difference * equality_dummy - (F::Extension::ONE - insert_here));
constraints.push(insert_here * difference);
let mut new_item = element_to_insert * insert_here.into();
let mut new_item = element_to_insert.scalar_mul(insert_here);
if r > 0 {
new_item += list_items[r - 1] * already_inserted.into();
new_item += list_items[r - 1].scalar_mul(already_inserted);
}
already_inserted += insert_here;
if r < self.vec_size {
new_item += list_items[r] * (F::Extension::ONE - already_inserted).into();
new_item += list_items[r].scalar_mul(F::Extension::ONE - already_inserted);
}
// Output constraint.
@ -135,13 +135,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
constraints.push(difference * equality_dummy - (F::ONE - insert_here));
constraints.push(insert_here * difference);
let mut new_item = element_to_insert * insert_here.into();
let mut new_item = element_to_insert.scalar_mul(insert_here);
if r > 0 {
new_item += list_items[r - 1] * already_inserted.into();
new_item += list_items[r - 1].scalar_mul(already_inserted);
}
already_inserted += insert_here;
if r < self.vec_size {
new_item += list_items[r] * (F::ONE - already_inserted).into();
new_item += list_items[r].scalar_mul(F::ONE - already_inserted);
}
// Output constraint.

View File

@ -109,7 +109,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
for i in 0..self.num_points {
let point = vars.local_wires[self.wire_point(i)];
let value = vars.get_local_ext_algebra(self.wires_value(i));
let computed_value = interpolant.eval(point.into());
let computed_value = interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
}
@ -132,7 +132,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
for i in 0..self.num_points {
let point = vars.local_wires[self.wire_point(i)];
let value = vars.get_local_ext(self.wires_value(i));
let computed_value = interpolant.eval(point.into());
let computed_value = interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
}

View File

@ -85,7 +85,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
constraints.push(index_matches * difference);
// Value equality constraint.
constraints.extend(
((list_items[i] - claimed_element) * index_matches.into()).to_basefield_array(),
((list_items[i] - claimed_element).scalar_mul(index_matches)).to_basefield_array(),
);
}
@ -112,7 +112,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
// Value equality constraint.
constraints.extend(
((list_items[i] - claimed_element) * index_matches.into()).to_basefield_array(),
((list_items[i] - claimed_element).scalar_mul(index_matches)).to_basefield_array(),
);
}

View File

@ -1,5 +1,5 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::Field;
use crate::gates::gate::PrefixedGate;
use crate::iop::target::Target;
@ -51,15 +51,15 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, const D: usize>(
.map(|j| {
let wire_value = vars.local_wires[j];
let k_i = common_data.k_is[j];
let s_id = x * k_i.into();
wire_value + s_id * betas[i].into() + gammas[i].into()
let s_id = x.scalar_mul(k_i);
wire_value + s_id.scalar_mul(betas[i]) + gammas[i].into()
})
.collect::<Vec<_>>();
let denominator_values = (0..common_data.config.num_routed_wires)
.map(|j| {
let wire_value = vars.local_wires[j];
let s_sigma = s_sigmas[j];
wire_value + s_sigma * betas[i].into() + gammas[i].into()
wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into()
})
.collect::<Vec<_>>();
let quotient_values = (0..common_data.config.num_routed_wires)

View File

@ -137,6 +137,16 @@ impl<F: Field> PolynomialCoeffs<F> {
.fold(F::ZERO, |acc, &c| acc * x + c)
}
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
where
F: FieldExtension<D>,
{
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}