From 9096b758f4d6cedbe89e7d1fd72209db1f1f812c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 4 May 2022 20:57:07 +0200 Subject: [PATCH] Start of multi-table STARKs --- Cargo.toml | 2 +- starky2/Cargo.toml | 14 + starky2/src/config.rs | 34 +++ starky2/src/constraint_consumer.rs | 166 ++++++++++++ starky2/src/get_challenges.rs | 335 +++++++++++++++++++++++ starky2/src/lib.rs | 19 ++ starky2/src/mock_stark.rs | 360 +++++++++++++++++++++++++ starky2/src/permutation.rs | 397 +++++++++++++++++++++++++++ starky2/src/proof.rs | 213 +++++++++++++++ starky2/src/prover.rs | 414 +++++++++++++++++++++++++++++ starky2/src/recursive_verifier.rs | 333 +++++++++++++++++++++++ starky2/src/stark.rs | 204 ++++++++++++++ starky2/src/stark_testing.rs | 87 ++++++ starky2/src/util.rs | 16 ++ starky2/src/vanishing_poly.rs | 68 +++++ starky2/src/vars.rs | 26 ++ starky2/src/verifier.rs | 208 +++++++++++++++ 17 files changed, 2895 insertions(+), 1 deletion(-) create mode 100644 starky2/Cargo.toml create mode 100644 starky2/src/config.rs create mode 100644 starky2/src/constraint_consumer.rs create mode 100644 starky2/src/get_challenges.rs create mode 100644 starky2/src/lib.rs create mode 100644 starky2/src/mock_stark.rs create mode 100644 starky2/src/permutation.rs create mode 100644 starky2/src/proof.rs create mode 100644 starky2/src/prover.rs create mode 100644 starky2/src/recursive_verifier.rs create mode 100644 starky2/src/stark.rs create mode 100644 starky2/src/stark_testing.rs create mode 100644 starky2/src/util.rs create mode 100644 starky2/src/vanishing_poly.rs create mode 100644 starky2/src/vars.rs create mode 100644 starky2/src/verifier.rs diff --git a/Cargo.toml b/Cargo.toml index 5f95ea39..f8a59b92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa", "u32"] +members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa", "u32", "starky2"] [profile.release] opt-level = 3 diff --git a/starky2/Cargo.toml b/starky2/Cargo.toml new file mode 100644 index 00000000..feb4cbc3 --- /dev/null +++ b/starky2/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "starky2" +description = "Implementation of STARKs 2" +version = "0.1.0" +edition = "2021" + +[dependencies] +plonky2 = { path = "../plonky2" } +plonky2_util = { path = "../util" } +anyhow = "1.0.40" +env_logger = "0.9.0" +itertools = "0.10.0" +log = "0.4.14" +rayon = "1.5.1" diff --git a/starky2/src/config.rs b/starky2/src/config.rs new file mode 100644 index 00000000..500cd957 --- /dev/null +++ b/starky2/src/config.rs @@ -0,0 +1,34 @@ +use plonky2::fri::reduction_strategies::FriReductionStrategy; +use plonky2::fri::{FriConfig, FriParams}; + +pub struct StarkConfig { + pub security_bits: usize, + + /// The number of challenge points to generate, for IOPs that have soundness errors of (roughly) + /// `degree / |F|`. + pub num_challenges: usize, + + pub fri_config: FriConfig, +} + +impl StarkConfig { + /// A typical configuration with a rate of 2, resulting in fast but large proofs. + /// Targets ~100 bit conjectured security. + pub fn standard_fast_config() -> Self { + Self { + security_bits: 100, + num_challenges: 2, + fri_config: FriConfig { + rate_bits: 1, + cap_height: 4, + proof_of_work_bits: 10, + reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), + num_query_rounds: 90, + }, + } + } + + pub(crate) fn fri_params(&self, degree_bits: usize) -> FriParams { + self.fri_config.fri_params(degree_bits, false) + } +} diff --git a/starky2/src/constraint_consumer.rs b/starky2/src/constraint_consumer.rs new file mode 100644 index 00000000..ada28730 --- /dev/null +++ b/starky2/src/constraint_consumer.rs @@ -0,0 +1,166 @@ +use std::marker::PhantomData; + +use plonky2::field::extension_field::Extendable; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +pub struct ConstraintConsumer { + /// Random values used to combine multiple constraints into one. + alphas: Vec, + + /// Running sums of constraints that have been emitted so far, scaled by powers of alpha. + // TODO(JN): This is pub so it can be used in a test. Once we have an API for accessing this + // result, it should be made private. + pub constraint_accs: Vec

, + + /// The evaluation of `X - g^(n-1)`. + z_last: P, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the first trace row, and zero at other points in the subgroup. + lagrange_basis_first: P, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the last trace row, and zero at other points in the subgroup. + lagrange_basis_last: P, +} + +impl ConstraintConsumer

{ + pub fn new( + alphas: Vec, + z_last: P, + lagrange_basis_first: P, + lagrange_basis_last: P, + ) -> Self { + Self { + constraint_accs: vec![P::ZEROS; alphas.len()], + alphas, + z_last, + lagrange_basis_first, + lagrange_basis_last, + } + } + + // TODO: Do this correctly. + pub fn accumulators(self) -> Vec { + self.constraint_accs + .into_iter() + .map(|acc| acc.as_slice()[0]) + .collect() + } + + /// Add one constraint valid on all rows except the last. + pub fn constraint_transition(&mut self, constraint: P) { + self.constraint(constraint * self.z_last); + } + + /// Add one constraint on all rows. + pub fn constraint(&mut self, constraint: P) { + for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) { + *acc *= alpha; + *acc += constraint; + } + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// first row of the trace. + pub fn constraint_first_row(&mut self, constraint: P) { + self.constraint(constraint * self.lagrange_basis_first); + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// last row of the trace. + pub fn constraint_last_row(&mut self, constraint: P) { + self.constraint(constraint * self.lagrange_basis_last); + } +} + +pub struct RecursiveConstraintConsumer, const D: usize> { + /// A random value used to combine multiple constraints into one. + alphas: Vec, + + /// A running sum of constraints that have been emitted so far, scaled by powers of alpha. + constraint_accs: Vec>, + + /// The evaluation of `X - g^(n-1)`. + z_last: ExtensionTarget, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the first trace row, and zero at other points in the subgroup. + lagrange_basis_first: ExtensionTarget, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the last trace row, and zero at other points in the subgroup. + lagrange_basis_last: ExtensionTarget, + + _phantom: PhantomData, +} + +impl, const D: usize> RecursiveConstraintConsumer { + pub fn new( + zero: ExtensionTarget, + alphas: Vec, + z_last: ExtensionTarget, + lagrange_basis_first: ExtensionTarget, + lagrange_basis_last: ExtensionTarget, + ) -> Self { + Self { + constraint_accs: vec![zero; alphas.len()], + alphas, + z_last, + lagrange_basis_first, + lagrange_basis_last, + _phantom: Default::default(), + } + } + + pub fn accumulators(self) -> Vec> { + self.constraint_accs + } + + /// Add one constraint valid on all rows except the last. + pub fn constraint_transition( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + let filtered_constraint = builder.mul_extension(constraint, self.z_last); + self.constraint(builder, filtered_constraint); + } + + /// Add one constraint valid on all rows. + pub fn constraint( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) { + *acc = builder.scalar_mul_add_extension(alpha, *acc, constraint); + } + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// first row of the trace. + pub fn constraint_first_row( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_first); + self.constraint(builder, filtered_constraint); + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// last row of the trace. + pub fn constraint_last_row( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_last); + self.constraint(builder, filtered_constraint); + } +} diff --git a/starky2/src/get_challenges.rs b/starky2/src/get_challenges.rs new file mode 100644 index 00000000..ceafa297 --- /dev/null +++ b/starky2/src/get_challenges.rs @@ -0,0 +1,335 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::polynomial::PolynomialCoeffs; +use plonky2::fri::proof::{FriProof, FriProofTarget}; +use plonky2::gadgets::polynomial::PolynomialCoeffsExtTarget; +use plonky2::hash::hash_types::{MerkleCapTarget, RichField}; +use plonky2::hash::merkle_tree::MerkleCap; +use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; + +use crate::config::StarkConfig; +use crate::permutation::{ + get_n_permutation_challenge_sets, get_n_permutation_challenge_sets_target, +}; +use crate::proof::*; +use crate::stark::Stark; + +fn get_challenges( + stark: &S, + trace_cap: &MerkleCap, + permutation_zs_cap: Option<&MerkleCap>, + quotient_polys_cap: &MerkleCap, + openings: &StarkOpeningSet, + commit_phase_merkle_caps: &[MerkleCap], + final_poly: &PolynomialCoeffs, + pow_witness: F, + config: &StarkConfig, + degree_bits: usize, +) -> StarkProofChallenges +where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, +{ + let num_challenges = config.num_challenges; + + let mut challenger = Challenger::::new(); + + challenger.observe_cap(trace_cap); + + let permutation_challenge_sets = permutation_zs_cap.map(|permutation_zs_cap| { + let tmp = get_n_permutation_challenge_sets( + &mut challenger, + num_challenges, + stark.permutation_batch_size(), + ); + challenger.observe_cap(permutation_zs_cap); + tmp + }); + + let stark_alphas = challenger.get_n_challenges(num_challenges); + + challenger.observe_cap(quotient_polys_cap); + let stark_zeta = challenger.get_extension_challenge::(); + + challenger.observe_openings(&openings.to_fri_openings()); + + StarkProofChallenges { + permutation_challenge_sets, + stark_alphas, + stark_zeta, + fri_challenges: challenger.fri_challenges::( + commit_phase_merkle_caps, + final_poly, + pow_witness, + degree_bits, + &config.fri_config, + ), + } +} + +impl StarkProofWithPublicInputs +where + F: RichField + Extendable, + C: GenericConfig, +{ + // TODO: Should be used later in compression? + #![allow(dead_code)] + pub(crate) fn fri_query_indices>( + &self, + stark: &S, + config: &StarkConfig, + degree_bits: usize, + ) -> Vec { + self.get_challenges(stark, config, degree_bits) + .fri_challenges + .fri_query_indices + } + + /// Computes all Fiat-Shamir challenges used in the STARK proof. + pub(crate) fn get_challenges>( + &self, + stark: &S, + config: &StarkConfig, + degree_bits: usize, + ) -> StarkProofChallenges { + let StarkProof { + trace_cap, + permutation_zs_cap, + quotient_polys_cap, + openings, + opening_proof: + FriProof { + commit_phase_merkle_caps, + final_poly, + pow_witness, + .. + }, + } = &self.proof; + + get_challenges::( + stark, + trace_cap, + permutation_zs_cap.as_ref(), + quotient_polys_cap, + openings, + commit_phase_merkle_caps, + final_poly, + *pow_witness, + config, + degree_bits, + ) + } +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn get_challenges_target< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + builder: &mut CircuitBuilder, + stark: &S, + trace_cap: &MerkleCapTarget, + permutation_zs_cap: Option<&MerkleCapTarget>, + quotient_polys_cap: &MerkleCapTarget, + openings: &StarkOpeningSetTarget, + commit_phase_merkle_caps: &[MerkleCapTarget], + final_poly: &PolynomialCoeffsExtTarget, + pow_witness: Target, + config: &StarkConfig, +) -> StarkProofChallengesTarget +where + C::Hasher: AlgebraicHasher, +{ + let num_challenges = config.num_challenges; + + let mut challenger = RecursiveChallenger::::new(builder); + + challenger.observe_cap(trace_cap); + + let permutation_challenge_sets = permutation_zs_cap.map(|permutation_zs_cap| { + let tmp = get_n_permutation_challenge_sets_target( + builder, + &mut challenger, + num_challenges, + stark.permutation_batch_size(), + ); + challenger.observe_cap(permutation_zs_cap); + tmp + }); + + let stark_alphas = challenger.get_n_challenges(builder, num_challenges); + + challenger.observe_cap(quotient_polys_cap); + let stark_zeta = challenger.get_extension_challenge(builder); + + challenger.observe_openings(&openings.to_fri_openings()); + + StarkProofChallengesTarget { + permutation_challenge_sets, + stark_alphas, + stark_zeta, + fri_challenges: challenger.fri_challenges::( + builder, + commit_phase_merkle_caps, + final_poly, + pow_witness, + &config.fri_config, + ), + } +} + +impl StarkProofWithPublicInputsTarget { + pub(crate) fn get_challenges< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + >( + &self, + builder: &mut CircuitBuilder, + stark: &S, + config: &StarkConfig, + ) -> StarkProofChallengesTarget + where + C::Hasher: AlgebraicHasher, + { + let StarkProofTarget { + trace_cap, + permutation_zs_cap, + quotient_polys_cap, + openings, + opening_proof: + FriProofTarget { + commit_phase_merkle_caps, + final_poly, + pow_witness, + .. + }, + } = &self.proof; + + get_challenges_target::( + builder, + stark, + trace_cap, + permutation_zs_cap.as_ref(), + quotient_polys_cap, + openings, + commit_phase_merkle_caps, + final_poly, + *pow_witness, + config, + ) + } +} + +// TODO: Deal with the compressed stuff. +// impl, C: GenericConfig, const D: usize> +// CompressedProofWithPublicInputs +// { +// /// Computes all Fiat-Shamir challenges used in the Plonk proof. +// pub(crate) fn get_challenges( +// &self, +// common_data: &CommonCircuitData, +// ) -> anyhow::Result> { +// let CompressedProof { +// wires_cap, +// plonk_zs_partial_products_cap, +// quotient_polys_cap, +// openings, +// opening_proof: +// CompressedFriProof { +// commit_phase_merkle_caps, +// final_poly, +// pow_witness, +// .. +// }, +// } = &self.proof; +// +// get_challenges( +// self.get_public_inputs_hash(), +// wires_cap, +// plonk_zs_partial_products_cap, +// quotient_polys_cap, +// openings, +// commit_phase_merkle_caps, +// final_poly, +// *pow_witness, +// common_data, +// ) +// } +// +// /// Computes all coset elements that can be inferred in the FRI reduction steps. +// pub(crate) fn get_inferred_elements( +// &self, +// challenges: &ProofChallenges, +// common_data: &CommonCircuitData, +// ) -> FriInferredElements { +// let ProofChallenges { +// plonk_zeta, +// fri_alpha, +// fri_betas, +// fri_query_indices, +// .. +// } = challenges; +// let mut fri_inferred_elements = Vec::new(); +// // Holds the indices that have already been seen at each reduction depth. +// let mut seen_indices_by_depth = +// vec![HashSet::new(); common_data.fri_params.reduction_arity_bits.len()]; +// let precomputed_reduced_evals = PrecomputedReducedOpenings::from_os_and_alpha( +// &self.proof.openings.to_fri_openings(), +// *fri_alpha, +// ); +// let log_n = common_data.degree_bits + common_data.config.fri_config.rate_bits; +// // Simulate the proof verification and collect the inferred elements. +// // The content of the loop is basically the same as the `fri_verifier_query_round` function. +// for &(mut x_index) in fri_query_indices { +// let mut subgroup_x = F::MULTIPLICATIVE_GROUP_GENERATOR +// * F::primitive_root_of_unity(log_n).exp_u64(reverse_bits(x_index, log_n) as u64); +// let mut old_eval = fri_combine_initial::( +// &common_data.get_fri_instance(*plonk_zeta), +// &self +// .proof +// .opening_proof +// .query_round_proofs +// .initial_trees_proofs[&x_index], +// *fri_alpha, +// subgroup_x, +// &precomputed_reduced_evals, +// &common_data.fri_params, +// ); +// for (i, &arity_bits) in common_data +// .fri_params +// .reduction_arity_bits +// .iter() +// .enumerate() +// { +// let coset_index = x_index >> arity_bits; +// if !seen_indices_by_depth[i].insert(coset_index) { +// // If this index has already been seen, we can skip the rest of the reductions. +// break; +// } +// fri_inferred_elements.push(old_eval); +// let arity = 1 << arity_bits; +// let mut evals = self.proof.opening_proof.query_round_proofs.steps[i][&coset_index] +// .evals +// .clone(); +// let x_index_within_coset = x_index & (arity - 1); +// evals.insert(x_index_within_coset, old_eval); +// old_eval = compute_evaluation( +// subgroup_x, +// x_index_within_coset, +// arity_bits, +// &evals, +// fri_betas[i], +// ); +// subgroup_x = subgroup_x.exp_power_of_2(arity_bits); +// x_index = coset_index; +// } +// } +// FriInferredElements(fri_inferred_elements) +// } +// } diff --git a/starky2/src/lib.rs b/starky2/src/lib.rs new file mode 100644 index 00000000..0924600d --- /dev/null +++ b/starky2/src/lib.rs @@ -0,0 +1,19 @@ +#![allow(incomplete_features)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::type_complexity)] +#![feature(generic_const_exprs)] + +pub mod config; +pub mod constraint_consumer; +mod get_challenges; +pub mod mock_stark; +pub mod permutation; +pub mod proof; +pub mod prover; +pub mod recursive_verifier; +pub mod stark; +pub mod stark_testing; +pub mod util; +pub mod vanishing_poly; +pub mod vars; +pub mod verifier; diff --git a/starky2/src/mock_stark.rs b/starky2/src/mock_stark.rs new file mode 100644 index 00000000..337c0142 --- /dev/null +++ b/starky2/src/mock_stark.rs @@ -0,0 +1,360 @@ +use std::marker::PhantomData; + +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::permutation::PermutationPair; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; +use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; + +struct AllStarks, const D: usize> { + fibonacci: FibonacciStark, + multiplications: MultiplicationStark, +} + +/// Toy STARK system used for testing. +/// Computes a Fibonacci sequence with state `[x0, x1, i, j]` using the state transition +/// `x0' <- x1, x1' <- x0 + x1, i' <- i+1, j' <- j+1`. +/// Note: The `i, j` columns are only used to test the permutation argument. +#[derive(Copy, Clone)] +struct FibonacciStark, const D: usize> { + num_rows: usize, + _phantom: PhantomData, +} + +impl, const D: usize> FibonacciStark { + // The first public input is `x0`. + const PI_INDEX_X0: usize = 0; + // The second public input is `x1`. + const PI_INDEX_X1: usize = 1; + // The third public input is the second element of the last row, which should be equal to the + // `num_rows`-th Fibonacci number. + const PI_INDEX_RES: usize = 2; + + fn new(num_rows: usize) -> Self { + Self { + num_rows, + _phantom: PhantomData, + } + } + + /// Generate the trace using `x0, x1, 0, 1` as initial state values. + fn generate_trace(&self, x0: F, x1: F) -> Vec> { + let mut trace_rows = (0..self.num_rows) + .scan([x0, x1, F::ZERO, F::ONE], |acc, _| { + let tmp = *acc; + acc[0] = tmp[1]; + acc[1] = tmp[0] + tmp[1]; + acc[2] = tmp[2] + F::ONE; + acc[3] = tmp[3] + F::ONE; + Some(tmp) + }) + .collect::>(); + trace_rows[self.num_rows - 1][3] = F::ZERO; // So that column 2 and 3 are permutation of one another. + trace_rows_to_poly_values(trace_rows) + } +} + +impl, const D: usize> Stark for FibonacciStark { + const COLUMNS: usize = 4; + const PUBLIC_INPUTS: usize = 3; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + // Check public inputs. + yield_constr + .constraint_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]); + yield_constr + .constraint_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]); + yield_constr + .constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]); + + // x0' <- x1 + yield_constr.constraint_transition(vars.next_values[0] - vars.local_values[1]); + // x1' <- x0 + x1 + yield_constr.constraint_transition( + vars.next_values[1] - vars.local_values[0] - vars.local_values[1], + ); + } + + fn eval_ext_recursively( + &self, + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + // Check public inputs. + let pis_constraints = [ + builder.sub_extension(vars.local_values[0], vars.public_inputs[Self::PI_INDEX_X0]), + builder.sub_extension(vars.local_values[1], vars.public_inputs[Self::PI_INDEX_X1]), + builder.sub_extension(vars.local_values[1], vars.public_inputs[Self::PI_INDEX_RES]), + ]; + yield_constr.constraint_first_row(builder, pis_constraints[0]); + yield_constr.constraint_first_row(builder, pis_constraints[1]); + yield_constr.constraint_last_row(builder, pis_constraints[2]); + + // x0' <- x1 + let first_col_constraint = builder.sub_extension(vars.next_values[0], vars.local_values[1]); + yield_constr.constraint_transition(builder, first_col_constraint); + // x1' <- x0 + x1 + let second_col_constraint = { + let tmp = builder.sub_extension(vars.next_values[1], vars.local_values[0]); + builder.sub_extension(tmp, vars.local_values[1]) + }; + yield_constr.constraint_transition(builder, second_col_constraint); + } + + fn constraint_degree(&self) -> usize { + 2 + } + + fn permutation_pairs(&self) -> Vec { + vec![PermutationPair::singletons(2, 3)] + } +} + +#[derive(Copy, Clone)] +struct MultiplicationStark< + F: RichField + Extendable, + const D: usize, + const NUM_MULTIPLICANDS: usize, +> { + num_rows: usize, + _phantom: PhantomData, +} + +impl, const D: usize, const W: usize> MultiplicationStark { + fn multiplicand(&self, i: usize) -> usize { + debug_assert!(i < W); + i + } + + // Product of the first `i` multiplicands + fn intermediate_product(&self, i: usize) -> usize { + debug_assert!(i < W && i > 0); + W + i - 1 + } + + fn product(&self) -> usize { + 2 * W - 2 + } + + const fn num_columns() -> usize { + 2 * W - 1 + } + + fn new(num_rows: usize) -> Self { + Self { + num_rows, + _phantom: PhantomData, + } + } + + fn generate_trace(&self, multiplicands: &[Vec]) -> Vec> + where + [(); Self::num_columns()]:, + { + debug_assert_eq!(multiplicands.len(), self.num_rows); + let mut trace_rows = multiplicands + .iter() + .map(|row| { + debug_assert_eq!(row.len(), W); + let mut result = [F::ZERO; Self::num_columns()]; + for i in 0..W { + result[self.multiplicand(i)] = row[i]; + } + let mut acc = row[0] * row[1]; + for i in 1..W - 1 { + result[self.intermediate_product(i)] = acc; + acc *= row[i + 1]; + } + result[self.product()] = acc; + result + }) + .collect::>(); + trace_rows_to_poly_values(trace_rows) + } +} + +impl, const D: usize, const W: usize> Stark + for MultiplicationStark +{ + const COLUMNS: usize = 2 * W - 1; + const PUBLIC_INPUTS: usize = 0; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + yield_constr.constraint( + vars.local_values[self.intermediate_product(1)] + - vars.local_values[self.multiplicand(0)] * vars.local_values[self.multiplicand(1)], + ); + for i in 2..W - 1 { + yield_constr.constraint( + vars.local_values[self.intermediate_product(i)] + - vars.local_values[self.intermediate_product(i - 1)] + * vars.local_values[self.multiplicand(i)], + ) + } + } + + fn eval_ext_recursively( + &self, + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + todo!() + } + + fn constraint_degree(&self) -> usize { + 2 + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::extension_field::Extendable; + use plonky2::field::field_types::Field; + use plonky2::hash::hash_types::RichField; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{ + AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig, + }; + use plonky2::util::timing::TimingTree; + + use crate::config::StarkConfig; + use crate::mock_stark::FibonacciStark; + use crate::proof::StarkProofWithPublicInputs; + use crate::prover::prove; + use crate::recursive_verifier::{ + add_virtual_stark_proof_with_pis, recursively_verify_stark_proof, + set_stark_proof_with_pis_target, + }; + use crate::stark::Stark; + use crate::stark_testing::test_stark_low_degree; + use crate::verifier::verify_stark_proof; + + fn fibonacci(n: usize, x0: F, x1: F) -> F { + (0..n).fold((x0, x1), |x, _| (x.1, x.0 + x.1)).1 + } + + #[test] + fn test_fibonacci_stark() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; + let stark = S::new(num_rows); + let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); + let proof = prove::( + stark, + &config, + trace, + public_inputs, + &mut TimingTree::default(), + )?; + + verify_stark_proof(stark, proof, &config) + } + + #[test] + fn test_fibonacci_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let num_rows = 1 << 5; + let stark = S::new(num_rows); + test_stark_low_degree(stark) + } + + #[test] + fn test_recursive_stark_verifier() -> Result<()> { + init_logger(); + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; + let stark = S::new(num_rows); + let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); + let proof = prove::( + stark, + &config, + trace, + public_inputs, + &mut TimingTree::default(), + )?; + verify_stark_proof(stark, proof.clone(), &config)?; + + recursive_proof::(stark, proof, &config, true) + } + + fn recursive_proof< + F: RichField + Extendable, + C: GenericConfig, + S: Stark + Copy, + InnerC: GenericConfig, + const D: usize, + >( + stark: S, + inner_proof: StarkProofWithPublicInputs, + inner_config: &StarkConfig, + print_gate_counts: bool, + ) -> Result<()> + where + InnerC::Hasher: AlgebraicHasher, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, + { + let circuit_config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(circuit_config); + let mut pw = PartialWitness::new(); + let degree_bits = inner_proof.proof.recover_degree_bits(inner_config); + let pt = add_virtual_stark_proof_with_pis(&mut builder, stark, inner_config, degree_bits); + set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof); + + recursively_verify_stark_proof::(&mut builder, stark, pt, inner_config); + + if print_gate_counts { + builder.print_gate_counts(0); + } + + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof) + } + + fn init_logger() { + let _ = env_logger::builder().format_timestamp(None).try_init(); + } +} diff --git a/starky2/src/permutation.rs b/starky2/src/permutation.rs new file mode 100644 index 00000000..443ff787 --- /dev/null +++ b/starky2/src/permutation.rs @@ -0,0 +1,397 @@ +//! Permutation arguments. + +use itertools::Itertools; +use plonky2::field::batch_util::batch_multiply_inplace; +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget}; +use rayon::prelude::*; + +use crate::config::StarkConfig; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::stark::Stark; +use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; + +/// A pair of lists of columns, `lhs` and `rhs`, that should be permutations of one another. +/// In particular, there should exist some permutation `pi` such that for any `i`, +/// `trace[lhs[i]] = pi(trace[rhs[i]])`. Here `trace` denotes the trace in column-major form, so +/// `trace[col]` is a column vector. +pub struct PermutationPair { + /// Each entry contains two column indices, representing two columns which should be + /// permutations of one another. + pub column_pairs: Vec<(usize, usize)>, +} + +impl PermutationPair { + pub fn singletons(lhs: usize, rhs: usize) -> Self { + Self { + column_pairs: vec![(lhs, rhs)], + } + } +} + +/// A single instance of a permutation check protocol. +pub(crate) struct PermutationInstance<'a, T: Copy> { + pub(crate) pair: &'a PermutationPair, + pub(crate) challenge: PermutationChallenge, +} + +/// Randomness for a single instance of a permutation check protocol. +#[derive(Copy, Clone)] +pub(crate) struct PermutationChallenge { + /// Randomness used to combine multiple columns into one. + pub(crate) beta: T, + /// Random offset that's added to the beta-reduced column values. + pub(crate) gamma: T, +} + +/// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. +#[derive(Clone)] +pub(crate) struct PermutationChallengeSet { + pub(crate) challenges: Vec>, +} + +/// Compute all Z polynomials (for permutation arguments). +pub(crate) fn compute_permutation_z_polys( + stark: &S, + config: &StarkConfig, + trace_poly_values: &[PolynomialValues], + permutation_challenge_sets: &[PermutationChallengeSet], +) -> Vec> +where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, +{ + let permutation_pairs = stark.permutation_pairs(); + let permutation_batches = get_permutation_batches( + &permutation_pairs, + permutation_challenge_sets, + config.num_challenges, + stark.permutation_batch_size(), + ); + + permutation_batches + .into_par_iter() + .map(|instances| compute_permutation_z_poly(&instances, trace_poly_values)) + .collect() +} + +/// Compute a single Z polynomial. +fn compute_permutation_z_poly( + instances: &[PermutationInstance], + trace_poly_values: &[PolynomialValues], +) -> PolynomialValues { + let degree = trace_poly_values[0].len(); + let (reduced_lhs_polys, reduced_rhs_polys): (Vec<_>, Vec<_>) = instances + .iter() + .map(|instance| permutation_reduced_polys(instance, trace_poly_values, degree)) + .unzip(); + + let numerator = poly_product_elementwise(reduced_lhs_polys.into_iter()); + let denominator = poly_product_elementwise(reduced_rhs_polys.into_iter()); + + // Compute the quotients. + let denominator_inverses = F::batch_multiplicative_inverse(&denominator.values); + let mut quotients = numerator.values; + batch_multiply_inplace(&mut quotients, &denominator_inverses); + + // Compute Z, which contains partial products of the quotients. + let mut partial_products = Vec::with_capacity(degree); + let mut acc = F::ONE; + for q in quotients { + partial_products.push(acc); + acc *= q; + } + PolynomialValues::new(partial_products) +} + +/// Computes the reduced polynomial, `\sum beta^i f_i(x) + gamma`, for both the "left" and "right" +/// sides of a given `PermutationPair`. +fn permutation_reduced_polys( + instance: &PermutationInstance, + trace_poly_values: &[PolynomialValues], + degree: usize, +) -> (PolynomialValues, PolynomialValues) { + let PermutationInstance { + pair: PermutationPair { column_pairs }, + challenge: PermutationChallenge { beta, gamma }, + } = instance; + + let mut reduced_lhs = PolynomialValues::constant(*gamma, degree); + let mut reduced_rhs = PolynomialValues::constant(*gamma, degree); + for ((lhs, rhs), weight) in column_pairs.iter().zip(beta.powers()) { + reduced_lhs.add_assign_scaled(&trace_poly_values[*lhs], weight); + reduced_rhs.add_assign_scaled(&trace_poly_values[*rhs], weight); + } + (reduced_lhs, reduced_rhs) +} + +/// Computes the elementwise product of a set of polynomials. Assumes that the set is non-empty and +/// that each polynomial has the same length. +fn poly_product_elementwise( + mut polys: impl Iterator>, +) -> PolynomialValues { + let mut product = polys.next().expect("Expected at least one polynomial"); + for poly in polys { + batch_multiply_inplace(&mut product.values, &poly.values) + } + product +} + +fn get_permutation_challenge>( + challenger: &mut Challenger, +) -> PermutationChallenge { + let beta = challenger.get_challenge(); + let gamma = challenger.get_challenge(); + PermutationChallenge { beta, gamma } +} + +fn get_permutation_challenge_set>( + challenger: &mut Challenger, + num_challenges: usize, +) -> PermutationChallengeSet { + let challenges = (0..num_challenges) + .map(|_| get_permutation_challenge(challenger)) + .collect(); + PermutationChallengeSet { challenges } +} + +pub(crate) fn get_n_permutation_challenge_sets>( + challenger: &mut Challenger, + num_challenges: usize, + num_sets: usize, +) -> Vec> { + (0..num_sets) + .map(|_| get_permutation_challenge_set(challenger, num_challenges)) + .collect() +} + +fn get_permutation_challenge_target< + F: RichField + Extendable, + H: AlgebraicHasher, + const D: usize, +>( + builder: &mut CircuitBuilder, + challenger: &mut RecursiveChallenger, +) -> PermutationChallenge { + let beta = challenger.get_challenge(builder); + let gamma = challenger.get_challenge(builder); + PermutationChallenge { beta, gamma } +} + +fn get_permutation_challenge_set_target< + F: RichField + Extendable, + H: AlgebraicHasher, + const D: usize, +>( + builder: &mut CircuitBuilder, + challenger: &mut RecursiveChallenger, + num_challenges: usize, +) -> PermutationChallengeSet { + let challenges = (0..num_challenges) + .map(|_| get_permutation_challenge_target(builder, challenger)) + .collect(); + PermutationChallengeSet { challenges } +} + +pub(crate) fn get_n_permutation_challenge_sets_target< + F: RichField + Extendable, + H: AlgebraicHasher, + const D: usize, +>( + builder: &mut CircuitBuilder, + challenger: &mut RecursiveChallenger, + num_challenges: usize, + num_sets: usize, +) -> Vec> { + (0..num_sets) + .map(|_| get_permutation_challenge_set_target(builder, challenger, num_challenges)) + .collect() +} + +/// Get a list of instances of our batch-permutation argument. These are permutation arguments +/// where the same `Z(x)` polynomial is used to check more than one permutation. +/// Before batching, each permutation pair leads to `num_challenges` permutation arguments, so we +/// start with the cartesian product of `permutation_pairs` and `0..num_challenges`. Then we +/// chunk these arguments based on our batch size. +pub(crate) fn get_permutation_batches<'a, T: Copy>( + permutation_pairs: &'a [PermutationPair], + permutation_challenge_sets: &[PermutationChallengeSet], + num_challenges: usize, + batch_size: usize, +) -> Vec>> { + permutation_pairs + .iter() + .cartesian_product(0..num_challenges) + .chunks(batch_size) + .into_iter() + .map(|batch| { + batch + .enumerate() + .map(|(i, (pair, chal))| { + let challenge = permutation_challenge_sets[i].challenges[chal]; + PermutationInstance { pair, challenge } + }) + .collect_vec() + }) + .collect() +} + +pub struct PermutationCheckVars +where + F: Field, + FE: FieldExtension, + P: PackedField, +{ + pub(crate) local_zs: Vec

, + pub(crate) next_zs: Vec

, + pub(crate) permutation_challenge_sets: Vec>, +} + +pub(crate) fn eval_permutation_checks( + stark: &S, + config: &StarkConfig, + vars: StarkEvaluationVars, + permutation_data: PermutationCheckVars, + consumer: &mut ConstraintConsumer

, +) where + F: RichField + Extendable, + FE: FieldExtension, + P: PackedField, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + let PermutationCheckVars { + local_zs, + next_zs, + permutation_challenge_sets, + } = permutation_data; + + // Check that Z(1) = 1; + for &z in &local_zs { + consumer.constraint_first_row(z - FE::ONE); + } + + let permutation_pairs = stark.permutation_pairs(); + + let permutation_batches = get_permutation_batches( + &permutation_pairs, + &permutation_challenge_sets, + config.num_challenges, + stark.permutation_batch_size(), + ); + + // Each zs value corresponds to a permutation batch. + for (i, instances) in permutation_batches.iter().enumerate() { + // Z(gx) * down = Z x * up + let (reduced_lhs, reduced_rhs): (Vec

, Vec

) = instances + .iter() + .map(|instance| { + let PermutationInstance { + pair: PermutationPair { column_pairs }, + challenge: PermutationChallenge { beta, gamma }, + } = instance; + let mut factor = ReducingFactor::new(*beta); + let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs + .iter() + .map(|&(i, j)| (vars.local_values[i], vars.local_values[j])) + .unzip(); + ( + factor.reduce_ext(lhs.into_iter()) + FE::from_basefield(*gamma), + factor.reduce_ext(rhs.into_iter()) + FE::from_basefield(*gamma), + ) + }) + .unzip(); + let constraint = next_zs[i] * reduced_rhs.into_iter().product::

