Implement multi-table CTLs

This commit is contained in:
wborgeaud 2022-06-01 18:53:19 +02:00
parent e8fc5b5752
commit 2e3a738bc5
3 changed files with 87 additions and 51 deletions

View File

@ -107,8 +107,8 @@ mod tests {
keccak_trace[5].values[..].copy_from_slice(&vs1);
let cross_table_lookups = vec![CrossTableLookup {
looking_table: Table::Cpu,
looking_columns: vec![2, 4],
looking_tables: vec![Table::Cpu],
looking_columns: vec![vec![2, 4]],
looked_table: Table::Keccak,
looked_columns: vec![3, 5],
default: vec![F::ONE; 2],

View File

@ -1,4 +1,5 @@
use anyhow::{ensure, Result};
use itertools::izip;
use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::field_types::Field;
use plonky2::field::packed_field::PackedField;
@ -22,8 +23,8 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
#[derive(Clone)]
pub struct CrossTableLookup<F: Field> {
pub looking_table: Table,
pub looking_columns: Vec<usize>,
pub looking_tables: Vec<Table>,
pub looking_columns: Vec<Vec<usize>>,
pub looked_table: Table,
pub looked_columns: Vec<usize>,
pub default: Vec<F>,
@ -31,15 +32,18 @@ pub struct CrossTableLookup<F: Field> {
impl<F: Field> CrossTableLookup<F> {
pub fn new(
looking_table: Table,
looking_columns: Vec<usize>,
looking_tables: Vec<Table>,
looking_columns: Vec<Vec<usize>>,
looked_table: Table,
looked_columns: Vec<usize>,
default: Vec<F>,
) -> Self {
assert_eq!(looking_columns.len(), looked_columns.len());
assert_eq!(looking_tables.len(), looking_columns.len());
assert!(looking_columns
.iter()
.all(|cols| cols.len() == looked_columns.len()));
Self {
looking_table,
looking_tables,
looking_columns,
looked_table,
looked_columns,
@ -87,7 +91,7 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
let challenges = get_grand_product_challenge_set(challenger, config.num_challenges);
let mut ctl_data_per_table = vec![CtlData::new(challenges.clone()); trace_poly_values.len()];
for CrossTableLookup {
looking_table,
looking_tables,
looking_columns,
looked_table,
looked_columns,
@ -95,11 +99,13 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
} in cross_table_lookups
{
for &challenge in &challenges.challenges {
let z_looking = partial_products(
&trace_poly_values[*looking_table as usize],
looking_columns,
challenge,
);
let zs_looking = looking_tables
.iter()
.zip(looking_columns)
.map(|(table, columns)| {
partial_products(&trace_poly_values[*table as usize], columns, challenge)
})
.collect::<Vec<_>>();
let z_looked = partial_products(
&trace_poly_values[*looked_table as usize],
looked_columns,
@ -107,17 +113,25 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
);
debug_assert_eq!(
*z_looking.values.last().unwrap(),
zs_looking
.iter()
.map(|z| *z.values.last().unwrap())
.product::<F>(),
*z_looked.values.last().unwrap()
* challenge.combine(default).exp_u64(
trace_poly_values[*looking_table as usize][0].len() as u64
looking_tables
.iter()
.map(|table| trace_poly_values[*table as usize][0].len() as u64)
.sum::<u64>()
- trace_poly_values[*looked_table as usize][0].len() as u64
)
);
ctl_data_per_table[*looking_table as usize]
.zs_columns
.push((z_looking, looking_columns.clone()));
for (table, columns, z) in izip!(looking_tables, looking_columns, zs_looking) {
ctl_data_per_table[*table as usize]
.zs_columns
.push((z, columns.clone()));
}
ctl_data_per_table[*looked_table as usize]
.zs_columns
.push((z_looked, looked_columns.clone()));
@ -177,7 +191,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
let mut ctl_vars_per_table = vec![vec![]; proofs.len()];
for CrossTableLookup {
looking_table,
looking_tables,
looking_columns,
looked_table,
looked_columns,
@ -185,13 +199,15 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
} in cross_table_lookups
{
for &challenges in &ctl_challenges.challenges {
let (looking_z, looking_z_next) = ctl_zs[*looking_table as usize].next().unwrap();
ctl_vars_per_table[*looking_table as usize].push(Self {
local_z: *looking_z,
next_z: *looking_z_next,
challenges,
columns: looking_columns,
});
for (table, columns) in looking_tables.iter().zip(looking_columns) {
let (looking_z, looking_z_next) = ctl_zs[*table as usize].next().unwrap();
ctl_vars_per_table[*table as usize].push(Self {
local_z: *looking_z,
next_z: *looking_z_next,
challenges,
columns,
});
}
let (looked_z, looked_z_next) = ctl_zs[*looked_table as usize].next().unwrap();
ctl_vars_per_table[*looked_table as usize].push(Self {
@ -262,7 +278,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
let mut ctl_vars_per_table = vec![vec![]; proofs.len()];
for CrossTableLookup {
looking_table,
looking_tables,
looking_columns,
looked_table,
looked_columns,
@ -270,13 +286,15 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
} in cross_table_lookups
{
for &challenges in &ctl_challenges.challenges {
let (looking_z, looking_z_next) = ctl_zs[*looking_table as usize].next().unwrap();
ctl_vars_per_table[*looking_table as usize].push(Self {
local_z: *looking_z,
next_z: *looking_z_next,
challenges,
columns: looking_columns,
});
for (table, columns) in looking_tables.iter().zip(looking_columns) {
let (looking_z, looking_z_next) = ctl_zs[*table as usize].next().unwrap();
ctl_vars_per_table[*table as usize].push(Self {
local_z: *looking_z,
next_z: *looking_z_next,
challenges,
columns,
});
}
let (looked_z, looked_z_next) = ctl_zs[*looked_table as usize].next().unwrap();
ctl_vars_per_table[*looked_table as usize].push(Self {
@ -354,22 +372,29 @@ pub(crate) fn verify_cross_table_lookups<
for (
i,
CrossTableLookup {
looking_table,
looking_tables,
looked_table,
default,
..
},
) in cross_table_lookups.into_iter().enumerate()
{
let looking_degree = 1 << degrees_bits[looking_table as usize];
let looking_degrees_sum = looking_tables
.iter()
.map(|&table| 1 << degrees_bits[table as usize])
.sum::<u64>();
let looked_degree = 1 << degrees_bits[looked_table as usize];
let looking_z = *ctl_zs_openings[looking_table as usize].next().unwrap();
let looking_zs_prod = looking_tables
.into_iter()
.map(|table| *ctl_zs_openings[table as usize].next().unwrap())
.product::<F>();
let looked_z = *ctl_zs_openings[looked_table as usize].next().unwrap();
let challenge = challenges.challenges[i % config.num_challenges];
let combined_default = challenge.combine(default.iter());
ensure!(
looking_z == looked_z * combined_default.exp_u64(looking_degree - looked_degree),
looking_zs_prod
== looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree),
"Cross-table lookup verification failed."
);
}
@ -399,16 +424,23 @@ pub(crate) fn verify_cross_table_lookups_circuit<
for (
i,
CrossTableLookup {
looking_table,
looking_tables,
looked_table,
default,
..
},
) in cross_table_lookups.into_iter().enumerate()
{
let looking_degree = 1 << degrees_bits[looking_table as usize];
let looking_degrees_sum = looking_tables
.iter()
.map(|&table| 1 << degrees_bits[table as usize])
.sum::<u64>();
let looked_degree = 1 << degrees_bits[looked_table as usize];
let looking_z = *ctl_zs_openings[looking_table as usize].next().unwrap();
let looking_zs_prod = builder.mul_many(
looking_tables
.into_iter()
.map(|table| *ctl_zs_openings[table as usize].next().unwrap()),
);
let looked_z = *ctl_zs_openings[looked_table as usize].next().unwrap();
let challenge = challenges.challenges[i % inner_config.num_challenges];
let default = default
@ -417,8 +449,8 @@ pub(crate) fn verify_cross_table_lookups_circuit<
.collect::<Vec<_>>();
let combined_default = challenge.combine_base_circuit(builder, &default);
let pad = builder.exp_u64(combined_default, looking_degree - looked_degree);
let pad = builder.exp_u64(combined_default, looking_degrees_sum - looked_degree);
let padded_looked_z = builder.mul(looked_z, pad);
builder.connect(looking_z, padded_looked_z);
builder.connect(looking_zs_prod, padded_looked_z);
}
}

View File

@ -25,7 +25,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x^3`.
pub fn cube(&mut self, x: Target) -> Target {
self.mul_many(&[x, x, x])
self.mul_many([x, x, x])
}
/// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`.
@ -206,12 +206,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// Multiply `n` `Target`s.
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
terms
.iter()
.copied()
.reduce(|acc, t| self.mul(acc, t))
.unwrap_or_else(|| self.one())
pub fn mul_many<T>(&mut self, terms: impl IntoIterator<Item = T>) -> Target
where
T: Borrow<Target>,
{
let mut iter = terms.into_iter();
if let Some(first) = iter.next() {
iter.fold(*first.borrow(), |acc, t| self.mul(acc, *t.borrow()))
} else {
self.one()
}
}
/// Exponentiate `base` to the power of `2^power_log`.