diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index c79a2a37..245c6d17 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -26,38 +26,6 @@ pub(crate) struct LowDegreeInterpolationGate, const _phantom: PhantomData, } -impl, const D: usize> LowDegreeInterpolationGate { - pub fn powers_init(&self, i: usize) -> usize { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wire_shift(); - } - self.end_coeffs() + i - 2 - } - - pub fn powers_eval(&self, i: usize) -> Range { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wires_evaluation_point(); - } - let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; - start..start + D - } - - /// End of wire indices, exclusive. - fn end(&self) -> usize { - self.powers_eval(self.num_points() - 1).end - } - - /// The domain of the points we're interpolating. - fn coset(&self, shift: F) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. - g.powers().take(size).map(move |x| x * shift) - } -} - impl, const D: usize> InterpolationGate for LowDegreeInterpolationGate { @@ -130,6 +98,40 @@ impl, const D: usize> InterpolationGate } } +impl, const D: usize> LowDegreeInterpolationGate { + /// `powers_shift(i)` is the wire index of `wire_shift^i`. + pub fn powers_shift(&self, i: usize) -> usize { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wire_shift(); + } + self.end_coeffs() + i - 2 + } + + /// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`. + pub fn powers_evaluation_point(&self, i: usize) -> Range { + debug_assert!(0 < i && i < self.num_points()); + if i == 1 { + return self.wires_evaluation_point(); + } + let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; + start..start + D + } + + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.powers_evaluation_point(self.num_points() - 1).end + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } +} + impl, const D: usize> Gate for LowDegreeInterpolationGate { fn id(&self) -> String { format!("{:?}", self, D) @@ -142,33 +144,35 @@ impl, const D: usize> Gate for LowDegreeInter .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) .collect::>(); - powers_init.insert(0, F::Extension::ONE); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); } - let ocoeffs = coeffs + powers_shift.insert(0, F::Extension::ONE); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs .iter() - .zip(powers_init) + .zip(powers_shift) .map(|(&c, p)| c.scalar_mul(p)) .collect::>(); let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - let ointerpolant = PolynomialCoeffsAlgebra::new(ocoeffs); + let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs); for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) .into_iter() .enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = ointerpolant.eval_base(point); + let computed_value = altered_interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { @@ -190,33 +194,36 @@ impl, const D: usize> Gate for LowDegreeInter let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) .collect::>(); - powers_init.insert(0, F::ONE); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { - constraints.push(powers_init[i - 1] * wire_shift - powers_init[i]); + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { + constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); } - let ocoeffs = coeffs + powers_shift.insert(0, F::ONE); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs .iter() - .zip(powers_init) + .zip(powers_shift) .map(|(&c, p)| c.scalar_mul(p)) .collect::>(); let interpolant = PolynomialCoeffs::new(coeffs); - let ointerpolant = PolynomialCoeffs::new(ocoeffs); + let altered_interpolant = PolynomialCoeffs::new(altered_coeffs); for (i, point) in F::two_adic_subgroup(self.subgroup_bits) .into_iter() .enumerate() { let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = ointerpolant.eval_base(point); + let computed_value = altered_interpolant.eval_base(point); constraints.extend(&(value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext(self.powers_eval(i))) + .map(|i| vars.get_local_ext(self.powers_evaluation_point(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { @@ -242,25 +249,28 @@ impl, const D: usize> Gate for LowDegreeInter let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect::>(); - let mut powers_init = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_init(i)]) + + let mut powers_shift = (1..self.num_points()) + .map(|i| vars.local_wires[self.powers_shift(i)]) .collect::>(); - powers_init.insert(0, builder.one_extension()); - let wire_shift = powers_init[1]; - for i in 2..self.num_points() { + let shift = powers_shift[0]; + for i in 1..self.num_points() - 1 { constraints.push(builder.mul_sub_extension( - powers_init[i - 1], - wire_shift, - powers_init[i], + powers_shift[i - 1], + shift, + powers_shift[i], )); } - let ocoeffs = coeffs + powers_shift.insert(0, builder.one_extension()); + // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. + // Then, `altered(w^i) = original(shift*w^i)`. + let altered_coeffs = coeffs .iter() - .zip(powers_init) + .zip(powers_shift) .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) .collect::>(); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - let ointerpolant = PolynomialCoeffsExtAlgebraTarget(ocoeffs); + let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs); for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) .into_iter() @@ -268,7 +278,7 @@ impl, const D: usize> Gate for LowDegreeInter { let value = vars.get_local_ext_algebra(self.wires_value(i)); let point = builder.constant_extension(point); - let computed_value = ointerpolant.eval_scalar(builder, point); + let computed_value = altered_interpolant.eval_scalar(builder, point); constraints.extend( &builder .sub_ext_algebra(value, computed_value) @@ -277,7 +287,7 @@ impl, const D: usize> Gate for LowDegreeInter } let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_eval(i))) + .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { @@ -328,8 +338,6 @@ impl, const D: usize> Gate for LowDegreeInter } fn degree(&self) -> usize { - // The highest power of x is `num_points - 1`, and then multiplication by the coefficient - // adds 1. 2 } @@ -395,7 +403,7 @@ impl, const D: usize> SimpleGenerator .enumerate() .skip(2) { - out_buffer.set_wire(local_wire(self.gate.powers_init(i)), power); + out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power); } // Compute the interpolant. @@ -413,10 +421,15 @@ impl, const D: usize> SimpleGenerator } let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - for i in 2..self.gate.num_points() { + for (i, power) in evaluation_point + .powers() + .take(self.gate.num_points()) + .enumerate() + .skip(2) + { out_buffer.set_extension_target( - ExtensionTarget::from_range(self.gate_index, self.gate.powers_eval(i)), - evaluation_point.exp_u64(i as u64), + ExtensionTarget::from_range(self.gate_index, self.gate.powers_evaluation_point(i)), + power, ); } let evaluation_value = interpolant.eval(evaluation_point);