Add distinction between (non-)wrapping constraints

This commit is contained in:
wborgeaud 2022-02-02 11:23:03 +01:00
parent 1e04f4f5a4
commit bff763e3e7
6 changed files with 61 additions and 44 deletions

View File

@ -14,6 +14,9 @@ pub struct ConstraintConsumer<P: PackedField> {
/// Running sums of constraints that have been emitted so far, scaled by powers of alpha.
constraint_accs: Vec<P>,
/// The evaluation of `X - g^(n-1)`.
z_last: P,
/// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated
/// with the first trace row, and zero at other points in the subgroup.
lagrange_basis_first: P,
@ -24,10 +27,16 @@ pub struct ConstraintConsumer<P: PackedField> {
}
impl<P: PackedField> ConstraintConsumer<P> {
pub fn new(alphas: Vec<P::Scalar>, lagrange_basis_first: P, lagrange_basis_last: P) -> Self {
pub fn new(
alphas: Vec<P::Scalar>,
z_last: P,
lagrange_basis_first: P,
lagrange_basis_last: P,
) -> Self {
Self {
constraint_accs: vec![P::ZEROS; alphas.len()],
alphas,
z_last,
lagrange_basis_first,
lagrange_basis_last,
}
@ -41,31 +50,29 @@ impl<P: PackedField> ConstraintConsumer<P> {
.collect()
}
/// Add one constraint.
pub fn one(&mut self, constraint: P) {
/// Add one constraint valid on all rows except the last.
pub fn constraint(&mut self, constraint: P) {
self.constraint_wrapping(constraint * self.z_last);
}
/// Add one constraint on all rows.
pub fn constraint_wrapping(&mut self, constraint: P) {
for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) {
*acc *= alpha;
*acc += constraint;
}
}
/// Add a series of constraints.
pub fn many(&mut self, constraints: impl IntoIterator<Item = P>) {
constraints
.into_iter()
.for_each(|constraint| self.one(constraint));
}
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
/// first row of the trace.
pub fn one_first_row(&mut self, constraint: P) {
self.one(constraint * self.lagrange_basis_first);
pub fn constraint_first_row(&mut self, constraint: P) {
self.constraint_wrapping(constraint * self.lagrange_basis_first);
}
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
/// last row of the trace.
pub fn one_last_row(&mut self, constraint: P) {
self.one(constraint * self.lagrange_basis_last);
pub fn constraint_last_row(&mut self, constraint: P) {
self.constraint_wrapping(constraint * self.lagrange_basis_last);
}
}
@ -76,6 +83,9 @@ pub struct RecursiveConstraintConsumer<F: RichField + Extendable<D>, const D: us
/// A running sum of constraints that have been emitted so far, scaled by powers of alpha.
constraint_acc: ExtensionTarget<D>,
/// The evaluation of `X - g^(n-1)`.
z_last: ExtensionTarget<D>,
/// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated
/// with the first trace row, and zero at other points in the subgroup.
lagrange_basis_first: ExtensionTarget<D>,
@ -88,42 +98,45 @@ pub struct RecursiveConstraintConsumer<F: RichField + Extendable<D>, const D: us
}
impl<F: RichField + Extendable<D>, const D: usize> RecursiveConstraintConsumer<F, D> {
/// Add one constraint.
pub fn one(&mut self, builder: &mut CircuitBuilder<F, D>, constraint: ExtensionTarget<D>) {
/// Add one constraint valid on all rows except the last.
pub fn constraint(
&mut self,
builder: &mut CircuitBuilder<F, D>,
constraint: ExtensionTarget<D>,
) {
self.constraint_acc =
builder.scalar_mul_add_extension(self.alpha, self.constraint_acc, constraint);
}
/// Add a series of constraints.
pub fn many(
/// Add one constraint valid on all rows.
pub fn constraint_wrapping(
&mut self,
builder: &mut CircuitBuilder<F, D>,
constraints: impl IntoIterator<Item = ExtensionTarget<D>>,
constraint: ExtensionTarget<D>,
) {
constraints
.into_iter()
.for_each(|constraint| self.one(builder, constraint));
let filtered_constraint = builder.mul_extension(constraint, self.z_last);
self.constraint(builder, filtered_constraint);
}
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
/// first row of the trace.
pub fn one_first_row(
pub fn constraint_first_row(
&mut self,
builder: &mut CircuitBuilder<F, D>,
constraint: ExtensionTarget<D>,
) {
let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_first);
self.one(builder, filtered_constraint);
self.constraint(builder, filtered_constraint);
}
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
/// last row of the trace.
pub fn one_last_row(
pub fn constraint_last_row(
&mut self,
builder: &mut CircuitBuilder<F, D>,
constraint: ExtensionTarget<D>,
) {
let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_last);
self.one(builder, filtered_constraint);
self.constraint(builder, filtered_constraint);
}
}

View File

@ -60,14 +60,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for FibonacciStar
P: PackedField<Scalar = FE>,
{
// Check public inputs.
yield_constr.one_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]);
yield_constr.one_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]);
yield_constr.one_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]);
yield_constr
.constraint_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]);
yield_constr
.constraint_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]);
yield_constr
.constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]);
// x0 <- x1
yield_constr.one(vars.next_values[0] - vars.local_values[1]);
yield_constr.constraint(vars.next_values[0] - vars.local_values[1]);
// x1 <- x0 + x1
yield_constr.one(vars.next_values[1] - vars.local_values[0] - vars.local_values[1]);
yield_constr.constraint(vars.next_values[1] - vars.local_values[0] - vars.local_values[1]);
}
fn eval_ext_recursively(

View File

@ -197,6 +197,7 @@ where
// TODO: Set `P` to a genuine `PackedField` here.
let mut consumer = ConstraintConsumer::<F>::new(
alphas.clone(),
coset[i] - last,
lagrange_first.values[i],
lagrange_last.values[i],
);
@ -214,9 +215,8 @@ where
// We divide the constraints evaluations by `Z_H(x) / x - last`, i.e., the vanishing
// polynomial of `H` without it's last element.
let denominator_inv = z_h_on_coset.eval_inverse(i);
let z_last = coset[i] - last;
for eval in &mut constraints_evals {
*eval *= denominator_inv * z_last;
*eval *= denominator_inv;
}
constraints_evals
})

