From 14bbf5ae11baa27a76516ea79c7dbef6a1020753 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 17 Sep 2021 17:50:43 +0200 Subject: [PATCH] Fix AVX2 conflict --- src/gates/poseidon.rs | 24 ++++++++++++------------ src/hash/poseidon.rs | 30 +++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index c918c17a..651ac4fe 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -118,13 +118,13 @@ where // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -152,13 +152,13 @@ where // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -197,13 +197,13 @@ where // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -230,13 +230,13 @@ where // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -424,14 +424,14 @@ where let mut round_ctr = 0; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { out_buffer.set_wire( local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), state[i], ); } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -462,14 +462,14 @@ where round_ctr += poseidon::N_PARTIAL_ROUNDS; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { out_buffer.set_wire( local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), state[i], ); } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index f11c8bf0..e3b6e1e0 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -9,7 +9,7 @@ use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::field::field_types::{PrimeField, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; // The number of full rounds and partial rounds is given by the @@ -478,7 +478,18 @@ where #[inline(always)] #[unroll_for_loops] - fn constant_layer, const D: usize>( + fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn constant_layer_field, const D: usize>( state: &mut [F; WIDTH], round_ctr: usize, ) { @@ -534,7 +545,20 @@ where #[inline(always)] #[unroll_for_loops] - fn sbox_layer, const D: usize>(state: &mut [F; WIDTH]) { + fn sbox_layer(state: &mut [Self; WIDTH]) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = Self::sbox_monomial(state[i]); + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn sbox_layer_field, const D: usize>( + state: &mut [F; WIDTH], + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH {