diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index a40013fa..ab3c3619 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -8,7 +8,7 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; /// `Target`s representing an element of an extension field. -#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub struct ExtensionTarget(pub [Target; D]); impl ExtensionTarget { diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 814cd8ae..24499760 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -3,7 +3,7 @@ use std::convert::TryInto; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; @@ -58,7 +58,29 @@ impl, const D: usize> CircuitBuilder { return result; } - let (gate, i) = self.find_arithmetic_gate(const_0, const_1); + // See if we've already computed the same operation. + let operation = ArithmeticOperation { + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + }; + if let Some(&result) = self.arithmetic_results.get(&operation) { + return result; + } + + // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. + let result = self.add_arithmetic_extension_operation(operation); + self.arithmetic_results.insert(operation, result); + result + } + + fn add_arithmetic_extension_operation( + &mut self, + operation: ArithmeticOperation, + ) -> ExtensionTarget { + let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1); let wires_multiplicand_0 = ExtensionTarget::from_range( gate, ArithmeticExtensionGate::::wires_ith_multiplicand_0(i), @@ -70,9 +92,9 @@ impl, const D: usize> CircuitBuilder { let wires_addend = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_addend(i)); - self.connect_extension(multiplicand_0, wires_multiplicand_0); - self.connect_extension(multiplicand_1, wires_multiplicand_1); - self.connect_extension(addend, wires_addend); + self.connect_extension(operation.multiplicand_0, wires_multiplicand_0); + self.connect_extension(operation.multiplicand_1, wires_multiplicand_1); + self.connect_extension(operation.addend, wires_addend); ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } @@ -447,7 +469,7 @@ impl, const D: usize> CircuitBuilder { quotient: inv, }); - // Enforce that x times its purported inverse equals 1. + // Enforce that y times its purported inverse equals 1. let y_inv = self.mul_extension(y, inv); self.connect_extension(y_inv, one); @@ -524,6 +546,16 @@ impl, const D: usize> CircuitBuilder { } } +/// Represents an arithmetic operation in the circuit. Used to memoize results. +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub(crate) struct ArithmeticOperation, const D: usize> { + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, +} + #[cfg(test)] mod tests { use std::convert::TryInto; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 98c15395..95b48e2f 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -106,8 +106,7 @@ impl, const D: usize> Gate for ArithmeticExte let computed_output = { let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul); - let scaled_addend = builder.scalar_mul_ext_algebra(const_1, addend); - builder.add_ext_algebra(scaled_mul, scaled_addend) + builder.scalar_mul_add_ext_algebra(const_1, addend, scaled_mul) }; let diff = builder.sub_ext_algebra(output, computed_output); diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 9188a473..1f5f746d 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -258,7 +258,7 @@ where vars: EvaluationTargets, ) -> Vec> { // The naive method is more efficient if we have enough routed wires for PoseidonMdsGate. - let naive = + let use_mds_gate = builder.config.num_routed_wires >= PoseidonMdsGate::::new().num_wires(); let mut constraints = Vec::with_capacity(self.num_constraints()); @@ -306,7 +306,7 @@ where } // Partial rounds. - if naive { + if use_mds_gate { for r in 0..poseidon::N_PARTIAL_ROUNDS { >::constant_layer_recursive(builder, &mut state, round_ctr); let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 62042cd1..5dcde1e0 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -12,6 +12,7 @@ use crate::field::fft::fft_root_table; use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::{FriConfig, FriParams}; +use crate::gadgets::arithmetic_extension::ArithmeticOperation; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; @@ -70,6 +71,9 @@ pub struct CircuitBuilder, const D: usize> { constants_to_targets: HashMap, targets_to_constants: HashMap, + /// Memoized results of `arithmetic_extension` calls. + pub(crate) arithmetic_results: HashMap, ExtensionTarget>, + /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` arithmetic operations. pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, @@ -100,6 +104,7 @@ impl, const D: usize> CircuitBuilder { marked_targets: Vec::new(), generators: Vec::new(), constants_to_targets: HashMap::new(), + arithmetic_results: HashMap::new(), targets_to_constants: HashMap::new(), free_arithmetic: HashMap::new(), free_random_access: HashMap::new(), diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 6d0dc982..0ebdbd22 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -1,6 +1,6 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::field::field_types::{Field, RichField}; use crate::gates::gate::PrefixedGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -328,7 +328,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons gammas: &[Target], alphas: &[Target], ) -> Vec> { - let one = builder.one_extension(); let max_degree = common_data.quotient_degree_factor; let (num_prods, final_num_prod) = common_data.num_partial_products; @@ -356,36 +355,37 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut s_ids = Vec::new(); for j in 0..common_data.config.num_routed_wires { let k = builder.constant(common_data.k_is[j]); - let k_ext = builder.convert_to_ext(k); - s_ids.push(builder.mul_extension(k_ext, x)); + s_ids.push(builder.scalar_mul_ext(k, x)); } for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; let z_gz = next_zs[i]; + + // L_1(x) Z(x) = 0. vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); - let numerator_values = (0..common_data.config.num_routed_wires) - .map(|j| { - let wire_value = vars.local_wires[j]; - let beta_ext = builder.convert_to_ext(betas[i]); - let gamma_ext = builder.convert_to_ext(gammas[i]); - // `beta * s_id + wire_value + gamma` - builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext) - }) - .collect::>(); - let denominator_values = (0..common_data.config.num_routed_wires) - .map(|j| { - let wire_value = vars.local_wires[j]; - let beta_ext = builder.convert_to_ext(betas[i]); - let gamma_ext = builder.convert_to_ext(gammas[i]); - // `beta * s_sigma + wire_value + gamma` - builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext) - }) - .collect::>(); - let quotient_values = (0..common_data.config.num_routed_wires) - .map(|j| builder.div_extension(numerator_values[j], denominator_values[j])) - .collect::>(); + let mut numerator_values = Vec::new(); + let mut denominator_values = Vec::new(); + let mut quotient_values = Vec::new(); + + for j in 0..common_data.config.num_routed_wires { + let wire_value = vars.local_wires[j]; + let beta_ext = builder.convert_to_ext(betas[i]); + let gamma_ext = builder.convert_to_ext(gammas[i]); + + // The numerator is `beta * s_id + wire_value + gamma`, and the denominator is + // `beta * s_sigma + wire_value + gamma`. + let wire_value_plus_gamma = builder.add_extension(wire_value, gamma_ext); + let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma); + let denominator = + builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma); + let quotient = builder.div_extension(numerator, denominator); + + numerator_values.push(numerator); + denominator_values.push(denominator); + quotient_values.push(quotient); + } // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];