diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index d929acd9..48848b73 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -24,7 +24,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2022-11-23 override: true - name: rust-cache @@ -60,7 +60,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2022-11-23 override: true components: rustfmt, clippy diff --git a/README.md b/README.md index 59fc4d09..59ef6959 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ RUSTFLAGS=-Ctarget-cpu=native cargo run --release --example bench_recursion -- - ## Jemalloc -Plonky2 prefers the [Jemalloc](http://jemalloc.net) memory allocator due to its superior performance. To use it, include `jemallocator = "0.3.2"` in`Cargo.toml`and add the following lines +Plonky2 prefers the [Jemalloc](http://jemalloc.net) memory allocator due to its superior performance. To use it, include `jemallocator = "0.5.0"` in`Cargo.toml`and add the following lines to your `main.rs`: ```rust diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 17d855c1..03850f7a 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -31,6 +31,9 @@ sha2 = "0.10.2" static_assertions = "1.1.0" tiny-keccak = "2.0.2" +[target.'cfg(not(target_env = "msvc"))'.dependencies] +jemallocator = "0.5.0" + [dev-dependencies] criterion = "0.4.0" hex = "0.4.3" diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 26840c5f..22f7c123 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -8,12 +8,12 @@ use crate::config::StarkConfig; use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; use crate::cpu::membus::NUM_GP_CHANNELS; -use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns}; +use crate::cross_table_lookup::{Column, CrossTableLookup, TableWithColumns}; use crate::keccak::keccak_stark; use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_memory::columns::KECCAK_WIDTH_BYTES; -use crate::keccak_memory::keccak_memory_stark; -use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; +use crate::keccak_sponge::columns::KECCAK_RATE_BYTES; +use crate::keccak_sponge::keccak_sponge_stark; +use crate::keccak_sponge::keccak_sponge_stark::{num_logic_ctls, KeccakSpongeStark}; use crate::logic; use crate::logic::LogicStark; use crate::memory::memory_stark; @@ -24,7 +24,7 @@ use crate::stark::Stark; pub struct AllStark, const D: usize> { pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, - pub keccak_memory_stark: KeccakMemoryStark, + pub keccak_sponge_stark: KeccakSpongeStark, pub logic_stark: LogicStark, pub memory_stark: MemoryStark, pub cross_table_lookups: Vec>, @@ -35,7 +35,7 @@ impl, const D: usize> Default for AllStark { Self { cpu_stark: CpuStark::default(), keccak_stark: KeccakStark::default(), - keccak_memory_stark: KeccakMemoryStark::default(), + keccak_sponge_stark: KeccakSpongeStark::default(), logic_stark: LogicStark::default(), memory_stark: MemoryStark::default(), cross_table_lookups: all_cross_table_lookups(), @@ -48,7 +48,7 @@ impl, const D: usize> AllStark { [ self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), - self.keccak_memory_stark.num_permutation_batches(config), + self.keccak_sponge_stark.num_permutation_batches(config), self.logic_stark.num_permutation_batches(config), self.memory_stark.num_permutation_batches(config), ] @@ -58,7 +58,7 @@ impl, const D: usize> AllStark { [ self.cpu_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(), - self.keccak_memory_stark.permutation_batch_size(), + self.keccak_sponge_stark.permutation_batch_size(), self.logic_stark.permutation_batch_size(), self.memory_stark.permutation_batch_size(), ] @@ -69,66 +69,77 @@ impl, const D: usize> AllStark { pub enum Table { Cpu = 0, Keccak = 1, - KeccakMemory = 2, + KeccakSponge = 2, Logic = 3, Memory = 4, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; -#[allow(unused)] // TODO: Should be used soon. pub(crate) fn all_cross_table_lookups() -> Vec> { - vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_memory()] + let mut ctls = vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_sponge()]; + // TODO: Some CTLs temporarily disabled while we get them working. + disable_ctl(&mut ctls[0]); + disable_ctl(&mut ctls[1]); + disable_ctl(&mut ctls[2]); + disable_ctl(&mut ctls[3]); + ctls +} + +fn disable_ctl(ctl: &mut CrossTableLookup) { + for table in &mut ctl.looking_tables { + table.filter_column = Some(Column::zero()); + } + ctl.looked_table.filter_column = Some(Column::zero()); } fn ctl_keccak() -> CrossTableLookup { - let cpu_looking = TableWithColumns::new( - Table::Cpu, - cpu_stark::ctl_data_keccak(), - Some(cpu_stark::ctl_filter_keccak()), + let keccak_sponge_looking = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_keccak(), + Some(keccak_sponge_stark::ctl_looking_keccak_filter()), ); - let keccak_memory_looking = TableWithColumns::new( - Table::KeccakMemory, - keccak_memory_stark::ctl_looking_keccak(), - Some(keccak_memory_stark::ctl_filter()), + let keccak_looked = TableWithColumns::new( + Table::Keccak, + keccak_stark::ctl_data(), + Some(keccak_stark::ctl_filter()), ); - CrossTableLookup::new( - vec![cpu_looking, keccak_memory_looking], - TableWithColumns::new( - Table::Keccak, - keccak_stark::ctl_data(), - Some(keccak_stark::ctl_filter()), - ), - None, - ) + CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked, None) } -fn ctl_keccak_memory() -> CrossTableLookup { - CrossTableLookup::new( - vec![TableWithColumns::new( - Table::Cpu, - cpu_stark::ctl_data_keccak_memory(), - Some(cpu_stark::ctl_filter_keccak_memory()), - )], - TableWithColumns::new( - Table::KeccakMemory, - keccak_memory_stark::ctl_looked_data(), - Some(keccak_memory_stark::ctl_filter()), - ), - None, - ) +fn ctl_keccak_sponge() -> CrossTableLookup { + let cpu_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_keccak_sponge(), + Some(cpu_stark::ctl_filter_keccak_sponge()), + ); + let keccak_sponge_looked = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looked_data(), + Some(keccak_sponge_stark::ctl_looked_filter()), + ); + CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked, None) } fn ctl_logic() -> CrossTableLookup { - CrossTableLookup::new( - vec![TableWithColumns::new( - Table::Cpu, - cpu_stark::ctl_data_logic(), - Some(cpu_stark::ctl_filter_logic()), - )], - TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter())), - None, - ) + let cpu_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_logic(), + Some(cpu_stark::ctl_filter_logic()), + ); + let mut all_lookers = vec![cpu_looking]; + for i in 0..num_logic_ctls() { + let keccak_sponge_looking = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_logic(i), + // TODO: Double check, but I think it's the same filter for memory and logic? + Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), + ); + all_lookers.push(keccak_sponge_looking); + } + let logic_looked = + TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter())); + CrossTableLookup::new(all_lookers, logic_looked, None) } fn ctl_memory() -> CrossTableLookup { @@ -144,662 +155,21 @@ fn ctl_memory() -> CrossTableLookup { Some(cpu_stark::ctl_filter_gp_memory(channel)), ) }); - let keccak_memory_reads = (0..KECCAK_WIDTH_BYTES).map(|i| { + let keccak_sponge_reads = (0..KECCAK_RATE_BYTES).map(|i| { TableWithColumns::new( - Table::KeccakMemory, - keccak_memory_stark::ctl_looking_memory(i, true), - Some(keccak_memory_stark::ctl_filter()), - ) - }); - let keccak_memory_writes = (0..KECCAK_WIDTH_BYTES).map(|i| { - TableWithColumns::new( - Table::KeccakMemory, - keccak_memory_stark::ctl_looking_memory(i, false), - Some(keccak_memory_stark::ctl_filter()), + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_memory(i), + Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), ) }); let all_lookers = iter::once(cpu_memory_code_read) .chain(cpu_memory_gp_ops) - .chain(keccak_memory_reads) - .chain(keccak_memory_writes) + .chain(keccak_sponge_reads) .collect(); - CrossTableLookup::new( - all_lookers, - TableWithColumns::new( - Table::Memory, - memory_stark::ctl_data(), - Some(memory_stark::ctl_filter()), - ), - None, - ) -} - -#[cfg(test)] -mod tests { - use std::borrow::BorrowMut; - - use anyhow::Result; - use ethereum_types::U256; - use itertools::Itertools; - use plonky2::field::polynomial::PolynomialValues; - use plonky2::field::types::{Field, PrimeField64}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData}; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2::util::timing::TimingTree; - use rand::{thread_rng, Rng}; - - use crate::all_stark::{AllStark, NUM_TABLES}; - use crate::config::StarkConfig; - use crate::cpu::cpu_stark::CpuStark; - use crate::cpu::kernel::aggregator::KERNEL; - use crate::cross_table_lookup::testutils::check_ctls; - use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; - use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; - use crate::logic::{self, LogicStark, Operation}; - use crate::memory::memory_stark::tests::generate_random_memory_ops; - use crate::memory::memory_stark::MemoryStark; - use crate::memory::NUM_CHANNELS; - use crate::proof::{AllProof, PublicValues}; - use crate::prover::prove_with_traces; - use crate::recursive_verifier::tests::recursively_verify_all_proof; - use crate::recursive_verifier::{ - add_virtual_recursive_all_proof, all_verifier_data_recursive_stark_proof, - set_recursive_all_proof_target, RecursiveAllProof, - }; - use crate::stark::Stark; - use crate::util::{limb_from_bits_le, trace_rows_to_poly_values}; - use crate::verifier::verify_proof; - use crate::{cpu, keccak, memory}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - fn make_keccak_trace( - num_keccak_perms: usize, - keccak_stark: &KeccakStark, - rng: &mut R, - ) -> Vec> { - let keccak_inputs = (0..num_keccak_perms) - .map(|_| [0u64; NUM_INPUTS].map(|_| rng.gen())) - .collect_vec(); - keccak_stark.generate_trace(keccak_inputs, &mut TimingTree::default()) - } - - fn make_keccak_memory_trace( - keccak_memory_stark: &KeccakMemoryStark, - config: &StarkConfig, - ) -> Vec> { - keccak_memory_stark.generate_trace( - vec![], - config.fri_config.num_cap_elements(), - &mut TimingTree::default(), - ) - } - - fn make_logic_trace( - num_rows: usize, - logic_stark: &LogicStark, - rng: &mut R, - ) -> Vec> { - let all_ops = [logic::Op::And, logic::Op::Or, logic::Op::Xor]; - let ops = (0..num_rows) - .map(|_| { - let op = all_ops[rng.gen_range(0..all_ops.len())]; - let input0 = U256(rng.gen()); - let input1 = U256(rng.gen()); - Operation::new(op, input0, input1) - }) - .collect(); - logic_stark.generate_trace(ops, &mut TimingTree::default()) - } - - fn make_memory_trace( - num_memory_ops: usize, - memory_stark: &MemoryStark, - rng: &mut R, - ) -> (Vec>, usize) { - let memory_ops = generate_random_memory_ops(num_memory_ops, rng); - let trace = memory_stark.generate_trace(memory_ops, &mut TimingTree::default()); - let num_ops = trace[0].values.len(); - (trace, num_ops) - } - - fn bits_from_opcode(opcode: u8) -> [F; 8] { - [ - F::from_bool(opcode & (1 << 0) != 0), - F::from_bool(opcode & (1 << 1) != 0), - F::from_bool(opcode & (1 << 2) != 0), - F::from_bool(opcode & (1 << 3) != 0), - F::from_bool(opcode & (1 << 4) != 0), - F::from_bool(opcode & (1 << 5) != 0), - F::from_bool(opcode & (1 << 6) != 0), - F::from_bool(opcode & (1 << 7) != 0), - ] - } - - fn make_cpu_trace( - num_keccak_perms: usize, - num_logic_rows: usize, - num_memory_ops: usize, - cpu_stark: &CpuStark, - keccak_trace: &[PolynomialValues], - logic_trace: &[PolynomialValues], - memory_trace: &mut [PolynomialValues], - ) -> Vec> { - let keccak_input_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms) - .map(|i| { - (0..2 * NUM_INPUTS) - .map(|j| { - keccak::columns::reg_input_limb(j) - .eval_table(keccak_trace, (i + 1) * NUM_ROUNDS - 1) - }) - .collect::>() - .try_into() - .unwrap() - }) - .collect(); - let keccak_output_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms) - .map(|i| { - (0..2 * NUM_INPUTS) - .map(|j| { - keccak_trace[keccak::columns::reg_output_limb(j)].values - [(i + 1) * NUM_ROUNDS - 1] - }) - .collect::>() - .try_into() - .unwrap() - }) - .collect(); - - let mut cpu_trace_rows: Vec<[F; CpuStark::::COLUMNS]> = vec![]; - let mut bootstrap_row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - bootstrap_row.is_bootstrap_kernel = F::ONE; - cpu_trace_rows.push(bootstrap_row.into()); - - for i in 0..num_keccak_perms { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_keccak = F::ONE; - let keccak = row.general.keccak_mut(); - for j in 0..2 * NUM_INPUTS { - keccak.input_limbs[j] = keccak_input_limbs[i][j]; - keccak.output_limbs[j] = keccak_output_limbs[i][j]; - } - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // Pad to `num_memory_ops` for memory testing. - for _ in cpu_trace_rows.len()..num_memory_ops { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.opcode_bits = bits_from_opcode(0x5b); - row.is_cpu_cycle = F::ONE; - row.is_kernel_mode = F::ONE; - row.program_counter = F::from_canonical_usize(KERNEL.global_labels["main"]); - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - for i in 0..num_memory_ops { - let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i] - .to_canonical_u64() - .try_into() - .unwrap(); - let clock = mem_timestamp / NUM_CHANNELS; - let channel = mem_timestamp % NUM_CHANNELS; - - let filter = memory_trace[memory::columns::FILTER].values[i]; - assert!(filter.is_one() || filter.is_zero()); - let is_actual_op = filter.is_one(); - - if is_actual_op { - let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); - row.clock = F::from_canonical_usize(clock); - - dbg!(channel, row.mem_channels.len()); - let channel = &mut row.mem_channels[channel]; - channel.used = F::ONE; - channel.is_read = memory_trace[memory::columns::IS_READ].values[i]; - channel.addr_context = memory_trace[memory::columns::ADDR_CONTEXT].values[i]; - channel.addr_segment = memory_trace[memory::columns::ADDR_SEGMENT].values[i]; - channel.addr_virtual = memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; - for j in 0..8 { - channel.value[j] = memory_trace[memory::columns::value_limb(j)].values[i]; - } - } - } - - for i in 0..num_logic_rows { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.is_kernel_mode = F::ONE; - - // Since these are the first cycle rows, we must start with PC=main then increment. - row.program_counter = F::from_canonical_usize(KERNEL.global_labels["main"] + i); - row.opcode_bits = bits_from_opcode( - if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO { - 0x16 - } else if logic_trace[logic::columns::IS_OR].values[i] != F::ZERO { - 0x17 - } else if logic_trace[logic::columns::IS_XOR].values[i] != F::ZERO { - 0x18 - } else { - panic!() - }, - ); - - let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0); - for (col_cpu, limb_cols_logic) in - row.mem_channels[0].value.iter_mut().zip(input0_bit_cols) - { - *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); - } - - let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1); - for (col_cpu, limb_cols_logic) in - row.mem_channels[1].value.iter_mut().zip(input1_bit_cols) - { - *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); - } - - for (col_cpu, col_logic) in row.mem_channels[2] - .value - .iter_mut() - .zip(logic::columns::RESULT) - { - *col_cpu = logic_trace[col_logic].values[i]; - } - - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // Trap to kernel - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - let last_row: cpu::columns::CpuColumnsView = - cpu_trace_rows[cpu_trace_rows.len() - 1].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software - row.is_kernel_mode = F::ONE; - row.program_counter = last_row.program_counter + F::ONE; - row.mem_channels[0].value = [ - row.program_counter, - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `EXIT_KERNEL` (to kernel) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0xf9); - row.is_kernel_mode = F::ONE; - row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]); - row.mem_channels[0].value = [ - F::from_canonical_u16(15682), - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `JUMP` (in kernel mode) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0x56); - row.is_kernel_mode = F::ONE; - row.program_counter = F::from_canonical_u16(15682); - row.mem_channels[0].value = [ - F::from_canonical_u16(15106), - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.mem_channels[1].value = [ - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.general.jumps_mut().input0_upper_zero = F::ONE; - row.general.jumps_mut().dst_valid_or_kernel = F::ONE; - row.general.jumps_mut().input0_jumpable = F::ONE; - row.general.jumps_mut().input1_sum_inv = F::ONE; - row.general.jumps_mut().should_jump = F::ONE; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `EXIT_KERNEL` (to userspace) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0xf9); - row.is_kernel_mode = F::ONE; - row.program_counter = F::from_canonical_u16(15106); - row.mem_channels[0].value = [ - F::from_canonical_u16(63064), - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `JUMP` (taken) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0x56); - row.is_kernel_mode = F::ZERO; - row.program_counter = F::from_canonical_u16(63064); - row.mem_channels[0].value = [ - F::from_canonical_u16(3754), - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.mem_channels[1].value = [ - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.general.jumps_mut().input0_upper_zero = F::ONE; - row.general.jumps_mut().dst_valid = F::ONE; - row.general.jumps_mut().dst_valid_or_kernel = F::ONE; - row.general.jumps_mut().input0_jumpable = F::ONE; - row.general.jumps_mut().input1_sum_inv = F::ONE; - row.general.jumps_mut().should_jump = F::ONE; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `JUMPI` (taken) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0x57); - row.is_kernel_mode = F::ZERO; - row.program_counter = F::from_canonical_u16(3754); - row.mem_channels[0].value = [ - F::from_canonical_u16(37543), - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.mem_channels[1].value = [ - F::ZERO, - F::ZERO, - F::ZERO, - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.general.jumps_mut().input0_upper_zero = F::ONE; - row.general.jumps_mut().dst_valid = F::ONE; - row.general.jumps_mut().dst_valid_or_kernel = F::ONE; - row.general.jumps_mut().input0_jumpable = F::ONE; - row.general.jumps_mut().input1_sum_inv = F::ONE; - row.general.jumps_mut().should_jump = F::ONE; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `JUMPI` (not taken) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0x57); - row.is_kernel_mode = F::ZERO; - row.program_counter = F::from_canonical_u16(37543); - row.mem_channels[0].value = [ - F::from_canonical_u16(37543), - F::ZERO, - F::ZERO, - F::ZERO, - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.general.jumps_mut().input0_upper_sum_inv = F::ONE; - row.general.jumps_mut().dst_valid = F::ONE; - row.general.jumps_mut().dst_valid_or_kernel = F::ONE; - row.general.jumps_mut().input0_jumpable = F::ZERO; - row.general.jumps_mut().should_continue = F::ONE; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // `JUMP` (trapping) - { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - let last_row: cpu::columns::CpuColumnsView = - cpu_trace_rows[cpu_trace_rows.len() - 1].into(); - row.is_cpu_cycle = F::ONE; - row.opcode_bits = bits_from_opcode(0x56); - row.is_kernel_mode = F::ZERO; - row.program_counter = last_row.program_counter + F::ONE; - row.mem_channels[0].value = [ - F::from_canonical_u16(37543), - F::ZERO, - F::ZERO, - F::ZERO, - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.mem_channels[1].value = [ - F::ONE, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]; - row.general.jumps_mut().input0_upper_sum_inv = F::ONE; - row.general.jumps_mut().dst_valid = F::ONE; - row.general.jumps_mut().dst_valid_or_kernel = F::ONE; - row.general.jumps_mut().input0_jumpable = F::ZERO; - row.general.jumps_mut().input1_sum_inv = F::ONE; - row.general.jumps_mut().should_trap = F::ONE; - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // Pad to a power of two. - for i in 0..cpu_trace_rows.len().next_power_of_two() - cpu_trace_rows.len() { - let mut row: cpu::columns::CpuColumnsView = - [F::ZERO; CpuStark::::COLUMNS].into(); - row.opcode_bits = bits_from_opcode(0xff); - row.is_cpu_cycle = F::ONE; - row.is_kernel_mode = F::ONE; - row.program_counter = - F::from_canonical_usize(KERNEL.global_labels["fault_exception"] + i); - cpu_stark.generate(row.borrow_mut()); - cpu_trace_rows.push(row.into()); - } - - // Ensure we finish in a halted state. - { - let num_rows = cpu_trace_rows.len(); - let halt_label = F::from_canonical_usize(KERNEL.global_labels["halt_pc0"]); - - let last_row: &mut cpu::columns::CpuColumnsView = - cpu_trace_rows[num_rows - 1].borrow_mut(); - last_row.program_counter = halt_label; - } - - trace_rows_to_poly_values(cpu_trace_rows) - } - - fn get_proof(config: &StarkConfig) -> Result<(AllStark, AllProof)> { - let all_stark = AllStark::default(); - - let num_logic_rows = 62; - let num_memory_ops = 1 << 5; - - let mut rng = thread_rng(); - let num_keccak_perms = 2; - - let keccak_trace = make_keccak_trace(num_keccak_perms, &all_stark.keccak_stark, &mut rng); - let keccak_memory_trace = make_keccak_memory_trace(&all_stark.keccak_memory_stark, config); - let logic_trace = make_logic_trace(num_logic_rows, &all_stark.logic_stark, &mut rng); - let mem_trace = make_memory_trace(num_memory_ops, &all_stark.memory_stark, &mut rng); - let mut memory_trace = mem_trace.0; - let num_memory_ops = mem_trace.1; - let cpu_trace = make_cpu_trace( - num_keccak_perms, - num_logic_rows, - num_memory_ops, - &all_stark.cpu_stark, - &keccak_trace, - &logic_trace, - &mut memory_trace, - ); - - let traces = [ - cpu_trace, - keccak_trace, - keccak_memory_trace, - logic_trace, - memory_trace, - ]; - check_ctls(&traces, &all_stark.cross_table_lookups); - - let public_values = PublicValues::default(); - let proof = prove_with_traces::( - &all_stark, - config, - traces, - public_values, - &mut TimingTree::default(), - )?; - - Ok((all_stark, proof)) - } - - #[test] - #[ignore] // Ignoring but not deleting so the test can serve as an API usage example - fn test_all_stark() -> Result<()> { - let config = StarkConfig::standard_fast_config(); - let (all_stark, proof) = get_proof(&config)?; - verify_proof(all_stark, proof, &config) - } - - #[test] - #[ignore] // Ignoring but not deleting so the test can serve as an API usage example - fn test_all_stark_recursive_verifier() -> Result<()> { - init_logger(); - - let config = StarkConfig::standard_fast_config(); - let (all_stark, proof) = get_proof(&config)?; - verify_proof(all_stark.clone(), proof.clone(), &config)?; - - recursive_proof(all_stark, proof, &config) - } - - fn recursive_proof( - inner_all_stark: AllStark, - inner_proof: AllProof, - inner_config: &StarkConfig, - ) -> Result<()> { - let circuit_config = CircuitConfig::standard_recursion_config(); - let recursive_all_proof = recursively_verify_all_proof( - &inner_all_stark, - &inner_proof, - inner_config, - &circuit_config, - )?; - - let verifier_data: [VerifierCircuitData; NUM_TABLES] = - all_verifier_data_recursive_stark_proof( - &inner_all_stark, - inner_proof.degree_bits(inner_config), - inner_config, - &circuit_config, - ); - let circuit_config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(circuit_config); - let mut pw = PartialWitness::new(); - let recursive_all_proof_target = - add_virtual_recursive_all_proof(&mut builder, &verifier_data); - set_recursive_all_proof_target(&mut pw, &recursive_all_proof_target, &recursive_all_proof); - RecursiveAllProof::verify_circuit( - &mut builder, - recursive_all_proof_target, - &verifier_data, - inner_all_stark.cross_table_lookups, - inner_config, - ); - - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof) - } - - fn init_logger() { - let _ = env_logger::builder().format_timestamp(None).try_init(); - } + let memory_looked = TableWithColumns::new( + Table::Memory, + memory_stark::ctl_data(), + Some(memory_stark::ctl_filter()), + ); + CrossTableLookup::new(all_lookers, memory_looked, None) } diff --git a/evm/src/arithmetic/add.rs b/evm/src/arithmetic/add.rs index b09307b0..4e9de4b3 100644 --- a/evm/src/arithmetic/add.rs +++ b/evm/src/arithmetic/add.rs @@ -35,6 +35,7 @@ pub(crate) fn eval_packed_generic_are_equal( is_op: P, larger: I, smaller: J, + is_two_row_op: bool, ) -> P where P: PackedField, @@ -47,7 +48,11 @@ where for (a, b) in larger.zip(smaller) { // t should be either 0 or 2^LIMB_BITS let t = cy + a - b; - yield_constr.constraint(is_op * t * (overflow - t)); + if is_two_row_op { + yield_constr.constraint_transition(is_op * t * (overflow - t)); + } else { + yield_constr.constraint(is_op * t * (overflow - t)); + } // cy <-- 0 or 1 // NB: this is multiplication by a constant, so doesn't // increase the degree of the constraint. @@ -62,6 +67,7 @@ pub(crate) fn eval_ext_circuit_are_equal( is_op: ExtensionTarget, larger: I, smaller: J, + is_two_row_op: bool, ) -> ExtensionTarget where F: RichField + Extendable, @@ -87,7 +93,11 @@ where let t2 = builder.mul_extension(t, t1); let filtered_limb_constraint = builder.mul_extension(is_op, t2); - yield_constr.constraint(builder, filtered_limb_constraint); + if is_two_row_op { + yield_constr.constraint_transition(builder, filtered_limb_constraint); + } else { + yield_constr.constraint(builder, filtered_limb_constraint); + } cy = builder.mul_const_extension(overflow_inv, t); } @@ -125,6 +135,7 @@ pub fn eval_packed_generic( is_add, output_computed, output_limbs.iter().copied(), + false, ); } @@ -155,6 +166,7 @@ pub fn eval_ext_circuit, const D: usize>( is_add, output_computed.into_iter(), output_limbs.iter().copied(), + false, ); } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 5d835e77..5790ae66 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -17,7 +17,11 @@ pub struct ArithmeticStark { } impl ArithmeticStark { - pub fn generate(&self, local_values: &mut [F; columns::NUM_ARITH_COLUMNS]) { + pub fn generate( + &self, + local_values: &mut [F; columns::NUM_ARITH_COLUMNS], + next_values: &mut [F; columns::NUM_ARITH_COLUMNS], + ) { // Check that at most one operation column is "one" and that the // rest are "zero". assert_eq!( @@ -47,17 +51,17 @@ impl ArithmeticStark { } else if local_values[columns::IS_GT].is_one() { compare::generate(local_values, columns::IS_GT); } else if local_values[columns::IS_ADDMOD].is_one() { - modular::generate(local_values, columns::IS_ADDMOD); + modular::generate(local_values, next_values, columns::IS_ADDMOD); } else if local_values[columns::IS_SUBMOD].is_one() { - modular::generate(local_values, columns::IS_SUBMOD); + modular::generate(local_values, next_values, columns::IS_SUBMOD); } else if local_values[columns::IS_MULMOD].is_one() { - modular::generate(local_values, columns::IS_MULMOD); + modular::generate(local_values, next_values, columns::IS_MULMOD); } else if local_values[columns::IS_MOD].is_one() { - modular::generate(local_values, columns::IS_MOD); + modular::generate(local_values, next_values, columns::IS_MOD); } else if local_values[columns::IS_DIV].is_one() { - modular::generate(local_values, columns::IS_DIV); + modular::generate(local_values, next_values, columns::IS_DIV); } else { - todo!("the requested operation has not yet been implemented"); + panic!("the requested operation should not be handled by the arithmetic table"); } } } @@ -74,11 +78,12 @@ impl, const D: usize> Stark for ArithmeticSta P: PackedField, { let lv = vars.local_values; + let nv = vars.next_values; add::eval_packed_generic(lv, yield_constr); sub::eval_packed_generic(lv, yield_constr); mul::eval_packed_generic(lv, yield_constr); compare::eval_packed_generic(lv, yield_constr); - modular::eval_packed_generic(lv, yield_constr); + modular::eval_packed_generic(lv, nv, yield_constr); } fn eval_ext_circuit( @@ -88,11 +93,12 @@ impl, const D: usize> Stark for ArithmeticSta yield_constr: &mut RecursiveConstraintConsumer, ) { let lv = vars.local_values; + let nv = vars.next_values; add::eval_ext_circuit(builder, lv, yield_constr); sub::eval_ext_circuit(builder, lv, yield_constr); mul::eval_ext_circuit(builder, lv, yield_constr); compare::eval_ext_circuit(builder, lv, yield_constr); - modular::eval_ext_circuit(builder, lv, yield_constr); + modular::eval_ext_circuit(builder, lv, nv, yield_constr); } fn constraint_degree(&self) -> usize { diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 923fbc73..779be2ee 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -12,7 +12,11 @@ const fn n_limbs() -> usize { if EVM_REGISTER_BITS % LIMB_BITS != 0 { panic!("limb size must divide EVM register size"); } - EVM_REGISTER_BITS / LIMB_BITS + let n = EVM_REGISTER_BITS / LIMB_BITS; + if n % 2 == 1 { + panic!("number of limbs must be even"); + } + n } /// Number of LIMB_BITS limbs that are in on EVM register-sized number. @@ -40,43 +44,66 @@ pub(crate) const ALL_OPERATIONS: [usize; 12] = [ /// Within the Arithmetic Unit, there are shared columns which can be /// used by any arithmetic circuit, depending on which one is active -/// this cycle. Can be increased as needed as other operations are -/// implemented. -const NUM_SHARED_COLS: usize = 9 * N_LIMBS; // only need 64 for add, sub, and mul +/// this cycle. +/// +/// Modular arithmetic takes 9 * N_LIMBS columns which is split across +/// two rows, the first with 5 * N_LIMBS columns and the second with +/// 4 * N_LIMBS columns. (There are hence N_LIMBS "wasted columns" in +/// the second row.) +const NUM_SHARED_COLS: usize = 5 * N_LIMBS; const GENERAL_INPUT_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; const GENERAL_INPUT_1: Range = GENERAL_INPUT_0.end..GENERAL_INPUT_0.end + N_LIMBS; const GENERAL_INPUT_2: Range = GENERAL_INPUT_1.end..GENERAL_INPUT_1.end + N_LIMBS; const GENERAL_INPUT_3: Range = GENERAL_INPUT_2.end..GENERAL_INPUT_2.end + N_LIMBS; -const AUX_INPUT_0: Range = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + 2 * N_LIMBS; -const AUX_INPUT_1: Range = AUX_INPUT_0.end..AUX_INPUT_0.end + 2 * N_LIMBS; +const AUX_INPUT_0_LO: Range = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + N_LIMBS; + +// The auxiliary input columns overlap the general input columns +// because they correspond to the values in the second row for modular +// operations. +const AUX_INPUT_0_HI: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; +const AUX_INPUT_1: Range = AUX_INPUT_0_HI.end..AUX_INPUT_0_HI.end + 2 * N_LIMBS; +// These auxiliary input columns are awkwardly split across two rows, +// with the first half after the general input columns and the second +// half after the auxiliary input columns. const AUX_INPUT_2: Range = AUX_INPUT_1.end..AUX_INPUT_1.end + N_LIMBS; +// ADD takes 3 * N_LIMBS = 48 columns pub(crate) const ADD_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const ADD_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const ADD_OUTPUT: Range = GENERAL_INPUT_2; +// SUB takes 3 * N_LIMBS = 48 columns pub(crate) const SUB_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const SUB_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const SUB_OUTPUT: Range = GENERAL_INPUT_2; +// MUL takes 4 * N_LIMBS = 64 columns pub(crate) const MUL_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const MUL_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const MUL_OUTPUT: Range = GENERAL_INPUT_2; pub(crate) const MUL_AUX_INPUT: Range = GENERAL_INPUT_3; +// LT and GT take 4 * N_LIMBS = 64 columns pub(crate) const CMP_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const CMP_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2.start; pub(crate) const CMP_AUX_INPUT: Range = GENERAL_INPUT_3; +// MULMOD takes 4 * N_LIMBS + 2 * 2*N_LIMBS + N_LIMBS = 144 columns +// but split over two rows of 80 columns and 64 columns. +// +// ADDMOD, SUBMOD, MOD and DIV are currently implemented in terms of +// the general modular code, so they also take 144 columns (also split +// over two rows). pub(crate) const MODULAR_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const MODULAR_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const MODULAR_MODULUS: Range = GENERAL_INPUT_2; pub(crate) const MODULAR_OUTPUT: Range = GENERAL_INPUT_3; -pub(crate) const MODULAR_QUO_INPUT: Range = AUX_INPUT_0; +pub(crate) const MODULAR_QUO_INPUT_LO: Range = AUX_INPUT_0_LO; // NB: Last value is not used in AUX, it is used in MOD_IS_ZERO -pub(crate) const MODULAR_AUX_INPUT: Range = AUX_INPUT_1; +pub(crate) const MODULAR_QUO_INPUT_HI: Range = AUX_INPUT_0_HI; +pub(crate) const MODULAR_AUX_INPUT: Range = AUX_INPUT_1.start..AUX_INPUT_1.end - 1; pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1; pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_INPUT_2; @@ -85,6 +112,6 @@ pub(crate) const DIV_NUMERATOR: Range = MODULAR_INPUT_0; #[allow(unused)] // TODO: Will be used when hooking into the CPU pub(crate) const DIV_DENOMINATOR: Range = MODULAR_MODULUS; #[allow(unused)] // TODO: Will be used when hooking into the CPU -pub(crate) const DIV_OUTPUT: Range = MODULAR_QUO_INPUT.start..MODULAR_QUO_INPUT.start + 16; +pub(crate) const DIV_OUTPUT: Range = MODULAR_QUO_INPUT_LO; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; diff --git a/evm/src/arithmetic/compare.rs b/evm/src/arithmetic/compare.rs index 7a360430..780053ce 100644 --- a/evm/src/arithmetic/compare.rs +++ b/evm/src/arithmetic/compare.rs @@ -57,16 +57,27 @@ pub(crate) fn eval_packed_generic_lt( input1: &[P], aux: &[P], output: P, + is_two_row_op: bool, ) { debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); // Verify (input0 < input1) == output by providing aux such that // input0 - input1 == aux + output*2^256. let lhs_limbs = input0.iter().zip(input1).map(|(&a, &b)| a - b); - let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.iter().copied(), lhs_limbs); + let cy = eval_packed_generic_are_equal( + yield_constr, + is_op, + aux.iter().copied(), + lhs_limbs, + is_two_row_op, + ); // We don't need to check that cy is 0 or 1, since output has // already been checked to be 0 or 1. - yield_constr.constraint(is_op * (cy - output)); + if is_two_row_op { + yield_constr.constraint_transition(is_op * (cy - output)); + } else { + yield_constr.constraint(is_op * (cy - output)); + } } pub fn eval_packed_generic( @@ -88,8 +99,8 @@ pub fn eval_packed_generic( let is_cmp = is_lt + is_gt; eval_packed_generic_check_is_one_bit(yield_constr, is_cmp, output); - eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output); - eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output); + eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output, false); + eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output, false); } fn eval_ext_circuit_check_is_one_bit, const D: usize>( @@ -112,6 +123,7 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( input1: &[ExtensionTarget], aux: &[ExtensionTarget], output: ExtensionTarget, + is_two_row_op: bool, ) { debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); @@ -131,10 +143,11 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( is_op, aux.iter().copied(), lhs_limbs.into_iter(), + is_two_row_op, ); let good_output = builder.sub_extension(cy, output); let filter = builder.mul_extension(is_op, good_output); - yield_constr.constraint(builder, filter); + yield_constr.constraint_transition(builder, filter); } pub fn eval_ext_circuit, const D: usize>( @@ -153,8 +166,26 @@ pub fn eval_ext_circuit, const D: usize>( let is_cmp = builder.add_extension(is_lt, is_gt); eval_ext_circuit_check_is_one_bit(builder, yield_constr, is_cmp, output); - eval_ext_circuit_lt(builder, yield_constr, is_lt, input0, input1, aux, output); - eval_ext_circuit_lt(builder, yield_constr, is_gt, input1, input0, aux, output); + eval_ext_circuit_lt( + builder, + yield_constr, + is_lt, + input0, + input1, + aux, + output, + false, + ); + eval_ext_circuit_lt( + builder, + yield_constr, + is_gt, + input1, + input0, + aux, + output, + false, + ); } #[cfg(test)] diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index a6f59446..1493b292 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -1,3 +1,9 @@ +use std::str::FromStr; + +use ethereum_types::U256; + +use crate::util::{addmod, mulmod, submod}; + mod add; mod compare; mod modular; @@ -7,3 +13,146 @@ mod utils; pub mod arithmetic_stark; pub(crate) mod columns; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum BinaryOperator { + Add, + Mul, + Sub, + Div, + Mod, + Lt, + Gt, + Shl, + Shr, + AddFp254, + MulFp254, + SubFp254, +} + +impl BinaryOperator { + pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { + match self { + BinaryOperator::Add => input0.overflowing_add(input1).0, + BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Sub => input0.overflowing_sub(input1).0, + BinaryOperator::Div => { + if input1.is_zero() { + U256::zero() + } else { + input0 / input1 + } + } + BinaryOperator::Mod => { + if input1.is_zero() { + U256::zero() + } else { + input0 % input1 + } + } + BinaryOperator::Lt => { + if input0 < input1 { + U256::one() + } else { + U256::zero() + } + } + BinaryOperator::Gt => { + if input0 > input1 { + U256::one() + } else { + U256::zero() + } + } + BinaryOperator::Shl => { + if input0 > 255.into() { + U256::zero() + } else { + input1 << input0 + } + } + BinaryOperator::Shr => { + if input0 > 255.into() { + U256::zero() + } else { + input1 >> input0 + } + } + BinaryOperator::AddFp254 => addmod(input0, input1, bn_base_order()), + BinaryOperator::MulFp254 => mulmod(input0, input1, bn_base_order()), + BinaryOperator::SubFp254 => submod(input0, input1, bn_base_order()), + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum TernaryOperator { + AddMod, + MulMod, +} + +impl TernaryOperator { + pub(crate) fn result(&self, input0: U256, input1: U256, input2: U256) -> U256 { + match self { + TernaryOperator::AddMod => addmod(input0, input1, input2), + TernaryOperator::MulMod => mulmod(input0, input1, input2), + } + } +} + +#[derive(Debug)] +#[allow(unused)] // TODO: Should be used soon. +pub(crate) enum Operation { + BinaryOperation { + operator: BinaryOperator, + input0: U256, + input1: U256, + result: U256, + }, + TernaryOperation { + operator: TernaryOperator, + input0: U256, + input1: U256, + input2: U256, + result: U256, + }, +} + +impl Operation { + pub(crate) fn binary(operator: BinaryOperator, input0: U256, input1: U256) -> Self { + let result = operator.result(input0, input1); + Self::BinaryOperation { + operator, + input0, + input1, + result, + } + } + + pub(crate) fn ternary( + operator: TernaryOperator, + input0: U256, + input1: U256, + input2: U256, + ) -> Self { + let result = operator.result(input0, input1, input2); + Self::TernaryOperation { + operator, + input0, + input1, + input2, + result, + } + } + + pub(crate) fn result(&self) -> U256 { + match self { + Operation::BinaryOperation { result, .. } => *result, + Operation::TernaryOperation { result, .. } => *result, + } + } +} + +fn bn_base_order() -> U256 { + U256::from_str("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47").unwrap() +} diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 09c3996e..46f6c0fa 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -86,6 +86,27 @@ //! //! In the case of DIV, we do something similar, except that we "replace" //! the modulus with "2^256" to force the quotient to be zero. +//! +//! -*- +//! +//! NB: The implementation uses 9 * N_LIMBS = 144 columns because of +//! the requirements of the general purpose MULMOD; since ADDMOD, +//! SUBMOD, MOD and DIV are currently implemented in terms of the +//! general modular code, they also take 144 columns. Possible +//! improvements: +//! +//! - We could reduce the number of columns to 112 for ADDMOD, SUBMOD, +//! etc. if they were implemented separately, so they don't pay the +//! full cost of the general MULMOD. +//! +//! - All these operations could have alternative forms where the +//! output was not guaranteed to be reduced, which is often sufficient +//! in practice, and which would save a further 16 columns. +//! +//! - If the modulus is known in advance (such as for elliptic curve +//! arithmetic), specialised handling of MULMOD in that case would +//! only require 96 columns, or 80 if the output doesn't need to be +//! reduced. use num::bigint::Sign; use num::{BigInt, One, Zero}; @@ -171,11 +192,13 @@ fn bigint_to_columns(num: &BigInt) -> [i64; N] { /// zero if they are not used. fn generate_modular_op( lv: &mut [F; NUM_ARITH_COLUMNS], + nv: &mut [F; NUM_ARITH_COLUMNS], filter: usize, operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1], ) { // Inputs are all range-checked in [0, 2^16), so the "as i64" // conversion is safe. + let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0); let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1); let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS); @@ -246,21 +269,38 @@ fn generate_modular_op( let aux_limbs = pol_remove_root_2exp::(constr_poly); lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); - lv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c))); - lv[MODULAR_QUO_INPUT].copy_from_slice("_limbs.map(|c| F::from_noncanonical_i64(c))); - lv[MODULAR_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c))); - lv[MODULAR_MOD_IS_ZERO] = mod_is_zero; + + // Copy lo and hi halves of quot_limbs into their respective registers + for (i, &lo) in MODULAR_QUO_INPUT_LO.zip("_limbs[..N_LIMBS]) { + lv[i] = F::from_noncanonical_i64(lo); + } + for (i, &hi) in MODULAR_QUO_INPUT_HI.zip("_limbs[N_LIMBS..]) { + nv[i] = F::from_noncanonical_i64(hi); + } + + for (i, &c) in MODULAR_AUX_INPUT.zip(&aux_limbs[..2 * N_LIMBS - 1]) { + nv[i] = F::from_noncanonical_i64(c); + } + + nv[MODULAR_MOD_IS_ZERO] = mod_is_zero; + nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c))); } /// Generate the output and auxiliary values for modular operations. /// /// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. -pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], filter: usize) { +pub(crate) fn generate( + lv: &mut [F; NUM_ARITH_COLUMNS], + nv: &mut [F; NUM_ARITH_COLUMNS], + filter: usize, +) { match filter { - columns::IS_ADDMOD => generate_modular_op(lv, filter, pol_add), - columns::IS_SUBMOD => generate_modular_op(lv, filter, pol_sub), - columns::IS_MULMOD => generate_modular_op(lv, filter, pol_mul_wide), - columns::IS_MOD | columns::IS_DIV => generate_modular_op(lv, filter, |a, _| pol_extend(a)), + columns::IS_ADDMOD => generate_modular_op(lv, nv, filter, pol_add), + columns::IS_SUBMOD => generate_modular_op(lv, nv, filter, pol_sub), + columns::IS_MULMOD => generate_modular_op(lv, nv, filter, pol_mul_wide), + columns::IS_MOD | columns::IS_DIV => { + generate_modular_op(lv, nv, filter, |a, _| pol_extend(a)) + } _ => panic!("generate modular operation called with unknown opcode"), } } @@ -275,26 +315,28 @@ pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], filter: us /// and check consistency when m = 0, and that c is reduced. fn modular_constr_poly( lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, filter: P, ) -> [P; 2 * N_LIMBS] { range_check_error!(MODULAR_INPUT_0, 16); range_check_error!(MODULAR_INPUT_1, 16); range_check_error!(MODULAR_MODULUS, 16); - range_check_error!(MODULAR_QUO_INPUT, 16); + range_check_error!(MODULAR_QUO_INPUT_LO, 16); + range_check_error!(MODULAR_QUO_INPUT_HI, 16); range_check_error!(MODULAR_AUX_INPUT, 20, signed); range_check_error!(MODULAR_OUTPUT, 16); let mut modulus = read_value::(lv, MODULAR_MODULUS); - let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; + let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; // Check that mod_is_zero is zero or one - yield_constr.constraint(filter * (mod_is_zero * mod_is_zero - mod_is_zero)); + yield_constr.constraint_transition(filter * (mod_is_zero * mod_is_zero - mod_is_zero)); // Check that mod_is_zero is zero if modulus is not zero (they // could both be zero) let limb_sum = modulus.into_iter().sum::

