Add Keccak input registers

This commit is contained in:
wborgeaud 2022-06-09 22:31:33 +02:00
parent 10ac355d06
commit 1cc38bb032
2 changed files with 60 additions and 13 deletions

View File

@ -17,7 +17,7 @@ use crate::keccak::logic::{
};
use crate::keccak::registers::{
reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime,
reg_b, reg_c, reg_c_partial, reg_step, NUM_REGISTERS,
reg_b, reg_c, reg_c_partial, reg_input_limb, reg_step, NUM_REGISTERS,
};
use crate::keccak::round_flags::{eval_round_flags, eval_round_flags_recursively};
use crate::stark::Stark;
@ -65,6 +65,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
fn generate_trace_rows_for_perm(&self, input: [u64; INPUT_LIMBS]) -> Vec<[F; NUM_REGISTERS]> {
let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS];
self.copy_input(input, &mut rows[0]);
for x in 0..5 {
for y in 0..5 {
let input_xy = input[x * 5 + y];
@ -76,6 +77,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
self.generate_trace_row_for_round(&mut rows[0], 0);
for round in 1..24 {
self.copy_input(input, &mut rows[round]);
self.copy_output_to_input(rows[round - 1], &mut rows[round]);
self.generate_trace_row_for_round(&mut rows[round], round);
}
@ -188,6 +190,14 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
row[out_reg_hi] = F::from_canonical_u64(row[in_reg_hi].to_canonical_u64() ^ rc_hi);
}
fn copy_input(&self, input: [u64; INPUT_LIMBS], row: &mut [F; NUM_REGISTERS]) {
for i in 0..INPUT_LIMBS {
let (low, high) = (input[i] as u32, input[i] >> 32);
row[reg_input_limb(2 * i)] = F::from_canonical_u32(low);
row[reg_input_limb(2 * i + 1)] = F::from_canonical_u64(high);
}
}
pub fn generate_trace(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
@ -223,6 +233,16 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
{
eval_round_flags(vars, yield_constr);
// Constraint the input registers to be equal throughout the rounds of a permutation.
for i in 0..2 * INPUT_LIMBS {
let local_input_limb = vars.local_values[reg_input_limb(i)];
let next_input_limb = vars.next_values[reg_input_limb(i)];
let is_last_round = vars.local_values[reg_step(NUM_ROUNDS - 1)];
yield_constr.constraint_transition(
(P::ONES - is_last_round) * (next_input_limb - local_input_limb),
);
}
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
for x in 0..5 {
for z in 0..64 {
@ -361,6 +381,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
eval_round_flags_recursively(builder, vars, yield_constr);
for i in 0..2 * INPUT_LIMBS {
let local_input_limb = vars.local_values[reg_input_limb(i)];
let next_input_limb = vars.next_values[reg_input_limb(i)];
let is_last_round = vars.local_values[reg_step(NUM_ROUNDS - 1)];
let diff = builder.sub_extension(local_input_limb, next_input_limb);
let constraint = builder.mul_sub_extension(is_last_round, diff, diff);
yield_constr.constraint_transition(builder, constraint);
}
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
for x in 0..5 {
for z in 0..64 {
@ -513,7 +542,7 @@ mod tests {
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::keccak::keccak_stark::{KeccakStark, INPUT_LIMBS, NUM_ROUNDS};
use crate::keccak::registers::reg_a_prime_prime_prime;
use crate::keccak::registers::reg_output_limb;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
#[test]
@ -557,16 +586,10 @@ mod tests {
let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()]);
let last_row = rows[NUM_ROUNDS - 1];
let mut output = Vec::new();
let base = F::from_canonical_u64(1 << 32);
for x in 0..5 {
for y in 0..5 {
output.push(
last_row[reg_a_prime_prime_prime(x, y)]
+ base * last_row[reg_a_prime_prime_prime(x, y) + 1],
);
}
}
let output = (0..INPUT_LIMBS)
.map(|i| last_row[reg_output_limb(2 * i)] + base * last_row[reg_output_limb(2 * i + 1)])
.collect::<Vec<_>>();
let mut keccak_input: [[u64; 5]; 5] = [
input[0..5].try_into().unwrap(),

View File

@ -1,4 +1,4 @@
use crate::keccak::keccak_stark::NUM_ROUNDS;
use crate::keccak::keccak_stark::{INPUT_LIMBS, NUM_ROUNDS};
/// A register which is set to 1 if we are in the `i`th round, otherwise 0.
pub(crate) const fn reg_step(i: usize) -> usize {
@ -6,6 +6,30 @@ pub(crate) const fn reg_step(i: usize) -> usize {
i
}
/// Registers to hold permutation inputs.
/// `reg_input_limb(2*i) -> input[i] as u32`
/// `reg_input_limb(2*i+1) -> input[i] >> 32`
pub(crate) const fn reg_input_limb(i: usize) -> usize {
debug_assert!(i < 2 * INPUT_LIMBS);
NUM_ROUNDS + i
}
/// Registers to hold permutation outputs.
/// `reg_output_limb(2*i) -> output[i] as u32`
/// `reg_output_limb(2*i+1) -> output[i] >> 32`
#[allow(dead_code)] // TODO: Remove once it is used.
pub(crate) const fn reg_output_limb(i: usize) -> usize {
debug_assert!(i < 2 * INPUT_LIMBS);
let ii = i / 2;
let x = ii / 5;
let y = ii % 5;
if i % 2 == 0 {
reg_a_prime_prime_prime(x, y)
} else {
reg_a_prime_prime_prime(x, y) + 1
}
}
const R: [[u8; 5]; 5] = [
[0, 36, 3, 41, 18],
[1, 44, 10, 45, 2],
@ -14,7 +38,7 @@ const R: [[u8; 5]; 5] = [
[27, 20, 39, 8, 14],
];
const START_A: usize = NUM_ROUNDS;
const START_A: usize = NUM_ROUNDS + 2 * INPUT_LIMBS;
pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);