diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 371ac0a5..e8fa6ccf 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -131,7 +131,6 @@ where >::partial_first_constant_layer(&mut state); state = >::mds_partial_layer_init(&mut state); - // for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(state[0] - sbox_in); @@ -170,7 +169,78 @@ where } fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - todo!() + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints } fn eval_unfiltered_recursively( @@ -178,7 +248,89 @@ where builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - todo!() + let one = builder.one_extension(); + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(builder.mul_sub_extension(swap, swap, swap)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + + >::partial_first_constant_layer_recursive(builder, &mut state); + state = >::mds_partial_layer_init_recursive(builder, &mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state[0] = builder.arithmetic_extension( + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), + F::ONE, + one, + one, + state[0], + ); + state = + >::mds_partial_layer_fast_field_recursive(builder, &state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state = >::mds_partial_layer_fast_field_recursive( + builder, + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints + .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); + } + + constraints } fn generators( diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 4449e89f..d3c42bc6 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -7,8 +7,10 @@ use std::convert::TryInto; use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; -use crate::field::extension_field::FieldExtension; -use crate::field::field_types::{Field, PrimeField}; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::plonk::circuit_builder::CircuitBuilder; // The number of full rounds and partial rounds is given by the // calc_round_numbers.py script. They happen to be the same for both @@ -192,6 +194,40 @@ where res } + #[inline(always)] + #[unroll_for_loops] + fn mds_row_shf_recursive, const D: usize>( + builder: &mut CircuitBuilder, + r: usize, + v: &[ExtensionTarget; WIDTH], + ) -> ExtensionTarget { + let one = builder.one_extension(); + debug_assert!(r < WIDTH); + // The values of MDS_MATRIX_EXPS are known to be small, so we can + // accumulate all the products for each row and reduce just once + // at the end (done by the caller). + + // NB: Unrolling this, calculating each term independently, and + // summing at the end, didn't improve performance for me. + let mut res = builder.zero_extension(); + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + res = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), + F::ONE, + one, + v[(i + r) % WIDTH], + res, + ); + } + } + + res + } + #[inline(always)] #[unroll_for_loops] fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] { @@ -231,6 +267,25 @@ where result } + #[inline(always)] + #[unroll_for_loops] + fn mds_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + ) -> [ExtensionTarget; WIDTH] { + let mut result = [builder.zero_extension(); WIDTH]; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for r in 0..12 { + if r < WIDTH { + result[r] = Self::mds_row_shf_recursive(builder, r, state); + } + } + + result + } + #[inline(always)] #[unroll_for_loops] fn partial_first_constant_layer, const D: usize>( @@ -244,6 +299,27 @@ where } } + #[inline(always)] + #[unroll_for_loops] + fn partial_first_constant_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + ) { + let one = builder.one_extension(); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), + F::ONE, + one, + one, + state[i], + ); + } + } + } + #[inline(always)] #[unroll_for_loops] fn mds_partial_layer_init, const D: usize>( @@ -276,6 +352,41 @@ where result } + #[inline(always)] + #[unroll_for_loops] + fn mds_partial_layer_init_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + ) -> [ExtensionTarget; WIDTH] { + let one = builder.one_extension(); + let mut result = [builder.zero_extension(); WIDTH]; + + // Initial matrix has first row/column = [1, 0, ..., 0]; + + // c = 0 + result[0] = state[0]; + + assert!(WIDTH <= 12); + for c in 1..12 { + if c < WIDTH { + assert!(WIDTH <= 12); + for r in 1..12 { + if r < WIDTH { + // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in + // column-major order so that this dot product is cache + // friendly. + let t = F::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], + ); + result[c] = + builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); + } + } + } + } + result + } + /// Computes s*A where s is the state row vector and A is the matrix /// /// [ M_00 | v ] @@ -343,6 +454,46 @@ where result } + #[inline(always)] + #[unroll_for_loops] + fn mds_partial_layer_fast_field_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + r: usize, + ) -> [ExtensionTarget; WIDTH] { + let zero = builder.zero_extension(); + let one = builder.one_extension(); + + // Set d = [M_00 | w^] dot [state] + let s0 = state[0]; + let mut d = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), + F::ONE, + one, + s0, + zero, + ); + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); + } + } + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let mut result = [zero; WIDTH]; + result[0] = d; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]); + } + } + result + } + #[inline(always)] #[unroll_for_loops] fn constant_layer, const D: usize>( @@ -357,6 +508,28 @@ where } } + #[inline(always)] + #[unroll_for_loops] + fn constant_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + round_ctr: usize, + ) { + let one = builder.one_extension(); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), + F::ONE, + one, + one, + state[i], + ); + } + } + } + #[inline(always)] fn sbox_monomial, const D: usize>(x: F) -> F { // x |--> x^7 @@ -366,6 +539,18 @@ where x3 * x4 } + #[inline(always)] + fn sbox_monomial_recursive, const D: usize>( + builder: &mut CircuitBuilder, + x: ExtensionTarget, + ) -> ExtensionTarget { + // x |--> x^7 + let x2 = builder.mul_extension(x, x); + let x4 = builder.mul_extension(x2, x2); + let x3 = builder.mul_extension(x, x2); + builder.mul_extension(x3, x4) + } + #[inline(always)] #[unroll_for_loops] fn sbox_layer, const D: usize>(state: &mut [F; WIDTH]) { @@ -377,6 +562,20 @@ where } } + #[inline(always)] + #[unroll_for_loops] + fn sbox_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + ) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = Self::sbox_monomial_recursive(builder, state[i]); + } + } + } + #[inline] fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { for _ in 0..HALF_N_FULL_ROUNDS {