Remove Keccak input limbs

This commit is contained in:
wborgeaud 2022-06-14 01:21:17 +02:00
parent d626679c6c
commit 1dce18495a
4 changed files with 18 additions and 56 deletions

View File

@ -98,8 +98,8 @@ mod tests {
.map(|i| {
(0..2 * NUM_INPUTS)
.map(|j| {
keccak_trace[keccak::registers::reg_input_limb(j)].values
[(i + 1) * NUM_ROUNDS - 1]
keccak::registers::reg_input_limb(j)
.eval_table(&keccak_trace, (i + 1) * NUM_ROUNDS - 1)
})
.collect::<Vec<_>>()
.try_into()
@ -143,8 +143,11 @@ mod tests {
let mut keccak_keccak_input_output = (0..2 * NUM_INPUTS)
.map(keccak::registers::reg_input_limb)
.collect::<Vec<_>>();
keccak_keccak_input_output
.extend((0..2 * NUM_INPUTS).map(keccak::registers::reg_output_limb));
keccak_keccak_input_output.extend(Column::singles(
(0..2 * NUM_INPUTS)
.map(keccak::registers::reg_output_limb)
.collect(),
));
let cross_table_lookups = vec![CrossTableLookup::new(
vec![TableWithColumns::new(
Table::Cpu,
@ -153,7 +156,7 @@ mod tests {
)],
TableWithColumns::new(
Table::Keccak,
Column::singles(keccak_keccak_input_output),
keccak_keccak_input_output,
Column::single(keccak::registers::reg_step(NUM_ROUNDS - 1)),
),
None,

View File

@ -51,8 +51,8 @@ impl<F: Field> Column<F> {
Self::LinearCombination(v)
}
pub fn le_bits(cs: &[usize]) -> Self {
Self::linear_combination(cs.iter().copied().zip(F::TWO.powers()))
pub fn le_bits<I: IntoIterator<Item = usize>>(cs: I) -> Self {
Self::linear_combination(cs.into_iter().zip(F::TWO.powers()))
}
pub fn sum(cs: &[usize]) -> Self {

View File

@ -62,7 +62,6 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_REGISTERS]> {
let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS];
self.copy_input(input, &mut rows[0]);
for x in 0..5 {
for y in 0..5 {
let input_xy = input[x * 5 + y];
@ -74,7 +73,6 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
self.generate_trace_row_for_round(&mut rows[0], 0);
for round in 1..24 {
self.copy_input(input, &mut rows[round]);
self.copy_output_to_input(rows[round - 1], &mut rows[round]);
self.generate_trace_row_for_round(&mut rows[round], round);
}
@ -187,14 +185,6 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
row[out_reg_hi] = F::from_canonical_u64(row[in_reg_hi].to_canonical_u64() ^ rc_hi);
}
fn copy_input(&self, input: [u64; NUM_INPUTS], row: &mut [F; NUM_REGISTERS]) {
for i in 0..NUM_INPUTS {
let (low, high) = (input[i] as u32, input[i] >> 32);
row[reg_input_limb(2 * i)] = F::from_canonical_u32(low);
row[reg_input_limb(2 * i + 1)] = F::from_canonical_u64(high);
}
}
pub fn generate_trace(&self, inputs: Vec<[u64; NUM_INPUTS]>) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
@ -230,23 +220,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
{
eval_round_flags(vars, yield_constr);
for i in 0..2 * NUM_INPUTS {
let local_input_limb = vars.local_values[reg_input_limb(i)];
let next_input_limb = vars.next_values[reg_input_limb(i)];
let is_last_round = vars.local_values[reg_step(NUM_ROUNDS - 1)];
// Constrain the input registers to be equal throughout the rounds of a permutation.
yield_constr.constraint_transition(
(P::ONES - is_last_round) * (next_input_limb - local_input_limb),
);
// Verify that the bit decomposition is done correctly.
let range = if i % 2 == 0 { 0..32 } else { 32..64 };
let bits = range.map(|j| vars.local_values[reg_a((i / 2) / 5, (i / 2) % 5, j)]);
let expected_input_limb = bits.rev().fold(P::ZEROS, |acc, b| acc.doubles() + b);
let is_first_round = vars.local_values[reg_step(0)];
yield_constr.constraint(is_first_round * (local_input_limb - expected_input_limb));
}
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
for x in 0..5 {
for z in 0..64 {
@ -390,25 +363,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
eval_round_flags_recursively(builder, vars, yield_constr);
for i in 0..2 * NUM_INPUTS {
let local_input_limb = vars.local_values[reg_input_limb(i)];
let next_input_limb = vars.next_values[reg_input_limb(i)];
let is_last_round = vars.local_values[reg_step(NUM_ROUNDS - 1)];
let diff = builder.sub_extension(local_input_limb, next_input_limb);
let constraint = builder.mul_sub_extension(is_last_round, diff, diff);
yield_constr.constraint_transition(builder, constraint);
let range = if i % 2 == 0 { 0..32 } else { 32..64 };
let bits = range
.map(|j| vars.local_values[reg_a((i / 2) / 5, (i / 2) % 5, j)])
.collect::<Vec<_>>();
let expected_input_limb = reduce_with_powers_ext_circuit(builder, &bits, two);
let is_first_round = vars.local_values[reg_step(0)];
let diff = builder.sub_extension(local_input_limb, expected_input_limb);
let constraint = builder.mul_extension(is_first_round, diff);
yield_constr.constraint(builder, constraint);
}
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
for x in 0..5 {
for z in 0..64 {

View File

@ -1,3 +1,6 @@
use plonky2::field::field_types::Field;
use crate::cross_table_lookup::Column;
use crate::keccak::keccak_stark::{NUM_INPUTS, NUM_ROUNDS};
/// A register which is set to 1 if we are in the `i`th round, otherwise 0.
@ -9,9 +12,11 @@ pub const fn reg_step(i: usize) -> usize {
/// Registers to hold permutation inputs.
/// `reg_input_limb(2*i) -> input[i] as u32`
/// `reg_input_limb(2*i+1) -> input[i] >> 32`
pub const fn reg_input_limb(i: usize) -> usize {
pub fn reg_input_limb<F: Field>(i: usize) -> Column<F> {
debug_assert!(i < 2 * NUM_INPUTS);
NUM_ROUNDS + i
let range = if i % 2 == 0 { 0..32 } else { 32..64 };
let bits = range.map(|j| reg_a((i / 2) / 5, (i / 2) % 5, j));
Column::le_bits(bits)
}
/// Registers to hold permutation outputs.
@ -37,7 +42,7 @@ const R: [[u8; 5]; 5] = [
[27, 20, 39, 8, 14],
];
const START_A: usize = NUM_ROUNDS + 2 * NUM_INPUTS;
const START_A: usize = NUM_ROUNDS;
pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);