Halo2 style lookup arguments in System Zero (#513)

* Halo2 style lookup arguments in System Zero

It's a really nice and simple protocol, particularly for the verifier since the constraints are trivial (aside from the underlying batched permutation checks, which we already support). See the [Halo2 book](https://zcash.github.io/halo2/design/proving-system/lookup.html) and this [talk](https://www.youtube.com/watch?v=YlTt12s7vGE&t=5237s) by @daira.

Previously we generated the whole trace in row-wise form, but it's much more efficient to generate these "permuted" columns column-wise. So I changed our STARK framework to accept the trace in column-wise form. STARK impls now have the flexibility to do some generation row-wise and some column-wise (without extra costs; there's a single transpose as before).

* sorting

* fixes

* PR feedback

* into_iter

* timing
This commit is contained in:
Daniel Lubarov 2022-03-16 17:37:34 -07:00 committed by GitHub
parent 660d785ed1
commit 7d6c0a448d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 324 additions and 34 deletions

View File

@ -462,6 +462,11 @@ pub trait PrimeField64: PrimeField + Field64 {
fn to_canonical_u64(&self) -> u64;
fn to_noncanonical_u64(&self) -> u64;
#[inline(always)]
fn to_canonical(&self) -> Self {
Self::from_canonical_u64(self.to_canonical_u64())
}
}
/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`.

View File

@ -95,7 +95,7 @@ impl Field for GoldilocksField {
Self(n.mod_floor(&Self::order()).to_u64_digits()[0])
}
#[inline]
#[inline(always)]
fn from_canonical_u64(n: u64) -> Self {
debug_assert!(n < Self::ORDER);
Self(n)
@ -156,6 +156,7 @@ impl PrimeField64 for GoldilocksField {
c
}
#[inline(always)]
fn to_noncanonical_u64(&self) -> u64 {
self.0
}

View File

@ -2,12 +2,14 @@ use std::marker::PhantomData;
use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::packed_field::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
/// Toy STARK system used for testing.
@ -37,8 +39,8 @@ impl<F: RichField + Extendable<D>, const D: usize> FibonacciStark<F, D> {
}
/// Generate the trace using `x0, x1, 0, 1` as initial state values.
fn generate_trace(&self, x0: F, x1: F) -> Vec<[F; Self::COLUMNS]> {
let mut trace = (0..self.num_rows)
fn generate_trace(&self, x0: F, x1: F) -> Vec<PolynomialValues<F>> {
let mut trace_rows = (0..self.num_rows)
.scan([x0, x1, F::ZERO, F::ONE], |acc, _| {
let tmp = *acc;
acc[0] = tmp[1];
@ -48,8 +50,8 @@ impl<F: RichField + Extendable<D>, const D: usize> FibonacciStark<F, D> {
Some(tmp)
})
.collect::<Vec<_>>();
trace[self.num_rows - 1][3] = F::ZERO; // So that column 2 and 3 are permutation of one another.
trace
trace_rows[self.num_rows - 1][3] = F::ZERO; // So that column 2 and 3 are permutation of one another.
trace_rows_to_poly_values(trace_rows)
}
}
@ -113,9 +115,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for FibonacciStar
}
fn permutation_pairs(&self) -> Vec<PermutationPair> {
vec![PermutationPair {
column_pairs: vec![(2, 3)],
}]
vec![PermutationPair::singletons(2, 3)]
}
}

View File

@ -15,6 +15,7 @@ pub mod prover;
pub mod recursive_verifier;
pub mod stark;
pub mod stark_testing;
pub mod util;
pub mod vanishing_poly;
pub mod vars;
pub mod verifier;

View File

@ -30,6 +30,14 @@ pub struct PermutationPair {
pub column_pairs: Vec<(usize, usize)>,
}
impl PermutationPair {
pub fn singletons(lhs: usize, rhs: usize) -> Self {
Self {
column_pairs: vec![(lhs, rhs)],
}
}
}
/// A single instance of a permutation check protocol.
pub(crate) struct PermutationInstance<'a, T: Copy> {
pub(crate) pair: &'a PermutationPair,

View File

@ -30,7 +30,7 @@ use crate::vars::StarkEvaluationVars;
pub fn prove<F, C, S, const D: usize>(
stark: S,
config: &StarkConfig,
trace: Vec<[F; S::COLUMNS]>,
trace_poly_values: Vec<PolynomialValues<F>>,
public_inputs: [F; S::PUBLIC_INPUTS],
timing: &mut TimingTree,
) -> Result<StarkProofWithPublicInputs<F, C, D>>
@ -42,7 +42,7 @@ where
[(); S::PUBLIC_INPUTS]:,
[(); C::Hasher::HASH_SIZE]:,
{
let degree = trace.len();
let degree = trace_poly_values[0].len();
let degree_bits = log2_strict(degree);
let fri_params = config.fri_params(degree_bits);
let rate_bits = config.fri_config.rate_bits;
@ -52,18 +52,6 @@ where
"FRI total reduction arity is too large.",
);
let trace_vecs = trace.iter().map(|row| row.to_vec()).collect_vec();
let trace_col_major: Vec<Vec<F>> = transpose(&trace_vecs);
let trace_poly_values: Vec<PolynomialValues<F>> = timed!(
timing,
"compute trace polynomials",
trace_col_major
.par_iter()
.map(|column| PolynomialValues::new(column.clone()))
.collect()
);
let trace_commitment = timed!(
timing,
"compute trace commitment",

16
starky/src/util.rs Normal file
View File

@ -0,0 +1,16 @@
use itertools::Itertools;
use plonky2::field::field_types::Field;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::util::transpose;
/// A helper function to transpose a row-wise trace and put it in the format that `prove` expects.
pub fn trace_rows_to_poly_values<F: Field, const COLUMNS: usize>(
trace_rows: Vec<[F; COLUMNS]>,
) -> Vec<PolynomialValues<F>> {
let trace_row_vecs = trace_rows.into_iter().map(|row| row.to_vec()).collect_vec();
let trace_col_vecs: Vec<Vec<F>> = transpose(&trace_row_vecs);
trace_col_vecs
.into_iter()
.map(|column| PolynomialValues::new(column))
.collect()
}

View File

@ -118,7 +118,10 @@ where
.chunks(stark.quotient_degree_factor())
.enumerate()
{
ensure!(vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg));
ensure!(
vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg),
"Mismatch between evaluation and opening of quotient polynomial"
);
}
let merkle_caps = once(proof.trace_cap)

View File

@ -10,6 +10,14 @@ plonky2_util = { path = "../util" }
starky = { path = "../starky" }
anyhow = "1.0.40"
env_logger = "0.9.0"
itertools = "0.10.0"
log = "0.4.14"
rand = "0.8.4"
rand_chacha = "0.3.1"
[dev-dependencies]
criterion = "0.3.5"
[[bench]]
name = "lookup_permuted_cols"
harness = false

View File

@ -0,0 +1,30 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::Itertools;
use plonky2::field::field_types::Field;
use plonky2::field::goldilocks_field::GoldilocksField;
use rand::{thread_rng, Rng};
use system_zero::lookup::permuted_cols;
type F = GoldilocksField;
fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("lookup-permuted-cols");
for size_log in [16, 17, 18] {
let size = 1 << size_log;
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| {
// We could benchmark a table of random values with
// let table = F::rand_vec(size);
// But in practice we currently use tables that are pre-sorted, which makes
// permuted_cols cheaper since it will sort the table.
let table = (0..size).map(F::from_canonical_usize).collect_vec();
let input = (0..size)
.map(|_| table[thread_rng().gen_range(0..size)])
.collect_vec();
b.iter(|| permuted_cols(&input, &table));
});
}
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -4,6 +4,7 @@
mod alu;
mod core_registers;
pub mod lookup;
mod memory;
mod permutation_unit;
mod public_input_layout;

147
system_zero/src/lookup.rs Normal file
View File

@ -0,0 +1,147 @@
//! Implementation of the Halo2 lookup argument.
//!
//! References:
//! - https://zcash.github.io/halo2/design/proving-system/lookup.html
//! - https://www.youtube.com/watch?v=YlTt12s7vGE&t=5237s
use std::cmp::Ordering;
use itertools::Itertools;
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use starky::vars::StarkEvaluationTargets;
use starky::vars::StarkEvaluationVars;
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::registers::lookup::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_lookups<F: PrimeField64>(trace_cols: &mut [Vec<F>]) {
for i in 0..NUM_LOOKUPS {
let inputs = &trace_cols[col_input(i)];
let table = &trace_cols[col_table(i)];
let (permuted_inputs, permuted_table) = permuted_cols(inputs, table);
trace_cols[col_permuted_input(i)] = permuted_inputs;
trace_cols[col_permuted_table(i)] = permuted_table;
}
}
/// Given an input column and a table column, generate the permuted input and permuted table columns
/// used in the Halo2 permutation argument.
pub fn permuted_cols<F: PrimeField64>(inputs: &[F], table: &[F]) -> (Vec<F>, Vec<F>) {
let n = inputs.len();
// The permuted inputs do not have to be ordered, but we found that sorting was faster than
// hash-based grouping. We also sort the table, as this helps us identify "unused" table
// elements efficiently.
// To compare elements, e.g. for sorting, we first need them in canonical form. It would be
// wasteful to canonicalize in each comparison, as a single element may be involved in many
// comparisons. So we will canonicalize once upfront, then use `to_noncanonical_u64` when
// comparing elements.
let sorted_inputs = inputs
.iter()
.map(|x| x.to_canonical())
.sorted_unstable_by_key(|x| x.to_noncanonical_u64())
.collect_vec();
let sorted_table = table
.iter()
.map(|x| x.to_canonical())
.sorted_unstable_by_key(|x| x.to_noncanonical_u64())
.collect_vec();
let mut unused_table_inds = Vec::with_capacity(n);
let mut unused_table_vals = Vec::with_capacity(n);
let mut permuted_table = vec![F::ZERO; n];
let mut i = 0;
let mut j = 0;
while (j < n) && (i < n) {
let input_val = sorted_inputs[i].to_noncanonical_u64();
let table_val = sorted_table[j].to_noncanonical_u64();
match input_val.cmp(&table_val) {
Ordering::Greater => {
unused_table_vals.push(sorted_table[j]);
j += 1;
}
Ordering::Less => {
if let Some(x) = unused_table_vals.pop() {
permuted_table[i] = x;
} else {
unused_table_inds.push(i);
}
i += 1;
}
Ordering::Equal => {
permuted_table[i] = sorted_table[j];
i += 1;
j += 1;
}
}
}
#[allow(clippy::needless_range_loop)] // indexing is just more natural here
for jj in j..n {
unused_table_vals.push(sorted_table[jj]);
}
for ii in i..n {
unused_table_inds.push(ii);
}
for (ind, val) in unused_table_inds.into_iter().zip_eq(unused_table_vals) {
permuted_table[ind] = val;
}
(sorted_inputs, permuted_table)
}
pub(crate) fn eval_lookups<F: Field, P: PackedField<Scalar = F>>(
vars: StarkEvaluationVars<F, P, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut ConstraintConsumer<P>,
) {
for i in 0..NUM_LOOKUPS {
let local_perm_input = vars.local_values[col_permuted_input(i)];
let next_perm_table = vars.next_values[col_permuted_table(i)];
let next_perm_input = vars.next_values[col_permuted_input(i)];
// A "vertical" diff between the local and next permuted inputs.
let diff_input_prev = next_perm_input - local_perm_input;
// A "horizontal" diff between the next permuted input and permuted table value.
let diff_input_table = next_perm_input - next_perm_table;
yield_constr.constraint(diff_input_prev * diff_input_table);
// This is actually constraining the first row, as per the spec, since `diff_input_table`
// is a diff of the next row's values. In the context of `constraint_last_row`, the next
// row is the first row.
yield_constr.constraint_last_row(diff_input_table);
}
}
pub(crate) fn eval_lookups_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
for i in 0..NUM_LOOKUPS {
let local_perm_input = vars.local_values[col_permuted_input(i)];
let next_perm_table = vars.next_values[col_permuted_table(i)];
let next_perm_input = vars.next_values[col_permuted_input(i)];
// A "vertical" diff between the local and next permuted inputs.
let diff_input_prev = builder.sub_extension(next_perm_input, local_perm_input);
// A "horizontal" diff between the next permuted input and permuted table value.
let diff_input_table = builder.sub_extension(next_perm_input, next_perm_table);
let diff_product = builder.mul_extension(diff_input_prev, diff_input_table);
yield_constr.constraint(builder, diff_product);
// This is actually constraining the first row, as per the spec, since `diff_input_table`
// is a diff of the next row's values. In the context of `constraint_last_row`, the next
// row is the first row.
yield_constr.constraint_last_row(builder, diff_input_table);
}
}

View File

@ -3,19 +3,35 @@
const START_UNIT: usize = super::START_LOOKUP;
const NUM_LOOKUPS: usize =
pub(crate) const NUM_LOOKUPS: usize =
super::range_check_16::NUM_RANGE_CHECKS + super::range_check_degree::NUM_RANGE_CHECKS;
pub(crate) const fn col_input(i: usize) -> usize {
if i < super::range_check_16::NUM_RANGE_CHECKS {
super::range_check_16::col_rc_16_input(i)
} else {
super::range_check_degree::col_rc_degree_input(i - super::range_check_16::NUM_RANGE_CHECKS)
}
}
/// This column contains a permutation of the input values.
const fn col_permuted_input(i: usize) -> usize {
pub(crate) const fn col_permuted_input(i: usize) -> usize {
debug_assert!(i < NUM_LOOKUPS);
START_UNIT + 2 * i
}
pub(crate) const fn col_table(i: usize) -> usize {
if i < super::range_check_16::NUM_RANGE_CHECKS {
super::core::COL_RANGE_16
} else {
super::core::COL_CLOCK
}
}
/// This column contains a permutation of the table values.
const fn col_permuted_table(i: usize) -> usize {
pub(crate) const fn col_permuted_table(i: usize) -> usize {
debug_assert!(i < NUM_LOOKUPS);
START_UNIT + 2 * i + 1
}
pub(super) const END: usize = START_UNIT + NUM_LOOKUPS;
pub(super) const END: usize = START_UNIT + NUM_LOOKUPS * 2;

View File

@ -1,6 +1,6 @@
//! Range check unit which checks that values are in `[0, 2^16)`.
pub(super) const NUM_RANGE_CHECKS: usize = 5;
pub(crate) const NUM_RANGE_CHECKS: usize = 5;
/// The input of the `i`th range check, i.e. the value being range checked.
pub(crate) const fn col_rc_16_input(i: usize) -> usize {

View File

@ -1,6 +1,6 @@
//! Range check unit which checks that values are in `[0, degree)`.
pub(super) const NUM_RANGE_CHECKS: usize = 5;
pub(crate) const NUM_RANGE_CHECKS: usize = 5;
/// The input of the `i`th range check, i.e. the value being range checked.
pub(crate) const fn col_rc_degree_input(i: usize) -> usize {

View File

@ -2,8 +2,12 @@ use std::marker::PhantomData;
use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::packed_field::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use starky::permutation::PermutationPair;
use starky::stark::Stark;
@ -15,12 +19,13 @@ use crate::core_registers::{
eval_core_registers, eval_core_registers_recursively, generate_first_row_core_registers,
generate_next_row_core_registers,
};
use crate::lookup::{eval_lookups, eval_lookups_recursively, generate_lookups};
use crate::memory::TransactionMemory;
use crate::permutation_unit::{
eval_permutation_unit, eval_permutation_unit_recursively, generate_permutation_unit,
};
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::registers::NUM_COLUMNS;
use crate::registers::{lookup, NUM_COLUMNS};
/// We require at least 2^16 rows as it helps support efficient 16-bit range checks.
const MIN_TRACE_ROWS: usize = 1 << 16;
@ -31,7 +36,9 @@ pub struct SystemZero<F: RichField + Extendable<D>, const D: usize> {
}
impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
fn generate_trace(&self) -> Vec<[F; NUM_COLUMNS]> {
/// Generate the rows of the trace. Note that this does not generate the permuted columns used
/// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows(&self) -> Vec<[F; NUM_COLUMNS]> {
let memory = TransactionMemory::default();
let mut row = [F::ZERO; NUM_COLUMNS];
@ -59,6 +66,45 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
trace.push(row);
trace
}
fn generate_trace(&self) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
// Generate the witness, except for permuted columns in the lookup argument.
let trace_rows = timed!(
&mut timing,
"generate trace rows",
self.generate_trace_rows()
);
// Transpose from row-wise to column-wise.
let trace_row_vecs: Vec<_> = timed!(
&mut timing,
"convert to Vecs",
trace_rows.into_iter().map(|row| row.to_vec()).collect()
);
let mut trace_col_vecs: Vec<Vec<F>> =
timed!(&mut timing, "transpose", transpose(&trace_row_vecs));
// Generate permuted columns in the lookup argument.
timed!(
&mut timing,
"generate lookup columns",
generate_lookups(&mut trace_col_vecs)
);
let trace_polys = timed!(
&mut timing,
"convert to PolynomialValues",
trace_col_vecs
.into_iter()
.map(|column| PolynomialValues::new(column))
.collect()
);
timing.print();
trace_polys
}
}
impl<F: RichField + Extendable<D>, const D: usize> Default for SystemZero<F, D> {
@ -84,6 +130,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
eval_core_registers(vars, yield_constr);
eval_alu(vars, yield_constr);
eval_permutation_unit::<F, FE, P, D2>(vars, yield_constr);
eval_lookups(vars, yield_constr);
// TODO: Other units
}
@ -96,6 +143,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
eval_core_registers_recursively(builder, vars, yield_constr);
eval_alu_recursively(builder, vars, yield_constr);
eval_permutation_unit_recursively(builder, vars, yield_constr);
eval_lookups_recursively(builder, vars, yield_constr);
// TODO: Other units
}
@ -104,9 +152,22 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
}
fn permutation_pairs(&self) -> Vec<PermutationPair> {
let mut pairs = Vec::new();
for i in 0..lookup::NUM_LOOKUPS {
pairs.push(PermutationPair::singletons(
lookup::col_input(i),
lookup::col_permuted_input(i),
));
pairs.push(PermutationPair::singletons(
lookup::col_table(i),
lookup::col_permuted_table(i),
));
}
// TODO: Add permutation pairs for memory.
// TODO: Add permutation pairs for range checks.
vec![]
pairs
}
}
@ -127,8 +188,9 @@ mod tests {
use crate::system_zero::SystemZero;
#[test]
#[ignore] // A bit slow.
fn run() -> Result<()> {
init_logger();
type F = GoldilocksField;
type C = PoseidonGoldilocksConfig;
const D: usize = 2;
@ -154,4 +216,8 @@ mod tests {
let system = S::default();
test_stark_low_degree(system)
}
fn init_logger() {
let _ = env_logger::builder().format_timestamp(None).try_init();
}
}