From a8da9b945ec799044a3f36bba5b3aadebfcf442a Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 7 Jun 2021 17:09:53 +0200 Subject: [PATCH] Working MulExtensionGate --- src/fri/recursive_verifier.rs | 75 ++++++------ src/gadgets/arithmetic.rs | 22 ++-- src/gates/mul_extension.rs | 215 ++++++++++++++++++++++++---------- 3 files changed, 203 insertions(+), 109 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 6fe38b45..edb9b075 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -143,7 +143,7 @@ impl, const D: usize> CircuitBuilder { subgroup_x: Target, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); - let config = &self.config.fri_config; + let config = &self.config.fri_config.clone(); let degree_log = proof.evals_proofs[0].1.siblings.len() - config.rate_bits; let subgroup_x = self.convert_to_ext(subgroup_x); let mut alpha_powers = self.powers(alpha); @@ -157,7 +157,8 @@ impl, const D: usize> CircuitBuilder { let evals = [0, 1, 4] .iter() .flat_map(|&i| proof.unsalted_evals(i, config)) - .map(|&e| self.convert_to_ext(e)); + .map(|&e| self.convert_to_ext(e)) + .collect::>(); let openings = os .constants .iter() @@ -170,42 +171,42 @@ impl, const D: usize> CircuitBuilder { numerator = self.mul_add_extension(a, diff, numerator); } let denominator = self.sub_extension(subgroup_x, zeta); - // let quotient = self.div_unsafe() - sum += numerator / denominator; + let quotient = self.div_unsafe_extension(numerator, denominator); + let sum = self.add_extension(sum, quotient); - let ev: F::Extension = proof - .unsalted_evals(3, config) - .iter() - .zip(alpha_powers.clone()) - .map(|(&e, a)| a * e.into()) - .sum(); - let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta; - let zs_interpol = interpolant(&[ - (zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())), - ( - zeta_right, - reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers), - ), - ]); - let numerator = ev - zs_interpol.eval(subgroup_x); - let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_right); - sum += numerator / denominator; - - let ev: F::Extension = proof - .unsalted_evals(2, config) - .iter() - .zip(alpha_powers.clone()) - .map(|(&e, a)| a * e.into()) - .sum(); - let zeta_frob = zeta.frobenius(); - let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::>(); - let wires_interpol = interpolant(&[ - (zeta, reduce_with_iter(&os.wires, alpha_powers.clone())), - (zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)), - ]); - let numerator = ev - wires_interpol.eval(subgroup_x); - let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob); - sum += numerator / denominator; + // let ev: F::Extension = proof + // .unsalted_evals(3, config) + // .iter() + // .zip(alpha_powers.clone()) + // .map(|(&e, a)| a * e.into()) + // .sum(); + // let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta; + // let zs_interpol = interpolant(&[ + // (zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())), + // ( + // zeta_right, + // reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers), + // ), + // ]); + // let numerator = ev - zs_interpol.eval(subgroup_x); + // let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_right); + // sum += numerator / denominator; + // + // let ev: F::Extension = proof + // .unsalted_evals(2, config) + // .iter() + // .zip(alpha_powers.clone()) + // .map(|(&e, a)| a * e.into()) + // .sum(); + // let zeta_frob = zeta.frobenius(); + // let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::>(); + // let wires_interpol = interpolant(&[ + // (zeta, reduce_with_iter(&os.wires, alpha_powers.clone())), + // (zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)), + // ]); + // let numerator = ev - wires_interpol.eval(subgroup_x); + // let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob); + // sum += numerator / denominator; sum } diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index a84caf03..fd32c83d 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,6 +1,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; use crate::generator::SimpleGenerator; @@ -254,11 +254,12 @@ impl, const D: usize> CircuitBuilder { }; let q = Target::Wire(wire_multiplicand_0); - self.add_generator(QuotientGeneratorExtension { - numerator: x, - denominator: ExtensionTarget(), - quotient: ExtensionTarget(), - }) + todo!() + // self.add_generator(QuotientGeneratorExtension { + // numerator: x, + // denominator: ExtensionTarget(), + // quotient: ExtensionTarget(), + // }) } } @@ -286,7 +287,7 @@ struct QuotientGeneratorExtension { quotient: ExtensionTarget, } -impl SimpleGenerator for QuotientGeneratorExtension { +impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { fn dependencies(&self) -> Vec { let mut deps = self.numerator.to_target_array().to_vec(); deps.extend(&self.denominator.to_target_array()); @@ -298,7 +299,12 @@ impl SimpleGenerator for QuotientGeneratorExtension let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; let mut pw = PartialWitness::new(); - pw.set_ext_wires(self.quotient.to_target_array(), quotient); + for i in 0..D { + pw.set_target( + self.quotient.to_target_array()[i], + quotient.to_basefield_array()[i], + ); + } pw } } diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs index ec5ba671..18ec5827 100644 --- a/src/gates/mul_extension.rs +++ b/src/gates/mul_extension.rs @@ -1,6 +1,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; +use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; @@ -8,9 +8,81 @@ use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; use crate::witness::PartialWitness; +use std::convert::TryInto; use std::ops::Range; -/// A gate which can multiply to field extension elements. +// TODO: Replace this when https://github.com/mir-protocol/plonky2/issues/56 is resolved. +fn mul_vec(a: &[F], b: &[F], w: F) -> Vec { + let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]); + let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]); + + let c0 = a0 * b0 + w * (a1 * b3 + a2 * b2 + a3 * b1); + let c1 = a0 * b1 + a1 * b0 + w * (a2 * b3 + a3 * b2); + let c2 = a0 * b2 + a1 * b1 + a2 * b0 + w * a3 * b3; + let c3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0; + + vec![c0, c1, c2, c3] +} +impl, const D: usize> CircuitBuilder { + fn mul_vec( + &mut self, + a: &[ExtensionTarget], + b: &[ExtensionTarget], + w: ExtensionTarget, + ) -> Vec> { + let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]); + let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]); + + // TODO: Optimize this. + let c0 = { + let tmp0 = self.mul_extension(a0, b0); + let tmp1 = self.mul_extension(a1, b3); + let tmp2 = self.mul_extension(a2, b2); + let tmp3 = self.mul_extension(a3, b1); + let tmp = self.add_extension(tmp1, tmp2); + let tmp = self.add_extension(tmp, tmp3); + let tmp = self.mul_extension(w, tmp); + let tmp = self.add_extension(tmp0, tmp); + tmp + }; + let c1 = { + let tmp0 = self.mul_extension(a0, b1); + let tmp1 = self.mul_extension(a1, b0); + let tmp2 = self.mul_extension(a2, b3); + let tmp3 = self.mul_extension(a3, b2); + let tmp = self.add_extension(tmp2, tmp3); + let tmp = self.mul_extension(w, tmp); + let tmp = self.add_extension(tmp, tmp0); + let tmp = self.add_extension(tmp, tmp1); + tmp + }; + let c2 = { + let tmp0 = self.mul_extension(a0, b2); + let tmp1 = self.mul_extension(a1, b1); + let tmp2 = self.mul_extension(a2, b0); + let tmp3 = self.mul_extension(a3, b3); + let tmp = self.mul_extension(w, tmp3); + let tmp = self.add_extension(tmp, tmp2); + let tmp = self.add_extension(tmp, tmp1); + let tmp = self.add_extension(tmp, tmp0); + tmp + }; + let c3 = { + let tmp0 = self.mul_extension(a0, b3); + let tmp1 = self.mul_extension(a1, b2); + let tmp2 = self.mul_extension(a2, b1); + let tmp3 = self.mul_extension(a3, b0); + let tmp = self.add_extension(tmp3, tmp2); + let tmp = self.add_extension(tmp, tmp1); + let tmp = self.add_extension(tmp, tmp0); + tmp + }; + + vec![c0, c1, c2, c3] + } +} + +/// A gate which can multiply two field extension elements. /// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough. #[derive(Debug)] pub struct MulExtensionGate; @@ -38,13 +110,25 @@ impl, const D: usize> Gate for MulExtensionGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend; - vec![computed_output - output] + let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec(); + let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec(); + let output = vars.local_wires[Self::wires_output()].to_vec(); + let computed_output = mul_vec( + &[ + const_0, + F::Extension::ZERO, + F::Extension::ZERO, + F::Extension::ZERO, + ], + &multiplicand_0, + F::Extension::W.into(), + ); + let computed_output = mul_vec(&computed_output, &multiplicand_1, F::Extension::W.into()); + output + .into_iter() + .zip(computed_output) + .map(|(o, co)| o - co) + .collect() } fn eval_unfiltered_recursively( @@ -53,16 +137,18 @@ impl, const D: usize> Gate for MulExtensionGate { vars: EvaluationTargets, ) -> Vec> { let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - - let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); - let addend_term = builder.mul_extension(const_1, addend); - let computed_output = builder.add_many_extension(&[product_term, addend_term]); - vec![builder.sub_extension(computed_output, output)] + let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec(); + let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec(); + let output = vars.local_wires[Self::wires_output()].to_vec(); + let w = builder.constant_extension(F::Extension::W.into()); + let zero = builder.zero_extension(); + let computed_output = builder.mul_vec(&[const_0, zero, zero, zero], &multiplicand_0, w); + let computed_output = builder.mul_vec(&computed_output, &multiplicand_1, w); + output + .into_iter() + .zip(computed_output) + .map(|(o, co)| builder.sub_extension(o, co)) + .collect() } fn generators( @@ -70,20 +156,19 @@ impl, const D: usize> Gate for MulExtensionGate { gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = ArithmeticGenerator { + let gen = MulExtensionGenerator { gate_index, const_0: local_constants[0], - const_1: local_constants[1], }; vec![Box::new(gen)] } fn num_wires(&self) -> usize { - 4 + 12 } fn num_constants(&self) -> usize { - 2 + 1 } fn degree(&self) -> usize { @@ -91,59 +176,60 @@ impl, const D: usize> Gate for MulExtensionGate { } fn num_constraints(&self) -> usize { - 1 + D } } -struct ArithmeticGenerator { +struct MulExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, - const_1: F, } -impl SimpleGenerator for ArithmeticGenerator { +impl, const D: usize> SimpleGenerator for MulExtensionGenerator { fn dependencies(&self) -> Vec { - vec![ - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }), - ] + MulExtensionGate::::wires_multiplicand_0() + .chain(MulExtensionGate::::wires_multiplicand_1()) + .map(|i| { + Target::Wire(Wire { + gate: self.gate_index, + input: i, + }) + }) + .collect() } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let multiplicand_1_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let addend_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }; - let output_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_OUTPUT, - }; + let multiplicand_0 = MulExtensionGate::::wires_multiplicand_0() + .map(|i| { + witness.get_wire(Wire { + gate: self.gate_index, + input: i, + }) + }) + .collect::>(); + let multiplicand_0 = F::Extension::from_basefield_array(multiplicand_0.try_into().unwrap()); + let multiplicand_1 = MulExtensionGate::::wires_multiplicand_1() + .map(|i| { + witness.get_wire(Wire { + gate: self.gate_index, + input: i, + }) + }) + .collect::>(); + let multiplicand_1 = F::Extension::from_basefield_array(multiplicand_1.try_into().unwrap()); + let output = MulExtensionGate::::wires_output() + .map(|i| Wire { + gate: self.gate_index, + input: i, + }) + .collect::>(); - let multiplicand_0 = witness.get_wire(multiplicand_0_target); - let multiplicand_1 = witness.get_wire(multiplicand_1_target); - let addend = witness.get_wire(addend_target); + let computed_output = + F::Extension::from_basefield(self.const_0) * multiplicand_0 * multiplicand_1; - let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend; - - PartialWitness::singleton_wire(output_target, output) + let mut pw = PartialWitness::new(); + pw.set_ext_wires(output, computed_output); + pw } } @@ -152,9 +238,10 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::gates::arithmetic::ArithmeticGate; use crate::gates::gate_testing::test_low_degree; + use crate::gates::mul_extension::MulExtensionGate; #[test] fn low_degree() { - test_low_degree(ArithmeticGate::new::()) + test_low_degree(MulExtensionGate::<4>::new::()) } }