improve the Poseidon API

(thanks goes to Chrysostomos Nanakos for the help!)
This commit is contained in:
Balazs Komuves 2026-01-29 14:56:08 +01:00
parent 4740fa3d88
commit e596c5b16b
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
7 changed files with 113 additions and 234 deletions

View File

@ -3,7 +3,6 @@ use criterion::{criterion_group, criterion_main, Criterion};
use std::hint::{black_box};
use rust_poseidon_bn254_pure::bn254::field::*;
use rust_poseidon_bn254_pure::bn254::montgomery::{Mont};
use rust_poseidon_bn254_pure::poseidon;
use rust_poseidon_bn254_pure::poseidon2;
@ -18,17 +17,10 @@ fn initial_triple() -> Triple {
]
}
pub fn poseidon1_permute_felt(input: Triple) -> Triple {
let mut state: [Mont; 3] = Felt::to_mont_vec(input);
state = poseidon::permutation::permute_mont_T3(state);
let out: Triple = Felt::from_mont_vec(state);
out
}
fn iterate_poseidon1(n: usize) -> Triple {
let mut state: Triple = initial_triple();
for _i in 0..n {
state = poseidon1_permute_felt(state);
state = poseidon::permute::<3>(state);
}
state
}
@ -36,7 +28,7 @@ fn iterate_poseidon1(n: usize) -> Triple {
fn iterate_poseidon2(n: usize) -> Triple {
let mut state: Triple = initial_triple();
for _i in 0..n {
state = poseidon2::permutation::permute_felt(state);
state = poseidon2::permute(state);
}
state
}

View File

@ -8,9 +8,8 @@ use rust_poseidon_bn254_pure::bn254::constant::*;
use rust_poseidon_bn254_pure::bn254::montgomery::*;
use rust_poseidon_bn254_pure::bn254::field::*;
use rust_poseidon_bn254_pure::poseidon2::permutation::*;
use rust_poseidon_bn254_pure::poseidon::permutation::*;
use rust_poseidon_bn254_pure::poseidon;
use rust_poseidon_bn254_pure::poseidon2;
//------------------------------------------------------------------------------
@ -111,7 +110,7 @@ fn main() {
println!("");
let input = [ Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ];
let output = permute_felt( input );
let output = poseidon2::permute( input );
println!("x = {}", input[0] );
println!("y = {}", input[1] );
@ -136,7 +135,7 @@ fn main() {
let now = Instant::now();
let mut state: [Felt; 3] = input.clone();
for _i in 0..10000 {
state = permute_felt(state);
state = poseidon2::permute(state);
}
// expected output:
@ -205,20 +204,16 @@ fn main() {
// compress3 = 6542985608222806190361240322586112750744169038454362455181422643027100751666
// compress4 = 18821383157269793795438455681495246036402687001665670618754263018637548127333
let in1: Felt = Felt::from_u32(1);
let out1 = compress_1(in1);
let out1 = poseidon::hash1( Felt::from_u32(1) );
println!("compress(1) = {}", Felt::to_decimal_string(out1) );
let in2: [Felt; 2] = [ Felt::from_u32(1) , Felt::from_u32(2) ];
let out2 = compress_2(in2);
let out2 = poseidon::hash2( Felt::from_u32(1) , Felt::from_u32(2) );
println!("compress(2) = {}", Felt::to_decimal_string(out2) );
let in3: [Felt; 3] = [ Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) ];
let out3 = compress_3(in3);
let out3 = poseidon::hash3( Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) );
println!("compress(3) = {}", Felt::to_decimal_string(out3) );
let in4: [Felt; 4] = [ Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) , Felt::from_u32(4) ];
let out4 = compress_4(in4);
let out4 = poseidon::hash4( Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) , Felt::from_u32(4) );
println!("compress(4) = {}", Felt::to_decimal_string(out4) );
}

View File

