Small recursion optimizations (#338)

* Small recursion optimizations

Main thing is memoizing arithmetic operations. Overall savings is ~50 gates.

* feedback
This commit is contained in:
Daniel Lubarov 2021-11-04 16:23:01 -07:00 committed by GitHub
parent fdce382af3
commit 1450ffb29c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 36 deletions

View File

@ -8,7 +8,7 @@ use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
/// `Target`s representing an element of an extension field.
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ExtensionTarget<const D: usize>(pub [Target; D]);
impl<const D: usize> ExtensionTarget<D> {

View File

@ -3,7 +3,7 @@ use std::convert::TryInto;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::FieldExtension;
use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::{Field, RichField};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
@ -58,7 +58,29 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return result;
}
let (gate, i) = self.find_arithmetic_gate(const_0, const_1);
// See if we've already computed the same operation.
let operation = ArithmeticOperation {
const_0,
const_1,
multiplicand_0,
multiplicand_1,
addend,
};
if let Some(&result) = self.arithmetic_results.get(&operation) {
return result;
}
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
let result = self.add_arithmetic_extension_operation(operation);
self.arithmetic_results.insert(operation, result);
result
}
fn add_arithmetic_extension_operation(
&mut self,
operation: ArithmeticOperation<F, D>,
) -> ExtensionTarget<D> {
let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1);
let wires_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(i),
@ -70,9 +92,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let wires_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_addend(i));
self.connect_extension(multiplicand_0, wires_multiplicand_0);
self.connect_extension(multiplicand_1, wires_multiplicand_1);
self.connect_extension(addend, wires_addend);
self.connect_extension(operation.multiplicand_0, wires_multiplicand_0);
self.connect_extension(operation.multiplicand_1, wires_multiplicand_1);
self.connect_extension(operation.addend, wires_addend);
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
}
@ -447,7 +469,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
quotient: inv,
});
// Enforce that x times its purported inverse equals 1.
// Enforce that y times its purported inverse equals 1.
let y_inv = self.mul_extension(y, inv);
self.connect_extension(y_inv, one);
@ -524,6 +546,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Represents an arithmetic operation in the circuit. Used to memoize results.
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
const_0: F,
const_1: F,
multiplicand_0: ExtensionTarget<D>,
multiplicand_1: ExtensionTarget<D>,
addend: ExtensionTarget<D>,
}
#[cfg(test)]
mod tests {
use std::convert::TryInto;

View File

@ -106,8 +106,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExte
let computed_output = {
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
let scaled_addend = builder.scalar_mul_ext_algebra(const_1, addend);
builder.add_ext_algebra(scaled_mul, scaled_addend)
builder.scalar_mul_add_ext_algebra(const_1, addend, scaled_mul)
};
let diff = builder.sub_ext_algebra(output, computed_output);

View File

@ -258,7 +258,7 @@ where
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
let naive =
let use_mds_gate =
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
let mut constraints = Vec::with_capacity(self.num_constraints());
@ -306,7 +306,7 @@ where
}
// Partial rounds.
if naive {
if use_mds_gate {
for r in 0..poseidon::N_PARTIAL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];

View File

@ -12,6 +12,7 @@ use crate::field::fft::fft_root_table;
use crate::field::field_types::{Field, RichField};
use crate::fri::commitment::PolynomialBatchCommitment;
use crate::fri::{FriConfig, FriParams};
use crate::gadgets::arithmetic_extension::ArithmeticOperation;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::gates::constant::ConstantGate;
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
@ -70,6 +71,9 @@ pub struct CircuitBuilder<F: RichField + Extendable<D>, const D: usize> {
constants_to_targets: HashMap<F, Target>,
targets_to_constants: HashMap<Target, F>,
/// Memoized results of `arithmetic_extension` calls.
pub(crate) arithmetic_results: HashMap<ArithmeticOperation<F, D>, ExtensionTarget<D>>,
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
@ -100,6 +104,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
marked_targets: Vec::new(),
generators: Vec::new(),
constants_to_targets: HashMap::new(),
arithmetic_results: HashMap::new(),
targets_to_constants: HashMap::new(),
free_arithmetic: HashMap::new(),
free_random_access: HashMap::new(),

View File

@ -1,6 +1,6 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::PrefixedGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -328,7 +328,6 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
gammas: &[Target],
alphas: &[Target],
) -> Vec<ExtensionTarget<D>> {
let one = builder.one_extension();
let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products;
@ -356,36 +355,37 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
let mut s_ids = Vec::new();
for j in 0..common_data.config.num_routed_wires {
let k = builder.constant(common_data.k_is[j]);
let k_ext = builder.convert_to_ext(k);
s_ids.push(builder.mul_extension(k_ext, x));
s_ids.push(builder.scalar_mul_ext(k, x));
}
for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
// L_1(x) Z(x) = 0.
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
let numerator_values = (0..common_data.config.num_routed_wires)
.map(|j| {
let wire_value = vars.local_wires[j];
let beta_ext = builder.convert_to_ext(betas[i]);
let gamma_ext = builder.convert_to_ext(gammas[i]);
// `beta * s_id + wire_value + gamma`
builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext)
})
.collect::<Vec<_>>();
let denominator_values = (0..common_data.config.num_routed_wires)
.map(|j| {
let wire_value = vars.local_wires[j];
let beta_ext = builder.convert_to_ext(betas[i]);
let gamma_ext = builder.convert_to_ext(gammas[i]);
// `beta * s_sigma + wire_value + gamma`
builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext)
})
.collect::<Vec<_>>();
let quotient_values = (0..common_data.config.num_routed_wires)
.map(|j| builder.div_extension(numerator_values[j], denominator_values[j]))
.collect::<Vec<_>>();
let mut numerator_values = Vec::new();
let mut denominator_values = Vec::new();
let mut quotient_values = Vec::new();
for j in 0..common_data.config.num_routed_wires {
let wire_value = vars.local_wires[j];
let beta_ext = builder.convert_to_ext(betas[i]);
let gamma_ext = builder.convert_to_ext(gammas[i]);
// The numerator is `beta * s_id + wire_value + gamma`, and the denominator is
// `beta * s_sigma + wire_value + gamma`.
let wire_value_plus_gamma = builder.add_extension(wire_value, gamma_ext);
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
let denominator =
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
let quotient = builder.div_extension(numerator, denominator);
numerator_values.push(numerator);
denominator_values.push(denominator);
quotient_values.push(quotient);
}
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];