diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs index 93c981a6..9ee2134f 100644 --- a/src/gates/reducing_extension.rs +++ b/src/gates/reducing_extension.rs @@ -13,11 +13,11 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field. #[derive(Debug, Clone)] -pub struct ReducingExtGate { +pub struct ReducingExtensionGate { pub num_coeffs: usize, } -impl ReducingExtGate { +impl ReducingExtensionGate { pub fn new(num_coeffs: usize) -> Self { Self { num_coeffs } } @@ -51,7 +51,7 @@ impl ReducingExtGate { } } -impl, const D: usize> Gate for ReducingExtGate { +impl, const D: usize> Gate for ReducingExtensionGate { fn id(&self) -> String { format!("{:?}", self) } @@ -163,14 +163,16 @@ impl, const D: usize> Gate for ReducingExtGat #[derive(Debug)] struct ReducingGenerator { gate_index: usize, - gate: ReducingExtGate, + gate: ReducingExtensionGate, } impl, const D: usize> SimpleGenerator for ReducingGenerator { fn dependencies(&self) -> Vec { - ReducingExtGate::::wires_alpha() - .chain(ReducingExtGate::::wires_old_acc()) - .chain((0..self.gate.num_coeffs).flat_map(|i| ReducingExtGate::::wires_coeff(i))) + ReducingExtensionGate::::wires_alpha() + .chain(ReducingExtensionGate::::wires_old_acc()) + .chain( + (0..self.gate.num_coeffs).flat_map(|i| ReducingExtensionGate::::wires_coeff(i)), + ) .map(|i| Target::wire(self.gate_index, i)) .collect() } @@ -181,16 +183,18 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< witness.get_extension_target(t) }; - let alpha = extract_extension(ReducingExtGate::::wires_alpha()); - let old_acc = extract_extension(ReducingExtGate::::wires_old_acc()); + let alpha = extract_extension(ReducingExtensionGate::::wires_alpha()); + let old_acc = extract_extension(ReducingExtensionGate::::wires_old_acc()); let coeffs = (0..self.gate.num_coeffs) - .map(|i| extract_extension(ReducingExtGate::::wires_coeff(i))) + .map(|i| extract_extension(ReducingExtensionGate::::wires_coeff(i))) .collect::>(); let accs = (0..self.gate.num_coeffs) .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) .collect::>(); - let output = - ExtensionTarget::from_range(self.gate_index, ReducingExtGate::::wires_output()); + let output = ExtensionTarget::from_range( + self.gate_index, + ReducingExtensionGate::::wires_output(), + ); let mut acc = old_acc; for i in 0..self.gate.num_coeffs { @@ -208,15 +212,15 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::reducing_extension::ReducingExtGate; + use crate::gates::reducing_extension::ReducingExtensionGate; #[test] fn low_degree() { - test_low_degree::(ReducingExtGate::new(22)); + test_low_degree::(ReducingExtensionGate::new(22)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(ReducingExtGate::new(22)) + test_eval_fns::(ReducingExtensionGate::new(22)) } } diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 12be80f6..f2cd3d55 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -3,8 +3,9 @@ use std::borrow::Borrow; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::reducing::ReducingGate; -use crate::gates::reducing_extension::ReducingExtGate; +use crate::gates::reducing_extension::ReducingExtensionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -94,7 +95,7 @@ impl ReducingFactorTarget { Self { base, count: 0 } } - /// Reduces a length `n` vector of `Target`s using `n/21` `ReducingGate`s (with 33 routed wires and 126 wires). + /// Reduces a vector of `Target`s using `ReducingGate`s. pub fn reduce_base( &mut self, terms: &[Target], @@ -103,11 +104,16 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { + let l = terms.len(); + // For small reductions, use an arithmetic gate. + if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { + return self.reduce_base_arithmetic(terms, builder); + } let max_coeffs_len = ReducingGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); - self.count += terms.len() as u64; + self.count += l as u64; let zero = builder.zero(); let zero_ext = builder.zero_extension(); let mut acc = zero_ext; @@ -138,6 +144,26 @@ impl ReducingFactorTarget { acc } + /// Reduces a vector of `Target`s using `ArithmeticGate`s. + fn reduce_base_arithmetic( + &mut self, + terms: &[Target], + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + self.count += terms.len() as u64; + terms + .iter() + .rev() + .fold(builder.zero_extension(), |acc, &t| { + let et = builder.convert_to_ext(t); + builder.mul_add_extension(self.base, acc, et) + }) + } + + /// Reduces a vector of `ExtensionTarget`s using `ReducingExtensionGate`s. pub fn reduce( &mut self, terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. @@ -146,12 +172,16 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { - let max_coeffs_len = ReducingExtGate::::max_coeffs_len( + let l = terms.len(); + // For small reductions, use an arithmetic gate. + if l - 1 <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops { + return self.reduce_arithmetic(terms, builder); + } + let max_coeffs_len = ReducingExtensionGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); - self.count += terms.len() as u64; - let zero = builder.zero(); + self.count += l as u64; let zero_ext = builder.zero_extension(); let mut acc = zero_ext; let mut reversed_terms = terms.to_vec(); @@ -160,30 +190,55 @@ impl ReducingFactorTarget { } reversed_terms.reverse(); for chunk in reversed_terms.chunks_exact(max_coeffs_len) { - let gate = ReducingExtGate::new(max_coeffs_len); + let gate = ReducingExtensionGate::new(max_coeffs_len); let gate_index = builder.add_gate(gate.clone(), Vec::new()); builder.connect_extension( self.base, - ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_alpha()), + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_alpha()), ); builder.connect_extension( acc, - ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_old_acc()), + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_old_acc(), + ), ); for (i, &t) in chunk.iter().enumerate() { builder.connect_extension( t, - ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_coeff(i)), + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_coeff(i), + ), ); } - acc = ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_output()); + acc = + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_output()); } acc } + /// Reduces a vector of `ExtensionTarget`s using `ArithmeticGate`s. + fn reduce_arithmetic( + &mut self, + terms: &[ExtensionTarget], + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + self.count += terms.len() as u64; + terms + .iter() + .rev() + .fold(builder.zero_extension(), |acc, &et| { + builder.mul_add_extension(self.base, acc, et) + }) + } + pub fn shift( &mut self, x: ExtensionTarget,