Column enum

This commit is contained in:
wborgeaud 2022-06-14 00:53:31 +02:00
parent 732002691b
commit d626679c6c
6 changed files with 166 additions and 74 deletions

View File

@ -62,7 +62,7 @@ mod tests {
use crate::config::StarkConfig;
use crate::cpu::columns::{KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS};
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns};
use crate::cross_table_lookup::{Column, CrossTableLookup, TableWithColumns};
use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS};
use crate::proof::AllProof;
use crate::prover::prove;
@ -148,13 +148,13 @@ mod tests {
let cross_table_lookups = vec![CrossTableLookup::new(
vec![TableWithColumns::new(
Table::Cpu,
cpu_keccak_input_output,
vec![cpu::columns::IS_KECCAK],
Column::singles(cpu_keccak_input_output),
Column::single(cpu::columns::IS_KECCAK),
)],
TableWithColumns::new(
Table::Keccak,
keccak_keccak_input_output,
vec![keccak::registers::reg_step(NUM_ROUNDS - 1)],
Column::singles(keccak_keccak_input_output),
Column::single(keccak::registers::reg_step(NUM_ROUNDS - 1)),
),
None,
)];

View File

@ -1,3 +1,5 @@
use std::iter::repeat;
use anyhow::{ensure, Result};
use itertools::Itertools;
use plonky2::field::extension_field::{Extendable, FieldExtension};
@ -21,42 +23,129 @@ use crate::proof::{StarkProofWithPublicInputs, StarkProofWithPublicInputsTarget}
use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
/// Represent a column or a linear combination of columns.
#[derive(Clone)]
pub struct TableWithColumns {
table: Table,
columns: Vec<usize>,
/// Vector of columns `[c_1,...,c_k]` used as a filter using the sum `c_1 + ... + c_k`.
/// An empty vector corresponds to no filter.
filter_columns: Vec<usize>,
pub enum Column<F: Field> {
Single(usize),
LinearCombination(Vec<(usize, F)>),
Empty,
}
impl TableWithColumns {
pub fn new(table: Table, columns: Vec<usize>, filter_columns: Vec<usize>) -> Self {
impl<F: Field> Column<F> {
pub fn single(c: usize) -> Self {
Self::Single(c)
}
pub fn singles(cs: Vec<usize>) -> Vec<Self> {
cs.into_iter().map(Self::single).collect()
}
pub fn linear_combination<I: IntoIterator<Item = (usize, F)>>(iter: I) -> Self {
let v = iter.into_iter().collect::<Vec<_>>();
assert!(!v.is_empty());
debug_assert_eq!(
filter_columns.iter().unique().count(),
filter_columns.len(),
v.iter().map(|(c, _)| c).unique().count(),
v.len(),
"Duplicate filter columns."
);
Self::LinearCombination(v)
}
pub fn le_bits(cs: &[usize]) -> Self {
Self::linear_combination(cs.iter().copied().zip(F::TWO.powers()))
}
pub fn sum(cs: &[usize]) -> Self {
Self::linear_combination(cs.iter().copied().zip(repeat(F::ONE)))
}
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
}
pub fn eval<FE, P, const D: usize>(&self, v: &[P]) -> P
where
FE: FieldExtension<D, BaseField = F>,
P: PackedField<Scalar = FE>,
{
match self {
Column::Single(c) => v[*c],
Column::LinearCombination(cs) => {
cs.iter().map(|&(c, f)| v[c] * FE::from_basefield(f)).sum()
}
Column::Empty => panic!("Cannot eval with empty column."),
}
}
/// Evaluate on an row of a table given in column-major form.
pub fn eval_table(&self, table: &[PolynomialValues<F>], row: usize) -> F {
match self {
Column::Single(c) => table[*c].values[row],
Column::LinearCombination(cs) => {
cs.iter().map(|&(c, f)| table[c].values[row] * f).sum()
}
Column::Empty => panic!("Cannot eval with empty column."),
}
}
pub fn eval_circuit<const D: usize>(
&self,
builder: &mut CircuitBuilder<F, D>,
v: &[ExtensionTarget<D>],
) -> ExtensionTarget<D>
where
F: RichField + Extendable<D>,
{
match self {
Column::Single(c) => v[*c],
Column::LinearCombination(cs) => {
let pairs = cs
.iter()
.map(|&(c, f)| {
(
v[c],
builder.constant_extension(F::Extension::from_basefield(f)),
)
})
.collect::<Vec<_>>();
let zero = builder.zero_extension();
builder.inner_product_extension(F::ONE, zero, pairs)
}
Column::Empty => panic!("Cannot eval with empty column."),
}
}
}
#[derive(Clone)]
pub struct TableWithColumns<F: Field> {
table: Table,
columns: Vec<Column<F>>,
filter_column: Column<F>,
}
impl<F: Field> TableWithColumns<F> {
pub fn new(table: Table, columns: Vec<Column<F>>, filter_column: Column<F>) -> Self {
assert!(columns.iter().all(|c| !c.is_empty()));
Self {
table,
columns,
filter_columns,
filter_column,
}
}
}
#[derive(Clone)]
pub struct CrossTableLookup<F: Field> {
looking_tables: Vec<TableWithColumns>,
looked_table: TableWithColumns,
looking_tables: Vec<TableWithColumns<F>>,
looked_table: TableWithColumns<F>,
/// Default value if filters are not used.
default: Option<Vec<F>>,
}
impl<F: Field> CrossTableLookup<F> {
pub fn new(
looking_tables: Vec<TableWithColumns>,
looked_table: TableWithColumns,
looking_tables: Vec<TableWithColumns<F>>,
looked_table: TableWithColumns<F>,
default: Option<Vec<F>>,
) -> Self {
assert!(looking_tables
@ -65,8 +154,8 @@ impl<F: Field> CrossTableLookup<F> {
assert!(
looking_tables
.iter()
.all(|twc| twc.filter_columns.is_empty() == default.is_some())
&& default.is_some() == looked_table.filter_columns.is_empty(),
.all(|twc| twc.filter_column.is_empty() == default.is_some())
&& default.is_some() == looked_table.filter_column.is_empty(),
"Default values should be provided iff there are no filter columns."
);
if let Some(default) = &default {
@ -87,7 +176,7 @@ pub struct CtlData<F: Field> {
pub(crate) challenges: GrandProductChallengeSet<F>,
/// Vector of `(Z, columns, filter_columns)` where `Z` is a Z-polynomial for a lookup
/// on columns `columns` with filter columns `filter_columns`.
pub zs_columns: Vec<(PolynomialValues<F>, Vec<usize>, Vec<usize>)>,
pub zs_columns: Vec<(PolynomialValues<F>, Vec<Column<F>>, Column<F>)>,
}
impl<F: Field> CtlData<F> {
@ -130,14 +219,14 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
partial_products(
&trace_poly_values[table.table as usize],
&table.columns,
&table.filter_columns,
&table.filter_column,
challenge,
)
});
let z_looked = partial_products(
&trace_poly_values[looked_table.table as usize],
&looked_table.columns,
&looked_table.filter_columns,
&looked_table.filter_column,
challenge,
);
@ -168,7 +257,7 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
ctl_data_per_table[table.table as usize].zs_columns.push((
z,
table.columns.clone(),
table.filter_columns.clone(),
table.filter_column.clone(),
));
}
ctl_data_per_table[looked_table.table as usize]
@ -176,7 +265,7 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
.push((
z_looked,
looked_table.columns.clone(),
looked_table.filter_columns.clone(),
looked_table.filter_column.clone(),
));
}
}
@ -185,21 +274,25 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
fn partial_products<F: Field>(
trace: &[PolynomialValues<F>],
columns: &[usize],
filter_columns: &[usize],
columns: &[Column<F>],
filter_column: &Column<F>,
challenge: GrandProductChallenge<F>,
) -> PolynomialValues<F> {
let mut partial_prod = F::ONE;
let degree = trace[0].len();
let mut res = Vec::with_capacity(degree);
for i in 0..degree {
let filter = if filter_columns.is_empty() {
let filter = if filter_column.is_empty() {
F::ONE
} else {
filter_columns.iter().map(|&j| trace[j].values[i]).sum()
filter_column.eval_table(trace, i)
};
if filter.is_one() {
partial_prod *= challenge.combine(columns.iter().map(|&j| &trace[j].values[i]));
let evals = columns
.iter()
.map(|c| c.eval_table(trace, i))
.collect::<Vec<_>>();
partial_prod *= challenge.combine(evals.iter());
} else {
assert_eq!(filter, F::ZERO, "Non-binary filter?")
};
@ -218,8 +311,8 @@ where
pub(crate) local_z: P,
pub(crate) next_z: P,
pub(crate) challenges: GrandProductChallenge<F>,
pub(crate) columns: &'a [usize],
pub(crate) filter_columns: &'a [usize],
pub(crate) columns: &'a [Column<F>],
pub(crate) filter_column: &'a Column<F>,
}
impl<'a, F: RichField + Extendable<D>, const D: usize>
@ -258,7 +351,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
next_z: *looking_z_next,
challenges,
columns: &table.columns,
filter_columns: &table.filter_columns,
filter_column: &table.filter_column,
});
}
@ -268,7 +361,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
next_z: *looked_z_next,
challenges,
columns: &looked_table.columns,
filter_columns: &looked_table.filter_columns,
filter_column: &looked_table.filter_column,
});
}
}
@ -293,14 +386,17 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, C, S, const D: usize, con
next_z,
challenges,
columns,
filter_columns,
filter_column,
} = lookup_vars;
let combine = |v: &[P]| -> P { challenges.combine(columns.iter().map(|&i| &v[i])) };
let combine = |v: &[P]| -> P {
let evals = columns.iter().map(|c| c.eval(v)).collect::<Vec<_>>();
challenges.combine(evals.iter())
};
let filter = |v: &[P]| -> P {
if filter_columns.is_empty() {
if filter_column.is_empty() {
P::ONES
} else {
filter_columns.iter().map(|&i| v[i]).sum()
filter_column.eval(v)
}
};
let local_filter = filter(vars.local_values);
@ -317,16 +413,16 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, C, S, const D: usize, con
}
#[derive(Clone)]
pub struct CtlCheckVarsTarget<'a, const D: usize> {
pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> {
pub(crate) local_z: ExtensionTarget<D>,
pub(crate) next_z: ExtensionTarget<D>,
pub(crate) challenges: GrandProductChallenge<Target>,
pub(crate) columns: &'a [usize],
pub(crate) filter_columns: &'a [usize],
pub(crate) columns: &'a [Column<F>],
pub(crate) filter_column: &'a Column<F>,
}
impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
pub(crate) fn from_proofs<F: Field>(
impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> {
pub(crate) fn from_proofs(
proofs: &[StarkProofWithPublicInputsTarget<D>],
cross_table_lookups: &'a [CrossTableLookup<F>],
ctl_challenges: &'a GrandProductChallengeSet<Target>,
@ -359,7 +455,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
next_z: *looking_z_next,
challenges,
columns: &table.columns,
filter_columns: &table.filter_columns,
filter_column: &table.filter_column,
});
}
@ -369,7 +465,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
next_z: *looked_z_next,
challenges,
columns: &looked_table.columns,
filter_columns: &looked_table.filter_columns,
filter_column: &looked_table.filter_column,
});
}
}
@ -384,7 +480,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
>(
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, { S::COLUMNS }, { S::PUBLIC_INPUTS }>,
ctl_vars: &[CtlCheckVarsTarget<D>],
ctl_vars: &[CtlCheckVarsTarget<F, D>],
consumer: &mut RecursiveConstraintConsumer<F, D>,
) {
for lookup_vars in ctl_vars {
@ -393,19 +489,19 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
next_z,
challenges,
columns,
filter_columns,
filter_column,
} = lookup_vars;
let one = builder.one_extension();
let local_filter = if filter_columns.is_empty() {
let local_filter = if filter_column.is_empty() {
one
} else {
builder.add_many_extension(filter_columns.iter().map(|&i| vars.local_values[i]))
filter_column.eval_circuit(builder, vars.local_values)
};
let next_filter = if filter_columns.is_empty() {
let next_filter = if filter_column.is_empty() {
one
} else {
builder.add_many_extension(filter_columns.iter().map(|&i| vars.next_values[i]))
filter_column.eval_circuit(builder, vars.next_values)
};
fn select<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -418,24 +514,20 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
}
// Check value of `Z(1)`
let combined_local = challenges.combine_circuit(
builder,
&columns
.iter()
.map(|&i| vars.local_values[i])
.collect::<Vec<_>>(),
);
let local_columns_eval = columns
.iter()
.map(|c| c.eval_circuit(builder, vars.local_values))
.collect::<Vec<_>>();
let combined_local = challenges.combine_circuit(builder, &local_columns_eval);
let selected_local = select(builder, local_filter, combined_local);
let first_row = builder.sub_extension(*local_z, selected_local);
consumer.constraint_first_row(builder, first_row);
// Check `Z(gw) = combination * Z(w)`
let combined_next = challenges.combine_circuit(
builder,
&columns
.iter()
.map(|&i| vars.next_values[i])
.collect::<Vec<_>>(),
);
let next_columns_eval = columns
.iter()
.map(|c| c.eval_circuit(builder, vars.next_values))
.collect::<Vec<_>>();
let combined_next = challenges.combine_circuit(builder, &next_columns_eval);
let selected_next = select(builder, next_filter, combined_next);
let mut transition = builder.mul_extension(*local_z, selected_next);
transition = builder.sub_extension(*next_z, transition);

View File

@ -392,14 +392,14 @@ where
.iter()
.enumerate()
.map(
|(i, (_, columns, filter_columns))| CtlCheckVars::<F, F, P, 1> {
|(i, (_, columns, filter_column))| CtlCheckVars::<F, F, P, 1> {
local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step)
[num_permutation_zs + i],
next_z: permutation_ctl_zs_commitment
.get_lde_values_packed(i_next_start, step)[num_permutation_zs + i],
challenges: ctl_data.challenges.challenges[i % config.num_challenges],
columns,
filter_columns,
filter_column,
},
)
.collect::<Vec<_>>();
@ -510,14 +510,14 @@ fn check_constraints<'a, F, C, S, const D: usize>(
.iter()
.enumerate()
.map(
|(iii, (_, columns, filter_columns))| CtlCheckVars::<F, F, F, 1> {
|(iii, (_, columns, filter_column))| CtlCheckVars::<F, F, F, 1> {
local_z: get_comm_values(permutation_ctl_zs_commitment, i)
[num_permutation_zs + iii],
next_z: get_comm_values(permutation_ctl_zs_commitment, i_next)
[num_permutation_zs + iii],
challenges: ctl_data.challenges.challenges[iii % config.num_challenges],
columns,
filter_columns,
filter_column,
},
)
.collect::<Vec<_>>();

View File

@ -100,7 +100,7 @@ fn verify_stark_proof_with_challenges_circuit<
stark: S,
proof_with_pis: &StarkProofWithPublicInputsTarget<D>,
challenges: &StarkProofChallengesTarget<D>,
ctl_vars: &[CtlCheckVarsTarget<D>],
ctl_vars: &[CtlCheckVarsTarget<F, D>],
inner_config: &StarkConfig,
) where
C::Hasher: AlgebraicHasher<F>,

View File

@ -50,7 +50,7 @@ pub(crate) fn eval_vanishing_poly_circuit<F, C, S, const D: usize>(
config: &StarkConfig,
vars: StarkEvaluationTargets<D, { S::COLUMNS }, { S::PUBLIC_INPUTS }>,
permutation_data: Option<PermutationCheckDataTarget<D>>,
ctl_vars: &[CtlCheckVarsTarget<D>],
ctl_vars: &[CtlCheckVarsTarget<F, D>],
consumer: &mut RecursiveConstraintConsumer<F, D>,
) where
F: RichField + Extendable<D>,

View File

@ -30,7 +30,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e.,
/// the number with little-endian bit representation given by `bits`.
pub(crate) fn le_sum(&mut self, bits: impl Iterator<Item = impl Borrow<BoolTarget>>) -> Target {
pub fn le_sum(&mut self, bits: impl Iterator<Item = impl Borrow<BoolTarget>>) -> Target {
let bits = bits.map(|b| *b.borrow()).collect_vec();
let num_bits = bits.len();
if num_bits == 0 {