From d6be2b987b0c103eed6ffe675fc7a9d8e4c6e85b Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Fri, 22 Sep 2023 09:19:13 -0400 Subject: [PATCH 1/6] Remove `generic_const_exprs` feature from EVM crate (#1246) * Remove const_generic_exprs feature from EVM crate * Get a generic impl of StarkFrame --- evm/src/arithmetic/arithmetic_stark.rs | 25 ++- evm/src/byte_packing/byte_packing_stark.rs | 122 +++++++------ evm/src/cpu/bootstrap_kernel.rs | 21 +-- evm/src/cpu/cpu_stark.rs | 43 +++-- evm/src/cross_table_lookup.rs | 20 ++- evm/src/evaluation_frame.rs | 47 +++++ evm/src/fixed_recursive_verifier.rs | 24 +-- evm/src/keccak/keccak_stark.rs | 178 ++++++++++--------- evm/src/keccak/round_flags.rs | 38 ++-- evm/src/keccak_sponge/keccak_sponge_stark.rs | 33 +++- evm/src/lib.rs | 3 +- evm/src/logic.rs | 18 +- evm/src/lookup.rs | 29 +-- evm/src/memory/memory_stark.rs | 90 +++++----- evm/src/permutation.rs | 15 +- evm/src/prover.rs | 68 ++----- evm/src/recursive_verifier.rs | 11 +- evm/src/stark.rs | 21 ++- evm/src/stark_testing.rs | 52 ++---- evm/src/vanishing_poly.rs | 6 +- evm/src/vars.rs | 19 -- evm/src/verifier.rs | 29 +-- 22 files changed, 465 insertions(+), 447 deletions(-) create mode 100644 evm/src/evaluation_frame.rs delete mode 100644 evm/src/vars.rs diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 5441cf27..a6db1278 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -7,18 +7,20 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::util::transpose; use static_assertions::const_assert; +use super::columns::NUM_ARITH_COLUMNS; use crate::all_stark::Table; use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::{Column, TableWithColumns}; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::permutation::PermutationPair; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Link the 16-bit columns of the arithmetic table, split into groups /// of N_LIMBS at a time in `regs`, with the corresponding 32-bit @@ -168,11 +170,16 @@ impl ArithmeticStark { } impl, const D: usize> Stark for ArithmeticStark { - const COLUMNS: usize = columns::NUM_ARITH_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_ARITH_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -183,8 +190,8 @@ impl, const D: usize> Stark for ArithmeticSta eval_lookups(vars, yield_constr, col, col + 1); } - let lv = vars.local_values; - let nv = vars.next_values; + let lv: &[P; NUM_ARITH_COLUMNS] = vars.get_local_values().try_into().unwrap(); + let nv: &[P; NUM_ARITH_COLUMNS] = vars.get_next_values().try_into().unwrap(); // Check the range column: First value must be 0, last row // must be 2^16-1, and intermediate rows must increment by 0 @@ -207,7 +214,7 @@ impl, const D: usize> Stark for ArithmeticSta fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { // Range check all the columns @@ -215,8 +222,10 @@ impl, const D: usize> Stark for ArithmeticSta eval_lookups_circuit(builder, vars, yield_constr, col, col + 1); } - let lv = vars.local_values; - let nv = vars.next_values; + let lv: &[ExtensionTarget; NUM_ARITH_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let nv: &[ExtensionTarget; NUM_ARITH_COLUMNS] = + vars.get_next_values().try_into().unwrap(); let rc1 = lv[columns::RANGE_COUNTER]; let rc2 = nv[columns::RANGE_COUNTER]; diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs index aa6a2dcf..d8f8e2e8 100644 --- a/evm/src/byte_packing/byte_packing_stark.rs +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -51,9 +51,9 @@ use crate::byte_packing::columns::{ }; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::witness::memory::MemoryAddress; /// Strict upper bound for the individual bytes range-check. @@ -211,7 +211,7 @@ impl, const D: usize> BytePackingStark { row[value_bytes(i)] = F::from_canonical_u8(byte); row[index_bytes(i)] = F::ONE; - rows.push(row.into()); + rows.push(row); row[index_bytes(i)] = F::ZERO; row[ADDR_VIRTUAL] -= F::ONE; } @@ -248,7 +248,7 @@ impl, const D: usize> BytePackingStark { } } - /// There is only one `i` for which `vars.local_values[index_bytes(i)]` is non-zero, + /// There is only one `i` for which `local_values[index_bytes(i)]` is non-zero, /// and `i+1` is the current position: fn get_active_position(&self, row: &[P; NUM_COLUMNS]) -> P where @@ -281,11 +281,16 @@ impl, const D: usize> BytePackingStark { } impl, const D: usize> Stark for BytePackingStark { - const COLUMNS: usize = NUM_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -296,68 +301,62 @@ impl, const D: usize> Stark for BytePackingSt eval_lookups(vars, yield_constr, col, col + 1); } + let local_values: &[P; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap(); + let next_values: &[P; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap(); + let one = P::ONES; // We filter active columns by summing all the byte indices. // Constraining each of them to be boolean is done later on below. - let current_filter = vars.local_values[BYTE_INDICES_COLS] - .iter() - .copied() - .sum::

(); + let current_filter = local_values[BYTE_INDICES_COLS].iter().copied().sum::

(); yield_constr.constraint(current_filter * (current_filter - one)); // The filter column must start by one. yield_constr.constraint_first_row(current_filter - one); // The is_read flag must be boolean. - let current_is_read = vars.local_values[IS_READ]; + let current_is_read = local_values[IS_READ]; yield_constr.constraint(current_is_read * (current_is_read - one)); // Each byte index must be boolean. for i in 0..NUM_BYTES { - let idx_i = vars.local_values[index_bytes(i)]; + let idx_i = local_values[index_bytes(i)]; yield_constr.constraint(idx_i * (idx_i - one)); } // The sequence start flag column must start by one. - let current_sequence_start = vars.local_values[index_bytes(0)]; + let current_sequence_start = local_values[index_bytes(0)]; yield_constr.constraint_first_row(current_sequence_start - one); // The sequence end flag must be boolean - let current_sequence_end = vars.local_values[SEQUENCE_END]; + let current_sequence_end = local_values[SEQUENCE_END]; yield_constr.constraint(current_sequence_end * (current_sequence_end - one)); // If filter is off, all flags and byte indices must be off. - let byte_indices = vars.local_values[BYTE_INDICES_COLS] - .iter() - .copied() - .sum::

(); + let byte_indices = local_values[BYTE_INDICES_COLS].iter().copied().sum::

(); yield_constr.constraint( (current_filter - one) * (current_is_read + current_sequence_end + byte_indices), ); // Only padding rows have their filter turned off. - let next_filter = vars.next_values[BYTE_INDICES_COLS] - .iter() - .copied() - .sum::

(); + let next_filter = next_values[BYTE_INDICES_COLS].iter().copied().sum::

(); yield_constr.constraint_transition(next_filter * (next_filter - current_filter)); // Unless the current sequence end flag is activated, the is_read filter must remain unchanged. - let next_is_read = vars.next_values[IS_READ]; + let next_is_read = next_values[IS_READ]; yield_constr .constraint_transition((current_sequence_end - one) * (next_is_read - current_is_read)); // If the sequence end flag is activated, the next row must be a new sequence or filter must be off. - let next_sequence_start = vars.next_values[index_bytes(0)]; + let next_sequence_start = next_values[index_bytes(0)]; yield_constr.constraint_transition( current_sequence_end * next_filter * (next_sequence_start - one), ); // The active position in a byte sequence must increase by one on every row // or be one on the next row (i.e. at the start of a new sequence). - let current_position = self.get_active_position(vars.local_values); - let next_position = self.get_active_position(vars.next_values); + let current_position = self.get_active_position(local_values); + let next_position = self.get_active_position(next_values); yield_constr.constraint_transition( next_filter * (next_position - one) * (next_position - current_position - one), ); @@ -371,14 +370,14 @@ impl, const D: usize> Stark for BytePackingSt // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. // The virtual address must decrement by one at each step of a sequence. - let current_context = vars.local_values[ADDR_CONTEXT]; - let next_context = vars.next_values[ADDR_CONTEXT]; - let current_segment = vars.local_values[ADDR_SEGMENT]; - let next_segment = vars.next_values[ADDR_SEGMENT]; - let current_virtual = vars.local_values[ADDR_VIRTUAL]; - let next_virtual = vars.next_values[ADDR_VIRTUAL]; - let current_timestamp = vars.local_values[TIMESTAMP]; - let next_timestamp = vars.next_values[TIMESTAMP]; + let current_context = local_values[ADDR_CONTEXT]; + let next_context = next_values[ADDR_CONTEXT]; + let current_segment = local_values[ADDR_SEGMENT]; + let next_segment = next_values[ADDR_SEGMENT]; + let current_virtual = local_values[ADDR_VIRTUAL]; + let next_virtual = next_values[ADDR_VIRTUAL]; + let current_timestamp = local_values[TIMESTAMP]; + let next_timestamp = next_values[TIMESTAMP]; yield_constr.constraint_transition( next_filter * (next_sequence_start - one) * (next_context - current_context), ); @@ -395,9 +394,9 @@ impl, const D: usize> Stark for BytePackingSt // If not at the end of a sequence, each next byte must equal the current one // when reading through the sequence, or the next byte index must be one. for i in 0..NUM_BYTES { - let current_byte = vars.local_values[value_bytes(i)]; - let next_byte = vars.next_values[value_bytes(i)]; - let next_byte_index = vars.next_values[index_bytes(i)]; + let current_byte = local_values[value_bytes(i)]; + let next_byte = next_values[value_bytes(i)]; + let next_byte_index = next_values[index_bytes(i)]; yield_constr.constraint_transition( (current_sequence_end - one) * (next_byte_index - one) * (next_byte - current_byte), ); @@ -407,7 +406,7 @@ impl, const D: usize> Stark for BytePackingSt fn eval_ext_circuit( &self, builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { // Range check all the columns @@ -415,9 +414,14 @@ impl, const D: usize> Stark for BytePackingSt eval_lookups_circuit(builder, vars, yield_constr, col, col + 1); } + let local_values: &[ExtensionTarget; NUM_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let next_values: &[ExtensionTarget; NUM_COLUMNS] = + vars.get_next_values().try_into().unwrap(); + // We filter active columns by summing all the byte indices. // Constraining each of them to be boolean is done later on below. - let current_filter = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); + let current_filter = builder.add_many_extension(&local_values[BYTE_INDICES_COLS]); let constraint = builder.mul_sub_extension(current_filter, current_filter, current_filter); yield_constr.constraint(builder, constraint); @@ -426,25 +430,25 @@ impl, const D: usize> Stark for BytePackingSt yield_constr.constraint_first_row(builder, constraint); // The is_read flag must be boolean. - let current_is_read = vars.local_values[IS_READ]; + let current_is_read = local_values[IS_READ]; let constraint = builder.mul_sub_extension(current_is_read, current_is_read, current_is_read); yield_constr.constraint(builder, constraint); // Each byte index must be boolean. for i in 0..NUM_BYTES { - let idx_i = vars.local_values[index_bytes(i)]; + let idx_i = local_values[index_bytes(i)]; let constraint = builder.mul_sub_extension(idx_i, idx_i, idx_i); yield_constr.constraint(builder, constraint); } // The sequence start flag column must start by one. - let current_sequence_start = vars.local_values[index_bytes(0)]; + let current_sequence_start = local_values[index_bytes(0)]; let constraint = builder.add_const_extension(current_sequence_start, F::NEG_ONE); yield_constr.constraint_first_row(builder, constraint); // The sequence end flag must be boolean - let current_sequence_end = vars.local_values[SEQUENCE_END]; + let current_sequence_end = local_values[SEQUENCE_END]; let constraint = builder.mul_sub_extension( current_sequence_end, current_sequence_end, @@ -453,27 +457,27 @@ impl, const D: usize> Stark for BytePackingSt yield_constr.constraint(builder, constraint); // If filter is off, all flags and byte indices must be off. - let byte_indices = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); + let byte_indices = builder.add_many_extension(&local_values[BYTE_INDICES_COLS]); let constraint = builder.add_extension(current_sequence_end, byte_indices); let constraint = builder.add_extension(constraint, current_is_read); let constraint = builder.mul_sub_extension(constraint, current_filter, constraint); yield_constr.constraint(builder, constraint); // Only padding rows have their filter turned off. - let next_filter = builder.add_many_extension(&vars.next_values[BYTE_INDICES_COLS]); + let next_filter = builder.add_many_extension(&next_values[BYTE_INDICES_COLS]); let constraint = builder.sub_extension(next_filter, current_filter); let constraint = builder.mul_extension(next_filter, constraint); yield_constr.constraint_transition(builder, constraint); // Unless the current sequence end flag is activated, the is_read filter must remain unchanged. - let next_is_read = vars.next_values[IS_READ]; + let next_is_read = next_values[IS_READ]; let diff_is_read = builder.sub_extension(next_is_read, current_is_read); let constraint = builder.mul_sub_extension(diff_is_read, current_sequence_end, diff_is_read); yield_constr.constraint_transition(builder, constraint); // If the sequence end flag is activated, the next row must be a new sequence or filter must be off. - let next_sequence_start = vars.next_values[index_bytes(0)]; + let next_sequence_start = next_values[index_bytes(0)]; let constraint = builder.mul_sub_extension( current_sequence_end, next_sequence_start, @@ -484,8 +488,8 @@ impl, const D: usize> Stark for BytePackingSt // The active position in a byte sequence must increase by one on every row // or be one on the next row (i.e. at the start of a new sequence). - let current_position = self.get_active_position_circuit(builder, vars.local_values); - let next_position = self.get_active_position_circuit(builder, vars.next_values); + let current_position = self.get_active_position_circuit(builder, local_values); + let next_position = self.get_active_position_circuit(builder, next_values); let position_diff = builder.sub_extension(next_position, current_position); let is_new_or_inactive = builder.mul_sub_extension(next_filter, next_position, next_filter); @@ -505,14 +509,14 @@ impl, const D: usize> Stark for BytePackingSt // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. // The virtual address must decrement by one at each step of a sequence. - let current_context = vars.local_values[ADDR_CONTEXT]; - let next_context = vars.next_values[ADDR_CONTEXT]; - let current_segment = vars.local_values[ADDR_SEGMENT]; - let next_segment = vars.next_values[ADDR_SEGMENT]; - let current_virtual = vars.local_values[ADDR_VIRTUAL]; - let next_virtual = vars.next_values[ADDR_VIRTUAL]; - let current_timestamp = vars.local_values[TIMESTAMP]; - let next_timestamp = vars.next_values[TIMESTAMP]; + let current_context = local_values[ADDR_CONTEXT]; + let next_context = next_values[ADDR_CONTEXT]; + let current_segment = local_values[ADDR_SEGMENT]; + let next_segment = next_values[ADDR_SEGMENT]; + let current_virtual = local_values[ADDR_VIRTUAL]; + let next_virtual = next_values[ADDR_VIRTUAL]; + let current_timestamp = local_values[TIMESTAMP]; + let next_timestamp = next_values[TIMESTAMP]; let addr_filter = builder.mul_sub_extension(next_filter, next_sequence_start, next_filter); { let constraint = builder.sub_extension(next_context, current_context); @@ -538,9 +542,9 @@ impl, const D: usize> Stark for BytePackingSt // If not at the end of a sequence, each next byte must equal the current one // when reading through the sequence, or the next byte index must be one. for i in 0..NUM_BYTES { - let current_byte = vars.local_values[value_bytes(i)]; - let next_byte = vars.next_values[value_bytes(i)]; - let next_byte_index = vars.next_values[index_bytes(i)]; + let current_byte = local_values[value_bytes(i)]; + let next_byte = next_values[value_bytes(i)]; + let next_byte_index = next_values[index_bytes(i)]; let byte_diff = builder.sub_extension(next_byte, current_byte); let constraint = builder.mul_sub_extension(byte_diff, next_byte_index, byte_diff); let constraint = diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 4aee617c..759c852a 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -1,22 +1,20 @@ //! The initial phase of execution, where the kernel code is hashed while being written to memory. //! The hash is then checked against a precomputed kernel hash. -use std::borrow::Borrow; - use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; +use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::witness::memory::MemoryAddress; use crate::witness::util::{keccak_sponge_log, mem_write_gp_log_and_fill}; @@ -58,13 +56,11 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState log::info!("Bootstrapping took {} cycles", state.traces.clock()); } -pub(crate) fn eval_bootstrap_kernel>( - vars: StarkEvaluationVars, +pub(crate) fn eval_bootstrap_kernel_packed>( + local_values: &CpuColumnsView

, + next_values: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let local_values: &CpuColumnsView<_> = vars.local_values.borrow(); - let next_values: &CpuColumnsView<_> = vars.next_values.borrow(); - // IS_BOOTSTRAP_KERNEL must have an init value of 1, a final value of 0, and a delta in {0, -1}. let local_is_bootstrap = local_values.is_bootstrap_kernel; let next_is_bootstrap = next_values.is_bootstrap_kernel; @@ -103,13 +99,12 @@ pub(crate) fn eval_bootstrap_kernel>( } } -pub(crate) fn eval_bootstrap_kernel_circuit, const D: usize>( +pub(crate) fn eval_bootstrap_kernel_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + local_values: &CpuColumnsView>, + next_values: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let local_values: &CpuColumnsView<_> = vars.local_values.borrow(); - let next_values: &CpuColumnsView<_> = vars.next_values.borrow(); let one = builder.one_extension(); // IS_BOOTSTRAP_KERNEL must have an init value of 1, a final value of 0, and a delta in {0, -1}. diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index bd2fcf19..14bb6015 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -7,6 +7,7 @@ use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use super::halt; use crate::all_stark::Table; @@ -18,10 +19,10 @@ use crate::cpu::{ modfp254, pc, push0, shift, simple_logic, stack, stack_bounds, syscalls_exceptions, }; use crate::cross_table_lookup::{Column, TableWithColumns}; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub fn ctl_data_keccak_sponge() -> Vec> { // When executing KECCAK_GENERAL, the GP memory channels are used as follows: @@ -227,19 +228,29 @@ impl CpuStark { } impl, const D: usize> Stark for CpuStark { - const COLUMNS: usize = NUM_CPU_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_CPU_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, P: PackedField, { - let local_values = vars.local_values.borrow(); - let next_values = vars.next_values.borrow(); - bootstrap_kernel::eval_bootstrap_kernel(vars, yield_constr); + let local_values = + TryInto::<[P; NUM_CPU_COLUMNS]>::try_into(vars.get_local_values()).unwrap(); + let local_values: &CpuColumnsView

= local_values.borrow(); + let next_values = + TryInto::<[P; NUM_CPU_COLUMNS]>::try_into(vars.get_next_values()).unwrap(); + let next_values: &CpuColumnsView

= next_values.borrow(); + + bootstrap_kernel::eval_bootstrap_kernel_packed(local_values, next_values, yield_constr); contextops::eval_packed(local_values, next_values, yield_constr); control_flow::eval_packed_generic(local_values, next_values, yield_constr); decode::eval_packed_generic(local_values, yield_constr); @@ -262,12 +273,24 @@ impl, const D: usize> Stark for CpuStark, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { - let local_values = vars.local_values.borrow(); - let next_values = vars.next_values.borrow(); - bootstrap_kernel::eval_bootstrap_kernel_circuit(builder, vars, yield_constr); + let local_values = + TryInto::<[ExtensionTarget; NUM_CPU_COLUMNS]>::try_into(vars.get_local_values()) + .unwrap(); + let local_values: &CpuColumnsView> = local_values.borrow(); + let next_values = + TryInto::<[ExtensionTarget; NUM_CPU_COLUMNS]>::try_into(vars.get_next_values()) + .unwrap(); + let next_values: &CpuColumnsView> = next_values.borrow(); + + bootstrap_kernel::eval_bootstrap_kernel_ext_circuit( + builder, + local_values, + next_values, + yield_constr, + ); contextops::eval_ext_circuit(builder, local_values, next_values, yield_constr); control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr); decode::eval_ext_circuit(builder, local_values, yield_constr); diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index a2dad1ab..28b18994 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -16,10 +16,10 @@ use plonky2::plonk::config::GenericConfig; use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkEvaluationFrame; use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; use crate::proof::{StarkProofTarget, StarkProofWithMetadata}; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Represent a linear combination of columns. #[derive(Clone, Debug)] @@ -473,7 +473,7 @@ impl<'a, F: RichField + Extendable, const D: usize> /// Z(w) = Z(gw) * combine(w) where combine is called on the local row /// and not the next. This enables CTLs across two rows. pub(crate) fn eval_cross_table_lookup_checks( - vars: StarkEvaluationVars, + vars: &S::EvaluationFrame, ctl_vars: &[CtlCheckVars], consumer: &mut ConstraintConsumer

, ) where @@ -482,6 +482,9 @@ pub(crate) fn eval_cross_table_lookup_checks, S: Stark, { + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); + for lookup_vars in ctl_vars { let CtlCheckVars { local_z, @@ -493,11 +496,11 @@ pub(crate) fn eval_cross_table_lookup_checks>(); let combined = challenges.combine(evals.iter()); let local_filter = if let Some(column) = filter_column { - column.eval_with_next(vars.local_values, vars.next_values) + column.eval_with_next(local_values, next_values) } else { P::ONES }; @@ -580,10 +583,13 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< const D: usize, >( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &S::EvaluationFrameTarget, ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, ) { + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); + for lookup_vars in ctl_vars { let CtlCheckVarsTarget { local_z, @@ -595,7 +601,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< let one = builder.one_extension(); let local_filter = if let Some(column) = filter_column { - column.eval_circuit(builder, vars.local_values) + column.eval_circuit(builder, local_values) } else { one }; @@ -611,7 +617,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< let evals = columns .iter() - .map(|c| c.eval_with_next_circuit(builder, vars.local_values, vars.next_values)) + .map(|c| c.eval_with_next_circuit(builder, local_values, next_values)) .collect::>(); let combined = challenges.combine_circuit(builder, &evals); diff --git a/evm/src/evaluation_frame.rs b/evm/src/evaluation_frame.rs new file mode 100644 index 00000000..0f6bbe2c --- /dev/null +++ b/evm/src/evaluation_frame.rs @@ -0,0 +1,47 @@ +/// A trait for viewing an evaluation frame of a STARK table. +/// +/// It allows to access the current and next rows at a given step +/// and can be used to implement constraint evaluation both natively +/// and recursively. +pub trait StarkEvaluationFrame: Sized { + /// The number of columns for the STARK table this evaluation frame views. + const COLUMNS: usize; + + /// Returns the local values (i.e. current row) for this evaluation frame. + fn get_local_values(&self) -> &[T]; + /// Returns the next values (i.e. next row) for this evaluation frame. + fn get_next_values(&self) -> &[T]; + + /// Outputs a new evaluation frame from the provided local and next values. + /// + /// **NOTE**: Concrete implementations of this method SHOULD ensure that + /// the provided slices lengths match the `Self::COLUMNS` value. + fn from_values(lv: &[T], nv: &[T]) -> Self; +} + +pub struct StarkFrame { + local_values: [T; N], + next_values: [T; N], +} + +impl StarkEvaluationFrame for StarkFrame { + const COLUMNS: usize = N; + + fn get_local_values(&self) -> &[T] { + &self.local_values + } + + fn get_next_values(&self) -> &[T] { + &self.next_values + } + + fn from_values(lv: &[T], nv: &[T]) -> Self { + assert_eq!(lv.len(), Self::COLUMNS); + assert_eq!(nv.len(), Self::COLUMNS); + + Self { + local_values: lv.try_into().unwrap(), + next_values: nv.try_into().unwrap(), + } + } +} diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 05ec015c..7fefe95f 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -28,17 +28,10 @@ use plonky2::util::timing::TimingTree; use plonky2_util::log2_ceil; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; -use crate::arithmetic::arithmetic_stark::ArithmeticStark; -use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; -use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; use crate::generation::GenerationInputs; use crate::get_challenges::observe_public_values_target; -use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; -use crate::logic::LogicStark; -use crate::memory::memory_stark::MemoryStark; use crate::permutation::{get_grand_product_challenge_set_target, GrandProductChallengeSet}; use crate::proof::{ BlockHashesTarget, BlockMetadataTarget, ExtraBlockDataTarget, PublicValues, PublicValuesTarget, @@ -297,13 +290,6 @@ where F: RichField + Extendable, C: GenericConfig + 'static, C::Hasher: AlgebraicHasher, - [(); ArithmeticStark::::COLUMNS]:, - [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, { pub fn to_bytes( &self, @@ -1083,10 +1069,7 @@ where degree_bits_range: Range, all_ctls: &[CrossTableLookup], stark_config: &StarkConfig, - ) -> Self - where - [(); S::COLUMNS]:, - { + ) -> Self { let by_stark_size = degree_bits_range .map(|degree_bits| { ( @@ -1207,10 +1190,7 @@ where degree_bits: usize, all_ctls: &[CrossTableLookup], stark_config: &StarkConfig, - ) -> Self - where - [(); S::COLUMNS]:, - { + ) -> Self { let initial_wrapper = recursive_stark_circuit( table, stark, diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 74f92622..c517a5f6 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -6,12 +6,14 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit; use plonky2::timed; use plonky2::util::timing::TimingTree; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::keccak::columns::{ reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step, @@ -24,7 +26,6 @@ use crate::keccak::logic::{ use crate::keccak::round_flags::{eval_round_flags, eval_round_flags_recursively}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Number of rounds in a Keccak permutation. pub(crate) const NUM_ROUNDS: usize = 24; @@ -239,11 +240,16 @@ impl, const D: usize> KeccakStark { } impl, const D: usize> Stark for KeccakStark { - const COLUMNS: usize = NUM_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -251,33 +257,34 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { let one_ext = builder.one_extension(); @@ -433,49 +440,44 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark>( - vars: StarkEvaluationVars, + vars: &StarkFrame, yield_constr: &mut ConstraintConsumer

, ) { + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); + // Initially, the first step flag should be 1 while the others should be 0. - yield_constr.constraint_first_row(vars.local_values[reg_step(0)] - F::ONE); + yield_constr.constraint_first_row(local_values[reg_step(0)] - F::ONE); for i in 1..NUM_ROUNDS { - yield_constr.constraint_first_row(vars.local_values[reg_step(i)]); + yield_constr.constraint_first_row(local_values[reg_step(i)]); } // Flags should circularly increment, or be all zero for padding rows. - let next_any_flag = (0..NUM_ROUNDS) - .map(|i| vars.next_values[reg_step(i)]) - .sum::

(); + let next_any_flag = (0..NUM_ROUNDS).map(|i| next_values[reg_step(i)]).sum::

(); for i in 0..NUM_ROUNDS { - let current_round_flag = vars.local_values[reg_step(i)]; - let next_round_flag = vars.next_values[reg_step((i + 1) % NUM_ROUNDS)]; + let current_round_flag = local_values[reg_step(i)]; + let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)]; yield_constr.constraint_transition(next_any_flag * (next_round_flag - current_round_flag)); } // Padding rows should always be followed by padding rows. let current_any_flag = (0..NUM_ROUNDS) - .map(|i| vars.local_values[reg_step(i)]) + .map(|i| local_values[reg_step(i)]) .sum::

(); yield_constr.constraint_transition(next_any_flag * (current_any_flag - F::ONE)); } pub(crate) fn eval_round_flags_recursively, const D: usize>( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &StarkFrame, NUM_COLUMNS>, yield_constr: &mut RecursiveConstraintConsumer, ) { let one = builder.one_extension(); + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); // Initially, the first step flag should be 1 while the others should be 0. - let step_0_minus_1 = builder.sub_extension(vars.local_values[reg_step(0)], one); + let step_0_minus_1 = builder.sub_extension(local_values[reg_step(0)], one); yield_constr.constraint_first_row(builder, step_0_minus_1); for i in 1..NUM_ROUNDS { - yield_constr.constraint_first_row(builder, vars.local_values[reg_step(i)]); + yield_constr.constraint_first_row(builder, local_values[reg_step(i)]); } // Flags should circularly increment, or be all zero for padding rows. let next_any_flag = - builder.add_many_extension((0..NUM_ROUNDS).map(|i| vars.next_values[reg_step(i)])); + builder.add_many_extension((0..NUM_ROUNDS).map(|i| next_values[reg_step(i)])); for i in 0..NUM_ROUNDS { - let current_round_flag = vars.local_values[reg_step(i)]; - let next_round_flag = vars.next_values[reg_step((i + 1) % NUM_ROUNDS)]; + let current_round_flag = local_values[reg_step(i)]; + let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)]; let diff = builder.sub_extension(next_round_flag, current_round_flag); let constraint = builder.mul_extension(next_any_flag, diff); yield_constr.constraint_transition(builder, constraint); @@ -63,7 +67,7 @@ pub(crate) fn eval_round_flags_recursively, const D // Padding rows should always be followed by padding rows. let current_any_flag = - builder.add_many_extension((0..NUM_ROUNDS).map(|i| vars.local_values[reg_step(i)])); + builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)])); let constraint = builder.mul_sub_extension(next_any_flag, current_any_flag, next_any_flag); yield_constr.constraint_transition(builder, constraint); } diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index d78e9651..65edc941 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -17,10 +17,10 @@ use plonky2_util::ceil_div_usize; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::kernel::keccak_util::keccakf_u32s; use crate::cross_table_lookup::Column; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::keccak_sponge::columns::*; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::witness::memory::MemoryAddress; pub(crate) fn ctl_looked_data() -> Vec> { @@ -423,18 +423,27 @@ impl, const D: usize> KeccakSpongeStark { } impl, const D: usize> Stark for KeccakSpongeStark { - const COLUMNS: usize = NUM_KECCAK_SPONGE_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_KECCAK_SPONGE_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, P: PackedField, { - let local_values: &KeccakSpongeColumnsView

