Rename constraint methods (#497)

Most of our constraints apply to all rows, and it seems safest to make that the "default".
This commit is contained in:
Daniel Lubarov 2022-02-20 17:48:31 -07:00 committed by GitHub
parent bedd2aa711
commit bc3685587c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 38 deletions

View File

@ -53,12 +53,12 @@ impl<P: PackedField> ConstraintConsumer<P> {
}
/// Add one constraint valid on all rows except the last.
pub fn constraint(&mut self, constraint: P) {
self.constraint_wrapping(constraint * self.z_last);
pub fn constraint_transition(&mut self, constraint: P) {
self.constraint(constraint * self.z_last);
}
/// Add one constraint on all rows.
pub fn constraint_wrapping(&mut self, constraint: P) {
pub fn constraint(&mut self, constraint: P) {
for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) {
*acc *= alpha;
*acc += constraint;
@ -68,13 +68,13 @@ impl<P: PackedField> ConstraintConsumer<P> {
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
/// first row of the trace.
pub fn constraint_first_row(&mut self, constraint: P) {
self.constraint_wrapping(constraint * self.lagrange_basis_first);
self.constraint(constraint * self.lagrange_basis_first);
}
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
/// last row of the trace.
pub fn constraint_last_row(&mut self, constraint: P) {
self.constraint_wrapping(constraint * self.lagrange_basis_last);
self.constraint(constraint * self.lagrange_basis_last);
}
}
@ -122,17 +122,17 @@ impl<F: RichField + Extendable<D>, const D: usize> RecursiveConstraintConsumer<F
}
/// Add one constraint valid on all rows except the last.
pub fn constraint(
pub fn constraint_transition(
&mut self,
builder: &mut CircuitBuilder<F, D>,
constraint: ExtensionTarget<D>,
) {
let filtered_constraint = builder.mul_extension(constraint, self.z_last);
self.constraint_wrapping(builder, filtered_constraint);
self.constraint(builder, filtered_constraint);
}
/// Add one constraint valid on all rows.
pub fn constraint_wrapping(
pub fn constraint(
&mut self,
builder: &mut CircuitBuilder<F, D>,
constraint: ExtensionTarget<D>,
@ -150,7 +150,7 @@ impl<F: RichField + Extendable<D>, const D: usize> RecursiveConstraintConsumer<F
constraint: ExtensionTarget<D>,
) {
let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_first);
self.constraint_wrapping(builder, filtered_constraint);
self.constraint(builder, filtered_constraint);
}
/// Add one constraint, but first multiply it by a filter such that it will only apply to the
@ -161,6 +161,6 @@ impl<F: RichField + Extendable<D>, const D: usize> RecursiveConstraintConsumer<F
constraint: ExtensionTarget<D>,
) {
let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_last);
self.constraint_wrapping(builder, filtered_constraint);
self.constraint(builder, filtered_constraint);
}
}

View File

@ -68,9 +68,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for FibonacciStar
.constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]);
// x0' <- x1
yield_constr.constraint(vars.next_values[0] - vars.local_values[1]);
yield_constr.constraint_transition(vars.next_values[0] - vars.local_values[1]);
// x1' <- x0 + x1
yield_constr.constraint(vars.next_values[1] - vars.local_values[0] - vars.local_values[1]);
yield_constr.constraint_transition(
vars.next_values[1] - vars.local_values[0] - vars.local_values[1],
);
}
fn eval_ext_recursively(
@ -91,13 +93,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for FibonacciStar
// x0' <- x1
let first_col_constraint = builder.sub_extension(vars.next_values[0], vars.local_values[1]);
yield_constr.constraint(builder, first_col_constraint);
yield_constr.constraint_transition(builder, first_col_constraint);
// x1' <- x0 + x1
let second_col_constraint = {
let tmp = builder.sub_extension(vars.next_values[1], vars.local_values[0]);
builder.sub_extension(tmp, vars.local_values[1])
};
yield_constr.constraint(builder, second_col_constraint);
yield_constr.constraint_transition(builder, second_col_constraint);
}
fn constraint_degree(&self) -> usize {

View File

@ -41,7 +41,7 @@ pub(crate) fn eval_addition<F: Field, P: PackedField<Scalar = F>>(
let computed_out = in_1 + in_2 + in_3;
yield_constr.constraint_wrapping(is_add * (out - computed_out));
yield_constr.constraint(is_add * (out - computed_out));
}
pub(crate) fn eval_addition_recursively<F: RichField + Extendable<D>, const D: usize>(
@ -66,5 +66,5 @@ pub(crate) fn eval_addition_recursively<F: RichField + Extendable<D>, const D: u
let diff = builder.sub_extension(out, computed_out);
let filtered_diff = builder.mul_extension(is_add, diff);
yield_constr.constraint_wrapping(builder, filtered_diff);
yield_constr.constraint(builder, filtered_diff);
}

View File

@ -45,7 +45,7 @@ pub(crate) fn eval_alu<F: Field, P: PackedField<Scalar = F>>(
// Check that the operation flag values are binary.
for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] {
let val = local_values[col];
yield_constr.constraint_wrapping(val * val - val);
yield_constr.constraint(val * val - val);
}
eval_addition(local_values, yield_constr);
@ -65,7 +65,7 @@ pub(crate) fn eval_alu_recursively<F: RichField + Extendable<D>, const D: usize>
for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] {
let val = local_values[col];
let constraint = builder.mul_sub_extension(val, val, val);
yield_constr.constraint_wrapping(builder, constraint);
yield_constr.constraint(builder, constraint);
}
eval_addition_recursively(builder, local_values, yield_constr);

