PR feedback

This commit is contained in:
wborgeaud 2021-09-18 09:23:39 +02:00
parent b8f6b3a778
commit 0be8650bca
2 changed files with 52 additions and 82 deletions

View File

@ -58,17 +58,22 @@ where
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set
/// of full rounds. /// of full rounds.
fn wire_full_sbox_0(round: usize, i: usize) -> usize { fn wire_full_sbox_0(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < WIDTH);
2 * WIDTH + 1 + WIDTH * round + i 2 * WIDTH + 1 + WIDTH * round + i
} }
/// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds.
fn wire_partial_sbox(round: usize) -> usize { fn wire_partial_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round
} }
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set
/// of full rounds. /// of full rounds.
fn wire_full_sbox_1(round: usize, i: usize) -> usize { fn wire_full_sbox_1(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < WIDTH);
2 * WIDTH 2 * WIDTH
+ 1 + 1
+ WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round)

View File

@ -18,9 +18,9 @@ use crate::plonk::circuit_builder::CircuitBuilder;
// //
// NB: Changing any of these values will require regenerating all of // NB: Changing any of these values will require regenerating all of
// the precomputed constant arrays in this file. // the precomputed constant arrays in this file.
pub const HALF_N_FULL_ROUNDS: usize = 4; pub(crate) const HALF_N_FULL_ROUNDS: usize = 4;
pub const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
pub const N_PARTIAL_ROUNDS: usize = 22; pub(crate) const N_PARTIAL_ROUNDS: usize = 22;
const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :)
@ -170,7 +170,7 @@ where
#[inline(always)] #[inline(always)]
#[unroll_for_loops] #[unroll_for_loops]
/// Same as `mds_row_shf` for general fields. /// Same as `mds_row_shf` for field extensions of `Self`.
fn mds_row_shf_field<F: FieldExtension<D, BaseField = Self>, const D: usize>( fn mds_row_shf_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
r: usize, r: usize,
v: &[F; WIDTH], v: &[F; WIDTH],
@ -188,8 +188,6 @@ where
res res
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_row_shf`. /// Recursive version of `mds_row_shf`.
fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>( fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
@ -200,17 +198,14 @@ where
debug_assert!(r < WIDTH); debug_assert!(r < WIDTH);
let mut res = builder.zero_extension(); let mut res = builder.zero_extension();
assert!(WIDTH <= 12); for i in 0..WIDTH {
for i in 0..12 { res = builder.arithmetic_extension(
if i < WIDTH { F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
res = builder.arithmetic_extension( F::ONE,
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), one,
F::ONE, v[(i + r) % WIDTH],
one, res,
v[(i + r) % WIDTH], );
res,
);
}
} }
res res
@ -239,7 +234,7 @@ where
#[inline(always)] #[inline(always)]
#[unroll_for_loops] #[unroll_for_loops]
/// Same as `mds_layer` for general fields. /// Same as `mds_layer` for field extensions of `Self`.
fn mds_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>( fn mds_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH], state: &[F; WIDTH],
) -> [F; WIDTH] { ) -> [F; WIDTH] {
@ -255,8 +250,6 @@ where
result result
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_layer`. /// Recursive version of `mds_layer`.
fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>( fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
@ -264,11 +257,8 @@ where
) -> [ExtensionTarget<D>; WIDTH] { ) -> [ExtensionTarget<D>; WIDTH] {
let mut result = [builder.zero_extension(); WIDTH]; let mut result = [builder.zero_extension(); WIDTH];
assert!(WIDTH <= 12); for r in 0..WIDTH {
for r in 0..12 { result[r] = Self::mds_row_shf_recursive(builder, r, state);
if r < WIDTH {
result[r] = Self::mds_row_shf_recursive(builder, r, state);
}
} }
result result
@ -287,25 +277,20 @@ where
} }
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `partial_first_constant_layer`. /// Recursive version of `partial_first_constant_layer`.
fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>( fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH], state: &mut [ExtensionTarget<D>; WIDTH],
) { ) {
let one = builder.one_extension(); let one = builder.one_extension();
assert!(WIDTH <= 12); for i in 0..WIDTH {
for i in 0..12 { state[i] = builder.arithmetic_extension(
if i < WIDTH { F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
state[i] = builder.arithmetic_extension( F::ONE,
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), one,
F::ONE, one,
one, state[i],
one, );
state[i],
);
}
} }
} }
@ -341,8 +326,6 @@ where
result result
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_partial_layer_init`. /// Recursive version of `mds_partial_layer_init`.
fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>( fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
@ -353,18 +336,14 @@ where
result[0] = state[0]; result[0] = state[0];
assert!(WIDTH <= 12); for c in 1..WIDTH {
for c in 1..12 { assert!(WIDTH <= 12);
if c < WIDTH { for r in 1..12 {
assert!(WIDTH <= 12); if r < WIDTH {
for r in 1..12 { let t = F::from_canonical_u64(
if r < WIDTH { Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
let t = F::from_canonical_u64( );
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
);
result[c] =
builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
}
} }
} }
} }
@ -409,7 +388,7 @@ where
#[inline(always)] #[inline(always)]
#[unroll_for_loops] #[unroll_for_loops]
/// Same as `mds_partial_layer_fast` for general fields. /// Same as `mds_partial_layer_fast` for field extensions of `Self`.
fn mds_partial_layer_fast_field<F: FieldExtension<D, BaseField = Self>, const D: usize>( fn mds_partial_layer_fast_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH], state: &[F; WIDTH],
r: usize, r: usize,
@ -437,8 +416,6 @@ where
result result
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_partial_layer_fast`. /// Recursive version of `mds_partial_layer_fast`.
fn mds_partial_layer_fast_recursive<F: RichField + Extendable<D>, const D: usize>( fn mds_partial_layer_fast_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
@ -456,12 +433,9 @@ where
s0, s0,
zero, zero,
); );
assert!(WIDTH <= 12); for i in 1..WIDTH {
for i in 1..12 { let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
if i < WIDTH { d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
}
} }
let mut result = [zero; WIDTH]; let mut result = [zero; WIDTH];
@ -489,6 +463,7 @@ where
#[inline(always)] #[inline(always)]
#[unroll_for_loops] #[unroll_for_loops]
/// Same as `constant_layer` for field extensions of `Self`.
fn constant_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>( fn constant_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH], state: &mut [F; WIDTH],
round_ctr: usize, round_ctr: usize,
@ -501,8 +476,6 @@ where
} }
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `constant_layer`. /// Recursive version of `constant_layer`.
fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>( fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
@ -510,17 +483,14 @@ where
round_ctr: usize, round_ctr: usize,
) { ) {
let one = builder.one_extension(); let one = builder.one_extension();
assert!(WIDTH <= 12); for i in 0..WIDTH {
for i in 0..12 { state[i] = builder.arithmetic_extension(
if i < WIDTH { F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
state[i] = builder.arithmetic_extension( F::ONE,
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), one,
F::ONE, one,
one, state[i],
one, );
state[i],
);
}
} }
} }
@ -533,7 +503,6 @@ where
x3 * x4 x3 * x4
} }
#[inline(always)]
/// Recursive version of `sbox_monomial`. /// Recursive version of `sbox_monomial`.
fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>( fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
@ -556,6 +525,7 @@ where
#[inline(always)] #[inline(always)]
#[unroll_for_loops] #[unroll_for_loops]
/// Same as `sbox_layer` for field extensions of `Self`.
fn sbox_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>( fn sbox_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH], state: &mut [F; WIDTH],
) { ) {
@ -567,18 +537,13 @@ where
} }
} }
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `sbox_layer`. /// Recursive version of `sbox_layer`.
fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>( fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH], state: &mut [ExtensionTarget<D>; WIDTH],
) { ) {
assert!(WIDTH <= 12); for i in 0..WIDTH {
for i in 0..12 { state[i] = Self::sbox_monomial_recursive(builder, state[i]);
if i < WIDTH {
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
}
} }
} }