diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 2285a170..e282583e 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -13,3 +13,5 @@ itertools = "0.10.0" log = "0.4.14" rayon = "1.5.1" rand = "0.8.5" +rand_chacha = "0.3.1" +keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" } \ No newline at end of file diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 4ef79314..c3c9653c 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -49,20 +49,22 @@ impl Table { #[cfg(test)] mod tests { use anyhow::Result; + use itertools::Itertools; use plonky2::field::field_types::Field; - use plonky2::field::polynomial::PolynomialValues; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::util::timing::TimingTree; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; use crate::all_stark::{AllStark, Table}; use crate::config::StarkConfig; use crate::cpu; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::CrossTableLookup; - use crate::keccak::keccak_stark::KeccakStark; + use crate::keccak::keccak_stark::{KeccakStark, INPUT_LIMBS, NUM_ROUNDS}; use crate::proof::AllProof; use crate::prover::prove; use crate::recursive_verifier::{ @@ -85,31 +87,41 @@ mod tests { let keccak_stark = KeccakStark:: { f: Default::default(), }; - let keccak_rows = 256; + let keccak_rows = (NUM_ROUNDS + 1).next_power_of_two(); let keccak_looked_col = 3; + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let num_inputs = 1; + let keccak_inputs = (0..num_inputs) + .map(|_| [0u64; INPUT_LIMBS].map(|_| rng.gen())) + .collect_vec(); + let keccak_trace = keccak_stark.generate_trace(keccak_inputs); + let column_to_copy: Vec<_> = keccak_trace[keccak_looked_col].values[..].into(); + + let default = vec![F::ONE; 1]; + let mut cpu_trace_rows = vec![]; for i in 0..cpu_rows { let mut cpu_trace_row = [F::ZERO; CpuStark::::COLUMNS]; cpu_trace_row[cpu::columns::IS_CPU_CYCLE] = F::ONE; - cpu_trace_row[cpu::columns::OPCODE] = F::from_canonical_usize(i); + if i < keccak_rows { + cpu_trace_row[cpu::columns::OPCODE] = column_to_copy[i]; + } else { + cpu_trace_row[cpu::columns::OPCODE] = default[0]; + } cpu_stark.generate(&mut cpu_trace_row); cpu_trace_rows.push(cpu_trace_row); } let cpu_trace = trace_rows_to_poly_values(cpu_trace_rows); - let mut keccak_trace = - vec![PolynomialValues::zero(keccak_rows); KeccakStark::::COLUMNS]; - keccak_trace[keccak_looked_col] = cpu_trace[cpu::columns::OPCODE].clone(); - - let default = vec![F::ZERO; 2]; - let cross_table_lookups = vec![CrossTableLookup { - looking_tables: vec![Table::Cpu], - looking_columns: vec![vec![cpu::columns::OPCODE]], - looked_table: Table::Keccak, - looked_columns: vec![keccak_looked_col], + // TODO: temporary until cross-table-lookup filters are implemented + let cross_table_lookups = vec![CrossTableLookup::new( + vec![Table::Cpu], + vec![vec![cpu::columns::OPCODE]], + Table::Keccak, + vec![keccak_looked_col], default, - }]; + )]; let all_stark = AllStark { cpu_stark, diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index f4887b68..1f89f7b9 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -23,11 +23,11 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; #[derive(Clone)] pub struct CrossTableLookup { - pub looking_tables: Vec, - pub looking_columns: Vec>, - pub looked_table: Table, - pub looked_columns: Vec, - pub default: Vec, + looking_tables: Vec
, + looking_columns: Vec>, + looked_table: Table, + looked_columns: Vec, + default: Vec, } impl CrossTableLookup { @@ -42,6 +42,7 @@ impl CrossTableLookup { assert!(looking_columns .iter() .all(|cols| cols.len() == looked_columns.len())); + assert!(default.len() == looked_columns.len()); Self { looking_tables, looking_columns, diff --git a/evm/src/keccak/constants.rs b/evm/src/keccak/constants.rs new file mode 100644 index 00000000..72286237 --- /dev/null +++ b/evm/src/keccak/constants.rs @@ -0,0 +1,157 @@ +const RC: [u64; 24] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, +]; + +const RC_BITS: [[u8; 64]; 24] = [ + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + [ + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + ], +]; + +pub(crate) const fn rc_value_bit(round: usize, bit_index: usize) -> u8 { + RC_BITS[round][bit_index] +} + +pub(crate) const fn rc_value(round: usize) -> u64 { + RC[round] +} diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 0a8a49f6..cbc2408b 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -1,60 +1,522 @@ use std::marker::PhantomData; +use itertools::Itertools; +use log::info; use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::PolynomialValues; use plonky2::hash::hash_types::RichField; +use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit; +use plonky2::timed; +use plonky2::util::timing::TimingTree; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::permutation::PermutationPair; +use crate::keccak::constants::{rc_value, rc_value_bit}; +use crate::keccak::logic::{ + andn, andn_gen, andn_gen_circuit, xor, xor3_gen, xor3_gen_circuit, xor_gen, xor_gen_circuit, +}; +use crate::keccak::registers::{ + reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, + reg_b, reg_c, reg_c_partial, reg_step, NUM_REGISTERS, +}; +use crate::keccak::round_flags::{eval_round_flags, eval_round_flags_recursively}; use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +/// Number of rounds in a Keccak permutation. +pub(crate) const NUM_ROUNDS: usize = 24; + +/// Number of 64-bit limbs in a preimage of the Keccak permutation. +pub(crate) const INPUT_LIMBS: usize = 25; + +pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; + #[derive(Copy, Clone)] pub struct KeccakStark { pub(crate) f: PhantomData, } +impl, const D: usize> KeccakStark { + /// Generate the rows of the trace. Note that this does not generate the permuted columns used + /// in our lookup arguments, as those are computed after transposing to column-wise form. + pub(crate) fn generate_trace_rows( + &self, + inputs: Vec<[u64; INPUT_LIMBS]>, + ) -> Vec<[F; NUM_REGISTERS]> { + let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two(); + info!("{} rows", num_rows); + let mut rows = Vec::with_capacity(num_rows); + for input in inputs.iter() { + rows.extend(self.generate_trace_rows_for_perm(*input)); + } + + // Pad rows to power of two. + for i in rows.len()..num_rows { + let mut row = [F::ZERO; NUM_REGISTERS]; + self.copy_output_to_input(rows[i - 1], &mut row); + self.generate_trace_row_for_round(&mut row, i % NUM_ROUNDS); + rows.push(row); + } + + rows + } + + fn generate_trace_rows_for_perm(&self, input: [u64; INPUT_LIMBS]) -> Vec<[F; NUM_REGISTERS]> { + let mut rows = vec![[F::ZERO; NUM_REGISTERS]; NUM_ROUNDS]; + + for x in 0..5 { + for y in 0..5 { + let input_xy = input[x * 5 + y]; + for z in 0..64 { + rows[0][reg_a(x, y, z)] = F::from_canonical_u64((input_xy >> z) & 1); + } + } + } + + self.generate_trace_row_for_round(&mut rows[0], 0); + for round in 1..24 { + self.copy_output_to_input(rows[round - 1], &mut rows[round]); + self.generate_trace_row_for_round(&mut rows[round], round); + } + + rows + } + + fn copy_output_to_input( + &self, + prev_row: [F; NUM_REGISTERS], + next_row: &mut [F; NUM_REGISTERS], + ) { + for x in 0..5 { + for y in 0..5 { + let cur_lo = prev_row[reg_a_prime_prime_prime(x, y)]; + let cur_hi = prev_row[reg_a_prime_prime_prime(x, y) + 1]; + let cur_u64 = cur_lo.to_canonical_u64() | (cur_hi.to_canonical_u64() << 32); + let bit_values: Vec = (0..64) + .scan(cur_u64, |acc, _| { + let tmp = *acc & 1; + *acc >>= 1; + Some(tmp) + }) + .collect(); + + for z in 0..64 { + next_row[reg_a(x, y, z)] = F::from_canonical_u64(bit_values[z]); + } + } + } + } + + fn generate_trace_row_for_round(&self, row: &mut [F; NUM_REGISTERS], round: usize) { + row[reg_step(round)] = F::ONE; + + // Populate C partial and C. + for x in 0..5 { + for z in 0..64 { + let a = [0, 1, 2, 3, 4].map(|i| row[reg_a(x, i, z)]); + let c_partial = xor([a[0], a[1], a[2]]); + let c = xor([c_partial, a[3], a[4]]); + row[reg_c_partial(x, z)] = c_partial; + row[reg_c(x, z)] = c; + } + } + + // Populate A'. + // A'[x, y] = xor(A[x, y], D[x]) + // = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) + for x in 0..5 { + for y in 0..5 { + for z in 0..64 { + row[reg_a_prime(x, y, z)] = xor([ + row[reg_a(x, y, z)], + row[reg_c((x + 4) % 5, z)], + row[reg_c((x + 1) % 5, (z + 64 - 1) % 64)], + ]); + } + } + } + + // Populate A''. + // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + for x in 0..5 { + for y in 0..5 { + let get_bit = |z| { + xor([ + row[reg_b(x, y, z)], + andn(row[reg_b((x + 1) % 5, y, z)], row[reg_b((x + 2) % 5, y, z)]), + ]) + }; + + let lo = (0..32) + .rev() + .fold(F::ZERO, |acc, z| acc.double() + get_bit(z)); + let hi = (32..64) + .rev() + .fold(F::ZERO, |acc, z| acc.double() + get_bit(z)); + + let reg_lo = reg_a_prime_prime(x, y); + let reg_hi = reg_lo + 1; + row[reg_lo] = lo; + row[reg_hi] = hi; + } + } + + // For the XOR, we split A''[0, 0] to bits. + let val_lo = row[reg_a_prime_prime(0, 0)].to_canonical_u64(); + let val_hi = row[reg_a_prime_prime(0, 0) + 1].to_canonical_u64(); + let val = val_lo | (val_hi << 32); + let bit_values: Vec = (0..64) + .scan(val, |acc, _| { + let tmp = *acc & 1; + *acc >>= 1; + Some(tmp) + }) + .collect(); + for i in 0..64 { + row[reg_a_prime_prime_0_0_bit(i)] = F::from_canonical_u64(bit_values[i]); + } + + // A''[0, 0] is additionally xor'd with RC. + let in_reg_lo = reg_a_prime_prime(0, 0); + let in_reg_hi = in_reg_lo + 1; + let out_reg_lo = reg_a_prime_prime_prime(0, 0); + let out_reg_hi = out_reg_lo + 1; + let rc_lo = rc_value(round) & ((1 << 32) - 1); + let rc_hi = rc_value(round) >> 32; + row[out_reg_lo] = F::from_canonical_u64(row[in_reg_lo].to_canonical_u64() ^ rc_lo); + row[out_reg_hi] = F::from_canonical_u64(row[in_reg_hi].to_canonical_u64() ^ rc_hi); + } + + pub fn generate_trace(&self, inputs: Vec<[u64; INPUT_LIMBS]>) -> Vec> { + let mut timing = TimingTree::new("generate trace", log::Level::Debug); + + // Generate the witness, except for permuted columns in the lookup argument. + let trace_rows = timed!( + &mut timing, + "generate trace rows", + self.generate_trace_rows(inputs) + ); + + let trace_polys = timed!( + &mut timing, + "convert to PolynomialValues", + trace_rows_to_poly_values(trace_rows) + ); + + timing.print(); + trace_polys + } +} + impl, const D: usize> Stark for KeccakStark { - const COLUMNS: usize = 7; - const PUBLIC_INPUTS: usize = 0; + const COLUMNS: usize = NUM_REGISTERS; + const PUBLIC_INPUTS: usize = NUM_PUBLIC_INPUTS; fn eval_packed_generic( &self, - _vars: StarkEvaluationVars, - _yield_constr: &mut ConstraintConsumer

, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, P: PackedField, { + eval_round_flags(vars, yield_constr); + + // C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2]) + for x in 0..5 { + for z in 0..64 { + let c_partial = vars.local_values[reg_c_partial(x, z)]; + let a_0 = vars.local_values[reg_a(x, 0, z)]; + let a_1 = vars.local_values[reg_a(x, 1, z)]; + let a_2 = vars.local_values[reg_a(x, 2, z)]; + let xor_012 = xor3_gen(a_0, a_1, a_2); + yield_constr.constraint(c_partial - xor_012); + } + } + + // C[x] = xor(C_partial[x], A[x, 3], A[x, 4]) + for x in 0..5 { + for z in 0..64 { + let c = vars.local_values[reg_c(x, z)]; + let xor_012 = vars.local_values[reg_c_partial(x, z)]; + let a_3 = vars.local_values[reg_a(x, 3, z)]; + let a_4 = vars.local_values[reg_a(x, 4, z)]; + let xor_01234 = xor3_gen(xor_012, a_3, a_4); + yield_constr.constraint(c - xor_01234); + } + } + + // A'[x, y] = xor(A[x, y], D[x]) + // = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) + for x in 0..5 { + for z in 0..64 { + let c_left = vars.local_values[reg_c((x + 4) % 5, z)]; + let c_right = vars.local_values[reg_c((x + 1) % 5, (z + 64 - 1) % 64)]; + let d = xor_gen(c_left, c_right); + + for y in 0..5 { + let a = vars.local_values[reg_a(x, y, z)]; + let a_prime = vars.local_values[reg_a_prime(x, y, z)]; + let xor = xor_gen(d, a); + yield_constr.constraint(a_prime - xor); + } + } + } + + // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + for x in 0..5 { + for y in 0..5 { + let get_bit = |z| { + xor_gen( + vars.local_values[reg_b(x, y, z)], + andn_gen( + vars.local_values[reg_b((x + 1) % 5, y, z)], + vars.local_values[reg_b((x + 2) % 5, y, z)], + ), + ) + }; + + let reg_lo = reg_a_prime_prime(x, y); + let reg_hi = reg_lo + 1; + let lo = vars.local_values[reg_lo]; + let hi = vars.local_values[reg_hi]; + let computed_lo = (0..32) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z)); + let computed_hi = (32..64) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + get_bit(z)); + + yield_constr.constraint(computed_lo - lo); + yield_constr.constraint(computed_hi - hi); + } + } + + // A'''[0, 0] = A''[0, 0] XOR RC + let a_prime_prime_0_0_bits = (0..64) + .map(|i| vars.local_values[reg_a_prime_prime_0_0_bit(i)]) + .collect_vec(); + let computed_a_prime_prime_0_0_lo = (0..32) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]); + let computed_a_prime_prime_0_0_hi = (32..64) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + a_prime_prime_0_0_bits[z]); + let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)]; + let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1]; + yield_constr.constraint(computed_a_prime_prime_0_0_lo - a_prime_prime_0_0_lo); + yield_constr.constraint(computed_a_prime_prime_0_0_hi - a_prime_prime_0_0_hi); + + let get_xored_bit = |i| { + let mut rc_bit_i = P::ZEROS; + for r in 0..NUM_ROUNDS { + let this_round = vars.local_values[reg_step(r)]; + let this_round_constant = + P::from(FE::from_canonical_u32(rc_value_bit(r, i) as u32)); + rc_bit_i += this_round * this_round_constant; + } + + xor_gen(a_prime_prime_0_0_bits[i], rc_bit_i) + }; + + let a_prime_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime_prime(0, 0)]; + let a_prime_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime_prime(0, 0) + 1]; + let computed_a_prime_prime_prime_0_0_lo = (0..32) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z)); + let computed_a_prime_prime_prime_0_0_hi = (32..64) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + get_xored_bit(z)); + yield_constr.constraint(computed_a_prime_prime_prime_0_0_lo - a_prime_prime_prime_0_0_lo); + yield_constr.constraint(computed_a_prime_prime_prime_0_0_hi - a_prime_prime_prime_0_0_hi); + + // Enforce that this round's output equals the next round's input. + for x in 0..5 { + for y in 0..5 { + let output_lo = vars.local_values[reg_a_prime_prime_prime(x, y)]; + let output_hi = vars.local_values[reg_a_prime_prime_prime(x, y) + 1]; + let input_bits = (0..64) + .map(|z| vars.next_values[reg_a(x, y, z)]) + .collect_vec(); + let input_bits_combined_lo = (0..32) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]); + let input_bits_combined_hi = (32..64) + .rev() + .fold(P::ZEROS, |acc, z| acc.doubles() + input_bits[z]); + yield_constr.constraint_transition(output_lo - input_bits_combined_lo); + yield_constr.constraint_transition(output_hi - input_bits_combined_hi); + } + } } fn eval_ext_circuit( &self, - _builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - _vars: StarkEvaluationTargets, - _yield_constr: &mut RecursiveConstraintConsumer, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, ) { + let two = builder.two(); + + eval_round_flags_recursively(builder, vars, yield_constr); + + // C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2]) + for x in 0..5 { + for z in 0..64 { + let c_partial = vars.local_values[reg_c_partial(x, z)]; + let a_0 = vars.local_values[reg_a(x, 0, z)]; + let a_1 = vars.local_values[reg_a(x, 1, z)]; + let a_2 = vars.local_values[reg_a(x, 2, z)]; + + let xor_012 = xor3_gen_circuit(builder, a_0, a_1, a_2); + let diff = builder.sub_extension(c_partial, xor_012); + yield_constr.constraint(builder, diff); + } + } + + // C[x] = xor(C_partial[x], A[x, 3], A[x, 4]) + for x in 0..5 { + for z in 0..64 { + let c = vars.local_values[reg_c(x, z)]; + let xor_012 = vars.local_values[reg_c_partial(x, z)]; + let a_3 = vars.local_values[reg_a(x, 3, z)]; + let a_4 = vars.local_values[reg_a(x, 4, z)]; + + let xor_01234 = xor3_gen_circuit(builder, xor_012, a_3, a_4); + let diff = builder.sub_extension(c, xor_01234); + yield_constr.constraint(builder, diff); + } + } + + // A'[x, y] = xor(A[x, y], D[x]) + // = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) + for x in 0..5 { + for z in 0..64 { + let c_left = vars.local_values[reg_c((x + 4) % 5, z)]; + let c_right = vars.local_values[reg_c((x + 1) % 5, (z + 64 - 1) % 64)]; + let d = xor_gen_circuit(builder, c_left, c_right); + + for y in 0..5 { + let a = vars.local_values[reg_a(x, y, z)]; + let a_prime = vars.local_values[reg_a_prime(x, y, z)]; + let xor = xor_gen_circuit(builder, d, a); + let diff = builder.sub_extension(a_prime, xor); + yield_constr.constraint(builder, diff); + } + } + } + + // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + for x in 0..5 { + for y in 0..5 { + let mut get_bit = |z| { + let andn = andn_gen_circuit( + builder, + vars.local_values[reg_b((x + 1) % 5, y, z)], + vars.local_values[reg_b((x + 2) % 5, y, z)], + ); + xor_gen_circuit(builder, vars.local_values[reg_b(x, y, z)], andn) + }; + + let reg_lo = reg_a_prime_prime(x, y); + let reg_hi = reg_lo + 1; + let lo = vars.local_values[reg_lo]; + let hi = vars.local_values[reg_hi]; + let bits_lo = (0..32).map(&mut get_bit).collect_vec(); + let bits_hi = (32..64).map(get_bit).collect_vec(); + let computed_lo = reduce_with_powers_ext_circuit(builder, &bits_lo, two); + let computed_hi = reduce_with_powers_ext_circuit(builder, &bits_hi, two); + let diff = builder.sub_extension(computed_lo, lo); + yield_constr.constraint(builder, diff); + let diff = builder.sub_extension(computed_hi, hi); + yield_constr.constraint(builder, diff); + } + } + + // A'''[0, 0] = A''[0, 0] XOR RC + let a_prime_prime_0_0_bits = (0..64) + .map(|i| vars.local_values[reg_a_prime_prime_0_0_bit(i)]) + .collect_vec(); + let computed_a_prime_prime_0_0_lo = + reduce_with_powers_ext_circuit(builder, &a_prime_prime_0_0_bits[0..32], two); + let computed_a_prime_prime_0_0_hi = + reduce_with_powers_ext_circuit(builder, &a_prime_prime_0_0_bits[32..64], two); + let a_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime(0, 0)]; + let a_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime(0, 0) + 1]; + let diff = builder.sub_extension(computed_a_prime_prime_0_0_lo, a_prime_prime_0_0_lo); + yield_constr.constraint(builder, diff); + let diff = builder.sub_extension(computed_a_prime_prime_0_0_hi, a_prime_prime_0_0_hi); + yield_constr.constraint(builder, diff); + + let mut get_xored_bit = |i| { + let mut rc_bit_i = builder.zero_extension(); + for r in 0..NUM_ROUNDS { + let this_round = vars.local_values[reg_step(r)]; + let this_round_constant = builder + .constant_extension(F::from_canonical_u32(rc_value_bit(r, i) as u32).into()); + rc_bit_i = builder.mul_add_extension(this_round, this_round_constant, rc_bit_i); + } + + xor_gen_circuit(builder, a_prime_prime_0_0_bits[i], rc_bit_i) + }; + + let a_prime_prime_prime_0_0_lo = vars.local_values[reg_a_prime_prime_prime(0, 0)]; + let a_prime_prime_prime_0_0_hi = vars.local_values[reg_a_prime_prime_prime(0, 0) + 1]; + let bits_lo = (0..32).map(&mut get_xored_bit).collect_vec(); + let bits_hi = (32..64).map(get_xored_bit).collect_vec(); + let computed_a_prime_prime_prime_0_0_lo = + reduce_with_powers_ext_circuit(builder, &bits_lo, two); + let computed_a_prime_prime_prime_0_0_hi = + reduce_with_powers_ext_circuit(builder, &bits_hi, two); + let diff = builder.sub_extension( + computed_a_prime_prime_prime_0_0_lo, + a_prime_prime_prime_0_0_lo, + ); + yield_constr.constraint(builder, diff); + let diff = builder.sub_extension( + computed_a_prime_prime_prime_0_0_hi, + a_prime_prime_prime_0_0_hi, + ); + yield_constr.constraint(builder, diff); + + // Enforce that this round's output equals the next round's input. + for x in 0..5 { + for y in 0..5 { + let output_lo = vars.local_values[reg_a_prime_prime_prime(x, y)]; + let output_hi = vars.local_values[reg_a_prime_prime_prime(x, y) + 1]; + let input_bits = (0..64) + .map(|z| vars.next_values[reg_a(x, y, z)]) + .collect_vec(); + let input_bits_combined_lo = + reduce_with_powers_ext_circuit(builder, &input_bits[0..32], two); + let input_bits_combined_hi = + reduce_with_powers_ext_circuit(builder, &input_bits[32..64], two); + let diff = builder.sub_extension(output_lo, input_bits_combined_lo); + yield_constr.constraint_transition(builder, diff); + let diff = builder.sub_extension(output_hi, input_bits_combined_hi); + yield_constr.constraint_transition(builder, diff); + } + } } fn constraint_degree(&self) -> usize { 3 } - - fn permutation_pairs(&self) -> Vec { - vec![PermutationPair::singletons(0, 6)] - } } #[cfg(test)] mod tests { use anyhow::Result; + use keccak_rust::{KeccakF, StateBitsWidth}; + use plonky2::field::field_types::Field; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::keccak::keccak_stark::KeccakStark; + use crate::keccak::keccak_stark::{KeccakStark, INPUT_LIMBS, NUM_ROUNDS}; + use crate::keccak::registers::reg_a_prime_prime_prime; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; #[test] - #[ignore] // TODO: remove this when constraints are no longer all 0. fn test_stark_degree() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; @@ -79,4 +541,51 @@ mod tests { }; test_stark_circuit_constraints::(stark) } + + #[test] + fn keccak_correctness_test() -> Result<()> { + let input: [u64; INPUT_LIMBS] = rand::random(); + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakStark; + + let stark = S { + f: Default::default(), + }; + + let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()]); + let last_row = rows[NUM_ROUNDS - 1]; + let mut output = Vec::new(); + let base = F::from_canonical_u64(1 << 32); + for x in 0..5 { + for y in 0..5 { + output.push( + last_row[reg_a_prime_prime_prime(x, y)] + + base * last_row[reg_a_prime_prime_prime(x, y) + 1], + ); + } + } + + let mut keccak_input: [[u64; 5]; 5] = [ + input[0..5].try_into().unwrap(), + input[5..10].try_into().unwrap(), + input[10..15].try_into().unwrap(), + input[15..20].try_into().unwrap(), + input[20..25].try_into().unwrap(), + ]; + + let keccak = KeccakF::new(StateBitsWidth::F1600); + keccak.permutations(&mut keccak_input); + let expected: Vec<_> = keccak_input + .iter() + .flatten() + .map(|&x| F::from_canonical_u64(x)) + .collect(); + + assert_eq!(output, expected); + + Ok(()) + } } diff --git a/evm/src/keccak/logic.rs b/evm/src/keccak/logic.rs new file mode 100644 index 00000000..7d248258 --- /dev/null +++ b/evm/src/keccak/logic.rs @@ -0,0 +1,65 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::PrimeField64; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +pub(crate) fn xor(xs: [F; N]) -> F { + xs.into_iter().fold(F::ZERO, |acc, x| { + debug_assert!(x.is_zero() || x.is_one()); + F::from_canonical_u64(acc.to_canonical_u64() ^ x.to_canonical_u64()) + }) +} + +/// Computes the arithmetic generalization of `xor(x, y)`, i.e. `x + y - 2 x y`. +pub(crate) fn xor_gen(x: P, y: P) -> P { + x + y - x * y.doubles() +} + +/// Computes the arithmetic generalization of `xor3(x, y, z)`. +pub(crate) fn xor3_gen(x: P, y: P, z: P) -> P { + xor_gen(x, xor_gen(y, z)) +} + +/// Computes the arithmetic generalization of `xor(x, y)`, i.e. `x + y - 2 x y`. +pub(crate) fn xor_gen_circuit, const D: usize>( + builder: &mut CircuitBuilder, + x: ExtensionTarget, + y: ExtensionTarget, +) -> ExtensionTarget { + let sum = builder.add_extension(x, y); + builder.arithmetic_extension(-F::TWO, F::ONE, x, y, sum) +} + +/// Computes the arithmetic generalization of `xor(x, y)`, i.e. `x + y - 2 x y`. +pub(crate) fn xor3_gen_circuit, const D: usize>( + builder: &mut CircuitBuilder, + x: ExtensionTarget, + y: ExtensionTarget, + z: ExtensionTarget, +) -> ExtensionTarget { + let x_xor_y = xor_gen_circuit(builder, x, y); + xor_gen_circuit(builder, x_xor_y, z) +} + +pub(crate) fn andn(x: F, y: F) -> F { + debug_assert!(x.is_zero() || x.is_one()); + debug_assert!(y.is_zero() || y.is_one()); + let x = x.to_canonical_u64(); + let y = y.to_canonical_u64(); + F::from_canonical_u64(!x & y) +} + +pub(crate) fn andn_gen(x: P, y: P) -> P { + (P::ONES - x) * y +} + +pub(crate) fn andn_gen_circuit, const D: usize>( + builder: &mut CircuitBuilder, + x: ExtensionTarget, + y: ExtensionTarget, +) -> ExtensionTarget { + // (1 - x) y = -xy + y + builder.arithmetic_extension(F::NEG_ONE, F::ONE, x, y, y) +} diff --git a/evm/src/keccak/mod.rs b/evm/src/keccak/mod.rs index 4c31c32b..2d104339 100644 --- a/evm/src/keccak/mod.rs +++ b/evm/src/keccak/mod.rs @@ -1 +1,5 @@ +pub mod constants; pub mod keccak_stark; +pub mod logic; +pub mod registers; +pub mod round_flags; diff --git a/evm/src/keccak/registers.rs b/evm/src/keccak/registers.rs new file mode 100644 index 00000000..3a891828 --- /dev/null +++ b/evm/src/keccak/registers.rs @@ -0,0 +1,94 @@ +use crate::keccak::keccak_stark::NUM_ROUNDS; + +/// A register which is set to 1 if we are in the `i`th round, otherwise 0. +pub(crate) const fn reg_step(i: usize) -> usize { + debug_assert!(i < NUM_ROUNDS); + i +} + +const R: [[u8; 5]; 5] = [ + [0, 36, 3, 41, 18], + [1, 44, 10, 45, 2], + [62, 6, 43, 15, 61], + [28, 55, 25, 21, 56], + [27, 20, 39, 8, 14], +]; + +const START_A: usize = NUM_ROUNDS; +pub(crate) const fn reg_a(x: usize, y: usize, z: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(y < 5); + debug_assert!(z < 64); + START_A + x * 64 * 5 + y * 64 + z +} + +// C_partial[x] = xor(A[x, 0], A[x, 1], A[x, 2]) +const START_C_PARTIAL: usize = START_A + 5 * 5 * 64; +pub(crate) const fn reg_c_partial(x: usize, z: usize) -> usize { + START_C_PARTIAL + x * 64 + z +} + +// C[x] = xor(C_partial[x], A[x, 3], A[x, 4]) +const START_C: usize = START_C_PARTIAL + 5 * 64; +pub(crate) const fn reg_c(x: usize, z: usize) -> usize { + START_C + x * 64 + z +} + +// D is inlined. +// const fn reg_d(x: usize, z: usize) {} + +// A'[x, y] = xor(A[x, y], D[x]) +// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) +const START_A_PRIME: usize = START_C + 5 * 64; +pub(crate) const fn reg_a_prime(x: usize, y: usize, z: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(y < 5); + debug_assert!(z < 64); + START_A_PRIME + x * 64 * 5 + y * 64 + z +} + +pub(crate) const fn reg_b(x: usize, y: usize, z: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(y < 5); + debug_assert!(z < 64); + // B is just a rotation of A', so these are aliases for A' registers. + // From the spec, + // B[y, (2x + 3y) % 5] = ROT(A'[x, y], r[x, y]) + // So, + // B[x, y] = f((x + 3y) % 5, x) + // where f(a, b) = ROT(A'[a, b], r[a, b]) + let a = (x + 3 * y) % 5; + let b = x; + let rot = R[a][b] as usize; + reg_a_prime(a, b, (z + 64 - rot) % 64) +} + +// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). +const START_A_PRIME_PRIME: usize = START_A_PRIME + 5 * 5 * 64; +pub(crate) const fn reg_a_prime_prime(x: usize, y: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(y < 5); + START_A_PRIME_PRIME + x * 2 * 5 + y * 2 +} + +const START_A_PRIME_PRIME_0_0_BITS: usize = START_A_PRIME_PRIME + 5 * 5 * 2; +pub(crate) const fn reg_a_prime_prime_0_0_bit(i: usize) -> usize { + debug_assert!(i < 64); + START_A_PRIME_PRIME_0_0_BITS + i +} + +const REG_A_PRIME_PRIME_PRIME_0_0_LO: usize = START_A_PRIME_PRIME_0_0_BITS + 64; +const REG_A_PRIME_PRIME_PRIME_0_0_HI: usize = REG_A_PRIME_PRIME_PRIME_0_0_LO + 1; + +// A'''[0, 0] is additionally xor'd with RC. +pub(crate) const fn reg_a_prime_prime_prime(x: usize, y: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(y < 5); + if x == 0 && y == 0 { + REG_A_PRIME_PRIME_PRIME_0_0_LO + } else { + reg_a_prime_prime(x, y) + } +} + +pub(crate) const NUM_REGISTERS: usize = REG_A_PRIME_PRIME_PRIME_0_0_HI + 1; diff --git a/evm/src/keccak/round_flags.rs b/evm/src/keccak/round_flags.rs new file mode 100644 index 00000000..63128ba1 --- /dev/null +++ b/evm/src/keccak/round_flags.rs @@ -0,0 +1,40 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::keccak::keccak_stark::{NUM_PUBLIC_INPUTS, NUM_ROUNDS}; +use crate::keccak::registers::reg_step; +use crate::keccak::registers::NUM_REGISTERS; +use crate::vars::StarkEvaluationTargets; +use crate::vars::StarkEvaluationVars; + +pub(crate) fn eval_round_flags>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + // Initially, the first step flag should be 1 while the others should be 0. + yield_constr.constraint_first_row(vars.local_values[reg_step(0)] - F::ONE); + for i in 1..NUM_ROUNDS { + yield_constr.constraint_first_row(vars.local_values[reg_step(i)]); + } + + // TODO: Transition. +} + +pub(crate) fn eval_round_flags_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one = builder.one_extension(); + + // Initially, the first step flag should be 1 while the others should be 0. + let step_0_minus_1 = builder.sub_extension(vars.local_values[reg_step(0)], one); + yield_constr.constraint_first_row(builder, step_0_minus_1); + for i in 1..NUM_ROUNDS { + yield_constr.constraint_first_row(builder, vars.local_values[reg_step(i)]); + } +} diff --git a/evm/src/lib.rs b/evm/src/lib.rs index bd8ca5f3..e0f04e04 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -1,4 +1,5 @@ #![allow(incomplete_features)] +#![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] #![feature(generic_const_exprs)] diff --git a/field/src/packed_field.rs b/field/src/packed_field.rs index 4b3336d9..bf5fed2d 100644 --- a/field/src/packed_field.rs +++ b/field/src/packed_field.rs @@ -95,6 +95,10 @@ where let n = buf.len() / Self::WIDTH; unsafe { std::slice::from_raw_parts_mut(buf_ptr, n) } } + + fn doubles(&self) -> Self { + *self * Self::Scalar::TWO + } } unsafe impl PackedField for F {