From 671bb9be2e1e84b22b9a8296ffdc44c22b9a0316 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Fri, 5 Nov 2021 09:29:08 -0700 Subject: [PATCH] Specialize `InterpolationGate` (#339) * Specialize `InterpolationGate` To cosets of subgroups of roots of unity. This way - `InterpolationGate` needs fewer routed wires, bringing our minimum routed wires down from 28 to 25. - The recursive `compute_evaluation` avoids some multiplications, saving 100~200 gates depending on `num_routed_wires`. * Update test * feedback --- src/fri/recursive_verifier.rs | 13 +-- src/gadgets/interpolation.rs | 38 ++++---- src/gates/gate_tree.rs | 2 +- src/gates/interpolation.rs | 157 +++++++++++++++++++--------------- src/plonk/circuit_data.rs | 2 +- 5 files changed, 117 insertions(+), 95 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index aae0963c..fc0a7341 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -43,16 +43,7 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - let points = g - .powers() - .map(|y| { - let yc = self.constant(y); - self.mul(coset_start, yc) - }) - .zip(evals) - .collect::>(); - - self.interpolate(&points, beta) + self.interpolate_coset(arity_bits, coset_start, &evals, beta) } /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check @@ -63,7 +54,7 @@ impl, const D: usize> CircuitBuilder { &self.config, max_fri_arity.max(1 << self.config.cap_height), ); - let interpolation_gate = InterpolationGate::::new(max_fri_arity); + let interpolation_gate = InterpolationGate::::new(log2_strict(max_fri_arity)); let min_wires = random_access .num_wires() diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index f72ed0fc..09b6329b 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -6,17 +6,20 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { - /// Interpolate a list of point/evaluation pairs at a given point. - /// Returns the evaluation of the interpolated polynomial at `evaluation_point`. - pub fn interpolate( + /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the + /// given size, and whose values are given. Returns the evaluation of the interpolant at + /// `evaluation_point`. + pub fn interpolate_coset( &mut self, - interpolation_points: &[(Target, ExtensionTarget)], + subgroup_bits: usize, + coset_shift: Target, + values: &[ExtensionTarget], evaluation_point: ExtensionTarget, ) -> ExtensionTarget { - let gate = InterpolationGate::new(interpolation_points.len()); + let gate = InterpolationGate::new(subgroup_bits); let gate_index = self.add_gate(gate.clone(), vec![]); - for (i, &(p, v)) in interpolation_points.iter().enumerate() { - self.connect(p, Target::wire(gate_index, gate.wire_point(i))); + self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift())); + for (i, &v) in values.iter().enumerate() { self.connect_extension( v, ExtensionTarget::from_range(gate_index, gate.wires_value(i)), @@ -53,14 +56,17 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let len = 4; - let points = (0..len) - .map(|_| (F::rand(), FF::rand())) - .collect::>(); + let subgroup_bits = 2; + let len = 1 << subgroup_bits; + let coset_shift = F::rand(); + let g = F::primitive_root_of_unity(subgroup_bits); + let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); + let values = FF::rand_vec(len); let homogeneous_points = points .iter() - .map(|&(a, b)| (>::from_basefield(a), b)) + .zip(values.iter()) + .map(|(&a, &b)| (>::from_basefield(a), b)) .collect::>(); let true_interpolant = interpolant(&homogeneous_points); @@ -68,14 +74,16 @@ mod tests { let z = FF::rand(); let true_eval = true_interpolant.eval(z); - let points_target = points + let coset_shift_target = builder.constant(coset_shift); + + let value_targets = values .iter() - .map(|&(p, v)| (builder.constant(p), builder.constant_extension(v))) + .map(|&v| (builder.constant_extension(v))) .collect::>(); let zt = builder.constant_extension(z); - let eval = builder.interpolate(&points_target, zt); + let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt); let true_eval_target = builder.constant_extension(true_eval); builder.connect_extension(eval, true_eval_target); diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 83a7e2fe..ed9a73ac 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -242,7 +242,7 @@ mod tests { GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(BaseSumGate::<4>::new(4)), GateRef::new(GMiMCGate::::new()), - GateRef::new(InterpolationGate::new(4)), + GateRef::new(InterpolationGate::new(2)), ]; let (tree, _, _) = Tree::from_gates(gates.clone()); diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 24b755d0..52dca440 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -17,48 +17,45 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::polynomial::polynomial::PolynomialCoeffs; -/// Evaluates the interpolant of some given elements from a field extension. -/// -/// In particular, this gate takes as inputs `num_points` points, `num_points` values, and the point -/// to evaluate the interpolant at. It computes the interpolant and outputs its evaluation at the -/// given point. +/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. #[derive(Clone, Debug)] pub(crate) struct InterpolationGate, const D: usize> { - pub num_points: usize, + pub subgroup_bits: usize, _phantom: PhantomData, } impl, const D: usize> InterpolationGate { - pub fn new(num_points: usize) -> Self { + pub fn new(subgroup_bits: usize) -> Self { Self { - num_points, + subgroup_bits, _phantom: PhantomData, } } - fn start_points(&self) -> usize { + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire index of the coset shift. + pub fn wire_shift(&self) -> usize { 0 } - /// Wire indices of the `i`th interpolant point. - pub fn wire_point(&self, i: usize) -> usize { - debug_assert!(i < self.num_points); - self.start_points() + i - } - fn start_values(&self) -> usize { - self.start_points() + self.num_points + 1 } /// Wire indices of the `i`th interpolant value. pub fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points); + debug_assert!(i < self.num_points()); let start = self.start_values() + i * D; start..start + D } fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points * D + self.start_values() + self.num_points() * D } /// Wire indices of the point to evaluate the interpolant at. @@ -89,14 +86,46 @@ impl, const D: usize> InterpolationGate { /// Wire indices of the interpolant's `i`th coefficient. pub fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points); + debug_assert!(i < self.num_points()); let start = self.start_coeffs() + i * D; start..start + D } /// End of wire indices, exclusive. fn end(&self) -> usize { - self.start_coeffs() + self.num_points * D + self.start_coeffs() + self.num_points() * D + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } + + /// The domain of the points we're interpolating. + fn coset_ext(&self, shift: F::Extension) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers().take(size).map(move |x| shift.scalar_mul(x)) + } + + /// The domain of the points we're interpolating. + fn coset_ext_recursive( + &self, + builder: &mut CircuitBuilder, + shift: ExtensionTarget, + ) -> Vec> { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers() + .take(size) + .map(move |x| { + let subgroup_element = builder.constant(x.into()); + builder.scalar_mul_ext(subgroup_element, shift) + }) + .collect() } } @@ -108,13 +137,13 @@ impl, const D: usize> Gate for InterpolationG fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - let coeffs = (0..self.num_points) + let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - for i in 0..self.num_points { - let point = vars.local_wires[self.wire_point(i)]; + let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); @@ -131,13 +160,13 @@ impl, const D: usize> Gate for InterpolationG fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - let coeffs = (0..self.num_points) + let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffs::new(coeffs); - for i in 0..self.num_points { - let point = vars.local_wires[self.wire_point(i)]; + let coset = self.coset(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext(self.wires_value(i)); let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); @@ -158,13 +187,13 @@ impl, const D: usize> Gate for InterpolationG ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - let coeffs = (0..self.num_points) + let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - for i in 0..self.num_points { - let point = vars.local_wires[self.wire_point(i)]; + let coset = self.coset_ext_recursive(builder, vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval_scalar(builder, point); constraints.extend( @@ -210,13 +239,13 @@ impl, const D: usize> Gate for InterpolationG fn degree(&self) -> usize { // The highest power of x is `num_points - 1`, and then multiplication by the coefficient // adds 1. - self.num_points + self.num_points() } fn num_constraints(&self) -> usize { // num_points * D constraints to check for consistency between the coefficients and the // point-value pairs, plus D constraints for the evaluation value. - self.num_points * D + D + self.num_points() * D + D } } @@ -240,18 +269,18 @@ impl, const D: usize> SimpleGenerator let local_targets = |inputs: Range| inputs.map(local_target); - let mut deps = Vec::new(); + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); deps.extend(local_targets(self.gate.wires_evaluation_point())); - for i in 0..self.gate.num_points { - deps.push(local_target(self.gate.wire_point(i))); + for i in 0..num_points { deps.extend(local_targets(self.gate.wires_value(i))); } deps } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let n = self.gate.num_points; - let local_wire = |input| Wire { gate: self.gate_index, input, @@ -267,13 +296,11 @@ impl, const D: usize> SimpleGenerator }; // Compute the interpolant. - let points = (0..n) - .map(|i| { - ( - F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))), - get_local_ext(self.gate.wires_value(i)), - ) - }) + let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); + let points = points + .into_iter() + .enumerate() + .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) .collect::>(); let interpolant = interpolant(&points); @@ -308,31 +335,30 @@ mod tests { #[test] fn wire_indices() { let gate = InterpolationGate:: { - num_points: 2, + subgroup_bits: 1, _phantom: PhantomData, }; // The exact indices aren't really important, but we want to make sure we don't have any // overlaps or gaps. - assert_eq!(gate.wire_point(0), 0); - assert_eq!(gate.wire_point(1), 1); - assert_eq!(gate.wires_value(0), 2..6); - assert_eq!(gate.wires_value(1), 6..10); - assert_eq!(gate.wires_evaluation_point(), 10..14); - assert_eq!(gate.wires_evaluation_value(), 14..18); - assert_eq!(gate.wires_coeff(0), 18..22); - assert_eq!(gate.wires_coeff(1), 22..26); - assert_eq!(gate.num_wires(), 26); + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_evaluation_point(), 9..13); + assert_eq!(gate.wires_evaluation_value(), 13..17); + assert_eq!(gate.wires_coeff(0), 17..21); + assert_eq!(gate.wires_coeff(1), 21..25); + assert_eq!(gate.num_wires(), 25); } #[test] fn low_degree() { - test_low_degree::(InterpolationGate::new(4)); + test_low_degree::(InterpolationGate::new(2)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(InterpolationGate::new(4)) + test_eval_fns::(InterpolationGate::new(2)) } #[test] @@ -343,15 +369,15 @@ mod tests { /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. fn get_wires( - num_points: usize, + gate: &InterpolationGate, + shift: F, coeffs: PolynomialCoeffs, - points: Vec, eval_point: FF, ) -> Vec { - let mut v = Vec::new(); - v.extend_from_slice(&points); - for j in 0..num_points { - v.extend(coeffs.eval(points[j].into()).0); + let points = gate.coset(shift); + let mut v = vec![shift]; + for x in points { + v.extend(coeffs.eval(x.into()).0); } v.extend(eval_point.0); v.extend(coeffs.eval(eval_point).0); @@ -362,16 +388,13 @@ mod tests { } // Get a working row for InterpolationGate. + let shift = F::rand(); let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); - let points = vec![F::rand(), F::rand()]; let eval_point = FF::rand(); - let gate = InterpolationGate:: { - num_points: 2, - _phantom: PhantomData, - }; + let gate = InterpolationGate::::new(1); let vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(2, coeffs, points, eval_point), + local_wires: &get_wires(&gate, shift, coeffs, eval_point), public_inputs_hash: &HashOut::rand(), }; diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 2f98d0a6..9ba05a87 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -53,7 +53,7 @@ impl CircuitConfig { pub(crate) fn standard_recursion_config() -> Self { Self { num_wires: 143, - num_routed_wires: 28, + num_routed_wires: 25, constant_gate_size: 6, security_bits: 100, rate_bits: 3,