diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index c4deba99..8f19a072 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -117,7 +117,7 @@ impl<'a> Interpreter<'a> { let mut result = Self { kernel_mode: true, jumpdests: find_jumpdests(code), - generation_state: GenerationState::new(GenerationInputs::default(), code), + generation_state: GenerationState::new(GenerationInputs::default(), code).unwrap(), prover_inputs_map: prover_inputs, context: 0, halt_offsets: vec![DEFAULT_HALT_OFFSET], @@ -905,7 +905,10 @@ impl<'a> Interpreter<'a> { .prover_inputs_map .get(&(self.generation_state.registers.program_counter - 1)) .ok_or_else(|| anyhow!("Offset not in prover inputs."))?; - let output = self.generation_state.prover_input(prover_input_fn); + let output = self + .generation_state + .prover_input(prover_input_fn) + .map_err(|_| anyhow!("Invalid prover inputs."))?; self.push(output); Ok(()) } diff --git a/evm/src/cpu/kernel/tests/account_code.rs b/evm/src/cpu/kernel/tests/account_code.rs index 805fed04..f4c18fe6 100644 --- a/evm/src/cpu/kernel/tests/account_code.rs +++ b/evm/src/cpu/kernel/tests/account_code.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use keccak_hash::keccak; @@ -46,7 +46,9 @@ fn prepare_interpreter( interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/balance.rs b/evm/src/cpu/kernel/tests/balance.rs index 049bf9f8..40214405 100644 --- a/evm/src/cpu/kernel/tests/balance.rs +++ b/evm/src/cpu/kernel/tests/balance.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use keccak_hash::keccak; @@ -37,7 +37,9 @@ fn prepare_interpreter( interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/delete.rs b/evm/src/cpu/kernel/tests/mpt/delete.rs index 532a1603..42e8caf9 100644 --- a/evm/src/cpu/kernel/tests/mpt/delete.rs +++ b/evm/src/cpu/kernel/tests/mpt/delete.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{BigEndianHash, H256}; @@ -61,7 +61,8 @@ fn test_state_trie( let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 3d6c2a23..05077a94 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::partial_trie::PartialTrie; use ethereum_types::{BigEndianHash, H256}; @@ -113,7 +113,8 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index f8dbc274..6fd95a30 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{BigEndianHash, H256}; @@ -174,7 +174,8 @@ fn test_state_trie( let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs).map_err(|_| anyhow!("Invalid MPT data"))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index aed311d2..50a8a0ef 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethereum_types::{BigEndianHash, H256, U256}; use crate::cpu::kernel::aggregator::KERNEL; @@ -23,7 +23,9 @@ fn load_all_mpts_empty() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -62,7 +64,9 @@ fn load_all_mpts_leaf() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -111,7 +115,9 @@ fn load_all_mpts_hash() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -152,7 +158,9 @@ fn load_all_mpts_empty_branch() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); @@ -207,7 +215,9 @@ fn load_all_mpts_ext_to_leaf() -> Result<()> { let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index 62313f62..f9ae94f0 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethereum_types::BigEndianHash; use crate::cpu::kernel::aggregator::KERNEL; @@ -22,7 +22,9 @@ fn mpt_read() -> Result<()> { let initial_stack = vec![0xdeadbeefu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); diff --git a/evm/src/cpu/kernel/tests/receipt.rs b/evm/src/cpu/kernel/tests/receipt.rs index 783f592b..b5583654 100644 --- a/evm/src/cpu/kernel/tests/receipt.rs +++ b/evm/src/cpu/kernel/tests/receipt.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use ethereum_types::{Address, U256}; use hex_literal::hex; use keccak_hash::keccak; @@ -413,7 +413,9 @@ fn test_mpt_insert_receipt() -> Result<()> { let initial_stack = vec![retdest]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); - interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.generation_state.mpt_prover_inputs = + all_mpt_prover_inputs_reversed(&trie_inputs) + .map_err(|err| anyhow!("Invalid MPT data: {:?}", err))?; interpreter.run()?; // If TrieData is empty, we need to push 0 because the first value is always 0. diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 3f5bafba..6b9ce000 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use anyhow::anyhow; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; use ethereum_types::{Address, BigEndianHash, H256, U256}; use plonky2::field::extension::Extendable; @@ -220,7 +221,8 @@ pub fn generate_traces, const D: usize>( PublicValues, GenerationOutputs, )> { - let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code); + let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code) + .map_err(|err| anyhow!("Failed to parse all the initial prover inputs: {:?}", err))?; apply_metadata_and_tries_memops(&mut state, &inputs); @@ -238,7 +240,8 @@ pub fn generate_traces, const D: usize>( state.traces.get_lengths() ); - let outputs = get_outputs(&mut state); + let outputs = get_outputs(&mut state) + .map_err(|err| anyhow!("Failed to generate post-state info: {:?}", err))?; let read_metadata = |field| state.memory.read_global_metadata(field); let trie_roots_before = TrieRoots { diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index 47129ed0..dbc36cac 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -11,6 +11,7 @@ use rlp_derive::{RlpDecodable, RlpEncodable}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::Node; #[derive(RlpEncodable, RlpDecodable, Debug)] @@ -60,15 +61,18 @@ pub struct LegacyReceiptRlp { pub logs: Vec, } -pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { - let mut inputs = all_mpt_prover_inputs(trie_inputs); +pub(crate) fn all_mpt_prover_inputs_reversed( + trie_inputs: &TrieInputs, +) -> Result, ProgramError> { + let mut inputs = all_mpt_prover_inputs(trie_inputs)?; inputs.reverse(); - inputs + Ok(inputs) } -pub(crate) fn parse_receipts(rlp: &[u8]) -> Vec { - let payload_info = PayloadInfo::from(rlp).unwrap(); - let decoded_receipt: LegacyReceiptRlp = rlp::decode(rlp).unwrap(); +pub(crate) fn parse_receipts(rlp: &[u8]) -> Result, ProgramError> { + let payload_info = PayloadInfo::from(rlp).map_err(|_| ProgramError::InvalidRlp)?; + let decoded_receipt: LegacyReceiptRlp = + rlp::decode(rlp).map_err(|_| ProgramError::InvalidRlp)?; let mut parsed_receipt = Vec::new(); parsed_receipt.push(payload_info.value_len.into()); // payload_len of the entire receipt @@ -76,13 +80,15 @@ pub(crate) fn parse_receipts(rlp: &[u8]) -> Vec { parsed_receipt.push(decoded_receipt.cum_gas_used); parsed_receipt.extend(decoded_receipt.bloom.iter().map(|byte| U256::from(*byte))); let encoded_logs = rlp::encode_list(&decoded_receipt.logs); - let logs_payload_info = PayloadInfo::from(&encoded_logs).unwrap(); + let logs_payload_info = + PayloadInfo::from(&encoded_logs).map_err(|_| ProgramError::InvalidRlp)?; parsed_receipt.push(logs_payload_info.value_len.into()); // payload_len of all the logs parsed_receipt.push(decoded_receipt.logs.len().into()); for log in decoded_receipt.logs { let encoded_log = rlp::encode(&log); - let log_payload_info = PayloadInfo::from(&encoded_log).unwrap(); + let log_payload_info = + PayloadInfo::from(&encoded_log).map_err(|_| ProgramError::InvalidRlp)?; parsed_receipt.push(log_payload_info.value_len.into()); // payload of one log parsed_receipt.push(U256::from_big_endian(&log.address.to_fixed_bytes())); parsed_receipt.push(log.topics.len().into()); @@ -91,10 +97,10 @@ pub(crate) fn parse_receipts(rlp: &[u8]) -> Vec { parsed_receipt.extend(log.data.iter().map(|byte| U256::from(*byte))); } - parsed_receipt + Ok(parsed_receipt) } /// Generate prover inputs for the initial MPT data, in the format expected by `mpt/load.asm`. -pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { +pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Result, ProgramError> { let mut prover_inputs = vec![]; let storage_tries_by_state_key = trie_inputs @@ -111,19 +117,19 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { empty_nibbles(), &mut prover_inputs, &storage_tries_by_state_key, - ); + )?; mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { - rlp::decode_list(rlp) - }); + Ok(rlp::decode_list(rlp)) + })?; mpt_prover_inputs( &trie_inputs.receipts_trie, &mut prover_inputs, &parse_receipts, - ); + )?; - prover_inputs + Ok(prover_inputs) } /// Given a trie, generate the prover input data for that trie. In essence, this serializes a trie @@ -134,36 +140,52 @@ pub(crate) fn mpt_prover_inputs( trie: &HashedPartialTrie, prover_inputs: &mut Vec, parse_value: &F, -) where - F: Fn(&[u8]) -> Vec, +) -> Result<(), ProgramError> +where + F: Fn(&[u8]) -> Result, ProgramError>, { prover_inputs.push((PartialTrieType::of(trie) as u32).into()); match trie.deref() { - Node::Empty => {} - Node::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + Node::Empty => Ok(()), + Node::Hash(h) => { + prover_inputs.push(U256::from_big_endian(h.as_bytes())); + Ok(()) + } Node::Branch { children, value } => { if value.is_empty() { prover_inputs.push(U256::zero()); // value_present = 0 } else { - let parsed_value = parse_value(value); + let parsed_value = parse_value(value)?; prover_inputs.push(U256::one()); // value_present = 1 prover_inputs.extend(parsed_value); } for child in children { - mpt_prover_inputs(child, prover_inputs, parse_value); + mpt_prover_inputs(child, prover_inputs, parse_value)?; } + + Ok(()) } Node::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); - mpt_prover_inputs(child, prover_inputs, parse_value); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); + mpt_prover_inputs(child, prover_inputs, parse_value) } Node::Leaf { nibbles, value } => { prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); - let leaf = parse_value(value); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); + let leaf = parse_value(value)?; prover_inputs.extend(leaf); + + Ok(()) } } } @@ -175,13 +197,20 @@ pub(crate) fn mpt_prover_inputs_state_trie( key: Nibbles, prover_inputs: &mut Vec, storage_tries_by_state_key: &HashMap, -) { +) -> Result<(), ProgramError> { prover_inputs.push((PartialTrieType::of(trie) as u32).into()); match trie.deref() { - Node::Empty => {} - Node::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + Node::Empty => Ok(()), + Node::Hash(h) => { + prover_inputs.push(U256::from_big_endian(h.as_bytes())); + Ok(()) + } Node::Branch { children, value } => { - assert!(value.is_empty(), "State trie should not have branch values"); + if !value.is_empty() { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidMptInput, + )); + } prover_inputs.push(U256::zero()); // value_present = 0 for (i, child) in children.iter().enumerate() { @@ -194,22 +223,28 @@ pub(crate) fn mpt_prover_inputs_state_trie( extended_key, prover_inputs, storage_tries_by_state_key, - ); + )?; } + + Ok(()) } Node::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); let extended_key = key.merge_nibbles(nibbles); mpt_prover_inputs_state_trie( child, extended_key, prover_inputs, storage_tries_by_state_key, - ); + ) } Node::Leaf { nibbles, value } => { - let account: AccountRlp = rlp::decode(value).expect("Decoding failed"); + let account: AccountRlp = rlp::decode(value).map_err(|_| ProgramError::InvalidRlp)?; let AccountRlp { nonce, balance, @@ -228,18 +263,24 @@ pub(crate) fn mpt_prover_inputs_state_trie( "In TrieInputs, an account's storage_root didn't match the associated storage trie hash"); prover_inputs.push(nibbles.count.into()); - prover_inputs.push(nibbles.try_into_u256().unwrap()); + prover_inputs.push( + nibbles + .try_into_u256() + .map_err(|_| ProgramError::IntegerTooLarge)?, + ); prover_inputs.push(nonce); prover_inputs.push(balance); - mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value); + mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value)?; prover_inputs.push(code_hash.into_uint()); + + Ok(()) } } } -fn parse_storage_value(value_rlp: &[u8]) -> Vec { - let value: U256 = rlp::decode(value_rlp).expect("Decoding failed"); - vec![value] +fn parse_storage_value(value_rlp: &[u8]) -> Result, ProgramError> { + let value: U256 = rlp::decode(value_rlp).map_err(|_| ProgramError::InvalidRlp)?; + Ok(vec![value]) } fn empty_nibbles() -> Nibbles { diff --git a/evm/src/generation/outputs.rs b/evm/src/generation/outputs.rs index 63a86906..0ce87082 100644 --- a/evm/src/generation/outputs.rs +++ b/evm/src/generation/outputs.rs @@ -8,6 +8,8 @@ use crate::generation::state::GenerationState; use crate::generation::trie_extractor::{ read_state_trie_value, read_storage_trie_value, read_trie, AccountTrieRecord, }; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; /// The post-state after trace generation; intended for debugging. #[derive(Clone, Debug)] @@ -29,47 +31,44 @@ pub struct AccountOutput { pub storage: HashMap, } -pub(crate) fn get_outputs(state: &mut GenerationState) -> GenerationOutputs { - // First observe all addresses passed in the by caller. +pub(crate) fn get_outputs( + state: &mut GenerationState, +) -> Result { + // First observe all addresses passed in by caller. for address in state.inputs.addresses.clone() { state.observe_address(address); } - let account_map = read_trie::( - &state.memory, - state.memory.read_global_metadata(StateTrieRoot).as_usize(), - read_state_trie_value, - ); + let ptr = u256_to_usize(state.memory.read_global_metadata(StateTrieRoot))?; + let account_map = read_trie::(&state.memory, ptr, read_state_trie_value)?; - let accounts = account_map - .into_iter() - .map(|(state_key_nibbles, account)| { - assert_eq!( - state_key_nibbles.count, 64, - "Each state key should have 64 nibbles = 256 bits" - ); - let state_key_h256 = H256::from_uint(&state_key_nibbles.try_into_u256().unwrap()); + let mut accounts = HashMap::with_capacity(account_map.len()); - let addr_or_state_key = - if let Some(address) = state.state_key_to_address.get(&state_key_h256) { - AddressOrStateKey::Address(*address) - } else { - AddressOrStateKey::StateKey(state_key_h256) - }; + for (state_key_nibbles, account) in account_map.into_iter() { + if state_key_nibbles.count != 64 { + return Err(ProgramError::IntegerTooLarge); + } + let state_key_h256 = H256::from_uint(&state_key_nibbles.try_into_u256().unwrap()); - let account_output = account_trie_record_to_output(state, account); - (addr_or_state_key, account_output) - }) - .collect(); + let addr_or_state_key = + if let Some(address) = state.state_key_to_address.get(&state_key_h256) { + AddressOrStateKey::Address(*address) + } else { + AddressOrStateKey::StateKey(state_key_h256) + }; - GenerationOutputs { accounts } + let account_output = account_trie_record_to_output(state, account)?; + accounts.insert(addr_or_state_key, account_output); + } + + Ok(GenerationOutputs { accounts }) } fn account_trie_record_to_output( state: &GenerationState, account: AccountTrieRecord, -) -> AccountOutput { - let storage = get_storage(state, account.storage_ptr); +) -> Result { + let storage = get_storage(state, account.storage_ptr)?; // TODO: This won't work if the account was created during the txn. // Need to track changes to code, similar to how we track addresses @@ -78,27 +77,33 @@ fn account_trie_record_to_output( .inputs .contract_code .get(&account.code_hash) - .unwrap_or_else(|| panic!("Code not found: {:?}", account.code_hash)) + .ok_or(ProgramError::UnknownContractCode)? .clone(); - AccountOutput { + Ok(AccountOutput { balance: account.balance, nonce: account.nonce, storage, code, - } + }) } /// Get an account's storage trie, given a pointer to its root. -fn get_storage(state: &GenerationState, storage_ptr: usize) -> HashMap { - read_trie::(&state.memory, storage_ptr, read_storage_trie_value) - .into_iter() - .map(|(storage_key_nibbles, value)| { - assert_eq!( - storage_key_nibbles.count, 64, - "Each storage key should have 64 nibbles = 256 bits" - ); - (storage_key_nibbles.try_into_u256().unwrap(), value) - }) - .collect() +fn get_storage( + state: &GenerationState, + storage_ptr: usize, +) -> Result, ProgramError> { + let storage_trie = read_trie::(&state.memory, storage_ptr, |x| { + Ok(read_storage_trie_value(x)) + })?; + + let mut map = HashMap::with_capacity(storage_trie.len()); + for (storage_key_nibbles, value) in storage_trie.into_iter() { + if storage_key_nibbles.count != 64 { + return Err(ProgramError::IntegerTooLarge); + }; + map.insert(storage_key_nibbles.try_into_u256().unwrap(), value); + } + + Ok(map) } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 14293289..205dff7c 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -16,7 +16,9 @@ use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint}; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; +use crate::witness::errors::ProgramError; +use crate::witness::errors::ProverInputError::*; use crate::witness::util::{current_context_peek, stack_peek}; /// Prover input function represented as a scoped function name. @@ -31,7 +33,7 @@ impl From> for ProverInputFn { } impl GenerationState { - pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> U256 { + pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[0].as_str() { "end_of_txns" => self.run_end_of_txns(), "ff" => self.run_ff(input_fn), @@ -42,51 +44,59 @@ impl GenerationState { "current_hash" => self.run_current_hash(), "account_code" => self.run_account_code(input_fn), "bignum_modmul" => self.run_bignum_modmul(), - _ => panic!("Unrecognized prover input function."), + _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } - fn run_end_of_txns(&mut self) -> U256 { + fn run_end_of_txns(&mut self) -> Result { let end = self.next_txn_index == self.inputs.signed_txns.len(); if end { - U256::one() + Ok(U256::one()) } else { self.next_txn_index += 1; - U256::zero() + Ok(U256::zero()) } } /// Finite field operations. - fn run_ff(&self, input_fn: &ProverInputFn) -> U256 { - let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); - let op = FieldOp::from_str(input_fn.0[2].as_str()).unwrap(); - let x = stack_peek(self, 0).expect("Empty stack"); + fn run_ff(&self, input_fn: &ProverInputFn) -> Result { + let field = EvmField::from_str(input_fn.0[1].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; + let op = FieldOp::from_str(input_fn.0[2].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; + let x = stack_peek(self, 0)?; field.op(op, x) } /// Special finite field operations. - fn run_sf(&self, input_fn: &ProverInputFn) -> U256 { - let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); + fn run_sf(&self, input_fn: &ProverInputFn) -> Result { + let field = EvmField::from_str(input_fn.0[1].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let inputs: [U256; 4] = match field { - Bls381Base => std::array::from_fn(|i| { - stack_peek(self, i).expect("Insufficient number of items on stack") - }), + Bls381Base => (0..4) + .map(|i| stack_peek(self, i)) + .collect::, _>>()? + .try_into() + .unwrap(), _ => todo!(), }; - match input_fn.0[2].as_str() { + let res = match input_fn.0[2].as_str() { "add_lo" => field.add_lo(inputs), "add_hi" => field.add_hi(inputs), "mul_lo" => field.mul_lo(inputs), "mul_hi" => field.mul_hi(inputs), "sub_lo" => field.sub_lo(inputs), "sub_hi" => field.sub_hi(inputs), - _ => todo!(), - } + _ => return Err(ProgramError::ProverInputError(InvalidFunction)), + }; + + Ok(res) } /// Finite field extension operations. - fn run_ffe(&self, input_fn: &ProverInputFn) -> U256 { - let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); + fn run_ffe(&self, input_fn: &ProverInputFn) -> Result { + let field = EvmField::from_str(input_fn.0[1].as_str()) + .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let n = input_fn.0[2] .as_str() .split('_') @@ -94,61 +104,61 @@ impl GenerationState { .unwrap() .parse::() .unwrap(); - let ptr = stack_peek(self, 11 - n) - .expect("Insufficient number of items on stack") - .as_usize(); + let ptr = stack_peek(self, 11 - n).map(u256_to_usize)??; let f: [U256; 12] = match field { Bn254Base => std::array::from_fn(|i| current_context_peek(self, BnPairing, ptr + i)), _ => todo!(), }; - field.field_extension_inverse(n, f) + Ok(field.field_extension_inverse(n, f)) } /// MPT data. - fn run_mpt(&mut self) -> U256 { + fn run_mpt(&mut self) -> Result { self.mpt_prover_inputs .pop() - .unwrap_or_else(|| panic!("Out of MPT data")) + .ok_or(ProgramError::ProverInputError(OutOfMptData)) } /// RLP data. - fn run_rlp(&mut self) -> U256 { + fn run_rlp(&mut self) -> Result { self.rlp_prover_inputs .pop() - .unwrap_or_else(|| panic!("Out of RLP data")) + .ok_or(ProgramError::ProverInputError(OutOfRlpData)) } - fn run_current_hash(&mut self) -> U256 { - U256::from_big_endian(&self.inputs.block_hashes.cur_hash.0) + fn run_current_hash(&mut self) -> Result { + Ok(U256::from_big_endian(&self.inputs.block_hashes.cur_hash.0)) } /// Account code. - fn run_account_code(&mut self, input_fn: &ProverInputFn) -> U256 { + fn run_account_code(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { "length" => { // Return length of code. // stack: codehash, ... - let codehash = stack_peek(self, 0).expect("Empty stack"); - self.inputs + let codehash = stack_peek(self, 0)?; + Ok(self + .inputs .contract_code .get(&H256::from_uint(&codehash)) - .unwrap_or_else(|| panic!("No code found with hash {codehash}")) + .ok_or(ProgramError::ProverInputError(CodeHashNotFound))? .len() - .into() + .into()) } "get" => { // Return `code[i]`. // stack: i, code_length, codehash, ... - let i = stack_peek(self, 0).expect("Unexpected stack").as_usize(); - let codehash = stack_peek(self, 2).expect("Unexpected stack"); - self.inputs + let i = stack_peek(self, 0).map(u256_to_usize)??; + let codehash = stack_peek(self, 2)?; + Ok(self + .inputs .contract_code .get(&H256::from_uint(&codehash)) - .unwrap_or_else(|| panic!("No code found with hash {codehash}"))[i] - .into() + .ok_or(ProgramError::ProverInputError(CodeHashNotFound))?[i] + .into()) } - _ => panic!("Invalid prover input function."), + _ => Err(ProgramError::ProverInputError(InvalidInput)), } } @@ -156,24 +166,12 @@ impl GenerationState { // On the first call, calculates the remainder and quotient of the given inputs. // These are stored, as limbs, in self.bignum_modmul_result_limbs. // Subsequent calls return one limb at a time, in order (first remainder and then quotient). - fn run_bignum_modmul(&mut self) -> U256 { + fn run_bignum_modmul(&mut self) -> Result { if self.bignum_modmul_result_limbs.is_empty() { - let len = stack_peek(self, 1) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); - let a_start_loc = stack_peek(self, 2) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); - let b_start_loc = stack_peek(self, 3) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); - let m_start_loc = stack_peek(self, 4) - .expect("Stack does not have enough items") - .try_into() - .unwrap(); + let len = stack_peek(self, 1).map(u256_to_usize)??; + let a_start_loc = stack_peek(self, 2).map(u256_to_usize)??; + let b_start_loc = stack_peek(self, 3).map(u256_to_usize)??; + let m_start_loc = stack_peek(self, 4).map(u256_to_usize)??; let (remainder, quotient) = self.bignum_modmul(len, a_start_loc, b_start_loc, m_start_loc); @@ -187,7 +185,9 @@ impl GenerationState { self.bignum_modmul_result_limbs.reverse(); } - self.bignum_modmul_result_limbs.pop().unwrap() + self.bignum_modmul_result_limbs + .pop() + .ok_or(ProgramError::ProverInputError(InvalidInput)) } fn bignum_modmul( @@ -284,27 +284,33 @@ impl EvmField { } } - fn op(&self, op: FieldOp, x: U256) -> U256 { + fn op(&self, op: FieldOp, x: U256) -> Result { match op { FieldOp::Inverse => self.inverse(x), FieldOp::Sqrt => self.sqrt(x), } } - fn inverse(&self, x: U256) -> U256 { + fn inverse(&self, x: U256) -> Result { let n = self.order(); - assert!(x < n); + if x >= n { + return Err(ProgramError::ProverInputError(InvalidInput)); + }; modexp(x, n - 2, n) } - fn sqrt(&self, x: U256) -> U256 { + fn sqrt(&self, x: U256) -> Result { let n = self.order(); - assert!(x < n); + if x >= n { + return Err(ProgramError::ProverInputError(InvalidInput)); + }; let (q, r) = (n + 1).div_mod(4.into()); - assert!( - r.is_zero(), - "Only naive sqrt implementation for now. If needed implement Tonelli-Shanks." - ); + + if !r.is_zero() { + return Err(ProgramError::ProverInputError(InvalidInput)); + }; + + // Only naive sqrt implementation for now. If needed implement Tonelli-Shanks modexp(x, q, n) } @@ -363,15 +369,18 @@ impl EvmField { } } -fn modexp(x: U256, e: U256, n: U256) -> U256 { +fn modexp(x: U256, e: U256, n: U256) -> Result { let mut current = x; let mut product = U256::one(); for j in 0..256 { if e.bit(j) { - product = U256::try_from(product.full_mul(current) % n).unwrap(); + product = U256::try_from(product.full_mul(current) % n) + .map_err(|_| ProgramError::ProverInputError(InvalidInput))?; } - current = U256::try_from(current.full_mul(current) % n).unwrap(); + current = U256::try_from(current.full_mul(current) % n) + .map_err(|_| ProgramError::ProverInputError(InvalidInput))?; } - product + + Ok(product) } diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 2b85821f..aec01e1b 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -10,6 +10,8 @@ use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryState}; use crate::witness::state::RegistersState; use crate::witness::traces::{TraceCheckpoint, Traces}; @@ -49,7 +51,7 @@ pub(crate) struct GenerationState { } impl GenerationState { - pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Self { + pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Result { log::debug!("Input signed_txns: {:?}", &inputs.signed_txns); log::debug!("Input state_trie: {:?}", &inputs.tries.state_trie); log::debug!( @@ -59,11 +61,11 @@ impl GenerationState { log::debug!("Input receipts_trie: {:?}", &inputs.tries.receipts_trie); log::debug!("Input storage_tries: {:?}", &inputs.tries.storage_tries); log::debug!("Input contract_code: {:?}", &inputs.contract_code); - let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries); + let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries)?; let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); let bignum_modmul_result_limbs = Vec::new(); - Self { + Ok(Self { inputs, registers: Default::default(), memory: MemoryState::new(kernel_code), @@ -73,23 +75,25 @@ impl GenerationState { rlp_prover_inputs, state_key_to_address: HashMap::new(), bignum_modmul_result_limbs, - } + }) } /// Updates `program_counter`, and potentially adds some extra handling if we're jumping to a /// special location. - pub fn jump_to(&mut self, dst: usize) { + pub fn jump_to(&mut self, dst: usize) -> Result<(), ProgramError> { self.registers.program_counter = dst; if dst == KERNEL.global_labels["observe_new_address"] { - let tip_u256 = stack_peek(self, 0).expect("Empty stack"); + let tip_u256 = stack_peek(self, 0)?; let tip_h256 = H256::from_uint(&tip_u256); let tip_h160 = H160::from(tip_h256); self.observe_address(tip_h160); } else if dst == KERNEL.global_labels["observe_new_contract"] { - let tip_u256 = stack_peek(self, 0).expect("Empty stack"); + let tip_u256 = stack_peek(self, 0)?; let tip_h256 = H256::from_uint(&tip_u256); - self.observe_contract(tip_h256); + self.observe_contract(tip_h256)?; } + + Ok(()) } /// Observe the given address, so that we will be able to recognize the associated state key. @@ -101,9 +105,9 @@ impl GenerationState { /// Observe the given code hash and store the associated code. /// When called, the code corresponding to `codehash` should be stored in the return data. - pub fn observe_contract(&mut self, codehash: H256) { + pub fn observe_contract(&mut self, codehash: H256) -> Result<(), ProgramError> { if self.inputs.contract_code.contains_key(&codehash) { - return; // Return early if the code hash has already been observed. + return Ok(()); // Return early if the code hash has already been observed. } let ctx = self.registers.context; @@ -112,7 +116,7 @@ impl GenerationState { Segment::ContextMetadata, ContextMetadata::ReturndataSize as usize, ); - let returndata_size = self.memory.get(returndata_size_addr).as_usize(); + let returndata_size = u256_to_usize(self.memory.get(returndata_size_addr))?; let code = self.memory.contexts[ctx].segments[Segment::Returndata as usize].content [..returndata_size] .iter() @@ -121,6 +125,8 @@ impl GenerationState { debug_assert_eq!(keccak(&code), codehash); self.inputs.contract_code.insert(codehash, code); + + Ok(()) } pub fn checkpoint(&self) -> GenerationStateCheckpoint { diff --git a/evm/src/generation/trie_extractor.rs b/evm/src/generation/trie_extractor.rs index a508a720..42c50c6d 100644 --- a/evm/src/generation/trie_extractor.rs +++ b/evm/src/generation/trie_extractor.rs @@ -7,6 +7,8 @@ use ethereum_types::{BigEndianHash, H256, U256, U512}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::memory::segments::Segment; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryState}; /// Account data as it's stored in the state trie, with a pointer to the storage trie. @@ -18,13 +20,13 @@ pub(crate) struct AccountTrieRecord { pub(crate) code_hash: H256, } -pub(crate) fn read_state_trie_value(slice: &[U256]) -> AccountTrieRecord { - AccountTrieRecord { +pub(crate) fn read_state_trie_value(slice: &[U256]) -> Result { + Ok(AccountTrieRecord { nonce: slice[0].low_u64(), balance: slice[1], - storage_ptr: slice[2].as_usize(), + storage_ptr: u256_to_usize(slice[2])?, code_hash: H256::from_uint(&slice[3]), - } + }) } pub(crate) fn read_storage_trie_value(slice: &[U256]) -> U256 { @@ -34,72 +36,76 @@ pub(crate) fn read_storage_trie_value(slice: &[U256]) -> U256 { pub(crate) fn read_trie( memory: &MemoryState, ptr: usize, - read_value: fn(&[U256]) -> V, -) -> HashMap { + read_value: fn(&[U256]) -> Result, +) -> Result, ProgramError> { let mut res = HashMap::new(); let empty_nibbles = Nibbles { count: 0, packed: U512::zero(), }; - read_trie_helper::(memory, ptr, read_value, empty_nibbles, &mut res); - res + read_trie_helper::(memory, ptr, read_value, empty_nibbles, &mut res)?; + Ok(res) } pub(crate) fn read_trie_helper( memory: &MemoryState, ptr: usize, - read_value: fn(&[U256]) -> V, + read_value: fn(&[U256]) -> Result, prefix: Nibbles, res: &mut HashMap, -) { +) -> Result<(), ProgramError> { let load = |offset| memory.get(MemoryAddress::new(0, Segment::TrieData, offset)); let load_slice_from = |init_offset| { &memory.contexts[0].segments[Segment::TrieData as usize].content[init_offset..] }; - let trie_type = PartialTrieType::all()[load(ptr).as_usize()]; + let trie_type = PartialTrieType::all()[u256_to_usize(load(ptr))?]; match trie_type { - PartialTrieType::Empty => {} - PartialTrieType::Hash => {} + PartialTrieType::Empty => Ok(()), + PartialTrieType::Hash => Ok(()), PartialTrieType::Branch => { let ptr_payload = ptr + 1; for i in 0u8..16 { - let child_ptr = load(ptr_payload + i as usize).as_usize(); - read_trie_helper::(memory, child_ptr, read_value, prefix.merge_nibble(i), res); + let child_ptr = u256_to_usize(load(ptr_payload + i as usize))?; + read_trie_helper::(memory, child_ptr, read_value, prefix.merge_nibble(i), res)?; } - let value_ptr = load(ptr_payload + 16).as_usize(); + let value_ptr = u256_to_usize(load(ptr_payload + 16))?; if value_ptr != 0 { - res.insert(prefix, read_value(load_slice_from(value_ptr))); + res.insert(prefix, read_value(load_slice_from(value_ptr))?); }; + + Ok(()) } PartialTrieType::Extension => { - let count = load(ptr + 1).as_usize(); + let count = u256_to_usize(load(ptr + 1))?; let packed = load(ptr + 2); let nibbles = Nibbles { count, packed: packed.into(), }; - let child_ptr = load(ptr + 3).as_usize(); + let child_ptr = u256_to_usize(load(ptr + 3))?; read_trie_helper::( memory, child_ptr, read_value, prefix.merge_nibbles(&nibbles), res, - ); + ) } PartialTrieType::Leaf => { - let count = load(ptr + 1).as_usize(); + let count = u256_to_usize(load(ptr + 1))?; let packed = load(ptr + 2); let nibbles = Nibbles { count, packed: packed.into(), }; - let value_ptr = load(ptr + 3).as_usize(); + let value_ptr = u256_to_usize(load(ptr + 3))?; res.insert( prefix.merge_nibbles(&nibbles), - read_value(load_slice_from(value_ptr)), + read_value(load_slice_from(value_ptr))?, ); + + Ok(()) } } } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 1457344c..113dd287 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -935,7 +935,7 @@ where witness, &public_values_target.extra_block_data, &public_values.extra_block_data, - ); + )?; Ok(()) } @@ -1072,26 +1072,21 @@ pub(crate) fn set_extra_public_values_target( witness: &mut W, ed_target: &ExtraBlockDataTarget, ed: &ExtraBlockData, -) where +) -> Result<(), ProgramError> +where F: RichField + Extendable, W: Witness, { witness.set_target( ed_target.txn_number_before, - F::from_canonical_usize(ed.txn_number_before.as_usize()), + u256_to_u32(ed.txn_number_before)?, ); witness.set_target( ed_target.txn_number_after, - F::from_canonical_usize(ed.txn_number_after.as_usize()), - ); - witness.set_target( - ed_target.gas_used_before, - F::from_canonical_usize(ed.gas_used_before.as_usize()), - ); - witness.set_target( - ed_target.gas_used_after, - F::from_canonical_usize(ed.gas_used_after.as_usize()), + u256_to_u32(ed.txn_number_after)?, ); + witness.set_target(ed_target.gas_used_before, u256_to_u32(ed.gas_used_before)?); + witness.set_target(ed_target.gas_used_after, u256_to_u32(ed.gas_used_after)?); let block_bloom_before = ed.block_bloom_before; let mut block_bloom_limbs = [F::ZERO; 64]; @@ -1108,4 +1103,6 @@ pub(crate) fn set_extra_public_values_target( } witness.set_target_arr(&ed_target.block_bloom_after, &block_bloom_limbs); + + Ok(()) } diff --git a/evm/src/util.rs b/evm/src/util.rs index a3f6d050..08233056 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,6 +70,11 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } +/// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. +pub(crate) fn u256_to_usize(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + #[allow(unused)] // TODO: Remove? /// Returns the 32-bit little-endian limbs of a `U256`. pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { @@ -171,6 +176,8 @@ pub(crate) fn u256_to_biguint(x: U256) -> BigUint { pub(crate) fn biguint_to_u256(x: BigUint) -> U256 { let bytes = x.to_bytes_le(); + // This could panic if `bytes.len() > 32` but this is only + // used here with `BigUint` constructed from `U256`. U256::from_little_endian(&bytes) } diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 1ab99eae..81862460 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -6,6 +6,7 @@ pub enum ProgramError { OutOfGas, InvalidOpcode, StackUnderflow, + InvalidRlp, InvalidJumpDestination, InvalidJumpiDestination, StackOverflow, @@ -14,6 +15,8 @@ pub enum ProgramError { GasLimitError, InterpreterError, IntegerTooLarge, + ProverInputError(ProverInputError), + UnknownContractCode, } #[allow(clippy::enum_variant_names)] @@ -23,3 +26,13 @@ pub enum MemoryError { SegmentTooLarge { segment: U256 }, VirtTooLarge { virt: U256 }, } + +#[derive(Debug)] +pub enum ProverInputError { + OutOfMptData, + OutOfRlpData, + CodeHashNotFound, + InvalidMptInput, + InvalidInput, + InvalidFunction, +} diff --git a/evm/src/witness/memory.rs b/evm/src/witness/memory.rs index 62e6a2fe..3b62c945 100644 --- a/evm/src/witness/memory.rs +++ b/evm/src/witness/memory.rs @@ -58,6 +58,8 @@ impl MemoryAddress { if virt.bits() > 32 { return Err(MemoryError(VirtTooLarge { virt })); } + + // Calling `as_usize` here is safe as those have been checked above. Ok(Self { context: context.as_usize(), segment: segment.as_usize(), diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 8349d56d..2abeaea4 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -15,6 +15,7 @@ use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE; use crate::extension_tower::BN_BASE; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; +use crate::util::u256_to_usize; use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, VirtTooLarge}; use crate::witness::errors::ProgramError; use crate::witness::errors::ProgramError::MemoryError; @@ -127,7 +128,7 @@ pub(crate) fn generate_keccak_general( row.is_keccak_sponge = F::ONE; let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; - let len = len.as_usize(); + let len = u256_to_usize(len)?; let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; let input = (0..len) @@ -162,7 +163,7 @@ pub(crate) fn generate_prover_input( ) -> Result<(), ProgramError> { let pc = state.registers.program_counter; let input_fn = &KERNEL.prover_inputs[&pc]; - let input = state.prover_input(input_fn); + let input = state.prover_input(input_fn)?; let write = stack_push_log_and_fill(state, &mut row, input)?; state.traces.push_memory(write); @@ -217,7 +218,7 @@ pub(crate) fn generate_jump( state.traces.push_memory(log_in0); state.traces.push_cpu(row); - state.jump_to(dst as usize); + state.jump_to(dst as usize)?; Ok(()) } @@ -241,7 +242,7 @@ pub(crate) fn generate_jumpi( let dst: u32 = dst .try_into() .map_err(|_| ProgramError::InvalidJumpiDestination)?; - state.jump_to(dst as usize); + state.jump_to(dst as usize)?; } else { row.general.jumps_mut().should_jump = F::ZERO; row.general.jumps_mut().cond_sum_pinv = F::ZERO; @@ -312,7 +313,7 @@ pub(crate) fn generate_set_context( let [(ctx, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let sp_to_save = state.registers.stack_len.into(); let old_ctx = state.registers.context; - let new_ctx = ctx.as_usize(); + let new_ctx = u256_to_usize(ctx)?; let sp_field = ContextMetadata::StackSize as usize; let old_sp_addr = MemoryAddress::new(old_ctx, Segment::ContextMetadata, sp_field); @@ -347,7 +348,8 @@ pub(crate) fn generate_set_context( }; state.registers.context = new_ctx; - state.registers.stack_len = new_sp.as_usize(); + let new_sp = u256_to_usize(new_sp)?; + state.registers.stack_len = new_sp; state.traces.push_memory(log_in); state.traces.push_memory(log_write_old_sp); state.traces.push_memory(log_read_new_sp); @@ -362,6 +364,10 @@ pub(crate) fn generate_push( ) -> Result<(), ProgramError> { let code_context = state.registers.code_context(); let num_bytes = n as usize; + if num_bytes > 32 { + // The call to `U256::from_big_endian()` would panic. + return Err(ProgramError::IntegerTooLarge); + } let initial_offset = state.registers.program_counter + 1; // First read val without going through `mem_read_with_log` type methods, so we can pass it @@ -589,7 +595,7 @@ pub(crate) fn generate_syscall( ); let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; - let new_program_counter = handler_addr.as_usize(); + let new_program_counter = u256_to_usize(handler_addr)?; let syscall_info = U256::from(state.registers.program_counter + 1) + (U256::from(u64::from(state.registers.is_kernel)) << 32) @@ -694,7 +700,11 @@ pub(crate) fn generate_mload_32bytes( ) -> Result<(), ProgramError> { let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; - let len = len.as_usize(); + let len = u256_to_usize(len)?; + if len > 32 { + // The call to `U256::from_big_endian()` would panic. + return Err(ProgramError::IntegerTooLarge); + } let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; if usize::MAX - base_address.virt < len { @@ -762,7 +772,7 @@ pub(crate) fn generate_mstore_32bytes( ) -> Result<(), ProgramError> { let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = stack_pop_with_log_and_fill::<5, _>(state, &mut row)?; - let len = len.as_usize(); + let len = u256_to_usize(len)?; let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; @@ -827,7 +837,7 @@ pub(crate) fn generate_exception( ); let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2; - let new_program_counter = handler_addr.as_usize(); + let new_program_counter = u256_to_usize(handler_addr)?; let exc_info = U256::from(state.registers.program_counter) + (U256::from(state.registers.gas_used) << 192); diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 94488614..068a8e11 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -29,11 +29,14 @@ fn to_bits_le(n: u8) -> [F; 8] { } /// Peek at the stack item `i`th from the top. If `i=0` this gives the tip. -pub(crate) fn stack_peek(state: &GenerationState, i: usize) -> Option { +pub(crate) fn stack_peek( + state: &GenerationState, + i: usize, +) -> Result { if i >= state.registers.stack_len { - return None; + return Err(ProgramError::StackUnderflow); } - Some(state.memory.get(MemoryAddress::new( + Ok(state.memory.get(MemoryAddress::new( state.registers.context, Segment::Stack, state.registers.stack_len - 1 - i,