diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 7f7ee32b..93de5e97 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -20,6 +20,7 @@ pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; pub mod reducing; +pub mod reducing_extension; pub mod subtraction_u32; pub mod switch; diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs new file mode 100644 index 00000000..532b484f --- /dev/null +++ b/src/gates/reducing_extension.rs @@ -0,0 +1,222 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field. +#[derive(Debug, Clone)] +pub struct ReducingExtensionGate { + pub num_coeffs: usize, +} + +impl ReducingExtensionGate { + pub fn new(num_coeffs: usize) -> Self { + Self { num_coeffs } + } + + pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize { + // `3*D` routed wires are used for the output, alpha and old accumulator. + // Need `num_coeffs*D` routed wires for coeffs, and `(num_coeffs-1)*D` wires for accumulators. + ((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2)) + } + + pub fn wires_output() -> Range { + 0..D + } + pub fn wires_alpha() -> Range { + D..2 * D + } + pub fn wires_old_acc() -> Range { + 2 * D..3 * D + } + const START_COEFFS: usize = 3 * D; + pub fn wires_coeff(i: usize) -> Range { + Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D + } + fn start_accs(&self) -> usize { + Self::START_COEFFS + self.num_coeffs * D + } + fn wires_accs(&self, i: usize) -> Range { + debug_assert!(i < self.num_coeffs); + if i == self.num_coeffs - 1 { + // The last accumulator is the output. + return Self::wires_output(); + } + self.start_accs() + D * i..self.start_accs() + D * (i + 1) + } +} + +impl, const D: usize> Gate for ReducingExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.push(acc * alpha + coeffs[i] - accs[i]); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_basefield_array()) + .collect() + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let alpha = vars.get_local_ext(Self::wires_alpha()); + let old_acc = vars.get_local_ext(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array()); + acc = accs[i]; + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + let coeff = coeffs[i]; + let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff); + tmp = builder.sub_ext_algebra(tmp, accs[i]); + constraints.push(tmp); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_ext_target_array()) + .collect() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + vec![Box::new( + ReducingGenerator { + gate_index, + gate: self.clone(), + } + .adapter(), + )] + } + + fn num_wires(&self) -> usize { + 2 * D + 2 * D * self.num_coeffs + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + D * self.num_coeffs + } +} + +#[derive(Debug)] +struct ReducingGenerator { + gate_index: usize, + gate: ReducingExtensionGate, +} + +impl, const D: usize> SimpleGenerator for ReducingGenerator { + fn dependencies(&self) -> Vec { + ReducingExtensionGate::::wires_alpha() + .chain(ReducingExtensionGate::::wires_old_acc()) + .chain((0..self.gate.num_coeffs).flat_map(ReducingExtensionGate::::wires_coeff)) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let alpha = local_extension(ReducingExtensionGate::::wires_alpha()); + let old_acc = local_extension(ReducingExtensionGate::::wires_old_acc()); + let coeffs = (0..self.gate.num_coeffs) + .map(|i| local_extension(ReducingExtensionGate::::wires_coeff(i))) + .collect::>(); + let accs = (0..self.gate.num_coeffs) + .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) + .collect::>(); + + let mut acc = old_acc; + for i in 0..self.gate.num_coeffs { + let computed_acc = acc * alpha + coeffs[i]; + out_buffer.set_extension_target(accs[i], computed_acc); + acc = computed_acc; + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::reducing_extension::ReducingExtensionGate; + + #[test] + fn low_degree() { + test_low_degree::(ReducingExtensionGate::new(22)); + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(ReducingExtensionGate::new(22)) + } +} diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 3e00602c..f700a6ff 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -3,7 +3,9 @@ use std::borrow::Borrow; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::reducing::ReducingGate; +use crate::gates::reducing_extension::ReducingExtensionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -93,7 +95,7 @@ impl ReducingFactorTarget { Self { base, count: 0 } } - /// Reduces a length `n` vector of `Target`s using `n/21` `ReducingGate`s (with 33 routed wires and 126 wires). + /// Reduces a vector of `Target`s using `ReducingGate`s. pub fn reduce_base( &mut self, terms: &[Target], @@ -102,11 +104,22 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { + let l = terms.len(); + + // For small reductions, use an arithmetic gate. + if l <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops + 1 { + let terms_ext = terms + .iter() + .map(|&t| builder.convert_to_ext(t)) + .collect::>(); + return self.reduce_arithmetic(&terms_ext, builder); + } + let max_coeffs_len = ReducingGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); - self.count += terms.len() as u64; + self.count += l as u64; let zero = builder.zero(); let zero_ext = builder.zero_extension(); let mut acc = zero_ext; @@ -137,6 +150,7 @@ impl ReducingFactorTarget { acc } + /// Reduces a vector of `ExtensionTarget`s using `ReducingExtensionGate`s. pub fn reduce( &mut self, terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. @@ -146,18 +160,74 @@ impl ReducingFactorTarget { F: RichField + Extendable, { let l = terms.len(); - self.count += l as u64; - let mut terms_vec = terms.to_vec(); - let mut acc = builder.zero_extension(); - terms_vec.reverse(); - - for x in terms_vec { - acc = builder.mul_add_extension(self.base, acc, x); + // For small reductions, use an arithmetic gate. + if l <= ArithmeticExtensionGate::::new_from_config(&builder.config).num_ops + 1 { + return self.reduce_arithmetic(terms, builder); } + + let max_coeffs_len = ReducingExtensionGate::::max_coeffs_len( + builder.config.num_wires, + builder.config.num_routed_wires, + ); + self.count += l as u64; + let zero_ext = builder.zero_extension(); + let mut acc = zero_ext; + let mut reversed_terms = terms.to_vec(); + while reversed_terms.len() % max_coeffs_len != 0 { + reversed_terms.push(zero_ext); + } + reversed_terms.reverse(); + for chunk in reversed_terms.chunks_exact(max_coeffs_len) { + let gate = ReducingExtensionGate::new(max_coeffs_len); + let gate_index = builder.add_gate(gate.clone(), Vec::new()); + + builder.connect_extension( + self.base, + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_alpha()), + ); + builder.connect_extension( + acc, + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_old_acc(), + ), + ); + for (i, &t) in chunk.iter().enumerate() { + builder.connect_extension( + t, + ExtensionTarget::from_range( + gate_index, + ReducingExtensionGate::::wires_coeff(i), + ), + ); + } + + acc = + ExtensionTarget::from_range(gate_index, ReducingExtensionGate::::wires_output()); + } + acc } + /// Reduces a vector of `ExtensionTarget`s using `ArithmeticGate`s. + fn reduce_arithmetic( + &mut self, + terms: &[ExtensionTarget], + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + self.count += terms.len() as u64; + terms + .iter() + .rev() + .fold(builder.zero_extension(), |acc, &et| { + builder.mul_add_extension(self.base, acc, et) + }) + } + pub fn shift( &mut self, x: ExtensionTarget, @@ -260,4 +330,9 @@ mod tests { fn test_reduce_gadget_base_100() -> Result<()> { test_reduce_gadget_base(100) } + + #[test] + fn test_reduce_gadget_100() -> Result<()> { + test_reduce_gadget(100) + } }