From 19220b21d71b828c3d952e11f2c7716244e0ec43 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:27:38 -0400 Subject: [PATCH] Remove redundant Keccak sponge cols (#1233) * Rename columns in KeccakSponge for clarity * Remove redundant columns * Apply comments --- evm/src/keccak_sponge/columns.rs | 13 +- evm/src/keccak_sponge/keccak_sponge_stark.rs | 160 ++++++++++++------- 2 files changed, 113 insertions(+), 60 deletions(-) diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs index 44f66a5d..431c09e0 100644 --- a/evm/src/keccak_sponge/columns.rs +++ b/evm/src/keccak_sponge/columns.rs @@ -5,11 +5,14 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4; +pub(crate) const KECCAK_WIDTH_MINUS_DIGEST_U32S: usize = + (KECCAK_WIDTH_BYTES - KECCAK_DIGEST_BYTES) / 4; pub(crate) const KECCAK_RATE_BYTES: usize = 136; pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4; pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64; pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4; pub(crate) const KECCAK_DIGEST_BYTES: usize = 32; +pub(crate) const KECCAK_DIGEST_U32S: usize = KECCAK_DIGEST_BYTES / 4; #[repr(C)] #[derive(Eq, PartialEq, Debug)] @@ -52,10 +55,14 @@ pub(crate) struct KeccakSpongeColumnsView { pub xored_rate_u32s: [T; KECCAK_RATE_U32S], /// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the - /// permutation is applied. - pub updated_state_u32s: [T; KECCAK_WIDTH_U32S], + /// permutation is applied, minus the first limbs where the digest is extracted from. + /// Those missing limbs can be recomputed from their corresponding bytes stored in + /// `updated_digest_state_bytes`. + pub partial_updated_state_u32s: [T; KECCAK_WIDTH_MINUS_DIGEST_U32S], - pub updated_state_bytes: [T; KECCAK_DIGEST_BYTES], + /// The first part of the state of the sponge, seen as bytes, after the permutation is applied. + /// This also represents the output digest of the Keccak sponge during the squeezing phase. + pub updated_digest_state_bytes: [T; KECCAK_DIGEST_BYTES], } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index 5f1a49cc..d78e9651 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -28,7 +28,7 @@ pub(crate) fn ctl_looked_data() -> Vec> { let mut outputs = Vec::with_capacity(8); for i in (0..8).rev() { let cur_col = Column::linear_combination( - cols.updated_state_bytes[i * 4..(i + 1) * 4] + cols.updated_digest_state_bytes[i * 4..(i + 1) * 4] .iter() .enumerate() .map(|(j, &c)| (c, F::from_canonical_u64(1 << (24 - 8 * j)))), @@ -49,15 +49,30 @@ pub(crate) fn ctl_looked_data() -> Vec> { pub(crate) fn ctl_looking_keccak() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; - Column::singles( + let mut res: Vec<_> = Column::singles( [ cols.xored_rate_u32s.as_slice(), &cols.original_capacity_u32s, - &cols.updated_state_u32s, ] .concat(), ) - .collect() + .collect(); + + // We recover the 32-bit digest limbs from their corresponding bytes, + // and then append them to the rest of the updated state limbs. + let digest_u32s = cols.updated_digest_state_bytes.chunks_exact(4).map(|c| { + Column::linear_combination( + c.iter() + .enumerate() + .map(|(i, &b)| (b, F::from_canonical_usize(1 << (8 * i)))), + ) + }); + + res.extend(digest_u32s); + + res.extend(Column::singles(&cols.partial_updated_state_u32s)); + + res } pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { @@ -239,7 +254,21 @@ impl, const D: usize> KeccakSpongeStark { block.try_into().unwrap(), ); - sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32); + sponge_state[..KECCAK_DIGEST_U32S] + .iter_mut() + .zip(row.updated_digest_state_bytes.chunks_exact(4)) + .for_each(|(s, bs)| { + *s = bs + .iter() + .enumerate() + .map(|(i, b)| (b.to_canonical_u64() as u32) << (8 * i)) + .sum(); + }); + + sponge_state[KECCAK_DIGEST_U32S..] + .iter_mut() + .zip(row.partial_updated_state_u32s) + .for_each(|(s, x)| *s = x.to_canonical_u64() as u32); rows.push(row.into()); already_absorbed_bytes += KECCAK_RATE_BYTES; @@ -357,24 +386,33 @@ impl, const D: usize> KeccakSpongeStark { row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32); keccakf_u32s(&mut sponge_state); - row.updated_state_u32s = sponge_state.map(F::from_canonical_u32); - let is_final_block = row.is_final_input_len.iter().copied().sum::() == F::ONE; - if is_final_block { - for (l, &elt) in row.updated_state_u32s[..8].iter().enumerate() { + // Store all but the first `KECCAK_DIGEST_U32S` limbs in the updated state. + // Those missing limbs will be broken down into bytes and stored separately. + row.partial_updated_state_u32s.copy_from_slice( + &sponge_state[KECCAK_DIGEST_U32S..] + .iter() + .copied() + .map(|i| F::from_canonical_u32(i)) + .collect::>(), + ); + sponge_state[..KECCAK_DIGEST_U32S] + .iter() + .enumerate() + .for_each(|(l, &elt)| { let mut cur_elt = elt; (0..4).for_each(|i| { - row.updated_state_bytes[l * 4 + i] = - F::from_canonical_u32((cur_elt.to_canonical_u64() & 0xFF) as u32); - cur_elt = F::from_canonical_u64(cur_elt.to_canonical_u64() >> 8); + row.updated_digest_state_bytes[l * 4 + i] = + F::from_canonical_u32(cur_elt & 0xFF); + cur_elt >>= 8; }); - let mut s = row.updated_state_bytes[l * 4].to_canonical_u64(); + // 32-bit limb reconstruction consistency check. + let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64(); for i in 1..4 { - s += row.updated_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); + s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); } - assert_eq!(elt, F::from_canonical_u64(s), "not equal"); - } - } + assert_eq!(elt as u64, s, "not equal"); + }) } fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] { @@ -445,26 +483,39 @@ impl, const D: usize> Stark for KeccakSpongeS ); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after += + current_bytes_after[i] * P::from(FE::from_canonical_usize(1 << (8 * i))); + } + yield_constr + .constraint_transition(is_full_input_block * (*next_before - current_after)); + } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. yield_constr.constraint_transition( is_full_input_block - * (already_absorbed_bytes + P::from(FE::from_canonical_u64(136)) + * (already_absorbed_bytes + P::from(FE::from_canonical_usize(KECCAK_RATE_BYTES)) - next_values.already_absorbed_bytes), ); @@ -481,16 +532,6 @@ impl, const D: usize> Stark for KeccakSpongeS let entry_match = offset - P::from(FE::from_canonical_usize(i)); yield_constr.constraint(is_final_len * entry_match); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() { - let mut s = local_values.updated_state_bytes[l * 4]; - for i in 1..4 { - s += local_values.updated_state_bytes[l * 4 + i] - * P::from(FE::from_canonical_usize(1 << (8 * i))); - } - yield_constr.constraint(is_final_block * (s - elt)); - } } fn eval_ext_circuit( @@ -566,19 +607,36 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after = builder.mul_const_add_extension( + F::from_canonical_usize(1 << (8 * i)), + current_bytes_after[i], + current_after, + ); + } + let diff = builder.sub_extension(*next_before, current_after); + let constraint = builder.mul_extension(is_full_input_block, diff); + yield_constr.constraint_transition(builder, constraint); + } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { let diff = builder.sub_extension(next_before, current_after); let constraint = builder.mul_extension(is_full_input_block, diff); yield_constr.constraint_transition(builder, constraint); } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { let diff = builder.sub_extension(next_before, current_after); @@ -586,9 +644,11 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. - let absorbed_bytes = - builder.add_const_extension(already_absorbed_bytes, F::from_canonical_u64(136)); + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. + let absorbed_bytes = builder.add_const_extension( + already_absorbed_bytes, + F::from_canonical_usize(KECCAK_RATE_BYTES), + ); let absorbed_diff = builder.sub_extension(absorbed_bytes, next_values.already_absorbed_bytes); let constraint = builder.mul_extension(is_full_input_block, absorbed_diff); @@ -615,21 +675,6 @@ impl, const D: usize> Stark for KeccakSpongeS let constraint = builder.mul_extension(is_final_len, entry_match); yield_constr.constraint(builder, constraint); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() { - let mut s = local_values.updated_state_bytes[l * 4]; - for i in 1..4 { - s = builder.mul_const_add_extension( - F::from_canonical_usize(1 << (8 * i)), - local_values.updated_state_bytes[l * 4 + i], - s, - ); - } - let constraint = builder.sub_extension(s, elt); - let constraint = builder.mul_extension(is_final_block, constraint); - yield_constr.constraint(builder, constraint); - } } fn constraint_degree(&self) -> usize { @@ -698,9 +743,10 @@ mod tests { let rows = stark.generate_rows_for_op(op); assert_eq!(rows.len(), 1); let last_row: &KeccakSpongeColumnsView = rows.last().unwrap().borrow(); - let output = last_row.updated_state_u32s[..8] + let output = last_row + .updated_digest_state_bytes .iter() - .flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes()) + .map(|x| x.to_canonical_u64() as u8) .collect_vec(); assert_eq!(output, expected_output.0);