2025-04-03 11:15:09 +02:00

636 lines
23 KiB
Rust

//! Implementation of the Poseidon2 hash function as Plonky2 Gate
//! based on Poseidon Gate:
//! https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/gates/poseidon.rs
//!
use core::marker::PhantomData;
use plonky2_field::extension::Extendable;
use plonky2_field::types::Field;
use plonky2::gates::gate::Gate;
use plonky2::gates::util::StridedConstraintConsumer;
use plonky2::hash::hash_types::RichField;
use crate::poseidon2_hash::poseidon2::{Poseidon2, FULL_ROUND_BEGIN, FULL_ROUND_END, PARTIAL_ROUNDS, SPONGE_WIDTH};
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef};
use plonky2::iop::target::Target;
use plonky2::iop::wire::Wire;
use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CommonCircuitData;
use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use plonky2::util::serialization::{Buffer, IoResult, Read, Write};
/// Evaluates a full Poseidon2 permutation with 12 state elements.
///
/// This also has some extra features to make it suitable for efficiently
/// verifying Merkle proofs. It has a flag which can be used to swap the first
/// four inputs with the next four, for ordering sibling digests.
#[derive(Debug, Default)]
pub struct Poseidon2Gate<F: RichField + Extendable<D>, const D: usize>(PhantomData<F>);
impl<F: RichField + Extendable<D>, const D: usize> Poseidon2Gate<F, D> {
pub fn new() -> Self {
Self(PhantomData)
}
/// The wire index for the `i`th input to the permutation.
pub fn wire_input(i: usize) -> usize {
i
}
/// The wire index for the `i`th output to the permutation.
pub fn wire_output(i: usize) -> usize {
SPONGE_WIDTH + i
}
/// If this is set to 1, the first four inputs will be swapped with the next
/// four inputs. This is useful for ordering hashes in Merkle proofs.
/// Otherwise, this should be set to 0.
pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH;
const START_DELTA: usize = 2 * SPONGE_WIDTH + 1;
/// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute
/// the swapped inputs.
fn wire_delta(i: usize) -> usize {
assert!(i < 4);
Self::START_DELTA + i
}
const START_FULL_ROUND_BEGIN: usize = Self::START_DELTA + 4;
/// A wire which stores the input of the `i`-th S-box of the `round`-th
/// round of the first set of full rounds.
fn wire_first_full_round(round: usize, i: usize) -> usize {
debug_assert!(
round != 0,
"First round S-box inputs are not stored as wires"
);
debug_assert!(round < FULL_ROUND_BEGIN);
debug_assert!(i < SPONGE_WIDTH);
Self::START_FULL_ROUND_BEGIN + SPONGE_WIDTH * (round - 1) + i
}
const START_PARTIAL: usize =
Self::START_FULL_ROUND_BEGIN + SPONGE_WIDTH * (FULL_ROUND_BEGIN - 1);
/// A wire which stores the input of the S-box of the `round`-th round of
/// the partial rounds.
fn wire_partial_round(round: usize) -> usize {
debug_assert!(round < PARTIAL_ROUNDS);
Self::START_PARTIAL + round
}
const START_FULL_ROUND_END: usize = Self::START_PARTIAL + PARTIAL_ROUNDS;
/// A wire which stores the input of the `i`-th S-box of the `round`-th
/// round of the second set of full rounds.
fn wire_second_full_round(round: usize, i: usize) -> usize {
debug_assert!(round < FULL_ROUND_BEGIN);
debug_assert!(i < SPONGE_WIDTH);
Self::START_FULL_ROUND_END + SPONGE_WIDTH * round + i
}
/// End of wire indices, exclusive.
fn end() -> usize {
Self::START_FULL_ROUND_END + SPONGE_WIDTH * FULL_ROUND_BEGIN
}
}
impl<F: RichField + Extendable<D> + Poseidon2, const D: usize> Gate<F, D> for Poseidon2Gate<F, D> {
fn id(&self) -> String {
format!("{:?}<WIDTH={}>", self, SPONGE_WIDTH)
}
fn serialize(
&self,
_dst: &mut Vec<u8>,
_common_data: &CommonCircuitData<F, D>,
) -> IoResult<()> {
Ok(())
}
fn deserialize(_src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
Ok(Poseidon2Gate::new())
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE));
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
}
// Compute the possibly-swapped input layer.
let mut state = [F::Extension::ZERO; SPONGE_WIDTH];
for i in 0..4 {
let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = Self::wire_input(i);
let input_rhs = Self::wire_input(i + 4);
state[i] = vars.local_wires[input_lhs] + delta_i;
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
}
for i in 8..SPONGE_WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
}
// linear layer
<F as Poseidon2>::matmul_external_field(&mut state);
// First External layer
for r in 0..FULL_ROUND_BEGIN {
<F as Poseidon2>::constant_layer_field(&mut state, r);
if r != 0 {
for i in 0..SPONGE_WIDTH {
let sbox_in =
vars.local_wires[Self::wire_first_full_round(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
}
<F as Poseidon2>::sbox_layer_field(&mut state);
<F as Poseidon2>::matmul_external_field(&mut state);
}
// Internal layer
for r in 0..PARTIAL_ROUNDS {
state[0] += F::Extension::from_canonical_u64(<F as Poseidon2>::RC12_MID[r]);
let sbox_in =
vars.local_wires[Self::wire_partial_round(r)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon2>::sbox_p(sbox_in);
<F as Poseidon2>::matmul_internal_field(&mut state, &<F as Poseidon2>::MAT_DIAG12_M_1);
}
// Second External layer
for r in FULL_ROUND_BEGIN..FULL_ROUND_END {
<F as Poseidon2>::constant_layer_field(&mut state, r);
for i in 0..SPONGE_WIDTH {
let sbox_in =
vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon2>::sbox_layer_field(&mut state);
<F as Poseidon2>::matmul_external_field(&mut state);
}
//12 constraints
for i in 0..SPONGE_WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
}
constraints
}
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
yield_constr.one(swap * swap.sub_one());
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
yield_constr.one(swap * (input_rhs - input_lhs) - delta_i);
}
// Compute the possibly-swapped input layer.
let mut state = [F::ZERO; SPONGE_WIDTH];
for i in 0..4 {
let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = Self::wire_input(i);
let input_rhs = Self::wire_input(i + 4);
state[i] = vars.local_wires[input_lhs] + delta_i;
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
}
for i in 8..SPONGE_WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
}
// linear layer
<F as Poseidon2>::matmul_external(&mut state);
// First External layer
for r in 0..FULL_ROUND_BEGIN {
<F as Poseidon2>::constant_layer(&mut state, r);
if r != 0 {
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)];
yield_constr.one(state[i] - sbox_in);
state[i] = sbox_in;
}
}
<F as Poseidon2>::sbox_layer(&mut state);
<F as Poseidon2>::matmul_external(&mut state);
}
// Internal layer
for r in 0..PARTIAL_ROUNDS {
state[0] += F::from_canonical_u64(<F as Poseidon2>::RC12_MID[r]);
let sbox_in = vars.local_wires[Self::wire_partial_round(r)];
yield_constr.one(state[0] - sbox_in);
state[0] = sbox_in;
state[0] = <F as Poseidon2>::sbox_p(state[0]);
<F as Poseidon2>::matmul_internal(&mut state, &<F as Poseidon2>::MAT_DIAG12_M_1);
}
// Second External layer
for r in FULL_ROUND_BEGIN..FULL_ROUND_END {
<F as Poseidon2>::constant_layer(&mut state, r);
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)];
yield_constr.one(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon2>::sbox_layer(&mut state);
<F as Poseidon2>::matmul_external(&mut state);
}
for i in 0..SPONGE_WIDTH {
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
}
}
fn eval_unfiltered_circuit(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap));
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
let diff = builder.sub_extension(input_rhs, input_lhs);
constraints.push(builder.mul_sub_extension(swap, diff, delta_i));
}
// Compute the possibly-swapped input layer.
let mut state = [builder.zero_extension(); SPONGE_WIDTH];
for i in 0..4 {
let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
state[i] = builder.add_extension(input_lhs, delta_i);
state[i + 4] = builder.sub_extension(input_rhs, delta_i);
}
for i in 8..SPONGE_WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
}
// linear layer
state = <F as Poseidon2>::matmul_external_circuit(builder, &mut state);
// First External layer
for r in 0..FULL_ROUND_BEGIN {
<F as Poseidon2>::constant_layer_circuit(builder, &mut state, r);
if r != 0 {
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
}
<F as Poseidon2>::sbox_layer_circuit(builder, &mut state);
state = <F as Poseidon2>::matmul_external_circuit(builder, &mut state);
}
// Internal layer
for r in 0..PARTIAL_ROUNDS {
let round_constant = F::Extension::from_canonical_u64(<F as Poseidon2>::RC12_MID[r]);
let round_constant = builder.constant_extension(round_constant);
state[0] = builder.add_extension(state[0], round_constant);
let sbox_in = vars.local_wires[Self::wire_partial_round(r)];
constraints.push(builder.sub_extension(state[0], sbox_in));
//state[0] = sbox_in;
state[0] = <F as Poseidon2>::sbox_p_circuit(builder, sbox_in);
<F as Poseidon2>::matmul_internal_circuit(builder, &mut state);
}
// Second External layer
for r in FULL_ROUND_BEGIN..FULL_ROUND_END {
<F as Poseidon2>::constant_layer_circuit(builder, &mut state, r);
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
<F as Poseidon2>::sbox_layer_circuit(builder, &mut state);
state = <F as Poseidon2>::matmul_external_circuit(builder, &mut state);
}
for i in 0..SPONGE_WIDTH {
constraints
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
}
constraints
}
fn generators(&self, row: usize, _local_constants: &[F]) -> Vec<WitnessGeneratorRef<F, D>> {
let gen = Poseidon2Generator::<F, D> {
row,
_phantom: PhantomData,
};
vec![WitnessGeneratorRef::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
Self::end()
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
7
}
fn num_constraints(&self) -> usize {
SPONGE_WIDTH * (FULL_ROUND_END - 1) + PARTIAL_ROUNDS + SPONGE_WIDTH + 1 + 4
}
}
#[derive(Debug, Default)]
pub struct Poseidon2Generator<F: RichField + Extendable<D> + Poseidon2, const D: usize> {
row: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D> + Poseidon2, const D: usize> SimpleGenerator<F, D>
for Poseidon2Generator<F, D>
{
fn id(&self) -> String {
"Poseidon2Generator".to_string()
}
fn dependencies(&self) -> Vec<Target> {
(0..SPONGE_WIDTH)
.map(|i| Poseidon2Gate::<F, D>::wire_input(i))
.chain(Some(Poseidon2Gate::<F, D>::WIRE_SWAP))
.map(|column| Target::wire(self.row, column))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) -> anyhow::Result<()> {
let local_wire = |column| Wire {
row: self.row,
column,
};
let mut state = (0..SPONGE_WIDTH)
.map(|i| witness.get_wire(local_wire(Poseidon2Gate::<F, D>::wire_input(i))))
.collect::<Vec<_>>();
let swap_value = witness.get_wire(local_wire(Poseidon2Gate::<F, D>::WIRE_SWAP));
debug_assert!(swap_value == F::ZERO || swap_value == F::ONE);
for i in 0..4 {
let delta_i = swap_value * (state[i + 4] - state[i]);
out_buffer.set_wire(local_wire(Poseidon2Gate::<F, D>::wire_delta(i)), delta_i)?;
}
if swap_value == F::ONE {
for i in 0..4 {
state.swap(i, 4 + i);
}
}
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap();
// Linear layer
<F as Poseidon2>::matmul_external_field(&mut state);
// first External layer
for r in 0..FULL_ROUND_BEGIN {
<F as Poseidon2>::constant_layer_field(&mut state, r);
if r != 0 {
for i in 0..SPONGE_WIDTH {
out_buffer.set_wire(
local_wire(Poseidon2Gate::<F, D>::wire_first_full_round(r, i)),
state[i],
)?;
}
}
<F as Poseidon2>::sbox_layer_field(&mut state);
<F as Poseidon2>::matmul_external_field(&mut state);
}
// Internal layer
for r in 0..PARTIAL_ROUNDS {
state[0] += F::from_canonical_u64(<F as Poseidon2>::RC12_MID[r]);
out_buffer.set_wire(
local_wire(Poseidon2Gate::<F, D>::wire_partial_round(r)),
state[0],
)?;
state[0] = <F as Poseidon2>::sbox_p(state[0]);
<F as Poseidon2>::matmul_internal_field(&mut state, &<F as Poseidon2>::MAT_DIAG12_M_1);
}
// Second External layer
for r in FULL_ROUND_BEGIN..FULL_ROUND_END {
<F as Poseidon2>::constant_layer_field(&mut state, r);
for i in 0..SPONGE_WIDTH {
out_buffer.set_wire(
local_wire(Poseidon2Gate::<F, D>::wire_second_full_round(
r - FULL_ROUND_BEGIN,
i,
)),
state[i],
)?;
}
<F as Poseidon2>::sbox_layer_field(&mut state);
<F as Poseidon2>::matmul_external_field(&mut state);
}
for i in 0..SPONGE_WIDTH {
out_buffer.set_wire(local_wire(Poseidon2Gate::<F, D>::wire_output(i)), state[i])?;
}
Ok(())
}
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
dst.write_usize(self.row)
}
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
let row = src.read_usize()?;
Ok(Self {
row,
_phantom: PhantomData,
})
}
}
//------------------------------------- Tests -----------------------------------------
#[cfg(test)]
mod tests {
use anyhow::Result;
use plonky2_field::goldilocks_field::GoldilocksField;
use plonky2_field::types::Field;
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gate::poseidon2::Poseidon2Gate;
use crate::poseidon2_hash::poseidon2::{Poseidon2, SPONGE_WIDTH};
use plonky2::iop::generator::generate_partial_witness;
use plonky2::iop::wire::Wire;
use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CircuitConfig;
use plonky2::plonk::config::GenericConfig;
use crate::config::Poseidon2GoldilocksConfig;
#[test]
fn wire_indices() {
type F = GoldilocksField;
type Gate = Poseidon2Gate<F, 4>;
assert_eq!(Gate::wire_input(0), 0);
assert_eq!(Gate::wire_input(11), 11);
assert_eq!(Gate::wire_output(0), 12);
assert_eq!(Gate::wire_output(11), 23);
assert_eq!(Gate::WIRE_SWAP, 24);
assert_eq!(Gate::wire_delta(0), 25);
assert_eq!(Gate::wire_delta(3), 28);
}
#[test]
fn generated_output() -> Result<()>{
const D: usize = 2;
type C = Poseidon2GoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig {
num_wires: 143,
..CircuitConfig::standard_recursion_config()
};
let mut builder = CircuitBuilder::new(config);
type Gate = Poseidon2Gate<F, D>;
let gate = Gate::new();
let row = builder.add_gate(gate, vec![]);
let circuit = builder.build_prover::<C>();
println!("width = {}", SPONGE_WIDTH);
let permutation_inputs = (0..SPONGE_WIDTH).map(F::from_canonical_usize).collect::<Vec<_>>();
for i in 0..SPONGE_WIDTH {
println!("out {} = {}", i, permutation_inputs[i].clone());
}
let mut inputs = PartialWitness::new();
inputs.set_wire(
Wire {
row,
column: Gate::WIRE_SWAP,
},
F::ZERO,
)?;
for i in 0..SPONGE_WIDTH {
inputs.set_wire(
Wire {
row,
column: Gate::wire_input(i),
},
permutation_inputs[i],
)?;
}
let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap();
let expected_outputs: [F; SPONGE_WIDTH] = F::poseidon2(permutation_inputs.try_into().unwrap());
for i in 0..SPONGE_WIDTH {
let out = witness.get_wire(Wire {
row: 0,
column: Gate::wire_output(i),
});
println!("out {} = {}", i, out.clone());
assert_eq!(out, expected_outputs[i]);
};
Ok(())
}
#[test]
fn low_degree() {
type F = GoldilocksField;
let gate = Poseidon2Gate::<F, 4>::new();
test_low_degree(gate)
}
#[test]
fn eval_fns() -> Result<()> {
const D: usize = 2;
type C = Poseidon2GoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let gate = Poseidon2Gate::<F, 2>::new();
test_eval_fns::<F, C, _, D>(gate)
}
#[test]
fn test_proof() -> Result<()>{
use plonky2_field::types::Sample;
use plonky2::gates::gate::Gate;
use plonky2::hash::hash_types::HashOut;
use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars};
const D: usize = 2;
type C = Poseidon2GoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let gate = Poseidon2Gate::<F, 2>::new();
let wires = <<Poseidon2GoldilocksConfig as GenericConfig<D>>::F as plonky2_field::extension::Extendable<D>>::Extension::rand_vec(gate.num_wires());
let constants = <<Poseidon2GoldilocksConfig as GenericConfig<D>>::F as plonky2_field::extension::Extendable<D>>::Extension::rand_vec(gate.num_constants());
let public_inputs_hash = HashOut::rand();
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let wires_t = builder.add_virtual_extension_targets(wires.len());
let constants_t = builder.add_virtual_extension_targets(constants.len());
pw.set_extension_targets(&wires_t, &wires)?;
pw.set_extension_targets(&constants_t, &constants)?;
let public_inputs_hash_t = builder.add_virtual_hash();
pw.set_hash_target(public_inputs_hash_t, public_inputs_hash)?;
let vars = EvaluationVars {
local_constants: &constants,
local_wires: &wires,
public_inputs_hash: &public_inputs_hash,
};
let evals = gate.eval_unfiltered(vars);
let vars_t = EvaluationTargets {
local_constants: &constants_t,
local_wires: &wires_t,
public_inputs_hash: &public_inputs_hash_t,
};
let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t);
pw.set_extension_targets(&evals_t, &evals)?;
let data = builder.build::<C>();
let proof = data.prove(pw);
assert!(proof.is_ok());
Ok(())
}
}