diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index cbc2408b..1af02d7f 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -17,7 +17,7 @@ use crate::keccak::logic::{ }; use crate::keccak::registers::{ reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, - reg_b, reg_c, reg_c_partial, reg_step, NUM_REGISTERS, + reg_b, reg_c, reg_c_partial, reg_input_limb, reg_step, NUM_REGISTERS, }; use crate::keccak::round_flags::{eval_round_flags, eval_round_flags_recursively}; use crate::stark::Stark; @@ -65,6 +65,7 @@ impl, const D: usize> KeccakStark { fn generate_trace_rows_for_perm(&self, input: [u64; INPUT_LIMBS]) -> Vec<[F; NUM_REGISTERS]> { let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS]; + self.copy_input(input, &mut rows[0]); for x in 0..5 { for y in 0..5 { let input_xy = input[x * 5 + y]; @@ -76,6 +77,7 @@ impl, const D: usize> KeccakStark { self.generate_trace_row_for_round(&mut rows[0], 0); for round in 1..24 { + self.copy_input(input, &mut rows[round]); self.copy_output_to_input(rows[round - 1], &mut rows[round]); self.generate_trace_row_for_round(&mut rows[round], round); } @@ -188,6 +190,14 @@ impl, const D: usize> KeccakStark { row[out_reg_hi] = F::from_canonical_u64(row[in_reg_hi].to_canonical_u64() ^ rc_hi); } + fn copy_input(&self, input: [u64; INPUT_LIMBS], row: &mut [F; NUM_REGISTERS]) { + for i in 0..INPUT_LIMBS { + let (low, high) = (input[i] as u32, input[i] >> 32); + row[reg_input_limb(2 * i)] = F::from_canonical_u32(low); + row[reg_input_limb(2 * i + 1)] = F::from_canonical_u64(high); + } + } + pub fn generate_trace(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec> { let mut timing = TimingTree::new("generate trace", log::Level::Debug); @@ -223,6 +233,16 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark>(); let mut keccak_input: [[u64; 5]; 5] = [ input[0..5].try_into().unwrap(), diff --git a/evm/src/keccak/registers.rs b/evm/src/keccak/registers.rs index 3a891828..9eb265f2 100644 --- a/evm/src/keccak/registers.rs +++ b/evm/src/keccak/registers.rs @@ -1,4 +1,4 @@ -use crate::keccak::keccak_stark::NUM_ROUNDS; +use crate::keccak::keccak_stark::{INPUT_LIMBS, NUM_ROUNDS}; /// A register which is set to 1 if we are in the `i`th round, otherwise 0. pub(crate) const fn reg_step(i: usize) -> usize { @@ -6,6 +6,30 @@ pub(crate) const fn reg_step(i: usize) -> usize { i } +/// Registers to hold permutation inputs. +/// `reg_input_limb(2*i) -> input[i] as u32` +/// `reg_input_limb(2*i+1) -> input[i] >> 32` +pub(crate) const fn reg_input_limb(i: usize) -> usize { + debug_assert!(i < 2 * INPUT_LIMBS); + NUM_ROUNDS + i +} + +/// Registers to hold permutation outputs. +/// `reg_output_limb(2*i) -> output[i] as u32` +/// `reg_output_limb(2*i+1) -> output[i] >> 32` +#[allow(dead_code)] // TODO: Remove once it is used. +pub(crate) const fn reg_output_limb(i: usize) -> usize { + debug_assert!(i < 2 * INPUT_LIMBS); + let ii = i / 2; + let x = ii / 5; + let y = ii % 5; + if i % 2 == 0 { + reg_a_prime_prime_prime(x, y) + } else { + reg_a_prime_prime_prime(x, y) + 1 + } +} + const R: [[u8; 5]; 5] = [ [0, 36, 3, 41, 18], [1, 44, 10, 45, 2], @@ -14,7 +38,7 @@ const R: [[u8; 5]; 5] = [ [27, 20, 39, 8, 14], ]; -const START_A: usize = NUM_ROUNDS; +const START_A: usize = NUM_ROUNDS + 2 * INPUT_LIMBS; pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize { debug_assert!(x < 5); debug_assert!(y < 5);