Merge pull request #662 from mir-protocol/stack_pruning_opt_perms

For permutations, find the optimal sequence of swaps
This commit is contained in:
Daniel Lubarov 2022-08-12 17:19:47 -07:00 committed by GitHub
commit 1763b6bc37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 429 additions and 6 deletions

View File

@ -25,6 +25,7 @@ keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" }
keccak-hash = "0.9.0"
[dev-dependencies]
criterion = "0.3.5"
hex = "0.4.3"
[features]
@ -35,3 +36,7 @@ parallel = ["maybe_rayon/parallel"]
[[bin]]
name = "assemble"
required-features = ["asmtools"]
[[bench]]
name = "stack_manipulation"
harness = false

View File

@ -0,0 +1,75 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use plonky2_evm::cpu::kernel::assemble_to_bytes;
fn criterion_benchmark(c: &mut Criterion) {
rotl_group(c);
rotr_group(c);
insert_group(c);
delete_group(c);
replace_group(c);
shuffle_group(c);
misc_group(c);
}
fn rotl_group(c: &mut Criterion) {
let mut group = c.benchmark_group("rotl");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (b, c, d, e, f, g, h, a)"))
});
}
fn rotr_group(c: &mut Criterion) {
let mut group = c.benchmark_group("rotr");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (h, a, b, c, d, e, f, g)"))
});
}
fn insert_group(c: &mut Criterion) {
let mut group = c.benchmark_group("insert");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, d, 123, e, f, g, h)"))
});
}
fn delete_group(c: &mut Criterion) {
let mut group = c.benchmark_group("delete");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, e, f, g, h)"))
});
}
fn replace_group(c: &mut Criterion) {
let mut group = c.benchmark_group("replace");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, 5, e, f, g, h)"))
});
}
fn shuffle_group(c: &mut Criterion) {
let mut group = c.benchmark_group("shuffle");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (g, d, h, a, f, e, b, c)"))
});
}
fn misc_group(c: &mut Criterion) {
let mut group = c.benchmark_group("misc");
group.sample_size(10);
group.bench_function(BenchmarkId::from_parameter(8), |b| {
b.iter(|| assemble("%stack (a, b, c, a, e, f, g, h) -> (g, 1, h, g, f, 3, b, b)"))
});
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
fn assemble(code: &str) {
assemble_to_bytes(&[code.into()]);
}

View File

@ -10,7 +10,7 @@ use crate::cpu::kernel::ast::StackReplacement;
use crate::cpu::kernel::keccak_util::hash_kernel;
use crate::cpu::kernel::optimizer::optimize_asm;
use crate::cpu::kernel::prover_input::ProverInputFn;
use crate::cpu::kernel::stack_manipulation::expand_stack_manipulation;
use crate::cpu::kernel::stack::stack_manipulation::expand_stack_manipulation;
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
use crate::cpu::kernel::{
ast::{File, Item},

View File

@ -10,7 +10,7 @@ mod opcodes;
mod optimizer;
mod parser;
pub mod prover_input;
mod stack_manipulation;
pub mod stack;
mod txn_fields;
mod utils;

View File

@ -0,0 +1,2 @@
mod permutations;
pub mod stack_manipulation;

View File

@ -0,0 +1,278 @@
//! This module contains logic for finding the optimal sequence of swaps to get from one stack state
//! to another, specifically for the case where the source and destination states are permutations
//! of one another.
//!
//! We solve the problem in three steps:
//! 1. Find a permutation `P` such that `P A = B`.
//! 2. If `A` contains duplicates, optimize `P` by reducing the number of cycles.
//! 3. Convert each cycle into a set of `(0 i)` transpositions, which correspond to swap
//! instructions in the EVM.
//!
//! We typically represent a permutation as a sequence of cycles. For example, the permutation
//! `(1 2 3)(1 2)(4 5)` acts as:
//!
//! ```ignore
//! (1 2 3)(1 2)(4 5)[A_0, A_1, A_2, A_3, A_4, A_5] = (1 2 3)(1 2)[A_0, A_1, A_2, A_3, A_5, A_4]
//! = (1 2 3)[A_0, A_2, A_1, A_3, A_5, A_4]
//! = [A_0, A_3, A_2, A_1, A_5, A_4]
//! ```
//!
//! We typically represent a `(0 i)` transposition as a single scalar `i`.
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use crate::cpu::kernel::stack::stack_manipulation::{StackItem, StackOp};
/// Find the optimal sequence of stack operations to get from `src` to `dst`. Assumes that `src` and
/// `dst` are permutations of one another.
pub(crate) fn get_stack_ops_for_perm(src: &[StackItem], dst: &[StackItem]) -> Vec<StackOp> {
// We store stacks with the tip at the end, but the permutation calls below use the opposite
// convention. They're a bit simpler when SWAP are (0 i) transposes.
let mut src = src.to_vec();
let mut dst = dst.to_vec();
src.reverse();
dst.reverse();
let perm = find_permutation(&src, &dst);
let optimized_perm = combine_cycles(perm, &src);
let trans = permutation_to_transpositions(optimized_perm);
transpositions_to_stack_ops(trans)
}
/// Apply the given permutation to the given list.
#[cfg(test)]
fn apply_perm<T: Eq + Hash + Clone>(permutation: Vec<Vec<usize>>, mut lst: Vec<T>) -> Vec<T> {
// Run through perm in REVERSE order.
for cycl in permutation.iter().rev() {
let n = cycl.len();
let last = lst[cycl[n - 1]].clone();
for i in (0..n - 1).rev() {
let j = (i + 1) % n;
lst[cycl[j]] = lst[cycl[i]].clone();
}
lst[cycl[0]] = last;
}
lst
}
/// This function does STEP 1.
/// Given 2 lists A, B find a permutation P such that P . A = B.
pub fn find_permutation<T: Eq + Hash + Clone>(lst_a: &[T], lst_b: &[T]) -> Vec<Vec<usize>> {
// We should check to ensure that A and B are indeed rearrangments of each other.
assert!(is_permutation(lst_a, lst_b));
let n = lst_a.len();
// Keep track of the A_i's which have been already placed into the correct position.
let mut correct_a = HashSet::new();
// loc_b is a dictionary where loc_b[b] is the indices i where b = B_i != A_i.
// We need to swap appropriate A_j's into these positions.
let mut loc_b: HashMap<T, Vec<usize>> = HashMap::new();
for i in 0..n {
if lst_a[i] == lst_b[i] {
// If A_i = B_i, we never do SWAP_i as we are already in the correct position.
correct_a.insert(i);
} else {
loc_b.entry(lst_b[i].clone()).or_default().push(i);
}
}
// This will be a list of disjoint cycles.
let mut permutation = vec![];
// For technical reasons, it's handy to include [0] as a trivial cycle.
// This is because if A_0 = A_i for some other i in a cycle,
// we can save transpositions by expanding the cycle to include 0.
if correct_a.contains(&0) {
permutation.push(vec![0]);
}
for i in 0..n {
// If i is both not in the correct position and not already in a cycle, it will start a new cycle.
if correct_a.contains(&i) {
continue;
}
correct_a.insert(i);
let mut cycl = vec![i];
// lst_a[i] need to be swapped into an index j such that lst_b[j] = lst_a[i].
// This exactly means j should be an element of loc_b[lst_a[i]].
// We pop as each j should only be used once.
// In this step we simply find any permutation. We will improve it to an optimal one in STEP 2.
let mut j = loc_b.get_mut(&lst_a[i]).unwrap().pop().unwrap();
// Keep adding elements to the cycle until we return to our initial index
while j != i {
correct_a.insert(j);
cycl.push(j);
j = loc_b.get_mut(&lst_a[j]).unwrap().pop().unwrap();
}
permutation.push(cycl);
}
permutation
}
/// This function does STEP 2. It tests to see if cycles can be combined which might occur if A has duplicates.
fn combine_cycles<T: Eq + Hash + Clone>(mut perm: Vec<Vec<usize>>, lst_a: &[T]) -> Vec<Vec<usize>> {
// If perm is a single cycle, there is nothing to combine.
if perm.len() == 1 {
return perm;
}
let n = lst_a.len();
// Need a dictionary to keep track of duplicates in lst_a.
let mut all_a_positions: HashMap<T, Vec<usize>> = HashMap::new();
for i in 0..n {
all_a_positions.entry(lst_a[i].clone()).or_default().push(i);
}
// For each element a which occurs at positions i1, ..., ij, combine cycles such that all
// ik which occur in a cycle occur in the same cycle.
for positions in all_a_positions.values() {
if positions.len() == 1 {
continue;
}
let mut joinedperm = vec![];
let mut newperm = vec![];
let mut pos = 0;
for cycl in perm {
// Does cycl include an element of positions?
let mut disjoint = true;
for term in positions {
if cycl.contains(term) {
if joinedperm.is_empty() {
// This is the first cycle we have found including an element of positions.
joinedperm = cycl.clone();
pos = cycl.iter().position(|x| x == term).unwrap();
} else {
// Need to merge 2 cycles. If A_i = A_j then the permutations
// (C_1, ..., C_k1, i, C_{k1 + 1}, ... C_k2)(D_1, ..., D_k3, j, D_{k3 + 1}, ... D_k4)
// (C_1, ..., C_k1, i, D_{k3 + 1}, ... D_k4, D_1, ..., D_k3, j, C_{k1 + 1}, ... C_k2)
// lead to the same oupput but the second will require less transpositions.
let newpos = cycl.iter().position(|x| x == term).unwrap();
joinedperm = [
&joinedperm[..pos + 1],
&cycl[newpos + 1..],
&cycl[..newpos + 1],
&joinedperm[pos + 1..],
]
.concat();
}
disjoint = false;
break;
}
}
if disjoint {
newperm.push(cycl);
}
}
if !joinedperm.is_empty() {
newperm.push(joinedperm);
}
perm = newperm;
}
perm
}
// This function does STEP 3. Converting all cycles to [0, i] transpositions.
fn permutation_to_transpositions(perm: Vec<Vec<usize>>) -> Vec<usize> {
let mut trans = vec![];
// The method is pretty simple, we have:
// (0 C_1 ... C_i) = (0 C_i) ... (0 C_1)
// (C_1 ... C_i) = (0 C_1) (0 C_i) ... (0\ C_1).
// We simply need to check to see if 0 is in our cycle to see which one to use.
for cycl in perm {
let n = cycl.len();
let zero_pos = cycl.iter().position(|x| *x == 0);
if let Some(pos) = zero_pos {
trans.extend((1..n).map(|i| cycl[(n + pos - i) % n]));
} else {
trans.extend((0..=n).map(|i| cycl[(n - i) % n]));
}
}
trans
}
#[cfg(test)]
fn trans_to_perm(trans: Vec<usize>) -> Vec<Vec<usize>> {
trans.into_iter().map(|i| vec![0, i]).collect()
}
fn transpositions_to_stack_ops(trans: Vec<usize>) -> Vec<StackOp> {
trans.into_iter().map(|i| StackOp::Swap(i as u8)).collect()
}
pub fn is_permutation<T: Eq + Hash + Clone>(a: &[T], b: &[T]) -> bool {
make_multiset(a) == make_multiset(b)
}
fn make_multiset<T: Eq + Hash + Clone>(vals: &[T]) -> HashMap<T, usize> {
let mut counts = HashMap::new();
for val in vals {
*counts.entry(val.clone()).or_default() += 1;
}
counts
}
#[cfg(test)]
mod tests {
use rand::prelude::SliceRandom;
use rand::thread_rng;
use crate::cpu::kernel::stack::permutations::{
apply_perm, combine_cycles, find_permutation, is_permutation,
permutation_to_transpositions, trans_to_perm,
};
#[test]
fn test_combine_cycles() {
assert_eq!(
combine_cycles(vec![vec![0, 2], vec![3, 4]], &['a', 'b', 'c', 'd', 'a']),
vec![vec![0, 3, 4, 2]]
);
}
#[test]
fn test_is_permutation() {
assert!(is_permutation(&['a', 'b', 'c'], &['b', 'c', 'a']));
assert!(!is_permutation(&['a', 'b', 'c'], &['a', 'b', 'b', 'c']));
assert!(!is_permutation(&['a', 'b', 'c'], &['a', 'd', 'c']));
}
#[test]
fn test_all() {
let mut test_lst = vec![
'a', 'a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'e', 'f', 'g', 'h', 'k',
];
let mut rng = thread_rng();
test_lst.shuffle(&mut rng);
for _ in 0..1000 {
let lst_a = test_lst.clone();
test_lst.shuffle(&mut rng);
let lst_b = test_lst.clone();
let perm = find_permutation(&lst_a, &lst_b);
assert_eq!(apply_perm(perm.clone(), lst_a.clone()), lst_b);
let shortperm = combine_cycles(perm.clone(), &lst_a);
assert_eq!(apply_perm(shortperm.clone(), lst_a.clone()), lst_b);
let trans = trans_to_perm(permutation_to_transpositions(perm));
assert_eq!(apply_perm(trans.clone(), lst_a.clone()), lst_b);
let shorttrans = trans_to_perm(permutation_to_transpositions(shortperm));
assert_eq!(apply_perm(shorttrans.clone(), lst_a.clone()), lst_b);
assert!(shorttrans.len() <= trans.len());
}
}
}

View File

@ -1,13 +1,15 @@
use std::cmp::Ordering;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use itertools::Itertools;
use crate::cpu::columns::NUM_CPU_COLUMNS;
use crate::cpu::kernel::assembler::BYTES_PER_OFFSET;
use crate::cpu::kernel::ast::{Item, PushTarget, StackReplacement};
use crate::cpu::kernel::stack_manipulation::StackOp::Pop;
use crate::cpu::kernel::stack::permutations::{get_stack_ops_for_perm, is_permutation};
use crate::cpu::kernel::stack::stack_manipulation::StackOp::Pop;
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
use crate::memory;
@ -164,13 +166,13 @@ impl Ord for Node {
/// Like `StackReplacement`, but without constants or macro vars, since those were expanded already.
#[derive(Eq, PartialEq, Hash, Clone, Debug)]
enum StackItem {
pub(crate) enum StackItem {
NamedItem(String),
PushTarget(PushTarget),
}
#[derive(Clone, Debug)]
enum StackOp {
pub(crate) enum StackOp {
Push(PushTarget),
Pop,
Dup(u8),
@ -188,6 +190,11 @@ fn next_ops(
return vec![StackOp::Pop]
}
if is_permutation(src, dst) {
// The transpositions are right-associative, so the last one gets applied first, hence pop.
return vec![get_stack_ops_for_perm(src, dst).pop().unwrap()];
}
let mut ops = vec![StackOp::Pop];
ops.extend(
@ -220,11 +227,31 @@ fn next_ops(
.map(StackOp::Dup),
);
ops.extend((1..src_len).map(StackOp::Swap));
ops.extend(
(1..src_len)
.filter(|i| should_try_swap(src, dst, *i))
.map(StackOp::Swap),
);
ops
}
/// Whether we should consider `SWAP_i` in the search.
fn should_try_swap(src: &[StackItem], dst: &[StackItem], i: u8) -> bool {
if src.is_empty() {
return false;
}
let i = i as usize;
let i_from = src.len() - 1;
let i_to = i_from - i;
// Only consider a swap if it places one of the two affected elements in the desired position.
let top_correct_pos = i_to < dst.len() && src[i_from] == dst[i_to];
let other_correct_pos = i_from < dst.len() && src[i_to] == dst[i_from];
top_correct_pos | other_correct_pos
}
impl StackOp {
fn cost(&self) -> u32 {
let (cpu_rows, memory_rows) = match self {
@ -287,3 +314,39 @@ impl StackOp {
}
}
}
#[cfg(test)]
mod tests {
use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV};
use crate::cpu::kernel::stack::stack_manipulation::StackItem::NamedItem;
use crate::cpu::kernel::stack::stack_manipulation::{shortest_path, StackItem};
#[test]
fn test_shortest_path() {
init_logger();
shortest_path(
vec![named("ret"), named("a"), named("b"), named("d")],
vec![named("ret"), named("b"), named("a")],
vec![],
);
}
#[test]
fn test_shortest_path_permutation() {
init_logger();
shortest_path(
vec![named("a"), named("b"), named("c")],
vec![named("c"), named("a"), named("b")],
vec![],
);
}
fn named(name: &str) -> StackItem {
NamedItem(name.into())
}
fn init_logger() {
let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug"));
}
}