View File

@ -49,7 +49,7 @@ pub(crate) fn eval_core_registers<F: Field, P: PackedField<Scalar = F>>(
let next_clock = vars.next_values[COL_CLOCK];
let delta_clock = next_clock - local_clock;
yield_constr.constraint_first_row(local_clock);
yield_constr.constraint(delta_clock - F::ONE);
yield_constr.constraint_transition(delta_clock - F::ONE);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
@ -57,7 +57,7 @@ pub(crate) fn eval_core_registers<F: Field, P: PackedField<Scalar = F>>(
let delta_range_16 = next_range_16 - local_range_16;
yield_constr.constraint_first_row(local_range_16);
yield_constr.constraint_last_row(local_range_16 - F::from_canonical_u64((1 << 16) - 1));
yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16);
yield_constr.constraint_transition(delta_range_16 * delta_range_16 - delta_range_16);
// TODO constraints for stack etc.
}
@ -77,7 +77,7 @@ pub(crate) fn eval_core_registers_recursively<F: RichField + Extendable<D>, cons
let delta_clock = builder.sub_extension(next_clock, local_clock);
yield_constr.constraint_first_row(builder, local_clock);
let constraint = builder.sub_extension(delta_clock, one_ext);
yield_constr.constraint(builder, constraint);
yield_constr.constraint_transition(builder, constraint);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
@ -87,7 +87,7 @@ pub(crate) fn eval_core_registers_recursively<F: RichField + Extendable<D>, cons
let constraint = builder.sub_extension(local_range_16, max_u16_ext);
yield_constr.constraint_last_row(builder, constraint);
let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16);
yield_constr.constraint(builder, constraint);
yield_constr.constraint_transition(builder, constraint);
// TODO constraints for stack etc.
}

View File

@ -127,8 +127,7 @@ pub(crate) fn eval_permutation_unit<F, FE, P, const D: usize>(
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint(state_cubed - local_values[col_full_first_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
@ -136,8 +135,7 @@ pub(crate) fn eval_permutation_unit<F, FE, P, const D: usize>(
state = mds_layer(state);
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]);
yield_constr.constraint(state[i] - local_values[col_full_first_after_mds(r, i)]);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
}
@ -146,10 +144,10 @@ pub(crate) fn eval_permutation_unit<F, FE, P, const D: usize>(
state = constant_layer(state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = state[0] * state[0].square();
yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]);
yield_constr.constraint(state0_cubed - local_values[col_partial_mid_sbox(r)]);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] *= state0_cubed.square(); // Form state ** 7.
yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]);
yield_constr.constraint(state[0] - local_values[col_partial_after_sbox(r)]);
state[0] = local_values[col_partial_after_sbox(r)];
state = mds_layer(state);
@ -160,8 +158,7 @@ pub(crate) fn eval_permutation_unit<F, FE, P, const D: usize>(
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint(state_cubed - local_values[col_full_second_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
@ -169,8 +166,7 @@ pub(crate) fn eval_permutation_unit<F, FE, P, const D: usize>(
state = mds_layer(state);
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]);
yield_constr.constraint(state[i] - local_values[col_full_second_after_mds(r, i)]);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
}
@ -197,7 +193,7 @@ pub(crate) fn eval_permutation_unit_recursively<F: RichField + Extendable<D>, co
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
yield_constr.constraint(builder, diff);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
@ -208,7 +204,7 @@ pub(crate) fn eval_permutation_unit_recursively<F: RichField + Extendable<D>, co
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
yield_constr.constraint(builder, diff);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
}
@ -218,11 +214,11 @@ pub(crate) fn eval_permutation_unit_recursively<F: RichField + Extendable<D>, co
let state0_cubed = builder.cube_extension(state[0]);
let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
yield_constr.constraint(builder, diff);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7.
let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
yield_constr.constraint(builder, diff);
state[0] = local_values[col_partial_after_sbox(r)];
state = F::mds_layer_recursive(builder, &state);
@ -239,7 +235,7 @@ pub(crate) fn eval_permutation_unit_recursively<F: RichField + Extendable<D>, co
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
yield_constr.constraint(builder, diff);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
@ -250,7 +246,7 @@ pub(crate) fn eval_permutation_unit_recursively<F: RichField + Extendable<D>, co
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
yield_constr.constraint(builder, diff);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
}