fixes and debugging

This commit is contained in:
Nicholas Ward 2022-06-06 10:31:42 -07:00
parent e6880e591b
commit 2c285ca2cd
2 changed files with 90 additions and 66 deletions

View File

@ -49,6 +49,7 @@ impl Table {
#[cfg(test)]
mod tests {
use anyhow::Result;
use itertools::Itertools;
use plonky2::field::field_types::Field;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::iop::witness::PartialWitness;
@ -62,7 +63,7 @@ mod tests {
use crate::cpu;
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::CrossTableLookup;
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak::keccak_stark::{KeccakStark, NUM_ROUNDS, INPUT_LIMBS};
use crate::proof::AllProof;
use crate::prover::prove;
use crate::recursive_verifier::{
@ -80,29 +81,35 @@ mod tests {
let cpu_stark = CpuStark::<F, D> {
f: Default::default(),
};
let cpu_rows = 256;
let cpu_rows = 1 << 6;
let keccak_stark = KeccakStark::<F, D> {
f: Default::default(),
};
let keccak_rows = 256;
let keccak_looked_col = 3;
let keccak_rows = (NUM_ROUNDS + 1).next_power_of_two();
let mut cpu_trace_rows = vec![];
for i in 0..cpu_rows {
let mut cpu_trace_row = [F::ZERO; CpuStark::<F, D>::COLUMNS];
cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ONE;
cpu_trace_row[cpu::columns::OPCODE] = F::from_canonical_usize(i);
cpu_stark.generate(&mut cpu_trace_row);
cpu_trace_rows.push(cpu_trace_row);
}
let cpu_trace = trace_rows_to_poly_values(cpu_trace_rows);
let mut cpu_trace = vec![PolynomialValues::<F>::zero(cpu_rows); 10];
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut keccak_trace =
vec![PolynomialValues::zero(keccak_rows); KeccakStark::<F, D>::COLUMNS];
keccak_trace[keccak_looked_col] = cpu_trace[cpu::columns::OPCODE].clone();
let num_inpts = 1;
let keccak_inputs = (0..num_inpts)
.map(|_| [0u64; INPUT_LIMBS].map(|_| rng.gen()))
.collect_vec();
let keccak_trace = keccak_stark.generate_trace(keccak_inputs);
let vs0: Vec<_> = keccak_trace[3].values[..].into();
let vs1: Vec<_> = keccak_trace[5].values[..].into();
let start = thread_rng().gen_range(0..cpu_rows - keccak_rows);
let default = vec![F::ONE; 2];
cpu_trace[2].values = vec![default[0]; cpu_rows];
cpu_trace[2].values[start..start + keccak_rows].copy_from_slice(&vs0);
cpu_trace[4].values = vec![default[1]; cpu_rows];
cpu_trace[4].values[start..start + keccak_rows].copy_from_slice(&vs1);
let default = vec![F::ZERO; 2];
let cross_table_lookups = vec![CrossTableLookup {
looking_tables: vec![Table::Cpu],
looking_columns: vec![vec![cpu::columns::OPCODE]],

View File

@ -29,7 +29,7 @@ pub(crate) const NUM_ROUNDS: usize = 24;
/// Number of 64-bit limbs in a preimage of the Keccak permutation.
pub(crate) const INPUT_LIMBS: usize = 25;
pub(crate) const NUM_PUBLIC_INPUTS: usize = 4;
pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
#[derive(Copy, Clone)]
pub struct KeccakStark<F, const D: usize> {
@ -39,14 +39,15 @@ pub struct KeccakStark<F, const D: usize> {
impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
/// Generate the rows of the trace. Note that this does not generate the permuted columns used
/// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec<[F; NUM_REGISTERS]> {
pub(crate) fn generate_trace_rows(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec<[F; NUM_REGISTERS]> {
let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
info!("{} rows", num_rows);
let mut rows = Vec::with_capacity(num_rows);
for input in inputs {
rows.extend(self.generate_trace_rows_for_perm(input));
for input in inputs.iter().take(1) {
rows.extend(self.generate_trace_rows_for_perm(input.clone()));
}
// Pad rows to power of two.
for i in rows.len()..num_rows {
let mut row = [F::ZERO; NUM_REGISTERS];
self.generate_trace_rows_for_round(&mut row, i % NUM_ROUNDS);
@ -59,8 +60,8 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
fn generate_trace_rows_for_perm(
&self,
input: [u64; INPUT_LIMBS],
) -> [[F; NUM_REGISTERS]; NUM_ROUNDS] {
let mut rows = [[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS];
) -> Vec<[F; NUM_REGISTERS]> {
let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS];
for x in 0..5 {
for y in 0..5 {
@ -73,7 +74,24 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
self.generate_trace_rows_for_round(&mut rows[0], 0);
for round in 1..24 {
// TODO: Populate input from prev. row output.
for x in 0..5 {
for y in 0..5 {
let cur = rows[round - 1][reg_a_prime_prime_prime(x, y)];
let cur_u64 = cur.to_canonical_u64();
let bit_values: Vec<u64> = (0..64)
.scan(cur_u64, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(tmp)
})
.collect();
for z in 0..64 {
rows[round][reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]);
}
}
}
rows[round] = rows[round - 1].clone();
self.generate_trace_rows_for_round(&mut rows[round], round);
}
@ -81,7 +99,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
}
fn generate_trace_rows_for_round(&self, row: &mut [F; NUM_REGISTERS], round: usize) {
row[round] = F::ONE;
row[reg_step(round)] = F::ONE;
// Populate C partial and C.
for x in 0..5 {
@ -113,24 +131,23 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
for x in 0..5 {
for y in 0..5 {
let get_bit = |z| {
xor([
row[reg_b(x, y, z)],
andn(row[reg_b((x + 1) % 5, y, z)], row[reg_b((x + 2) % 5, y, z)]),
])
};
// let get_bit = |z| {
// // xor([
// // row[reg_b(x, y, z)],
// // andn(row[reg_b((x + 1) % 5, y, z)], row[reg_b((x + 2) % 5, y, z)]),
// // ])
// };
let lo = (0..32)
.rev()
.fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
let hi = (32..64)
.rev()
.fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
let lo = F::ZERO;//row[reg_b(x, y, 0)];
// let hi = (32..64)
// .rev()
// .fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
let reg_lo = reg_a_prime_prime(x, y);
let reg_hi = reg_lo + 1;
row[reg_lo] = lo;
row[reg_hi] = hi;
// row[reg_hi] = hi;
}
}
@ -223,28 +240,28 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
for x in 0..5 {
for y in 0..5 {
let get_bit = |z| {
xor_gen(
vars.local_values[reg_b(x, y, z)],
andn_gen(
vars.local_values[reg_b((x + 1) % 5, y, z)],
vars.local_values[reg_b((x + 2) % 5, y, z)],
),
)
// xor_gen(
vars.local_values[reg_b(x, y, z)]
// andn_gen(
// vars.local_values[reg_b((x + 1) % 5, y, z)],
// vars.local_values[reg_b((x + 2) % 5, y, z)],
// ),
// )
};
let reg_lo = reg_a_prime_prime(x, y);
let reg_hi = reg_lo + 1;
let lo = vars.local_values[reg_lo];
let hi = vars.local_values[reg_hi];
let computed_lo = (0..32)
.rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
// let hi = vars.local_values[reg_hi];
let computed_lo = P::ZEROS;// vars.local_values[reg_b(x, y, 0)];//(0..32)
// .rev()
// .fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
let computed_hi = (32..64)
.rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
yield_constr.constraint(computed_lo - lo);
yield_constr.constraint(computed_hi - hi);
yield_constr.constraint_last_row(computed_lo - lo);
// yield_constr.constraint(computed_hi - hi);
}
}
@ -260,8 +277,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
.fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]);
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
yield_constr.constraint(computed_a_prime_prime_0_0_lo - a_prime_prime_0_0_lo);
yield_constr.constraint(computed_a_prime_prime_0_0_hi - a_prime_prime_0_0_hi);
// yield_constr.constraint(computed_a_prime_prime_0_0_lo - a_prime_prime_0_0_lo);
// yield_constr.constraint(computed_a_prime_prime_0_0_hi - a_prime_prime_0_0_hi);
let get_xored_bit = |i| {
let mut rc_bit_i = P::ZEROS;
@ -283,8 +300,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let computed_a_prime_prime_prime_0_0_hi = (32..64)
.rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z));
yield_constr.constraint(computed_a_prime_prime_prime_0_0_lo - a_prime_prime_prime_0_0_lo);
yield_constr.constraint(computed_a_prime_prime_prime_0_0_hi - a_prime_prime_prime_0_0_hi);
// yield_constr.constraint(computed_a_prime_prime_prime_0_0_lo - a_prime_prime_prime_0_0_lo);
// yield_constr.constraint(computed_a_prime_prime_prime_0_0_hi - a_prime_prime_prime_0_0_hi);
// Enforce that this round's output equals the next round's input.
for x in 0..5 {
@ -296,7 +313,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let input_bits_combined = (0..64)
.rev()
.fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]);
yield_constr.constraint(output - input_bits_combined);
// yield_constr.constraint(output - input_bits_combined);
}
}
}
@ -321,7 +338,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2);
let diff = builder.sub_extension(c_partial, xor_012);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
}
}
@ -335,7 +352,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4);
let diff = builder.sub_extension(c, xor_01234);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
}
}
@ -352,7 +369,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let a_prime = vars.local_values[reg_a_prime(x, y, z)];
let xor = xor_gen_circuit(builder, d, a);
let diff = builder.sub_extension(a_prime, xor);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
}
}
}
@ -378,9 +395,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two);
let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two);
let diff = builder.sub_extension(computed_lo, lo);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(computed_hi, hi);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
}
}
@ -395,9 +412,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
let diff = builder.sub_extension(computed_a_prime_prime_0_0_lo, a_prime_prime_0_0_lo);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(computed_a_prime_prime_0_0_hi, a_prime_prime_0_0_hi);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
let mut get_xored_bit = |i| {
let mut rc_bit_i = builder.zero_extension();
@ -423,12 +440,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
computed_a_prime_prime_prime_0_0_lo,
a_prime_prime_prime_0_0_lo,
);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
let diff = builder.sub_extension(
computed_a_prime_prime_prime_0_0_hi,
a_prime_prime_prime_0_0_hi,
);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
// Enforce that this round's output equals the next round's input.
for x in 0..5 {
@ -439,7 +456,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
.collect_vec();
let input_bits_combined = reduce_with_powers_ext_circuit(builder, &input_bits, two);
let diff = builder.sub_extension(output, input_bits_combined);
yield_constr.constraint(builder, diff);
// yield_constr.constraint(builder, diff);
}
}
}