mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-12 18:53:11 +00:00
Merge branch 'evm_keccak_stark' into filtered_ctl
# Conflicts: # evm/src/all_stark.rs # evm/src/cross_table_lookup.rs
This commit is contained in:
commit
afda9db00a
@ -13,3 +13,5 @@ itertools = "0.10.0"
|
||||
log = "0.4.14"
|
||||
rayon = "1.5.1"
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" }
|
||||
@ -49,20 +49,22 @@ impl Table {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use itertools::Itertools;
|
||||
use plonky2::field::field_types::Field;
|
||||
use plonky2::field::polynomial::PolynomialValues;
|
||||
use plonky2::iop::witness::PartialWitness;
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::plonk::circuit_data::CircuitConfig;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use plonky2::util::timing::TimingTree;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
|
||||
use crate::all_stark::{AllStark, Table};
|
||||
use crate::config::StarkConfig;
|
||||
use crate::cpu;
|
||||
use crate::cpu::cpu_stark::CpuStark;
|
||||
use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns};
|
||||
use crate::keccak::keccak_stark::KeccakStark;
|
||||
use crate::cross_table_lookup::CrossTableLookup;
|
||||
use crate::keccak::keccak_stark::{KeccakStark, INPUT_LIMBS, NUM_ROUNDS};
|
||||
use crate::proof::AllProof;
|
||||
use crate::prover::prove;
|
||||
use crate::recursive_verifier::{
|
||||
@ -85,33 +87,41 @@ mod tests {
|
||||
let keccak_stark = KeccakStark::<F, D> {
|
||||
f: Default::default(),
|
||||
};
|
||||
let keccak_rows = 256;
|
||||
let keccak_rows = (NUM_ROUNDS + 1).next_power_of_two();
|
||||
let keccak_looked_col = 3;
|
||||
|
||||
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
|
||||
let num_inputs = 1;
|
||||
let keccak_inputs = (0..num_inputs)
|
||||
.map(|_| [0u64; INPUT_LIMBS].map(|_| rng.gen()))
|
||||
.collect_vec();
|
||||
let keccak_trace = keccak_stark.generate_trace(keccak_inputs);
|
||||
let column_to_copy: Vec<_> = keccak_trace[keccak_looked_col].values[..].into();
|
||||
|
||||
let default = vec![F::ONE; 1];
|
||||
|
||||
let mut cpu_trace_rows = vec![];
|
||||
for i in 0..cpu_rows {
|
||||
let mut cpu_trace_row = [F::ZERO; CpuStark::<F, D>::COLUMNS];
|
||||
cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ZERO;
|
||||
cpu_trace_row[cpu::columns::OPCODE] = F::from_canonical_usize(i);
|
||||
cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ONE;
|
||||
if i < keccak_rows {
|
||||
cpu_trace_row[cpu::columns::OPCODE] = column_to_copy[i];
|
||||
} else {
|
||||
cpu_trace_row[cpu::columns::OPCODE] = default[0];
|
||||
}
|
||||
cpu_stark.generate(&mut cpu_trace_row);
|
||||
cpu_trace_rows.push(cpu_trace_row);
|
||||
}
|
||||
let cpu_trace = trace_rows_to_poly_values(cpu_trace_rows);
|
||||
|
||||
let mut keccak_trace =
|
||||
vec![PolynomialValues::zero(keccak_rows); KeccakStark::<F, D>::COLUMNS];
|
||||
keccak_trace[keccak_looked_col] = cpu_trace[cpu::columns::OPCODE].clone();
|
||||
|
||||
let default = vec![F::ZERO; 2];
|
||||
let cross_table_lookups = vec![CrossTableLookup {
|
||||
looking_tables: vec![TableWithColumns::new(
|
||||
Table::Cpu,
|
||||
vec![cpu::columns::OPCODE],
|
||||
vec![],
|
||||
)],
|
||||
looked_table: TableWithColumns::new(Table::Keccak, vec![keccak_looked_col], vec![]),
|
||||
default: Some(default),
|
||||
}];
|
||||
// TODO: temporary until cross-table-lookup filters are implemented
|
||||
let cross_table_lookups = vec![CrossTableLookup::new(
|
||||
vec![Table::Cpu],
|
||||
vec![vec![cpu::columns::OPCODE]],
|
||||
Table::Keccak,
|
||||
vec![keccak_looked_col],
|
||||
default,
|
||||
)];
|
||||
|
||||
let all_stark = AllStark {
|
||||
cpu_stark,
|
||||
|
||||
@ -48,10 +48,10 @@ impl TableWithColumns {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CrossTableLookup<F: Field> {
|
||||
pub looking_tables: Vec<TableWithColumns>,
|
||||
pub looked_table: TableWithColumns,
|
||||
looking_tables: Vec<TableWithColumns>,
|
||||
looked_table: TableWithColumns,
|
||||
/// Default value if filters are not used.
|
||||
pub default: Option<Vec<F>>,
|
||||
default: Option<Vec<F>>,
|
||||
}
|
||||
|
||||
impl<F: Field> CrossTableLookup<F> {
|
||||
@ -70,6 +70,7 @@ impl<F: Field> CrossTableLookup<F> {
|
||||
== default.is_some()
|
||||
&& default.is_some() == looked_table.filter_columns.is_empty()
|
||||
);
|
||||
assert!(default.len() == looked_columns.len());
|
||||
Self {
|
||||
looking_tables,
|
||||
looked_table,
|
||||
|
||||
157
evm/src/keccak/constants.rs
Normal file
157
evm/src/keccak/constants.rs
Normal file
@ -0,0 +1,157 @@
|
||||
const RC: [u64; 24] = [
|
||||
0x0000000000000001,
|
||||
0x0000000000008082,
|
||||
0x800000000000808A,
|
||||
0x8000000080008000,
|
||||
0x000000000000808B,
|
||||
0x0000000080000001,
|
||||
0x8000000080008081,
|
||||
0x8000000000008009,
|
||||
0x000000000000008A,
|
||||
0x0000000000000088,
|
||||
0x0000000080008009,
|
||||
0x000000008000000A,
|
||||
0x000000008000808B,
|
||||
0x800000000000008B,
|
||||
0x8000000000008089,
|
||||
0x8000000000008003,
|
||||
0x8000000000008002,
|
||||
0x8000000000000080,
|
||||
0x000000000000800A,
|
||||
0x800000008000000A,
|
||||
0x8000000080008081,
|
||||
0x8000000000008080,
|
||||
0x0000000080000001,
|
||||
0x8000000080008008,
|
||||
];
|
||||
|
||||
const RC_BITS: [[u8; 64]; 24] = [
|
||||
[
|
||||
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
[
|
||||
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
],
|
||||
[
|
||||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
],
|
||||
];
|
||||
|
||||
pub(crate) const fn rc_value_bit(round: usize, bit_index: usize) -> u8 {
|
||||
RC_BITS[round][bit_index]
|
||||
}
|
||||
|
||||
pub(crate) const fn rc_value(round: usize) -> u64 {
|
||||
RC[round]
|
||||
}
|
||||
@ -1,60 +1,522 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use itertools::Itertools;
|
||||
use log::info;
|
||||
use plonky2::field::extension_field::{Extendable, FieldExtension};
|
||||
use plonky2::field::packed_field::PackedField;
|
||||
use plonky2::field::polynomial::PolynomialValues;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit;
|
||||
use plonky2::timed;
|
||||
use plonky2::util::timing::TimingTree;
|
||||
|
||||
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
|
||||
use crate::permutation::PermutationPair;
|
||||
use crate::keccak::constants::{rc_value, rc_value_bit};
|
||||
use crate::keccak::logic::{
|
||||
andn, andn_gen, andn_gen_circuit, xor, xor3_gen, xor3_gen_circuit, xor_gen, xor_gen_circuit,
|
||||
};
|
||||
use crate::keccak::registers::{
|
||||
reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime,
|
||||
reg_b, reg_c, reg_c_partial, reg_step, NUM_REGISTERS,
|
||||
};
|
||||
use crate::keccak::round_flags::{eval_round_flags, eval_round_flags_recursively};
|
||||
use crate::stark::Stark;
|
||||
use crate::util::trace_rows_to_poly_values;
|
||||
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
|
||||
|
||||
/// Number of rounds in a Keccak permutation.
|
||||
pub(crate) const NUM_ROUNDS: usize = 24;
|
||||
|
||||
/// Number of 64-bit limbs in a preimage of the Keccak permutation.
|
||||
pub(crate) const INPUT_LIMBS: usize = 25;
|
||||
|
||||
pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct KeccakStark<F, const D: usize> {
|
||||
pub(crate) f: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
|
||||
/// Generate the rows of the trace. Note that this does not generate the permuted columns used
|
||||
/// in our lookup arguments, as those are computed after transposing to column-wise form.
|
||||
pub(crate) fn generate_trace_rows(
|
||||
&self,
|
||||
inputs: Vec<[u64; INPUT_LIMBS]>,
|
||||
) -> Vec<[F; NUM_REGISTERS]> {
|
||||
let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
|
||||
info!("{} rows", num_rows);
|
||||
let mut rows = Vec::with_capacity(num_rows);
|
||||
for input in inputs.iter() {
|
||||
rows.extend(self.generate_trace_rows_for_perm(*input));
|
||||
}
|
||||
|
||||
// Pad rows to power of two.
|
||||
for i in rows.len()..num_rows {
|
||||
let mut row = [F::ZERO; NUM_REGISTERS];
|
||||
self.copy_output_to_input(rows[i - 1], &mut row);
|
||||
self.generate_trace_row_for_round(&mut row, i % NUM_ROUNDS);
|
||||
rows.push(row);
|
||||
}
|
||||
|
||||
rows
|
||||
}
|
||||
|
||||
fn generate_trace_rows_for_perm(&self, input: [u64; INPUT_LIMBS]) -> Vec<[F; NUM_REGISTERS]> {
|
||||
let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS];
|
||||
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let input_xy = input[x * 5 + y];
|
||||
for z in 0..64 {
|
||||
rows[0][reg_a(x, y, z)] = F::from_canonical_u64((input_xy >> z) & 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.generate_trace_row_for_round(&mut rows[0], 0);
|
||||
for round in 1..24 {
|
||||
self.copy_output_to_input(rows[round - 1], &mut rows[round]);
|
||||
self.generate_trace_row_for_round(&mut rows[round], round);
|
||||
}
|
||||
|
||||
rows
|
||||
}
|
||||
|
||||
fn copy_output_to_input(
|
||||
&self,
|
||||
prev_row: [F; NUM_REGISTERS],
|
||||
next_row: &mut [F; NUM_REGISTERS],
|
||||
) {
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let cur_lo = prev_row[reg_a_prime_prime_prime(x, y)];
|
||||
let cur_hi = prev_row[reg_a_prime_prime_prime(x, y) + 1];
|
||||
let cur_u64 = cur_lo.to_canonical_u64() | (cur_hi.to_canonical_u64() << 32);
|
||||
let bit_values: Vec<u64> = (0..64)
|
||||
.scan(cur_u64, |acc, _| {
|
||||
let tmp = *acc & 1;
|
||||
*acc >>= 1;
|
||||
Some(tmp)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for z in 0..64 {
|
||||
next_row[reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_trace_row_for_round(&self, row: &mut [F; NUM_REGISTERS], round: usize) {
|
||||
row[reg_step(round)] = F::ONE;
|
||||
|
||||
// Populate C partial and C.
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let a = [0, 1, 2, 3, 4].map(|i| row[reg_a(x, i, z)]);
|
||||
let c_partial = xor([a[0], a[1], a[2]]);
|
||||
let c = xor([c_partial, a[3], a[4]]);
|
||||
row[reg_c_partial(x, z)] = c_partial;
|
||||
row[reg_c(x, z)] = c;
|
||||
}
|
||||
}
|
||||
|
||||
// Populate A'.
|
||||
// A'[x, y] = xor(A[x, y], D[x])
|
||||
// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1))
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
for z in 0..64 {
|
||||
row[reg_a_prime(x, y, z)] = xor([
|
||||
row[reg_a(x, y, z)],
|
||||
row[reg_c((x + 4) % 5, z)],
|
||||
row[reg_c((x + 1) % 5, (z + 64 - 1) % 64)],
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Populate A''.
|
||||
// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let get_bit = |z| {
|
||||
xor([
|
||||
row[reg_b(x, y, z)],
|
||||
andn(row[reg_b((x + 1) % 5, y, z)], row[reg_b((x + 2) % 5, y, z)]),
|
||||
])
|
||||
};
|
||||
|
||||
let lo = (0..32)
|
||||
.rev()
|
||||
.fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
|
||||
let hi = (32..64)
|
||||
.rev()
|
||||
.fold(F::ZERO, |acc, z| acc.double() + get_bit(z));
|
||||
|
||||
let reg_lo = reg_a_prime_prime(x, y);
|
||||
let reg_hi = reg_lo + 1;
|
||||
row[reg_lo] = lo;
|
||||
row[reg_hi] = hi;
|
||||
}
|
||||
}
|
||||
|
||||
// For the XOR, we split A''[0, 0] to bits.
|
||||
let val_lo = row[reg_a_prime_prime(0, 0)].to_canonical_u64();
|
||||
let val_hi = row[reg_a_prime_prime(0, 0) + 1].to_canonical_u64();
|
||||
let val = val_lo | (val_hi << 32);
|
||||
let bit_values: Vec<u64> = (0..64)
|
||||
.scan(val, |acc, _| {
|
||||
let tmp = *acc & 1;
|
||||
*acc >>= 1;
|
||||
Some(tmp)
|
||||
})
|
||||
.collect();
|
||||
for i in 0..64 {
|
||||
row[reg_a_prime_prime_0_0_bit(i)] = F::from_canonical_u64(bit_values[i]);
|
||||
}
|
||||
|
||||
// A''[0, 0] is additionally xor'd with RC.
|
||||
let in_reg_lo = reg_a_prime_prime(0, 0);
|
||||
let in_reg_hi = in_reg_lo + 1;
|
||||
let out_reg_lo = reg_a_prime_prime_prime(0, 0);
|
||||
let out_reg_hi = out_reg_lo + 1;
|
||||
let rc_lo = rc_value(round) & ((1 << 32) - 1);
|
||||
let rc_hi = rc_value(round) >> 32;
|
||||
row[out_reg_lo] = F::from_canonical_u64(row[in_reg_lo].to_canonical_u64() ^ rc_lo);
|
||||
row[out_reg_hi] = F::from_canonical_u64(row[in_reg_hi].to_canonical_u64() ^ rc_hi);
|
||||
}
|
||||
|
||||
pub fn generate_trace(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec<PolynomialValues<F>> {
|
||||
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
|
||||
|
||||
// Generate the witness, except for permuted columns in the lookup argument.
|
||||
let trace_rows = timed!(
|
||||
&mut timing,
|
||||
"generate trace rows",
|
||||
self.generate_trace_rows(inputs)
|
||||
);
|
||||
|
||||
let trace_polys = timed!(
|
||||
&mut timing,
|
||||
"convert to PolynomialValues",
|
||||
trace_rows_to_poly_values(trace_rows)
|
||||
);
|
||||
|
||||
timing.print();
|
||||
trace_polys
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F, D> {
|
||||
const COLUMNS: usize = 7;
|
||||
const PUBLIC_INPUTS: usize = 0;
|
||||
const COLUMNS: usize = NUM_REGISTERS;
|
||||
const PUBLIC_INPUTS: usize = NUM_PUBLIC_INPUTS;
|
||||
|
||||
fn eval_packed_generic<FE, P, const D2: usize>(
|
||||
&self,
|
||||
_vars: StarkEvaluationVars<FE, P, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
|
||||
_yield_constr: &mut ConstraintConsumer<P>,
|
||||
vars: StarkEvaluationVars<FE, P, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
|
||||
yield_constr: &mut ConstraintConsumer<P>,
|
||||
) where
|
||||
FE: FieldExtension<D2, BaseField = F>,
|
||||
P: PackedField<Scalar = FE>,
|
||||
{
|
||||
eval_round_flags(vars, yield_constr);
|
||||
|
||||
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let c_partial = vars.local_values[reg_c_partial(x, z)];
|
||||
let a_0 = vars.local_values[reg_a(x, 0, z)];
|
||||
let a_1 = vars.local_values[reg_a(x, 1, z)];
|
||||
let a_2 = vars.local_values[reg_a(x, 2, z)];
|
||||
let xor_012 = xor3_gen(a_0, a_1, a_2);
|
||||
yield_constr.constraint(c_partial - xor_012);
|
||||
}
|
||||
}
|
||||
|
||||
// C[x] = xor(C_partial[x], A[x, 3], A[x, 4])
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let c = vars.local_values[reg_c(x, z)];
|
||||
let xor_012 = vars.local_values[reg_c_partial(x, z)];
|
||||
let a_3 = vars.local_values[reg_a(x, 3, z)];
|
||||
let a_4 = vars.local_values[reg_a(x, 4, z)];
|
||||
let xor_01234 = xor3_gen(xor_012, a_3, a_4);
|
||||
yield_constr.constraint(c - xor_01234);
|
||||
}
|
||||
}
|
||||
|
||||
// A'[x, y] = xor(A[x, y], D[x])
|
||||
// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1))
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let c_left = vars.local_values[reg_c((x + 4) % 5, z)];
|
||||
let c_right = vars.local_values[reg_c((x + 1) % 5, (z + 64 - 1) % 64)];
|
||||
let d = xor_gen(c_left, c_right);
|
||||
|
||||
for y in 0..5 {
|
||||
let a = vars.local_values[reg_a(x, y, z)];
|
||||
let a_prime = vars.local_values[reg_a_prime(x, y, z)];
|
||||
let xor = xor_gen(d, a);
|
||||
yield_constr.constraint(a_prime - xor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let get_bit = |z| {
|
||||
xor_gen(
|
||||
vars.local_values[reg_b(x, y, z)],
|
||||
andn_gen(
|
||||
vars.local_values[reg_b((x + 1) % 5, y, z)],
|
||||
vars.local_values[reg_b((x + 2) % 5, y, z)],
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let reg_lo = reg_a_prime_prime(x, y);
|
||||
let reg_hi = reg_lo + 1;
|
||||
let lo = vars.local_values[reg_lo];
|
||||
let hi = vars.local_values[reg_hi];
|
||||
let computed_lo = (0..32)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
|
||||
let computed_hi = (32..64)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z));
|
||||
|
||||
yield_constr.constraint(computed_lo - lo);
|
||||
yield_constr.constraint(computed_hi - hi);
|
||||
}
|
||||
}
|
||||
|
||||
// A'''[0, 0] = A''[0, 0] XOR RC
|
||||
let a_prime_prime_0_0_bits = (0..64)
|
||||
.map(|i| vars.local_values[reg_a_prime_prime_0_0_bit(i)])
|
||||
.collect_vec();
|
||||
let computed_a_prime_prime_0_0_lo = (0..32)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]);
|
||||
let computed_a_prime_prime_0_0_hi = (32..64)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]);
|
||||
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
|
||||
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
|
||||
yield_constr.constraint(computed_a_prime_prime_0_0_lo - a_prime_prime_0_0_lo);
|
||||
yield_constr.constraint(computed_a_prime_prime_0_0_hi - a_prime_prime_0_0_hi);
|
||||
|
||||
let get_xored_bit = |i| {
|
||||
let mut rc_bit_i = P::ZEROS;
|
||||
for r in 0..NUM_ROUNDS {
|
||||
let this_round = vars.local_values[reg_step(r)];
|
||||
let this_round_constant =
|
||||
P::from(FE::from_canonical_u32(rc_value_bit(r, i) as u32));
|
||||
rc_bit_i += this_round * this_round_constant;
|
||||
}
|
||||
|
||||
xor_gen(a_prime_prime_0_0_bits[i], rc_bit_i)
|
||||
};
|
||||
|
||||
let a_prime_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime_prime(0, 0)];
|
||||
let a_prime_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime_prime(0, 0) + 1];
|
||||
let computed_a_prime_prime_prime_0_0_lo = (0..32)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z));
|
||||
let computed_a_prime_prime_prime_0_0_hi = (32..64)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z));
|
||||
yield_constr.constraint(computed_a_prime_prime_prime_0_0_lo - a_prime_prime_prime_0_0_lo);
|
||||
yield_constr.constraint(computed_a_prime_prime_prime_0_0_hi - a_prime_prime_prime_0_0_hi);
|
||||
|
||||
// Enforce that this round's output equals the next round's input.
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let output_lo = vars.local_values[reg_a_prime_prime_prime(x, y)];
|
||||
let output_hi = vars.local_values[reg_a_prime_prime_prime(x, y) + 1];
|
||||
let input_bits = (0..64)
|
||||
.map(|z| vars.next_values[reg_a(x, y, z)])
|
||||
.collect_vec();
|
||||
let input_bits_combined_lo = (0..32)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]);
|
||||
let input_bits_combined_hi = (32..64)
|
||||
.rev()
|
||||
.fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]);
|
||||
yield_constr.constraint_transition(output_lo - input_bits_combined_lo);
|
||||
yield_constr.constraint_transition(output_hi - input_bits_combined_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_ext_circuit(
|
||||
&self,
|
||||
_builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
|
||||
_vars: StarkEvaluationTargets<D, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
|
||||
_yield_constr: &mut RecursiveConstraintConsumer<F, D>,
|
||||
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
|
||||
vars: StarkEvaluationTargets<D, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
|
||||
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
|
||||
) {
|
||||
let two = builder.two();
|
||||
|
||||
eval_round_flags_recursively(builder, vars, yield_constr);
|
||||
|
||||
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let c_partial = vars.local_values[reg_c_partial(x, z)];
|
||||
let a_0 = vars.local_values[reg_a(x, 0, z)];
|
||||
let a_1 = vars.local_values[reg_a(x, 1, z)];
|
||||
let a_2 = vars.local_values[reg_a(x, 2, z)];
|
||||
|
||||
let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2);
|
||||
let diff = builder.sub_extension(c_partial, xor_012);
|
||||
yield_constr.constraint(builder, diff);
|
||||
}
|
||||
}
|
||||
|
||||
// C[x] = xor(C_partial[x], A[x, 3], A[x, 4])
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let c = vars.local_values[reg_c(x, z)];
|
||||
let xor_012 = vars.local_values[reg_c_partial(x, z)];
|
||||
let a_3 = vars.local_values[reg_a(x, 3, z)];
|
||||
let a_4 = vars.local_values[reg_a(x, 4, z)];
|
||||
|
||||
let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4);
|
||||
let diff = builder.sub_extension(c, xor_01234);
|
||||
yield_constr.constraint(builder, diff);
|
||||
}
|
||||
}
|
||||
|
||||
// A'[x, y] = xor(A[x, y], D[x])
|
||||
// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1))
|
||||
for x in 0..5 {
|
||||
for z in 0..64 {
|
||||
let c_left = vars.local_values[reg_c((x + 4) % 5, z)];
|
||||
let c_right = vars.local_values[reg_c((x + 1) % 5, (z + 64 - 1) % 64)];
|
||||
let d = xor_gen_circuit(builder, c_left, c_right);
|
||||
|
||||
for y in 0..5 {
|
||||
let a = vars.local_values[reg_a(x, y, z)];
|
||||
let a_prime = vars.local_values[reg_a_prime(x, y, z)];
|
||||
let xor = xor_gen_circuit(builder, d, a);
|
||||
let diff = builder.sub_extension(a_prime, xor);
|
||||
yield_constr.constraint(builder, diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let mut get_bit = |z| {
|
||||
let andn = andn_gen_circuit(
|
||||
builder,
|
||||
vars.local_values[reg_b((x + 1) % 5, y, z)],
|
||||
vars.local_values[reg_b((x + 2) % 5, y, z)],
|
||||
);
|
||||
xor_gen_circuit(builder, vars.local_values[reg_b(x, y, z)], andn)
|
||||
};
|
||||
|
||||
let reg_lo = reg_a_prime_prime(x, y);
|
||||
let reg_hi = reg_lo + 1;
|
||||
let lo = vars.local_values[reg_lo];
|
||||
let hi = vars.local_values[reg_hi];
|
||||
let bits_lo = (0..32).map(&mut get_bit).collect_vec();
|
||||
let bits_hi = (32..64).map(get_bit).collect_vec();
|
||||
let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two);
|
||||
let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two);
|
||||
let diff = builder.sub_extension(computed_lo, lo);
|
||||
yield_constr.constraint(builder, diff);
|
||||
let diff = builder.sub_extension(computed_hi, hi);
|
||||
yield_constr.constraint(builder, diff);
|
||||
}
|
||||
}
|
||||
|
||||
// A'''[0, 0] = A''[0, 0] XOR RC
|
||||
let a_prime_prime_0_0_bits = (0..64)
|
||||
.map(|i| vars.local_values[reg_a_prime_prime_0_0_bit(i)])
|
||||
.collect_vec();
|
||||
let computed_a_prime_prime_0_0_lo =
|
||||
reduce_with_powers_ext_circuit(builder, &a_prime_prime_0_0_bits[0..32], two);
|
||||
let computed_a_prime_prime_0_0_hi =
|
||||
reduce_with_powers_ext_circuit(builder, &a_prime_prime_0_0_bits[32..64], two);
|
||||
let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)];
|
||||
let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1];
|
||||
let diff = builder.sub_extension(computed_a_prime_prime_0_0_lo, a_prime_prime_0_0_lo);
|
||||
yield_constr.constraint(builder, diff);
|
||||
let diff = builder.sub_extension(computed_a_prime_prime_0_0_hi, a_prime_prime_0_0_hi);
|
||||
yield_constr.constraint(builder, diff);
|
||||
|
||||
let mut get_xored_bit = |i| {
|
||||
let mut rc_bit_i = builder.zero_extension();
|
||||
for r in 0..NUM_ROUNDS {
|
||||
let this_round = vars.local_values[reg_step(r)];
|
||||
let this_round_constant = builder
|
||||
.constant_extension(F::from_canonical_u32(rc_value_bit(r, i) as u32).into());
|
||||
rc_bit_i = builder.mul_add_extension(this_round, this_round_constant, rc_bit_i);
|
||||
}
|
||||
|
||||
xor_gen_circuit(builder, a_prime_prime_0_0_bits[i], rc_bit_i)
|
||||
};
|
||||
|
||||
let a_prime_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime_prime(0, 0)];
|
||||
let a_prime_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime_prime(0, 0) + 1];
|
||||
let bits_lo = (0..32).map(&mut get_xored_bit).collect_vec();
|
||||
let bits_hi = (32..64).map(get_xored_bit).collect_vec();
|
||||
let computed_a_prime_prime_prime_0_0_lo =
|
||||
reduce_with_powers_ext_circuit(builder, &bits_lo, two);
|
||||
let computed_a_prime_prime_prime_0_0_hi =
|
||||
reduce_with_powers_ext_circuit(builder, &bits_hi, two);
|
||||
let diff = builder.sub_extension(
|
||||
computed_a_prime_prime_prime_0_0_lo,
|
||||
a_prime_prime_prime_0_0_lo,
|
||||
);
|
||||
yield_constr.constraint(builder, diff);
|
||||
let diff = builder.sub_extension(
|
||||
computed_a_prime_prime_prime_0_0_hi,
|
||||
a_prime_prime_prime_0_0_hi,
|
||||
);
|
||||
yield_constr.constraint(builder, diff);
|
||||
|
||||
// Enforce that this round's output equals the next round's input.
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
let output_lo = vars.local_values[reg_a_prime_prime_prime(x, y)];
|
||||
let output_hi = vars.local_values[reg_a_prime_prime_prime(x, y) + 1];
|
||||
let input_bits = (0..64)
|
||||
.map(|z| vars.next_values[reg_a(x, y, z)])
|
||||
.collect_vec();
|
||||
let input_bits_combined_lo =
|
||||
reduce_with_powers_ext_circuit(builder, &input_bits[0..32], two);
|
||||
let input_bits_combined_hi =
|
||||
reduce_with_powers_ext_circuit(builder, &input_bits[32..64], two);
|
||||
let diff = builder.sub_extension(output_lo, input_bits_combined_lo);
|
||||
yield_constr.constraint_transition(builder, diff);
|
||||
let diff = builder.sub_extension(output_hi, input_bits_combined_hi);
|
||||
yield_constr.constraint_transition(builder, diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn constraint_degree(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn permutation_pairs(&self) -> Vec<PermutationPair> {
|
||||
vec![PermutationPair::singletons(0, 6)]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use keccak_rust::{KeccakF, StateBitsWidth};
|
||||
use plonky2::field::field_types::Field;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
|
||||
use crate::keccak::keccak_stark::KeccakStark;
|
||||
use crate::keccak::keccak_stark::{KeccakStark, INPUT_LIMBS, NUM_ROUNDS};
|
||||
use crate::keccak::registers::reg_a_prime_prime_prime;
|
||||
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: remove this when constraints are no longer all 0.
|
||||
fn test_stark_degree() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
@ -79,4 +541,51 @@ mod tests {
|
||||
};
|
||||
test_stark_circuit_constraints::<F, C, S, D>(stark)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keccak_correctness_test() -> Result<()> {
|
||||
let input: [u64; INPUT_LIMBS] = rand::random();
|
||||
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
type S = KeccakStark<F, D>;
|
||||
|
||||
let stark = S {
|
||||
f: Default::default(),
|
||||
};
|
||||
|
||||
let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()]);
|
||||
let last_row = rows[NUM_ROUNDS - 1];
|
||||
let mut output = Vec::new();
|
||||
let base = F::from_canonical_u64(1 << 32);
|
||||
for x in 0..5 {
|
||||
for y in 0..5 {
|
||||
output.push(
|
||||
last_row[reg_a_prime_prime_prime(x, y)]
|
||||
+ base * last_row[reg_a_prime_prime_prime(x, y) + 1],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut keccak_input: [[u64; 5]; 5] = [
|
||||
input[0..5].try_into().unwrap(),
|
||||
input[5..10].try_into().unwrap(),
|
||||
input[10..15].try_into().unwrap(),
|
||||
input[15..20].try_into().unwrap(),
|
||||
input[20..25].try_into().unwrap(),
|
||||
];
|
||||
|
||||
let keccak = KeccakF::new(StateBitsWidth::F1600);
|
||||
keccak.permutations(&mut keccak_input);
|
||||
let expected: Vec<_> = keccak_input
|
||||
.iter()
|
||||
.flatten()
|
||||
.map(|&x| F::from_canonical_u64(x))
|
||||
.collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
65
evm/src/keccak/logic.rs
Normal file
65
evm/src/keccak/logic.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use plonky2::field::extension_field::Extendable;
|
||||
use plonky2::field::field_types::PrimeField64;
|
||||
use plonky2::field::packed_field::PackedField;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::ext_target::ExtensionTarget;
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
pub(crate) fn xor<F: PrimeField64, const N: usize>(xs: [F; N]) -> F {
|
||||
xs.into_iter().fold(F::ZERO, |acc, x| {
|
||||
debug_assert!(x.is_zero() || x.is_one());
|
||||
F::from_canonical_u64(acc.to_canonical_u64() ^ x.to_canonical_u64())
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes the arithmetic generalization of `xor(x, y)`, i.e. `x + y - 2 x y`.
|
||||
pub(crate) fn xor_gen<P: PackedField>(x: P, y: P) -> P {
|
||||
x + y - x * y.doubles()
|
||||
}
|
||||
|
||||
/// Computes the arithmetic generalization of `xor3(x, y, z)`.
|
||||
pub(crate) fn xor3_gen<P: PackedField>(x: P, y: P, z: P) -> P {
|
||||
xor_gen(x, xor_gen(y, z))
|
||||
}
|
||||
|
||||
/// Computes the arithmetic generalization of `xor(x, y)`, i.e. `x + y - 2 x y`.
|
||||
pub(crate) fn xor_gen_circuit<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
x: ExtensionTarget<D>,
|
||||
y: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let sum = builder.add_extension(x, y);
|
||||
builder.arithmetic_extension(-F::TWO, F::ONE, x, y, sum)
|
||||
}
|
||||
|
||||
/// Computes the arithmetic generalization of `xor(x, y)`, i.e. `x + y - 2 x y`.
|
||||
pub(crate) fn xor3_gen_circuit<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
x: ExtensionTarget<D>,
|
||||
y: ExtensionTarget<D>,
|
||||
z: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let x_xor_y = xor_gen_circuit(builder, x, y);
|
||||
xor_gen_circuit(builder, x_xor_y, z)
|
||||
}
|
||||
|
||||
pub(crate) fn andn<F: PrimeField64>(x: F, y: F) -> F {
|
||||
debug_assert!(x.is_zero() || x.is_one());
|
||||
debug_assert!(y.is_zero() || y.is_one());
|
||||
let x = x.to_canonical_u64();
|
||||
let y = y.to_canonical_u64();
|
||||
F::from_canonical_u64(!x & y)
|
||||
}
|
||||
|
||||
pub(crate) fn andn_gen<P: PackedField>(x: P, y: P) -> P {
|
||||
(P::ONES - x) * y
|
||||
}
|
||||
|
||||
pub(crate) fn andn_gen_circuit<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
x: ExtensionTarget<D>,
|
||||
y: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
// (1 - x) y = -xy + y
|
||||
builder.arithmetic_extension(F::NEG_ONE, F::ONE, x, y, y)
|
||||
}
|
||||
@ -1 +1,5 @@
|
||||
pub mod constants;
|
||||
pub mod keccak_stark;
|
||||
pub mod logic;
|
||||
pub mod registers;
|
||||
pub mod round_flags;
|
||||
|
||||
94
evm/src/keccak/registers.rs
Normal file
94
evm/src/keccak/registers.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use crate::keccak::keccak_stark::NUM_ROUNDS;
|
||||
|
||||
/// A register which is set to 1 if we are in the `i`th round, otherwise 0.
|
||||
pub(crate) const fn reg_step(i: usize) -> usize {
|
||||
debug_assert!(i < NUM_ROUNDS);
|
||||
i
|
||||
}
|
||||
|
||||
const R: [[u8; 5]; 5] = [
|
||||
[0, 36, 3, 41, 18],
|
||||
[1, 44, 10, 45, 2],
|
||||
[62, 6, 43, 15, 61],
|
||||
[28, 55, 25, 21, 56],
|
||||
[27, 20, 39, 8, 14],
|
||||
];
|
||||
|
||||
const START_A: usize = NUM_ROUNDS;
|
||||
pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize {
|
||||
debug_assert!(x < 5);
|
||||
debug_assert!(y < 5);
|
||||
debug_assert!(z < 64);
|
||||
START_A + x * 64 * 5 + y * 64 + z
|
||||
}
|
||||
|
||||
// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2])
|
||||
const START_C_PARTIAL: usize = START_A + 5 * 5 * 64;
|
||||
pub(crate) const fn reg_c_partial(x: usize, z: usize) -> usize {
|
||||
START_C_PARTIAL + x * 64 + z
|
||||
}
|
||||
|
||||
// C[x] = xor(C_partial[x], A[x, 3], A[x, 4])
|
||||
const START_C: usize = START_C_PARTIAL + 5 * 64;
|
||||
pub(crate) const fn reg_c(x: usize, z: usize) -> usize {
|
||||
START_C + x * 64 + z
|
||||
}
|
||||
|
||||
// D is inlined.
|
||||
// const fn reg_d(x: usize, z: usize) {}
|
||||
|
||||
// A'[x, y] = xor(A[x, y], D[x])
|
||||
// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1))
|
||||
const START_A_PRIME: usize = START_C + 5 * 64;
|
||||
pub(crate) const fn reg_a_prime(x: usize, y: usize, z: usize) -> usize {
|
||||
debug_assert!(x < 5);
|
||||
debug_assert!(y < 5);
|
||||
debug_assert!(z < 64);
|
||||
START_A_PRIME + x * 64 * 5 + y * 64 + z
|
||||
}
|
||||
|
||||
pub(crate) const fn reg_b(x: usize, y: usize, z: usize) -> usize {
|
||||
debug_assert!(x < 5);
|
||||
debug_assert!(y < 5);
|
||||
debug_assert!(z < 64);
|
||||
// B is just a rotation of A', so these are aliases for A' registers.
|
||||
// From the spec,
|
||||
// B[y, (2x + 3y) % 5] = ROT(A'[x, y], r[x, y])
|
||||
// So,
|
||||
// B[x, y] = f((x + 3y) % 5, x)
|
||||
// where f(a, b) = ROT(A'[a, b], r[a, b])
|
||||
let a = (x + 3 * y) % 5;
|
||||
let b = x;
|
||||
let rot = R[a][b] as usize;
|
||||
reg_a_prime(a, b, (z + 64 - rot) % 64)
|
||||
}
|
||||
|
||||
// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
|
||||
const START_A_PRIME_PRIME: usize = START_A_PRIME + 5 * 5 * 64;
|
||||
pub(crate) const fn reg_a_prime_prime(x: usize, y: usize) -> usize {
|
||||
debug_assert!(x < 5);
|
||||
debug_assert!(y < 5);
|
||||
START_A_PRIME_PRIME + x * 2 * 5 + y * 2
|
||||
}
|
||||
|
||||
const START_A_PRIME_PRIME_0_0_BITS: usize = START_A_PRIME_PRIME + 5 * 5 * 2;
|
||||
pub(crate) const fn reg_a_prime_prime_0_0_bit(i: usize) -> usize {
|
||||
debug_assert!(i < 64);
|
||||
START_A_PRIME_PRIME_0_0_BITS + i
|
||||
}
|
||||
|
||||
const REG_A_PRIME_PRIME_PRIME_0_0_LO: usize = START_A_PRIME_PRIME_0_0_BITS + 64;
|
||||
const REG_A_PRIME_PRIME_PRIME_0_0_HI: usize = REG_A_PRIME_PRIME_PRIME_0_0_LO + 1;
|
||||
|
||||
// A'''[0, 0] is additionally xor'd with RC.
|
||||
pub(crate) const fn reg_a_prime_prime_prime(x: usize, y: usize) -> usize {
|
||||
debug_assert!(x < 5);
|
||||
debug_assert!(y < 5);
|
||||
if x == 0 && y == 0 {
|
||||
REG_A_PRIME_PRIME_PRIME_0_0_LO
|
||||
} else {
|
||||
reg_a_prime_prime(x, y)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const NUM_REGISTERS: usize = REG_A_PRIME_PRIME_PRIME_0_0_HI + 1;
|
||||
40
evm/src/keccak/round_flags.rs
Normal file
40
evm/src/keccak/round_flags.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use plonky2::field::extension_field::Extendable;
|
||||
use plonky2::field::field_types::Field;
|
||||
use plonky2::field::packed_field::PackedField;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
|
||||
use crate::keccak::keccak_stark::{NUM_PUBLIC_INPUTS, NUM_ROUNDS};
|
||||
use crate::keccak::registers::reg_step;
|
||||
use crate::keccak::registers::NUM_REGISTERS;
|
||||
use crate::vars::StarkEvaluationTargets;
|
||||
use crate::vars::StarkEvaluationVars;
|
||||
|
||||
pub(crate) fn eval_round_flags<F: Field, P: PackedField<Scalar = F>>(
|
||||
vars: StarkEvaluationVars<F, P, NUM_REGISTERS, NUM_PUBLIC_INPUTS>,
|
||||
yield_constr: &mut ConstraintConsumer<P>,
|
||||
) {
|
||||
// Initially, the first step flag should be 1 while the others should be 0.
|
||||
yield_constr.constraint_first_row(vars.local_values[reg_step(0)] - F::ONE);
|
||||
for i in 1..NUM_ROUNDS {
|
||||
yield_constr.constraint_first_row(vars.local_values[reg_step(i)]);
|
||||
}
|
||||
|
||||
// TODO: Transition.
|
||||
}
|
||||
|
||||
pub(crate) fn eval_round_flags_recursively<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: StarkEvaluationTargets<D, NUM_REGISTERS, NUM_PUBLIC_INPUTS>,
|
||||
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
|
||||
) {
|
||||
let one = builder.one_extension();
|
||||
|
||||
// Initially, the first step flag should be 1 while the others should be 0.
|
||||
let step_0_minus_1 = builder.sub_extension(vars.local_values[reg_step(0)], one);
|
||||
yield_constr.constraint_first_row(builder, step_0_minus_1);
|
||||
for i in 1..NUM_ROUNDS {
|
||||
yield_constr.constraint_first_row(builder, vars.local_values[reg_step(i)]);
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
#![allow(clippy::type_complexity)]
|
||||
#![feature(generic_const_exprs)]
|
||||
|
||||
@ -95,6 +95,10 @@ where
|
||||
let n = buf.len() / Self::WIDTH;
|
||||
unsafe { std::slice::from_raw_parts_mut(buf_ptr, n) }
|
||||
}
|
||||
|
||||
fn doubles(&self) -> Self {
|
||||
*self * Self::Scalar::TWO
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<F: Field> PackedField for F {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user