diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index acb6c935..ffc1e404 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -6,6 +6,7 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::CrossTableLookup; use crate::keccak::keccak_stark::KeccakStark; use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; use crate::stark::Stark; #[derive(Clone)] @@ -13,6 +14,7 @@ pub struct AllStark, const D: usize> { pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, pub logic_stark: LogicStark, + pub memory_stark: MemoryStark, pub cross_table_lookups: Vec>, } @@ -43,11 +45,12 @@ pub enum Table { Cpu = 0, Keccak = 1, Logic = 2, + Memory = 3, } impl Table { pub(crate) fn num_tables() -> usize { - Table::Logic as usize + 1 + Table::Memory as usize + 1 } } @@ -68,11 +71,12 @@ mod tests { use crate::config::StarkConfig; use crate::cpu::columns::{KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS}; use crate::cpu::cpu_stark::{self as cpu_stark_mod, CpuStark}; - use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns}; use crate::keccak::keccak_stark::{ self as keccak_stark_mod, KeccakStark, NUM_INPUTS, NUM_ROUNDS, }; use crate::logic::{self, LogicStark}; + use crate::cross_table_lookup::{Column, CrossTableLookup, TableWithColumns}; + use crate::memory::memory_stark::{generate_random_memory_ops, MemoryStark}; use crate::proof::AllProof; use crate::prover::prove; use crate::recursive_verifier::{ @@ -137,6 +141,7 @@ mod tests { cpu_stark: &CpuStark, keccak_trace: &[PolynomialValues], logic_trace: &[PolynomialValues], + memory_trace: &[PolynomialValues], ) -> Vec> { let keccak_input_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms) .map(|i| { @@ -162,10 +167,6 @@ mod tests { .unwrap() }) .collect(); - let memory_trace = memory_stark.generate_trace(keccak_inputs); - let column_to_copy: Vec<_> = keccak_trace[keccak_looked_col].values[..].into(); - - let default = vec![F::ONE; 1]; let mut cpu_trace_rows = vec![]; for i in 0..num_keccak_perms { @@ -203,6 +204,30 @@ mod tests { cpu_stark.generate(&mut row); cpu_trace_rows.push(row); } + for i in 0..num_memory_ops { + let mem_timestamp: usize = memory_trace[memory::registers::TIMESTAMP].values[i] + .to_canonical_u64() + .try_into() + .unwrap(); + let clock = mem_timestamp / NUM_MEMORY_OPS; + let op = mem_timestamp % NUM_MEMORY_OPS; + + cpu_trace_rows[i][cpu::columns::uses_memop(op)] = F::ONE; + memory_trace[memory::registers::is_memop(op)].values[i] = F::ONE; + cpu_trace_rows[i][cpu::columns::CLOCK] = F::from_canonical_usize(clock); + cpu_trace_rows[i][cpu::columns::memop_is_read(op)] = + memory_trace[memory::registers::IS_READ].values[i]; + cpu_trace_rows[i][cpu::columns::memop_addr_context(op)] = + memory_trace[memory::registers::ADDR_CONTEXT].values[i]; + cpu_trace_rows[i][cpu::columns::memop_addr_segment(op)] = + memory_trace[memory::registers::ADDR_SEGMENT].values[i]; + cpu_trace_rows[i][cpu::columns::memop_addr_virtual(op)] = + memory_trace[memory::registers::ADDR_VIRTUAL].values[i]; + for j in 0..8 { + cpu_trace_rows[i][cpu::columns::memop_value(op, j)] = + memory_trace[memory::registers::value_limb(j)].values[i]; + } + } trace_rows_to_poly_values(cpu_trace_rows) } @@ -220,6 +245,11 @@ mod tests { }; let num_logic_rows = 62; + let memory_stark = MemoryStark:: { + f: Default::default(), + }; + let num_memory_ops = 1 << 5; + let mut rng = thread_rng(); let num_keccak_perms = 2; @@ -228,11 +258,47 @@ mod tests { let cpu_trace = make_cpu_trace( num_keccak_perms, num_logic_rows, + num_memory_ops, &cpu_stark, &keccak_trace, &logic_trace, ); + let memory_ops = generate_random_memory_ops(num_memory_ops); + let memory_trace = memory_stark.generate_trace(memory_ops); + + let mut cpu_keccak_input_output = cpu::columns::KECCAK_INPUT_LIMBS.collect::>(); + cpu_keccak_input_output.extend(cpu::columns::KECCAK_OUTPUT_LIMBS); + let mut keccak_keccak_input_output = (0..2 * NUM_INPUTS) + .map(keccak::registers::reg_input_limb) + .collect::>(); + keccak_keccak_input_output.extend(Column::singles( + (0..2 * NUM_INPUTS).map(keccak::registers::reg_output_limb), + )); + + let cpu_logic_input_output = { + let mut res = vec![ + cpu::columns::IS_AND, + cpu::columns::IS_OR, + cpu::columns::IS_XOR, + ]; + res.extend(cpu::columns::LOGIC_INPUT0); + res.extend(cpu::columns::LOGIC_INPUT1); + res.extend(cpu::columns::LOGIC_OUTPUT); + res + }; + let logic_logic_input_output = { + let mut res = vec![ + logic::columns::IS_AND, + logic::columns::IS_OR, + logic::columns::IS_XOR, + ]; + res.extend(logic::columns::INPUT0_PACKED); + res.extend(logic::columns::INPUT1_PACKED); + res.extend(logic::columns::RESULT); + res + }; + let cross_table_lookups = vec![ CrossTableLookup::new( vec![TableWithColumns::new( @@ -262,13 +328,14 @@ mod tests { cpu_stark, keccak_stark, logic_stark, + memory_stark, cross_table_lookups, }; let proof = prove::( &all_stark, config, - vec![cpu_trace, keccak_trace, logic_trace], + vec![cpu_trace, keccak_trace, logic_trace, memory_trace], vec![vec![]; 3], &mut TimingTree::default(), )?; diff --git a/evm/src/cpu/columns.rs b/evm/src/cpu/columns.rs index 5e5f3f55..44264c52 100644 --- a/evm/src/cpu/columns.rs +++ b/evm/src/cpu/columns.rs @@ -13,10 +13,20 @@ pub const IS_CPU_CYCLE: usize = IS_BOOTSTRAP_CONTRACT + 1; /// If CPU cycle: The opcode being decoded, in {0, ..., 255}. pub const OPCODE: usize = IS_CPU_CYCLE + 1; +<<<<<<< HEAD /// If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. Invalid /// opcodes are split between a number of flags for practical reasons. Exactly one of these flags /// must be 1. pub const IS_STOP: usize = OPCODE + 1; +======= +pub const KECCAK_DUMMY: usize = OPCODE + 1; +pub const MEMORY_DUMMY: usize = KECCAK_DUMMY + 1; + +// If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. Invalid +// opcodes are split between a number of flags for practical reasons. Exactly one of these flags +// must be 1. +pub const IS_STOP: usize = MEMORY_DUMMY + 1; +>>>>>>> ade7e5e0 (updated all_stark framework to include memory stark (doesn't pass yet)) pub const IS_ADD: usize = IS_STOP + 1; pub const IS_MUL: usize = IS_ADD + 1; pub const IS_SUB: usize = IS_MUL + 1; diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index ad3d0b76..11219388 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -1,9 +1,14 @@ +use std::collections::HashMap; use std::marker::PhantomData; use itertools::{izip, multiunzip}; use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; +use plonky2::field::polynomial::PolynomialValues; use plonky2::hash::hash_types::RichField; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use rand::{thread_rng, Rng}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::memory::registers::{ @@ -15,7 +20,7 @@ use crate::memory::registers::{ SORTED_MEMORY_TIMESTAMP, }; use crate::stark::Stark; -use crate::util::permuted_cols; +use crate::util::{permuted_cols, trace_rows_to_poly_values}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; #[derive(Default)] @@ -42,6 +47,52 @@ pub struct MemoryStark { pub(crate) f: PhantomData, } +pub fn generate_random_memory_ops(num_ops: usize) -> Vec<(F, F, F, [F; 8], F, F)> { + let mut memory_ops = Vec::new(); + + let mut rng = thread_rng(); + + let mut current_memory_values: HashMap<(F, F, F), [F; 8]> = HashMap::new(); + let mut cur_timestamp = 0; + for i in 0..num_ops { + let is_read = if i == 0 { false } else { rng.gen() }; + let is_read_F = F::from_bool(is_read); + + let (context, segment, virt, vals) = if is_read { + let written: Vec<_> = current_memory_values.keys().collect(); + let &(context, segment, virt) = written[rng.gen_range(0..written.len())]; + let &vals = current_memory_values + .get(&(context, segment, virt)) + .unwrap(); + + (context, segment, virt, vals) + } else { + let context = F::from_canonical_usize(rng.gen_range(0..256)); + let segment = F::from_canonical_usize(rng.gen_range(0..8)); + let virt = F::from_canonical_usize(rng.gen_range(0..20)); + + let val: [u32; 8] = rng.gen(); + let vals: [F; 8] = val + .iter() + .map(|&x| F::from_canonical_u32(x)) + .collect::>() + .try_into() + .unwrap(); + + current_memory_values.insert((context, segment, virt), vals); + + (context, segment, virt, vals) + }; + + let timestamp = F::from_canonical_usize(cur_timestamp); + cur_timestamp += 1; + + memory_ops.push((context, segment, virt, vals, is_read_F, timestamp)) + } + + memory_ops +} + pub fn sort_memory_ops( context: &[F], segment: &[F], @@ -152,7 +203,7 @@ pub fn generate_range_check_value( } impl, const D: usize> MemoryStark { - fn generate_trace_rows( + pub(crate) fn generate_trace_rows( &self, memory_ops: Vec<(F, F, F, [F; 8], F, F)>, ) -> Vec<[F; NUM_REGISTERS]> { @@ -175,7 +226,7 @@ impl, const D: usize> MemoryStark { self.generate_memory(&mut trace_cols); - let mut trace_rows = vec![[F::ZERO; NUM_REGISTERS]]; + let mut trace_rows = vec![[F::ZERO; NUM_REGISTERS]; num_ops]; for (i, col) in trace_cols.iter().enumerate() { for (j, &val) in col.iter().enumerate() { trace_rows[j][i] = val; @@ -231,7 +282,7 @@ impl, const D: usize> MemoryStark { trace_cols[MEMORY_VIRTUAL_FIRST_CHANGE] = virtual_first_change; trace_cols[MEMORY_RANGE_CHECK] = range_check_value; - trace_cols[MEMORY_COUNTER] = (0..trace_cols.len()) + trace_cols[MEMORY_COUNTER] = (0..trace_cols[0].len()) .map(|i| F::from_canonical_usize(i)) .collect(); @@ -240,6 +291,29 @@ impl, const D: usize> MemoryStark { trace_cols[MEMORY_RANGE_CHECK_PERMUTED] = permuted_inputs; trace_cols[MEMORY_COUNTER_PERMUTED] = permuted_table; } + + pub fn generate_trace( + &self, + memory_ops: Vec<(F, F, F, [F; 8], F, F)>, + ) -> Vec> { + let mut timing = TimingTree::new("generate trace", log::Level::Debug); + + // Generate the witness, except for permuted columns in the lookup argument. + let trace_rows = timed!( + &mut timing, + "generate trace rows", + self.generate_trace_rows(memory_ops) + ); + + let trace_polys = timed!( + &mut timing, + "convert to PolynomialValues", + trace_rows_to_poly_values(trace_rows) + ); + + timing.print(); + trace_polys + } } impl, const D: usize> Stark for MemoryStark { @@ -287,10 +361,10 @@ impl, const D: usize> Stark for MemoryStark = HashMap::new(); - let mut cur_timestamp = 0; - for i in 0..num_ops { - let is_read = if i == 0 { false } else { rng.gen() }; - let is_read_F = F::from_bool(is_read); - - let (context, segment, virt, vals) = if is_read { - let written: Vec<_> = current_memory_values.keys().collect(); - let &(context, segment, virt) = written[rng.gen_range(0..written.len())]; - let &vals = current_memory_values - .get(&(context, segment, virt)) - .unwrap(); - - (context, segment, virt, vals) - } else { - let context = F::from_canonical_usize(rng.gen_range(0..256)); - let segment = F::from_canonical_usize(rng.gen_range(0..8)); - let virt = F::from_canonical_usize(rng.gen_range(0..20)); - - let val: [u32; 8] = rng.gen(); - let vals: [F; 8] = val - .iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); - - current_memory_values.insert((context, segment, virt), vals); - - (context, segment, virt, vals) - }; - - let timestamp = F::from_canonical_usize(cur_timestamp); - cur_timestamp += 1; - - memory_ops.push((context, segment, virt, vals, is_read_F, timestamp)) - } + let memory_ops = generate_random_memory_ops(num_ops); let rows = stark.generate_trace_rows(memory_ops); diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 88a45e8b..ba4fac52 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -22,6 +22,7 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData}; use crate::keccak::keccak_stark::KeccakStark; use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckVars; use crate::permutation::{ compute_permutation_z_polys, get_n_grand_product_challenge_sets, GrandProductChallengeSet, @@ -49,6 +50,8 @@ where [(); KeccakStark::::PUBLIC_INPUTS]:, [(); LogicStark::::COLUMNS]:, [(); LogicStark::::PUBLIC_INPUTS]:, + [(); MemoryStark::::COLUMNS]:, + [(); MemoryStark::::PUBLIC_INPUTS]:, { let num_starks = Table::num_tables(); debug_assert_eq!(num_starks, trace_poly_values.len()); @@ -132,8 +135,21 @@ where &mut challenger, timing, )?; + let memory_proof = prove_single_table( + &all_stark.memory_stark, + config, + &trace_poly_values[Table::Memory as usize], + &trace_commitments[Table::Memory as usize], + &ctl_data_per_table[Table::Memory as usize], + public_inputs[Table::Memory as usize] + .clone() + .try_into() + .unwrap(), + &mut challenger, + timing, + )?; - let stark_proofs = vec![cpu_proof, keccak_proof, logic_proof]; + let stark_proofs = vec![cpu_proof, keccak_proof, logic_proof, memory_proof]; debug_assert_eq!(stark_proofs.len(), num_starks); Ok(AllProof { stark_proofs }) diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index fe83cc28..37e108cb 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -18,6 +18,7 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CtlCheckVarsTarget}; use crate::keccak::keccak_stark::KeccakStark; use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckDataTarget; use crate::proof::{ AllProof, AllProofChallengesTarget, AllProofTarget, StarkOpeningSetTarget, StarkProof, @@ -44,6 +45,8 @@ pub fn verify_proof_circuit< [(); KeccakStark::::PUBLIC_INPUTS]:, [(); LogicStark::::COLUMNS]:, [(); LogicStark::::PUBLIC_INPUTS]:, + [(); MemoryStark::::COLUMNS]:, + [(); MemoryStark::::PUBLIC_INPUTS]:, C::Hasher: AlgebraicHasher, { let AllProofChallengesTarget { @@ -57,6 +60,7 @@ pub fn verify_proof_circuit< cpu_stark, keccak_stark, logic_stark, + memory_stark, cross_table_lookups, } = all_stark; @@ -91,6 +95,14 @@ pub fn verify_proof_circuit< &ctl_vars_per_table[Table::Logic as usize], inner_config, ); + verify_stark_proof_with_challenges_circuit::( + builder, + memory_stark, + &all_proof.stark_proofs[Table::Memory as usize], + &stark_challenges[Table::Memory as usize], + &ctl_vars_per_table[Table::Memory as usize], + inner_config, + ); verify_cross_table_lookups_circuit::( builder, @@ -291,6 +303,20 @@ pub fn add_virtual_all_proof, const D: usize>( public_inputs, } }, + { + let proof = add_virtual_stark_proof( + builder, + all_stark.memory_stark, + config, + degree_bits[Table::Memory as usize], + nums_ctl_zs[Table::Memory as usize], + ); + let public_inputs = builder.add_virtual_targets(KeccakStark::::PUBLIC_INPUTS); + StarkProofWithPublicInputsTarget { + proof, + public_inputs, + } + }, ]; assert_eq!(stark_proofs.len(), Table::num_tables()); diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 62f5fc4f..a26e1b16 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -13,6 +13,7 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars}; use crate::keccak::keccak_stark::KeccakStark; use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckVars; use crate::proof::{ AllProof, AllProofChallenges, StarkOpeningSet, StarkProofChallenges, StarkProofWithPublicInputs, @@ -33,6 +34,8 @@ where [(); KeccakStark::::PUBLIC_INPUTS]:, [(); LogicStark::::COLUMNS]:, [(); LogicStark::::PUBLIC_INPUTS]:, + [(); MemoryStark::::COLUMNS]:, + [(); MemoryStark::::PUBLIC_INPUTS]:, [(); C::Hasher::HASH_SIZE]:, { let AllProofChallenges { @@ -46,6 +49,7 @@ where cpu_stark, keccak_stark, logic_stark, + memory_stark, cross_table_lookups, } = all_stark; @@ -70,6 +74,13 @@ where &ctl_vars_per_table[Table::Keccak as usize], config, )?; + verify_stark_proof_with_challenges( + memory_stark, + &all_proof.stark_proofs[Table::Memory as usize], + &stark_challenges[Table::Memory as usize], + &ctl_vars_per_table[Table::Memory as usize], + config, + )?; verify_stark_proof_with_challenges( logic_stark, &all_proof.stark_proofs[Table::Logic as usize],