() + - local_zs[i] * reduced_lhs.into_iter().product::

(); + consumer.constraint(constraint); + } +} + +pub struct PermutationCheckDataTarget { + pub(crate) local_zs: Vec>, + pub(crate) next_zs: Vec>, + pub(crate) permutation_challenge_sets: Vec>, +} + +pub(crate) fn eval_permutation_checks_recursively( + builder: &mut CircuitBuilder, + stark: &S, + config: &StarkConfig, + vars: StarkEvaluationTargets, + permutation_data: PermutationCheckDataTarget, + consumer: &mut RecursiveConstraintConsumer, +) where + F: RichField + Extendable, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + let PermutationCheckDataTarget { + local_zs, + next_zs, + permutation_challenge_sets, + } = permutation_data; + + let one = builder.one_extension(); + // Check that Z(1) = 1; + for &z in &local_zs { + let z_1 = builder.sub_extension(z, one); + consumer.constraint_first_row(builder, z_1); + } + + let permutation_pairs = stark.permutation_pairs(); + + let permutation_batches = get_permutation_batches( + &permutation_pairs, + &permutation_challenge_sets, + config.num_challenges, + stark.permutation_batch_size(), + ); + + // Each zs value corresponds to a permutation batch. + for (i, instances) in permutation_batches.iter().enumerate() { + let (reduced_lhs, reduced_rhs): (Vec>, Vec>) = + instances + .iter() + .map(|instance| { + let PermutationInstance { + pair: PermutationPair { column_pairs }, + challenge: PermutationChallenge { beta, gamma }, + } = instance; + let beta_ext = builder.convert_to_ext(*beta); + let gamma_ext = builder.convert_to_ext(*gamma); + let mut factor = ReducingFactorTarget::new(beta_ext); + let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs + .iter() + .map(|&(i, j)| (vars.local_values[i], vars.local_values[j])) + .unzip(); + let reduced_lhs = factor.reduce(&lhs, builder); + let reduced_rhs = factor.reduce(&rhs, builder); + ( + builder.add_extension(reduced_lhs, gamma_ext), + builder.add_extension(reduced_rhs, gamma_ext), + ) + }) + .unzip(); + let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs); + let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs); + // constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product + let constraint = { + let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product); + builder.mul_sub_extension(next_zs[i], reduced_rhs_product, tmp) + }; + consumer.constraint(builder, constraint) + } +} diff --git a/starky2/src/proof.rs b/starky2/src/proof.rs new file mode 100644 index 00000000..afefdd96 --- /dev/null +++ b/starky2/src/proof.rs @@ -0,0 +1,213 @@ +use itertools::Itertools; +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::fri::oracle::PolynomialBatch; +use plonky2::fri::proof::{ + CompressedFriProof, FriChallenges, FriChallengesTarget, FriProof, FriProofTarget, +}; +use plonky2::fri::structure::{ + FriOpeningBatch, FriOpeningBatchTarget, FriOpenings, FriOpeningsTarget, +}; +use plonky2::hash::hash_types::{MerkleCapTarget, RichField}; +use plonky2::hash::merkle_tree::MerkleCap; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::Target; +use plonky2::plonk::config::GenericConfig; +use rayon::prelude::*; + +use crate::config::StarkConfig; +use crate::permutation::PermutationChallengeSet; + +#[derive(Debug, Clone)] +pub struct StarkProof, C: GenericConfig, const D: usize> { + /// Merkle cap of LDEs of trace values. + pub trace_cap: MerkleCap, + /// Merkle cap of LDEs of permutation Z values. + pub permutation_zs_cap: Option>, + /// Merkle cap of LDEs of trace values. + pub quotient_polys_cap: MerkleCap, + /// Purported values of each polynomial at the challenge point. + pub openings: StarkOpeningSet, + /// A batch FRI argument for all openings. + pub opening_proof: FriProof, +} + +impl, C: GenericConfig, const D: usize> StarkProof { + /// Recover the length of the trace from a STARK proof and a STARK config. + pub fn recover_degree_bits(&self, config: &StarkConfig) -> usize { + let initial_merkle_proof = &self.opening_proof.query_round_proofs[0] + .initial_trees_proof + .evals_proofs[0] + .1; + let lde_bits = config.fri_config.cap_height + initial_merkle_proof.siblings.len(); + lde_bits - config.fri_config.rate_bits + } +} + +pub struct StarkProofTarget { + pub trace_cap: MerkleCapTarget, + pub permutation_zs_cap: Option, + pub quotient_polys_cap: MerkleCapTarget, + pub openings: StarkOpeningSetTarget, + pub opening_proof: FriProofTarget, +} + +impl StarkProofTarget { + /// Recover the length of the trace from a STARK proof and a STARK config. + pub fn recover_degree_bits(&self, config: &StarkConfig) -> usize { + let initial_merkle_proof = &self.opening_proof.query_round_proofs[0] + .initial_trees_proof + .evals_proofs[0] + .1; + let lde_bits = config.fri_config.cap_height + initial_merkle_proof.siblings.len(); + lde_bits - config.fri_config.rate_bits + } +} + +#[derive(Debug, Clone)] +pub struct StarkProofWithPublicInputs< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub proof: StarkProof, + // TODO: Maybe make it generic over a `S: Stark` and replace with `[F; S::PUBLIC_INPUTS]`. + pub public_inputs: Vec, +} + +pub struct StarkProofWithPublicInputsTarget { + pub proof: StarkProofTarget, + pub public_inputs: Vec, +} + +pub struct CompressedStarkProof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + /// Merkle cap of LDEs of trace values. + pub trace_cap: MerkleCap, + /// Purported values of each polynomial at the challenge point. + pub openings: StarkOpeningSet, + /// A batch FRI argument for all openings. + pub opening_proof: CompressedFriProof, +} + +pub struct CompressedStarkProofWithPublicInputs< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub proof: CompressedStarkProof, + pub public_inputs: Vec, +} + +pub(crate) struct StarkProofChallenges, const D: usize> { + /// Randomness used in any permutation arguments. + pub permutation_challenge_sets: Option>>, + + /// Random values used to combine STARK constraints. + pub stark_alphas: Vec, + + /// Point at which the STARK polynomials are opened. + pub stark_zeta: F::Extension, + + pub fri_challenges: FriChallenges, +} + +pub(crate) struct StarkProofChallengesTarget { + pub permutation_challenge_sets: Option>>, + pub stark_alphas: Vec, + pub stark_zeta: ExtensionTarget, + pub fri_challenges: FriChallengesTarget, +} + +/// Purported values of each polynomial at the challenge point. +#[derive(Debug, Clone)] +pub struct StarkOpeningSet, const D: usize> { + pub local_values: Vec, + pub next_values: Vec, + pub permutation_zs: Option>, + pub permutation_zs_right: Option>, + pub quotient_polys: Vec, +} + +impl, const D: usize> StarkOpeningSet { + pub fn new>( + zeta: F::Extension, + g: F, + trace_commitment: &PolynomialBatch, + permutation_zs_commitment: Option<&PolynomialBatch>, + quotient_commitment: &PolynomialBatch, + ) -> Self { + let eval_commitment = |z: F::Extension, c: &PolynomialBatch| { + c.polynomials + .par_iter() + .map(|p| p.to_extension().eval(z)) + .collect::>() + }; + let zeta_right = zeta.scalar_mul(g); + Self { + local_values: eval_commitment(zeta, trace_commitment), + next_values: eval_commitment(zeta_right, trace_commitment), + permutation_zs: permutation_zs_commitment.map(|c| eval_commitment(zeta, c)), + permutation_zs_right: permutation_zs_commitment.map(|c| eval_commitment(zeta_right, c)), + quotient_polys: eval_commitment(zeta, quotient_commitment), + } + } + + pub(crate) fn to_fri_openings(&self) -> FriOpenings { + let zeta_batch = FriOpeningBatch { + values: self + .local_values + .iter() + .chain(self.permutation_zs.iter().flatten()) + .chain(&self.quotient_polys) + .copied() + .collect_vec(), + }; + let zeta_right_batch = FriOpeningBatch { + values: self + .next_values + .iter() + .chain(self.permutation_zs_right.iter().flatten()) + .copied() + .collect_vec(), + }; + FriOpenings { + batches: vec![zeta_batch, zeta_right_batch], + } + } +} + +pub struct StarkOpeningSetTarget { + pub local_values: Vec>, + pub next_values: Vec>, + pub permutation_zs: Option>>, + pub permutation_zs_right: Option>>, + pub quotient_polys: Vec>, +} + +impl StarkOpeningSetTarget { + pub(crate) fn to_fri_openings(&self) -> FriOpeningsTarget { + let zeta_batch = FriOpeningBatchTarget { + values: self + .local_values + .iter() + .chain(self.permutation_zs.iter().flatten()) + .chain(&self.quotient_polys) + .copied() + .collect_vec(), + }; + let zeta_right_batch = FriOpeningBatchTarget { + values: self + .next_values + .iter() + .chain(self.permutation_zs_right.iter().flatten()) + .copied() + .collect_vec(), + }; + FriOpeningsTarget { + batches: vec![zeta_batch, zeta_right_batch], + } + } +} diff --git a/starky2/src/prover.rs b/starky2/src/prover.rs new file mode 100644 index 00000000..808fb50b --- /dev/null +++ b/starky2/src/prover.rs @@ -0,0 +1,414 @@ +use std::iter::once; +use std::marker::PhantomData; + +use anyhow::{ensure, Result}; +use itertools::Itertools; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packable::Packable; +use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues}; +use plonky2::field::zero_poly_coset::ZeroPolyOnCoset; +use plonky2::fri::oracle::PolynomialBatch; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::challenger::Challenger; +use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; +use plonky2_util::{log2_ceil, log2_strict}; +use rayon::prelude::*; + +use crate::config::StarkConfig; +use crate::constraint_consumer::ConstraintConsumer; +use crate::permutation::PermutationCheckVars; +use crate::permutation::{ + compute_permutation_z_polys, get_n_permutation_challenge_sets, PermutationChallengeSet, +}; +use crate::proof::{StarkOpeningSet, StarkProof, StarkProofWithPublicInputs}; +use crate::stark::Stark; +use crate::vanishing_poly::eval_vanishing_poly; +use crate::vars::StarkEvaluationVars; + +enum Table { + Cpu = 0, + Keccak = 1, +} + +struct CpuStark { + f: PhantomData, +} +struct KeccakStark { + f: PhantomData, +} + +struct AllStarks, const D: usize> { + cpu: CpuStark, + keccak: KeccakStark, +} + +struct CrossTableLookup { + looking_table: Table, + looking_columns: Vec, + looked_table: usize, + looked_columns: Vec, +} + +impl CrossTableLookup { + pub fn new( + looking_table: Table, + looking_columns: Vec, + looked_table: usize, + looked_columns: Vec, + ) -> Self { + assert_eq!(looking_columns.len(), looked_columns.len()); + Self { + looking_table: looking_table, + looking_columns: looking_columns, + looked_table: looked_table, + looked_columns: looked_columns, + } + } +} + +pub fn prove( + all_starks: AllStarks, + config: &StarkConfig, + trace_poly_values: Vec>>, + cross_table_lookups: Vec, + public_inputs: Vec>, + timing: &mut TimingTree, +) -> Result> +where + F: RichField + Extendable, + C: GenericConfig, + [(); <::Packing>::WIDTH]:, + [(); C::Hasher::HASH_SIZE]:, +{ + let num_starks = Table::Keccak as usize + 1; + debug_assert_eq!(num_starks, trace_poly_values.len()); + debug_assert_eq!(num_starks, public_inputs.len()); + + let degree = trace_poly_values[0].len(); + let degree_bits = log2_strict(degree); + let fri_params = config.fri_params(degree_bits); + let rate_bits = config.fri_config.rate_bits; + let cap_height = config.fri_config.cap_height; + assert!( + fri_params.total_arities() <= degree_bits + rate_bits - cap_height, + "FRI total reduction arity is too large.", + ); + + let trace_commitments = timed!( + timing, + "compute trace commitments", + trace_poly_values + .iter() + .map(|trace| { + PolynomialBatch::::from_values( + // TODO: Cloning this isn't great; consider having `from_values` accept a reference, + // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. + trace.clone(), + rate_bits, + false, + cap_height, + timing, + None, + ) + }) + .collect::>() + ); + + let trace_caps = trace_commitments + .iter() + .map(|c| c.merkle_tree.cap) + .collect::>(); + let mut challenger = Challenger::new(); + for cap in &trace_caps { + challenger.observe_cap(cap); + } + + let permutation_zs_commitment_challenges = (0..num_starks) + .map(|i| { + permutation_challenges( + all_starks.stark(i), + &trace_poly_values[i], + config, + &mut challenger, + timing, + ) + }) + .collect::>(); + + let permutation_zs_commitment = permutation_zs_commitment_challenges + .iter() + .map(|pzcc| pzcc.map(|(comm, _)| comm)) + .collect::>(); + let permutation_zs_cap = permutation_zs_commitment + .iter() + .map(|pzc| pzc.as_ref().map(|commit| commit.merkle_tree.cap.clone())) + .collect::>(); + for cap in &permutation_zs_cap { + challenger.observe_cap(cap); + } + + // let alphas = challenger.get_n_challenges(config.num_challenges); + // let quotient_polys = compute_quotient_polys::::Packing, C, S, D>( + // &stark, + // &trace_commitment, + // &permutation_zs_commitment_challenges, + // public_inputs, + // alphas, + // degree_bits, + // config, + // ); + // let all_quotient_chunks = quotient_polys + // .into_par_iter() + // .flat_map(|mut quotient_poly| { + // quotient_poly + // .trim_to_len(degree * stark.quotient_degree_factor()) + // .expect("Quotient has failed, the vanishing polynomial is not divisible by Z_H"); + // // Split quotient into degree-n chunks. + // quotient_poly.chunks(degree) + // }) + // .collect(); + // let quotient_commitment = timed!( + // timing, + // "compute quotient commitment", + // PolynomialBatch::from_coeffs( + // all_quotient_chunks, + // rate_bits, + // false, + // config.fri_config.cap_height, + // timing, + // None, + // ) + // ); + // let quotient_polys_cap = quotient_commitment.merkle_tree.cap.clone(); + // challenger.observe_cap("ient_polys_cap); + // + // let zeta = challenger.get_extension_challenge::(); + // // To avoid leaking witness data, we want to ensure that our opening locations, `zeta` and + // // `g * zeta`, are not in our subgroup `H`. It suffices to check `zeta` only, since + // // `(g * zeta)^n = zeta^n`, where `n` is the order of `g`. + // let g = F::primitive_root_of_unity(degree_bits); + // ensure!( + // zeta.exp_power_of_2(degree_bits) != F::Extension::ONE, + // "Opening point is in the subgroup." + // ); + // let openings = StarkOpeningSet::new( + // zeta, + // g, + // &trace_commitment, + // permutation_zs_commitment, + // "ient_commitment, + // ); + // challenger.observe_openings(&openings.to_fri_openings()); + // + // let initial_merkle_trees = once(&trace_commitment) + // .chain(permutation_zs_commitment) + // .chain(once("ient_commitment)) + // .collect_vec(); + // + // let opening_proof = timed!( + // timing, + // "compute openings proof", + // PolynomialBatch::prove_openings( + // &stark.fri_instance(zeta, g, config), + // &initial_merkle_trees, + // &mut challenger, + // &fri_params, + // timing, + // ) + // ); + // let proof = StarkProof { + // trace_cap, + // permutation_zs_cap, + // quotient_polys_cap, + // openings, + // opening_proof, + // }; + // + // Ok(StarkProofWithPublicInputs { + // proof, + // public_inputs: public_inputs.to_vec(), + // }) + todo!() +} + +fn add_cross_table_lookup_columns( + config: &StarkConfig, + trace_poly_values: Vec>>, + cross_table_lookups: Vec, +) { + for cross_table_lookup in cross_table_lookups { + let CrossTableLookup { + looking_table: source_table, + looking_columns: source_columns, + looked_table: target_table, + looked_columns: target_columns, + } = cross_table_lookup; + } +} + +fn permutation_challenges<'a, F, P, C, S, const D: usize>( + stark: &S, + trace_poly_values: &[PolynomialValues], + config: &StarkConfig, + challenger: &mut Challenger, + timing: &mut TimingTree, +) -> Option<(PolynomialBatch, Vec>)> +where + F: RichField + Extendable, + P: PackedField, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); P::WIDTH]:, +{ + // Permutation arguments. + stark.uses_permutation_args().then(|| { + let permutation_challenge_sets = get_n_permutation_challenge_sets( + challenger, + config.num_challenges, + stark.permutation_batch_size(), + ); + let permutation_z_polys = compute_permutation_z_polys::( + &stark, + config, + &trace_poly_values, + &permutation_challenge_sets, + ); + + let permutation_zs_commitment = timed!( + timing, + "compute permutation Z commitments", + PolynomialBatch::from_values( + permutation_z_polys, + rate_bits, + false, + config.fri_config.cap_height, + timing, + None, + ) + ); + (permutation_zs_commitment, permutation_challenge_sets) + }) +} + +/// Computes the quotient polynomials `(sum alpha^i C_i(x)) / Z_H(x)` for `alpha` in `alphas`, +/// where the `C_i`s are the Stark constraints. +fn compute_quotient_polys<'a, F, P, C, S, const D: usize>( + stark: &S, + trace_commitment: &'a PolynomialBatch, + permutation_zs_commitment_challenges: &'a Option<( + PolynomialBatch, + Vec>, + )>, + public_inputs: [F; S::PUBLIC_INPUTS], + alphas: Vec, + degree_bits: usize, + config: &StarkConfig, +) -> Vec> +where + F: RichField + Extendable, + P: PackedField, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); P::WIDTH]:, +{ + let degree = 1 << degree_bits; + let rate_bits = config.fri_config.rate_bits; + + let quotient_degree_bits = log2_ceil(stark.quotient_degree_factor()); + assert!( + quotient_degree_bits <= rate_bits, + "Having constraints of degree higher than the rate is not supported yet." + ); + let step = 1 << (rate_bits - quotient_degree_bits); + // When opening the `Z`s polys at the "next" point, need to look at the point `next_step` steps away. + let next_step = 1 << quotient_degree_bits; + + // Evaluation of the first Lagrange polynomial on the LDE domain. + let lagrange_first = PolynomialValues::selector(degree, 0).lde_onto_coset(quotient_degree_bits); + // Evaluation of the last Lagrange polynomial on the LDE domain. + let lagrange_last = + PolynomialValues::selector(degree, degree - 1).lde_onto_coset(quotient_degree_bits); + + let z_h_on_coset = ZeroPolyOnCoset::::new(degree_bits, quotient_degree_bits); + + // Retrieve the LDE values at index `i`. + let get_trace_values_packed = |i_start| -> [P; S::COLUMNS] { + trace_commitment + .get_lde_values_packed(i_start, step) + .try_into() + .unwrap() + }; + + // Last element of the subgroup. + let last = F::primitive_root_of_unity(degree_bits).inverse(); + let size = degree << quotient_degree_bits; + let coset = F::cyclic_subgroup_coset_known_order( + F::primitive_root_of_unity(degree_bits + quotient_degree_bits), + F::coset_shift(), + size, + ); + + // We will step by `P::WIDTH`, and in each iteration, evaluate the quotient polynomial at + // a batch of `P::WIDTH` points. + let quotient_values = (0..size) + .into_par_iter() + .step_by(P::WIDTH) + .map(|i_start| { + let i_next_start = (i_start + next_step) % size; + let i_range = i_start..i_start + P::WIDTH; + + let x = *P::from_slice(&coset[i_range.clone()]); + let z_last = x - last; + let lagrange_basis_first = *P::from_slice(&lagrange_first.values[i_range.clone()]); + let lagrange_basis_last = *P::from_slice(&lagrange_last.values[i_range]); + + let mut consumer = ConstraintConsumer::new( + alphas.clone(), + z_last, + lagrange_basis_first, + lagrange_basis_last, + ); + let vars = StarkEvaluationVars { + local_values: &get_trace_values_packed(i_start), + next_values: &get_trace_values_packed(i_next_start), + public_inputs: &public_inputs, + }; + let permutation_check_data = permutation_zs_commitment_challenges.as_ref().map( + |(permutation_zs_commitment, permutation_challenge_sets)| PermutationCheckVars { + local_zs: permutation_zs_commitment.get_lde_values_packed(i_start, step), + next_zs: permutation_zs_commitment.get_lde_values_packed(i_next_start, step), + permutation_challenge_sets: permutation_challenge_sets.to_vec(), + }, + ); + eval_vanishing_poly::( + stark, + config, + vars, + permutation_check_data, + &mut consumer, + ); + let mut constraints_evals = consumer.accumulators(); + // We divide the constraints evaluations by `Z_H(x)`. + let denominator_inv = z_h_on_coset.eval_inverse_packed(i_start); + for eval in &mut constraints_evals { + *eval *= denominator_inv; + } + constraints_evals + }) + .collect::>(); + + transpose("ient_values) + .into_par_iter() + .map(PolynomialValues::new) + .map(|values| values.coset_ifft(F::coset_shift())) + .collect() +} diff --git a/starky2/src/recursive_verifier.rs b/starky2/src/recursive_verifier.rs new file mode 100644 index 00000000..e091d64c --- /dev/null +++ b/starky2/src/recursive_verifier.rs @@ -0,0 +1,333 @@ +use std::iter::once; + +use anyhow::{ensure, Result}; +use itertools::Itertools; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::fri::witness_util::set_fri_proof_target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::witness::Witness; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; +use plonky2::util::reducing::ReducingFactorTarget; +use plonky2::with_context; + +use crate::config::StarkConfig; +use crate::constraint_consumer::RecursiveConstraintConsumer; +use crate::permutation::PermutationCheckDataTarget; +use crate::proof::{ + StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, + StarkProofWithPublicInputs, StarkProofWithPublicInputsTarget, +}; +use crate::stark::Stark; +use crate::vanishing_poly::eval_vanishing_poly_recursively; +use crate::vars::StarkEvaluationTargets; + +pub fn recursively_verify_stark_proof< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + builder: &mut CircuitBuilder, + stark: S, + proof_with_pis: StarkProofWithPublicInputsTarget, + inner_config: &StarkConfig, +) where + C::Hasher: AlgebraicHasher, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + assert_eq!(proof_with_pis.public_inputs.len(), S::PUBLIC_INPUTS); + let degree_bits = proof_with_pis.proof.recover_degree_bits(inner_config); + let challenges = with_context!( + builder, + "compute challenges", + proof_with_pis.get_challenges::(builder, &stark, inner_config) + ); + + recursively_verify_stark_proof_with_challenges::( + builder, + stark, + proof_with_pis, + challenges, + inner_config, + degree_bits, + ); +} + +/// Recursively verifies an inner proof. +fn recursively_verify_stark_proof_with_challenges< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + builder: &mut CircuitBuilder, + stark: S, + proof_with_pis: StarkProofWithPublicInputsTarget, + challenges: StarkProofChallengesTarget, + inner_config: &StarkConfig, + degree_bits: usize, +) where + C::Hasher: AlgebraicHasher, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + check_permutation_options(&stark, &proof_with_pis, &challenges).unwrap(); + let one = builder.one_extension(); + + let StarkProofWithPublicInputsTarget { + proof, + public_inputs, + } = proof_with_pis; + let StarkOpeningSetTarget { + local_values, + next_values, + permutation_zs, + permutation_zs_right, + quotient_polys, + } = &proof.openings; + let vars = StarkEvaluationTargets { + local_values: &local_values.to_vec().try_into().unwrap(), + next_values: &next_values.to_vec().try_into().unwrap(), + public_inputs: &public_inputs + .into_iter() + .map(|t| builder.convert_to_ext(t)) + .collect::>() + .try_into() + .unwrap(), + }; + + let zeta_pow_deg = builder.exp_power_of_2_extension(challenges.stark_zeta, degree_bits); + let z_h_zeta = builder.sub_extension(zeta_pow_deg, one); + let (l_1, l_last) = + eval_l_1_and_l_last_recursively(builder, degree_bits, challenges.stark_zeta, z_h_zeta); + let last = + builder.constant_extension(F::Extension::primitive_root_of_unity(degree_bits).inverse()); + let z_last = builder.sub_extension(challenges.stark_zeta, last); + + let mut consumer = RecursiveConstraintConsumer::::new( + builder.zero_extension(), + challenges.stark_alphas, + z_last, + l_1, + l_last, + ); + + let permutation_data = stark + .uses_permutation_args() + .then(|| PermutationCheckDataTarget { + local_zs: permutation_zs.as_ref().unwrap().clone(), + next_zs: permutation_zs_right.as_ref().unwrap().clone(), + permutation_challenge_sets: challenges.permutation_challenge_sets.unwrap(), + }); + + with_context!( + builder, + "evaluate vanishing polynomial", + eval_vanishing_poly_recursively::( + builder, + &stark, + inner_config, + vars, + permutation_data, + &mut consumer, + ) + ); + let vanishing_polys_zeta = consumer.accumulators(); + + // Check each polynomial identity, of the form `vanishing(x) = Z_H(x) quotient(x)`, at zeta. + let mut scale = ReducingFactorTarget::new(zeta_pow_deg); + for (i, chunk) in quotient_polys + .chunks(stark.quotient_degree_factor()) + .enumerate() + { + let recombined_quotient = scale.reduce(chunk, builder); + let computed_vanishing_poly = builder.mul_extension(z_h_zeta, recombined_quotient); + builder.connect_extension(vanishing_polys_zeta[i], computed_vanishing_poly); + } + + let merkle_caps = once(proof.trace_cap) + .chain(proof.permutation_zs_cap) + .chain(once(proof.quotient_polys_cap)) + .collect_vec(); + + let fri_instance = stark.fri_instance_target( + builder, + challenges.stark_zeta, + F::primitive_root_of_unity(degree_bits), + inner_config, + ); + builder.verify_fri_proof::( + &fri_instance, + &proof.openings.to_fri_openings(), + &challenges.fri_challenges, + &merkle_caps, + &proof.opening_proof, + &inner_config.fri_params(degree_bits), + ); +} + +fn eval_l_1_and_l_last_recursively, const D: usize>( + builder: &mut CircuitBuilder, + log_n: usize, + x: ExtensionTarget, + z_x: ExtensionTarget, +) -> (ExtensionTarget, ExtensionTarget) { + let n = builder.constant_extension(F::Extension::from_canonical_usize(1 << log_n)); + let g = builder.constant_extension(F::Extension::primitive_root_of_unity(log_n)); + let one = builder.one_extension(); + let l_1_deno = builder.mul_sub_extension(n, x, n); + let l_last_deno = builder.mul_sub_extension(g, x, one); + let l_last_deno = builder.mul_extension(n, l_last_deno); + + ( + builder.div_extension(z_x, l_1_deno), + builder.div_extension(z_x, l_last_deno), + ) +} + +pub fn add_virtual_stark_proof_with_pis< + F: RichField + Extendable, + S: Stark, + const D: usize, +>( + builder: &mut CircuitBuilder, + stark: S, + config: &StarkConfig, + degree_bits: usize, +) -> StarkProofWithPublicInputsTarget { + let proof = add_virtual_stark_proof::(builder, stark, config, degree_bits); + let public_inputs = builder.add_virtual_targets(S::PUBLIC_INPUTS); + StarkProofWithPublicInputsTarget { + proof, + public_inputs, + } +} + +pub fn add_virtual_stark_proof, S: Stark, const D: usize>( + builder: &mut CircuitBuilder, + stark: S, + config: &StarkConfig, + degree_bits: usize, +) -> StarkProofTarget { + let fri_params = config.fri_params(degree_bits); + let cap_height = fri_params.config.cap_height; + + let num_leaves_per_oracle = once(S::COLUMNS) + .chain( + stark + .uses_permutation_args() + .then(|| stark.num_permutation_batches(config)), + ) + .chain(once(stark.quotient_degree_factor() * config.num_challenges)) + .collect_vec(); + + let permutation_zs_cap = stark + .uses_permutation_args() + .then(|| builder.add_virtual_cap(cap_height)); + + StarkProofTarget { + trace_cap: builder.add_virtual_cap(cap_height), + permutation_zs_cap, + quotient_polys_cap: builder.add_virtual_cap(cap_height), + openings: add_stark_opening_set::(builder, stark, config), + opening_proof: builder.add_virtual_fri_proof(&num_leaves_per_oracle, &fri_params), + } +} + +fn add_stark_opening_set, S: Stark, const D: usize>( + builder: &mut CircuitBuilder, + stark: S, + config: &StarkConfig, +) -> StarkOpeningSetTarget { + let num_challenges = config.num_challenges; + StarkOpeningSetTarget { + local_values: builder.add_virtual_extension_targets(S::COLUMNS), + next_values: builder.add_virtual_extension_targets(S::COLUMNS), + permutation_zs: stark + .uses_permutation_args() + .then(|| builder.add_virtual_extension_targets(stark.num_permutation_batches(config))), + permutation_zs_right: stark + .uses_permutation_args() + .then(|| builder.add_virtual_extension_targets(stark.num_permutation_batches(config))), + quotient_polys: builder + .add_virtual_extension_targets(stark.quotient_degree_factor() * num_challenges), + } +} + +pub fn set_stark_proof_with_pis_target, W, const D: usize>( + witness: &mut W, + stark_proof_with_pis_target: &StarkProofWithPublicInputsTarget, + stark_proof_with_pis: &StarkProofWithPublicInputs, +) where + F: RichField + Extendable, + C::Hasher: AlgebraicHasher, + W: Witness, +{ + let StarkProofWithPublicInputs { + proof, + public_inputs, + } = stark_proof_with_pis; + let StarkProofWithPublicInputsTarget { + proof: pt, + public_inputs: pi_targets, + } = stark_proof_with_pis_target; + + // Set public inputs. + for (&pi_t, &pi) in pi_targets.iter().zip_eq(public_inputs) { + witness.set_target(pi_t, pi); + } + + set_stark_proof_target(witness, pt, proof); +} + +pub fn set_stark_proof_target, W, const D: usize>( + witness: &mut W, + proof_target: &StarkProofTarget, + proof: &StarkProof, +) where + F: RichField + Extendable, + C::Hasher: AlgebraicHasher, + W: Witness, +{ + witness.set_cap_target(&proof_target.trace_cap, &proof.trace_cap); + witness.set_cap_target(&proof_target.quotient_polys_cap, &proof.quotient_polys_cap); + + witness.set_fri_openings( + &proof_target.openings.to_fri_openings(), + &proof.openings.to_fri_openings(), + ); + + if let (Some(permutation_zs_cap_target), Some(permutation_zs_cap)) = + (&proof_target.permutation_zs_cap, &proof.permutation_zs_cap) + { + witness.set_cap_target(permutation_zs_cap_target, permutation_zs_cap); + } + + set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof); +} + +/// Utility function to check that all permutation data wrapped in `Option`s are `Some` iff +/// the Stark uses a permutation argument. +fn check_permutation_options, S: Stark, const D: usize>( + stark: &S, + proof_with_pis: &StarkProofWithPublicInputsTarget, + challenges: &StarkProofChallengesTarget, +) -> Result<()> { + let options_is_some = [ + proof_with_pis.proof.permutation_zs_cap.is_some(), + proof_with_pis.proof.openings.permutation_zs.is_some(), + proof_with_pis.proof.openings.permutation_zs_right.is_some(), + challenges.permutation_challenge_sets.is_some(), + ]; + ensure!( + options_is_some + .into_iter() + .all(|b| b == stark.uses_permutation_args()), + "Permutation data doesn't match with Stark configuration." + ); + Ok(()) +} diff --git a/starky2/src/stark.rs b/starky2/src/stark.rs new file mode 100644 index 00000000..a130d283 --- /dev/null +++ b/starky2/src/stark.rs @@ -0,0 +1,204 @@ +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::fri::structure::{ + FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriOracleInfo, + FriPolynomialInfo, +}; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_util::ceil_div_usize; + +use crate::config::StarkConfig; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::permutation::PermutationPair; +use crate::vars::StarkEvaluationTargets; +use crate::vars::StarkEvaluationVars; + +/// Represents a STARK system. +pub trait Stark, const D: usize>: Sync { + /// The total number of columns in the trace. + const COLUMNS: usize; + /// The number of public inputs. + const PUBLIC_INPUTS: usize; + + /// Evaluate constraints at a vector of points. + /// + /// The points are elements of a field `FE`, a degree `D2` extension of `F`. This lets us + /// evaluate constraints over a larger domain if desired. This can also be called with `FE = F` + /// and `D2 = 1`, in which case we are using the trivial extension, i.e. just evaluating + /// constraints over `F`. + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField; + + /// Evaluate constraints at a vector of points from the base field `F`. + fn eval_packed_base>( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) { + self.eval_packed_generic(vars, yield_constr) + } + + /// Evaluate constraints at a single point from the degree `D` extension field. + fn eval_ext( + &self, + vars: StarkEvaluationVars< + F::Extension, + F::Extension, + { Self::COLUMNS }, + { Self::PUBLIC_INPUTS }, + >, + yield_constr: &mut ConstraintConsumer, + ) { + self.eval_packed_generic(vars, yield_constr) + } + + /// Evaluate constraints at a vector of points from the degree `D` extension field. This is like + /// `eval_ext`, except in the context of a recursive circuit. + /// Note: constraints must be added through`yeld_constr.constraint(builder, constraint)` in the + /// same order as they are given in `eval_packed_generic`. + fn eval_ext_recursively( + &self, + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ); + + /// The maximum constraint degree. + fn constraint_degree(&self) -> usize; + + /// The maximum constraint degree. + fn quotient_degree_factor(&self) -> usize { + 1.max(self.constraint_degree() - 1) + } + + /// Computes the FRI instance used to prove this Stark. + fn fri_instance( + &self, + zeta: F::Extension, + g: F, + config: &StarkConfig, + ) -> FriInstanceInfo { + let no_blinding_oracle = FriOracleInfo { blinding: false }; + let mut oracle_indices = 0..; + + let trace_info = + FriPolynomialInfo::from_range(oracle_indices.next().unwrap(), 0..Self::COLUMNS); + + let permutation_zs_info = if self.uses_permutation_args() { + FriPolynomialInfo::from_range( + oracle_indices.next().unwrap(), + 0..self.num_permutation_batches(config), + ) + } else { + vec![] + }; + + let quotient_info = FriPolynomialInfo::from_range( + oracle_indices.next().unwrap(), + 0..self.quotient_degree_factor() * config.num_challenges, + ); + + let zeta_batch = FriBatchInfo { + point: zeta, + polynomials: [ + trace_info.clone(), + permutation_zs_info.clone(), + quotient_info, + ] + .concat(), + }; + let zeta_right_batch = FriBatchInfo { + point: zeta.scalar_mul(g), + polynomials: [trace_info, permutation_zs_info].concat(), + }; + FriInstanceInfo { + oracles: vec![no_blinding_oracle; oracle_indices.next().unwrap()], + batches: vec![zeta_batch, zeta_right_batch], + } + } + + /// Computes the FRI instance used to prove this Stark. + fn fri_instance_target( + &self, + builder: &mut CircuitBuilder, + zeta: ExtensionTarget, + g: F, + config: &StarkConfig, + ) -> FriInstanceInfoTarget { + let no_blinding_oracle = FriOracleInfo { blinding: false }; + let mut oracle_indices = 0..; + + let trace_info = + FriPolynomialInfo::from_range(oracle_indices.next().unwrap(), 0..Self::COLUMNS); + + let permutation_zs_info = if self.uses_permutation_args() { + FriPolynomialInfo::from_range( + oracle_indices.next().unwrap(), + 0..self.num_permutation_batches(config), + ) + } else { + vec![] + }; + + let quotient_info = FriPolynomialInfo::from_range( + oracle_indices.next().unwrap(), + 0..self.quotient_degree_factor() * config.num_challenges, + ); + + let zeta_batch = FriBatchInfoTarget { + point: zeta, + polynomials: [ + trace_info.clone(), + permutation_zs_info.clone(), + quotient_info, + ] + .concat(), + }; + let zeta_right = builder.mul_const_extension(g, zeta); + let zeta_right_batch = FriBatchInfoTarget { + point: zeta_right, + polynomials: [trace_info, permutation_zs_info].concat(), + }; + FriInstanceInfoTarget { + oracles: vec![no_blinding_oracle; oracle_indices.next().unwrap()], + batches: vec![zeta_batch, zeta_right_batch], + } + } + + /// Pairs of lists of columns that should be permutations of one another. A permutation argument + /// will be used for each such pair. Empty by default. + fn permutation_pairs(&self) -> Vec { + vec![] + } + + fn uses_permutation_args(&self) -> bool { + !self.permutation_pairs().is_empty() + } + + /// The number of permutation argument instances that can be combined into a single constraint. + fn permutation_batch_size(&self) -> usize { + // The permutation argument constraints look like + // Z(x) \prod(...) = Z(g x) \prod(...) + // where each product has a number of terms equal to the batch size. So our batch size + // should be one less than our constraint degree, which happens to be our quotient degree. + self.quotient_degree_factor() + } + + fn num_permutation_instances(&self, config: &StarkConfig) -> usize { + self.permutation_pairs().len() * config.num_challenges + } + + fn num_permutation_batches(&self, config: &StarkConfig) -> usize { + ceil_div_usize( + self.num_permutation_instances(config), + self.permutation_batch_size(), + ) + } +} diff --git a/starky2/src/stark_testing.rs b/starky2/src/stark_testing.rs new file mode 100644 index 00000000..222ebf39 --- /dev/null +++ b/starky2/src/stark_testing.rs @@ -0,0 +1,87 @@ +use anyhow::{ensure, Result}; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues}; +use plonky2::hash::hash_types::RichField; +use plonky2::util::transpose; +use plonky2_util::{log2_ceil, log2_strict}; + +use crate::constraint_consumer::ConstraintConsumer; +use crate::stark::Stark; +use crate::vars::StarkEvaluationVars; + +const WITNESS_SIZE: usize = 1 << 5; + +/// Tests that the constraints imposed by the given STARK are low-degree by applying them to random +/// low-degree witness polynomials. +pub fn test_stark_low_degree, S: Stark, const D: usize>( + stark: S, +) -> Result<()> +where + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + let rate_bits = log2_ceil(stark.constraint_degree() + 1); + + let trace_ldes = random_low_degree_matrix::(S::COLUMNS, rate_bits); + let size = trace_ldes.len(); + let public_inputs = F::rand_arr::<{ S::PUBLIC_INPUTS }>(); + + let lagrange_first = PolynomialValues::selector(WITNESS_SIZE, 0).lde(rate_bits); + let lagrange_last = PolynomialValues::selector(WITNESS_SIZE, WITNESS_SIZE - 1).lde(rate_bits); + + let last = F::primitive_root_of_unity(log2_strict(WITNESS_SIZE)).inverse(); + let subgroup = + F::cyclic_subgroup_known_order(F::primitive_root_of_unity(log2_strict(size)), size); + let alpha = F::rand(); + let constraint_evals = (0..size) + .map(|i| { + let vars = StarkEvaluationVars { + local_values: &trace_ldes[i].clone().try_into().unwrap(), + next_values: &trace_ldes[(i + (1 << rate_bits)) % size] + .clone() + .try_into() + .unwrap(), + public_inputs: &public_inputs, + }; + + let mut consumer = ConstraintConsumer::::new( + vec![alpha], + subgroup[i] - last, + lagrange_first.values[i], + lagrange_last.values[i], + ); + stark.eval_packed_base(vars, &mut consumer); + consumer.accumulators()[0] + }) + .collect::>(); + + let constraint_eval_degree = PolynomialValues::new(constraint_evals).degree(); + let maximum_degree = WITNESS_SIZE * stark.constraint_degree() - 1; + + ensure!( + constraint_eval_degree <= maximum_degree, + "Expected degrees at most {} * {} - 1 = {}, actual {:?}", + WITNESS_SIZE, + stark.constraint_degree(), + maximum_degree, + constraint_eval_degree + ); + + Ok(()) +} + +fn random_low_degree_matrix(num_polys: usize, rate_bits: usize) -> Vec> { + let polys = (0..num_polys) + .map(|_| random_low_degree_values(rate_bits)) + .collect::>(); + + transpose(&polys) +} + +fn random_low_degree_values(rate_bits: usize) -> Vec { + PolynomialCoeffs::new(F::rand_vec(WITNESS_SIZE)) + .lde(rate_bits) + .fft() + .values +} diff --git a/starky2/src/util.rs b/starky2/src/util.rs new file mode 100644 index 00000000..011a1add --- /dev/null +++ b/starky2/src/util.rs @@ -0,0 +1,16 @@ +use itertools::Itertools; +use plonky2::field::field_types::Field; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::util::transpose; + +/// A helper function to transpose a row-wise trace and put it in the format that `prove` expects. +pub fn trace_rows_to_poly_values( + trace_rows: Vec<[F; COLUMNS]>, +) -> Vec> { + let trace_row_vecs = trace_rows.into_iter().map(|row| row.to_vec()).collect_vec(); + let trace_col_vecs: Vec> = transpose(&trace_row_vecs); + trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect() +} diff --git a/starky2/src/vanishing_poly.rs b/starky2/src/vanishing_poly.rs new file mode 100644 index 00000000..dc32b800 --- /dev/null +++ b/starky2/src/vanishing_poly.rs @@ -0,0 +1,68 @@ +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::GenericConfig; + +use crate::config::StarkConfig; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::permutation::{ + eval_permutation_checks, eval_permutation_checks_recursively, PermutationCheckDataTarget, + PermutationCheckVars, +}; +use crate::stark::Stark; +use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; + +pub(crate) fn eval_vanishing_poly( + stark: &S, + config: &StarkConfig, + vars: StarkEvaluationVars, + permutation_data: Option>, + consumer: &mut ConstraintConsumer

