Move next row logic inside Column

Co-authored-by: Nicholas Ward <npward@berkeley.edu>
This commit is contained in:
Hamy Ratoanina 2023-09-15 18:14:07 -04:00
parent 27d9113feb
commit 1a4caaa08f
No known key found for this signature in database
GPG Key ID: 054683A21827F7C4
6 changed files with 137 additions and 101 deletions

View File

@ -128,19 +128,16 @@ fn ctl_byte_packing<F: Field>() -> CrossTableLookup<F> {
let cpu_packing_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_byte_packing(),
vec![],
Some(cpu_stark::ctl_filter_byte_packing()),
);
let cpu_unpacking_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_byte_unpacking(),
vec![],
Some(cpu_stark::ctl_filter_byte_unpacking()),
);
let byte_packing_looked = TableWithColumns::new(
Table::BytePacking,
byte_packing_stark::ctl_looked_data(),
vec![],
Some(byte_packing_stark::ctl_looked_filter()),
);
CrossTableLookup::new(
@ -153,13 +150,11 @@ fn ctl_keccak<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak(),
vec![],
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data(),
vec![],
Some(keccak_stark::ctl_filter()),
);
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
@ -169,13 +164,11 @@ fn ctl_keccak_sponge<F: Field>() -> CrossTableLookup<F> {
let cpu_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_keccak_sponge(),
vec![],
Some(cpu_stark::ctl_filter_keccak_sponge()),
);
let keccak_sponge_looked = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looked_data(),
vec![],
Some(keccak_sponge_stark::ctl_looked_filter()),
);
CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked)
@ -185,7 +178,6 @@ fn ctl_logic<F: Field>() -> CrossTableLookup<F> {
let cpu_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_logic(),
vec![],
Some(cpu_stark::ctl_filter_logic()),
);
let mut all_lookers = vec![cpu_looking];
@ -193,17 +185,12 @@ fn ctl_logic<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_logic(i),
vec![],
Some(keccak_sponge_stark::ctl_looking_logic_filter()),
);
all_lookers.push(keccak_sponge_looking);
}
let logic_looked = TableWithColumns::new(
Table::Logic,
logic::ctl_data(),
vec![],
Some(logic::ctl_filter()),
);
let logic_looked =
TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter()));
CrossTableLookup::new(all_lookers, logic_looked)
}
@ -211,14 +198,12 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
let cpu_memory_code_read = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_code_memory(),
vec![],
Some(cpu_stark::ctl_filter_code_memory()),
);
let cpu_memory_gp_ops = (0..NUM_GP_CHANNELS).map(|channel| {
TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_gp_memory(channel),
vec![],
Some(cpu_stark::ctl_filter_gp_memory(channel)),
)
});
@ -226,7 +211,6 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_memory(i),
vec![],
Some(keccak_sponge_stark::ctl_looking_memory_filter(i)),
)
});
@ -234,7 +218,6 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
TableWithColumns::new(
Table::BytePacking,
byte_packing_stark::ctl_looking_memory(i),
vec![],
Some(byte_packing_stark::ctl_looking_memory_filter(i)),
)
});
@ -246,7 +229,6 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
let memory_looked = TableWithColumns::new(
Table::Memory,
memory_stark::ctl_data(),
vec![],
Some(memory_stark::ctl_filter()),
);
CrossTableLookup::new(all_lookers, memory_looked)

View File

@ -96,7 +96,6 @@ pub fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> {
TableWithColumns::new(
Table::Arithmetic,
cpu_arith_data_link(&COMBINED_OPS, &REGISTER_MAP),
vec![],
filter_column,
)
}

View File

