Remove reg_preimage columns in KeccakStark (#1279)

* Remove reg_preimage columns in KeccakStark

* Apply comments

* Minor cleanup
This commit is contained in:
Linda Guiga 2023-10-06 15:49:57 -04:00 committed by GitHub
parent 0de6f94962
commit e58d7795f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 98 additions and 95 deletions

View File

@ -96,7 +96,8 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
ctl_arithmetic(), ctl_arithmetic(),
ctl_byte_packing(), ctl_byte_packing(),
ctl_keccak_sponge(), ctl_keccak_sponge(),
ctl_keccak(), ctl_keccak_inputs(),
ctl_keccak_outputs(),
ctl_logic(), ctl_logic(),
ctl_memory(), ctl_memory(),
] ]
@ -131,16 +132,33 @@ fn ctl_byte_packing<F: Field>() -> CrossTableLookup<F> {
) )
} }
fn ctl_keccak<F: Field>() -> CrossTableLookup<F> { // We now need two different looked tables for `KeccakStark`:
// one for the inputs and one for the outputs.
// They are linked with the timestamp.
fn ctl_keccak_inputs<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new( let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge, Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak(), keccak_sponge_stark::ctl_looking_keccak_inputs(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()), Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
); );
let keccak_looked = TableWithColumns::new( let keccak_looked = TableWithColumns::new(
Table::Keccak, Table::Keccak,
keccak_stark::ctl_data(), keccak_stark::ctl_data_inputs(),
Some(keccak_stark::ctl_filter()), Some(keccak_stark::ctl_filter_inputs()),
);
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
}
fn ctl_keccak_outputs<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak_outputs(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data_outputs(),
Some(keccak_stark::ctl_filter_outputs()),
); );
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked)
} }

View File

