diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 165bed41..18c696da 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -34,27 +34,17 @@ impl, const D: usize> CircuitBuilder { let g = F::primitive_root_of_unity(arity_bits); let g_inv = g.exp_u64((arity as u64) - 1); - let g_inv_t = self.constant(g_inv); // The evaluation vector needs to be reordered first. let mut evals = evals.to_vec(); reverse_index_bits_in_place(&mut evals); // Want `g^(arity - rev_x_index_within_coset)` as in the out-of-circuit version. Compute it // as `(g^-1)^rev_x_index_within_coset`. - let start = self.exp_from_bits(g_inv_t, x_index_within_coset_bits.iter().rev()); + let start = self.exp_from_bits_const_base(g_inv, x_index_within_coset_bits.iter().rev()); let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - let points = g - .powers() - .map(|y| { - let yc = self.constant(y); - self.mul(coset_start, yc) - }) - .zip(evals) - .collect::>(); - - self.interpolate(&points, beta) + self.interpolate_coset(arity_bits, coset_start, &evals, beta) } /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check @@ -65,7 +55,7 @@ impl, const D: usize> CircuitBuilder { &self.config, max_fri_arity.max(1 << self.config.cap_height), ); - let interpolation_gate = InterpolationGate::::new(max_fri_arity); + let interpolation_gate = InterpolationGate::::new(log2_strict(max_fri_arity)); let min_wires = random_access .num_wires() @@ -335,9 +325,8 @@ impl, const D: usize> CircuitBuilder { // `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain. let (mut subgroup_x, vanish_zeta) = with_context!(self, "compute x from its index", { let g = self.constant(F::coset_shift()); - let phi = self.constant(F::primitive_root_of_unity(n_log)); - - let phi = self.exp_from_bits(phi, x_index_bits.iter().rev()); + let phi = F::primitive_root_of_unity(n_log); + let phi = self.exp_from_bits_const_base(phi, x_index_bits.iter().rev()); let g_ext = self.convert_to_ext(g); let phi_ext = self.convert_to_ext(phi); // `subgroup_x = g*phi, vanish_zeta = g*phi - zeta` diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index c043afdd..74cc890d 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,6 +1,7 @@ use std::borrow::Borrow; use crate::field::extension_field::Extendable; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::exponentiation::ExponentiationGate; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -114,7 +115,16 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { - self.exp_u64(base, 1 << power_log) + if power_log > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops { + // Cheaper to just use `ExponentiateGate`. + return self.exp_u64(base, 1 << power_log); + } + + let mut product = base; + for _ in 0..power_log { + product = self.square(product); + } + product } // TODO: Test @@ -150,6 +160,39 @@ impl, const D: usize> CircuitBuilder { self.exp_from_bits(base, exponent_bits.iter()) } + /// Like `exp_from_bits` but with a constant base. + pub fn exp_from_bits_const_base( + &mut self, + base: F, + exponent_bits: impl IntoIterator>, + ) -> Target { + let base_t = self.constant(base); + let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); + + if exponent_bits.len() > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops + { + // Cheaper to just use `ExponentiateGate`. + return self.exp_from_bits(base_t, exponent_bits); + } + + let mut product = self.one(); + for (i, bit) in exponent_bits.iter().enumerate() { + let pow = 1 << i; + // If the bit is on, we multiply product by base^pow. + // We can arithmetize this as: + // product *= 1 + bit (base^pow - 1) + // product = (base^pow - 1) product bit + product + product = self.arithmetic( + base.exp_u64(pow as u64) - F::ONE, + F::ONE, + product, + bit.target, + product, + ) + } + product + } + /// Exponentiate `base` to the power of a known `exponent`. // TODO: Test pub fn exp_u64(&mut self, base: Target, mut exponent: u64) -> Target { diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 727399c6..4206e810 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -5,17 +5,20 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { - /// Interpolate a list of point/evaluation pairs at a given point. - /// Returns the evaluation of the interpolated polynomial at `evaluation_point`. - pub fn interpolate( + /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the + /// given size, and whose values are given. Returns the evaluation of the interpolant at + /// `evaluation_point`. + pub fn interpolate_coset( &mut self, - interpolation_points: &[(Target, ExtensionTarget)], + subgroup_bits: usize, + coset_shift: Target, + values: &[ExtensionTarget], evaluation_point: ExtensionTarget, ) -> ExtensionTarget { - let gate = InterpolationGate::new(interpolation_points.len()); + let gate = InterpolationGate::new(subgroup_bits); let gate_index = self.add_gate(gate.clone(), vec![]); - for (i, &(p, v)) in interpolation_points.iter().enumerate() { - self.connect(p, Target::wire(gate_index, gate.wire_point(i))); + self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift())); + for (i, &v) in values.iter().enumerate() { self.connect_extension( v, ExtensionTarget::from_range(gate_index, gate.wires_value(i)), @@ -53,14 +56,17 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let len = 4; - let points = (0..len) - .map(|_| (F::rand(), FF::rand())) - .collect::>(); + let subgroup_bits = 2; + let len = 1 << subgroup_bits; + let coset_shift = F::rand(); + let g = F::primitive_root_of_unity(subgroup_bits); + let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); + let values = FF::rand_vec(len); let homogeneous_points = points .iter() - .map(|&(a, b)| (>::from_basefield(a), b)) + .zip(values.iter()) + .map(|(&a, &b)| (>::from_basefield(a), b)) .collect::>(); let true_interpolant = interpolant(&homogeneous_points); @@ -68,14 +74,16 @@ mod tests { let z = FF::rand(); let true_eval = true_interpolant.eval(z); - let points_target = points + let coset_shift_target = builder.constant(coset_shift); + + let value_targets = values .iter() - .map(|&(p, v)| (builder.constant(p), builder.constant_extension(v))) + .map(|&v| (builder.constant_extension(v))) .collect::>(); let zt = builder.constant_extension(z); - let eval = builder.interpolate(&points_target, zt); + let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt); let true_eval_target = builder.constant_extension(true_eval); builder.connect_extension(eval, true_eval_target); diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 8c57d54d..65234f4e 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -86,18 +86,20 @@ pub trait Gate, const D: usize>: 'static + Send + Sync { .collect() } + /// Adds this gate's filtered constraints into the `combined_gate_constraints` buffer. fn eval_filtered_recursively( &self, builder: &mut CircuitBuilder, mut vars: EvaluationTargets, prefix: &[bool], - ) -> Vec> { + combined_gate_constraints: &mut Vec>, + ) { let filter = compute_filter_recursively(builder, prefix, vars.local_constants); vars.remove_prefix(prefix); - self.eval_unfiltered_recursively(builder, vars) - .into_iter() - .map(|c| builder.mul_extension(filter, c)) - .collect() + let my_constraints = self.eval_unfiltered_recursively(builder, vars); + for (acc, c) in combined_gate_constraints.iter_mut().zip(my_constraints) { + *acc = builder.mul_add_extension(filter, c, *acc); + } } fn generators( diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index c5889806..11130c43 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -243,7 +243,7 @@ mod tests { GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(BaseSumGate::<4>::new(4)), GateRef::new(GMiMCGate::::new()), - GateRef::new(InterpolationGate::new(4)), + GateRef::new(InterpolationGate::new(2)), ]; let (tree, _, _) = Tree::from_gates(gates.clone()); diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 9abd4784..88fca7ff 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -16,48 +16,45 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::polynomial::polynomial::PolynomialCoeffs; -/// Evaluates the interpolant of some given elements from a field extension. -/// -/// In particular, this gate takes as inputs `num_points` points, `num_points` values, and the point -/// to evaluate the interpolant at. It computes the interpolant and outputs its evaluation at the -/// given point. +/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. #[derive(Clone, Debug)] -pub(crate) struct InterpolationGate, const D: usize> { - pub num_points: usize, +pub(crate) struct InterpolationGate, const D: usize> { + pub subgroup_bits: usize, _phantom: PhantomData, } impl, const D: usize> InterpolationGate { - pub fn new(num_points: usize) -> Self { + pub fn new(subgroup_bits: usize) -> Self { Self { - num_points, + subgroup_bits, _phantom: PhantomData, } } - fn start_points(&self) -> usize { + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire index of the coset shift. + pub fn wire_shift(&self) -> usize { 0 } - /// Wire indices of the `i`th interpolant point. - pub fn wire_point(&self, i: usize) -> usize { - debug_assert!(i < self.num_points); - self.start_points() + i - } - fn start_values(&self) -> usize { - self.start_points() + self.num_points + 1 } /// Wire indices of the `i`th interpolant value. pub fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points); + debug_assert!(i < self.num_points()); let start = self.start_values() + i * D; start..start + D } fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points * D + self.start_values() + self.num_points() * D } /// Wire indices of the point to evaluate the interpolant at. @@ -88,14 +85,46 @@ impl, const D: usize> InterpolationGate { /// Wire indices of the interpolant's `i`th coefficient. pub fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points); + debug_assert!(i < self.num_points()); let start = self.start_coeffs() + i * D; start..start + D } /// End of wire indices, exclusive. fn end(&self) -> usize { - self.start_coeffs() + self.num_points * D + self.start_coeffs() + self.num_points() * D + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } + + /// The domain of the points we're interpolating. + fn coset_ext(&self, shift: F::Extension) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers().take(size).map(move |x| shift.scalar_mul(x)) + } + + /// The domain of the points we're interpolating. + fn coset_ext_recursive( + &self, + builder: &mut CircuitBuilder, + shift: ExtensionTarget, + ) -> Vec> { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers() + .take(size) + .map(move |x| { + let subgroup_element = builder.constant(x.into()); + builder.scalar_mul_ext(subgroup_element, shift) + }) + .collect() } } @@ -107,13 +136,13 @@ impl, const D: usize> Gate for InterpolationGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - let coeffs = (0..self.num_points) + let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - for i in 0..self.num_points { - let point = vars.local_wires[self.wire_point(i)]; + let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); @@ -130,13 +159,13 @@ impl, const D: usize> Gate for InterpolationGate { fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - let coeffs = (0..self.num_points) + let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffs::new(coeffs); - for i in 0..self.num_points { - let point = vars.local_wires[self.wire_point(i)]; + let coset = self.coset(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext(self.wires_value(i)); let computed_value = interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); @@ -157,13 +186,13 @@ impl, const D: usize> Gate for InterpolationGate { ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - let coeffs = (0..self.num_points) + let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - for i in 0..self.num_points { - let point = vars.local_wires[self.wire_point(i)]; + let coset = self.coset_ext_recursive(builder, vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval_scalar(builder, point); constraints.extend( @@ -209,13 +238,13 @@ impl, const D: usize> Gate for InterpolationGate { fn degree(&self) -> usize { // The highest power of x is `num_points - 1`, and then multiplication by the coefficient // adds 1. - self.num_points + self.num_points() } fn num_constraints(&self) -> usize { // num_points * D constraints to check for consistency between the coefficients and the // point-value pairs, plus D constraints for the evaluation value. - self.num_points * D + D + self.num_points() * D + D } } @@ -237,18 +266,18 @@ impl, const D: usize> SimpleGenerator for InterpolationGener let local_targets = |inputs: Range| inputs.map(local_target); - let mut deps = Vec::new(); + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); deps.extend(local_targets(self.gate.wires_evaluation_point())); - for i in 0..self.gate.num_points { - deps.push(local_target(self.gate.wire_point(i))); + for i in 0..num_points { deps.extend(local_targets(self.gate.wires_value(i))); } deps } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let n = self.gate.num_points; - let local_wire = |input| Wire { gate: self.gate_index, input, @@ -264,13 +293,11 @@ impl, const D: usize> SimpleGenerator for InterpolationGener }; // Compute the interpolant. - let points = (0..n) - .map(|i| { - ( - F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))), - get_local_ext(self.gate.wires_value(i)), - ) - }) + let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); + let points = points + .into_iter() + .enumerate() + .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) .collect::>(); let interpolant = interpolant(&points); @@ -305,26 +332,25 @@ mod tests { #[test] fn wire_indices() { let gate = InterpolationGate:: { - num_points: 2, + subgroup_bits: 1, _phantom: PhantomData, }; // The exact indices aren't really important, but we want to make sure we don't have any // overlaps or gaps. - assert_eq!(gate.wire_point(0), 0); - assert_eq!(gate.wire_point(1), 1); - assert_eq!(gate.wires_value(0), 2..6); - assert_eq!(gate.wires_value(1), 6..10); - assert_eq!(gate.wires_evaluation_point(), 10..14); - assert_eq!(gate.wires_evaluation_value(), 14..18); - assert_eq!(gate.wires_coeff(0), 18..22); - assert_eq!(gate.wires_coeff(1), 22..26); - assert_eq!(gate.num_wires(), 26); + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_evaluation_point(), 9..13); + assert_eq!(gate.wires_evaluation_value(), 13..17); + assert_eq!(gate.wires_coeff(0), 17..21); + assert_eq!(gate.wires_coeff(1), 21..25); + assert_eq!(gate.num_wires(), 25); } #[test] fn low_degree() { - test_low_degree::(InterpolationGate::new(4)); + test_low_degree::(InterpolationGate::new(2)); } #[test] @@ -332,7 +358,7 @@ mod tests { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; - test_eval_fns::(InterpolationGate::new(4)) + test_eval_fns::(InterpolationGate::new(2)) } #[test] @@ -344,15 +370,15 @@ mod tests { /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. fn get_wires( - num_points: usize, + gate: &InterpolationGate, + shift: F, coeffs: PolynomialCoeffs, - points: Vec, eval_point: FF, ) -> Vec { - let mut v = Vec::new(); - v.extend_from_slice(&points); - for j in 0..num_points { - v.extend(coeffs.eval(points[j].into()).0); + let points = gate.coset(shift); + let mut v = vec![shift]; + for x in points { + v.extend(coeffs.eval(x.into()).0); } v.extend(eval_point.0); v.extend(coeffs.eval(eval_point).0); @@ -363,16 +389,13 @@ mod tests { } // Get a working row for InterpolationGate. + let shift = F::rand(); let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); - let points = vec![F::rand(), F::rand()]; let eval_point = FF::rand(); - let gate = InterpolationGate:: { - num_points: 2, - _phantom: PhantomData, - }; + let gate = InterpolationGate::::new(1); let vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(2, coeffs, points, eval_point), + local_wires: &get_wires(&gate, shift, coeffs, eval_point), public_inputs_hash: &HashOut::rand(), }; diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 8419b7b9..e8ff8797 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -53,7 +53,7 @@ impl CircuitConfig { pub(crate) fn standard_recursion_config() -> Self { Self { num_wires: 143, - num_routed_wires: 28, + num_routed_wires: 25, constant_gate_size: 6, security_bits: 100, rate_bits: 3, diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index b16c8396..a68fe5c5 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -293,24 +293,20 @@ pub fn evaluate_gate_constraints_recursively, const D: usize>( num_gate_constraints: usize, vars: EvaluationTargets, ) -> Vec> { - let mut all_gate_constraints = vec![vec![]; num_gate_constraints]; + let mut all_gate_constraints = vec![builder.zero_extension(); num_gate_constraints]; for gate in gates { - let gate_constraints = with_context!( + with_context!( builder, &format!("evaluate {} constraints", gate.gate.0.id()), - gate.gate - .0 - .eval_filtered_recursively(builder, vars, &gate.prefix) + gate.gate.0.eval_filtered_recursively( + builder, + vars, + &gate.prefix, + &mut all_gate_constraints + ) ); - for (i, c) in gate_constraints.into_iter().enumerate() { - all_gate_constraints[i].push(c); - } } - let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; - for (i, v) in all_gate_constraints.into_iter().enumerate() { - constraints[i] = builder.add_many_extension(&v); - } - constraints + all_gate_constraints } /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random