Use ExtensionAlgebra + new CircuitBuilder::mul_extension

This commit is contained in:
wborgeaud 2021-06-09 10:51:50 +02:00
parent 7f63276623
commit 9adf5bb43f
13 changed files with 95 additions and 145 deletions

View File

@ -1,4 +1,5 @@
use crate::field::field::Field;
use std::convert::TryInto;
pub mod algebra;
pub mod quadratic;
@ -88,10 +89,6 @@ where
{
debug_assert_eq!(l.len() % D, 0);
l.chunks_exact(D)
.map(|c| {
let mut arr = [F::ZERO; D];
arr.copy_from_slice(c);
F::Extension::from_basefield_array(arr)
})
.map(|c| F::Extension::from_basefield_array(c.to_vec().try_into().unwrap()))
.collect()
}

View File

@ -2,7 +2,10 @@ use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::{Extendable, FieldExtension, OEF};
use crate::field::field::Field;
use crate::gates::mul_extension::MulExtensionGate;
use crate::target::Target;
use std::convert::{TryFrom, TryInto};
use std::ops::Range;
/// `Target`s representing an element of an extension field.
#[derive(Copy, Clone, Debug)]
@ -26,6 +29,19 @@ impl<const D: usize> ExtensionTarget<D> {
Self(res)
}
pub fn from_range(gate: usize, range: Range<usize>) -> Self {
debug_assert_eq!(range.end - range.start, D);
Target::wires_from_range(gate, range).try_into().unwrap()
}
}
impl<const D: usize> TryFrom<Vec<Target>> for ExtensionTarget<D> {
type Error = Vec<Target>;
fn try_from(value: Vec<Target>) -> Result<Self, Self::Error> {
Ok(Self(value.try_into()?))
}
}
/// `Target`s representing an element of an extension of an extension field.
@ -128,7 +144,34 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a
}
pub fn mul_extension_with_const(
&mut self,
const_0: F,
multiplicand_0: ExtensionTarget<D>,
multiplicand_1: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let gate = self.add_gate(MulExtensionGate::new(), vec![const_0]);
let wire_multiplicand_0 =
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_multiplicand_0());
let wire_multiplicand_1 =
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_multiplicand_1());
let wire_output = ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_output());
self.route_extension(multiplicand_0, wire_multiplicand_0);
self.route_extension(multiplicand_1, wire_multiplicand_1);
wire_output
}
pub fn mul_extension(
&mut self,
multiplicand_0: ExtensionTarget<D>,
multiplicand_1: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1)
}
pub fn mul_extension_naive(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
@ -156,7 +199,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let w = self.constant(F::Extension::W);
for i in 0..D {
for j in 0..D {
let ai_bi = self.mul_extension(a.0[i], b.0[j]);
let ai_bi = self.mul_extension_naive(a.0[i], b.0[j]);
res[(i + j) % D] = if i + j < D {
self.add_extension(ai_bi, res[(i + j) % D])
} else {
@ -171,7 +214,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let mut product = self.one_extension();
for term in terms {
product = self.mul_extension(product, *term);
product = self.mul_extension_naive(product, *term);
}
product
}
@ -184,7 +227,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let product = self.mul_extension(a, b);
let product = self.mul_extension_naive(a, b);
self.add_extension(product, c)
}
@ -204,7 +247,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
mut b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
for i in 0..D {
b.0[i] = self.mul_extension(a, b.0[i]);
b.0[i] = self.mul_extension_naive(a, b.0[i]);
}
b
}
@ -225,13 +268,9 @@ pub fn flatten_target<const D: usize>(l: &[ExtensionTarget<D>]) -> Vec<Target> {
}
/// Batch every D-sized chunks into extension targets.
pub fn unflatten_target<const D: usize>(l: &[Target]) -> Vec<ExtensionTarget<D>> {
pub fn unflatten_target<F: Extendable<D>, const D: usize>(l: &[Target]) -> Vec<ExtensionTarget<D>> {
debug_assert_eq!(l.len() % D, 0);
l.chunks_exact(D)
.map(|c| {
let mut arr = Default::default();
arr.copy_from_slice(c);
ExtensionTarget(arr)
})
.map(|c| c.to_vec().try_into().unwrap())
.collect()
}

View File

@ -187,7 +187,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log));
let zeta_right = self.mul_extension(g, zeta);
let zeta_right = self.mul_extension_naive(g, zeta);
let mut ev_zeta = self.zero_extension();
for &t in &os.plonk_zs {
let a = alpha_powers.next(self);
@ -203,7 +203,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let numerator = self.sub_extension(ev, interpol_val);
let vanish = self.sub_extension(subgroup_x, zeta);
let vanish_right = self.sub_extension(subgroup_x, zeta_right);
let denominator = self.mul_extension(vanish, vanish_right);
let denominator = self.mul_extension_naive(vanish, vanish_right);
let quotient = self.div_unsafe_extension(numerator, denominator);
let sum = self.add_extension(sum, quotient);
@ -237,7 +237,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let interpol_val = wires_interpol.eval(self, subgroup_x);
let numerator = self.sub_extension(ev, interpol_val);
let vanish_frob = self.sub_extension(subgroup_x, zeta_frob);
let denominator = self.mul_extension(vanish, vanish_frob);
let denominator = self.mul_extension_naive(vanish, vanish_frob);
let quotient = self.div_unsafe_extension(numerator, denominator);
let sum = self.add_extension(sum, quotient);

View File

@ -238,17 +238,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// Add an `ArithmeticGate` to compute `q * y`.
let gate = self.add_gate(MulExtensionGate::new(), vec![F::ONE]);
let multiplicand_0 = MulExtensionGate::<D>::wires_multiplicand_0()
.map(|i| Target::Wire(Wire { gate, input: i }))
.collect::<Vec<_>>();
let multiplicand_0 =
Target::wires_from_range(gate, MulExtensionGate::<D>::wires_multiplicand_0());
let multiplicand_0 = ExtensionTarget(multiplicand_0.try_into().unwrap());
let multiplicand_1 = MulExtensionGate::<D>::wires_multiplicand_1()
.map(|i| Target::Wire(Wire { gate, input: i }))
.collect::<Vec<_>>();
let multiplicand_1 =
Target::wires_from_range(gate, MulExtensionGate::<D>::wires_multiplicand_1());
let multiplicand_1 = ExtensionTarget(multiplicand_1.try_into().unwrap());
let output = MulExtensionGate::<D>::wires_output()
.map(|i| Target::Wire(Wire { gate, input: i }))
.collect::<Vec<_>>();
let output = Target::wires_from_range(gate, MulExtensionGate::<D>::wires_output());
let output = ExtensionTarget(output.try_into().unwrap());
self.add_generator(QuotientGeneratorExtension {
@ -324,7 +320,7 @@ impl<const D: usize> PowersTarget<D> {
builder: &mut CircuitBuilder<F, D>,
) -> ExtensionTarget<D> {
let result = self.current;
self.current = builder.mul_extension(self.base, self.current);
self.current = builder.mul_extension_naive(self.base, self.current);
result
}
}

View File

@ -26,7 +26,7 @@ impl<const D: usize> PolynomialCoeffsExtTarget<D> {
) -> ExtensionTarget<D> {
let mut acc = builder.zero_extension();
for &c in self.0.iter().rev() {
let tmp = builder.mul_extension(point, acc);
let tmp = builder.mul_extension_naive(point, acc);
acc = builder.add_extension(tmp, c);
}
acc

View File

@ -19,9 +19,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
input: BaseSumGate::<B>::WIRE_SUM,
});
self.route(x, sum);
(BaseSumGate::<B>::WIRE_LIMBS_START..BaseSumGate::<B>::WIRE_LIMBS_START + num_limbs)
.map(|i| Target::Wire(Wire { gate, input: i }))
.collect()
Target::wires_from_range(
gate,
BaseSumGate::<B>::WIRE_LIMBS_START..BaseSumGate::<B>::WIRE_LIMBS_START + num_limbs,
)
}
/// Asserts that `x`'s bit representation has at least `trailing_zeros` trailing zeros.

View File

@ -57,7 +57,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate {
let output = vars.local_wires[Self::WIRE_OUTPUT];
let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
let addend_term = builder.mul_extension(const_1, addend);
let addend_term = builder.mul_extension_naive(const_1, addend);
let computed_output = builder.add_many_extension(&[product_term, addend_term]);
vec![builder.sub_extension(computed_output, output)]
}

View File

@ -69,7 +69,7 @@ impl<F: Extendable<D>, const D: usize, const B: usize> Gate<F, D> for BaseSumGat
(0..B).for_each(|i| {
let it = builder.constant_extension(F::from_canonical_usize(i).into());
let diff = builder.sub_extension(limb, it);
acc = builder.mul_extension(acc, diff);
acc = builder.mul_extension_naive(acc, diff);
});
acc
});

View File

@ -131,7 +131,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let swap = vars.local_wires[Self::WIRE_SWAP];
let one_ext = builder.one_extension();
let not_swap = builder.sub_extension(swap, one_ext);
constraints.push(builder.mul_extension(swap, not_swap));
constraints.push(builder.mul_extension_naive(swap, not_swap));
let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD];
let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW];
@ -168,8 +168,8 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let constant = builder.constant_extension(self.constants[r].into());
let cubing_input =
builder.add_many_extension(&[state[active], addition_buffer, constant]);
let square = builder.mul_extension(cubing_input, cubing_input);
let f = builder.mul_extension(square, cubing_input);
let square = builder.mul_extension_naive(cubing_input, cubing_input);
let f = builder.mul_extension_naive(square, cubing_input);
addition_buffer = builder.add_extension(addition_buffer, f);
state[active] = builder.sub_extension(state[active], f);
}

