From 482dfe559aa151721d6d2ae8b768e12f2a01ffce Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 20 Mar 2022 08:58:23 -0700 Subject: [PATCH] Vectorize constraint evaluation in Starky (#520) --- field/src/zero_poly_coset.rs | 12 ++++++++ plonky2/src/fri/oracle.rs | 32 +++++++++++++++++++- plonky2/src/plonk/prover.rs | 18 +++++------ plonky2/src/util/reducing.rs | 22 ++++++++++---- starky/src/permutation.rs | 25 +++++++++------- starky/src/prover.rs | 58 +++++++++++++++++++++++------------- starky/src/vanishing_poly.rs | 6 ++-- 7 files changed, 121 insertions(+), 52 deletions(-) diff --git a/field/src/zero_poly_coset.rs b/field/src/zero_poly_coset.rs index 0b7452f5..f4f6e722 100644 --- a/field/src/zero_poly_coset.rs +++ b/field/src/zero_poly_coset.rs @@ -1,4 +1,5 @@ use crate::field_types::Field; +use crate::packed_field::PackedField; /// Precomputations of the evaluation of `Z_H(X) = X^n - 1` on a coset `gK` with `H <= K`. pub struct ZeroPolyOnCoset { @@ -39,6 +40,17 @@ impl ZeroPolyOnCoset { self.inverses[i % self.rate] } + /// Like `eval_inverse`, but for a range of indices starting with `i_start`. + pub fn eval_inverse_packed>(&self, i_start: usize) -> P { + let mut packed = P::ZEROS; + packed + .as_slice_mut() + .iter_mut() + .enumerate() + .for_each(|(j, packed_j)| *packed_j = self.eval_inverse(i_start + j)); + packed + } + /// Returns `L_1(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. pub fn eval_l1(&self, i: usize, x: F) -> F { // Could also precompute the inverses using Montgomery. diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index bd1e9ac5..30058423 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -1,6 +1,8 @@ +use itertools::Itertools; use plonky2_field::extension_field::Extendable; use plonky2_field::fft::FftRootTable; use plonky2_field::field_types::Field; +use plonky2_field::packed_field::PackedField; use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2_util::{log2_strict, reverse_index_bits_in_place}; use rayon::prelude::*; @@ -126,12 +128,40 @@ impl, C: GenericConfig, const D: usize> .collect() } - pub fn get_lde_values(&self, index: usize) -> &[F] { + /// Fetches LDE values at the `index * step`th point. + pub fn get_lde_values(&self, index: usize, step: usize) -> &[F] { + let index = index * step; let index = reverse_bits(index, self.degree_log + self.rate_bits); let slice = &self.merkle_tree.leaves[index]; &slice[..slice.len() - if self.blinding { SALT_SIZE } else { 0 }] } + /// Like `get_lde_values`, but fetches LDE values from a batch of `P::WIDTH` points, and returns + /// packed values. + pub fn get_lde_values_packed

(&self, index_start: usize, step: usize) -> Vec

+ where + P: PackedField, + { + let row_wise = (0..P::WIDTH) + .map(|i| self.get_lde_values(index_start + i, step)) + .collect_vec(); + + // This is essentially a transpose, but we will not use the generic transpose method as we + // want inner lists to be of type P, not Vecs which would involve allocation. + let leaf_size = row_wise[0].len(); + (0..leaf_size) + .map(|j| { + let mut packed = P::ZEROS; + packed + .as_slice_mut() + .iter_mut() + .zip(&row_wise) + .for_each(|(packed_i, row_i)| *packed_i = row_i[j]); + packed + }) + .collect_vec() + } + /// Produces a batch opening proof. pub fn prove_openings( instance: &FriInstanceInfo, diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index 1d99b60a..ce9e1582 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -352,10 +352,6 @@ fn compute_quotient_polys< let points = F::two_adic_subgroup(common_data.degree_bits + quotient_degree_bits); let lde_size = points.len(); - // Retrieve the LDE values at index `i`. - let get_at_index = - |comm: &'a PolynomialBatch, i: usize| -> &'a [F] { comm.get_lde_values(i * step) }; - let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, quotient_degree_bits); let points_batches = points.par_chunks(BATCH_SIZE); @@ -384,15 +380,17 @@ fn compute_quotient_polys< for (&i, &x) in indices_batch.iter().zip(xs_batch) { let shifted_x = F::coset_shift() * x; let i_next = (i + next_step) % lde_size; - let local_constants_sigmas = - get_at_index(&prover_data.constants_sigmas_commitment, i); + let local_constants_sigmas = prover_data + .constants_sigmas_commitment + .get_lde_values(i, step); let local_constants = &local_constants_sigmas[common_data.constants_range()]; let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()]; - let local_wires = get_at_index(wires_commitment, i); - let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i); + let local_wires = wires_commitment.get_lde_values(i, step); + let local_zs_partial_products = + zs_partial_products_commitment.get_lde_values(i, step); let local_zs = &local_zs_partial_products[common_data.zs_range()]; - let next_zs = - &get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()]; + let next_zs = &zs_partial_products_commitment.get_lde_values(i_next, step) + [common_data.zs_range()]; let partial_products = &local_zs_partial_products[common_data.partial_products_range()]; diff --git a/plonky2/src/util/reducing.rs b/plonky2/src/util/reducing.rs index 626668e6..991b30e7 100644 --- a/plonky2/src/util/reducing.rs +++ b/plonky2/src/util/reducing.rs @@ -2,6 +2,7 @@ use std::borrow::Borrow; use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::field_types::Field; +use plonky2_field::packed_field::PackedField; use plonky2_field::polynomial::PolynomialCoeffs; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; @@ -35,9 +36,14 @@ impl ReducingFactor { self.base * x } - fn mul_ext, const D: usize>(&mut self, x: FE) -> FE { + fn mul_ext(&mut self, x: P) -> P + where + FE: FieldExtension, + P: PackedField, + { self.count += 1; - x.scalar_mul(self.base) + // TODO: Would like to use `FE::scalar_mul`, but it doesn't work with Packed currently. + x * FE::from_basefield(self.base) } fn mul_poly(&mut self, p: &mut PolynomialCoeffs) { @@ -50,12 +56,16 @@ impl ReducingFactor { .fold(F::ZERO, |acc, x| self.mul(acc) + *x.borrow()) } - pub fn reduce_ext, const D: usize>( + pub fn reduce_ext( &mut self, - iter: impl DoubleEndedIterator>, - ) -> FE { + iter: impl DoubleEndedIterator>, + ) -> P + where + FE: FieldExtension, + P: PackedField, + { iter.rev() - .fold(FE::ZERO, |acc, x| self.mul_ext(acc) + *x.borrow()) + .fold(P::ZEROS, |acc, x| self.mul_ext(acc) + *x.borrow()) } pub fn reduce_polys( diff --git a/starky/src/permutation.rs b/starky/src/permutation.rs index 91b1be27..443ff787 100644 --- a/starky/src/permutation.rs +++ b/starky/src/permutation.rs @@ -246,19 +246,23 @@ pub(crate) fn get_permutation_batches<'a, T: Copy>( .collect() } -// TODO: Use slices. -pub struct PermutationCheckVars, const D2: usize> { - pub(crate) local_zs: Vec, - pub(crate) next_zs: Vec, +pub struct PermutationCheckVars +where + F: Field, + FE: FieldExtension, + P: PackedField, +{ + pub(crate) local_zs: Vec

, + pub(crate) next_zs: Vec

, pub(crate) permutation_challenge_sets: Vec>, } pub(crate) fn eval_permutation_checks( stark: &S, config: &StarkConfig, - vars: StarkEvaluationVars, - permutation_data: PermutationCheckVars, - consumer: &mut ConstraintConsumer, + vars: StarkEvaluationVars, + permutation_data: PermutationCheckVars, + consumer: &mut ConstraintConsumer

, ) where F: RichField + Extendable, FE: FieldExtension, @@ -291,7 +295,7 @@ pub(crate) fn eval_permutation_checks, Vec) = instances + let (reduced_lhs, reduced_rhs): (Vec

, Vec

) = instances .iter() .map(|instance| { let PermutationInstance { @@ -309,13 +313,12 @@ pub(crate) fn eval_permutation_checks() + - local_zs[i] * reduced_lhs.into_iter().product::

(); consumer.constraint(constraint); } } -// TODO: Use slices. pub struct PermutationCheckDataTarget { pub(crate) local_zs: Vec>, pub(crate) next_zs: Vec>, diff --git a/starky/src/prover.rs b/starky/src/prover.rs index da1b5dd4..582054ea 100644 --- a/starky/src/prover.rs +++ b/starky/src/prover.rs @@ -4,6 +4,8 @@ use anyhow::{ensure, Result}; use itertools::Itertools; use plonky2::field::extension_field::Extendable; use plonky2::field::field_types::Field; +use plonky2::field::packable::Packable; +use plonky2::field::packed_field::PackedField; use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2::field::zero_poly_coset::ZeroPolyOnCoset; use plonky2::fri::oracle::PolynomialBatch; @@ -40,6 +42,7 @@ where S: Stark, [(); S::COLUMNS]:, [(); S::PUBLIC_INPUTS]:, + [(); <::Packing>::WIDTH]:, [(); C::Hasher::HASH_SIZE]:, { let degree = trace_poly_values[0].len(); @@ -110,7 +113,7 @@ where } let alphas = challenger.get_n_challenges(config.num_challenges); - let quotient_polys = compute_quotient_polys::( + let quotient_polys = compute_quotient_polys::::Packing, C, S, D>( &stark, &trace_commitment, &permutation_zs_commitment_challenges, @@ -194,7 +197,7 @@ where /// Computes the quotient polynomials `(sum alpha^i C_i(x)) / Z_H(x)` for `alpha` in `alphas`, /// where the `C_i`s are the Stark constraints. -fn compute_quotient_polys<'a, F, C, S, const D: usize>( +fn compute_quotient_polys<'a, F, P, C, S, const D: usize>( stark: &S, trace_commitment: &'a PolynomialBatch, permutation_zs_commitment_challenges: &'a Option<( @@ -208,10 +211,12 @@ fn compute_quotient_polys<'a, F, C, S, const D: usize>( ) -> Vec> where F: RichField + Extendable, + P: PackedField, C: GenericConfig, S: Stark, [(); S::COLUMNS]:, [(); S::PUBLIC_INPUTS]:, + [(); P::WIDTH]:, { let degree = 1 << degree_bits; let rate_bits = config.fri_config.rate_bits; @@ -234,9 +239,12 @@ where let z_h_on_coset = ZeroPolyOnCoset::::new(degree_bits, quotient_degree_bits); // Retrieve the LDE values at index `i`. - let get_at_index = - |comm: &'a PolynomialBatch, i: usize| -> &'a [F] { comm.get_lde_values(i * step) }; - let get_trace_at_index = |i| get_at_index(trace_commitment, i).try_into().unwrap(); + let get_trace_values_packed = |i_start| -> [P; S::COLUMNS] { + trace_commitment + .get_lde_values_packed(i_start, step) + .try_into() + .unwrap() + }; // Last element of the subgroup. let last = F::primitive_root_of_unity(degree_bits).inverse(); @@ -247,41 +255,49 @@ where size, ); + // We will step by `P::WIDTH`, and in each iteration, evaluate the quotient polynomial at + // a batch of `P::WIDTH` points. let quotient_values = (0..size) .into_par_iter() - .map(|i| { - // TODO: Set `P` to a genuine `PackedField` here. - let mut consumer = ConstraintConsumer::::new( + .step_by(P::WIDTH) + .map(|i_start| { + let i_next_start = (i_start + next_step) % size; + let i_range = i_start..i_start + P::WIDTH; + let i_next_range = i_next_start..i_next_start + P::WIDTH; + + let x = *P::from_slice(&coset[i_range.clone()]); + let z_last = x - last; + let lagrange_basis_first = *P::from_slice(&lagrange_first.values[i_range.clone()]); + let lagrange_basis_last = *P::from_slice(&lagrange_last.values[i_range]); + + let mut consumer = ConstraintConsumer::new( alphas.clone(), - coset[i] - last, - lagrange_first.values[i], - lagrange_last.values[i], + z_last, + lagrange_basis_first, + lagrange_basis_last, ); - let vars = StarkEvaluationVars:: { - local_values: &get_trace_at_index(i), - next_values: &get_trace_at_index((i + next_step) % size), + let vars = StarkEvaluationVars { + local_values: &get_trace_values_packed(i_start), + next_values: &get_trace_values_packed(i_next_start), public_inputs: &public_inputs, }; let permutation_check_data = permutation_zs_commitment_challenges.as_ref().map( |(permutation_zs_commitment, permutation_challenge_sets)| PermutationCheckVars { - local_zs: get_at_index(permutation_zs_commitment, i).to_vec(), - next_zs: get_at_index(permutation_zs_commitment, (i + next_step) % size) - .to_vec(), + local_zs: permutation_zs_commitment.get_lde_values_packed(i_start, step), + next_zs: permutation_zs_commitment.get_lde_values_packed(i_next_start, step), permutation_challenge_sets: permutation_challenge_sets.to_vec(), }, ); - // TODO: Use packed field for F. - eval_vanishing_poly::( + eval_vanishing_poly::( stark, config, vars, permutation_check_data, &mut consumer, ); - // TODO: Fix this once we use a genuine `PackedField`. let mut constraints_evals = consumer.accumulators(); // We divide the constraints evaluations by `Z_H(x)`. - let denominator_inv = z_h_on_coset.eval_inverse(i); + let denominator_inv = z_h_on_coset.eval_inverse_packed(i_start); for eval in &mut constraints_evals { *eval *= denominator_inv; } diff --git a/starky/src/vanishing_poly.rs b/starky/src/vanishing_poly.rs index c8c75730..dc32b800 100644 --- a/starky/src/vanishing_poly.rs +++ b/starky/src/vanishing_poly.rs @@ -16,9 +16,9 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) fn eval_vanishing_poly( stark: &S, config: &StarkConfig, - vars: StarkEvaluationVars, - permutation_data: Option>, - consumer: &mut ConstraintConsumer, + vars: StarkEvaluationVars, + permutation_data: Option>, + consumer: &mut ConstraintConsumer

, ) where F: RichField + Extendable, FE: FieldExtension,