From 722f99743a4630c0a898e856e6edbb6dcb178c1c Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 17 Aug 2021 00:49:01 -0700 Subject: [PATCH 1/2] Use scalar_mul vs converting --- src/gates/arithmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index c8f8f59c..9fd72aef 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -67,7 +67,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let addend = vars.get_local_ext(Self::wires_ith_addend(i)); let output = vars.get_local_ext(Self::wires_ith_output(i)); let computed_output = - multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into(); + (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1); constraints.extend((output - computed_output).to_basefield_array()); } From 561228103f25f562568b11bfde4bfbda75054105 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 17 Aug 2021 10:26:31 +0200 Subject: [PATCH 2/2] More scalar muls --- src/field/extension_field/algebra.rs | 17 ++++++++++++++++- src/field/extension_field/quadratic.rs | 5 ++++- src/field/extension_field/quartic.rs | 5 ++++- src/gadgets/arithmetic_extension.rs | 9 +++++---- src/gates/arithmetic.rs | 6 +++--- src/gates/insertion.rs | 12 ++++++------ src/gates/interpolation.rs | 4 ++-- src/gates/random_access.rs | 4 ++-- src/plonk/vanishing_poly.rs | 8 ++++---- src/polynomial/polynomial.rs | 10 ++++++++++ 10 files changed, 56 insertions(+), 24 deletions(-) diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index b4a044c3..78f787fe 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -23,6 +23,14 @@ impl, const D: usize> ExtensionAlgebra { pub fn to_basefield_array(self) -> [F; D] { self.0 } + + pub fn scalar_mul(&self, scalar: F) -> Self { + let mut res = self.0; + res.iter_mut().for_each(|x| { + *x *= scalar; + }); + Self(res) + } } impl, const D: usize> From for ExtensionAlgebra { @@ -151,6 +159,13 @@ impl, const D: usize> PolynomialCoeffsAlgebra { .rev() .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) } + + pub fn eval_base(&self, x: F) -> ExtensionAlgebra { + self.coeffs + .iter() + .rev() + .fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c) + } } #[cfg(test)] @@ -205,7 +220,7 @@ mod tests { let c = a * b; let res = selector(xs, &ts); for i in 0..D { - ans[i] += res * c.to_basefield_array()[i].into(); + ans[i] += res.scalar_mul(c.to_basefield_array()[i]); } } ans diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index fa698fb6..1d31a0be 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -77,7 +77,10 @@ impl Field for QuadraticCrandallField { let a_pow_r = a_pow_r_minus_1 * *self; debug_assert!(FieldExtension::<2>::is_in_basefield(&a_pow_r)); - Some(a_pow_r_minus_1 * a_pow_r.0[0].inverse().into()) + Some(FieldExtension::<2>::scalar_mul( + &a_pow_r_minus_1, + a_pow_r.0[0].inverse(), + )) } fn to_canonical_u64(&self) -> u64 { diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 2bc1c9b3..2c111627 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -110,7 +110,10 @@ impl Field for QuarticCrandallField { let a_pow_r = a_pow_r_minus_1 * *self; debug_assert!(FieldExtension::<4>::is_in_basefield(&a_pow_r)); - Some(a_pow_r_minus_1 * a_pow_r.0[0].inverse().into()) + Some(FieldExtension::<4>::scalar_mul( + &a_pow_r_minus_1, + a_pow_r.0[0].inverse(), + )) } fn to_canonical_u64(&self) -> u64 { diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 1471b031..94b37eaa 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::Field; use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS}; @@ -98,14 +99,14 @@ impl, const D: usize> CircuitBuilder { let first_term_const = if first_term_zero { Some(F::Extension::ZERO) } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { - Some(x * y * const_0.into()) + Some((x * y).scalar_mul(const_0)) } else { None }; let second_term_const = if second_term_zero { Some(F::Extension::ZERO) } else { - addend_const.map(|x| x * const_1.into()) + addend_const.map(|x| x.scalar_mul(const_1)) }; if let (Some(x), Some(y)) = (first_term_const, second_term_const) { return Some(self.constant_extension(x + y)); @@ -117,12 +118,12 @@ impl, const D: usize> CircuitBuilder { if second_term_zero { if let Some(x) = mul_0_const { - if (x * const_0.into()).is_one() { + if x.scalar_mul(const_0).is_one() { return Some(multiplicand_1); } } if let Some(x) = mul_1_const { - if (x * const_0.into()).is_one() { + if x.scalar_mul(const_0).is_one() { return Some(multiplicand_0); } } diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 9fd72aef..dcabcee8 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -48,7 +48,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i)); let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); let computed_output = - multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into(); + (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1); constraints.extend((output - computed_output).to_basefield_array()); } @@ -176,8 +176,8 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio ArithmeticExtensionGate::::wires_ith_output(self.i), ); - let computed_output = - multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0) + + addend.scalar_mul(self.const_1); out_buffer.set_extension_target(output_target, computed_output) } diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 2e0ac8a1..7a42be41 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -97,13 +97,13 @@ impl, const D: usize> Gate for InsertionGate { constraints.push(difference * equality_dummy - (F::Extension::ONE - insert_here)); constraints.push(insert_here * difference); - let mut new_item = element_to_insert * insert_here.into(); + let mut new_item = element_to_insert.scalar_mul(insert_here); if r > 0 { - new_item += list_items[r - 1] * already_inserted.into(); + new_item += list_items[r - 1].scalar_mul(already_inserted); } already_inserted += insert_here; if r < self.vec_size { - new_item += list_items[r] * (F::Extension::ONE - already_inserted).into(); + new_item += list_items[r].scalar_mul(F::Extension::ONE - already_inserted); } // Output constraint. @@ -135,13 +135,13 @@ impl, const D: usize> Gate for InsertionGate { constraints.push(difference * equality_dummy - (F::ONE - insert_here)); constraints.push(insert_here * difference); - let mut new_item = element_to_insert * insert_here.into(); + let mut new_item = element_to_insert.scalar_mul(insert_here); if r > 0 { - new_item += list_items[r - 1] * already_inserted.into(); + new_item += list_items[r - 1].scalar_mul(already_inserted); } already_inserted += insert_here; if r < self.vec_size { - new_item += list_items[r] * (F::ONE - already_inserted).into(); + new_item += list_items[r].scalar_mul(F::ONE - already_inserted); } // Output constraint. diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index d04ca5e8..9356d428 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -109,7 +109,7 @@ impl, const D: usize> Gate for InterpolationGate { for i in 0..self.num_points { let point = vars.local_wires[self.wire_point(i)]; let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval(point.into()); + let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } @@ -132,7 +132,7 @@ impl, const D: usize> Gate for InterpolationGate { for i in 0..self.num_points { let point = vars.local_wires[self.wire_point(i)]; let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = interpolant.eval(point.into()); + let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index 07173de5..0dd008f0 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -85,7 +85,7 @@ impl, const D: usize> Gate for RandomAccessGate { constraints.push(index_matches * difference); // Value equality constraint. constraints.extend( - ((list_items[i] - claimed_element) * index_matches.into()).to_basefield_array(), + ((list_items[i] - claimed_element).scalar_mul(index_matches)).to_basefield_array(), ); } @@ -112,7 +112,7 @@ impl, const D: usize> Gate for RandomAccessGate { // Value equality constraint. constraints.extend( - ((list_items[i] - claimed_element) * index_matches.into()).to_basefield_array(), + ((list_items[i] - claimed_element).scalar_mul(index_matches)).to_basefield_array(), ); } diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 6a282be9..5871a008 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -1,5 +1,5 @@ use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; use crate::gates::gate::PrefixedGate; use crate::iop::target::Target; @@ -51,15 +51,15 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( .map(|j| { let wire_value = vars.local_wires[j]; let k_i = common_data.k_is[j]; - let s_id = x * k_i.into(); - wire_value + s_id * betas[i].into() + gammas[i].into() + let s_id = x.scalar_mul(k_i); + wire_value + s_id.scalar_mul(betas[i]) + gammas[i].into() }) .collect::>(); let denominator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; let s_sigma = s_sigmas[j]; - wire_value + s_sigma * betas[i].into() + gammas[i].into() + wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into() }) .collect::>(); let quotient_values = (0..common_data.config.num_routed_wires) diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 4e15be56..386d49ae 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -137,6 +137,16 @@ impl PolynomialCoeffs { .fold(F::ZERO, |acc, &c| acc * x + c) } + pub fn eval_base(&self, x: F::BaseField) -> F + where + F: FieldExtension, + { + self.coeffs + .iter() + .rev() + .fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c) + } + pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec { polys.into_iter().map(|p| p.lde(rate_bits)).collect() }