From e596c5b16bb01415751445135d92b5b39448336e Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 29 Jan 2026 14:56:08 +0100 Subject: [PATCH] improve the Poseidon API (thanks goes to Chrysostomos Nanakos for the help!) --- benches/iterated_perm.rs | 12 +- src/bin/testmain.rs | 21 ++- src/bin/twenty.rs | 4 +- src/poseidon/mod.rs | 9 ++ src/poseidon/permutation.rs | 266 +++++++++-------------------------- src/poseidon2/mod.rs | 4 + src/poseidon2/permutation.rs | 31 ++-- 7 files changed, 113 insertions(+), 234 deletions(-) diff --git a/benches/iterated_perm.rs b/benches/iterated_perm.rs index 36c1908..6290f3c 100644 --- a/benches/iterated_perm.rs +++ b/benches/iterated_perm.rs @@ -3,7 +3,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; use std::hint::{black_box}; use rust_poseidon_bn254_pure::bn254::field::*; -use rust_poseidon_bn254_pure::bn254::montgomery::{Mont}; use rust_poseidon_bn254_pure::poseidon; use rust_poseidon_bn254_pure::poseidon2; @@ -18,17 +17,10 @@ fn initial_triple() -> Triple { ] } -pub fn poseidon1_permute_felt(input: Triple) -> Triple { - let mut state: [Mont; 3] = Felt::to_mont_vec(input); - state = poseidon::permutation::permute_mont_T3(state); - let out: Triple = Felt::from_mont_vec(state); - out -} - fn iterate_poseidon1(n: usize) -> Triple { let mut state: Triple = initial_triple(); for _i in 0..n { - state = poseidon1_permute_felt(state); + state = poseidon::permute::<3>(state); } state } @@ -36,7 +28,7 @@ fn iterate_poseidon1(n: usize) -> Triple { fn iterate_poseidon2(n: usize) -> Triple { let mut state: Triple = initial_triple(); for _i in 0..n { - state = poseidon2::permutation::permute_felt(state); + state = poseidon2::permute(state); } state } diff --git a/src/bin/testmain.rs b/src/bin/testmain.rs index 6d8aef9..7180580 100644 --- a/src/bin/testmain.rs +++ b/src/bin/testmain.rs @@ -8,9 +8,8 @@ use rust_poseidon_bn254_pure::bn254::constant::*; use rust_poseidon_bn254_pure::bn254::montgomery::*; use rust_poseidon_bn254_pure::bn254::field::*; -use rust_poseidon_bn254_pure::poseidon2::permutation::*; - -use rust_poseidon_bn254_pure::poseidon::permutation::*; +use rust_poseidon_bn254_pure::poseidon; +use rust_poseidon_bn254_pure::poseidon2; //------------------------------------------------------------------------------ @@ -111,7 +110,7 @@ fn main() { println!(""); let input = [ Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ]; - let output = permute_felt( input ); + let output = poseidon2::permute( input ); println!("x = {}", input[0] ); println!("y = {}", input[1] ); @@ -136,7 +135,7 @@ fn main() { let now = Instant::now(); let mut state: [Felt; 3] = input.clone(); for _i in 0..10000 { - state = permute_felt(state); + state = poseidon2::permute(state); } // expected output: @@ -205,20 +204,16 @@ fn main() { // compress3 = 6542985608222806190361240322586112750744169038454362455181422643027100751666 // compress4 = 18821383157269793795438455681495246036402687001665670618754263018637548127333 - let in1: Felt = Felt::from_u32(1); - let out1 = compress_1(in1); + let out1 = poseidon::hash1( Felt::from_u32(1) ); println!("compress(1) = {}", Felt::to_decimal_string(out1) ); - let in2: [Felt; 2] = [ Felt::from_u32(1) , Felt::from_u32(2) ]; - let out2 = compress_2(in2); + let out2 = poseidon::hash2( Felt::from_u32(1) , Felt::from_u32(2) ); println!("compress(2) = {}", Felt::to_decimal_string(out2) ); - let in3: [Felt; 3] = [ Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) ]; - let out3 = compress_3(in3); + let out3 = poseidon::hash3( Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) ); println!("compress(3) = {}", Felt::to_decimal_string(out3) ); - let in4: [Felt; 4] = [ Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) , Felt::from_u32(4) ]; - let out4 = compress_4(in4); + let out4 = poseidon::hash4( Felt::from_u32(1) , Felt::from_u32(2) , Felt::from_u32(3) , Felt::from_u32(4) ); println!("compress(4) = {}", Felt::to_decimal_string(out4) ); } diff --git a/src/bin/twenty.rs b/src/bin/twenty.rs index a7b9273..04a4898 100644 --- a/src/bin/twenty.rs +++ b/src/bin/twenty.rs @@ -1,6 +1,6 @@ use rust_poseidon_bn254_pure::bn254::field::*; -use rust_poseidon_bn254_pure::poseidon2::permutation::*; +use rust_poseidon_bn254_pure::poseidon2; fn main() { @@ -14,7 +14,7 @@ fn main() { let mut state: [Felt; 3] = input.clone(); for _i in 0..20 { - state = permute_felt(state); + state = poseidon2::permute(state); } println!("x' = {}", state[0] ); println!("y' = {}", state[1] ); diff --git a/src/poseidon/mod.rs b/src/poseidon/mod.rs index 925bc29..9715253 100644 --- a/src/poseidon/mod.rs +++ b/src/poseidon/mod.rs @@ -2,3 +2,12 @@ pub mod constants; pub mod permutation; +pub use permutation::hash1; +pub use permutation::hash2; +pub use permutation::hash3; +pub use permutation::hash4; + +pub use permutation::compress; + +pub use permutation::permute; +pub use permutation::permute_mont; \ No newline at end of file diff --git a/src/poseidon/permutation.rs b/src/poseidon/permutation.rs index 25e0f57..8d0ee9d 100644 --- a/src/poseidon/permutation.rs +++ b/src/poseidon/permutation.rs @@ -1,10 +1,9 @@ -#![allow(unused)] - // // circomlib-compatible Poseidon (v1) implementation // +#![allow(unused)] #![allow(dead_code)] #![allow(non_snake_case)] @@ -18,31 +17,36 @@ use crate::poseidon::constants::t5; //------------------------------------------------------------------------------ -// width of the permutation state -#[derive(Copy, Clone)] -#[repr(usize)] -pub enum Width { - T2 = 2, - T3 = 3, - T4 = 4, - T5 = 5, -} - -//------------------------------------------------------------------------------ - // number of internal rounds for `t = 2..17` const INTERNAL_ROUND_COUNT: [usize; 16] = [56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68]; -const fn internal_round_count(T: usize) -> usize { - let k = T - 2; - if k < 16 { - INTERNAL_ROUND_COUNT[ k - 2 ] - } - else { - 0 - } +pub trait PoseidonParams { + const NP: usize; + fn const_C() -> &'static [Mont]; + fn const_M() -> &'static [Mont]; + fn const_P() -> &'static [Mont]; + fn const_S() -> &'static [Mont]; } +pub struct Params; + +macro_rules! impl_params { + ($T:literal, $mod:ident) => { + impl PoseidonParams<$T> for Params { + const NP: usize = INTERNAL_ROUND_COUNT[$T - 2]; + fn const_C() -> &'static [Mont] { &$mod::CONST_C } + fn const_M() -> &'static [Mont] { &$mod::CONST_M } + fn const_P() -> &'static [Mont] { &$mod::CONST_P } + fn const_S() -> &'static [Mont] { &$mod::CONST_S } + } + }; +} + +impl_params!(2, t2); +impl_params!(3, t3); +impl_params!(4, t4); +impl_params!(5, t5); + //------------------------------------------------------------------------------ #[inline(always)] @@ -52,7 +56,7 @@ fn sbox(x: Mont) -> Mont { Mont::mul(x,x4) } -fn matrix_mul(input: [Mont; T], mtx: [Mont; T*T]) -> [Mont; T] { +fn matrix_mul(input: [Mont; T], mtx: &[Mont]) -> [Mont; T] { let mut out: [Mont; T] = [Mont::zero(); T]; for i in 0..T { let mut acc: Mont = Mont::zero(); @@ -77,13 +81,13 @@ fn mix_S(input: [Mont; T], scoeffs: &[Mont]) -> [Mont; T] { out } -fn internal_round(rc: Mont, scoeffs: &[Mont], input: [Mont; T]) -> [Mont; T] { +fn internal_round(input: [Mont; T], rc: Mont, scoeffs: &[Mont]) -> [Mont; T] { let mut xs: [Mont; T] = input; xs[0] = Mont::add( sbox( xs[0] ) , rc ); mix_S::(xs, scoeffs) } -fn external_round(rcs: &[Mont], input: [Mont; T], mtx: [Mont; T*T]) -> [Mont; T] { +fn external_round(input: [Mont; T], rcs: &[Mont], mtx: &[Mont]) -> [Mont; T] { let mut xs: [Mont; T] = [Mont::zero(); T]; for j in 0..T { xs[j] = Mont::add( sbox( input[j] ) , rcs[j] ); @@ -92,30 +96,15 @@ fn external_round(rcs: &[Mont], input: [Mont; T], mtx: [Mont; T* } //------------------------------------------------------------------------------ -// TODO: can we somehow unify the different T cases???? -/* -// debugging -fn printRound(text: &str, round: usize, state: &[Mont]) { - println!("{} {} -> ", text, round); - for x in state { - println!(" {}", Mont::to_decimal_string(x) ); - } -} -*/ +pub fn permute_mont(input: [Mont; T]) -> [Mont; T] where Params: PoseidonParams { -//-------------------------------------- -// T = 2 - -pub fn permute_mont_T2(input: [Mont; 2]) -> [Mont; 2] { - const T: usize = 2; - - const TT: usize = 2*T-1; - const NP: usize = INTERNAL_ROUND_COUNT[T-2]; - const C: [Mont; 72] = t2::CONST_C; - const M: [Mont; 4] = t2::CONST_M; - const P: [Mont; 4] = t2::CONST_P; - const S: [Mont; 168] = t2::CONST_S; + let TT = 2*T - 1; + let NP = >::NP; + let C = >::const_C(); + let M = >::const_M(); + let P = >::const_P(); + let S = >::const_S(); let mut state: [Mont; T] = input; for j in 0..T { @@ -124,173 +113,58 @@ pub fn permute_mont_T2(input: [Mont; 2]) -> [Mont; 2] { for i in 0..4 { let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ]; let mat = if i<3 { M } else { P }; - state = external_round::( rcs , state , mat ); - // printRound("initial round", i, &state); + state = external_round::( state , rcs , mat ); } for i in 0..NP { let rc: Mont = C[ i + 5*T ]; let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ]; - state = internal_round::( rc , scoeffs , state ); - // printRound("internal round", i, &state); + state = internal_round::( state , rc , scoeffs ); } for i in 4..8 { let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] }; - state = external_round::( rcs , state , M ); - // printRound("final round", i, &state); + state = external_round::( state , rcs , M ); } state } -//-------------------------------------- - -pub fn permute_mont_T3(input: [Mont; 3]) -> [Mont; 3] { - const T: usize = 3; - - const TT: usize = 2*T-1; - const NP: usize = INTERNAL_ROUND_COUNT[T-2]; - const C: [Mont; 81] = t3::CONST_C; - const M: [Mont; 9] = t3::CONST_M; - const P: [Mont; 9] = t3::CONST_P; - const S: [Mont; 285] = t3::CONST_S; - - let mut state: [Mont; T] = input; - for j in 0..T { - state[j] = Mont::add( state[j] , C[j] ); - } - for i in 0..4 { - let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ]; - let mat = if i<3 { M } else { P }; - state = external_round::( rcs , state , mat ); - // printRound("initial round", i, &state); - } - for i in 0..NP { - let rc: Mont = C[ i + 5*T ]; - let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ]; - state = internal_round::( rc , scoeffs , state ); - // printRound("internal round", i, &state); - } - for i in 4..8 { - let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] }; - state = external_round::( rcs , state , M ); - // printRound("final round", i, &state); - } - state -} - -//-------------------------------------- - -pub fn permute_mont_T4(input: [Mont; 4]) -> [Mont; 4] { - const T: usize = 4; - - const TT: usize = 2*T-1; - const NP: usize = INTERNAL_ROUND_COUNT[T-2]; - const C: [Mont; 88] = t4::CONST_C; - const M: [Mont; 16] = t4::CONST_M; - const P: [Mont; 16] = t4::CONST_P; - const S: [Mont; 392] = t4::CONST_S; - - let mut state: [Mont; T] = input; - for j in 0..T { - state[j] = Mont::add( state[j] , C[j] ); - } - for i in 0..4 { - let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ]; - let mat = if i<3 { M } else { P }; - state = external_round::( rcs , state , mat ); - // printRound("initial round", i, &state); - } - for i in 0..NP { - let rc: Mont = C[ i + 5*T ]; - let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ]; - state = internal_round::( rc , scoeffs , state ); - // printRound("internal round", i, &state); - } - for i in 4..8 { - let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] }; - state = external_round::( rcs , state , M ); - // printRound("final round", i, &state); - } - state -} - -//-------------------------------------- - -pub fn permute_mont_T5(input: [Mont; 5]) -> [Mont; 5] { - const T: usize = 5; - - const TT: usize = 2*T-1; - const NP: usize = INTERNAL_ROUND_COUNT[T-2]; - const C: [Mont; 100] = t5::CONST_C; - const M: [Mont; 25] = t5::CONST_M; - const P: [Mont; 25] = t5::CONST_P; - const S: [Mont; 540] = t5::CONST_S; - - let mut state: [Mont; T] = input; - for j in 0..T { - state[j] = Mont::add( state[j] , C[j] ); - } - for i in 0..4 { - let rcs: &[Mont] = &C[ ((i+1)*T) .. ((i+2)*T) ]; - let mat = if i<3 { M } else { P }; - state = external_round::( rcs , state , mat ); - // printRound("initial round", i, &state); - } - for i in 0..NP { - let rc: Mont = C[ i + 5*T ]; - let scoeffs: &[Mont] = &S[ (i*TT) .. ((i+1)*TT) ]; - state = internal_round::( rc , scoeffs , state ); - // printRound("internal round", i, &state); - } - for i in 4..8 { - let rcs: &[Mont] = if i<7 { &C[ (NP + (i+1)*T) .. (NP + (i+2)*T) ] } else { &[Mont::zero(); T] }; - state = external_round::( rcs , state , M ); - // printRound("final round", i, &state); - } - state +pub fn compress_mont(input: [Mont; K]) -> Mont where Params: PoseidonParams<{K+1}> { + let mut state: [Mont; K+1] = [Mont::zero(); K+1]; + for i in 0..K { state[i+1] = input[i]; } + state = permute_mont::<{K+1}>(state); + state[0] } //------------------------------------------------------------------------------ -pub fn compress_1(input: Felt) -> Felt { - let mut state: [Mont; 2] = - [ Mont::zero() - , Felt::to_mont(input) - ]; - state = permute_mont_T2(state); - Felt::from_mont(state[0]) +pub fn permute(input: [Felt; T]) -> [Felt; T] where Params: PoseidonParams { + let state: [Mont; T] = Felt::to_mont_vec(input); + let output = permute_mont::(state); + Felt::from_mont_vec(output) } -pub fn compress_2(input: [Felt;2]) -> Felt { - let mut state: [Mont; 3] = - [ Mont::zero() - , Felt::to_mont(input[0]) - , Felt::to_mont(input[1]) - ]; - state = permute_mont_T3(state); - Felt::from_mont(state[0]) -} - -pub fn compress_3(input: [Felt;3]) -> Felt { - let mut state: [Mont; 4] = - [ Mont::zero() - , Felt::to_mont(input[0]) - , Felt::to_mont(input[1]) - , Felt::to_mont(input[2]) - ]; - state = permute_mont_T4(state); - Felt::from_mont(state[0]) -} - -pub fn compress_4(input: [Felt;4]) -> Felt { - let mut state: [Mont; 5] = - [ Mont::zero() - , Felt::to_mont(input[0]) - , Felt::to_mont(input[1]) - , Felt::to_mont(input[2]) - , Felt::to_mont(input[3]) - ]; - state = permute_mont_T5(state); +pub fn compress(input: [Felt; K]) -> Felt where Params: PoseidonParams<{K+1}> { + let mut state: [Mont; K+1] = [Mont::zero(); K+1]; + for i in 0..K { state[i+1] = Felt::to_mont(input[i]); } + state = permute_mont::<{K+1}>(state); Felt::from_mont(state[0]) } //------------------------------------------------------------------------------ + +pub fn hash1(a: Felt) -> Felt { + compress::<1>([ a ]) +} + +pub fn hash2(a: Felt, b: Felt) -> Felt { + compress::<2>([ a, b ]) +} + +pub fn hash3(a: Felt, b: Felt, c: Felt) -> Felt { + compress::<3>([ a, b, c ]) +} + +pub fn hash4(a: Felt, b: Felt, c: Felt, d: Felt) -> Felt { + compress::<4>([ a, b, c, d ]) +} + +//------------------------------------------------------------------------------ diff --git a/src/poseidon2/mod.rs b/src/poseidon2/mod.rs index 925bc29..0c41fcc 100644 --- a/src/poseidon2/mod.rs +++ b/src/poseidon2/mod.rs @@ -2,3 +2,7 @@ pub mod constants; pub mod permutation; +pub use permutation::compress; + +pub use permutation::permute; +pub use permutation::permute_mont; \ No newline at end of file diff --git a/src/poseidon2/permutation.rs b/src/poseidon2/permutation.rs index 35ab07f..e8d5e5f 100644 --- a/src/poseidon2/permutation.rs +++ b/src/poseidon2/permutation.rs @@ -31,7 +31,7 @@ fn linear(input: MontTriple) -> MontTriple { ] } -fn internal_round(rc: Mont, input: MontTriple) -> MontTriple { +fn internal_round(input: MontTriple, rc: Mont) -> MontTriple { let x = sbox( Mont::add( input[0] , rc ) ); let s = add3( x , input[1] , input[2] ); [ Mont::add( s , x ) @@ -40,7 +40,7 @@ fn internal_round(rc: Mont, input: MontTriple) -> MontTriple { ] } -fn external_round(rcs: MontTriple, input: MontTriple) -> MontTriple { +fn external_round(input: MontTriple, rcs: MontTriple) -> MontTriple { let x = sbox( Mont::add( input[0] , rcs[0] ) ); let y = sbox( Mont::add( input[1] , rcs[1] ) ); let z = sbox( Mont::add( input[2] , rcs[2] ) ); @@ -53,29 +53,34 @@ fn external_round(rcs: MontTriple, input: MontTriple) -> MontTriple { pub fn permute_mont(input: MontTriple) -> MontTriple { let mut state = linear(input); - for i in 0..4 { state = external_round( get_initial_RCs(i) , state ); } - for i in 0..56 { state = internal_round( INTERNAL_MONT [i] , state ); } - for i in 0..4 { state = external_round( get_final_RCs (i) , state ); } + for i in 0..4 { state = external_round( state , get_initial_RCs(i) ); } + for i in 0..56 { state = internal_round( state , INTERNAL_MONT [i] ); } + for i in 0..4 { state = external_round( state , get_final_RCs (i) ); } state } //------------------------------------------------------------------------------ -pub fn permute_felt_iterated(input: FeltTriple, count: usize) -> FeltTriple { +pub fn compress(input: [Felt; 2]) -> Felt { + let mut state: [Mont; 3] = [Mont::zero(); 3]; + for i in 0..2 { state[i] = Felt::to_mont(input[i]); } + state = permute_mont(state); + Felt::from_mont(state[0]) +} +pub fn permute(input: [Felt; 3]) -> [Felt; 3] { + let state: MontTriple = Felt::to_mont_vec(input); + let output = permute_mont(state); + Felt::from_mont_vec(output) +} + +pub fn permute_iterated(input: [Felt; 3], count: usize) -> [Felt; 3] { let mut state: MontTriple = Felt::to_mont_vec(input); - for _i in 0..count { state = permute_mont(state); } - let out: FeltTriple = Felt::from_mont_vec(state); - out } -pub fn permute_felt(input: FeltTriple) -> FeltTriple { - permute_felt_iterated(input, 1) -} - //------------------------------------------------------------------------------