= vars.local_values.borrow(); - let next_values: &KeccakSpongeColumnsView

= vars.next_values.borrow(); + let local_values = + TryInto::<[P; NUM_KECCAK_SPONGE_COLUMNS]>::try_into(vars.get_local_values()).unwrap(); + let local_values: &KeccakSpongeColumnsView

= local_values.borrow(); + let next_values = + TryInto::<[P; NUM_KECCAK_SPONGE_COLUMNS]>::try_into(vars.get_next_values()).unwrap(); + let next_values: &KeccakSpongeColumnsView

= next_values.borrow(); // Each flag (full-input block, final block or implied dummy flag) must be boolean. let is_full_input_block = local_values.is_full_input_block; @@ -537,11 +546,19 @@ impl, const D: usize> Stark for KeccakSpongeS fn eval_ext_circuit( &self, builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { - let local_values: &KeccakSpongeColumnsView> = vars.local_values.borrow(); - let next_values: &KeccakSpongeColumnsView> = vars.next_values.borrow(); + let local_values = TryInto::<[ExtensionTarget; NUM_KECCAK_SPONGE_COLUMNS]>::try_into( + vars.get_local_values(), + ) + .unwrap(); + let local_values: &KeccakSpongeColumnsView> = local_values.borrow(); + let next_values = TryInto::<[ExtensionTarget; NUM_KECCAK_SPONGE_COLUMNS]>::try_into( + vars.get_next_values(), + ) + .unwrap(); + let next_values: &KeccakSpongeColumnsView> = next_values.borrow(); let one = builder.one_extension(); diff --git a/evm/src/lib.rs b/evm/src/lib.rs index ab48cda0..474d6faf 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -4,7 +4,6 @@ #![allow(clippy::type_complexity)] #![allow(clippy::field_reassign_with_default)] #![feature(let_chains)] -#![feature(generic_const_exprs)] pub mod all_stark; pub mod arithmetic; @@ -14,6 +13,7 @@ pub mod constraint_consumer; pub mod cpu; pub mod cross_table_lookup; pub mod curve_pairings; +pub mod evaluation_frame; pub mod extension_tower; pub mod fixed_recursive_verifier; pub mod generation; @@ -31,7 +31,6 @@ pub mod stark; pub mod stark_testing; pub mod util; pub mod vanishing_poly; -pub mod vars; pub mod verifier; pub mod witness; diff --git a/evm/src/logic.rs b/evm/src/logic.rs index 3529ed9a..319dfab2 100644 --- a/evm/src/logic.rs +++ b/evm/src/logic.rs @@ -7,16 +7,17 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use plonky2::timed; use plonky2::util::timing::TimingTree; use plonky2_util::ceil_div_usize; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::logic::columns::NUM_COLUMNS; use crate::stark::Stark; use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive, trace_rows_to_poly_values}; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; // Total number of bits per input/output. const VAL_BITS: usize = 256; @@ -181,17 +182,22 @@ impl LogicStark { } impl, const D: usize> Stark for LogicStark { - const COLUMNS: usize = NUM_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, P: PackedField, { - let lv = &vars.local_values; + let lv = vars.get_local_values(); // IS_AND, IS_OR, and IS_XOR come from the CPU table, so we assume they're valid. let is_and = lv[columns::IS_AND]; @@ -237,10 +243,10 @@ impl, const D: usize> Stark for LogicStark, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { - let lv = &vars.local_values; + let lv = vars.get_local_values(); // IS_AND, IS_OR, and IS_XOR come from the CPU table, so we assume they're valid. let is_and = lv[columns::IS_AND]; diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index d7e12bac..d6c1b217 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -5,20 +5,24 @@ use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +use crate::evaluation_frame::StarkEvaluationFrame; -pub(crate) fn eval_lookups, const COLS: usize>( - vars: StarkEvaluationVars, +pub(crate) fn eval_lookups, E: StarkEvaluationFrame

