From 4740fa3d8803c77e1b64f3bd9d8a1a5e07932c1f Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 29 Jan 2026 11:53:41 +0100 Subject: [PATCH] change the Poseidon2 state representation from triple to array --- benches/iterated_perm.rs | 17 ++++-------- src/bin/testmain.rs | 22 ++++++++-------- src/bin/twenty.rs | 16 ++++++------ src/poseidon2/constants.rs | 8 +++--- src/poseidon2/permutation.rs | 50 +++++++++++++++--------------------- 5 files changed, 49 insertions(+), 64 deletions(-) diff --git a/benches/iterated_perm.rs b/benches/iterated_perm.rs index e46af17..36c1908 100644 --- a/benches/iterated_perm.rs +++ b/benches/iterated_perm.rs @@ -9,31 +9,24 @@ use rust_poseidon_bn254_pure::poseidon2; //------------------------------------------------------------------------------ -type Triple = (Felt,Felt,Felt); +type Triple = [Felt; 3]; fn initial_triple() -> Triple { - ( Felt::from_u32(0) - , Felt::from_u32(1) - , Felt::from_u32(2) - ) -} - -fn initial_vector() -> [Felt; 3] { [ Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ] } -pub fn poseidon1_permute_felt(input: [Felt; 3]) -> [Felt; 3] { +pub fn poseidon1_permute_felt(input: Triple) -> Triple { let mut state: [Mont; 3] = Felt::to_mont_vec(input); state = poseidon::permutation::permute_mont_T3(state); - let out: [Felt; 3] = Felt::from_mont_vec(state); + let out: Triple = Felt::from_mont_vec(state); out } -fn iterate_poseidon1(n: usize) -> [Felt; 3] { - let mut state: [Felt; 3] = initial_vector(); +fn iterate_poseidon1(n: usize) -> Triple { + let mut state: Triple = initial_triple(); for _i in 0..n { state = poseidon1_permute_felt(state); } diff --git a/src/bin/testmain.rs b/src/bin/testmain.rs index 522278b..6d8aef9 100644 --- a/src/bin/testmain.rs +++ b/src/bin/testmain.rs @@ -110,12 +110,12 @@ fn main() { println!("poseidon2 KAT:"); println!(""); - let input = ( Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ); + let input = [ Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ]; let output = permute_felt( input ); - println!("x = {}", input.0 ); - println!("y = {}", input.1 ); - println!("z = {}", input.2 ); + println!("x = {}", input[0] ); + println!("y = {}", input[1] ); + println!("z = {}", input[2] ); println!("~> "); @@ -125,16 +125,16 @@ fn main() { // y' = 0x13f731d6ffbad391be22d2ac364151849e19fa38eced4e761bcd21dbdc600288 // z' = 0x1433e2c8f68382c447c5c14b8b3df7cbfd9273dd655fe52f1357c27150da786f // - println!("x' = {}", output.0 ); - println!("y' = {}", output.1 ); - println!("z' = {}", output.2 ); + println!("x' = {}", output[0] ); + println!("y' = {}", output[1] ); + println!("z' = {}", output[2] ); println!(""); println!("poseidon2 iterated 10,000 times:"); println!(""); let now = Instant::now(); - let mut state: (Felt,Felt,Felt) = input.clone(); + let mut state: [Felt; 3] = input.clone(); for _i in 0..10000 { state = permute_felt(state); } @@ -145,9 +145,9 @@ fn main() { // y'' = 0x138d88ea0ece1c9618254fe2146a6120080e16128467187bf1448e80f31eee3f // z'' = 0x1e51d60083aa3e8fa189e1c72844c5e09225f5977a834f53b471bf0de0dd59eb // - println!("x'' = {}", state.0 ); - println!("y'' = {}", state.1 ); - println!("z'' = {}", state.2 ); + println!("x'' = {}", state[0] ); + println!("y'' = {}", state[1] ); + println!("z'' = {}", state[2] ); let elapsed = now.elapsed(); println!("Elapsed: {:.3?}", elapsed); diff --git a/src/bin/twenty.rs b/src/bin/twenty.rs index a07e4fb..a7b9273 100644 --- a/src/bin/twenty.rs +++ b/src/bin/twenty.rs @@ -6,18 +6,18 @@ fn main() { println!("iterating Poseidon2 twenty times"); - let input = ( Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ); + let input = [ Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ]; - println!("x = {}", input.0 ); - println!("y = {}", input.1 ); - println!("z = {}", input.2 ); + println!("x = {}", input[0] ); + println!("y = {}", input[1] ); + println!("z = {}", input[2] ); - let mut state: (Felt,Felt,Felt) = input.clone(); + let mut state: [Felt; 3] = input.clone(); for _i in 0..20 { state = permute_felt(state); } - println!("x' = {}", state.0 ); - println!("y' = {}", state.1 ); - println!("z' = {}", state.2 ); + println!("x' = {}", state[0] ); + println!("y' = {}", state[1] ); + println!("z' = {}", state[2] ); } \ No newline at end of file diff --git a/src/poseidon2/constants.rs b/src/poseidon2/constants.rs index 3584258..3ef2756 100644 --- a/src/poseidon2/constants.rs +++ b/src/poseidon2/constants.rs @@ -110,20 +110,20 @@ pub const FINAL_MONT: [Mont; 12] = //------------------------------------------------------------------------------ -pub const fn get_initial_RCs(r: usize) -> (Mont, Mont, Mont) { +pub const fn get_initial_RCs(r: usize) -> [Mont; 3] { let j = 3*r; let x: Mont = INITIAL_MONT[j ]; let y: Mont = INITIAL_MONT[j+1]; let z: Mont = INITIAL_MONT[j+2]; - (x,y,z) + [x,y,z] } -pub const fn get_final_RCs(r: usize) -> (Mont, Mont, Mont) { +pub const fn get_final_RCs(r: usize) -> [Mont; 3] { let j = 3*r; let x: Mont = FINAL_MONT[j ]; let y: Mont = FINAL_MONT[j+1]; let z: Mont = FINAL_MONT[j+2]; - (x,y,z) + [x,y,z] } //------------------------------------------------------------------------------ diff --git a/src/poseidon2/permutation.rs b/src/poseidon2/permutation.rs index 5528540..35ab07f 100644 --- a/src/poseidon2/permutation.rs +++ b/src/poseidon2/permutation.rs @@ -6,8 +6,8 @@ use crate::bn254::field::*; use crate::bn254::montgomery::*; use crate::poseidon2::constants::*; -pub type FeltTriple = (Felt,Felt,Felt); -pub type MontTriple = (Mont,Mont,Mont); +pub type FeltTriple = [Felt; 3]; +pub type MontTriple = [Mont; 3]; //------------------------------------------------------------------------------ @@ -24,31 +24,31 @@ fn add3(x: Mont, y: Mont, z: Mont) -> Mont { } fn linear(input: MontTriple) -> MontTriple { - let s = add3(input.0, input.1, input.2); - ( Mont::add( s , input.0 ) - , Mont::add( s , input.1 ) - , Mont::add( s , input.2 ) - ) + let s = add3( input[0], input[1], input[2] ); + [ Mont::add( s , input[0] ) + , Mont::add( s , input[1] ) + , Mont::add( s , input[2] ) + ] } fn internal_round(rc: Mont, input: MontTriple) -> MontTriple { - let x = sbox( Mont::add( input.0 , rc ) ); - let s = add3( x , input.1 , input.2 ); - ( Mont::add( s , x ) - , Mont::add( s , input.1 ) - , Mont::add( s , Mont::dbl(input.2) ) - ) + let x = sbox( Mont::add( input[0] , rc ) ); + let s = add3( x , input[1] , input[2] ); + [ Mont::add( s , x ) + , Mont::add( s , input[1] ) + , Mont::add( s , Mont::dbl(input[2]) ) + ] } fn external_round(rcs: MontTriple, input: MontTriple) -> MontTriple { - let x = sbox( Mont::add( input.0 , rcs.0 ) ); - let y = sbox( Mont::add( input.1 , rcs.1 ) ); - let z = sbox( Mont::add( input.2 , rcs.2 ) ); + let x = sbox( Mont::add( input[0] , rcs[0] ) ); + let y = sbox( Mont::add( input[1] , rcs[1] ) ); + let z = sbox( Mont::add( input[2] , rcs[2] ) ); let s = add3( x , y , z ); - ( Mont::add( s , x ) + [ Mont::add( s , x ) , Mont::add( s , y ) , Mont::add( s , z ) - ) + ] } pub fn permute_mont(input: MontTriple) -> MontTriple { @@ -63,21 +63,13 @@ pub fn permute_mont(input: MontTriple) -> MontTriple { pub fn permute_felt_iterated(input: FeltTriple, count: usize) -> FeltTriple { - let mut state: MontTriple = - ( Felt::to_mont(input.0) - , Felt::to_mont(input.1) - , Felt::to_mont(input.2) - ); + let mut state: MontTriple = Felt::to_mont_vec(input); - for _i in 0..count { + for _i in 0..count { state = permute_mont(state); } - let out: FeltTriple = - ( Felt::from_mont(state.0) - , Felt::from_mont(state.1) - , Felt::from_mont(state.2) - ); + let out: FeltTriple = Felt::from_mont_vec(state); out }