Merge pull request #17 from mir-protocol/rescue_bench

Make Rescue a bit faster
This commit is contained in:
Daniel Lubarov 2021-04-24 10:53:52 -07:00 committed by GitHub
commit 7ffb9cf9b2
5 changed files with 145 additions and 100 deletions

View File

@ -14,8 +14,8 @@ const PROVER_POLYS: usize = 113 + 3 + 4;
fn main() {
const THREADS: usize = 12;
const LDE_BITS: i32 = 3;
const W: usize = 13;
const HASHES_PER_POLY: usize = 1 << (13 + LDE_BITS);
const W: usize = 12;
const HASHES_PER_POLY: usize = 1 << (13 + LDE_BITS) / 6;
let threads = (0..THREADS)
.map(|_i| {

46
src/bin/bench_rescue.rs Normal file
View File

@ -0,0 +1,46 @@
use std::thread;
use std::time::Instant;
use plonky2::field::crandall_field::CrandallField;
use plonky2::field::field::Field;
use plonky2::rescue::rescue;
type F = CrandallField;
// 113 wire polys, 3 Z polys, 4 parts of quotient poly.
const PROVER_POLYS: usize = 113 + 3 + 4;
fn main() {
const THREADS: usize = 12;
const LDE_BITS: i32 = 3;
const W: usize = 12;
const HASHES_PER_POLY: usize = (1 << (13 + LDE_BITS)) / 6;
let threads = (0..THREADS)
.map(|_i| {
thread::spawn(move || {
let mut x = [F::ZERO; W];
for i in 0..W {
x[i] = F::from_canonical_u64((i as u64) * 123456 + 789);
}
let hashes_per_thread = HASHES_PER_POLY * PROVER_POLYS / THREADS;
let start = Instant::now();
for _ in 0..hashes_per_thread {
x = rescue(x);
}
let duration = start.elapsed();
println!("took {:?}", duration);
println!(
"avg {:?}us",
duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64)
);
println!("result {:?}", x);
})
})
.collect::<Vec<_>>();
for t in threads {
t.join().expect("oops");
}
}

View File

@ -145,6 +145,85 @@ impl Field for CrandallField {
fn from_canonical_u64(n: u64) -> Self {
Self(n)
}
fn cube_root(&self) -> Self {
let x0 = *self;
let x1 = x0.square();
let x2 = x1.square();
let x3 = x2 * x0;
let x4 = x3.square();
let x5 = x4.square();
// let x6 = x4.square();
let x7 = x5.square();
let x8 = x7.square();
let x9 = x8.square();
let x10 = x9.square();
let x11 = x10 * x5;
let x12 = x11.square();
let x13 = x12.square();
let x14 = x13.square();
// let x15 = x13.square();
let x16 = x14.square();
let x17 = x16.square();
let x18 = x17.square();
let x19 = x18.square();
let x20 = x19.square();
let x21 = x20 * x11;
let x22 = x21.square();
let x23 = x22.square();
let x24 = x23.square();
let x25 = x24.square();
let x26 = x25.square();
let x27 = x26.square();
let x28 = x27.square();
let x29 = x28.square();
let x30 = x29.square();
let x31 = x30.square();
let x32 = x31.square();
let x33 = x32 * x14;
let x34 = x33 * x3;
let x35 = x34.square();
let x36 = x35 * x34;
let x37 = x36 * x5;
let x38 = x37 * x34;
let x39 = x38 * x37;
let x40 = x39.square();
let x41 = x40.square();
let x42 = x41 * x38;
let x43 = x42.square();
let x44 = x43.square();
let x45 = x44.square();
let x46 = x45.square();
let x47 = x46.square();
let x48 = x47.square();
let x49 = x48.square();
let x50 = x49.square();
let x51 = x50.square();
let x52 = x51.square();
let x53 = x52.square();
let x54 = x53.square();
let x55 = x54.square();
let x56 = x55.square();
let x57 = x56.square();
let x58 = x57.square();
let x59 = x58.square();
let x60 = x59.square();
let x61 = x60.square();
let x62 = x61.square();
let x63 = x62.square();
let x64 = x63.square();
let x65 = x64.square();
let x66 = x65.square();
let x67 = x66.square();
let x68 = x67.square();
let x69 = x68.square();
let x70 = x69.square();
let x71 = x70.square();
let x72 = x71.square();
let x73 = x72.square();
let x74 = x73 * x39;
x74
}
}
impl Neg for CrandallField {

View File

@ -160,6 +160,16 @@ pub trait Field:
self.exp(Self::from_canonical_usize(power))
}
fn kth_root(&self, k: usize) -> Self {
let p_minus_1 = Self::ORDER - 1;
debug_assert!(p_minus_1 % k as u64 != 0, "Not a permutation in this field");
todo!()
}
fn cube_root(&self) -> Self {
self.kth_root(3)
}
fn powers(&self) -> Powers<Self> {
Powers {
base: *self,

View File

@ -1,8 +1,10 @@
//! Implements Rescue Prime.
use unroll::unroll_for_loops;
use crate::field::field::Field;
const ROUNDS: usize = 10;
const ROUNDS: usize = 8;
const W: usize = 12;
@ -177,7 +179,7 @@ const MDS: [[u64; W]; W] = [
],
];
const RESCUE_CONSTANTS: [[u64; W]; 20] = [
const RESCUE_CONSTANTS: [[u64; W]; 16] = [
[
12050887499329086906,
1748247961703512657,
@ -402,66 +404,10 @@ const RESCUE_CONSTANTS: [[u64; W]; 20] = [
16465224002344550280,
10282380383506806095,
],
[
12608209810104211593,
11808578423511814760,
16177950852717156460,
9394439296563712221,
12586575762376685187,
17703393198607870393,
9811861465513647715,
14126450959506560131,
12713673607080398908,
18301828072718562389,
11180556590297273821,
4451415492203885059,
],
[
10465807219916311101,
1213997644391575261,
17672155373280862521,
1491206970207330736,
10977478805896263804,
13260961975618373124,
16060889403827043708,
3223573072465920682,
17624203443801796697,
10247205738678800822,
11100653267668698651,
14328592975764892571,
],
[
6984072551318461094,
3416562710010527326,
12847783919251969270,
12223185134739244472,
12073170519625198198,
6221124633828606855,
17596623990006806590,
1153871693574764968,
2548851681903410721,
9823373270182377847,
16708030507924899244,
9619306826188519218,
],
[
5842685042453818473,
12400879353954910914,
647112787845575111,
4893664959929687347,
3759391664155971284,
15871181179823725763,
3629377713951158273,
3439101502554162312,
8325686353010019444,
10630488935940555500,
3478529754946055748,
12681233130980545828,
],
];
fn rescue<F: Field>(mut xs: [F; W]) -> [F; W] {
for r in 0..10 {
pub fn rescue<F: Field>(mut xs: [F; W]) -> [F; W] {
for r in 0..8 {
xs = sbox_layer_a(xs);
xs = mds_layer(xs);
xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2]);
@ -470,62 +416,27 @@ fn rescue<F: Field>(mut xs: [F; W]) -> [F; W] {
xs = mds_layer(xs);
xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2 + 1]);
}
// for i in 0..W {
// xs[i] = xs[i].to_canonical();
// }
xs
}
// #[inline(always)]
#[unroll_for_loops]
fn sbox_layer_a<F: Field>(x: [F; W]) -> [F; W] {
let mut result = [F::ZERO; W];
for i in 0..W {
result[i] = sbox_a(x[i]);
result[i] = x[i].cube();
}
result
}
// #[inline(always)]
#[unroll_for_loops]
fn sbox_layer_b<F: Field>(x: [F; W]) -> [F; W] {
let mut result = [F::ZERO; W];
for i in 0..W {
result[i] = sbox_b(x[i]);
result[i] = x[i].cube_root();
}
result
}
// #[inline(always)]
#[unroll_for_loops]
fn sbox_a<F: Field>(x: F) -> F {
// x^{-5}, via Fermat's little theorem
// TODO: This only works for our current field.
const EXP: u64 = 7378697628517453005;
let mut product = F::ONE;
let mut current = x;
for i in 0..64 {
if ((EXP >> i) & 1) != 0 {
product = product * current;
}
current = current.square();
}
product
}
#[inline(always)]
fn sbox_b<F: Field>(x: F) -> F {
// x^5
let x2 = x.square();
let x3 = x2 * x;
x2 * x3
}
// #[inline(always)]
#[unroll_for_loops]
fn mds_layer<F: Field>(x: [F; W]) -> [F; W] {
let mut result = [F::ZERO; W];
@ -537,7 +448,6 @@ fn mds_layer<F: Field>(x: [F; W]) -> [F; W] {
result
}
#[inline(always)]
#[unroll_for_loops]
fn constant_layer<F: Field>(xs: [F; W], con: &[u64; W]) -> [F; W] {
let mut result = [F::ZERO; W];