experimenting with variations...

This commit is contained in:
Balazs Komuves 2026-01-23 00:57:38 +01:00
parent 4ab91e4b28
commit 1262c72bc7
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
6 changed files with 152 additions and 56 deletions

View File

@ -15,11 +15,19 @@ criterion = "0.3"
[lib]
bench = false
[profile.release]
debug = true
[[bin]]
name = "testmain"
test = false
bench = false
[[bin]]
name = "twenty"
test = false
bench = false
[[bench]]
name = "iterated_perm"
harness = false

View File

@ -134,6 +134,13 @@ fn main() {
for _i in 0..10000 {
state = permute_felt(&state);
}
// expected output:
//
// x'' = 0x27f23fcc813ee313937d46b6d5bab2df03fcb3cf1829f0332ba9f9968509f130
// y'' = 0x138d88ea0ece1c9618254fe2146a6120080e16128467187bf1448e80f31eee3f
// z'' = 0x1e51d60083aa3e8fa189e1c72844c5e09225f5977a834f53b471bf0de0dd59eb
//
println!("x'' = {}", state.0 );
println!("y'' = {}", state.1 );
println!("z'' = {}", state.2 );

23
src/bin/twenty.rs Normal file
View File

@ -0,0 +1,23 @@
use rust_poseidon_bn254_pure::bn254::field::*;
use rust_poseidon_bn254_pure::poseidon2::permutation::*;
fn main() {
println!("iterating Poseidon2 twenty times");
let input = ( Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) );
println!("x = {}", input.0 );
println!("y = {}", input.1 );
println!("z = {}", input.2 );
let mut state: (Felt,Felt,Felt) = input.clone();
for _i in 0..20 {
state = permute_felt(&state);
}
println!("x' = {}", state.0 );
println!("y' = {}", state.1 );
println!("z' = {}", state.2 );
}

View File

