Working MulExtensionGate

This commit is contained in:
wborgeaud 2021-06-07 17:09:53 +02:00
parent 6f2275bc6d
commit a8da9b945e
3 changed files with 203 additions and 109 deletions

View File

@ -143,7 +143,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
subgroup_x: Target,
) -> ExtensionTarget<D> {
assert!(D > 1, "Not implemented for D=1.");
let config = &self.config.fri_config;
let config = &self.config.fri_config.clone();
let degree_log = proof.evals_proofs[0].1.siblings.len() - config.rate_bits;
let subgroup_x = self.convert_to_ext(subgroup_x);
let mut alpha_powers = self.powers(alpha);
@ -157,7 +157,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let evals = [0, 1, 4]
.iter()
.flat_map(|&i| proof.unsalted_evals(i, config))
.map(|&e| self.convert_to_ext(e));
.map(|&e| self.convert_to_ext(e))
.collect::<Vec<_>>();
let openings = os
.constants
.iter()
@ -170,42 +171,42 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
numerator = self.mul_add_extension(a, diff, numerator);
}
let denominator = self.sub_extension(subgroup_x, zeta);
// let quotient = self.div_unsafe()
sum += numerator / denominator;
let quotient = self.div_unsafe_extension(numerator, denominator);
let sum = self.add_extension(sum, quotient);
let ev: F::Extension = proof
.unsalted_evals(3, config)
.iter()
.zip(alpha_powers.clone())
.map(|(&e, a)| a * e.into())
.sum();
let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta;
let zs_interpol = interpolant(&[
(zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())),
(
zeta_right,
reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers),
),
]);
let numerator = ev - zs_interpol.eval(subgroup_x);
let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_right);
sum += numerator / denominator;
let ev: F::Extension = proof
.unsalted_evals(2, config)
.iter()
.zip(alpha_powers.clone())
.map(|(&e, a)| a * e.into())
.sum();
let zeta_frob = zeta.frobenius();
let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::<Vec<_>>();
let wires_interpol = interpolant(&[
(zeta, reduce_with_iter(&os.wires, alpha_powers.clone())),
(zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)),
]);
let numerator = ev - wires_interpol.eval(subgroup_x);
let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob);
sum += numerator / denominator;
// let ev: F::Extension = proof
// .unsalted_evals(3, config)
// .iter()
// .zip(alpha_powers.clone())
// .map(|(&e, a)| a * e.into())
// .sum();
// let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta;
// let zs_interpol = interpolant(&[
// (zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())),
// (
// zeta_right,
// reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers),
// ),
// ]);
// let numerator = ev - zs_interpol.eval(subgroup_x);
// let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_right);
// sum += numerator / denominator;
//
// let ev: F::Extension = proof
// .unsalted_evals(2, config)
// .iter()
// .zip(alpha_powers.clone())
// .map(|(&e, a)| a * e.into())
// .sum();
// let zeta_frob = zeta.frobenius();
// let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::<Vec<_>>();
// let wires_interpol = interpolant(&[
// (zeta, reduce_with_iter(&os.wires, alpha_powers.clone())),
// (zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)),
// ]);
// let numerator = ev - wires_interpol.eval(subgroup_x);
// let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob);
// sum += numerator / denominator;
sum
}

View File

