From 763d63de0875be4238328b752a39e6919b568062 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 4 Aug 2022 17:14:07 -0500 Subject: [PATCH] For permutations, find the optimal sequence of swaps Using a method Angus described. This is mainly his idea and code, I just ported it to Rust. --- evm/Cargo.toml | 5 + evm/benches/stack_manipulation.rs | 75 +++++ evm/src/cpu/kernel/assembler.rs | 2 +- evm/src/cpu/kernel/mod.rs | 2 +- evm/src/cpu/kernel/stack/mod.rs | 2 + evm/src/cpu/kernel/stack/permutations.rs | 278 ++++++++++++++++++ .../kernel/{ => stack}/stack_manipulation.rs | 71 ++++- 7 files changed, 429 insertions(+), 6 deletions(-) create mode 100644 evm/benches/stack_manipulation.rs create mode 100644 evm/src/cpu/kernel/stack/mod.rs create mode 100644 evm/src/cpu/kernel/stack/permutations.rs rename evm/src/cpu/kernel/{ => stack}/stack_manipulation.rs (82%) diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 48ef12d7..e844da3a 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -25,6 +25,7 @@ keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" } keccak-hash = "0.9.0" [dev-dependencies] +criterion = "0.3.5" hex = "0.4.3" [features] @@ -35,3 +36,7 @@ parallel = ["maybe_rayon/parallel"] [[bin]] name = "assemble" required-features = ["asmtools"] + +[[bench]] +name = "stack_manipulation" +harness = false diff --git a/evm/benches/stack_manipulation.rs b/evm/benches/stack_manipulation.rs new file mode 100644 index 00000000..20f86512 --- /dev/null +++ b/evm/benches/stack_manipulation.rs @@ -0,0 +1,75 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use plonky2_evm::cpu::kernel::assemble_to_bytes; + +fn criterion_benchmark(c: &mut Criterion) { + rotl_group(c); + rotr_group(c); + insert_group(c); + delete_group(c); + replace_group(c); + shuffle_group(c); + misc_group(c); +} + +fn rotl_group(c: &mut Criterion) { + let mut group = c.benchmark_group("rotl"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (b, c, d, e, f, g, h, a)")) + }); +} + +fn rotr_group(c: &mut Criterion) { + let mut group = c.benchmark_group("rotr"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (h, a, b, c, d, e, f, g)")) + }); +} + +fn insert_group(c: &mut Criterion) { + let mut group = c.benchmark_group("insert"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, d, 123, e, f, g, h)")) + }); +} + +fn delete_group(c: &mut Criterion) { + let mut group = c.benchmark_group("delete"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, e, f, g, h)")) + }); +} + +fn replace_group(c: &mut Criterion) { + let mut group = c.benchmark_group("replace"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, 5, e, f, g, h)")) + }); +} + +fn shuffle_group(c: &mut Criterion) { + let mut group = c.benchmark_group("shuffle"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (g, d, h, a, f, e, b, c)")) + }); +} + +fn misc_group(c: &mut Criterion) { + let mut group = c.benchmark_group("misc"); + group.sample_size(10); + group.bench_function(BenchmarkId::from_parameter(8), |b| { + b.iter(|| assemble("%stack (a, b, c, a, e, f, g, h) -> (g, 1, h, g, f, 3, b, b)")) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +fn assemble(code: &str) { + assemble_to_bytes(&[code.into()]); +} diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 14ec9aa0..f5175c41 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -10,7 +10,7 @@ use crate::cpu::kernel::ast::StackReplacement; use crate::cpu::kernel::keccak_util::hash_kernel; use crate::cpu::kernel::optimizer::optimize_asm; use crate::cpu::kernel::prover_input::ProverInputFn; -use crate::cpu::kernel::stack_manipulation::expand_stack_manipulation; +use crate::cpu::kernel::stack::stack_manipulation::expand_stack_manipulation; use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; use crate::cpu::kernel::{ ast::{File, Item}, diff --git a/evm/src/cpu/kernel/mod.rs b/evm/src/cpu/kernel/mod.rs index 4879ad76..eceba813 100644 --- a/evm/src/cpu/kernel/mod.rs +++ b/evm/src/cpu/kernel/mod.rs @@ -10,7 +10,7 @@ mod opcodes; mod optimizer; mod parser; pub mod prover_input; -mod stack_manipulation; +pub mod stack; mod txn_fields; mod utils; diff --git a/evm/src/cpu/kernel/stack/mod.rs b/evm/src/cpu/kernel/stack/mod.rs new file mode 100644 index 00000000..4c7640e4 --- /dev/null +++ b/evm/src/cpu/kernel/stack/mod.rs @@ -0,0 +1,2 @@ +mod permutations; +pub mod stack_manipulation; diff --git a/evm/src/cpu/kernel/stack/permutations.rs b/evm/src/cpu/kernel/stack/permutations.rs new file mode 100644 index 00000000..d64755ed --- /dev/null +++ b/evm/src/cpu/kernel/stack/permutations.rs @@ -0,0 +1,278 @@ +//! This module contains logic for finding the optimal sequence of swaps to get from one stack state +//! to another, specifically for the case where the source and destination states are permutations +//! of one another. +//! +//! We solve the problem in three steps: +//! 1. Find a permutation `P` such that `P A = B`. +//! 2. If `A` contains duplicates, optimize `P` by reducing the number of cycles. +//! 3. Convert each cycle into a set of `(0 i)` transpositions, which correspond to swap +//! instructions in the EVM. +//! +//! We typically represent a permutation as a sequence of cycles. For example, the permutation +//! `(1 2 3)(1 2)(4 5)` acts as: +//! +//! ```ignore +//! (1 2 3)(1 2)(4 5)[A_0, A_1, A_2, A_3, A_4, A_5] = (1 2 3)(1 2)[A_0, A_1, A_2, A_3, A_5, A_4] +//! = (1 2 3)[A_0, A_2, A_1, A_3, A_5, A_4] +//! = [A_0, A_3, A_2, A_1, A_5, A_4] +//! ``` +//! +//! We typically represent a `(0 i)` transposition as a single scalar `i`. + +use std::collections::{HashMap, HashSet}; +use std::hash::Hash; + +use crate::cpu::kernel::stack::stack_manipulation::{StackItem, StackOp}; + +/// Find the optimal sequence of stack operations to get from `src` to `dst`. Assumes that `src` and +/// `dst` are permutations of one another. +pub(crate) fn get_stack_ops_for_perm(src: &[StackItem], dst: &[StackItem]) -> Vec { + // We store stacks with the tip at the end, but the permutation calls below use the opposite + // convention. They're a bit simpler when SWAP are (0 i) transposes. + let mut src = src.to_vec(); + let mut dst = dst.to_vec(); + src.reverse(); + dst.reverse(); + + let perm = find_permutation(&src, &dst); + let optimized_perm = combine_cycles(perm, &src); + let trans = permutation_to_transpositions(optimized_perm); + transpositions_to_stack_ops(trans) +} + +/// Apply the given permutation to the given list. +#[cfg(test)] +fn apply_perm(permutation: Vec>, mut lst: Vec) -> Vec { + // Run through perm in REVERSE order. + for cycl in permutation.iter().rev() { + let n = cycl.len(); + let last = lst[cycl[n - 1]].clone(); + for i in (0..n - 1).rev() { + let j = (i + 1) % n; + lst[cycl[j]] = lst[cycl[i]].clone(); + } + lst[cycl[0]] = last; + } + lst +} + +/// This function does STEP 1. +/// Given 2 lists A, B find a permutation P such that P . A = B. +pub fn find_permutation(lst_a: &[T], lst_b: &[T]) -> Vec> { + // We should check to ensure that A and B are indeed rearrangments of each other. + assert!(is_permutation(lst_a, lst_b)); + + let n = lst_a.len(); + + // Keep track of the A_i's which have been already placed into the correct position. + let mut correct_a = HashSet::new(); + + // loc_b is a dictionary where loc_b[b] is the indices i where b = B_i != A_i. + // We need to swap appropriate A_j's into these positions. + let mut loc_b: HashMap> = HashMap::new(); + + for i in 0..n { + if lst_a[i] == lst_b[i] { + // If A_i = B_i, we never do SWAP_i as we are already in the correct position. + correct_a.insert(i); + } else { + loc_b.entry(lst_b[i].clone()).or_default().push(i); + } + } + + // This will be a list of disjoint cycles. + let mut permutation = vec![]; + + // For technical reasons, it's handy to include [0] as a trivial cycle. + // This is because if A_0 = A_i for some other i in a cycle, + // we can save transpositions by expanding the cycle to include 0. + if correct_a.contains(&0) { + permutation.push(vec![0]); + } + + for i in 0..n { + // If i is both not in the correct position and not already in a cycle, it will start a new cycle. + if correct_a.contains(&i) { + continue; + } + + correct_a.insert(i); + let mut cycl = vec![i]; + + // lst_a[i] need to be swapped into an index j such that lst_b[j] = lst_a[i]. + // This exactly means j should be an element of loc_b[lst_a[i]]. + // We pop as each j should only be used once. + // In this step we simply find any permutation. We will improve it to an optimal one in STEP 2. + let mut j = loc_b.get_mut(&lst_a[i]).unwrap().pop().unwrap(); + + // Keep adding elements to the cycle until we return to our initial index + while j != i { + correct_a.insert(j); + cycl.push(j); + j = loc_b.get_mut(&lst_a[j]).unwrap().pop().unwrap(); + } + + permutation.push(cycl); + } + permutation +} + +/// This function does STEP 2. It tests to see if cycles can be combined which might occur if A has duplicates. +fn combine_cycles(mut perm: Vec>, lst_a: &[T]) -> Vec> { + // If perm is a single cycle, there is nothing to combine. + if perm.len() == 1 { + return perm; + } + + let n = lst_a.len(); + + // Need a dictionary to keep track of duplicates in lst_a. + let mut all_a_positions: HashMap> = HashMap::new(); + for i in 0..n { + all_a_positions.entry(lst_a[i].clone()).or_default().push(i); + } + + // For each element a which occurs at positions i1, ..., ij, combine cycles such that all + // ik which occur in a cycle occur in the same cycle. + for positions in all_a_positions.values() { + if positions.len() == 1 { + continue; + } + + let mut joinedperm = vec![]; + let mut newperm = vec![]; + let mut pos = 0; + for cycl in perm { + // Does cycl include an element of positions? + let mut disjoint = true; + + for term in positions { + if cycl.contains(term) { + if joinedperm.is_empty() { + // This is the first cycle we have found including an element of positions. + joinedperm = cycl.clone(); + pos = cycl.iter().position(|x| x == term).unwrap(); + } else { + // Need to merge 2 cycles. If A_i = A_j then the permutations + // (C_1, ..., C_k1, i, C_{k1 + 1}, ... C_k2)(D_1, ..., D_k3, j, D_{k3 + 1}, ... D_k4) + // (C_1, ..., C_k1, i, D_{k3 + 1}, ... D_k4, D_1, ..., D_k3, j, C_{k1 + 1}, ... C_k2) + // lead to the same oupput but the second will require less transpositions. + let newpos = cycl.iter().position(|x| x == term).unwrap(); + joinedperm = [ + &joinedperm[..pos + 1], + &cycl[newpos + 1..], + &cycl[..newpos + 1], + &joinedperm[pos + 1..], + ] + .concat(); + } + disjoint = false; + break; + } + } + if disjoint { + newperm.push(cycl); + } + } + if !joinedperm.is_empty() { + newperm.push(joinedperm); + } + perm = newperm; + } + perm +} + +// This function does STEP 3. Converting all cycles to [0, i] transpositions. +fn permutation_to_transpositions(perm: Vec>) -> Vec { + let mut trans = vec![]; + // The method is pretty simple, we have: + // (0 C_1 ... C_i) = (0 C_i) ... (0 C_1) + // (C_1 ... C_i) = (0 C_1) (0 C_i) ... (0\ C_1). + // We simply need to check to see if 0 is in our cycle to see which one to use. + for cycl in perm { + let n = cycl.len(); + let zero_pos = cycl.iter().position(|x| *x == 0); + if let Some(pos) = zero_pos { + trans.extend((1..n).map(|i| cycl[(n + pos - i) % n])); + } else { + trans.extend((0..=n).map(|i| cycl[(n - i) % n])); + } + } + trans +} + +#[cfg(test)] +fn trans_to_perm(trans: Vec) -> Vec> { + trans.into_iter().map(|i| vec![0, i]).collect() +} + +fn transpositions_to_stack_ops(trans: Vec) -> Vec { + trans.into_iter().map(|i| StackOp::Swap(i as u8)).collect() +} + +pub fn is_permutation(a: &[T], b: &[T]) -> bool { + make_multiset(a) == make_multiset(b) +} + +fn make_multiset(vals: &[T]) -> HashMap { + let mut counts = HashMap::new(); + for val in vals { + *counts.entry(val.clone()).or_default() += 1; + } + counts +} + +#[cfg(test)] +mod tests { + use rand::prelude::SliceRandom; + use rand::thread_rng; + + use crate::cpu::kernel::stack::permutations::{ + apply_perm, combine_cycles, find_permutation, is_permutation, + permutation_to_transpositions, trans_to_perm, + }; + + #[test] + fn test_combine_cycles() { + assert_eq!( + combine_cycles(vec![vec![0, 2], vec![3, 4]], &['a', 'b', 'c', 'd', 'a']), + vec![vec![0, 3, 4, 2]] + ); + } + + #[test] + fn test_is_permutation() { + assert!(is_permutation(&['a', 'b', 'c'], &['b', 'c', 'a'])); + assert!(!is_permutation(&['a', 'b', 'c'], &['a', 'b', 'b', 'c'])); + assert!(!is_permutation(&['a', 'b', 'c'], &['a', 'd', 'c'])); + } + + #[test] + fn test_all() { + let mut test_lst = vec![ + 'a', 'a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'e', 'f', 'g', 'h', 'k', + ]; + + let mut rng = thread_rng(); + test_lst.shuffle(&mut rng); + for _ in 0..1000 { + let lst_a = test_lst.clone(); + test_lst.shuffle(&mut rng); + let lst_b = test_lst.clone(); + + let perm = find_permutation(&lst_a, &lst_b); + assert_eq!(apply_perm(perm.clone(), lst_a.clone()), lst_b); + + let shortperm = combine_cycles(perm.clone(), &lst_a); + assert_eq!(apply_perm(shortperm.clone(), lst_a.clone()), lst_b); + + let trans = trans_to_perm(permutation_to_transpositions(perm)); + assert_eq!(apply_perm(trans.clone(), lst_a.clone()), lst_b); + + let shorttrans = trans_to_perm(permutation_to_transpositions(shortperm)); + assert_eq!(apply_perm(shorttrans.clone(), lst_a.clone()), lst_b); + + assert!(shorttrans.len() <= trans.len()); + } + } +} diff --git a/evm/src/cpu/kernel/stack_manipulation.rs b/evm/src/cpu/kernel/stack/stack_manipulation.rs similarity index 82% rename from evm/src/cpu/kernel/stack_manipulation.rs rename to evm/src/cpu/kernel/stack/stack_manipulation.rs index a1f02c7e..9f685953 100644 --- a/evm/src/cpu/kernel/stack_manipulation.rs +++ b/evm/src/cpu/kernel/stack/stack_manipulation.rs @@ -1,13 +1,15 @@ use std::cmp::Ordering; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::{BinaryHeap, HashMap}; +use std::hash::Hash; use itertools::Itertools; use crate::cpu::columns::NUM_CPU_COLUMNS; use crate::cpu::kernel::assembler::BYTES_PER_OFFSET; use crate::cpu::kernel::ast::{Item, PushTarget, StackReplacement}; -use crate::cpu::kernel::stack_manipulation::StackOp::Pop; +use crate::cpu::kernel::stack::permutations::{get_stack_ops_for_perm, is_permutation}; +use crate::cpu::kernel::stack::stack_manipulation::StackOp::Pop; use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; use crate::memory; @@ -164,13 +166,13 @@ impl Ord for Node { /// Like `StackReplacement`, but without constants or macro vars, since those were expanded already. #[derive(Eq, PartialEq, Hash, Clone, Debug)] -enum StackItem { +pub(crate) enum StackItem { NamedItem(String), PushTarget(PushTarget), } #[derive(Clone, Debug)] -enum StackOp { +pub(crate) enum StackOp { Push(PushTarget), Pop, Dup(u8), @@ -188,6 +190,11 @@ fn next_ops( return vec![StackOp::Pop] } + if is_permutation(src, dst) { + // The transpositions are right-associative, so the last one gets applied first, hence pop. + return vec![get_stack_ops_for_perm(src, dst).pop().unwrap()]; + } + let mut ops = vec![StackOp::Pop]; ops.extend( @@ -220,11 +227,31 @@ fn next_ops( .map(StackOp::Dup), ); - ops.extend((1..src_len).map(StackOp::Swap)); + ops.extend( + (1..src_len) + .filter(|i| should_try_swap(src, dst, *i)) + .map(StackOp::Swap), + ); ops } +/// Whether we should consider `SWAP_i` in the search. +fn should_try_swap(src: &[StackItem], dst: &[StackItem], i: u8) -> bool { + if src.is_empty() { + return false; + } + + let i = i as usize; + let i_from = src.len() - 1; + let i_to = i_from - i; + + // Only consider a swap if it places one of the two affected elements in the desired position. + let top_correct_pos = i_to < dst.len() && src[i_from] == dst[i_to]; + let other_correct_pos = i_from < dst.len() && src[i_to] == dst[i_from]; + top_correct_pos | other_correct_pos +} + impl StackOp { fn cost(&self) -> u32 { let (cpu_rows, memory_rows) = match self { @@ -287,3 +314,39 @@ impl StackOp { } } } + +#[cfg(test)] +mod tests { + use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; + + use crate::cpu::kernel::stack::stack_manipulation::StackItem::NamedItem; + use crate::cpu::kernel::stack::stack_manipulation::{shortest_path, StackItem}; + + #[test] + fn test_shortest_path() { + init_logger(); + shortest_path( + vec![named("ret"), named("a"), named("b"), named("d")], + vec![named("ret"), named("b"), named("a")], + vec![], + ); + } + + #[test] + fn test_shortest_path_permutation() { + init_logger(); + shortest_path( + vec![named("a"), named("b"), named("c")], + vec![named("c"), named("a"), named("b")], + vec![], + ); + } + + fn named(name: &str) -> StackItem { + NamedItem(name.into()) + } + + fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); + } +}