diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 1e22ef33..c10ab104 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -11,6 +11,7 @@ anyhow = "1.0.40" env_logger = "0.9.0" ethereum-types = "0.13.1" hex = { version = "0.4.3", optional = true } +hex-literal = "0.3.4" itertools = "0.10.3" log = "0.4.14" once_cell = "1.13.0" @@ -24,7 +25,6 @@ keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" } keccak-hash = "0.9.0" [dev-dependencies] -hex-literal = "0.3.4" hex = "0.4.3" [features] diff --git a/evm/src/cpu/columns.rs b/evm/src/cpu/columns.rs index 42b0a5bc..ae6872df 100644 --- a/evm/src/cpu/columns.rs +++ b/evm/src/cpu/columns.rs @@ -52,7 +52,7 @@ pub struct CpuColumnsView { pub is_shl: T, pub is_shr: T, pub is_sar: T, - pub is_sha3: T, + pub is_keccak256: T, pub is_address: T, pub is_balance: T, pub is_origin: T, diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index bd7ea5fe..233c01c4 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -45,7 +45,7 @@ const OPCODES: [(u64, usize, usize); 107] = [ (0x1c, 0, COL_MAP.is_shr), (0x1d, 0, COL_MAP.is_sar), (0x1e, 1, COL_MAP.is_invalid_1), // 0x1e-0x1f - (0x20, 0, COL_MAP.is_sha3), + (0x20, 0, COL_MAP.is_keccak256), (0x21, 0, COL_MAP.is_invalid_2), (0x22, 1, COL_MAP.is_invalid_3), // 0x22-0x23 (0x24, 2, COL_MAP.is_invalid_4), // 0x24-0x27 diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index c6c47387..ec42f5c4 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use ethereum_types::U256; +use hex_literal::hex; use itertools::Itertools; use once_cell::sync::Lazy; @@ -14,6 +15,12 @@ pub static KERNEL: Lazy = Lazy::new(combined_kernel); pub fn evm_constants() -> HashMap { let mut c = HashMap::new(); + c.insert( + "BN_BASE".into(), + U256::from_big_endian(&hex!( + "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" + )), + ); for segment in Segment::all() { c.insert(segment.var_name().into(), (segment as u32).into()); } @@ -28,6 +35,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/exp.asm"), include_str!("asm/curve_mul.asm"), include_str!("asm/curve_add.asm"), + include_str!("asm/memory.asm"), include_str!("asm/moddiv.asm"), include_str!("asm/secp256k1/curve_mul.asm"), include_str!("asm/secp256k1/curve_add.asm"), diff --git a/evm/src/cpu/kernel/asm/basic_macros.asm b/evm/src/cpu/kernel/asm/basic_macros.asm index 7bf001b4..20f8958c 100644 --- a/evm/src/cpu/kernel/asm/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/basic_macros.asm @@ -26,6 +26,24 @@ %endrep %endmacro +%macro pop5 + %rep 5 + pop + %endrep +%endmacro + +%macro pop6 + %rep 6 + pop + %endrep +%endmacro + +%macro pop7 + %rep 7 + pop + %endrep +%endmacro + %macro add_const(c) // stack: input, ... PUSH $c @@ -64,6 +82,13 @@ // stack: input / c, ... %endmacro +%macro shl_const(c) + // stack: input, ... + PUSH $c + SHL + // stack: input << c, ... +%endmacro + %macro eq_const(c) // stack: input, ... PUSH $c diff --git a/evm/src/cpu/kernel/asm/curve_add.asm b/evm/src/cpu/kernel/asm/curve_add.asm index 4ac4e0e4..15f9df05 100644 --- a/evm/src/cpu/kernel/asm/curve_add.asm +++ b/evm/src/cpu/kernel/asm/curve_add.asm @@ -94,14 +94,8 @@ global ec_add_valid_points: ec_add_first_zero: JUMPDEST // stack: x0, y0, x1, y1, retdest - // Just return (x1,y1) - %pop2 - // stack: x1, y1, retdest - SWAP1 - // stack: y1, x1, retdest - SWAP2 - // stack: retdest, x1, y1 + %stack (x0, y0, x1, y1, retdest) -> (retdest, x1, y1) JUMP // BN254 elliptic curve addition. @@ -110,19 +104,8 @@ ec_add_snd_zero: JUMPDEST // stack: x0, y0, x1, y1, retdest - // Just return (x1,y1) - SWAP2 - // stack: x1, y0, x0, y1, retdest - POP - // stack: y0, x0, y1, retdest - SWAP2 - // stack: y1, x0, y0, retdest - POP - // stack: x0, y0, retdest - SWAP1 - // stack: y0, x0, retdest - SWAP2 - // stack: retdest, x0, y0 + // Just return (x0,y0) + %stack (x0, y0, x1, y1, retdest) -> (retdest, x0, y0) JUMP // BN254 elliptic curve addition. @@ -170,16 +153,7 @@ ec_add_valid_points_with_lambda: // stack: y2, x2, lambda, x0, y0, x1, y1, retdest // Return x2,y2 - SWAP5 - // stack: x1, x2, lambda, x0, y0, y2, y1, retdest - POP - // stack: x2, lambda, x0, y0, y2, y1, retdest - SWAP5 - // stack: y1, lambda, x0, y0, y2, x2, retdest - %pop4 - // stack: y2, x2, retdest - SWAP2 - // stack: retdest, x2, y2 + %stack (y2, x2, lambda, x0, y0, x1, y1, retdest) -> (retdest, x2, y2) JUMP // BN254 elliptic curve addition. @@ -291,21 +265,7 @@ global ec_double: // stack: y < N, x < N, x, y AND // stack: (y < N) & (x < N), x, y - SWAP2 - // stack: y, x, (y < N) & (x < N), x - SWAP1 - // stack: x, y, (y < N) & (x < N) - %bn_base - // stack: N, x, y, b - %bn_base - // stack: N, N, x, y, b - DUP3 - // stack: x, N, N, x, y, b - %bn_base - // stack: N, x, N, N, x, y, b - DUP2 - // stack: x, N, x, N, N, x, y, b - DUP1 + %stack (b, x, y) -> (x, x, @BN_BASE, x, @BN_BASE, @BN_BASE, x, y, b) // stack: x, x, N, x, N, N, x, y, b MULMOD // stack: x^2 % N, x, N, N, x, y, b diff --git a/evm/src/cpu/kernel/asm/ecrecover.asm b/evm/src/cpu/kernel/asm/ecrecover.asm index d0994054..538a86dc 100644 --- a/evm/src/cpu/kernel/asm/ecrecover.asm +++ b/evm/src/cpu/kernel/asm/ecrecover.asm @@ -107,33 +107,53 @@ ecrecover_with_first_point: // stack: u2, Y, X, retdest // Compute u2 * GENERATOR and chain the call to `ec_mul` with a call to `ec_add` to compute PUBKEY = (X,Y) + u2 * GENERATOR, - // and a call to `final_hashing` to get the final result `SHA3(PUBKEY)[-20:]`. - PUSH final_hashing - // stack: final_hashing, u2, Y, X, retdest + // and a call to `pubkey_to_addr` to get the final result `KECCAK256(PUBKEY)[-20:]`. + PUSH pubkey_to_addr + // stack: pubkey_to_addr, u2, Y, X, retdest SWAP3 - // stack: X, u2, Y, final_hashing, retdest + // stack: X, u2, Y, pubkey_to_addr, retdest PUSH ec_add_valid_points_secp - // stack: ec_add_valid_points_secp, X, u2, Y, final_hashing, retdest + // stack: ec_add_valid_points_secp, X, u2, Y, pubkey_to_addr, retdest SWAP1 - // stack: X, ec_add_valid_points_secp, u2, Y, final_hashing, retdest + // stack: X, ec_add_valid_points_secp, u2, Y, pubkey_to_addr, retdest PUSH 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798 // x-coordinate of generator - // stack: Gx, X, ec_add_valid_points_secp, u2, Y, final_hashing, retdest + // stack: Gx, X, ec_add_valid_points_secp, u2, Y, pubkey_to_addr, retdest SWAP1 - // stack: X, Gx, ec_add_valid_points_secp, u2, Y, final_hashing, retdest + // stack: X, Gx, ec_add_valid_points_secp, u2, Y, pubkey_to_addr, retdest PUSH 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8 // y-coordinate of generator - // stack: Gy, X, Gx, ec_add_valid_points_secp, u2, Y, final_hashing, retdest + // stack: Gy, X, Gx, ec_add_valid_points_secp, u2, Y, pubkey_to_addr, retdest SWAP1 - // stack: X, Gy, Gx, ec_add_valid_points_secp, u2, Y, final_hashing, retdest + // stack: X, Gy, Gx, ec_add_valid_points_secp, u2, Y, pubkey_to_addr, retdest SWAP4 - // stack: u2, Gy, Gx, ec_add_valid_points_secp, X, Y, final_hashing, retdest + // stack: u2, Gy, Gx, ec_add_valid_points_secp, X, Y, pubkey_to_addr, retdest SWAP2 - // stack: Gx, Gy, u2, ec_add_valid_points_secp, X, Y, final_hashing, retdest + // stack: Gx, Gy, u2, ec_add_valid_points_secp, X, Y, pubkey_to_addr, retdest %jump(ec_mul_valid_point_secp) -// TODO -final_hashing: +// Take a public key (PKx, PKy) and return the associated address KECCAK256(PKx || PKy)[-20:]. +pubkey_to_addr: JUMPDEST - PUSH 0xdeadbeef + // stack: PKx, PKy, retdest + PUSH 0 + // stack: 0, PKx, PKy, retdest + MSTORE // TODO: switch to kernel memory (like `%mstore_current(@SEGMENT_KERNEL_GENERAL)`). + // stack: PKy, retdest + PUSH 0x20 + // stack: 0x20, PKy, retdest + MSTORE + // stack: retdest + PUSH 0x40 + // stack: 0x40, retdest + PUSH 0 + // stack: 0, 0x40, retdest + KECCAK256 + // stack: hash, retdest + PUSH 0xffffffffffffffffffffffffffffffffffffffff + // stack: 2^160-1, hash, retdest + AND + // stack: address, retdest + SWAP1 + // stack: retdest, address JUMP // Check if v, r, and s are in correct form. diff --git a/evm/src/cpu/kernel/asm/memory.asm b/evm/src/cpu/kernel/asm/memory.asm index e3af9954..26d0b855 100644 --- a/evm/src/cpu/kernel/asm/memory.asm +++ b/evm/src/cpu/kernel/asm/memory.asm @@ -1,4 +1,4 @@ -// Load a byte from the given segment of the current context's memory space. +// Load a value from the given segment of the current context's memory space. // Note that main memory values are one byte each, but in general memory values // can be 256 bits. This macro deals with a single address (unlike MLOAD), so // if it is used with main memory, it will load a single byte. @@ -12,7 +12,7 @@ // stack: value %endmacro -// Store a byte to the given segment of the current context's memory space. +// Store a value to the given segment of the current context's memory space. // Note that main memory values are one byte each, but in general memory values // can be 256 bits. This macro deals with a single address (unlike MSTORE), so // if it is used with main memory, it will store a single byte. @@ -25,3 +25,96 @@ MSTORE_GENERAL // stack: (empty) %endmacro + +// Load a single byte from kernel code. +%macro mload_kernel_code + // stack: offset + PUSH @SEGMENT_CODE + // stack: segment, offset + PUSH 0 // kernel has context 0 + // stack: context, segment, offset + MLOAD_GENERAL + // stack: value +%endmacro + +// Load a big-endian u32, consisting of 4 bytes (c_3, c_2, c_1, c_0), +// from kernel code. +%macro mload_kernel_code_u32 + // stack: offset + DUP1 + %mload_kernel_code + // stack: c_3, offset + %shl_const(8) + // stack: c_3 << 8, offset + DUP2 + %add_const(1) + %mload_kernel_code + OR + // stack: (c_3 << 8) | c_2, offset + %shl_const(8) + // stack: ((c_3 << 8) | c_2) << 8, offset + DUP2 + %add_const(2) + %mload_kernel_code + OR + // stack: (((c_3 << 8) | c_2) << 8) | c_1, offset + %shl_const(8) + // stack: ((((c_3 << 8) | c_2) << 8) | c_1) << 8, offset + SWAP1 + %add_const(3) + %mload_kernel_code + OR + // stack: (((((c_3 << 8) | c_2) << 8) | c_1) << 8) | c_0 +%endmacro + +// Copies `count` values from +// SRC = (src_ctx, src_segment, src_addr) +// to +// DST = (dst_ctx, dst_segment, dst_addr). +// These tuple definitions are used for brevity in the stack comments below. +global memcpy: + JUMPDEST + // stack: DST, SRC, count, retdest + DUP7 + // stack: count, DST, SRC, count, retdest + ISZERO + // stack: count == 0, DST, SRC, count, retdest + %jumpi(memcpy_finish) + // stack: DST, SRC, count, retdest + + // Copy the next value. + DUP6 + DUP6 + DUP6 + // stack: SRC, DST, SRC, count, retdest + MLOAD_GENERAL + // stack: value, DST, SRC, count, retdest + DUP4 + DUP4 + DUP4 + // stack: DST, value, DST, SRC, count, retdest + MSTORE_GENERAL + // stack: DST, SRC, count, retdest + + // Increment dst_addr. + SWAP2 + %add_const(1) + SWAP2 + // Increment src_addr. + SWAP5 + %add_const(1) + SWAP5 + // Decrement count. + SWAP6 + %sub_const(1) + SWAP6 + + // Continue the loop. + %jump(memcpy) + +memcpy_finish: + JUMPDEST + // stack: DST, SRC, count, retdest + %pop7 + // stack: retdest + JUMP diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 8b7327dc..070ec291 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -5,8 +5,9 @@ use itertools::izip; use log::debug; use super::ast::PushTarget; -use crate::cpu::kernel::ast::Literal; +use crate::cpu::kernel::ast::{Literal, StackReplacement}; use crate::cpu::kernel::keccak_util::hash_kernel; +use crate::cpu::kernel::stack_manipulation::expand_stack_manipulation; use crate::cpu::kernel::{ ast::{File, Item}, opcodes::{get_opcode, get_push_opcode}, @@ -63,6 +64,7 @@ pub(crate) fn assemble(files: Vec, constants: HashMap) -> Ke let expanded_file = expand_macros(file.body, ¯os); let expanded_file = expand_repeats(expanded_file); let expanded_file = inline_constants(expanded_file, &constants); + let expanded_file = expand_stack_manipulation(expanded_file); local_labels.push(find_labels(&expanded_file, &mut offset, &mut global_labels)); expanded_files.push(expanded_file); } @@ -163,14 +165,31 @@ fn expand_repeats(body: Vec) -> Vec { } fn inline_constants(body: Vec, constants: &HashMap) -> Vec { + let resolve_const = |c| { + Literal::Decimal( + constants + .get(&c) + .unwrap_or_else(|| panic!("No such constant: {}", c)) + .to_string(), + ) + }; + body.into_iter() .map(|item| { if let Item::Push(PushTarget::Constant(c)) = item { - let value = constants - .get(&c) - .unwrap_or_else(|| panic!("No such constant: {}", c)); - let literal = Literal::Decimal(value.to_string()); - Item::Push(PushTarget::Literal(literal)) + Item::Push(PushTarget::Literal(resolve_const(c))) + } else if let Item::StackManipulation(from, to) = item { + let to = to + .into_iter() + .map(|replacement| { + if let StackReplacement::Constant(c) = replacement { + StackReplacement::Literal(resolve_const(c)) + } else { + replacement + } + }) + .collect(); + Item::StackManipulation(from, to) } else { item } @@ -187,8 +206,11 @@ fn find_labels( let mut local_labels = HashMap::::new(); for item in body { match item { - Item::MacroDef(_, _, _) | Item::MacroCall(_, _) | Item::Repeat(_, _) => { - panic!("Macros and repeats should have been expanded already") + Item::MacroDef(_, _, _) + | Item::MacroCall(_, _) + | Item::Repeat(_, _) + | Item::StackManipulation(_, _) => { + panic!("Item should have been expanded already: {:?}", item); } Item::GlobalLabelDeclaration(label) => { let old = global_labels.insert(label.clone(), *offset); @@ -215,8 +237,11 @@ fn assemble_file( // Assemble the file. for item in body { match item { - Item::MacroDef(_, _, _) | Item::MacroCall(_, _) | Item::Repeat(_, _) => { - panic!("Macros and repeats should have been expanded already") + Item::MacroDef(_, _, _) + | Item::MacroCall(_, _) + | Item::Repeat(_, _) + | Item::StackManipulation(_, _) => { + panic!("Item should have been expanded already: {:?}", item); } Item::GlobalLabelDeclaration(_) | Item::LocalLabelDeclaration(_) => { // Nothing to do; we processed labels in the prior phase. @@ -427,6 +452,24 @@ mod tests { assert_eq!(kernel.code, vec![add, add, add]); } + #[test] + fn stack_manipulation() { + let pop = get_opcode("POP"); + let swap1 = get_opcode("SWAP1"); + let swap2 = get_opcode("SWAP2"); + + let kernel = parse_and_assemble(&["%stack (a, b, c) -> (c, b, a)"]); + assert_eq!(kernel.code, vec![swap2]); + + let kernel = parse_and_assemble(&["%stack (a, b, c) -> (b)"]); + assert_eq!(kernel.code, vec![pop, swap1, pop]); + + let mut consts = HashMap::new(); + consts.insert("LIFE".into(), 42.into()); + parse_and_assemble_with_constants(&["%stack (a, b) -> (b, @LIFE)"], consts); + // We won't check the code since there are two equally efficient implementations. + } + fn parse_and_assemble(files: &[&str]) -> Kernel { parse_and_assemble_with_constants(files, HashMap::new()) } diff --git a/evm/src/cpu/kernel/ast.rs b/evm/src/cpu/kernel/ast.rs index 9bb315ff..92728104 100644 --- a/evm/src/cpu/kernel/ast.rs +++ b/evm/src/cpu/kernel/ast.rs @@ -14,6 +14,11 @@ pub(crate) enum Item { MacroCall(String, Vec), /// Repetition, like `%rep` in NASM. Repeat(Literal, Vec), + /// A directive to manipulate the stack according to a specified pattern. + /// The first list gives names to items on the top of the stack. + /// The second list specifies replacement items. + /// Example: `(a, b, c) -> (c, 5, 0x20, @SOME_CONST, a)`. + StackManipulation(Vec, Vec), /// Declares a global label. GlobalLabelDeclaration(String), /// Declares a label that is local to the current file. @@ -26,6 +31,14 @@ pub(crate) enum Item { Bytes(Vec), } +#[derive(Clone, Debug)] +pub(crate) enum StackReplacement { + NamedItem(String), + Literal(Literal), + MacroVar(String), + Constant(String), +} + /// The target of a `PUSH` operation. #[derive(Clone, Debug)] pub(crate) enum PushTarget { @@ -35,7 +48,7 @@ pub(crate) enum PushTarget { Constant(String), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub(crate) enum Literal { Decimal(String), Hex(String), diff --git a/evm/src/cpu/kernel/evm_asm.pest b/evm/src/cpu/kernel/evm_asm.pest index d5a89d99..78938b64 100644 --- a/evm/src/cpu/kernel/evm_asm.pest +++ b/evm/src/cpu/kernel/evm_asm.pest @@ -15,12 +15,15 @@ literal = { literal_hex | literal_decimal } variable = ${ "$" ~ identifier } constant = ${ "@" ~ identifier } -item = { macro_def | macro_call | repeat | global_label | local_label | bytes_item | push_instruction | nullary_instruction } -macro_def = { ^"%macro" ~ identifier ~ macro_paramlist? ~ item* ~ ^"%endmacro" } -macro_call = ${ "%" ~ !(^"macro" | ^"endmacro" | ^"rep" | ^"endrep") ~ identifier ~ macro_arglist? } +item = { macro_def | macro_call | repeat | stack | global_label | local_label | bytes_item | push_instruction | nullary_instruction } +macro_def = { ^"%macro" ~ identifier ~ paramlist? ~ item* ~ ^"%endmacro" } +macro_call = ${ "%" ~ !(^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") ~ identifier ~ macro_arglist? } repeat = { ^"%rep" ~ literal ~ item* ~ ^"%endrep" } -macro_paramlist = { "(" ~ identifier ~ ("," ~ identifier)* ~ ")" } +paramlist = { "(" ~ identifier ~ ("," ~ identifier)* ~ ")" } macro_arglist = !{ "(" ~ push_target ~ ("," ~ push_target)* ~ ")" } +stack = { ^"%stack" ~ paramlist ~ "->" ~ stack_replacements } +stack_replacements = { "(" ~ stack_replacement ~ ("," ~ stack_replacement)* ~ ")" } +stack_replacement = { literal | identifier | constant } global_label = { ^"GLOBAL " ~ identifier ~ ":" } local_label = { identifier ~ ":" } bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* } diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index aa5d1ac3..f2fb276a 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, bail}; -use ethereum_types::{U256, U512}; +use ethereum_types::{BigEndianHash, U256, U512}; +use keccak_hash::keccak; /// Halt interpreter execution whenever a jump to this offset is done. const HALT_OFFSET: usize = 0xdeadbeef; @@ -26,6 +27,11 @@ impl EvmMemory { U256::from_big_endian(&self.memory[offset..offset + 32]) } + fn mload8(&mut self, offset: usize) -> u8 { + self.expand(offset + 1); + self.memory[offset] + } + fn mstore(&mut self, offset: usize, value: U256) { self.expand(offset + 32); let value_be = { @@ -140,7 +146,7 @@ impl<'a> Interpreter<'a> { 0x1b => todo!(), // "SHL", 0x1c => todo!(), // "SHR", 0x1d => todo!(), // "SAR", - 0x20 => todo!(), // "KECCAK256", + 0x20 => self.run_keccak256(), // "KECCAK256", 0x30 => todo!(), // "ADDRESS", 0x31 => todo!(), // "BALANCE", 0x32 => todo!(), // "ORIGIN", @@ -320,6 +326,16 @@ impl<'a> Interpreter<'a> { self.push(!x); } + fn run_keccak256(&mut self) { + let offset = self.pop().as_usize(); + let size = self.pop().as_usize(); + let bytes = (offset..offset + size) + .map(|i| self.memory.mload8(i)) + .collect::>(); + let hash = keccak(bytes); + self.push(hash.into_uint()); + } + fn run_prover_input(&mut self) -> anyhow::Result<()> { let input = self .prover_inputs diff --git a/evm/src/cpu/kernel/mod.rs b/evm/src/cpu/kernel/mod.rs index 2dd70aa3..1f13a042 100644 --- a/evm/src/cpu/kernel/mod.rs +++ b/evm/src/cpu/kernel/mod.rs @@ -4,6 +4,7 @@ mod ast; pub(crate) mod keccak_util; mod opcodes; mod parser; +mod stack_manipulation; #[cfg(test)] mod interpreter; diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index b8ac3f40..aa84ee05 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -1,7 +1,7 @@ use pest::iterators::Pair; use pest::Parser; -use crate::cpu::kernel::ast::{File, Item, Literal, PushTarget}; +use crate::cpu::kernel::ast::{File, Item, Literal, PushTarget, StackReplacement}; /// Parses EVM assembly code. #[derive(pest_derive::Parser)] @@ -24,6 +24,7 @@ fn parse_item(item: Pair) -> Item { Rule::macro_def => parse_macro_def(item), Rule::macro_call => parse_macro_call(item), Rule::repeat => parse_repeat(item), + Rule::stack => parse_stack(item), Rule::global_label => { Item::GlobalLabelDeclaration(item.into_inner().next().unwrap().as_str().into()) } @@ -44,7 +45,7 @@ fn parse_macro_def(item: Pair) -> Item { let name = inner.next().unwrap().as_str().into(); // The parameter list is optional. - let params = if let Some(Rule::macro_paramlist) = inner.peek().map(|pair| pair.as_rule()) { + let params = if let Some(Rule::paramlist) = inner.peek().map(|pair| pair.as_rule()) { let params = inner.next().unwrap().into_inner(); params.map(|param| param.as_str().to_string()).collect() } else { @@ -78,6 +79,42 @@ fn parse_repeat(item: Pair) -> Item { Item::Repeat(count, inner.map(parse_item).collect()) } +fn parse_stack(item: Pair) -> Item { + assert_eq!(item.as_rule(), Rule::stack); + let mut inner = item.into_inner().peekable(); + + let params = inner.next().unwrap(); + assert_eq!(params.as_rule(), Rule::paramlist); + let replacements = inner.next().unwrap(); + assert_eq!(replacements.as_rule(), Rule::stack_replacements); + + let params = params + .into_inner() + .map(|param| param.as_str().to_string()) + .collect(); + let replacements = replacements + .into_inner() + .map(parse_stack_replacement) + .collect(); + Item::StackManipulation(params, replacements) +} + +fn parse_stack_replacement(target: Pair) -> StackReplacement { + assert_eq!(target.as_rule(), Rule::stack_replacement); + let inner = target.into_inner().next().unwrap(); + match inner.as_rule() { + Rule::identifier => StackReplacement::NamedItem(inner.as_str().into()), + Rule::literal => StackReplacement::Literal(parse_literal(inner)), + Rule::variable => { + StackReplacement::MacroVar(inner.into_inner().next().unwrap().as_str().into()) + } + Rule::constant => { + StackReplacement::Constant(inner.into_inner().next().unwrap().as_str().into()) + } + _ => panic!("Unexpected {:?}", inner.as_rule()), + } +} + fn parse_push_target(target: Pair) -> PushTarget { assert_eq!(target.as_rule(), Rule::push_target); let inner = target.into_inner().next().unwrap(); diff --git a/evm/src/cpu/kernel/stack_manipulation.rs b/evm/src/cpu/kernel/stack_manipulation.rs new file mode 100644 index 00000000..63d0566c --- /dev/null +++ b/evm/src/cpu/kernel/stack_manipulation.rs @@ -0,0 +1,262 @@ +use std::cmp::Ordering; +use std::collections::hash_map::Entry::{Occupied, Vacant}; +use std::collections::{BinaryHeap, HashMap}; + +use itertools::Itertools; + +use crate::cpu::columns::NUM_CPU_COLUMNS; +use crate::cpu::kernel::ast::{Item, Literal, PushTarget, StackReplacement}; +use crate::cpu::kernel::stack_manipulation::StackOp::Pop; +use crate::memory; + +pub(crate) fn expand_stack_manipulation(body: Vec) -> Vec { + let mut expanded = vec![]; + for item in body { + if let Item::StackManipulation(names, replacements) = item { + expanded.extend(expand(names, replacements)); + } else { + expanded.push(item); + } + } + expanded +} + +fn expand(names: Vec, replacements: Vec) -> Vec { + let mut src = names.into_iter().map(StackItem::NamedItem).collect_vec(); + + let unique_literals = replacements + .iter() + .filter_map(|item| match item { + StackReplacement::Literal(n) => Some(n.clone()), + _ => None, + }) + .unique() + .collect_vec(); + + let mut dst = replacements + .into_iter() + .map(|item| match item { + StackReplacement::NamedItem(name) => StackItem::NamedItem(name), + StackReplacement::Literal(n) => StackItem::Literal(n), + StackReplacement::MacroVar(_) | StackReplacement::Constant(_) => { + panic!("Should have been expanded already: {:?}", item) + } + }) + .collect_vec(); + + // %stack uses our convention where the top item is written on the left side. + // `shortest_path` expects the opposite, so we reverse src and dst. + src.reverse(); + dst.reverse(); + + let path = shortest_path(src, dst, unique_literals); + path.into_iter().map(StackOp::into_item).collect() +} + +/// Finds the lowest-cost sequence of `StackOp`s that transforms `src` to `dst`. +/// Uses a variant of Dijkstra's algorithm. +fn shortest_path( + src: Vec, + dst: Vec, + unique_literals: Vec, +) -> Vec { + // Nodes to visit, starting with the lowest-cost node. + let mut queue = BinaryHeap::new(); + queue.push(Node { + stack: src.clone(), + cost: 0, + }); + + // For each node, stores `(best_cost, Option<(parent, op)>)`. + let mut node_info = HashMap::, (u32, Option<(Vec, StackOp)>)>::new(); + node_info.insert(src.clone(), (0, None)); + + while let Some(node) = queue.pop() { + if node.stack == dst { + // The destination is now the lowest-cost node, so we must have found the best path. + let mut path = vec![]; + let mut stack = &node.stack; + // Rewind back to src, recording a list of operations which will be backwards. + while let Some((parent, op)) = &node_info[stack].1 { + stack = parent; + path.push(op.clone()); + } + assert_eq!(stack, &src); + path.reverse(); + return path; + } + + let (best_cost, _) = node_info[&node.stack]; + if best_cost < node.cost { + // Since we can't efficiently remove nodes from the heap, it can contain duplicates. + // In this case, we've already visited this stack state with a lower cost. + continue; + } + + for op in next_ops(&node.stack, &dst, &unique_literals) { + let neighbor = match op.apply_to(node.stack.clone()) { + Some(n) => n, + None => continue, + }; + + let cost = node.cost + op.cost(); + let entry = node_info.entry(neighbor.clone()); + if let Occupied(e) = &entry && e.get().0 <= cost { + // We already found a better or equal path. + continue; + } + + let neighbor_info = (cost, Some((node.stack.clone(), op.clone()))); + match entry { + Occupied(mut e) => { + e.insert(neighbor_info); + } + Vacant(e) => { + e.insert(neighbor_info); + } + } + + queue.push(Node { + stack: neighbor, + cost, + }); + } + } + + panic!("No path found from {:?} to {:?}", src, dst) +} + +/// A node in the priority queue used by Dijkstra's algorithm. +#[derive(Eq, PartialEq)] +struct Node { + stack: Vec, + cost: u32, +} + +impl PartialOrd for Node { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Node { + fn cmp(&self, other: &Self) -> Ordering { + // We want a min-heap rather than the default max-heap, so this is the opposite of the + // natural ordering of costs. + other.cost.cmp(&self.cost) + } +} + +/// Like `StackReplacement`, but without constants or macro vars, since those were expanded already. +#[derive(Eq, PartialEq, Hash, Clone, Debug)] +enum StackItem { + NamedItem(String), + Literal(Literal), +} + +#[derive(Clone, Debug)] +enum StackOp { + Push(Literal), + Pop, + Dup(u8), + Swap(u8), +} + +/// A set of candidate operations to consider for the next step in the path from `src` to `dst`. +fn next_ops(src: &[StackItem], dst: &[StackItem], unique_literals: &[Literal]) -> Vec { + if let Some(top) = src.last() && !dst.contains(top) { + // If the top of src doesn't appear in dst, don't bother with anything other than a POP. + return vec![StackOp::Pop] + } + + let mut ops = vec![StackOp::Pop]; + + ops.extend( + unique_literals + .iter() + // Only consider pushing this literal if we need more occurrences of it, otherwise swaps + // will be a better way to rearrange the existing occurrences as needed. + .filter(|lit| { + let item = StackItem::Literal((*lit).clone()); + let src_count = src.iter().filter(|x| **x == item).count(); + let dst_count = dst.iter().filter(|x| **x == item).count(); + src_count < dst_count + }) + .cloned() + .map(StackOp::Push), + ); + + let src_len = src.len() as u8; + + ops.extend( + (1..=src_len) + // Only consider duplicating this item if we need more occurrences of it, otherwise swaps + // will be a better way to rearrange the existing occurrences as needed. + .filter(|i| { + let item = &src[src.len() - *i as usize]; + let src_count = src.iter().filter(|x| *x == item).count(); + let dst_count = dst.iter().filter(|x| *x == item).count(); + src_count < dst_count + }) + .map(StackOp::Dup), + ); + + ops.extend((1..src_len).map(StackOp::Swap)); + + ops +} + +impl StackOp { + fn cost(&self) -> u32 { + let (cpu_rows, memory_rows) = match self { + StackOp::Push(n) => { + let bytes = n.to_trimmed_be_bytes().len() as u32; + // This is just a rough estimate; we can update it after implementing PUSH. + (bytes, bytes) + } + // A POP takes one cycle, and doesn't involve memory, it just decrements a pointer. + Pop => (1, 0), + // A DUP takes one cycle, and a read and a write. + StackOp::Dup(_) => (1, 2), + // A SWAP takes one cycle with four memory ops, to read both values then write to them. + StackOp::Swap(_) => (1, 4), + }; + + let cpu_cost = cpu_rows * NUM_CPU_COLUMNS as u32; + let memory_cost = memory_rows * memory::columns::NUM_COLUMNS as u32; + cpu_cost + memory_cost + } + + /// Returns an updated stack after this operation is performed, or `None` if this operation + /// would not be valid on the given stack. + fn apply_to(&self, mut stack: Vec) -> Option> { + let len = stack.len(); + match self { + StackOp::Push(n) => { + stack.push(StackItem::Literal(n.clone())); + } + Pop => { + stack.pop()?; + } + StackOp::Dup(n) => { + let idx = len.checked_sub(*n as usize)?; + stack.push(stack[idx].clone()); + } + StackOp::Swap(n) => { + let from = len.checked_sub(1)?; + let to = len.checked_sub(*n as usize + 1)?; + stack.swap(from, to); + } + } + Some(stack) + } + + fn into_item(self) -> Item { + match self { + StackOp::Push(n) => Item::Push(PushTarget::Literal(n)), + Pop => Item::StandardOp("POP".into()), + StackOp::Dup(n) => Item::StandardOp(format!("DUP{}", n)), + StackOp::Swap(n) => Item::StandardOp(format!("SWAP{}", n)), + } + } +} diff --git a/evm/src/cpu/kernel/tests/ecrecover.rs b/evm/src/cpu/kernel/tests/ecrecover.rs index d6eac6f7..78bdea3e 100644 --- a/evm/src/cpu/kernel/tests/ecrecover.rs +++ b/evm/src/cpu/kernel/tests/ecrecover.rs @@ -1,20 +1,13 @@ +use std::str::FromStr; + use anyhow::Result; use ethereum_types::U256; -use keccak_hash::keccak; use crate::cpu::kernel::aggregator::combined_kernel; use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::interpreter::run; use crate::cpu::kernel::tests::u256ify; -fn pubkey_to_addr(x: U256, y: U256) -> Vec { - let mut buf = [0; 64]; - x.to_big_endian(&mut buf[0..32]); - y.to_big_endian(&mut buf[32..64]); - let hash = keccak(buf); - hash.0[12..].to_vec() -} - fn test_valid_ecrecover( hash: &str, v: &str, @@ -24,10 +17,9 @@ fn test_valid_ecrecover( kernel: &Kernel, ) -> Result<()> { let ecrecover = kernel.global_labels["ecrecover"]; - let initial_stack = u256ify([s, r, v, hash])?; + let initial_stack = u256ify(["0xdeadbeef", s, r, v, hash])?; let stack = run(&kernel.code, ecrecover, initial_stack)?.stack; - let got = pubkey_to_addr(stack[1], stack[0]); - assert_eq!(got, hex::decode(&expected[2..]).unwrap()); + assert_eq!(stack[0], U256::from_str(expected).unwrap()); Ok(()) }