@ -1,6 +1,6 @@
use rust_poseidon_bn254_pure::bn254::field::*;
use rust_poseidon_bn254_pure::poseidon2::permutation::*;
use rust_poseidon_bn254_pure::poseidon2;
fn main() {
@ -14,7 +14,7 @@ fn main() {
let mut state: [Felt; 3] = input.clone();
for _i in 0..20 {
state = permute_felt(state);
state = poseidon2::permute(state);
}
println!("x' = {}", state[0] );
println!("y' = {}", state[1] );

View File

@ -2,3 +2,12 @@
pub mod constants;
pub mod permutation;
pub use permutation::hash1;
pub use permutation::hash2;
pub use permutation::hash3;
pub use permutation::hash4;
pub use permutation::compress;
pub use permutation::permute;
pub use permutation::permute_mont;

View File

@ -1,10 +1,9 @@
#![allow(unused)]
//
// circomlib-compatible Poseidon (v1) implementation
//
#![allow(unused)]
#![allow(dead_code)]
#![allow(non_snake_case)]
@ -18,31 +17,36 @@ use crate::poseidon::constants::t5;
//------------------------------------------------------------------------------
// width of the permutation state
#[derive(Copy, Clone)]
#[repr(usize)]
pub enum Width {
T2 = 2,
T3 = 3,
T4 = 4,
T5 = 5,
}
//------------------------------------------------------------------------------
// number of internal rounds for `t = 2..17`
const INTERNAL_ROUND_COUNT: [usize; 16] = [56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68];
const fn internal_round_count(T: usize) -> usize {
let k = T - 2;
if k < 16 {
INTERNAL_ROUND_COUNT[ k - 2 ]
}
else {
0
}
pub trait PoseidonParams<const T: usize> {
const NP: usize;
fn const_C() -> &'static [Mont];
fn const_M() -> &'static [Mont];
fn const_P() -> &'static [Mont];
fn const_S() -> &'static [Mont];
}
pub struct Params;
macro_rules! impl_params {
($T:literal, $mod:ident) => {
impl PoseidonParams<$T> for Params {
const NP: usize = INTERNAL_ROUND_COUNT[$T - 2];
fn const_C() -> &'static [Mont] { &$mod::CONST_C }
fn const_M() -> &'static [Mont] { &$mod::CONST_M }
fn const_P() -> &'static [Mont] { &$mod::CONST_P }
fn const_S() -> &'static [Mont] { &$mod::CONST_S }
}
};
}
impl_params!(2, t2);
impl_params!(3, t3);
impl_params!(4, t4);
impl_params!(5, t5);
//------------------------------------------------------------------------------
#[inline(always)]
@ -52,7 +56,7 @@ fn sbox(x: Mont) -> Mont {
Mont::mul(x,x4)
}
fn matrix_mul<const T: usize>(input: [Mont; T], mtx: [Mont; T*T]) -> [Mont; T] {
fn matrix_mul<const T: usize>(input: [Mont; T], mtx: &[Mont]) -> [Mont; T] {
let mut out: [Mont; T] = [Mont::zero(); T];
for i in 0..T {
let mut acc: Mont = Mont::zero();
@ -77,13 +81,13 @@ fn mix_S<const T: usize>(input: [Mont; T], scoeffs: &[Mont]) -> [Mont; T] {
out
}
fn internal_round<const T: usize>(rc: Mont, scoeffs: &[Mont], input: [Mont; T]) -> [Mont; T] {
fn internal_round<const T: usize>(input: [Mont; T], rc: Mont, scoeffs: &[Mont]) -> [Mont; T] {
let mut xs: [Mont; T] = input;
xs[0] = Mont::add( sbox( xs[0] ) , rc );
mix_S::<T>(xs, scoeffs)
}
fn external_round<const T: usize>(rcs: &[Mont], input: [Mont; T], mtx: [Mont; T*T]) -> [Mont; T] {
fn external_round<const T: usize>(input: [Mont; T], rcs: &[Mont], mtx: &[Mont]) -> [Mont; T] {
let mut xs: [Mont; T] = [Mont::zero(); T];
for j in 0..T {
xs[j] = Mont::add( sbox( input[j] ) , rcs[j] );
@ -92,30 +96,15 @@ fn external_round<const T: usize>(rcs: &[Mont], input: [Mont; T], mtx: [Mont; T*
}
//------------------------------------------------------------------------------
// TODO: can we somehow unify the different T cases????
/*
// debugging
fn printRound(text: &str, round: usize, state: &[Mont]) {
println!("{} {} -> ", text, round);
for x in state {
println!(" {}", Mont::to_decimal_string(x) );
}
}
*/
pub fn permute_mont<const T: usize>(input: [Mont; T]) -> [Mont; T] where Params: PoseidonParams<T> {
//--------------------------------------
// T = 2
pub fn permute_mont_T2(input: [Mont; 2]) -> [Mont; 2] {
const T: usize = 2;
const TT: usize = 2*T-1;
const NP: usize = INTERNAL_ROUND_COUNT[T-2];
const C: [Mont; 72] = t2::CONST_C;
const M: [Mont; 4] = t2::CONST_M;
const P: [Mont; 4] = t2::CONST_P;
const S: [Mont; 168] = t2::CONST_S;
let TT = 2*T - 1;
let NP = <Params as PoseidonParams<T>>::NP;
let C = <Params as PoseidonParams<T>>::const_C();
let M = <Params as PoseidonParams<T>>::const_M();
let P = <Params as PoseidonParams<T>>::const_P();
let S = <Params as PoseidonParams<T>>::const_S();
let mut state: [Mont; T] = input;
for j in 0..T {
@ -124,173 +113,58 @@ pub fn permute_mont_T2(input: [Mont; 2]) -> [Mont; 2] {
for i in 0..4 {
let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ];
let mat = if i<3 { M } else { P };
state = external_round::<T>( rcs , state , mat );
// printRound("initial round", i, &state);
state = external_round::<T>( state , rcs , mat );
}
for i in 0..NP {
let rc: Mont = C[ i + 5*T ];
let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ];
state = internal_round::<T>( rc , scoeffs , state );
// printRound("internal round", i, &state);
state = internal_round::<T>( state , rc , scoeffs );
}
for i in 4..8 {
let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] };
state = external_round::<T>( rcs , state , M );
// printRound("final round", i, &state);
state = external_round::<T>( state , rcs , M );
}
state
}
//--------------------------------------
pub fn permute_mont_T3(input: [Mont; 3]) -> [Mont; 3] {
const T: usize = 3;
const TT: usize = 2*T-1;
const NP: usize = INTERNAL_ROUND_COUNT[T-2];
const C: [Mont; 81] = t3::CONST_C;
const M: [Mont; 9] = t3::CONST_M;
const P: [Mont; 9] = t3::CONST_P;
const S: [Mont; 285] = t3::CONST_S;
let mut state: [Mont; T] = input;
for j in 0..T {
state[j] = Mont::add( state[j] , C[j] );
}
for i in 0..4 {
let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ];
let mat = if i<3 { M } else { P };
state = external_round::<T>( rcs , state , mat );
// printRound("initial round", i, &state);
}
for i in 0..NP {
let rc: Mont = C[ i + 5*T ];
let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ];
state = internal_round::<T>( rc , scoeffs , state );
// printRound("internal round", i, &state);
}
for i in 4..8 {
let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] };
state = external_round::<T>( rcs , state , M );
// printRound("final round", i, &state);
}
state
}
//--------------------------------------
pub fn permute_mont_T4(input: [Mont; 4]) -> [Mont; 4] {
const T: usize = 4;
const TT: usize = 2*T-1;
const NP: usize = INTERNAL_ROUND_COUNT[T-2];
const C: [Mont; 88] = t4::CONST_C;
const M: [Mont; 16] = t4::CONST_M;
const P: [Mont; 16] = t4::CONST_P;
const S: [Mont; 392] = t4::CONST_S;
let mut state: [Mont; T] = input;
for j in 0..T {
state[j] = Mont::add( state[j] , C[j] );
}
for i in 0..4 {
let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ];
let mat = if i<3 { M } else { P };
state = external_round::<T>( rcs , state , mat );
// printRound("initial round", i, &state);
}
for i in 0..NP {
let rc: Mont = C[ i + 5*T ];
let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ];
state = internal_round::<T>( rc , scoeffs , state );
// printRound("internal round", i, &state);
}
for i in 4..8 {
let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] };
state = external_round::<T>( rcs , state , M );
// printRound("final round", i, &state);
}
state
}
//--------------------------------------
pub fn permute_mont_T5(input: [Mont; 5]) -> [Mont; 5] {
const T: usize = 5;
const TT: usize = 2*T-1;
const NP: usize = INTERNAL_ROUND_COUNT[T-2];
const C: [Mont; 100] = t5::CONST_C;
const M: [Mont; 25] = t5::CONST_M;
const P: [Mont; 25] = t5::CONST_P;
const S: [Mont; 540] = t5::CONST_S;
let mut state: [Mont; T] = input;
for j in 0..T {
state[j] = Mont::add( state[j] , C[j] );
}
for i in 0..4 {
let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ];
let mat = if i<3 { M } else { P };
state = external_round::<T>( rcs , state , mat );
// printRound("initial round", i, &state);
}
for i in 0..NP {
let rc: Mont = C[ i + 5*T ];
let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ];
state = internal_round::<T>( rc , scoeffs , state );
// printRound("internal round", i, &state);
}
for i in 4..8 {
let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] };
state = external_round::<T>( rcs , state , M );
// printRound("final round", i, &state);
}
state
pub fn compress_mont<const K: usize>(input: [Mont; K]) -> Mont where Params: PoseidonParams<{K+1}> {
let mut state: [Mont; K+1] = [Mont::zero(); K+1];
for i in 0..K { state[i+1] = input[i]; }
state = permute_mont::<{K+1}>(state);
state[0]
}
//------------------------------------------------------------------------------
pub fn compress_1(input: Felt) -> Felt {
let mut state: [Mont; 2] =
[ Mont::zero()
, Felt::to_mont(input)
];
state = permute_mont_T2(state);
Felt::from_mont(state[0])
pub fn permute<const T: usize>(input: [Felt; T]) -> [Felt; T] where Params: PoseidonParams<T> {
let state: [Mont; T] = Felt::to_mont_vec(input);
let output = permute_mont::<T>(state);
Felt::from_mont_vec(output)
}
pub fn compress_2(input: [Felt;2]) -> Felt {
let mut state: [Mont; 3] =
[ Mont::zero()
, Felt::to_mont(input[0])
, Felt::to_mont(input[1])
];
state = permute_mont_T3(state);
Felt::from_mont(state[0])
}
pub fn compress_3(input: [Felt;3]) -> Felt {
let mut state: [Mont; 4] =
[ Mont::zero()
, Felt::to_mont(input[0])
, Felt::to_mont(input[1])
, Felt::to_mont(input[2])
];
state = permute_mont_T4(state);
Felt::from_mont(state[0])
}
pub fn compress_4(input: [Felt;4]) -> Felt {
let mut state: [Mont; 5] =
[ Mont::zero()
, Felt::to_mont(input[0])
, Felt::to_mont(input[1])
, Felt::to_mont(input[2])
, Felt::to_mont(input[3])
];
state = permute_mont_T5(state);
pub fn compress<const K: usize>(input: [Felt; K]) -> Felt where Params: PoseidonParams<{K+1}> {
let mut state: [Mont; K+1] = [Mont::zero(); K+1];
for i in 0..K { state[i+1] = Felt::to_mont(input[i]); }
state = permute_mont::<{K+1}>(state);
Felt::from_mont(state[0])
}
//------------------------------------------------------------------------------
pub fn hash1(a: Felt) -> Felt {
compress::<1>([ a ])
}
pub fn hash2(a: Felt, b: Felt) -> Felt {
compress::<2>([ a, b ])
}
pub fn hash3(a: Felt, b: Felt, c: Felt) -> Felt {
compress::<3>([ a, b, c ])
}
pub fn hash4(a: Felt, b: Felt, c: Felt, d: Felt) -> Felt {
compress::<4>([ a, b, c, d ])
}
//------------------------------------------------------------------------------

View File

@ -2,3 +2,7 @@
pub mod constants;
pub mod permutation;
pub use permutation::compress;
pub use permutation::permute;
pub use permutation::permute_mont;

View File

@ -31,7 +31,7 @@ fn linear(input: MontTriple) -> MontTriple {
]
}
fn internal_round(rc: Mont, input: MontTriple) -> MontTriple {
fn internal_round(input: MontTriple, rc: Mont) -> MontTriple {
let x = sbox( Mont::add( input[0] , rc ) );
let s = add3( x , input[1] , input[2] );
[ Mont::add( s , x )
@ -40,7 +40,7 @@ fn internal_round(rc: Mont, input: MontTriple) -> MontTriple {
]
}
fn external_round(rcs: MontTriple, input: MontTriple) -> MontTriple {
fn external_round(input: MontTriple, rcs: MontTriple) -> MontTriple {
let x = sbox( Mont::add( input[0] , rcs[0] ) );
let y = sbox( Mont::add( input[1] , rcs[1] ) );
let z = sbox( Mont::add( input[2] , rcs[2] ) );
@ -53,29 +53,34 @@ fn external_round(rcs: MontTriple, input: MontTriple) -> MontTriple {
pub fn permute_mont(input: MontTriple) -> MontTriple {
let mut state = linear(input);
for i in 0..4 { state = external_round( get_initial_RCs(i) , state ); }
for i in 0..56 { state = internal_round( INTERNAL_MONT [i] , state ); }
for i in 0..4 { state = external_round( get_final_RCs (i) , state ); }
for i in 0..4 { state = external_round( state , get_initial_RCs(i) ); }
for i in 0..56 { state = internal_round( state , INTERNAL_MONT [i] ); }
for i in 0..4 { state = external_round( state , get_final_RCs (i) ); }
state
}
//------------------------------------------------------------------------------
pub fn permute_felt_iterated(input: FeltTriple, count: usize) -> FeltTriple {
pub fn compress(input: [Felt; 2]) -> Felt {
let mut state: [Mont; 3] = [Mont::zero(); 3];
for i in 0..2 { state[i] = Felt::to_mont(input[i]); }
state = permute_mont(state);
Felt::from_mont(state[0])
}
pub fn permute(input: [Felt; 3]) -> [Felt; 3] {
let state: MontTriple = Felt::to_mont_vec(input);
let output = permute_mont(state);
Felt::from_mont_vec(output)
}
pub fn permute_iterated(input: [Felt; 3], count: usize) -> [Felt; 3] {
let mut state: MontTriple = Felt::to_mont_vec(input);
for _i in 0..count {
state = permute_mont(state);
}
let out: FeltTriple = Felt::from_mont_vec(state);
out
}
pub fn permute_felt(input: FeltTriple) -> FeltTriple {
permute_felt_iterated(input, 1)
}
//------------------------------------------------------------------------------