Optimize recursive Poseidon constraint evaluation (#333)

* More wires for ConstantGate

* fix

* fix

* Optimize recursive Poseidon constraint evaluation

- Avoid `ArithmeticGate`s with unique constants; use `ConstantGate` wires instead
- Avoid an unnecessary squaring in exponentiations

Brings Poseidon evaluation down to a reasonable 273 gates when `num_routed_wires = 48`.
This commit is contained in:
Daniel Lubarov 2021-11-02 14:42:30 -07:00 committed by GitHub
parent e39af10a6b
commit c8e043a53f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 28 deletions

View File

@ -412,10 +412,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut product = self.one_extension();
for j in 0..bits_u64(exponent as u64) {
if j != 0 {
current = self.square_extension(current);
}
if (exponent >> j & 1) != 0 {
product = self.mul_extension(product, current);
}
current = self.square_extension(current);
}
product
}

View File

@ -7,7 +7,7 @@ use unroll::unroll_for_loops;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{PrimeField, RichField};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate;
use crate::gates::poseidon_mds::PoseidonMdsGate;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -216,11 +216,8 @@ where
let mut res = builder.zero_extension();
for i in 0..WIDTH {
res = builder.mul_const_add_extension(
Self::from_canonical_u64(1 << <Self as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]),
v[(i + r) % WIDTH],
res,
);
let c = Self::from_canonical_u64(1 << <Self as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]);
res = builder.mul_const_add_extension(c, v[(i + r) % WIDTH], res);
}
res
@ -319,12 +316,10 @@ where
Self: RichField + Extendable<D>,
{
for i in 0..WIDTH {
state[i] = builder.add_const_extension(
state[i],
Self::from_canonical_u64(
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i],
),
);
let c = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i];
let c = Self::Extension::from_canonical_u64(c);
let c = builder.constant_extension(c);
state[i] = builder.add_extension(state[i], c);
}
}
@ -374,10 +369,10 @@ where
for r in 1..WIDTH {
for c in 1..WIDTH {
let t = Self::from_canonical_u64(
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1],
);
result[c] = builder.mul_const_add_extension(t, state[r], result[c]);
let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1];
let t = Self::Extension::from_canonical_u64(t);
let t = builder.constant_extension(t);
result[c] = builder.mul_add_extension(t, state[r], result[c]);
}
}
result
@ -459,19 +454,18 @@ where
s0,
);
for i in 1..WIDTH {
let t = Self::from_canonical_u64(
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1],
);
let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
let t = Self::from_canonical_u64(t);
d = builder.mul_const_add_extension(t, state[i], d);
}
let mut result = [builder.zero_extension(); WIDTH];
result[0] = d;
for i in 1..WIDTH {
let t = Self::from_canonical_u64(
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_VS[r][i - 1],
);
result[i] = builder.mul_const_add_extension(t, state[0], state[i]);
let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_VS[r][i - 1];
let t = Self::Extension::from_canonical_u64(t);
let t = builder.constant_extension(t);
result[i] = builder.mul_add_extension(t, state[0], state[i]);
}
result
}
@ -509,10 +503,10 @@ where
Self: RichField + Extendable<D>,
{
for i in 0..WIDTH {
state[i] = builder.add_const_extension(
state[i],
Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
);
let c = ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr];
let c = Self::Extension::from_canonical_u64(c);
let c = builder.constant_extension(c);
state[i] = builder.add_extension(state[i], c);
}
}