Optimize recursive Poseidon constraint evaluation (#333)

* More wires for ConstantGate

* fix

* fix

* Optimize recursive Poseidon constraint evaluation

- Avoid `ArithmeticGate`s with unique constants; use `ConstantGate` wires instead
- Avoid an unnecessary squaring in exponentiations

Brings Poseidon evaluation down to a reasonable 273 gates when `num_routed_wires = 48`.
This commit is contained in:
Daniel Lubarov 2021-11-02 14:42:30 -07:00 committed by GitHub
parent e39af10a6b
commit c8e043a53f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 28 deletions

View File

@ -412,10 +412,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut product = self.one_extension(); let mut product = self.one_extension();
for j in 0..bits_u64(exponent as u64) { for j in 0..bits_u64(exponent as u64) {
if j != 0 {
current = self.square_extension(current);
}
if (exponent >> j & 1) != 0 { if (exponent >> j & 1) != 0 {
product = self.mul_extension(product, current); product = self.mul_extension(product, current);
} }
current = self.square_extension(current);
} }
product product
} }

View File

@ -7,7 +7,7 @@ use unroll::unroll_for_loops;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{PrimeField, RichField}; use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate; use crate::gates::gate::Gate;
use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::gates::poseidon_mds::PoseidonMdsGate;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
@ -216,11 +216,8 @@ where
let mut res = builder.zero_extension(); let mut res = builder.zero_extension();
for i in 0..WIDTH { for i in 0..WIDTH {
res = builder.mul_const_add_extension( let c = Self::from_canonical_u64(1 << <Self as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]);
Self::from_canonical_u64(1 << <Self as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]), res = builder.mul_const_add_extension(c, v[(i + r) % WIDTH], res);
v[(i + r) % WIDTH],
res,
);
} }
res res
@ -319,12 +316,10 @@ where
Self: RichField + Extendable<D>, Self: RichField + Extendable<D>,
{ {
for i in 0..WIDTH { for i in 0..WIDTH {
state[i] = builder.add_const_extension( let c = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i];
state[i], let c = Self::Extension::from_canonical_u64(c);
Self::from_canonical_u64( let c = builder.constant_extension(c);
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i], state[i] = builder.add_extension(state[i], c);
),
);
} }
} }
@ -374,10 +369,10 @@ where
for r in 1..WIDTH { for r in 1..WIDTH {
for c in 1..WIDTH { for c in 1..WIDTH {
let t = Self::from_canonical_u64( let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1];
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1], let t = Self::Extension::from_canonical_u64(t);
); let t = builder.constant_extension(t);
result[c] = builder.mul_const_add_extension(t, state[r], result[c]); result[c] = builder.mul_add_extension(t, state[r], result[c]);
} }
} }
result result
@ -459,19 +454,18 @@ where
s0, s0,
); );
for i in 1..WIDTH { for i in 1..WIDTH {
let t = Self::from_canonical_u64( let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1], let t = Self::from_canonical_u64(t);
);
d = builder.mul_const_add_extension(t, state[i], d); d = builder.mul_const_add_extension(t, state[i], d);
} }
let mut result = [builder.zero_extension(); WIDTH]; let mut result = [builder.zero_extension(); WIDTH];
result[0] = d; result[0] = d;
for i in 1..WIDTH { for i in 1..WIDTH {
let t = Self::from_canonical_u64( let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_VS[r][i - 1];
<Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_VS[r][i - 1], let t = Self::Extension::from_canonical_u64(t);
); let t = builder.constant_extension(t);
result[i] = builder.mul_const_add_extension(t, state[0], state[i]); result[i] = builder.mul_add_extension(t, state[0], state[i]);
} }
result result
} }
@ -509,10 +503,10 @@ where
Self: RichField + Extendable<D>, Self: RichField + Extendable<D>,
{ {
for i in 0..WIDTH { for i in 0..WIDTH {
state[i] = builder.add_const_extension( let c = ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr];
state[i], let c = Self::Extension::from_canonical_u64(c);
Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), let c = builder.constant_extension(c);
); state[i] = builder.add_extension(state[i], c);
} }
} }