Special cases for extension field arithmetic (#138)

We previously checked for special cases, like arithmetic on constant Targets, in `arithmetic`. We can handle those cases without actually adding an `ArithmeticGate`.

Now that `arithmetic` just calls `arithmetic_extension`, it makes more sense to check for special cases in the latter method, so it applies to both base and extension field arithmetic.

Reduces gate count from 16149 to 15689.
This commit is contained in:
Daniel Lubarov 2021-07-30 09:03:11 -07:00 committed by GitHub
parent bb316fb146
commit 50b07f2ceb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 66 deletions

View File

@ -8,7 +8,7 @@ use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
/// `Target`s representing an element of an extension field.
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct ExtensionTarget<const D: usize>(pub [Target; D]);
impl<const D: usize> ExtensionTarget<D> {

View File

@ -33,12 +33,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
const_1: F,
addend: Target,
) -> Target {
// See if we can determine the result without adding an `ArithmeticGate`.
if let Some(result) =
self.arithmetic_special_cases(const_0, multiplicand_0, multiplicand_1, const_1, addend)
{
return result;
}
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
let addend_ext = self.convert_to_ext(addend);
@ -53,64 +47,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.0[0]
}
/// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`.
fn arithmetic_special_cases(
&mut self,
const_0: F,
multiplicand_0: Target,
multiplicand_1: Target,
const_1: F,
addend: Target,
) -> Option<Target> {
let zero = self.zero();
let mul_0_const = self.target_as_constant(multiplicand_0);
let mul_1_const = self.target_as_constant(multiplicand_1);
let addend_const = self.target_as_constant(addend);
let first_term_zero =
const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero;
let second_term_zero = const_1 == F::ZERO || addend == zero;
// If both terms are constant, return their (constant) sum.
let first_term_const = if first_term_zero {
Some(F::ZERO)
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
Some(const_0 * x * y)
} else {
None
};
let second_term_const = if second_term_zero {
Some(F::ZERO)
} else {
addend_const.map(|x| const_1 * x)
};
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
return Some(self.constant(x + y));
}
if first_term_zero && const_1.is_one() {
return Some(addend);
}
if second_term_zero {
if let Some(x) = mul_0_const {
if (const_0 * x).is_one() {
return Some(multiplicand_1);
}
}
if let Some(x) = mul_1_const {
if (const_1 * x).is_one() {
return Some(multiplicand_0);
}
}
}
None
}
/// Computes `x * y + z`.
pub fn mul_add(&mut self, x: Target, y: Target, z: Target) -> Target {
self.arithmetic(F::ONE, x, y, F::ONE, z)

View File

@ -5,6 +5,7 @@ use num::Integer;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::Field;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
@ -68,6 +69,17 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
multiplicand_1: ExtensionTarget<D>,
addend: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
// See if we can determine the result without adding an `ArithmeticGate`.
if let Some(result) = self.arithmetic_extension_special_cases(
const_0,
const_1,
multiplicand_0,
multiplicand_1,
addend,
) {
return result;
}
let zero = self.zero_extension();
self.double_arithmetic_extension(
const_0,
@ -82,6 +94,64 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.0
}
/// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`.
fn arithmetic_extension_special_cases(
&mut self,
const_0: F,
const_1: F,
multiplicand_0: ExtensionTarget<D>,
multiplicand_1: ExtensionTarget<D>,
addend: ExtensionTarget<D>,
) -> Option<ExtensionTarget<D>> {
let zero = self.zero_extension();
let mul_0_const = self.target_as_constant_ext(multiplicand_0);
let mul_1_const = self.target_as_constant_ext(multiplicand_1);
let addend_const = self.target_as_constant_ext(addend);
let first_term_zero =
const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero;
let second_term_zero = const_1 == F::ZERO || addend == zero;
// If both terms are constant, return their (constant) sum.
let first_term_const = if first_term_zero {
Some(F::Extension::ZERO)
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
Some(x * y * const_0.into())
} else {
None
};
let second_term_const = if second_term_zero {
Some(F::Extension::ZERO)
} else {
addend_const.map(|x| x * const_1.into())
};
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
return Some(self.constant_extension(x + y));
}
if first_term_zero && const_1.is_one() {
return Some(addend);
}
if second_term_zero {
if let Some(x) = mul_0_const {
if (x * const_0.into()).is_one() {
return Some(multiplicand_1);
}
}
if let Some(x) = mul_1_const {
if (x * const_1.into()).is_one() {
return Some(multiplicand_0);
}
}
}
None
}
pub fn add_extension(
&mut self,
a: ExtensionTarget<D>,

View File

@ -6,7 +6,7 @@ use log::info;
use crate::field::cosets::get_unique_coset_shifts;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::fri::commitment::PolynomialBatchCommitment;
use crate::gates::constant::ConstantGate;
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
@ -301,6 +301,24 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.targets_to_constants.get(&target).cloned()
}
/// If the given `ExtensionTarget` is a constant (i.e. it was created by the
/// `constant_extension(F)` method), returns its constant value. Otherwise, returns `None`.
pub fn target_as_constant_ext(&self, target: ExtensionTarget<D>) -> Option<F::Extension> {
// Get a Vec of any coefficients that are constant. If we end up with exactly D of them,
// then the `ExtensionTarget` as a whole is constant.
let const_coeffs: Vec<F> = target
.0
.into_iter()
.filter_map(|&t| self.target_as_constant(t))
.collect();
if let Ok(d_const_coeffs) = const_coeffs.try_into() {
Some(F::Extension::from_basefield_array(d_const_coeffs))
} else {
None
}
}
pub fn push_context(&mut self, level: log::Level, ctx: &str) {
self.context_log.push(ctx, level, self.num_gates());
}