@ -20,7 +20,7 @@ pub fn reg_input_limb<F: Field>(i: usize) -> Column<F> {
let y = i_u64 / 5; let y = i_u64 / 5;
let x = i_u64 % 5; let x = i_u64 % 5;
let reg_low_limb = reg_preimage(x, y); let reg_low_limb = reg_a(x, y);
let is_high_limb = i % 2; let is_high_limb = i % 2;
Column::single(reg_low_limb + is_high_limb) Column::single(reg_low_limb + is_high_limb)
} }
@ -48,15 +48,11 @@ const R: [[u8; 5]; 5] = [
[27, 20, 39, 8, 14], [27, 20, 39, 8, 14],
]; ];
const START_PREIMAGE: usize = NUM_ROUNDS; /// Column holding the timestamp, used to link inputs and outputs
/// Registers to hold the original input to a permutation, i.e. the input to the first round. /// in the `KeccakSpongeStark`.
pub(crate) const fn reg_preimage(x: usize, y: usize) -> usize { pub(crate) const TIMESTAMP: usize = NUM_ROUNDS;
debug_assert!(x < 5);
debug_assert!(y < 5);
START_PREIMAGE + (x * 5 + y) * 2
}
const START_A: usize = START_PREIMAGE + 5 * 5 * 2; const START_A: usize = TIMESTAMP + 1;
pub(crate) const fn reg_a(x: usize, y: usize) -> usize { pub(crate) const fn reg_a(x: usize, y: usize) -> usize {
debug_assert!(x < 5); debug_assert!(x < 5);
debug_assert!(y < 5); debug_assert!(y < 5);

View File

@ -11,13 +11,13 @@ use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit;
use plonky2::timed; use plonky2::timed;
use plonky2::util::timing::TimingTree; use plonky2::util::timing::TimingTree;
use super::columns::reg_input_limb;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column; use crate::cross_table_lookup::Column;
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::keccak::columns::{ use crate::keccak::columns::{
reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime,
reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step, reg_b, reg_c, reg_c_prime, reg_output_limb, reg_step, NUM_COLUMNS, TIMESTAMP,
NUM_COLUMNS,
}; };
use crate::keccak::constants::{rc_value, rc_value_bit}; use crate::keccak::constants::{rc_value, rc_value_bit};
use crate::keccak::logic::{ use crate::keccak::logic::{
@ -33,13 +33,23 @@ pub(crate) const NUM_ROUNDS: usize = 24;
/// Number of 64-bit elements in the Keccak permutation input. /// Number of 64-bit elements in the Keccak permutation input.
pub(crate) const NUM_INPUTS: usize = 25; pub(crate) const NUM_INPUTS: usize = 25;
pub fn ctl_data<F: Field>() -> Vec<Column<F>> { pub fn ctl_data_inputs<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect(); let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect();
res.extend(Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb))); res.push(Column::single(TIMESTAMP));
res res
} }
pub fn ctl_filter<F: Field>() -> Column<F> { pub fn ctl_data_outputs<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb)).collect();
res.push(Column::single(TIMESTAMP));
res
}
pub fn ctl_filter_inputs<F: Field>() -> Column<F> {
Column::single(reg_step(0))
}
pub fn ctl_filter_outputs<F: Field>() -> Column<F> {
Column::single(reg_step(NUM_ROUNDS - 1)) Column::single(reg_step(NUM_ROUNDS - 1))
} }
@ -53,16 +63,16 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
/// in our lookup arguments, as those are computed after transposing to column-wise form. /// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows( fn generate_trace_rows(
&self, &self,
inputs: Vec<[u64; NUM_INPUTS]>, inputs_and_timestamps: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize, min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> { ) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = (inputs.len() * NUM_ROUNDS) let num_rows = (inputs_and_timestamps.len() * NUM_ROUNDS)
.max(min_rows) .max(min_rows)
.next_power_of_two(); .next_power_of_two();
let mut rows = Vec::with_capacity(num_rows); let mut rows = Vec::with_capacity(num_rows);
for input in inputs.iter() { for input_and_timestamp in inputs_and_timestamps.iter() {
let rows_for_perm = self.generate_trace_rows_for_perm(*input); let rows_for_perm = self.generate_trace_rows_for_perm(*input_and_timestamp);
rows.extend(rows_for_perm); rows.extend(rows_for_perm);
} }
@ -72,20 +82,19 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
rows rows
} }
fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_COLUMNS]> { fn generate_trace_rows_for_perm(
&self,
input_and_timestamp: ([u64; NUM_INPUTS], usize),
) -> Vec<[F; NUM_COLUMNS]> {
let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS]; let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS];
let input = input_and_timestamp.0;
// Populate the preimage for each row. let timestamp = input_and_timestamp.1;
// Set the timestamp of the current input.
// It will be checked against the value in `KeccakSponge`.
// The timestamp is used to link the input and output of
// the same permutation together.
for round in 0..24 { for round in 0..24 {
for x in 0..5 { rows[round][TIMESTAMP] = F::from_canonical_usize(timestamp);
for y in 0..5 {
let input_xy = input[y * 5 + x];
let reg_preimage_lo = reg_preimage(x, y);
let reg_preimage_hi = reg_preimage_lo + 1;
rows[round][reg_preimage_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF);
rows[round][reg_preimage_hi] = F::from_canonical_u64(input_xy >> 32);
}
}
} }
// Populate the round input for the first round. // Populate the round input for the first round.
@ -220,7 +229,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
pub fn generate_trace( pub fn generate_trace(
&self, &self,
inputs: Vec<[u64; NUM_INPUTS]>, inputs: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize, min_rows: usize,
timing: &mut TimingTree, timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> { ) -> Vec<PolynomialValues<F>> {
@ -269,26 +278,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let not_final_step = P::ONES - final_step; let not_final_step = P::ONES - final_step;
yield_constr.constraint(not_final_step * filter); yield_constr.constraint(not_final_step * filter);
// If this is not the final step, the local and next preimages must match. // If this is not the final step or a padding row,
// Also, if this is the first step, the preimage must match A. // the local and next timestamps must match.
let is_first_step = local_values[reg_step(0)]; let sum_round_flags = (0..NUM_ROUNDS)
for x in 0..5 { .map(|i| local_values[reg_step(i)])
for y in 0..5 { .sum::<P>();
let reg_preimage_lo = reg_preimage(x, y); yield_constr.constraint(
let reg_preimage_hi = reg_preimage_lo + 1; sum_round_flags * not_final_step * (next_values[TIMESTAMP] - local_values[TIMESTAMP]),
let diff_lo = local_values[reg_preimage_lo] - next_values[reg_preimage_lo]; );
let diff_hi = local_values[reg_preimage_hi] - next_values[reg_preimage_hi];
yield_constr.constraint_transition(not_final_step * diff_lo);
yield_constr.constraint_transition(not_final_step * diff_hi);
let reg_a_lo = reg_a(x, y);
let reg_a_hi = reg_a_lo + 1;
let diff_lo = local_values[reg_preimage_lo] - local_values[reg_a_lo];
let diff_hi = local_values[reg_preimage_hi] - local_values[reg_a_hi];
yield_constr.constraint(is_first_step * diff_lo);
yield_constr.constraint(is_first_step * diff_hi);
}
}
// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 { for x in 0..5 {
@ -454,34 +451,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let constraint = builder.mul_extension(not_final_step, filter); let constraint = builder.mul_extension(not_final_step, filter);
yield_constr.constraint(builder, constraint); yield_constr.constraint(builder, constraint);
// If this is not the final step, the local and next preimages must match. // If this is not the final step or a padding row,
// Also, if this is the first step, the preimage must match A. // the local and next timestamps must match.
let is_first_step = local_values[reg_step(0)]; let sum_round_flags =
for x in 0..5 { builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
for y in 0..5 { let diff = builder.sub_extension(next_values[TIMESTAMP], local_values[TIMESTAMP]);
let reg_preimage_lo = reg_preimage(x, y); let constr = builder.mul_many_extension([sum_round_flags, not_final_step, diff]);
let reg_preimage_hi = reg_preimage_lo + 1; yield_constr.constraint(builder, constr);
let diff = builder
.sub_extension(local_values[reg_preimage_lo], next_values[reg_preimage_lo]);
let constraint = builder.mul_extension(not_final_step, diff);
yield_constr.constraint_transition(builder, constraint);
let diff = builder
.sub_extension(local_values[reg_preimage_hi], next_values[reg_preimage_hi]);
let constraint = builder.mul_extension(not_final_step, diff);
yield_constr.constraint_transition(builder, constraint);
let reg_a_lo = reg_a(x, y);
let reg_a_hi = reg_a_lo + 1;
let diff_lo =
builder.sub_extension(local_values[reg_preimage_lo], local_values[reg_a_lo]);
let constraint = builder.mul_extension(is_first_step, diff_lo);
yield_constr.constraint(builder, constraint);
let diff_hi =
builder.sub_extension(local_values[reg_preimage_hi], local_values[reg_a_hi]);
let constraint = builder.mul_extension(is_first_step, diff_hi);
yield_constr.constraint(builder, constraint);
}
}
// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 { for x in 0..5 {
@ -699,7 +675,7 @@ mod tests {
f: Default::default(), f: Default::default(),
}; };
let rows = stark.generate_trace_rows(vec![input], 8); let rows = stark.generate_trace_rows(vec![(input, 0)], 8);
let last_row = rows[NUM_ROUNDS - 1]; let last_row = rows[NUM_ROUNDS - 1];
let output = (0..NUM_INPUTS) let output = (0..NUM_INPUTS)
.map(|i| { .map(|i| {
@ -732,7 +708,8 @@ mod tests {
init_logger(); init_logger();
let input: Vec<[u64; NUM_INPUTS]> = (0..NUM_PERMS).map(|_| rand::random()).collect(); let input: Vec<([u64; NUM_INPUTS], usize)> =
(0..NUM_PERMS).map(|_| (rand::random(), 0)).collect();
let mut timing = TimingTree::new("prove", log::Level::Debug); let mut timing = TimingTree::new("prove", log::Level::Debug);
let trace_poly_values = timed!( let trace_poly_values = timed!(

View File

@ -47,7 +47,7 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
.collect() .collect()
} }
pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> { pub(crate) fn ctl_looking_keccak_inputs<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP; let cols = KECCAK_SPONGE_COL_MAP;
let mut res: Vec<_> = Column::singles( let mut res: Vec<_> = Column::singles(
[ [
@ -57,6 +57,13 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
.concat(), .concat(),
) )
.collect(); .collect();
res.push(Column::single(cols.timestamp));
res
}
pub(crate) fn ctl_looking_keccak_outputs<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
// We recover the 32-bit digest limbs from their corresponding bytes, // We recover the 32-bit digest limbs from their corresponding bytes,
// and then append them to the rest of the updated state limbs. // and then append them to the rest of the updated state limbs.
@ -68,9 +75,10 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
) )
}); });
res.extend(digest_u32s); let mut res: Vec<_> = digest_u32s.collect();
res.extend(Column::singles(&cols.partial_updated_state_u32s)); res.extend(Column::singles(&cols.partial_updated_state_u32s));
res.push(Column::single(cols.timestamp));
res res
} }

View File

@ -36,7 +36,7 @@ pub(crate) struct Traces<T: Copy> {
pub(crate) cpu: Vec<CpuColumnsView<T>>, pub(crate) cpu: Vec<CpuColumnsView<T>>,
pub(crate) logic_ops: Vec<logic::Operation>, pub(crate) logic_ops: Vec<logic::Operation>,
pub(crate) memory_ops: Vec<MemoryOp>, pub(crate) memory_ops: Vec<MemoryOp>,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, pub(crate) keccak_inputs: Vec<([u64; keccak::keccak_stark::NUM_INPUTS], usize)>,
pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>, pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>,
} }
@ -131,18 +131,18 @@ impl<T: Copy> Traces<T> {
self.byte_packing_ops.push(op); self.byte_packing_ops.push(op);
} }
pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) { pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS], clock: usize) {
self.keccak_inputs.push(input); self.keccak_inputs.push((input, clock));
} }
pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES]) { pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) {
let chunks = input let chunks = input
.chunks(size_of::<u64>()) .chunks(size_of::<u64>())
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect_vec() .collect_vec()
.try_into() .try_into()
.unwrap(); .unwrap();
self.push_keccak(chunks); self.push_keccak(chunks, clock);
} }
pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) { pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) {

View File

@ -229,7 +229,9 @@ pub(crate) fn keccak_sponge_log<F: Field>(
address.increment(); address.increment();
} }
xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap()); xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap());
state.traces.push_keccak_bytes(sponge_state); state
.traces
.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS);
keccakf_u8s(&mut sponge_state); keccakf_u8s(&mut sponge_state);
} }
@ -254,7 +256,9 @@ pub(crate) fn keccak_sponge_log<F: Field>(
final_block[KECCAK_RATE_BYTES - 1] = 0b10000000; final_block[KECCAK_RATE_BYTES - 1] = 0b10000000;
} }
xor_into_sponge(state, &mut sponge_state, &final_block); xor_into_sponge(state, &mut sponge_state, &final_block);
state.traces.push_keccak_bytes(sponge_state); state
.traces
.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS);
state.traces.push_keccak_sponge(KeccakSpongeOp { state.traces.push_keccak_sponge(KeccakSpongeOp {
base_address, base_address,