mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-10 01:33:07 +00:00
Small recursion optimizations (#338)
* Small recursion optimizations Main thing is memoizing arithmetic operations. Overall savings is ~50 gates. * feedback
This commit is contained in:
parent
fdce382af3
commit
1450ffb29c
@ -8,7 +8,7 @@ use crate::iop::target::Target;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
/// `Target`s representing an element of an extension field.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct ExtensionTarget<const D: usize>(pub [Target; D]);
|
||||
|
||||
impl<const D: usize> ExtensionTarget<D> {
|
||||
|
||||
@ -3,7 +3,7 @@ use std::convert::TryInto;
|
||||
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::extension_field::{Extendable, OEF};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use crate::iop::target::Target;
|
||||
@ -58,7 +58,29 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
return result;
|
||||
}
|
||||
|
||||
let (gate, i) = self.find_arithmetic_gate(const_0, const_1);
|
||||
// See if we've already computed the same operation.
|
||||
let operation = ArithmeticOperation {
|
||||
const_0,
|
||||
const_1,
|
||||
multiplicand_0,
|
||||
multiplicand_1,
|
||||
addend,
|
||||
};
|
||||
if let Some(&result) = self.arithmetic_results.get(&operation) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
|
||||
let result = self.add_arithmetic_extension_operation(operation);
|
||||
self.arithmetic_results.insert(operation, result);
|
||||
result
|
||||
}
|
||||
|
||||
fn add_arithmetic_extension_operation(
|
||||
&mut self,
|
||||
operation: ArithmeticOperation<F, D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1);
|
||||
let wires_multiplicand_0 = ExtensionTarget::from_range(
|
||||
gate,
|
||||
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(i),
|
||||
@ -70,9 +92,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let wires_addend =
|
||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_addend(i));
|
||||
|
||||
self.connect_extension(multiplicand_0, wires_multiplicand_0);
|
||||
self.connect_extension(multiplicand_1, wires_multiplicand_1);
|
||||
self.connect_extension(addend, wires_addend);
|
||||
self.connect_extension(operation.multiplicand_0, wires_multiplicand_0);
|
||||
self.connect_extension(operation.multiplicand_1, wires_multiplicand_1);
|
||||
self.connect_extension(operation.addend, wires_addend);
|
||||
|
||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
|
||||
}
|
||||
@ -447,7 +469,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
quotient: inv,
|
||||
});
|
||||
|
||||
// Enforce that x times its purported inverse equals 1.
|
||||
// Enforce that y times its purported inverse equals 1.
|
||||
let y_inv = self.mul_extension(y, inv);
|
||||
self.connect_extension(y_inv, one);
|
||||
|
||||
@ -524,6 +546,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an arithmetic operation in the circuit. Used to memoize results.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
|
||||
pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
|
||||
const_0: F,
|
||||
const_1: F,
|
||||
multiplicand_0: ExtensionTarget<D>,
|
||||
multiplicand_1: ExtensionTarget<D>,
|
||||
addend: ExtensionTarget<D>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::convert::TryInto;
|
||||
|
||||
@ -106,8 +106,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExte
|
||||
let computed_output = {
|
||||
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
|
||||
let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
|
||||
let scaled_addend = builder.scalar_mul_ext_algebra(const_1, addend);
|
||||
builder.add_ext_algebra(scaled_mul, scaled_addend)
|
||||
builder.scalar_mul_add_ext_algebra(const_1, addend, scaled_mul)
|
||||
};
|
||||
|
||||
let diff = builder.sub_ext_algebra(output, computed_output);
|
||||
|
||||
@ -258,7 +258,7 @@ where
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
|
||||
let naive =
|
||||
let use_mds_gate =
|
||||
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
|
||||
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
@ -306,7 +306,7 @@ where
|
||||
}
|
||||
|
||||
// Partial rounds.
|
||||
if naive {
|
||||
if use_mds_gate {
|
||||
for r in 0..poseidon::N_PARTIAL_ROUNDS {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
|
||||
@ -12,6 +12,7 @@ use crate::field::fft::fft_root_table;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::fri::commitment::PolynomialBatchCommitment;
|
||||
use crate::fri::{FriConfig, FriParams};
|
||||
use crate::gadgets::arithmetic_extension::ArithmeticOperation;
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::gates::constant::ConstantGate;
|
||||
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
|
||||
@ -70,6 +71,9 @@ pub struct CircuitBuilder<F: RichField + Extendable<D>, const D: usize> {
|
||||
constants_to_targets: HashMap<F, Target>,
|
||||
targets_to_constants: HashMap<Target, F>,
|
||||
|
||||
/// Memoized results of `arithmetic_extension` calls.
|
||||
pub(crate) arithmetic_results: HashMap<ArithmeticOperation<F, D>, ExtensionTarget<D>>,
|
||||
|
||||
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
|
||||
/// these constants with gate index `g` and already using `i` arithmetic operations.
|
||||
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
|
||||
@ -100,6 +104,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
marked_targets: Vec::new(),
|
||||
generators: Vec::new(),
|
||||
constants_to_targets: HashMap::new(),
|
||||
arithmetic_results: HashMap::new(),
|
||||
targets_to_constants: HashMap::new(),
|
||||
free_arithmetic: HashMap::new(),
|
||||
free_random_access: HashMap::new(),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::PrefixedGate;
|
||||
use crate::iop::target::Target;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
@ -328,7 +328,6 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
|
||||
gammas: &[Target],
|
||||
alphas: &[Target],
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let one = builder.one_extension();
|
||||
let max_degree = common_data.quotient_degree_factor;
|
||||
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
||||
|
||||
@ -356,36 +355,37 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
|
||||
let mut s_ids = Vec::new();
|
||||
for j in 0..common_data.config.num_routed_wires {
|
||||
let k = builder.constant(common_data.k_is[j]);
|
||||
let k_ext = builder.convert_to_ext(k);
|
||||
s_ids.push(builder.mul_extension(k_ext, x));
|
||||
s_ids.push(builder.scalar_mul_ext(k, x));
|
||||
}
|
||||
|
||||
for i in 0..common_data.config.num_challenges {
|
||||
let z_x = local_zs[i];
|
||||
let z_gz = next_zs[i];
|
||||
|
||||
// L_1(x) Z(x) = 0.
|
||||
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
|
||||
|
||||
let numerator_values = (0..common_data.config.num_routed_wires)
|
||||
.map(|j| {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let beta_ext = builder.convert_to_ext(betas[i]);
|
||||
let gamma_ext = builder.convert_to_ext(gammas[i]);
|
||||
// `beta * s_id + wire_value + gamma`
|
||||
builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let denominator_values = (0..common_data.config.num_routed_wires)
|
||||
.map(|j| {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let beta_ext = builder.convert_to_ext(betas[i]);
|
||||
let gamma_ext = builder.convert_to_ext(gammas[i]);
|
||||
// `beta * s_sigma + wire_value + gamma`
|
||||
builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let quotient_values = (0..common_data.config.num_routed_wires)
|
||||
.map(|j| builder.div_extension(numerator_values[j], denominator_values[j]))
|
||||
.collect::<Vec<_>>();
|
||||
let mut numerator_values = Vec::new();
|
||||
let mut denominator_values = Vec::new();
|
||||
let mut quotient_values = Vec::new();
|
||||
|
||||
for j in 0..common_data.config.num_routed_wires {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let beta_ext = builder.convert_to_ext(betas[i]);
|
||||
let gamma_ext = builder.convert_to_ext(gammas[i]);
|
||||
|
||||
// The numerator is `beta * s_id + wire_value + gamma`, and the denominator is
|
||||
// `beta * s_sigma + wire_value + gamma`.
|
||||
let wire_value_plus_gamma = builder.add_extension(wire_value, gamma_ext);
|
||||
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
|
||||
let denominator =
|
||||
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
|
||||
let quotient = builder.div_extension(numerator, denominator);
|
||||
|
||||
numerator_values.push(numerator);
|
||||
denominator_values.push(denominator);
|
||||
quotient_values.push(quotient);
|
||||
}
|
||||
|
||||
// The partial products considered for this iteration of `i`.
|
||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user