diff --git a/starky/src/constraint_consumer.rs b/starky/src/constraint_consumer.rs index bc76f03b..adb88e41 100644 --- a/starky/src/constraint_consumer.rs +++ b/starky/src/constraint_consumer.rs @@ -8,11 +8,11 @@ use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; pub struct ConstraintConsumer { - /// A random value used to combine multiple constraints into one. - alpha: P::Scalar, + /// Random values used to combine multiple constraints into one. + alphas: Vec, - /// A running sum of constraints that have been emitted so far, scaled by powers of alpha. - constraint_acc: P, + /// Running sums of constraints that have been emitted so far, scaled by powers of alpha. + constraint_accs: Vec

, /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated /// with the first trace row, and zero at other points in the subgroup. @@ -24,24 +24,29 @@ pub struct ConstraintConsumer { } impl ConstraintConsumer

{ - pub fn new(alpha: P::Scalar, lagrange_basis_first: P, lagrange_basis_last: P) -> Self { + pub fn new(alphas: Vec, lagrange_basis_first: P, lagrange_basis_last: P) -> Self { Self { - alpha, - constraint_acc: P::ZEROS, + constraint_accs: vec![P::ZEROS; alphas.len()], + alphas, lagrange_basis_first, lagrange_basis_last, } } // TODO: Do this correctly. - pub fn accumulator(&self) -> P::Scalar { - self.constraint_acc.as_slice()[0] + pub fn accumulators(self) -> Vec { + self.constraint_accs + .into_iter() + .map(|acc| acc.as_slice()[0]) + .collect() } /// Add one constraint. pub fn one(&mut self, constraint: P) { - self.constraint_acc *= self.alpha; - self.constraint_acc += constraint; + for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) { + *acc *= alpha; + *acc += constraint; + } } /// Add a series of constraints. diff --git a/starky/src/prover.rs b/starky/src/prover.rs index 5473db68..e0652b24 100644 --- a/starky/src/prover.rs +++ b/starky/src/prover.rs @@ -74,7 +74,7 @@ where &stark, &trace_commitment, public_inputs, - &alphas, + alphas, degree_bits, rate_bits, ); @@ -145,7 +145,7 @@ fn compute_quotient_polys( stark: &S, trace_commitment: &PolynomialBatch, public_inputs: [F; S::PUBLIC_INPUTS], - alphas: &[F], + alphas: Vec, degree_bits: usize, rate_bits: usize, ) -> Vec> @@ -179,37 +179,34 @@ where comm.get_lde_values(i).try_into().unwrap() }; - alphas - .iter() - .map(|&alpha| { - let quotient_evals = PolynomialValues::new( - (0..degree << rate_bits) - .into_par_iter() - .map(|i| { - // TODO: Set `P` to a genuine `PackedField` here. - let mut consumer = ConstraintConsumer::::new( - alpha, - lagrange_first.values[i], - lagrange_last.values[i], - ); - let vars = - StarkEvaluationVars:: { - local_values: &get_at_index(trace_commitment, i), - next_values: &get_at_index( - trace_commitment, - (i + 1) % (degree << rate_bits), - ), - public_inputs: &public_inputs, - }; - stark.eval_packed_base(vars, &mut consumer); - // TODO: Fix this once we a genuine `PackedField`. - let constraints_eval = consumer.accumulator(); - let denominator_inv = z_h_on_coset.eval_inverse(i); - constraints_eval * denominator_inv - }) - .collect(), + let quotient_values = (0..degree << rate_bits) + .into_par_iter() + .map(|i| { + // TODO: Set `P` to a genuine `PackedField` here. + let mut consumer = ConstraintConsumer::::new( + alphas.clone(), + lagrange_first.values[i], + lagrange_last.values[i], ); - quotient_evals.coset_ifft(F::coset_shift()) + let vars = StarkEvaluationVars:: { + local_values: &get_at_index(trace_commitment, i), + next_values: &get_at_index(trace_commitment, (i + 1) % (degree << rate_bits)), + public_inputs: &public_inputs, + }; + stark.eval_packed_base(vars, &mut consumer); + // TODO: Fix this once we a genuine `PackedField`. + let mut constraints_evals = consumer.accumulators(); + let denominator_inv = z_h_on_coset.eval_inverse(i); + for eval in &mut constraints_evals { + *eval *= denominator_inv; + } + constraints_evals }) + .collect::>(); + + transpose("ient_values) + .into_par_iter() + .map(PolynomialValues::new) + .map(|values| values.coset_ifft(F::coset_shift())) .collect() }