use std::cmp::min; use std::collections::HashMap; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::simulate_cpu_between_labels_and_get_user_jumps; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize}; use crate::witness::errors::ProverInputError::*; use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::MemoryAddress; use crate::witness::operation::CONTEXT_SCALING_FACTOR; use crate::witness::util::{current_context_peek, stack_peek}; /// Prover input function represented as a scoped function name. /// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`. #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] pub struct ProverInputFn(Vec); impl From> for ProverInputFn { fn from(v: Vec) -> Self { Self(v) } } impl GenerationState { pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[0].as_str() { "no_txn" => self.no_txn(), "trie_ptr" => self.run_trie_ptr(input_fn), "ff" => self.run_ff(input_fn), "sf" => self.run_sf(input_fn), "ffe" => self.run_ffe(input_fn), "rlp" => self.run_rlp(), "current_hash" => self.run_current_hash(), "account_code" => self.run_account_code(), "bignum_modmul" => self.run_bignum_modmul(), "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), "jumpdest_table" => self.run_jumpdest_table(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } fn no_txn(&mut self) -> Result { Ok(U256::from(self.inputs.signed_txn.is_none() as u8)) } fn run_trie_ptr(&mut self, input_fn: &ProverInputFn) -> Result { let trie = input_fn.0[1].as_str(); match trie { "state" => Ok(U256::from(self.trie_root_ptrs.state_root_ptr)), "txn" => Ok(U256::from(self.trie_root_ptrs.txn_root_ptr)), "receipt" => Ok(U256::from(self.trie_root_ptrs.receipt_root_ptr)), _ => Err(ProgramError::ProverInputError(InvalidInput)), } } /// Finite field operations. fn run_ff(&self, input_fn: &ProverInputFn) -> Result { let field = EvmField::from_str(input_fn.0[1].as_str()) .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let op = FieldOp::from_str(input_fn.0[2].as_str()) .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let x = stack_peek(self, 0)?; field.op(op, x) } /// Special finite field operations. fn run_sf(&self, input_fn: &ProverInputFn) -> Result { let field = EvmField::from_str(input_fn.0[1].as_str()) .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let inputs: [U256; 4] = match field { Bls381Base => (0..4) .map(|i| stack_peek(self, i)) .collect::, _>>()? .try_into() .unwrap(), _ => todo!(), }; let res = match input_fn.0[2].as_str() { "add_lo" => field.add_lo(inputs), "add_hi" => field.add_hi(inputs), "mul_lo" => field.mul_lo(inputs), "mul_hi" => field.mul_hi(inputs), "sub_lo" => field.sub_lo(inputs), "sub_hi" => field.sub_hi(inputs), _ => return Err(ProgramError::ProverInputError(InvalidFunction)), }; Ok(res) } /// Finite field extension operations. fn run_ffe(&self, input_fn: &ProverInputFn) -> Result { let field = EvmField::from_str(input_fn.0[1].as_str()) .map_err(|_| ProgramError::ProverInputError(InvalidFunction))?; let n = input_fn.0[2] .as_str() .split('_') .nth(1) .unwrap() .parse::() .unwrap(); let ptr = stack_peek(self, 11 - n).map(u256_to_usize)??; let f: [U256; 12] = match field { Bn254Base => std::array::from_fn(|i| current_context_peek(self, BnPairing, ptr + i)), _ => todo!(), }; Ok(field.field_extension_inverse(n, f)) } /// RLP data. fn run_rlp(&mut self) -> Result { self.rlp_prover_inputs .pop() .ok_or(ProgramError::ProverInputError(OutOfRlpData)) } fn run_current_hash(&mut self) -> Result { Ok(U256::from_big_endian(&self.inputs.block_hashes.cur_hash.0)) } /// Account code loading. /// Initializes the code segment of the given context with the code corresponding /// to the provided hash. /// Returns the length of the code. fn run_account_code(&mut self) -> Result { // stack: codehash, ctx, ... let codehash = stack_peek(self, 0)?; let context = stack_peek(self, 1)? >> CONTEXT_SCALING_FACTOR; let context = u256_to_usize(context)?; let mut address = MemoryAddress::new(context, Segment::Code, 0); let code = self .inputs .contract_code .get(&H256::from_uint(&codehash)) .ok_or(ProgramError::ProverInputError(CodeHashNotFound))?; for &byte in code { self.memory.set(address, byte.into()); address.increment(); } Ok(code.len().into()) } // Bignum modular multiplication. // On the first call, calculates the remainder and quotient of the given inputs. // These are stored, as limbs, in self.bignum_modmul_result_limbs. // Subsequent calls return one limb at a time, in order (first remainder and then quotient). fn run_bignum_modmul(&mut self) -> Result { if self.bignum_modmul_result_limbs.is_empty() { let len = stack_peek(self, 1).map(u256_to_usize)??; let a_start_loc = stack_peek(self, 2).map(u256_to_usize)??; let b_start_loc = stack_peek(self, 3).map(u256_to_usize)??; let m_start_loc = stack_peek(self, 4).map(u256_to_usize)??; let (remainder, quotient) = self.bignum_modmul(len, a_start_loc, b_start_loc, m_start_loc); self.bignum_modmul_result_limbs = remainder .iter() .cloned() .pad_using(len, |_| 0.into()) .chain(quotient.iter().cloned().pad_using(2 * len, |_| 0.into())) .collect(); self.bignum_modmul_result_limbs.reverse(); } self.bignum_modmul_result_limbs .pop() .ok_or(ProgramError::ProverInputError(InvalidInput)) } fn bignum_modmul( &mut self, len: usize, a_start_loc: usize, b_start_loc: usize, m_start_loc: usize, ) -> (Vec, Vec) { let n = self.memory.contexts.len(); let a = &self.memory.contexts[n - 1].segments[Segment::KernelGeneral.unscale()].content [a_start_loc..a_start_loc + len]; let b = &self.memory.contexts[n - 1].segments[Segment::KernelGeneral.unscale()].content [b_start_loc..b_start_loc + len]; let m = &self.memory.contexts[n - 1].segments[Segment::KernelGeneral.unscale()].content [m_start_loc..m_start_loc + len]; let a_biguint = mem_vec_to_biguint(a); let b_biguint = mem_vec_to_biguint(b); let m_biguint = mem_vec_to_biguint(m); let prod = a_biguint * b_biguint; let quo = if m_biguint == BigUint::default() { BigUint::default() } else { &prod / &m_biguint }; let rem = prod - m_biguint * &quo; (biguint_to_mem_vec(rem), biguint_to_mem_vec(quo)) } /// Withdrawal data. fn run_withdrawal(&mut self) -> Result { self.withdrawal_prover_inputs .pop() .ok_or(ProgramError::ProverInputError(OutOfWithdrawalData)) } /// Return the number of bits of the top of the stack or an error if /// the top of the stack is zero or empty. fn run_num_bits(&mut self) -> Result { let value = stack_peek(self, 0)?; if value.is_zero() { Err(ProgramError::ProverInputError(NumBitsError)) } else { let num_bits = value.bits(); Ok(num_bits.into()) } } /// Generate the either the next used jump address or the the proof for the last jump address. fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { "next_address" => self.run_next_jumpdest_table_address(), "next_proof" => self.run_next_jumpdest_table_proof(), _ => Err(ProgramError::ProverInputError(InvalidInput)), } } /// Returns the next used jump address. fn run_next_jumpdest_table_address(&mut self) -> Result { let context = self.registers.context; let code_len = u256_to_usize(self.get_code_len()?.into()); if self.jumpdest_proofs.is_none() { self.generate_jumpdest_proofs()?; } let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop() { Ok((next_jumpdest_address + 1).into()) } else { self.jumpdest_proofs = None; Ok(U256::zero()) } } /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop() { Ok(next_jumpdest_proof.into()) } else { Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )) } } } impl GenerationState { /// Simulate the user's code and store all the jump addresses with their respective contexts. fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { let checkpoint = self.checkpoint(); let memory = self.memory.clone(); let code = self.get_current_code()?; // We need to set the simulated jumpdest bits to one as otherwise // the simulation will fail. // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( "jumpdest_analysis_end", "terminate_common", self, ) else { self.jumpdest_proofs = Some(HashMap::new()); return Ok(()); }; // Return to the state before starting the simulation self.rollback(checkpoint); self.memory = memory; // Find proofs for all contexts self.set_proofs_and_jumpdests(jumpdest_table); Ok(()) } /// Given a HashMap containing the contexts and the jumpdest addresses, compute their respective proofs, /// by calling `get_proofs_and_jumpdests` pub(crate) fn set_proofs_and_jumpdests( &mut self, jumpdest_table: HashMap>, ) { self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map( |(ctx, jumpdest_table)| { let code = self.get_code(ctx).unwrap(); if let Some(&largest_address) = jumpdest_table.last() { let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); (ctx, proofs) } else { (ctx, vec![]) } }, ))); } fn get_current_code(&self) -> Result, ProgramError> { self.get_code(self.registers.context) } fn get_code(&self, context: usize) -> Result, ProgramError> { let code_len = self.get_code_len()?; let code = (0..code_len) .map(|i| { u256_to_u8(self.memory.get(MemoryAddress::new( self.registers.context, Segment::Code, i, ))) }) .collect::, _>>()?; Ok(code) } fn get_code_len(&self) -> Result { let code_len = u256_to_usize(self.memory.get(MemoryAddress::new( self.registers.context, Segment::ContextMetadata, ContextMetadata::CodeSize.unscale(), )))?; Ok(code_len) } fn set_jumpdest_bits(&mut self, code: &[u8]) { const JUMPDEST_OPCODE: u8 = 0x5b; for (pos, opcode) in CodeIterator::new(code) { if opcode == JUMPDEST_OPCODE { self.memory.set( MemoryAddress::new(self.registers.context, Segment::JumpdestBits, pos), U256::one(), ); } } } } /// For all address in `jumpdest_table`, each bounded by `largest_address`, /// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses. fn get_proofs_and_jumpdests( code: &[u8], largest_address: usize, jumpdest_table: std::collections::BTreeSet, ) -> Vec { const PUSH1_OPCODE: u8 = 0x60; const PUSH32_OPCODE: u8 = 0x7f; let (proofs, _) = CodeIterator::until(code, largest_address + 1).fold( (vec![], 0), |(mut proofs, acc), (pos, opcode)| { let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { code[prefix_start..pos] .iter() .enumerate() .fold(true, |acc, (prefix_pos, &byte)| { let cond1 = byte > PUSH32_OPCODE; let cond2 = (prefix_start + prefix_pos) as i32 + (byte as i32 - PUSH1_OPCODE as i32) + 1 < pos as i32; acc && (cond1 || cond2) }) } else { false }; let acc = if has_prefix { pos - 32 } else { acc }; if jumpdest_table.contains(&pos) { // Push the proof proofs.push(acc); // Push the address proofs.push(pos); } (proofs, acc) }, ); proofs } /// An iterator over the EVM code contained in `code`, which skips the bytes /// that are the arguments of a PUSHXX opcode. struct CodeIterator<'a> { code: &'a [u8], pos: usize, end: usize, } impl<'a> CodeIterator<'a> { fn new(code: &'a [u8]) -> Self { CodeIterator { end: code.len(), code, pos: 0, } } fn until(code: &'a [u8], end: usize) -> Self { CodeIterator { end: std::cmp::min(code.len(), end), code, pos: 0, } } } impl<'a> Iterator for CodeIterator<'a> { type Item = (usize, u8); fn next(&mut self) -> Option { const PUSH1_OPCODE: u8 = 0x60; const PUSH32_OPCODE: u8 = 0x7f; let CodeIterator { code, pos, end } = self; if *pos >= *end { return None; } let opcode = code[*pos]; let old_pos = *pos; *pos += if (PUSH1_OPCODE..=PUSH32_OPCODE).contains(&opcode) { (opcode - PUSH1_OPCODE + 2).into() } else { 1 }; Some((old_pos, opcode)) } } enum EvmField { Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, } enum FieldOp { Inverse, Sqrt, } impl FromStr for EvmField { type Err = Error; fn from_str(s: &str) -> Result { Ok(match s { "bls381_base" => Bls381Base, "bls381_scalar" => Bls381Scalar, "bn254_base" => Bn254Base, "bn254_scalar" => Bn254Scalar, "secp256k1_base" => Secp256k1Base, "secp256k1_scalar" => Secp256k1Scalar, _ => bail!("Unrecognized field."), }) } } impl FromStr for FieldOp { type Err = Error; fn from_str(s: &str) -> Result { Ok(match s { "inverse" => Inverse, "sqrt" => Sqrt, _ => bail!("Unrecognized field operation."), }) } } impl EvmField { fn order(&self) -> U256 { match self { EvmField::Bls381Base => todo!(), EvmField::Bls381Scalar => todo!(), EvmField::Bn254Base => { U256::from_str("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") .unwrap() } EvmField::Bn254Scalar => todo!(), EvmField::Secp256k1Base => { U256::from_str("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f") .unwrap() } EvmField::Secp256k1Scalar => { U256::from_str("0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") .unwrap() } } } fn op(&self, op: FieldOp, x: U256) -> Result { match op { FieldOp::Inverse => self.inverse(x), FieldOp::Sqrt => self.sqrt(x), } } fn inverse(&self, x: U256) -> Result { let n = self.order(); if x >= n { return Err(ProgramError::ProverInputError(InvalidInput)); }; modexp(x, n - 2, n) } fn sqrt(&self, x: U256) -> Result { let n = self.order(); if x >= n { return Err(ProgramError::ProverInputError(InvalidInput)); }; let (q, r) = (n + 1).div_mod(4.into()); if !r.is_zero() { return Err(ProgramError::ProverInputError(InvalidInput)); }; // Only naive sqrt implementation for now. If needed implement Tonelli-Shanks modexp(x, q, n) } fn add_lo(&self, inputs: [U256; 4]) -> U256 { let [y1, x0, x1, y0] = inputs; let x = U512::from(x0) + (U512::from(x1) << 256); let y = U512::from(y0) + (U512::from(y1) << 256); let z = BLS381 { val: x } + BLS381 { val: y }; z.lo() } fn add_hi(&self, inputs: [U256; 4]) -> U256 { let [x0, x1, y0, y1] = inputs; let x = U512::from(x0) + (U512::from(x1) << 256); let y = U512::from(y0) + (U512::from(y1) << 256); let z = BLS381 { val: x } + BLS381 { val: y }; z.hi() } fn mul_lo(&self, inputs: [U256; 4]) -> U256 { let [y1, x0, x1, y0] = inputs; let x = U512::from(x0) + (U512::from(x1) << 256); let y = U512::from(y0) + (U512::from(y1) << 256); let z = BLS381 { val: x } * BLS381 { val: y }; z.lo() } fn mul_hi(&self, inputs: [U256; 4]) -> U256 { let [x0, x1, y0, y1] = inputs; let x = U512::from(x0) + (U512::from(x1) << 256); let y = U512::from(y0) + (U512::from(y1) << 256); let z = BLS381 { val: x } * BLS381 { val: y }; z.hi() } fn sub_lo(&self, inputs: [U256; 4]) -> U256 { let [y1, x0, x1, y0] = inputs; let x = U512::from(x0) + (U512::from(x1) << 256); let y = U512::from(y0) + (U512::from(y1) << 256); let z = BLS381 { val: x } - BLS381 { val: y }; z.lo() } fn sub_hi(&self, inputs: [U256; 4]) -> U256 { let [x0, x1, y0, y1] = inputs; let x = U512::from(x0) + (U512::from(x1) << 256); let y = U512::from(y0) + (U512::from(y1) << 256); let z = BLS381 { val: x } - BLS381 { val: y }; z.hi() } fn field_extension_inverse(&self, n: usize, f: [U256; 12]) -> U256 { let f: Fp12 = unsafe { transmute(f) }; let f_inv: [U256; 12] = unsafe { transmute(f.inv()) }; f_inv[n] } } fn modexp(x: U256, e: U256, n: U256) -> Result { let mut current = x; let mut product = U256::one(); for j in 0..256 { if e.bit(j) { product = U256::try_from(product.full_mul(current) % n) .map_err(|_| ProgramError::ProverInputError(InvalidInput))?; } current = U256::try_from(current.full_mul(current) % n) .map_err(|_| ProgramError::ProverInputError(InvalidInput))?; } Ok(product) }