From 61819af07d51b411dc7a03238d9c749016840c18 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 14 Aug 2022 21:03:45 -0700 Subject: [PATCH] Improved Keccak implementation Based on the approach @SyxtonPrime described. In terms of columns, the changes are: - Store inputs (`A`) as `u32` limbs, rather than individual bits. - Remove `C_partial`. It was used to store an intermediate product in a 5-way xor, but we've since realized that we can do a 5-way xor directly. - Add `C_prime`, an intermediate result used to help verify the relation between `A` and `A'`. --- evm/src/keccak/columns.rs | 52 +++---- evm/src/keccak/keccak_stark.rs | 256 ++++++++++++++++++--------------- 2 files changed, 168 insertions(+), 140 deletions(-) diff --git a/evm/src/keccak/columns.rs b/evm/src/keccak/columns.rs index 2d9a35f0..39116b4a 100644 --- a/evm/src/keccak/columns.rs +++ b/evm/src/keccak/columns.rs @@ -14,9 +14,12 @@ pub const fn reg_step(i: usize) -> usize { /// `reg_input_limb(2*i+1) -> input[i] >> 32` pub fn reg_input_limb(i: usize) -> Column { debug_assert!(i < 2 * NUM_INPUTS); - let range = if i % 2 == 0 { 0..32 } else { 32..64 }; - let bits = range.map(|j| reg_a((i / 2) / 5, (i / 2) % 5, j)); - Column::le_bits(bits) + let i_u64 = i / 2; // The index of the 64-bit chunk. + let x = i_u64 / 5; + let y = i_u64 % 5; + let reg_low_limb = reg_a(x, y); + let is_high_limb = i % 2; + Column::single(reg_low_limb + is_high_limb) } /// Registers to hold permutation outputs. @@ -24,14 +27,11 @@ pub fn reg_input_limb(i: usize) -> Column { /// `reg_output_limb(2*i+1) -> output[i] >> 32` pub const fn reg_output_limb(i: usize) -> usize { debug_assert!(i < 2 * NUM_INPUTS); - let ii = i / 2; - let x = ii / 5; - let y = ii % 5; - if i % 2 == 0 { - reg_a_prime_prime_prime(x, y) - } else { - reg_a_prime_prime_prime(x, y) + 1 - } + let i_u64 = i / 2; // The index of the 64-bit chunk. + let x = i_u64 / 5; + let y = i_u64 % 5; + let is_high_limb = i % 2; + reg_a_prime_prime_prime(x, y) + is_high_limb } const R: [[u8; 5]; 5] = [ @@ -43,31 +43,33 @@ const R: [[u8; 5]; 5] = [ ]; const START_A: usize = NUM_ROUNDS; -pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize { +pub(crate) const fn reg_a(x: usize, y: usize) -> usize { debug_assert!(x < 5); debug_assert!(y < 5); - debug_assert!(z < 64); - START_A + x * 64 * 5 + y * 64 + z + START_A + (x * 5 + y) * 2 } -// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2]) -const START_C_PARTIAL: usize = START_A + 5 * 5 * 64; -pub(crate) const fn reg_c_partial(x: usize, z: usize) -> usize { - START_C_PARTIAL + x * 64 + z -} - -// C[x] = xor(C_partial[x], A[x, 3], A[x, 4]) -const START_C: usize = START_C_PARTIAL + 5 * 64; +// C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]) +const START_C: usize = START_A + 5 * 5 * 2; pub(crate) const fn reg_c(x: usize, z: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(z < 64); START_C + x * 64 + z } -// D is inlined. -// const fn reg_d(x: usize, z: usize) {} +// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]) +const START_C_PRIME: usize = START_C + 5 * 64; +pub(crate) const fn reg_c_prime(x: usize, z: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(z < 64); + START_C_PRIME + x * 64 + z +} + +// Note: D is inlined, not stored in the witness. // A'[x, y] = xor(A[x, y], D[x]) // = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) -const START_A_PRIME: usize = START_C + 5 * 64; +const START_A_PRIME: usize = START_C_PRIME + 5 * 64; pub(crate) const fn reg_a_prime(x: usize, y: usize, z: usize) -> usize { debug_assert!(x < 5); debug_assert!(y < 5); diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 53dd66ab..94fa795d 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -15,7 +15,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cross_table_lookup::Column; use crate::keccak::columns::{ reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, - reg_b, reg_c, reg_c_partial, reg_input_limb, reg_output_limb, reg_step, NUM_COLUMNS, + reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_step, NUM_COLUMNS, }; use crate::keccak::constants::{rc_value, rc_value_bit}; use crate::keccak::logic::{ @@ -77,9 +77,10 @@ impl, const D: usize> KeccakStark { for x in 0..5 { for y in 0..5 { let input_xy = input[x * 5 + y]; - for z in 0..64 { - rows[0][reg_a(x, y, z)] = F::from_canonical_u64((input_xy >> z) & 1); - } + let reg_lo = reg_a(x, y); + let reg_hi = reg_lo + 1; + rows[0][reg_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF); + rows[0][reg_hi] = F::from_canonical_u64(input_xy >> 32); } } @@ -95,20 +96,12 @@ impl, const D: usize> KeccakStark { fn copy_output_to_input(&self, prev_row: [F; NUM_COLUMNS], next_row: &mut [F; NUM_COLUMNS]) { for x in 0..5 { for y in 0..5 { - let cur_lo = prev_row[reg_a_prime_prime_prime(x, y)]; - let cur_hi = prev_row[reg_a_prime_prime_prime(x, y) + 1]; - let cur_u64 = cur_lo.to_canonical_u64() | (cur_hi.to_canonical_u64() << 32); - let bit_values: Vec = (0..64) - .scan(cur_u64, |acc, _| { - let tmp = *acc & 1; - *acc >>= 1; - Some(tmp) - }) - .collect(); - - for z in 0..64 { - next_row[reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]); - } + let in_lo = reg_a(x, y); + let in_hi = in_lo + 1; + let out_lo = reg_a_prime_prime_prime(x, y); + let out_hi = out_lo + 1; + next_row[in_lo] = prev_row[out_lo]; + next_row[in_hi] = prev_row[out_hi]; } } } @@ -116,14 +109,28 @@ impl, const D: usize> KeccakStark { fn generate_trace_row_for_round(&self, row: &mut [F; NUM_COLUMNS], round: usize) { row[reg_step(round)] = F::ONE; - // Populate C partial and C. + // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]). for x in 0..5 { for z in 0..64 { - let a = [0, 1, 2, 3, 4].map(|i| row[reg_a(x, i, z)]); - let c_partial = xor([a[0], a[1], a[2]]); - let c = xor([c_partial, a[3], a[4]]); - row[reg_c_partial(x, z)] = c_partial; - row[reg_c(x, z)] = c; + let is_high_limb = z / 32; + let bit_in_limb = z % 32; + let a = [0, 1, 2, 3, 4].map(|i| { + let reg_a_limb = reg_a(x, i) + is_high_limb; + let a_limb = row[reg_a_limb].to_canonical_u64() as u32; + F::from_bool(((a_limb >> bit_in_limb) & 1) != 0) + }); + row[reg_c(x, z)] = xor(a); + } + } + + // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). + for x in 0..5 { + for z in 0..64 { + row[reg_c_prime(x, z)] = xor([ + row[reg_c(x, z)], + row[reg_c((x + 4) % 5, z)], + row[reg_c((x + 1) % 5, (z + 63) % 64)], + ]); } } @@ -133,8 +140,13 @@ impl, const D: usize> KeccakStark { for x in 0..5 { for y in 0..5 { for z in 0..64 { + let is_high_limb = z / 32; + let bit_in_limb = z % 32; + let reg_a_limb = reg_a(x, y) + is_high_limb; + let a_limb = row[reg_a_limb].to_canonical_u64() as u32; + let a_bit = F::from_bool(((a_limb >> bit_in_limb) & 1) != 0); row[reg_a_prime(x, y, z)] = xor([ - row[reg_a(x, y, z)], + a_bit, row[reg_c((x + 4) % 5, z)], row[reg_c((x + 1) % 5, (z + 64 - 1) % 64)], ]); @@ -228,44 +240,58 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, ) { let two = builder.two(); + let two_ext = builder.two_extension(); + let four_ext = builder.constant_extension(F::Extension::from_canonical_u8(4)); eval_round_flags_recursively(builder, vars, yield_constr); - // C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2]) + // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). for x in 0..5 { for z in 0..64 { - let c_partial = vars.local_values[reg_c_partial(x, z)]; - let a_0 = vars.local_values[reg_a(x, 0, z)]; - let a_1 = vars.local_values[reg_a(x, 1, z)]; - let a_2 = vars.local_values[reg_a(x, 2, z)]; - - let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2); - let diff = builder.sub_extension(c_partial, xor_012); + let xor = xor3_gen_circuit( + builder, + vars.local_values[reg_c(x, z)], + vars.local_values[reg_c((x + 4) % 5, z)], + vars.local_values[reg_c((x + 1) % 5, (z + 63) % 64)], + ); + let c_prime = vars.local_values[reg_c_prime(x, z)]; + let diff = builder.sub_extension(c_prime, xor); yield_constr.constraint(builder, diff); } } - // C[x] = xor(C_partial[x], A[x, 3], A[x, 4]) + // Check that the input limbs are consistent with A' and D. + // A[x, y, z] = xor(A'[x, y, z], D[x, y, z]) + // = xor(A'[x, y, z], C[x - 1, z], C[x + 1, z - 1]) + // = xor(A'[x, y, z], C[x, z], C'[x, z]). + // The last step is valid based on the identity we checked above. + // It isn't required, but makes this check a bit cleaner. for x in 0..5 { - for z in 0..64 { - let c = vars.local_values[reg_c(x, z)]; - let xor_012 = vars.local_values[reg_c_partial(x, z)]; - let a_3 = vars.local_values[reg_a(x, 3, z)]; - let a_4 = vars.local_values[reg_a(x, 4, z)]; - - let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4); - let diff = builder.sub_extension(c, xor_01234); - yield_constr.constraint(builder, diff); - } - } - - // A'[x, y] = xor(A[x, y], D[x]) - // = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) - for x in 0..5 { - for z in 0..64 { - let c_left = vars.local_values[reg_c((x + 4) % 5, z)]; - let c_right = vars.local_values[reg_c((x + 1) % 5, (z + 64 - 1) % 64)]; - let d = xor_gen_circuit(builder, c_left, c_right); - - for y in 0..5 { - let a = vars.local_values[reg_a(x, y, z)]; + for y in 0..5 { + let a_lo = vars.local_values[reg_a(x, y)]; + let a_hi = vars.local_values[reg_a(x, y) + 1]; + let mut get_bit = |z| { let a_prime = vars.local_values[reg_a_prime(x, y, z)]; - let xor = xor_gen_circuit(builder, d, a); - let diff = builder.sub_extension(a_prime, xor); - yield_constr.constraint(builder, diff); - } + let c = vars.local_values[reg_c(x, z)]; + let c_prime = vars.local_values[reg_c_prime(x, z)]; + xor3_gen_circuit(builder, a_prime, c, c_prime) + }; + let bits_lo = (0..32).map(&mut get_bit).collect_vec(); + let bits_hi = (32..64).map(get_bit).collect_vec(); + let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two); + let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two); + let diff = builder.sub_extension(computed_lo, a_lo); + yield_constr.constraint(builder, diff); + let diff = builder.sub_extension(computed_hi, a_hi); + yield_constr.constraint(builder, diff); + } + } + + // xor_{i=0}^4 A'[x, i, z] = C'[x, z], so for each x, z, + // diff * (diff - 2) * (diff - 4) = 0, where + // diff = sum_{i=0}^4 A'[x, i, z] - C'[x, z] + for x in 0..5 { + for z in 0..64 { + let sum = builder.add_many_extension( + [0, 1, 2, 3, 4].map(|i| vars.local_values[reg_a_prime(x, i, z)]), + ); + let diff = builder.sub_extension(sum, vars.local_values[reg_c_prime(x, z)]); + let diff_minus_two = builder.sub_extension(diff, two_ext); + let diff_minus_four = builder.sub_extension(diff, four_ext); + let constraint = + builder.mul_many_extension([diff, diff_minus_two, diff_minus_four]); + yield_constr.constraint(builder, constraint); } } @@ -495,18 +526,13 @@ impl, const D: usize> Stark for KeccakStark