Add constraints

This commit is contained in:
wborgeaud 2022-06-07 23:09:09 +02:00
parent 2ff738633b
commit 05d2c69eb0
3 changed files with 92 additions and 34 deletions

View File

@ -91,7 +91,7 @@ mod tests {
let mut cpu_trace_rows = vec![];
for i in 0..cpu_rows {
let mut cpu_trace_row = [F::ZERO; CpuStark::<F, D>::COLUMNS];
cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ONE;
cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ZERO;
cpu_trace_row[cpu::columns::OPCODE] = F::from_canonical_usize(i);
cpu_stark.generate(&mut cpu_trace_row);
cpu_trace_rows.push(cpu_trace_row);

View File

@ -76,8 +76,9 @@ impl<F: Field> CrossTableLookup<F> {
pub struct CtlData<F: Field> {
/// Challenges used in the argument.
pub(crate) challenges: GrandProductChallengeSet<F>,
/// Vector of `(Z, columns)` where `Z` is a Z-polynomial for a lookup on columns `columns`.
pub zs_columns: Vec<(PolynomialValues<F>, Vec<usize>)>,
/// Vector of `(Z, columns, filter_columns)` where `Z` is a Z-polynomial for a lookup
/// on columns `columns` with filter columns `filter_columns`.
pub zs_columns: Vec<(PolynomialValues<F>, Vec<usize>, Vec<usize>)>,
}
impl<F: Field> CtlData<F> {
@ -97,7 +98,7 @@ impl<F: Field> CtlData<F> {
}
pub fn z_polys(&self) -> Vec<PolynomialValues<F>> {
self.zs_columns.iter().map(|(p, _)| p.clone()).collect()
self.zs_columns.iter().map(|(p, _, _)| p.clone()).collect()
}
}
@ -155,13 +156,19 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
);
for (table, z) in looking_tables.iter().zip(zs_looking) {
ctl_data_per_table[table.table as usize]
.zs_columns
.push((z, table.columns.clone()));
ctl_data_per_table[table.table as usize].zs_columns.push((
z,
table.columns.clone(),
table.filter_columns.clone(),
));
}
ctl_data_per_table[looked_table.table as usize]
.zs_columns
.push((z_looked, looked_table.columns.clone()));
.push((
z_looked,
looked_table.columns.clone(),
looked_table.filter_columns.clone(),
));
}
}
ctl_data_per_table
@ -178,14 +185,16 @@ fn partial_products<F: Field>(
let mut res = Vec::with_capacity(degree);
for i in 0..degree {
let filter = if filter_columns.is_empty() {
1
F::ONE
} else {
filter_columns.iter().sum()
filter_columns.iter().map(|&j| trace[j].values[i]).sum()
};
partial_prod *= match filter {
0 => F::ONE,
1 => challenge.combine(columns.iter().map(|&j| &trace[j].values[i])),
_ => panic!("Non-binary filter?"),
partial_prod *= if filter.is_zero() {
F::ONE
} else if filter.is_one() {
challenge.combine(columns.iter().map(|&j| &trace[j].values[i]))
} else {
panic!("Non-binary filter?")
};
res.push(partial_prod);
}
@ -203,6 +212,7 @@ where
pub(crate) next_z: P,
pub(crate) challenges: GrandProductChallenge<F>,
pub(crate) columns: &'a [usize],
pub(crate) filter_columns: &'a [usize],
}
impl<'a, F: RichField + Extendable<D>, const D: usize>
@ -241,6 +251,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
next_z: *looking_z_next,
challenges,
columns: &table.columns,
filter_columns: &table.filter_columns,
});
}
@ -250,6 +261,7 @@ impl<'a, F: RichField + Extendable<D>, const D: usize>
next_z: *looked_z_next,
challenges,
columns: &looked_table.columns,
filter_columns: &looked_table.filter_columns,
});
}
}
@ -274,13 +286,26 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, C, S, const D: usize, con
next_z,
challenges,
columns,
filter_columns,
} = lookup_vars;
let combine = |v: &[P]| -> P { challenges.combine(columns.iter().map(|&i| &v[i])) };
let filter = |v: &[P]| -> P {
if filter_columns.is_empty() {
P::ONES
} else {
filter_columns.iter().map(|&i| v[i]).sum()
}
};
let local_filter = filter(vars.local_values);
let next_filter = filter(vars.next_values);
let select = |filter, x| filter * x + P::ONES - filter;
// Check value of `Z(1)`
consumer.constraint_first_row(*local_z - combine(vars.local_values));
consumer.constraint_first_row(*local_z - select(local_filter, combine(vars.local_values)));
// Check `Z(gw) = combination * Z(w)`
consumer.constraint_transition(*next_z - *local_z * combine(vars.next_values));
consumer.constraint_transition(
*next_z - *local_z * select(next_filter, combine(vars.next_values)),
);
}
}
@ -290,6 +315,7 @@ pub struct CtlCheckVarsTarget<'a, const D: usize> {
pub(crate) next_z: ExtensionTarget<D>,
pub(crate) challenges: GrandProductChallenge<Target>,
pub(crate) columns: &'a [usize],
pub(crate) filter_columns: &'a [usize],
}
impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
@ -326,6 +352,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
next_z: *looking_z_next,
challenges,
columns: &table.columns,
filter_columns: &table.filter_columns,
});
}
@ -335,6 +362,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> {
next_z: *looked_z_next,
challenges,
columns: &looked_table.columns,
filter_columns: &looked_table.filter_columns,
});
}
}
@ -358,8 +386,30 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
next_z,
challenges,
columns,
filter_columns,
} = lookup_vars;
let one = builder.one_extension();
let local_filter = if filter_columns.is_empty() {
one
} else {
builder.add_many_extension(filter_columns.iter().map(|&i| vars.local_values[i]))
};
let next_filter = if filter_columns.is_empty() {
one
} else {
builder.add_many_extension(filter_columns.iter().map(|&i| vars.next_values[i]))
};
fn select<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
filter: ExtensionTarget<D>,
x: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let one = builder.one_extension();
let tmp = builder.sub_extension(one, filter);
builder.mul_add_extension(filter, x, tmp) // filter * x + 1 - filter
}
// Check value of `Z(1)`
let combined_local = challenges.combine_circuit(
builder,
@ -368,7 +418,8 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
.map(|&i| vars.local_values[i])
.collect::<Vec<_>>(),
);
let first_row = builder.sub_extension(*local_z, combined_local);
let selected_local = select(builder, local_filter, combined_local);
let first_row = builder.sub_extension(*local_z, selected_local);
consumer.constraint_first_row(builder, first_row);
// Check `Z(gw) = combination * Z(w)`
let combined_next = challenges.combine_circuit(
@ -378,7 +429,8 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
.map(|&i| vars.next_values[i])
.collect::<Vec<_>>(),
);
let mut transition = builder.mul_extension(*local_z, combined_next);
let selected_next = select(builder, next_filter, combined_next);
let mut transition = builder.mul_extension(*local_z, selected_next);
transition = builder.sub_extension(*next_z, transition);
consumer.constraint_transition(builder, transition);
}

