From c8e043a53ff57a03b2ce6afc9cda55c421f79bcd Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 2 Nov 2021 14:42:30 -0700 Subject: [PATCH] Optimize recursive Poseidon constraint evaluation (#333) * More wires for ConstantGate * fix * fix * Optimize recursive Poseidon constraint evaluation - Avoid `ArithmeticGate`s with unique constants; use `ConstantGate` wires instead - Avoid an unnecessary squaring in exponentiations Brings Poseidon evaluation down to a reasonable 273 gates when `num_routed_wires = 48`. --- src/gadgets/arithmetic_extension.rs | 4 ++- src/hash/poseidon.rs | 48 +++++++++++++---------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 3f35663a..f352cdf9 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -412,10 +412,12 @@ impl, const D: usize> CircuitBuilder { let mut product = self.one_extension(); for j in 0..bits_u64(exponent as u64) { + if j != 0 { + current = self.square_extension(current); + } if (exponent >> j & 1) != 0 { product = self.mul_extension(product, current); } - current = self.square_extension(current); } product } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 616d2a9f..9a52060c 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -7,7 +7,7 @@ use unroll::unroll_for_loops; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field_types::{PrimeField, RichField}; +use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::gate::Gate; use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::plonk::circuit_builder::CircuitBuilder; @@ -216,11 +216,8 @@ where let mut res = builder.zero_extension(); for i in 0..WIDTH { - res = builder.mul_const_add_extension( - Self::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]), - v[(i + r) % WIDTH], - res, - ); + let c = Self::from_canonical_u64(1 << >::MDS_MATRIX_EXPS[i]); + res = builder.mul_const_add_extension(c, v[(i + r) % WIDTH], res); } res @@ -319,12 +316,10 @@ where Self: RichField + Extendable, { for i in 0..WIDTH { - state[i] = builder.add_const_extension( - state[i], - Self::from_canonical_u64( - >::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i], - ), - ); + let c = >::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]; + let c = Self::Extension::from_canonical_u64(c); + let c = builder.constant_extension(c); + state[i] = builder.add_extension(state[i], c); } } @@ -374,10 +369,10 @@ where for r in 1..WIDTH { for c in 1..WIDTH { - let t = Self::from_canonical_u64( - >::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1], - ); - result[c] = builder.mul_const_add_extension(t, state[r], result[c]); + let t = >::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]; + let t = Self::Extension::from_canonical_u64(t); + let t = builder.constant_extension(t); + result[c] = builder.mul_add_extension(t, state[r], result[c]); } } result @@ -459,19 +454,18 @@ where s0, ); for i in 1..WIDTH { - let t = Self::from_canonical_u64( - >::FAST_PARTIAL_ROUND_W_HATS[r][i - 1], - ); + let t = >::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; + let t = Self::from_canonical_u64(t); d = builder.mul_const_add_extension(t, state[i], d); } let mut result = [builder.zero_extension(); WIDTH]; result[0] = d; for i in 1..WIDTH { - let t = Self::from_canonical_u64( - >::FAST_PARTIAL_ROUND_VS[r][i - 1], - ); - result[i] = builder.mul_const_add_extension(t, state[0], state[i]); + let t = >::FAST_PARTIAL_ROUND_VS[r][i - 1]; + let t = Self::Extension::from_canonical_u64(t); + let t = builder.constant_extension(t); + result[i] = builder.mul_add_extension(t, state[0], state[i]); } result } @@ -509,10 +503,10 @@ where Self: RichField + Extendable, { for i in 0..WIDTH { - state[i] = builder.add_const_extension( - state[i], - Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), - ); + let c = ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]; + let c = Self::Extension::from_canonical_u64(c); + let c = builder.constant_extension(c); + state[i] = builder.add_extension(state[i], c); } }