fixes and debugging

This commit is contained in:
Nicholas Ward 2022-06-06 10:31:42 -07:00
parent e6880e591b
commit 2c285ca2cd
2 changed files with 90 additions and 66 deletions

View File

@ -49,6 +49,7 @@ impl Table {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::Result; use anyhow::Result;
use itertools::Itertools;
use plonky2::field::field_types::Field; use plonky2::field::field_types::Field;
use plonky2::field::polynomial::PolynomialValues; use plonky2::field::polynomial::PolynomialValues;
use plonky2::iop::witness::PartialWitness; use plonky2::iop::witness::PartialWitness;
@ -62,7 +63,7 @@ mod tests {
use crate::cpu; use crate::cpu;
use crate::cpu::cpu_stark::CpuStark; use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::CrossTableLookup; use crate::cross_table_lookup::CrossTableLookup;
use crate::keccak::keccak_stark::KeccakStark; use crate::keccak::keccak_stark::{KeccakStark, NUM_ROUNDS, INPUT_LIMBS};
use crate::proof::AllProof; use crate::proof::AllProof;
use crate::prover::prove; use crate::prover::prove;
use crate::recursive_verifier::{ use crate::recursive_verifier::{
@ -80,29 +81,35 @@ mod tests {
let cpu_stark = CpuStark::<F, D> { let cpu_stark = CpuStark::<F, D> {
f: Default::default(), f: Default::default(),
}; };
let cpu_rows = 256; let cpu_rows = 1 << 6;
let keccak_stark = KeccakStark::<F, D> { let keccak_stark = KeccakStark::<F, D> {
f: Default::default(), f: Default::default(),
}; };
let keccak_rows = 256; let keccak_rows = (NUM_ROUNDS + 1).next_power_of_two();
let keccak_looked_col = 3;
let mut cpu_trace_rows = vec![]; let mut cpu_trace = vec![PolynomialValues::<F>::zero(cpu_rows); 10];
for i in 0..cpu_rows {
let mut cpu_trace_row = [F::ZERO; CpuStark::<F, D>::COLUMNS]; let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ONE;
cpu_trace_row[cpu::columns::OPCODE] = F::from_canonical_usize(i);
cpu_stark.generate(&mut cpu_trace_row);
cpu_trace_rows.push(cpu_trace_row);
}
let cpu_trace = trace_rows_to_poly_values(cpu_trace_rows);
let mut keccak_trace = let num_inpts = 1;
vec![PolynomialValues::zero(keccak_rows); KeccakStark::<F, D>::COLUMNS]; let keccak_inputs = (0..num_inpts)
keccak_trace[keccak_looked_col] = cpu_trace[cpu::columns::OPCODE].clone(); .map(|_| [0u64; INPUT_LIMBS].map(|_| rng.gen()))
.collect_vec();
let keccak_trace = keccak_stark.generate_trace(keccak_inputs);
let vs0: Vec<_> = keccak_trace[3].values[..].into();
let vs1: Vec<_> = keccak_trace[5].values[..].into();
let start = thread_rng().gen_range(0..cpu_rows - keccak_rows);
let default = vec![F::ONE; 2];
cpu_trace[2].values = vec![default[0]; cpu_rows];
cpu_trace[2].values[start..start + keccak_rows].copy_from_slice(&vs0);
cpu_trace[4].values = vec![default[1]; cpu_rows];
cpu_trace[4].values[start..start + keccak_rows].copy_from_slice(&vs1);
let default = vec![F::ZERO; 2];
let cross_table_lookups = vec![CrossTableLookup { let cross_table_lookups = vec![CrossTableLookup {
looking_tables: vec![Table::Cpu], looking_tables: vec![Table::Cpu],
looking_columns: vec![vec![cpu::columns::OPCODE]], looking_columns: vec![vec![cpu::columns::OPCODE]],

View File

@ -29,7 +29,7 @@ pub(crate) const NUM_ROUNDS: usize = 24;
/// Number of 64-bit limbs in a preimage of the Keccak permutation. /// Number of 64-bit limbs in a preimage of the Keccak permutation.
pub(crate) const INPUT_LIMBS: usize = 25; pub(crate) const INPUT_LIMBS: usize = 25;
pub(crate) const NUM_PUBLIC_INPUTS: usize = 4; pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct KeccakStark<F, const D: usize> { pub struct KeccakStark<F, const D: usize> {
@ -39,14 +39,15 @@ pub struct KeccakStark<F, const D: usize> {
impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> { impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
/// Generate the rows of the trace. Note that this does not generate the permuted columns used /// Generate the rows of the trace. Note that this does not generate the permuted columns used
/// in our lookup arguments, as those are computed after transposing to column-wise form. /// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec<[F; NUM_REGISTERS]> { pub(crate) fn generate_trace_rows(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec<[F; NUM_REGISTERS]> {
let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two(); let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
info!("{} rows", num_rows); info!("{} rows", num_rows);
let mut rows = Vec::with_capacity(num_rows); let mut rows = Vec::with_capacity(num_rows);
for input in inputs { for input in inputs.iter().take(1) {
rows.extend(self.generate_trace_rows_for_perm(input)); rows.extend(self.generate_trace_rows_for_perm(input.clone()));
} }
// Pad rows to power of two.
for i in rows.len()..num_rows { for i in rows.len()..num_rows {
let mut row = [F::ZERO; NUM_REGISTERS]; let mut row = [F::ZERO; NUM_REGISTERS];
self.generate_trace_rows_for_round(&mut row, i % NUM_ROUNDS); self.generate_trace_rows_for_round(&mut row, i % NUM_ROUNDS);
@ -59,8 +60,8 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
fn generate_trace_rows_for_perm( fn generate_trace_rows_for_perm(
&self, &self,
input: [u64; INPUT_LIMBS], input: [u64; INPUT_LIMBS],
) -> [[F; NUM_REGISTERS]; NUM_ROUNDS] { ) -> Vec<[F; NUM_REGISTERS]> {
let mut rows = [[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS]; let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS];
for x in 0..5 { for x in 0..5 {
for y in 0..5 { for y in 0..5 {
@ -73,7 +74,24 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
self.generate_trace_rows_for_round(&mut rows[0], 0); self.generate_trace_rows_for_round(&mut rows[0], 0);
for round in 1..24 { for round in 1..24 {
// TODO: Populate input from prev. row output. for x in 0..5 {
for y in 0..5 {
let cur = rows[round - 1][reg_a_prime_prime_prime(x, y)];
let cur_u64 = cur.to_canonical_u64();
let bit_values: Vec<u64> = (0..64)
.scan(cur_u64, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(tmp)
})
.collect();
for z in 0..64 {
rows[round][reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]);
}
}
}
rows[round] = rows[round - 1].clone();
self.generate_trace_rows_for_round(&mut rows[round], round); self.generate_trace_rows_for_round(&mut rows[round], round);
} }
@ -81,7 +99,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
} }
fn generate_trace_rows_for_round(&self, row: &mut [F; NUM_REGISTERS], round: usize) { fn generate_trace_rows_for_round(&self, row: &mut [F; NUM_REGISTERS], round: usize) {
row[round] = F::ONE; row[reg_step(round)] = F::ONE;
// Populate C partial and C. // Populate C partial and C.
for x in 0..5 { for x in 0..5 {
@ -113,24 +131,23 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
for x in 0..5 { for x in 0..5 {
for y in 0..5 { for y in 0..5 {
let get_bit = |z| { // let get_bit = |z| {
xor([
row[reg_b(x, y, z)], // // xor([
andn(row[reg_b((x + 1) % 5, y, z)], row[reg_b((x + 2) % 5, y, z)]), // // row[reg_b(x, y, z)],
]) // // andn(row[reg_b((x + 1) % 5, y, z)], row[reg_b((x + 2) % 5, y, z)]),
}; // // ])
// };
let lo = (0..32) let lo = F::ZERO;//row[reg_b(x, y, 0)];
.rev() // let hi = (32..64)
.fold(F::ZERO, |acc, z| acc.double() + get_bit(z)); // .rev()
let hi = (32..64) // .fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
.rev()
.fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
let reg_lo = reg_a_prime_prime(x, y); let reg_lo = reg_a_prime_prime(x, y);
let reg_hi = reg_lo + 1; let reg_hi = reg_lo + 1;
row[reg_lo] = lo; row[reg_lo] = lo;
row[reg_hi] = hi; // row[reg_hi] = hi;
} }
} }
@ -223,28 +240,28 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
for x in 0..5 { for x in 0..5 {
for y in 0..5 { for y in 0..5 {
let get_bit = |z| { let get_bit = |z| {
xor_gen( // xor_gen(
vars.local_values[reg_b(x, y, z)], vars.local_values[reg_b(x, y, z)]
andn_gen( // andn_gen(
vars.local_values[reg_b((x + 1) % 5, y, z)], // vars.local_values[reg_b((x + 1) % 5, y, z)],
vars.local_values[reg_b((x + 2) % 5, y, z)], // vars.local_values[reg_b((x + 2) % 5, y, z)],
), // ),
) // )
}; };
let reg_lo = reg_a_prime_prime(x, y); let reg_lo = reg_a_prime_prime(x, y);
let reg_hi = reg_lo + 1; let reg_hi = reg_lo + 1;
let lo = vars.local_values[reg_lo]; let lo = vars.local_values[reg_lo];
let hi = vars.local_values[reg_hi]; // let hi = vars.local_values[reg_hi];
let computed_lo = (0..32) let computed_lo = P::ZEROS;// vars.local_values[reg_b(x, y, 0)];//(0..32)
.rev() // .rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z)); // .fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
let computed_hi = (32..64) let computed_hi = (32..64)
.rev() .rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z)); .fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
yield_constr.constraint(computed_lo - lo); yield_constr.constraint_last_row(computed_lo - lo);
yield_constr.constraint(computed_hi - hi); // yield_constr.constraint(computed_hi - hi);
} }
} }
@ -260,8 +277,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
.fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]); .fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]);
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)]; let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1]; let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
yield_constr.constraint(computed_a_prime_prime_0_0_lo - a_prime_prime_0_0_lo); // yield_constr.constraint(computed_a_prime_prime_0_0_lo - a_prime_prime_0_0_lo);
yield_constr.constraint(computed_a_prime_prime_0_0_hi - a_prime_prime_0_0_hi); // yield_constr.constraint(computed_a_prime_prime_0_0_hi - a_prime_prime_0_0_hi);
let get_xored_bit = |i| { let get_xored_bit = |i| {
let mut rc_bit_i = P::ZEROS; let mut rc_bit_i = P::ZEROS;
@ -283,8 +300,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let computed_a_prime_prime_prime_0_0_hi = (32..64) let computed_a_prime_prime_prime_0_0_hi = (32..64)
.rev() .rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z)); .fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z));
yield_constr.constraint(computed_a_prime_prime_prime_0_0_lo - a_prime_prime_prime_0_0_lo); // yield_constr.constraint(computed_a_prime_prime_prime_0_0_lo - a_prime_prime_prime_0_0_lo);
yield_constr.constraint(computed_a_prime_prime_prime_0_0_hi - a_prime_prime_prime_0_0_hi); // yield_constr.constraint(computed_a_prime_prime_prime_0_0_hi - a_prime_prime_prime_0_0_hi);
// Enforce that this round's output equals the next round's input. // Enforce that this round's output equals the next round's input.
for x in 0..5 { for x in 0..5 {
@ -296,7 +313,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let input_bits_combined = (0..64) let input_bits_combined = (0..64)
.rev() .rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]); .fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]);
yield_constr.constraint(output - input_bits_combined); // yield_constr.constraint(output - input_bits_combined);
} }
} }
} }
@ -321,7 +338,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2); let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2);
let diff = builder.sub_extension(c_partial, xor_012); let diff = builder.sub_extension(c_partial, xor_012);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
} }
} }
@ -335,7 +352,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4); let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4);
let diff = builder.sub_extension(c, xor_01234); let diff = builder.sub_extension(c, xor_01234);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
} }
} }
@ -352,7 +369,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let a_prime = vars.local_values[reg_a_prime(x, y, z)]; let a_prime = vars.local_values[reg_a_prime(x, y, z)];
let xor = xor_gen_circuit(builder, d, a); let xor = xor_gen_circuit(builder, d, a);
let diff = builder.sub_extension(a_prime, xor); let diff = builder.sub_extension(a_prime, xor);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
} }
} }
} }
@ -378,9 +395,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two); let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two);
let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two); let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two);
let diff = builder.sub_extension(computed_lo, lo); let diff = builder.sub_extension(computed_lo, lo);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(computed_hi, hi); let diff = builder.sub_extension(computed_hi, hi);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
} }
} }
@ -395,9 +412,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)]; let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1]; let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
let diff = builder.sub_extension(computed_a_prime_prime_0_0_lo, a_prime_prime_0_0_lo); let diff = builder.sub_extension(computed_a_prime_prime_0_0_lo, a_prime_prime_0_0_lo);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(computed_a_prime_prime_0_0_hi, a_prime_prime_0_0_hi); let diff = builder.sub_extension(computed_a_prime_prime_0_0_hi, a_prime_prime_0_0_hi);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
let mut get_xored_bit = |i| { let mut get_xored_bit = |i| {
let mut rc_bit_i = builder.zero_extension(); let mut rc_bit_i = builder.zero_extension();
@ -423,12 +440,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
computed_a_prime_prime_prime_0_0_lo, computed_a_prime_prime_prime_0_0_lo,
a_prime_prime_prime_0_0_lo, a_prime_prime_prime_0_0_lo,
); );
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
let diff = builder.sub_extension( let diff = builder.sub_extension(
computed_a_prime_prime_prime_0_0_hi, computed_a_prime_prime_prime_0_0_hi,
a_prime_prime_prime_0_0_hi, a_prime_prime_prime_0_0_hi,
); );
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
// Enforce that this round's output equals the next round's input. // Enforce that this round's output equals the next round's input.
for x in 0..5 { for x in 0..5 {
@ -439,7 +456,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
.collect_vec(); .collect_vec();
let input_bits_combined = reduce_with_powers_ext_circuit(builder, &input_bits, two); let input_bits_combined = reduce_with_powers_ext_circuit(builder, &input_bits, two);
let diff = builder.sub_extension(output, input_bits_combined); let diff = builder.sub_extension(output, input_bits_combined);
yield_constr.constraint(builder, diff); // yield_constr.constraint(builder, diff);
} }
} }
} }