diff --git a/src/gates/mod.rs b/src/gates/mod.rs index dcb96b96..24cf8159 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -14,6 +14,7 @@ pub(crate) mod noop; pub(crate) mod public_input; pub mod random_access; pub mod reducing; +pub mod reducing_ext; #[cfg(test)] mod gate_testing; diff --git a/src/gates/reducing_ext.rs b/src/gates/reducing_ext.rs new file mode 100644 index 00000000..fc83d17d --- /dev/null +++ b/src/gates/reducing_ext.rs @@ -0,0 +1,217 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::PartialWitness; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field. +#[derive(Debug, Clone)] +pub struct ReducingExtGate { + pub num_coeffs: usize, +} + +impl ReducingExtGate { + pub fn new(num_coeffs: usize) -> Self { + Self { num_coeffs } + } + + pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize { + ((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2)) + } + + pub fn wires_output() -> Range { + 0..D + } + pub fn wires_alpha() -> Range { + D..2 * D + } + pub fn wires_old_acc() -> Range { + 2 * D..3 * D + } + const START_COEFFS: usize = 3 * D; + pub fn wires_coeff(i: usize) -> Range { + Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D + } + fn start_accs(&self) -> usize { + Self::START_COEFFS + self.num_coeffs * D + } + fn wires_accs(&self, i: usize) -> Range { + if i == self.num_coeffs - 1 { + // The last accumulator is the output. + return Self::wires_output(); + } + self.start_accs() + D * i..self.start_accs() + D * (i + 1) + } +} + +impl, const D: usize> Gate for ReducingExtGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.push(acc * alpha + coeffs[i] - accs[i]); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_basefield_array()) + .collect() + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let alpha = vars.get_local_ext(Self::wires_alpha()); + let old_acc = vars.get_local_ext(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array()); + acc = accs[i]; + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let alpha = vars.get_local_ext_algebra(Self::wires_alpha()); + let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc()); + let coeffs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i))) + .collect::>(); + let accs = (0..self.num_coeffs) + .map(|i| vars.get_local_ext_algebra(self.wires_accs(i))) + .collect::>(); + + let mut constraints = Vec::with_capacity(>::num_constraints(self)); + let mut acc = old_acc; + for i in 0..self.num_coeffs { + let coeff = coeffs[i]; + let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff); + tmp = builder.sub_ext_algebra(tmp, accs[i]); + constraints.push(tmp); + acc = accs[i]; + } + + constraints + .into_iter() + .flat_map(|alg| alg.to_ext_target_array()) + .collect() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + vec![Box::new(ReducingGenerator { + gate_index, + gate: self.clone(), + })] + } + + fn num_wires(&self) -> usize { + 2 * D + 2 * D * self.num_coeffs + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + D * self.num_coeffs + } +} + +struct ReducingGenerator { + gate_index: usize, + gate: ReducingExtGate, +} + +impl, const D: usize> SimpleGenerator for ReducingGenerator { + fn dependencies(&self) -> Vec { + ReducingExtGate::::wires_alpha() + .chain(ReducingExtGate::::wires_old_acc()) + .chain((0..self.gate.num_coeffs).flat_map(|i| ReducingExtGate::::wires_coeff(i))) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let alpha = extract_extension(ReducingExtGate::::wires_alpha()); + let old_acc = extract_extension(ReducingExtGate::::wires_old_acc()); + let coeffs = (0..self.gate.num_coeffs) + .map(|i| extract_extension(ReducingExtGate::::wires_coeff(i))) + .collect::>(); + let accs = (0..self.gate.num_coeffs) + .map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i))) + .collect::>(); + let output = + ExtensionTarget::from_range(self.gate_index, ReducingExtGate::::wires_output()); + + let mut acc = old_acc; + for i in 0..self.gate.num_coeffs { + let computed_acc = acc * alpha + coeffs[i]; + out_buffer.set_extension_target(accs[i], computed_acc); + acc = computed_acc; + } + out_buffer.set_extension_target(output, acc); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::crandall_field::CrandallField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::reducing_ext::ReducingExtGate; + + #[test] + fn low_degree() { + test_low_degree::(ReducingExtGate::new(22)); + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(ReducingExtGate::new(22)) + } +} diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index bae4b5d6..955fb98b 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -39,7 +39,7 @@ pub struct CircuitBuilder, const D: usize> { gates: HashSet>, /// The concrete placement of each gate. - gate_instances: Vec>, + pub gate_instances: Vec>, /// Targets to be made public. public_inputs: Vec, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 59aaa485..861bf1b0 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -62,7 +62,7 @@ impl CircuitConfig { pub(crate) fn large_config() -> Self { Self { num_wires: 126, - num_routed_wires: 33, + num_routed_wires: 64, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 459454db..7e19f9d7 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -7,6 +7,7 @@ use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field_types::Field; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::reducing::ReducingGate; +use crate::gates::reducing_ext::ReducingExtGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -164,54 +165,89 @@ impl ReducingFactorTarget { where F: Extendable, { - let zero = builder.zero_extension(); - let l = terms.len(); - self.count += l as u64; - - let mut terms_vec = terms.to_vec(); - // If needed, we pad the original vector so that it has even length. - if terms_vec.len().is_odd() { - terms_vec.push(zero); + let max_coeffs_len = ReducingExtGate::::max_coeffs_len( + builder.config.num_wires, + builder.config.num_routed_wires, + ); + self.count += terms.len() as u64; + let zero_ext = builder.zero_extension(); + let mut acc = zero_ext; + let mut reversed_terms = terms.to_vec(); + while reversed_terms.len() % max_coeffs_len != 0 { + reversed_terms.push(zero_ext); } - terms_vec.reverse(); + reversed_terms.reverse(); + for chunk in reversed_terms.chunks_exact(max_coeffs_len) { + let gate = ReducingExtGate::new(max_coeffs_len); + let gate_index = builder.add_gate(gate.clone(), Vec::new()); - let mut acc = zero; - for pair in terms_vec.chunks(2) { - // We will route the output of the first arithmetic operation to the multiplicand of the - // second, i.e. we compute the following: - // out_0 = alpha acc + pair[0] - // acc' = out_1 = alpha out_0 + pair[1] + builder.route_extension( + self.base, + ExtensionTarget::from_range(gate_index, ReducingGate::::wires_alpha()), + ); + builder.route_extension( + acc, + ExtensionTarget::from_range(gate_index, ReducingGate::::wires_old_acc()), + ); + for (i, &t) in chunk.iter().enumerate() { + builder.route_extension( + t, + ExtensionTarget::from_range(gate_index, ReducingExtGate::::wires_coeff(i)), + ); + } - let (gate, range) = if let Some((g, c_0, c_1)) = builder.free_arithmetic { - if c_0 == F::ONE && c_1 == F::ONE { - (g, ArithmeticExtensionGate::::wires_third_output()) - } else { - ( - builder.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - } - } else { - ( - builder.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - }; - let out_0 = ExtensionTarget::from_range(gate, range); - acc = builder - .double_arithmetic_extension( - F::ONE, - F::ONE, - self.base, - acc, - pair[0], - self.base, - out_0, - pair[1], - ) - .1; + acc = ExtensionTarget::from_range(gate_index, ReducingGate::::wires_output()); } + acc + // let zero = builder.zero_extension(); + // let l = terms.len(); + // self.count += l as u64; + // + // let mut terms_vec = terms.to_vec(); + // // If needed, we pad the original vector so that it has even length. + // if terms_vec.len().is_odd() { + // terms_vec.push(zero); + // } + // terms_vec.reverse(); + // + // let mut acc = zero; + // for pair in terms_vec.chunks(2) { + // // We will route the output of the first arithmetic operation to the multiplicand of the + // // second, i.e. we compute the following: + // // out_0 = alpha acc + pair[0] + // // acc' = out_1 = alpha out_0 + pair[1] + // + // let (gate, range) = if let Some((g, c_0, c_1)) = builder.free_arithmetic { + // if c_0 == F::ONE && c_1 == F::ONE { + // (g, ArithmeticExtensionGate::::wires_third_output()) + // } else { + // ( + // builder.num_gates(), + // ArithmeticExtensionGate::::wires_first_output(), + // ) + // } + // } else { + // ( + // builder.num_gates(), + // ArithmeticExtensionGate::::wires_first_output(), + // ) + // }; + // let out_0 = ExtensionTarget::from_range(gate, range); + // acc = builder + // .double_arithmetic_extension( + // F::ONE, + // F::ONE, + // self.base, + // acc, + // pair[0], + // self.base, + // out_0, + // pair[1], + // ) + // .1; + // } + // acc } pub fn shift( @@ -301,7 +337,10 @@ mod tests { type FF = QuarticCrandallField; const D: usize = 4; - let config = CircuitConfig::large_config(); + let config = CircuitConfig { + num_routed_wires: 64, + ..CircuitConfig::large_config() + }; let pw = PartialWitness::new(config.num_wires); let mut builder = CircuitBuilder::::new(config); @@ -321,6 +360,9 @@ mod tests { builder.assert_equal_extension(manual_reduce, circuit_reduce); + for g in &builder.gate_instances { + println!("{}", g.gate_ref.0.id()); + } let data = builder.build(); let proof = data.prove(pw)?; @@ -332,6 +374,11 @@ mod tests { test_reduce_gadget(10) } + #[test] + fn test_yo() -> Result<()> { + test_reduce_gadget(100) + } + #[test] fn test_reduce_gadget_odd() -> Result<()> { test_reduce_gadget(11)