(); - yield_constr.constraint(filter * limb_sum * mod_is_zero); + yield_constr.constraint_transition(filter * limb_sum * mod_is_zero); // See the file documentation for why this suffices to handle // modulus = 0. @@ -308,8 +350,8 @@ fn modular_constr_poly( output[0] += mod_is_zero * lv[IS_DIV]; // Verify that the output is reduced, i.e. output < modulus. - let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; - // this sets is_less_than to 1 unless we get mod_is_zero when + let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; + // This sets is_less_than to 1 unless we get mod_is_zero when // doing a DIV; in that case, we need is_less_than=0, since the // function checks // @@ -317,6 +359,8 @@ fn modular_constr_poly( // // and we were given output = out_aux_red let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV]; + // NB: output and modulus in lv while out_aux_red and is_less_than + // (via mod_is_zero) depend on nv. eval_packed_generic_lt( yield_constr, filter, @@ -324,16 +368,23 @@ fn modular_constr_poly( &modulus, out_aux_red, is_less_than, + true, ); // restore output[0] output[0] -= mod_is_zero * lv[IS_DIV]; // prod = q(x) * m(x) - let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); + let quot = { + let mut quot = [P::default(); 2 * N_LIMBS]; + quot[..N_LIMBS].copy_from_slice(&lv[MODULAR_QUO_INPUT_LO]); + quot[N_LIMBS..].copy_from_slice(&nv[MODULAR_QUO_INPUT_HI]); + quot + }; + let prod = pol_mul_wide2(quot, modulus); // higher order terms must be zero for &x in prod[2 * N_LIMBS..].iter() { - yield_constr.constraint(filter * x); + yield_constr.constraint_transition(filter * x); } // constr_poly = c(x) + q(x) * m(x) @@ -341,8 +392,11 @@ fn modular_constr_poly( pol_add_assign(&mut constr_poly, &output); // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) - let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); - aux[2 * N_LIMBS - 1] = P::ZEROS; // zero out the MOD_IS_ZERO flag + let mut aux = [P::ZEROS; 2 * N_LIMBS]; + for (i, j) in MODULAR_AUX_INPUT.enumerate() { + aux[i] = nv[j]; + } + let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base)); @@ -352,6 +406,7 @@ fn modular_constr_poly( /// Add constraints for modular operations. pub(crate) fn eval_packed_generic( lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { // NB: The CTL code guarantees that filter is 0 or 1, i.e. that @@ -362,8 +417,12 @@ pub(crate) fn eval_packed_generic( + lv[columns::IS_SUBMOD] + lv[columns::IS_DIV]; + // Ensure that this operation is not the last row of the table; + // needed because we access the next row of the table in nv. + yield_constr.constraint_last_row(filter); + // constr_poly has 2*N_LIMBS limbs - let constr_poly = modular_constr_poly(lv, yield_constr, filter); + let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter); let input0 = read_value(lv, MODULAR_INPUT_0); let input1 = read_value(lv, MODULAR_INPUT_1); @@ -394,35 +453,36 @@ pub(crate) fn eval_packed_generic( // operation is valid if and only if all of those coefficients // are zero. for &c in constr_poly_copy.iter() { - yield_constr.constraint(filter * c); + yield_constr.constraint_transition(filter * c); } } } fn modular_constr_poly_ext_circuit, const D: usize>( lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], builder: &mut CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, ) -> [ExtensionTarget; 2 * N_LIMBS] { let mut modulus = read_value::(lv, MODULAR_MODULUS); - let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; + let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); let t = builder.mul_extension(filter, t); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); let limb_sum = builder.add_many_extension(modulus); let t = builder.mul_extension(limb_sum, mod_is_zero); let t = builder.mul_extension(filter, t); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); modulus[0] = builder.add_extension(modulus[0], mod_is_zero); let mut output = read_value::(lv, MODULAR_OUTPUT); output[0] = builder.mul_add_extension(mod_is_zero, lv[IS_DIV], output[0]); - let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; + let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; let one = builder.one_extension(); let is_less_than = builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); @@ -435,22 +495,33 @@ fn modular_constr_poly_ext_circuit, const D: usize> &modulus, out_aux_red, is_less_than, + true, ); output[0] = builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], output[0]); + let quot = { + let zero = builder.zero_extension(); + let mut quot = [zero; 2 * N_LIMBS]; + quot[..N_LIMBS].copy_from_slice(&lv[MODULAR_QUO_INPUT_LO]); + quot[N_LIMBS..].copy_from_slice(&nv[MODULAR_QUO_INPUT_HI]); + quot + }; - let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); for &x in prod[2 * N_LIMBS..].iter() { let t = builder.mul_extension(filter, x); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); } let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); - let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); - aux[2 * N_LIMBS - 1] = builder.zero_extension(); + let zero = builder.zero_extension(); + let mut aux = [zero; 2 * N_LIMBS]; + for (i, j) in MODULAR_AUX_INPUT.enumerate() { + aux[i] = nv[j]; + } + let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS)); let t = pol_adjoin_root_ext_circuit(builder, aux, base); pol_add_assign_ext_circuit(builder, &mut constr_poly, &t); @@ -461,6 +532,7 @@ fn modular_constr_poly_ext_circuit, const D: usize> pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { let filter = builder.add_many_extension([ @@ -471,8 +543,9 @@ pub(crate) fn eval_ext_circuit, const D: usize>( lv[columns::IS_DIV], ]); - let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter); + yield_constr.constraint_last_row(builder, filter); + let constr_poly = modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter); let input0 = read_value(lv, MODULAR_INPUT_0); let input1 = read_value(lv, MODULAR_INPUT_1); @@ -492,7 +565,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input); for &c in constr_poly_copy.iter() { let t = builder.mul_extension(filter, c); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); } } } @@ -518,6 +591,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); // if `IS_ADDMOD == 0`, then the constraints should be met even // if all values are garbage. @@ -533,7 +607,7 @@ mod tests { GoldilocksField::ONE, GoldilocksField::ONE, ); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -545,6 +619,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] { // Reset operation columns, then select one @@ -563,9 +638,9 @@ mod tests { lv[mi] = F::from_canonical_u16(rng.gen()); } - // For the second half of the tests, set the top 16 - - // start digits of the modulus to zero so it is much - // smaller than the inputs. + // For the second half of the tests, set the top + // 16-start digits of the modulus to zero so it is + // much smaller than the inputs. if i > N_RND_TESTS / 2 { // 1 <= start < N_LIMBS let start = (rng.gen::() % (N_LIMBS - 1)) + 1; @@ -574,15 +649,15 @@ mod tests { } } - generate(&mut lv, op_filter); + generate(&mut lv, &mut nv, op_filter); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, ); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -596,6 +671,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] { // Reset operation columns, then select one @@ -609,13 +685,14 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values and the modulus to zero; // the output is defined to be zero when modulus is zero. + for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); lv[mi] = F::ZERO; } - generate(&mut lv, op_filter); + generate(&mut lv, &mut nv, op_filter); // check that the correct output was generated if op_filter == IS_DIV { @@ -627,24 +704,25 @@ mod tests { let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, ); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); assert!(constraint_consumer .constraint_accs .iter() .all(|&acc| acc == F::ZERO)); // Corrupt one output limb by setting it to a non-zero value - let random_oi = if op_filter == IS_DIV { - DIV_OUTPUT.start + rng.gen::() % N_LIMBS + if op_filter == IS_DIV { + let random_oi = DIV_OUTPUT.start + rng.gen::() % N_LIMBS; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); } else { - MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS + let random_oi = MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); }; - lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); // Check that at least one of the constraints was non-zero assert!(constraint_consumer diff --git a/evm/src/arithmetic/sub.rs b/evm/src/arithmetic/sub.rs index d589f323..13f6e8d5 100644 --- a/evm/src/arithmetic/sub.rs +++ b/evm/src/arithmetic/sub.rs @@ -57,6 +57,7 @@ pub fn eval_packed_generic( is_sub, output_limbs.iter().copied(), output_computed, + false, ); } @@ -87,6 +88,7 @@ pub fn eval_ext_circuit, const D: usize>( is_sub, output_limbs.iter().copied(), output_computed.into_iter(), + false, ); } diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 0a894553..ba45f738 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -13,53 +13,47 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::keccak_util::keccakf_u32s; +use crate::cpu::membus::NUM_GP_CHANNELS; use crate::generation::state::GenerationState; -use crate::keccak_sponge::columns::KECCAK_RATE_U32S; use crate::memory::segments::Segment; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; - -/// We can't process more than `NUM_CHANNELS` bytes per row, since that's all the memory bandwidth -/// we have. We also can't process more than 4 bytes (or the number of bytes in a `u32`), since we -/// want them to fit in a single limb of Keccak input. -const BYTES_PER_ROW: usize = 4; +use crate::witness::memory::MemoryAddress; +use crate::witness::util::{keccak_sponge_log, mem_write_gp_log_and_fill}; pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState) { - let mut sponge_state = [0u32; 50]; - let mut sponge_input_pos: usize = 0; - // Iterate through chunks of the code, such that we can write one chunk to memory per row. - for chunk in &KERNEL - .padded_code() - .iter() - .enumerate() - .chunks(BYTES_PER_ROW) - { - state.current_cpu_row.is_bootstrap_kernel = F::ONE; + for chunk in &KERNEL.code.iter().enumerate().chunks(NUM_GP_CHANNELS) { + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + cpu_row.is_bootstrap_kernel = F::ONE; // Write this chunk to memory, while simultaneously packing its bytes into a u32 word. - let mut packed_bytes: u32 = 0; for (channel, (addr, &byte)) in chunk.enumerate() { - state.set_mem_cpu_current(channel, Segment::Code, addr, byte.into()); - - packed_bytes = (packed_bytes << 8) | byte as u32; + let address = MemoryAddress::new(0, Segment::Code, addr); + let write = + mem_write_gp_log_and_fill(channel, address, state, &mut cpu_row, byte.into()); + state.traces.push_memory(write); } - sponge_state[sponge_input_pos] = packed_bytes; - let keccak = state.current_cpu_row.general.keccak_mut(); - keccak.input_limbs = sponge_state.map(F::from_canonical_u32); - state.commit_cpu_row(); - - sponge_input_pos = (sponge_input_pos + 1) % KECCAK_RATE_U32S; - // If we just crossed a multiple of KECCAK_RATE_LIMBS, then we've filled the Keccak input - // buffer, so it's time to absorb. - if sponge_input_pos == 0 { - state.current_cpu_row.is_keccak = F::ONE; - keccakf_u32s(&mut sponge_state); - let keccak = state.current_cpu_row.general.keccak_mut(); - keccak.output_limbs = sponge_state.map(F::from_canonical_u32); - } + state.traces.push_cpu(cpu_row); } + + let mut final_cpu_row = CpuColumnsView::default(); + final_cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + final_cpu_row.is_bootstrap_kernel = F::ONE; + final_cpu_row.is_keccak_sponge = F::ONE; + // The Keccak sponge CTL uses memory value columns for its inputs and outputs. + final_cpu_row.mem_channels[0].value[0] = F::ZERO; + final_cpu_row.mem_channels[1].value[0] = F::from_canonical_usize(Segment::Code as usize); + final_cpu_row.mem_channels[2].value[0] = F::ZERO; + final_cpu_row.mem_channels[3].value[0] = F::from_canonical_usize(state.traces.clock()); + final_cpu_row.mem_channels[4].value = KERNEL.code_hash.map(F::from_canonical_u32); + state.traces.push_cpu(final_cpu_row); + keccak_sponge_log( + state, + MemoryAddress::new(0, Segment::Code, 0), + KERNEL.code.clone(), + ); } pub(crate) fn eval_bootstrap_kernel>( @@ -77,19 +71,25 @@ pub(crate) fn eval_bootstrap_kernel>( let delta_is_bootstrap = next_is_bootstrap - local_is_bootstrap; yield_constr.constraint_transition(delta_is_bootstrap * (delta_is_bootstrap + P::ONES)); - // TODO: Constraints to enforce that, if IS_BOOTSTRAP_KERNEL, - // - If CLOCK is a multiple of KECCAK_RATE_LIMBS, activate the Keccak CTL, and ensure the output - // is copied to the next row (besides the first limb which will immediately be overwritten). - // - Otherwise, ensure that the Keccak input is copied to the next row (besides the next limb). - // - The next limb we add to the buffer is also written to memory. + // If this is a bootloading row and the i'th memory channel is used, it must have the right + // address, name context = 0, segment = Code, virt = clock * NUM_GP_CHANNELS + i. + let code_segment = F::from_canonical_usize(Segment::Code as usize); + for (i, channel) in local_values.mem_channels.iter().enumerate() { + let filter = local_is_bootstrap * channel.used; + yield_constr.constraint(filter * channel.addr_context); + yield_constr.constraint(filter * (channel.addr_segment - code_segment)); + let expected_virt = local_values.clock * F::from_canonical_usize(NUM_GP_CHANNELS) + + F::from_canonical_usize(i); + yield_constr.constraint(filter * (channel.addr_virtual - expected_virt)); + } - // If IS_BOOTSTRAP_KERNEL changed (from 1 to 0), check that - // - the clock is a multiple of KECCAK_RATE_LIMBS (TODO) + // If this is the final bootstrap row (i.e. delta_is_bootstrap = 1), check that + // - all memory channels are disabled (TODO) // - the current kernel hash matches a precomputed one for (&expected, actual) in KERNEL .code_hash .iter() - .zip(local_values.general.keccak().output_limbs) + .zip(local_values.mem_channels.last().unwrap().value) { let expected = P::from(F::from_canonical_u32(expected)); let diff = expected - actual; @@ -117,19 +117,35 @@ pub(crate) fn eval_bootstrap_kernel_circuit, const builder.mul_add_extension(delta_is_bootstrap, delta_is_bootstrap, delta_is_bootstrap); yield_constr.constraint_transition(builder, constraint); - // TODO: Constraints to enforce that, if IS_BOOTSTRAP_KERNEL, - // - If CLOCK is a multiple of KECCAK_RATE_LIMBS, activate the Keccak CTL, and ensure the output - // is copied to the next row (besides the first limb which will immediately be overwritten). - // - Otherwise, ensure that the Keccak input is copied to the next row (besides the next limb). - // - The next limb we add to the buffer is also written to memory. + // If this is a bootloading row and the i'th memory channel is used, it must have the right + // address, name context = 0, segment = Code, virt = clock * NUM_GP_CHANNELS + i. + let code_segment = + builder.constant_extension(F::Extension::from_canonical_usize(Segment::Code as usize)); + for (i, channel) in local_values.mem_channels.iter().enumerate() { + let filter = builder.mul_extension(local_is_bootstrap, channel.used); + let constraint = builder.mul_extension(filter, channel.addr_context); + yield_constr.constraint(builder, constraint); - // If IS_BOOTSTRAP_KERNEL changed (from 1 to 0), check that - // - the clock is a multiple of KECCAK_RATE_LIMBS (TODO) + let segment_diff = builder.sub_extension(channel.addr_segment, code_segment); + let constraint = builder.mul_extension(filter, segment_diff); + yield_constr.constraint(builder, constraint); + + let i_ext = builder.constant_extension(F::Extension::from_canonical_usize(i)); + let num_gp_channels_f = F::from_canonical_usize(NUM_GP_CHANNELS); + let expected_virt = + builder.mul_const_add_extension(num_gp_channels_f, local_values.clock, i_ext); + let virt_diff = builder.sub_extension(channel.addr_virtual, expected_virt); + let constraint = builder.mul_extension(filter, virt_diff); + yield_constr.constraint(builder, constraint); + } + + // If this is the final bootstrap row (i.e. delta_is_bootstrap = 1), check that + // - all memory channels are disabled (TODO) // - the current kernel hash matches a precomputed one for (&expected, actual) in KERNEL .code_hash .iter() - .zip(local_values.general.keccak().output_limbs) + .zip(local_values.mem_channels.last().unwrap().value) { let expected = builder.constant_extension(F::Extension::from_canonical_u32(expected)); let diff = builder.sub_extension(expected, actual); diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 5a2c9426..67fe4256 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -4,8 +4,8 @@ use std::mem::{size_of, transmute}; /// General purpose columns, which can have different meanings depending on what CTL or other /// operation is occurring at this row. +#[derive(Clone, Copy)] pub(crate) union CpuGeneralColumnsView { - keccak: CpuKeccakView, arithmetic: CpuArithmeticView, logic: CpuLogicView, jumps: CpuJumpsView, @@ -13,16 +13,6 @@ pub(crate) union CpuGeneralColumnsView { } impl CpuGeneralColumnsView { - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn keccak(&self) -> &CpuKeccakView { - unsafe { &self.keccak } - } - - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn keccak_mut(&mut self) -> &mut CpuKeccakView { - unsafe { &mut self.keccak } - } - // SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn arithmetic(&self) -> &CpuArithmeticView { unsafe { &self.arithmetic } @@ -93,12 +83,6 @@ impl BorrowMut<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView { } } -#[derive(Copy, Clone)] -pub(crate) struct CpuKeccakView { - pub(crate) input_limbs: [T; 50], - pub(crate) output_limbs: [T; 50], -} - #[derive(Copy, Clone)] pub(crate) struct CpuArithmeticView { // TODO: Add "looking" columns for the arithmetic CTL. diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index d0ef3f28..408e17dc 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -6,6 +6,8 @@ use std::fmt::Debug; use std::mem::{size_of, transmute}; use std::ops::{Index, IndexMut}; +use plonky2::field::types::Field; + use crate::cpu::columns::general::CpuGeneralColumnsView; use crate::cpu::columns::ops::OpsColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; @@ -31,7 +33,7 @@ pub struct MemoryChannelView { } #[repr(C)] -#[derive(Eq, PartialEq, Debug)] +#[derive(Clone, Copy, Eq, PartialEq, Debug)] pub struct CpuColumnsView { /// Filter. 1 if the row is part of bootstrapping the kernel code, 0 otherwise. pub is_bootstrap_kernel: T, @@ -67,11 +69,8 @@ pub struct CpuColumnsView { /// If CPU cycle: the opcode, broken up into bits in little-endian order. pub opcode_bits: [T; 8], - /// Filter. 1 iff a Keccak lookup is performed on this row. - pub is_keccak: T, - - /// Filter. 1 iff a Keccak memory lookup is performed on this row. - pub is_keccak_memory: T, + /// Filter. 1 iff a Keccak sponge lookup is performed on this row. + pub is_keccak_sponge: T, pub(crate) general: CpuGeneralColumnsView, @@ -82,6 +81,12 @@ pub struct CpuColumnsView { // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_CPU_COLUMNS: usize = size_of::>(); +impl Default for CpuColumnsView { + fn default() -> Self { + Self::from([F::ZERO; NUM_CPU_COLUMNS]) + } +} + impl From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView { fn from(value: [T; NUM_CPU_COLUMNS]) -> Self { unsafe { transmute_no_compile_time_size_checks(value) } diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index c265be44..63f6795d 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -5,8 +5,8 @@ use std::ops::{Deref, DerefMut}; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; #[repr(C)] -#[derive(Eq, PartialEq, Debug)] -pub struct OpsColumnsView { +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +pub struct OpsColumnsView { // TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag pub add: T, pub mul: T, @@ -41,12 +41,6 @@ pub struct OpsColumnsView { pub pc: T, pub gas: T, pub jumpdest: T, - // TODO: combine GET_STATE_ROOT and SET_STATE_ROOT into one flag - pub get_state_root: T, - pub set_state_root: T, - // TODO: combine GET_RECEIPT_ROOT and SET_RECEIPT_ROOT into one flag - pub get_receipt_root: T, - pub set_receipt_root: T, pub push: T, pub dup: T, pub swap: T, @@ -65,38 +59,38 @@ pub struct OpsColumnsView { // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_OPS_COLUMNS: usize = size_of::>(); -impl From<[T; NUM_OPS_COLUMNS]> for OpsColumnsView { +impl From<[T; NUM_OPS_COLUMNS]> for OpsColumnsView { fn from(value: [T; NUM_OPS_COLUMNS]) -> Self { unsafe { transmute_no_compile_time_size_checks(value) } } } -impl From> for [T; NUM_OPS_COLUMNS] { +impl From> for [T; NUM_OPS_COLUMNS] { fn from(value: OpsColumnsView) -> Self { unsafe { transmute_no_compile_time_size_checks(value) } } } -impl Borrow> for [T; NUM_OPS_COLUMNS] { +impl Borrow> for [T; NUM_OPS_COLUMNS] { fn borrow(&self) -> &OpsColumnsView { unsafe { transmute(self) } } } -impl BorrowMut> for [T; NUM_OPS_COLUMNS] { +impl BorrowMut> for [T; NUM_OPS_COLUMNS] { fn borrow_mut(&mut self) -> &mut OpsColumnsView { unsafe { transmute(self) } } } -impl Deref for OpsColumnsView { +impl Deref for OpsColumnsView { type Target = [T; NUM_OPS_COLUMNS]; fn deref(&self) -> &Self::Target { unsafe { transmute(self) } } } -impl DerefMut for OpsColumnsView { +impl DerefMut for OpsColumnsView { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { transmute(self) } } diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index ba0bbd3b..c0adc7bd 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 37] = [ +const NATIVE_INSTRUCTIONS: [usize; 33] = [ COL_MAP.op.add, COL_MAP.op.mul, COL_MAP.op.sub, @@ -37,10 +37,6 @@ const NATIVE_INSTRUCTIONS: [usize; 37] = [ COL_MAP.op.pc, COL_MAP.op.gas, COL_MAP.op.jumpdest, - COL_MAP.op.get_state_root, - COL_MAP.op.set_state_root, - COL_MAP.op.get_receipt_root, - COL_MAP.op.set_receipt_root, // not PUSH (need to increment by more than 1) COL_MAP.op.dup, COL_MAP.op.swap, @@ -53,7 +49,7 @@ const NATIVE_INSTRUCTIONS: [usize; 37] = [ // not SYSCALL (performs a jump) ]; -fn get_halt_pcs() -> (F, F) { +pub(crate) fn get_halt_pcs() -> (F, F) { let halt_pc0 = KERNEL.global_labels["halt_pc0"]; let halt_pc1 = KERNEL.global_labels["halt_pc1"]; @@ -63,6 +59,12 @@ fn get_halt_pcs() -> (F, F) { ) } +pub(crate) fn get_start_pc() -> F { + let start_pc = KERNEL.global_labels["main"]; + + F::from_canonical_usize(start_pc) +} + pub fn eval_packed_generic( lv: &CpuColumnsView

