From b11e54d6edea362882f9dbc4cd8e9aab6e8de24f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 17:51:07 +0200 Subject: [PATCH 01/12] Semi-working --- src/gates/gmimc.rs | 3 +- src/gates/mod.rs | 1 + src/hash/poseidon.rs | 111 +++++++++++++++++++++++++++++++++++++------ 3 files changed, 99 insertions(+), 16 deletions(-) diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 9df6360c..225af379 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -13,8 +13,7 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Evaluates a full GMiMC permutation with 12 state elements, and writes the output to the next -/// gate's first `width` wires (which could be the input of another `GMiMCGate`). +/// Evaluates a full GMiMC permutation with 12 state elements. /// /// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. /// It has a flag which can be used to swap the first four inputs with the next four, for ordering diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 993623c3..a2df7cff 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -11,6 +11,7 @@ pub mod gmimc; pub mod insertion; pub mod interpolation; pub mod noop; +pub mod poseidon; pub(crate) mod public_input; pub mod random_access; pub mod reducing; diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index cec6ea71..4449e89f 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -7,7 +7,8 @@ use std::convert::TryInto; use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; -use crate::field::field_types::PrimeField; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::{Field, PrimeField}; // The number of full rounds and partial rounds is given by the // calc_round_numbers.py script. They happen to be the same for both @@ -15,9 +16,9 @@ use crate::field::field_types::PrimeField; // // NB: Changing any of these values will require regenerating all of // the precomputed constant arrays in this file. -const HALF_N_FULL_ROUNDS: usize = 4; -const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; -const N_PARTIAL_ROUNDS: usize = 22; +pub const HALF_N_FULL_ROUNDS: usize = 4; +pub const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; +pub const N_PARTIAL_ROUNDS: usize = 22; const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) @@ -25,7 +26,7 @@ const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. : /// `generate_constants` about how these were generated. We include enough for a WIDTH of 12; /// smaller widths just use a subset. #[rustfmt::skip] -const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ +pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ // WARNING: These must be in 0..CrandallField::ORDER (i.e. canonical form). If this condition is // not met, some platform-specific implementation of constant_layer may return incorrect // results. @@ -165,6 +166,32 @@ where res } + #[inline(always)] + #[unroll_for_loops] + fn mds_row_shf_field, const D: usize>( + r: usize, + v: &[F; WIDTH], + ) -> F { + debug_assert!(r < WIDTH); + // The values of MDS_MATRIX_EXPS are known to be small, so we can + // accumulate all the products for each row and reduce just once + // at the end (done by the caller). + + // NB: Unrolling this, calculating each term independently, and + // summing at the end, didn't improve performance for me. + let mut res = F::ZERO; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + res += v[(i + r) % WIDTH] * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]); + } + } + + res + } + #[inline(always)] #[unroll_for_loops] fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] { @@ -188,19 +215,41 @@ where #[inline(always)] #[unroll_for_loops] - fn partial_first_constant_layer(state: &mut [Self; WIDTH]) { + fn mds_layer_field, const D: usize>( + state: &[F; WIDTH], + ) -> [F; WIDTH] { + let mut result = [F::ZERO; WIDTH]; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for r in 0..12 { + if r < WIDTH { + result[r] = Self::mds_row_shf_field(r, state); + } + } + + result + } + + #[inline(always)] + #[unroll_for_loops] + fn partial_first_constant_layer, const D: usize>( + state: &mut [F; WIDTH], + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { - state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); + state[i] += F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); } } } #[inline(always)] #[unroll_for_loops] - fn mds_partial_layer_init(state: &[Self; WIDTH]) -> [Self; WIDTH] { - let mut result = [Self::ZERO; WIDTH]; + fn mds_partial_layer_init, const D: usize>( + state: &[F; WIDTH], + ) -> [F; WIDTH] { + let mut result = [F::ZERO; WIDTH]; // Initial matrix has first row/column = [1, 0, ..., 0]; @@ -216,7 +265,7 @@ where // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in // column-major order so that this dot product is cache // friendly. - let t = Self::from_canonical_u64( + let t = F::from_canonical_u64( Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], ); result[c] += state[r] * t; @@ -265,17 +314,51 @@ where #[inline(always)] #[unroll_for_loops] - fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) { + fn mds_partial_layer_fast_field, const D: usize>( + state: &[F; WIDTH], + r: usize, + ) -> [F; WIDTH] { + // Set d = [M_00 | w^] dot [state] + + let s0 = state[0]; + let mut d = s0 * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]); + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d += state[i] * t; + } + } + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let mut result = [F::ZERO; WIDTH]; + result[0] = d; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = state[0] * t + state[i]; + } + } + result + } + + #[inline(always)] + #[unroll_for_loops] + fn constant_layer, const D: usize>( + state: &mut [F; WIDTH], + round_ctr: usize, + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { - state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + state[i] += F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); } } } #[inline(always)] - fn sbox_monomial(x: Self) -> Self { + fn sbox_monomial, const D: usize>(x: F) -> F { // x |--> x^7 let x2 = x * x; let x4 = x2 * x2; @@ -285,7 +368,7 @@ where #[inline(always)] #[unroll_for_loops] - fn sbox_layer(state: &mut [Self; WIDTH]) { + fn sbox_layer, const D: usize>(state: &mut [F; WIDTH]) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { From 49ba7ccb52e8ffaa909326800ebfc381310961d9 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 18:16:05 +0200 Subject: [PATCH 02/12] Working --- src/gates/poseidon.rs | 407 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 407 insertions(+) create mode 100644 src/gates/poseidon.rs diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs new file mode 100644 index 00000000..2d333696 --- /dev/null +++ b/src/gates/poseidon.rs @@ -0,0 +1,407 @@ +use std::convert::TryInto; +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::hash::poseidon; +use crate::hash::poseidon::Poseidon; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Evaluates a full Poseidon permutation with 12 state elements. +/// +/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. +/// It has a flag which can be used to swap the first four inputs with the next four, for ordering +/// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for +/// computing the index of the leaf based on these swap bits. +#[derive(Debug)] +pub struct PoseidonGate< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + PoseidonGate +where + [(); WIDTH - 1]: , +{ + pub fn new() -> Self { + PoseidonGate { + _phantom: PhantomData, + } + } + + /// The wire index for the `i`th input to the permutation. + pub fn wire_input(i: usize) -> usize { + i + } + + /// The wire index for the `i`th output to the permutation. + pub fn wire_output(i: usize) -> usize { + WIDTH + i + } + + /// If this is set to 1, the first four inputs will be swapped with the next four inputs. This + /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. + pub const WIRE_SWAP: usize = 2 * WIDTH; + + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set + /// of full rounds. + fn wire_full_sbox_0(round: usize, i: usize) -> usize { + 2 * WIDTH + 1 + WIDTH * round + i + } + + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. + fn wire_partial_sbox(round: usize) -> usize { + 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + } + + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set + /// of full rounds. + fn wire_full_sbox_1(round: usize, i: usize) -> usize { + 2 * WIDTH + + 1 + + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) + + poseidon::N_PARTIAL_ROUNDS + + i + } + + /// End of wire indices, exclusive. + fn end() -> usize { + 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + } +} + +impl + Poseidon, const D: usize, const WIDTH: usize> Gate + for PoseidonGate +where + [(); WIDTH - 1]: , +{ + fn id(&self) -> String { + format!(" {:?}", WIDTH, self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::Extension::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + // for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state[0] += F::Extension::from_canonical_u64( + >::FAST_PARTIAL_ROUND_CONSTANTS[r], + ); + state = >::mds_partial_layer_fast_field(&state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + todo!() + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + todo!() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = PoseidonGenerator:: { + gate_index, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + Self::end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 7 + } + + fn num_constraints(&self) -> usize { + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + } +} + +#[derive(Debug)] +struct PoseidonGenerator< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + gate_index: usize, + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + SimpleGenerator for PoseidonGenerator +where + [(); WIDTH - 1]: , +{ + fn dependencies(&self) -> Vec { + (0..WIDTH) + .map(|i| PoseidonGate::::wire_input(i)) + .chain(Some(PoseidonGate::::WIRE_SWAP)) + .map(|input| Target::wire(self.gate_index, input)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let mut state = (0..WIDTH) + .map(|i| { + witness.get_wire(Wire { + gate: self.gate_index, + input: PoseidonGate::::wire_input(i), + }) + }) + .collect::>(); + + let swap_value = witness.get_wire(Wire { + gate: self.gate_index, + input: PoseidonGate::::WIRE_SWAP, + }); + debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + if swap_value == F::ONE { + for i in 0..4 { + state.swap(i, 4 + i); + } + } + + let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_partial_sbox(r)), + state[0], + ); + state[0] = >::sbox_monomial(state[0]); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); + } + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_partial_sbox( + poseidon::N_PARTIAL_ROUNDS - 1, + )), + state[0], + ); + state[0] = >::sbox_monomial(state[0]); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), + state[i], + ); + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_output(i)), + state[i], + ); + } + } +} + +mod yo {} +#[cfg(test)] +mod tests { + use std::convert::TryInto; + + use anyhow::Result; + + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::poseidon::PoseidonGate; + use crate::hash::poseidon::Poseidon; + use crate::iop::generator::generate_partial_witness; + use crate::iop::wire::Wire; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn generated_output() { + type F = CrandallField; + const WIDTH: usize = 12; + + let config = CircuitConfig { + num_wires: 143, + ..CircuitConfig::large_config() + }; + let mut builder = CircuitBuilder::new(config); + type Gate = PoseidonGate; + let gate = Gate::new(); + let gate_index = builder.add_gate(gate, vec![]); + let circuit = builder.build_prover(); + + let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); + + let mut inputs = PartialWitness::new(); + inputs.set_wire( + Wire { + gate: gate_index, + input: Gate::WIRE_SWAP, + }, + F::ZERO, + ); + for i in 0..WIDTH { + inputs.set_wire( + Wire { + gate: gate_index, + input: Gate::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness(inputs, &circuit.prover_only); + + let expected_outputs: [F; WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap()); + for i in 0..WIDTH { + let out = witness.get_wire(Wire { + gate: 0, + input: Gate::wire_output(i), + }); + assert_eq!(out, expected_outputs[i]); + } + } + + #[test] + fn low_degree() { + type F = CrandallField; + const WIDTH: usize = 12; + let gate = PoseidonGate::::new(); + test_low_degree(gate) + } + + #[test] + fn eval_fns() -> Result<()> { + type F = CrandallField; + const WIDTH: usize = 12; + let gate = PoseidonGate::::new(); + test_eval_fns(gate) + } +} From c508fe4362327d94b0b408054ef9cd046b653ec9 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 18:16:19 +0200 Subject: [PATCH 03/12] Minor --- src/gates/poseidon.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 2d333696..371ac0a5 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -324,7 +324,6 @@ where } } -mod yo {} #[cfg(test)] mod tests { use std::convert::TryInto; From 5d7f4de2a67742bbd7009484ce6f0148680b04bc Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 19:17:37 +0200 Subject: [PATCH 04/12] Working recursively --- src/gates/poseidon.rs | 158 +++++++++++++++++++++++++++++++- src/hash/poseidon.rs | 203 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 356 insertions(+), 5 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 371ac0a5..e8fa6ccf 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -131,7 +131,6 @@ where >::partial_first_constant_layer(&mut state); state = >::mds_partial_layer_init(&mut state); - // for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); @@ -170,7 +169,78 @@ where } fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - todo!() + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints } fn eval_unfiltered_recursively( @@ -178,7 +248,89 @@ where builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - todo!() + let one = builder.one_extension(); + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(builder.mul_sub_extension(swap, swap, swap)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + + >::partial_first_constant_layer_recursive(builder, &mut state); + state = >::mds_partial_layer_init_recursive(builder, &mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state[0] = builder.arithmetic_extension( + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), + F::ONE, + one, + one, + state[0], + ); + state = + >::mds_partial_layer_fast_field_recursive(builder, &state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state = >::mds_partial_layer_fast_field_recursive( + builder, + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints + .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); + } + + constraints } fn generators( diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 4449e89f..d3c42bc6 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -7,8 +7,10 @@ use std::convert::TryInto; use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; -use crate::field::extension_field::FieldExtension; -use crate::field::field_types::{Field, PrimeField}; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::plonk::circuit_builder::CircuitBuilder; // The number of full rounds and partial rounds is given by the // calc_round_numbers.py script. They happen to be the same for both @@ -192,6 +194,40 @@ where res } + #[inline(always)] + #[unroll_for_loops] + fn mds_row_shf_recursive, const D: usize>( + builder: &mut CircuitBuilder, + r: usize, + v: &[ExtensionTarget; WIDTH], + ) -> ExtensionTarget { + let one = builder.one_extension(); + debug_assert!(r < WIDTH); + // The values of MDS_MATRIX_EXPS are known to be small, so we can + // accumulate all the products for each row and reduce just once + // at the end (done by the caller). + + // NB: Unrolling this, calculating each term independently, and + // summing at the end, didn't improve performance for me. + let mut res = builder.zero_extension(); + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + res = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), + F::ONE, + one, + v[(i + r) % WIDTH], + res, + ); + } + } + + res + } + #[inline(always)] #[unroll_for_loops] fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] { @@ -231,6 +267,25 @@ where result } + #[inline(always)] + #[unroll_for_loops] + fn mds_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + ) -> [ExtensionTarget; WIDTH] { + let mut result = [builder.zero_extension(); WIDTH]; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for r in 0..12 { + if r < WIDTH { + result[r] = Self::mds_row_shf_recursive(builder, r, state); + } + } + + result + } + #[inline(always)] #[unroll_for_loops] fn partial_first_constant_layer, const D: usize>( @@ -244,6 +299,27 @@ where } } + #[inline(always)] + #[unroll_for_loops] + fn partial_first_constant_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + ) { + let one = builder.one_extension(); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), + F::ONE, + one, + one, + state[i], + ); + } + } + } + #[inline(always)] #[unroll_for_loops] fn mds_partial_layer_init, const D: usize>( @@ -276,6 +352,41 @@ where result } + #[inline(always)] + #[unroll_for_loops] + fn mds_partial_layer_init_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + ) -> [ExtensionTarget; WIDTH] { + let one = builder.one_extension(); + let mut result = [builder.zero_extension(); WIDTH]; + + // Initial matrix has first row/column = [1, 0, ..., 0]; + + // c = 0 + result[0] = state[0]; + + assert!(WIDTH <= 12); + for c in 1..12 { + if c < WIDTH { + assert!(WIDTH <= 12); + for r in 1..12 { + if r < WIDTH { + // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in + // column-major order so that this dot product is cache + // friendly. + let t = F::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], + ); + result[c] = + builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); + } + } + } + } + result + } + /// Computes s*A where s is the state row vector and A is the matrix /// /// [ M_00 | v ] @@ -343,6 +454,46 @@ where result } + #[inline(always)] + #[unroll_for_loops] + fn mds_partial_layer_fast_field_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + r: usize, + ) -> [ExtensionTarget; WIDTH] { + let zero = builder.zero_extension(); + let one = builder.one_extension(); + + // Set d = [M_00 | w^] dot [state] + let s0 = state[0]; + let mut d = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), + F::ONE, + one, + s0, + zero, + ); + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); + } + } + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let mut result = [zero; WIDTH]; + result[0] = d; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]); + } + } + result + } + #[inline(always)] #[unroll_for_loops] fn constant_layer, const D: usize>( @@ -357,6 +508,28 @@ where } } + #[inline(always)] + #[unroll_for_loops] + fn constant_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + round_ctr: usize, + ) { + let one = builder.one_extension(); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), + F::ONE, + one, + one, + state[i], + ); + } + } + } + #[inline(always)] fn sbox_monomial, const D: usize>(x: F) -> F { // x |--> x^7 @@ -366,6 +539,18 @@ where x3 * x4 } + #[inline(always)] + fn sbox_monomial_recursive, const D: usize>( + builder: &mut CircuitBuilder, + x: ExtensionTarget, + ) -> ExtensionTarget { + // x |--> x^7 + let x2 = builder.mul_extension(x, x); + let x4 = builder.mul_extension(x2, x2); + let x3 = builder.mul_extension(x, x2); + builder.mul_extension(x3, x4) + } + #[inline(always)] #[unroll_for_loops] fn sbox_layer, const D: usize>(state: &mut [F; WIDTH]) { @@ -377,6 +562,20 @@ where } } + #[inline(always)] + #[unroll_for_loops] + fn sbox_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + ) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = Self::sbox_monomial_recursive(builder, state[i]); + } + } + } + #[inline] fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { for _ in 0..HALF_N_FULL_ROUNDS { From b63d83aacff58d914da46f84c84ff8c3f53844b3 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 21:18:08 +0200 Subject: [PATCH 05/12] Add Poseidon gadget --- src/gadgets/hash.rs | 44 ++++++++++++++++++++++++++++++--- src/hash/hashing.rs | 2 +- src/plonk/circuit_data.rs | 2 +- src/plonk/recursive_verifier.rs | 2 +- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index 01519912..82de1095 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -3,8 +3,10 @@ use std::convert::TryInto; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::gmimc::GMiMCGate; +use crate::gates::poseidon::PoseidonGate; use crate::hash::gmimc::GMiMC; use crate::hash::hashing::{HashFamily, HASH_FAMILY}; +use crate::hash::poseidon::Poseidon; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_builder::CircuitBuilder; @@ -13,7 +15,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { pub fn permute(&mut self, inputs: [Target; W]) -> [Target; W] where - F: GMiMC, + F: GMiMC + Poseidon, + [(); W - 1]: , { // We don't want to swap any inputs, so set that wire to 0. let _false = self._false(); @@ -28,11 +31,12 @@ impl, const D: usize> CircuitBuilder { swap: BoolTarget, ) -> [Target; W] where - F: GMiMC, + F: GMiMC + Poseidon, + [(); W - 1]: , { match HASH_FAMILY { HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap), - HashFamily::Poseidon => todo!(), + HashFamily::Poseidon => self.poseidon_permute_swapped(inputs, swap), } } @@ -79,4 +83,38 @@ impl, const D: usize> CircuitBuilder { .try_into() .unwrap() } + + /// Conditionally swap two chunks of the inputs (useful in verifying Merkle proofs), then apply + /// the Poseidon permutation. + pub(crate) fn poseidon_permute_swapped( + &mut self, + inputs: [Target; W], + swap: BoolTarget, + ) -> [Target; W] + where + F: Poseidon, + [(); W - 1]: , + { + let gate_type = PoseidonGate::::new(); + let gate = self.add_gate(gate_type, vec![]); + + // We don't want to swap any inputs, so set that wire to 0. + let swap_wire = PoseidonGate::::WIRE_SWAP; + let swap_wire = Target::wire(gate, swap_wire); + self.connect(swap.target, swap_wire); + + // Route input wires. + for i in 0..W { + let in_wire = PoseidonGate::::wire_input(i); + let in_wire = Target::wire(gate, in_wire); + self.connect(inputs[i], in_wire); + } + + // Collect output wires. + (0..W) + .map(|i| Target::wire(gate, PoseidonGate::::wire_output(i))) + .collect::>() + .try_into() + .unwrap() + } } diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index 9b4791d3..ae14e058 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -10,7 +10,7 @@ pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; -pub(crate) const HASH_FAMILY: HashFamily = HashFamily::GMiMC; +pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; pub(crate) enum HashFamily { GMiMC, diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 21dfc28b..540e3e84 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -61,7 +61,7 @@ impl CircuitConfig { #[cfg(test)] pub(crate) fn large_config() -> Self { Self { - num_wires: 126, + num_wires: 143, num_routed_wires: 64, security_bits: 128, rate_bits: 3, diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 3bd50bec..6b0bd8c3 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -416,7 +416,7 @@ mod tests { type F = CrandallField; const D: usize = 4; let config = CircuitConfig { - num_wires: 126, + num_wires: 143, num_routed_wires: 64, security_bits: 128, rate_bits: 3, From 3534018fec6dfc5316df7891238d33f6da525a2c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 22:19:54 +0200 Subject: [PATCH 06/12] Remove hardcoded GMiMC --- src/hash/merkle_proofs.rs | 2 +- src/iop/challenger.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 22b6f318..793e114c 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -74,7 +74,7 @@ impl, const D: usize> CircuitBuilder { .concat() .try_into() .unwrap(); - let outputs = self.gmimc_permute_swapped(inputs, bit); + let outputs = self.permute_swapped(inputs, bit); state = HashOutTarget::from_vec(outputs[0..4].to_vec()); } diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 47d57db8..2fb43979 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -4,7 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{SPONGE_RATE, SPONGE_WIDTH}; +use crate::hash::hashing::{permute, SPONGE_RATE, SPONGE_WIDTH}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -105,7 +105,7 @@ impl Challenger { if self.output_buffer.is_empty() { // Evaluate the permutation to produce `r` new outputs. - self.sponge_state = F::gmimc_permute(self.sponge_state); + self.sponge_state = permute(self.sponge_state); self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); } @@ -160,7 +160,7 @@ impl Challenger { } // Apply the permutation. - self.sponge_state = F::gmimc_permute(self.sponge_state); + self.sponge_state = permute(self.sponge_state); } self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); From 5488be2acdd86ee200c0fad0557292ed8851e98f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 17 Sep 2021 13:15:22 +0200 Subject: [PATCH 07/12] Add HashGate constant type --- src/hash/hashing.rs | 2 ++ src/hash/merkle_proofs.rs | 11 +++++------ src/hash/poseidon.rs | 5 +---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index ae14e058..00cd0a74 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -2,6 +2,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; +use crate::gates::poseidon::PoseidonGate; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -11,6 +12,7 @@ pub(crate) const SPONGE_CAPACITY: usize = 4; pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; +pub(crate) type HashGate = PoseidonGate; pub(crate) enum HashFamily { GMiMC, diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 793e114c..7a176dd9 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -6,9 +6,8 @@ use serde::{Deserialize, Serialize}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; -use crate::gates::gmimc::GMiMCGate; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{compress, hash_or_noop}; +use crate::hash::hashing::{compress, hash_or_noop, HashGate}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; @@ -107,10 +106,10 @@ impl, const D: usize> CircuitBuilder { let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let gate_type = GMiMCGate::::new(); + let gate_type = HashGate::::new(); let gate = self.add_gate(gate_type, vec![]); - let swap_wire = GMiMCGate::::WIRE_SWAP; + let swap_wire = HashGate::::WIRE_SWAP; let swap_wire = Target::Wire(Wire { gate, input: swap_wire, @@ -121,7 +120,7 @@ impl, const D: usize> CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_input(i), + input: HashGate::::wire_input(i), }) }) .collect::>(); @@ -137,7 +136,7 @@ impl, const D: usize> CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_output(i), + input: HashGate::::wire_output(i), }) }) .collect(), diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index d3c42bc6..16f9117d 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -545,10 +545,7 @@ where x: ExtensionTarget, ) -> ExtensionTarget { // x |--> x^7 - let x2 = builder.mul_extension(x, x); - let x4 = builder.mul_extension(x2, x2); - let x3 = builder.mul_extension(x, x2); - builder.mul_extension(x3, x4) + builder.exp_u64_extension(x, 7) } #[inline(always)] From e418997d6fa0b46d2246a776712b341405c13e56 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 17 Sep 2021 13:29:59 +0200 Subject: [PATCH 08/12] Cleanup --- src/hash/poseidon.rs | 26 -------------------------- src/iop/challenger.rs | 2 +- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 16f9117d..c3a01636 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -175,15 +175,8 @@ where v: &[F; WIDTH], ) -> F { debug_assert!(r < WIDTH); - // The values of MDS_MATRIX_EXPS are known to be small, so we can - // accumulate all the products for each row and reduce just once - // at the end (done by the caller). - - // NB: Unrolling this, calculating each term independently, and - // summing at the end, didn't improve performance for me. let mut res = F::ZERO; - // This is a hacky way of fully unrolling the loop. assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { @@ -203,15 +196,8 @@ where ) -> ExtensionTarget { let one = builder.one_extension(); debug_assert!(r < WIDTH); - // The values of MDS_MATRIX_EXPS are known to be small, so we can - // accumulate all the products for each row and reduce just once - // at the end (done by the caller). - - // NB: Unrolling this, calculating each term independently, and - // summing at the end, didn't improve performance for me. let mut res = builder.zero_extension(); - // This is a hacky way of fully unrolling the loop. assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { @@ -256,7 +242,6 @@ where ) -> [F; WIDTH] { let mut result = [F::ZERO; WIDTH]; - // This is a hacky way of fully unrolling the loop. assert!(WIDTH <= 12); for r in 0..12 { if r < WIDTH { @@ -275,7 +260,6 @@ where ) -> [ExtensionTarget; WIDTH] { let mut result = [builder.zero_extension(); WIDTH]; - // This is a hacky way of fully unrolling the loop. assert!(WIDTH <= 12); for r in 0..12 { if r < WIDTH { @@ -361,9 +345,6 @@ where let one = builder.one_extension(); let mut result = [builder.zero_extension(); WIDTH]; - // Initial matrix has first row/column = [1, 0, ..., 0]; - - // c = 0 result[0] = state[0]; assert!(WIDTH <= 12); @@ -372,9 +353,6 @@ where assert!(WIDTH <= 12); for r in 1..12 { if r < WIDTH { - // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in - // column-major order so that this dot product is cache - // friendly. let t = F::from_canonical_u64( Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], ); @@ -429,8 +407,6 @@ where state: &[F; WIDTH], r: usize, ) -> [F; WIDTH] { - // Set d = [M_00 | w^] dot [state] - let s0 = state[0]; let mut d = s0 * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]); assert!(WIDTH <= 12); @@ -464,7 +440,6 @@ where let zero = builder.zero_extension(); let one = builder.one_extension(); - // Set d = [M_00 | w^] dot [state] let s0 = state[0]; let mut d = builder.arithmetic_extension( F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), @@ -481,7 +456,6 @@ where } } - // result = [d] concat [state[0] * v + state[shift up by 1]] let mut result = [zero; WIDTH]; result[0] = d; assert!(WIDTH <= 12); diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 2fb43979..d33c19ce 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -377,7 +377,7 @@ mod tests { } let config = CircuitConfig::large_config(); - let mut builder = CircuitBuilder::::new(config.clone()); + let mut builder = CircuitBuilder::::new(config); let mut recursive_challenger = RecursiveChallenger::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); for (r, inputs) in inputs_per_round.iter().enumerate() { From f83c587cc56218cfbd04efaff646f72dbc906f99 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 17 Sep 2021 13:47:08 +0200 Subject: [PATCH 09/12] Comments --- src/gates/gmimc.rs | 3 +-- src/gates/poseidon.rs | 12 ++++++++++-- src/hash/poseidon.rs | 13 ++++++++++++- src/plonk/recursive_verifier.rs | 2 +- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 225af379..5c031cfe 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -17,8 +17,7 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// /// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. /// It has a flag which can be used to swap the first four inputs with the next four, for ordering -/// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for -/// computing the index of the leaf based on these swap bits. +/// sibling digests. #[derive(Debug)] pub struct GMiMCGate< F: RichField + Extendable + GMiMC, diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index e8fa6ccf..8ae5b2ad 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -18,8 +18,7 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// /// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. /// It has a flag which can be used to swap the first four inputs with the next four, for ordering -/// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for -/// computing the index of the leaf based on these swap bits. +/// sibling digests. #[derive(Debug)] pub struct PoseidonGate< F: RichField + Extendable + Poseidon, @@ -117,6 +116,7 @@ where let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; + // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); for i in 0..WIDTH { @@ -129,6 +129,7 @@ where round_ctr += 1; } + // Partial rounds. >::partial_first_constant_layer(&mut state); state = >::mds_partial_layer_init(&mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { @@ -149,6 +150,7 @@ where ); round_ctr += poseidon::N_PARTIAL_ROUNDS; + // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); for i in 0..WIDTH { @@ -193,6 +195,7 @@ where let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; + // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); for i in 0..WIDTH { @@ -205,6 +208,7 @@ where round_ctr += 1; } + // Partial rounds. >::partial_first_constant_layer(&mut state); state = >::mds_partial_layer_init(&mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { @@ -224,6 +228,7 @@ where ); round_ctr += poseidon::N_PARTIAL_ROUNDS; + // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); for i in 0..WIDTH { @@ -275,6 +280,7 @@ where let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; + // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_recursive(builder, &mut state, round_ctr); for i in 0..WIDTH { @@ -287,6 +293,7 @@ where round_ctr += 1; } + // Partial rounds. >::partial_first_constant_layer_recursive(builder, &mut state); state = >::mds_partial_layer_init_recursive(builder, &mut state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { @@ -313,6 +320,7 @@ where ); round_ctr += poseidon::N_PARTIAL_ROUNDS; + // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_recursive(builder, &mut state, round_ctr); for i in 0..WIDTH { diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index c3a01636..f11c8bf0 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -170,6 +170,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `mds_row_shf` for general fields. fn mds_row_shf_field, const D: usize>( r: usize, v: &[F; WIDTH], @@ -189,6 +190,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Recursive version of `mds_row_shf`. fn mds_row_shf_recursive, const D: usize>( builder: &mut CircuitBuilder, r: usize, @@ -237,6 +239,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `mds_layer` for general fields. fn mds_layer_field, const D: usize>( state: &[F; WIDTH], ) -> [F; WIDTH] { @@ -254,6 +257,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Recursive version of `mds_layer`. fn mds_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], @@ -285,6 +289,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Recursive version of `partial_first_constant_layer`. fn partial_first_constant_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], @@ -338,6 +343,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Recursive version of `mds_partial_layer_init`. fn mds_partial_layer_init_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], @@ -403,6 +409,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `mds_partial_layer_fast` for general fields. fn mds_partial_layer_fast_field, const D: usize>( state: &[F; WIDTH], r: usize, @@ -432,7 +439,8 @@ where #[inline(always)] #[unroll_for_loops] - fn mds_partial_layer_fast_field_recursive, const D: usize>( + /// Recursive version of `mds_partial_layer_fast`. + fn mds_partial_layer_fast_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], r: usize, @@ -484,6 +492,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Recursive version of `constant_layer`. fn constant_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], @@ -514,6 +523,7 @@ where } #[inline(always)] + /// Recursive version of `sbox_monomial`. fn sbox_monomial_recursive, const D: usize>( builder: &mut CircuitBuilder, x: ExtensionTarget, @@ -535,6 +545,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Recursive version of `sbox_layer`. fn sbox_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 6b0bd8c3..0b371b9d 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -361,7 +361,7 @@ mod tests { type F = CrandallField; const D: usize = 4; let config = CircuitConfig { - num_wires: 126, + num_wires: 143, num_routed_wires: 33, security_bits: 128, rate_bits: 3, From 675f32835ba9c518c80cdf6f0a48101cc4e1cb3e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 17 Sep 2021 13:50:42 +0200 Subject: [PATCH 10/12] Minor --- src/gates/poseidon.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 8ae5b2ad..c918c17a 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -307,13 +307,12 @@ where one, state[0], ); - state = - >::mds_partial_layer_fast_field_recursive(builder, &state, r); + state = >::mds_partial_layer_fast_recursive(builder, &state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(builder.sub_extension(state[0], sbox_in)); state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state = >::mds_partial_layer_fast_field_recursive( + state = >::mds_partial_layer_fast_recursive( builder, &state, poseidon::N_PARTIAL_ROUNDS - 1, From 14bbf5ae11baa27a76516ea79c7dbef6a1020753 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 17 Sep 2021 17:50:43 +0200 Subject: [PATCH 11/12] Fix AVX2 conflict --- src/gates/poseidon.rs | 24 ++++++++++++------------ src/hash/poseidon.rs | 30 +++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index c918c17a..651ac4fe 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -118,13 +118,13 @@ where // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -152,13 +152,13 @@ where // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -197,13 +197,13 @@ where // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -230,13 +230,13 @@ where // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -424,14 +424,14 @@ where let mut round_ctr = 0; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { out_buffer.set_wire( local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), state[i], ); } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } @@ -462,14 +462,14 @@ where round_ctr += poseidon::N_PARTIAL_ROUNDS; for r in 0..poseidon::HALF_N_FULL_ROUNDS { - >::constant_layer(&mut state, round_ctr); + >::constant_layer_field(&mut state, round_ctr); for i in 0..WIDTH { out_buffer.set_wire( local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), state[i], ); } - >::sbox_layer(&mut state); + >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); round_ctr += 1; } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index f11c8bf0..e3b6e1e0 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -9,7 +9,7 @@ use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::field::field_types::{PrimeField, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; // The number of full rounds and partial rounds is given by the @@ -478,7 +478,18 @@ where #[inline(always)] #[unroll_for_loops] - fn constant_layer, const D: usize>( + fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn constant_layer_field, const D: usize>( state: &mut [F; WIDTH], round_ctr: usize, ) { @@ -534,7 +545,20 @@ where #[inline(always)] #[unroll_for_loops] - fn sbox_layer, const D: usize>(state: &mut [F; WIDTH]) { + fn sbox_layer(state: &mut [Self; WIDTH]) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = Self::sbox_monomial(state[i]); + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn sbox_layer_field, const D: usize>( + state: &mut [F; WIDTH], + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { From 0be8650bca117a891e9e950e9c5e55bcfbdcc56e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Sat, 18 Sep 2021 09:23:39 +0200 Subject: [PATCH 12/12] PR feedback --- src/gates/poseidon.rs | 5 ++ src/hash/poseidon.rs | 129 +++++++++++++++--------------------------- 2 files changed, 52 insertions(+), 82 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 651ac4fe..3aef44ab 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -58,17 +58,22 @@ where /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < WIDTH); 2 * WIDTH + 1 + WIDTH * round + i } /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round } /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < WIDTH); 2 * WIDTH + 1 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index e3b6e1e0..ca519634 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -18,9 +18,9 @@ use crate::plonk::circuit_builder::CircuitBuilder; // // NB: Changing any of these values will require regenerating all of // the precomputed constant arrays in this file. -pub const HALF_N_FULL_ROUNDS: usize = 4; -pub const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; -pub const N_PARTIAL_ROUNDS: usize = 22; +pub(crate) const HALF_N_FULL_ROUNDS: usize = 4; +pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; +pub(crate) const N_PARTIAL_ROUNDS: usize = 22; const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) @@ -170,7 +170,7 @@ where #[inline(always)] #[unroll_for_loops] - /// Same as `mds_row_shf` for general fields. + /// Same as `mds_row_shf` for field extensions of `Self`. fn mds_row_shf_field, const D: usize>( r: usize, v: &[F; WIDTH], @@ -188,8 +188,6 @@ where res } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_row_shf`. fn mds_row_shf_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -200,17 +198,14 @@ where debug_assert!(r < WIDTH); let mut res = builder.zero_extension(); - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - res = builder.arithmetic_extension( - F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), - F::ONE, - one, - v[(i + r) % WIDTH], - res, - ); - } + for i in 0..WIDTH { + res = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), + F::ONE, + one, + v[(i + r) % WIDTH], + res, + ); } res @@ -239,7 +234,7 @@ where #[inline(always)] #[unroll_for_loops] - /// Same as `mds_layer` for general fields. + /// Same as `mds_layer` for field extensions of `Self`. fn mds_layer_field, const D: usize>( state: &[F; WIDTH], ) -> [F; WIDTH] { @@ -255,8 +250,6 @@ where result } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_layer`. fn mds_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -264,11 +257,8 @@ where ) -> [ExtensionTarget; WIDTH] { let mut result = [builder.zero_extension(); WIDTH]; - assert!(WIDTH <= 12); - for r in 0..12 { - if r < WIDTH { - result[r] = Self::mds_row_shf_recursive(builder, r, state); - } + for r in 0..WIDTH { + result[r] = Self::mds_row_shf_recursive(builder, r, state); } result @@ -287,25 +277,20 @@ where } } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `partial_first_constant_layer`. fn partial_first_constant_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], ) { let one = builder.one_extension(); - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - state[i] = builder.arithmetic_extension( - F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), - F::ONE, - one, - one, - state[i], - ); - } + for i in 0..WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), + F::ONE, + one, + one, + state[i], + ); } } @@ -341,8 +326,6 @@ where result } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_partial_layer_init`. fn mds_partial_layer_init_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -353,18 +336,14 @@ where result[0] = state[0]; - assert!(WIDTH <= 12); - for c in 1..12 { - if c < WIDTH { - assert!(WIDTH <= 12); - for r in 1..12 { - if r < WIDTH { - let t = F::from_canonical_u64( - Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], - ); - result[c] = - builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); - } + for c in 1..WIDTH { + assert!(WIDTH <= 12); + for r in 1..12 { + if r < WIDTH { + let t = F::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], + ); + result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); } } } @@ -409,7 +388,7 @@ where #[inline(always)] #[unroll_for_loops] - /// Same as `mds_partial_layer_fast` for general fields. + /// Same as `mds_partial_layer_fast` for field extensions of `Self`. fn mds_partial_layer_fast_field, const D: usize>( state: &[F; WIDTH], r: usize, @@ -437,8 +416,6 @@ where result } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `mds_partial_layer_fast`. fn mds_partial_layer_fast_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -456,12 +433,9 @@ where s0, zero, ); - assert!(WIDTH <= 12); - for i in 1..12 { - if i < WIDTH { - let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); - d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); - } + for i in 1..WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); } let mut result = [zero; WIDTH]; @@ -489,6 +463,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `constant_layer` for field extensions of `Self`. fn constant_layer_field, const D: usize>( state: &mut [F; WIDTH], round_ctr: usize, @@ -501,8 +476,6 @@ where } } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `constant_layer`. fn constant_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -510,17 +483,14 @@ where round_ctr: usize, ) { let one = builder.one_extension(); - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - state[i] = builder.arithmetic_extension( - F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), - F::ONE, - one, - one, - state[i], - ); - } + for i in 0..WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), + F::ONE, + one, + one, + state[i], + ); } } @@ -533,7 +503,6 @@ where x3 * x4 } - #[inline(always)] /// Recursive version of `sbox_monomial`. fn sbox_monomial_recursive, const D: usize>( builder: &mut CircuitBuilder, @@ -556,6 +525,7 @@ where #[inline(always)] #[unroll_for_loops] + /// Same as `sbox_layer` for field extensions of `Self`. fn sbox_layer_field, const D: usize>( state: &mut [F; WIDTH], ) { @@ -567,18 +537,13 @@ where } } - #[inline(always)] - #[unroll_for_loops] /// Recursive version of `sbox_layer`. fn sbox_layer_recursive, const D: usize>( builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], ) { - assert!(WIDTH <= 12); - for i in 0..12 { - if i < WIDTH { - state[i] = Self::sbox_monomial_recursive(builder, state[i]); - } + for i in 0..WIDTH { + state[i] = Self::sbox_monomial_recursive(builder, state[i]); } }