View File

@ -357,9 +357,7 @@ mod tests {
};
assert!(
gate.eval_unfiltered(vars.clone())
.iter()
.all(|x| x.is_zero()),
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}

View File

@ -1,7 +1,6 @@
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension, OEF};
use crate::field::field::Field;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::gates::gate::{Gate, GateRef};
use crate::generator::{SimpleGenerator, WitnessGenerator};
use crate::target::Target;
@ -11,77 +10,6 @@ use crate::witness::PartialWitness;
use std::convert::TryInto;
use std::ops::Range;
// TODO: Replace this when https://github.com/mir-protocol/plonky2/issues/56 is resolved.
fn mul_vec<F: Field>(a: &[F], b: &[F], w: F) -> Vec<F> {
let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]);
let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]);
let c0 = a0 * b0 + w * (a1 * b3 + a2 * b2 + a3 * b1);
let c1 = a0 * b1 + a1 * b0 + w * (a2 * b3 + a3 * b2);
let c2 = a0 * b2 + a1 * b1 + a2 * b0 + w * a3 * b3;
let c3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0;
vec![c0, c1, c2, c3]
}
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn mul_vec(
&mut self,
a: &[ExtensionTarget<D>],
b: &[ExtensionTarget<D>],
w: ExtensionTarget<D>,
) -> Vec<ExtensionTarget<D>> {
let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]);
let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]);
// TODO: Optimize this.
let c0 = {
let tmp0 = self.mul_extension(a0, b0);
let tmp1 = self.mul_extension(a1, b3);
let tmp2 = self.mul_extension(a2, b2);
let tmp3 = self.mul_extension(a3, b1);
let tmp = self.add_extension(tmp1, tmp2);
let tmp = self.add_extension(tmp, tmp3);
let tmp = self.mul_extension(w, tmp);
let tmp = self.add_extension(tmp0, tmp);
tmp
};
let c1 = {
let tmp0 = self.mul_extension(a0, b1);
let tmp1 = self.mul_extension(a1, b0);
let tmp2 = self.mul_extension(a2, b3);
let tmp3 = self.mul_extension(a3, b2);
let tmp = self.add_extension(tmp2, tmp3);
let tmp = self.mul_extension(w, tmp);
let tmp = self.add_extension(tmp, tmp0);
let tmp = self.add_extension(tmp, tmp1);
tmp
};
let c2 = {
let tmp0 = self.mul_extension(a0, b2);
let tmp1 = self.mul_extension(a1, b1);
let tmp2 = self.mul_extension(a2, b0);
let tmp3 = self.mul_extension(a3, b3);
let tmp = self.mul_extension(w, tmp3);
let tmp = self.add_extension(tmp, tmp2);
let tmp = self.add_extension(tmp, tmp1);
let tmp = self.add_extension(tmp, tmp0);
tmp
};
let c3 = {
let tmp0 = self.mul_extension(a0, b3);
let tmp1 = self.mul_extension(a1, b2);
let tmp2 = self.mul_extension(a2, b1);
let tmp3 = self.mul_extension(a3, b0);
let tmp = self.add_extension(tmp3, tmp2);
let tmp = self.add_extension(tmp, tmp1);
let tmp = self.add_extension(tmp, tmp0);
tmp
};
vec![c0, c1, c2, c3]
}
}
/// A gate which can multiply two field extension elements.
/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough.
#[derive(Debug)]
@ -110,25 +38,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let const_0 = vars.local_constants[0];
let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec();
let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec();
let output = vars.local_wires[Self::wires_output()].to_vec();
let computed_output = mul_vec(
&[
const_0,
F::Extension::ZERO,
F::Extension::ZERO,
F::Extension::ZERO,
],
&multiplicand_0,
F::Extension::W.into(),
);
let computed_output = mul_vec(&computed_output, &multiplicand_1, F::Extension::W.into());
output
.into_iter()
.zip(computed_output)
.map(|(o, co)| o - co)
.collect()
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0());
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1());
let output = vars.get_local_ext_algebra(Self::wires_output());
let computed_output = multiplicand_0 * multiplicand_1 * const_0.into();
(output - computed_output).to_basefield_array().to_vec()
}
fn eval_unfiltered_recursively(
@ -137,18 +51,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let const_0 = vars.local_constants[0];
let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec();
let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec();
let output = vars.local_wires[Self::wires_output()].to_vec();
let w = builder.constant_extension(F::Extension::W.into());
let zero = builder.zero_extension();
let computed_output = builder.mul_vec(&[const_0, zero, zero, zero], &multiplicand_0, w);
let computed_output = builder.mul_vec(&computed_output, &multiplicand_1, w);
output
.into_iter()
.zip(computed_output)
.map(|(o, co)| builder.sub_extension(o, co))
.collect()
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0());
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1());
let output = vars.get_local_ext_algebra(Self::wires_output());
let computed_output = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
let computed_output = builder.scalar_mul_ext_algebra(const_0, computed_output);
let diff = builder.sub_ext_algebra(output, computed_output);
diff.to_ext_target_array().to_vec()
}
fn generators(
@ -236,7 +145,6 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MulExtensionGenera
#[cfg(test)]
mod tests {
use crate::field::crandall_field::CrandallField;
use crate::gates::arithmetic::ArithmeticGate;
use crate::gates::gate_testing::test_low_degree;
use crate::gates::mul_extension::MulExtensionGate;

View File

@ -1,5 +1,6 @@
use crate::circuit_data::CircuitConfig;
use crate::wire::Wire;
use std::ops::Range;
/// A location in the witness.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
@ -21,4 +22,8 @@ impl Target {
Target::VirtualAdviceTarget { .. } => false,
}
}
pub fn wires_from_range(gate: usize, range: Range<usize>) -> Vec<Self> {
range.map(|i| Self::wire(gate, i)).collect()
}
}

View File

@ -1,4 +1,5 @@
use crate::circuit_data::CircuitConfig;
use std::ops::Range;
/// Represents a wire in the circuit.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
@ -13,4 +14,8 @@ impl Wire {
pub fn is_routable(&self, config: &CircuitConfig) -> bool {
self.input < config.num_routed_wires
}
pub fn from_range(gate: usize, range: Range<usize>) -> Vec<Self> {
range.map(|i| Wire { gate, input: i }).collect()
}
}