This commit is contained in:
wborgeaud 2021-09-17 13:29:59 +02:00
parent 5488be2acd
commit e418997d6f
2 changed files with 1 additions and 27 deletions

View File

@ -175,15 +175,8 @@ where
v: &[F; WIDTH],
) -> F {
debug_assert!(r < WIDTH);
// The values of MDS_MATRIX_EXPS are known to be small, so we can
// accumulate all the products for each row and reduce just once
// at the end (done by the caller).
// NB: Unrolling this, calculating each term independently, and
// summing at the end, didn't improve performance for me.
let mut res = F::ZERO;
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
@ -203,15 +196,8 @@ where
) -> ExtensionTarget<D> {
let one = builder.one_extension();
debug_assert!(r < WIDTH);
// The values of MDS_MATRIX_EXPS are known to be small, so we can
// accumulate all the products for each row and reduce just once
// at the end (done by the caller).
// NB: Unrolling this, calculating each term independently, and
// summing at the end, didn't improve performance for me.
let mut res = builder.zero_extension();
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
@ -256,7 +242,6 @@ where
) -> [F; WIDTH] {
let mut result = [F::ZERO; WIDTH];
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for r in 0..12 {
if r < WIDTH {
@ -275,7 +260,6 @@ where
) -> [ExtensionTarget<D>; WIDTH] {
let mut result = [builder.zero_extension(); WIDTH];
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for r in 0..12 {
if r < WIDTH {
@ -361,9 +345,6 @@ where
let one = builder.one_extension();
let mut result = [builder.zero_extension(); WIDTH];
// Initial matrix has first row/column = [1, 0, ..., 0];
// c = 0
result[0] = state[0];
assert!(WIDTH <= 12);
@ -372,9 +353,6 @@ where
assert!(WIDTH <= 12);
for r in 1..12 {
if r < WIDTH {
// NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in
// column-major order so that this dot product is cache
// friendly.
let t = F::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
@ -429,8 +407,6 @@ where
state: &[F; WIDTH],
r: usize,
) -> [F; WIDTH] {
// Set d = [M_00 | w^] dot [state]
let s0 = state[0];
let mut d = s0 * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]);
assert!(WIDTH <= 12);
@ -464,7 +440,6 @@ where
let zero = builder.zero_extension();
let one = builder.one_extension();
// Set d = [M_00 | w^] dot [state]
let s0 = state[0];
let mut d = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]),
@ -481,7 +456,6 @@ where
}
}
// result = [d] concat [state[0] * v + state[shift up by 1]]
let mut result = [zero; WIDTH];
result[0] = d;
assert!(WIDTH <= 12);

View File

@ -377,7 +377,7 @@ mod tests {
}
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, 4>::new(config.clone());
let mut builder = CircuitBuilder::<F, 4>::new(config);
let mut recursive_challenger = RecursiveChallenger::new(&mut builder);
let mut recursive_outputs_per_round: Vec<Vec<Target>> = Vec::new();
for (r, inputs) in inputs_per_round.iter().enumerate() {