Support accessing previous row in CTLs

This commit is contained in:
Daniel Lubarov 2022-08-22 11:07:39 -07:00
parent 8a8b3f36aa
commit a37dec9881
3 changed files with 137 additions and 73 deletions

View File

@ -145,6 +145,7 @@ mod tests {
use crate::cpu::cpu_stark::CpuStark; use crate::cpu::cpu_stark::CpuStark;
use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::aggregator::KERNEL;
use crate::cross_table_lookup::testutils::check_ctls; use crate::cross_table_lookup::testutils::check_ctls;
use crate::cross_table_lookup::Column;
use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS};
use crate::logic::{self, LogicStark, Operation}; use crate::logic::{self, LogicStark, Operation};
use crate::memory::memory_stark::tests::generate_random_memory_ops; use crate::memory::memory_stark::tests::generate_random_memory_ops;
@ -216,8 +217,10 @@ mod tests {
.map(|i| { .map(|i| {
(0..2 * NUM_INPUTS) (0..2 * NUM_INPUTS)
.map(|j| { .map(|j| {
keccak::columns::reg_input_limb(j) // There's an extra -1 because the argument to eval_table is the local row,
.eval_table(keccak_trace, (i + 1) * NUM_ROUNDS - 1) // but the inputs/outputs live in the next row.
let local_row = (i + 1) * NUM_ROUNDS - 1 - 1;
keccak::columns::reg_input_limb(j).eval_table(keccak_trace, local_row)
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
@ -228,8 +231,11 @@ mod tests {
.map(|i| { .map(|i| {
(0..2 * NUM_INPUTS) (0..2 * NUM_INPUTS)
.map(|j| { .map(|j| {
keccak_trace[keccak::columns::reg_output_limb(j)].values let out_limb = Column::single(keccak::columns::reg_output_limb(j));
[(i + 1) * NUM_ROUNDS - 1] // There's an extra -1 because the argument to eval_table is the local row,
// but the inputs/outputs live in the next row.
let local_row = (i + 1) * NUM_ROUNDS - 1 - 1;
out_limb.eval_table(keccak_trace, local_row)
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()

View File

@ -10,7 +10,7 @@ use plonky2::hash::hash_types::RichField;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS};
use crate::cpu::{bootstrap_kernel, control_flow, decode, jumps, simple_logic, syscalls}; use crate::cpu::{bootstrap_kernel, control_flow, decode, jumps, simple_logic, syscalls};
use crate::cross_table_lookup::Column; use crate::cross_table_lookup::{Column, WeightedColumn};
use crate::memory::NUM_CHANNELS; use crate::memory::NUM_CHANNELS;
use crate::stark::Stark; use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
@ -50,10 +50,14 @@ pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
.collect_vec(); .collect_vec();
cols.extend(Column::singles(COL_MAP.mem_value[channel])); cols.extend(Column::singles(COL_MAP.mem_value[channel]));
let scalar = F::from_canonical_usize(NUM_CHANNELS); let weight = F::from_canonical_usize(NUM_CHANNELS);
let addend = F::from_canonical_usize(channel); let addend = F::from_canonical_usize(channel);
cols.push(Column::linear_combination_with_constant( cols.push(Column::linear_combination_with_constant(
vec![(COL_MAP.clock, scalar)], vec![WeightedColumn {
column: COL_MAP.clock,
next: true,
weight,
}],
addend, addend,
)); ));

View File

@ -1,5 +1,3 @@
use std::iter::repeat;
use anyhow::{ensure, Result}; use anyhow::{ensure, Result};
use itertools::Itertools; use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::extension::{Extendable, FieldExtension};
@ -24,16 +22,30 @@ use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
/// Represent a linear combination of columns. /// Represent a linear combination of columns.
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct Column<F: Field> { pub struct Column<F: Field> {
linear_combination: Vec<(usize, F)>, linear_combination: Vec<WeightedColumn<F>>,
constant: F, constant: F,
} }
#[derive(Clone, Debug)]
pub(crate) struct WeightedColumn<F: Field> {
/// The index of the column.
pub(crate) column: usize,
/// True if this column refers to a column in the next row, rather than the local row.
/// Most CTLs consist of only "next" columns.
pub(crate) next: bool,
pub(crate) weight: F,
}
impl<F: Field> Column<F> { impl<F: Field> Column<F> {
pub fn single(c: usize) -> Self { pub fn single(column: usize) -> Self {
Self { Self {
linear_combination: vec![(c, F::ONE)], linear_combination: vec![WeightedColumn {
column,
next: true,
weight: F::ONE,
}],
constant: F::ZERO, constant: F::ZERO,
} }
} }
@ -42,14 +54,17 @@ impl<F: Field> Column<F> {
cs.into_iter().map(Self::single) cs.into_iter().map(Self::single)
} }
pub fn linear_combination_with_constant<I: IntoIterator<Item = (usize, F)>>( pub(crate) fn linear_combination_with_constant<I: IntoIterator<Item = WeightedColumn<F>>>(
iter: I, iter: I,
constant: F, constant: F,
) -> Self { ) -> Self {
let v = iter.into_iter().collect::<Vec<_>>(); let v = iter.into_iter().collect::<Vec<_>>();
assert!(!v.is_empty()); assert!(!v.is_empty());
debug_assert_eq!( debug_assert_eq!(
v.iter().map(|(c, _)| c).unique().count(), v.iter()
.map(|weighted_col| weighted_col.column)
.unique()
.count(),
v.len(), v.len(),
"Duplicate columns." "Duplicate columns."
); );
@ -59,35 +74,65 @@ impl<F: Field> Column<F> {
} }
} }
pub fn linear_combination<I: IntoIterator<Item = (usize, F)>>(iter: I) -> Self { pub(crate) fn linear_combination<I: IntoIterator<Item = WeightedColumn<F>>>(iter: I) -> Self {
Self::linear_combination_with_constant(iter, F::ZERO) Self::linear_combination_with_constant(iter, F::ZERO)
} }
pub fn le_bits<I: IntoIterator<Item = usize>>(cs: I) -> Self { pub fn le_bits<I: IntoIterator<Item = usize>>(cs: I) -> Self {
Self::linear_combination(cs.into_iter().zip(F::TWO.powers())) Self::linear_combination(cs.into_iter().zip(F::TWO.powers()).map(|(column, weight)| {
WeightedColumn {
column,
next: true,
weight,
}
}))
} }
pub fn sum<I: IntoIterator<Item = usize>>(cs: I) -> Self { pub fn sum<I: IntoIterator<Item = usize>>(cs: I) -> Self {
Self::linear_combination(cs.into_iter().zip(repeat(F::ONE))) Self::linear_combination(cs.into_iter().map(|column| WeightedColumn {
column,
next: true,
weight: F::ONE,
}))
} }
pub fn eval<FE, P, const D: usize>(&self, v: &[P]) -> P pub fn eval<FE, P, const D: usize>(&self, local_values: &[P], next_values: &[P]) -> P
where where
FE: FieldExtension<D, BaseField = F>, FE: FieldExtension<D, BaseField = F>,
P: PackedField<Scalar = FE>, P: PackedField<Scalar = FE>,
{ {
self.linear_combination self.linear_combination
.iter() .iter()
.map(|&(c, f)| v[c] * FE::from_basefield(f)) .map(|weighted_col| {
let values = if weighted_col.next {
next_values
} else {
local_values
};
values[weighted_col.column] * FE::from_basefield(weighted_col.weight)
})
.sum::<P>() .sum::<P>()
+ FE::from_basefield(self.constant) + FE::from_basefield(self.constant)
} }
/// Evaluate on an row of a table given in column-major form. /// Evaluate on an row of a table given in column-major form.
pub fn eval_table(&self, table: &[PolynomialValues<F>], row: usize) -> F { pub fn eval_table(&self, table: &[PolynomialValues<F>], local_row: usize) -> F {
let mut next_row = local_row + 1;
if next_row == table[0].len() {
next_row = 0;
}
self.linear_combination self.linear_combination
.iter() .iter()
.map(|&(c, f)| table[c].values[row] * f) .map(|weighted_col| {
let row = if weighted_col.next {
next_row
} else {
local_row
};
let poly = &table[weighted_col.column];
poly.values[row] * weighted_col.weight
})
.sum::<F>() .sum::<F>()
+ self.constant + self.constant
} }
@ -95,7 +140,8 @@ impl<F: Field> Column<F> {
pub fn eval_circuit<const D: usize>( pub fn eval_circuit<const D: usize>(
&self, &self,
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
v: &[ExtensionTarget<D>], local_values: &[ExtensionTarget<D>],
next_values: &[ExtensionTarget<D>],
) -> ExtensionTarget<D> ) -> ExtensionTarget<D>
where where
F: RichField + Extendable<D>, F: RichField + Extendable<D>,
@ -103,10 +149,15 @@ impl<F: Field> Column<F> {
let pairs = self let pairs = self
.linear_combination .linear_combination
.iter() .iter()
.map(|&(c, f)| { .map(|weighted_col| {
let values = if weighted_col.next {
next_values
} else {
local_values
};
( (
v[c], values[weighted_col.column],
builder.constant_extension(F::Extension::from_basefield(f)), builder.constant_extension(F::Extension::from_basefield(weighted_col.weight)),
) )
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -115,7 +166,7 @@ impl<F: Field> Column<F> {
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct TableWithColumns<F: Field> { pub struct TableWithColumns<F: Field> {
table: Table, table: Table,
columns: Vec<Column<F>>, columns: Vec<Column<F>>,
@ -132,7 +183,7 @@ impl<F: Field> TableWithColumns<F> {
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct CrossTableLookup<F: Field> { pub struct CrossTableLookup<F: Field> {
looking_tables: Vec<TableWithColumns<F>>, looking_tables: Vec<TableWithColumns<F>>,
looked_table: TableWithColumns<F>, looked_table: TableWithColumns<F>,
@ -279,16 +330,22 @@ fn partial_products<F: Field>(
let mut partial_prod = F::ONE; let mut partial_prod = F::ONE;
let degree = trace[0].len(); let degree = trace[0].len();
let mut res = Vec::with_capacity(degree); let mut res = Vec::with_capacity(degree);
for i in 0..degree { for next_row in 0..degree {
let local_row = if next_row == 0 {
degree - 1
} else {
next_row - 1
};
let filter = if let Some(column) = filter_column { let filter = if let Some(column) = filter_column {
column.eval_table(trace, i) column.eval_table(trace, local_row)
} else { } else {
F::ONE F::ONE
}; };
if filter.is_one() { if filter.is_one() {
let evals = columns let evals = columns
.iter() .iter()
.map(|c| c.eval_table(trace, i)) .map(|c| c.eval_table(trace, local_row))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
partial_prod *= challenge.combine(evals.iter()); partial_prod *= challenge.combine(evals.iter());
} else { } else {
@ -386,27 +443,23 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, C, S, const D: usize, con
columns, columns,
filter_column, filter_column,
} = lookup_vars; } = lookup_vars;
let combine = |v: &[P]| -> P { // TODO: Avoid collecting here.
let evals = columns.iter().map(|c| c.eval(v)).collect::<Vec<_>>(); let evals = columns
challenges.combine(evals.iter()) .iter()
.map(|c| c.eval(vars.local_values, vars.next_values))
.collect::<Vec<_>>();
let combined = challenges.combine(evals.iter());
let filter = if let Some(column) = filter_column {
column.eval(vars.local_values, vars.next_values)
} else {
P::ONES
}; };
let filter = |v: &[P]| -> P { let multiplier = filter * combined + P::ONES - filter;
if let Some(column) = filter_column {
column.eval(v)
} else {
P::ONES
}
};
let local_filter = filter(vars.local_values);
let next_filter = filter(vars.next_values);
let select = |filter, x| filter * x + P::ONES - filter;
// Check value of `Z(1)` // Check value of `Z(1)`
consumer.constraint_first_row(*local_z - select(local_filter, combine(vars.local_values))); consumer.constraint_last_row(*next_z - multiplier);
// Check `Z(gw) = combination * Z(w)` // Check `Z(gw) = combination * Z(w)`
consumer.constraint_transition( consumer.constraint_transition(*next_z - *local_z * multiplier);
*next_z - *local_z * select(next_filter, combine(vars.next_values)),
);
} }
} }
@ -491,16 +544,12 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
} = lookup_vars; } = lookup_vars;
let one = builder.one_extension(); let one = builder.one_extension();
let local_filter = if let Some(column) = filter_column { let filter = if let Some(column) = filter_column {
column.eval_circuit(builder, vars.local_values) column.eval_circuit(builder, vars.local_values, vars.next_values)
} else {
one
};
let next_filter = if let Some(column) = filter_column {
column.eval_circuit(builder, vars.next_values)
} else { } else {
one one
}; };
// TODO: Can use builder.select_ext_generalized.
fn select<F: RichField + Extendable<D>, const D: usize>( fn select<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
filter: ExtensionTarget<D>, filter: ExtensionTarget<D>,
@ -512,23 +561,17 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
} }
// Check value of `Z(1)` // Check value of `Z(1)`
let local_columns_eval = columns let evals = columns
.iter() .iter()
.map(|c| c.eval_circuit(builder, vars.local_values)) .map(|c| c.eval_circuit(builder, vars.local_values, vars.next_values))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let combined_local = challenges.combine_circuit(builder, &local_columns_eval); let combined = challenges.combine_circuit(builder, &evals);
let selected_local = select(builder, local_filter, combined_local); let multiplier = select(builder, filter, combined);
let first_row = builder.sub_extension(*local_z, selected_local); let first_row = builder.sub_extension(*next_z, multiplier);
consumer.constraint_first_row(builder, first_row); consumer.constraint_last_row(builder, first_row);
// Check `Z(gw) = combination * Z(w)` // Check `Z(gw) = combination * Z(w)`
let next_columns_eval = columns let product = builder.mul_extension(*local_z, multiplier);
.iter() let transition = builder.sub_extension(*next_z, product);
.map(|c| c.eval_circuit(builder, vars.next_values))
.collect::<Vec<_>>();
let combined_next = challenges.combine_circuit(builder, &next_columns_eval);
let selected_next = select(builder, next_filter, combined_next);
let mut transition = builder.mul_extension(*local_z, selected_next);
transition = builder.sub_extension(*next_z, transition);
consumer.constraint_transition(builder, transition); consumer.constraint_transition(builder, transition);
} }
} }
@ -746,9 +789,17 @@ pub(crate) mod testutils {
multiset: &mut MultiSet<F>, multiset: &mut MultiSet<F>,
) { ) {
let trace = &trace_poly_values[table.table as usize]; let trace = &trace_poly_values[table.table as usize];
for i in 0..trace[0].len() { let degree = trace[0].len();
for next_row in 0..trace[0].len() {
let local_row = if next_row == 0 {
degree - 1
} else {
next_row - 1
};
let filter = if let Some(column) = &table.filter_column { let filter = if let Some(column) = &table.filter_column {
column.eval_table(trace, i) column.eval_table(trace, local_row)
} else { } else {
F::ONE F::ONE
}; };
@ -756,9 +807,12 @@ pub(crate) mod testutils {
let row = table let row = table
.columns .columns
.iter() .iter()
.map(|c| c.eval_table(trace, i)) .map(|c| c.eval_table(trace, local_row))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
multiset.entry(row).or_default().push((table.table, i)); multiset
.entry(row)
.or_default()
.push((table.table, local_row));
} else { } else {
assert_eq!(filter, F::ZERO, "Non-binary filter?") assert_eq!(filter, F::ZERO, "Non-binary filter?")
} }