diff --git a/evm/src/keccak/columns.rs b/evm/src/keccak/columns.rs index 039db078..afd92ad9 100644 --- a/evm/src/keccak/columns.rs +++ b/evm/src/keccak/columns.rs @@ -9,10 +9,6 @@ pub const fn reg_step(i: usize) -> usize { i } -/// A register which indicates if a row should be included in the CTL. Should be 1 only for certain -/// rows which are final steps, i.e. with `reg_step(23) = 1`. -pub const REG_FILTER: usize = NUM_ROUNDS; - /// Registers to hold permutation inputs. /// `reg_input_limb(2*i) -> input[i] as u32` /// `reg_input_limb(2*i+1) -> input[i] >> 32` @@ -52,7 +48,7 @@ const R: [[u8; 5]; 5] = [ [27, 20, 39, 8, 14], ]; -const START_PREIMAGE: usize = NUM_ROUNDS + 1; +const START_PREIMAGE: usize = NUM_ROUNDS; /// Registers to hold the original input to a permutation, i.e. the input to the first round. pub(crate) const fn reg_preimage(x: usize, y: usize) -> usize { debug_assert!(x < 5); diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 7c077da1..7bd3c385 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -15,7 +15,7 @@ use crate::cross_table_lookup::Column; use crate::keccak::columns::{ reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step, - NUM_COLUMNS, REG_FILTER, + NUM_COLUMNS, }; use crate::keccak::constants::{rc_value, rc_value_bit}; use crate::keccak::logic::{ @@ -39,7 +39,7 @@ pub fn ctl_data() -> Vec> { } pub fn ctl_filter() -> Column { - Column::single(REG_FILTER) + Column::single(reg_step(NUM_ROUNDS - 1)) } #[derive(Copy, Clone, Default)] @@ -58,19 +58,16 @@ impl, const D: usize> KeccakStark { let num_rows = (inputs.len() * NUM_ROUNDS) .max(min_rows) .next_power_of_two(); + let mut rows = Vec::with_capacity(num_rows); for input in inputs.iter() { - let mut rows_for_perm = self.generate_trace_rows_for_perm(*input); - // Since this is a real operation, not padding, we set the filter to 1 on the last row. - rows_for_perm[NUM_ROUNDS - 1][REG_FILTER] = F::ONE; + let rows_for_perm = self.generate_trace_rows_for_perm(*input); rows.extend(rows_for_perm); } - let pad_rows = self.generate_trace_rows_for_perm([0; NUM_INPUTS]); while rows.len() < num_rows { - rows.extend(&pad_rows); + rows.push([F::ZERO; NUM_COLUMNS]); } - rows.drain(num_rows..); rows } @@ -255,7 +252,7 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark>( yield_constr.constraint_first_row(vars.local_values[reg_step(i)]); } + // Flags should circularly increment, or be all zero for padding rows. + let next_any_flag = (0..NUM_ROUNDS) + .map(|i| vars.next_values[reg_step(i)]) + .sum::

(); for i in 0..NUM_ROUNDS { let current_round_flag = vars.local_values[reg_step(i)]; let next_round_flag = vars.next_values[reg_step((i + 1) % NUM_ROUNDS)]; - yield_constr.constraint_transition(next_round_flag - current_round_flag); + yield_constr.constraint_transition(next_any_flag * (next_round_flag - current_round_flag)); } + + // Padding rows should always be followed by padding rows. + let current_any_flag = (0..NUM_ROUNDS) + .map(|i| vars.local_values[reg_step(i)]) + .sum::

(); + yield_constr.constraint_transition(next_any_flag * (current_any_flag - F::ONE)); } pub(crate) fn eval_round_flags_recursively, const D: usize>( @@ -40,10 +50,20 @@ pub(crate) fn eval_round_flags_recursively, const D yield_constr.constraint_first_row(builder, vars.local_values[reg_step(i)]); } + // Flags should circularly increment, or be all zero for padding rows. + let next_any_flag = + builder.add_many_extension((0..NUM_ROUNDS).map(|i| vars.next_values[reg_step(i)])); for i in 0..NUM_ROUNDS { let current_round_flag = vars.local_values[reg_step(i)]; let next_round_flag = vars.next_values[reg_step((i + 1) % NUM_ROUNDS)]; let diff = builder.sub_extension(next_round_flag, current_round_flag); - yield_constr.constraint_transition(builder, diff); + let constraint = builder.mul_extension(next_any_flag, diff); + yield_constr.constraint_transition(builder, constraint); } + + // Padding rows should always be followed by padding rows. + let current_any_flag = + builder.add_many_extension((0..NUM_ROUNDS).map(|i| vars.local_values[reg_step(i)])); + let constraint = builder.mul_sub_extension(next_any_flag, current_any_flag, next_any_flag); + yield_constr.constraint_transition(builder, constraint); }