mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-08 16:53:07 +00:00
Merge pull request #662 from mir-protocol/stack_pruning_opt_perms
For permutations, find the optimal sequence of swaps
This commit is contained in:
commit
1763b6bc37
@ -25,6 +25,7 @@ keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" }
|
||||
keccak-hash = "0.9.0"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.5"
|
||||
hex = "0.4.3"
|
||||
|
||||
[features]
|
||||
@ -35,3 +36,7 @@ parallel = ["maybe_rayon/parallel"]
|
||||
[[bin]]
|
||||
name = "assemble"
|
||||
required-features = ["asmtools"]
|
||||
|
||||
[[bench]]
|
||||
name = "stack_manipulation"
|
||||
harness = false
|
||||
|
||||
75
evm/benches/stack_manipulation.rs
Normal file
75
evm/benches/stack_manipulation.rs
Normal file
@ -0,0 +1,75 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use plonky2_evm::cpu::kernel::assemble_to_bytes;
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
rotl_group(c);
|
||||
rotr_group(c);
|
||||
insert_group(c);
|
||||
delete_group(c);
|
||||
replace_group(c);
|
||||
shuffle_group(c);
|
||||
misc_group(c);
|
||||
}
|
||||
|
||||
fn rotl_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("rotl");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (b, c, d, e, f, g, h, a)"))
|
||||
});
|
||||
}
|
||||
|
||||
fn rotr_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("rotr");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (h, a, b, c, d, e, f, g)"))
|
||||
});
|
||||
}
|
||||
|
||||
fn insert_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("insert");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, d, 123, e, f, g, h)"))
|
||||
});
|
||||
}
|
||||
|
||||
fn delete_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("delete");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, e, f, g, h)"))
|
||||
});
|
||||
}
|
||||
|
||||
fn replace_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("replace");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (a, b, c, 5, e, f, g, h)"))
|
||||
});
|
||||
}
|
||||
|
||||
fn shuffle_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("shuffle");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, d, e, f, g, h) -> (g, d, h, a, f, e, b, c)"))
|
||||
});
|
||||
}
|
||||
|
||||
fn misc_group(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("misc");
|
||||
group.sample_size(10);
|
||||
group.bench_function(BenchmarkId::from_parameter(8), |b| {
|
||||
b.iter(|| assemble("%stack (a, b, c, a, e, f, g, h) -> (g, 1, h, g, f, 3, b, b)"))
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
|
||||
fn assemble(code: &str) {
|
||||
assemble_to_bytes(&[code.into()]);
|
||||
}
|
||||
@ -10,7 +10,7 @@ use crate::cpu::kernel::ast::StackReplacement;
|
||||
use crate::cpu::kernel::keccak_util::hash_kernel;
|
||||
use crate::cpu::kernel::optimizer::optimize_asm;
|
||||
use crate::cpu::kernel::prover_input::ProverInputFn;
|
||||
use crate::cpu::kernel::stack_manipulation::expand_stack_manipulation;
|
||||
use crate::cpu::kernel::stack::stack_manipulation::expand_stack_manipulation;
|
||||
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
|
||||
use crate::cpu::kernel::{
|
||||
ast::{File, Item},
|
||||
|
||||
@ -10,7 +10,7 @@ mod opcodes;
|
||||
mod optimizer;
|
||||
mod parser;
|
||||
pub mod prover_input;
|
||||
mod stack_manipulation;
|
||||
pub mod stack;
|
||||
mod txn_fields;
|
||||
mod utils;
|
||||
|
||||
|
||||
2
evm/src/cpu/kernel/stack/mod.rs
Normal file
2
evm/src/cpu/kernel/stack/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
mod permutations;
|
||||
pub mod stack_manipulation;
|
||||
278
evm/src/cpu/kernel/stack/permutations.rs
Normal file
278
evm/src/cpu/kernel/stack/permutations.rs
Normal file
@ -0,0 +1,278 @@
|
||||
//! This module contains logic for finding the optimal sequence of swaps to get from one stack state
|
||||
//! to another, specifically for the case where the source and destination states are permutations
|
||||
//! of one another.
|
||||
//!
|
||||
//! We solve the problem in three steps:
|
||||
//! 1. Find a permutation `P` such that `P A = B`.
|
||||
//! 2. If `A` contains duplicates, optimize `P` by reducing the number of cycles.
|
||||
//! 3. Convert each cycle into a set of `(0 i)` transpositions, which correspond to swap
|
||||
//! instructions in the EVM.
|
||||
//!
|
||||
//! We typically represent a permutation as a sequence of cycles. For example, the permutation
|
||||
//! `(1 2 3)(1 2)(4 5)` acts as:
|
||||
//!
|
||||
//! ```ignore
|
||||
//! (1 2 3)(1 2)(4 5)[A_0, A_1, A_2, A_3, A_4, A_5] = (1 2 3)(1 2)[A_0, A_1, A_2, A_3, A_5, A_4]
|
||||
//! = (1 2 3)[A_0, A_2, A_1, A_3, A_5, A_4]
|
||||
//! = [A_0, A_3, A_2, A_1, A_5, A_4]
|
||||
//! ```
|
||||
//!
|
||||
//! We typically represent a `(0 i)` transposition as a single scalar `i`.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::Hash;
|
||||
|
||||
use crate::cpu::kernel::stack::stack_manipulation::{StackItem, StackOp};
|
||||
|
||||
/// Find the optimal sequence of stack operations to get from `src` to `dst`. Assumes that `src` and
|
||||
/// `dst` are permutations of one another.
|
||||
pub(crate) fn get_stack_ops_for_perm(src: &[StackItem], dst: &[StackItem]) -> Vec<StackOp> {
|
||||
// We store stacks with the tip at the end, but the permutation calls below use the opposite
|
||||
// convention. They're a bit simpler when SWAP are (0 i) transposes.
|
||||
let mut src = src.to_vec();
|
||||
let mut dst = dst.to_vec();
|
||||
src.reverse();
|
||||
dst.reverse();
|
||||
|
||||
let perm = find_permutation(&src, &dst);
|
||||
let optimized_perm = combine_cycles(perm, &src);
|
||||
let trans = permutation_to_transpositions(optimized_perm);
|
||||
transpositions_to_stack_ops(trans)
|
||||
}
|
||||
|
||||
/// Apply the given permutation to the given list.
|
||||
#[cfg(test)]
|
||||
fn apply_perm<T: Eq + Hash + Clone>(permutation: Vec<Vec<usize>>, mut lst: Vec<T>) -> Vec<T> {
|
||||
// Run through perm in REVERSE order.
|
||||
for cycl in permutation.iter().rev() {
|
||||
let n = cycl.len();
|
||||
let last = lst[cycl[n - 1]].clone();
|
||||
for i in (0..n - 1).rev() {
|
||||
let j = (i + 1) % n;
|
||||
lst[cycl[j]] = lst[cycl[i]].clone();
|
||||
}
|
||||
lst[cycl[0]] = last;
|
||||
}
|
||||
lst
|
||||
}
|
||||
|
||||
/// This function does STEP 1.
|
||||
/// Given 2 lists A, B find a permutation P such that P . A = B.
|
||||
pub fn find_permutation<T: Eq + Hash + Clone>(lst_a: &[T], lst_b: &[T]) -> Vec<Vec<usize>> {
|
||||
// We should check to ensure that A and B are indeed rearrangments of each other.
|
||||
assert!(is_permutation(lst_a, lst_b));
|
||||
|
||||
let n = lst_a.len();
|
||||
|
||||
// Keep track of the A_i's which have been already placed into the correct position.
|
||||
let mut correct_a = HashSet::new();
|
||||
|
||||
// loc_b is a dictionary where loc_b[b] is the indices i where b = B_i != A_i.
|
||||
// We need to swap appropriate A_j's into these positions.
|
||||
let mut loc_b: HashMap<T, Vec<usize>> = HashMap::new();
|
||||
|
||||
for i in 0..n {
|
||||
if lst_a[i] == lst_b[i] {
|
||||
// If A_i = B_i, we never do SWAP_i as we are already in the correct position.
|
||||
correct_a.insert(i);
|
||||
} else {
|
||||
loc_b.entry(lst_b[i].clone()).or_default().push(i);
|
||||
}
|
||||
}
|
||||
|
||||
// This will be a list of disjoint cycles.
|
||||
let mut permutation = vec![];
|
||||
|
||||
// For technical reasons, it's handy to include [0] as a trivial cycle.
|
||||
// This is because if A_0 = A_i for some other i in a cycle,
|
||||
// we can save transpositions by expanding the cycle to include 0.
|
||||
if correct_a.contains(&0) {
|
||||
permutation.push(vec![0]);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
// If i is both not in the correct position and not already in a cycle, it will start a new cycle.
|
||||
if correct_a.contains(&i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
correct_a.insert(i);
|
||||
let mut cycl = vec![i];
|
||||
|
||||
// lst_a[i] need to be swapped into an index j such that lst_b[j] = lst_a[i].
|
||||
// This exactly means j should be an element of loc_b[lst_a[i]].
|
||||
// We pop as each j should only be used once.
|
||||
// In this step we simply find any permutation. We will improve it to an optimal one in STEP 2.
|
||||
let mut j = loc_b.get_mut(&lst_a[i]).unwrap().pop().unwrap();
|
||||
|
||||
// Keep adding elements to the cycle until we return to our initial index
|
||||
while j != i {
|
||||
correct_a.insert(j);
|
||||
cycl.push(j);
|
||||
j = loc_b.get_mut(&lst_a[j]).unwrap().pop().unwrap();
|
||||
}
|
||||
|
||||
permutation.push(cycl);
|
||||
}
|
||||
permutation
|
||||
}
|
||||
|
||||
/// This function does STEP 2. It tests to see if cycles can be combined which might occur if A has duplicates.
|
||||
fn combine_cycles<T: Eq + Hash + Clone>(mut perm: Vec<Vec<usize>>, lst_a: &[T]) -> Vec<Vec<usize>> {
|
||||
// If perm is a single cycle, there is nothing to combine.
|
||||
if perm.len() == 1 {
|
||||
return perm;
|
||||
}
|
||||
|
||||
let n = lst_a.len();
|
||||
|
||||
// Need a dictionary to keep track of duplicates in lst_a.
|
||||
let mut all_a_positions: HashMap<T, Vec<usize>> = HashMap::new();
|
||||
for i in 0..n {
|
||||
all_a_positions.entry(lst_a[i].clone()).or_default().push(i);
|
||||
}
|
||||
|
||||
// For each element a which occurs at positions i1, ..., ij, combine cycles such that all
|
||||
// ik which occur in a cycle occur in the same cycle.
|
||||
for positions in all_a_positions.values() {
|
||||
if positions.len() == 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut joinedperm = vec![];
|
||||
let mut newperm = vec![];
|
||||
let mut pos = 0;
|
||||
for cycl in perm {
|
||||
// Does cycl include an element of positions?
|
||||
let mut disjoint = true;
|
||||
|
||||
for term in positions {
|
||||
if cycl.contains(term) {
|
||||
if joinedperm.is_empty() {
|
||||
// This is the first cycle we have found including an element of positions.
|
||||
joinedperm = cycl.clone();
|
||||
pos = cycl.iter().position(|x| x == term).unwrap();
|
||||
} else {
|
||||
// Need to merge 2 cycles. If A_i = A_j then the permutations
|
||||
// (C_1, ..., C_k1, i, C_{k1 + 1}, ... C_k2)(D_1, ..., D_k3, j, D_{k3 + 1}, ... D_k4)
|
||||
// (C_1, ..., C_k1, i, D_{k3 + 1}, ... D_k4, D_1, ..., D_k3, j, C_{k1 + 1}, ... C_k2)
|
||||
// lead to the same oupput but the second will require less transpositions.
|
||||
let newpos = cycl.iter().position(|x| x == term).unwrap();
|
||||
joinedperm = [
|
||||
&joinedperm[..pos + 1],
|
||||
&cycl[newpos + 1..],
|
||||
&cycl[..newpos + 1],
|
||||
&joinedperm[pos + 1..],
|
||||
]
|
||||
.concat();
|
||||
}
|
||||
disjoint = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if disjoint {
|
||||
newperm.push(cycl);
|
||||
}
|
||||
}
|
||||
if !joinedperm.is_empty() {
|
||||
newperm.push(joinedperm);
|
||||
}
|
||||
perm = newperm;
|
||||
}
|
||||
perm
|
||||
}
|
||||
|
||||
// This function does STEP 3. Converting all cycles to [0, i] transpositions.
|
||||
fn permutation_to_transpositions(perm: Vec<Vec<usize>>) -> Vec<usize> {
|
||||
let mut trans = vec![];
|
||||
// The method is pretty simple, we have:
|
||||
// (0 C_1 ... C_i) = (0 C_i) ... (0 C_1)
|
||||
// (C_1 ... C_i) = (0 C_1) (0 C_i) ... (0\ C_1).
|
||||
// We simply need to check to see if 0 is in our cycle to see which one to use.
|
||||
for cycl in perm {
|
||||
let n = cycl.len();
|
||||
let zero_pos = cycl.iter().position(|x| *x == 0);
|
||||
if let Some(pos) = zero_pos {
|
||||
trans.extend((1..n).map(|i| cycl[(n + pos - i) % n]));
|
||||
} else {
|
||||
trans.extend((0..=n).map(|i| cycl[(n - i) % n]));
|
||||
}
|
||||
}
|
||||
trans
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn trans_to_perm(trans: Vec<usize>) -> Vec<Vec<usize>> {
|
||||
trans.into_iter().map(|i| vec![0, i]).collect()
|
||||
}
|
||||
|
||||
fn transpositions_to_stack_ops(trans: Vec<usize>) -> Vec<StackOp> {
|
||||
trans.into_iter().map(|i| StackOp::Swap(i as u8)).collect()
|
||||
}
|
||||
|
||||
pub fn is_permutation<T: Eq + Hash + Clone>(a: &[T], b: &[T]) -> bool {
|
||||
make_multiset(a) == make_multiset(b)
|
||||
}
|
||||
|
||||
fn make_multiset<T: Eq + Hash + Clone>(vals: &[T]) -> HashMap<T, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for val in vals {
|
||||
*counts.entry(val.clone()).or_default() += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
use crate::cpu::kernel::stack::permutations::{
|
||||
apply_perm, combine_cycles, find_permutation, is_permutation,
|
||||
permutation_to_transpositions, trans_to_perm,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_combine_cycles() {
|
||||
assert_eq!(
|
||||
combine_cycles(vec![vec![0, 2], vec![3, 4]], &['a', 'b', 'c', 'd', 'a']),
|
||||
vec![vec![0, 3, 4, 2]]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_permutation() {
|
||||
assert!(is_permutation(&['a', 'b', 'c'], &['b', 'c', 'a']));
|
||||
assert!(!is_permutation(&['a', 'b', 'c'], &['a', 'b', 'b', 'c']));
|
||||
assert!(!is_permutation(&['a', 'b', 'c'], &['a', 'd', 'c']));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all() {
|
||||
let mut test_lst = vec![
|
||||
'a', 'a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'e', 'f', 'g', 'h', 'k',
|
||||
];
|
||||
|
||||
let mut rng = thread_rng();
|
||||
test_lst.shuffle(&mut rng);
|
||||
for _ in 0..1000 {
|
||||
let lst_a = test_lst.clone();
|
||||
test_lst.shuffle(&mut rng);
|
||||
let lst_b = test_lst.clone();
|
||||
|
||||
let perm = find_permutation(&lst_a, &lst_b);
|
||||
assert_eq!(apply_perm(perm.clone(), lst_a.clone()), lst_b);
|
||||
|
||||
let shortperm = combine_cycles(perm.clone(), &lst_a);
|
||||
assert_eq!(apply_perm(shortperm.clone(), lst_a.clone()), lst_b);
|
||||
|
||||
let trans = trans_to_perm(permutation_to_transpositions(perm));
|
||||
assert_eq!(apply_perm(trans.clone(), lst_a.clone()), lst_b);
|
||||
|
||||
let shorttrans = trans_to_perm(permutation_to_transpositions(shortperm));
|
||||
assert_eq!(apply_perm(shorttrans.clone(), lst_a.clone()), lst_b);
|
||||
|
||||
assert!(shorttrans.len() <= trans.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,15 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::hash_map::Entry::{Occupied, Vacant};
|
||||
use std::collections::{BinaryHeap, HashMap};
|
||||
use std::hash::Hash;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::cpu::columns::NUM_CPU_COLUMNS;
|
||||
use crate::cpu::kernel::assembler::BYTES_PER_OFFSET;
|
||||
use crate::cpu::kernel::ast::{Item, PushTarget, StackReplacement};
|
||||
use crate::cpu::kernel::stack_manipulation::StackOp::Pop;
|
||||
use crate::cpu::kernel::stack::permutations::{get_stack_ops_for_perm, is_permutation};
|
||||
use crate::cpu::kernel::stack::stack_manipulation::StackOp::Pop;
|
||||
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
|
||||
use crate::memory;
|
||||
|
||||
@ -164,13 +166,13 @@ impl Ord for Node {
|
||||
|
||||
/// Like `StackReplacement`, but without constants or macro vars, since those were expanded already.
|
||||
#[derive(Eq, PartialEq, Hash, Clone, Debug)]
|
||||
enum StackItem {
|
||||
pub(crate) enum StackItem {
|
||||
NamedItem(String),
|
||||
PushTarget(PushTarget),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum StackOp {
|
||||
pub(crate) enum StackOp {
|
||||
Push(PushTarget),
|
||||
Pop,
|
||||
Dup(u8),
|
||||
@ -188,6 +190,11 @@ fn next_ops(
|
||||
return vec![StackOp::Pop]
|
||||
}
|
||||
|
||||
if is_permutation(src, dst) {
|
||||
// The transpositions are right-associative, so the last one gets applied first, hence pop.
|
||||
return vec![get_stack_ops_for_perm(src, dst).pop().unwrap()];
|
||||
}
|
||||
|
||||
let mut ops = vec![StackOp::Pop];
|
||||
|
||||
ops.extend(
|
||||
@ -220,11 +227,31 @@ fn next_ops(
|
||||
.map(StackOp::Dup),
|
||||
);
|
||||
|
||||
ops.extend((1..src_len).map(StackOp::Swap));
|
||||
ops.extend(
|
||||
(1..src_len)
|
||||
.filter(|i| should_try_swap(src, dst, *i))
|
||||
.map(StackOp::Swap),
|
||||
);
|
||||
|
||||
ops
|
||||
}
|
||||
|
||||
/// Whether we should consider `SWAP_i` in the search.
|
||||
fn should_try_swap(src: &[StackItem], dst: &[StackItem], i: u8) -> bool {
|
||||
if src.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let i = i as usize;
|
||||
let i_from = src.len() - 1;
|
||||
let i_to = i_from - i;
|
||||
|
||||
// Only consider a swap if it places one of the two affected elements in the desired position.
|
||||
let top_correct_pos = i_to < dst.len() && src[i_from] == dst[i_to];
|
||||
let other_correct_pos = i_from < dst.len() && src[i_to] == dst[i_from];
|
||||
top_correct_pos | other_correct_pos
|
||||
}
|
||||
|
||||
impl StackOp {
|
||||
fn cost(&self) -> u32 {
|
||||
let (cpu_rows, memory_rows) = match self {
|
||||
@ -287,3 +314,39 @@ impl StackOp {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV};
|
||||
|
||||
use crate::cpu::kernel::stack::stack_manipulation::StackItem::NamedItem;
|
||||
use crate::cpu::kernel::stack::stack_manipulation::{shortest_path, StackItem};
|
||||
|
||||
#[test]
|
||||
fn test_shortest_path() {
|
||||
init_logger();
|
||||
shortest_path(
|
||||
vec![named("ret"), named("a"), named("b"), named("d")],
|
||||
vec![named("ret"), named("b"), named("a")],
|
||||
vec![],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shortest_path_permutation() {
|
||||
init_logger();
|
||||
shortest_path(
|
||||
vec![named("a"), named("b"), named("c")],
|
||||
vec![named("c"), named("a"), named("b")],
|
||||
vec![],
|
||||
);
|
||||
}
|
||||
|
||||
fn named(name: &str) -> StackItem {
|
||||
NamedItem(name.into())
|
||||
}
|
||||
|
||||
fn init_logger() {
|
||||
let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug"));
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user