//! Implementation of the Poseidon2 hash function as Plonky2 Gate //! based on Poseidon Gate: //! https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/gates/poseidon.rs //! use core::marker::PhantomData; use plonky2_field::extension::Extendable; use plonky2_field::types::Field; use plonky2::gates::gate::Gate; use plonky2::gates::util::StridedConstraintConsumer; use plonky2::hash::hash_types::RichField; use crate::poseidon2_hash::poseidon2::{Poseidon2, FULL_ROUND_BEGIN, FULL_ROUND_END, PARTIAL_ROUNDS, SPONGE_WIDTH}; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef}; use plonky2::iop::target::Target; use plonky2::iop::wire::Wire; use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CommonCircuitData; use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use plonky2::util::serialization::{Buffer, IoResult, Read, Write}; /// Evaluates a full Poseidon2 permutation with 12 state elements. /// /// This also has some extra features to make it suitable for efficiently /// verifying Merkle proofs. It has a flag which can be used to swap the first /// four inputs with the next four, for ordering sibling digests. #[derive(Debug, Default)] pub struct Poseidon2Gate, const D: usize>(PhantomData); impl, const D: usize> Poseidon2Gate { pub fn new() -> Self { Self(PhantomData) } /// The wire index for the `i`th input to the permutation. pub fn wire_input(i: usize) -> usize { i } /// The wire index for the `i`th output to the permutation. pub fn wire_output(i: usize) -> usize { SPONGE_WIDTH + i } /// If this is set to 1, the first four inputs will be swapped with the next /// four inputs. This is useful for ordering hashes in Merkle proofs. /// Otherwise, this should be set to 0. pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH; const START_DELTA: usize = 2 * SPONGE_WIDTH + 1; /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute /// the swapped inputs. fn wire_delta(i: usize) -> usize { assert!(i < 4); Self::START_DELTA + i } const START_FULL_ROUND_BEGIN: usize = Self::START_DELTA + 4; /// A wire which stores the input of the `i`-th S-box of the `round`-th /// round of the first set of full rounds. fn wire_first_full_round(round: usize, i: usize) -> usize { debug_assert!( round != 0, "First round S-box inputs are not stored as wires" ); debug_assert!(round < FULL_ROUND_BEGIN); debug_assert!(i < SPONGE_WIDTH); Self::START_FULL_ROUND_BEGIN + SPONGE_WIDTH * (round - 1) + i } const START_PARTIAL: usize = Self::START_FULL_ROUND_BEGIN + SPONGE_WIDTH * (FULL_ROUND_BEGIN - 1); /// A wire which stores the input of the S-box of the `round`-th round of /// the partial rounds. fn wire_partial_round(round: usize) -> usize { debug_assert!(round < PARTIAL_ROUNDS); Self::START_PARTIAL + round } const START_FULL_ROUND_END: usize = Self::START_PARTIAL + PARTIAL_ROUNDS; /// A wire which stores the input of the `i`-th S-box of the `round`-th /// round of the second set of full rounds. fn wire_second_full_round(round: usize, i: usize) -> usize { debug_assert!(round < FULL_ROUND_BEGIN); debug_assert!(i < SPONGE_WIDTH); Self::START_FULL_ROUND_END + SPONGE_WIDTH * round + i } /// End of wire indices, exclusive. fn end() -> usize { Self::START_FULL_ROUND_END + SPONGE_WIDTH * FULL_ROUND_BEGIN } } impl + Poseidon2, const D: usize> Gate for Poseidon2Gate { fn id(&self) -> String { format!("{:?}", self, SPONGE_WIDTH) } fn serialize( &self, _dst: &mut Vec, _common_data: &CommonCircuitData, ) -> IoResult<()> { Ok(()) } fn deserialize(_src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { Ok(Poseidon2Gate::new()) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); // Assert that `swap` is binary. let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { let input_lhs = vars.local_wires[Self::wire_input(i)]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; let delta_i = vars.local_wires[Self::wire_delta(i)]; constraints.push(swap * (input_rhs - input_lhs) - delta_i); } // Compute the possibly-swapped input layer. let mut state = [F::Extension::ZERO; SPONGE_WIDTH]; for i in 0..4 { let delta_i = vars.local_wires[Self::wire_delta(i)]; let input_lhs = Self::wire_input(i); let input_rhs = Self::wire_input(i + 4); state[i] = vars.local_wires[input_lhs] + delta_i; state[i + 4] = vars.local_wires[input_rhs] - delta_i; } for i in 8..SPONGE_WIDTH { state[i] = vars.local_wires[Self::wire_input(i)]; } // linear layer ::matmul_external_field(&mut state); // First External layer for r in 0..FULL_ROUND_BEGIN { ::constant_layer_field(&mut state, r); if r != 0 { for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } } ::sbox_layer_field(&mut state); ::matmul_external_field(&mut state); } // Internal layer for r in 0..PARTIAL_ROUNDS { state[0] += F::Extension::from_canonical_u64(::RC12_MID[r]); let sbox_in = vars.local_wires[Self::wire_partial_round(r)]; constraints.push(state[0] - sbox_in); state[0] = ::sbox_p(sbox_in); ::matmul_internal_field(&mut state, &::MAT_DIAG12_M_1); } // Second External layer for r in FULL_ROUND_BEGIN..FULL_ROUND_END { ::constant_layer_field(&mut state, r); for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)]; constraints.push(state[i] - sbox_in); state[i] = sbox_in; } ::sbox_layer_field(&mut state); ::matmul_external_field(&mut state); } //12 constraints for i in 0..SPONGE_WIDTH { constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); } constraints } fn eval_unfiltered_base_one( &self, vars: EvaluationVarsBase, mut yield_constr: StridedConstraintConsumer, ) { // Assert that `swap` is binary. let swap = vars.local_wires[Self::WIRE_SWAP]; yield_constr.one(swap * swap.sub_one()); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { let input_lhs = vars.local_wires[Self::wire_input(i)]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; let delta_i = vars.local_wires[Self::wire_delta(i)]; yield_constr.one(swap * (input_rhs - input_lhs) - delta_i); } // Compute the possibly-swapped input layer. let mut state = [F::ZERO; SPONGE_WIDTH]; for i in 0..4 { let delta_i = vars.local_wires[Self::wire_delta(i)]; let input_lhs = Self::wire_input(i); let input_rhs = Self::wire_input(i + 4); state[i] = vars.local_wires[input_lhs] + delta_i; state[i + 4] = vars.local_wires[input_rhs] - delta_i; } for i in 8..SPONGE_WIDTH { state[i] = vars.local_wires[Self::wire_input(i)]; } // linear layer ::matmul_external(&mut state); // First External layer for r in 0..FULL_ROUND_BEGIN { ::constant_layer(&mut state, r); if r != 0 { for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)]; yield_constr.one(state[i] - sbox_in); state[i] = sbox_in; } } ::sbox_layer(&mut state); ::matmul_external(&mut state); } // Internal layer for r in 0..PARTIAL_ROUNDS { state[0] += F::from_canonical_u64(::RC12_MID[r]); let sbox_in = vars.local_wires[Self::wire_partial_round(r)]; yield_constr.one(state[0] - sbox_in); state[0] = sbox_in; state[0] = ::sbox_p(state[0]); ::matmul_internal(&mut state, &::MAT_DIAG12_M_1); } // Second External layer for r in FULL_ROUND_BEGIN..FULL_ROUND_END { ::constant_layer(&mut state, r); for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)]; yield_constr.one(state[i] - sbox_in); state[i] = sbox_in; } ::sbox_layer(&mut state); ::matmul_external(&mut state); } for i in 0..SPONGE_WIDTH { yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); } } fn eval_unfiltered_circuit( &self, builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); // Assert that `swap` is binary. let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { let input_lhs = vars.local_wires[Self::wire_input(i)]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; let delta_i = vars.local_wires[Self::wire_delta(i)]; let diff = builder.sub_extension(input_rhs, input_lhs); constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); } // Compute the possibly-swapped input layer. let mut state = [builder.zero_extension(); SPONGE_WIDTH]; for i in 0..4 { let delta_i = vars.local_wires[Self::wire_delta(i)]; let input_lhs = vars.local_wires[Self::wire_input(i)]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; state[i] = builder.add_extension(input_lhs, delta_i); state[i + 4] = builder.sub_extension(input_rhs, delta_i); } for i in 8..SPONGE_WIDTH { state[i] = vars.local_wires[Self::wire_input(i)]; } // linear layer state = ::matmul_external_circuit(builder, &mut state); // First External layer for r in 0..FULL_ROUND_BEGIN { ::constant_layer_circuit(builder, &mut state, r); if r != 0 { for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)]; constraints.push(builder.sub_extension(state[i], sbox_in)); state[i] = sbox_in; } } ::sbox_layer_circuit(builder, &mut state); state = ::matmul_external_circuit(builder, &mut state); } // Internal layer for r in 0..PARTIAL_ROUNDS { let round_constant = F::Extension::from_canonical_u64(::RC12_MID[r]); let round_constant = builder.constant_extension(round_constant); state[0] = builder.add_extension(state[0], round_constant); let sbox_in = vars.local_wires[Self::wire_partial_round(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); //state[0] = sbox_in; state[0] = ::sbox_p_circuit(builder, sbox_in); ::matmul_internal_circuit(builder, &mut state); } // Second External layer for r in FULL_ROUND_BEGIN..FULL_ROUND_END { ::constant_layer_circuit(builder, &mut state, r); for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)]; constraints.push(builder.sub_extension(state[i], sbox_in)); state[i] = sbox_in; } ::sbox_layer_circuit(builder, &mut state); state = ::matmul_external_circuit(builder, &mut state); } for i in 0..SPONGE_WIDTH { constraints .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); } constraints } fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { let gen = Poseidon2Generator:: { row, _phantom: PhantomData, }; vec![WitnessGeneratorRef::new(gen.adapter())] } fn num_wires(&self) -> usize { Self::end() } fn num_constants(&self) -> usize { 0 } fn degree(&self) -> usize { 7 } fn num_constraints(&self) -> usize { SPONGE_WIDTH * (FULL_ROUND_END - 1) + PARTIAL_ROUNDS + SPONGE_WIDTH + 1 + 4 } } #[derive(Debug, Default)] pub struct Poseidon2Generator + Poseidon2, const D: usize> { row: usize, _phantom: PhantomData, } impl + Poseidon2, const D: usize> SimpleGenerator for Poseidon2Generator { fn id(&self) -> String { "Poseidon2Generator".to_string() } fn dependencies(&self) -> Vec { (0..SPONGE_WIDTH) .map(|i| Poseidon2Gate::::wire_input(i)) .chain(Some(Poseidon2Gate::::WIRE_SWAP)) .map(|column| Target::wire(self.row, column)) .collect() } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> anyhow::Result<()> { let local_wire = |column| Wire { row: self.row, column, }; let mut state = (0..SPONGE_WIDTH) .map(|i| witness.get_wire(local_wire(Poseidon2Gate::::wire_input(i)))) .collect::>(); let swap_value = witness.get_wire(local_wire(Poseidon2Gate::::WIRE_SWAP)); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); for i in 0..4 { let delta_i = swap_value * (state[i + 4] - state[i]); out_buffer.set_wire(local_wire(Poseidon2Gate::::wire_delta(i)), delta_i)?; } if swap_value == F::ONE { for i in 0..4 { state.swap(i, 4 + i); } } let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); // Linear layer ::matmul_external_field(&mut state); // first External layer for r in 0..FULL_ROUND_BEGIN { ::constant_layer_field(&mut state, r); if r != 0 { for i in 0..SPONGE_WIDTH { out_buffer.set_wire( local_wire(Poseidon2Gate::::wire_first_full_round(r, i)), state[i], )?; } } ::sbox_layer_field(&mut state); ::matmul_external_field(&mut state); } // Internal layer for r in 0..PARTIAL_ROUNDS { state[0] += F::from_canonical_u64(::RC12_MID[r]); out_buffer.set_wire( local_wire(Poseidon2Gate::::wire_partial_round(r)), state[0], )?; state[0] = ::sbox_p(state[0]); ::matmul_internal_field(&mut state, &::MAT_DIAG12_M_1); } // Second External layer for r in FULL_ROUND_BEGIN..FULL_ROUND_END { ::constant_layer_field(&mut state, r); for i in 0..SPONGE_WIDTH { out_buffer.set_wire( local_wire(Poseidon2Gate::::wire_second_full_round( r - FULL_ROUND_BEGIN, i, )), state[i], )?; } ::sbox_layer_field(&mut state); ::matmul_external_field(&mut state); } for i in 0..SPONGE_WIDTH { out_buffer.set_wire(local_wire(Poseidon2Gate::::wire_output(i)), state[i])?; } Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { dst.write_usize(self.row) } fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { let row = src.read_usize()?; Ok(Self { row, _phantom: PhantomData, }) } } //------------------------------------- Tests ----------------------------------------- #[cfg(test)] mod tests { use anyhow::Result; use plonky2_field::goldilocks_field::GoldilocksField; use plonky2_field::types::Field; use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gate::poseidon2::Poseidon2Gate; use crate::poseidon2_hash::poseidon2::{Poseidon2, SPONGE_WIDTH}; use plonky2::iop::generator::generate_partial_witness; use plonky2::iop::wire::Wire; use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::GenericConfig; use crate::config::Poseidon2GoldilocksConfig; #[test] fn wire_indices() { type F = GoldilocksField; type Gate = Poseidon2Gate; assert_eq!(Gate::wire_input(0), 0); assert_eq!(Gate::wire_input(11), 11); assert_eq!(Gate::wire_output(0), 12); assert_eq!(Gate::wire_output(11), 23); assert_eq!(Gate::WIRE_SWAP, 24); assert_eq!(Gate::wire_delta(0), 25); assert_eq!(Gate::wire_delta(3), 28); } #[test] fn generated_output() -> Result<()>{ const D: usize = 2; type C = Poseidon2GoldilocksConfig; type F = >::F; let config = CircuitConfig { num_wires: 143, ..CircuitConfig::standard_recursion_config() }; let mut builder = CircuitBuilder::new(config); type Gate = Poseidon2Gate; let gate = Gate::new(); let row = builder.add_gate(gate, vec![]); let circuit = builder.build_prover::(); println!("width = {}", SPONGE_WIDTH); let permutation_inputs = (0..SPONGE_WIDTH).map(F::from_canonical_usize).collect::>(); for i in 0..SPONGE_WIDTH { println!("out {} = {}", i, permutation_inputs[i].clone()); } let mut inputs = PartialWitness::new(); inputs.set_wire( Wire { row, column: Gate::WIRE_SWAP, }, F::ZERO, )?; for i in 0..SPONGE_WIDTH { inputs.set_wire( Wire { row, column: Gate::wire_input(i), }, permutation_inputs[i], )?; } let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap(); let expected_outputs: [F; SPONGE_WIDTH] = F::poseidon2(permutation_inputs.try_into().unwrap()); for i in 0..SPONGE_WIDTH { let out = witness.get_wire(Wire { row: 0, column: Gate::wire_output(i), }); println!("out {} = {}", i, out.clone()); assert_eq!(out, expected_outputs[i]); }; Ok(()) } #[test] fn low_degree() { type F = GoldilocksField; let gate = Poseidon2Gate::::new(); test_low_degree(gate) } #[test] fn eval_fns() -> Result<()> { const D: usize = 2; type C = Poseidon2GoldilocksConfig; type F = >::F; let gate = Poseidon2Gate::::new(); test_eval_fns::(gate) } #[test] fn test_proof() -> Result<()>{ use plonky2_field::types::Sample; use plonky2::gates::gate::Gate; use plonky2::hash::hash_types::HashOut; use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars}; const D: usize = 2; type C = Poseidon2GoldilocksConfig; type F = >::F; let gate = Poseidon2Gate::::new(); let wires = <>::F as plonky2_field::extension::Extendable>::Extension::rand_vec(gate.num_wires()); let constants = <>::F as plonky2_field::extension::Extendable>::Extension::rand_vec(gate.num_constants()); let public_inputs_hash = HashOut::rand(); let config = CircuitConfig::standard_recursion_config(); let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let wires_t = builder.add_virtual_extension_targets(wires.len()); let constants_t = builder.add_virtual_extension_targets(constants.len()); pw.set_extension_targets(&wires_t, &wires)?; pw.set_extension_targets(&constants_t, &constants)?; let public_inputs_hash_t = builder.add_virtual_hash(); pw.set_hash_target(public_inputs_hash_t, public_inputs_hash)?; let vars = EvaluationVars { local_constants: &constants, local_wires: &wires, public_inputs_hash: &public_inputs_hash, }; let evals = gate.eval_unfiltered(vars); let vars_t = EvaluationTargets { local_constants: &constants_t, local_wires: &wires_t, public_inputs_hash: &public_inputs_hash_t, }; let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); pw.set_extension_targets(&evals_t, &evals)?; let data = builder.build::(); let proof = data.prove(pw); assert!(proof.is_ok()); Ok(()) } }