>( + vars: &E, yield_constr: &mut ConstraintConsumer

, col_permuted_input: usize, col_permuted_table: usize, ) { - let local_perm_input = vars.local_values[col_permuted_input]; - let next_perm_table = vars.next_values[col_permuted_table]; - let next_perm_input = vars.next_values[col_permuted_input]; + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); + + let local_perm_input = local_values[col_permuted_input]; + let next_perm_table = next_values[col_permuted_table]; + let next_perm_input = next_values[col_permuted_input]; // A "vertical" diff between the local and next permuted inputs. let diff_input_prev = next_perm_input - local_perm_input; @@ -35,18 +39,21 @@ pub(crate) fn eval_lookups, const COLS: usi pub(crate) fn eval_lookups_circuit< F: RichField + Extendable, + E: StarkEvaluationFrame>, const D: usize, - const COLS: usize, >( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &E, yield_constr: &mut RecursiveConstraintConsumer, col_permuted_input: usize, col_permuted_table: usize, ) { - let local_perm_input = vars.local_values[col_permuted_input]; - let next_perm_table = vars.next_values[col_permuted_table]; - let next_perm_input = vars.next_values[col_permuted_input]; + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); + + let local_perm_input = local_values[col_permuted_input]; + let next_perm_table = next_values[col_permuted_table]; + let next_perm_input = next_values[col_permuted_input]; // A "vertical" diff between the local and next permuted inputs. let diff_input_prev = builder.sub_extension(next_perm_input, local_perm_input); diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 36f75665..fde93be0 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -7,6 +7,7 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; use plonky2::timed; use plonky2::util::timing::TimingTree; use plonky2::util::transpose; @@ -14,6 +15,7 @@ use plonky2_maybe_rayon::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::memory::columns::{ value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, @@ -23,7 +25,6 @@ use crate::memory::columns::{ use crate::memory::VALUE_LIMBS; use crate::permutation::PermutationPair; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::witness::memory::MemoryOpKind::Read; use crate::witness::memory::{MemoryAddress, MemoryOp}; @@ -238,48 +239,55 @@ impl, const D: usize> MemoryStark { } impl, const D: usize> Stark for MemoryStark { - const COLUMNS: usize = NUM_COLUMNS; + type EvaluationFrame = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_COLUMNS>; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, P: PackedField, { let one = P::from(FE::ONE); + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); - let timestamp = vars.local_values[TIMESTAMP]; - let addr_context = vars.local_values[ADDR_CONTEXT]; - let addr_segment = vars.local_values[ADDR_SEGMENT]; - let addr_virtual = vars.local_values[ADDR_VIRTUAL]; - let values: Vec<_> = (0..8).map(|i| vars.local_values[value_limb(i)]).collect(); + let timestamp = local_values[TIMESTAMP]; + let addr_context = local_values[ADDR_CONTEXT]; + let addr_segment = local_values[ADDR_SEGMENT]; + let addr_virtual = local_values[ADDR_VIRTUAL]; + let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect(); - let next_timestamp = vars.next_values[TIMESTAMP]; - let next_is_read = vars.next_values[IS_READ]; - let next_addr_context = vars.next_values[ADDR_CONTEXT]; - let next_addr_segment = vars.next_values[ADDR_SEGMENT]; - let next_addr_virtual = vars.next_values[ADDR_VIRTUAL]; - let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); + let next_timestamp = next_values[TIMESTAMP]; + let next_is_read = next_values[IS_READ]; + let next_addr_context = next_values[ADDR_CONTEXT]; + let next_addr_segment = next_values[ADDR_SEGMENT]; + let next_addr_virtual = next_values[ADDR_VIRTUAL]; + let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect(); // The filter must be 0 or 1. - let filter = vars.local_values[FILTER]; + let filter = local_values[FILTER]; yield_constr.constraint(filter * (filter - P::ONES)); // If this is a dummy row (filter is off), it must be a read. This means the prover can // insert reads which never appear in the CPU trace (which are harmless), but not writes. let is_dummy = P::ONES - filter; - let is_write = P::ONES - vars.local_values[IS_READ]; + let is_write = P::ONES - local_values[IS_READ]; yield_constr.constraint(is_dummy * is_write); - let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; - let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; - let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE]; + let context_first_change = local_values[CONTEXT_FIRST_CHANGE]; + let segment_first_change = local_values[SEGMENT_FIRST_CHANGE]; + let virtual_first_change = local_values[VIRTUAL_FIRST_CHANGE]; let address_unchanged = one - context_first_change - segment_first_change - virtual_first_change; - let range_check = vars.local_values[RANGE_CHECK]; + let range_check = local_values[RANGE_CHECK]; let not_context_first_change = one - context_first_change; let not_segment_first_change = one - segment_first_change; @@ -313,7 +321,7 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { let one = builder.one_extension(); + let local_values = vars.get_local_values(); + let next_values = vars.get_next_values(); - let addr_context = vars.local_values[ADDR_CONTEXT]; - let addr_segment = vars.local_values[ADDR_SEGMENT]; - let addr_virtual = vars.local_values[ADDR_VIRTUAL]; - let values: Vec<_> = (0..8).map(|i| vars.local_values[value_limb(i)]).collect(); - let timestamp = vars.local_values[TIMESTAMP]; + let addr_context = local_values[ADDR_CONTEXT]; + let addr_segment = local_values[ADDR_SEGMENT]; + let addr_virtual = local_values[ADDR_VIRTUAL]; + let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect(); + let timestamp = local_values[TIMESTAMP]; - let next_addr_context = vars.next_values[ADDR_CONTEXT]; - let next_addr_segment = vars.next_values[ADDR_SEGMENT]; - let next_addr_virtual = vars.next_values[ADDR_VIRTUAL]; - let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); - let next_is_read = vars.next_values[IS_READ]; - let next_timestamp = vars.next_values[TIMESTAMP]; + let next_addr_context = next_values[ADDR_CONTEXT]; + let next_addr_segment = next_values[ADDR_SEGMENT]; + let next_addr_virtual = next_values[ADDR_VIRTUAL]; + let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect(); + let next_is_read = next_values[IS_READ]; + let next_timestamp = next_values[TIMESTAMP]; // The filter must be 0 or 1. - let filter = vars.local_values[FILTER]; + let filter = local_values[FILTER]; let constraint = builder.mul_sub_extension(filter, filter, filter); yield_constr.constraint(builder, constraint); // If this is a dummy row (filter is off), it must be a read. This means the prover can // insert reads which never appear in the CPU trace (which are harmless), but not writes. let is_dummy = builder.sub_extension(one, filter); - let is_write = builder.sub_extension(one, vars.local_values[IS_READ]); + let is_write = builder.sub_extension(one, local_values[IS_READ]); let is_dummy_write = builder.mul_extension(is_dummy, is_write); yield_constr.constraint(builder, is_dummy_write); - let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; - let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; - let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE]; + let context_first_change = local_values[CONTEXT_FIRST_CHANGE]; + let segment_first_change = local_values[SEGMENT_FIRST_CHANGE]; + let virtual_first_change = local_values[VIRTUAL_FIRST_CHANGE]; let address_unchanged = { let mut cur = builder.sub_extension(one, context_first_change); cur = builder.sub_extension(cur, segment_first_change); builder.sub_extension(cur, virtual_first_change) }; - let range_check = vars.local_values[RANGE_CHECK]; + let range_check = local_values[RANGE_CHECK]; let not_context_first_change = builder.sub_extension(one, context_first_change); let not_segment_first_change = builder.sub_extension(one, segment_first_change); @@ -433,7 +443,7 @@ impl, const D: usize> Stark for MemoryStark( stark: &S, config: &StarkConfig, - vars: StarkEvaluationVars, + vars: &S::EvaluationFrame, permutation_vars: PermutationCheckVars, consumer: &mut ConstraintConsumer