, +) where + F: RichField + Extendable, + FE: FieldExtension, + P: PackedField, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + stark.eval_packed_generic(vars, consumer); + if let Some(permutation_data) = permutation_data { + eval_permutation_checks::( + stark, + config, + vars, + permutation_data, + consumer, + ); + } +} + +pub(crate) fn eval_vanishing_poly_recursively( + builder: &mut CircuitBuilder, + stark: &S, + config: &StarkConfig, + vars: StarkEvaluationTargets, + permutation_data: Option>, + consumer: &mut RecursiveConstraintConsumer, +) where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + stark.eval_ext_recursively(builder, vars, consumer); + if let Some(permutation_data) = permutation_data { + eval_permutation_checks_recursively::( + builder, + stark, + config, + vars, + permutation_data, + consumer, + ); + } +} diff --git a/starky2/src/vars.rs b/starky2/src/vars.rs new file mode 100644 index 00000000..cb83aeb7 --- /dev/null +++ b/starky2/src/vars.rs @@ -0,0 +1,26 @@ +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::iop::ext_target::ExtensionTarget; + +#[derive(Debug, Copy, Clone)] +pub struct StarkEvaluationVars<'a, F, P, const COLUMNS: usize, const PUBLIC_INPUTS: usize> +where + F: Field, + P: PackedField, +{ + pub local_values: &'a [P; COLUMNS], + pub next_values: &'a [P; COLUMNS], + pub public_inputs: &'a [P::Scalar; PUBLIC_INPUTS], +} + +#[derive(Debug, Copy, Clone)] +pub struct StarkEvaluationTargets< + 'a, + const D: usize, + const COLUMNS: usize, + const PUBLIC_INPUTS: usize, +> { + pub local_values: &'a [ExtensionTarget; COLUMNS], + pub next_values: &'a [ExtensionTarget; COLUMNS], + pub public_inputs: &'a [ExtensionTarget; PUBLIC_INPUTS], +} diff --git a/starky2/src/verifier.rs b/starky2/src/verifier.rs new file mode 100644 index 00000000..ca88ae8b --- /dev/null +++ b/starky2/src/verifier.rs @@ -0,0 +1,208 @@ +use std::iter::once; + +use anyhow::{ensure, Result}; +use itertools::Itertools; +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::field_types::Field; +use plonky2::fri::verifier::verify_fri_proof; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2::plonk::plonk_common::reduce_with_powers; + +use crate::config::StarkConfig; +use crate::constraint_consumer::ConstraintConsumer; +use crate::permutation::PermutationCheckVars; +use crate::proof::{StarkOpeningSet, StarkProofChallenges, StarkProofWithPublicInputs}; +use crate::stark::Stark; +use crate::vanishing_poly::eval_vanishing_poly; +use crate::vars::StarkEvaluationVars; + +pub fn verify_stark_proof< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + stark: S, + proof_with_pis: StarkProofWithPublicInputs, + config: &StarkConfig, +) -> Result<()> +where + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + ensure!(proof_with_pis.public_inputs.len() == S::PUBLIC_INPUTS); + let degree_bits = proof_with_pis.proof.recover_degree_bits(config); + let challenges = proof_with_pis.get_challenges(&stark, config, degree_bits); + verify_stark_proof_with_challenges(stark, proof_with_pis, challenges, degree_bits, config) +} + +pub(crate) fn verify_stark_proof_with_challenges< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + stark: S, + proof_with_pis: StarkProofWithPublicInputs, + challenges: StarkProofChallenges, + degree_bits: usize, + config: &StarkConfig, +) -> Result<()> +where + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + check_permutation_options(&stark, &proof_with_pis, &challenges)?; + let StarkProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + let StarkOpeningSet { + local_values, + next_values, + permutation_zs, + permutation_zs_right, + quotient_polys, + } = &proof.openings; + let vars = StarkEvaluationVars { + local_values: &local_values.to_vec().try_into().unwrap(), + next_values: &next_values.to_vec().try_into().unwrap(), + public_inputs: &public_inputs + .into_iter() + .map(F::Extension::from_basefield) + .collect::>() + .try_into() + .unwrap(), + }; + + let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta); + let last = F::primitive_root_of_unity(degree_bits).inverse(); + let z_last = challenges.stark_zeta - last.into(); + let mut consumer = ConstraintConsumer::::new( + challenges + .stark_alphas + .iter() + .map(|&alpha| F::Extension::from_basefield(alpha)) + .collect::>(), + z_last, + l_1, + l_last, + ); + let permutation_data = stark.uses_permutation_args().then(|| PermutationCheckVars { + local_zs: permutation_zs.as_ref().unwrap().clone(), + next_zs: permutation_zs_right.as_ref().unwrap().clone(), + permutation_challenge_sets: challenges.permutation_challenge_sets.unwrap(), + }); + eval_vanishing_poly::( + &stark, + config, + vars, + permutation_data, + &mut consumer, + ); + let vanishing_polys_zeta = consumer.accumulators(); + + // Check each polynomial identity, of the form `vanishing(x) = Z_H(x) quotient(x)`, at zeta. + let zeta_pow_deg = challenges.stark_zeta.exp_power_of_2(degree_bits); + let z_h_zeta = zeta_pow_deg - F::Extension::ONE; + // `quotient_polys_zeta` holds `num_challenges * quotient_degree_factor` evaluations. + // Each chunk of `quotient_degree_factor` holds the evaluations of `t_0(zeta),...,t_{quotient_degree_factor-1}(zeta)` + // where the "real" quotient polynomial is `t(X) = t_0(X) + t_1(X)*X^n + t_2(X)*X^{2n} + ...`. + // So to reconstruct `t(zeta)` we can compute `reduce_with_powers(chunk, zeta^n)` for each + // `quotient_degree_factor`-sized chunk of the original evaluations. + for (i, chunk) in quotient_polys + .chunks(stark.quotient_degree_factor()) + .enumerate() + { + ensure!( + vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg), + "Mismatch between evaluation and opening of quotient polynomial" + ); + } + + let merkle_caps = once(proof.trace_cap) + .chain(proof.permutation_zs_cap) + .chain(once(proof.quotient_polys_cap)) + .collect_vec(); + + verify_fri_proof::( + &stark.fri_instance( + challenges.stark_zeta, + F::primitive_root_of_unity(degree_bits), + config, + ), + &proof.openings.to_fri_openings(), + &challenges.fri_challenges, + &merkle_caps, + &proof.opening_proof, + &config.fri_params(degree_bits), + )?; + + Ok(()) +} + +/// Evaluate the Lagrange polynomials `L_1` and `L_n` at a point `x`. +/// `L_1(x) = (x^n - 1)/(n * (x - 1))` +/// `L_n(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. +fn eval_l_1_and_l_last(log_n: usize, x: F) -> (F, F) { + let n = F::from_canonical_usize(1 << log_n); + let g = F::primitive_root_of_unity(log_n); + let z_x = x.exp_power_of_2(log_n) - F::ONE; + let invs = F::batch_multiplicative_inverse(&[n * (x - F::ONE), n * (g * x - F::ONE)]); + + (z_x * invs[0], z_x * invs[1]) +} + +/// Utility function to check that all permutation data wrapped in `Option`s are `Some` iff +/// the Stark uses a permutation argument. +fn check_permutation_options< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + stark: &S, + proof_with_pis: &StarkProofWithPublicInputs, + challenges: &StarkProofChallenges, +) -> Result<()> { + let options_is_some = [ + proof_with_pis.proof.permutation_zs_cap.is_some(), + proof_with_pis.proof.openings.permutation_zs.is_some(), + proof_with_pis.proof.openings.permutation_zs_right.is_some(), + challenges.permutation_challenge_sets.is_some(), + ]; + ensure!( + options_is_some + .into_iter() + .all(|b| b == stark.uses_permutation_args()), + "Permutation data doesn't match with Stark configuration." + ); + Ok(()) +} + +#[cfg(test)] +mod tests { + use plonky2::field::field_types::Field; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::polynomial::PolynomialValues; + + use crate::verifier::eval_l_1_and_l_last; + + #[test] + fn test_eval_l_1_and_l_last() { + type F = GoldilocksField; + let log_n = 5; + let n = 1 << log_n; + + let x = F::rand(); // challenge point + let expected_l_first_x = PolynomialValues::selector(n, 0).ifft().eval(x); + let expected_l_last_x = PolynomialValues::selector(n, n - 1).ifft().eval(x); + + let (l_first_x, l_last_x) = eval_l_1_and_l_last(log_n, x); + assert_eq!(l_first_x, expected_l_first_x); + assert_eq!(l_last_x, expected_l_last_x); + } +}