diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 291cab46..221400e0 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -7,7 +7,7 @@ use crate::field::field::Field; use crate::gates::deterministic_gate::DeterministicGate; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator2}; -use crate::gmimc::gmimc_permute; +use crate::gmimc::{gmimc_permute_array, gmimc_permute}; use crate::target::Target2; use crate::wire::Wire; use crate::witness::PartialWitness; diff --git a/src/gmimc.rs b/src/gmimc.rs index 2e5ede16..9013a2b7 100644 --- a/src/gmimc.rs +++ b/src/gmimc.rs @@ -13,6 +13,30 @@ pub fn gmimc_compress(a: [F; 4], b: [F; 4], constants: [state_1[0], state_1[1], state_1[2], state_1[3]] } +/// Like `gmimc_permute`, but takes constants as an owned array. May be faster. +#[unroll_for_loops] +pub fn gmimc_permute_array( + mut xs: [F; W], + constants: [u64; R], +) -> [F; W] { + // Value that is implicitly added to each element. + // See https://affine.group/2020/02/starkware-challenge + let mut addition_buffer = F::ZERO; + + for r in 0..R { + let active = r % W; + let f = (xs[active] + addition_buffer + F::from_canonical_u64(constants[r])).cube(); + addition_buffer += f; + xs[active] -= f; + } + + for i in 0..W { + xs[i] += addition_buffer; + } + + xs +} + #[unroll_for_loops] pub fn gmimc_permute( mut xs: [F; W], @@ -57,18 +81,29 @@ pub fn gmimc_permute_naive( #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::field::Field; + use crate::field::field::Field; use crate::gmimc::{gmimc_permute, gmimc_permute_naive}; + use std::sync::Arc; #[test] fn consistency() { type F = CrandallField; - let mut xs = [F::ZERO; 12]; - for i in 0..12 { + const W: usize = 12; + const R: usize = 101; + + let mut constants = [F::ZERO; R]; + for i in 0..R { + constants[i] = F::from_canonical_usize(i); + } + let constants = Arc::new(constants); + + let mut xs = [F::ZERO; W]; + for i in 0..W { xs[i] = F::from_canonical_usize(i); } - let out = gmimc_permute::<_, _, 108>(xs); - let out_naive = gmimc_permute_naive::<_, _, 108>(xs); + + let out = gmimc_permute::(xs, constants.clone()); + let out_naive = gmimc_permute_naive::(xs, constants); assert_eq!(out, out_naive); } } diff --git a/src/main.rs b/src/main.rs index c6e416bf..e5188d89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -53,12 +53,10 @@ fn bench_gmimc() { for i in 0..GMIMC_ROUNDS { constants[i] = F::from_canonical_u64(GMIMC_CONSTANTS[i]); } - let constants = Arc::new(constants); let threads = 12; let hashes_per_poly = 1 << (13 + 3); let threads = (0..threads).map(|_i| { - let constants = constants.clone(); thread::spawn(move || { let mut x = [F::ZERO; 12]; for i in 0..12 { @@ -68,7 +66,7 @@ fn bench_gmimc() { let hashes_per_thread = hashes_per_poly * PROVER_POLYS / threads; let start = Instant::now(); for _ in 0..hashes_per_thread { - x = gmimc::gmimc_permute::<_, 12, GMIMC_ROUNDS>(x, constants.clone()); + x = gmimc::gmimc_permute_array::<_, 12, GMIMC_ROUNDS>(x, GMIMC_CONSTANTS); } let duration = start.elapsed(); println!("took {:?}", duration);