use std::borrow::Borrow; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::reducing::ReducingGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; /// When verifying the composition polynomial in FRI we have to compute sums of the form /// `(sum_0^k a^i * x_i)/d_0 + (sum_k^r a^i * y_i)/d_1` /// The most efficient way to do this is to compute both quotient separately using Horner's method, /// scale the second one by `a^(r-1-k)`, and add them up. /// This struct abstract away these operations by implementing Horner's method and keeping track /// of the number of multiplications by `a` to compute the scaling factor. /// See https://github.com/mir-protocol/plonky2/pull/69 for more details and discussions. #[derive(Debug, Clone)] pub struct ReducingFactor { base: F, count: u64, } impl ReducingFactor { pub fn new(base: F) -> Self { Self { base, count: 0 } } fn mul(&mut self, x: F) -> F { self.count += 1; self.base * x } fn mul_poly(&mut self, p: &mut PolynomialCoeffs) { self.count += 1; *p *= self.base; } pub fn reduce(&mut self, iter: impl DoubleEndedIterator>) -> F { iter.rev() .fold(F::ZERO, |acc, x| self.mul(acc) + *x.borrow()) } pub fn reduce_polys( &mut self, polys: impl DoubleEndedIterator>>, ) -> PolynomialCoeffs { polys.rev().fold(PolynomialCoeffs::empty(), |mut acc, x| { self.mul_poly(&mut acc); acc += x.borrow(); acc }) } pub fn reduce_polys_base, const D: usize>( &mut self, polys: impl IntoIterator>>, ) -> PolynomialCoeffs { self.base .powers() .zip(polys) .map(|(base_power, poly)| { self.count += 1; poly.borrow().mul_extension(base_power) }) .sum() } pub fn shift(&mut self, x: F) -> F { let tmp = self.base.exp_u64(self.count) * x; self.count = 0; tmp } pub fn shift_poly(&mut self, p: &mut PolynomialCoeffs) { *p *= self.base.exp_u64(self.count); self.count = 0; } pub fn reset(&mut self) { self.count = 0; } } #[derive(Debug, Clone)] pub struct ReducingFactorTarget { base: ExtensionTarget, count: u64, } impl ReducingFactorTarget { pub fn new(base: ExtensionTarget) -> Self { Self { base, count: 0 } } /// Reduces a length `n` vector of `Target`s using `n/21` `ReducingGate`s (with 33 routed wires and 126 wires). pub fn reduce_base( &mut self, terms: &[Target], builder: &mut CircuitBuilder, ) -> ExtensionTarget where F: RichField + Extendable, { let max_coeffs_len = ReducingGate::::max_coeffs_len( builder.config.num_wires, builder.config.num_routed_wires, ); self.count += terms.len() as u64; let zero = builder.zero(); let zero_ext = builder.zero_extension(); let mut acc = zero_ext; let mut reversed_terms = terms.to_vec(); while reversed_terms.len() % max_coeffs_len != 0 { reversed_terms.push(zero); } reversed_terms.reverse(); for chunk in reversed_terms.chunks_exact(max_coeffs_len) { let gate = ReducingGate::new(max_coeffs_len); let gate_index = builder.add_gate(gate.clone(), Vec::new()); builder.connect_extension( self.base, ExtensionTarget::from_range(gate_index, ReducingGate::::wires_alpha()), ); builder.connect_extension( acc, ExtensionTarget::from_range(gate_index, ReducingGate::::wires_old_acc()), ); for (&t, c) in chunk.iter().zip(gate.wires_coeffs()) { builder.connect(t, Target::wire(gate_index, c)); } acc = ExtensionTarget::from_range(gate_index, ReducingGate::::wires_output()); } acc } pub fn reduce( &mut self, terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. builder: &mut CircuitBuilder, ) -> ExtensionTarget where F: RichField + Extendable, { let l = terms.len(); self.count += l as u64; let mut terms_vec = terms.to_vec(); let mut acc = builder.zero_extension(); terms_vec.reverse(); for x in terms_vec { acc = builder.mul_add_extension(self.base, acc, x); } acc } pub fn shift( &mut self, x: ExtensionTarget, builder: &mut CircuitBuilder, ) -> ExtensionTarget where F: RichField + Extendable, { let exp = builder.exp_u64_extension(self.base, self.count); self.count = 0; builder.mul_extension(exp, x) } pub fn reset(&mut self) { self.count = 0; } } #[cfg(test)] mod tests { use anyhow::Result; use super::*; use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticExtension; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; fn test_reduce_gadget_base(n: usize) -> Result<()> { type F = CrandallField; type FF = QuarticExtension; const D: usize = 4; let config = CircuitConfig::large_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let alpha = FF::rand(); let vs = F::rand_vec(n); let manual_reduce = ReducingFactor::new(alpha).reduce(vs.iter().map(|&v| FF::from(v))); let manual_reduce = builder.constant_extension(manual_reduce); let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); let vs_t = vs.iter().map(|&v| builder.constant(v)).collect::>(); let circuit_reduce = alpha_t.reduce_base(&vs_t, &mut builder); builder.connect_extension(manual_reduce, circuit_reduce); let data = builder.build(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } fn test_reduce_gadget(n: usize) -> Result<()> { type F = CrandallField; type FF = QuarticExtension; const D: usize = 4; let config = CircuitConfig::large_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let alpha = FF::rand(); let vs = (0..n).map(FF::from_canonical_usize).collect::>(); let manual_reduce = ReducingFactor::new(alpha).reduce(vs.iter()); let manual_reduce = builder.constant_extension(manual_reduce); let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); let vs_t = vs .iter() .map(|&v| builder.constant_extension(v)) .collect::>(); let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder); builder.connect_extension(manual_reduce, circuit_reduce); let data = builder.build(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } #[test] fn test_reduce_gadget_even() -> Result<()> { test_reduce_gadget(10) } #[test] fn test_reduce_gadget_odd() -> Result<()> { test_reduce_gadget(11) } #[test] fn test_reduce_gadget_base_100() -> Result<()> { test_reduce_gadget_base(100) } }