, ) where @@ -335,6 +335,8 @@ pub(crate) fn eval_permutation_checks, S: Stark, { + let local_values = vars.get_local_values(); + let PermutationCheckVars { local_zs, next_zs, @@ -368,7 +370,7 @@ pub(crate) fn eval_permutation_checks, Vec<_>) = column_pairs .iter() - .map(|&(i, j)| (vars.local_values[i], vars.local_values[j])) + .map(|&(i, j)| (local_values[i], local_values[j])) .unzip(); ( factor.reduce_ext(lhs.into_iter()) + FE::from_basefield(*gamma), @@ -392,14 +394,15 @@ pub(crate) fn eval_permutation_checks_circuit( builder: &mut CircuitBuilder, stark: &S, config: &StarkConfig, - vars: StarkEvaluationTargets, + vars: &S::EvaluationFrameTarget, permutation_data: PermutationCheckDataTarget, consumer: &mut RecursiveConstraintConsumer, ) where F: RichField + Extendable, S: Stark, - [(); S::COLUMNS]:, { + let local_values = vars.get_local_values(); + let PermutationCheckDataTarget { local_zs, next_zs, @@ -437,7 +440,7 @@ pub(crate) fn eval_permutation_checks_circuit( let mut factor = ReducingFactorTarget::new(beta_ext); let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs .iter() - .map(|&(i, j)| (vars.local_values[i], vars.local_values[j])) + .map(|&(i, j)| (local_values[i], local_values[j])) .unzip(); let reduced_lhs = factor.reduce(&lhs, builder); let reduced_rhs = factor.reduce(&rhs, builder); diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 7b960c95..d7368c78 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -20,20 +20,14 @@ use plonky2_maybe_rayon::*; use plonky2_util::{log2_ceil, log2_strict}; use crate::all_stark::{AllStark, Table, NUM_TABLES}; -use crate::arithmetic::arithmetic_stark::ArithmeticStark; -use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; -use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData}; +use crate::evaluation_frame::StarkEvaluationFrame; use crate::generation::outputs::GenerationOutputs; use crate::generation::{generate_traces, GenerationInputs}; use crate::get_challenges::observe_public_values; -use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; -use crate::logic::LogicStark; -use crate::memory::memory_stark::MemoryStark; use crate::permutation::{ compute_permutation_z_polys, get_grand_product_challenge_set, get_n_grand_product_challenge_sets, GrandProductChallengeSet, PermutationCheckVars, @@ -41,7 +35,6 @@ use crate::permutation::{ use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; -use crate::vars::StarkEvaluationVars; /// Generate traces, then create all STARK proofs. pub fn prove( @@ -53,13 +46,6 @@ pub fn prove( where F: RichField + Extendable, C: GenericConfig, - [(); ArithmeticStark::::COLUMNS]:, - [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, { let (proof, _outputs) = prove_with_outputs(all_stark, config, inputs, timing)?; Ok(proof) @@ -76,13 +62,6 @@ pub fn prove_with_outputs( where F: RichField + Extendable, C: GenericConfig, - [(); ArithmeticStark::::COLUMNS]:, - [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, { timed!(timing, "build kernel", Lazy::force(&KERNEL)); let (traces, public_values, outputs) = timed!( @@ -105,13 +84,6 @@ pub(crate) fn prove_with_traces( where F: RichField + Extendable, C: GenericConfig, - [(); ArithmeticStark::::COLUMNS]:, - [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, { let rate_bits = config.fri_config.rate_bits; let cap_height = config.fri_config.cap_height; @@ -197,13 +169,6 @@ fn prove_with_commitments( where F: RichField + Extendable, C: GenericConfig, - [(); ArithmeticStark::::COLUMNS]:, - [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, { let arithmetic_proof = timed!( timing, @@ -322,7 +287,6 @@ where F: RichField + Extendable, C: GenericConfig, S: Stark, - [(); S::COLUMNS]:, { let degree = trace_poly_values[0].len(); let degree_bits = log2_strict(degree); @@ -507,7 +471,6 @@ where P: PackedField, C: GenericConfig, S: Stark, - [(); S::COLUMNS]:, { let degree = 1 << degree_bits; let rate_bits = config.fri_config.rate_bits; @@ -530,12 +493,8 @@ where let z_h_on_coset = ZeroPolyOnCoset::::new(degree_bits, quotient_degree_bits); // Retrieve the LDE values at index `i`. - let get_trace_values_packed = |i_start| -> [P; S::COLUMNS] { - trace_commitment - .get_lde_values_packed(i_start, step) - .try_into() - .unwrap() - }; + let get_trace_values_packed = + |i_start| -> Vec

{ trace_commitment.get_lde_values_packed(i_start, step) }; // Last element of the subgroup. let last = F::primitive_root_of_unity(degree_bits).inverse(); @@ -566,10 +525,10 @@ where lagrange_basis_first, lagrange_basis_last, ); - let vars = StarkEvaluationVars { - local_values: &get_trace_values_packed(i_start), - next_values: &get_trace_values_packed(i_next_start), - }; + let vars = S::EvaluationFrame::from_values( + &get_trace_values_packed(i_start), + &get_trace_values_packed(i_next_start), + ); let permutation_check_vars = permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { local_zs: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step) @@ -597,7 +556,7 @@ where eval_vanishing_poly::( stark, config, - vars, + &vars, permutation_check_vars, &ctl_vars, &mut consumer, @@ -642,7 +601,6 @@ fn check_constraints<'a, F, C, S, const D: usize>( F: RichField + Extendable, C: GenericConfig, S: Stark, - [(); S::COLUMNS]:, { let degree = 1 << degree_bits; let rate_bits = 0; // Set this to higher value to check constraint degree. @@ -688,10 +646,10 @@ fn check_constraints<'a, F, C, S, const D: usize>( lagrange_basis_first, lagrange_basis_last, ); - let vars = StarkEvaluationVars { - local_values: trace_subgroup_evals[i].as_slice().try_into().unwrap(), - next_values: trace_subgroup_evals[i_next].as_slice().try_into().unwrap(), - }; + let vars = S::EvaluationFrame::from_values( + &trace_subgroup_evals[i], + &trace_subgroup_evals[i_next], + ); let permutation_check_vars = permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { local_zs: permutation_ctl_zs_subgroup_evals[i][..num_permutation_zs].to_vec(), @@ -715,7 +673,7 @@ fn check_constraints<'a, F, C, S, const D: usize>( eval_vanishing_poly::( stark, config, - vars, + &vars, permutation_check_vars, &ctl_vars, &mut consumer, diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 76a92338..f4e76c39 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -30,6 +30,7 @@ use crate::config::StarkConfig; use crate::constraint_consumer::RecursiveConstraintConsumer; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cross_table_lookup::{verify_cross_table_lookups, CrossTableLookup, CtlCheckVarsTarget}; +use crate::evaluation_frame::StarkEvaluationFrame; use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; use crate::permutation::{ @@ -45,7 +46,6 @@ use crate::proof::{ use crate::stark::Stark; use crate::util::{h256_limbs, u256_limbs, u256_to_u32, u256_to_u64}; use crate::vanishing_poly::eval_vanishing_poly_circuit; -use crate::vars::StarkEvaluationTargets; use crate::witness::errors::ProgramError; /// Table-wise recursive proofs of an `AllProof`. @@ -297,7 +297,6 @@ pub(crate) fn recursive_stark_circuit< min_degree_bits: usize, ) -> StarkWrapperCircuit where - [(); S::COLUMNS]:, C::Hasher: AlgebraicHasher, { let mut builder = CircuitBuilder::::new(circuit_config.clone()); @@ -405,7 +404,6 @@ fn verify_stark_proof_with_challenges_circuit< inner_config: &StarkConfig, ) where C::Hasher: AlgebraicHasher, - [(); S::COLUMNS]:, { let zero = builder.zero(); let one = builder.one_extension(); @@ -418,10 +416,7 @@ fn verify_stark_proof_with_challenges_circuit< ctl_zs_first, quotient_polys, } = &proof.openings; - let vars = StarkEvaluationTargets { - local_values: &local_values.to_vec().try_into().unwrap(), - next_values: &next_values.to_vec().try_into().unwrap(), - }; + let vars = S::EvaluationFrameTarget::from_values(local_values, next_values); let degree_bits = proof.recover_degree_bits(inner_config); let zeta_pow_deg = builder.exp_power_of_2_extension(challenges.stark_zeta, degree_bits); @@ -456,7 +451,7 @@ fn verify_stark_proof_with_challenges_circuit< builder, stark, inner_config, - vars, + &vars, permutation_data, ctl_vars, &mut consumer, diff --git a/evm/src/stark.rs b/evm/src/stark.rs index 73b51ada..b3ea818f 100644 --- a/evm/src/stark.rs +++ b/evm/src/stark.rs @@ -12,8 +12,8 @@ use plonky2_util::ceil_div_usize; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkEvaluationFrame; use crate::permutation::PermutationPair; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; const TRACE_ORACLE_INDEX: usize = 0; const PERMUTATION_CTL_ORACLE_INDEX: usize = 1; @@ -22,7 +22,16 @@ const QUOTIENT_ORACLE_INDEX: usize = 2; /// Represents a STARK system. pub trait Stark, const D: usize>: Sync { /// The total number of columns in the trace. - const COLUMNS: usize; + const COLUMNS: usize = Self::EvaluationFrameTarget::COLUMNS; + + /// This is used to evaluate constraints natively. + type EvaluationFrame: StarkEvaluationFrame

+ where + FE: FieldExtension, + P: PackedField; + + /// The `Target` version of `Self::EvaluationFrame`, used to evaluate constraints recursively. + type EvaluationFrameTarget: StarkEvaluationFrame>; /// Evaluate constraints at a vector of points. /// @@ -32,7 +41,7 @@ pub trait Stark, const D: usize>: Sync { /// constraints over `F`. fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -41,7 +50,7 @@ pub trait Stark, const D: usize>: Sync { /// Evaluate constraints at a vector of points from the base field `F`. fn eval_packed_base>( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer

, ) { self.eval_packed_generic(vars, yield_constr) @@ -50,7 +59,7 @@ pub trait Stark, const D: usize>: Sync { /// Evaluate constraints at a single point from the degree `D` extension field. fn eval_ext( &self, - vars: StarkEvaluationVars, + vars: &Self::EvaluationFrame, yield_constr: &mut ConstraintConsumer, ) { self.eval_packed_generic(vars, yield_constr) @@ -63,7 +72,7 @@ pub trait Stark, const D: usize>: Sync { fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ); diff --git a/evm/src/stark_testing.rs b/evm/src/stark_testing.rs index e005d2ea..5fe44127 100644 --- a/evm/src/stark_testing.rs +++ b/evm/src/stark_testing.rs @@ -3,17 +3,16 @@ use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2::field::types::{Field, Sample}; use plonky2::hash::hash_types::RichField; -use plonky2::hash::hashing::PlonkyPermutation; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2::plonk::config::GenericConfig; use plonky2::util::transpose; use plonky2_util::{log2_ceil, log2_strict}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkEvaluationFrame; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; const WITNESS_SIZE: usize = 1 << 5; @@ -21,10 +20,7 @@ const WITNESS_SIZE: usize = 1 << 5; /// low-degree witness polynomials. pub fn test_stark_low_degree, S: Stark, const D: usize>( stark: S, -) -> Result<()> -where - [(); S::COLUMNS]:, -{ +) -> Result<()> { let rate_bits = log2_ceil(stark.constraint_degree() + 1); let trace_ldes = random_low_degree_matrix::(S::COLUMNS, rate_bits); @@ -39,13 +35,10 @@ where let alpha = F::rand(); let constraint_evals = (0..size) .map(|i| { - let vars = StarkEvaluationVars { - local_values: &trace_ldes[i].clone().try_into().unwrap(), - next_values: &trace_ldes[(i + (1 << rate_bits)) % size] - .clone() - .try_into() - .unwrap(), - }; + let vars = S::EvaluationFrame::from_values( + &trace_ldes[i], + &trace_ldes[(i + (1 << rate_bits)) % size], + ); let mut consumer = ConstraintConsumer::::new( vec![alpha], @@ -53,7 +46,7 @@ where lagrange_first.values[i], lagrange_last.values[i], ); - stark.eval_packed_base(vars, &mut consumer); + stark.eval_packed_base(&vars, &mut consumer); consumer.accumulators()[0] }) .collect::>(); @@ -84,17 +77,13 @@ pub fn test_stark_circuit_constraints< const D: usize, >( stark: S, -) -> Result<()> -where - [(); S::COLUMNS]:, - [(); >::Permutation::WIDTH]:, - [(); >::Permutation::WIDTH]:, -{ +) -> Result<()> { // Compute native constraint evaluation on random values. - let vars = StarkEvaluationVars { - local_values: &F::Extension::rand_array::<{ S::COLUMNS }>(), - next_values: &F::Extension::rand_array::<{ S::COLUMNS }>(), - }; + let vars = S::EvaluationFrame::from_values( + &F::Extension::rand_vec(S::COLUMNS), + &F::Extension::rand_vec(S::COLUMNS), + ); + let alphas = F::rand_vec(1); let z_last = F::Extension::rand(); let lagrange_first = F::Extension::rand(); @@ -109,7 +98,7 @@ where lagrange_first, lagrange_last, ); - stark.eval_ext(vars, &mut consumer); + stark.eval_ext(&vars, &mut consumer); let native_eval = consumer.accumulators()[0]; // Compute circuit constraint evaluation on same random values. @@ -118,9 +107,9 @@ where let mut pw = PartialWitness::::new(); let locals_t = builder.add_virtual_extension_targets(S::COLUMNS); - pw.set_extension_targets(&locals_t, vars.local_values); + pw.set_extension_targets(&locals_t, vars.get_local_values()); let nexts_t = builder.add_virtual_extension_targets(S::COLUMNS); - pw.set_extension_targets(&nexts_t, vars.next_values); + pw.set_extension_targets(&nexts_t, vars.get_next_values()); let alphas_t = builder.add_virtual_targets(1); pw.set_target(alphas_t[0], alphas[0]); let z_last_t = builder.add_virtual_extension_target(); @@ -130,10 +119,7 @@ where let lagrange_last_t = builder.add_virtual_extension_target(); pw.set_extension_target(lagrange_last_t, lagrange_last); - let vars = StarkEvaluationTargets:: { - local_values: &locals_t.try_into().unwrap(), - next_values: &nexts_t.try_into().unwrap(), - }; + let vars = S::EvaluationFrameTarget::from_values(&locals_t, &nexts_t); let mut consumer = RecursiveConstraintConsumer::::new( builder.zero_extension(), alphas_t, @@ -141,7 +127,7 @@ where lagrange_first_t, lagrange_last_t, ); - stark.eval_ext_circuit(&mut builder, vars, &mut consumer); + stark.eval_ext_circuit(&mut builder, &vars, &mut consumer); let circuit_eval = consumer.accumulators()[0]; let native_eval_t = builder.constant_extension(native_eval); builder.connect_extension(circuit_eval, native_eval_t); diff --git a/evm/src/vanishing_poly.rs b/evm/src/vanishing_poly.rs index 3a2da78c..6c88b16e 100644 --- a/evm/src/vanishing_poly.rs +++ b/evm/src/vanishing_poly.rs @@ -14,12 +14,11 @@ use crate::permutation::{ PermutationCheckVars, }; use crate::stark::Stark; -use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) fn eval_vanishing_poly( stark: &S, config: &StarkConfig, - vars: StarkEvaluationVars, + vars: &S::EvaluationFrame, permutation_vars: Option>, ctl_vars: &[CtlCheckVars], consumer: &mut ConstraintConsumer

, @@ -46,14 +45,13 @@ pub(crate) fn eval_vanishing_poly_circuit( builder: &mut CircuitBuilder, stark: &S, config: &StarkConfig, - vars: StarkEvaluationTargets, + vars: &S::EvaluationFrameTarget, permutation_data: Option>, ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, ) where F: RichField + Extendable, S: Stark, - [(); S::COLUMNS]:, { stark.eval_ext_circuit(builder, vars, consumer); if let Some(permutation_data) = permutation_data { diff --git a/evm/src/vars.rs b/evm/src/vars.rs deleted file mode 100644 index 6c82675c..00000000 --- a/evm/src/vars.rs +++ /dev/null @@ -1,19 +0,0 @@ -use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; -use plonky2::iop::ext_target::ExtensionTarget; - -#[derive(Debug, Copy, Clone)] -pub struct StarkEvaluationVars<'a, F, P, const COLUMNS: usize> -where - F: Field, - P: PackedField, -{ - pub local_values: &'a [P; COLUMNS], - pub next_values: &'a [P; COLUMNS], -} - -#[derive(Debug, Copy, Clone)] -pub struct StarkEvaluationTargets<'a, const D: usize, const COLUMNS: usize> { - pub local_values: &'a [ExtensionTarget; COLUMNS], - pub next_values: &'a [ExtensionTarget; COLUMNS], -} diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 11f8155d..c7b58060 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -11,17 +11,11 @@ use plonky2::plonk::config::GenericConfig; use plonky2::plonk::plonk_common::reduce_with_powers; use crate::all_stark::{AllStark, Table, NUM_TABLES}; -use crate::arithmetic::arithmetic_stark::ArithmeticStark; -use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; -use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars}; -use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; -use crate::logic::LogicStark; -use crate::memory::memory_stark::MemoryStark; +use crate::evaluation_frame::StarkEvaluationFrame; use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; use crate::permutation::{GrandProductChallenge, PermutationCheckVars}; @@ -31,7 +25,6 @@ use crate::proof::{ use crate::stark::Stark; use crate::util::h2u; use crate::vanishing_poly::eval_vanishing_poly; -use crate::vars::StarkEvaluationVars; pub fn verify_proof, C: GenericConfig, const D: usize>( all_stark: &AllStark, @@ -39,13 +32,6 @@ pub fn verify_proof, C: GenericConfig, co config: &StarkConfig, ) -> Result<()> where - [(); ArithmeticStark::::COLUMNS]:, - [(); BytePackingStark::::COLUMNS]:, - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, { let AllProofChallenges { stark_challenges, @@ -301,10 +287,7 @@ pub(crate) fn verify_stark_proof_with_challenges< challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], config: &StarkConfig, -) -> Result<()> -where - [(); S::COLUMNS]:, -{ +) -> Result<()> { log::debug!("Checking proof: {}", type_name::()); validate_proof_shape(stark, proof, config, ctl_vars.len())?; let StarkOpeningSet { @@ -315,10 +298,7 @@ where ctl_zs_first, quotient_polys, } = &proof.openings; - let vars = StarkEvaluationVars { - local_values: &local_values.to_vec().try_into().unwrap(), - next_values: &next_values.to_vec().try_into().unwrap(), - }; + let vars = S::EvaluationFrame::from_values(local_values, next_values); let degree_bits = proof.recover_degree_bits(config); let (l_0, l_last) = eval_l_0_and_l_last(degree_bits, challenges.stark_zeta); @@ -343,7 +323,7 @@ where eval_vanishing_poly::( stark, config, - vars, + &vars, permutation_data, ctl_vars, &mut consumer, @@ -401,7 +381,6 @@ where F: RichField + Extendable, C: GenericConfig, S: Stark, - [(); S::COLUMNS]:, { let StarkProof { trace_cap, From 0abc3b92104893bb0f725178ad35f6bb94098f35 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Fri, 22 Sep 2023 10:14:47 -0400 Subject: [PATCH 2/6] Apply comments (#1248) --- evm/src/cpu/cpu_stark.rs | 16 ++++++---------- evm/src/keccak_sponge/keccak_sponge_stark.rs | 20 ++++++++------------ 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 14bb6015..f23ff308 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -243,11 +243,9 @@ impl, const D: usize> Stark for CpuStark, P: PackedField, { - let local_values = - TryInto::<[P; NUM_CPU_COLUMNS]>::try_into(vars.get_local_values()).unwrap(); + let local_values: &[P; NUM_CPU_COLUMNS] = vars.get_local_values().try_into().unwrap(); let local_values: &CpuColumnsView

= local_values.borrow(); - let next_values = - TryInto::<[P; NUM_CPU_COLUMNS]>::try_into(vars.get_next_values()).unwrap(); + let next_values: &[P; NUM_CPU_COLUMNS] = vars.get_next_values().try_into().unwrap(); let next_values: &CpuColumnsView

= next_values.borrow(); bootstrap_kernel::eval_bootstrap_kernel_packed(local_values, next_values, yield_constr); @@ -276,13 +274,11 @@ impl, const D: usize> Stark for CpuStark, ) { - let local_values = - TryInto::<[ExtensionTarget; NUM_CPU_COLUMNS]>::try_into(vars.get_local_values()) - .unwrap(); + let local_values: &[ExtensionTarget; NUM_CPU_COLUMNS] = + vars.get_local_values().try_into().unwrap(); let local_values: &CpuColumnsView> = local_values.borrow(); - let next_values = - TryInto::<[ExtensionTarget; NUM_CPU_COLUMNS]>::try_into(vars.get_next_values()) - .unwrap(); + let next_values: &[ExtensionTarget; NUM_CPU_COLUMNS] = + vars.get_next_values().try_into().unwrap(); let next_values: &CpuColumnsView> = next_values.borrow(); bootstrap_kernel::eval_bootstrap_kernel_ext_circuit( diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index 65edc941..2ed31c1f 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -438,11 +438,11 @@ impl, const D: usize> Stark for KeccakSpongeS FE: FieldExtension, P: PackedField, { - let local_values = - TryInto::<[P; NUM_KECCAK_SPONGE_COLUMNS]>::try_into(vars.get_local_values()).unwrap(); + let local_values: &[P; NUM_KECCAK_SPONGE_COLUMNS] = + vars.get_local_values().try_into().unwrap(); let local_values: &KeccakSpongeColumnsView

= local_values.borrow(); - let next_values = - TryInto::<[P; NUM_KECCAK_SPONGE_COLUMNS]>::try_into(vars.get_next_values()).unwrap(); + let next_values: &[P; NUM_KECCAK_SPONGE_COLUMNS] = + vars.get_next_values().try_into().unwrap(); let next_values: &KeccakSpongeColumnsView

= next_values.borrow(); // Each flag (full-input block, final block or implied dummy flag) must be boolean. @@ -549,15 +549,11 @@ impl, const D: usize> Stark for KeccakSpongeS vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer, ) { - let local_values = TryInto::<[ExtensionTarget; NUM_KECCAK_SPONGE_COLUMNS]>::try_into( - vars.get_local_values(), - ) - .unwrap(); + let local_values: &[ExtensionTarget; NUM_KECCAK_SPONGE_COLUMNS] = + vars.get_local_values().try_into().unwrap(); let local_values: &KeccakSpongeColumnsView> = local_values.borrow(); - let next_values = TryInto::<[ExtensionTarget; NUM_KECCAK_SPONGE_COLUMNS]>::try_into( - vars.get_next_values(), - ) - .unwrap(); + let next_values: &[ExtensionTarget; NUM_KECCAK_SPONGE_COLUMNS] = + vars.get_next_values().try_into().unwrap(); let next_values: &KeccakSpongeColumnsView> = next_values.borrow(); let one = builder.one_extension(); From 8c78271f5c17fac12f55e1026819d3c61e9c1e81 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 25 Sep 2023 18:20:22 +0200 Subject: [PATCH 3/6] Add `random` value to block metadata and fix `sys_prevrandao` (#1207) * Add random to block metadata and fix `sys_prevrandao` * Minor * Observe block_random * Write block_random * cargo fmt * block_random: H256 * Move sys_prevrandao to metadata.asm and delete syscall_stubs.asm * Set block_random in set_block_metadata_target * Minor * Minor --- evm/src/cpu/kernel/aggregator.rs | 1 - evm/src/cpu/kernel/asm/core/syscall_stubs.asm | 12 ---- evm/src/cpu/kernel/asm/memory/metadata.asm | 7 +++ .../cpu/kernel/constants/global_metadata.rs | 59 ++++++++++--------- evm/src/generation/mod.rs | 4 ++ evm/src/get_challenges.rs | 2 + evm/src/proof.rs | 25 ++++++-- evm/src/recursive_verifier.rs | 14 ++++- evm/src/verifier.rs | 6 +- evm/tests/add11_yml.rs | 3 +- evm/tests/basic_smart_contract.rs | 1 + evm/tests/log_opcode.rs | 5 +- evm/tests/self_balance_gas_cost.rs | 1 + evm/tests/simple_transfer.rs | 3 +- 14 files changed, 90 insertions(+), 53 deletions(-) delete mode 100644 evm/src/cpu/kernel/asm/core/syscall_stubs.asm diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 160702df..0c7c6579 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -36,7 +36,6 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/core/nonce.asm"), include_str!("asm/core/process_txn.asm"), include_str!("asm/core/syscall.asm"), - include_str!("asm/core/syscall_stubs.asm"), include_str!("asm/core/terminate.asm"), include_str!("asm/core/transfer.asm"), include_str!("asm/core/util.asm"), diff --git a/evm/src/cpu/kernel/asm/core/syscall_stubs.asm b/evm/src/cpu/kernel/asm/core/syscall_stubs.asm deleted file mode 100644 index 70a64b87..00000000 --- a/evm/src/cpu/kernel/asm/core/syscall_stubs.asm +++ /dev/null @@ -1,12 +0,0 @@ -// Labels for unimplemented syscalls to make the kernel assemble. -// Each label should be removed from this file once it is implemented. - -// This is a temporary version that returns the block difficulty (i.e. the old version of this opcode). -// TODO: Fix this. -// TODO: What semantics will this have for Edge? -global sys_prevrandao: - // stack: kexit_info - %charge_gas_const(@GAS_BASE) - %mload_global_metadata(@GLOBAL_METADATA_BLOCK_DIFFICULTY) - %stack (difficulty, kexit_info) -> (kexit_info, difficulty) - EXIT_KERNEL diff --git a/evm/src/cpu/kernel/asm/memory/metadata.asm b/evm/src/cpu/kernel/asm/memory/metadata.asm index 5b1417da..c26d3d5f 100644 --- a/evm/src/cpu/kernel/asm/memory/metadata.asm +++ b/evm/src/cpu/kernel/asm/memory/metadata.asm @@ -383,3 +383,10 @@ zero_hash: %decrement %mstore_global_metadata(@GLOBAL_METADATA_CALL_STACK_DEPTH) %endmacro + +global sys_prevrandao: + // stack: kexit_info + %charge_gas_const(@GAS_BASE) + %mload_global_metadata(@GLOBAL_METADATA_BLOCK_RANDOM) + %stack (random, kexit_info) -> (kexit_info, random) + EXIT_KERNEL diff --git a/evm/src/cpu/kernel/constants/global_metadata.rs b/evm/src/cpu/kernel/constants/global_metadata.rs index 01f7d0dc..ad685222 100644 --- a/evm/src/cpu/kernel/constants/global_metadata.rs +++ b/evm/src/cpu/kernel/constants/global_metadata.rs @@ -39,55 +39,56 @@ pub(crate) enum GlobalMetadata { BlockTimestamp = 15, BlockNumber = 16, BlockDifficulty = 17, - BlockGasLimit = 18, - BlockChainId = 19, - BlockBaseFee = 20, - BlockGasUsed = 21, + BlockRandom = 18, + BlockGasLimit = 19, + BlockChainId = 20, + BlockBaseFee = 21, + BlockGasUsed = 22, /// Before current transactions block values. - BlockGasUsedBefore = 22, + BlockGasUsedBefore = 23, /// After current transactions block values. - BlockGasUsedAfter = 23, + BlockGasUsedAfter = 24, /// Current block header hash - BlockCurrentHash = 24, + BlockCurrentHash = 25, /// Gas to refund at the end of the transaction. - RefundCounter = 25, + RefundCounter = 26, /// Length of the addresses access list. - AccessedAddressesLen = 26, + AccessedAddressesLen = 27, /// Length of the storage keys access list. - AccessedStorageKeysLen = 27, + AccessedStorageKeysLen = 28, /// Length of the self-destruct list. - SelfDestructListLen = 28, + SelfDestructListLen = 29, /// Length of the bloom entry buffer. - BloomEntryLen = 29, + BloomEntryLen = 30, /// Length of the journal. - JournalLen = 30, + JournalLen = 31, /// Length of the `JournalData` segment. - JournalDataLen = 31, + JournalDataLen = 32, /// Current checkpoint. - CurrentCheckpoint = 32, - TouchedAddressesLen = 33, + CurrentCheckpoint = 33, + TouchedAddressesLen = 34, // Gas cost for the access list in type-1 txns. See EIP-2930. - AccessListDataCost = 34, + AccessListDataCost = 35, // Start of the access list in the RLP for type-1 txns. - AccessListRlpStart = 35, + AccessListRlpStart = 36, // Length of the access list in the RLP for type-1 txns. - AccessListRlpLen = 36, + AccessListRlpLen = 37, // Boolean flag indicating if the txn is a contract creation txn. - ContractCreation = 37, - IsPrecompileFromEoa = 38, - CallStackDepth = 39, + ContractCreation = 38, + IsPrecompileFromEoa = 39, + CallStackDepth = 40, /// Transaction logs list length - LogsLen = 40, - LogsDataLen = 41, - LogsPayloadLen = 42, - TxnNumberBefore = 43, - TxnNumberAfter = 44, + LogsLen = 41, + LogsDataLen = 42, + LogsPayloadLen = 43, + TxnNumberBefore = 44, + TxnNumberAfter = 45, } impl GlobalMetadata { - pub(crate) const COUNT: usize = 45; + pub(crate) const COUNT: usize = 46; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -109,6 +110,7 @@ impl GlobalMetadata { Self::BlockTimestamp, Self::BlockNumber, Self::BlockDifficulty, + Self::BlockRandom, Self::BlockGasLimit, Self::BlockChainId, Self::BlockBaseFee, @@ -160,6 +162,7 @@ impl GlobalMetadata { Self::BlockTimestamp => "GLOBAL_METADATA_BLOCK_TIMESTAMP", Self::BlockNumber => "GLOBAL_METADATA_BLOCK_NUMBER", Self::BlockDifficulty => "GLOBAL_METADATA_BLOCK_DIFFICULTY", + Self::BlockRandom => "GLOBAL_METADATA_BLOCK_RANDOM", Self::BlockGasLimit => "GLOBAL_METADATA_BLOCK_GAS_LIMIT", Self::BlockChainId => "GLOBAL_METADATA_BLOCK_CHAIN_ID", Self::BlockBaseFee => "GLOBAL_METADATA_BLOCK_BASE_FEE", diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 35078e07..3f5bafba 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -100,6 +100,10 @@ fn apply_metadata_and_tries_memops, const D: usize> (GlobalMetadata::BlockTimestamp, metadata.block_timestamp), (GlobalMetadata::BlockNumber, metadata.block_number), (GlobalMetadata::BlockDifficulty, metadata.block_difficulty), + ( + GlobalMetadata::BlockRandom, + metadata.block_random.into_uint(), + ), (GlobalMetadata::BlockGasLimit, metadata.block_gaslimit), (GlobalMetadata::BlockChainId, metadata.block_chain_id), (GlobalMetadata::BlockBaseFee, metadata.block_base_fee), diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index ab25a28d..1d2ae602 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -65,6 +65,7 @@ fn observe_block_metadata< challenger.observe_element(u256_to_u32(block_metadata.block_number)?); challenger.observe_element(u256_to_u32(block_metadata.block_difficulty)?); challenger.observe_element(u256_to_u32(block_metadata.block_gaslimit)?); + challenger.observe_elements(&h256_limbs::(block_metadata.block_random)); challenger.observe_element(u256_to_u32(block_metadata.block_chain_id)?); let basefee = u256_to_u64(block_metadata.block_base_fee)?; challenger.observe_element(basefee.0); @@ -91,6 +92,7 @@ fn observe_block_metadata_target< challenger.observe_element(block_metadata.block_timestamp); challenger.observe_element(block_metadata.block_number); challenger.observe_element(block_metadata.block_difficulty); + challenger.observe_elements(&block_metadata.block_random); challenger.observe_element(block_metadata.block_gaslimit); challenger.observe_element(block_metadata.block_chain_id); challenger.observe_elements(&block_metadata.block_base_fee); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 4da5ad23..03e520ca 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -101,6 +101,7 @@ pub struct BlockMetadata { pub block_number: U256, /// The difficulty (before PoS transition) of this block. pub block_difficulty: U256, + pub block_random: H256, /// The gas limit of this block. It must fit in a `u32`. pub block_gaslimit: U256, /// The chain id of this block. @@ -175,6 +176,7 @@ impl PublicValuesTarget { block_timestamp, block_number, block_difficulty, + block_random, block_gaslimit, block_chain_id, block_base_fee, @@ -186,6 +188,7 @@ impl PublicValuesTarget { buffer.write_target(block_timestamp)?; buffer.write_target(block_number)?; buffer.write_target(block_difficulty)?; + buffer.write_target_array(&block_random)?; buffer.write_target(block_gaslimit)?; buffer.write_target(block_chain_id)?; buffer.write_target_array(&block_base_fee)?; @@ -235,6 +238,7 @@ impl PublicValuesTarget { block_timestamp: buffer.read_target()?, block_number: buffer.read_target()?, block_difficulty: buffer.read_target()?, + block_random: buffer.read_target_array()?, block_gaslimit: buffer.read_target()?, block_chain_id: buffer.read_target()?, block_base_fee: buffer.read_target_array()?, @@ -407,6 +411,7 @@ pub struct BlockMetadataTarget { pub block_timestamp: Target, pub block_number: Target, pub block_difficulty: Target, + pub block_random: [Target; 8], pub block_gaslimit: Target, pub block_chain_id: Target, pub block_base_fee: [Target; 2], @@ -415,24 +420,26 @@ pub struct BlockMetadataTarget { } impl BlockMetadataTarget { - const SIZE: usize = 77; + const SIZE: usize = 85; pub fn from_public_inputs(pis: &[Target]) -> Self { let block_beneficiary = pis[0..5].try_into().unwrap(); let block_timestamp = pis[5]; let block_number = pis[6]; let block_difficulty = pis[7]; - let block_gaslimit = pis[8]; - let block_chain_id = pis[9]; - let block_base_fee = pis[10..12].try_into().unwrap(); - let block_gas_used = pis[12]; - let block_bloom = pis[13..77].try_into().unwrap(); + let block_random = pis[8..16].try_into().unwrap(); + let block_gaslimit = pis[16]; + let block_chain_id = pis[17]; + let block_base_fee = pis[18..20].try_into().unwrap(); + let block_gas_used = pis[20]; + let block_bloom = pis[21..85].try_into().unwrap(); Self { block_beneficiary, block_timestamp, block_number, block_difficulty, + block_random, block_gaslimit, block_chain_id, block_base_fee, @@ -458,6 +465,9 @@ impl BlockMetadataTarget { block_timestamp: builder.select(condition, bm0.block_timestamp, bm1.block_timestamp), block_number: builder.select(condition, bm0.block_number, bm1.block_number), block_difficulty: builder.select(condition, bm0.block_difficulty, bm1.block_difficulty), + block_random: core::array::from_fn(|i| { + builder.select(condition, bm0.block_random[i], bm1.block_random[i]) + }), block_gaslimit: builder.select(condition, bm0.block_gaslimit, bm1.block_gaslimit), block_chain_id: builder.select(condition, bm0.block_chain_id, bm1.block_chain_id), block_base_fee: core::array::from_fn(|i| { @@ -481,6 +491,9 @@ impl BlockMetadataTarget { builder.connect(bm0.block_timestamp, bm1.block_timestamp); builder.connect(bm0.block_number, bm1.block_number); builder.connect(bm0.block_difficulty, bm1.block_difficulty); + for i in 0..8 { + builder.connect(bm0.block_random[i], bm1.block_random[i]); + } builder.connect(bm0.block_gaslimit, bm1.block_gaslimit); builder.connect(bm0.block_chain_id, bm1.block_chain_id); for i in 0..2 { diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index f4e76c39..1457344c 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -548,11 +548,15 @@ pub(crate) fn get_memory_extra_looking_products_circuit< ), ]; - let beneficiary_base_fee_cur_hash_fields: [(usize, &[Target]); 3] = [ + let beneficiary_random_base_fee_cur_hash_fields: [(usize, &[Target]); 4] = [ ( GlobalMetadata::BlockBeneficiary as usize, &public_values.block_metadata.block_beneficiary, ), + ( + GlobalMetadata::BlockRandom as usize, + &public_values.block_metadata.block_random, + ), ( GlobalMetadata::BlockBaseFee as usize, &public_values.block_metadata.block_base_fee, @@ -576,7 +580,7 @@ pub(crate) fn get_memory_extra_looking_products_circuit< ); }); - beneficiary_base_fee_cur_hash_fields.map(|(field, targets)| { + beneficiary_random_base_fee_cur_hash_fields.map(|(field, targets)| { product = add_data_write( builder, challenge, @@ -772,6 +776,7 @@ pub(crate) fn add_virtual_block_metadata, const D: let block_timestamp = builder.add_virtual_public_input(); let block_number = builder.add_virtual_public_input(); let block_difficulty = builder.add_virtual_public_input(); + let block_random = builder.add_virtual_public_input_arr(); let block_gaslimit = builder.add_virtual_public_input(); let block_chain_id = builder.add_virtual_public_input(); let block_base_fee = builder.add_virtual_public_input_arr(); @@ -782,6 +787,7 @@ pub(crate) fn add_virtual_block_metadata, const D: block_timestamp, block_number, block_difficulty, + block_random, block_gaslimit, block_chain_id, block_base_fee, @@ -1014,6 +1020,10 @@ where block_metadata_target.block_difficulty, u256_to_u32(block_metadata.block_difficulty)?, ); + witness.set_target_arr( + &block_metadata_target.block_random, + &h256_limbs(block_metadata.block_random), + ); witness.set_target( block_metadata_target.block_gaslimit, u256_to_u32(block_metadata.block_gaslimit)?, diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index c7b58060..96ef2860 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -1,7 +1,7 @@ use std::any::type_name; use anyhow::{ensure, Result}; -use ethereum_types::U256; +use ethereum_types::{BigEndianHash, U256}; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::types::Field; @@ -157,6 +157,10 @@ where GlobalMetadata::BlockNumber, public_values.block_metadata.block_number, ), + ( + GlobalMetadata::BlockRandom, + public_values.block_metadata.block_random.into_uint(), + ), ( GlobalMetadata::BlockDifficulty, public_values.block_metadata.block_difficulty, diff --git a/evm/tests/add11_yml.rs b/evm/tests/add11_yml.rs index f628e944..a456f02d 100644 --- a/evm/tests/add11_yml.rs +++ b/evm/tests/add11_yml.rs @@ -5,7 +5,7 @@ use std::time::Duration; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{Address, H256}; +use ethereum_types::{Address, BigEndianHash, H256}; use hex_literal::hex; use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; @@ -83,6 +83,7 @@ fn add11_yml() -> anyhow::Result<()> { block_timestamp: 0x03e8.into(), block_number: 1.into(), block_difficulty: 0x020000.into(), + block_random: H256::from_uint(&0x020000.into()), block_gaslimit: 0xff112233u32.into(), block_chain_id: 1.into(), block_base_fee: 0xa.into(), diff --git a/evm/tests/basic_smart_contract.rs b/evm/tests/basic_smart_contract.rs index 2cd549ff..0130c6fe 100644 --- a/evm/tests/basic_smart_contract.rs +++ b/evm/tests/basic_smart_contract.rs @@ -115,6 +115,7 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { block_gas_used: gas_used.into(), block_bloom: [0.into(); 8], block_base_fee: 0xa.into(), + block_random: Default::default(), }; let mut contract_code = HashMap::new(); diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index 16d83bd0..67407807 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -8,7 +8,7 @@ use bytes::Bytes; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{Address, H256, U256}; +use ethereum_types::{Address, BigEndianHash, H256, U256}; use hex_literal::hex; use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; @@ -135,6 +135,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { block_timestamp: 0x03e8.into(), block_number: 1.into(), block_difficulty: 0x020000.into(), + block_random: H256::from_uint(&0x020000.into()), block_gaslimit: 0xffffffffu32.into(), block_chain_id: 1.into(), block_base_fee: 0xa.into(), @@ -365,6 +366,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { .unwrap(), U256::from_dec_str("2722259584404615024560450425766186844160").unwrap(), ], + block_random: Default::default(), }; let beneficiary_account_after = AccountRlp { @@ -791,6 +793,7 @@ fn test_two_txn() -> anyhow::Result<()> { block_timestamp: 0x03e8.into(), block_number: 1.into(), block_difficulty: 0x020000.into(), + block_random: H256::from_uint(&0x020000.into()), block_gaslimit: 0xffffffffu32.into(), block_chain_id: 1.into(), block_base_fee: 0xa.into(), diff --git a/evm/tests/self_balance_gas_cost.rs b/evm/tests/self_balance_gas_cost.rs index 9ba1ac54..de16db94 100644 --- a/evm/tests/self_balance_gas_cost.rs +++ b/evm/tests/self_balance_gas_cost.rs @@ -104,6 +104,7 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { block_gas_used: gas_used.into(), block_bloom: [0.into(); 8], block_base_fee: 0xa.into(), + block_random: Default::default(), }; let mut contract_code = HashMap::new(); diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index b8c47fe9..268ad661 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -5,7 +5,7 @@ use std::time::Duration; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{Address, H256, U256}; +use ethereum_types::{Address, BigEndianHash, H256, U256}; use hex_literal::hex; use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; @@ -71,6 +71,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { block_timestamp: 0x03e8.into(), block_number: 1.into(), block_difficulty: 0x020000.into(), + block_random: H256::from_uint(&0x020000.into()), block_gaslimit: 0xff112233u32.into(), block_chain_id: 1.into(), block_base_fee: 0xa.into(), From 043d12c20e7b2b4c70051d27bbce3c217c271dab Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Mon, 25 Sep 2023 17:30:31 -0400 Subject: [PATCH 4/6] Fix observe_block_metadata --- evm/src/get_challenges.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 1d2ae602..cf17231a 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -64,8 +64,8 @@ fn observe_block_metadata< challenger.observe_element(u256_to_u32(block_metadata.block_timestamp)?); challenger.observe_element(u256_to_u32(block_metadata.block_number)?); challenger.observe_element(u256_to_u32(block_metadata.block_difficulty)?); - challenger.observe_element(u256_to_u32(block_metadata.block_gaslimit)?); challenger.observe_elements(&h256_limbs::(block_metadata.block_random)); + challenger.observe_element(u256_to_u32(block_metadata.block_gaslimit)?); challenger.observe_element(u256_to_u32(block_metadata.block_chain_id)?); let basefee = u256_to_u64(block_metadata.block_base_fee)?; challenger.observe_element(basefee.0); From 72241ca728b2534f39ab91f0df864679f7f80678 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 26 Sep 2023 15:05:45 +0200 Subject: [PATCH 5/6] Connect block_gas_used (#1253) --- evm/src/proof.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 03e520ca..d14485c5 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -499,6 +499,7 @@ impl BlockMetadataTarget { for i in 0..2 { builder.connect(bm0.block_base_fee[i], bm1.block_base_fee[i]) } + builder.connect(bm0.block_gas_used, bm1.block_gas_used); for i in 0..64 { builder.connect(bm0.block_bloom[i], bm1.block_bloom[i]) } From 03a95581981038d9c5501d35f6b59564e8c1a13c Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:13:57 -0400 Subject: [PATCH 6/6] Handle additional panics (#1250) * Remove some panic risks * Remove more panics * Handle jump with empty stack * Handle last expect * More panics * Handle from_big_endian * Handle from_little_endian * Remove remaining risky as_usize() * Remove explicit panic * Clippy * Handle unwrap * Make error messages more explicit * Simplify u256 to usize conversion --- evm/src/cpu/kernel/interpreter.rs | 7 +- evm/src/cpu/kernel/tests/account_code.rs | 6 +- evm/src/cpu/kernel/tests/balance.rs | 6 +- evm/src/cpu/kernel/tests/mpt/delete.rs | 5 +- evm/src/cpu/kernel/tests/mpt/hash.rs | 5 +- evm/src/cpu/kernel/tests/mpt/insert.rs | 5 +- evm/src/cpu/kernel/tests/mpt/load.rs | 22 +++- evm/src/cpu/kernel/tests/mpt/read.rs | 6 +- evm/src/cpu/kernel/tests/receipt.rs | 6 +- evm/src/generation/mod.rs | 7 +- evm/src/generation/mpt.rs | 117 +++++++++++------ evm/src/generation/outputs.rs | 89 ++++++------- evm/src/generation/prover_input.rs | 153 ++++++++++++----------- evm/src/generation/state.rs | 28 +++-- evm/src/generation/trie_extractor.rs | 52 ++++---- evm/src/recursive_verifier.rs | 21 ++-- evm/src/util.rs | 7 ++ evm/src/witness/errors.rs | 13 ++ evm/src/witness/memory.rs | 2 + evm/src/witness/operation.rs | 30 +++-- evm/src/witness/util.rs | 9 +- 21 files changed, 361 insertions(+), 235 deletions(-) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index c4deba99..8f19a072 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -117,7 +117,7 @@ impl<'a> Interpreter<'a> { let mut result = Self { kernel_mode: true, jumpdests: find_jumpdests(code), - generation_state: GenerationState::new(GenerationInputs::default(), code), + generation_state: GenerationState::new(GenerationInputs::default(), code).unwrap(), prover_inputs_map: prover_inputs, context: 0, halt_offsets: vec![DEFAULT_HALT_OFFSET], @@ -905,7 +905,10 @@ impl<'a> Interpreter<'a> { .prover_inputs_map .get(&(self.generation_state.registers.program_counter - 1)) .ok_or_else(|| anyhow!("Offset not in prover inputs."))?; - let output = self.generation_state.prover_input(prover_input_fn); + let output = self + .generation_state + .prover_input(prover_input_fn) + .map_err(|_| anyhow!("Invalid prover inputs."))?; self.push(output); Ok(()) } diff --git a/evm/src/cpu/kernel/tests/account_code.rs b/evm/src/cpu/kernel/tests/account_code.rs index 805fed04..f4c18fe6 100644 --- a/evm/src/cpu/kernel/tests/account_code.rs +++ b/evm/src/cpu/kernel/tests/account_code.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use keccak_hash::keccak; @@ -46,7 +46,9 @@ fn prepare_interpreter( interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/balance.rs b/evm/src/cpu/kernel/tests/balance.rs index 049bf9f8..40214405 100644 --- a/evm/src/cpu/kernel/tests/balance.rs +++ b/evm/src/cpu/kernel/tests/balance.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use keccak_hash::keccak; @@ -37,7 +37,9 @@ fn prepare_interpreter( interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/delete.rs b/evm/src/cpu/kernel/tests/mpt/delete.rs index 532a1603..42e8caf9 100644 --- a/evm/src/cpu/kernel/tests/mpt/delete.rs +++ b/evm/src/cpu/kernel/tests/mpt/delete.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{BigEndianHash, H256}; @@ -61,7 +61,8 @@ fn test_state_trie( let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 3d6c2a23..05077a94 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::partial_trie::PartialTrie; use ethereum_types::{BigEndianHash, H256}; @@ -113,7 +113,8 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index f8dbc274..6fd95a30 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{BigEndianHash, H256}; @@ -174,7 +174,8 @@ fn test_state_trie( let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index aed311d2..50a8a0ef 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethereum_types::{BigEndianHash, H256, U256}; use crate::cpu::kernel::aggregator::KERNEL; @@ -23,7 +23,9 @@ fn load_all_mpts_empty() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -62,7 +64,9 @@ fn load_all_mpts_leaf() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -111,7 +115,9 @@ fn load_all_mpts_hash() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -152,7 +158,9 @@ fn load_all_mpts_empty_branch() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -207,7 +215,9 @@ fn load_all_mpts_ext_to_leaf() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index 62313f62..f9ae94f0 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethereum_types::BigEndianHash; use crate::cpu::kernel::aggregator::KERNEL; @@ -22,7 +22,9 @@ fn mpt_read() -> Result<()> { let initial_stack = vec![0xdeadbeefu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/receipt.rs b/evm/src/cpu/kernel/tests/receipt.rs index 783f592b..b5583654 100644 --- a/evm/src/cpu/kernel/tests/receipt.rs +++ b/evm/src/cpu/kernel/tests/receipt.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethereum_types::{Address, U256}; use hex_literal::hex; use keccak_hash::keccak; @@ -413,7 +413,9 @@ fn test_mpt_insert_receipt() -> Result<()> { let initial_stack = vec![retdest]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; // If TrieData is empty, we need to push 0 because the first value is always 0. diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 3f5bafba..6b9ce000 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use anyhow::anyhow; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use plonky2::field::extension::Extendable; @@ -220,7 +221,8 @@ pub fn generate_traces, const D: usize>( PublicValues, GenerationOutputs, )> { - let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code); + let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code) + .map_err(|err| anyhow!("Failed to parse all the initial prover inputs: {:?}", err))?; apply_metadata_and_tries_memops(&mut state, &inputs); @@ -238,7 +240,8 @@ pub fn generate_traces, const D: usize>( state.traces.get_lengths() ); - let outputs = get_outputs(&mut state); + let outputs = get_outputs(&mut state) + .map_err(|err| anyhow!("Failed to generate post-state info: {:?}", err))?; let read_metadata = |field| state.memory.read_global_metadata(field); let trie_roots_before = TrieRoots { diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index 47129ed0..dbc36cac 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -11,6 +11,7 @@ use rlp_derive::{RlpDecodable, RlpEncodable}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::Node; #[derive(RlpEncodable, RlpDecodable, Debug)] @@ -60,15 +61,18 @@ pub struct LegacyReceiptRlp { pub logs: Vec, } -pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { - let mut inputs = all_mpt_prover_inputs(trie_inputs); +pub(crate) fn all_mpt_prover_inputs_reversed( + trie_inputs: &TrieInputs, +) -> Result, ProgramError> { + let mut inputs = all_mpt_prover_inputs(trie_inputs)?; inputs.reverse(); - inputs + Ok(inputs) } -pub(crate) fn parse_receipts(rlp: &[u8]) -> Vec { - let payload_info = PayloadInfo::from(rlp).unwrap(); - let decoded_receipt: LegacyReceiptRlp = rlp::decode(rlp).unwrap(); +pub(crate) fn parse_receipts(rlp: &[u8]) -> Result, ProgramError> { + let payload_info = PayloadInfo::from(rlp).map_err(|_| ProgramError::InvalidRlp)?; + let decoded_receipt: LegacyReceiptRlp = + rlp::decode(rlp).map_err(|_| ProgramError::InvalidRlp)?; let mut parsed_receipt = Vec::new(); parsed_receipt.push(payload_info.value_len.into()); // payload_len of the entire receipt @@ -76,13 +80,15 @@ pub(crate) fn parse_receipts(rlp: &[u8]) -> Vec { parsed_receipt.push(decoded_receipt.cum_gas_used); parsed_receipt.extend(decoded_receipt.bloom.iter().map(|byte| U256::from(*byte))); let encoded_logs = rlp::encode_list(&decoded_receipt.logs); - let logs_payload_info = PayloadInfo::from(&encoded_logs).unwrap(); + let logs_payload_info = + PayloadInfo::from(&encoded_logs).map_err(|_| ProgramError::InvalidRlp)?; parsed_receipt.push(logs_payload_info.value_len.into()); // payload_len of all the logs parsed_receipt.push(decoded_receipt.logs.len().into()); for log in decoded_receipt.logs { let encoded_log = rlp::encode(&log); - let log_payload_info = PayloadInfo::from(&encoded_log).unwrap(); + let log_payload_info = + PayloadInfo::from(&encoded_log).map_err(|_| ProgramError::InvalidRlp)?; parsed_receipt.push(log_payload_info.value_len.into()); // payload of one log parsed_receipt.push(U256::from_big_endian(&log.address.to_fixed_bytes())); parsed_receipt.push(log.topics.len().into()); @@ -91,10 +97,10 @@ pub(crate) fn parse_receipts(rlp: &[u8]) -> Vec { parsed_receipt.extend(log.data.iter().map(|byte| U256::from(*byte))); } - parsed_receipt + Ok(parsed_receipt) } /// Generate prover inputs for the initial MPT data, in the format expected by `mpt/load.asm`. -pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { +pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Result, ProgramError> { let mut prover_inputs = vec![]; let storage_tries_by_state_key = trie_inputs @@ -111,19 +117,19 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { empty_nibbles(), &mut prover_inputs, &storage_tries_by_state_key, - ); + )?; mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { - rlp::decode_list(rlp) - }); + Ok(rlp::decode_list(rlp)) + })?; mpt_prover_inputs( &trie_inputs.receipts_trie, &mut prover_inputs, &parse_receipts, - ); + )?; - prover_inputs + Ok(prover_inputs) } /// Given a trie, generate the prover input data for that trie. In essence, this serializes a trie @@ -134,36 +140,52 @@ pub(crate) fn mpt_prover_inputs( trie: &HashedPartialTrie, prover_inputs: &mut Vec, parse_value: &F, -) where - F: Fn(&[u8]) -> Vec, +) -> Result<(), ProgramError> +where + F: Fn(&[u8]) -> Result, ProgramError>, { prover_inputs.push((PartialTrieType::of(trie) as u32).into()); match trie.deref() { - Node::Empty => {} - Node::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + Node::Empty => Ok(()), + Node::Hash(h) => { + prover_inputs.push(U256::from_big_endian(h.as_bytes())); + Ok(()) + } Node::Branch { children, value } => { if value.is_empty() { prover_inputs.push(U256::zero()); // value_present = 0 } else { - let parsed_value = parse_value(value); + let parsed_value = parse_value(value)?; prover_inputs.push(U256::one()); // value_present = 1 prover_inputs.extend(parsed_value); } for child in children { - mpt_prover_inputs(child, prover_inputs, parse_value); + mpt_prover_inputs(child, prover_inputs, parse_value)?; } + + Ok(()) } Node::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); - mpt_prover_inputs(child, prover_inputs, parse_value); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); + mpt_prover_inputs(child, prover_inputs, parse_value) } Node::Leaf { nibbles, value } => { prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); - let leaf = parse_value(value); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); + let leaf = parse_value(value)?; prover_inputs.extend(leaf); + + Ok(()) } } } @@ -175,13 +197,20 @@ pub(crate) fn mpt_prover_inputs_state_trie( key: Nibbles, prover_inputs: &mut Vec, storage_tries_by_state_key: &HashMap, -) { +) -> Result<(), ProgramError> { prover_inputs.push((PartialTrieType::of(trie) as u32).into()); match trie.deref() { - Node::Empty => {} - Node::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + Node::Empty => Ok(()), + Node::Hash(h) => { + prover_inputs.push(U256::from_big_endian(h.as_bytes())); + Ok(()) + } Node::Branch { children, value } => { - assert!(value.is_empty(), "State trie should not have branch values"); + if !value.is_empty() { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidMptInput, + )); + } prover_inputs.push(U256::zero()); // value_present = 0 for (i, child) in children.iter().enumerate() { @@ -194,22 +223,28 @@ pub(crate) fn mpt_prover_inputs_state_trie( extended_key, prover_inputs, storage_tries_by_state_key, - ); + )?; } + + Ok(()) } Node::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); let extended_key = key.merge_nibbles(nibbles); mpt_prover_inputs_state_trie( child, extended_key, prover_inputs, storage_tries_by_state_key, - ); + ) } Node::Leaf { nibbles, value } => { - let account: AccountRlp = rlp::decode(value).expect("Decoding failed"); + let account: AccountRlp = rlp::decode(value).map_err(|_| ProgramError::InvalidRlp)?; let AccountRlp { nonce, balance, @@ -228,18 +263,24 @@ pub(crate) fn mpt_prover_inputs_state_trie( "In TrieInputs, an account's storage_root didn't match the associated storage trie hash"); prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); prover_inputs.push(nonce); prover_inputs.push(balance); - mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value); + mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value)?; prover_inputs.push(code_hash.into_uint()); + + Ok(()) } } } -fn parse_storage_value(value_rlp: &[u8]) -> Vec { - let value: U256 = rlp::decode(value_rlp).expect("Decoding failed"); - vec![value] +fn parse_storage_value(value_rlp: &[u8]) -> Result, ProgramError> { + let value: U256 = rlp::decode(value_rlp).map_err(|_| ProgramError::InvalidRlp)?; + Ok(vec![value]) } fn empty_nibbles() -> Nibbles { diff --git a/evm/src/generation/outputs.rs b/evm/src/generation/outputs.rs index 63a86906..0ce87082 100644 --- a/evm/src/generation/outputs.rs +++ b/evm/src/generation/outputs.rs @@ -8,6 +8,8 @@ use crate::generation::state::GenerationState; use crate::generation::trie_extractor::{ read_state_trie_value, read_storage_trie_value, read_trie, AccountTrieRecord, }; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; /// The post-state after trace generation; intended for debugging. #[derive(Clone, Debug)] @@ -29,47 +31,44 @@ pub struct AccountOutput { pub storage: HashMap, } -pub(crate) fn get_outputs(state: &mut GenerationState) -> GenerationOutputs { - // First observe all addresses passed in the by caller. +pub(crate) fn get_outputs( + state: &mut GenerationState, +) -> Result { + // First observe all addresses passed in by caller. for address in state.inputs.addresses.clone() { state.observe_address(address); } - let account_map = read_trie::( - &state.memory, - state.memory.read_global_metadata(StateTrieRoot).as_usize(), - read_state_trie_value, - ); + let ptr = u256_to_usize(state.memory.read_global_metadata(StateTrieRoot))?; + let account_map = read_trie::(&state.memory, ptr, read_state_trie_value)?; - let accounts = account_map - .into_iter() - .map(|(state_key_nibbles, account)| { - assert_eq!( - state_key_nibbles.count, 64, - "Each state key should have 64 nibbles = 256 bits" - ); - let state_key_h256 = H256::from_uint(&state_key_nibbles.try_into_u256().unwrap()); + let mut accounts = HashMap::with_capacity(account_map.len()); - let addr_or_state_key = - if let Some(address) = state.state_key_to_address.get(&state_key_h256) { - AddressOrStateKey::Address(*address) - } else { - AddressOrStateKey::StateKey(state_key_h256) - }; + for (state_key_nibbles, account) in account_map.into_iter() { + if state_key_nibbles.count != 64 { + return Err(ProgramError::IntegerTooLarge); + } + let state_key_h256 = H256::from_uint(&state_key_nibbles.try_into_u256().unwrap()); - let account_output = account_trie_record_to_output(state, account); - (addr_or_state_key, account_output) - }) - .collect(); + let addr_or_state_key = + if let Some(address) = state.state_key_to_address.get(&state_key_h256) { + AddressOrStateKey::Address(*address) + } else { + AddressOrStateKey::StateKey(state_key_h256) + }; - GenerationOutputs { accounts } + let account_output = account_trie_record_to_output(state, account)?; + accounts.insert(addr_or_state_key, account_output); + } + + Ok(GenerationOutputs { accounts }) } fn account_trie_record_to_output( state: &GenerationState, account: AccountTrieRecord, -) -> AccountOutput { - let storage = get_storage(state, account.storage_ptr); +) -> Result { + let storage = get_storage(state, account.storage_ptr)?; // TODO: This won't work if the account was created during the txn. // Need to track changes to code, similar to how we track addresses @@ -78,27 +77,33 @@ fn account_trie_record_to_output( .inputs .contract_code .get(&account.code_hash) - .unwrap_or_else(|| panic!("Code not found: {:?}", account.code_hash)) + .ok_or(ProgramError::UnknownContractCode)? .clone(); - AccountOutput { + Ok(AccountOutput { balance: account.balance, nonce: account.nonce, storage, code, - } + }) } /// Get an account's storage trie, given a pointer to its root. -fn get_storage(state: &GenerationState, storage_ptr: usize) -> HashMap { - read_trie::(&state.memory, storage_ptr, read_storage_trie_value) - .into_iter() - .map(|(storage_key_nibbles, value)| { - assert_eq!( - storage_key_nibbles.count, 64, - "Each storage key should have 64 nibbles = 256 bits" - ); - (storage_key_nibbles.try_into_u256().unwrap(), value) - }) - .collect() +fn get_storage( + state: &GenerationState, + storage_ptr: usize, +) -> Result, ProgramError> { + let storage_trie = read_trie::(&state.memory, storage_ptr, |x| { + Ok(read_storage_trie_value(x)) + })?; + + let mut map = HashMap::with_capacity(storage_trie.len()); + for (storage_key_nibbles, value) in storage_trie.into_iter() { + if storage_key_nibbles.count != 64 { + return Err(ProgramError::IntegerTooLarge); + }; + map.insert(storage_key_nibbles.try_into_u256().unwrap(), value); + } + + Ok(map) } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 14293289..205dff7c 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -16,7 +16,9 @@ use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint}; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; +use crate::witness::errors::ProgramError; +use crate::witness::errors::ProverInputError::*; use crate::witness::util::{current_context_peek, stack_peek}; /// Prover input function represented as a scoped function name. @@ -31,7 +33,7 @@ impl From> for ProverInputFn { } impl GenerationState { - pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> U256 { + pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[0].as_str() { "end_of_txns" => self.run_end_of_txns(), "ff" => self.run_ff(input_fn), @@ -42,51 +44,59 @@ impl GenerationState { "current_hash" => self.run_current_hash(), "account_code" => self.run_account_code(input_fn), "bignum_modmul" => self.run_bignum_modmul(), - _ => panic!("Unrecognized prover input function."), + _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } - fn run_end_of_txns(&mut self) -> U256 { + fn run_end_of_txns(&mut self) -> Result { let end = self.next_txn_index == self.inputs.signed_txns.len(); if end { - U256::one() + Ok(U256::one()) } else { self.next_txn_index += 1; - U256::zero() + Ok(U256::zero()) } } /// Finite field operations. - fn run_ff(&self, input_fn: &ProverInputFn) -> U256 { - let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); - let op = FieldOp::from_str(input_fn.0[2].as_str()).unwrap(); - let x = stack_peek(self, 0).expect("Empty stack"); + fn run_ff(&self, input_fn: &ProverInputFn) -> Result { + let field = EvmField::from_str(input_fn.0[1].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; + let op = FieldOp::from_str(input_fn.0[2].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; + let x = stack_peek(self, 0)?; field.op(op, x) } /// Special finite field operations. - fn run_sf(&self, input_fn: &ProverInputFn) -> U256 { - let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); + fn run_sf(&self, input_fn: &ProverInputFn) -> Result { + let field = EvmField::from_str(input_fn.0[1].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let inputs: [U256; 4] = match field { - Bls381Base => std::array::from_fn(|i| { - stack_peek(self, i).expect("Insufficient number of items on stack") - }), + Bls381Base => (0..4) + .map(|i| stack_peek(self, i)) + .collect::, _>>()? + .try_into() + .unwrap(), _ => todo!(), }; - match input_fn.0[2].as_str() { + let res = match input_fn.0[2].as_str() { "add_lo" => field.add_lo(inputs), "add_hi" => field.add_hi(inputs), "mul_lo" => field.mul_lo(inputs), "mul_hi" => field.mul_hi(inputs), "sub_lo" => field.sub_lo(inputs), "sub_hi" => field.sub_hi(inputs), - _ => todo!(), - } + _ => return Err(ProgramError::ProverInputError(InvalidFunction)), + }; + + Ok(res) } /// Finite field extension operations. - fn run_ffe(&self, input_fn: &ProverInputFn) -> U256 { - let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); + fn run_ffe(&self, input_fn: &ProverInputFn) -> Result { + let field = EvmField::from_str(input_fn.0[1].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let n = input_fn.0[2] .as_str() .split('_') @@ -94,61 +104,61 @@ impl GenerationState { .unwrap() .parse::() .unwrap(); - let ptr = stack_peek(self, 11 - n) - .expect("Insufficient number of items on stack") - .as_usize(); + let ptr = stack_peek(self, 11 - n).map(u256_to_usize)??; let f: [U256; 12] = match field { Bn254Base => std::array::from_fn(|i| current_context_peek(self, BnPairing, ptr + i)), _ => todo!(), }; - field.field_extension_inverse(n, f) + Ok(field.field_extension_inverse(n, f)) } /// MPT data. - fn run_mpt(&mut self) -> U256 { + fn run_mpt(&mut self) -> Result { self.mpt_prover_inputs .pop() - .unwrap_or_else(|| panic!("Out of MPT data")) + .ok_or(ProgramError::ProverInputError(OutOfMptData)) } /// RLP data. - fn run_rlp(&mut self) -> U256 { + fn run_rlp(&mut self) -> Result { self.rlp_prover_inputs .pop() - .unwrap_or_else(|| panic!("Out of RLP data")) + .ok_or(ProgramError::ProverInputError(OutOfRlpData)) } - fn run_current_hash(&mut self) -> U256 { - U256::from_big_endian(&self.inputs.block_hashes.cur_hash.0) + fn run_current_hash(&mut self) -> Result { + Ok(U256::from_big_endian(&self.inputs.block_hashes.cur_hash.0)) } /// Account code. - fn run_account_code(&mut self, input_fn: &ProverInputFn) -> U256 { + fn run_account_code(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { "length" => { // Return length of code. // stack: codehash, ... - let codehash = stack_peek(self, 0).expect("Empty stack"); - self.inputs + let codehash = stack_peek(self, 0)?; + Ok(self + .inputs .contract_code .get(&H256::from_uint(&codehash)) - .unwrap_or_else(|| panic!("No code found with hash {codehash}")) + .ok_or(ProgramError::ProverInputError(CodeHashNotFound))? .len() - .into() + .into()) } "get" => { // Return `code[i]`. // stack: i, code_length, codehash, ... - let i = stack_peek(self, 0).expect("Unexpected stack").as_usize(); - let codehash = stack_peek(self, 2).expect("Unexpected stack"); - self.inputs + let i = stack_peek(self, 0).map(u256_to_usize)??; + let codehash = stack_peek(self, 2)?; + Ok(self + .inputs .contract_code .get(&H256::from_uint(&codehash)) - .unwrap_or_else(|| panic!("No code found with hash {codehash}"))[i] - .into() + .ok_or(ProgramError::ProverInputError(CodeHashNotFound))?[i] + .into()) } - _ => panic!("Invalid prover input function."), + _ => Err(ProgramError::ProverInputError(InvalidInput)), } } @@ -156,24 +166,12 @@ impl GenerationState { // On the first call, calculates the remainder and quotient of the given inputs. // These are stored, as limbs, in self.bignum_modmul_result_limbs. // Subsequent calls return one limb at a time, in order (first remainder and then quotient). - fn run_bignum_modmul(&mut self) -> U256 { + fn run_bignum_modmul(&mut self) -> Result { if self.bignum_modmul_result_limbs.is_empty() { - let len = stack_peek(self, 1) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); - let a_start_loc = stack_peek(self, 2) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); - let b_start_loc = stack_peek(self, 3) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); - let m_start_loc = stack_peek(self, 4) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); + let len = stack_peek(self, 1).map(u256_to_usize)??; + let a_start_loc = stack_peek(self, 2).map(u256_to_usize)??; + let b_start_loc = stack_peek(self, 3).map(u256_to_usize)??; + let m_start_loc = stack_peek(self, 4).map(u256_to_usize)??; let (remainder, quotient) = self.bignum_modmul(len, a_start_loc, b_start_loc, m_start_loc); @@ -187,7 +185,9 @@ impl GenerationState { self.bignum_modmul_result_limbs.reverse(); } - self.bignum_modmul_result_limbs.pop().unwrap() + self.bignum_modmul_result_limbs + .pop() + .ok_or(ProgramError::ProverInputError(InvalidInput)) } fn bignum_modmul( @@ -284,27 +284,33 @@ impl EvmField { } } - fn op(&self, op: FieldOp, x: U256) -> U256 { + fn op(&self, op: FieldOp, x: U256) -> Result { match op { FieldOp::Inverse => self.inverse(x), FieldOp::Sqrt => self.sqrt(x), } } - fn inverse(&self, x: U256) -> U256 { + fn inverse(&self, x: U256) -> Result { let n = self.order(); - assert!(x < n); + if x >= n { + return Err(ProgramError::ProverInputError(InvalidInput)); + }; modexp(x, n - 2, n) } - fn sqrt(&self, x: U256) -> U256 { + fn sqrt(&self, x: U256) -> Result { let n = self.order(); - assert!(x < n); + if x >= n { + return Err(ProgramError::ProverInputError(InvalidInput)); + }; let (q, r) = (n + 1).div_mod(4.into()); - assert!( - r.is_zero(), - "Only naive sqrt implementation for now. If needed implement Tonelli-Shanks." - ); + + if !r.is_zero() { + return Err(ProgramError::ProverInputError(InvalidInput)); + }; + + // Only naive sqrt implementation for now. If needed implement Tonelli-Shanks modexp(x, q, n) } @@ -363,15 +369,18 @@ impl EvmField { } } -fn modexp(x: U256, e: U256, n: U256) -> U256 { +fn modexp(x: U256, e: U256, n: U256) -> Result { let mut current = x; let mut product = U256::one(); for j in 0..256 { if e.bit(j) { - product = U256::try_from(product.full_mul(current) % n).unwrap(); + product = U256::try_from(product.full_mul(current) % n) + .map_err(|_| ProgramError::ProverInputError(InvalidInput))?; } - current = U256::try_from(current.full_mul(current) % n).unwrap(); + current = U256::try_from(current.full_mul(current) % n) + .map_err(|_| ProgramError::ProverInputError(InvalidInput))?; } - product + + Ok(product) } diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 2b85821f..aec01e1b 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -10,6 +10,8 @@ use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryState}; use crate::witness::state::RegistersState; use crate::witness::traces::{TraceCheckpoint, Traces}; @@ -49,7 +51,7 @@ pub(crate) struct GenerationState { } impl GenerationState { - pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Self { + pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Result { log::debug!("Input signed_txns: {:?}", &inputs.signed_txns); log::debug!("Input state_trie: {:?}", &inputs.tries.state_trie); log::debug!( @@ -59,11 +61,11 @@ impl GenerationState { log::debug!("Input receipts_trie: {:?}", &inputs.tries.receipts_trie); log::debug!("Input storage_tries: {:?}", &inputs.tries.storage_tries); log::debug!("Input contract_code: {:?}", &inputs.contract_code); - let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries); + let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries)?; let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); let bignum_modmul_result_limbs = Vec::new(); - Self { + Ok(Self { inputs, registers: Default::default(), memory: MemoryState::new(kernel_code), @@ -73,23 +75,25 @@ impl GenerationState { rlp_prover_inputs, state_key_to_address: HashMap::new(), bignum_modmul_result_limbs, - } + }) } /// Updates `program_counter`, and potentially adds some extra handling if we're jumping to a /// special location. - pub fn jump_to(&mut self, dst: usize) { + pub fn jump_to(&mut self, dst: usize) -> Result<(), ProgramError> { self.registers.program_counter = dst; if dst == KERNEL.global_labels["observe_new_address"] { - let tip_u256 = stack_peek(self, 0).expect("Empty stack"); + let tip_u256 = stack_peek(self, 0)?; let tip_h256 = H256::from_uint(&tip_u256); let tip_h160 = H160::from(tip_h256); self.observe_address(tip_h160); } else if dst == KERNEL.global_labels["observe_new_contract"] { - let tip_u256 = stack_peek(self, 0).expect("Empty stack"); + let tip_u256 = stack_peek(self, 0)?; let tip_h256 = H256::from_uint(&tip_u256); - self.observe_contract(tip_h256); + self.observe_contract(tip_h256)?; } + + Ok(()) } /// Observe the given address, so that we will be able to recognize the associated state key. @@ -101,9 +105,9 @@ impl GenerationState { /// Observe the given code hash and store the associated code. /// When called, the code corresponding to `codehash` should be stored in the return data. - pub fn observe_contract(&mut self, codehash: H256) { + pub fn observe_contract(&mut self, codehash: H256) -> Result<(), ProgramError> { if self.inputs.contract_code.contains_key(&codehash) { - return; // Return early if the code hash has already been observed. + return Ok(()); // Return early if the code hash has already been observed. } let ctx = self.registers.context; @@ -112,7 +116,7 @@ impl GenerationState { Segment::ContextMetadata, ContextMetadata::ReturndataSize as usize, ); - let returndata_size = self.memory.get(returndata_size_addr).as_usize(); + let returndata_size = u256_to_usize(self.memory.get(returndata_size_addr))?; let code = self.memory.contexts[ctx].segments[Segment::Returndata as usize].content [..returndata_size] .iter() @@ -121,6 +125,8 @@ impl GenerationState { debug_assert_eq!(keccak(&code), codehash); self.inputs.contract_code.insert(codehash, code); + + Ok(()) } pub fn checkpoint(&self) -> GenerationStateCheckpoint { diff --git a/evm/src/generation/trie_extractor.rs b/evm/src/generation/trie_extractor.rs index a508a720..42c50c6d 100644 --- a/evm/src/generation/trie_extractor.rs +++ b/evm/src/generation/trie_extractor.rs @@ -7,6 +7,8 @@ use ethereum_types::{BigEndianHash, H256, U256, U512}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::memory::segments::Segment; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryState}; /// Account data as it's stored in the state trie, with a pointer to the storage trie. @@ -18,13 +20,13 @@ pub(crate) struct AccountTrieRecord { pub(crate) code_hash: H256, } -pub(crate) fn read_state_trie_value(slice: &[U256]) -> AccountTrieRecord { - AccountTrieRecord { +pub(crate) fn read_state_trie_value(slice: &[U256]) -> Result { + Ok(AccountTrieRecord { nonce: slice[0].low_u64(), balance: slice[1], - storage_ptr: slice[2].as_usize(), + storage_ptr: u256_to_usize(slice[2])?, code_hash: H256::from_uint(&slice[3]), - } + }) } pub(crate) fn read_storage_trie_value(slice: &[U256]) -> U256 { @@ -34,72 +36,76 @@ pub(crate) fn read_storage_trie_value(slice: &[U256]) -> U256 { pub(crate) fn read_trie( memory: &MemoryState, ptr: usize, - read_value: fn(&[U256]) -> V, -) -> HashMap { + read_value: fn(&[U256]) -> Result, +) -> Result, ProgramError> { let mut res = HashMap::new(); let empty_nibbles = Nibbles { count: 0, packed: U512::zero(), }; - read_trie_helper::(memory, ptr, read_value, empty_nibbles, &mut res); - res + read_trie_helper::(memory, ptr, read_value, empty_nibbles, &mut res)?; + Ok(res) } pub(crate) fn read_trie_helper( memory: &MemoryState, ptr: usize, - read_value: fn(&[U256]) -> V, + read_value: fn(&[U256]) -> Result, prefix: Nibbles, res: &mut HashMap, -) { +) -> Result<(), ProgramError> { let load = |offset| memory.get(MemoryAddress::new(0, Segment::TrieData, offset)); let load_slice_from = |init_offset| { &memory.contexts[0].segments[Segment::TrieData as usize].content[init_offset..] }; - let trie_type = PartialTrieType::all()[load(ptr).as_usize()]; + let trie_type = PartialTrieType::all()[u256_to_usize(load(ptr))?]; match trie_type { - PartialTrieType::Empty => {} - PartialTrieType::Hash => {} + PartialTrieType::Empty => Ok(()), + PartialTrieType::Hash => Ok(()), PartialTrieType::Branch => { let ptr_payload = ptr + 1; for i in 0u8..16 { - let child_ptr = load(ptr_payload + i as usize).as_usize(); - read_trie_helper::(memory, child_ptr, read_value, prefix.merge_nibble(i), res); + let child_ptr = u256_to_usize(load(ptr_payload + i as usize))?; + read_trie_helper::(memory, child_ptr, read_value, prefix.merge_nibble(i), res)?; } - let value_ptr = load(ptr_payload + 16).as_usize(); + let value_ptr = u256_to_usize(load(ptr_payload + 16))?; if value_ptr != 0 { - res.insert(prefix, read_value(load_slice_from(value_ptr))); + res.insert(prefix, read_value(load_slice_from(value_ptr))?); }; + + Ok(()) } PartialTrieType::Extension => { - let count = load(ptr + 1).as_usize(); + let count = u256_to_usize(load(ptr + 1))?; let packed = load(ptr + 2); let nibbles = Nibbles { count, packed: packed.into(), }; - let child_ptr = load(ptr + 3).as_usize(); + let child_ptr = u256_to_usize(load(ptr + 3))?; read_trie_helper::( memory, child_ptr, read_value, prefix.merge_nibbles(&nibbles), res, - ); + ) } PartialTrieType::Leaf => { - let count = load(ptr + 1).as_usize(); + let count = u256_to_usize(load(ptr + 1))?; let packed = load(ptr + 2); let nibbles = Nibbles { count, packed: packed.into(), }; - let value_ptr = load(ptr + 3).as_usize(); + let value_ptr = u256_to_usize(load(ptr + 3))?; res.insert( prefix.merge_nibbles(&nibbles), - read_value(load_slice_from(value_ptr)), + read_value(load_slice_from(value_ptr))?, ); + + Ok(()) } } } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 1457344c..113dd287 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -935,7 +935,7 @@ where witness, &public_values_target.extra_block_data, &public_values.extra_block_data, - ); + )?; Ok(()) } @@ -1072,26 +1072,21 @@ pub(crate) fn set_extra_public_values_target( witness: &mut W, ed_target: &ExtraBlockDataTarget, ed: &ExtraBlockData, -) where +) -> Result<(), ProgramError> +where F: RichField + Extendable, W: Witness, { witness.set_target( ed_target.txn_number_before, - F::from_canonical_usize(ed.txn_number_before.as_usize()), + u256_to_u32(ed.txn_number_before)?, ); witness.set_target( ed_target.txn_number_after, - F::from_canonical_usize(ed.txn_number_after.as_usize()), - ); - witness.set_target( - ed_target.gas_used_before, - F::from_canonical_usize(ed.gas_used_before.as_usize()), - ); - witness.set_target( - ed_target.gas_used_after, - F::from_canonical_usize(ed.gas_used_after.as_usize()), + u256_to_u32(ed.txn_number_after)?, ); + witness.set_target(ed_target.gas_used_before, u256_to_u32(ed.gas_used_before)?); + witness.set_target(ed_target.gas_used_after, u256_to_u32(ed.gas_used_after)?); let block_bloom_before = ed.block_bloom_before; let mut block_bloom_limbs = [F::ZERO; 64]; @@ -1108,4 +1103,6 @@ pub(crate) fn set_extra_public_values_target( } witness.set_target_arr(&ed_target.block_bloom_after, &block_bloom_limbs); + + Ok(()) } diff --git a/evm/src/util.rs b/evm/src/util.rs index a3f6d050..08233056 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,6 +70,11 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } +/// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. +pub(crate) fn u256_to_usize(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + #[allow(unused)] // TODO: Remove? /// Returns the 32-bit little-endian limbs of a `U256`. pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { @@ -171,6 +176,8 @@ pub(crate) fn u256_to_biguint(x: U256) -> BigUint { pub(crate) fn biguint_to_u256(x: BigUint) -> U256 { let bytes = x.to_bytes_le(); + // This could panic if `bytes.len() > 32` but this is only + // used here with `BigUint` constructed from `U256`. U256::from_little_endian(&bytes) } diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 1ab99eae..81862460 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -6,6 +6,7 @@ pub enum ProgramError { OutOfGas, InvalidOpcode, StackUnderflow, + InvalidRlp, InvalidJumpDestination, InvalidJumpiDestination, StackOverflow, @@ -14,6 +15,8 @@ pub enum ProgramError { GasLimitError, InterpreterError, IntegerTooLarge, + ProverInputError(ProverInputError), + UnknownContractCode, } #[allow(clippy::enum_variant_names)] @@ -23,3 +26,13 @@ pub enum MemoryError { SegmentTooLarge { segment: U256 }, VirtTooLarge { virt: U256 }, } + +#[derive(Debug)] +pub enum ProverInputError { + OutOfMptData, + OutOfRlpData, + CodeHashNotFound, + InvalidMptInput, + InvalidInput, + InvalidFunction, +} diff --git a/evm/src/witness/memory.rs b/evm/src/witness/memory.rs index 62e6a2fe..3b62c945 100644 --- a/evm/src/witness/memory.rs +++ b/evm/src/witness/memory.rs @@ -58,6 +58,8 @@ impl MemoryAddress { if virt.bits() > 32 { return Err(MemoryError(VirtTooLarge { virt })); } + + // Calling `as_usize` here is safe as those have been checked above. Ok(Self { context: context.as_usize(), segment: segment.as_usize(), diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 8349d56d..2abeaea4 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -15,6 +15,7 @@ use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE; use crate::extension_tower::BN_BASE; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; +use crate::util::u256_to_usize; use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, VirtTooLarge}; use crate::witness::errors::ProgramError; use crate::witness::errors::ProgramError::MemoryError; @@ -127,7 +128,7 @@ pub(crate) fn generate_keccak_general( row.is_keccak_sponge = F::ONE; let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; - let len = len.as_usize(); + let len = u256_to_usize(len)?; let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; let input = (0..len) @@ -162,7 +163,7 @@ pub(crate) fn generate_prover_input( ) -> Result<(), ProgramError> { let pc = state.registers.program_counter; let input_fn = &KERNEL.prover_inputs[&pc]; - let input = state.prover_input(input_fn); + let input = state.prover_input(input_fn)?; let write = stack_push_log_and_fill(state, &mut row, input)?; state.traces.push_memory(write); @@ -217,7 +218,7 @@ pub(crate) fn generate_jump( state.traces.push_memory(log_in0); state.traces.push_cpu(row); - state.jump_to(dst as usize); + state.jump_to(dst as usize)?; Ok(()) } @@ -241,7 +242,7 @@ pub(crate) fn generate_jumpi( let dst: u32 = dst .try_into() .map_err(|_| ProgramError::InvalidJumpiDestination)?; - state.jump_to(dst as usize); + state.jump_to(dst as usize)?; } else { row.general.jumps_mut().should_jump = F::ZERO; row.general.jumps_mut().cond_sum_pinv = F::ZERO; @@ -312,7 +313,7 @@ pub(crate) fn generate_set_context( let [(ctx, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let sp_to_save = state.registers.stack_len.into(); let old_ctx = state.registers.context; - let new_ctx = ctx.as_usize(); + let new_ctx = u256_to_usize(ctx)?; let sp_field = ContextMetadata::StackSize as usize; let old_sp_addr = MemoryAddress::new(old_ctx, Segment::ContextMetadata, sp_field); @@ -347,7 +348,8 @@ pub(crate) fn generate_set_context( }; state.registers.context = new_ctx; - state.registers.stack_len = new_sp.as_usize(); + let new_sp = u256_to_usize(new_sp)?; + state.registers.stack_len = new_sp; state.traces.push_memory(log_in); state.traces.push_memory(log_write_old_sp); state.traces.push_memory(log_read_new_sp); @@ -362,6 +364,10 @@ pub(crate) fn generate_push( ) -> Result<(), ProgramError> { let code_context = state.registers.code_context(); let num_bytes = n as usize; + if num_bytes > 32 { + // The call to `U256::from_big_endian()` would panic. + return Err(ProgramError::IntegerTooLarge); + } let initial_offset = state.registers.program_counter + 1; // First read val without going through `mem_read_with_log` type methods, so we can pass it @@ -589,7 +595,7 @@ pub(crate) fn generate_syscall( ); let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; - let new_program_counter = handler_addr.as_usize(); + let new_program_counter = u256_to_usize(handler_addr)?; let syscall_info = U256::from(state.registers.program_counter + 1) + (U256::from(u64::from(state.registers.is_kernel)) << 32) @@ -694,7 +700,11 @@ pub(crate) fn generate_mload_32bytes( ) -> Result<(), ProgramError> { let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; - let len = len.as_usize(); + let len = u256_to_usize(len)?; + if len > 32 { + // The call to `U256::from_big_endian()` would panic. + return Err(ProgramError::IntegerTooLarge); + } let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; if usize::MAX - base_address.virt < len { @@ -762,7 +772,7 @@ pub(crate) fn generate_mstore_32bytes( ) -> Result<(), ProgramError> { let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = stack_pop_with_log_and_fill::<5, _>(state, &mut row)?; - let len = len.as_usize(); + let len = u256_to_usize(len)?; let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; @@ -827,7 +837,7 @@ pub(crate) fn generate_exception( ); let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; - let new_program_counter = handler_addr.as_usize(); + let new_program_counter = u256_to_usize(handler_addr)?; let exc_info = U256::from(state.registers.program_counter) + (U256::from(state.registers.gas_used) << 192); diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 94488614..068a8e11 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -29,11 +29,14 @@ fn to_bits_le(n: u8) -> [F; 8] { } /// Peek at the stack item `i`th from the top. If `i=0` this gives the tip. -pub(crate) fn stack_peek(state: &GenerationState, i: usize) -> Option { +pub(crate) fn stack_peek( + state: &GenerationState, + i: usize, +) -> Result { if i >= state.registers.stack_len { - return None; + return Err(ProgramError::StackUnderflow); } - Some(state.memory.get(MemoryAddress::new( + Ok(state.memory.get(MemoryAddress::new( state.registers.context, Segment::Stack, state.registers.stack_len - 1 - i,