@ -1,6 +1,6 @@
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::gates::arithmetic::ArithmeticGate;
use crate::generator::SimpleGenerator;
@ -254,11 +254,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
};
let q = Target::Wire(wire_multiplicand_0);
self.add_generator(QuotientGeneratorExtension {
numerator: x,
denominator: ExtensionTarget(),
quotient: ExtensionTarget(),
})
todo!()
// self.add_generator(QuotientGeneratorExtension {
// numerator: x,
// denominator: ExtensionTarget(),
// quotient: ExtensionTarget(),
// })
}
}
@ -286,7 +287,7 @@ struct QuotientGeneratorExtension<const D: usize> {
quotient: ExtensionTarget<D>,
}
impl<F: Field, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension<D> {
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension<D> {
fn dependencies(&self) -> Vec<Target> {
let mut deps = self.numerator.to_target_array().to_vec();
deps.extend(&self.denominator.to_target_array());
@ -298,7 +299,12 @@ impl<F: Field, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension
let dem = witness.get_extension_target(self.denominator);
let quotient = num / dem;
let mut pw = PartialWitness::new();
pw.set_ext_wires(self.quotient.to_target_array(), quotient);
for i in 0..D {
pw.set_target(
self.quotient.to_target_array()[i],
quotient.to_basefield_array()[i],
);
}
pw
}
}

View File

@ -1,6 +1,6 @@
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::{Extendable, FieldExtension, OEF};
use crate::field::field::Field;
use crate::gates::gate::{Gate, GateRef};
use crate::generator::{SimpleGenerator, WitnessGenerator};
@ -8,9 +8,81 @@ use crate::target::Target;
use crate::vars::{EvaluationTargets, EvaluationVars};
use crate::wire::Wire;
use crate::witness::PartialWitness;
use std::convert::TryInto;
use std::ops::Range;
/// A gate which can multiply to field extension elements.
// TODO: Replace this when https://github.com/mir-protocol/plonky2/issues/56 is resolved.
fn mul_vec<F: Field>(a: &[F], b: &[F], w: F) -> Vec<F> {
let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]);
let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]);
let c0 = a0 * b0 + w * (a1 * b3 + a2 * b2 + a3 * b1);
let c1 = a0 * b1 + a1 * b0 + w * (a2 * b3 + a3 * b2);
let c2 = a0 * b2 + a1 * b1 + a2 * b0 + w * a3 * b3;
let c3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0;
vec![c0, c1, c2, c3]
}
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn mul_vec(
&mut self,
a: &[ExtensionTarget<D>],
b: &[ExtensionTarget<D>],
w: ExtensionTarget<D>,
) -> Vec<ExtensionTarget<D>> {
let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]);
let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]);
// TODO: Optimize this.
let c0 = {
let tmp0 = self.mul_extension(a0, b0);
let tmp1 = self.mul_extension(a1, b3);
let tmp2 = self.mul_extension(a2, b2);
let tmp3 = self.mul_extension(a3, b1);
let tmp = self.add_extension(tmp1, tmp2);
let tmp = self.add_extension(tmp, tmp3);
let tmp = self.mul_extension(w, tmp);
let tmp = self.add_extension(tmp0, tmp);
tmp
};
let c1 = {
let tmp0 = self.mul_extension(a0, b1);
let tmp1 = self.mul_extension(a1, b0);
let tmp2 = self.mul_extension(a2, b3);
let tmp3 = self.mul_extension(a3, b2);
let tmp = self.add_extension(tmp2, tmp3);
let tmp = self.mul_extension(w, tmp);
let tmp = self.add_extension(tmp, tmp0);
let tmp = self.add_extension(tmp, tmp1);
tmp
};
let c2 = {
let tmp0 = self.mul_extension(a0, b2);
let tmp1 = self.mul_extension(a1, b1);
let tmp2 = self.mul_extension(a2, b0);
let tmp3 = self.mul_extension(a3, b3);
let tmp = self.mul_extension(w, tmp3);
let tmp = self.add_extension(tmp, tmp2);
let tmp = self.add_extension(tmp, tmp1);
let tmp = self.add_extension(tmp, tmp0);
tmp
};
let c3 = {
let tmp0 = self.mul_extension(a0, b3);
let tmp1 = self.mul_extension(a1, b2);
let tmp2 = self.mul_extension(a2, b1);
let tmp3 = self.mul_extension(a3, b0);
let tmp = self.add_extension(tmp3, tmp2);
let tmp = self.add_extension(tmp, tmp1);
let tmp = self.add_extension(tmp, tmp0);
tmp
};
vec![c0, c1, c2, c3]
}
}
/// A gate which can multiply two field extension elements.
/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough.
#[derive(Debug)]
pub struct MulExtensionGate<const D: usize>;
@ -38,13 +110,25 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0];
let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1];
let addend = vars.local_wires[Self::WIRE_ADDEND];
let output = vars.local_wires[Self::WIRE_OUTPUT];
let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend;
vec![computed_output - output]
let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec();
let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec();
let output = vars.local_wires[Self::wires_output()].to_vec();
let computed_output = mul_vec(
&[
const_0,
F::Extension::ZERO,
F::Extension::ZERO,
F::Extension::ZERO,
],
&multiplicand_0,
F::Extension::W.into(),
);
let computed_output = mul_vec(&computed_output, &multiplicand_1, F::Extension::W.into());
output
.into_iter()
.zip(computed_output)
.map(|(o, co)| o - co)
.collect()
}
fn eval_unfiltered_recursively(
@ -53,16 +137,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0];
let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1];
let addend = vars.local_wires[Self::WIRE_ADDEND];
let output = vars.local_wires[Self::WIRE_OUTPUT];
let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
let addend_term = builder.mul_extension(const_1, addend);
let computed_output = builder.add_many_extension(&[product_term, addend_term]);
vec![builder.sub_extension(computed_output, output)]
let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec();
let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec();
let output = vars.local_wires[Self::wires_output()].to_vec();
let w = builder.constant_extension(F::Extension::W.into());
let zero = builder.zero_extension();
let computed_output = builder.mul_vec(&[const_0, zero, zero, zero], &multiplicand_0, w);
let computed_output = builder.mul_vec(&computed_output, &multiplicand_1, w);
output
.into_iter()
.zip(computed_output)
.map(|(o, co)| builder.sub_extension(o, co))
.collect()
}
fn generators(
@ -70,20 +156,19 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = ArithmeticGenerator {
let gen = MulExtensionGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
};
vec![Box::new(gen)]
}
fn num_wires(&self) -> usize {
4
12
}
fn num_constants(&self) -> usize {
2
1
}
fn degree(&self) -> usize {
@ -91,59 +176,60 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
}
fn num_constraints(&self) -> usize {
1
D
}
}
struct ArithmeticGenerator<F: Field> {
struct MulExtensionGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
const_1: F,
}
impl<F: Field> SimpleGenerator<F> for ArithmeticGenerator<F> {
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MulExtensionGenerator<F, D> {
fn dependencies(&self) -> Vec<Target> {
vec![
Target::Wire(Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_MULTIPLICAND_0,
}),
Target::Wire(Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_MULTIPLICAND_1,
}),
Target::Wire(Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_ADDEND,
}),
]
MulExtensionGate::<D>::wires_multiplicand_0()
.chain(MulExtensionGate::<D>::wires_multiplicand_1())
.map(|i| {
Target::Wire(Wire {
gate: self.gate_index,
input: i,
})
})
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let multiplicand_0_target = Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_MULTIPLICAND_0,
};
let multiplicand_1_target = Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_MULTIPLICAND_1,
};
let addend_target = Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_ADDEND,
};
let output_target = Wire {
gate: self.gate_index,
input: ArithmeticGate::WIRE_OUTPUT,
};
let multiplicand_0 = MulExtensionGate::<D>::wires_multiplicand_0()
.map(|i| {
witness.get_wire(Wire {
gate: self.gate_index,
input: i,
})
})
.collect::<Vec<_>>();
let multiplicand_0 = F::Extension::from_basefield_array(multiplicand_0.try_into().unwrap());
let multiplicand_1 = MulExtensionGate::<D>::wires_multiplicand_1()
.map(|i| {
witness.get_wire(Wire {
gate: self.gate_index,
input: i,
})
})
.collect::<Vec<_>>();
let multiplicand_1 = F::Extension::from_basefield_array(multiplicand_1.try_into().unwrap());
let output = MulExtensionGate::<D>::wires_output()
.map(|i| Wire {
gate: self.gate_index,
input: i,
})
.collect::<Vec<_>>();
let multiplicand_0 = witness.get_wire(multiplicand_0_target);
let multiplicand_1 = witness.get_wire(multiplicand_1_target);
let addend = witness.get_wire(addend_target);
let computed_output =
F::Extension::from_basefield(self.const_0) * multiplicand_0 * multiplicand_1;
let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend;
PartialWitness::singleton_wire(output_target, output)
let mut pw = PartialWitness::new();
pw.set_ext_wires(output, computed_output);
pw
}
}
@ -152,9 +238,10 @@ mod tests {
use crate::field::crandall_field::CrandallField;
use crate::gates::arithmetic::ArithmeticGate;
use crate::gates::gate_testing::test_low_degree;
use crate::gates::mul_extension::MulExtensionGate;
#[test]
fn low_degree() {
test_low_degree(ArithmeticGate::new::<CrandallField, 4>())
test_low_degree(MulExtensionGate::<4>::new::<CrandallField>())
}
}