From bc3685587cc371f96eabd7d169498251e1db55b1 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 20 Feb 2022 17:48:31 -0700 Subject: [PATCH] Rename constraint methods (#497) Most of our constraints apply to all rows, and it seems safest to make that the "default". --- starky/src/constraint_consumer.rs | 20 ++++++++++---------- starky/src/fibonacci_stark.rs | 10 ++++++---- system_zero/src/alu/addition.rs | 4 ++-- system_zero/src/alu/mod.rs | 4 ++-- system_zero/src/core_registers.rs | 8 ++++---- system_zero/src/permutation_unit.rs | 28 ++++++++++++---------------- 6 files changed, 36 insertions(+), 38 deletions(-) diff --git a/starky/src/constraint_consumer.rs b/starky/src/constraint_consumer.rs index 88f66118..ada28730 100644 --- a/starky/src/constraint_consumer.rs +++ b/starky/src/constraint_consumer.rs @@ -53,12 +53,12 @@ impl ConstraintConsumer

{ } /// Add one constraint valid on all rows except the last. - pub fn constraint(&mut self, constraint: P) { - self.constraint_wrapping(constraint * self.z_last); + pub fn constraint_transition(&mut self, constraint: P) { + self.constraint(constraint * self.z_last); } /// Add one constraint on all rows. - pub fn constraint_wrapping(&mut self, constraint: P) { + pub fn constraint(&mut self, constraint: P) { for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) { *acc *= alpha; *acc += constraint; @@ -68,13 +68,13 @@ impl ConstraintConsumer

{ /// Add one constraint, but first multiply it by a filter such that it will only apply to the /// first row of the trace. pub fn constraint_first_row(&mut self, constraint: P) { - self.constraint_wrapping(constraint * self.lagrange_basis_first); + self.constraint(constraint * self.lagrange_basis_first); } /// Add one constraint, but first multiply it by a filter such that it will only apply to the /// last row of the trace. pub fn constraint_last_row(&mut self, constraint: P) { - self.constraint_wrapping(constraint * self.lagrange_basis_last); + self.constraint(constraint * self.lagrange_basis_last); } } @@ -122,17 +122,17 @@ impl, const D: usize> RecursiveConstraintConsumer, constraint: ExtensionTarget, ) { let filtered_constraint = builder.mul_extension(constraint, self.z_last); - self.constraint_wrapping(builder, filtered_constraint); + self.constraint(builder, filtered_constraint); } /// Add one constraint valid on all rows. - pub fn constraint_wrapping( + pub fn constraint( &mut self, builder: &mut CircuitBuilder, constraint: ExtensionTarget, @@ -150,7 +150,7 @@ impl, const D: usize> RecursiveConstraintConsumer, ) { let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_first); - self.constraint_wrapping(builder, filtered_constraint); + self.constraint(builder, filtered_constraint); } /// Add one constraint, but first multiply it by a filter such that it will only apply to the @@ -161,6 +161,6 @@ impl, const D: usize> RecursiveConstraintConsumer, ) { let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_last); - self.constraint_wrapping(builder, filtered_constraint); + self.constraint(builder, filtered_constraint); } } diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs index bd1775e1..a0204359 100644 --- a/starky/src/fibonacci_stark.rs +++ b/starky/src/fibonacci_stark.rs @@ -68,9 +68,11 @@ impl, const D: usize> Stark for FibonacciStar .constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]); // x0' <- x1 - yield_constr.constraint(vars.next_values[0] - vars.local_values[1]); + yield_constr.constraint_transition(vars.next_values[0] - vars.local_values[1]); // x1' <- x0 + x1 - yield_constr.constraint(vars.next_values[1] - vars.local_values[0] - vars.local_values[1]); + yield_constr.constraint_transition( + vars.next_values[1] - vars.local_values[0] - vars.local_values[1], + ); } fn eval_ext_recursively( @@ -91,13 +93,13 @@ impl, const D: usize> Stark for FibonacciStar // x0' <- x1 let first_col_constraint = builder.sub_extension(vars.next_values[0], vars.local_values[1]); - yield_constr.constraint(builder, first_col_constraint); + yield_constr.constraint_transition(builder, first_col_constraint); // x1' <- x0 + x1 let second_col_constraint = { let tmp = builder.sub_extension(vars.next_values[1], vars.local_values[0]); builder.sub_extension(tmp, vars.local_values[1]) }; - yield_constr.constraint(builder, second_col_constraint); + yield_constr.constraint_transition(builder, second_col_constraint); } fn constraint_degree(&self) -> usize { diff --git a/system_zero/src/alu/addition.rs b/system_zero/src/alu/addition.rs index 068092e8..dc83ecb8 100644 --- a/system_zero/src/alu/addition.rs +++ b/system_zero/src/alu/addition.rs @@ -41,7 +41,7 @@ pub(crate) fn eval_addition>( let computed_out = in_1 + in_2 + in_3; - yield_constr.constraint_wrapping(is_add * (out - computed_out)); + yield_constr.constraint(is_add * (out - computed_out)); } pub(crate) fn eval_addition_recursively, const D: usize>( @@ -66,5 +66,5 @@ pub(crate) fn eval_addition_recursively, const D: u let diff = builder.sub_extension(out, computed_out); let filtered_diff = builder.mul_extension(is_add, diff); - yield_constr.constraint_wrapping(builder, filtered_diff); + yield_constr.constraint(builder, filtered_diff); } diff --git a/system_zero/src/alu/mod.rs b/system_zero/src/alu/mod.rs index 17a12df1..4e7e09fa 100644 --- a/system_zero/src/alu/mod.rs +++ b/system_zero/src/alu/mod.rs @@ -45,7 +45,7 @@ pub(crate) fn eval_alu>( // Check that the operation flag values are binary. for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { let val = local_values[col]; - yield_constr.constraint_wrapping(val * val - val); + yield_constr.constraint(val * val - val); } eval_addition(local_values, yield_constr); @@ -65,7 +65,7 @@ pub(crate) fn eval_alu_recursively, const D: usize> for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { let val = local_values[col]; let constraint = builder.mul_sub_extension(val, val, val); - yield_constr.constraint_wrapping(builder, constraint); + yield_constr.constraint(builder, constraint); } eval_addition_recursively(builder, local_values, yield_constr); diff --git a/system_zero/src/core_registers.rs b/system_zero/src/core_registers.rs index c8c6533b..1f33611a 100644 --- a/system_zero/src/core_registers.rs +++ b/system_zero/src/core_registers.rs @@ -49,7 +49,7 @@ pub(crate) fn eval_core_registers>( let next_clock = vars.next_values[COL_CLOCK]; let delta_clock = next_clock - local_clock; yield_constr.constraint_first_row(local_clock); - yield_constr.constraint(delta_clock - F::ONE); + yield_constr.constraint_transition(delta_clock - F::ONE); // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. let local_range_16 = vars.local_values[COL_RANGE_16]; @@ -57,7 +57,7 @@ pub(crate) fn eval_core_registers>( let delta_range_16 = next_range_16 - local_range_16; yield_constr.constraint_first_row(local_range_16); yield_constr.constraint_last_row(local_range_16 - F::from_canonical_u64((1 << 16) - 1)); - yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); + yield_constr.constraint_transition(delta_range_16 * delta_range_16 - delta_range_16); // TODO constraints for stack etc. } @@ -77,7 +77,7 @@ pub(crate) fn eval_core_registers_recursively, cons let delta_clock = builder.sub_extension(next_clock, local_clock); yield_constr.constraint_first_row(builder, local_clock); let constraint = builder.sub_extension(delta_clock, one_ext); - yield_constr.constraint(builder, constraint); + yield_constr.constraint_transition(builder, constraint); // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. let local_range_16 = vars.local_values[COL_RANGE_16]; @@ -87,7 +87,7 @@ pub(crate) fn eval_core_registers_recursively, cons let constraint = builder.sub_extension(local_range_16, max_u16_ext); yield_constr.constraint_last_row(builder, constraint); let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); - yield_constr.constraint(builder, constraint); + yield_constr.constraint_transition(builder, constraint); // TODO constraints for stack etc. } diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index 366cff65..079ab14a 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -127,8 +127,7 @@ pub(crate) fn eval_permutation_unit( for i in 0..SPONGE_WIDTH { let state_cubed = state[i] * state[i].square(); - yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); + yield_constr.constraint(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; state[i] *= state_cubed.square(); // Form state ** 7. } @@ -136,8 +135,7 @@ pub(crate) fn eval_permutation_unit( state = mds_layer(state); for i in 0..SPONGE_WIDTH { - yield_constr - .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); + yield_constr.constraint(state[i] - local_values[col_full_first_after_mds(r, i)]); state[i] = local_values[col_full_first_after_mds(r, i)]; } } @@ -146,10 +144,10 @@ pub(crate) fn eval_permutation_unit( state = constant_layer(state, HALF_N_FULL_ROUNDS + r); let state0_cubed = state[0] * state[0].square(); - yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); + yield_constr.constraint(state0_cubed - local_values[col_partial_mid_sbox(r)]); let state0_cubed = local_values[col_partial_mid_sbox(r)]; state[0] *= state0_cubed.square(); // Form state ** 7. - yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); + yield_constr.constraint(state[0] - local_values[col_partial_after_sbox(r)]); state[0] = local_values[col_partial_after_sbox(r)]; state = mds_layer(state); @@ -160,8 +158,7 @@ pub(crate) fn eval_permutation_unit( for i in 0..SPONGE_WIDTH { let state_cubed = state[i] * state[i].square(); - yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_second_mid_sbox(r, i)]); + yield_constr.constraint(state_cubed - local_values[col_full_second_mid_sbox(r, i)]); let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; state[i] *= state_cubed.square(); // Form state ** 7. } @@ -169,8 +166,7 @@ pub(crate) fn eval_permutation_unit( state = mds_layer(state); for i in 0..SPONGE_WIDTH { - yield_constr - .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); + yield_constr.constraint(state[i] - local_values[col_full_second_after_mds(r, i)]); state[i] = local_values[col_full_second_after_mds(r, i)]; } } @@ -197,7 +193,7 @@ pub(crate) fn eval_permutation_unit_recursively, co let state_cubed = builder.cube_extension(state[i]); let diff = builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); - yield_constr.constraint_wrapping(builder, diff); + yield_constr.constraint(builder, diff); let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); // Form state ** 7. @@ -208,7 +204,7 @@ pub(crate) fn eval_permutation_unit_recursively, co for i in 0..SPONGE_WIDTH { let diff = builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); - yield_constr.constraint_wrapping(builder, diff); + yield_constr.constraint(builder, diff); state[i] = local_values[col_full_first_after_mds(r, i)]; } } @@ -218,11 +214,11 @@ pub(crate) fn eval_permutation_unit_recursively, co let state0_cubed = builder.cube_extension(state[0]); let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); - yield_constr.constraint_wrapping(builder, diff); + yield_constr.constraint(builder, diff); let state0_cubed = local_values[col_partial_mid_sbox(r)]; state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); - yield_constr.constraint_wrapping(builder, diff); + yield_constr.constraint(builder, diff); state[0] = local_values[col_partial_after_sbox(r)]; state = F::mds_layer_recursive(builder, &state); @@ -239,7 +235,7 @@ pub(crate) fn eval_permutation_unit_recursively, co let state_cubed = builder.cube_extension(state[i]); let diff = builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); - yield_constr.constraint_wrapping(builder, diff); + yield_constr.constraint(builder, diff); let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); // Form state ** 7. @@ -250,7 +246,7 @@ pub(crate) fn eval_permutation_unit_recursively, co for i in 0..SPONGE_WIDTH { let diff = builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); - yield_constr.constraint_wrapping(builder, diff); + yield_constr.constraint(builder, diff); state[i] = local_values[col_full_second_after_mds(r, i)]; } }