diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index 844fef25..2b906a61 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -18,7 +18,7 @@ impl, const D: usize> CircuitBuilder { /// Like `select_ext`, but accepts a condition input which does not necessarily have to be /// binary. In this case, it computes the arithmetic generalization of `if b { x } else { y }`, - /// i.e. `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`. + /// i.e. `bx - (by-y)`. pub fn select_ext_generalized( &mut self, b: ExtensionTarget, diff --git a/src/gates/mod.rs b/src/gates/mod.rs index a3513361..76066285 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -14,6 +14,7 @@ pub mod insertion; pub mod interpolation; pub mod noop; pub mod poseidon; +pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; pub mod reducing; diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 493bd263..39fe0e98 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -3,8 +3,9 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; -use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; +use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::hash::poseidon; use crate::hash::poseidon::Poseidon; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -93,7 +94,7 @@ where [(); WIDTH - 1]: , { fn id(&self) -> String { - format!(" {:?}", WIDTH, self) + format!("{:?}", self, WIDTH) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { @@ -256,6 +257,10 @@ where builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { + // The naive method is more efficient if we have enough routed wires for PoseidonMdsGate. + let naive = + builder.config.num_routed_wires >= PoseidonMdsGate::::new().num_wires(); + let mut constraints = Vec::with_capacity(self.num_constraints()); // Assert that `swap` is binary. @@ -263,18 +268,23 @@ where constraints.push(builder.mul_sub_extension(swap, swap, swap)); let mut state = Vec::with_capacity(WIDTH); + // We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`. + // We will arithmetize them as + // swap (b - a) + a + // -swap (b - a) + b + // so that `b - a` can be used for both. + let mut state_first_4 = vec![]; + let mut state_next_4 = vec![]; for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; let delta = builder.sub_extension(b, a); - state.push(builder.mul_add_extension(swap, delta, a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - let delta = builder.sub_extension(b, a); - state.push(builder.mul_add_extension(swap, delta, a)); + state_first_4.push(builder.mul_add_extension(swap, delta, a)); + state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b)); } + + state.extend(state_first_4); + state.extend(state_next_4); for i in 8..WIDTH { state.push(vars.local_wires[i]); } @@ -296,27 +306,39 @@ where } // Partial rounds. - >::partial_first_constant_layer_recursive(builder, &mut state); - state = >::mds_partial_layer_init_recursive(builder, &mut state); - for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { - let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + if naive { + for r in 0..poseidon::N_PARTIAL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + } else { + >::partial_first_constant_layer_recursive(builder, &mut state); + state = >::mds_partial_layer_init_recursive(builder, &mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state[0] = builder.add_const_extension( + state[0], + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), + ); + state = + >::mds_partial_layer_fast_recursive(builder, &state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; constraints.push(builder.sub_extension(state[0], sbox_in)); state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state[0] = builder.add_const_extension( - state[0], - F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), + state = >::mds_partial_layer_fast_recursive( + builder, + &state, + poseidon::N_PARTIAL_ROUNDS - 1, ); - state = >::mds_partial_layer_fast_recursive(builder, &state, r); + round_ctr += poseidon::N_PARTIAL_ROUNDS; } - let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; - constraints.push(builder.sub_extension(state[0], sbox_in)); - state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state = >::mds_partial_layer_fast_recursive( - builder, - &state, - poseidon::N_PARTIAL_ROUNDS - 1, - ); - round_ctr += poseidon::N_PARTIAL_ROUNDS; // Second set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs new file mode 100644 index 00000000..8a42b588 --- /dev/null +++ b/src/gates/poseidon_mds.rs @@ -0,0 +1,274 @@ +use std::convert::TryInto; +use std::marker::PhantomData; +use std::ops::Range; + +use crate::field::extension_field::algebra::ExtensionAlgebra; +use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::hash::poseidon::Poseidon; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +#[derive(Debug)] +pub struct PoseidonMdsGate< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + PoseidonMdsGate +where + [(); WIDTH - 1]: , +{ + pub fn new() -> Self { + PoseidonMdsGate { + _phantom: PhantomData, + } + } + + pub fn wires_input(i: usize) -> Range { + assert!(i < WIDTH); + i * D..(i + 1) * D + } + + pub fn wires_output(i: usize) -> Range { + assert!(i < WIDTH); + (WIDTH + i) * D..(WIDTH + i + 1) * D + } + + // Following are methods analogous to ones in `Poseidon`, but for extension algebras. + + /// Same as `mds_row_shf` for an extension algebra of `F`. + fn mds_row_shf_algebra( + r: usize, + v: &[ExtensionAlgebra; WIDTH], + ) -> ExtensionAlgebra { + debug_assert!(r < WIDTH); + let mut res = ExtensionAlgebra::ZERO; + + for i in 0..WIDTH { + let coeff = + F::Extension::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]); + res += v[(i + r) % WIDTH].scalar_mul(coeff); + } + + res + } + + /// Same as `mds_row_shf_recursive` for an extension algebra of `F`. + fn mds_row_shf_algebra_recursive( + builder: &mut CircuitBuilder, + r: usize, + v: &[ExtensionAlgebraTarget; WIDTH], + ) -> ExtensionAlgebraTarget { + debug_assert!(r < WIDTH); + let mut res = builder.zero_ext_algebra(); + + for i in 0..WIDTH { + let coeff = builder.constant_extension(F::Extension::from_canonical_u64( + 1 << >::MDS_MATRIX_EXPS[i], + )); + res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res); + } + + res + } + + /// Same as `mds_layer` for an extension algebra of `F`. + fn mds_layer_algebra( + state: &[ExtensionAlgebra; WIDTH], + ) -> [ExtensionAlgebra; WIDTH] { + let mut result = [ExtensionAlgebra::ZERO; WIDTH]; + + for r in 0..WIDTH { + result[r] = Self::mds_row_shf_algebra(r, state); + } + + result + } + + /// Same as `mds_layer_recursive` for an extension algebra of `F`. + fn mds_layer_algebra_recursive( + builder: &mut CircuitBuilder, + state: &[ExtensionAlgebraTarget; WIDTH], + ) -> [ExtensionAlgebraTarget; WIDTH] { + let mut result = [builder.zero_ext_algebra(); WIDTH]; + + for r in 0..WIDTH { + result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state); + } + + result + } +} + +impl + Poseidon, const D: usize, const WIDTH: usize> Gate + for PoseidonMdsGate +where + [(); WIDTH - 1]: , +{ + fn id(&self) -> String { + format!("{:?}", self, WIDTH) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let inputs: [_; WIDTH] = (0..WIDTH) + .map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) + .collect::>() + .try_into() + .unwrap(); + + let computed_outputs = Self::mds_layer_algebra(&inputs); + + (0..WIDTH) + .map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) + .zip(computed_outputs) + .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) + .collect() + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let inputs: [_; WIDTH] = (0..WIDTH) + .map(|i| vars.get_local_ext(Self::wires_input(i))) + .collect::>() + .try_into() + .unwrap(); + + let computed_outputs = F::mds_layer_field(&inputs); + + (0..WIDTH) + .map(|i| vars.get_local_ext(Self::wires_output(i))) + .zip(computed_outputs) + .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) + .collect() + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let inputs: [_; WIDTH] = (0..WIDTH) + .map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) + .collect::>() + .try_into() + .unwrap(); + + let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs); + + (0..WIDTH) + .map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) + .zip(computed_outputs) + .flat_map(|(out, computed_out)| { + builder + .sub_ext_algebra(out, computed_out) + .to_ext_target_array() + }) + .collect() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = PoseidonMdsGenerator:: { gate_index }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + 2 * D * WIDTH + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 + } + + fn num_constraints(&self) -> usize { + WIDTH * D + } +} + +#[derive(Clone, Debug)] +struct PoseidonMdsGenerator +where + [(); WIDTH - 1]: , +{ + gate_index: usize, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + SimpleGenerator for PoseidonMdsGenerator +where + [(); WIDTH - 1]: , +{ + fn dependencies(&self) -> Vec { + (0..WIDTH) + .flat_map(|i| { + Target::wires_from_range( + self.gate_index, + PoseidonMdsGate::::wires_input(i), + ) + }) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_local_get_target = + |wire_range| ExtensionTarget::from_range(self.gate_index, wire_range); + let get_local_ext = + |wire_range| witness.get_extension_target(get_local_get_target(wire_range)); + + let inputs: [_; WIDTH] = (0..WIDTH) + .map(|i| get_local_ext(PoseidonMdsGate::::wires_input(i))) + .collect::>() + .try_into() + .unwrap(); + + let outputs = F::mds_layer_field(&inputs); + + for (i, &out) in outputs.iter().enumerate() { + out_buffer.set_extension_target( + get_local_get_target(PoseidonMdsGate::::wires_output(i)), + out, + ); + } + } +} + +#[cfg(test)] +mod tests { + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::poseidon_mds::PoseidonMdsGate; + use crate::hash::hashing::SPONGE_WIDTH; + + #[test] + fn low_degree() { + type F = GoldilocksField; + let gate = PoseidonMdsGate::::new(); + test_low_degree(gate) + } + + #[test] + fn eval_fns() -> anyhow::Result<()> { + type F = GoldilocksField; + let gate = PoseidonMdsGate::::new(); + test_eval_fns(gate) + } +} diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index bb3e8a75..193ef2c3 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -1,11 +1,15 @@ //! Implementation of the Poseidon hash function, as described in //! https://eprint.iacr.org/2019/458.pdf +use std::convert::TryInto; + use unroll::unroll_for_loops; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{PrimeField, RichField}; +use crate::gates::gate::Gate; +use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::plonk::circuit_builder::CircuitBuilder; // The number of full rounds and partial rounds is given by the @@ -205,17 +209,20 @@ where } /// Recursive version of `mds_row_shf`. - fn mds_row_shf_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn mds_row_shf_recursive( + builder: &mut CircuitBuilder, r: usize, v: &[ExtensionTarget; WIDTH], - ) -> ExtensionTarget { + ) -> ExtensionTarget + where + Self: RichField + Extendable, + { debug_assert!(r < WIDTH); let mut res = builder.zero_extension(); for i in 0..WIDTH { res = builder.mul_const_add_extension( - F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), + Self::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]), v[(i + r) % WIDTH], res, ); @@ -262,17 +269,38 @@ where } /// Recursive version of `mds_layer`. - fn mds_layer_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn mds_layer_recursive( + builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], - ) -> [ExtensionTarget; WIDTH] { - let mut result = [builder.zero_extension(); WIDTH]; + ) -> [ExtensionTarget; WIDTH] + where + Self: RichField + Extendable, + { + // If we have enough routed wires, we will use PoseidonMdsGate. + let mds_gate = PoseidonMdsGate::::new(); + if builder.config.num_routed_wires >= mds_gate.num_wires() { + let index = builder.add_gate(mds_gate, vec![]); + for i in 0..WIDTH { + let input_wire = PoseidonMdsGate::::wires_input(i); + builder.connect_extension(state[i], ExtensionTarget::from_range(index, input_wire)); + } + (0..WIDTH) + .map(|i| { + let output_wire = PoseidonMdsGate::::wires_output(i); + ExtensionTarget::from_range(index, output_wire) + }) + .collect::>() + .try_into() + .unwrap() + } else { + let mut result = [builder.zero_extension(); WIDTH]; - for r in 0..WIDTH { - result[r] = Self::mds_row_shf_recursive(builder, r, state); + for r in 0..WIDTH { + result[r] = Self::mds_row_shf_recursive(builder, r, state); + } + + result } - - result } #[inline(always)] @@ -289,14 +317,18 @@ where } /// Recursive version of `partial_first_constant_layer`. - fn partial_first_constant_layer_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn partial_first_constant_layer_recursive( + builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], - ) { + ) where + Self: RichField + Extendable, + { for i in 0..WIDTH { state[i] = builder.add_const_extension( state[i], - F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), + Self::from_canonical_u64( + >::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i], + ), ); } } @@ -334,18 +366,22 @@ where } /// Recursive version of `mds_partial_layer_init`. - fn mds_partial_layer_init_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn mds_partial_layer_init_recursive( + builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], - ) -> [ExtensionTarget; WIDTH] { + ) -> [ExtensionTarget; WIDTH] + where + Self: RichField + Extendable, + { let mut result = [builder.zero_extension(); WIDTH]; result[0] = state[0]; for r in 1..WIDTH { for c in 1..WIDTH { - let t = - F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]); + let t = Self::from_canonical_u64( + >::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1], + ); result[c] = builder.mul_const_add_extension(t, state[r], result[c]); } } @@ -414,23 +450,32 @@ where } /// Recursive version of `mds_partial_layer_fast`. - fn mds_partial_layer_fast_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn mds_partial_layer_fast_recursive( + builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], r: usize, - ) -> [ExtensionTarget; WIDTH] { + ) -> [ExtensionTarget; WIDTH] + where + Self: RichField + Extendable, + { let s0 = state[0]; - let mut d = - builder.mul_const_extension(F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), s0); + let mut d = builder.mul_const_extension( + Self::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[0]), + s0, + ); for i in 1..WIDTH { - let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + let t = Self::from_canonical_u64( + >::FAST_PARTIAL_ROUND_W_HATS[r][i - 1], + ); d = builder.mul_const_add_extension(t, state[i], d); } let mut result = [builder.zero_extension(); WIDTH]; result[0] = d; for i in 1..WIDTH { - let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + let t = Self::from_canonical_u64( + >::FAST_PARTIAL_ROUND_VS[r][i - 1], + ); result[i] = builder.mul_const_add_extension(t, state[0], state[i]); } result @@ -461,15 +506,17 @@ where } /// Recursive version of `constant_layer`. - fn constant_layer_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn constant_layer_recursive( + builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], round_ctr: usize, - ) { + ) where + Self: RichField + Extendable, + { for i in 0..WIDTH { state[i] = builder.add_const_extension( state[i], - F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), + Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), ); } } @@ -484,10 +531,13 @@ where } /// Recursive version of `sbox_monomial`. - fn sbox_monomial_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn sbox_monomial_recursive( + builder: &mut CircuitBuilder, x: ExtensionTarget, - ) -> ExtensionTarget { + ) -> ExtensionTarget + where + Self: RichField + Extendable, + { // x |--> x^7 builder.exp_u64_extension(x, 7) } @@ -513,12 +563,14 @@ where } /// Recursive version of `sbox_layer`. - fn sbox_layer_recursive, const D: usize>( - builder: &mut CircuitBuilder, + fn sbox_layer_recursive( + builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], - ) { + ) where + Self: RichField + Extendable, + { for i in 0..WIDTH { - state[i] = Self::sbox_monomial_recursive(builder, state[i]); + state[i] = >::sbox_monomial_recursive(builder, state[i]); } }