@ -103,7 +103,6 @@ pub fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
TableWithColumns::new(
Table::Cpu,
columns,
vec![],
Some(Column::sum([
COL_MAP.op.binary_op,
COL_MAP.op.fp254_op,
@ -121,12 +120,7 @@ pub fn ctl_arithmetic_shift_rows<F: Field>() -> TableWithColumns<F> {
// (also `ops` is used as the operation filter). The list of
// operations includes binary operations which will simply ignore
// the third input.
TableWithColumns::new(
Table::Cpu,
columns,
vec![],
Some(Column::single(COL_MAP.op.shift)),
)
TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift)))
}
pub fn ctl_data_byte_packing<F: Field>() -> Vec<Column<F>> {

View File

@ -25,6 +25,7 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
#[derive(Clone, Debug)]
pub struct Column<F: Field> {
linear_combination: Vec<(usize, F)>,
next_row_linear_combination: Vec<(usize, F)>,
constant: F,
}
@ -32,6 +33,7 @@ impl<F: Field> Column<F> {
pub fn single(c: usize) -> Self {
Self {
linear_combination: vec![(c, F::ONE)],
next_row_linear_combination: vec![],
constant: F::ZERO,
}
}
@ -42,9 +44,24 @@ impl<F: Field> Column<F> {
cs.into_iter().map(|c| Self::single(*c.borrow()))
}
pub fn single_next_row(c: usize) -> Self {
Self {
linear_combination: vec![],
next_row_linear_combination: vec![(c, F::ONE)],
constant: F::ZERO,
}
}
pub fn singles_next_row<I: IntoIterator<Item = impl Borrow<usize>>>(
cs: I,
) -> impl Iterator<Item = Self> {
cs.into_iter().map(|c| Self::single_next_row(*c.borrow()))
}
pub fn constant(constant: F) -> Self {
Self {
linear_combination: vec![],
next_row_linear_combination: vec![],
constant,
}
}
@ -70,6 +87,34 @@ impl<F: Field> Column<F> {
);
Self {
linear_combination: v,
next_row_linear_combination: vec![],
constant,
}
}
pub fn linear_combination_and_next_row_with_constant<I: IntoIterator<Item = (usize, F)>>(
iter: I,
next_row_iter: I,
constant: F,
) -> Self {
let v = iter.into_iter().collect::<Vec<_>>();
let next_row_v = next_row_iter.into_iter().collect::<Vec<_>>();
assert!(!v.is_empty() || !next_row_v.is_empty());
debug_assert_eq!(
v.iter().map(|(c, _)| c).unique().count(),
v.len(),
"Duplicate columns."
);
debug_assert_eq!(
next_row_v.iter().map(|(c, _)| c).unique().count(),
next_row_v.len(),
"Duplicate columns."
);
Self {
linear_combination: v,
next_row_linear_combination: next_row_v,
constant,
}
}
@ -106,13 +151,43 @@ impl<F: Field> Column<F> {
+ FE::from_basefield(self.constant)
}
pub fn eval_with_next<FE, P, const D: usize>(&self, v: &[P], next_v: &[P]) -> P
where
FE: FieldExtension<D, BaseField = F>,
P: PackedField<Scalar = FE>,
{
self.linear_combination
.iter()
.map(|&(c, f)| v[c] * FE::from_basefield(f))
.sum::<P>()
+ self
.next_row_linear_combination
.iter()
.map(|&(c, f)| next_v[c] * FE::from_basefield(f))
.sum::<P>()
+ FE::from_basefield(self.constant)
}
/// Evaluate on an row of a table given in column-major form.
pub fn eval_table(&self, table: &[PolynomialValues<F>], row: usize) -> F {
self.linear_combination
let mut res = self
.linear_combination
.iter()
.map(|&(c, f)| table[c].values[row] * f)
.sum::<F>()
+ self.constant
+ self.constant;
// If we access the next row at the last row, for sanity, we consider the next row's values to be 0.
// If CTLs are correctly written, the filter should be 0 in that case anyway.
if !self.next_row_linear_combination.is_empty() && row < table.len() - 1 {
res += self
.next_row_linear_combination
.iter()
.map(|&(c, f)| table[c].values[row + 1] * f)
.sum::<F>();
}
res
}
pub fn eval_circuit<const D: usize>(
@ -136,27 +211,50 @@ impl<F: Field> Column<F> {
let constant = builder.constant_extension(F::Extension::from_basefield(self.constant));
builder.inner_product_extension(F::ONE, constant, pairs)
}
pub fn eval_with_next_circuit<const D: usize>(
&self,
builder: &mut CircuitBuilder<F, D>,
v: &[ExtensionTarget<D>],
next_v: &[ExtensionTarget<D>],
) -> ExtensionTarget<D>
where
F: RichField + Extendable<D>,
{
let mut pairs = self
.linear_combination
.iter()
.map(|&(c, f)| {
(
v[c],
builder.constant_extension(F::Extension::from_basefield(f)),
)
})
.collect::<Vec<_>>();
let next_row_pairs = self.next_row_linear_combination.iter().map(|&(c, f)| {
(
next_v[c],
builder.constant_extension(F::Extension::from_basefield(f)),
)
});
pairs.extend(next_row_pairs);
let constant = builder.constant_extension(F::Extension::from_basefield(self.constant));
builder.inner_product_extension(F::ONE, constant, pairs)
}
}
#[derive(Clone, Debug)]
pub struct TableWithColumns<F: Field> {
table: Table,
local_columns: Vec<Column<F>>,
next_columns: Vec<Column<F>>,
columns: Vec<Column<F>>,
pub(crate) filter_column: Option<Column<F>>,
}
impl<F: Field> TableWithColumns<F> {
pub fn new(
table: Table,
local_columns: Vec<Column<F>>,
next_columns: Vec<Column<F>>,
filter_column: Option<Column<F>>,
) -> Self {
pub fn new(table: Table, columns: Vec<Column<F>>, filter_column: Option<Column<F>>) -> Self {
Self {
table,
local_columns,
next_columns,
columns,
filter_column,
}
}
@ -175,8 +273,7 @@ impl<F: Field> CrossTableLookup<F> {
) -> Self {
assert!(looking_tables
.iter()
.all(|twc| (twc.local_columns.len() + twc.next_columns.len())
== (looked_table.local_columns.len() + looked_table.next_columns.len())));
.all(|twc| twc.columns.len() == looked_table.columns.len()));
Self {
looking_tables,
looked_table,
@ -204,8 +301,7 @@ pub struct CtlData<F: Field> {
pub(crate) struct CtlZData<F: Field> {
pub(crate) z: PolynomialValues<F>,
pub(crate) challenge: GrandProductChallenge<F>,
pub(crate) local_columns: Vec<Column<F>>,
pub(crate) next_columns: Vec<Column<F>>,
pub(crate) columns: Vec<Column<F>>,
pub(crate) filter_column: Option<Column<F>>,
}
@ -242,16 +338,14 @@ pub(crate) fn cross_table_lookup_data<F: RichField, const D: usize>(
let zs_looking = looking_tables.iter().map(|table| {
partial_products(
&trace_poly_values[table.table as usize],
&table.local_columns,
&table.next_columns,
&table.columns,
&table.filter_column,
challenge,
)
});
let z_looked = partial_products(
&trace_poly_values[looked_table.table as usize],
&looked_table.local_columns,
&looked_table.next_columns,
&looked_table.columns,
&looked_table.filter_column,
challenge,
);
@ -261,8 +355,7 @@ pub(crate) fn cross_table_lookup_data<F: RichField, const D: usize>(
.push(CtlZData {
z,
challenge,
local_columns: table.local_columns.clone(),
next_columns: table.next_columns.clone(),
columns: table.columns.clone(),
filter_column: table.filter_column.clone(),
});
}
@ -271,8 +364,7 @@ pub(crate) fn cross_table_lookup_data<F: RichField, const D: usize>(
.push(CtlZData {
z: z_looked,
challenge,
local_columns: looked_table.local_columns.clone(),
next_columns: looked_table.next_columns.clone(),
columns: looked_table.columns.clone(),
filter_column: looked_table.filter_column.clone(),
});
}
@ -282,8 +374,7 @@ pub(crate) fn cross_table_lookup_data<F: RichField, const D: usize>(
fn partial_products<F: Field>(
trace: &[PolynomialValues<F>],
local_columns: &[Column<F>],
next_columns: &[Column<F>],
columns: &[Column<F>],
filter_column: &Option<Column<F>>,
challenge: GrandProductChallenge<F>,
) -> PolynomialValues<F> {
@ -297,16 +388,9 @@ fn partial_products<F: Field>(
F::ONE
};
if filter.is_one() {
let evals = local_columns
let evals = columns
.iter()
.map(|c| c.eval_table(trace, i))
.chain(
next_columns
.iter()
// The modulo is there to avoid out of bounds. For any CTL using next row
// values, we expect the filter to be 0 at the last row.
.map(|c| c.eval_table(trace, (i + 1) % degree)),
)
.collect::<Vec<_>>();
partial_prod *= challenge.combine(evals.iter());
} else {
@ -328,8 +412,7 @@ where
pub(crate) local_z: P,
pub(crate) next_z: P,
pub(crate) challenges: GrandProductChallenge<F>,
pub(crate) local_columns: &'a [Column<F>],
pub(crate) next_columns: &'a [Column<F>],
pub(crate) columns: &'a [Column<F>],
pub(crate) filter_column: &'a Option<Column<F>>,
}
@ -366,8 +449,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
local_z: *looking_z,
next_z: *looking_z_next,
challenges,
local_columns: &table.local_columns,
next_columns: &table.next_columns,
columns: &table.columns,
filter_column: &table.filter_column,
});
}
@ -377,8 +459,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
local_z: *looked_z,
next_z: *looked_z_next,
challenges,
local_columns: &looked_table.local_columns,
next_columns: &looked_table.next_columns,
columns: &looked_table.columns,
filter_column: &looked_table.filter_column,
});
}
@ -406,16 +487,14 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, S, const D: usize, const
local_z,
next_z,
challenges,
local_columns,
next_columns,
columns,
filter_column,
} = lookup_vars;
let mut evals = local_columns
let evals = columns
.iter()
.map(|c| c.eval(vars.local_values))
.map(|c| c.eval_with_next(vars.local_values, vars.next_values))
.collect::<Vec<_>>();
evals.extend(next_columns.iter().map(|c| c.eval(vars.next_values)));
let combined = challenges.combine(evals.iter());
let local_filter = if let Some(column) = filter_column {
column.eval(vars.local_values)
@ -436,8 +515,7 @@ pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> {
pub(crate) local_z: ExtensionTarget<D>,
pub(crate) next_z: ExtensionTarget<D>,
pub(crate) challenges: GrandProductChallenge<Target>,
pub(crate) local_columns: &'a [Column<F>],
pub(crate) next_columns: &'a [Column<F>],
pub(crate) columns: &'a [Column<F>],
pub(crate) filter_column: &'a Option<Column<F>>,
}
@ -473,8 +551,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> {
local_z: *looking_z,
next_z: *looking_z_next,
challenges,
local_columns: &looking_table.local_columns,
next_columns: &looking_table.next_columns,
columns: &looking_table.columns,
filter_column: &looking_table.filter_column,
});
}
@ -486,8 +563,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> {
local_z: *looked_z,
next_z: *looked_z_next,
challenges,
local_columns: &looked_table.local_columns,
next_columns: &looked_table.next_columns,
columns: &looked_table.columns,
filter_column: &looked_table.filter_column,
});
}
@ -513,8 +589,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
local_z,
next_z,
challenges,
local_columns,
next_columns,
columns,
filter_column,
} = lookup_vars;
@ -534,15 +609,10 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
builder.mul_add_extension(filter, x, tmp) // filter * x + 1 - filter
}
let mut evals = local_columns
let evals = columns
.iter()
.map(|c| c.eval_circuit(builder, vars.local_values))
.map(|c| c.eval_with_next_circuit(builder, vars.local_values, vars.next_values))
.collect::<Vec<_>>();
evals.extend(
next_columns
.iter()
.map(|c| c.eval_circuit(builder, vars.next_values)),
);
let combined = challenges.combine_circuit(builder, &evals);
let select = select(builder, local_filter, combined);
@ -692,15 +762,9 @@ pub(crate) mod testutils {
};
if filter.is_one() {
let row = table
.local_columns
.columns
.iter()
.map(|c| c.eval_table(trace, i))
.chain(
table
.next_columns
.iter()
.map(|c| c.eval_table(trace, (i + 1) % trace[0].len())),
)
.collect::<Vec<_>>();
multiset.entry(row).or_default().push((table.table, i));
} else {

View File

@ -763,8 +763,7 @@ mod tests {
beta: F::ZERO,
gamma: F::ZERO,
},
local_columns: vec![],
next_columns: vec![],
columns: vec![],
filter_column: None,
};
let ctl_data = CtlData {

View File

@ -589,8 +589,7 @@ where
next_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_next_start, step)
[num_permutation_zs + i],
challenges: zs_columns.challenge,
local_columns: &zs_columns.local_columns,
next_columns: &zs_columns.next_columns,
columns: &zs_columns.columns,
filter_column: &zs_columns.filter_column,
})
.collect::<Vec<_>>();
@ -708,8 +707,7 @@ fn check_constraints<'a, F, C, S, const D: usize>(
local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii],
next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii],
challenges: zs_columns.challenge,
local_columns: &zs_columns.local_columns,
next_columns: &zs_columns.next_columns,
columns: &zs_columns.columns,
filter_column: &zs_columns.filter_column,
})
.collect::<Vec<_>>();