mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-02-05 14:33:31 +00:00
Small recursion optimizations (#338)
* Small recursion optimizations Main thing is memoizing arithmetic operations. Overall savings is ~50 gates. * feedback
This commit is contained in:
parent
fdce382af3
commit
1450ffb29c
@ -8,7 +8,7 @@ use crate::iop::target::Target;
|
|||||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||||
|
|
||||||
/// `Target`s representing an element of an extension field.
|
/// `Target`s representing an element of an extension field.
|
||||||
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
|
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||||
pub struct ExtensionTarget<const D: usize>(pub [Target; D]);
|
pub struct ExtensionTarget<const D: usize>(pub [Target; D]);
|
||||||
|
|
||||||
impl<const D: usize> ExtensionTarget<D> {
|
impl<const D: usize> ExtensionTarget<D> {
|
||||||
|
|||||||
@ -3,7 +3,7 @@ use std::convert::TryInto;
|
|||||||
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
|
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
|
||||||
use crate::field::extension_field::FieldExtension;
|
use crate::field::extension_field::FieldExtension;
|
||||||
use crate::field::extension_field::{Extendable, OEF};
|
use crate::field::extension_field::{Extendable, OEF};
|
||||||
use crate::field::field_types::{Field, RichField};
|
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||||
use crate::iop::target::Target;
|
use crate::iop::target::Target;
|
||||||
@ -58,7 +58,29 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (gate, i) = self.find_arithmetic_gate(const_0, const_1);
|
// See if we've already computed the same operation.
|
||||||
|
let operation = ArithmeticOperation {
|
||||||
|
const_0,
|
||||||
|
const_1,
|
||||||
|
multiplicand_0,
|
||||||
|
multiplicand_1,
|
||||||
|
addend,
|
||||||
|
};
|
||||||
|
if let Some(&result) = self.arithmetic_results.get(&operation) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
|
||||||
|
let result = self.add_arithmetic_extension_operation(operation);
|
||||||
|
self.arithmetic_results.insert(operation, result);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_arithmetic_extension_operation(
|
||||||
|
&mut self,
|
||||||
|
operation: ArithmeticOperation<F, D>,
|
||||||
|
) -> ExtensionTarget<D> {
|
||||||
|
let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1);
|
||||||
let wires_multiplicand_0 = ExtensionTarget::from_range(
|
let wires_multiplicand_0 = ExtensionTarget::from_range(
|
||||||
gate,
|
gate,
|
||||||
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(i),
|
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(i),
|
||||||
@ -70,9 +92,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
let wires_addend =
|
let wires_addend =
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_addend(i));
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_addend(i));
|
||||||
|
|
||||||
self.connect_extension(multiplicand_0, wires_multiplicand_0);
|
self.connect_extension(operation.multiplicand_0, wires_multiplicand_0);
|
||||||
self.connect_extension(multiplicand_1, wires_multiplicand_1);
|
self.connect_extension(operation.multiplicand_1, wires_multiplicand_1);
|
||||||
self.connect_extension(addend, wires_addend);
|
self.connect_extension(operation.addend, wires_addend);
|
||||||
|
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
|
||||||
}
|
}
|
||||||
@ -447,7 +469,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
quotient: inv,
|
quotient: inv,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Enforce that x times its purported inverse equals 1.
|
// Enforce that y times its purported inverse equals 1.
|
||||||
let y_inv = self.mul_extension(y, inv);
|
let y_inv = self.mul_extension(y, inv);
|
||||||
self.connect_extension(y_inv, one);
|
self.connect_extension(y_inv, one);
|
||||||
|
|
||||||
@ -524,6 +546,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Represents an arithmetic operation in the circuit. Used to memoize results.
|
||||||
|
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
|
||||||
|
pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
|
||||||
|
const_0: F,
|
||||||
|
const_1: F,
|
||||||
|
multiplicand_0: ExtensionTarget<D>,
|
||||||
|
multiplicand_1: ExtensionTarget<D>,
|
||||||
|
addend: ExtensionTarget<D>,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
|
|||||||
@ -106,8 +106,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExte
|
|||||||
let computed_output = {
|
let computed_output = {
|
||||||
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
|
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
|
||||||
let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
|
let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
|
||||||
let scaled_addend = builder.scalar_mul_ext_algebra(const_1, addend);
|
builder.scalar_mul_add_ext_algebra(const_1, addend, scaled_mul)
|
||||||
builder.add_ext_algebra(scaled_mul, scaled_addend)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let diff = builder.sub_ext_algebra(output, computed_output);
|
let diff = builder.sub_ext_algebra(output, computed_output);
|
||||||
|
|||||||
@ -258,7 +258,7 @@ where
|
|||||||
vars: EvaluationTargets<D>,
|
vars: EvaluationTargets<D>,
|
||||||
) -> Vec<ExtensionTarget<D>> {
|
) -> Vec<ExtensionTarget<D>> {
|
||||||
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
|
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
|
||||||
let naive =
|
let use_mds_gate =
|
||||||
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
|
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
|
||||||
|
|
||||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||||
@ -306,7 +306,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Partial rounds.
|
// Partial rounds.
|
||||||
if naive {
|
if use_mds_gate {
|
||||||
for r in 0..poseidon::N_PARTIAL_ROUNDS {
|
for r in 0..poseidon::N_PARTIAL_ROUNDS {
|
||||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||||
|
|||||||
@ -12,6 +12,7 @@ use crate::field::fft::fft_root_table;
|
|||||||
use crate::field::field_types::{Field, RichField};
|
use crate::field::field_types::{Field, RichField};
|
||||||
use crate::fri::commitment::PolynomialBatchCommitment;
|
use crate::fri::commitment::PolynomialBatchCommitment;
|
||||||
use crate::fri::{FriConfig, FriParams};
|
use crate::fri::{FriConfig, FriParams};
|
||||||
|
use crate::gadgets::arithmetic_extension::ArithmeticOperation;
|
||||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||||
use crate::gates::constant::ConstantGate;
|
use crate::gates::constant::ConstantGate;
|
||||||
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
|
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
|
||||||
@ -70,6 +71,9 @@ pub struct CircuitBuilder<F: RichField + Extendable<D>, const D: usize> {
|
|||||||
constants_to_targets: HashMap<F, Target>,
|
constants_to_targets: HashMap<F, Target>,
|
||||||
targets_to_constants: HashMap<Target, F>,
|
targets_to_constants: HashMap<Target, F>,
|
||||||
|
|
||||||
|
/// Memoized results of `arithmetic_extension` calls.
|
||||||
|
pub(crate) arithmetic_results: HashMap<ArithmeticOperation<F, D>, ExtensionTarget<D>>,
|
||||||
|
|
||||||
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
|
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
|
||||||
/// these constants with gate index `g` and already using `i` arithmetic operations.
|
/// these constants with gate index `g` and already using `i` arithmetic operations.
|
||||||
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
|
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
|
||||||
@ -100,6 +104,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
marked_targets: Vec::new(),
|
marked_targets: Vec::new(),
|
||||||
generators: Vec::new(),
|
generators: Vec::new(),
|
||||||
constants_to_targets: HashMap::new(),
|
constants_to_targets: HashMap::new(),
|
||||||
|
arithmetic_results: HashMap::new(),
|
||||||
targets_to_constants: HashMap::new(),
|
targets_to_constants: HashMap::new(),
|
||||||
free_arithmetic: HashMap::new(),
|
free_arithmetic: HashMap::new(),
|
||||||
free_random_access: HashMap::new(),
|
free_random_access: HashMap::new(),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use crate::field::extension_field::target::ExtensionTarget;
|
use crate::field::extension_field::target::ExtensionTarget;
|
||||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
use crate::field::field_types::{Field, RichField};
|
||||||
use crate::gates::gate::PrefixedGate;
|
use crate::gates::gate::PrefixedGate;
|
||||||
use crate::iop::target::Target;
|
use crate::iop::target::Target;
|
||||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||||
@ -328,7 +328,6 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
|
|||||||
gammas: &[Target],
|
gammas: &[Target],
|
||||||
alphas: &[Target],
|
alphas: &[Target],
|
||||||
) -> Vec<ExtensionTarget<D>> {
|
) -> Vec<ExtensionTarget<D>> {
|
||||||
let one = builder.one_extension();
|
|
||||||
let max_degree = common_data.quotient_degree_factor;
|
let max_degree = common_data.quotient_degree_factor;
|
||||||
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
||||||
|
|
||||||
@ -356,36 +355,37 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
|
|||||||
let mut s_ids = Vec::new();
|
let mut s_ids = Vec::new();
|
||||||
for j in 0..common_data.config.num_routed_wires {
|
for j in 0..common_data.config.num_routed_wires {
|
||||||
let k = builder.constant(common_data.k_is[j]);
|
let k = builder.constant(common_data.k_is[j]);
|
||||||
let k_ext = builder.convert_to_ext(k);
|
s_ids.push(builder.scalar_mul_ext(k, x));
|
||||||
s_ids.push(builder.mul_extension(k_ext, x));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..common_data.config.num_challenges {
|
for i in 0..common_data.config.num_challenges {
|
||||||
let z_x = local_zs[i];
|
let z_x = local_zs[i];
|
||||||
let z_gz = next_zs[i];
|
let z_gz = next_zs[i];
|
||||||
|
|
||||||
|
// L_1(x) Z(x) = 0.
|
||||||
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
|
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
|
||||||
|
|
||||||
let numerator_values = (0..common_data.config.num_routed_wires)
|
let mut numerator_values = Vec::new();
|
||||||
.map(|j| {
|
let mut denominator_values = Vec::new();
|
||||||
let wire_value = vars.local_wires[j];
|
let mut quotient_values = Vec::new();
|
||||||
let beta_ext = builder.convert_to_ext(betas[i]);
|
|
||||||
let gamma_ext = builder.convert_to_ext(gammas[i]);
|
for j in 0..common_data.config.num_routed_wires {
|
||||||
// `beta * s_id + wire_value + gamma`
|
let wire_value = vars.local_wires[j];
|
||||||
builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext)
|
let beta_ext = builder.convert_to_ext(betas[i]);
|
||||||
})
|
let gamma_ext = builder.convert_to_ext(gammas[i]);
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let denominator_values = (0..common_data.config.num_routed_wires)
|
// The numerator is `beta * s_id + wire_value + gamma`, and the denominator is
|
||||||
.map(|j| {
|
// `beta * s_sigma + wire_value + gamma`.
|
||||||
let wire_value = vars.local_wires[j];
|
let wire_value_plus_gamma = builder.add_extension(wire_value, gamma_ext);
|
||||||
let beta_ext = builder.convert_to_ext(betas[i]);
|
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
|
||||||
let gamma_ext = builder.convert_to_ext(gammas[i]);
|
let denominator =
|
||||||
// `beta * s_sigma + wire_value + gamma`
|
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
|
||||||
builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext)
|
let quotient = builder.div_extension(numerator, denominator);
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
numerator_values.push(numerator);
|
||||||
let quotient_values = (0..common_data.config.num_routed_wires)
|
denominator_values.push(denominator);
|
||||||
.map(|j| builder.div_extension(numerator_values[j], denominator_values[j]))
|
quotient_values.push(quotient);
|
||||||
.collect::<Vec<_>>();
|
}
|
||||||
|
|
||||||
// The partial products considered for this iteration of `i`.
|
// The partial products considered for this iteration of `i`.
|
||||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user