Merge pull request #648 from mir-protocol/optimizer

Some simple optimization rules
This commit is contained in:
Daniel Lubarov 2022-08-03 13:53:58 -07:00 committed by GitHub
commit 7481831b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 376 additions and 20 deletions

View File

@ -39,7 +39,7 @@ pub(crate) fn combined_kernel() -> Kernel {
];
let parsed_files = files.iter().map(|f| parse(f)).collect_vec();
assemble(parsed_files, evm_constants())
assemble(parsed_files, evm_constants(), true)
}
#[cfg(test)]

View File

@ -120,7 +120,7 @@
// stack: input, ...
PUSH $c
// stack: c, input, ...
GE // Check it backwards: (input <= c) == (c >= input)
LT ISZERO // Check it backwards: (input <= c) == !(c < input)
// stack: input <= c, ...
%endmacro
@ -136,7 +136,7 @@
// stack: input, ...
PUSH $c
// stack: c, input, ...
LE // Check it backwards: (input >= c) == (c <= input)
GT ISZERO // Check it backwards: (input >= c) == !(c > input)
// stack: input >= c, ...
%endmacro

View File

@ -7,6 +7,7 @@ use log::debug;
use super::ast::PushTarget;
use crate::cpu::kernel::ast::StackReplacement;
use crate::cpu::kernel::keccak_util::hash_kernel;
use crate::cpu::kernel::optimizer::optimize_asm;
use crate::cpu::kernel::prover_input::ProverInputFn;
use crate::cpu::kernel::stack_manipulation::expand_stack_manipulation;
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
@ -64,7 +65,11 @@ impl Macro {
}
}
pub(crate) fn assemble(files: Vec<File>, constants: HashMap<String, U256>) -> Kernel {
pub(crate) fn assemble(
files: Vec<File>,
constants: HashMap<String, U256>,
optimize: bool,
) -> Kernel {
let macros = find_macros(&files);
let mut global_labels = HashMap::new();
let mut prover_inputs = HashMap::new();
@ -75,7 +80,10 @@ pub(crate) fn assemble(files: Vec<File>, constants: HashMap<String, U256>) -> Ke
let expanded_file = expand_macros(file.body, &macros);
let expanded_file = expand_repeats(expanded_file);
let expanded_file = inline_constants(expanded_file, &constants);
let expanded_file = expand_stack_manipulation(expanded_file);
let mut expanded_file = expand_stack_manipulation(expanded_file);
if optimize {
optimize_asm(&mut expanded_file);
}
local_labels.push(find_labels(
&expanded_file,
&mut offset,
@ -381,7 +389,7 @@ mod tests {
let expected_kernel = Kernel::new(expected_code, expected_global_labels, HashMap::new());
let program = vec![file_1, file_2];
assert_eq!(assemble(program, HashMap::new()), expected_kernel);
assert_eq!(assemble(program, HashMap::new(), false), expected_kernel);
}
#[test]
@ -399,7 +407,7 @@ mod tests {
Item::StandardOp("JUMPDEST".to_string()),
],
};
assemble(vec![file_1, file_2], HashMap::new());
assemble(vec![file_1, file_2], HashMap::new(), false);
}
#[test]
@ -413,7 +421,7 @@ mod tests {
Item::StandardOp("ADD".to_string()),
],
};
assemble(vec![file], HashMap::new());
assemble(vec![file], HashMap::new(), false);
}
#[test]
@ -421,7 +429,7 @@ mod tests {
let file = File {
body: vec![Item::Bytes(vec![0x12, 42]), Item::Bytes(vec![0xFE, 255])],
};
let code = assemble(vec![file], HashMap::new()).code;
let code = assemble(vec![file], HashMap::new(), false).code;
assert_eq!(code, vec![0x12, 42, 0xfe, 255]);
}
@ -438,10 +446,11 @@ mod tests {
#[test]
fn macro_with_vars() {
let kernel = parse_and_assemble(&[
let files = &[
"%macro add(x, y) PUSH $x PUSH $y ADD %endmacro",
"%add(2, 3)",
]);
];
let kernel = parse_and_assemble_ext(files, HashMap::new(), false);
let push1 = get_push_opcode(1);
let add = get_opcode("ADD");
assert_eq!(kernel.code, vec![push1, 2, push1, 3, add]);
@ -479,7 +488,7 @@ mod tests {
let mut constants = HashMap::new();
constants.insert("DEAD_BEEF".into(), 0xDEADBEEFu64.into());
let kernel = parse_and_assemble_with_constants(code, constants);
let kernel = parse_and_assemble_ext(code, constants, true);
let push4 = get_push_opcode(4);
assert_eq!(kernel.code, vec![push4, 0xDE, 0xAD, 0xBE, 0xEF]);
}
@ -510,7 +519,7 @@ mod tests {
let mut consts = HashMap::new();
consts.insert("LIFE".into(), 42.into());
parse_and_assemble_with_constants(&["%stack (a, b) -> (b, @LIFE)"], consts);
parse_and_assemble_ext(&["%stack (a, b) -> (b, @LIFE)"], consts, true);
// We won't check the code since there are two equally efficient implementations.
let kernel = parse_and_assemble(&["start: %stack (a, b) -> (start)"]);
@ -522,14 +531,15 @@ mod tests {
}
fn parse_and_assemble(files: &[&str]) -> Kernel {
parse_and_assemble_with_constants(files, HashMap::new())
parse_and_assemble_ext(files, HashMap::new(), true)
}
fn parse_and_assemble_with_constants(
fn parse_and_assemble_ext(
files: &[&str],
constants: HashMap<String, U256>,
optimize: bool,
) -> Kernel {
let parsed_files = files.iter().map(|f| parse(f)).collect_vec();
assemble(parsed_files, constants)
assemble(parsed_files, constants, optimize)
}
}

View File

@ -7,7 +7,7 @@ pub(crate) struct File {
pub(crate) body: Vec<Item>,
}
#[derive(Clone, Debug)]
#[derive(Eq, PartialEq, Clone, Debug)]
pub(crate) enum Item {
/// Defines a new macro: name, params, body.
MacroDef(String, Vec<String>, Vec<Item>),
@ -34,7 +34,7 @@ pub(crate) enum Item {
Bytes(Vec<u8>),
}
#[derive(Clone, Debug)]
#[derive(Eq, PartialEq, Clone, Debug)]
pub(crate) enum StackReplacement {
/// Can be either a named item or a label.
Identifier(String),

View File

@ -0,0 +1,37 @@
use crate::cpu::kernel::assembler::BYTES_PER_OFFSET;
use crate::cpu::kernel::ast::Item;
use crate::cpu::kernel::ast::Item::*;
use crate::cpu::kernel::ast::PushTarget::*;
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
pub(crate) fn is_code_improved(before: &[Item], after: &[Item]) -> bool {
cost_estimate(after) < cost_estimate(before)
}
fn cost_estimate(code: &[Item]) -> u32 {
code.iter().map(cost_estimate_item).sum()
}
fn cost_estimate_item(item: &Item) -> u32 {
match item {
MacroDef(_, _, _) => 0,
GlobalLabelDeclaration(_) => 0,
LocalLabelDeclaration(_) => 0,
Push(Literal(n)) => cost_estimate_push(u256_to_trimmed_be_bytes(n).len()),
Push(Label(_)) => cost_estimate_push(BYTES_PER_OFFSET as usize),
ProverInput(_) => 1,
StandardOp(op) => cost_estimate_standard_op(op.as_str()),
_ => panic!("Unexpected item: {:?}", item),
}
}
fn cost_estimate_standard_op(_op: &str) -> u32 {
// For now we just treat any standard operation as having the same cost. This is pretty naive,
// but should work fine with our current set of simple optimization rules.
1
}
fn cost_estimate_push(num_bytes: usize) -> u32 {
// TODO: Once PUSH is actually implemented, check if this needs to be revised.
num_bytes as u32
}

View File

@ -3,9 +3,11 @@ pub mod assembler;
mod ast;
mod constants;
mod context_metadata;
mod cost_estimator;
mod global_metadata;
pub(crate) mod keccak_util;
mod opcodes;
mod optimizer;
mod parser;
pub mod prover_input;
mod stack_manipulation;
@ -26,6 +28,6 @@ use crate::cpu::kernel::constants::evm_constants;
/// This is for debugging the kernel only.
pub fn assemble_to_bytes(files: &[String]) -> Vec<u8> {
let parsed_files: Vec<_> = files.iter().map(|f| parse(f)).collect();
let kernel = assemble(parsed_files, evm_constants());
let kernel = assemble(parsed_files, evm_constants(), true);
kernel.code
}

View File

@ -0,0 +1,260 @@
use ethereum_types::U256;
use Item::{Push, StandardOp};
use PushTarget::Literal;
use crate::cpu::kernel::ast::Item::{GlobalLabelDeclaration, LocalLabelDeclaration};
use crate::cpu::kernel::ast::PushTarget::Label;
use crate::cpu::kernel::ast::{Item, PushTarget};
use crate::cpu::kernel::cost_estimator::is_code_improved;
use crate::cpu::kernel::utils::{replace_windows, u256_from_bool};
pub(crate) fn optimize_asm(code: &mut Vec<Item>) {
// Run the optimizer until nothing changes.
loop {
let old_code = code.clone();
optimize_asm_once(code);
if code == &old_code {
break;
}
}
}
/// A single optimization pass.
fn optimize_asm_once(code: &mut Vec<Item>) {
constant_propagation(code);
no_op_jumps(code);
remove_swapped_pushes(code);
remove_swaps_commutative(code);
remove_ignored_values(code);
}
/// Constant propagation.
fn constant_propagation(code: &mut Vec<Item>) {
// Constant propagation for unary ops: `[PUSH x, UNARYOP] -> [PUSH UNARYOP(x)]`
replace_windows_if_better(code, |window| {
if let [Push(Literal(x)), StandardOp(op)] = window {
match op.as_str() {
"ISZERO" => Some(vec![Push(Literal(u256_from_bool(x.is_zero())))]),
"NOT" => Some(vec![Push(Literal(!x))]),
_ => None,
}
} else {
None
}
});
// Constant propagation for binary ops: `[PUSH y, PUSH x, BINOP] -> [PUSH BINOP(x, y)]`
replace_windows_if_better(code, |window| {
if let [Push(Literal(y)), Push(Literal(x)), StandardOp(op)] = window {
match op.as_str() {
"ADD" => Some(x.overflowing_add(y).0),
"SUB" => Some(x.overflowing_sub(y).0),
"MUL" => Some(x.overflowing_mul(y).0),
"DIV" => Some(x.checked_div(y).unwrap_or(U256::zero())),
"MOD" => Some(x.checked_rem(y).unwrap_or(U256::zero())),
"EXP" => Some(x.overflowing_pow(y).0),
"SHL" => Some(x << y),
"SHR" => Some(x >> y),
"AND" => Some(x & y),
"OR" => Some(x | y),
"XOR" => Some(x ^ y),
"LT" => Some(u256_from_bool(x < y)),
"GT" => Some(u256_from_bool(x > y)),
"EQ" => Some(u256_from_bool(x == y)),
"BYTE" => Some(if x < 32.into() {
y.byte(x.as_usize()).into()
} else {
U256::zero()
}),
_ => None,
}
.map(|res| vec![Push(Literal(res))])
} else {
None
}
});
}
/// Remove no-op jumps: `[PUSH label, JUMP, label:] -> [label:]`.
fn no_op_jumps(code: &mut Vec<Item>) {
replace_windows(code, |window| {
if let [Push(Label(l)), StandardOp(jump), decl] = window
&& &jump == "JUMP"
&& (decl == LocalLabelDeclaration(l.clone()) || decl == GlobalLabelDeclaration(l.clone()))
{
Some(vec![LocalLabelDeclaration(l)])
} else {
None
}
});
}
/// Remove swaps: `[PUSH x, PUSH y, SWAP1] -> [PUSH y, PUSH x]`.
// Could be generalized to recognize more than two pushes.
fn remove_swapped_pushes(code: &mut Vec<Item>) {
replace_windows(code, |window| {
if let [Push(x), Push(y), StandardOp(swap1)] = window
&& &swap1 == "SWAP1" {
Some(vec![Push(y), Push(x)])
} else {
None
}
});
}
/// Remove SWAP1 before a commutative function.
fn remove_swaps_commutative(code: &mut Vec<Item>) {
replace_windows(code, |window| {
if let [StandardOp(swap1), StandardOp(f)] = window && &swap1 == "SWAP1" {
let commutative = matches!(f.as_str(), "ADD" | "MUL" | "AND" | "OR" | "XOR" | "EQ");
commutative.then_some(vec![StandardOp(f)])
} else {
None
}
});
}
/// Remove push-pop type patterns, such as: `[DUP1, POP]`.
// Could be extended to other non-side-effecting operations, e.g. [DUP1, ADD, POP] -> [POP].
fn remove_ignored_values(code: &mut Vec<Item>) {
replace_windows(code, |[a, b]| {
if let StandardOp(pop) = b && &pop == "POP" {
match a {
Push(_) => Some(vec![]),
StandardOp(dup) if dup.starts_with("DUP") => Some(vec![]),
_ => None,
}
} else {
None
}
});
}
/// Like `replace_windows`, but specifically for code, and only makes replacements if our cost
/// estimator thinks that the new code is more efficient.
fn replace_windows_if_better<const W: usize, F>(code: &mut Vec<Item>, maybe_replace: F)
where
F: Fn([Item; W]) -> Option<Vec<Item>>,
{
replace_windows(code, |window| {
maybe_replace(window.clone()).filter(|suggestion| is_code_improved(&window, suggestion))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_propagation_iszero() {
let mut code = vec![Push(Literal(3.into())), StandardOp("ISZERO".into())];
constant_propagation(&mut code);
assert_eq!(code, vec![Push(Literal(0.into()))]);
}
#[test]
fn test_constant_propagation_add_overflowing() {
let mut code = vec![
Push(Literal(U256::max_value())),
Push(Literal(U256::max_value())),
StandardOp("ADD".into()),
];
constant_propagation(&mut code);
assert_eq!(code, vec![Push(Literal(U256::max_value() - 1))]);
}
#[test]
fn test_constant_propagation_sub_underflowing() {
let original = vec![
Push(Literal(U256::one())),
Push(Literal(U256::zero())),
StandardOp("SUB".into()),
];
let mut code = original.clone();
constant_propagation(&mut code);
// Constant propagation could replace the code with [PUSH U256::MAX], but that's actually
// more expensive, so the code shouldn't be changed.
// (The code could also be replaced with [PUSH 0; NOT], which would be an improvement, but
// our optimizer isn't smart enough yet.)
assert_eq!(code, original);
}
#[test]
fn test_constant_propagation_mul() {
let mut code = vec![
Push(Literal(3.into())),
Push(Literal(4.into())),
StandardOp("MUL".into()),
];
constant_propagation(&mut code);
assert_eq!(code, vec![Push(Literal(12.into()))]);
}
#[test]
fn test_constant_propagation_div() {
let mut code = vec![
Push(Literal(3.into())),
Push(Literal(8.into())),
StandardOp("DIV".into()),
];
constant_propagation(&mut code);
assert_eq!(code, vec![Push(Literal(2.into()))]);
}
#[test]
fn test_constant_propagation_div_zero() {
let mut code = vec![
Push(Literal(0.into())),
Push(Literal(1.into())),
StandardOp("DIV".into()),
];
constant_propagation(&mut code);
assert_eq!(code, vec![Push(Literal(0.into()))]);
}
#[test]
fn test_no_op_jump() {
let mut code = vec![
Push(Label("mylabel".into())),
StandardOp("JUMP".into()),
LocalLabelDeclaration("mylabel".into()),
];
no_op_jumps(&mut code);
assert_eq!(code, vec![LocalLabelDeclaration("mylabel".into())]);
}
#[test]
fn test_remove_swapped_pushes() {
let mut code = vec![
Push(Literal("42".into())),
Push(Label("mylabel".into())),
StandardOp("SWAP1".into()),
];
remove_swapped_pushes(&mut code);
assert_eq!(
code,
vec![Push(Label("mylabel".into())), Push(Literal("42".into()))]
);
}
#[test]
fn test_remove_swap_mul() {
let mut code = vec![StandardOp("SWAP1".into()), StandardOp("MUL".into())];
remove_swaps_commutative(&mut code);
assert_eq!(code, vec![StandardOp("MUL".into())]);
}
#[test]
fn test_remove_push_pop() {
let mut code = vec![Push(Literal("42".into())), StandardOp("POP".into())];
remove_ignored_values(&mut code);
assert_eq!(code, vec![]);
}
#[test]
fn test_remove_dup_pop() {
let mut code = vec![StandardOp("DUP5".into()), StandardOp("POP".into())];
remove_ignored_values(&mut code);
assert_eq!(code, vec![]);
}
}

View File

@ -45,7 +45,7 @@ fn parse_item(item: Pair<Rule>) -> Item {
.collect::<Vec<_>>()
.into(),
),
Rule::nullary_instruction => Item::StandardOp(item.as_str().into()),
Rule::nullary_instruction => Item::StandardOp(item.as_str().to_uppercase()),
_ => panic!("Unexpected {:?}", item.as_rule()),
}
}

View File

@ -1,16 +1,63 @@
use std::fmt::Debug;
use ethereum_types::U256;
use plonky2_util::ceil_div_usize;
/// Enumerate the length `W` windows of `vec`, and run `maybe_replace` on each one.
///
/// Whenever `maybe_replace` returns `Some(replacement)`, the given replacement will be applied.
pub(crate) fn replace_windows<const W: usize, T, F>(vec: &mut Vec<T>, maybe_replace: F)
where
T: Clone + Debug,
F: Fn([T; W]) -> Option<Vec<T>>,
{
let mut start = 0;
while start + W <= vec.len() {
let range = start..start + W;
let window = vec[range.clone()].to_vec().try_into().unwrap();
if let Some(replacement) = maybe_replace(window) {
vec.splice(range, replacement);
// Go back to the earliest window that changed.
start = start.saturating_sub(W - 1);
} else {
start += 1;
}
}
}
pub(crate) fn u256_to_trimmed_be_bytes(u256: &U256) -> Vec<u8> {
let num_bytes = ceil_div_usize(u256.bits(), 8).max(1);
// `byte` is little-endian, so we manually reverse it.
(0..num_bytes).rev().map(|i| u256.byte(i)).collect()
}
pub(crate) fn u256_from_bool(b: bool) -> U256 {
if b {
U256::one()
} else {
U256::zero()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replace_windows() {
// This replacement function adds pairs of integers together.
let mut vec = vec![1, 2, 3, 4, 5];
replace_windows(&mut vec, |[x, y]| Some(vec![x + y]));
assert_eq!(vec, vec![15u32]);
// This replacement function splits each composite integer into two factors.
let mut vec = vec![9, 1, 6, 8, 15, 7, 9];
replace_windows(&mut vec, |[n]| {
(2..n).find(|d| n % d == 0).map(|d| vec![d, n / d])
});
assert_eq!(vec, vec![3, 3, 1, 2, 3, 2, 2, 2, 3, 5, 7, 3, 3]);
}
#[test]
fn literal_to_be_bytes() {
assert_eq!(u256_to_trimmed_be_bytes(&0.into()), vec![0x00]);