From 50b07f2ceb91b8390d7e2f5759447ff2f41329c9 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Fri, 30 Jul 2021 09:03:11 -0700 Subject: [PATCH] Special cases for extension field arithmetic (#138) We previously checked for special cases, like arithmetic on constant Targets, in `arithmetic`. We can handle those cases without actually adding an `ArithmeticGate`. Now that `arithmetic` just calls `arithmetic_extension`, it makes more sense to check for special cases in the latter method, so it applies to both base and extension field arithmetic. Reduces gate count from 16149 to 15689. --- src/field/extension_field/target.rs | 2 +- src/gadgets/arithmetic.rs | 64 -------------------------- src/gadgets/arithmetic_extension.rs | 70 +++++++++++++++++++++++++++++ src/plonk/circuit_builder.rs | 20 ++++++++- 4 files changed, 90 insertions(+), 66 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 9c2e9648..c7f9dc39 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -8,7 +8,7 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; /// `Target`s representing an element of an extension field. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Eq, PartialEq, Debug)] pub struct ExtensionTarget(pub [Target; D]); impl ExtensionTarget { diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 4163e00b..4d9ece34 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -33,12 +33,6 @@ impl, const D: usize> CircuitBuilder { const_1: F, addend: Target, ) -> Target { - // See if we can determine the result without adding an `ArithmeticGate`. - if let Some(result) = - self.arithmetic_special_cases(const_0, multiplicand_0, multiplicand_1, const_1, addend) - { - return result; - } let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); let addend_ext = self.convert_to_ext(addend); @@ -53,64 +47,6 @@ impl, const D: usize> CircuitBuilder { .0[0] } - /// Checks for special cases where the value of - /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` - /// can be determined without adding an `ArithmeticGate`. - fn arithmetic_special_cases( - &mut self, - const_0: F, - multiplicand_0: Target, - multiplicand_1: Target, - const_1: F, - addend: Target, - ) -> Option { - let zero = self.zero(); - - let mul_0_const = self.target_as_constant(multiplicand_0); - let mul_1_const = self.target_as_constant(multiplicand_1); - let addend_const = self.target_as_constant(addend); - - let first_term_zero = - const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; - let second_term_zero = const_1 == F::ZERO || addend == zero; - - // If both terms are constant, return their (constant) sum. - let first_term_const = if first_term_zero { - Some(F::ZERO) - } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { - Some(const_0 * x * y) - } else { - None - }; - let second_term_const = if second_term_zero { - Some(F::ZERO) - } else { - addend_const.map(|x| const_1 * x) - }; - if let (Some(x), Some(y)) = (first_term_const, second_term_const) { - return Some(self.constant(x + y)); - } - - if first_term_zero && const_1.is_one() { - return Some(addend); - } - - if second_term_zero { - if let Some(x) = mul_0_const { - if (const_0 * x).is_one() { - return Some(multiplicand_1); - } - } - if let Some(x) = mul_1_const { - if (const_1 * x).is_one() { - return Some(multiplicand_0); - } - } - } - - None - } - /// Computes `x * y + z`. pub fn mul_add(&mut self, x: Target, y: Target, z: Target) -> Target { self.arithmetic(F::ONE, x, y, F::ONE, z) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 59ad71d5..e274b781 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -5,6 +5,7 @@ use num::Integer; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; +use crate::field::field_types::Field; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; @@ -68,6 +69,17 @@ impl, const D: usize> CircuitBuilder { multiplicand_1: ExtensionTarget, addend: ExtensionTarget, ) -> ExtensionTarget { + // See if we can determine the result without adding an `ArithmeticGate`. + if let Some(result) = self.arithmetic_extension_special_cases( + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + ) { + return result; + } + let zero = self.zero_extension(); self.double_arithmetic_extension( const_0, @@ -82,6 +94,64 @@ impl, const D: usize> CircuitBuilder { .0 } + /// Checks for special cases where the value of + /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` + /// can be determined without adding an `ArithmeticGate`. + fn arithmetic_extension_special_cases( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, + ) -> Option> { + let zero = self.zero_extension(); + + let mul_0_const = self.target_as_constant_ext(multiplicand_0); + let mul_1_const = self.target_as_constant_ext(multiplicand_1); + let addend_const = self.target_as_constant_ext(addend); + + let first_term_zero = + const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let second_term_zero = const_1 == F::ZERO || addend == zero; + + // If both terms are constant, return their (constant) sum. + let first_term_const = if first_term_zero { + Some(F::Extension::ZERO) + } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { + Some(x * y * const_0.into()) + } else { + None + }; + let second_term_const = if second_term_zero { + Some(F::Extension::ZERO) + } else { + addend_const.map(|x| x * const_1.into()) + }; + if let (Some(x), Some(y)) = (first_term_const, second_term_const) { + return Some(self.constant_extension(x + y)); + } + + if first_term_zero && const_1.is_one() { + return Some(addend); + } + + if second_term_zero { + if let Some(x) = mul_0_const { + if (x * const_0.into()).is_one() { + return Some(multiplicand_1); + } + } + if let Some(x) = mul_1_const { + if (x * const_1.into()).is_one() { + return Some(multiplicand_0); + } + } + } + + None + } + pub fn add_extension( &mut self, a: ExtensionTarget, diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 3a905cf5..8a0eadc0 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -6,7 +6,7 @@ use log::info; use crate::field::cosets::get_unique_coset_shifts; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; @@ -301,6 +301,24 @@ impl, const D: usize> CircuitBuilder { self.targets_to_constants.get(&target).cloned() } + /// If the given `ExtensionTarget` is a constant (i.e. it was created by the + /// `constant_extension(F)` method), returns its constant value. Otherwise, returns `None`. + pub fn target_as_constant_ext(&self, target: ExtensionTarget) -> Option { + // Get a Vec of any coefficients that are constant. If we end up with exactly D of them, + // then the `ExtensionTarget` as a whole is constant. + let const_coeffs: Vec = target + .0 + .into_iter() + .filter_map(|&t| self.target_as_constant(t)) + .collect(); + + if let Ok(d_const_coeffs) = const_coeffs.try_into() { + Some(F::Extension::from_basefield_array(d_const_coeffs)) + } else { + None + } + } + pub fn push_context(&mut self, level: log::Level, ctx: &str) { self.context_log.push(ctx, level, self.num_gates()); }