From 0be8650bca117a891e9e950e9c5e55bcfbdcc56e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Sat, 18 Sep 2021 09:23:39 +0200 Subject: [PATCH] PR feedback --- src/gates/poseidon.rs | 5 ++ src/hash/poseidon.rs | 129 +++++++++++++++--------------------------- 2 files changed, 52 insertions(+), 82 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 651ac4fe..3aef44ab 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -58,17 +58,22 @@ where /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < WIDTH); 2 * WIDTH + 1 + WIDTH * round + i } /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round } /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < WIDTH); 2 * WIDTH + 1 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index e3b6e1e0..ca519634 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -18,9 +18,9 @@ use crate::plonk::circuit_builder::CircuitBuilder; // // NB: Changing any of these values will require regenerating all of // the precomputed constant arrays in this file. -pub const HALF_N_FULL_ROUNDS: usize = 4; -pub const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; -pub const N_PARTIAL_ROUNDS: usize = 22; +pub(crate) const HALF_N_FULL_ROUNDS: usize = 4; +pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; +pub(crate) const N_PARTIAL_ROUNDS: usize = 22; const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) @@ -170,7 +170,7 @@ where #[inline(always)] #[unroll_for_loops] - /// Same as `mds_row_shf` for general fields. + /// Same as `mds_row_shf` for field extensions of `Self`. fn mds_row_shf_field, const D: usize>( r: usize, v: &[F; WIDTH], @@ -188,8 +188,6 @@ where res } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_row_shf`. fn mds_row_shf_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -200,17 +198,14 @@ where debug_assert!(r < WIDTH); let mut res = builder.zero_extension(); - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - res = builder.arithmetic_extension( - F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), - F::ONE, - one, - v[(i + r) % WIDTH], - res, - ); - } + for i in 0..WIDTH { + res = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), + F::ONE, + one, + v[(i + r) % WIDTH], + res, + ); } res @@ -239,7 +234,7 @@ where #[inline(always)] #[unroll_for_loops] - /// Same as `mds_layer` for general fields. + /// Same as `mds_layer` for field extensions of `Self`. fn mds_layer_field, const D: usize>( state: &[F; WIDTH], ) -> [F; WIDTH] { @@ -255,8 +250,6 @@ where result } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_layer`. fn mds_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -264,11 +257,8 @@ where ) -> [ExtensionTarget; WIDTH] { let mut result = [builder.zero_extension(); WIDTH]; - assert!(WIDTH <= 12); - for r in 0..12 { - if r < WIDTH { - result[r] = Self::mds_row_shf_recursive(builder, r, state); - } + for r in 0..WIDTH { + result[r] = Self::mds_row_shf_recursive(builder, r, state); } result @@ -287,25 +277,20 @@ where } } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `partial_first_constant_layer`. fn partial_first_constant_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], ) { let one = builder.one_extension(); - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - state[i] = builder.arithmetic_extension( - F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), - F::ONE, - one, - one, - state[i], - ); - } + for i in 0..WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), + F::ONE, + one, + one, + state[i], + ); } } @@ -341,8 +326,6 @@ where result } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_partial_layer_init`. fn mds_partial_layer_init_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -353,18 +336,14 @@ where result[0] = state[0]; - assert!(WIDTH <= 12); - for c in 1..12 { - if c < WIDTH { - assert!(WIDTH <= 12); - for r in 1..12 { - if r < WIDTH { - let t = F::from_canonical_u64( - Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], - ); - result[c] = - builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); - } + for c in 1..WIDTH { + assert!(WIDTH <= 12); + for r in 1..12 { + if r < WIDTH { + let t = F::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], + ); + result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); } } } @@ -409,7 +388,7 @@ where #[inline(always)] #[unroll_for_loops] - /// Same as `mds_partial_layer_fast` for general fields. + /// Same as `mds_partial_layer_fast` for field extensions of `Self`. fn mds_partial_layer_fast_field, const D: usize>( state: &[F; WIDTH], r: usize, @@ -437,8 +416,6 @@ where result } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_partial_layer_fast`. fn mds_partial_layer_fast_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -456,12 +433,9 @@ where s0, zero, ); - assert!(WIDTH <= 12); - for i in 1..12 { - if i < WIDTH { - let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); - d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); - } + for i in 1..WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); } let mut result = [zero; WIDTH]; @@ -489,6 +463,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `constant_layer` for field extensions of `Self`. fn constant_layer_field, const D: usize>( state: &mut [F; WIDTH], round_ctr: usize, @@ -501,8 +476,6 @@ where } } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `constant_layer`. fn constant_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -510,17 +483,14 @@ where round_ctr: usize, ) { let one = builder.one_extension(); - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - state[i] = builder.arithmetic_extension( - F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), - F::ONE, - one, - one, - state[i], - ); - } + for i in 0..WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), + F::ONE, + one, + one, + state[i], + ); } } @@ -533,7 +503,6 @@ where x3 * x4 } - #[inline(always)] /// Recursive version of `sbox_monomial`. fn sbox_monomial_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -556,6 +525,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `sbox_layer` for field extensions of `Self`. fn sbox_layer_field, const D: usize>( state: &mut [F; WIDTH], ) { @@ -567,18 +537,13 @@ where } } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `sbox_layer`. fn sbox_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], ) { - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - state[i] = Self::sbox_monomial_recursive(builder, state[i]); - } + for i in 0..WIDTH { + state[i] = Self::sbox_monomial_recursive(builder, state[i]); } }