PR feedback

This commit is contained in:
wborgeaud 2022-06-13 18:54:12 +02:00
parent fdd6a7cad8
commit e969f10b20
3 changed files with 25 additions and 126 deletions

View File

@ -49,7 +49,7 @@ impl Table {
#[cfg(test)]
mod tests {
use anyhow::Result;
use itertools::Itertools;
use itertools::{izip, Itertools};
use plonky2::field::field_types::Field;
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
@ -60,6 +60,7 @@ mod tests {
use crate::all_stark::{AllStark, Table};
use crate::config::StarkConfig;
use crate::cpu::columns::{KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS};
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns};
use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS};
@ -88,12 +89,12 @@ mod tests {
};
let mut rng = thread_rng();
let num_inputs = 2;
let keccak_inputs = (0..num_inputs)
let num_keccak_perms = 2;
let keccak_inputs = (0..num_keccak_perms)
.map(|_| [0u64; NUM_INPUTS].map(|_| rng.gen()))
.collect_vec();
let keccak_trace = keccak_stark.generate_trace(keccak_inputs);
let keccak_input_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_inputs)
let keccak_input_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms)
.map(|i| {
(0..2 * NUM_INPUTS)
.map(|j| {
@ -105,7 +106,7 @@ mod tests {
.unwrap()
})
.collect();
let keccak_output_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_inputs)
let keccak_output_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms)
.map(|i| {
(0..2 * NUM_INPUTS)
.map(|j| {
@ -126,16 +127,18 @@ mod tests {
cpu_stark.generate(&mut cpu_trace_row);
cpu_trace_rows.push(cpu_trace_row);
}
for i in 0..num_inputs {
for i in 0..num_keccak_perms {
cpu_trace_rows[i][cpu::columns::IS_KECCAK] = F::ONE;
for j in 0..2 * NUM_INPUTS {
cpu_trace_rows[i][cpu::columns::KECCAK_INPUT_LIMBS[j]] = keccak_input_limbs[i][j];
cpu_trace_rows[i][cpu::columns::KECCAK_OUTPUT_LIMBS[j]] = keccak_output_limbs[i][j];
for (j, input, output) in
izip!(0..2 * NUM_INPUTS, KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS)
{
cpu_trace_rows[i][input] = keccak_input_limbs[i][j];
cpu_trace_rows[i][output] = keccak_output_limbs[i][j];
}
}
let cpu_trace = trace_rows_to_poly_values(cpu_trace_rows);
let mut cpu_keccak_input_output = cpu::columns::KECCAK_INPUT_LIMBS.to_vec();
let mut cpu_keccak_input_output = cpu::columns::KECCAK_INPUT_LIMBS.collect::<Vec<_>>();
cpu_keccak_input_output.extend(cpu::columns::KECCAK_OUTPUT_LIMBS);
let mut keccak_keccak_input_output = (0..2 * NUM_INPUTS)
.map(keccak::registers::reg_input_limb)

View File

@ -1,3 +1,5 @@
use std::ops::Range;
// Filter. 1 if the row corresponds to a cycle of execution and 0 otherwise.
// Lets us re-use decode columns in non-cycle rows.
pub const IS_CPU_CYCLE: usize = 0;
@ -137,111 +139,9 @@ pub const IS_KECCAK: usize = OPCODE_BITS[OPCODE_BITS.len() - 1] + 1;
pub const START_KECCAK_INPUT: usize = IS_KECCAK + 1;
#[allow(dead_code)] // TODO: Remove when used
pub const KECCAK_INPUT_LIMBS: [usize; 50] = [
START_KECCAK_INPUT,
START_KECCAK_INPUT + 1,
START_KECCAK_INPUT + 2,
START_KECCAK_INPUT + 3,
START_KECCAK_INPUT + 4,
START_KECCAK_INPUT + 5,
START_KECCAK_INPUT + 6,
START_KECCAK_INPUT + 7,
START_KECCAK_INPUT + 8,
START_KECCAK_INPUT + 9,
START_KECCAK_INPUT + 10,
START_KECCAK_INPUT + 11,
START_KECCAK_INPUT + 12,
START_KECCAK_INPUT + 13,
START_KECCAK_INPUT + 14,
START_KECCAK_INPUT + 15,
START_KECCAK_INPUT + 16,
START_KECCAK_INPUT + 17,
START_KECCAK_INPUT + 18,
START_KECCAK_INPUT + 19,
START_KECCAK_INPUT + 20,
START_KECCAK_INPUT + 21,
START_KECCAK_INPUT + 22,
START_KECCAK_INPUT + 23,
START_KECCAK_INPUT + 24,
START_KECCAK_INPUT + 25,
START_KECCAK_INPUT + 26,
START_KECCAK_INPUT + 27,
START_KECCAK_INPUT + 28,
START_KECCAK_INPUT + 29,
START_KECCAK_INPUT + 30,
START_KECCAK_INPUT + 31,
START_KECCAK_INPUT + 32,
START_KECCAK_INPUT + 33,
START_KECCAK_INPUT + 34,
START_KECCAK_INPUT + 35,
START_KECCAK_INPUT + 36,
START_KECCAK_INPUT + 37,
START_KECCAK_INPUT + 38,
START_KECCAK_INPUT + 39,
START_KECCAK_INPUT + 40,
START_KECCAK_INPUT + 41,
START_KECCAK_INPUT + 42,
START_KECCAK_INPUT + 43,
START_KECCAK_INPUT + 44,
START_KECCAK_INPUT + 45,
START_KECCAK_INPUT + 46,
START_KECCAK_INPUT + 47,
START_KECCAK_INPUT + 48,
START_KECCAK_INPUT + 49,
];
pub const KECCAK_INPUT_LIMBS: Range<usize> = START_KECCAK_INPUT..START_KECCAK_INPUT + 50;
pub const START_KECCAK_OUTPUT: usize = START_KECCAK_INPUT + 50;
pub const KECCAK_OUTPUT_LIMBS: [usize; 50] = [
START_KECCAK_OUTPUT,
START_KECCAK_OUTPUT + 1,
START_KECCAK_OUTPUT + 2,
START_KECCAK_OUTPUT + 3,
START_KECCAK_OUTPUT + 4,
START_KECCAK_OUTPUT + 5,
START_KECCAK_OUTPUT + 6,
START_KECCAK_OUTPUT + 7,
START_KECCAK_OUTPUT + 8,
START_KECCAK_OUTPUT + 9,
START_KECCAK_OUTPUT + 10,
START_KECCAK_OUTPUT + 11,
START_KECCAK_OUTPUT + 12,
START_KECCAK_OUTPUT + 13,
START_KECCAK_OUTPUT + 14,
START_KECCAK_OUTPUT + 15,
START_KECCAK_OUTPUT + 16,
START_KECCAK_OUTPUT + 17,
START_KECCAK_OUTPUT + 18,
START_KECCAK_OUTPUT + 19,
START_KECCAK_OUTPUT + 20,
START_KECCAK_OUTPUT + 21,
START_KECCAK_OUTPUT + 22,
START_KECCAK_OUTPUT + 23,
START_KECCAK_OUTPUT + 24,
START_KECCAK_OUTPUT + 25,
START_KECCAK_OUTPUT + 26,
START_KECCAK_OUTPUT + 27,
START_KECCAK_OUTPUT + 28,
START_KECCAK_OUTPUT + 29,
START_KECCAK_OUTPUT + 30,
START_KECCAK_OUTPUT + 31,
START_KECCAK_OUTPUT + 32,
START_KECCAK_OUTPUT + 33,
START_KECCAK_OUTPUT + 34,
START_KECCAK_OUTPUT + 35,
START_KECCAK_OUTPUT + 36,
START_KECCAK_OUTPUT + 37,
START_KECCAK_OUTPUT + 38,
START_KECCAK_OUTPUT + 39,
START_KECCAK_OUTPUT + 40,
START_KECCAK_OUTPUT + 41,
START_KECCAK_OUTPUT + 42,
START_KECCAK_OUTPUT + 43,
START_KECCAK_OUTPUT + 44,
START_KECCAK_OUTPUT + 45,
START_KECCAK_OUTPUT + 46,
START_KECCAK_OUTPUT + 47,
START_KECCAK_OUTPUT + 48,
START_KECCAK_OUTPUT + 49,
];
pub const START_KECCAK_OUTPUT: usize = KECCAK_INPUT_LIMBS.end;
pub const KECCAK_OUTPUT_LIMBS: Range<usize> = START_KECCAK_OUTPUT..START_KECCAK_OUTPUT + 50;
pub const NUM_CPU_COLUMNS: usize = KECCAK_OUTPUT_LIMBS[KECCAK_OUTPUT_LIMBS.len() - 1] + 1;
pub const NUM_CPU_COLUMNS: usize = KECCAK_OUTPUT_LIMBS.end;

View File

@ -1,6 +1,5 @@
use std::collections::HashSet;
use anyhow::{ensure, Result};
use itertools::Itertools;
use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::field_types::Field;
use plonky2::field::packed_field::PackedField;
@ -34,7 +33,7 @@ pub struct TableWithColumns {
impl TableWithColumns {
pub fn new(table: Table, columns: Vec<usize>, filter_columns: Vec<usize>) -> Self {
debug_assert_eq!(
filter_columns.iter().collect::<HashSet<_>>().len(),
filter_columns.iter().unique().count(),
filter_columns.len(),
"Duplicate filter columns."
);
@ -66,8 +65,7 @@ impl<F: Field> CrossTableLookup<F> {
assert!(
looking_tables
.iter()
.all(|twc| twc.filter_columns.is_empty())
== default.is_some()
.all(|twc| twc.filter_columns.is_empty() == default.is_some())
&& default.is_some() == looked_table.filter_columns.is_empty(),
"Default values should be provided iff there are no filter columns."
);
@ -200,12 +198,10 @@ fn partial_products<F: Field>(
} else {
filter_columns.iter().map(|&j| trace[j].values[i]).sum()
};
partial_prod *= if filter.is_zero() {
F::ONE
} else if filter.is_one() {
challenge.combine(columns.iter().map(|&j| &trace[j].values[i]))
if filter.is_one() {
partial_prod *= challenge.combine(columns.iter().map(|&j| &trace[j].values[i]));
} else {
panic!("Non-binary filter?")
assert_eq!(filter, F::ZERO, "Non-binary filter?")
};
res.push(partial_prod);
}