From 49ba7ccb52e8ffaa909326800ebfc381310961d9 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 16 Sep 2021 18:16:05 +0200 Subject: [PATCH] Working --- src/gates/poseidon.rs | 407 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 407 insertions(+) create mode 100644 src/gates/poseidon.rs diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs new file mode 100644 index 00000000..2d333696 --- /dev/null +++ b/src/gates/poseidon.rs @@ -0,0 +1,407 @@ +use std::convert::TryInto; +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::hash::poseidon; +use crate::hash::poseidon::Poseidon; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Evaluates a full Poseidon permutation with 12 state elements. +/// +/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. +/// It has a flag which can be used to swap the first four inputs with the next four, for ordering +/// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for +/// computing the index of the leaf based on these swap bits. +#[derive(Debug)] +pub struct PoseidonGate< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + PoseidonGate +where + [(); WIDTH - 1]: , +{ + pub fn new() -> Self { + PoseidonGate { + _phantom: PhantomData, + } + } + + /// The wire index for the `i`th input to the permutation. + pub fn wire_input(i: usize) -> usize { + i + } + + /// The wire index for the `i`th output to the permutation. + pub fn wire_output(i: usize) -> usize { + WIDTH + i + } + + /// If this is set to 1, the first four inputs will be swapped with the next four inputs. This + /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. + pub const WIRE_SWAP: usize = 2 * WIDTH; + + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set + /// of full rounds. + fn wire_full_sbox_0(round: usize, i: usize) -> usize { + 2 * WIDTH + 1 + WIDTH * round + i + } + + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. + fn wire_partial_sbox(round: usize) -> usize { + 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + } + + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set + /// of full rounds. + fn wire_full_sbox_1(round: usize, i: usize) -> usize { + 2 * WIDTH + + 1 + + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) + + poseidon::N_PARTIAL_ROUNDS + + i + } + + /// End of wire indices, exclusive. + fn end() -> usize { + 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + } +} + +impl + Poseidon, const D: usize, const WIDTH: usize> Gate + for PoseidonGate +where + [(); WIDTH - 1]: , +{ + fn id(&self) -> String { + format!(" {:?}", WIDTH, self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::Extension::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + // for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state[0] += F::Extension::from_canonical_u64( + >::FAST_PARTIAL_ROUND_CONSTANTS[r], + ); + state = >::mds_partial_layer_fast_field(&state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + todo!() + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + todo!() + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = PoseidonGenerator:: { + gate_index, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + Self::end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 7 + } + + fn num_constraints(&self) -> usize { + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + } +} + +#[derive(Debug)] +struct PoseidonGenerator< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + gate_index: usize, + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + SimpleGenerator for PoseidonGenerator +where + [(); WIDTH - 1]: , +{ + fn dependencies(&self) -> Vec { + (0..WIDTH) + .map(|i| PoseidonGate::::wire_input(i)) + .chain(Some(PoseidonGate::::WIRE_SWAP)) + .map(|input| Target::wire(self.gate_index, input)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let mut state = (0..WIDTH) + .map(|i| { + witness.get_wire(Wire { + gate: self.gate_index, + input: PoseidonGate::::wire_input(i), + }) + }) + .collect::>(); + + let swap_value = witness.get_wire(Wire { + gate: self.gate_index, + input: PoseidonGate::::WIRE_SWAP, + }); + debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + if swap_value == F::ONE { + for i in 0..4 { + state.swap(i, 4 + i); + } + } + + let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_partial_sbox(r)), + state[0], + ); + state[0] = >::sbox_monomial(state[0]); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); + } + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_partial_sbox( + poseidon::N_PARTIAL_ROUNDS - 1, + )), + state[0], + ); + state[0] = >::sbox_monomial(state[0]); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer(&mut state, round_ctr); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), + state[i], + ); + } + >::sbox_layer(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_output(i)), + state[i], + ); + } + } +} + +mod yo {} +#[cfg(test)] +mod tests { + use std::convert::TryInto; + + use anyhow::Result; + + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::poseidon::PoseidonGate; + use crate::hash::poseidon::Poseidon; + use crate::iop::generator::generate_partial_witness; + use crate::iop::wire::Wire; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn generated_output() { + type F = CrandallField; + const WIDTH: usize = 12; + + let config = CircuitConfig { + num_wires: 143, + ..CircuitConfig::large_config() + }; + let mut builder = CircuitBuilder::new(config); + type Gate = PoseidonGate; + let gate = Gate::new(); + let gate_index = builder.add_gate(gate, vec![]); + let circuit = builder.build_prover(); + + let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); + + let mut inputs = PartialWitness::new(); + inputs.set_wire( + Wire { + gate: gate_index, + input: Gate::WIRE_SWAP, + }, + F::ZERO, + ); + for i in 0..WIDTH { + inputs.set_wire( + Wire { + gate: gate_index, + input: Gate::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness(inputs, &circuit.prover_only); + + let expected_outputs: [F; WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap()); + for i in 0..WIDTH { + let out = witness.get_wire(Wire { + gate: 0, + input: Gate::wire_output(i), + }); + assert_eq!(out, expected_outputs[i]); + } + } + + #[test] + fn low_degree() { + type F = CrandallField; + const WIDTH: usize = 12; + let gate = PoseidonGate::::new(); + test_low_degree(gate) + } + + #[test] + fn eval_fns() -> Result<()> { + type F = CrandallField; + const WIDTH: usize = 12; + let gate = PoseidonGate::::new(); + test_eval_fns(gate) + } +}