@ -16,12 +16,11 @@ use crate::bn254::platform::*;
//------------------------------------------------------------------------------
#[derive(Copy, Clone)]
pub struct BigInt<const N: usize> {
pub limbs: [u32; N]
}
pub struct BigInt<const N: usize>([u32; N]);
#[inline(always)]
pub fn mkBigInt<const N: usize>(ls: [u32; N]) -> BigInt<N> {
BigInt { limbs: ls }
BigInt(ls)
}
pub type BigInt256 = BigInt<8>;
@ -29,6 +28,7 @@ pub type BigInt512 = BigInt<16>;
//------------------------------------------------------------------------------
#[inline(always)]
pub fn boolToU32(c: bool) -> u32 {
if c { 1 } else { 0 }
}
@ -39,7 +39,7 @@ impl<const N: usize> fmt::Display for BigInt<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let _ = f.write_str("0x");
for i in 0..N {
let _ = f.write_fmt(format_args!("{:08x}",self.limbs[N-1-i]));
let _ = f.write_fmt(format_args!("{:08x}",self.0[N-1-i]));
}
Ok(())
}
@ -55,31 +55,37 @@ impl<const N: usize> BigInt<N> {
impl<const N: usize> BigInt<N> {
#[inline(always)]
pub fn unwrap(big: BigInt<N>) -> [u32; N] {
big.0
}
#[inline(always)]
pub const fn make(ls: [u32; N]) -> BigInt<N> {
BigInt { limbs: ls }
BigInt(ls)
}
pub fn truncate1(big : &BigInt<{N+1}>) -> BigInt<N> {
// let small: [u32; N] = &big.limbs[0..N];
let mut small: [u32; N] = [0; N];
for i in 0..N { small[i] = big.limbs[i]; }
BigInt { limbs: small }
for i in 0..N { small[i] = big.0[i]; }
BigInt(small)
}
pub fn zero() -> BigInt<N> {
BigInt { limbs: [0; N] }
BigInt([0; N])
}
pub fn from_u32(x: u32) -> BigInt<N> {
let mut xs = [0; N];
xs[0] = x;
BigInt { limbs: xs }
BigInt(xs)
}
pub fn is_zero(big: &BigInt<N>) -> bool {
let mut ok : bool = true;
for i in 0..N {
if big.limbs[i] != 0 {
if big.0[i] != 0 {
ok = false;
break;
}
@ -90,7 +96,7 @@ impl<const N: usize> BigInt<N> {
pub fn is_equal(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
let mut ok : bool = true;
for i in 0..N {
if big1.limbs[i] != big2.limbs[i] {
if big1.0[i] != big2.0[i] {
ok = false;
break;
}
@ -101,11 +107,11 @@ impl<const N: usize> BigInt<N> {
pub fn cmp(big1: &BigInt<N>, big2: &BigInt<N>) -> Ordering {
let mut res : Ordering = Ordering::Equal;
for i in (0..N).rev() {
if big1.limbs[i] < big2.limbs[i] {
if big1.0[i] < big2.0[i] {
res = Ordering::Less;
break;
}
if big1.limbs[i] > big2.limbs[i] {
if big1.0[i] > big2.0[i] {
res = Ordering::Greater;
break;
}
@ -113,6 +119,7 @@ impl<const N: usize> BigInt<N> {
res
}
#[inline(always)]
pub fn is_lt(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
BigInt::cmp(&big1, &big2) == Ordering::Less
}
@ -129,27 +136,29 @@ impl<const N: usize> BigInt<N> {
!BigInt::is_lt(&big1, &big2)
}
#[inline(always)]
pub fn addCarry(big1: &BigInt<N>, big2: &BigInt<N>) -> (BigInt<N>, bool) {
let mut c : bool = false;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (z,cout) = addCarry32( big1.limbs[i] , big2.limbs[i] , c);
let (z,cout) = addCarry32( big1.0[i] , big2.0[i] , c);
zs[i] = z;
c = cout;
}
let big: BigInt<N> = BigInt { limbs: zs };
let big: BigInt<N> = BigInt(zs);
(big, c)
}
#[inline(always)]
pub fn subBorrow(big1: &BigInt<N>, big2: &BigInt<N>) -> (BigInt<N>, bool) {
let mut c : bool = false;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (z,cout) = subBorrow32( big1.limbs[i] , big2.limbs[i] , c);
let (z,cout) = subBorrow32( big1.0[i] , big2.0[i] , c );
zs[i] = z;
c = cout;
}
let big: BigInt<N> = BigInt { limbs: zs };
let big: BigInt<N> = BigInt(zs);
(big, c)
}
@ -167,11 +176,11 @@ impl<const N: usize> BigInt<N> {
let mut c : u32 = 0;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (lo,hi) = mulAdd32(scalar, big2.limbs[i], c);
let (lo,hi) = mulAdd32(scalar, big2.0[i], c);
zs[i] = lo;
c = hi;
}
let big: BigInt<N> = BigInt { limbs: zs };
let big: BigInt<N> = BigInt(zs);
(big, c)
}
@ -179,11 +188,11 @@ impl<const N: usize> BigInt<N> {
let mut c : u32 = 0;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (lo,hi) = mulAddAdd32(scalar, big2.limbs[i], c, add.limbs[i]);
let (lo,hi) = mulAddAdd32(scalar, big2.0[i], c, add.0[i]);
zs[i] = lo;
c = hi;
}
let big: BigInt<N> = BigInt { limbs: zs };
let big: BigInt<N> = BigInt(zs);
(big, c)
}
@ -191,16 +200,16 @@ impl<const N: usize> BigInt<N> {
let mut product : [u32; N+M] = [0; N+M];
let mut state : [u32; N] = [0; N];
for j in 0..M {
let (scaled,carry) = BigInt::scaleAdd( big2.limbs[j], &big1, &(BigInt { limbs: state }) );
product[j] = scaled.limbs[0];
for i in 1..N { state[i-1] = scaled.limbs[i] }
let (scaled,carry) = BigInt::scaleAdd( big2.0[j], &big1, &BigInt(state) );
product[j] = scaled.0[0];
for i in 1..N { state[i-1] = scaled.0[i] }
state[N-1] = carry;
}
for i in 0..N {
product[i+M] = state[i]
}
BigInt { limbs: product }
BigInt(product)
}
pub fn mul(big1: &BigInt<N>, big2: &BigInt<N>) -> BigInt<{N+N}> {

View File

@ -17,75 +17,78 @@ use crate::bn254::constant::*;
type Big = BigInt<8>;
#[derive(Copy, Clone)]
pub struct Mont {
pub big: Big
}
pub struct Mont(Big);
pub const MONT_R1 : Mont = Mont { big: BIG_R1 };
pub const MONT_R2 : Mont = Mont { big: BIG_R2 };
pub const MONT_R3 : Mont = Mont { big: BIG_R3 };
pub const MONT_R1 : Mont = Mont(BIG_R1);
pub const MONT_R2 : Mont = Mont(BIG_R2);
pub const MONT_R3 : Mont = Mont(BIG_R3);
//------------------------------------------------------------------------------
impl Mont {
#[inline(always)]
pub const fn unsafe_make( xs: [u32; 8] ) -> Mont {
Mont { big: BigInt::make(xs) }
Mont(BigInt::make(xs))
}
#[inline(always)]
pub fn zero() -> Mont {
Mont { big: BigInt::zero() }
Mont(BigInt::zero())
}
pub fn is_equal(mont1: &Mont, mont2: &Mont) -> bool {
BigInt::is_equal(&mont1.big, &mont2.big)
BigInt::is_equal(&mont1.0, &mont2.0)
}
pub fn neg(mont: &Mont) -> Mont {
if BigInt::is_zero(&mont.big) {
if BigInt::is_zero(&mont.0) {
Mont::zero()
}
else {
Mont { big: BigInt::sub(&FIELD_PRIME, &mont.big) }
Mont(BigInt::sub(&FIELD_PRIME, &mont.0))
}
}
#[inline(always)]
pub fn add(mont1: &Mont, mont2: &Mont) -> Mont {
let (big, carry) = BigInt::addCarry(&mont1.big, &mont2.big);
let (big, carry) = BigInt::addCarry(&mont1.0, &mont2.0);
if carry || BigInt::is_ge(&big, &FIELD_PRIME) {
Mont { big: BigInt::sub(&big, &FIELD_PRIME) }
Mont(BigInt::sub(&big, &FIELD_PRIME))
}
else {
Mont { big: big }
Mont(big)
}
}
#[inline(always)]
pub fn sub(mont1: &Mont, mont2: &Mont) -> Mont {
let (big, carry) = BigInt::subBorrow(&mont1.big, &mont2.big);
let (big, carry) = BigInt::subBorrow(&mont1.0, &mont2.0);
if carry {
Mont { big: BigInt::add(&big, &FIELD_PRIME) }
Mont(BigInt::add(&big, &FIELD_PRIME))
}
else {
Mont { big: big }
Mont(big)
}
}
#[inline(always)]
pub fn dbl(mont: &Mont) -> Mont {
Mont::add(&mont, &mont)
}
// the Montgomery reduction algorithm
// <https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_arithmetic_on_multiprecision_integers>
fn redc(input: BigInt<16>) -> Big {
fn redc_safe(input: BigInt<16>) -> Big {
let mut T: [u32; 17] = [0; 17];
for i in 0..16 { T[i] = input.limbs[i]; }
for i in 0..16 { T[i] = BigInt::unwrap(input)[i]; }
for i in 0..8 {
let mut carry: u32 = 0;
let m: u32 = mulTrunc32( T[i] , MONT_Q );
for j in 0..8 {
let (lo,hi) = mulAddAdd32( m, FIELD_PRIME.limbs[j], carry, T[i+j] );
let (lo,hi) = mulAddAdd32( m, BigInt::unwrap(FIELD_PRIME)[j], carry, T[i+j] );
T[i+j] = lo;
carry = hi;
}
@ -99,7 +102,7 @@ impl Mont {
let mut S : [u32; 9] = [0; 9];
for i in 0..9 { S[i] = T[8+i]; }
let A : BigInt<9> = BigInt { limbs: S };
let A : BigInt<9> = BigInt::make(S);
let (B,c) : (BigInt<9>,bool) = BigInt::subBorrow( &A , &PRIME_EXT );
if c {
@ -112,28 +115,65 @@ impl Mont {
}
}
// we can abuse the fact that we know the prime number `p`,
// for which `p < 2^254` so we won't overflow in the 17th word
fn redc(input: BigInt<16>) -> Big {
let mut T: [u32; 16] = BigInt::unwrap(input);
for i in 0..8 {
let mut carry: u32 = 0;
let m: u32 = mulTrunc32( T[i] , MONT_Q );
for j in 0..8 {
let (lo,hi) = mulAddAdd32( m, BigInt::unwrap(FIELD_PRIME)[j], carry, T[i+j] );
T[i+j] = lo;
carry = hi;
}
for j in 8..(16-i) {
let (x,c) = addCarry32_( T[i+j] , carry );
T[i+j] = x;
carry = boolToU32(c);
}
}
let mut S : [u32; 8] = [0; 8];
for i in 0..8 { S[i] = T[8+i]; }
let A : Big = BigInt::make(S);
let (B,c) : (Big,bool) = BigInt::subBorrow( &A , &FIELD_PRIME );
if c {
// `A - prime < 0` is equivalent to `A < prime`
A
}
else {
// `A - prime >= 0` is equivalent to `A >= prime`
B
}
}
pub fn sqr(mont: &Mont) -> Mont {
let large = BigInt::sqr(&mont.big);
Mont { big: Mont::redc(large) }
let large = BigInt::sqr(&mont.0);
Mont(Mont::redc(large))
}
pub fn mul(mont1: &Mont, mont2: &Mont) -> Mont {
let large = BigInt::mul(&mont1.big, &mont2.big);
Mont { big: Mont::redc(large) }
let large = BigInt::mul(&mont1.0, &mont2.0);
Mont(Mont::redc(large))
}
// this does conversion from the standard representation
// we assume the input is in the range `[0..p-1]`
pub fn unsafe_convert_from_big(input: &Big) -> Mont {
let mont: Mont = Mont { big: input.clone() };
let mont: Mont = Mont(input.clone());
Mont::mul( &mont , &MONT_R2 )
}
// this does conversion to the standard representation
pub fn convert_to_big(mont: &Mont) -> Big {
let mut tmp: [u32; 16] = [0; 16];
for i in 0..8 { tmp[i] = mont.big.limbs[i] }
Mont::redc( BigInt { limbs: tmp } )
for i in 0..8 { tmp[i] = BigInt::unwrap(mont.0)[i] }
Mont::redc( BigInt::make(tmp) )
}
// take a small number, interpret it as modulo P,
@ -151,7 +191,7 @@ impl Mont {
impl fmt::Debug for Mont {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let _ = f.write_str("[");
let res = f.write_fmt(format_args!("{}",self.big));
let res = f.write_fmt(format_args!("{}",self.0));
let _ = f.write_str("]");
res
}
@ -169,7 +209,7 @@ impl fmt::Display for Mont {
impl Mont {
pub fn print_internal(s: &str, A: &Mont) {
println!("{} = [{}]", s, A.big);
println!("{} = [{}]", s, A.0);
}
pub fn print_standard(s: &str, A: &Mont) {

View File

@ -5,38 +5,47 @@
//------------------------------------------------------------------------------
// unstable version
#[inline(always)]
pub fn addCarry32_(x: u32, y: u32) -> (u32,bool) {
u32::overflowing_add(x,y)
}
#[inline(always)]
pub fn subBorrow32_(x: u32, y: u32) -> (u32,bool) {
u32::overflowing_sub(x,y)
}
#[inline(always)]
pub fn addCarry32(x :u32, y: u32, cin: bool) -> (u32,bool) {
u32::carrying_add(x,y,cin)
}
#[inline(always)]
pub fn subBorrow32(x: u32, y: u32, cin: bool) -> (u32,bool) {
u32::borrowing_sub(x,y,cin)
}
#[inline(always)]
pub fn mulTrunc32(x: u32, y: u32) -> u32 {
u32::wrapping_mul(x,y)
}
#[inline(always)]
pub fn mulExt32(x: u32, y: u32) -> (u32,u32) {
u32::widening_mul(x,y)
}
#[inline(always)]
pub fn mulAdd32(x: u32, y: u32, a: u32) -> (u32,u32) {
u32::carrying_mul(x,y,a)
}
#[inline(always)]
pub fn mulAddAdd32(x: u32, y: u32, a: u32, b: u32) -> (u32,u32) {
u32::carrying_mul_add(x,y,a,b)
}
#[inline(always)]
pub fn takeApart64(x: u64) -> (u32,u32) {
let lo: u32 = (x & 0x_FFFF_FFFF) as u32;
let hi: u32 = (x >> 32 ) as u32;