View File

@ -391,14 +391,17 @@ where
.zs_columns
.iter()
.enumerate()
.map(|(i, (_, columns))| CtlCheckVars::<F, F, P, 1> {
local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step)
[num_permutation_zs + i],
next_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_next_start, step)
[num_permutation_zs + i],
challenges: ctl_data.challenges.challenges[i % config.num_challenges],
columns,
})
.map(
|(i, (_, columns, filter_columns))| CtlCheckVars::<F, F, P, 1> {
local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step)
[num_permutation_zs + i],
next_z: permutation_ctl_zs_commitment
.get_lde_values_packed(i_next_start, step)[num_permutation_zs + i],
challenges: ctl_data.challenges.challenges[i % config.num_challenges],
columns,
filter_columns,
},
)
.collect::<Vec<_>>();
eval_vanishing_poly::<F, F, P, C, S, D, 1>(
stark,
@ -506,14 +509,17 @@ fn check_constraints<'a, F, C, S, const D: usize>(
.zs_columns
.iter()
.enumerate()
.map(|(iii, (_, columns))| CtlCheckVars::<F, F, F, 1> {
local_z: get_comm_values(permutation_ctl_zs_commitment, i)
[num_permutation_zs + iii],
next_z: get_comm_values(permutation_ctl_zs_commitment, i_next)
[num_permutation_zs + iii],
challenges: ctl_data.challenges.challenges[iii % config.num_challenges],
columns,
})
.map(
|(iii, (_, columns, filter_columns))| CtlCheckVars::<F, F, F, 1> {
local_z: get_comm_values(permutation_ctl_zs_commitment, i)
[num_permutation_zs + iii],
next_z: get_comm_values(permutation_ctl_zs_commitment, i_next)
[num_permutation_zs + iii],
challenges: ctl_data.challenges.challenges[iii % config.num_challenges],
columns,
filter_columns,
},
)
.collect::<Vec<_>>();
eval_vanishing_poly::<F, F, F, C, S, D, 1>(
stark,