diff --git a/src/bin/bench_gmimc.rs b/src/bin/bench_gmimc.rs index f234285d..2f81aac0 100644 --- a/src/bin/bench_gmimc.rs +++ b/src/bin/bench_gmimc.rs @@ -14,8 +14,8 @@ const PROVER_POLYS: usize = 113 + 3 + 4; fn main() { const THREADS: usize = 12; const LDE_BITS: i32 = 3; - const W: usize = 13; - const HASHES_PER_POLY: usize = 1 << (13 + LDE_BITS); + const W: usize = 12; + const HASHES_PER_POLY: usize = 1 << (13 + LDE_BITS) / 6; let threads = (0..THREADS) .map(|_i| { diff --git a/src/bin/bench_rescue.rs b/src/bin/bench_rescue.rs new file mode 100644 index 00000000..96334689 --- /dev/null +++ b/src/bin/bench_rescue.rs @@ -0,0 +1,46 @@ +use std::thread; +use std::time::Instant; + +use plonky2::field::crandall_field::CrandallField; +use plonky2::field::field::Field; +use plonky2::rescue::rescue; + +type F = CrandallField; + +// 113 wire polys, 3 Z polys, 4 parts of quotient poly. +const PROVER_POLYS: usize = 113 + 3 + 4; + +fn main() { + const THREADS: usize = 12; + const LDE_BITS: i32 = 3; + const W: usize = 12; + const HASHES_PER_POLY: usize = (1 << (13 + LDE_BITS)) / 6; + + let threads = (0..THREADS) + .map(|_i| { + thread::spawn(move || { + let mut x = [F::ZERO; W]; + for i in 0..W { + x[i] = F::from_canonical_u64((i as u64) * 123456 + 789); + } + + let hashes_per_thread = HASHES_PER_POLY * PROVER_POLYS / THREADS; + let start = Instant::now(); + for _ in 0..hashes_per_thread { + x = rescue(x); + } + let duration = start.elapsed(); + println!("took {:?}", duration); + println!( + "avg {:?}us", + duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64) + ); + println!("result {:?}", x); + }) + }) + .collect::>(); + + for t in threads { + t.join().expect("oops"); + } +} diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 8983ef67..0ec184e8 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -145,6 +145,85 @@ impl Field for CrandallField { fn from_canonical_u64(n: u64) -> Self { Self(n) } + + fn cube_root(&self) -> Self { + let x0 = *self; + let x1 = x0.square(); + let x2 = x1.square(); + let x3 = x2 * x0; + let x4 = x3.square(); + let x5 = x4.square(); + // let x6 = x4.square(); + let x7 = x5.square(); + let x8 = x7.square(); + let x9 = x8.square(); + let x10 = x9.square(); + let x11 = x10 * x5; + let x12 = x11.square(); + let x13 = x12.square(); + let x14 = x13.square(); + // let x15 = x13.square(); + let x16 = x14.square(); + let x17 = x16.square(); + let x18 = x17.square(); + let x19 = x18.square(); + let x20 = x19.square(); + let x21 = x20 * x11; + let x22 = x21.square(); + let x23 = x22.square(); + let x24 = x23.square(); + let x25 = x24.square(); + let x26 = x25.square(); + let x27 = x26.square(); + let x28 = x27.square(); + let x29 = x28.square(); + let x30 = x29.square(); + let x31 = x30.square(); + let x32 = x31.square(); + let x33 = x32 * x14; + let x34 = x33 * x3; + let x35 = x34.square(); + let x36 = x35 * x34; + let x37 = x36 * x5; + let x38 = x37 * x34; + let x39 = x38 * x37; + let x40 = x39.square(); + let x41 = x40.square(); + let x42 = x41 * x38; + let x43 = x42.square(); + let x44 = x43.square(); + let x45 = x44.square(); + let x46 = x45.square(); + let x47 = x46.square(); + let x48 = x47.square(); + let x49 = x48.square(); + let x50 = x49.square(); + let x51 = x50.square(); + let x52 = x51.square(); + let x53 = x52.square(); + let x54 = x53.square(); + let x55 = x54.square(); + let x56 = x55.square(); + let x57 = x56.square(); + let x58 = x57.square(); + let x59 = x58.square(); + let x60 = x59.square(); + let x61 = x60.square(); + let x62 = x61.square(); + let x63 = x62.square(); + let x64 = x63.square(); + let x65 = x64.square(); + let x66 = x65.square(); + let x67 = x66.square(); + let x68 = x67.square(); + let x69 = x68.square(); + let x70 = x69.square(); + let x71 = x70.square(); + let x72 = x71.square(); + let x73 = x72.square(); + let x74 = x73 * x39; + x74 + } } impl Neg for CrandallField { diff --git a/src/field/field.rs b/src/field/field.rs index ed81cb6b..6c12bf7d 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -160,6 +160,16 @@ pub trait Field: self.exp(Self::from_canonical_usize(power)) } + fn kth_root(&self, k: usize) -> Self { + let p_minus_1 = Self::ORDER - 1; + debug_assert!(p_minus_1 % k as u64 != 0, "Not a permutation in this field"); + todo!() + } + + fn cube_root(&self) -> Self { + self.kth_root(3) + } + fn powers(&self) -> Powers { Powers { base: *self, diff --git a/src/rescue.rs b/src/rescue.rs index c28f315a..f088fdb7 100644 --- a/src/rescue.rs +++ b/src/rescue.rs @@ -1,8 +1,10 @@ +//! Implements Rescue Prime. + use unroll::unroll_for_loops; use crate::field::field::Field; -const ROUNDS: usize = 10; +const ROUNDS: usize = 8; const W: usize = 12; @@ -177,7 +179,7 @@ const MDS: [[u64; W]; W] = [ ], ]; -const RESCUE_CONSTANTS: [[u64; W]; 20] = [ +const RESCUE_CONSTANTS: [[u64; W]; 16] = [ [ 12050887499329086906, 1748247961703512657, @@ -402,66 +404,10 @@ const RESCUE_CONSTANTS: [[u64; W]; 20] = [ 16465224002344550280, 10282380383506806095, ], - [ - 12608209810104211593, - 11808578423511814760, - 16177950852717156460, - 9394439296563712221, - 12586575762376685187, - 17703393198607870393, - 9811861465513647715, - 14126450959506560131, - 12713673607080398908, - 18301828072718562389, - 11180556590297273821, - 4451415492203885059, - ], - [ - 10465807219916311101, - 1213997644391575261, - 17672155373280862521, - 1491206970207330736, - 10977478805896263804, - 13260961975618373124, - 16060889403827043708, - 3223573072465920682, - 17624203443801796697, - 10247205738678800822, - 11100653267668698651, - 14328592975764892571, - ], - [ - 6984072551318461094, - 3416562710010527326, - 12847783919251969270, - 12223185134739244472, - 12073170519625198198, - 6221124633828606855, - 17596623990006806590, - 1153871693574764968, - 2548851681903410721, - 9823373270182377847, - 16708030507924899244, - 9619306826188519218, - ], - [ - 5842685042453818473, - 12400879353954910914, - 647112787845575111, - 4893664959929687347, - 3759391664155971284, - 15871181179823725763, - 3629377713951158273, - 3439101502554162312, - 8325686353010019444, - 10630488935940555500, - 3478529754946055748, - 12681233130980545828, - ], ]; -fn rescue(mut xs: [F; W]) -> [F; W] { - for r in 0..10 { +pub fn rescue(mut xs: [F; W]) -> [F; W] { + for r in 0..8 { xs = sbox_layer_a(xs); xs = mds_layer(xs); xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2]); @@ -470,62 +416,27 @@ fn rescue(mut xs: [F; W]) -> [F; W] { xs = mds_layer(xs); xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2 + 1]); } - - // for i in 0..W { - // xs[i] = xs[i].to_canonical(); - // } - xs } -// #[inline(always)] #[unroll_for_loops] fn sbox_layer_a(x: [F; W]) -> [F; W] { let mut result = [F::ZERO; W]; for i in 0..W { - result[i] = sbox_a(x[i]); + result[i] = x[i].cube(); } result } -// #[inline(always)] #[unroll_for_loops] fn sbox_layer_b(x: [F; W]) -> [F; W] { let mut result = [F::ZERO; W]; for i in 0..W { - result[i] = sbox_b(x[i]); + result[i] = x[i].cube_root(); } result } -// #[inline(always)] -#[unroll_for_loops] -fn sbox_a(x: F) -> F { - // x^{-5}, via Fermat's little theorem - // TODO: This only works for our current field. - const EXP: u64 = 7378697628517453005; - - let mut product = F::ONE; - let mut current = x; - - for i in 0..64 { - if ((EXP >> i) & 1) != 0 { - product = product * current; - } - current = current.square(); - } - product -} - -#[inline(always)] -fn sbox_b(x: F) -> F { - // x^5 - let x2 = x.square(); - let x3 = x2 * x; - x2 * x3 -} - -// #[inline(always)] #[unroll_for_loops] fn mds_layer(x: [F; W]) -> [F; W] { let mut result = [F::ZERO; W]; @@ -537,7 +448,6 @@ fn mds_layer(x: [F; W]) -> [F; W] { result } -#[inline(always)] #[unroll_for_loops] fn constant_layer(xs: [F; W], con: &[u64; W]) -> [F; W] { let mut result = [F::ZERO; W];