From eeb33f99cae0933e717fad3aad318ec850538b06 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 9 Aug 2021 11:30:03 +0200 Subject: [PATCH] Optimize mul_ext_algebra --- src/field/extension_field/algebra.rs | 2 +- src/gadgets/arithmetic_extension.rs | 93 +++++++++++++++++++++------- 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index 633c4d82..b4a044c3 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::OEF; /// Let `F_D` be the optimal extension field `F[X]/(X^D-W)`. Then `ExtensionAlgebra` is the quotient `F_D[X]/(X^D-W)`. /// It's a `D`-dimensional algebra over `F_D` useful to lift the multiplication over `F_D` to a multiplication over `(F_D)^D`. #[derive(Copy, Clone)] -pub struct ExtensionAlgebra, const D: usize>([F; D]); +pub struct ExtensionAlgebra, const D: usize>(pub [F; D]); impl, const D: usize> ExtensionAlgebra { pub const ZERO: Self = Self([F::ZERO; D]); diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index e5702ecf..c800cf8c 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -12,7 +12,6 @@ use crate::iop::target::Target; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; -use crate::with_context; impl, const D: usize> CircuitBuilder { pub fn double_arithmetic_extension( @@ -175,6 +174,31 @@ impl, const D: usize> CircuitBuilder { res } + pub fn inner_product_extension( + &mut self, + constant: F, + starting_acc: ExtensionTarget, + vecs: Vec<[ExtensionTarget; 2]>, + ) -> ExtensionTarget { + let mut acc = starting_acc; + for chunk in vecs.chunks_exact(2) { + let [a0, b0] = chunk[0]; + let [a1, b1] = chunk[1]; + let gate = self.num_gates(); + let first_out = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_output(), + ); + acc = self + .double_arithmetic_extension(constant, F::ONE, a0, b0, acc, a1, b1, first_out) + .1; + } + if vecs.len().is_odd() { + let n = vecs.len() - 1; + acc = self.arithmetic_extension(constant, F::ONE, vecs[n][0], vecs[n][1], acc); + } + acc + } pub fn add_extension( &mut self, @@ -350,32 +374,26 @@ impl, const D: usize> CircuitBuilder { b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { let zero = self.zero_extension(); - let mut ops = Vec::new(); - let mut opsw = Vec::new(); + let mut inner = vec![vec![]; D]; + let mut inner_w = vec![vec![]; D]; for i in 0..D { for j in 0..D - i { - ops.push([a.0[i], b.0[j], zero]); + inner[(i + j) % D].push([a.0[i], b.0[j]]); } for j in D - i..D { - opsw.push([a.0[i], b.0[j], zero]); + inner_w[(i + j) % D].push([a.0[i], b.0[j]]); } } - let mut muls = self.arithmetic_many_extension(F::ONE, F::ONE, ops); - let mut mulsw = self.arithmetic_many_extension(F::Extension::W, F::ONE, opsw); - let mut toadd = vec![vec![]; D]; - for i in 0..D { - for j in 0..D - i { - toadd[(i + j) % D].push(muls.remove(0)); - } - for j in D - i..D { - toadd[(i + j) % D].push(mulsw.remove(0)); - } - } - let mut res = [zero; D]; - for i in 0..D { - res[i] = self.add_many_extension(&toadd[i]); - } - ExtensionAlgebraTarget(res) + let res = inner_w + .into_iter() + .zip(inner) + .map(|(vecs_w, vecs)| { + let acc = self.inner_product_extension(F::Extension::W, zero, vecs_w); + self.inner_product_extension(F::ONE, acc, vecs) + }) + .collect::>(); + + ExtensionAlgebraTarget(res.try_into().unwrap()) } /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. @@ -623,9 +641,12 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use std::convert::TryInto; + use anyhow::Result; use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; @@ -696,4 +717,34 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_mul_algebra() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand_vec(4); + let y = FF::rand_vec(4); + let xa = ExtensionAlgebra(x.try_into().unwrap()); + let ya = ExtensionAlgebra(y.try_into().unwrap()); + let za = xa * ya; + + let xt = builder.constant_ext_algebra(xa); + let yt = builder.constant_ext_algebra(ya); + let zt = builder.constant_ext_algebra(za); + let comp_zt = builder.mul_ext_algebra(xt, yt); + for i in 0..D { + builder.assert_equal_extension(zt.0[i], comp_zt.0[i]); + } + + let data = builder.build(); + let proof = data.prove(PartialWitness::new())?; + + verify(proof, &data.verifier_only, &data.common) + } }