mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-07 00:03:10 +00:00
Working MulExtensionGate
This commit is contained in:
parent
6f2275bc6d
commit
a8da9b945e
@ -143,7 +143,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
subgroup_x: Target,
|
||||
) -> ExtensionTarget<D> {
|
||||
assert!(D > 1, "Not implemented for D=1.");
|
||||
let config = &self.config.fri_config;
|
||||
let config = &self.config.fri_config.clone();
|
||||
let degree_log = proof.evals_proofs[0].1.siblings.len() - config.rate_bits;
|
||||
let subgroup_x = self.convert_to_ext(subgroup_x);
|
||||
let mut alpha_powers = self.powers(alpha);
|
||||
@ -157,7 +157,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let evals = [0, 1, 4]
|
||||
.iter()
|
||||
.flat_map(|&i| proof.unsalted_evals(i, config))
|
||||
.map(|&e| self.convert_to_ext(e));
|
||||
.map(|&e| self.convert_to_ext(e))
|
||||
.collect::<Vec<_>>();
|
||||
let openings = os
|
||||
.constants
|
||||
.iter()
|
||||
@ -170,42 +171,42 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
numerator = self.mul_add_extension(a, diff, numerator);
|
||||
}
|
||||
let denominator = self.sub_extension(subgroup_x, zeta);
|
||||
// let quotient = self.div_unsafe()
|
||||
sum += numerator / denominator;
|
||||
let quotient = self.div_unsafe_extension(numerator, denominator);
|
||||
let sum = self.add_extension(sum, quotient);
|
||||
|
||||
let ev: F::Extension = proof
|
||||
.unsalted_evals(3, config)
|
||||
.iter()
|
||||
.zip(alpha_powers.clone())
|
||||
.map(|(&e, a)| a * e.into())
|
||||
.sum();
|
||||
let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta;
|
||||
let zs_interpol = interpolant(&[
|
||||
(zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())),
|
||||
(
|
||||
zeta_right,
|
||||
reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers),
|
||||
),
|
||||
]);
|
||||
let numerator = ev - zs_interpol.eval(subgroup_x);
|
||||
let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_right);
|
||||
sum += numerator / denominator;
|
||||
|
||||
let ev: F::Extension = proof
|
||||
.unsalted_evals(2, config)
|
||||
.iter()
|
||||
.zip(alpha_powers.clone())
|
||||
.map(|(&e, a)| a * e.into())
|
||||
.sum();
|
||||
let zeta_frob = zeta.frobenius();
|
||||
let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::<Vec<_>>();
|
||||
let wires_interpol = interpolant(&[
|
||||
(zeta, reduce_with_iter(&os.wires, alpha_powers.clone())),
|
||||
(zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)),
|
||||
]);
|
||||
let numerator = ev - wires_interpol.eval(subgroup_x);
|
||||
let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob);
|
||||
sum += numerator / denominator;
|
||||
// let ev: F::Extension = proof
|
||||
// .unsalted_evals(3, config)
|
||||
// .iter()
|
||||
// .zip(alpha_powers.clone())
|
||||
// .map(|(&e, a)| a * e.into())
|
||||
// .sum();
|
||||
// let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta;
|
||||
// let zs_interpol = interpolant(&[
|
||||
// (zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())),
|
||||
// (
|
||||
// zeta_right,
|
||||
// reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers),
|
||||
// ),
|
||||
// ]);
|
||||
// let numerator = ev - zs_interpol.eval(subgroup_x);
|
||||
// let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_right);
|
||||
// sum += numerator / denominator;
|
||||
//
|
||||
// let ev: F::Extension = proof
|
||||
// .unsalted_evals(2, config)
|
||||
// .iter()
|
||||
// .zip(alpha_powers.clone())
|
||||
// .map(|(&e, a)| a * e.into())
|
||||
// .sum();
|
||||
// let zeta_frob = zeta.frobenius();
|
||||
// let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::<Vec<_>>();
|
||||
// let wires_interpol = interpolant(&[
|
||||
// (zeta, reduce_with_iter(&os.wires, alpha_powers.clone())),
|
||||
// (zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)),
|
||||
// ]);
|
||||
// let numerator = ev - wires_interpol.eval(subgroup_x);
|
||||
// let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob);
|
||||
// sum += numerator / denominator;
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::circuit_builder::CircuitBuilder;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field::Field;
|
||||
use crate::gates::arithmetic::ArithmeticGate;
|
||||
use crate::generator::SimpleGenerator;
|
||||
@ -254,11 +254,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
};
|
||||
|
||||
let q = Target::Wire(wire_multiplicand_0);
|
||||
self.add_generator(QuotientGeneratorExtension {
|
||||
numerator: x,
|
||||
denominator: ExtensionTarget(),
|
||||
quotient: ExtensionTarget(),
|
||||
})
|
||||
todo!()
|
||||
// self.add_generator(QuotientGeneratorExtension {
|
||||
// numerator: x,
|
||||
// denominator: ExtensionTarget(),
|
||||
// quotient: ExtensionTarget(),
|
||||
// })
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,7 +287,7 @@ struct QuotientGeneratorExtension<const D: usize> {
|
||||
quotient: ExtensionTarget<D>,
|
||||
}
|
||||
|
||||
impl<F: Field, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension<D> {
|
||||
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension<D> {
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let mut deps = self.numerator.to_target_array().to_vec();
|
||||
deps.extend(&self.denominator.to_target_array());
|
||||
@ -298,7 +299,12 @@ impl<F: Field, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension
|
||||
let dem = witness.get_extension_target(self.denominator);
|
||||
let quotient = num / dem;
|
||||
let mut pw = PartialWitness::new();
|
||||
pw.set_ext_wires(self.quotient.to_target_array(), quotient);
|
||||
for i in 0..D {
|
||||
pw.set_target(
|
||||
self.quotient.to_target_array()[i],
|
||||
quotient.to_basefield_array()[i],
|
||||
);
|
||||
}
|
||||
pw
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::circuit_builder::CircuitBuilder;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension, OEF};
|
||||
use crate::field::field::Field;
|
||||
use crate::gates::gate::{Gate, GateRef};
|
||||
use crate::generator::{SimpleGenerator, WitnessGenerator};
|
||||
@ -8,9 +8,81 @@ use crate::target::Target;
|
||||
use crate::vars::{EvaluationTargets, EvaluationVars};
|
||||
use crate::wire::Wire;
|
||||
use crate::witness::PartialWitness;
|
||||
use std::convert::TryInto;
|
||||
use std::ops::Range;
|
||||
|
||||
/// A gate which can multiply to field extension elements.
|
||||
// TODO: Replace this when https://github.com/mir-protocol/plonky2/issues/56 is resolved.
|
||||
fn mul_vec<F: Field>(a: &[F], b: &[F], w: F) -> Vec<F> {
|
||||
let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]);
|
||||
let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]);
|
||||
|
||||
let c0 = a0 * b0 + w * (a1 * b3 + a2 * b2 + a3 * b1);
|
||||
let c1 = a0 * b1 + a1 * b0 + w * (a2 * b3 + a3 * b2);
|
||||
let c2 = a0 * b2 + a1 * b1 + a2 * b0 + w * a3 * b3;
|
||||
let c3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0;
|
||||
|
||||
vec![c0, c1, c2, c3]
|
||||
}
|
||||
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
fn mul_vec(
|
||||
&mut self,
|
||||
a: &[ExtensionTarget<D>],
|
||||
b: &[ExtensionTarget<D>],
|
||||
w: ExtensionTarget<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]);
|
||||
let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]);
|
||||
|
||||
// TODO: Optimize this.
|
||||
let c0 = {
|
||||
let tmp0 = self.mul_extension(a0, b0);
|
||||
let tmp1 = self.mul_extension(a1, b3);
|
||||
let tmp2 = self.mul_extension(a2, b2);
|
||||
let tmp3 = self.mul_extension(a3, b1);
|
||||
let tmp = self.add_extension(tmp1, tmp2);
|
||||
let tmp = self.add_extension(tmp, tmp3);
|
||||
let tmp = self.mul_extension(w, tmp);
|
||||
let tmp = self.add_extension(tmp0, tmp);
|
||||
tmp
|
||||
};
|
||||
let c1 = {
|
||||
let tmp0 = self.mul_extension(a0, b1);
|
||||
let tmp1 = self.mul_extension(a1, b0);
|
||||
let tmp2 = self.mul_extension(a2, b3);
|
||||
let tmp3 = self.mul_extension(a3, b2);
|
||||
let tmp = self.add_extension(tmp2, tmp3);
|
||||
let tmp = self.mul_extension(w, tmp);
|
||||
let tmp = self.add_extension(tmp, tmp0);
|
||||
let tmp = self.add_extension(tmp, tmp1);
|
||||
tmp
|
||||
};
|
||||
let c2 = {
|
||||
let tmp0 = self.mul_extension(a0, b2);
|
||||
let tmp1 = self.mul_extension(a1, b1);
|
||||
let tmp2 = self.mul_extension(a2, b0);
|
||||
let tmp3 = self.mul_extension(a3, b3);
|
||||
let tmp = self.mul_extension(w, tmp3);
|
||||
let tmp = self.add_extension(tmp, tmp2);
|
||||
let tmp = self.add_extension(tmp, tmp1);
|
||||
let tmp = self.add_extension(tmp, tmp0);
|
||||
tmp
|
||||
};
|
||||
let c3 = {
|
||||
let tmp0 = self.mul_extension(a0, b3);
|
||||
let tmp1 = self.mul_extension(a1, b2);
|
||||
let tmp2 = self.mul_extension(a2, b1);
|
||||
let tmp3 = self.mul_extension(a3, b0);
|
||||
let tmp = self.add_extension(tmp3, tmp2);
|
||||
let tmp = self.add_extension(tmp, tmp1);
|
||||
let tmp = self.add_extension(tmp, tmp0);
|
||||
tmp
|
||||
};
|
||||
|
||||
vec![c0, c1, c2, c3]
|
||||
}
|
||||
}
|
||||
|
||||
/// A gate which can multiply two field extension elements.
|
||||
/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough.
|
||||
#[derive(Debug)]
|
||||
pub struct MulExtensionGate<const D: usize>;
|
||||
@ -38,13 +110,25 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0];
|
||||
let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1];
|
||||
let addend = vars.local_wires[Self::WIRE_ADDEND];
|
||||
let output = vars.local_wires[Self::WIRE_OUTPUT];
|
||||
let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend;
|
||||
vec![computed_output - output]
|
||||
let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec();
|
||||
let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec();
|
||||
let output = vars.local_wires[Self::wires_output()].to_vec();
|
||||
let computed_output = mul_vec(
|
||||
&[
|
||||
const_0,
|
||||
F::Extension::ZERO,
|
||||
F::Extension::ZERO,
|
||||
F::Extension::ZERO,
|
||||
],
|
||||
&multiplicand_0,
|
||||
F::Extension::W.into(),
|
||||
);
|
||||
let computed_output = mul_vec(&computed_output, &multiplicand_1, F::Extension::W.into());
|
||||
output
|
||||
.into_iter()
|
||||
.zip(computed_output)
|
||||
.map(|(o, co)| o - co)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
@ -53,16 +137,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0];
|
||||
let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1];
|
||||
let addend = vars.local_wires[Self::WIRE_ADDEND];
|
||||
let output = vars.local_wires[Self::WIRE_OUTPUT];
|
||||
|
||||
let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
|
||||
let addend_term = builder.mul_extension(const_1, addend);
|
||||
let computed_output = builder.add_many_extension(&[product_term, addend_term]);
|
||||
vec![builder.sub_extension(computed_output, output)]
|
||||
let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec();
|
||||
let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec();
|
||||
let output = vars.local_wires[Self::wires_output()].to_vec();
|
||||
let w = builder.constant_extension(F::Extension::W.into());
|
||||
let zero = builder.zero_extension();
|
||||
let computed_output = builder.mul_vec(&[const_0, zero, zero, zero], &multiplicand_0, w);
|
||||
let computed_output = builder.mul_vec(&computed_output, &multiplicand_1, w);
|
||||
output
|
||||
.into_iter()
|
||||
.zip(computed_output)
|
||||
.map(|(o, co)| builder.sub_extension(o, co))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn generators(
|
||||
@ -70,20 +156,19 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
|
||||
gate_index: usize,
|
||||
local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = ArithmeticGenerator {
|
||||
let gen = MulExtensionGenerator {
|
||||
gate_index,
|
||||
const_0: local_constants[0],
|
||||
const_1: local_constants[1],
|
||||
};
|
||||
vec![Box::new(gen)]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
4
|
||||
12
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
2
|
||||
1
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
@ -91,59 +176,60 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
1
|
||||
D
|
||||
}
|
||||
}
|
||||
|
||||
struct ArithmeticGenerator<F: Field> {
|
||||
struct MulExtensionGenerator<F: Extendable<D>, const D: usize> {
|
||||
gate_index: usize,
|
||||
const_0: F,
|
||||
const_1: F,
|
||||
}
|
||||
|
||||
impl<F: Field> SimpleGenerator<F> for ArithmeticGenerator<F> {
|
||||
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MulExtensionGenerator<F, D> {
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
vec![
|
||||
Target::Wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_MULTIPLICAND_0,
|
||||
}),
|
||||
Target::Wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_MULTIPLICAND_1,
|
||||
}),
|
||||
Target::Wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_ADDEND,
|
||||
}),
|
||||
]
|
||||
MulExtensionGate::<D>::wires_multiplicand_0()
|
||||
.chain(MulExtensionGate::<D>::wires_multiplicand_1())
|
||||
.map(|i| {
|
||||
Target::Wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: i,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
|
||||
let multiplicand_0_target = Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_MULTIPLICAND_0,
|
||||
};
|
||||
let multiplicand_1_target = Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_MULTIPLICAND_1,
|
||||
};
|
||||
let addend_target = Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_ADDEND,
|
||||
};
|
||||
let output_target = Wire {
|
||||
gate: self.gate_index,
|
||||
input: ArithmeticGate::WIRE_OUTPUT,
|
||||
};
|
||||
let multiplicand_0 = MulExtensionGate::<D>::wires_multiplicand_0()
|
||||
.map(|i| {
|
||||
witness.get_wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: i,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let multiplicand_0 = F::Extension::from_basefield_array(multiplicand_0.try_into().unwrap());
|
||||
let multiplicand_1 = MulExtensionGate::<D>::wires_multiplicand_1()
|
||||
.map(|i| {
|
||||
witness.get_wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: i,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let multiplicand_1 = F::Extension::from_basefield_array(multiplicand_1.try_into().unwrap());
|
||||
let output = MulExtensionGate::<D>::wires_output()
|
||||
.map(|i| Wire {
|
||||
gate: self.gate_index,
|
||||
input: i,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let multiplicand_0 = witness.get_wire(multiplicand_0_target);
|
||||
let multiplicand_1 = witness.get_wire(multiplicand_1_target);
|
||||
let addend = witness.get_wire(addend_target);
|
||||
let computed_output =
|
||||
F::Extension::from_basefield(self.const_0) * multiplicand_0 * multiplicand_1;
|
||||
|
||||
let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend;
|
||||
|
||||
PartialWitness::singleton_wire(output_target, output)
|
||||
let mut pw = PartialWitness::new();
|
||||
pw.set_ext_wires(output, computed_output);
|
||||
pw
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,9 +238,10 @@ mod tests {
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::gates::arithmetic::ArithmeticGate;
|
||||
use crate::gates::gate_testing::test_low_degree;
|
||||
use crate::gates::mul_extension::MulExtensionGate;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree(ArithmeticGate::new::<CrandallField, 4>())
|
||||
test_low_degree(MulExtensionGate::<4>::new::<CrandallField>())
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user