This commit is contained in:
Nicholas Ward 2022-06-01 09:19:23 -07:00
parent 60c0b4ee79
commit 69aed6586a
2 changed files with 56 additions and 35 deletions

View File

@ -16,7 +16,7 @@ use crate::keccak::logic::{
};
use crate::keccak::registers::{
rc_value, rc_value_bit, reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit,
reg_a_prime_prime_prime, reg_b, reg_c, reg_c_partial, reg_step, NUM_REGISTERS,
reg_a_prime_prime_prime, reg_b, reg_c, reg_c_partial, reg_step, NUM_REGISTERS, reg_dummy,
};
use crate::keccak::round_flags::{eval_round_flags, eval_round_flags_recursively};
use crate::stark::Stark;
@ -54,6 +54,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
for i in rows.len()..num_rows {
let mut row = [F::ZERO; NUM_REGISTERS];
self.generate_trace_rows_for_round(&mut row, i % NUM_ROUNDS);
row[reg_dummy()] = F::ONE;
rows.push(row);
}
@ -74,29 +75,34 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
self.generate_trace_rows_for_round(&mut rows[0], 0);
for round in 1..24 {
for x in 0..5 {
for y in 0..5 {
let cur = rows[round - 1][reg_a_prime_prime_prime(x, y)];
let cur_u64 = cur.to_canonical_u64();
let bit_values: Vec<u64> = (0..64)
.scan(cur_u64, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(tmp)
})
.collect();
for z in 0..64 {
rows[round][reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]);
}
}
}
self.copy_output_to_input(rows[round - 1], &mut rows[round]);
self.generate_trace_rows_for_round(&mut rows[round], round);
}
rows
}
fn copy_output_to_input(&self, prev_row: [F; NUM_REGISTERS], next_row: &mut [F; NUM_REGISTERS]) {
for x in 0..5 {
for y in 0..5 {
let cur_lo = prev_row[reg_a_prime_prime_prime(x, y)];
let cur_hi = prev_row[reg_a_prime_prime_prime(x, y) + 1];
let cur_u64 = cur_lo.to_canonical_u64() + (1 << 32) * cur_hi.to_canonical_u64();
let bit_values: Vec<u64> = (0..64)
.scan(cur_u64, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(tmp)
})
.collect();
for z in 0..64 {
next_row[reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]);
}
}
}
}
fn generate_trace_rows_for_round(&self, row: &mut [F; NUM_REGISTERS], round: usize) {
row[reg_step(round)] = F::ONE;
@ -323,14 +329,20 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
// Enforce that this round's output equals the next round's input.
for x in 0..5 {
for y in 0..5 {
let output = vars.local_values[reg_a_prime_prime_prime(x, y)];
let output_lo = vars.local_values[reg_a_prime_prime_prime(x, y)];
let output_hi = vars.local_values[reg_a_prime_prime_prime(x, y) + 1];
let input_bits = (0..64)
.map(|z| vars.next_values[reg_a(x, y, z)])
.collect_vec();
let input_bits_combined = (0..64)
let input_bits_combined_lo = (0..32)
.rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]);
// yield_constr.constraint(output - input_bits_combined);
let input_bits_combined_hi = (32..64)
.rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]);
let dummy = vars.next_values[reg_dummy()];
yield_constr.constraint_transition((P::ONES - dummy) * (output_lo - input_bits_combined_lo));
yield_constr.constraint_transition((P::ONES - dummy) * (output_hi - input_bits_combined_hi));
}
}
}
@ -355,7 +367,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2);
let diff = builder.sub_extension(c_partial, xor_012);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
}
}
@ -369,7 +381,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4);
let diff = builder.sub_extension(c, xor_01234);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
}
}
@ -386,7 +398,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let a_prime = vars.local_values[reg_a_prime(x, y, z)];
let xor = xor_gen_circuit(builder, d, a);
let diff = builder.sub_extension(a_prime, xor);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
}
}
}
@ -412,9 +424,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two);
let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two);
let diff = builder.sub_extension(computed_lo, lo);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(computed_hi, hi);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
}
}
@ -429,9 +441,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
let diff = builder.sub_extension(computed_a_prime_prime_0_0_lo, a_prime_prime_0_0_lo);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(computed_a_prime_prime_0_0_hi, a_prime_prime_0_0_hi);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
let mut get_xored_bit = |i| {
let mut rc_bit_i = builder.zero_extension();
@ -457,23 +469,27 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
computed_a_prime_prime_prime_0_0_lo,
a_prime_prime_prime_0_0_lo,
);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(
computed_a_prime_prime_prime_0_0_hi,
a_prime_prime_prime_0_0_hi,
);
// yield_constr.constraint(builder, diff);
yield_constr.constraint(builder, diff);
// Enforce that this round's output equals the next round's input.
for x in 0..5 {
for y in 0..5 {
let output = vars.local_values[reg_a_prime_prime_prime(x, y)];
let output_lo = vars.local_values[reg_a_prime_prime_prime(x, y)];
let output_hi = vars.local_values[reg_a_prime_prime_prime(x, y) + 1];
let input_bits = (0..64)
.map(|z| vars.next_values[reg_a(x, y, z)])
.collect_vec();
let input_bits_combined = reduce_with_powers_ext_circuit(builder, &input_bits, two);
let diff = builder.sub_extension(output, input_bits_combined);
// yield_constr.constraint(builder, diff);
let input_bits_combined_lo = reduce_with_powers_ext_circuit(builder, &input_bits[0..32], two);
let input_bits_combined_hi = reduce_with_powers_ext_circuit(builder, &input_bits[32..64], two);
let diff = builder.sub_extension(output_lo, input_bits_combined_lo);
yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(output_hi, input_bits_combined_hi);
yield_constr.constraint(builder, diff);
}
}
}

View File

@ -6,6 +6,11 @@ pub(crate) const fn reg_step(i: usize) -> usize {
i
}
/// A register which is set to 1 if we are in the `i`th round, otherwise 0.
pub(crate) const fn reg_dummy() -> usize {
NUM_ROUNDS
}
const R: [[u8; 5]; 5] = [
[0, 18, 41, 3, 36],
[1, 2, 45, 10, 44],
@ -172,7 +177,7 @@ pub(crate) const fn rc_value(round: usize) -> u64 {
RC[round]
}
const START_A: usize = NUM_ROUNDS;
const START_A: usize = NUM_ROUNDS + 1;
pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize {
debug_assert!(x < 5);
debug_assert!(y < 5);