diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 8d45a9a2..7f34f90b 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -39,7 +39,7 @@ pub(crate) fn combined_kernel() -> Kernel { ]; let parsed_files = files.iter().map(|f| parse(f)).collect_vec(); - assemble(parsed_files, evm_constants()) + assemble(parsed_files, evm_constants(), true) } #[cfg(test)] diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index e266b2cb..4dd93d14 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -120,7 +120,7 @@ // stack: input, ... PUSH $c // stack: c, input, ... - GE // Check it backwards: (input <= c) == (c >= input) + LT ISZERO // Check it backwards: (input <= c) == !(c < input) // stack: input <= c, ... %endmacro @@ -136,7 +136,7 @@ // stack: input, ... PUSH $c // stack: c, input, ... - LE // Check it backwards: (input >= c) == (c <= input) + GT ISZERO // Check it backwards: (input >= c) == !(c > input) // stack: input >= c, ... %endmacro diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 6e98b22c..636251a3 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -7,6 +7,7 @@ use log::debug; use super::ast::PushTarget; use crate::cpu::kernel::ast::StackReplacement; use crate::cpu::kernel::keccak_util::hash_kernel; +use crate::cpu::kernel::optimizer::optimize_asm; use crate::cpu::kernel::prover_input::ProverInputFn; use crate::cpu::kernel::stack_manipulation::expand_stack_manipulation; use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; @@ -64,7 +65,11 @@ impl Macro { } } -pub(crate) fn assemble(files: Vec, constants: HashMap) -> Kernel { +pub(crate) fn assemble( + files: Vec, + constants: HashMap, + optimize: bool, +) -> Kernel { let macros = find_macros(&files); let mut global_labels = HashMap::new(); let mut prover_inputs = HashMap::new(); @@ -75,7 +80,10 @@ pub(crate) fn assemble(files: Vec, constants: HashMap) -> Ke let expanded_file = expand_macros(file.body, ¯os); let expanded_file = expand_repeats(expanded_file); let expanded_file = inline_constants(expanded_file, &constants); - let expanded_file = expand_stack_manipulation(expanded_file); + let mut expanded_file = expand_stack_manipulation(expanded_file); + if optimize { + optimize_asm(&mut expanded_file); + } local_labels.push(find_labels( &expanded_file, &mut offset, @@ -381,7 +389,7 @@ mod tests { let expected_kernel = Kernel::new(expected_code, expected_global_labels, HashMap::new()); let program = vec![file_1, file_2]; - assert_eq!(assemble(program, HashMap::new()), expected_kernel); + assert_eq!(assemble(program, HashMap::new(), false), expected_kernel); } #[test] @@ -399,7 +407,7 @@ mod tests { Item::StandardOp("JUMPDEST".to_string()), ], }; - assemble(vec![file_1, file_2], HashMap::new()); + assemble(vec![file_1, file_2], HashMap::new(), false); } #[test] @@ -413,7 +421,7 @@ mod tests { Item::StandardOp("ADD".to_string()), ], }; - assemble(vec![file], HashMap::new()); + assemble(vec![file], HashMap::new(), false); } #[test] @@ -421,7 +429,7 @@ mod tests { let file = File { body: vec![Item::Bytes(vec![0x12, 42]), Item::Bytes(vec![0xFE, 255])], }; - let code = assemble(vec![file], HashMap::new()).code; + let code = assemble(vec![file], HashMap::new(), false).code; assert_eq!(code, vec![0x12, 42, 0xfe, 255]); } @@ -438,10 +446,11 @@ mod tests { #[test] fn macro_with_vars() { - let kernel = parse_and_assemble(&[ + let files = &[ "%macro add(x, y) PUSH $x PUSH $y ADD %endmacro", "%add(2, 3)", - ]); + ]; + let kernel = parse_and_assemble_ext(files, HashMap::new(), false); let push1 = get_push_opcode(1); let add = get_opcode("ADD"); assert_eq!(kernel.code, vec![push1, 2, push1, 3, add]); @@ -479,7 +488,7 @@ mod tests { let mut constants = HashMap::new(); constants.insert("DEAD_BEEF".into(), 0xDEADBEEFu64.into()); - let kernel = parse_and_assemble_with_constants(code, constants); + let kernel = parse_and_assemble_ext(code, constants, true); let push4 = get_push_opcode(4); assert_eq!(kernel.code, vec![push4, 0xDE, 0xAD, 0xBE, 0xEF]); } @@ -510,7 +519,7 @@ mod tests { let mut consts = HashMap::new(); consts.insert("LIFE".into(), 42.into()); - parse_and_assemble_with_constants(&["%stack (a, b) -> (b, @LIFE)"], consts); + parse_and_assemble_ext(&["%stack (a, b) -> (b, @LIFE)"], consts, true); // We won't check the code since there are two equally efficient implementations. let kernel = parse_and_assemble(&["start: %stack (a, b) -> (start)"]); @@ -522,14 +531,15 @@ mod tests { } fn parse_and_assemble(files: &[&str]) -> Kernel { - parse_and_assemble_with_constants(files, HashMap::new()) + parse_and_assemble_ext(files, HashMap::new(), true) } - fn parse_and_assemble_with_constants( + fn parse_and_assemble_ext( files: &[&str], constants: HashMap, + optimize: bool, ) -> Kernel { let parsed_files = files.iter().map(|f| parse(f)).collect_vec(); - assemble(parsed_files, constants) + assemble(parsed_files, constants, optimize) } } diff --git a/evm/src/cpu/kernel/ast.rs b/evm/src/cpu/kernel/ast.rs index bc2a3ec2..a0de748a 100644 --- a/evm/src/cpu/kernel/ast.rs +++ b/evm/src/cpu/kernel/ast.rs @@ -7,7 +7,7 @@ pub(crate) struct File { pub(crate) body: Vec, } -#[derive(Clone, Debug)] +#[derive(Eq, PartialEq, Clone, Debug)] pub(crate) enum Item { /// Defines a new macro: name, params, body. MacroDef(String, Vec, Vec), @@ -34,7 +34,7 @@ pub(crate) enum Item { Bytes(Vec), } -#[derive(Clone, Debug)] +#[derive(Eq, PartialEq, Clone, Debug)] pub(crate) enum StackReplacement { /// Can be either a named item or a label. Identifier(String), diff --git a/evm/src/cpu/kernel/cost_estimator.rs b/evm/src/cpu/kernel/cost_estimator.rs new file mode 100644 index 00000000..3dfcf63a --- /dev/null +++ b/evm/src/cpu/kernel/cost_estimator.rs @@ -0,0 +1,37 @@ +use crate::cpu::kernel::assembler::BYTES_PER_OFFSET; +use crate::cpu::kernel::ast::Item; +use crate::cpu::kernel::ast::Item::*; +use crate::cpu::kernel::ast::PushTarget::*; +use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; + +pub(crate) fn is_code_improved(before: &[Item], after: &[Item]) -> bool { + cost_estimate(after) < cost_estimate(before) +} + +fn cost_estimate(code: &[Item]) -> u32 { + code.iter().map(cost_estimate_item).sum() +} + +fn cost_estimate_item(item: &Item) -> u32 { + match item { + MacroDef(_, _, _) => 0, + GlobalLabelDeclaration(_) => 0, + LocalLabelDeclaration(_) => 0, + Push(Literal(n)) => cost_estimate_push(u256_to_trimmed_be_bytes(n).len()), + Push(Label(_)) => cost_estimate_push(BYTES_PER_OFFSET as usize), + ProverInput(_) => 1, + StandardOp(op) => cost_estimate_standard_op(op.as_str()), + _ => panic!("Unexpected item: {:?}", item), + } +} + +fn cost_estimate_standard_op(_op: &str) -> u32 { + // For now we just treat any standard operation as having the same cost. This is pretty naive, + // but should work fine with our current set of simple optimization rules. + 1 +} + +fn cost_estimate_push(num_bytes: usize) -> u32 { + // TODO: Once PUSH is actually implemented, check if this needs to be revised. + num_bytes as u32 +} diff --git a/evm/src/cpu/kernel/mod.rs b/evm/src/cpu/kernel/mod.rs index 641ee529..4879ad76 100644 --- a/evm/src/cpu/kernel/mod.rs +++ b/evm/src/cpu/kernel/mod.rs @@ -3,9 +3,11 @@ pub mod assembler; mod ast; mod constants; mod context_metadata; +mod cost_estimator; mod global_metadata; pub(crate) mod keccak_util; mod opcodes; +mod optimizer; mod parser; pub mod prover_input; mod stack_manipulation; @@ -26,6 +28,6 @@ use crate::cpu::kernel::constants::evm_constants; /// This is for debugging the kernel only. pub fn assemble_to_bytes(files: &[String]) -> Vec { let parsed_files: Vec<_> = files.iter().map(|f| parse(f)).collect(); - let kernel = assemble(parsed_files, evm_constants()); + let kernel = assemble(parsed_files, evm_constants(), true); kernel.code } diff --git a/evm/src/cpu/kernel/optimizer.rs b/evm/src/cpu/kernel/optimizer.rs new file mode 100644 index 00000000..2a1db6d3 --- /dev/null +++ b/evm/src/cpu/kernel/optimizer.rs @@ -0,0 +1,260 @@ +use ethereum_types::U256; +use Item::{Push, StandardOp}; +use PushTarget::Literal; + +use crate::cpu::kernel::ast::Item::{GlobalLabelDeclaration, LocalLabelDeclaration}; +use crate::cpu::kernel::ast::PushTarget::Label; +use crate::cpu::kernel::ast::{Item, PushTarget}; +use crate::cpu::kernel::cost_estimator::is_code_improved; +use crate::cpu::kernel::utils::{replace_windows, u256_from_bool}; + +pub(crate) fn optimize_asm(code: &mut Vec) { + // Run the optimizer until nothing changes. + loop { + let old_code = code.clone(); + optimize_asm_once(code); + if code == &old_code { + break; + } + } +} + +/// A single optimization pass. +fn optimize_asm_once(code: &mut Vec) { + constant_propagation(code); + no_op_jumps(code); + remove_swapped_pushes(code); + remove_swaps_commutative(code); + remove_ignored_values(code); +} + +/// Constant propagation. +fn constant_propagation(code: &mut Vec) { + // Constant propagation for unary ops: `[PUSH x, UNARYOP] -> [PUSH UNARYOP(x)]` + replace_windows_if_better(code, |window| { + if let [Push(Literal(x)), StandardOp(op)] = window { + match op.as_str() { + "ISZERO" => Some(vec![Push(Literal(u256_from_bool(x.is_zero())))]), + "NOT" => Some(vec![Push(Literal(!x))]), + _ => None, + } + } else { + None + } + }); + + // Constant propagation for binary ops: `[PUSH y, PUSH x, BINOP] -> [PUSH BINOP(x, y)]` + replace_windows_if_better(code, |window| { + if let [Push(Literal(y)), Push(Literal(x)), StandardOp(op)] = window { + match op.as_str() { + "ADD" => Some(x.overflowing_add(y).0), + "SUB" => Some(x.overflowing_sub(y).0), + "MUL" => Some(x.overflowing_mul(y).0), + "DIV" => Some(x.checked_div(y).unwrap_or(U256::zero())), + "MOD" => Some(x.checked_rem(y).unwrap_or(U256::zero())), + "EXP" => Some(x.overflowing_pow(y).0), + "SHL" => Some(x << y), + "SHR" => Some(x >> y), + "AND" => Some(x & y), + "OR" => Some(x | y), + "XOR" => Some(x ^ y), + "LT" => Some(u256_from_bool(x < y)), + "GT" => Some(u256_from_bool(x > y)), + "EQ" => Some(u256_from_bool(x == y)), + "BYTE" => Some(if x < 32.into() { + y.byte(x.as_usize()).into() + } else { + U256::zero() + }), + _ => None, + } + .map(|res| vec![Push(Literal(res))]) + } else { + None + } + }); +} + +/// Remove no-op jumps: `[PUSH label, JUMP, label:] -> [label:]`. +fn no_op_jumps(code: &mut Vec) { + replace_windows(code, |window| { + if let [Push(Label(l)), StandardOp(jump), decl] = window + && &jump == "JUMP" + && (decl == LocalLabelDeclaration(l.clone()) || decl == GlobalLabelDeclaration(l.clone())) + { + Some(vec![LocalLabelDeclaration(l)]) + } else { + None + } + }); +} + +/// Remove swaps: `[PUSH x, PUSH y, SWAP1] -> [PUSH y, PUSH x]`. +// Could be generalized to recognize more than two pushes. +fn remove_swapped_pushes(code: &mut Vec) { + replace_windows(code, |window| { + if let [Push(x), Push(y), StandardOp(swap1)] = window + && &swap1 == "SWAP1" { + Some(vec![Push(y), Push(x)]) + } else { + None + } + }); +} + +/// Remove SWAP1 before a commutative function. +fn remove_swaps_commutative(code: &mut Vec) { + replace_windows(code, |window| { + if let [StandardOp(swap1), StandardOp(f)] = window && &swap1 == "SWAP1" { + let commutative = matches!(f.as_str(), "ADD" | "MUL" | "AND" | "OR" | "XOR" | "EQ"); + commutative.then_some(vec![StandardOp(f)]) + } else { + None + } + }); +} + +/// Remove push-pop type patterns, such as: `[DUP1, POP]`. +// Could be extended to other non-side-effecting operations, e.g. [DUP1, ADD, POP] -> [POP]. +fn remove_ignored_values(code: &mut Vec) { + replace_windows(code, |[a, b]| { + if let StandardOp(pop) = b && &pop == "POP" { + match a { + Push(_) => Some(vec![]), + StandardOp(dup) if dup.starts_with("DUP") => Some(vec![]), + _ => None, + } + } else { + None + } + }); +} + +/// Like `replace_windows`, but specifically for code, and only makes replacements if our cost +/// estimator thinks that the new code is more efficient. +fn replace_windows_if_better(code: &mut Vec, maybe_replace: F) +where + F: Fn([Item; W]) -> Option>, +{ + replace_windows(code, |window| { + maybe_replace(window.clone()).filter(|suggestion| is_code_improved(&window, suggestion)) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_constant_propagation_iszero() { + let mut code = vec![Push(Literal(3.into())), StandardOp("ISZERO".into())]; + constant_propagation(&mut code); + assert_eq!(code, vec![Push(Literal(0.into()))]); + } + + #[test] + fn test_constant_propagation_add_overflowing() { + let mut code = vec![ + Push(Literal(U256::max_value())), + Push(Literal(U256::max_value())), + StandardOp("ADD".into()), + ]; + constant_propagation(&mut code); + assert_eq!(code, vec![Push(Literal(U256::max_value() - 1))]); + } + + #[test] + fn test_constant_propagation_sub_underflowing() { + let original = vec![ + Push(Literal(U256::one())), + Push(Literal(U256::zero())), + StandardOp("SUB".into()), + ]; + let mut code = original.clone(); + constant_propagation(&mut code); + // Constant propagation could replace the code with [PUSH U256::MAX], but that's actually + // more expensive, so the code shouldn't be changed. + // (The code could also be replaced with [PUSH 0; NOT], which would be an improvement, but + // our optimizer isn't smart enough yet.) + assert_eq!(code, original); + } + + #[test] + fn test_constant_propagation_mul() { + let mut code = vec![ + Push(Literal(3.into())), + Push(Literal(4.into())), + StandardOp("MUL".into()), + ]; + constant_propagation(&mut code); + assert_eq!(code, vec![Push(Literal(12.into()))]); + } + + #[test] + fn test_constant_propagation_div() { + let mut code = vec![ + Push(Literal(3.into())), + Push(Literal(8.into())), + StandardOp("DIV".into()), + ]; + constant_propagation(&mut code); + assert_eq!(code, vec![Push(Literal(2.into()))]); + } + + #[test] + fn test_constant_propagation_div_zero() { + let mut code = vec![ + Push(Literal(0.into())), + Push(Literal(1.into())), + StandardOp("DIV".into()), + ]; + constant_propagation(&mut code); + assert_eq!(code, vec![Push(Literal(0.into()))]); + } + + #[test] + fn test_no_op_jump() { + let mut code = vec![ + Push(Label("mylabel".into())), + StandardOp("JUMP".into()), + LocalLabelDeclaration("mylabel".into()), + ]; + no_op_jumps(&mut code); + assert_eq!(code, vec![LocalLabelDeclaration("mylabel".into())]); + } + + #[test] + fn test_remove_swapped_pushes() { + let mut code = vec![ + Push(Literal("42".into())), + Push(Label("mylabel".into())), + StandardOp("SWAP1".into()), + ]; + remove_swapped_pushes(&mut code); + assert_eq!( + code, + vec![Push(Label("mylabel".into())), Push(Literal("42".into()))] + ); + } + + #[test] + fn test_remove_swap_mul() { + let mut code = vec![StandardOp("SWAP1".into()), StandardOp("MUL".into())]; + remove_swaps_commutative(&mut code); + assert_eq!(code, vec![StandardOp("MUL".into())]); + } + + #[test] + fn test_remove_push_pop() { + let mut code = vec![Push(Literal("42".into())), StandardOp("POP".into())]; + remove_ignored_values(&mut code); + assert_eq!(code, vec![]); + } + + #[test] + fn test_remove_dup_pop() { + let mut code = vec![StandardOp("DUP5".into()), StandardOp("POP".into())]; + remove_ignored_values(&mut code); + assert_eq!(code, vec![]); + } +} diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index 860dc19d..66bf0757 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -45,7 +45,7 @@ fn parse_item(item: Pair) -> Item { .collect::>() .into(), ), - Rule::nullary_instruction => Item::StandardOp(item.as_str().into()), + Rule::nullary_instruction => Item::StandardOp(item.as_str().to_uppercase()), _ => panic!("Unexpected {:?}", item.as_rule()), } } diff --git a/evm/src/cpu/kernel/utils.rs b/evm/src/cpu/kernel/utils.rs index d9682679..8900b8e2 100644 --- a/evm/src/cpu/kernel/utils.rs +++ b/evm/src/cpu/kernel/utils.rs @@ -1,16 +1,63 @@ +use std::fmt::Debug; + use ethereum_types::U256; use plonky2_util::ceil_div_usize; +/// Enumerate the length `W` windows of `vec`, and run `maybe_replace` on each one. +/// +/// Whenever `maybe_replace` returns `Some(replacement)`, the given replacement will be applied. +pub(crate) fn replace_windows(vec: &mut Vec, maybe_replace: F) +where + T: Clone + Debug, + F: Fn([T; W]) -> Option>, +{ + let mut start = 0; + while start + W <= vec.len() { + let range = start..start + W; + let window = vec[range.clone()].to_vec().try_into().unwrap(); + if let Some(replacement) = maybe_replace(window) { + vec.splice(range, replacement); + // Go back to the earliest window that changed. + start = start.saturating_sub(W - 1); + } else { + start += 1; + } + } +} + pub(crate) fn u256_to_trimmed_be_bytes(u256: &U256) -> Vec { let num_bytes = ceil_div_usize(u256.bits(), 8).max(1); // `byte` is little-endian, so we manually reverse it. (0..num_bytes).rev().map(|i| u256.byte(i)).collect() } +pub(crate) fn u256_from_bool(b: bool) -> U256 { + if b { + U256::one() + } else { + U256::zero() + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_replace_windows() { + // This replacement function adds pairs of integers together. + let mut vec = vec![1, 2, 3, 4, 5]; + replace_windows(&mut vec, |[x, y]| Some(vec![x + y])); + assert_eq!(vec, vec![15u32]); + + // This replacement function splits each composite integer into two factors. + let mut vec = vec![9, 1, 6, 8, 15, 7, 9]; + replace_windows(&mut vec, |[n]| { + (2..n).find(|d| n % d == 0).map(|d| vec![d, n / d]) + }); + assert_eq!(vec, vec![3, 3, 1, 2, 3, 2, 2, 2, 3, 5, 7, 3, 3]); + } + #[test] fn literal_to_be_bytes() { assert_eq!(u256_to_trimmed_be_bytes(&0.into()), vec![0x00]);