Merge pull request #250 from mir-protocol/poseidon_gate

Poseidon gate and global move to Poseidon
This commit is contained in:
wborgeaud 2021-09-18 18:33:37 +02:00 committed by GitHub
commit 92f5d39671
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 894 additions and 33 deletions

View File

@ -3,8 +3,10 @@ use std::convert::TryInto;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::gmimc::GMiMCGate;
use crate::gates::poseidon::PoseidonGate;
use crate::hash::gmimc::GMiMC;
use crate::hash::hashing::{HashFamily, HASH_FAMILY};
use crate::hash::poseidon::Poseidon;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -13,7 +15,8 @@ use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn permute<const W: usize>(&mut self, inputs: [Target; W]) -> [Target; W]
where
F: GMiMC<W>,
F: GMiMC<W> + Poseidon<W>,
[(); W - 1]: ,
{
// We don't want to swap any inputs, so set that wire to 0.
let _false = self._false();
@ -28,11 +31,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
swap: BoolTarget,
) -> [Target; W]
where
F: GMiMC<W>,
F: GMiMC<W> + Poseidon<W>,
[(); W - 1]: ,
{
match HASH_FAMILY {
HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap),
HashFamily::Poseidon => todo!(),
HashFamily::Poseidon => self.poseidon_permute_swapped(inputs, swap),
}
}
@ -79,4 +83,38 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.try_into()
.unwrap()
}
/// Conditionally swap two chunks of the inputs (useful in verifying Merkle proofs), then apply
/// the Poseidon permutation.
pub(crate) fn poseidon_permute_swapped<const W: usize>(
&mut self,
inputs: [Target; W],
swap: BoolTarget,
) -> [Target; W]
where
F: Poseidon<W>,
[(); W - 1]: ,
{
let gate_type = PoseidonGate::<F, D, W>::new();
let gate = self.add_gate(gate_type, vec![]);
// We don't want to swap any inputs, so set that wire to 0.
let swap_wire = PoseidonGate::<F, D, W>::WIRE_SWAP;
let swap_wire = Target::wire(gate, swap_wire);
self.connect(swap.target, swap_wire);
// Route input wires.
for i in 0..W {
let in_wire = PoseidonGate::<F, D, W>::wire_input(i);
let in_wire = Target::wire(gate, in_wire);
self.connect(inputs[i], in_wire);
}
// Collect output wires.
(0..W)
.map(|i| Target::wire(gate, PoseidonGate::<F, D, W>::wire_output(i)))
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
}

View File

@ -13,13 +13,11 @@ use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Evaluates a full GMiMC permutation with 12 state elements, and writes the output to the next
/// gate's first `width` wires (which could be the input of another `GMiMCGate`).
/// Evaluates a full GMiMC permutation with 12 state elements.
///
/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs.
/// It has a flag which can be used to swap the first four inputs with the next four, for ordering
/// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for
/// computing the index of the leaf based on these swap bits.
/// sibling digests.
#[derive(Debug)]
pub struct GMiMCGate<
F: RichField + Extendable<D> + GMiMC<WIDTH>,

View File

@ -12,6 +12,7 @@ pub mod gmimc;
pub mod insertion;
pub mod interpolation;
pub mod noop;
pub mod poseidon;
pub(crate) mod public_input;
pub mod random_access;
pub mod reducing;

570
src/gates/poseidon.rs Normal file
View File

