From 9adf5bb43f530f49962065b3d4be7494e909f043 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 9 Jun 2021 10:51:50 +0200 Subject: [PATCH] Use `ExtensionAlgebra` + new `CircuitBuilder::mul_extension` --- src/field/extension_field/mod.rs | 7 +- src/field/extension_field/target.rs | 59 +++++++++++--- src/fri/recursive_verifier.rs | 6 +- src/gadgets/arithmetic.rs | 16 ++-- src/gadgets/polynomial.rs | 2 +- src/gadgets/split_base.rs | 8 +- src/gates/arithmetic.rs | 2 +- src/gates/base_sum.rs | 2 +- src/gates/gmimc.rs | 6 +- src/gates/interpolation.rs | 4 +- src/gates/mul_extension.rs | 118 +++------------------------- src/target.rs | 5 ++ src/wire.rs | 5 ++ 13 files changed, 95 insertions(+), 145 deletions(-) diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index 60d2b2e1..d706a341 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -1,4 +1,5 @@ use crate::field::field::Field; +use std::convert::TryInto; pub mod algebra; pub mod quadratic; @@ -88,10 +89,6 @@ where { debug_assert_eq!(l.len() % D, 0); l.chunks_exact(D) - .map(|c| { - let mut arr = [F::ZERO; D]; - arr.copy_from_slice(c); - F::Extension::from_basefield_array(arr) - }) + .map(|c| F::Extension::from_basefield_array(c.to_vec().try_into().unwrap())) .collect() } diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 901cf854..17d28172 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -2,7 +2,10 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; +use crate::gates::mul_extension::MulExtensionGate; use crate::target::Target; +use std::convert::{TryFrom, TryInto}; +use std::ops::Range; /// `Target`s representing an element of an extension field. #[derive(Copy, Clone, Debug)] @@ -26,6 +29,19 @@ impl ExtensionTarget { Self(res) } + + pub fn from_range(gate: usize, range: Range) -> Self { + debug_assert_eq!(range.end - range.start, D); + Target::wires_from_range(gate, range).try_into().unwrap() + } +} + +impl TryFrom> for ExtensionTarget { + type Error = Vec; + + fn try_from(value: Vec) -> Result { + Ok(Self(value.try_into()?)) + } } /// `Target`s representing an element of an extension of an extension field. @@ -128,7 +144,34 @@ impl, const D: usize> CircuitBuilder { a } + pub fn mul_extension_with_const( + &mut self, + const_0: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + let gate = self.add_gate(MulExtensionGate::new(), vec![const_0]); + + let wire_multiplicand_0 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_0()); + let wire_multiplicand_1 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_1()); + let wire_output = ExtensionTarget::from_range(gate, MulExtensionGate::::wires_output()); + + self.route_extension(multiplicand_0, wire_multiplicand_0); + self.route_extension(multiplicand_1, wire_multiplicand_1); + wire_output + } + pub fn mul_extension( + &mut self, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) + } + + pub fn mul_extension_naive( &mut self, a: ExtensionTarget, b: ExtensionTarget, @@ -156,7 +199,7 @@ impl, const D: usize> CircuitBuilder { let w = self.constant(F::Extension::W); for i in 0..D { for j in 0..D { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); + let ai_bi = self.mul_extension_naive(a.0[i], b.0[j]); res[(i + j) % D] = if i + j < D { self.add_extension(ai_bi, res[(i + j) % D]) } else { @@ -171,7 +214,7 @@ impl, const D: usize> CircuitBuilder { pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let mut product = self.one_extension(); for term in terms { - product = self.mul_extension(product, *term); + product = self.mul_extension_naive(product, *term); } product } @@ -184,7 +227,7 @@ impl, const D: usize> CircuitBuilder { b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { - let product = self.mul_extension(a, b); + let product = self.mul_extension_naive(a, b); self.add_extension(product, c) } @@ -204,7 +247,7 @@ impl, const D: usize> CircuitBuilder { mut b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { for i in 0..D { - b.0[i] = self.mul_extension(a, b.0[i]); + b.0[i] = self.mul_extension_naive(a, b.0[i]); } b } @@ -225,13 +268,9 @@ pub fn flatten_target(l: &[ExtensionTarget]) -> Vec { } /// Batch every D-sized chunks into extension targets. -pub fn unflatten_target(l: &[Target]) -> Vec> { +pub fn unflatten_target, const D: usize>(l: &[Target]) -> Vec> { debug_assert_eq!(l.len() % D, 0); l.chunks_exact(D) - .map(|c| { - let mut arr = Default::default(); - arr.copy_from_slice(c); - ExtensionTarget(arr) - }) + .map(|c| c.to_vec().try_into().unwrap()) .collect() } diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 38948bf7..2ae287f4 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -187,7 +187,7 @@ impl, const D: usize> CircuitBuilder { } let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); - let zeta_right = self.mul_extension(g, zeta); + let zeta_right = self.mul_extension_naive(g, zeta); let mut ev_zeta = self.zero_extension(); for &t in &os.plonk_zs { let a = alpha_powers.next(self); @@ -203,7 +203,7 @@ impl, const D: usize> CircuitBuilder { let numerator = self.sub_extension(ev, interpol_val); let vanish = self.sub_extension(subgroup_x, zeta); let vanish_right = self.sub_extension(subgroup_x, zeta_right); - let denominator = self.mul_extension(vanish, vanish_right); + let denominator = self.mul_extension_naive(vanish, vanish_right); let quotient = self.div_unsafe_extension(numerator, denominator); let sum = self.add_extension(sum, quotient); @@ -237,7 +237,7 @@ impl, const D: usize> CircuitBuilder { let interpol_val = wires_interpol.eval(self, subgroup_x); let numerator = self.sub_extension(ev, interpol_val); let vanish_frob = self.sub_extension(subgroup_x, zeta_frob); - let denominator = self.mul_extension(vanish, vanish_frob); + let denominator = self.mul_extension_naive(vanish, vanish_frob); let quotient = self.div_unsafe_extension(numerator, denominator); let sum = self.add_extension(sum, quotient); diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 4a2e4bd3..98c4d0d0 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -238,17 +238,13 @@ impl, const D: usize> CircuitBuilder { // Add an `ArithmeticGate` to compute `q * y`. let gate = self.add_gate(MulExtensionGate::new(), vec![F::ONE]); - let multiplicand_0 = MulExtensionGate::::wires_multiplicand_0() - .map(|i| Target::Wire(Wire { gate, input: i })) - .collect::>(); + let multiplicand_0 = + Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_0()); let multiplicand_0 = ExtensionTarget(multiplicand_0.try_into().unwrap()); - let multiplicand_1 = MulExtensionGate::::wires_multiplicand_1() - .map(|i| Target::Wire(Wire { gate, input: i })) - .collect::>(); + let multiplicand_1 = + Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_1()); let multiplicand_1 = ExtensionTarget(multiplicand_1.try_into().unwrap()); - let output = MulExtensionGate::::wires_output() - .map(|i| Target::Wire(Wire { gate, input: i })) - .collect::>(); + let output = Target::wires_from_range(gate, MulExtensionGate::::wires_output()); let output = ExtensionTarget(output.try_into().unwrap()); self.add_generator(QuotientGeneratorExtension { @@ -324,7 +320,7 @@ impl PowersTarget { builder: &mut CircuitBuilder, ) -> ExtensionTarget { let result = self.current; - self.current = builder.mul_extension(self.base, self.current); + self.current = builder.mul_extension_naive(self.base, self.current); result } } diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index 543be834..9ccfc6a8 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -26,7 +26,7 @@ impl PolynomialCoeffsExtTarget { ) -> ExtensionTarget { let mut acc = builder.zero_extension(); for &c in self.0.iter().rev() { - let tmp = builder.mul_extension(point, acc); + let tmp = builder.mul_extension_naive(point, acc); acc = builder.add_extension(tmp, c); } acc diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 35d5ac93..810939fb 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -19,9 +19,11 @@ impl, const D: usize> CircuitBuilder { input: BaseSumGate::::WIRE_SUM, }); self.route(x, sum); - (BaseSumGate::::WIRE_LIMBS_START..BaseSumGate::::WIRE_LIMBS_START + num_limbs) - .map(|i| Target::Wire(Wire { gate, input: i })) - .collect() + + Target::wires_from_range( + gate, + BaseSumGate::::WIRE_LIMBS_START..BaseSumGate::::WIRE_LIMBS_START + num_limbs, + ) } /// Asserts that `x`'s bit representation has at least `trailing_zeros` trailing zeros. diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 0d0fdd7c..8208f0f8 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -57,7 +57,7 @@ impl, const D: usize> Gate for ArithmeticGate { let output = vars.local_wires[Self::WIRE_OUTPUT]; let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); - let addend_term = builder.mul_extension(const_1, addend); + let addend_term = builder.mul_extension_naive(const_1, addend); let computed_output = builder.add_many_extension(&[product_term, addend_term]); vec![builder.sub_extension(computed_output, output)] } diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 8c8064a1..f56519e1 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -69,7 +69,7 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat (0..B).for_each(|i| { let it = builder.constant_extension(F::from_canonical_usize(i).into()); let diff = builder.sub_extension(limb, it); - acc = builder.mul_extension(acc, diff); + acc = builder.mul_extension_naive(acc, diff); }); acc }); diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 19042d57..7c75951d 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -131,7 +131,7 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< let swap = vars.local_wires[Self::WIRE_SWAP]; let one_ext = builder.one_extension(); let not_swap = builder.sub_extension(swap, one_ext); - constraints.push(builder.mul_extension(swap, not_swap)); + constraints.push(builder.mul_extension_naive(swap, not_swap)); let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD]; let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW]; @@ -168,8 +168,8 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< let constant = builder.constant_extension(self.constants[r].into()); let cubing_input = builder.add_many_extension(&[state[active], addition_buffer, constant]); - let square = builder.mul_extension(cubing_input, cubing_input); - let f = builder.mul_extension(square, cubing_input); + let square = builder.mul_extension_naive(cubing_input, cubing_input); + let f = builder.mul_extension_naive(square, cubing_input); addition_buffer = builder.add_extension(addition_buffer, f); state[active] = builder.sub_extension(state[active], f); } diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 9bf76c2c..b3fd35d0 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -357,9 +357,7 @@ mod tests { }; assert!( - gate.eval_unfiltered(vars.clone()) - .iter() - .all(|x| x.is_zero()), + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), "Gate constraints are not satisfied." ); } diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs index 18ec5827..00a0ad8d 100644 --- a/src/gates/mul_extension.rs +++ b/src/gates/mul_extension.rs @@ -1,7 +1,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension, OEF}; -use crate::field::field::Field; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::target::Target; @@ -11,77 +10,6 @@ use crate::witness::PartialWitness; use std::convert::TryInto; use std::ops::Range; -// TODO: Replace this when https://github.com/mir-protocol/plonky2/issues/56 is resolved. -fn mul_vec(a: &[F], b: &[F], w: F) -> Vec { - let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]); - let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]); - - let c0 = a0 * b0 + w * (a1 * b3 + a2 * b2 + a3 * b1); - let c1 = a0 * b1 + a1 * b0 + w * (a2 * b3 + a3 * b2); - let c2 = a0 * b2 + a1 * b1 + a2 * b0 + w * a3 * b3; - let c3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0; - - vec![c0, c1, c2, c3] -} -impl, const D: usize> CircuitBuilder { - fn mul_vec( - &mut self, - a: &[ExtensionTarget], - b: &[ExtensionTarget], - w: ExtensionTarget, - ) -> Vec> { - let (a0, a1, a2, a3) = (a[0], a[1], a[2], a[3]); - let (b0, b1, b2, b3) = (b[0], b[1], b[2], b[3]); - - // TODO: Optimize this. - let c0 = { - let tmp0 = self.mul_extension(a0, b0); - let tmp1 = self.mul_extension(a1, b3); - let tmp2 = self.mul_extension(a2, b2); - let tmp3 = self.mul_extension(a3, b1); - let tmp = self.add_extension(tmp1, tmp2); - let tmp = self.add_extension(tmp, tmp3); - let tmp = self.mul_extension(w, tmp); - let tmp = self.add_extension(tmp0, tmp); - tmp - }; - let c1 = { - let tmp0 = self.mul_extension(a0, b1); - let tmp1 = self.mul_extension(a1, b0); - let tmp2 = self.mul_extension(a2, b3); - let tmp3 = self.mul_extension(a3, b2); - let tmp = self.add_extension(tmp2, tmp3); - let tmp = self.mul_extension(w, tmp); - let tmp = self.add_extension(tmp, tmp0); - let tmp = self.add_extension(tmp, tmp1); - tmp - }; - let c2 = { - let tmp0 = self.mul_extension(a0, b2); - let tmp1 = self.mul_extension(a1, b1); - let tmp2 = self.mul_extension(a2, b0); - let tmp3 = self.mul_extension(a3, b3); - let tmp = self.mul_extension(w, tmp3); - let tmp = self.add_extension(tmp, tmp2); - let tmp = self.add_extension(tmp, tmp1); - let tmp = self.add_extension(tmp, tmp0); - tmp - }; - let c3 = { - let tmp0 = self.mul_extension(a0, b3); - let tmp1 = self.mul_extension(a1, b2); - let tmp2 = self.mul_extension(a2, b1); - let tmp3 = self.mul_extension(a3, b0); - let tmp = self.add_extension(tmp3, tmp2); - let tmp = self.add_extension(tmp, tmp1); - let tmp = self.add_extension(tmp, tmp0); - tmp - }; - - vec![c0, c1, c2, c3] - } -} - /// A gate which can multiply two field extension elements. /// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough. #[derive(Debug)] @@ -110,25 +38,11 @@ impl, const D: usize> Gate for MulExtensionGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; - let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec(); - let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec(); - let output = vars.local_wires[Self::wires_output()].to_vec(); - let computed_output = mul_vec( - &[ - const_0, - F::Extension::ZERO, - F::Extension::ZERO, - F::Extension::ZERO, - ], - &multiplicand_0, - F::Extension::W.into(), - ); - let computed_output = mul_vec(&computed_output, &multiplicand_1, F::Extension::W.into()); - output - .into_iter() - .zip(computed_output) - .map(|(o, co)| o - co) - .collect() + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let output = vars.get_local_ext_algebra(Self::wires_output()); + let computed_output = multiplicand_0 * multiplicand_1 * const_0.into(); + (output - computed_output).to_basefield_array().to_vec() } fn eval_unfiltered_recursively( @@ -137,18 +51,13 @@ impl, const D: usize> Gate for MulExtensionGate { vars: EvaluationTargets, ) -> Vec> { let const_0 = vars.local_constants[0]; - let multiplicand_0 = vars.local_wires[Self::wires_multiplicand_0()].to_vec(); - let multiplicand_1 = vars.local_wires[Self::wires_multiplicand_1()].to_vec(); - let output = vars.local_wires[Self::wires_output()].to_vec(); - let w = builder.constant_extension(F::Extension::W.into()); - let zero = builder.zero_extension(); - let computed_output = builder.mul_vec(&[const_0, zero, zero, zero], &multiplicand_0, w); - let computed_output = builder.mul_vec(&computed_output, &multiplicand_1, w); - output - .into_iter() - .zip(computed_output) - .map(|(o, co)| builder.sub_extension(o, co)) - .collect() + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let output = vars.get_local_ext_algebra(Self::wires_output()); + let computed_output = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); + let computed_output = builder.scalar_mul_ext_algebra(const_0, computed_output); + let diff = builder.sub_ext_algebra(output, computed_output); + diff.to_ext_target_array().to_vec() } fn generators( @@ -236,7 +145,6 @@ impl, const D: usize> SimpleGenerator for MulExtensionGenera #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::gates::arithmetic::ArithmeticGate; use crate::gates::gate_testing::test_low_degree; use crate::gates::mul_extension::MulExtensionGate; diff --git a/src/target.rs b/src/target.rs index b5736564..423865fa 100644 --- a/src/target.rs +++ b/src/target.rs @@ -1,5 +1,6 @@ use crate::circuit_data::CircuitConfig; use crate::wire::Wire; +use std::ops::Range; /// A location in the witness. #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] @@ -21,4 +22,8 @@ impl Target { Target::VirtualAdviceTarget { .. } => false, } } + + pub fn wires_from_range(gate: usize, range: Range) -> Vec { + range.map(|i| Self::wire(gate, i)).collect() + } } diff --git a/src/wire.rs b/src/wire.rs index 61b7f5be..02d43029 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -1,4 +1,5 @@ use crate::circuit_data::CircuitConfig; +use std::ops::Range; /// Represents a wire in the circuit. #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] @@ -13,4 +14,8 @@ impl Wire { pub fn is_routable(&self, config: &CircuitConfig) -> bool { self.input < config.num_routed_wires } + + pub fn from_range(gate: usize, range: Range) -> Vec { + range.map(|i| Wire { gate, input: i }).collect() + } }