Optimize mul_ext_algebra

This commit is contained in:
wborgeaud 2021-08-09 11:30:03 +02:00
parent f0f8320b93
commit eeb33f99ca
2 changed files with 73 additions and 22 deletions

View File

@ -7,7 +7,7 @@ use crate::field::extension_field::OEF;
/// Let `F_D` be the optimal extension field `F[X]/(X^D-W)`. Then `ExtensionAlgebra<F_D>` is the quotient `F_D[X]/(X^D-W)`.
/// It's a `D`-dimensional algebra over `F_D` useful to lift the multiplication over `F_D` to a multiplication over `(F_D)^D`.
#[derive(Copy, Clone)]
pub struct ExtensionAlgebra<F: OEF<D>, const D: usize>([F; D]);
pub struct ExtensionAlgebra<F: OEF<D>, const D: usize>(pub [F; D]);
impl<F: OEF<D>, const D: usize> ExtensionAlgebra<F, D> {
pub const ZERO: Self = Self([F::ZERO; D]);

View File

@ -12,7 +12,6 @@ use crate::iop::target::Target;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::bits_u64;
use crate::with_context;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn double_arithmetic_extension(
@ -175,6 +174,31 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
res
}
pub fn inner_product_extension(
&mut self,
constant: F,
starting_acc: ExtensionTarget<D>,
vecs: Vec<[ExtensionTarget<D>; 2]>,
) -> ExtensionTarget<D> {
let mut acc = starting_acc;
for chunk in vecs.chunks_exact(2) {
let [a0, b0] = chunk[0];
let [a1, b1] = chunk[1];
let gate = self.num_gates();
let first_out = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_first_output(),
);
acc = self
.double_arithmetic_extension(constant, F::ONE, a0, b0, acc, a1, b1, first_out)
.1;
}
if vecs.len().is_odd() {
let n = vecs.len() - 1;
acc = self.arithmetic_extension(constant, F::ONE, vecs[n][0], vecs[n][1], acc);
}
acc
}
pub fn add_extension(
&mut self,
@ -350,32 +374,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
let zero = self.zero_extension();
let mut ops = Vec::new();
let mut opsw = Vec::new();
let mut inner = vec![vec![]; D];
let mut inner_w = vec![vec![]; D];
for i in 0..D {
for j in 0..D - i {
ops.push([a.0[i], b.0[j], zero]);
inner[(i + j) % D].push([a.0[i], b.0[j]]);
}
for j in D - i..D {
opsw.push([a.0[i], b.0[j], zero]);
inner_w[(i + j) % D].push([a.0[i], b.0[j]]);
}
}
let mut muls = self.arithmetic_many_extension(F::ONE, F::ONE, ops);
let mut mulsw = self.arithmetic_many_extension(F::Extension::W, F::ONE, opsw);
let mut toadd = vec![vec![]; D];
for i in 0..D {
for j in 0..D - i {
toadd[(i + j) % D].push(muls.remove(0));
}
for j in D - i..D {
toadd[(i + j) % D].push(mulsw.remove(0));
}
}
let mut res = [zero; D];
for i in 0..D {
res[i] = self.add_many_extension(&toadd[i]);
}
ExtensionAlgebraTarget(res)
let res = inner_w
.into_iter()
.zip(inner)
.map(|(vecs_w, vecs)| {
let acc = self.inner_product_extension(F::Extension::W, zero, vecs_w);
self.inner_product_extension(F::ONE, acc, vecs)
})
.collect::<Vec<_>>();
ExtensionAlgebraTarget(res.try_into().unwrap())
}
/// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
@ -623,9 +641,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
@ -696,4 +717,34 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_mul_algebra() -> Result<()> {
type F = CrandallField;
type FF = QuarticCrandallField;
const D: usize = 4;
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = FF::rand_vec(4);
let y = FF::rand_vec(4);
let xa = ExtensionAlgebra(x.try_into().unwrap());
let ya = ExtensionAlgebra(y.try_into().unwrap());
let za = xa * ya;
let xt = builder.constant_ext_algebra(xa);
let yt = builder.constant_ext_algebra(ya);
let zt = builder.constant_ext_algebra(za);
let comp_zt = builder.mul_ext_algebra(xt, yt);
for i in 0..D {
builder.assert_equal_extension(zt.0[i], comp_zt.0[i]);
}
let data = builder.build();
let proof = data.prove(PartialWitness::new())?;
verify(proof, &data.verifier_only, &data.common)
}
}