diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 9df6360c..225af379 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -13,8 +13,7 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Evaluates a full GMiMC permutation with 12 state elements, and writes the output to the next -/// gate's first `width` wires (which could be the input of another `GMiMCGate`). +/// Evaluates a full GMiMC permutation with 12 state elements. /// /// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. /// It has a flag which can be used to swap the first four inputs with the next four, for ordering diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 993623c3..a2df7cff 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -11,6 +11,7 @@ pub mod gmimc; pub mod insertion; pub mod interpolation; pub mod noop; +pub mod poseidon; pub(crate) mod public_input; pub mod random_access; pub mod reducing; diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index cec6ea71..4449e89f 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -7,7 +7,8 @@ use std::convert::TryInto; use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; -use crate::field::field_types::PrimeField; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::{Field, PrimeField}; // The number of full rounds and partial rounds is given by the // calc_round_numbers.py script. They happen to be the same for both @@ -15,9 +16,9 @@ use crate::field::field_types::PrimeField; // // NB: Changing any of these values will require regenerating all of // the precomputed constant arrays in this file. -const HALF_N_FULL_ROUNDS: usize = 4; -const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; -const N_PARTIAL_ROUNDS: usize = 22; +pub const HALF_N_FULL_ROUNDS: usize = 4; +pub const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; +pub const N_PARTIAL_ROUNDS: usize = 22; const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) @@ -25,7 +26,7 @@ const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. : /// `generate_constants` about how these were generated. We include enough for a WIDTH of 12; /// smaller widths just use a subset. #[rustfmt::skip] -const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ +pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ // WARNING: These must be in 0..CrandallField::ORDER (i.e. canonical form). If this condition is // not met, some platform-specific implementation of constant_layer may return incorrect // results. @@ -165,6 +166,32 @@ where res } + #[inline(always)] + #[unroll_for_loops] + fn mds_row_shf_field, const D: usize>( + r: usize, + v: &[F; WIDTH], + ) -> F { + debug_assert!(r < WIDTH); + // The values of MDS_MATRIX_EXPS are known to be small, so we can + // accumulate all the products for each row and reduce just once + // at the end (done by the caller). + + // NB: Unrolling this, calculating each term independently, and + // summing at the end, didn't improve performance for me. + let mut res = F::ZERO; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + res += v[(i + r) % WIDTH] * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]); + } + } + + res + } + #[inline(always)] #[unroll_for_loops] fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] { @@ -188,19 +215,41 @@ where #[inline(always)] #[unroll_for_loops] - fn partial_first_constant_layer(state: &mut [Self; WIDTH]) { + fn mds_layer_field, const D: usize>( + state: &[F; WIDTH], + ) -> [F; WIDTH] { + let mut result = [F::ZERO; WIDTH]; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for r in 0..12 { + if r < WIDTH { + result[r] = Self::mds_row_shf_field(r, state); + } + } + + result + } + + #[inline(always)] + #[unroll_for_loops] + fn partial_first_constant_layer, const D: usize>( + state: &mut [F; WIDTH], + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { - state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); + state[i] += F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); } } } #[inline(always)] #[unroll_for_loops] - fn mds_partial_layer_init(state: &[Self; WIDTH]) -> [Self; WIDTH] { - let mut result = [Self::ZERO; WIDTH]; + fn mds_partial_layer_init, const D: usize>( + state: &[F; WIDTH], + ) -> [F; WIDTH] { + let mut result = [F::ZERO; WIDTH]; // Initial matrix has first row/column = [1, 0, ..., 0]; @@ -216,7 +265,7 @@ where // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in // column-major order so that this dot product is cache // friendly. - let t = Self::from_canonical_u64( + let t = F::from_canonical_u64( Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], ); result[c] += state[r] * t; @@ -265,17 +314,51 @@ where #[inline(always)] #[unroll_for_loops] - fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) { + fn mds_partial_layer_fast_field, const D: usize>( + state: &[F; WIDTH], + r: usize, + ) -> [F; WIDTH] { + // Set d = [M_00 | w^] dot [state] + + let s0 = state[0]; + let mut d = s0 * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]); + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d += state[i] * t; + } + } + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let mut result = [F::ZERO; WIDTH]; + result[0] = d; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = state[0] * t + state[i]; + } + } + result + } + + #[inline(always)] + #[unroll_for_loops] + fn constant_layer, const D: usize>( + state: &mut [F; WIDTH], + round_ctr: usize, + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { - state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + state[i] += F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); } } } #[inline(always)] - fn sbox_monomial(x: Self) -> Self { + fn sbox_monomial, const D: usize>(x: F) -> F { // x |--> x^7 let x2 = x * x; let x4 = x2 * x2; @@ -285,7 +368,7 @@ where #[inline(always)] #[unroll_for_loops] - fn sbox_layer(state: &mut [Self; WIDTH]) { + fn sbox_layer, const D: usize>(state: &mut [F; WIDTH]) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH {