From 7d6c0a448ddb68f5c181f9440bf3213f898519aa Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 16 Mar 2022 17:37:34 -0700 Subject: [PATCH] Halo2 style lookup arguments in System Zero (#513) * Halo2 style lookup arguments in System Zero It's a really nice and simple protocol, particularly for the verifier since the constraints are trivial (aside from the underlying batched permutation checks, which we already support). See the [Halo2 book](https://zcash.github.io/halo2/design/proving-system/lookup.html) and this [talk](https://www.youtube.com/watch?v=YlTt12s7vGE&t=5237s) by @daira. Previously we generated the whole trace in row-wise form, but it's much more efficient to generate these "permuted" columns column-wise. So I changed our STARK framework to accept the trace in column-wise form. STARK impls now have the flexibility to do some generation row-wise and some column-wise (without extra costs; there's a single transpose as before). * sorting * fixes * PR feedback * into_iter * timing --- field/src/field_types.rs | 5 + field/src/goldilocks_field.rs | 3 +- starky/src/fibonacci_stark.rs | 14 +- starky/src/lib.rs | 1 + starky/src/permutation.rs | 8 + starky/src/prover.rs | 16 +- starky/src/util.rs | 16 ++ starky/src/verifier.rs | 5 +- system_zero/Cargo.toml | 8 + system_zero/benches/lookup_permuted_cols.rs | 30 ++++ system_zero/src/lib.rs | 1 + system_zero/src/lookup.rs | 147 ++++++++++++++++++ system_zero/src/registers/lookup.rs | 24 ++- system_zero/src/registers/range_check_16.rs | 2 +- .../src/registers/range_check_degree.rs | 2 +- system_zero/src/system_zero.rs | 76 ++++++++- 16 files changed, 324 insertions(+), 34 deletions(-) create mode 100644 starky/src/util.rs create mode 100644 system_zero/benches/lookup_permuted_cols.rs create mode 100644 system_zero/src/lookup.rs diff --git a/field/src/field_types.rs b/field/src/field_types.rs index 83826b9f..4adfdbf4 100644 --- a/field/src/field_types.rs +++ b/field/src/field_types.rs @@ -462,6 +462,11 @@ pub trait PrimeField64: PrimeField + Field64 { fn to_canonical_u64(&self) -> u64; fn to_noncanonical_u64(&self) -> u64; + + #[inline(always)] + fn to_canonical(&self) -> Self { + Self::from_canonical_u64(self.to_canonical_u64()) + } } /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. diff --git a/field/src/goldilocks_field.rs b/field/src/goldilocks_field.rs index 4ed32a0d..c3172991 100644 --- a/field/src/goldilocks_field.rs +++ b/field/src/goldilocks_field.rs @@ -95,7 +95,7 @@ impl Field for GoldilocksField { Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) } - #[inline] + #[inline(always)] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); Self(n) @@ -156,6 +156,7 @@ impl PrimeField64 for GoldilocksField { c } + #[inline(always)] fn to_noncanonical_u64(&self) -> u64 { self.0 } diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs index 7961ad50..fa9ccd87 100644 --- a/starky/src/fibonacci_stark.rs +++ b/starky/src/fibonacci_stark.rs @@ -2,12 +2,14 @@ use std::marker::PhantomData; use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::PolynomialValues; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::permutation::PermutationPair; use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Toy STARK system used for testing. @@ -37,8 +39,8 @@ impl, const D: usize> FibonacciStark { } /// Generate the trace using `x0, x1, 0, 1` as initial state values. - fn generate_trace(&self, x0: F, x1: F) -> Vec<[F; Self::COLUMNS]> { - let mut trace = (0..self.num_rows) + fn generate_trace(&self, x0: F, x1: F) -> Vec> { + let mut trace_rows = (0..self.num_rows) .scan([x0, x1, F::ZERO, F::ONE], |acc, _| { let tmp = *acc; acc[0] = tmp[1]; @@ -48,8 +50,8 @@ impl, const D: usize> FibonacciStark { Some(tmp) }) .collect::>(); - trace[self.num_rows - 1][3] = F::ZERO; // So that column 2 and 3 are permutation of one another. - trace + trace_rows[self.num_rows - 1][3] = F::ZERO; // So that column 2 and 3 are permutation of one another. + trace_rows_to_poly_values(trace_rows) } } @@ -113,9 +115,7 @@ impl, const D: usize> Stark for FibonacciStar } fn permutation_pairs(&self) -> Vec { - vec![PermutationPair { - column_pairs: vec![(2, 3)], - }] + vec![PermutationPair::singletons(2, 3)] } } diff --git a/starky/src/lib.rs b/starky/src/lib.rs index 8249d90b..b2293443 100644 --- a/starky/src/lib.rs +++ b/starky/src/lib.rs @@ -15,6 +15,7 @@ pub mod prover; pub mod recursive_verifier; pub mod stark; pub mod stark_testing; +pub mod util; pub mod vanishing_poly; pub mod vars; pub mod verifier; diff --git a/starky/src/permutation.rs b/starky/src/permutation.rs index 2e1d603c..91b1be27 100644 --- a/starky/src/permutation.rs +++ b/starky/src/permutation.rs @@ -30,6 +30,14 @@ pub struct PermutationPair { pub column_pairs: Vec<(usize, usize)>, } +impl PermutationPair { + pub fn singletons(lhs: usize, rhs: usize) -> Self { + Self { + column_pairs: vec![(lhs, rhs)], + } + } +} + /// A single instance of a permutation check protocol. pub(crate) struct PermutationInstance<'a, T: Copy> { pub(crate) pair: &'a PermutationPair, diff --git a/starky/src/prover.rs b/starky/src/prover.rs index 336b9963..da1b5dd4 100644 --- a/starky/src/prover.rs +++ b/starky/src/prover.rs @@ -30,7 +30,7 @@ use crate::vars::StarkEvaluationVars; pub fn prove( stark: S, config: &StarkConfig, - trace: Vec<[F; S::COLUMNS]>, + trace_poly_values: Vec>, public_inputs: [F; S::PUBLIC_INPUTS], timing: &mut TimingTree, ) -> Result> @@ -42,7 +42,7 @@ where [(); S::PUBLIC_INPUTS]:, [(); C::Hasher::HASH_SIZE]:, { - let degree = trace.len(); + let degree = trace_poly_values[0].len(); let degree_bits = log2_strict(degree); let fri_params = config.fri_params(degree_bits); let rate_bits = config.fri_config.rate_bits; @@ -52,18 +52,6 @@ where "FRI total reduction arity is too large.", ); - let trace_vecs = trace.iter().map(|row| row.to_vec()).collect_vec(); - let trace_col_major: Vec> = transpose(&trace_vecs); - - let trace_poly_values: Vec> = timed!( - timing, - "compute trace polynomials", - trace_col_major - .par_iter() - .map(|column| PolynomialValues::new(column.clone())) - .collect() - ); - let trace_commitment = timed!( timing, "compute trace commitment", diff --git a/starky/src/util.rs b/starky/src/util.rs new file mode 100644 index 00000000..011a1add --- /dev/null +++ b/starky/src/util.rs @@ -0,0 +1,16 @@ +use itertools::Itertools; +use plonky2::field::field_types::Field; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::util::transpose; + +/// A helper function to transpose a row-wise trace and put it in the format that `prove` expects. +pub fn trace_rows_to_poly_values( + trace_rows: Vec<[F; COLUMNS]>, +) -> Vec> { + let trace_row_vecs = trace_rows.into_iter().map(|row| row.to_vec()).collect_vec(); + let trace_col_vecs: Vec> = transpose(&trace_row_vecs); + trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect() +} diff --git a/starky/src/verifier.rs b/starky/src/verifier.rs index a9bf897c..d5071af7 100644 --- a/starky/src/verifier.rs +++ b/starky/src/verifier.rs @@ -118,7 +118,10 @@ where .chunks(stark.quotient_degree_factor()) .enumerate() { - ensure!(vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg)); + ensure!( + vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg), + "Mismatch between evaluation and opening of quotient polynomial" + ); } let merkle_caps = once(proof.trace_cap) diff --git a/system_zero/Cargo.toml b/system_zero/Cargo.toml index 032bfb53..a9029dad 100644 --- a/system_zero/Cargo.toml +++ b/system_zero/Cargo.toml @@ -10,6 +10,14 @@ plonky2_util = { path = "../util" } starky = { path = "../starky" } anyhow = "1.0.40" env_logger = "0.9.0" +itertools = "0.10.0" log = "0.4.14" rand = "0.8.4" rand_chacha = "0.3.1" + +[dev-dependencies] +criterion = "0.3.5" + +[[bench]] +name = "lookup_permuted_cols" +harness = false diff --git a/system_zero/benches/lookup_permuted_cols.rs b/system_zero/benches/lookup_permuted_cols.rs new file mode 100644 index 00000000..371b3470 --- /dev/null +++ b/system_zero/benches/lookup_permuted_cols.rs @@ -0,0 +1,30 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use itertools::Itertools; +use plonky2::field::field_types::Field; +use plonky2::field::goldilocks_field::GoldilocksField; +use rand::{thread_rng, Rng}; +use system_zero::lookup::permuted_cols; + +type F = GoldilocksField; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("lookup-permuted-cols"); + + for size_log in [16, 17, 18] { + let size = 1 << size_log; + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + // We could benchmark a table of random values with + // let table = F::rand_vec(size); + // But in practice we currently use tables that are pre-sorted, which makes + // permuted_cols cheaper since it will sort the table. + let table = (0..size).map(F::from_canonical_usize).collect_vec(); + let input = (0..size) + .map(|_| table[thread_rng().gen_range(0..size)]) + .collect_vec(); + b.iter(|| permuted_cols(&input, &table)); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/system_zero/src/lib.rs b/system_zero/src/lib.rs index 35576cd3..81e5e9b1 100644 --- a/system_zero/src/lib.rs +++ b/system_zero/src/lib.rs @@ -4,6 +4,7 @@ mod alu; mod core_registers; +pub mod lookup; mod memory; mod permutation_unit; mod public_input_layout; diff --git a/system_zero/src/lookup.rs b/system_zero/src/lookup.rs new file mode 100644 index 00000000..5a5f0da1 --- /dev/null +++ b/system_zero/src/lookup.rs @@ -0,0 +1,147 @@ +//! Implementation of the Halo2 lookup argument. +//! +//! References: +//! - https://zcash.github.io/halo2/design/proving-system/lookup.html +//! - https://www.youtube.com/watch?v=YlTt12s7vGE&t=5237s + +use std::cmp::Ordering; + +use itertools::Itertools; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::vars::StarkEvaluationTargets; +use starky::vars::StarkEvaluationVars; + +use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::lookup::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_lookups(trace_cols: &mut [Vec]) { + for i in 0..NUM_LOOKUPS { + let inputs = &trace_cols[col_input(i)]; + let table = &trace_cols[col_table(i)]; + let (permuted_inputs, permuted_table) = permuted_cols(inputs, table); + trace_cols[col_permuted_input(i)] = permuted_inputs; + trace_cols[col_permuted_table(i)] = permuted_table; + } +} + +/// Given an input column and a table column, generate the permuted input and permuted table columns +/// used in the Halo2 permutation argument. +pub fn permuted_cols(inputs: &[F], table: &[F]) -> (Vec, Vec) { + let n = inputs.len(); + + // The permuted inputs do not have to be ordered, but we found that sorting was faster than + // hash-based grouping. We also sort the table, as this helps us identify "unused" table + // elements efficiently. + + // To compare elements, e.g. for sorting, we first need them in canonical form. It would be + // wasteful to canonicalize in each comparison, as a single element may be involved in many + // comparisons. So we will canonicalize once upfront, then use `to_noncanonical_u64` when + // comparing elements. + + let sorted_inputs = inputs + .iter() + .map(|x| x.to_canonical()) + .sorted_unstable_by_key(|x| x.to_noncanonical_u64()) + .collect_vec(); + let sorted_table = table + .iter() + .map(|x| x.to_canonical()) + .sorted_unstable_by_key(|x| x.to_noncanonical_u64()) + .collect_vec(); + + let mut unused_table_inds = Vec::with_capacity(n); + let mut unused_table_vals = Vec::with_capacity(n); + let mut permuted_table = vec![F::ZERO; n]; + let mut i = 0; + let mut j = 0; + while (j < n) && (i < n) { + let input_val = sorted_inputs[i].to_noncanonical_u64(); + let table_val = sorted_table[j].to_noncanonical_u64(); + match input_val.cmp(&table_val) { + Ordering::Greater => { + unused_table_vals.push(sorted_table[j]); + j += 1; + } + Ordering::Less => { + if let Some(x) = unused_table_vals.pop() { + permuted_table[i] = x; + } else { + unused_table_inds.push(i); + } + i += 1; + } + Ordering::Equal => { + permuted_table[i] = sorted_table[j]; + i += 1; + j += 1; + } + } + } + + #[allow(clippy::needless_range_loop)] // indexing is just more natural here + for jj in j..n { + unused_table_vals.push(sorted_table[jj]); + } + for ii in i..n { + unused_table_inds.push(ii); + } + for (ind, val) in unused_table_inds.into_iter().zip_eq(unused_table_vals) { + permuted_table[ind] = val; + } + + (sorted_inputs, permuted_table) +} + +pub(crate) fn eval_lookups>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + for i in 0..NUM_LOOKUPS { + let local_perm_input = vars.local_values[col_permuted_input(i)]; + let next_perm_table = vars.next_values[col_permuted_table(i)]; + let next_perm_input = vars.next_values[col_permuted_input(i)]; + + // A "vertical" diff between the local and next permuted inputs. + let diff_input_prev = next_perm_input - local_perm_input; + // A "horizontal" diff between the next permuted input and permuted table value. + let diff_input_table = next_perm_input - next_perm_table; + + yield_constr.constraint(diff_input_prev * diff_input_table); + + // This is actually constraining the first row, as per the spec, since `diff_input_table` + // is a diff of the next row's values. In the context of `constraint_last_row`, the next + // row is the first row. + yield_constr.constraint_last_row(diff_input_table); + } +} + +pub(crate) fn eval_lookups_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + for i in 0..NUM_LOOKUPS { + let local_perm_input = vars.local_values[col_permuted_input(i)]; + let next_perm_table = vars.next_values[col_permuted_table(i)]; + let next_perm_input = vars.next_values[col_permuted_input(i)]; + + // A "vertical" diff between the local and next permuted inputs. + let diff_input_prev = builder.sub_extension(next_perm_input, local_perm_input); + // A "horizontal" diff between the next permuted input and permuted table value. + let diff_input_table = builder.sub_extension(next_perm_input, next_perm_table); + + let diff_product = builder.mul_extension(diff_input_prev, diff_input_table); + yield_constr.constraint(builder, diff_product); + + // This is actually constraining the first row, as per the spec, since `diff_input_table` + // is a diff of the next row's values. In the context of `constraint_last_row`, the next + // row is the first row. + yield_constr.constraint_last_row(builder, diff_input_table); + } +} diff --git a/system_zero/src/registers/lookup.rs b/system_zero/src/registers/lookup.rs index eb773acf..fd0abd43 100644 --- a/system_zero/src/registers/lookup.rs +++ b/system_zero/src/registers/lookup.rs @@ -3,19 +3,35 @@ const START_UNIT: usize = super::START_LOOKUP; -const NUM_LOOKUPS: usize = +pub(crate) const NUM_LOOKUPS: usize = super::range_check_16::NUM_RANGE_CHECKS + super::range_check_degree::NUM_RANGE_CHECKS; +pub(crate) const fn col_input(i: usize) -> usize { + if i < super::range_check_16::NUM_RANGE_CHECKS { + super::range_check_16::col_rc_16_input(i) + } else { + super::range_check_degree::col_rc_degree_input(i - super::range_check_16::NUM_RANGE_CHECKS) + } +} + /// This column contains a permutation of the input values. -const fn col_permuted_input(i: usize) -> usize { +pub(crate) const fn col_permuted_input(i: usize) -> usize { debug_assert!(i < NUM_LOOKUPS); START_UNIT + 2 * i } +pub(crate) const fn col_table(i: usize) -> usize { + if i < super::range_check_16::NUM_RANGE_CHECKS { + super::core::COL_RANGE_16 + } else { + super::core::COL_CLOCK + } +} + /// This column contains a permutation of the table values. -const fn col_permuted_table(i: usize) -> usize { +pub(crate) const fn col_permuted_table(i: usize) -> usize { debug_assert!(i < NUM_LOOKUPS); START_UNIT + 2 * i + 1 } -pub(super) const END: usize = START_UNIT + NUM_LOOKUPS; +pub(super) const END: usize = START_UNIT + NUM_LOOKUPS * 2; diff --git a/system_zero/src/registers/range_check_16.rs b/system_zero/src/registers/range_check_16.rs index c44db494..674df302 100644 --- a/system_zero/src/registers/range_check_16.rs +++ b/system_zero/src/registers/range_check_16.rs @@ -1,6 +1,6 @@ //! Range check unit which checks that values are in `[0, 2^16)`. -pub(super) const NUM_RANGE_CHECKS: usize = 5; +pub(crate) const NUM_RANGE_CHECKS: usize = 5; /// The input of the `i`th range check, i.e. the value being range checked. pub(crate) const fn col_rc_16_input(i: usize) -> usize { diff --git a/system_zero/src/registers/range_check_degree.rs b/system_zero/src/registers/range_check_degree.rs index 6d61e6e2..caad705d 100644 --- a/system_zero/src/registers/range_check_degree.rs +++ b/system_zero/src/registers/range_check_degree.rs @@ -1,6 +1,6 @@ //! Range check unit which checks that values are in `[0, degree)`. -pub(super) const NUM_RANGE_CHECKS: usize = 5; +pub(crate) const NUM_RANGE_CHECKS: usize = 5; /// The input of the `i`th range check, i.e. the value being range checked. pub(crate) const fn col_rc_degree_input(i: usize) -> usize { diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs index c42a04a8..32c49266 100644 --- a/system_zero/src/system_zero.rs +++ b/system_zero/src/system_zero.rs @@ -2,8 +2,12 @@ use std::marker::PhantomData; use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::PolynomialValues; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::permutation::PermutationPair; use starky::stark::Stark; @@ -15,12 +19,13 @@ use crate::core_registers::{ eval_core_registers, eval_core_registers_recursively, generate_first_row_core_registers, generate_next_row_core_registers, }; +use crate::lookup::{eval_lookups, eval_lookups_recursively, generate_lookups}; use crate::memory::TransactionMemory; use crate::permutation_unit::{ eval_permutation_unit, eval_permutation_unit_recursively, generate_permutation_unit, }; use crate::public_input_layout::NUM_PUBLIC_INPUTS; -use crate::registers::NUM_COLUMNS; +use crate::registers::{lookup, NUM_COLUMNS}; /// We require at least 2^16 rows as it helps support efficient 16-bit range checks. const MIN_TRACE_ROWS: usize = 1 << 16; @@ -31,7 +36,9 @@ pub struct SystemZero, const D: usize> { } impl, const D: usize> SystemZero { - fn generate_trace(&self) -> Vec<[F; NUM_COLUMNS]> { + /// Generate the rows of the trace. Note that this does not generate the permuted columns used + /// in our lookup arguments, as those are computed after transposing to column-wise form. + fn generate_trace_rows(&self) -> Vec<[F; NUM_COLUMNS]> { let memory = TransactionMemory::default(); let mut row = [F::ZERO; NUM_COLUMNS]; @@ -59,6 +66,45 @@ impl, const D: usize> SystemZero { trace.push(row); trace } + + fn generate_trace(&self) -> Vec> { + let mut timing = TimingTree::new("generate trace", log::Level::Debug); + + // Generate the witness, except for permuted columns in the lookup argument. + let trace_rows = timed!( + &mut timing, + "generate trace rows", + self.generate_trace_rows() + ); + + // Transpose from row-wise to column-wise. + let trace_row_vecs: Vec<_> = timed!( + &mut timing, + "convert to Vecs", + trace_rows.into_iter().map(|row| row.to_vec()).collect() + ); + let mut trace_col_vecs: Vec> = + timed!(&mut timing, "transpose", transpose(&trace_row_vecs)); + + // Generate permuted columns in the lookup argument. + timed!( + &mut timing, + "generate lookup columns", + generate_lookups(&mut trace_col_vecs) + ); + + let trace_polys = timed!( + &mut timing, + "convert to PolynomialValues", + trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect() + ); + + timing.print(); + trace_polys + } } impl, const D: usize> Default for SystemZero { @@ -84,6 +130,7 @@ impl, const D: usize> Stark for SystemZero(vars, yield_constr); + eval_lookups(vars, yield_constr); // TODO: Other units } @@ -96,6 +143,7 @@ impl, const D: usize> Stark for SystemZero, const D: usize> Stark for SystemZero Vec { + let mut pairs = Vec::new(); + + for i in 0..lookup::NUM_LOOKUPS { + pairs.push(PermutationPair::singletons( + lookup::col_input(i), + lookup::col_permuted_input(i), + )); + pairs.push(PermutationPair::singletons( + lookup::col_table(i), + lookup::col_permuted_table(i), + )); + } + // TODO: Add permutation pairs for memory. - // TODO: Add permutation pairs for range checks. - vec![] + + pairs } } @@ -127,8 +188,9 @@ mod tests { use crate::system_zero::SystemZero; #[test] - #[ignore] // A bit slow. fn run() -> Result<()> { + init_logger(); + type F = GoldilocksField; type C = PoseidonGoldilocksConfig; const D: usize = 2; @@ -154,4 +216,8 @@ mod tests { let system = S::default(); test_stark_low_degree(system) } + + fn init_logger() { + let _ = env_logger::builder().format_timestamp(None).try_init(); + } }