View File

@ -72,24 +72,25 @@ where
};
let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta);
let last = F::primitive_root_of_unity(degree_bits).inverse();
let z_last = challenges.stark_zeta - last.into();
let mut consumer = ConstraintConsumer::<F::Extension>::new(
challenges
.stark_alphas
.iter()
.map(|&alpha| F::Extension::from_basefield(alpha))
.collect::<Vec<_>>(),
z_last,
l_1,
l_last,
);
stark.eval_ext(vars, &mut consumer);
let acc = consumer.accumulators();
// Check each polynomial identity, of the form `vanishing(x) = Z_H(x) quotient(x) / (x - last)`, at zeta.
// Check each polynomial identity, of the form `vanishing(x) = Z_H(x) quotient(x)`, at zeta.
let quotient_polys_zeta = &proof.openings.quotient_polys;
let zeta_pow_deg = challenges.stark_zeta.exp_power_of_2(degree_bits);
let z_h_zeta = zeta_pow_deg - F::Extension::ONE;
let last = F::primitive_root_of_unity(degree_bits).inverse();
let z_last = challenges.stark_zeta - last.into();
// `quotient_polys_zeta` holds `num_challenges * quotient_degree_factor` evaluations.
// Each chunk of `quotient_degree_factor` holds the evaluations of `t_0(zeta),...,t_{quotient_degree_factor-1}(zeta)`
// where the "real" quotient polynomial is `t(X) = t_0(X) + t_1(X)*X^n + t_2(X)*X^{2n} + ...`.
@ -99,7 +100,7 @@ where
.chunks(1 << config.fri_config.rate_bits)
.enumerate()
{
ensure!(acc[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg) / z_last);
ensure!(acc[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg));
}
let merkle_caps = &[proof.trace_cap, proof.quotient_polys_cap];

View File

@ -55,16 +55,16 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
let local_clock = vars.local_values[COL_CLOCK];
let next_clock = vars.next_values[COL_CLOCK];
let delta_clock = next_clock - local_clock;
yield_constr.one_first_row(local_clock);
yield_constr.one(delta_clock - FE::ONE);
yield_constr.constraint_first_row(local_clock);
yield_constr.constraint(delta_clock - FE::ONE);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
let next_range_16 = vars.next_values[COL_RANGE_16];
let delta_range_16 = next_range_16 - local_range_16;
yield_constr.one_first_row(local_range_16);
yield_constr.one_last_row(local_range_16 - FE::from_canonical_u64((1 << 16) - 1));
yield_constr.one(delta_range_16 * (delta_range_16 - FE::ONE));
yield_constr.constraint_first_row(local_range_16);
yield_constr.constraint_last_row(local_range_16 - FE::from_canonical_u64((1 << 16) - 1));
yield_constr.constraint(delta_range_16 * (delta_range_16 - FE::ONE));
todo!()
}

View File

@ -53,7 +53,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
// Assert that the computed output matches the outputs in the trace.
for i in 0..SPONGE_WIDTH {
let out = local_values[col_permutation_output(i)];
yield_constr.one(state[i] - out);
yield_constr.constraint(state[i] - out);
}
}
@ -80,7 +80,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let out = local_values[col_permutation_output(i)];
let diff = builder.sub_extension(state[i], out);
yield_constr.one(builder, diff);
yield_constr.constraint(builder, diff);
}
}
}