@ -0,0 +1,570 @@
use std::convert::TryInto;
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::hash::poseidon;
use crate::hash::poseidon::Poseidon;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Evaluates a full Poseidon permutation with 12 state elements.
///
/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs.
/// It has a flag which can be used to swap the first four inputs with the next four, for ordering
/// sibling digests.
#[derive(Debug)]
pub struct PoseidonGate<
F: RichField + Extendable<D> + Poseidon<WIDTH>,
const D: usize,
const WIDTH: usize,
> where
[(); WIDTH - 1]: ,
{
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
PoseidonGate<F, D, WIDTH>
where
[(); WIDTH - 1]: ,
{
pub fn new() -> Self {
PoseidonGate {
_phantom: PhantomData,
}
}
/// The wire index for the `i`th input to the permutation.
pub fn wire_input(i: usize) -> usize {
i
}
/// The wire index for the `i`th output to the permutation.
pub fn wire_output(i: usize) -> usize {
WIDTH + i
}
/// If this is set to 1, the first four inputs will be swapped with the next four inputs. This
/// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0.
pub const WIRE_SWAP: usize = 2 * WIDTH;
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set
/// of full rounds.
fn wire_full_sbox_0(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < WIDTH);
2 * WIDTH + 1 + WIDTH * round + i
}
/// A wire which stores the input of the S-box of the `round`-th round of the partial rounds.
fn wire_partial_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round
}
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set
/// of full rounds.
fn wire_full_sbox_1(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < WIDTH);
2 * WIDTH
+ 1
+ WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round)
+ poseidon::N_PARTIAL_ROUNDS
+ i
}
/// End of wire indices, exclusive.
fn end() -> usize {
2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS
}
}
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
for PoseidonGate<F, D, WIDTH>
where
[(); WIDTH - 1]: ,
{
fn id(&self) -> String {
format!("<WIDTH={}> {:?}", WIDTH, self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
let mut state: [F::Extension; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
// First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
// Partial rounds.
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] += F::Extension::from_canonical_u64(
<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r],
);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::ONE));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
let mut state: [F; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
// First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
// Partial rounds.
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] +=
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let one = builder.one_extension();
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
let delta = builder.sub_extension(b, a);
state.push(builder.mul_add_extension(swap, delta, a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
let delta = builder.sub_extension(b, a);
state.push(builder.mul_add_extension(swap, delta, a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
let mut state: [ExtensionTarget<D>; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
// First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
// Partial rounds.
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &mut state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state[0] = builder.arithmetic_extension(
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
F::ONE,
one,
one,
state[0],
);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(
builder,
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
for i in 0..WIDTH {
constraints
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
}
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = PoseidonGenerator::<F, D, WIDTH> {
gate_index,
_phantom: PhantomData,
};
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
Self::end()
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
7
}
fn num_constraints(&self) -> usize {
WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1
}
}
#[derive(Debug)]
struct PoseidonGenerator<
F: RichField + Extendable<D> + Poseidon<WIDTH>,
const D: usize,
const WIDTH: usize,
> where
[(); WIDTH - 1]: ,
{
gate_index: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
SimpleGenerator<F> for PoseidonGenerator<F, D, WIDTH>
where
[(); WIDTH - 1]: ,
{
fn dependencies(&self) -> Vec<Target> {
(0..WIDTH)
.map(|i| PoseidonGate::<F, D, WIDTH>::wire_input(i))
.chain(Some(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP))
.map(|input| Target::wire(self.gate_index, input))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let mut state = (0..WIDTH)
.map(|i| {
witness.get_wire(Wire {
gate: self.gate_index,
input: PoseidonGate::<F, D, WIDTH>::wire_input(i),
})
})
.collect::<Vec<_>>();
let swap_value = witness.get_wire(Wire {
gate: self.gate_index,
input: PoseidonGate::<F, D, WIDTH>::WIRE_SWAP,
});
debug_assert!(swap_value == F::ZERO || swap_value == F::ONE);
if swap_value == F::ONE {
for i in 0..4 {
state.swap(i, 4 + i);
}
}
let mut state: [F; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_0(r, i)),
state[i],
);
}
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(r)),
state[0],
);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
state[0] +=
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
}
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(
poseidon::N_PARTIAL_ROUNDS - 1,
)),
state[0],
);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_1(r, i)),
state[i],
);
}
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
for i in 0..WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_output(i)),
state[i],
);
}
}
}
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::Field;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::poseidon::PoseidonGate;
use crate::hash::poseidon::Poseidon;
use crate::iop::generator::generate_partial_witness;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
#[test]
fn generated_output() {
type F = CrandallField;
const WIDTH: usize = 12;
let config = CircuitConfig {
num_wires: 143,
..CircuitConfig::large_config()
};
let mut builder = CircuitBuilder::new(config);
type Gate = PoseidonGate<F, 4, WIDTH>;
let gate = Gate::new();
let gate_index = builder.add_gate(gate, vec![]);
let circuit = builder.build_prover();
let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::<Vec<_>>();
let mut inputs = PartialWitness::new();
inputs.set_wire(
Wire {
gate: gate_index,
input: Gate::WIRE_SWAP,
},
F::ZERO,
);
for i in 0..WIDTH {
inputs.set_wire(
Wire {
gate: gate_index,
input: Gate::wire_input(i),
},
permutation_inputs[i],
);
}
let witness = generate_partial_witness(inputs, &circuit.prover_only);
let expected_outputs: [F; WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap());
for i in 0..WIDTH {
let out = witness.get_wire(Wire {
gate: 0,
input: Gate::wire_output(i),
});
assert_eq!(out, expected_outputs[i]);
}
}
#[test]
fn low_degree() {
type F = CrandallField;
const WIDTH: usize = 12;
let gate = PoseidonGate::<F, 4, WIDTH>::new();
test_low_degree(gate)
}
#[test]
fn eval_fns() -> Result<()> {
type F = CrandallField;
const WIDTH: usize = 12;
let gate = PoseidonGate::<F, 4, WIDTH>::new();
test_eval_fns(gate)
}
}

