Remove redundant Keccak sponge cols (#1233)

* Rename columns in KeccakSponge for clarity

* Remove redundant columns

* Apply comments
This commit is contained in:
Robin Salen 2023-09-14 15:27:38 -04:00 committed by GitHub
parent 06bc73f7ea
commit 19220b21d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 60 deletions

View File

@ -5,11 +5,14 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks};
pub(crate) const KECCAK_WIDTH_BYTES: usize = 200;
pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4;
pub(crate) const KECCAK_WIDTH_MINUS_DIGEST_U32S: usize =
(KECCAK_WIDTH_BYTES - KECCAK_DIGEST_BYTES) / 4;
pub(crate) const KECCAK_RATE_BYTES: usize = 136;
pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4;
pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64;
pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4;
pub(crate) const KECCAK_DIGEST_BYTES: usize = 32;
pub(crate) const KECCAK_DIGEST_U32S: usize = KECCAK_DIGEST_BYTES / 4;
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
@ -52,10 +55,14 @@ pub(crate) struct KeccakSpongeColumnsView<T: Copy> {
pub xored_rate_u32s: [T; KECCAK_RATE_U32S],
/// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the
/// permutation is applied.
pub updated_state_u32s: [T; KECCAK_WIDTH_U32S],
/// permutation is applied, minus the first limbs where the digest is extracted from.
/// Those missing limbs can be recomputed from their corresponding bytes stored in
/// `updated_digest_state_bytes`.
pub partial_updated_state_u32s: [T; KECCAK_WIDTH_MINUS_DIGEST_U32S],
pub updated_state_bytes: [T; KECCAK_DIGEST_BYTES],
/// The first part of the state of the sponge, seen as bytes, after the permutation is applied.
/// This also represents the output digest of the Keccak sponge during the squeezing phase.
pub updated_digest_state_bytes: [T; KECCAK_DIGEST_BYTES],
}
// `u8` is guaranteed to have a `size_of` of 1.

View File

