mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 23:03:08 +00:00
PoseidonMdsGate (#330)
PoseidonGate's recursive evaluations were using a lot of gates, and the MDS layer was the main culprit. The other issue is that `constant_layer_recursive` creates a bunch of `ArithmeticGate`s with unique constants. We could either change `ArithmeticGate` to support different constants per operation, or wire in constants from `ConstantGate`, and change `ConstantGate` to support several constants per gate. This won't really help anything near term since we're still between 2^12 and 2^13, but could have some benefits later, depending on what recursion arities and security settings we end up using. `PoseidonMdsGate` needs `2 * D * WIDTH = 48` routed wires, and the combination of adding a gate and increasing routed wires slows down the prover a bit. So for now, I kept it at 28 wires, and the old code path is still used.
This commit is contained in:
parent
caf95ae9dc
commit
c6f91148d5
@ -18,7 +18,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
/// Like `select_ext`, but accepts a condition input which does not necessarily have to be
|
||||
/// binary. In this case, it computes the arithmetic generalization of `if b { x } else { y }`,
|
||||
/// i.e. `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`.
|
||||
/// i.e. `bx - (by-y)`.
|
||||
pub fn select_ext_generalized(
|
||||
&mut self,
|
||||
b: ExtensionTarget<D>,
|
||||
|
||||
@ -14,6 +14,7 @@ pub mod insertion;
|
||||
pub mod interpolation;
|
||||
pub mod noop;
|
||||
pub mod poseidon;
|
||||
pub(crate) mod poseidon_mds;
|
||||
pub(crate) mod public_input;
|
||||
pub mod random_access;
|
||||
pub mod reducing;
|
||||
|
||||
@ -3,8 +3,9 @@ use std::marker::PhantomData;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::poseidon_mds::PoseidonMdsGate;
|
||||
use crate::hash::poseidon;
|
||||
use crate::hash::poseidon::Poseidon;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
@ -93,7 +94,7 @@ where
|
||||
[(); WIDTH - 1]: ,
|
||||
{
|
||||
fn id(&self) -> String {
|
||||
format!("<WIDTH={}> {:?}", WIDTH, self)
|
||||
format!("{:?}<WIDTH={}>", self, WIDTH)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
@ -256,6 +257,10 @@ where
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
|
||||
let naive =
|
||||
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
|
||||
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
// Assert that `swap` is binary.
|
||||
@ -263,18 +268,23 @@ where
|
||||
constraints.push(builder.mul_sub_extension(swap, swap, swap));
|
||||
|
||||
let mut state = Vec::with_capacity(WIDTH);
|
||||
// We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`.
|
||||
// We will arithmetize them as
|
||||
// swap (b - a) + a
|
||||
// -swap (b - a) + b
|
||||
// so that `b - a` can be used for both.
|
||||
let mut state_first_4 = vec![];
|
||||
let mut state_next_4 = vec![];
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i];
|
||||
let b = vars.local_wires[i + 4];
|
||||
let delta = builder.sub_extension(b, a);
|
||||
state.push(builder.mul_add_extension(swap, delta, a));
|
||||
}
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i + 4];
|
||||
let b = vars.local_wires[i];
|
||||
let delta = builder.sub_extension(b, a);
|
||||
state.push(builder.mul_add_extension(swap, delta, a));
|
||||
state_first_4.push(builder.mul_add_extension(swap, delta, a));
|
||||
state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b));
|
||||
}
|
||||
|
||||
state.extend(state_first_4);
|
||||
state.extend(state_next_4);
|
||||
for i in 8..WIDTH {
|
||||
state.push(vars.local_wires[i]);
|
||||
}
|
||||
@ -296,27 +306,39 @@ where
|
||||
}
|
||||
|
||||
// Partial rounds.
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &mut state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
if naive {
|
||||
for r in 0..poseidon::N_PARTIAL_ROUNDS {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
} else {
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &mut state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state[0] = builder.add_const_extension(
|
||||
state[0],
|
||||
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
|
||||
);
|
||||
state =
|
||||
<F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state[0] = builder.add_const_extension(
|
||||
state[0],
|
||||
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(
|
||||
builder,
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(
|
||||
builder,
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
// Second set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
|
||||
274
src/gates/poseidon_mds.rs
Normal file
274
src/gates/poseidon_mds.rs
Normal file
@ -0,0 +1,274 @@
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::algebra::ExtensionAlgebra;
|
||||
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::hash::poseidon::Poseidon;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PoseidonMdsGate<
|
||||
F: RichField + Extendable<D> + Poseidon<WIDTH>,
|
||||
const D: usize,
|
||||
const WIDTH: usize,
|
||||
> where
|
||||
[(); WIDTH - 1]: ,
|
||||
{
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
|
||||
PoseidonMdsGate<F, D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]: ,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
PoseidonMdsGate {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wires_input(i: usize) -> Range<usize> {
|
||||
assert!(i < WIDTH);
|
||||
i * D..(i + 1) * D
|
||||
}
|
||||
|
||||
pub fn wires_output(i: usize) -> Range<usize> {
|
||||
assert!(i < WIDTH);
|
||||
(WIDTH + i) * D..(WIDTH + i + 1) * D
|
||||
}
|
||||
|
||||
// Following are methods analogous to ones in `Poseidon`, but for extension algebras.
|
||||
|
||||
/// Same as `mds_row_shf` for an extension algebra of `F`.
|
||||
fn mds_row_shf_algebra(
|
||||
r: usize,
|
||||
v: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
|
||||
) -> ExtensionAlgebra<F::Extension, D> {
|
||||
debug_assert!(r < WIDTH);
|
||||
let mut res = ExtensionAlgebra::ZERO;
|
||||
|
||||
for i in 0..WIDTH {
|
||||
let coeff =
|
||||
F::Extension::from_canonical_u64(1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]);
|
||||
res += v[(i + r) % WIDTH].scalar_mul(coeff);
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Same as `mds_row_shf_recursive` for an extension algebra of `F`.
|
||||
fn mds_row_shf_algebra_recursive(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
r: usize,
|
||||
v: &[ExtensionAlgebraTarget<D>; WIDTH],
|
||||
) -> ExtensionAlgebraTarget<D> {
|
||||
debug_assert!(r < WIDTH);
|
||||
let mut res = builder.zero_ext_algebra();
|
||||
|
||||
for i in 0..WIDTH {
|
||||
let coeff = builder.constant_extension(F::Extension::from_canonical_u64(
|
||||
1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i],
|
||||
));
|
||||
res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res);
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Same as `mds_layer` for an extension algebra of `F`.
|
||||
fn mds_layer_algebra(
|
||||
state: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
|
||||
) -> [ExtensionAlgebra<F::Extension, D>; WIDTH] {
|
||||
let mut result = [ExtensionAlgebra::ZERO; WIDTH];
|
||||
|
||||
for r in 0..WIDTH {
|
||||
result[r] = Self::mds_row_shf_algebra(r, state);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Same as `mds_layer_recursive` for an extension algebra of `F`.
|
||||
fn mds_layer_algebra_recursive(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &[ExtensionAlgebraTarget<D>; WIDTH],
|
||||
) -> [ExtensionAlgebraTarget<D>; WIDTH] {
|
||||
let mut result = [builder.zero_ext_algebra(); WIDTH];
|
||||
|
||||
for r in 0..WIDTH {
|
||||
result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
|
||||
for PoseidonMdsGate<F, D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]: ,
|
||||
{
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}<WIDTH={}>", self, WIDTH)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let computed_outputs = Self::mds_layer_algebra(&inputs);
|
||||
|
||||
(0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let computed_outputs = F::mds_layer_field(&inputs);
|
||||
|
||||
(0..WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs);
|
||||
|
||||
(0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| {
|
||||
builder
|
||||
.sub_ext_algebra(out, computed_out)
|
||||
.to_ext_target_array()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = PoseidonMdsGenerator::<D, WIDTH> { gate_index };
|
||||
vec![Box::new(gen.adapter())]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
2 * D * WIDTH
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
WIDTH * D
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PoseidonMdsGenerator<const D: usize, const WIDTH: usize>
|
||||
where
|
||||
[(); WIDTH - 1]: ,
|
||||
{
|
||||
gate_index: usize,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
|
||||
SimpleGenerator<F> for PoseidonMdsGenerator<D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]: ,
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
(0..WIDTH)
|
||||
.flat_map(|i| {
|
||||
Target::wires_from_range(
|
||||
self.gate_index,
|
||||
PoseidonMdsGate::<F, D, WIDTH>::wires_input(i),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let get_local_get_target =
|
||||
|wire_range| ExtensionTarget::from_range(self.gate_index, wire_range);
|
||||
let get_local_ext =
|
||||
|wire_range| witness.get_extension_target(get_local_get_target(wire_range));
|
||||
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| get_local_ext(PoseidonMdsGate::<F, D, WIDTH>::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let outputs = F::mds_layer_field(&inputs);
|
||||
|
||||
for (i, &out) in outputs.iter().enumerate() {
|
||||
out_buffer.set_extension_target(
|
||||
get_local_get_target(PoseidonMdsGate::<F, D, WIDTH>::wires_output(i)),
|
||||
out,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::poseidon_mds::PoseidonMdsGate;
|
||||
use crate::hash::hashing::SPONGE_WIDTH;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
type F = GoldilocksField;
|
||||
let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
|
||||
test_low_degree(gate)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> anyhow::Result<()> {
|
||||
type F = GoldilocksField;
|
||||
let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
|
||||
test_eval_fns(gate)
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,15 @@
|
||||
//! Implementation of the Poseidon hash function, as described in
|
||||
//! https://eprint.iacr.org/2019/458.pdf
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{PrimeField, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::poseidon_mds::PoseidonMdsGate;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
// The number of full rounds and partial rounds is given by the
|
||||
@ -205,17 +209,20 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `mds_row_shf`.
|
||||
fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn mds_row_shf_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
r: usize,
|
||||
v: &[ExtensionTarget<D>; WIDTH],
|
||||
) -> ExtensionTarget<D> {
|
||||
) -> ExtensionTarget<D>
|
||||
where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
debug_assert!(r < WIDTH);
|
||||
let mut res = builder.zero_extension();
|
||||
|
||||
for i in 0..WIDTH {
|
||||
res = builder.mul_const_add_extension(
|
||||
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
|
||||
Self::from_canonical_u64(1 << <Self as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]),
|
||||
v[(i + r) % WIDTH],
|
||||
res,
|
||||
);
|
||||
@ -262,17 +269,38 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `mds_layer`.
|
||||
fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn mds_layer_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
state: &[ExtensionTarget<D>; WIDTH],
|
||||
) -> [ExtensionTarget<D>; WIDTH] {
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
) -> [ExtensionTarget<D>; WIDTH]
|
||||
where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
// If we have enough routed wires, we will use PoseidonMdsGate.
|
||||
let mds_gate = PoseidonMdsGate::<Self, D, WIDTH>::new();
|
||||
if builder.config.num_routed_wires >= mds_gate.num_wires() {
|
||||
let index = builder.add_gate(mds_gate, vec![]);
|
||||
for i in 0..WIDTH {
|
||||
let input_wire = PoseidonMdsGate::<Self, D, WIDTH>::wires_input(i);
|
||||
builder.connect_extension(state[i], ExtensionTarget::from_range(index, input_wire));
|
||||
}
|
||||
(0..WIDTH)
|
||||
.map(|i| {
|
||||
let output_wire = PoseidonMdsGate::<Self, D, WIDTH>::wires_output(i);
|
||||
ExtensionTarget::from_range(index, output_wire)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
} else {
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
|
||||
for r in 0..WIDTH {
|
||||
result[r] = Self::mds_row_shf_recursive(builder, r, state);
|
||||
for r in 0..WIDTH {
|
||||
result[r] = Self::mds_row_shf_recursive(builder, r, state);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@ -289,14 +317,18 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `partial_first_constant_layer`.
|
||||
fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn partial_first_constant_layer_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
state: &mut [ExtensionTarget<D>; WIDTH],
|
||||
) {
|
||||
) where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
for i in 0..WIDTH {
|
||||
state[i] = builder.add_const_extension(
|
||||
state[i],
|
||||
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
|
||||
Self::from_canonical_u64(
|
||||
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i],
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -334,18 +366,22 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `mds_partial_layer_init`.
|
||||
fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn mds_partial_layer_init_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
state: &[ExtensionTarget<D>; WIDTH],
|
||||
) -> [ExtensionTarget<D>; WIDTH] {
|
||||
) -> [ExtensionTarget<D>; WIDTH]
|
||||
where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
|
||||
result[0] = state[0];
|
||||
|
||||
for r in 1..WIDTH {
|
||||
for c in 1..WIDTH {
|
||||
let t =
|
||||
F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]);
|
||||
let t = Self::from_canonical_u64(
|
||||
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1],
|
||||
);
|
||||
result[c] = builder.mul_const_add_extension(t, state[r], result[c]);
|
||||
}
|
||||
}
|
||||
@ -414,23 +450,32 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `mds_partial_layer_fast`.
|
||||
fn mds_partial_layer_fast_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn mds_partial_layer_fast_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
state: &[ExtensionTarget<D>; WIDTH],
|
||||
r: usize,
|
||||
) -> [ExtensionTarget<D>; WIDTH] {
|
||||
) -> [ExtensionTarget<D>; WIDTH]
|
||||
where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
let s0 = state[0];
|
||||
let mut d =
|
||||
builder.mul_const_extension(F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), s0);
|
||||
let mut d = builder.mul_const_extension(
|
||||
Self::from_canonical_u64(1 << <Self as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[0]),
|
||||
s0,
|
||||
);
|
||||
for i in 1..WIDTH {
|
||||
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
|
||||
let t = Self::from_canonical_u64(
|
||||
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1],
|
||||
);
|
||||
d = builder.mul_const_add_extension(t, state[i], d);
|
||||
}
|
||||
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
result[0] = d;
|
||||
for i in 1..WIDTH {
|
||||
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
|
||||
let t = Self::from_canonical_u64(
|
||||
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_VS[r][i - 1],
|
||||
);
|
||||
result[i] = builder.mul_const_add_extension(t, state[0], state[i]);
|
||||
}
|
||||
result
|
||||
@ -461,15 +506,17 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `constant_layer`.
|
||||
fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn constant_layer_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
state: &mut [ExtensionTarget<D>; WIDTH],
|
||||
round_ctr: usize,
|
||||
) {
|
||||
) where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
for i in 0..WIDTH {
|
||||
state[i] = builder.add_const_extension(
|
||||
state[i],
|
||||
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
|
||||
Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -484,10 +531,13 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `sbox_monomial`.
|
||||
fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn sbox_monomial_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
x: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
) -> ExtensionTarget<D>
|
||||
where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
// x |--> x^7
|
||||
builder.exp_u64_extension(x, 7)
|
||||
}
|
||||
@ -513,12 +563,14 @@ where
|
||||
}
|
||||
|
||||
/// Recursive version of `sbox_layer`.
|
||||
fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
fn sbox_layer_recursive<const D: usize>(
|
||||
builder: &mut CircuitBuilder<Self, D>,
|
||||
state: &mut [ExtensionTarget<D>; WIDTH],
|
||||
) {
|
||||
) where
|
||||
Self: RichField + Extendable<D>,
|
||||
{
|
||||
for i in 0..WIDTH {
|
||||
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
|
||||
state[i] = <Self as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, state[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user