, nv: &CpuColumnsView

, @@ -89,8 +91,7 @@ pub fn eval_packed_generic( // - execution is in kernel mode, and // - the stack is empty. let is_last_noncpu_cycle = (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle; - let pc_diff = - nv.program_counter - P::Scalar::from_canonical_usize(KERNEL.global_labels["main"]); + let pc_diff = nv.program_counter - get_start_pc::(); yield_constr.constraint_transition(is_last_noncpu_cycle * pc_diff); yield_constr.constraint_transition(is_last_noncpu_cycle * (nv.is_kernel_mode - P::ONES)); yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len); @@ -142,9 +143,7 @@ pub fn eval_ext_circuit, const D: usize>( builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle); // Start at `main`. - let main = builder.constant_extension(F::Extension::from_canonical_usize( - KERNEL.global_labels["main"], - )); + let main = builder.constant_extension(get_start_pc::().into()); let pc_diff = builder.sub_extension(nv.program_counter, main); let pc_constr = builder.mul_extension(is_last_noncpu_cycle, pc_diff); yield_constr.constraint_transition(builder, pc_constr); diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 4cc38823..f0c85638 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -20,34 +20,28 @@ use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; -pub fn ctl_data_keccak() -> Vec> { - let keccak = COL_MAP.general.keccak(); - let mut res: Vec<_> = Column::singles(keccak.input_limbs).collect(); - res.extend(Column::singles(keccak.output_limbs)); - res -} - -pub fn ctl_data_keccak_memory() -> Vec> { +pub fn ctl_data_keccak_sponge() -> Vec> { // When executing KECCAK_GENERAL, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context // GP channel 1: stack[-2] = segment - // GP channel 2: stack[-3] = virtual + // GP channel 2: stack[-3] = virt + // GP channel 3: stack[-4] = len + // GP channel 4: pushed = outputs let context = Column::single(COL_MAP.mem_channels[0].value[0]); let segment = Column::single(COL_MAP.mem_channels[1].value[0]); let virt = Column::single(COL_MAP.mem_channels[2].value[0]); + let len = Column::single(COL_MAP.mem_channels[3].value[0]); let num_channels = F::from_canonical_usize(NUM_CHANNELS); - let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]); + let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); - vec![context, segment, virt, clock] + let mut cols = vec![context, segment, virt, len, timestamp]; + cols.extend(COL_MAP.mem_channels[3].value.map(Column::single)); + cols } -pub fn ctl_filter_keccak() -> Column { - Column::single(COL_MAP.is_keccak) -} - -pub fn ctl_filter_keccak_memory() -> Column { - Column::single(COL_MAP.is_keccak_memory) +pub fn ctl_filter_keccak_sponge() -> Column { + Column::single(COL_MAP.is_keccak_sponge) } pub fn ctl_data_logic() -> Vec> { @@ -122,11 +116,11 @@ pub struct CpuStark { } impl CpuStark { + // TODO: Remove? pub fn generate(&self, local_values: &mut [F; NUM_CPU_COLUMNS]) { let local_values: &mut CpuColumnsView<_> = local_values.borrow_mut(); decode::generate(local_values); membus::generate(local_values); - simple_logic::generate(local_values); stack_bounds::generate(local_values); // Must come after `decode`. } } @@ -144,17 +138,19 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark, prover_inputs: HashMap, ) -> Self { - let code_hash = hash_kernel(&Self::padded_code_helper(&code)); - + let code_hash_bytes = keccak(&code).0; + let code_hash = std::array::from_fn(|i| { + u32::from_le_bytes(std::array::from_fn(|j| code_hash_bytes[i * 4 + j])) + }); Self { code, code_hash, @@ -51,16 +51,16 @@ impl Kernel { } } - /// Zero-pads the code such that its length is a multiple of the Keccak rate. - pub(crate) fn padded_code(&self) -> Vec { - Self::padded_code_helper(&self.code) + /// Get a string representation of the current offset for debugging purposes. + pub(crate) fn offset_name(&self, offset: usize) -> String { + self.offset_label(offset) + .unwrap_or_else(|| offset.to_string()) } - fn padded_code_helper(code: &[u8]) -> Vec { - let padded_len = ceil_div_usize(code.len(), KECCAK_RATE_BYTES) * KECCAK_RATE_BYTES; - let mut padded_code = code.to_vec(); - padded_code.resize(padded_len, 0); - padded_code + pub(crate) fn offset_label(&self, offset: usize) -> Option { + self.global_labels + .iter() + .find_map(|(k, v)| (*v == offset).then(|| k.clone())) } } diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 93ec1902..3871db84 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -11,43 +11,21 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField; -use crate::generation::memory::{MemoryContextState, MemorySegmentState}; use crate::generation::prover_input::ProverInputFn; use crate::generation::state::GenerationState; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; +use crate::witness::memory::{MemoryContextState, MemorySegmentState, MemoryState}; +use crate::witness::util::stack_peek; type F = GoldilocksField; /// Halt interpreter execution whenever a jump to this offset is done. const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef; -#[derive(Clone, Debug)] -pub(crate) struct InterpreterMemory { - pub(crate) context_memory: Vec, -} - -impl Default for InterpreterMemory { - fn default() -> Self { - Self { - context_memory: vec![MemoryContextState::default()], - } - } -} - -impl InterpreterMemory { - fn with_code_and_stack(code: &[u8], stack: Vec) -> Self { - let mut mem = Self::default(); - for (i, b) in code.iter().copied().enumerate() { - mem.context_memory[0].segments[Segment::Code as usize].set(i, b.into()); - } - mem.context_memory[0].segments[Segment::Stack as usize].content = stack; - - mem - } - +impl MemoryState { fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 { - let value = self.context_memory[context].segments[segment as usize].get(offset); + let value = self.contexts[context].segments[segment as usize].get(offset); assert!( value.bits() <= segment.bit_range(), "Value read from memory exceeds expected range of {:?} segment", @@ -62,16 +40,14 @@ impl InterpreterMemory { "Value written to memory exceeds expected range of {:?} segment", segment ); - self.context_memory[context].segments[segment as usize].set(offset, value) + self.contexts[context].segments[segment as usize].set(offset, value) } } pub struct Interpreter<'a> { kernel_mode: bool, jumpdests: Vec, - pub(crate) offset: usize, pub(crate) context: usize, - pub(crate) memory: InterpreterMemory, pub(crate) generation_state: GenerationState, prover_inputs_map: &'a HashMap, pub(crate) halt_offsets: Vec, @@ -119,19 +95,21 @@ impl<'a> Interpreter<'a> { initial_stack: Vec, prover_inputs: &'a HashMap, ) -> Self { - Self { + let mut result = Self { kernel_mode: true, jumpdests: find_jumpdests(code), - offset: initial_offset, - memory: InterpreterMemory::with_code_and_stack(code, initial_stack), - generation_state: GenerationState::new(GenerationInputs::default()), + generation_state: GenerationState::new(GenerationInputs::default(), code), prover_inputs_map: prover_inputs, context: 0, halt_offsets: vec![DEFAULT_HALT_OFFSET], debug_offsets: vec![], running: false, opcode_count: [0; 0x100], - } + }; + result.generation_state.registers.program_counter = initial_offset; + result.generation_state.registers.stack_len = initial_stack.len(); + *result.stack_mut() = initial_stack; + result } pub(crate) fn run(&mut self) -> anyhow::Result<()> { @@ -152,48 +130,51 @@ impl<'a> Interpreter<'a> { } fn code(&self) -> &MemorySegmentState { - &self.memory.context_memory[self.context].segments[Segment::Code as usize] + &self.generation_state.memory.contexts[self.context].segments[Segment::Code as usize] } fn code_slice(&self, n: usize) -> Vec { - self.code().content[self.offset..self.offset + n] + let pc = self.generation_state.registers.program_counter; + self.code().content[pc..pc + n] .iter() .map(|u256| u256.byte(0)) .collect::>() } pub(crate) fn get_txn_field(&self, field: NormalizedTxnField) -> U256 { - self.memory.context_memory[0].segments[Segment::TxnFields as usize].get(field as usize) + self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize] + .get(field as usize) } pub(crate) fn set_txn_field(&mut self, field: NormalizedTxnField, value: U256) { - self.memory.context_memory[0].segments[Segment::TxnFields as usize] + self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize] .set(field as usize, value); } pub(crate) fn get_txn_data(&self) -> &[U256] { - &self.memory.context_memory[0].segments[Segment::TxnData as usize].content + &self.generation_state.memory.contexts[0].segments[Segment::TxnData as usize].content } pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 { - self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize].get(field as usize) + self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize] + .get(field as usize) } pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) { - self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize] + self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize] .set(field as usize, value) } pub(crate) fn get_trie_data(&self) -> &[U256] { - &self.memory.context_memory[0].segments[Segment::TrieData as usize].content + &self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content } pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec { - &mut self.memory.context_memory[0].segments[Segment::TrieData as usize].content + &mut self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content } pub(crate) fn get_rlp_memory(&self) -> Vec { - self.memory.context_memory[0].segments[Segment::RlpRaw as usize] + self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize] .content .iter() .map(|x| x.as_u32() as u8) @@ -201,23 +182,24 @@ impl<'a> Interpreter<'a> { } pub(crate) fn set_rlp_memory(&mut self, rlp: Vec) { - self.memory.context_memory[0].segments[Segment::RlpRaw as usize].content = + self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize].content = rlp.into_iter().map(U256::from).collect(); } pub(crate) fn set_code(&mut self, context: usize, code: Vec) { assert_ne!(context, 0, "Can't modify kernel code."); - while self.memory.context_memory.len() <= context { - self.memory - .context_memory + while self.generation_state.memory.contexts.len() <= context { + self.generation_state + .memory + .contexts .push(MemoryContextState::default()); } - self.memory.context_memory[context].segments[Segment::Code as usize].content = + self.generation_state.memory.contexts[context].segments[Segment::Code as usize].content = code.into_iter().map(U256::from).collect(); } pub(crate) fn get_jumpdest_bits(&self, context: usize) -> Vec { - self.memory.context_memory[context].segments[Segment::JumpdestBits as usize] + self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content .iter() .map(|x| x.bit(0)) @@ -225,19 +207,22 @@ impl<'a> Interpreter<'a> { } fn incr(&mut self, n: usize) { - self.offset += n; + self.generation_state.registers.program_counter += n; } pub(crate) fn stack(&self) -> &[U256] { - &self.memory.context_memory[self.context].segments[Segment::Stack as usize].content + &self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] + .content } fn stack_mut(&mut self) -> &mut Vec { - &mut self.memory.context_memory[self.context].segments[Segment::Stack as usize].content + &mut self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] + .content } pub(crate) fn push(&mut self, x: U256) { self.stack_mut().push(x); + self.generation_state.registers.stack_len += 1; } fn push_bool(&mut self, x: bool) { @@ -245,11 +230,18 @@ impl<'a> Interpreter<'a> { } pub(crate) fn pop(&mut self) -> U256 { - self.stack_mut().pop().expect("Pop on empty stack.") + let result = stack_peek(&self.generation_state, 0); + self.generation_state.registers.stack_len -= 1; + let new_len = self.stack_len(); + self.stack_mut().truncate(new_len); + result.expect("Empty stack") } fn run_opcode(&mut self) -> anyhow::Result<()> { - let opcode = self.code().get(self.offset).byte(0); + let opcode = self + .code() + .get(self.generation_state.registers.program_counter) + .byte(0); self.opcode_count[opcode as usize] += 1; self.incr(1); match opcode { @@ -321,10 +313,6 @@ impl<'a> Interpreter<'a> { 0x59 => self.run_msize(), // "MSIZE", 0x5a => todo!(), // "GAS", 0x5b => self.run_jumpdest(), // "JUMPDEST", - 0x5c => todo!(), // "GET_STATE_ROOT", - 0x5d => todo!(), // "SET_STATE_ROOT", - 0x5e => todo!(), // "GET_RECEIPT_ROOT", - 0x5f => todo!(), // "SET_RECEIPT_ROOT", x if (0x60..0x80).contains(&x) => self.run_push(x - 0x5f), // "PUSH" x if (0x80..0x90).contains(&x) => self.run_dup(x - 0x7f), // "DUP" x if (0x90..0xa0).contains(&x) => self.run_swap(x - 0x8f)?, // "SWAP" @@ -353,7 +341,10 @@ impl<'a> Interpreter<'a> { _ => bail!("Unrecognized opcode {}.", opcode), }; - if self.debug_offsets.contains(&self.offset) { + if self + .debug_offsets + .contains(&self.generation_state.registers.program_counter) + { println!("At {}, stack={:?}", self.offset_name(), self.stack()); } else if let Some(label) = self.offset_label() { println!("At {label}"); @@ -362,18 +353,12 @@ impl<'a> Interpreter<'a> { Ok(()) } - /// Get a string representation of the current offset for debugging purposes. fn offset_name(&self) -> String { - self.offset_label() - .unwrap_or_else(|| self.offset.to_string()) + KERNEL.offset_name(self.generation_state.registers.program_counter) } fn offset_label(&self) -> Option { - // TODO: Not sure we should use KERNEL? Interpreter is more general in other places. - KERNEL - .global_labels - .iter() - .find_map(|(k, v)| (*v == self.offset).then(|| k.clone())) + KERNEL.offset_label(self.generation_state.registers.program_counter) } fn run_stop(&mut self) { @@ -508,12 +493,10 @@ impl<'a> Interpreter<'a> { fn run_byte(&mut self) { let i = self.pop(); let x = self.pop(); - let result = if i > 32.into() { - 0 + let result = if i < 32.into() { + x.byte(31 - i.as_usize()) } else { - let mut bytes = [0; 32]; - x.to_big_endian(&mut bytes); - bytes[i.as_usize()] + 0 }; self.push(result.into()); } @@ -535,7 +518,8 @@ impl<'a> Interpreter<'a> { let size = self.pop().as_usize(); let bytes = (offset..offset + size) .map(|i| { - self.memory + self.generation_state + .memory .mload_general(self.context, Segment::MainMemory, i) .byte(0) }) @@ -552,7 +536,12 @@ impl<'a> Interpreter<'a> { let offset = self.pop().as_usize(); let size = self.pop().as_usize(); let bytes = (offset..offset + size) - .map(|i| self.memory.mload_general(context, segment, i).byte(0)) + .map(|i| { + self.generation_state + .memory + .mload_general(context, segment, i) + .byte(0) + }) .collect::>(); println!("Hashing {:?}", &bytes); let hash = keccak(bytes); @@ -561,7 +550,8 @@ impl<'a> Interpreter<'a> { fn run_callvalue(&mut self) { self.push( - self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize] + self.generation_state.memory.contexts[self.context].segments + [Segment::ContextMetadata as usize] .get(ContextMetadata::CallValue as usize), ) } @@ -571,7 +561,8 @@ impl<'a> Interpreter<'a> { let value = U256::from_big_endian( &(0..32) .map(|i| { - self.memory + self.generation_state + .memory .mload_general(self.context, Segment::Calldata, offset + i) .byte(0) }) @@ -582,7 +573,8 @@ impl<'a> Interpreter<'a> { fn run_calldatasize(&mut self) { self.push( - self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize] + self.generation_state.memory.contexts[self.context].segments + [Segment::ContextMetadata as usize] .get(ContextMetadata::CalldataSize as usize), ) } @@ -592,10 +584,12 @@ impl<'a> Interpreter<'a> { let offset = self.pop().as_usize(); let size = self.pop().as_usize(); for i in 0..size { - let calldata_byte = - self.memory - .mload_general(self.context, Segment::Calldata, offset + i); - self.memory.mstore_general( + let calldata_byte = self.generation_state.memory.mload_general( + self.context, + Segment::Calldata, + offset + i, + ); + self.generation_state.memory.mstore_general( self.context, Segment::MainMemory, dest_offset + i, @@ -607,10 +601,9 @@ impl<'a> Interpreter<'a> { fn run_prover_input(&mut self) -> anyhow::Result<()> { let prover_input_fn = self .prover_inputs_map - .get(&(self.offset - 1)) + .get(&(self.generation_state.registers.program_counter - 1)) .ok_or_else(|| anyhow!("Offset not in prover inputs."))?; - let stack = self.stack().to_vec(); - let output = self.generation_state.prover_input(&stack, prover_input_fn); + let output = self.generation_state.prover_input(prover_input_fn); self.push(output); Ok(()) } @@ -624,7 +617,8 @@ impl<'a> Interpreter<'a> { let value = U256::from_big_endian( &(0..32) .map(|i| { - self.memory + self.generation_state + .memory .mload_general(self.context, Segment::MainMemory, offset + i) .byte(0) }) @@ -639,15 +633,19 @@ impl<'a> Interpreter<'a> { let mut bytes = [0; 32]; value.to_big_endian(&mut bytes); for (i, byte) in (0..32).zip(bytes) { - self.memory - .mstore_general(self.context, Segment::MainMemory, offset + i, byte.into()); + self.generation_state.memory.mstore_general( + self.context, + Segment::MainMemory, + offset + i, + byte.into(), + ); } } fn run_mstore8(&mut self) { let offset = self.pop().as_usize(); let value = self.pop(); - self.memory.mstore_general( + self.generation_state.memory.mstore_general( self.context, Segment::MainMemory, offset, @@ -669,12 +667,13 @@ impl<'a> Interpreter<'a> { } fn run_pc(&mut self) { - self.push((self.offset - 1).into()); + self.push((self.generation_state.registers.program_counter - 1).into()); } fn run_msize(&mut self) { self.push( - self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize] + self.generation_state.memory.contexts[self.context].segments + [Segment::ContextMetadata as usize] .get(ContextMetadata::MSize as usize), ) } @@ -689,7 +688,7 @@ impl<'a> Interpreter<'a> { panic!("Destination is not a JUMPDEST."); } - self.offset = offset; + self.generation_state.registers.program_counter = offset; if self.halt_offsets.contains(&offset) { self.running = false; @@ -703,11 +702,11 @@ impl<'a> Interpreter<'a> { } fn run_dup(&mut self, n: u8) { - self.push(self.stack()[self.stack().len() - n as usize]); + self.push(self.stack()[self.stack_len() - n as usize]); } fn run_swap(&mut self, n: u8) -> anyhow::Result<()> { - let len = self.stack().len(); + let len = self.stack_len(); ensure!(len > n as usize); self.stack_mut().swap(len - 1, len - n as usize - 1); Ok(()) @@ -726,7 +725,10 @@ impl<'a> Interpreter<'a> { let context = self.pop().as_usize(); let segment = Segment::all()[self.pop().as_usize()]; let offset = self.pop().as_usize(); - let value = self.memory.mload_general(context, segment, offset); + let value = self + .generation_state + .memory + .mload_general(context, segment, offset); assert!(value.bits() <= segment.bit_range()); self.push(value); } @@ -743,7 +745,13 @@ impl<'a> Interpreter<'a> { segment, segment.bit_range() ); - self.memory.mstore_general(context, segment, offset, value); + self.generation_state + .memory + .mstore_general(context, segment, offset, value); + } + + fn stack_len(&self) -> usize { + self.generation_state.registers.stack_len } } @@ -833,10 +841,6 @@ fn get_mnemonic(opcode: u8) -> &'static str { 0x59 => "MSIZE", 0x5a => "GAS", 0x5b => "JUMPDEST", - 0x5c => "GET_STATE_ROOT", - 0x5d => "SET_STATE_ROOT", - 0x5e => "GET_RECEIPT_ROOT", - 0x5f => "SET_RECEIPT_ROOT", 0x60 => "PUSH1", 0x61 => "PUSH2", 0x62 => "PUSH3", @@ -969,11 +973,13 @@ mod tests { let run = run(&code, 0, vec![], &pis)?; assert_eq!(run.stack(), &[0xff.into(), 0xff00.into()]); assert_eq!( - run.memory.context_memory[0].segments[Segment::MainMemory as usize].get(0x27), + run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize] + .get(0x27), 0x42.into() ); assert_eq!( - run.memory.context_memory[0].segments[Segment::MainMemory as usize].get(0x1f), + run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize] + .get(0x1f), 0xff.into() ); Ok(()) diff --git a/evm/src/cpu/kernel/keccak_util.rs b/evm/src/cpu/kernel/keccak_util.rs index 01d38cc4..38361389 100644 --- a/evm/src/cpu/kernel/keccak_util.rs +++ b/evm/src/cpu/kernel/keccak_util.rs @@ -1,31 +1,9 @@ use tiny_keccak::keccakf; -use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_RATE_U32S}; - -/// A Keccak-f based hash. -/// -/// This hash does not use standard Keccak padding, since we don't care about extra zeros at the -/// end of the code. It also uses an overwrite-mode sponge, rather than a standard sponge where -/// inputs are xor'ed in. -pub(crate) fn hash_kernel(code: &[u8]) -> [u32; 8] { - debug_assert_eq!( - code.len() % KECCAK_RATE_BYTES, - 0, - "Code should have been padded to a multiple of the Keccak rate." - ); - - let mut state = [0u32; 50]; - for chunk in code.chunks(KECCAK_RATE_BYTES) { - for i in 0..KECCAK_RATE_U32S { - state[i] = u32::from_le_bytes(std::array::from_fn(|j| chunk[i * 4 + j])); - } - keccakf_u32s(&mut state); - } - state[..8].try_into().unwrap() -} +use crate::keccak_sponge::columns::{KECCAK_WIDTH_BYTES, KECCAK_WIDTH_U32S}; /// Like tiny-keccak's `keccakf`, but deals with `u32` limbs instead of `u64` limbs. -pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; 50]) { +pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; KECCAK_WIDTH_U32S]) { let mut state_u64s: [u64; 25] = std::array::from_fn(|i| { let lo = state_u32s[i * 2] as u64; let hi = state_u32s[i * 2 + 1] as u64; @@ -39,6 +17,17 @@ pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; 50]) { }); } +/// Like tiny-keccak's `keccakf`, but deals with bytes instead of `u64` limbs. +pub(crate) fn keccakf_u8s(state_u8s: &mut [u8; KECCAK_WIDTH_BYTES]) { + let mut state_u64s: [u64; 25] = + std::array::from_fn(|i| u64::from_le_bytes(state_u8s[i * 8..][..8].try_into().unwrap())); + keccakf(&mut state_u64s); + *state_u8s = std::array::from_fn(|i| { + let u64_limb = state_u64s[i / 8]; + u64_limb.to_le_bytes()[i % 8] + }); +} + #[cfg(test)] mod tests { use tiny_keccak::keccakf; diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index 31074ff6..8b575f79 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -76,10 +76,6 @@ pub(crate) fn get_opcode(mnemonic: &str) -> u8 { "MSIZE" => 0x59, "GAS" => 0x5a, "JUMPDEST" => 0x5b, - "GET_STATE_ROOT" => 0x5c, - "SET_STATE_ROOT" => 0x5d, - "GET_RECEIPT_ROOT" => 0x5e, - "SET_RECEIPT_ROOT" => 0x5f, "DUP1" => 0x80, "DUP2" => 0x81, "DUP3" => 0x82, diff --git a/evm/src/cpu/kernel/tests/account_code.rs b/evm/src/cpu/kernel/tests/account_code.rs index 7e5f88be..c6d7f156 100644 --- a/evm/src/cpu/kernel/tests/account_code.rs +++ b/evm/src/cpu/kernel/tests/account_code.rs @@ -42,7 +42,7 @@ fn prepare_interpreter( let mut state_trie: PartialTrie = Default::default(); let trie_inputs = Default::default(); - interpreter.offset = load_all_mpts; + interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); @@ -53,7 +53,7 @@ fn prepare_interpreter( keccak(address.to_fixed_bytes()).as_bytes(), )); // Next, execute mpt_insert_state_trie. - interpreter.offset = mpt_insert_state_trie; + interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. @@ -83,7 +83,7 @@ fn prepare_interpreter( ); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; @@ -115,7 +115,7 @@ fn test_extcodesize() -> Result<()> { let extcodesize = KERNEL.global_labels["extcodesize"]; // Test `extcodesize` - interpreter.offset = extcodesize; + interpreter.generation_state.registers.program_counter = extcodesize; interpreter.pop(); assert!(interpreter.stack().is_empty()); interpreter.push(0xDEADBEEFu32.into()); @@ -144,10 +144,10 @@ fn test_extcodecopy() -> Result<()> { // Put random data in main memory and the `KernelAccountCode` segment for realism. let mut rng = thread_rng(); for i in 0..2000 { - interpreter.memory.context_memory[interpreter.context].segments + interpreter.generation_state.memory.contexts[interpreter.context].segments [Segment::MainMemory as usize] .set(i, U256::from(rng.gen::())); - interpreter.memory.context_memory[interpreter.context].segments + interpreter.generation_state.memory.contexts[interpreter.context].segments [Segment::KernelAccountCode as usize] .set(i, U256::from(rng.gen::())); } @@ -158,7 +158,7 @@ fn test_extcodecopy() -> Result<()> { let size = rng.gen_range(0..1500); // Test `extcodecopy` - interpreter.offset = extcodecopy; + interpreter.generation_state.registers.program_counter = extcodecopy; interpreter.pop(); assert!(interpreter.stack().is_empty()); interpreter.push(0xDEADBEEFu32.into()); @@ -173,7 +173,7 @@ fn test_extcodecopy() -> Result<()> { assert!(interpreter.stack().is_empty()); // Check that the code was correctly copied to memory. for i in 0..size { - let memory = interpreter.memory.context_memory[interpreter.context].segments + let memory = interpreter.generation_state.memory.contexts[interpreter.context].segments [Segment::MainMemory as usize] .get(dest_offset + i); assert_eq!( diff --git a/evm/src/cpu/kernel/tests/balance.rs b/evm/src/cpu/kernel/tests/balance.rs index 1e784e85..b0e087a9 100644 --- a/evm/src/cpu/kernel/tests/balance.rs +++ b/evm/src/cpu/kernel/tests/balance.rs @@ -33,7 +33,7 @@ fn prepare_interpreter( let mut state_trie: PartialTrie = Default::default(); let trie_inputs = Default::default(); - interpreter.offset = load_all_mpts; + interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); @@ -44,7 +44,7 @@ fn prepare_interpreter( keccak(address.to_fixed_bytes()).as_bytes(), )); // Next, execute mpt_insert_state_trie. - interpreter.offset = mpt_insert_state_trie; + interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. @@ -74,7 +74,7 @@ fn prepare_interpreter( ); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; @@ -105,7 +105,7 @@ fn test_balance() -> Result<()> { prepare_interpreter(&mut interpreter, address, &account)?; // Test `balance` - interpreter.offset = KERNEL.global_labels["balance"]; + interpreter.generation_state.registers.program_counter = KERNEL.global_labels["balance"]; interpreter.pop(); assert!(interpreter.stack().is_empty()); interpreter.push(0xDEADBEEFu32.into()); diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 6321fb4b..6c6c6f63 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -113,7 +113,7 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { assert_eq!(interpreter.stack(), vec![]); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 6e1ad573..cf546969 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -164,7 +164,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account assert_eq!(interpreter.stack(), vec![]); // Next, execute mpt_insert_state_trie. - interpreter.offset = mpt_insert_state_trie; + interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. @@ -194,7 +194,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account ); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index d8808e24..62313f62 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -27,7 +27,7 @@ fn mpt_read() -> Result<()> { assert_eq!(interpreter.stack(), vec![]); // Now, execute mpt_read on the state trie. - interpreter.offset = mpt_read; + interpreter.generation_state.registers.program_counter = mpt_read; interpreter.push(0xdeadbeefu32.into()); interpreter.push(0xABCDEFu64.into()); interpreter.push(6.into()); diff --git a/evm/src/cpu/membus.rs b/evm/src/cpu/membus.rs index 1ec7b3e3..08cae757 100644 --- a/evm/src/cpu/membus.rs +++ b/evm/src/cpu/membus.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::CpuColumnsView; /// General-purpose memory channels; they can read and write to all contexts/segments/addresses. -pub const NUM_GP_CHANNELS: usize = 4; +pub const NUM_GP_CHANNELS: usize = 5; pub mod channel_indices { use std::ops::Range; diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index ece07c1c..3a2df351 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod bootstrap_kernel; pub(crate) mod columns; -mod control_flow; +pub(crate) mod control_flow; pub mod cpu_stark; pub(crate) mod decode; mod dup_swap; @@ -9,7 +9,7 @@ pub mod kernel; pub(crate) mod membus; mod modfp254; mod shift; -mod simple_logic; +pub(crate) mod simple_logic; mod stack; -mod stack_bounds; +pub(crate) mod stack_bounds; mod syscalls; diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index d383b6b2..bbabf173 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -14,7 +14,7 @@ pub(crate) fn eval_packed( yield_constr: &mut ConstraintConsumer

, ) { let is_shift = lv.op.shl + lv.op.shr; - let displacement = lv.mem_channels[1]; // holds the shift displacement d + let displacement = lv.mem_channels[0]; // holds the shift displacement d let two_exp = lv.mem_channels[2]; // holds 2^d // Not needed here; val is the input and we're verifying that output is @@ -65,7 +65,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_shift = builder.add_extension(lv.op.shl, lv.op.shr); - let displacement = lv.mem_channels[1]; + let displacement = lv.mem_channels[0]; let two_exp = lv.mem_channels[2]; let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64); diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index 37e06248..8a084c14 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -1,34 +1,29 @@ +use ethereum_types::U256; use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -pub fn generate(lv: &mut CpuColumnsView) { - let input0 = lv.mem_channels[0].value; - - let eq_filter = lv.op.eq.to_canonical_u64(); - let iszero_filter = lv.op.iszero.to_canonical_u64(); - assert!(eq_filter <= 1); - assert!(iszero_filter <= 1); - assert!(eq_filter + iszero_filter <= 1); - - if eq_filter + iszero_filter == 0 { - return; +fn limbs(x: U256) -> [u32; 8] { + let mut res = [0; 8]; + let x_u64: [u64; 4] = x.0; + for i in 0..4 { + res[2 * i] = x_u64[i] as u32; + res[2 * i + 1] = (x_u64[i] >> 32) as u32; } + res +} - let input1 = &mut lv.mem_channels[1].value; - if iszero_filter != 0 { - for limb in input1.iter_mut() { - *limb = F::ZERO; - } - } +pub fn generate_pinv_diff(val0: U256, val1: U256, lv: &mut CpuColumnsView) { + let val0_limbs = limbs(val0).map(F::from_canonical_u32); + let val1_limbs = limbs(val1).map(F::from_canonical_u32); - let input1 = lv.mem_channels[1].value; - let num_unequal_limbs = izip!(input0, input1) + let num_unequal_limbs = izip!(val0_limbs, val1_limbs) .map(|(limb0, limb1)| (limb0 != limb1) as usize) .sum(); let equal = num_unequal_limbs == 0; @@ -40,7 +35,7 @@ pub fn generate(lv: &mut CpuColumnsView) { } // Form `diff_pinv`. - // Let `diff = input0 - input1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise. + // Let `diff = val0 - val1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise. // Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set // `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have // `diff @ diff_pinv = 1 - equal` as desired. @@ -48,7 +43,7 @@ pub fn generate(lv: &mut CpuColumnsView) { let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs) .try_inverse() .unwrap_or(F::ZERO); - for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), input0, input1) { + for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), val0_limbs, val1_limbs) { *limb_pinv = (limb0 - limb1).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv; } } diff --git a/evm/src/cpu/simple_logic/mod.rs b/evm/src/cpu/simple_logic/mod.rs index 963b11b2..03d2dd15 100644 --- a/evm/src/cpu/simple_logic/mod.rs +++ b/evm/src/cpu/simple_logic/mod.rs @@ -1,4 +1,4 @@ -mod eq_iszero; +pub(crate) mod eq_iszero; mod not; use plonky2::field::extension::Extendable; @@ -9,17 +9,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -pub fn generate(lv: &mut CpuColumnsView) { - let cycle_filter = lv.is_cpu_cycle.to_canonical_u64(); - if cycle_filter == 0 { - return; - } - assert_eq!(cycle_filter, 1); - - not::generate(lv); - eq_iszero::generate(lv); -} - pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index 3b8a888f..16572e9c 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -6,34 +6,18 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; +use crate::cpu::membus::NUM_GP_CHANNELS; const LIMB_SIZE: usize = 32; const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1; -pub fn generate(lv: &mut CpuColumnsView) { - let is_not_filter = lv.op.not.to_canonical_u64(); - if is_not_filter == 0 { - return; - } - assert_eq!(is_not_filter, 1); - - let input = lv.mem_channels[0].value; - let output = &mut lv.mem_channels[1].value; - for (input, output_ref) in input.into_iter().zip(output.iter_mut()) { - let input = input.to_canonical_u64(); - assert_eq!(input >> LIMB_SIZE, 0); - let output = input ^ ALL_1_LIMB; - *output_ref = F::from_canonical_u64(output); - } -} - pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { // This is simple: just do output = 0xffffffff - input. let input = lv.mem_channels[0].value; - let output = lv.mem_channels[1].value; + let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.op.not; let filter = cycle_filter * is_not_filter; @@ -50,7 +34,7 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let input = lv.mem_channels[0].value; - let output = lv.mem_channels[1].value; + let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.op.not; let filter = builder.mul_extension(cycle_filter, is_not_filter); diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index ea235578..08ab3044 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -61,19 +61,15 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { byte: BASIC_BINARY_OP, shl: BASIC_BINARY_OP, shr: BASIC_BINARY_OP, - keccak_general: None, // TODO - prover_input: None, // TODO - pop: None, // TODO - jump: None, // TODO - jumpi: None, // TODO - pc: None, // TODO - gas: None, // TODO - jumpdest: None, // TODO - get_state_root: None, // TODO - set_state_root: None, // TODO - get_receipt_root: None, // TODO - set_receipt_root: None, // TODO - push: None, // TODO + keccak_general: None, // TODO + prover_input: None, // TODO + pop: None, // TODO + jump: None, // TODO + jumpi: None, // TODO + pc: None, // TODO + gas: None, // TODO + jumpdest: None, // TODO + push: None, // TODO dup: None, swap: None, get_context: None, // TODO diff --git a/evm/src/cpu/stack_bounds.rs b/evm/src/cpu/stack_bounds.rs index 99734433..627411ea 100644 --- a/evm/src/cpu/stack_bounds.rs +++ b/evm/src/cpu/stack_bounds.rs @@ -19,7 +19,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP}; -const MAX_USER_STACK_SIZE: u64 = 1024; +pub const MAX_USER_STACK_SIZE: usize = 1024; // Below only includes the operations that pop the top of the stack **without reading the value from // memory**, i.e. `POP`. @@ -45,7 +45,7 @@ pub fn generate(lv: &mut CpuColumnsView) { let check_overflow: F = INCREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); let no_check = F::ONE - (check_underflow + check_overflow); - let disallowed_len = check_overflow * F::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check; + let disallowed_len = check_overflow * F::from_canonical_usize(MAX_USER_STACK_SIZE) - no_check; let diff = lv.stack_len - disallowed_len; let user_mode = F::ONE - lv.is_kernel_mode; @@ -84,7 +84,7 @@ pub fn eval_packed( // 0 if `check_underflow`, `MAX_USER_STACK_SIZE` if `check_overflow`, and -1 if `no_check`. let disallowed_len = - check_overflow * P::Scalar::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check; + check_overflow * P::Scalar::from_canonical_usize(MAX_USER_STACK_SIZE) - no_check; // This `lhs` must equal some `rhs`. If `rhs` is nonzero, then this shows that `lv.stack_len` is // not `disallowed_len`. let lhs = (lv.stack_len - disallowed_len) * lv.stack_len_bounds_aux; @@ -108,7 +108,7 @@ pub fn eval_ext_circuit, const D: usize>( ) { let one = builder.one_extension(); let max_stack_size = - builder.constant_extension(F::from_canonical_u64(MAX_USER_STACK_SIZE).into()); + builder.constant_extension(F::from_canonical_usize(MAX_USER_STACK_SIZE).into()); // `check_underflow`, `check_overflow`, and `no_check` are mutually exclusive. let check_underflow = builder.add_many_extension(DECREMENTING_FLAGS.map(|i| lv[i])); diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index a1fd3ce7..4930321a 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -145,7 +145,7 @@ impl Column { pub struct TableWithColumns { table: Table, columns: Vec>, - filter_column: Option>, + pub(crate) filter_column: Option>, } impl TableWithColumns { @@ -160,8 +160,8 @@ impl TableWithColumns { #[derive(Clone)] pub struct CrossTableLookup { - looking_tables: Vec>, - looked_table: TableWithColumns, + pub(crate) looking_tables: Vec>, + pub(crate) looked_table: TableWithColumns, /// Default value if filters are not used. default: Option>, } @@ -248,6 +248,7 @@ pub fn cross_table_lookup_data, const D default, } in cross_table_lookups { + log::debug!("Processing CTL for {:?}", looked_table.table); for &challenge in &challenges.challenges { let zs_looking = looking_tables.iter().map(|table| { partial_products( @@ -610,16 +611,15 @@ pub(crate) fn verify_cross_table_lookups< .product::(); let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); let challenge = challenges.challenges[i % config.num_challenges]; - let combined_default = default - .as_ref() - .map(|default| challenge.combine(default.iter())) - .unwrap_or(F::ONE); - ensure!( - looking_zs_prod - == looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree), - "Cross-table lookup verification failed." - ); + if let Some(default) = default.as_ref() { + let combined_default = challenge.combine(default.iter()); + ensure!( + looking_zs_prod + == looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree), + "Cross-table lookup verification failed." + ); + } } } debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none())); @@ -694,6 +694,7 @@ pub(crate) mod testutils { type MultiSet = HashMap, Vec<(Table, usize)>>; /// Check that the provided traces and cross-table lookups are consistent. + #[allow(unused)] // TODO: used later? pub(crate) fn check_ctls( trace_poly_values: &[Vec>], cross_table_lookups: &[CrossTableLookup], diff --git a/evm/src/generation/memory.rs b/evm/src/generation/memory.rs deleted file mode 100644 index 944b42a6..00000000 --- a/evm/src/generation/memory.rs +++ /dev/null @@ -1,50 +0,0 @@ -use ethereum_types::U256; - -use crate::memory::memory_stark::MemoryOp; -use crate::memory::segments::Segment; - -#[allow(unused)] // TODO: Should be used soon. -#[derive(Debug)] -pub(crate) struct MemoryState { - /// A log of each memory operation, in the order that it occurred. - pub log: Vec, - - pub contexts: Vec, -} - -impl Default for MemoryState { - fn default() -> Self { - Self { - log: vec![], - // We start with an initial context for the kernel. - contexts: vec![MemoryContextState::default()], - } - } -} - -#[derive(Clone, Default, Debug)] -pub(crate) struct MemoryContextState { - /// The content of each memory segment. - pub segments: [MemorySegmentState; Segment::COUNT], -} - -#[derive(Clone, Default, Debug)] -pub(crate) struct MemorySegmentState { - pub content: Vec, -} - -impl MemorySegmentState { - pub(crate) fn get(&self, virtual_addr: usize) -> U256 { - self.content - .get(virtual_addr) - .copied() - .unwrap_or(U256::zero()) - } - - pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) { - if virtual_addr >= self.content.len() { - self.content.resize(virtual_addr + 1, U256::zero()); - } - self.content[virtual_addr] = value; - } -} diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 75f434d7..8b662a6d 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -4,24 +4,27 @@ use eth_trie_utils::partial_trie::PartialTrie; use ethereum_types::{Address, BigEndianHash, H256}; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; -use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::timed; use plonky2::util::timing::TimingTree; use serde::{Deserialize, Serialize}; +use GlobalMetadata::{ + ReceiptTrieRootDigestAfter, ReceiptTrieRootDigestBefore, StateTrieRootDigestAfter, + StateTrieRootDigestBefore, TransactionTrieRootDigestAfter, TransactionTrieRootDigestBefore, +}; use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel; -use crate::cpu::columns::NUM_CPU_COLUMNS; +use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; -use crate::memory::NUM_CHANNELS; use crate::proof::{BlockMetadata, PublicValues, TrieRoots}; -use crate::util::trace_rows_to_poly_values; +use crate::witness::memory::MemoryAddress; +use crate::witness::transition::transition; -pub(crate) mod memory; -pub(crate) mod mpt; +pub mod mpt; pub(crate) mod prover_input; pub(crate) mod rlp; pub(crate) mod state; @@ -65,79 +68,68 @@ pub(crate) fn generate_traces, const D: usize>( config: &StarkConfig, timing: &mut TimingTree, ) -> ([Vec>; NUM_TABLES], PublicValues) { - let mut state = GenerationState::::new(inputs.clone()); + let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code); generate_bootstrap_kernel::(&mut state); - for txn in &inputs.signed_txns { - generate_txn(&mut state, txn); - } + timed!(timing, "simulate CPU", simulate_cpu(&mut state)); - // TODO: Pad to a power of two, ending in the `halt` kernel function. + log::info!( + "Trace lengths (before padding): {:?}", + state.traces.checkpoint() + ); - let cpu_rows = state.cpu_rows.len(); - let mem_end_timestamp = cpu_rows * NUM_CHANNELS; - let mut read_metadata = |field| { - state.get_mem( + let read_metadata = |field| { + state.memory.get(MemoryAddress::new( 0, Segment::GlobalMetadata, field as usize, - mem_end_timestamp, - ) + )) }; let trie_roots_before = TrieRoots { - state_root: H256::from_uint(&read_metadata(GlobalMetadata::StateTrieRootDigestBefore)), - transactions_root: H256::from_uint(&read_metadata( - GlobalMetadata::TransactionTrieRootDigestBefore, - )), - receipts_root: H256::from_uint(&read_metadata(GlobalMetadata::ReceiptTrieRootDigestBefore)), + state_root: H256::from_uint(&read_metadata(StateTrieRootDigestBefore)), + transactions_root: H256::from_uint(&read_metadata(TransactionTrieRootDigestBefore)), + receipts_root: H256::from_uint(&read_metadata(ReceiptTrieRootDigestBefore)), }; let trie_roots_after = TrieRoots { - state_root: H256::from_uint(&read_metadata(GlobalMetadata::StateTrieRootDigestAfter)), - transactions_root: H256::from_uint(&read_metadata( - GlobalMetadata::TransactionTrieRootDigestAfter, - )), - receipts_root: H256::from_uint(&read_metadata(GlobalMetadata::ReceiptTrieRootDigestAfter)), + state_root: H256::from_uint(&read_metadata(StateTrieRootDigestAfter)), + transactions_root: H256::from_uint(&read_metadata(TransactionTrieRootDigestAfter)), + receipts_root: H256::from_uint(&read_metadata(ReceiptTrieRootDigestAfter)), }; - let GenerationState { - cpu_rows, - current_cpu_row, - memory, - keccak_inputs, - keccak_memory_inputs, - logic_ops, - .. - } = state; - assert_eq!(current_cpu_row, [F::ZERO; NUM_CPU_COLUMNS].into()); - - let cpu_trace = trace_rows_to_poly_values(cpu_rows); - let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs, timing); - let keccak_memory_trace = all_stark.keccak_memory_stark.generate_trace( - keccak_memory_inputs, - config.fri_config.num_cap_elements(), - timing, - ); - let logic_trace = all_stark.logic_stark.generate_trace(logic_ops, timing); - let memory_trace = all_stark.memory_stark.generate_trace(memory.log, timing); - let traces = [ - cpu_trace, - keccak_trace, - keccak_memory_trace, - logic_trace, - memory_trace, - ]; - let public_values = PublicValues { trie_roots_before, trie_roots_after, block_metadata: inputs.block_metadata, }; - (traces, public_values) + let tables = timed!( + timing, + "convert trace data to tables", + state.traces.into_tables(all_stark, config, timing) + ); + (tables, public_values) } -fn generate_txn(_state: &mut GenerationState, _signed_txn: &[u8]) { - // TODO +fn simulate_cpu, const D: usize>(state: &mut GenerationState) { + let halt_pc0 = KERNEL.global_labels["halt_pc0"]; + let halt_pc1 = KERNEL.global_labels["halt_pc1"]; + + let mut already_in_halt_loop = false; + loop { + // If we've reached the kernel's halt routine, and our trace length is a power of 2, stop. + let pc = state.registers.program_counter; + let in_halt_loop = pc == halt_pc0 || pc == halt_pc1; + if in_halt_loop && !already_in_halt_loop { + log::info!("CPU halted after {} cycles", state.traces.clock()); + } + already_in_halt_loop |= in_halt_loop; + if already_in_halt_loop && state.traces.clock().is_power_of_two() { + log::info!("CPU trace padded to {} cycles", state.traces.clock()); + break; + } + + transition(state); + } } diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index a5be1205..15b92f45 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -9,11 +9,11 @@ use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; #[derive(RlpEncodable, RlpDecodable, Debug)] -pub(crate) struct AccountRlp { - pub(crate) nonce: U256, - pub(crate) balance: U256, - pub(crate) storage_root: H256, - pub(crate) code_hash: H256, +pub struct AccountRlp { + pub nonce: U256, + pub balance: U256, + pub storage_root: H256, + pub code_hash: H256, } pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 4515bd95..885760eb 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -8,6 +8,7 @@ use crate::generation::prover_input::EvmField::{ }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::state::GenerationState; +use crate::witness::util::stack_peek; /// Prover input function represented as a scoped function name. /// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`. @@ -21,14 +22,13 @@ impl From> for ProverInputFn { } impl GenerationState { - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn prover_input(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 { + pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> U256 { match input_fn.0[0].as_str() { "end_of_txns" => self.run_end_of_txns(), - "ff" => self.run_ff(stack, input_fn), + "ff" => self.run_ff(input_fn), "mpt" => self.run_mpt(), "rlp" => self.run_rlp(), - "account_code" => self.run_account_code(stack, input_fn), + "account_code" => self.run_account_code(input_fn), _ => panic!("Unrecognized prover input function."), } } @@ -44,10 +44,10 @@ impl GenerationState { } /// Finite field operations. - fn run_ff(&self, stack: &[U256], input_fn: &ProverInputFn) -> U256 { + fn run_ff(&self, input_fn: &ProverInputFn) -> U256 { let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); let op = FieldOp::from_str(input_fn.0[2].as_str()).unwrap(); - let x = *stack.last().expect("Empty stack"); + let x = stack_peek(self, 0).expect("Empty stack"); field.op(op, x) } @@ -66,22 +66,21 @@ impl GenerationState { } /// Account code. - fn run_account_code(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 { + fn run_account_code(&mut self, input_fn: &ProverInputFn) -> U256 { match input_fn.0[1].as_str() { "length" => { // Return length of code. // stack: codehash, ... - let codehash = stack.last().expect("Empty stack"); - self.inputs.contract_code[&H256::from_uint(codehash)] + let codehash = stack_peek(self, 0).expect("Empty stack"); + self.inputs.contract_code[&H256::from_uint(&codehash)] .len() .into() } "get" => { // Return `code[i]`. // stack: i, code_length, codehash, ... - let stacklen = stack.len(); - let i = stack[stacklen - 1].as_usize(); - let codehash = stack[stacklen - 3]; + let i = stack_peek(self, 0).expect("Unexpected stack").as_usize(); + let codehash = stack_peek(self, 2).expect("Unexpected stack"); self.inputs.contract_code[&H256::from_uint(&codehash)][i].into() } _ => panic!("Invalid prover input function."), diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 17d63018..bf1fbd74 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -1,35 +1,26 @@ -use std::mem; - use ethereum_types::U256; use plonky2::field::types::Field; -use tiny_keccak::keccakf; -use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; -use crate::generation::memory::MemoryState; use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::GenerationInputs; -use crate::keccak_memory::keccak_memory_stark::KeccakMemoryOp; -use crate::memory::memory_stark::MemoryOp; -use crate::memory::segments::Segment; -use crate::memory::NUM_CHANNELS; -use crate::util::u256_limbs; -use crate::{keccak, logic}; +use crate::witness::memory::MemoryState; +use crate::witness::state::RegistersState; +use crate::witness::traces::{TraceCheckpoint, Traces}; + +pub(crate) struct GenerationStateCheckpoint { + pub(crate) registers: RegistersState, + pub(crate) traces: TraceCheckpoint, +} #[derive(Debug)] pub(crate) struct GenerationState { - #[allow(unused)] // TODO: Should be used soon. pub(crate) inputs: GenerationInputs, - pub(crate) next_txn_index: usize, - pub(crate) cpu_rows: Vec<[F; NUM_CPU_COLUMNS]>, - pub(crate) current_cpu_row: CpuColumnsView, - - pub(crate) current_context: usize, + pub(crate) registers: RegistersState, pub(crate) memory: MemoryState, + pub(crate) traces: Traces, - pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, - pub(crate) keccak_memory_inputs: Vec, - pub(crate) logic_ops: Vec, + pub(crate) next_txn_index: usize, /// Prover inputs containing MPT data, in reverse order so that the next input can be obtained /// via `pop()`. @@ -41,212 +32,30 @@ pub(crate) struct GenerationState { } impl GenerationState { - pub(crate) fn new(inputs: GenerationInputs) -> Self { + pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Self { let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries); let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); Self { inputs, + registers: Default::default(), + memory: MemoryState::new(kernel_code), + traces: Traces::default(), next_txn_index: 0, - cpu_rows: vec![], - current_cpu_row: [F::ZERO; NUM_CPU_COLUMNS].into(), - current_context: 0, - memory: MemoryState::default(), - keccak_inputs: vec![], - keccak_memory_inputs: vec![], - logic_ops: vec![], mpt_prover_inputs, rlp_prover_inputs, } } - /// Compute logical AND, and record the operation to be added in the logic table later. - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn and(&mut self, input0: U256, input1: U256) -> U256 { - self.logic_op(logic::Op::And, input0, input1) + pub fn checkpoint(&self) -> GenerationStateCheckpoint { + GenerationStateCheckpoint { + registers: self.registers, + traces: self.traces.checkpoint(), + } } - /// Compute logical OR, and record the operation to be added in the logic table later. - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn or(&mut self, input0: U256, input1: U256) -> U256 { - self.logic_op(logic::Op::Or, input0, input1) - } - - /// Compute logical XOR, and record the operation to be added in the logic table later. - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn xor(&mut self, input0: U256, input1: U256) -> U256 { - self.logic_op(logic::Op::Xor, input0, input1) - } - - /// Compute logical AND, and record the operation to be added in the logic table later. - pub(crate) fn logic_op(&mut self, op: logic::Op, input0: U256, input1: U256) -> U256 { - let operation = logic::Operation::new(op, input0, input1); - let result = operation.result; - self.logic_ops.push(operation); - result - } - - /// Like `get_mem_cpu`, but reads from the current context specifically. - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn get_mem_cpu_current( - &mut self, - channel_index: usize, - segment: Segment, - virt: usize, - ) -> U256 { - let context = self.current_context; - self.get_mem_cpu(channel_index, context, segment, virt) - } - - /// Simulates the CPU reading some memory through the given channel. Besides logging the memory - /// operation, this also generates the associated registers in the current CPU row. - pub(crate) fn get_mem_cpu( - &mut self, - channel_index: usize, - context: usize, - segment: Segment, - virt: usize, - ) -> U256 { - let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; - let value = self.get_mem(context, segment, virt, timestamp); - - let channel = &mut self.current_cpu_row.mem_channels[channel_index]; - channel.used = F::ONE; - channel.is_read = F::ONE; - channel.addr_context = F::from_canonical_usize(context); - channel.addr_segment = F::from_canonical_usize(segment as usize); - channel.addr_virtual = F::from_canonical_usize(virt); - channel.value = u256_limbs(value); - - value - } - - /// Read some memory, and log the operation. - pub(crate) fn get_mem( - &mut self, - context: usize, - segment: Segment, - virt: usize, - timestamp: usize, - ) -> U256 { - let value = self.memory.contexts[context].segments[segment as usize].get(virt); - self.memory.log.push(MemoryOp { - filter: true, - timestamp, - is_read: true, - context, - segment, - virt, - value, - }); - value - } - - /// Write some memory within the current execution context, and log the operation. - pub(crate) fn set_mem_cpu_current( - &mut self, - channel_index: usize, - segment: Segment, - virt: usize, - value: U256, - ) { - let context = self.current_context; - self.set_mem_cpu(channel_index, context, segment, virt, value); - } - - /// Write some memory, and log the operation. - pub(crate) fn set_mem_cpu( - &mut self, - channel_index: usize, - context: usize, - segment: Segment, - virt: usize, - value: U256, - ) { - let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; - self.set_mem(context, segment, virt, value, timestamp); - - let channel = &mut self.current_cpu_row.mem_channels[channel_index]; - channel.used = F::ONE; - channel.is_read = F::ZERO; // For clarity; should already be 0. - channel.addr_context = F::from_canonical_usize(context); - channel.addr_segment = F::from_canonical_usize(segment as usize); - channel.addr_virtual = F::from_canonical_usize(virt); - channel.value = u256_limbs(value); - } - - /// Write some memory, and log the operation. - pub(crate) fn set_mem( - &mut self, - context: usize, - segment: Segment, - virt: usize, - value: U256, - timestamp: usize, - ) { - self.memory.log.push(MemoryOp { - filter: true, - timestamp, - is_read: false, - context, - segment, - virt, - value, - }); - self.memory.contexts[context].segments[segment as usize].set(virt, value) - } - - /// Evaluate the Keccak-f permutation in-place on some data in memory, and record the operations - /// for the purpose of witness generation. - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn keccak_memory( - &mut self, - context: usize, - segment: Segment, - virt: usize, - ) -> [u64; keccak::keccak_stark::NUM_INPUTS] { - let read_timestamp = self.cpu_rows.len() * NUM_CHANNELS; - let _write_timestamp = read_timestamp + 1; - let input = (0..25) - .map(|i| { - let bytes = [0, 1, 2, 3, 4, 5, 6, 7].map(|j| { - let virt = virt + i * 8 + j; - let byte = self.get_mem(context, segment, virt, read_timestamp); - debug_assert!(byte.bits() <= 8); - byte.as_u32() as u8 - }); - u64::from_le_bytes(bytes) - }) - .collect::>() - .try_into() - .unwrap(); - let output = self.keccak(input); - self.keccak_memory_inputs.push(KeccakMemoryOp { - context, - segment, - virt, - read_timestamp, - input, - output, - }); - // TODO: Write output to memory. - output - } - - /// Evaluate the Keccak-f permutation, and record the operation for the purpose of witness - /// generation. - pub(crate) fn keccak( - &mut self, - mut input: [u64; keccak::keccak_stark::NUM_INPUTS], - ) -> [u64; keccak::keccak_stark::NUM_INPUTS] { - self.keccak_inputs.push(input); - keccakf(&mut input); - input - } - - pub(crate) fn commit_cpu_row(&mut self) { - let mut swapped_row = [F::ZERO; NUM_CPU_COLUMNS].into(); - mem::swap(&mut self.current_cpu_row, &mut swapped_row); - self.cpu_rows.push(swapped_row.into()); + pub fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { + self.registers = checkpoint.registers; + self.traces.rollback(checkpoint.traces); } } diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 87a61ae7..7be421fb 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -1,7 +1,6 @@ use std::marker::PhantomData; use itertools::Itertools; -use log::info; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -39,6 +38,7 @@ pub fn ctl_data() -> Vec> { } pub fn ctl_filter() -> Column { + // TODO: Also need to filter out padding rows somehow. Column::single(reg_step(NUM_ROUNDS - 1)) } @@ -50,12 +50,14 @@ pub struct KeccakStark { impl, const D: usize> KeccakStark { /// Generate the rows of the trace. Note that this does not generate the permuted columns used /// in our lookup arguments, as those are computed after transposing to column-wise form. - pub(crate) fn generate_trace_rows( + fn generate_trace_rows( &self, inputs: Vec<[u64; NUM_INPUTS]>, + min_rows: usize, ) -> Vec<[F; NUM_COLUMNS]> { - let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two(); - info!("{} rows", num_rows); + let num_rows = (inputs.len() * NUM_ROUNDS) + .max(min_rows) + .next_power_of_two(); let mut rows = Vec::with_capacity(num_rows); for input in inputs.iter() { rows.extend(self.generate_trace_rows_for_perm(*input)); @@ -204,13 +206,14 @@ impl, const D: usize> KeccakStark { pub fn generate_trace( &self, inputs: Vec<[u64; NUM_INPUTS]>, + min_rows: usize, timing: &mut TimingTree, ) -> Vec> { // Generate the witness, except for permuted columns in the lookup argument. let trace_rows = timed!( timing, "generate trace rows", - self.generate_trace_rows(inputs) + self.generate_trace_rows(inputs, min_rows) ); let trace_polys = timed!( timing, @@ -598,7 +601,7 @@ mod tests { f: Default::default(), }; - let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()]); + let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()], 8); let last_row = rows[NUM_ROUNDS - 1]; let output = (0..NUM_INPUTS) .map(|i| { @@ -637,7 +640,7 @@ mod tests { let trace_poly_values = timed!( timing, "generate trace", - stark.generate_trace(input.try_into().unwrap(), &mut timing) + stark.generate_trace(input.try_into().unwrap(), 8, &mut timing) ); // TODO: Cloning this isn't great; consider having `from_values` accept a reference, diff --git a/evm/src/keccak_memory/columns.rs b/evm/src/keccak_memory/columns.rs deleted file mode 100644 index 92bdbf2b..00000000 --- a/evm/src/keccak_memory/columns.rs +++ /dev/null @@ -1,29 +0,0 @@ -pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; - -/// 1 if this row represents a real operation; 0 if it's a padding row. -pub(crate) const COL_IS_REAL: usize = 0; - -// The address at which we will read inputs and write outputs. -pub(crate) const COL_CONTEXT: usize = 1; -pub(crate) const COL_SEGMENT: usize = 2; -pub(crate) const COL_VIRTUAL: usize = 3; - -/// The timestamp at which inputs should be read from memory. -/// Outputs will be written at the following timestamp. -pub(crate) const COL_READ_TIMESTAMP: usize = 4; - -const START_INPUT_LIMBS: usize = 5; -/// A byte of the input. -pub(crate) fn col_input_byte(i: usize) -> usize { - debug_assert!(i < KECCAK_WIDTH_BYTES); - START_INPUT_LIMBS + i -} - -const START_OUTPUT_LIMBS: usize = START_INPUT_LIMBS + KECCAK_WIDTH_BYTES; -/// A byte of the output. -pub(crate) fn col_output_byte(i: usize) -> usize { - debug_assert!(i < KECCAK_WIDTH_BYTES); - START_OUTPUT_LIMBS + i -} - -pub const NUM_COLUMNS: usize = START_OUTPUT_LIMBS + KECCAK_WIDTH_BYTES; diff --git a/evm/src/keccak_memory/keccak_memory_stark.rs b/evm/src/keccak_memory/keccak_memory_stark.rs deleted file mode 100644 index 3719fc8e..00000000 --- a/evm/src/keccak_memory/keccak_memory_stark.rs +++ /dev/null @@ -1,224 +0,0 @@ -use std::marker::PhantomData; - -use plonky2::field::extension::{Extendable, FieldExtension}; -use plonky2::field::packed::PackedField; -use plonky2::field::polynomial::PolynomialValues; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::timed; -use plonky2::util::timing::TimingTree; - -use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cross_table_lookup::Column; -use crate::keccak::keccak_stark::NUM_INPUTS; -use crate::keccak_memory::columns::*; -use crate::memory::segments::Segment; -use crate::stark::Stark; -use crate::util::trace_rows_to_poly_values; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; - -pub(crate) fn ctl_looked_data() -> Vec> { - Column::singles([COL_CONTEXT, COL_SEGMENT, COL_VIRTUAL, COL_READ_TIMESTAMP]).collect() -} - -pub(crate) fn ctl_looking_keccak() -> Vec> { - let input_cols = (0..50).map(|i| { - Column::le_bytes((0..4).map(|j| { - let byte_index = i * 4 + j; - col_input_byte(byte_index) - })) - }); - let output_cols = (0..50).map(|i| { - Column::le_bytes((0..4).map(|j| { - let byte_index = i * 4 + j; - col_output_byte(byte_index) - })) - }); - input_cols.chain(output_cols).collect() -} - -pub(crate) fn ctl_looking_memory(i: usize, is_read: bool) -> Vec> { - let mut res = vec![Column::constant(F::from_bool(is_read))]; - res.extend(Column::singles([COL_CONTEXT, COL_SEGMENT, COL_VIRTUAL])); - - res.push(Column::single(col_input_byte(i))); - // Since we're reading or writing a single byte, the higher limbs must be zero. - res.extend((1..8).map(|_| Column::zero())); - - // Since COL_READ_TIMESTAMP is the read time, we add 1 if this is a write. - let is_write_f = F::from_bool(!is_read); - res.push(Column::linear_combination_with_constant( - [(COL_READ_TIMESTAMP, F::ONE)], - is_write_f, - )); - - assert_eq!( - res.len(), - crate::memory::memory_stark::ctl_data::().len() - ); - res -} - -/// CTL filter used for both directions (looked and looking). -pub(crate) fn ctl_filter() -> Column { - Column::single(COL_IS_REAL) -} - -/// Information about a Keccak memory operation needed for witness generation. -#[derive(Debug)] -pub(crate) struct KeccakMemoryOp { - // The address at which we will read inputs and write outputs. - pub(crate) context: usize, - pub(crate) segment: Segment, - pub(crate) virt: usize, - - /// The timestamp at which inputs should be read from memory. - /// Outputs will be written at the following timestamp. - pub(crate) read_timestamp: usize, - - /// The input that was read at that address. - pub(crate) input: [u64; NUM_INPUTS], - pub(crate) output: [u64; NUM_INPUTS], -} - -#[derive(Copy, Clone, Default)] -pub struct KeccakMemoryStark { - pub(crate) f: PhantomData, -} - -impl, const D: usize> KeccakMemoryStark { - #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn generate_trace( - &self, - operations: Vec, - min_rows: usize, - timing: &mut TimingTree, - ) -> Vec> { - // Generate the witness row-wise. - let trace_rows = timed!( - timing, - "generate trace rows", - self.generate_trace_rows(operations, min_rows) - ); - - let trace_polys = timed!( - timing, - "convert to PolynomialValues", - trace_rows_to_poly_values(trace_rows) - ); - - trace_polys - } - - fn generate_trace_rows( - &self, - operations: Vec, - min_rows: usize, - ) -> Vec<[F; NUM_COLUMNS]> { - let num_rows = operations.len().max(min_rows).next_power_of_two(); - let mut rows = Vec::with_capacity(num_rows); - for op in operations { - rows.push(self.generate_row_for_op(op)); - } - - let padding_row = self.generate_padding_row(); - for _ in rows.len()..num_rows { - rows.push(padding_row); - } - rows - } - - fn generate_row_for_op(&self, op: KeccakMemoryOp) -> [F; NUM_COLUMNS] { - let mut row = [F::ZERO; NUM_COLUMNS]; - row[COL_IS_REAL] = F::ONE; - row[COL_CONTEXT] = F::from_canonical_usize(op.context); - row[COL_SEGMENT] = F::from_canonical_usize(op.segment as usize); - row[COL_VIRTUAL] = F::from_canonical_usize(op.virt); - row[COL_READ_TIMESTAMP] = F::from_canonical_usize(op.read_timestamp); - for i in 0..25 { - let input_u64 = op.input[i]; - let output_u64 = op.output[i]; - for j in 0..8 { - let byte_index = i * 8 + j; - row[col_input_byte(byte_index)] = F::from_canonical_u8(input_u64.to_le_bytes()[j]); - row[col_output_byte(byte_index)] = - F::from_canonical_u8(output_u64.to_le_bytes()[j]); - } - } - row - } - - fn generate_padding_row(&self) -> [F; NUM_COLUMNS] { - // We just need COL_IS_REAL to be zero, which it is by default. - // The other fields will have no effect. - [F::ZERO; NUM_COLUMNS] - } -} - -impl, const D: usize> Stark for KeccakMemoryStark { - const COLUMNS: usize = NUM_COLUMNS; - - fn eval_packed_generic( - &self, - vars: StarkEvaluationVars, - yield_constr: &mut ConstraintConsumer

, - ) where - FE: FieldExtension, - P: PackedField, - { - // is_real must be 0 or 1. - let is_real = vars.local_values[COL_IS_REAL]; - yield_constr.constraint(is_real * (is_real - P::ONES)); - } - - fn eval_ext_circuit( - &self, - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - vars: StarkEvaluationTargets, - yield_constr: &mut RecursiveConstraintConsumer, - ) { - // is_real must be 0 or 1. - let is_real = vars.local_values[COL_IS_REAL]; - let constraint = builder.mul_sub_extension(is_real, is_real, is_real); - yield_constr.constraint(builder, constraint); - } - - fn constraint_degree(&self) -> usize { - 2 - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; - use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; - - #[test] - fn test_stark_degree() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type S = KeccakMemoryStark; - - let stark = S { - f: Default::default(), - }; - test_stark_low_degree(stark) - } - - #[test] - fn test_stark_circuit() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type S = KeccakMemoryStark; - - let stark = S { - f: Default::default(), - }; - test_stark_circuit_constraints::(stark) - } -} diff --git a/evm/src/keccak_memory/mod.rs b/evm/src/keccak_memory/mod.rs deleted file mode 100644 index 7b5e3d01..00000000 --- a/evm/src/keccak_memory/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod columns; -pub mod keccak_memory_stark; diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs index 08194e87..440c59ab 100644 --- a/evm/src/keccak_sponge/columns.rs +++ b/evm/src/keccak_sponge/columns.rs @@ -21,7 +21,7 @@ pub(crate) struct KeccakSpongeColumnsView { /// in the block will be padding bytes; 0 otherwise. pub is_final_block: T, - // The address at which we will read the input block. + // The base address at which we will read the input block. pub context: T, pub segment: T, pub virt: T, diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index f2af8895..ebefce06 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -18,12 +18,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::kernel::keccak_util::keccakf_u32s; use crate::cross_table_lookup::Column; use crate::keccak_sponge::columns::*; -use crate::memory::segments::Segment; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +use crate::witness::memory::MemoryAddress; -#[allow(unused)] // TODO: Should be used soon. pub(crate) fn ctl_looked_data() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; let outputs = Column::singles(&cols.updated_state_u32s[..8]); @@ -31,14 +30,13 @@ pub(crate) fn ctl_looked_data() -> Vec> { cols.context, cols.segment, cols.virt, - cols.timestamp, cols.len, + cols.timestamp, ]) .chain(outputs) .collect() } -#[allow(unused)] // TODO: Should be used soon. pub(crate) fn ctl_looking_keccak() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; Column::singles( @@ -52,7 +50,6 @@ pub(crate) fn ctl_looking_keccak() -> Vec> { .collect() } -#[allow(unused)] // TODO: Should be used soon. pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; @@ -81,14 +78,18 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { res } +pub(crate) fn num_logic_ctls() -> usize { + const U8S_PER_CTL: usize = 32; + ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL) +} + /// CTL for performing the `i`th logic CTL. Since we need to do 136 byte XORs, and the logic CTL can /// XOR 32 bytes per CTL, there are 5 such CTLs. -#[allow(unused)] // TODO: Should be used soon. pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { const U32S_PER_CTL: usize = 8; const U8S_PER_CTL: usize = 32; - debug_assert!(i < ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL)); + debug_assert!(i < num_logic_ctls()); let cols = KECCAK_SPONGE_COL_MAP; let mut res = vec![ @@ -111,7 +112,7 @@ pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { .chunks(size_of::()) .map(|chunk| Column::le_bytes(chunk)) .chain(repeat(Column::zero())) - .take(U8S_PER_CTL), + .take(U32S_PER_CTL), ); // The output contains the XOR'd rate part. @@ -124,14 +125,12 @@ pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { res } -#[allow(unused)] // TODO: Should be used soon. pub(crate) fn ctl_looked_filter() -> Column { // The CPU table is only interested in our final-block rows, since those contain the final // sponge output. Column::single(KECCAK_SPONGE_COL_MAP.is_final_block) } -#[allow(unused)] // TODO: Should be used soon. /// CTL filter for reading the `i`th byte of input from memory. pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { // We perform the `i`th read if either @@ -141,26 +140,26 @@ pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len[i..])) } +pub(crate) fn ctl_looking_keccak_filter() -> Column { + let cols = KECCAK_SPONGE_COL_MAP; + Column::sum([cols.is_full_input_block, cols.is_final_block]) +} + /// Information about a Keccak sponge operation needed for witness generation. #[derive(Debug)] pub(crate) struct KeccakSpongeOp { - // The address at which inputs are read. - pub(crate) context: usize, - pub(crate) segment: Segment, - pub(crate) virt: usize, + /// The base address at which inputs are read. + pub(crate) base_address: MemoryAddress, /// The timestamp at which inputs are read. pub(crate) timestamp: usize, - /// The length of the input, in bytes. - pub(crate) len: usize, - /// The input that was read. pub(crate) input: Vec, } #[derive(Copy, Clone, Default)] -pub(crate) struct KeccakSpongeStark { +pub struct KeccakSpongeStark { f: PhantomData, } @@ -261,7 +260,7 @@ impl, const D: usize> KeccakSpongeStark { sponge_state: [u32; KECCAK_WIDTH_U32S], final_inputs: &[u8], ) -> KeccakSpongeColumnsView { - assert_eq!(already_absorbed_bytes + final_inputs.len(), op.len); + assert_eq!(already_absorbed_bytes + final_inputs.len(), op.input.len()); let mut row = KeccakSpongeColumnsView { is_final_block: F::ONE, @@ -295,11 +294,11 @@ impl, const D: usize> KeccakSpongeStark { already_absorbed_bytes: usize, mut sponge_state: [u32; KECCAK_WIDTH_U32S], ) { - row.context = F::from_canonical_usize(op.context); - row.segment = F::from_canonical_usize(op.segment as usize); - row.virt = F::from_canonical_usize(op.virt); + row.context = F::from_canonical_usize(op.base_address.context); + row.segment = F::from_canonical_usize(op.base_address.segment); + row.virt = F::from_canonical_usize(op.base_address.virt); row.timestamp = F::from_canonical_usize(op.timestamp); - row.len = F::from_canonical_usize(op.len); + row.len = F::from_canonical_usize(op.input.len()); row.already_absorbed_bytes = F::from_canonical_usize(already_absorbed_bytes); row.original_rate_u32s = sponge_state[..KECCAK_RATE_U32S] @@ -410,6 +409,7 @@ mod tests { use crate::keccak_sponge::keccak_sponge_stark::{KeccakSpongeOp, KeccakSpongeStark}; use crate::memory::segments::Segment; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + use crate::witness::memory::MemoryAddress; #[test] fn test_stark_degree() -> Result<()> { @@ -443,11 +443,12 @@ mod tests { let expected_output = keccak(&input); let op = KeccakSpongeOp { - context: 0, - segment: Segment::Code, - virt: 0, + base_address: MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 0, + }, timestamp: 0, - len: input.len(), input, }; let stark = S::default(); diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 6f332b59..4c368491 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] +#![allow(clippy::field_reassign_with_default)] #![feature(let_chains)] #![feature(generic_const_exprs)] @@ -14,7 +15,6 @@ pub mod cross_table_lookup; pub mod generation; mod get_challenges; pub mod keccak; -pub mod keccak_memory; pub mod keccak_sponge; pub mod logic; pub mod lookup; @@ -29,3 +29,12 @@ pub mod util; pub mod vanishing_poly; pub mod vars; pub mod verifier; +pub mod witness; + +// Set up Jemalloc +#[cfg(not(target_env = "msvc"))] +use jemallocator::Jemalloc; + +#[cfg(not(target_env = "msvc"))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; diff --git a/evm/src/logic.rs b/evm/src/logic.rs index dc6fc777..b7429610 100644 --- a/evm/src/logic.rs +++ b/evm/src/logic.rs @@ -72,13 +72,23 @@ pub struct LogicStark { pub f: PhantomData, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub(crate) enum Op { And, Or, Xor, } +impl Op { + pub(crate) fn result(&self, a: U256, b: U256) -> U256 { + match self { + Op::And => a & b, + Op::Or => a | b, + Op::Xor => a ^ b, + } + } +} + #[derive(Debug)] pub(crate) struct Operation { operator: Op, @@ -89,11 +99,7 @@ pub(crate) struct Operation { impl Operation { pub(crate) fn new(operator: Op, input0: U256, input1: U256) -> Self { - let result = match operator { - Op::And => input0 & input1, - Op::Or => input0 | input1, - Op::Xor => input0 ^ input1, - }; + let result = operator.result(input0, input1); Operation { operator, input0, @@ -101,18 +107,44 @@ impl Operation { result, } } + + fn into_row(self) -> [F; NUM_COLUMNS] { + let Operation { + operator, + input0, + input1, + result, + } = self; + let mut row = [F::ZERO; NUM_COLUMNS]; + row[match operator { + Op::And => columns::IS_AND, + Op::Or => columns::IS_OR, + Op::Xor => columns::IS_XOR, + }] = F::ONE; + for i in 0..256 { + row[columns::INPUT0.start + i] = F::from_bool(input0.bit(i)); + row[columns::INPUT1.start + i] = F::from_bool(input1.bit(i)); + } + let result_limbs: &[u64] = result.as_ref(); + for (i, &limb) in result_limbs.iter().enumerate() { + row[columns::RESULT.start + 2 * i] = F::from_canonical_u32(limb as u32); + row[columns::RESULT.start + 2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } + row + } } impl LogicStark { pub(crate) fn generate_trace( &self, operations: Vec, + min_rows: usize, timing: &mut TimingTree, ) -> Vec> { let trace_rows = timed!( timing, "generate trace rows", - self.generate_trace_rows(operations) + self.generate_trace_rows(operations, min_rows) ); let trace_polys = timed!( timing, @@ -122,46 +154,30 @@ impl LogicStark { trace_polys } - fn generate_trace_rows(&self, operations: Vec) -> Vec<[F; NUM_COLUMNS]> { + fn generate_trace_rows( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_COLUMNS]> { let len = operations.len(); - let padded_len = len.next_power_of_two(); + let padded_len = len.max(min_rows).next_power_of_two(); let mut rows = Vec::with_capacity(padded_len); for op in operations { - rows.push(Self::generate_row(op)); + rows.push(op.into_row()); } // Pad to a power of two. for _ in len..padded_len { - rows.push([F::ZERO; columns::NUM_COLUMNS]); + rows.push([F::ZERO; NUM_COLUMNS]); } rows } - - fn generate_row(operation: Operation) -> [F; columns::NUM_COLUMNS] { - let mut row = [F::ZERO; columns::NUM_COLUMNS]; - match operation.operator { - Op::And => row[columns::IS_AND] = F::ONE, - Op::Or => row[columns::IS_OR] = F::ONE, - Op::Xor => row[columns::IS_XOR] = F::ONE, - } - for (i, col) in columns::INPUT0.enumerate() { - row[col] = F::from_bool(operation.input0.bit(i)); - } - for (i, col) in columns::INPUT1.enumerate() { - row[col] = F::from_bool(operation.input1.bit(i)); - } - for (i, col) in columns::RESULT.enumerate() { - let bit_range = i * PACKED_LIMB_BITS..(i + 1) * PACKED_LIMB_BITS; - row[col] = limb_from_bits_le(bit_range.map(|j| F::from_bool(operation.result.bit(j)))); - } - row - } } impl, const D: usize> Stark for LogicStark { - const COLUMNS: usize = columns::NUM_COLUMNS; + const COLUMNS: usize = NUM_COLUMNS; fn eval_packed_generic( &self, diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index f5455a53..13a4180b 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -1,6 +1,5 @@ use std::marker::PhantomData; -use ethereum_types::U256; use itertools::Itertools; use maybe_rayon::*; use plonky2::field::extension::{Extendable, FieldExtension}; @@ -20,11 +19,12 @@ use crate::memory::columns::{ COUNTER_PERMUTED, FILTER, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, }; -use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +use crate::witness::memory::MemoryOpKind::Read; +use crate::witness::memory::{MemoryAddress, MemoryOp}; pub fn ctl_data() -> Vec> { let mut res = @@ -43,31 +43,24 @@ pub struct MemoryStark { pub(crate) f: PhantomData, } -#[derive(Clone, Debug)] -pub(crate) struct MemoryOp { - /// true if this is an actual memory operation, or false if it's a padding row. - pub filter: bool, - pub timestamp: usize, - pub is_read: bool, - pub context: usize, - pub segment: Segment, - pub virt: usize, - pub value: U256, -} - impl MemoryOp { /// Generate a row for a given memory operation. Note that this does not generate columns which /// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later. /// It also does not generate columns such as `COUNTER`, which are generated later, after the /// trace has been transposed into column-major form. - fn to_row(&self) -> [F; NUM_COLUMNS] { + fn into_row(self) -> [F; NUM_COLUMNS] { let mut row = [F::ZERO; NUM_COLUMNS]; row[FILTER] = F::from_bool(self.filter); row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); - row[IS_READ] = F::from_bool(self.is_read); - row[ADDR_CONTEXT] = F::from_canonical_usize(self.context); - row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment as usize); - row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt); + row[IS_READ] = F::from_bool(self.kind == Read); + let MemoryAddress { + context, + segment, + virt, + } = self.address; + row[ADDR_CONTEXT] = F::from_canonical_usize(context); + row[ADDR_SEGMENT] = F::from_canonical_usize(segment); + row[ADDR_VIRTUAL] = F::from_canonical_usize(virt); for j in 0..VALUE_LIMBS { row[value_limb(j)] = F::from_canonical_u32((self.value >> (j * 32)).low_u32()); } @@ -80,14 +73,14 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { .iter() .tuple_windows() .map(|(curr, next)| { - if curr.context != next.context { - next.context - curr.context - 1 - } else if curr.segment != next.segment { - next.segment as usize - curr.segment as usize - 1 - } else if curr.virt != next.virt { - next.virt - curr.virt - 1 + if curr.address.context != next.address.context { + next.address.context - curr.address.context - 1 + } else if curr.address.segment != next.address.segment { + next.address.segment - curr.address.segment - 1 + } else if curr.address.virt != next.address.virt { + next.address.virt - curr.address.virt - 1 } else { - next.timestamp - curr.timestamp - 1 + next.timestamp - curr.timestamp } }) .max() @@ -131,7 +124,7 @@ pub fn generate_first_change_flags_and_rc(trace_rows: &mut [[F; NU } else if virtual_first_change { next_virt - virt - F::ONE } else { - next_timestamp - timestamp - F::ONE + next_timestamp - timestamp }; } } @@ -140,13 +133,20 @@ impl, const D: usize> MemoryStark { /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated /// later, after transposing to column-major form. fn generate_trace_row_major(&self, mut memory_ops: Vec) -> Vec<[F; NUM_COLUMNS]> { - memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); + memory_ops.sort_by_key(|op| { + ( + op.address.context, + op.address.segment, + op.address.virt, + op.timestamp, + ) + }); Self::pad_memory_ops(&mut memory_ops); let mut trace_rows = memory_ops .into_par_iter() - .map(|op| op.to_row()) + .map(|op| op.into_row()) .collect::>(); generate_first_change_flags_and_rc(trace_rows.as_mut_slice()); trace_rows @@ -170,7 +170,7 @@ impl, const D: usize> MemoryStark { let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two(); let to_pad = num_ops_padded - num_ops; - let last_op = memory_ops.last().expect("No memory ops?").clone(); + let last_op = *memory_ops.last().expect("No memory ops?"); // We essentially repeat the last operation until our operation list has the desired size, // with a few changes: @@ -181,7 +181,7 @@ impl, const D: usize> MemoryStark { memory_ops.push(MemoryOp { filter: false, timestamp: last_op.timestamp + i + 1, - is_read: true, + kind: Read, ..last_op }); } @@ -283,7 +283,7 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark(num_ops: usize, rng: &mut R) -> Vec { - let mut memory_ops = Vec::new(); - - let mut current_memory_values: HashMap<(usize, Segment, usize), U256> = HashMap::new(); - let num_cycles = num_ops / 2; - for clock in 0..num_cycles { - let mut used_indices = HashSet::new(); - let mut new_writes_this_cycle = HashMap::new(); - let mut has_read = false; - for _ in 0..2 { - let mut channel_index = rng.gen_range(0..NUM_CHANNELS); - while used_indices.contains(&channel_index) { - channel_index = rng.gen_range(0..NUM_CHANNELS); - } - used_indices.insert(channel_index); - - let is_read = if clock == 0 { - false - } else { - !has_read && rng.gen() - }; - has_read = has_read || is_read; - - let (context, segment, virt, vals) = if is_read { - let written: Vec<_> = current_memory_values.keys().collect(); - let &(mut context, mut segment, mut virt) = - written[rng.gen_range(0..written.len())]; - while new_writes_this_cycle.contains_key(&(context, segment, virt)) { - (context, segment, virt) = *written[rng.gen_range(0..written.len())]; - } - - let &vals = current_memory_values - .get(&(context, segment, virt)) - .unwrap(); - - (context, segment, virt, vals) - } else { - // TODO: with taller memory table or more padding (to enable range-checking bigger diffs), - // test larger address values. - let mut context = rng.gen_range(0..40); - let segments = [Segment::Code, Segment::Stack, Segment::MainMemory]; - let mut segment = *segments.choose(rng).unwrap(); - let mut virt = rng.gen_range(0..20); - while new_writes_this_cycle.contains_key(&(context, segment, virt)) { - context = rng.gen_range(0..40); - segment = *segments.choose(rng).unwrap(); - virt = rng.gen_range(0..20); - } - - let val = U256(rng.gen()); - - new_writes_this_cycle.insert((context, segment, virt), val); - - (context, segment, virt, val) - }; - - let timestamp = clock * NUM_CHANNELS + channel_index; - memory_ops.push(MemoryOp { - filter: true, - timestamp, - is_read, - context, - segment, - virt, - value: vals, - }); - } - for (k, v) in new_writes_this_cycle { - current_memory_values.insert(k, v); - } - } - - memory_ops - } - #[test] fn test_stark_degree() -> Result<()> { const D: usize = 2; diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 4627784d..c99e1873 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -24,7 +24,7 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData}; use crate::generation::{generate_traces, GenerationInputs}; use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::{ @@ -49,7 +49,7 @@ where [(); C::Hasher::HASH_SIZE]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakMemoryStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, { @@ -71,7 +71,7 @@ where [(); C::Hasher::HASH_SIZE]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakMemoryStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, { @@ -132,12 +132,12 @@ where &mut challenger, timing, )?; - let keccak_memory_proof = prove_single_table( - &all_stark.keccak_memory_stark, + let keccak_sponge_proof = prove_single_table( + &all_stark.keccak_sponge_stark, config, - &trace_poly_values[Table::KeccakMemory as usize], - &trace_commitments[Table::KeccakMemory as usize], - &ctl_data_per_table[Table::KeccakMemory as usize], + &trace_poly_values[Table::KeccakSponge as usize], + &trace_commitments[Table::KeccakSponge as usize], + &ctl_data_per_table[Table::KeccakSponge as usize], &mut challenger, timing, )?; @@ -163,7 +163,7 @@ where let stark_proofs = [ cpu_proof, keccak_proof, - keccak_memory_proof, + keccak_sponge_proof, logic_proof, memory_proof, ]; diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 999b5e13..bc1357a9 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -27,7 +27,7 @@ use crate::cross_table_lookup::{ CtlCheckVarsTarget, }; use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::{ @@ -231,7 +231,7 @@ impl, C: GenericConfig, const D: usize> .enumerate() { builder.verify_proof::( - recursive_proof, + &recursive_proof, &verifier_data_target, &verifier_data[i].common, ); @@ -332,7 +332,7 @@ pub fn all_verifier_data_recursive_stark_proof< where [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakMemoryStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, @@ -356,9 +356,9 @@ where circuit_config, ), verifier_data_recursive_stark_proof( - Table::KeccakMemory, - all_stark.keccak_memory_stark, - degree_bits[Table::KeccakMemory as usize], + Table::KeccakSponge, + all_stark.keccak_sponge_stark, + degree_bits[Table::KeccakSponge as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, @@ -534,10 +534,10 @@ pub fn add_virtual_all_proof, const D: usize>( ), add_virtual_stark_proof( builder, - &all_stark.keccak_memory_stark, + &all_stark.keccak_sponge_stark, config, - degree_bits[Table::KeccakMemory as usize], - nums_ctl_zs[Table::KeccakMemory as usize], + degree_bits[Table::KeccakSponge as usize], + nums_ctl_zs[Table::KeccakSponge as usize], ), add_virtual_stark_proof( builder, @@ -853,7 +853,7 @@ pub(crate) mod tests { use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{CrossTableLookup, CtlCheckVarsTarget}; use crate::keccak::keccak_stark::KeccakStark; - use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; + use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; @@ -866,6 +866,7 @@ pub(crate) mod tests { /// Recursively verify a Stark proof. /// Outputs the recursive proof and the associated verifier data. + #[allow(unused)] // TODO: used later? fn recursively_verify_stark_proof< F: RichField + Extendable, C: GenericConfig, @@ -965,6 +966,7 @@ pub(crate) mod tests { } /// Recursively verify every Stark proof in an `AllProof`. + #[allow(unused)] // TODO: used later? pub fn recursively_verify_all_proof< F: RichField + Extendable, C: GenericConfig, @@ -978,7 +980,7 @@ pub(crate) mod tests { where [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakMemoryStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, @@ -1013,9 +1015,9 @@ pub(crate) mod tests { )? .0, recursively_verify_stark_proof( - Table::KeccakMemory, - all_stark.keccak_memory_stark, - &all_proof.stark_proofs[Table::KeccakMemory as usize], + Table::KeccakSponge, + all_stark.keccak_sponge_stark, + &all_proof.stark_proofs[Table::KeccakSponge as usize], &all_stark.cross_table_lookups, &ctl_challenges, states[2], diff --git a/evm/src/util.rs b/evm/src/util.rs index 7f958fd2..fb3f1f13 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -2,6 +2,7 @@ use std::mem::{size_of, transmute_copy, ManuallyDrop}; use ethereum_types::{H160, H256, U256}; use itertools::Itertools; +use num::BigUint; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -44,6 +45,7 @@ pub fn trace_rows_to_poly_values( .collect() } +#[allow(unused)] // TODO: Remove? /// Returns the 32-bit little-endian limbs of a `U256`. pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { u256.0 @@ -98,3 +100,55 @@ pub(crate) unsafe fn transmute_no_compile_time_size_checks(value: T) -> U // Copy the bit pattern. The original value is no longer safe to use. transmute_copy(&value) } + +pub(crate) fn addmod(x: U256, y: U256, m: U256) -> U256 { + if m.is_zero() { + return m; + } + let x = u256_to_biguint(x); + let y = u256_to_biguint(y); + let m = u256_to_biguint(m); + biguint_to_u256((x + y) % m) +} + +pub(crate) fn mulmod(x: U256, y: U256, m: U256) -> U256 { + if m.is_zero() { + return m; + } + let x = u256_to_biguint(x); + let y = u256_to_biguint(y); + let m = u256_to_biguint(m); + biguint_to_u256(x * y % m) +} + +pub(crate) fn submod(x: U256, y: U256, m: U256) -> U256 { + if m.is_zero() { + return m; + } + let mut x = u256_to_biguint(x); + let y = u256_to_biguint(y); + let m = u256_to_biguint(m); + while x < y { + x += &m; + } + biguint_to_u256((x - y) % m) +} + +pub(crate) fn u256_to_biguint(x: U256) -> BigUint { + let mut bytes = [0u8; 32]; + x.to_little_endian(&mut bytes); + BigUint::from_bytes_le(&bytes) +} + +pub(crate) fn biguint_to_u256(x: BigUint) -> U256 { + let bytes = x.to_bytes_le(); + U256::from_little_endian(&bytes) +} + +pub(crate) fn u256_saturating_cast_usize(x: U256) -> usize { + if x > usize::MAX.into() { + usize::MAX + } else { + x.as_usize() + } +} diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index ce15399a..a0329d04 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -1,3 +1,5 @@ +use std::any::type_name; + use anyhow::{ensure, Result}; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::types::Field; @@ -12,7 +14,7 @@ use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars}; use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckVars; @@ -31,7 +33,7 @@ pub fn verify_proof, C: GenericConfig, co where [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakMemoryStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, @@ -46,7 +48,7 @@ where let AllStark { cpu_stark, keccak_stark, - keccak_memory_stark, + keccak_sponge_stark, logic_stark, memory_stark, cross_table_lookups, @@ -74,10 +76,10 @@ where config, )?; verify_stark_proof_with_challenges( - keccak_memory_stark, - &all_proof.stark_proofs[Table::KeccakMemory as usize], - &stark_challenges[Table::KeccakMemory as usize], - &ctl_vars_per_table[Table::KeccakMemory as usize], + keccak_sponge_stark, + &all_proof.stark_proofs[Table::KeccakSponge as usize], + &stark_challenges[Table::KeccakSponge as usize], + &ctl_vars_per_table[Table::KeccakSponge as usize], config, )?; verify_stark_proof_with_challenges( @@ -122,6 +124,7 @@ where [(); S::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, { + log::debug!("Checking proof: {}", type_name::()); validate_proof_shape(&stark, proof, config, ctl_vars.len())?; let StarkOpeningSet { local_values, diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs new file mode 100644 index 00000000..bd4b03c9 --- /dev/null +++ b/evm/src/witness/errors.rs @@ -0,0 +1,10 @@ +#[allow(dead_code)] +#[derive(Debug)] +pub enum ProgramError { + OutOfGas, + InvalidOpcode, + StackUnderflow, + InvalidJumpDestination, + InvalidJumpiDestination, + StackOverflow, +} diff --git a/evm/src/witness/mem_tx.rs b/evm/src/witness/mem_tx.rs new file mode 100644 index 00000000..7cc33653 --- /dev/null +++ b/evm/src/witness/mem_tx.rs @@ -0,0 +1,12 @@ +use crate::witness::memory::{MemoryOp, MemoryOpKind, MemoryState}; + +pub fn apply_mem_ops(state: &mut MemoryState, mut ops: Vec) { + ops.sort_unstable_by_key(|mem_op| mem_op.timestamp); + + for op in ops { + let MemoryOp { address, op, .. } = op; + if let MemoryOpKind::Write(val) = op { + state.set(address, val); + } + } +} diff --git a/evm/src/witness/memory.rs b/evm/src/witness/memory.rs new file mode 100644 index 00000000..5e3c7bcf --- /dev/null +++ b/evm/src/witness/memory.rs @@ -0,0 +1,162 @@ +use ethereum_types::U256; + +use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; + +#[derive(Clone, Copy, Debug)] +pub enum MemoryChannel { + Code, + GeneralPurpose(usize), +} + +use MemoryChannel::{Code, GeneralPurpose}; + +use crate::memory::segments::Segment; +use crate::util::u256_saturating_cast_usize; + +impl MemoryChannel { + pub fn index(&self) -> usize { + match *self { + Code => 0, + GeneralPurpose(n) => { + assert!(n < NUM_GP_CHANNELS); + n + 1 + } + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct MemoryAddress { + pub(crate) context: usize, + pub(crate) segment: usize, + pub(crate) virt: usize, +} + +impl MemoryAddress { + pub(crate) fn new(context: usize, segment: Segment, virt: usize) -> Self { + Self { + context, + segment: segment as usize, + virt, + } + } + + pub(crate) fn new_u256s(context: U256, segment: U256, virt: U256) -> Self { + Self { + context: u256_saturating_cast_usize(context), + segment: u256_saturating_cast_usize(segment), + virt: u256_saturating_cast_usize(virt), + } + } + + pub(crate) fn increment(&mut self) { + self.virt = self.virt.saturating_add(1); + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum MemoryOpKind { + Read, + Write, +} + +#[derive(Clone, Copy, Debug)] +pub struct MemoryOp { + /// true if this is an actual memory operation, or false if it's a padding row. + pub filter: bool, + pub timestamp: usize, + pub address: MemoryAddress, + pub kind: MemoryOpKind, + pub value: U256, +} + +impl MemoryOp { + pub fn new( + channel: MemoryChannel, + clock: usize, + address: MemoryAddress, + kind: MemoryOpKind, + value: U256, + ) -> Self { + let timestamp = clock * NUM_CHANNELS + channel.index(); + MemoryOp { + filter: true, + timestamp, + address, + kind, + value, + } + } +} + +#[derive(Clone, Debug)] +pub struct MemoryState { + pub(crate) contexts: Vec, +} + +impl MemoryState { + pub fn new(kernel_code: &[u8]) -> Self { + let code_u256s = kernel_code.iter().map(|&x| x.into()).collect(); + let mut result = Self::default(); + result.contexts[0].segments[Segment::Code as usize].content = code_u256s; + result + } + + pub fn apply_ops(&mut self, ops: &[MemoryOp]) { + for &op in ops { + let MemoryOp { + address, + kind, + value, + .. + } = op; + if kind == MemoryOpKind::Write { + self.set(address, value); + } + } + } + + pub fn get(&self, address: MemoryAddress) -> U256 { + self.contexts[address.context].segments[address.segment].get(address.virt) + } + + pub fn set(&mut self, address: MemoryAddress, val: U256) { + self.contexts[address.context].segments[address.segment].set(address.virt, val); + } +} + +impl Default for MemoryState { + fn default() -> Self { + Self { + // We start with an initial context for the kernel. + contexts: vec![MemoryContextState::default()], + } + } +} + +#[derive(Clone, Default, Debug)] +pub(crate) struct MemoryContextState { + /// The content of each memory segment. + pub(crate) segments: [MemorySegmentState; Segment::COUNT], +} + +#[derive(Clone, Default, Debug)] +pub(crate) struct MemorySegmentState { + pub(crate) content: Vec, +} + +impl MemorySegmentState { + pub(crate) fn get(&self, virtual_addr: usize) -> U256 { + self.content + .get(virtual_addr) + .copied() + .unwrap_or(U256::zero()) + } + + pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) { + if virtual_addr >= self.content.len() { + self.content.resize(virtual_addr + 1, U256::zero()); + } + self.content[virtual_addr] = value; + } +} diff --git a/evm/src/witness/mod.rs b/evm/src/witness/mod.rs new file mode 100644 index 00000000..b9da345e --- /dev/null +++ b/evm/src/witness/mod.rs @@ -0,0 +1,7 @@ +mod errors; +pub(crate) mod memory; +mod operation; +pub(crate) mod state; +pub(crate) mod traces; +pub mod transition; +pub(crate) mod util; diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs new file mode 100644 index 00000000..6d65f16c --- /dev/null +++ b/evm/src/witness/operation.rs @@ -0,0 +1,507 @@ +use ethereum_types::{BigEndianHash, U256}; +use itertools::Itertools; +use keccak_hash::keccak; +use plonky2::field::types::Field; + +use crate::cpu::columns::CpuColumnsView; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff; +use crate::generation::state::GenerationState; +use crate::memory::segments::Segment; +use crate::util::u256_saturating_cast_usize; +use crate::witness::errors::ProgramError; +use crate::witness::memory::MemoryAddress; +use crate::witness::util::{ + keccak_sponge_log, mem_read_code_with_log_and_fill, mem_read_gp_with_log_and_fill, + mem_write_gp_log_and_fill, stack_pop_with_log_and_fill, stack_push_log_and_fill, +}; +use crate::{arithmetic, logic}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Operation { + Push(u8), + Dup(u8), + Swap(u8), + Iszero, + Not, + Byte, + Syscall(u8), + Eq, + BinaryLogic(logic::Op), + BinaryArithmetic(arithmetic::BinaryOperator), + TernaryArithmetic(arithmetic::TernaryOperator), + KeccakGeneral, + ProverInput, + Pop, + Jump, + Jumpi, + Pc, + Gas, + Jumpdest, + GetContext, + SetContext, + ConsumeGas, + ExitKernel, + MloadGeneral, + MstoreGeneral, +} + +pub(crate) fn generate_binary_logic_op( + op: logic::Op, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let operation = logic::Operation::new(op, in0, in1); + let log_out = stack_push_log_and_fill(state, &mut row, operation.result)?; + + state.traces.push_logic(operation); + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_binary_arithmetic_op( + operator: arithmetic::BinaryOperator, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(input0, log_in0), (input1, log_in1)] = + stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let operation = arithmetic::Operation::binary(operator, input0, input1); + + let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; + + if operator == arithmetic::BinaryOperator::Shl || operator == arithmetic::BinaryOperator::Shr { + const LOOKUP_CHANNEL: usize = 2; + let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); + if input0.bits() <= 32 { + let (_, read) = + mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row); + state.traces.push_memory(read); + } else { + // The shift constraints still expect the address to be set, even though no read will occur. + let mut channel = &mut row.mem_channels[LOOKUP_CHANNEL]; + channel.addr_context = F::from_canonical_usize(lookup_addr.context); + channel.addr_segment = F::from_canonical_usize(lookup_addr.segment); + channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); + } + } + + state.traces.push_arithmetic(operation); + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_ternary_arithmetic_op( + operator: arithmetic::TernaryOperator, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(input0, log_in0), (input1, log_in1), (input2, log_in2)] = + stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; + let operation = arithmetic::Operation::ternary(operator, input0, input1, input2); + let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; + + state.traces.push_arithmetic(operation); + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_keccak_general( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + row.is_keccak_sponge = F::ONE; + let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; + let len = len.as_usize(); + + let base_address = MemoryAddress::new_u256s(context, segment, base_virt); + let input = (0..len) + .map(|i| { + let address = MemoryAddress { + virt: base_address.virt.saturating_add(i), + ..base_address + }; + let val = state.memory.get(address); + val.as_u32() as u8 + }) + .collect_vec(); + log::debug!("Hashing {:?}", input); + + let hash = keccak(&input); + let log_push = stack_push_log_and_fill(state, &mut row, hash.into_uint())?; + + keccak_sponge_log(state, base_address, input); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_push); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_prover_input( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let pc = state.registers.program_counter; + let input_fn = &KERNEL.prover_inputs[&pc]; + let input = state.prover_input(input_fn); + let write = stack_push_log_and_fill(state, &mut row, input)?; + + state.traces.push_memory(write); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_pop( + state: &mut GenerationState, + row: CpuColumnsView, +) -> Result<(), ProgramError> { + if state.registers.stack_len == 0 { + return Err(ProgramError::StackUnderflow); + } + + state.registers.stack_len -= 1; + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_jump( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + + state.traces.push_memory(log_in0); + state.traces.push_cpu(row); + state.registers.program_counter = u256_saturating_cast_usize(dst); + // TODO: Set other cols like input0_upper_sum_inv. + Ok(()) +} + +pub(crate) fn generate_jumpi( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_cpu(row); + state.registers.program_counter = if cond.is_zero() { + state.registers.program_counter + 1 + } else { + u256_saturating_cast_usize(dst) + }; + // TODO: Set other cols like input0_upper_sum_inv. + Ok(()) +} + +pub(crate) fn generate_push( + n: u8, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let context = state.registers.effective_context(); + let num_bytes = n as usize + 1; + let initial_offset = state.registers.program_counter + 1; + let offsets = initial_offset..initial_offset + num_bytes; + let mut addrs = offsets.map(|offset| MemoryAddress::new(context, Segment::Code, offset)); + + // First read val without going through `mem_read_with_log` type methods, so we can pass it + // to stack_push_log_and_fill. + let bytes = (0..num_bytes) + .map(|i| { + state + .memory + .get(MemoryAddress::new( + context, + Segment::Code, + initial_offset + i, + )) + .as_u32() as u8 + }) + .collect_vec(); + + let val = U256::from_big_endian(&bytes); + let write = stack_push_log_and_fill(state, &mut row, val)?; + + // In the first cycle, we read up to NUM_GP_CHANNELS - 1 bytes, leaving the last GP channel + // to push the result. + for (i, addr) in (&mut addrs).take(NUM_GP_CHANNELS - 1).enumerate() { + let (_, read) = mem_read_gp_with_log_and_fill(i, addr, state, &mut row); + state.traces.push_memory(read); + } + state.traces.push_memory(write); + state.traces.push_cpu(row); + + // In any subsequent cycles, we read up to 1 + NUM_GP_CHANNELS bytes. + for mut addrs_chunk in &addrs.chunks(1 + NUM_GP_CHANNELS) { + let mut row = CpuColumnsView::default(); + row.is_cpu_cycle = F::ONE; + row.op.push = F::ONE; + + let first_addr = addrs_chunk.next().unwrap(); + let (_, first_read) = mem_read_code_with_log_and_fill(first_addr, state, &mut row); + state.traces.push_memory(first_read); + + for (i, addr) in addrs_chunk.enumerate() { + let (_, read) = mem_read_gp_with_log_and_fill(i, addr, state, &mut row); + state.traces.push_memory(read); + } + + state.traces.push_cpu(row); + } + + Ok(()) +} + +pub(crate) fn generate_dup( + n: u8, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let other_addr_lo = state + .registers + .stack_len + .checked_sub(1 + (n as usize)) + .ok_or(ProgramError::StackUnderflow)?; + let other_addr = MemoryAddress::new( + state.registers.effective_context(), + Segment::Stack, + other_addr_lo, + ); + + let (val, log_in) = mem_read_gp_with_log_and_fill(0, other_addr, state, &mut row); + let log_out = stack_push_log_and_fill(state, &mut row, val)?; + + state.traces.push_memory(log_in); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_swap( + n: u8, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let other_addr_lo = state + .registers + .stack_len + .checked_sub(2 + (n as usize)) + .ok_or(ProgramError::StackUnderflow)?; + let other_addr = MemoryAddress::new( + state.registers.effective_context(), + Segment::Stack, + other_addr_lo, + ); + + let [(in0, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row); + let log_out0 = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 2, other_addr, state, &mut row, in0); + let log_out1 = stack_push_log_and_fill(state, &mut row, in1)?; + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out0); + state.traces.push_memory(log_out1); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_not( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let result = !x; + let log_out = stack_push_log_and_fill(state, &mut row, result)?; + + state.traces.push_memory(log_in); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_byte( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(i, log_in0), (x, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + + let byte = if i < 32.into() { + // byte(i) is the i'th little-endian byte; we want the i'th big-endian byte. + x.byte(31 - i.as_usize()) + } else { + 0 + }; + let log_out = stack_push_log_and_fill(state, &mut row, byte.into())?; + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_iszero( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let is_zero = x.is_zero(); + let result = { + let t: u64 = is_zero.into(); + t.into() + }; + let log_out = stack_push_log_and_fill(state, &mut row, result)?; + + generate_pinv_diff(x, U256::zero(), &mut row); + + state.traces.push_memory(log_in); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_syscall( + opcode: u8, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let handler_jumptable_addr = KERNEL.global_labels["syscall_jumptable"]; + let handler_addr_addr = handler_jumptable_addr + (opcode as usize); + let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill( + 0, + MemoryAddress::new(0, Segment::Code, handler_addr_addr), + state, + &mut row, + ); + let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill( + 1, + MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1), + state, + &mut row, + ); + let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill( + 2, + MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2), + state, + &mut row, + ); + + let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; + let new_program_counter = handler_addr.as_usize(); + + let syscall_info = U256::from(state.registers.program_counter) + + (U256::from(u64::from(state.registers.is_kernel)) << 32); + let log_out = stack_push_log_and_fill(state, &mut row, syscall_info)?; + + state.registers.program_counter = new_program_counter; + state.registers.is_kernel = true; + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + + Ok(()) +} + +pub(crate) fn generate_eq( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let eq = in0 == in1; + let result = U256::from(u64::from(eq)); + let log_out = stack_push_log_and_fill(state, &mut row, result)?; + + generate_pinv_diff(in0, in1, &mut row); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_exit_kernel( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(kexit_info, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let kexit_info_u64: [u64; 4] = kexit_info.0; + let program_counter = kexit_info_u64[0] as usize; + let is_kernel_mode_val = (kexit_info_u64[1] >> 32) as u32; + assert!(is_kernel_mode_val == 0 || is_kernel_mode_val == 1); + let is_kernel_mode = is_kernel_mode_val != 0; + + state.registers.program_counter = program_counter; + state.registers.is_kernel = is_kernel_mode; + + state.traces.push_memory(log_in); + state.traces.push_cpu(row); + + Ok(()) +} + +pub(crate) fn generate_mload_general( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(context, log_in0), (segment, log_in1), (virt, log_in2)] = + stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; + + let val = state + .memory + .get(MemoryAddress::new_u256s(context, segment, virt)); + let log_out = stack_push_log_and_fill(state, &mut row, val)?; + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_mstore_general( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(context, log_in0), (segment, log_in1), (virt, log_in2), (val, log_in3)] = + stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; + + let address = MemoryAddress { + context: context.as_usize(), + segment: segment.as_usize(), + virt: virt.as_usize(), + }; + let log_write = mem_write_gp_log_and_fill(4, address, state, &mut row, val); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_write); + state.traces.push_cpu(row); + Ok(()) +} diff --git a/evm/src/witness/state.rs b/evm/src/witness/state.rs new file mode 100644 index 00000000..112b08af --- /dev/null +++ b/evm/src/witness/state.rs @@ -0,0 +1,32 @@ +use crate::cpu::kernel::aggregator::KERNEL; + +const KERNEL_CONTEXT: usize = 0; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct RegistersState { + pub program_counter: usize, + pub is_kernel: bool, + pub stack_len: usize, + pub context: usize, +} + +impl RegistersState { + pub(crate) fn effective_context(&self) -> usize { + if self.is_kernel { + KERNEL_CONTEXT + } else { + self.context + } + } +} + +impl Default for RegistersState { + fn default() -> Self { + Self { + program_counter: KERNEL.global_labels["main"], + is_kernel: true, + stack_len: 0, + context: 0, + } + } +} diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs new file mode 100644 index 00000000..41b654fb --- /dev/null +++ b/evm/src/witness/traces.rs @@ -0,0 +1,161 @@ +use std::mem::size_of; + +use itertools::Itertools; +use plonky2::field::extension::Extendable; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::util::timing::TimingTree; + +use crate::all_stark::{AllStark, NUM_TABLES}; +use crate::config::StarkConfig; +use crate::cpu::columns::CpuColumnsView; +use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; +use crate::util::trace_rows_to_poly_values; +use crate::witness::memory::MemoryOp; +use crate::{arithmetic, keccak, logic}; + +#[derive(Clone, Copy, Debug)] +pub struct TraceCheckpoint { + pub(self) cpu_len: usize, + pub(self) keccak_len: usize, + pub(self) keccak_sponge_len: usize, + pub(self) logic_len: usize, + pub(self) arithmetic_len: usize, + pub(self) memory_len: usize, +} + +#[derive(Debug)] +pub(crate) struct Traces { + pub(crate) cpu: Vec>, + pub(crate) logic_ops: Vec, + pub(crate) arithmetic: Vec, + pub(crate) memory_ops: Vec, + pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, + pub(crate) keccak_sponge_ops: Vec, +} + +impl Traces { + pub fn new() -> Self { + Traces { + cpu: vec![], + logic_ops: vec![], + arithmetic: vec![], + memory_ops: vec![], + keccak_inputs: vec![], + keccak_sponge_ops: vec![], + } + } + + pub fn checkpoint(&self) -> TraceCheckpoint { + TraceCheckpoint { + cpu_len: self.cpu.len(), + keccak_len: self.keccak_inputs.len(), + keccak_sponge_len: self.keccak_sponge_ops.len(), + logic_len: self.logic_ops.len(), + arithmetic_len: self.arithmetic.len(), + memory_len: self.memory_ops.len(), + } + } + + pub fn rollback(&mut self, checkpoint: TraceCheckpoint) { + self.cpu.truncate(checkpoint.cpu_len); + self.keccak_inputs.truncate(checkpoint.keccak_len); + self.keccak_sponge_ops + .truncate(checkpoint.keccak_sponge_len); + self.logic_ops.truncate(checkpoint.logic_len); + self.arithmetic.truncate(checkpoint.arithmetic_len); + self.memory_ops.truncate(checkpoint.memory_len); + } + + pub fn mem_ops_since(&self, checkpoint: TraceCheckpoint) -> &[MemoryOp] { + &self.memory_ops[checkpoint.memory_len..] + } + + pub fn push_cpu(&mut self, val: CpuColumnsView) { + self.cpu.push(val); + } + + pub fn push_logic(&mut self, op: logic::Operation) { + self.logic_ops.push(op); + } + + pub fn push_arithmetic(&mut self, op: arithmetic::Operation) { + self.arithmetic.push(op); + } + + pub fn push_memory(&mut self, op: MemoryOp) { + self.memory_ops.push(op); + } + + pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) { + self.keccak_inputs.push(input); + } + + pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES]) { + let chunks = input + .chunks(size_of::()) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap(); + self.push_keccak(chunks); + } + + pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) { + self.keccak_sponge_ops.push(op); + } + + pub fn clock(&self) -> usize { + self.cpu.len() + } + + pub fn into_tables( + self, + all_stark: &AllStark, + config: &StarkConfig, + timing: &mut TimingTree, + ) -> [Vec>; NUM_TABLES] + where + T: RichField + Extendable, + { + let cap_elements = config.fri_config.num_cap_elements(); + let Traces { + cpu, + logic_ops, + arithmetic: _, // TODO + memory_ops, + keccak_inputs, + keccak_sponge_ops, + } = self; + + let cpu_rows = cpu.into_iter().map(|x| x.into()).collect(); + let cpu_trace = trace_rows_to_poly_values(cpu_rows); + let keccak_trace = + all_stark + .keccak_stark + .generate_trace(keccak_inputs, cap_elements, timing); + let keccak_sponge_trace = + all_stark + .keccak_sponge_stark + .generate_trace(keccak_sponge_ops, cap_elements, timing); + let logic_trace = all_stark + .logic_stark + .generate_trace(logic_ops, cap_elements, timing); + let memory_trace = all_stark.memory_stark.generate_trace(memory_ops, timing); + + [ + cpu_trace, + keccak_trace, + keccak_sponge_trace, + logic_trace, + memory_trace, + ] + } +} + +impl Default for Traces { + fn default() -> Self { + Self::new() + } +} diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs new file mode 100644 index 00000000..39aac810 --- /dev/null +++ b/evm/src/witness/transition.rs @@ -0,0 +1,279 @@ +use itertools::Itertools; +use plonky2::field::types::Field; + +use crate::cpu::columns::CpuColumnsView; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::generation::state::GenerationState; +use crate::memory::segments::Segment; +use crate::witness::errors::ProgramError; +use crate::witness::memory::MemoryAddress; +use crate::witness::operation::*; +use crate::witness::state::RegistersState; +use crate::witness::util::{mem_read_code_with_log_and_fill, stack_peek}; +use crate::{arithmetic, logic}; + +fn read_code_memory(state: &mut GenerationState, row: &mut CpuColumnsView) -> u8 { + let code_context = state.registers.effective_context(); + row.code_context = F::from_canonical_usize(code_context); + + let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter); + let (opcode, mem_log) = mem_read_code_with_log_and_fill(address, state, row); + + state.traces.push_memory(mem_log); + + opcode +} + +fn decode(registers: RegistersState, opcode: u8) -> Result { + match (opcode, registers.is_kernel) { + (0x00, _) => Ok(Operation::Syscall(opcode)), + (0x01, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add)), + (0x02, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul)), + (0x03, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Sub)), + (0x04, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Div)), + (0x05, _) => Ok(Operation::Syscall(opcode)), + (0x06, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod)), + (0x07, _) => Ok(Operation::Syscall(opcode)), + (0x08, _) => Ok(Operation::TernaryArithmetic( + arithmetic::TernaryOperator::AddMod, + )), + (0x09, _) => Ok(Operation::TernaryArithmetic( + arithmetic::TernaryOperator::MulMod, + )), + (0x0a, _) => Ok(Operation::Syscall(opcode)), + (0x0b, _) => Ok(Operation::Syscall(opcode)), + (0x0c, true) => Ok(Operation::BinaryArithmetic( + arithmetic::BinaryOperator::AddFp254, + )), + (0x0d, true) => Ok(Operation::BinaryArithmetic( + arithmetic::BinaryOperator::MulFp254, + )), + (0x0e, true) => Ok(Operation::BinaryArithmetic( + arithmetic::BinaryOperator::SubFp254, + )), + (0x10, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt)), + (0x11, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt)), + (0x12, _) => Ok(Operation::Syscall(opcode)), + (0x13, _) => Ok(Operation::Syscall(opcode)), + (0x14, _) => Ok(Operation::Eq), + (0x15, _) => Ok(Operation::Iszero), + (0x16, _) => Ok(Operation::BinaryLogic(logic::Op::And)), + (0x17, _) => Ok(Operation::BinaryLogic(logic::Op::Or)), + (0x18, _) => Ok(Operation::BinaryLogic(logic::Op::Xor)), + (0x19, _) => Ok(Operation::Not), + (0x1a, _) => Ok(Operation::Byte), + (0x1b, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl)), + (0x1c, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr)), + (0x1d, _) => Ok(Operation::Syscall(opcode)), + (0x20, _) => Ok(Operation::Syscall(opcode)), + (0x21, true) => Ok(Operation::KeccakGeneral), + (0x30, _) => Ok(Operation::Syscall(opcode)), + (0x31, _) => Ok(Operation::Syscall(opcode)), + (0x32, _) => Ok(Operation::Syscall(opcode)), + (0x33, _) => Ok(Operation::Syscall(opcode)), + (0x34, _) => Ok(Operation::Syscall(opcode)), + (0x35, _) => Ok(Operation::Syscall(opcode)), + (0x36, _) => Ok(Operation::Syscall(opcode)), + (0x37, _) => Ok(Operation::Syscall(opcode)), + (0x38, _) => Ok(Operation::Syscall(opcode)), + (0x39, _) => Ok(Operation::Syscall(opcode)), + (0x3a, _) => Ok(Operation::Syscall(opcode)), + (0x3b, _) => Ok(Operation::Syscall(opcode)), + (0x3c, _) => Ok(Operation::Syscall(opcode)), + (0x3d, _) => Ok(Operation::Syscall(opcode)), + (0x3e, _) => Ok(Operation::Syscall(opcode)), + (0x3f, _) => Ok(Operation::Syscall(opcode)), + (0x40, _) => Ok(Operation::Syscall(opcode)), + (0x41, _) => Ok(Operation::Syscall(opcode)), + (0x42, _) => Ok(Operation::Syscall(opcode)), + (0x43, _) => Ok(Operation::Syscall(opcode)), + (0x44, _) => Ok(Operation::Syscall(opcode)), + (0x45, _) => Ok(Operation::Syscall(opcode)), + (0x46, _) => Ok(Operation::Syscall(opcode)), + (0x47, _) => Ok(Operation::Syscall(opcode)), + (0x48, _) => Ok(Operation::Syscall(opcode)), + (0x49, _) => Ok(Operation::ProverInput), + (0x50, _) => Ok(Operation::Pop), + (0x51, _) => Ok(Operation::Syscall(opcode)), + (0x52, _) => Ok(Operation::Syscall(opcode)), + (0x53, _) => Ok(Operation::Syscall(opcode)), + (0x54, _) => Ok(Operation::Syscall(opcode)), + (0x55, _) => Ok(Operation::Syscall(opcode)), + (0x56, _) => Ok(Operation::Jump), + (0x57, _) => Ok(Operation::Jumpi), + (0x58, _) => Ok(Operation::Pc), + (0x59, _) => Ok(Operation::Syscall(opcode)), + (0x5a, _) => Ok(Operation::Gas), + (0x5b, _) => Ok(Operation::Jumpdest), + (0x60..=0x7f, _) => Ok(Operation::Push(opcode & 0x1f)), + (0x80..=0x8f, _) => Ok(Operation::Dup(opcode & 0xf)), + (0x90..=0x9f, _) => Ok(Operation::Swap(opcode & 0xf)), + (0xa0, _) => Ok(Operation::Syscall(opcode)), + (0xa1, _) => Ok(Operation::Syscall(opcode)), + (0xa2, _) => Ok(Operation::Syscall(opcode)), + (0xa3, _) => Ok(Operation::Syscall(opcode)), + (0xa4, _) => Ok(Operation::Syscall(opcode)), + (0xf0, _) => Ok(Operation::Syscall(opcode)), + (0xf1, _) => Ok(Operation::Syscall(opcode)), + (0xf2, _) => Ok(Operation::Syscall(opcode)), + (0xf3, _) => Ok(Operation::Syscall(opcode)), + (0xf4, _) => Ok(Operation::Syscall(opcode)), + (0xf5, _) => Ok(Operation::Syscall(opcode)), + (0xf6, true) => Ok(Operation::GetContext), + (0xf7, true) => Ok(Operation::SetContext), + (0xf8, true) => Ok(Operation::ConsumeGas), + (0xf9, true) => Ok(Operation::ExitKernel), + (0xfa, _) => Ok(Operation::Syscall(opcode)), + (0xfb, true) => Ok(Operation::MloadGeneral), + (0xfc, true) => Ok(Operation::MstoreGeneral), + (0xfd, _) => Ok(Operation::Syscall(opcode)), + (0xff, _) => Ok(Operation::Syscall(opcode)), + _ => Err(ProgramError::InvalidOpcode), + } +} + +fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { + let flags = &mut row.op; + *match op { + Operation::Push(_) => &mut flags.push, + Operation::Dup(_) => &mut flags.dup, + Operation::Swap(_) => &mut flags.swap, + Operation::Iszero => &mut flags.iszero, + Operation::Not => &mut flags.not, + Operation::Byte => &mut flags.byte, + Operation::Syscall(_) => &mut flags.syscall, + Operation::Eq => &mut flags.eq, + Operation::BinaryLogic(logic::Op::And) => &mut flags.and, + Operation::BinaryLogic(logic::Op::Or) => &mut flags.or, + Operation::BinaryLogic(logic::Op::Xor) => &mut flags.xor, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add) => &mut flags.add, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul) => &mut flags.mul, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Sub) => &mut flags.sub, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Div) => &mut flags.div, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod) => &mut flags.mod_, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt) => &mut flags.lt, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt) => &mut flags.gt, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => &mut flags.shl, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => &mut flags.shr, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) => &mut flags.addfp254, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) => &mut flags.mulfp254, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.subfp254, + Operation::TernaryArithmetic(arithmetic::TernaryOperator::AddMod) => &mut flags.addmod, + Operation::TernaryArithmetic(arithmetic::TernaryOperator::MulMod) => &mut flags.mulmod, + Operation::KeccakGeneral => &mut flags.keccak_general, + Operation::ProverInput => &mut flags.prover_input, + Operation::Pop => &mut flags.pop, + Operation::Jump => &mut flags.jump, + Operation::Jumpi => &mut flags.jumpi, + Operation::Pc => &mut flags.pc, + Operation::Gas => &mut flags.gas, + Operation::Jumpdest => &mut flags.jumpdest, + Operation::GetContext => &mut flags.get_context, + Operation::SetContext => &mut flags.set_context, + Operation::ConsumeGas => &mut flags.consume_gas, + Operation::ExitKernel => &mut flags.exit_kernel, + Operation::MloadGeneral => &mut flags.mload_general, + Operation::MstoreGeneral => &mut flags.mstore_general, + } = F::ONE; +} + +fn perform_op( + state: &mut GenerationState, + op: Operation, + row: CpuColumnsView, +) -> Result<(), ProgramError> { + match op { + Operation::Push(n) => generate_push(n, state, row)?, + Operation::Dup(n) => generate_dup(n, state, row)?, + Operation::Swap(n) => generate_swap(n, state, row)?, + Operation::Iszero => generate_iszero(state, row)?, + Operation::Not => generate_not(state, row)?, + Operation::Byte => generate_byte(state, row)?, + Operation::Syscall(opcode) => generate_syscall(opcode, state, row)?, + Operation::Eq => generate_eq(state, row)?, + Operation::BinaryLogic(binary_logic_op) => { + generate_binary_logic_op(binary_logic_op, state, row)? + } + Operation::BinaryArithmetic(op) => generate_binary_arithmetic_op(op, state, row)?, + Operation::TernaryArithmetic(op) => generate_ternary_arithmetic_op(op, state, row)?, + Operation::KeccakGeneral => generate_keccak_general(state, row)?, + Operation::ProverInput => generate_prover_input(state, row)?, + Operation::Pop => generate_pop(state, row)?, + Operation::Jump => generate_jump(state, row)?, + Operation::Jumpi => generate_jumpi(state, row)?, + Operation::Pc => todo!(), + Operation::Gas => todo!(), + Operation::Jumpdest => todo!(), + Operation::GetContext => todo!(), + Operation::SetContext => todo!(), + Operation::ConsumeGas => todo!(), + Operation::ExitKernel => generate_exit_kernel(state, row)?, + Operation::MloadGeneral => generate_mload_general(state, row)?, + Operation::MstoreGeneral => generate_mstore_general(state, row)?, + }; + + state.registers.program_counter += match op { + Operation::Syscall(_) | Operation::ExitKernel => 0, + Operation::Push(n) => n as usize + 2, + Operation::Jump | Operation::Jumpi => 0, + _ => 1, + }; + + if let Some(label) = KERNEL.offset_label(state.registers.program_counter) { + if !label.starts_with("halt_pc") { + log::debug!("At {label}"); + } + } + + Ok(()) +} + +fn try_perform_instruction(state: &mut GenerationState) -> Result<(), ProgramError> { + let mut row: CpuColumnsView = CpuColumnsView::default(); + row.is_cpu_cycle = F::ONE; + row.clock = F::from_canonical_usize(state.traces.clock()); + row.context = F::from_canonical_usize(state.registers.context); + row.program_counter = F::from_canonical_usize(state.registers.program_counter); + row.is_kernel_mode = F::from_bool(state.registers.is_kernel); + row.stack_len = F::from_canonical_usize(state.registers.stack_len); + + let opcode = read_code_memory(state, &mut row); + let op = decode(state.registers, opcode)?; + let pc = state.registers.program_counter; + + log::trace!("\nCycle {}", state.traces.clock()); + log::trace!( + "Stack: {:?}", + (0..state.registers.stack_len) + .map(|i| stack_peek(state, i).unwrap()) + .collect_vec() + ); + log::trace!("Executing {:?} at {}", op, KERNEL.offset_name(pc)); + fill_op_flag(op, &mut row); + + perform_op(state, op, row) +} + +fn handle_error(_state: &mut GenerationState) { + todo!("generation for exception handling is not implemented"); +} + +pub(crate) fn transition(state: &mut GenerationState) { + let checkpoint = state.checkpoint(); + let result = try_perform_instruction(state); + + match result { + Ok(()) => { + state + .memory + .apply_ops(state.traces.mem_ops_since(checkpoint.traces)); + } + Err(e) => { + if state.registers.is_kernel { + panic!("exception in kernel mode: {:?}", e); + } + state.rollback(checkpoint); + handle_error(state) + } + } +} diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs new file mode 100644 index 00000000..08d68edc --- /dev/null +++ b/evm/src/witness/util.rs @@ -0,0 +1,250 @@ +use ethereum_types::U256; +use plonky2::field::types::Field; + +use crate::cpu::columns::CpuColumnsView; +use crate::cpu::kernel::keccak_util::keccakf_u8s; +use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; +use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE; +use crate::generation::state::GenerationState; +use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_WIDTH_BYTES}; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; +use crate::logic; +use crate::memory::segments::Segment; +use crate::witness::errors::ProgramError; +use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; + +fn to_byte_checked(n: U256) -> u8 { + let res = n.byte(0); + assert_eq!(n, res.into()); + res +} + +fn to_bits_le(n: u8) -> [F; 8] { + let mut res = [F::ZERO; 8]; + for (i, bit) in res.iter_mut().enumerate() { + *bit = F::from_bool(n & (1 << i) != 0); + } + res +} + +/// Peak at the stack item `i`th from the top. If `i=0` this gives the tip. +pub(crate) fn stack_peek(state: &GenerationState, i: usize) -> Option { + if i >= state.registers.stack_len { + return None; + } + Some(state.memory.get(MemoryAddress::new( + state.registers.effective_context(), + Segment::Stack, + state.registers.stack_len - 1 - i, + ))) +} + +pub(crate) fn mem_read_with_log( + channel: MemoryChannel, + address: MemoryAddress, + state: &GenerationState, +) -> (U256, MemoryOp) { + let val = state.memory.get(address); + let op = MemoryOp::new( + channel, + state.traces.clock(), + address, + MemoryOpKind::Read, + val, + ); + (val, op) +} + +pub(crate) fn mem_write_log( + channel: MemoryChannel, + address: MemoryAddress, + state: &mut GenerationState, + val: U256, +) -> MemoryOp { + MemoryOp::new( + channel, + state.traces.clock(), + address, + MemoryOpKind::Write, + val, + ) +} + +pub(crate) fn mem_read_code_with_log_and_fill( + address: MemoryAddress, + state: &GenerationState, + row: &mut CpuColumnsView, +) -> (u8, MemoryOp) { + let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state); + + let val_u8 = to_byte_checked(val); + row.opcode_bits = to_bits_le(val_u8); + + (val_u8, op) +} + +pub(crate) fn mem_read_gp_with_log_and_fill( + n: usize, + address: MemoryAddress, + state: &mut GenerationState, + row: &mut CpuColumnsView, +) -> (U256, MemoryOp) { + let (val, op) = mem_read_with_log(MemoryChannel::GeneralPurpose(n), address, state); + let val_limbs: [u64; 4] = val.0; + + let channel = &mut row.mem_channels[n]; + assert_eq!(channel.used, F::ZERO); + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(address.context); + channel.addr_segment = F::from_canonical_usize(address.segment); + channel.addr_virtual = F::from_canonical_usize(address.virt); + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } + + (val, op) +} + +pub(crate) fn mem_write_gp_log_and_fill( + n: usize, + address: MemoryAddress, + state: &mut GenerationState, + row: &mut CpuColumnsView, + val: U256, +) -> MemoryOp { + let op = mem_write_log(MemoryChannel::GeneralPurpose(n), address, state, val); + let val_limbs: [u64; 4] = val.0; + + let channel = &mut row.mem_channels[n]; + assert_eq!(channel.used, F::ZERO); + channel.used = F::ONE; + channel.is_read = F::ZERO; + channel.addr_context = F::from_canonical_usize(address.context); + channel.addr_segment = F::from_canonical_usize(address.segment); + channel.addr_virtual = F::from_canonical_usize(address.virt); + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } + + op +} + +pub(crate) fn stack_pop_with_log_and_fill( + state: &mut GenerationState, + row: &mut CpuColumnsView, +) -> Result<[(U256, MemoryOp); N], ProgramError> { + if state.registers.stack_len < N { + return Err(ProgramError::StackUnderflow); + } + + let result = std::array::from_fn(|i| { + let address = MemoryAddress::new( + state.registers.effective_context(), + Segment::Stack, + state.registers.stack_len - 1 - i, + ); + mem_read_gp_with_log_and_fill(i, address, state, row) + }); + + state.registers.stack_len -= N; + + Ok(result) +} + +pub(crate) fn stack_push_log_and_fill( + state: &mut GenerationState, + row: &mut CpuColumnsView, + val: U256, +) -> Result { + if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { + return Err(ProgramError::StackOverflow); + } + + let address = MemoryAddress::new( + state.registers.effective_context(), + Segment::Stack, + state.registers.stack_len, + ); + let res = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 1, address, state, row, val); + + state.registers.stack_len += 1; + + Ok(res) +} + +fn xor_into_sponge( + state: &mut GenerationState, + sponge_state: &mut [u8; KECCAK_WIDTH_BYTES], + block: &[u8; KECCAK_RATE_BYTES], +) { + for i in (0..KECCAK_RATE_BYTES).step_by(32) { + let range = i..KECCAK_RATE_BYTES.min(i + 32); + let lhs = U256::from_little_endian(&sponge_state[range.clone()]); + let rhs = U256::from_little_endian(&block[range]); + state + .traces + .push_logic(logic::Operation::new(logic::Op::Xor, lhs, rhs)); + } + for i in 0..KECCAK_RATE_BYTES { + sponge_state[i] ^= block[i]; + } +} + +pub(crate) fn keccak_sponge_log( + state: &mut GenerationState, + base_address: MemoryAddress, + input: Vec, +) { + let clock = state.traces.clock(); + + let mut address = base_address; + let mut input_blocks = input.chunks_exact(KECCAK_RATE_BYTES); + let mut sponge_state = [0u8; KECCAK_WIDTH_BYTES]; + for block in input_blocks.by_ref() { + for &byte in block { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Read, + byte.into(), + )); + address.increment(); + } + xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap()); + state.traces.push_keccak_bytes(sponge_state); + keccakf_u8s(&mut sponge_state); + } + + for &byte in input_blocks.remainder() { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Read, + byte.into(), + )); + address.increment(); + } + let mut final_block = [0u8; KECCAK_RATE_BYTES]; + final_block[..input_blocks.remainder().len()].copy_from_slice(input_blocks.remainder()); + // pad10*1 rule + if input_blocks.remainder().len() == KECCAK_RATE_BYTES - 1 { + // Both 1s are placed in the same byte. + final_block[input_blocks.remainder().len()] = 0b10000001; + } else { + final_block[input_blocks.remainder().len()] = 1; + final_block[KECCAK_RATE_BYTES - 1] = 0b10000000; + } + xor_into_sponge(state, &mut sponge_state, &final_block); + state.traces.push_keccak_bytes(sponge_state); + + state.traces.push_keccak_sponge(KeccakSpongeOp { + base_address, + timestamp: clock * NUM_CHANNELS, + input, + }); +} diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index 6e16fa47..abeef644 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; +use eth_trie_utils::partial_trie::PartialTrie; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::util::timing::TimingTree; @@ -17,20 +18,15 @@ type C = PoseidonGoldilocksConfig; /// Execute the empty list of transactions, i.e. a no-op. #[test] -#[ignore] // TODO: Won't work until witness generation logic is finished. fn test_empty_txn_list() -> anyhow::Result<()> { + init_logger(); + let all_stark = AllStark::::default(); let config = StarkConfig::standard_fast_config(); let block_metadata = BlockMetadata::default(); - let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 5, - packed: 0xABCDE.into(), - }, - value: vec![1, 2, 3], - }; + let state_trie = PartialTrie::Empty; let transactions_trie = PartialTrie::Empty; let receipts_trie = PartialTrie::Empty; let storage_tries = vec![]; @@ -51,7 +47,10 @@ fn test_empty_txn_list() -> anyhow::Result<()> { block_metadata, }; - let proof = prove::(&all_stark, &config, inputs, &mut TimingTree::default())?; + let mut timing = TimingTree::new("prove", log::Level::Debug); + let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + timing.print(); + assert_eq!( proof.public_values.trie_roots_before.state_root, state_trie_root @@ -79,3 +78,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { verify_proof(all_stark, proof, &config) } + +fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); +} diff --git a/plonky2/Cargo.toml b/plonky2/Cargo.toml index ab131463..9474095e 100644 --- a/plonky2/Cargo.toml +++ b/plonky2/Cargo.toml @@ -46,7 +46,7 @@ structopt = { version = "0.3.26", default-features = false } tynm = { version = "0.1.6", default-features = false } [target.'cfg(not(target_env = "msvc"))'.dev-dependencies] -jemallocator = "0.3.2" +jemallocator = "0.5.0" [[bin]] name = "generate_constants" diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 27101d1a..059ca963 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -105,17 +105,14 @@ where { let (inner_proof, inner_vd, inner_cd) = inner; let mut builder = CircuitBuilder::::new(config.clone()); - let mut pw = PartialWitness::new(); let pt = builder.add_virtual_proof_with_pis::(inner_cd); - pw.set_proof_with_pis_target(&pt, inner_proof); let inner_data = VerifierCircuitTarget { constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), circuit_digest: builder.add_virtual_hash(), }; - pw.set_verifier_data_target(&inner_data, inner_vd); - builder.verify_proof::(pt, &inner_data, inner_cd); + builder.verify_proof::(&pt, &inner_data, inner_cd); builder.print_gate_counts(0); if let Some(min_degree_bits) = min_degree_bits { @@ -131,6 +128,10 @@ where let data = builder.build::(); + let mut pw = PartialWitness::new(); + pw.set_proof_with_pis_target(&pt, inner_proof); + pw.set_verifier_data_target(&inner_data, inner_vd); + let mut timing = TimingTree::new("prove", Level::Debug); let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; timing.print(); diff --git a/plonky2/src/fri/proof.rs b/plonky2/src/fri/proof.rs index 4c5f65d8..f841b274 100644 --- a/plonky2/src/fri/proof.rs +++ b/plonky2/src/fri/proof.rs @@ -112,7 +112,7 @@ pub struct FriProof, H: Hasher, const D: usize> pub pow_witness: F, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct FriProofTarget { pub commit_phase_merkle_caps: Vec, pub query_round_proofs: Vec>, diff --git a/plonky2/src/gadgets/polynomial.rs b/plonky2/src/gadgets/polynomial.rs index f7a59192..80beb62f 100644 --- a/plonky2/src/gadgets/polynomial.rs +++ b/plonky2/src/gadgets/polynomial.rs @@ -7,7 +7,7 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::reducing::ReducingFactorTarget; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct PolynomialCoeffsExtTarget(pub Vec>); impl PolynomialCoeffsExtTarget { diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index b95a2113..c725c45c 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -1,5 +1,6 @@ use alloc::vec::Vec; +use anyhow::ensure; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::field::goldilocks_field::GoldilocksField; @@ -25,6 +26,7 @@ impl HashOut { elements: [F::ZERO; 4], }; + // TODO: Switch to a TryFrom impl. pub fn from_vec(elements: Vec) -> Self { debug_assert!(elements.len() == 4); Self { @@ -39,6 +41,23 @@ impl HashOut { } } +impl From<[F; 4]> for HashOut { + fn from(elements: [F; 4]) -> Self { + Self { elements } + } +} + +impl TryFrom<&[F]> for HashOut { + type Error = anyhow::Error; + + fn try_from(elements: &[F]) -> Result { + ensure!(elements.len() == 4); + Ok(Self { + elements: elements.try_into().unwrap(), + }) + } +} + impl Sample for HashOut where F: Field, @@ -97,6 +116,7 @@ pub struct HashOutTarget { } impl HashOutTarget { + // TODO: Switch to a TryFrom impl. pub fn from_vec(elements: Vec) -> Self { debug_assert!(elements.len() == 4); Self { @@ -111,6 +131,23 @@ impl HashOutTarget { } } +impl From<[Target; 4]> for HashOutTarget { + fn from(elements: [Target; 4]) -> Self { + Self { elements } + } +} + +impl TryFrom<&[Target]> for HashOutTarget { + type Error = anyhow::Error; + + fn try_from(elements: &[Target]) -> Result { + ensure!(elements.len() == 4); + Ok(Self { + elements: elements.try_into().unwrap(), + }) + } +} + #[derive(Clone, Debug)] pub struct MerkleCapTarget(pub Vec); diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 92f1dca0..86871701 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -135,7 +135,9 @@ impl> MerkleTree { let log2_leaves_len = log2_strict(leaves.len()); assert!( cap_height <= log2_leaves_len, - "cap height should be at most log2(leaves.len())" + "cap_height={} should be at most log2(leaves.len())={}", + cap_height, + log2_leaves_len ); let num_digests = 2 * (leaves.len() - (1 << cap_height)); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 8bd1d994..6bad9296 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -244,9 +244,15 @@ impl, const D: usize> CircuitBuilder { self.register_public_input(t); t } + /// Add a virtual verifier data, register it as a public input and set it to `self.verifier_data_public_input`. /// WARNING: Do not register any public input after calling this! TODO: relax this - pub(crate) fn add_verifier_data_public_input(&mut self) { + pub fn add_verifier_data_public_inputs(&mut self) { + assert!( + self.verifier_data_public_input.is_none(), + "add_verifier_data_public_inputs only needs to be called once" + ); + let verifier_data = VerifierCircuitTarget { constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), circuit_digest: self.add_virtual_hash(), @@ -886,7 +892,7 @@ impl, const D: usize> CircuitBuilder { num_partial_products, }; if let Some(goal_data) = self.goal_common_data { - assert_eq!(goal_data, common); + assert_eq!(goal_data, common, "The expected circuit data passed to cyclic recursion method did not match the actual circuit"); } let prover_only = ProverOnlyCircuitData { diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index caf3a7f8..fb9e6cde 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -40,7 +40,7 @@ pub struct Proof, C: GenericConfig, const pub opening_proof: FriProof, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ProofTarget { pub wires_cap: MerkleCapTarget, pub plonk_zs_partial_products_cap: MerkleCapTarget, @@ -283,7 +283,7 @@ pub(crate) struct FriInferredElements, const D: usi pub Vec, ); -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ProofWithPublicInputsTarget { pub proof: ProofTarget, pub public_inputs: Vec, diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index 6cbce94a..be7ed028 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -1,7 +1,5 @@ -use alloc::vec; use alloc::vec::Vec; -use anyhow::{ensure, Result}; use itertools::Itertools; use crate::field::extension::Extendable; @@ -9,66 +7,16 @@ use crate::fri::proof::{ FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, }; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; -use crate::gates::noop::NoopGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::circuit_data::{ - CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, -}; +use crate::plonk::circuit_data::{CommonCircuitData, VerifierCircuitTarget}; use crate::plonk::config::{AlgebraicHasher, GenericConfig}; -use crate::plonk::proof::{ - OpeningSetTarget, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget, -}; -use crate::util::ceil_div_usize; +use crate::plonk::proof::{OpeningSetTarget, ProofTarget, ProofWithPublicInputsTarget}; use crate::with_context; -/// Generate a proof having a given `CommonCircuitData`. -pub(crate) fn dummy_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, ->( - common_data: &CommonCircuitData, -) -> Result<( - ProofWithPublicInputs, - VerifierOnlyCircuitData, -)> { - let config = common_data.config.clone(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - ensure!( - !common_data.config.zero_knowledge, - "Degree calculation can be off if zero-knowledge is on." - ); - let degree = common_data.degree(); - // Number of `NoopGate`s to add to get a circuit of size `degree` in the end. - // Need to account for public input hashing, a `PublicInputGate` and a `ConstantGate`. - let num_noop_gate = degree - ceil_div_usize(common_data.num_public_inputs, 8) - 2; - - for _ in 0..num_noop_gate { - builder.add_gate(NoopGate, vec![]); - } - for gate in &common_data.gates { - builder.add_gate_to_gate_set(gate.clone()); - } - for _ in 0..common_data.num_public_inputs { - let t = builder.add_virtual_public_input(); - pw.set_target(t, F::ZERO); - } - - let data = builder.build::(); - assert_eq!(&data.common, common_data); - let proof = data.prove(pw)?; - - Ok((proof, data.verifier_only)) -} - impl, const D: usize> CircuitBuilder { /// Verify `proof0` if `condition` else verify `proof1`. /// `proof0` and `proof1` are assumed to use the same `CommonCircuitData`. @@ -143,7 +91,7 @@ impl, const D: usize> CircuitBuilder { ), }; - self.verify_proof::(selected_proof, &selected_verifier_data, inner_common_data); + self.verify_proof::(&selected_proof, &selected_verifier_data, inner_common_data); } /// Conditionally verify a proof with a new generated dummy proof. @@ -369,6 +317,7 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { use anyhow::Result; + use hashbrown::HashMap; use super::*; use crate::field::types::Sample; @@ -376,6 +325,7 @@ mod tests { use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::recursion::dummy_circuit::{dummy_circuit, dummy_proof}; #[test] fn test_conditional_recursive_verifier() -> Result<()> { @@ -400,7 +350,8 @@ mod tests { data.verify(proof.clone())?; // Generate dummy proof with the same `CommonCircuitData`. - let (dummy_proof, dummy_data) = dummy_proof(&data.common)?; + let dummy_data = dummy_circuit(&data.common); + let dummy_proof = dummy_proof(&dummy_data, HashMap::new())?; // Conditionally verify the two proofs. let mut builder = CircuitBuilder::::new(config); @@ -418,7 +369,7 @@ mod tests { constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), circuit_digest: builder.add_virtual_hash(), }; - pw.set_verifier_data_target(&dummy_inner_data, &dummy_data); + pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only); let b = builder.constant_bool(F::rand().0 % 2 == 0); builder.conditionally_verify_proof::( b, diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index 2e4f2613..497d655b 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -3,6 +3,7 @@ use alloc::vec; use anyhow::{ensure, Result}; +use hashbrown::HashMap; use itertools::Itertools; use crate::field::extension::Extendable; @@ -13,11 +14,11 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::{ - CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, + CircuitData, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, }; use crate::plonk::config::{AlgebraicHasher, GenericConfig}; use crate::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; -use crate::recursion::conditional_recursive_verifier::dummy_proof; +use crate::recursion::dummy_circuit::{dummy_circuit, dummy_proof}; pub struct CyclicRecursionData< 'a, @@ -30,12 +31,17 @@ pub struct CyclicRecursionData< common_data: &'a CommonCircuitData, } -pub struct CyclicRecursionTarget { - pub proof: ProofWithPublicInputsTarget, - pub verifier_data: VerifierCircuitTarget, - pub dummy_proof: ProofWithPublicInputsTarget, - pub dummy_verifier_data: VerifierCircuitTarget, - pub condition: BoolTarget, +pub struct CyclicRecursionTarget +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) proof: ProofWithPublicInputsTarget, + pub(crate) verifier_data: VerifierCircuitTarget, + pub(crate) dummy_proof: ProofWithPublicInputsTarget, + pub(crate) dummy_verifier_data: VerifierCircuitTarget, + pub(crate) condition: BoolTarget, + pub(crate) dummy_circuit: CircuitData, } impl, const D: usize> VerifierOnlyCircuitData { @@ -107,17 +113,16 @@ impl, const D: usize> CircuitBuilder { pub fn cyclic_recursion>( &mut self, condition: BoolTarget, - previous_virtual_public_inputs: &[Target], - common_data: &mut CommonCircuitData, - ) -> Result> + proof_with_pis: &ProofWithPublicInputsTarget, + common_data: &CommonCircuitData, + ) -> Result> where C::Hasher: AlgebraicHasher, { - if self.verifier_data_public_input.is_none() { - self.add_verifier_data_public_input(); - } - let verifier_data = self.verifier_data_public_input.clone().unwrap(); - common_data.num_public_inputs = self.num_public_inputs(); + let verifier_data = self + .verifier_data_public_input + .clone() + .expect("Must call add_verifier_data_public_inputs before cyclic recursion"); self.goal_common_data = Some(common_data.clone()); let dummy_verifier_data = VerifierCircuitTarget { @@ -125,10 +130,12 @@ impl, const D: usize> CircuitBuilder { circuit_digest: self.add_virtual_hash(), }; - let proof = self.add_virtual_proof_with_pis::(common_data); let dummy_proof = self.add_virtual_proof_with_pis::(common_data); - let pis = VerifierCircuitTarget::from_slice::(&proof.public_inputs, common_data)?; + let pis = VerifierCircuitTarget::from_slice::( + &proof_with_pis.public_inputs, + common_data, + )?; // Connect previous verifier data to current one. This guarantees that every proof in the cycle uses the same verifier data. self.connect_hashes(pis.circuit_digest, verifier_data.circuit_digest); for (h0, h1) in pis @@ -140,17 +147,10 @@ impl, const D: usize> CircuitBuilder { self.connect_hashes(*h0, *h1); } - for (x, y) in previous_virtual_public_inputs - .iter() - .zip(&proof.public_inputs) - { - self.connect(*x, *y); - } - // Verify the real proof if `condition` is set to true, otherwise verify the dummy proof. self.conditionally_verify_proof::( condition, - &proof, + proof_with_pis, &verifier_data, &dummy_proof, &dummy_verifier_data, @@ -167,26 +167,29 @@ impl, const D: usize> CircuitBuilder { } Ok(CyclicRecursionTarget { - proof, - verifier_data: verifier_data.clone(), + proof: proof_with_pis.clone(), + verifier_data, dummy_proof, dummy_verifier_data, condition, + dummy_circuit: dummy_circuit(common_data), }) } } /// Set the targets in a `CyclicRecursionTarget` to their corresponding values in a `CyclicRecursionData`. +/// The `public_inputs` parameter let the caller specify certain public inputs (identified by their +/// indices) which should be given specific values. The rest will default to zero. pub fn set_cyclic_recursion_data_target< F: RichField + Extendable, C: GenericConfig, const D: usize, >( pw: &mut PartialWitness, - cyclic_recursion_data_target: &CyclicRecursionTarget, + cyclic_recursion_data_target: &CyclicRecursionTarget, cyclic_recursion_data: &CyclicRecursionData, // Public inputs to set in the base case to seed some initial data. - public_inputs: &[F], + mut public_inputs: HashMap, ) -> Result<()> where C::Hasher: AlgebraicHasher, @@ -204,36 +207,41 @@ where cyclic_recursion_data.verifier_data, ); } else { - let (dummy_proof, dummy_data) = dummy_proof::(cyclic_recursion_data.common_data)?; pw.set_bool_target(cyclic_recursion_data_target.condition, false); - let mut proof = dummy_proof.clone(); - proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs); - let pis_len = proof.public_inputs.len(); - // The circuit checks that the verifier data is the same throughout the cycle, so - // we set the verifier data to the "real" verifier data even though it's unused in the base case. - let num_cap = cyclic_recursion_data + + let pis_len = cyclic_recursion_data_target + .dummy_circuit + .common + .num_public_inputs; + let cap_elements = cyclic_recursion_data .common_data .config .fri_config .num_cap_elements(); - let s = pis_len - 4 - 4 * num_cap; - proof.public_inputs[s..s + 4] - .copy_from_slice(&cyclic_recursion_data.verifier_data.circuit_digest.elements); - for i in 0..num_cap { - proof.public_inputs[s + 4 * (1 + i)..s + 4 * (2 + i)].copy_from_slice( - &cyclic_recursion_data.verifier_data.constants_sigmas_cap.0[i].elements, - ); + let start_vk_pis = pis_len - 4 - 4 * cap_elements; + + // The circuit checks that the verifier data is the same throughout the cycle, so + // we set the verifier data to the "real" verifier data even though it's unused in the base case. + let verifier_data = &cyclic_recursion_data.verifier_data; + public_inputs.extend((start_vk_pis..).zip(verifier_data.circuit_digest.elements)); + + for i in 0..cap_elements { + let start = start_vk_pis + 4 + 4 * i; + public_inputs.extend((start..).zip(verifier_data.constants_sigmas_cap.0[i].elements)); } + let proof = dummy_proof(&cyclic_recursion_data_target.dummy_circuit, public_inputs)?; pw.set_proof_with_pis_target(&cyclic_recursion_data_target.proof, &proof); pw.set_verifier_data_target( &cyclic_recursion_data_target.verifier_data, cyclic_recursion_data.verifier_data, ); - pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, &dummy_proof); + + let dummy_p = dummy_proof(&cyclic_recursion_data_target.dummy_circuit, HashMap::new())?; + pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, &dummy_p); pw.set_verifier_data_target( &cyclic_recursion_data_target.dummy_verifier_data, - &dummy_data, + &cyclic_recursion_data_target.dummy_circuit.verifier_only, ); } @@ -264,11 +272,12 @@ where #[cfg(test)] mod tests { use anyhow::Result; + use hashbrown::HashMap; use crate::field::extension::Extendable; use crate::field::types::{Field, PrimeField64}; use crate::gates::noop::NoopGate; - use crate::hash::hash_types::RichField; + use crate::hash::hash_types::{HashOutTarget, RichField}; use crate::hash::hashing::hash_n_to_hash_no_pad; use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; use crate::iop::witness::PartialWitness; @@ -298,7 +307,7 @@ mod tests { constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), circuit_digest: builder.add_virtual_hash(), }; - builder.verify_proof::(proof, &verifier_data, &data.common); + builder.verify_proof::(&proof, &verifier_data, &data.common); let data = builder.build::(); let config = CircuitConfig::standard_recursion_config(); @@ -308,13 +317,19 @@ mod tests { constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), circuit_digest: builder.add_virtual_hash(), }; - builder.verify_proof::(proof, &verifier_data, &data.common); + builder.verify_proof::(&proof, &verifier_data, &data.common); while builder.num_gates() < 1 << 12 { builder.add_gate(NoopGate, vec![]); } builder.build::().common } + /// Uses cyclic recursion to build a hash chain. + /// The circuit has the following public input structure: + /// - Initial hash (4) + /// - Output for the tip of the hash chain (4) + /// - Chain length, i.e. the number of times the hash has been applied (1) + /// - VK for cyclic recursion (?) #[test] fn test_cyclic_recursion() -> Result<()> { const D: usize = 2; @@ -322,55 +337,62 @@ mod tests { type F = >::F; let config = CircuitConfig::standard_recursion_config(); - let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); + let one = builder.one(); // Circuit that computes a repeated hash. let initial_hash = builder.add_virtual_hash(); builder.register_public_inputs(&initial_hash.elements); - // Hash from the previous proof. - let old_hash = builder.add_virtual_hash(); - // The input hash is either the previous hash or the initial hash depending on whether - // the last proof was a base case. - let input_hash = builder.add_virtual_hash(); - let h = builder.hash_n_to_hash_no_pad::(input_hash.elements.to_vec()); - builder.register_public_inputs(&h.elements); - // Previous counter. - let old_counter = builder.add_virtual_target(); - let new_counter = builder.add_virtual_public_input(); - let old_pis = [ - initial_hash.elements.as_slice(), - old_hash.elements.as_slice(), - [old_counter].as_slice(), - ] - .concat(); + let current_hash_in = builder.add_virtual_hash(); + let current_hash_out = + builder.hash_n_to_hash_no_pad::(current_hash_in.elements.to_vec()); + builder.register_public_inputs(¤t_hash_out.elements); + let counter = builder.add_virtual_public_input(); let mut common_data = common_data_for_recursion::(); + builder.add_verifier_data_public_inputs(); + common_data.num_public_inputs = builder.num_public_inputs(); let condition = builder.add_virtual_bool_target_safe(); - // Add cyclic recursion gadget. + + // Unpack inner proof's public inputs. + let inner_proof_with_pis = builder.add_virtual_proof_with_pis::(&common_data); + let inner_pis = &inner_proof_with_pis.public_inputs; + let inner_initial_hash = HashOutTarget::try_from(&inner_pis[0..4]).unwrap(); + let inner_latest_hash = HashOutTarget::try_from(&inner_pis[4..8]).unwrap(); + let inner_counter = inner_pis[8]; + + // Connect our initial hash to that of our inner proof. (If there is no inner proof, the + // initial hash will be unconstrained, which is intentional.) + builder.connect_hashes(initial_hash, inner_initial_hash); + + // The input hash is the previous hash output if we have an inner proof, or the initial hash + // if this is the base case. + let actual_hash_in = builder.select_hash(condition, inner_latest_hash, initial_hash); + builder.connect_hashes(current_hash_in, actual_hash_in); + + // Our chain length will be inner_counter + 1 if we have an inner proof, or 1 if not. + let new_counter = builder.mul_add(condition.target, inner_counter, one); + builder.connect(counter, new_counter); + let cyclic_data_target = - builder.cyclic_recursion::(condition, &old_pis, &mut common_data)?; - let input_hash_bis = - builder.select_hash(cyclic_data_target.condition, old_hash, initial_hash); - builder.connect_hashes(input_hash, input_hash_bis); - // New counter is the previous counter +1 if the previous proof wasn't a base case. - let new_counter_bis = builder.add(old_counter, condition.target); - builder.connect(new_counter, new_counter_bis); + builder.cyclic_recursion::(condition, &inner_proof_with_pis, &common_data)?; let cyclic_circuit_data = builder.build::(); + let mut pw = PartialWitness::new(); let cyclic_recursion_data = CyclicRecursionData { proof: &None, // Base case: We don't have a proof to put here yet. verifier_data: &cyclic_circuit_data.verifier_only, common_data: &cyclic_circuit_data.common, }; let initial_hash = [F::ZERO, F::ONE, F::TWO, F::from_canonical_usize(3)]; + let initial_hash_pis = initial_hash.into_iter().enumerate().collect(); set_cyclic_recursion_data_target( &mut pw, &cyclic_data_target, &cyclic_recursion_data, - &initial_hash, + initial_hash_pis, )?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( @@ -391,7 +413,7 @@ mod tests { &mut pw, &cyclic_data_target, &cyclic_recursion_data, - &[], + HashMap::new(), )?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( @@ -412,7 +434,7 @@ mod tests { &mut pw, &cyclic_data_target, &cyclic_recursion_data, - &[], + HashMap::new(), )?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( @@ -425,17 +447,20 @@ mod tests { let initial_hash = &proof.public_inputs[..4]; let hash = &proof.public_inputs[4..8]; let counter = proof.public_inputs[8]; - let mut h: [F; 4] = initial_hash.try_into().unwrap(); - assert_eq!( - hash, - core::iter::repeat_with(|| { - h = hash_n_to_hash_no_pad::(&h).elements; - h - }) - .nth(counter.to_canonical_u64() as usize) - .unwrap() + let expected_hash: [F; 4] = iterate_poseidon( + initial_hash.try_into().unwrap(), + counter.to_canonical_u64() as usize, ); + assert_eq!(hash, expected_hash); cyclic_circuit_data.verify(proof) } + + fn iterate_poseidon(initial_state: [F; 4], n: usize) -> [F; 4] { + let mut current = initial_state; + for _ in 0..n { + current = hash_n_to_hash_no_pad::(¤t).elements; + } + current + } } diff --git a/plonky2/src/recursion/dummy_circuit.rs b/plonky2/src/recursion/dummy_circuit.rs new file mode 100644 index 00000000..4012b5e6 --- /dev/null +++ b/plonky2/src/recursion/dummy_circuit.rs @@ -0,0 +1,67 @@ +use alloc::vec; + +use hashbrown::HashMap; +use plonky2_field::extension::Extendable; +use plonky2_util::ceil_div_usize; + +use crate::gates::noop::NoopGate; +use crate::hash::hash_types::RichField; +use crate::iop::witness::{PartialWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::{CircuitData, CommonCircuitData}; +use crate::plonk::config::GenericConfig; +use crate::plonk::proof::ProofWithPublicInputs; + +/// Generate a proof for a dummy circuit. The `public_inputs` parameter let the caller specify +/// certain public inputs (identified by their indices) which should be given specific values. +/// The rest will default to zero. +pub(crate) fn dummy_proof( + circuit: &CircuitData, + nonzero_public_inputs: HashMap, +) -> anyhow::Result> +where + F: RichField + Extendable, + C: GenericConfig, +{ + let mut pw = PartialWitness::new(); + for i in 0..circuit.common.num_public_inputs { + let pi = nonzero_public_inputs.get(&i).copied().unwrap_or_default(); + pw.set_target(circuit.prover_only.public_inputs[i], pi); + } + circuit.prove(pw) +} + +/// Generate a circuit matching a given `CommonCircuitData`. +pub(crate) fn dummy_circuit< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, +) -> CircuitData { + let config = common_data.config.clone(); + assert!( + !common_data.config.zero_knowledge, + "Degree calculation can be off if zero-knowledge is on." + ); + + // Number of `NoopGate`s to add to get a circuit of size `degree` in the end. + // Need to account for public input hashing, a `PublicInputGate` and a `ConstantGate`. + let degree = common_data.degree(); + let num_noop_gate = degree - ceil_div_usize(common_data.num_public_inputs, 8) - 2; + + let mut builder = CircuitBuilder::::new(config); + for _ in 0..num_noop_gate { + builder.add_gate(NoopGate, vec![]); + } + for gate in &common_data.gates { + builder.add_gate_to_gate_set(gate.clone()); + } + for _ in 0..common_data.num_public_inputs { + builder.add_virtual_public_input(); + } + + let circuit = builder.build::(); + assert_eq!(&circuit.common, common_data); + circuit +} diff --git a/plonky2/src/recursion/mod.rs b/plonky2/src/recursion/mod.rs index 33e8212e..3aba4ffd 100644 --- a/plonky2/src/recursion/mod.rs +++ b/plonky2/src/recursion/mod.rs @@ -1,3 +1,4 @@ pub mod conditional_recursive_verifier; pub mod cyclic_recursion; +pub(crate) mod dummy_circuit; pub mod recursive_verifier; diff --git a/plonky2/src/recursion/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs index 0854dad8..d53095a4 100644 --- a/plonky2/src/recursion/recursive_verifier.rs +++ b/plonky2/src/recursion/recursive_verifier.rs @@ -16,7 +16,7 @@ impl, const D: usize> CircuitBuilder { /// Recursively verifies an inner proof. pub fn verify_proof>( &mut self, - proof_with_pis: ProofWithPublicInputsTarget, + proof_with_pis: &ProofWithPublicInputsTarget, inner_verifier_data: &VerifierCircuitTarget, inner_common_data: &CommonCircuitData, ) where @@ -36,7 +36,7 @@ impl, const D: usize> CircuitBuilder { ); self.verify_proof_with_challenges::( - proof_with_pis.proof, + &proof_with_pis.proof, public_inputs_hash, challenges, inner_verifier_data, @@ -47,7 +47,7 @@ impl, const D: usize> CircuitBuilder { /// Recursively verifies an inner proof. fn verify_proof_with_challenges>( &mut self, - proof: ProofTarget, + proof: &ProofTarget, public_inputs_hash: HashOutTarget, challenges: ProofChallengesTarget, inner_verifier_data: &VerifierCircuitTarget, @@ -106,9 +106,9 @@ impl, const D: usize> CircuitBuilder { let merkle_caps = &[ inner_verifier_data.constants_sigmas_cap.clone(), - proof.wires_cap, - proof.plonk_zs_partial_products_cap, - proof.quotient_polys_cap, + proof.wires_cap.clone(), + proof.plonk_zs_partial_products_cap.clone(), + proof.quotient_polys_cap.clone(), ]; let fri_instance = inner_common_data.get_fri_instance_target(self, challenges.plonk_zeta); @@ -376,7 +376,7 @@ mod tests { ); pw.set_hash_target(inner_data.circuit_digest, inner_vd.circuit_digest); - builder.verify_proof::(pt, &inner_data, &inner_cd); + builder.verify_proof::(&pt, &inner_data, &inner_cd); if print_gate_counts { builder.print_gate_counts(0); diff --git a/system_zero/Cargo.toml b/system_zero/Cargo.toml index 6a36ee25..58a5e489 100644 --- a/system_zero/Cargo.toml +++ b/system_zero/Cargo.toml @@ -23,4 +23,4 @@ name = "lookup_permuted_cols" harness = false [target.'cfg(not(target_env = "msvc"))'.dev-dependencies] -jemallocator = "0.3.2" +jemallocator = "0.5.0"