@ -28,7 +28,7 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
let mut outputs = Vec::with_capacity(8);
for i in (0..8).rev() {
let cur_col = Column::linear_combination(
cols.updated_state_bytes[i * 4..(i + 1) * 4]
cols.updated_digest_state_bytes[i * 4..(i + 1) * 4]
.iter()
.enumerate()
.map(|(j, &c)| (c, F::from_canonical_u64(1 << (24 - 8 * j)))),
@ -49,15 +49,30 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
Column::singles(
let mut res: Vec<_> = Column::singles(
[
cols.xored_rate_u32s.as_slice(),
&cols.original_capacity_u32s,
&cols.updated_state_u32s,
]
.concat(),
)
.collect()
.collect();
// We recover the 32-bit digest limbs from their corresponding bytes,
// and then append them to the rest of the updated state limbs.
let digest_u32s = cols.updated_digest_state_bytes.chunks_exact(4).map(|c| {
Column::linear_combination(
c.iter()
.enumerate()
.map(|(i, &b)| (b, F::from_canonical_usize(1 << (8 * i)))),
)
});
res.extend(digest_u32s);
res.extend(Column::singles(&cols.partial_updated_state_u32s));
res
}
pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
@ -239,7 +254,21 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
block.try_into().unwrap(),
);
sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32);
sponge_state[..KECCAK_DIGEST_U32S]
.iter_mut()
.zip(row.updated_digest_state_bytes.chunks_exact(4))
.for_each(|(s, bs)| {
*s = bs
.iter()
.enumerate()
.map(|(i, b)| (b.to_canonical_u64() as u32) << (8 * i))
.sum();
});
sponge_state[KECCAK_DIGEST_U32S..]
.iter_mut()
.zip(row.partial_updated_state_u32s)
.for_each(|(s, x)| *s = x.to_canonical_u64() as u32);
rows.push(row.into());
already_absorbed_bytes += KECCAK_RATE_BYTES;
@ -357,24 +386,33 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32);
keccakf_u32s(&mut sponge_state);
row.updated_state_u32s = sponge_state.map(F::from_canonical_u32);
let is_final_block = row.is_final_input_len.iter().copied().sum::<F>() == F::ONE;
if is_final_block {
for (l, &elt) in row.updated_state_u32s[..8].iter().enumerate() {
// Store all but the first `KECCAK_DIGEST_U32S` limbs in the updated state.
// Those missing limbs will be broken down into bytes and stored separately.
row.partial_updated_state_u32s.copy_from_slice(
&sponge_state[KECCAK_DIGEST_U32S..]
.iter()
.copied()
.map(|i| F::from_canonical_u32(i))
.collect::<Vec<_>>(),
);
sponge_state[..KECCAK_DIGEST_U32S]
.iter()
.enumerate()
.for_each(|(l, &elt)| {
let mut cur_elt = elt;
(0..4).for_each(|i| {
row.updated_state_bytes[l * 4 + i] =
F::from_canonical_u32((cur_elt.to_canonical_u64() & 0xFF) as u32);
cur_elt = F::from_canonical_u64(cur_elt.to_canonical_u64() >> 8);
row.updated_digest_state_bytes[l * 4 + i] =
F::from_canonical_u32(cur_elt & 0xFF);
cur_elt >>= 8;
});
let mut s = row.updated_state_bytes[l * 4].to_canonical_u64();
// 32-bit limb reconstruction consistency check.
let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64();
for i in 1..4 {
s += row.updated_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i);
s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i);
}
assert_eq!(elt, F::from_canonical_u64(s), "not equal");
}
}
assert_eq!(elt as u64, s, "not equal");
})
}
fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] {
@ -445,26 +483,39 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
);
// If this is a full-input block, the next row's "before" should match our "after" state.
for (current_bytes_after, next_before) in local_values
.updated_digest_state_bytes
.chunks_exact(4)
.zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S])
{
let mut current_after = current_bytes_after[0];
for i in 1..4 {
current_after +=
current_bytes_after[i] * P::from(FE::from_canonical_usize(1 << (8 * i)));
}
yield_constr
.constraint_transition(is_full_input_block * (*next_before - current_after));
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.zip(next_values.original_rate_u32s.iter())
.zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter())
{
yield_constr.constraint_transition(is_full_input_block * (next_before - current_after));
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.skip(KECCAK_RATE_U32S)
.skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S)
.zip(next_values.original_capacity_u32s.iter())
{
yield_constr.constraint_transition(is_full_input_block * (next_before - current_after));
}
// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136.
// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`.
yield_constr.constraint_transition(
is_full_input_block
* (already_absorbed_bytes + P::from(FE::from_canonical_u64(136))
* (already_absorbed_bytes + P::from(FE::from_canonical_usize(KECCAK_RATE_BYTES))
- next_values.already_absorbed_bytes),
);
@ -481,16 +532,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
let entry_match = offset - P::from(FE::from_canonical_usize(i));
yield_constr.constraint(is_final_len * entry_match);
}
// Adding constraints for byte columns.
for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() {
let mut s = local_values.updated_state_bytes[l * 4];
for i in 1..4 {
s += local_values.updated_state_bytes[l * 4 + i]
* P::from(FE::from_canonical_usize(1 << (8 * i)));
}
yield_constr.constraint(is_final_block * (s - elt));
}
}
fn eval_ext_circuit(
@ -566,19 +607,36 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
yield_constr.constraint_transition(builder, constraint);
// If this is a full-input block, the next row's "before" should match our "after" state.
for (current_bytes_after, next_before) in local_values
.updated_digest_state_bytes
.chunks_exact(4)
.zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S])
{
let mut current_after = current_bytes_after[0];
for i in 1..4 {
current_after = builder.mul_const_add_extension(
F::from_canonical_usize(1 << (8 * i)),
current_bytes_after[i],
current_after,
);
}
let diff = builder.sub_extension(*next_before, current_after);
let constraint = builder.mul_extension(is_full_input_block, diff);
yield_constr.constraint_transition(builder, constraint);
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.zip(next_values.original_rate_u32s.iter())
.zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter())
{
let diff = builder.sub_extension(next_before, current_after);
let constraint = builder.mul_extension(is_full_input_block, diff);
yield_constr.constraint_transition(builder, constraint);
}
for (&current_after, &next_before) in local_values
.updated_state_u32s
.partial_updated_state_u32s
.iter()
.skip(KECCAK_RATE_U32S)
.skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S)
.zip(next_values.original_capacity_u32s.iter())
{
let diff = builder.sub_extension(next_before, current_after);
@ -586,9 +644,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
yield_constr.constraint_transition(builder, constraint);
}
// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136.
let absorbed_bytes =
builder.add_const_extension(already_absorbed_bytes, F::from_canonical_u64(136));
// If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`.
let absorbed_bytes = builder.add_const_extension(
already_absorbed_bytes,
F::from_canonical_usize(KECCAK_RATE_BYTES),
);
let absorbed_diff =
builder.sub_extension(absorbed_bytes, next_values.already_absorbed_bytes);
let constraint = builder.mul_extension(is_full_input_block, absorbed_diff);
@ -615,21 +675,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakSpongeS
let constraint = builder.mul_extension(is_final_len, entry_match);
yield_constr.constraint(builder, constraint);
}
// Adding constraints for byte columns.
for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() {
let mut s = local_values.updated_state_bytes[l * 4];
for i in 1..4 {
s = builder.mul_const_add_extension(
F::from_canonical_usize(1 << (8 * i)),
local_values.updated_state_bytes[l * 4 + i],
s,
);
}
let constraint = builder.sub_extension(s, elt);
let constraint = builder.mul_extension(is_final_block, constraint);
yield_constr.constraint(builder, constraint);
}
}
fn constraint_degree(&self) -> usize {
@ -698,9 +743,10 @@ mod tests {
let rows = stark.generate_rows_for_op(op);
assert_eq!(rows.len(), 1);
let last_row: &KeccakSpongeColumnsView<F> = rows.last().unwrap().borrow();
let output = last_row.updated_state_u32s[..8]
let output = last_row
.updated_digest_state_bytes
.iter()
.flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes())
.map(|x| x.to_canonical_u64() as u8)
.collect_vec();
assert_eq!(output, expected_output.0);