View File

@ -2,6 +2,7 @@
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::poseidon::PoseidonGate;
use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -10,7 +11,8 @@ pub(crate) const SPONGE_RATE: usize = 8;
pub(crate) const SPONGE_CAPACITY: usize = 4;
pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY;
pub(crate) const HASH_FAMILY: HashFamily = HashFamily::GMiMC;
pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon;
pub(crate) type HashGate<F, const D: usize, const W: usize> = PoseidonGate<F, D, W>;
pub(crate) enum HashFamily {
GMiMC,

View File

@ -6,9 +6,8 @@ use serde::{Deserialize, Serialize};
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::gmimc::GMiMCGate;
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};
use crate::hash::hashing::{compress, hash_or_noop};
use crate::hash::hashing::{compress, hash_or_noop, HashGate};
use crate::hash::merkle_tree::MerkleCap;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
@ -74,7 +73,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.concat()
.try_into()
.unwrap();
let outputs = self.gmimc_permute_swapped(inputs, bit);
let outputs = self.permute_swapped(inputs, bit);
state = HashOutTarget::from_vec(outputs[0..4].to_vec());
}
@ -107,10 +106,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut state: HashOutTarget = self.hash_or_noop(leaf_data);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let gate_type = GMiMCGate::<F, D, 12>::new();
let gate_type = HashGate::<F, D, 12>::new();
let gate = self.add_gate(gate_type, vec![]);
let swap_wire = GMiMCGate::<F, D, 12>::WIRE_SWAP;
let swap_wire = HashGate::<F, D, 12>::WIRE_SWAP;
let swap_wire = Target::Wire(Wire {
gate,
input: swap_wire,
@ -121,7 +120,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.map(|i| {
Target::Wire(Wire {
gate,
input: GMiMCGate::<F, D, 12>::wire_input(i),
input: HashGate::<F, D, 12>::wire_input(i),
})
})
.collect::<Vec<_>>();
@ -137,7 +136,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.map(|i| {
Target::Wire(Wire {
gate,
input: GMiMCGate::<F, D, 12>::wire_output(i),
input: HashGate::<F, D, 12>::wire_output(i),
})
})
.collect(),

View File

@ -7,7 +7,10 @@ use std::convert::TryInto;
use unroll::unroll_for_loops;
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::PrimeField;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{PrimeField, RichField};
use crate::plonk::circuit_builder::CircuitBuilder;
// The number of full rounds and partial rounds is given by the
// calc_round_numbers.py script. They happen to be the same for both
@ -15,9 +18,9 @@ use crate::field::field_types::PrimeField;
//
// NB: Changing any of these values will require regenerating all of
// the precomputed constant arrays in this file.
const HALF_N_FULL_ROUNDS: usize = 4;
const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
const N_PARTIAL_ROUNDS: usize = 22;
pub(crate) const HALF_N_FULL_ROUNDS: usize = 4;
pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
pub(crate) const N_PARTIAL_ROUNDS: usize = 22;
const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :)
@ -25,7 +28,7 @@ const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :
/// `generate_constants` about how these were generated. We include enough for a WIDTH of 12;
/// smaller widths just use a subset.
#[rustfmt::skip]
const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [
pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [
// WARNING: These must be in 0..CrandallField::ORDER (i.e. canonical form). If this condition is
// not met, some platform-specific implementation of constant_layer may return incorrect
// results.
@ -165,6 +168,49 @@ where
res
}
#[inline(always)]
#[unroll_for_loops]
/// Same as `mds_row_shf` for field extensions of `Self`.
fn mds_row_shf_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
r: usize,
v: &[F; WIDTH],
) -> F {
debug_assert!(r < WIDTH);
let mut res = F::ZERO;
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
res += v[(i + r) % WIDTH] * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]);
}
}
res
}
/// Recursive version of `mds_row_shf`.
fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
r: usize,
v: &[ExtensionTarget<D>; WIDTH],
) -> ExtensionTarget<D> {
let one = builder.one_extension();
debug_assert!(r < WIDTH);
let mut res = builder.zero_extension();
for i in 0..WIDTH {
res = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
F::ONE,
one,
v[(i + r) % WIDTH],
res,
);
}
res
}
#[inline(always)]
#[unroll_for_loops]
fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] {
@ -188,19 +234,72 @@ where
#[inline(always)]
#[unroll_for_loops]
fn partial_first_constant_layer(state: &mut [Self; WIDTH]) {
/// Same as `mds_layer` for field extensions of `Self`.
fn mds_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH],
) -> [F; WIDTH] {
let mut result = [F::ZERO; WIDTH];
assert!(WIDTH <= 12);
for r in 0..12 {
if r < WIDTH {
result[r] = Self::mds_row_shf_field(r, state);
}
}
result
}
/// Recursive version of `mds_layer`.
fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
) -> [ExtensionTarget<D>; WIDTH] {
let mut result = [builder.zero_extension(); WIDTH];
for r in 0..WIDTH {
result[r] = Self::mds_row_shf_recursive(builder, r, state);
}
result
}
#[inline(always)]
#[unroll_for_loops]
fn partial_first_constant_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH],
) {
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]);
state[i] += F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]);
}
}
}
/// Recursive version of `partial_first_constant_layer`.
fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
let one = builder.one_extension();
for i in 0..WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
F::ONE,
one,
one,
state[i],
);
}
}
#[inline(always)]
#[unroll_for_loops]
fn mds_partial_layer_init(state: &[Self; WIDTH]) -> [Self; WIDTH] {
let mut result = [Self::ZERO; WIDTH];
fn mds_partial_layer_init<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH],
) -> [F; WIDTH] {
let mut result = [F::ZERO; WIDTH];
// Initial matrix has first row/column = [1, 0, ..., 0];
@ -216,7 +315,7 @@ where
// NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in
// column-major order so that this dot product is cache
// friendly.
let t = Self::from_canonical_u64(
let t = F::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
result[c] += state[r] * t;
@ -227,6 +326,30 @@ where
result
}
/// Recursive version of `mds_partial_layer_init`.
fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
) -> [ExtensionTarget<D>; WIDTH] {
let one = builder.one_extension();
let mut result = [builder.zero_extension(); WIDTH];
result[0] = state[0];
for c in 1..WIDTH {
assert!(WIDTH <= 12);
for r in 1..12 {
if r < WIDTH {
let t = F::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
}
}
}
result
}
/// Computes s*A where s is the state row vector and A is the matrix
///
/// [ M_00 | v ]
@ -263,6 +386,70 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
/// Same as `mds_partial_layer_fast` for field extensions of `Self`.
fn mds_partial_layer_fast_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &[F; WIDTH],
r: usize,
) -> [F; WIDTH] {
let s0 = state[0];
let mut d = s0 * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]);
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d += state[i] * t;
}
}
// result = [d] concat [state[0] * v + state[shift up by 1]]
let mut result = [F::ZERO; WIDTH];
result[0] = d;
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
result[i] = state[0] * t + state[i];
}
}
result
}
/// Recursive version of `mds_partial_layer_fast`.
fn mds_partial_layer_fast_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
r: usize,
) -> [ExtensionTarget<D>; WIDTH] {
let zero = builder.zero_extension();
let one = builder.one_extension();
let s0 = state[0];
let mut d = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]),
F::ONE,
one,
s0,
zero,
);
for i in 1..WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
}
let mut result = [zero; WIDTH];
result[0] = d;
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]);
}
}
result
}
#[inline(always)]
#[unroll_for_loops]
fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) {
@ -275,7 +462,40 @@ where
}
#[inline(always)]
fn sbox_monomial(x: Self) -> Self {
#[unroll_for_loops]
/// Same as `constant_layer` for field extensions of `Self`.
fn constant_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH],
round_ctr: usize,
) {
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] += F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]);
}
}
}
/// Recursive version of `constant_layer`.
fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
round_ctr: usize,
) {
let one = builder.one_extension();
for i in 0..WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
F::ONE,
one,
one,
state[i],
);
}
}
#[inline(always)]
fn sbox_monomial<F: FieldExtension<D, BaseField = Self>, const D: usize>(x: F) -> F {
// x |--> x^7
let x2 = x * x;
let x4 = x2 * x2;
@ -283,6 +503,15 @@ where
x3 * x4
}
/// Recursive version of `sbox_monomial`.
fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
x: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
// x |--> x^7
builder.exp_u64_extension(x, 7)
}
#[inline(always)]
#[unroll_for_loops]
fn sbox_layer(state: &mut [Self; WIDTH]) {
@ -294,6 +523,30 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
/// Same as `sbox_layer` for field extensions of `Self`.
fn sbox_layer_field<F: FieldExtension<D, BaseField = Self>, const D: usize>(
state: &mut [F; WIDTH],
) {
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = Self::sbox_monomial(state[i]);
}
}
}
/// Recursive version of `sbox_layer`.
fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
for i in 0..WIDTH {
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
}
}
#[inline]
fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) {
for _ in 0..HALF_N_FULL_ROUNDS {

View File

@ -4,7 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::RichField;
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};
use crate::hash::hashing::{SPONGE_RATE, SPONGE_WIDTH};
use crate::hash::hashing::{permute, SPONGE_RATE, SPONGE_WIDTH};
use crate::hash::merkle_tree::MerkleCap;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -105,7 +105,7 @@ impl<F: RichField> Challenger<F> {
if self.output_buffer.is_empty() {
// Evaluate the permutation to produce `r` new outputs.
self.sponge_state = F::gmimc_permute(self.sponge_state);
self.sponge_state = permute(self.sponge_state);
self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec();
}
@ -160,7 +160,7 @@ impl<F: RichField> Challenger<F> {
}
// Apply the permutation.
self.sponge_state = F::gmimc_permute(self.sponge_state);
self.sponge_state = permute(self.sponge_state);
}
self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec();
@ -377,7 +377,7 @@ mod tests {
}
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, 4>::new(config.clone());
let mut builder = CircuitBuilder::<F, 4>::new(config);
let mut recursive_challenger = RecursiveChallenger::new(&mut builder);
let mut recursive_outputs_per_round: Vec<Vec<Target>> = Vec::new();
for (r, inputs) in inputs_per_round.iter().enumerate() {

View File

@ -61,7 +61,7 @@ impl CircuitConfig {
#[cfg(test)]
pub(crate) fn large_config() -> Self {
Self {
num_wires: 126,
num_wires: 143,
num_routed_wires: 64,
security_bits: 128,
rate_bits: 3,

View File

@ -361,7 +361,7 @@ mod tests {
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig {
num_wires: 126,
num_wires: 143,
num_routed_wires: 33,
security_bits: 128,
rate_bits: 3,
@ -416,7 +416,7 @@ mod tests {
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig {
num_wires: 126,
num_wires: 143,
num_routed_wires: 64,
security_bits: 128,
rate_bits: 3,