diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index 633c4d82..b4a044c3 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::OEF; /// Let `F_D` be the optimal extension field `F[X]/(X^D-W)`. Then `ExtensionAlgebra` is the quotient `F_D[X]/(X^D-W)`. /// It's a `D`-dimensional algebra over `F_D` useful to lift the multiplication over `F_D` to a multiplication over `(F_D)^D`. #[derive(Copy, Clone)] -pub struct ExtensionAlgebra, const D: usize>([F; D]); +pub struct ExtensionAlgebra, const D: usize>(pub [F; D]); impl, const D: usize> ExtensionAlgebra { pub const ZERO: Self = Self([F::ZERO; D]); diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index c7f9dc39..683afe0b 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -105,6 +105,10 @@ impl, const D: usize> CircuitBuilder { self.constant_extension(F::Extension::TWO) } + pub fn neg_one_extension(&mut self) -> ExtensionTarget { + self.constant_extension(F::Extension::NEG_ONE) + } + pub fn zero_ext_algebra(&mut self) -> ExtensionAlgebraTarget { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index d771e884..0ee786f7 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -40,14 +40,23 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - let points = g + let g_powers = g .powers() - .map(|y| { - let yt = self.constant(y); - self.mul(coset_start, yt) - }) - .zip(evals) + .take(arity) + .map(|y| self.constant(y)) .collect::>(); + let mut coset = Vec::new(); + for i in 0..arity / 2 { + let res = self.mul_two( + coset_start, + g_powers[2 * i], + coset_start, + g_powers[2 * i + 1], + ); + coset.push(res.0); + coset.push(res.1); + } + let points = coset.into_iter().zip(evals).collect::>(); self.interpolate(&points, beta) } @@ -195,6 +204,7 @@ impl, const D: usize> CircuitBuilder { assert!(D > 1, "Not implemented for D=1."); let config = self.config.clone(); let degree_log = proof.evals_proofs[0].1.siblings.len() - config.rate_bits; + let one = self.one_extension(); let subgroup_x = self.convert_to_ext(subgroup_x); let vanish_zeta = self.sub_extension(subgroup_x, zeta); let mut alpha = ReducingFactorTarget::new(alpha); @@ -223,8 +233,7 @@ impl, const D: usize> CircuitBuilder { self.sub_extension(single_composition_eval, precomputed_reduced_evals.single); // This division is safe because the denominator will be nonzero unless zeta is in the // codeword domain, which occurs with negligible probability given a large extension field. - let quotient = self.div_unsafe_extension(single_numerator, vanish_zeta); - sum = self.add_extension(sum, quotient); + sum = self.div_add_extension(single_numerator, vanish_zeta, sum); alpha.reset(); // Polynomials opened at `x` and `g x`, i.e., the Zs polynomials. @@ -245,14 +254,13 @@ impl, const D: usize> CircuitBuilder { ], subgroup_x, ); - let zs_numerator = self.sub_extension(zs_composition_eval, interpol_val); - let vanish_zeta_right = self.sub_extension(subgroup_x, zeta_right); + let (zs_numerator, vanish_zeta_right) = + self.sub_two_extension(zs_composition_eval, interpol_val, subgroup_x, zeta_right); let zs_denominator = self.mul_extension(vanish_zeta, vanish_zeta_right); + sum = alpha.shift(sum, self); // This division is safe because the denominator will be nonzero unless zeta is in the // codeword domain, which occurs with negligible probability given a large extension field. - let zs_quotient = self.div_unsafe_extension(zs_numerator, zs_denominator); - sum = alpha.shift(sum, self); - sum = self.add_extension(sum, zs_quotient); + sum = self.div_add_extension(zs_numerator, zs_denominator, sum); sum } diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index d670f58d..9779da7e 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -86,6 +86,16 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, y, F::ZERO, x) } + /// Computes `x * y`. + pub fn mul_two(&mut self, a0: Target, b0: Target, a1: Target, b1: Target) -> (Target, Target) { + let a0_ext = self.convert_to_ext(a0); + let b0_ext = self.convert_to_ext(b0); + let a1_ext = self.convert_to_ext(a1); + let b1_ext = self.convert_to_ext(b1); + let res = self.mul_two_extension(a0_ext, b0_ext, a1_ext, b1_ext); + (res.0 .0[0], res.1 .0[0]) + } + /// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { let terms_ext = terms diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 29b759fb..b7db20b9 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -152,6 +152,33 @@ impl, const D: usize> CircuitBuilder { None } + /// Returns `sum_{(a,b) in vecs} constant * a * b`. + pub fn inner_product_extension( + &mut self, + constant: F, + starting_acc: ExtensionTarget, + pairs: Vec<(ExtensionTarget, ExtensionTarget)>, + ) -> ExtensionTarget { + let mut acc = starting_acc; + for chunk in pairs.chunks_exact(2) { + let (a0, b0) = chunk[0]; + let (a1, b1) = chunk[1]; + let gate = self.num_gates(); + let first_out = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_output(), + ); + acc = self + .double_arithmetic_extension(constant, F::ONE, a0, b0, acc, a1, b1, first_out) + .1; + } + if pairs.len().is_odd() { + let n = pairs.len() - 1; + acc = self.arithmetic_extension(constant, F::ONE, pairs[n].0, pairs[n].1, acc); + } + acc + } + pub fn add_extension( &mut self, a: ExtensionTarget, @@ -320,24 +347,44 @@ impl, const D: usize> CircuitBuilder { self.mul_three_extension(x, x, x) } + /// Returns `a * b + c`. + pub fn mul_add_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + c: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut inner = vec![vec![]; D]; + let mut inner_w = vec![vec![]; D]; + for i in 0..D { + for j in 0..D - i { + inner[(i + j) % D].push((a.0[i], b.0[j])); + } + for j in D - i..D { + inner_w[(i + j) % D].push((a.0[i], b.0[j])); + } + } + let res = inner_w + .into_iter() + .zip(inner) + .zip(c.0) + .map(|((pairs_w, pairs), ci)| { + let acc = self.inner_product_extension(F::Extension::W, ci, pairs_w); + self.inner_product_extension(F::ONE, acc, pairs) + }) + .collect::>(); + + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + /// Returns `a * b`. pub fn mul_ext_algebra( &mut self, a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - let mut res = [self.zero_extension(); D]; - let w = self.constant(F::Extension::W); - for i in 0..D { - for j in 0..D { - res[(i + j) % D] = if i + j < D { - self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) - } else { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); - self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) - } - } - } - ExtensionAlgebraTarget(res) + let zero = self.zero_ext_algebra(); + self.mul_add_ext_algebra(a, b, zero) } /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. @@ -422,17 +469,41 @@ impl, const D: usize> CircuitBuilder { self.mul_extension(a_ext, b) } - /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the - /// extension field. + /// Returns `a * b + c`, where `b, c` are in the extension algebra and `a` in the extension field. + pub fn scalar_mul_add_ext_algebra( + &mut self, + a: ExtensionTarget, + b: ExtensionAlgebraTarget, + mut c: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + for i in 0..D / 2 { + let res = self.double_arithmetic_extension( + F::ONE, + F::ONE, + a, + b.0[2 * i], + c.0[2 * i], + a, + b.0[2 * i + 1], + c.0[2 * i + 1], + ); + c.0[2 * i] = res.0; + c.0[2 * i + 1] = res.1; + } + if D.is_odd() { + c.0[D - 1] = self.arithmetic_extension(F::ONE, F::ONE, a, b.0[D - 1], c.0[D - 1]); + } + c + } + + /// Returns `a * b`, where `b` is in the extension algebra and `a` in the extension field. pub fn scalar_mul_ext_algebra( &mut self, a: ExtensionTarget, - mut b: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D { - b.0[i] = self.mul_extension(a, b.0[i]); - } - b + let zero = self.zero_ext_algebra(); + self.scalar_mul_add_ext_algebra(a, b, zero) } /// Exponentiate `base` to the power of `2^power_log`. @@ -480,8 +551,43 @@ impl, const D: usize> CircuitBuilder { x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - let y_inv = self.inverse_extension(y); - self.mul_extension(x, y_inv) + let inv = self.add_virtual_extension_target(); + let one = self.one_extension(); + self.add_generator(QuotientGeneratorExtension { + numerator: one, + denominator: y, + quotient: inv, + }); + + // Enforce that x times its purported inverse equals 1. + let (y_inv, res) = self.mul_two_extension(y, inv, x, inv); + self.assert_equal_extension(y_inv, one); + + res + } + + /// Computes ` x / y + z`. + pub fn div_add_extension( + &mut self, + x: ExtensionTarget, + y: ExtensionTarget, + z: ExtensionTarget, + ) -> ExtensionTarget { + let inv = self.add_virtual_extension_target(); + let zero = self.zero_extension(); + let one = self.one_extension(); + self.add_generator(QuotientGeneratorExtension { + numerator: one, + denominator: y, + quotient: inv, + }); + + // Enforce that x times its purported inverse equals 1. + let (y_inv, res) = + self.double_arithmetic_extension(F::ONE, F::ONE, y, inv, zero, x, inv, z); + self.assert_equal_extension(y_inv, one); + + res } /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in @@ -585,9 +691,12 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use std::convert::TryInto; + use anyhow::Result; use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; @@ -659,4 +768,35 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_mul_algebra() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let pw = PartialWitness::new(config.num_wires); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand_vec(4); + let y = FF::rand_vec(4); + let xa = ExtensionAlgebra(x.try_into().unwrap()); + let ya = ExtensionAlgebra(y.try_into().unwrap()); + let za = xa * ya; + + let xt = builder.constant_ext_algebra(xa); + let yt = builder.constant_ext_algebra(ya); + let zt = builder.constant_ext_algebra(za); + let comp_zt = builder.mul_ext_algebra(xt, yt); + for i in 0..D { + builder.assert_equal_extension(zt.0[i], comp_zt.0[i]); + } + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 8ed5346a..92090d90 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -16,8 +16,12 @@ impl, const D: usize> CircuitBuilder { // b0 -> b1 // x -> a1 + (x-a0)*(b1-a1)/(b0-a0) - let x_m_a0 = self.sub_extension(evaluation_point, interpolation_points[0].0); - let b1_m_a1 = self.sub_extension(interpolation_points[1].1, interpolation_points[0].1); + let (x_m_a0, b1_m_a1) = self.sub_two_extension( + evaluation_point, + interpolation_points[0].0, + interpolation_points[1].1, + interpolation_points[0].1, + ); let b0_m_a0 = self.sub_extension(interpolation_points[1].0, interpolation_points[0].0); let quotient = self.div_extension(b1_m_a1, b0_m_a0); diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index a83cbcd4..3d371c53 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -2,6 +2,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar use crate::field::extension_field::Extendable; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::reducing::ReducingFactorTarget; pub struct PolynomialCoeffsExtTarget(pub Vec>); @@ -15,12 +16,9 @@ impl PolynomialCoeffsExtTarget { builder: &mut CircuitBuilder, point: Target, ) -> ExtensionTarget { - let mut acc = builder.zero_extension(); - for &c in self.0.iter().rev() { - let tmp = builder.scalar_mul_ext(point, acc); - acc = builder.add_extension(tmp, c); - } - acc + let point = builder.convert_to_ext(point); + let mut point = ReducingFactorTarget::new(point); + point.reduce(&self.0, builder) } pub fn eval>( @@ -28,12 +26,8 @@ impl PolynomialCoeffsExtTarget { builder: &mut CircuitBuilder, point: ExtensionTarget, ) -> ExtensionTarget { - let mut acc = builder.zero_extension(); - for &c in self.0.iter().rev() { - let tmp = builder.mul_extension(point, acc); - acc = builder.add_extension(tmp, c); - } - acc + let mut point = ReducingFactorTarget::new(point); + point.reduce(&self.0, builder) } } @@ -50,8 +44,7 @@ impl PolynomialCoeffsExtAlgebraTarget { { let mut acc = builder.zero_ext_algebra(); for &c in self.0.iter().rev() { - let tmp = builder.scalar_mul_ext_algebra(point, acc); - acc = builder.add_ext_algebra(tmp, c); + acc = builder.scalar_mul_add_ext_algebra(point, acc, c); } acc } @@ -66,8 +59,7 @@ impl PolynomialCoeffsExtAlgebraTarget { { let mut acc = builder.zero_ext_algebra(); for &c in self.0.iter().rev() { - let tmp = builder.mul_ext_algebra(point, acc); - acc = builder.add_ext_algebra(tmp, c); + acc = builder.mul_add_ext_algebra(point, acc, c); } acc } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index e953809b..003c9ab0 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -160,6 +160,8 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { + let one = builder.one_extension(); + let neg_one = builder.neg_one_extension(); let mut constraints = Vec::with_capacity(self.num_constraints()); let swap = vars.local_wires[Self::WIRE_SWAP]; @@ -195,8 +197,18 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; constraints.push(builder.sub_extension(cubing_input, cubing_input_wire)); let f = builder.cube_extension(cubing_input_wire); - addition_buffer = builder.add_extension(addition_buffer, f); - state[active] = builder.sub_extension(state[active], f); + // addition_buffer += f + // state[active] -= f + (addition_buffer, state[active]) = builder.double_arithmetic_extension( + F::ONE, + F::ONE, + one, + addition_buffer, + f, + neg_one, + f, + state[active], + ); } for i in 0..W { diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 2dd5bee8..a793463e 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -187,14 +187,20 @@ impl, const D: usize> Gate for InsertionGate { let mut new_item = builder.scalar_mul_ext_algebra(insert_here, element_to_insert); if r > 0 { - let to_add = builder.scalar_mul_ext_algebra(already_inserted, list_items[r - 1]); - new_item = builder.add_ext_algebra(new_item, to_add); + new_item = builder.scalar_mul_add_ext_algebra( + already_inserted, + list_items[r - 1], + new_item, + ); } already_inserted = builder.add_extension(already_inserted, insert_here); if r < self.vec_size { let not_already_inserted = builder.sub_extension(one, already_inserted); - let to_add = builder.scalar_mul_ext_algebra(not_already_inserted, list_items[r]); - new_item = builder.add_ext_algebra(new_item, to_add); + new_item = builder.scalar_mul_add_ext_algebra( + not_already_inserted, + list_items[r], + new_item, + ); } // Output constraint. diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index cfdcaf17..f3799b10 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -121,9 +121,8 @@ impl, const D: usize> Gate for ReducingGate { let mut constraints = Vec::new(); let mut acc = old_acc; for i in 0..self.num_coeffs { - let mut tmp = builder.mul_ext_algebra(acc, alpha); let coeff = builder.convert_to_ext_algebra(coeffs[i]); - tmp = builder.add_ext_algebra(tmp, coeff); + let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff); tmp = builder.sub_ext_algebra(tmp, accs[i]); constraints.push(tmp); acc = accs[i]; diff --git a/src/plonk/plonk_common.rs b/src/plonk/plonk_common.rs index fb083764..05812688 100644 --- a/src/plonk/plonk_common.rs +++ b/src/plonk/plonk_common.rs @@ -1,5 +1,7 @@ use std::borrow::Borrow; +use num::Integer; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; @@ -7,6 +9,7 @@ use crate::fri::commitment::SALT_SIZE; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::util::reducing::ReducingFactorTarget; /// Holds the Merkle tree index and blinding flag of a set of polynomials used in FRI. #[derive(Debug, Copy, Clone)] @@ -181,11 +184,9 @@ pub(crate) fn reduce_with_powers_ext_recursive, const D: usize> terms: &[ExtensionTarget], alpha: Target, ) -> ExtensionTarget { - let mut sum = builder.zero_extension(); - for &term in terms.iter().rev() { - sum = builder.scalar_mul_add_extension(alpha, sum, term); - } - sum + let alpha = builder.convert_to_ext(alpha); + let mut alpha = ReducingFactorTarget::new(alpha); + alpha.reduce(terms, builder) } /// Reduce a sequence of field elements by the given coefficients.