PR feedback

This commit is contained in:
wborgeaud 2021-09-18 09:23:39 +02:00
parent b8f6b3a778
commit 0be8650bca
2 changed files with 52 additions and 82 deletions

View File

@ -58,17 +58,22 @@ where
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set
/// of full rounds.
fn wire_full_sbox_0(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < WIDTH);
2 * WIDTH + 1 + WIDTH * round + i
}
/// A wire which stores the input of the S-box of the `round`-th round of the partial rounds.
fn wire_partial_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round
}
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set
/// of full rounds.
fn wire_full_sbox_1(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < WIDTH);
2 * WIDTH
+ 1
+ WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round)

View File

@ -18,9 +18,9 @@ use crate::plonk::circuit_builder::CircuitBuilder;
//
// NB: Changing any of these values will require regenerating all of
// the precomputed constant arrays in this file.
pub const HALF_N_FULL_ROUNDS: usize = 4;
pub const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
pub const N_PARTIAL_ROUNDS: usize = 22;
pub(crate) const HALF_N_FULL_ROUNDS: usize = 4;
pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
pub(crate) const N_PARTIAL_ROUNDS: usize = 22;
const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :)
@ -170,7 +170,7 @@ where
#[inline(always)]
#[unroll_for_loops]
/// Same as `mds_row_shf` for general fields.
/// Same as `mds_row_shf` for field extensions of `Self`.
fn mds_row_shf_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
r: usize,
v: &[F; WIDTH],
@ -188,8 +188,6 @@ where
res
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_row_shf`.
fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -200,17 +198,14 @@ where
debug_assert!(r < WIDTH);
let mut res = builder.zero_extension();
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
res = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
F::ONE,
one,
v[(i + r) % WIDTH],
res,
);
}
for i in 0..WIDTH {
res = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
F::ONE,
one,
v[(i + r) % WIDTH],
res,
);
}
res
@ -239,7 +234,7 @@ where
#[inline(always)]
#[unroll_for_loops]
/// Same as `mds_layer` for general fields.
/// Same as `mds_layer` for field extensions of `Self`.
fn mds_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH],
) -> [F; WIDTH] {
@ -255,8 +250,6 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_layer`.
fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -264,11 +257,8 @@ where
) -> [ExtensionTarget<D>; WIDTH] {
let mut result = [builder.zero_extension(); WIDTH];
assert!(WIDTH <= 12);
for r in 0..12 {
if r < WIDTH {
result[r] = Self::mds_row_shf_recursive(builder, r, state);
}
for r in 0..WIDTH {
result[r] = Self::mds_row_shf_recursive(builder, r, state);
}
result
@ -287,25 +277,20 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `partial_first_constant_layer`.
fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
let one = builder.one_extension();
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
F::ONE,
one,
one,
state[i],
);
}
for i in 0..WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
F::ONE,
one,
one,
state[i],
);
}
}
@ -341,8 +326,6 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_partial_layer_init`.
fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -353,18 +336,14 @@ where
result[0] = state[0];
assert!(WIDTH <= 12);
for c in 1..12 {
if c < WIDTH {
assert!(WIDTH <= 12);
for r in 1..12 {
if r < WIDTH {
let t = F::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
result[c] =
builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
}
for c in 1..WIDTH {
assert!(WIDTH <= 12);
for r in 1..12 {
if r < WIDTH {
let t = F::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
}
}
}
@ -409,7 +388,7 @@ where
#[inline(always)]
#[unroll_for_loops]
/// Same as `mds_partial_layer_fast` for general fields.
/// Same as `mds_partial_layer_fast` for field extensions of `Self`.
fn mds_partial_layer_fast_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH],
r: usize,
@ -437,8 +416,6 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `mds_partial_layer_fast`.
fn mds_partial_layer_fast_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -456,12 +433,9 @@ where
s0,
zero,
);
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
}
for i in 1..WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
}
let mut result = [zero; WIDTH];
@ -489,6 +463,7 @@ where
#[inline(always)]
#[unroll_for_loops]
/// Same as `constant_layer` for field extensions of `Self`.
fn constant_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH],
round_ctr: usize,
@ -501,8 +476,6 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `constant_layer`.
fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -510,17 +483,14 @@ where
round_ctr: usize,
) {
let one = builder.one_extension();
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
F::ONE,
one,
one,
state[i],
);
}
for i in 0..WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
F::ONE,
one,
one,
state[i],
);
}
}
@ -533,7 +503,6 @@ where
x3 * x4
}
#[inline(always)]
/// Recursive version of `sbox_monomial`.
fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
@ -556,6 +525,7 @@ where
#[inline(always)]
#[unroll_for_loops]
/// Same as `sbox_layer` for field extensions of `Self`.
fn sbox_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH],
) {
@ -567,18 +537,13 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
/// Recursive version of `sbox_layer`.
fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
}
for i in 0..WIDTH {
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
}
}