diff --git a/Cargo.toml b/Cargo.toml index 8914e42..31d4a3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,11 +15,19 @@ criterion = "0.3" [lib] bench = false +[profile.release] +debug = true + [[bin]] name = "testmain" test = false bench = false +[[bin]] +name = "twenty" +test = false +bench = false + [[bench]] name = "iterated_perm" harness = false diff --git a/src/bin/testmain.rs b/src/bin/testmain.rs index b96cd50..feb8d84 100644 --- a/src/bin/testmain.rs +++ b/src/bin/testmain.rs @@ -134,6 +134,13 @@ fn main() { for _i in 0..10000 { state = permute_felt(&state); } + + // expected output: + // + // x'' = 0x27f23fcc813ee313937d46b6d5bab2df03fcb3cf1829f0332ba9f9968509f130 + // y'' = 0x138d88ea0ece1c9618254fe2146a6120080e16128467187bf1448e80f31eee3f + // z'' = 0x1e51d60083aa3e8fa189e1c72844c5e09225f5977a834f53b471bf0de0dd59eb + // println!("x'' = {}", state.0 ); println!("y'' = {}", state.1 ); println!("z'' = {}", state.2 ); diff --git a/src/bin/twenty.rs b/src/bin/twenty.rs new file mode 100644 index 0000000..c377997 --- /dev/null +++ b/src/bin/twenty.rs @@ -0,0 +1,23 @@ + +use rust_poseidon_bn254_pure::bn254::field::*; +use rust_poseidon_bn254_pure::poseidon2::permutation::*; + +fn main() { + + println!("iterating Poseidon2 twenty times"); + + let input = ( Felt::from_u32(0) , Felt::from_u32(1) , Felt::from_u32(2) ); + + println!("x = {}", input.0 ); + println!("y = {}", input.1 ); + println!("z = {}", input.2 ); + + let mut state: (Felt,Felt,Felt) = input.clone(); + for _i in 0..20 { + state = permute_felt(&state); + } + println!("x' = {}", state.0 ); + println!("y' = {}", state.1 ); + println!("z' = {}", state.2 ); + +} \ No newline at end of file diff --git a/src/bn254/bigint.rs b/src/bn254/bigint.rs index aee9673..dbddace 100644 --- a/src/bn254/bigint.rs +++ b/src/bn254/bigint.rs @@ -16,12 +16,11 @@ use crate::bn254::platform::*; //------------------------------------------------------------------------------ #[derive(Copy, Clone)] -pub struct BigInt { - pub limbs: [u32; N] -} +pub struct BigInt([u32; N]); +#[inline(always)] pub fn mkBigInt(ls: [u32; N]) -> BigInt { - BigInt { limbs: ls } + BigInt(ls) } pub type BigInt256 = BigInt<8>; @@ -29,6 +28,7 @@ pub type BigInt512 = BigInt<16>; //------------------------------------------------------------------------------ +#[inline(always)] pub fn boolToU32(c: bool) -> u32 { if c { 1 } else { 0 } } @@ -39,7 +39,7 @@ impl fmt::Display for BigInt { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let _ = f.write_str("0x"); for i in 0..N { - let _ = f.write_fmt(format_args!("{:08x}",self.limbs[N-1-i])); + let _ = f.write_fmt(format_args!("{:08x}",self.0[N-1-i])); } Ok(()) } @@ -55,31 +55,37 @@ impl BigInt { impl BigInt { + #[inline(always)] + pub fn unwrap(big: BigInt) -> [u32; N] { + big.0 + } + + #[inline(always)] pub const fn make(ls: [u32; N]) -> BigInt { - BigInt { limbs: ls } + BigInt(ls) } pub fn truncate1(big : &BigInt<{N+1}>) -> BigInt { // let small: [u32; N] = &big.limbs[0..N]; let mut small: [u32; N] = [0; N]; - for i in 0..N { small[i] = big.limbs[i]; } - BigInt { limbs: small } + for i in 0..N { small[i] = big.0[i]; } + BigInt(small) } pub fn zero() -> BigInt { - BigInt { limbs: [0; N] } + BigInt([0; N]) } pub fn from_u32(x: u32) -> BigInt { let mut xs = [0; N]; xs[0] = x; - BigInt { limbs: xs } + BigInt(xs) } pub fn is_zero(big: &BigInt) -> bool { let mut ok : bool = true; for i in 0..N { - if big.limbs[i] != 0 { + if big.0[i] != 0 { ok = false; break; } @@ -90,7 +96,7 @@ impl BigInt { pub fn is_equal(big1: &BigInt, big2: &BigInt) -> bool { let mut ok : bool = true; for i in 0..N { - if big1.limbs[i] != big2.limbs[i] { + if big1.0[i] != big2.0[i] { ok = false; break; } @@ -101,11 +107,11 @@ impl BigInt { pub fn cmp(big1: &BigInt, big2: &BigInt) -> Ordering { let mut res : Ordering = Ordering::Equal; for i in (0..N).rev() { - if big1.limbs[i] < big2.limbs[i] { + if big1.0[i] < big2.0[i] { res = Ordering::Less; break; } - if big1.limbs[i] > big2.limbs[i] { + if big1.0[i] > big2.0[i] { res = Ordering::Greater; break; } @@ -113,6 +119,7 @@ impl BigInt { res } + #[inline(always)] pub fn is_lt(big1: &BigInt, big2: &BigInt) -> bool { BigInt::cmp(&big1, &big2) == Ordering::Less } @@ -129,27 +136,29 @@ impl BigInt { !BigInt::is_lt(&big1, &big2) } + #[inline(always)] pub fn addCarry(big1: &BigInt, big2: &BigInt) -> (BigInt, bool) { let mut c : bool = false; let mut zs : [u32; N] = [0; N]; for i in 0..N { - let (z,cout) = addCarry32( big1.limbs[i] , big2.limbs[i] , c); + let (z,cout) = addCarry32( big1.0[i] , big2.0[i] , c); zs[i] = z; c = cout; } - let big: BigInt = BigInt { limbs: zs }; + let big: BigInt = BigInt(zs); (big, c) } + #[inline(always)] pub fn subBorrow(big1: &BigInt, big2: &BigInt) -> (BigInt, bool) { let mut c : bool = false; let mut zs : [u32; N] = [0; N]; for i in 0..N { - let (z,cout) = subBorrow32( big1.limbs[i] , big2.limbs[i] , c); + let (z,cout) = subBorrow32( big1.0[i] , big2.0[i] , c ); zs[i] = z; c = cout; } - let big: BigInt = BigInt { limbs: zs }; + let big: BigInt = BigInt(zs); (big, c) } @@ -167,11 +176,11 @@ impl BigInt { let mut c : u32 = 0; let mut zs : [u32; N] = [0; N]; for i in 0..N { - let (lo,hi) = mulAdd32(scalar, big2.limbs[i], c); + let (lo,hi) = mulAdd32(scalar, big2.0[i], c); zs[i] = lo; c = hi; } - let big: BigInt = BigInt { limbs: zs }; + let big: BigInt = BigInt(zs); (big, c) } @@ -179,11 +188,11 @@ impl BigInt { let mut c : u32 = 0; let mut zs : [u32; N] = [0; N]; for i in 0..N { - let (lo,hi) = mulAddAdd32(scalar, big2.limbs[i], c, add.limbs[i]); + let (lo,hi) = mulAddAdd32(scalar, big2.0[i], c, add.0[i]); zs[i] = lo; c = hi; } - let big: BigInt = BigInt { limbs: zs }; + let big: BigInt = BigInt(zs); (big, c) } @@ -191,16 +200,16 @@ impl BigInt { let mut product : [u32; N+M] = [0; N+M]; let mut state : [u32; N] = [0; N]; for j in 0..M { - let (scaled,carry) = BigInt::scaleAdd( big2.limbs[j], &big1, &(BigInt { limbs: state }) ); - product[j] = scaled.limbs[0]; - for i in 1..N { state[i-1] = scaled.limbs[i] } + let (scaled,carry) = BigInt::scaleAdd( big2.0[j], &big1, &BigInt(state) ); + product[j] = scaled.0[0]; + for i in 1..N { state[i-1] = scaled.0[i] } state[N-1] = carry; } for i in 0..N { product[i+M] = state[i] } - BigInt { limbs: product } + BigInt(product) } pub fn mul(big1: &BigInt, big2: &BigInt) -> BigInt<{N+N}> { diff --git a/src/bn254/montgomery.rs b/src/bn254/montgomery.rs index 3f6dc11..6545f6d 100644 --- a/src/bn254/montgomery.rs +++ b/src/bn254/montgomery.rs @@ -17,75 +17,78 @@ use crate::bn254::constant::*; type Big = BigInt<8>; #[derive(Copy, Clone)] -pub struct Mont { - pub big: Big -} +pub struct Mont(Big); -pub const MONT_R1 : Mont = Mont { big: BIG_R1 }; -pub const MONT_R2 : Mont = Mont { big: BIG_R2 }; -pub const MONT_R3 : Mont = Mont { big: BIG_R3 }; +pub const MONT_R1 : Mont = Mont(BIG_R1); +pub const MONT_R2 : Mont = Mont(BIG_R2); +pub const MONT_R3 : Mont = Mont(BIG_R3); //------------------------------------------------------------------------------ impl Mont { + #[inline(always)] pub const fn unsafe_make( xs: [u32; 8] ) -> Mont { - Mont { big: BigInt::make(xs) } + Mont(BigInt::make(xs)) } + #[inline(always)] pub fn zero() -> Mont { - Mont { big: BigInt::zero() } + Mont(BigInt::zero()) } pub fn is_equal(mont1: &Mont, mont2: &Mont) -> bool { - BigInt::is_equal(&mont1.big, &mont2.big) + BigInt::is_equal(&mont1.0, &mont2.0) } pub fn neg(mont: &Mont) -> Mont { - if BigInt::is_zero(&mont.big) { + if BigInt::is_zero(&mont.0) { Mont::zero() } else { - Mont { big: BigInt::sub(&FIELD_PRIME, &mont.big) } + Mont(BigInt::sub(&FIELD_PRIME, &mont.0)) } } + #[inline(always)] pub fn add(mont1: &Mont, mont2: &Mont) -> Mont { - let (big, carry) = BigInt::addCarry(&mont1.big, &mont2.big); + let (big, carry) = BigInt::addCarry(&mont1.0, &mont2.0); if carry || BigInt::is_ge(&big, &FIELD_PRIME) { - Mont { big: BigInt::sub(&big, &FIELD_PRIME) } + Mont(BigInt::sub(&big, &FIELD_PRIME)) } else { - Mont { big: big } + Mont(big) } } + #[inline(always)] pub fn sub(mont1: &Mont, mont2: &Mont) -> Mont { - let (big, carry) = BigInt::subBorrow(&mont1.big, &mont2.big); + let (big, carry) = BigInt::subBorrow(&mont1.0, &mont2.0); if carry { - Mont { big: BigInt::add(&big, &FIELD_PRIME) } + Mont(BigInt::add(&big, &FIELD_PRIME)) } else { - Mont { big: big } + Mont(big) } } + #[inline(always)] pub fn dbl(mont: &Mont) -> Mont { Mont::add(&mont, &mont) } // the Montgomery reduction algorithm // - fn redc(input: BigInt<16>) -> Big { + fn redc_safe(input: BigInt<16>) -> Big { let mut T: [u32; 17] = [0; 17]; - for i in 0..16 { T[i] = input.limbs[i]; } + for i in 0..16 { T[i] = BigInt::unwrap(input)[i]; } for i in 0..8 { let mut carry: u32 = 0; let m: u32 = mulTrunc32( T[i] , MONT_Q ); for j in 0..8 { - let (lo,hi) = mulAddAdd32( m, FIELD_PRIME.limbs[j], carry, T[i+j] ); + let (lo,hi) = mulAddAdd32( m, BigInt::unwrap(FIELD_PRIME)[j], carry, T[i+j] ); T[i+j] = lo; carry = hi; } @@ -99,7 +102,7 @@ impl Mont { let mut S : [u32; 9] = [0; 9]; for i in 0..9 { S[i] = T[8+i]; } - let A : BigInt<9> = BigInt { limbs: S }; + let A : BigInt<9> = BigInt::make(S); let (B,c) : (BigInt<9>,bool) = BigInt::subBorrow( &A , &PRIME_EXT ); if c { @@ -112,28 +115,65 @@ impl Mont { } } + // we can abuse the fact that we know the prime number `p`, + // for which `p < 2^254` so we won't overflow in the 17th word + fn redc(input: BigInt<16>) -> Big { + + let mut T: [u32; 16] = BigInt::unwrap(input); + + for i in 0..8 { + let mut carry: u32 = 0; + let m: u32 = mulTrunc32( T[i] , MONT_Q ); + for j in 0..8 { + let (lo,hi) = mulAddAdd32( m, BigInt::unwrap(FIELD_PRIME)[j], carry, T[i+j] ); + T[i+j] = lo; + carry = hi; + } + for j in 8..(16-i) { + let (x,c) = addCarry32_( T[i+j] , carry ); + T[i+j] = x; + carry = boolToU32(c); + } + } + + let mut S : [u32; 8] = [0; 8]; + for i in 0..8 { S[i] = T[8+i]; } + + let A : Big = BigInt::make(S); + let (B,c) : (Big,bool) = BigInt::subBorrow( &A , &FIELD_PRIME ); + + if c { + // `A - prime < 0` is equivalent to `A < prime` + A + } + else { + // `A - prime >= 0` is equivalent to `A >= prime` + B + } + } + pub fn sqr(mont: &Mont) -> Mont { - let large = BigInt::sqr(&mont.big); - Mont { big: Mont::redc(large) } + let large = BigInt::sqr(&mont.0); + Mont(Mont::redc(large)) } pub fn mul(mont1: &Mont, mont2: &Mont) -> Mont { - let large = BigInt::mul(&mont1.big, &mont2.big); - Mont { big: Mont::redc(large) } + let large = BigInt::mul(&mont1.0, &mont2.0); + Mont(Mont::redc(large)) } // this does conversion from the standard representation // we assume the input is in the range `[0..p-1]` pub fn unsafe_convert_from_big(input: &Big) -> Mont { - let mont: Mont = Mont { big: input.clone() }; + let mont: Mont = Mont(input.clone()); Mont::mul( &mont , &MONT_R2 ) } // this does conversion to the standard representation pub fn convert_to_big(mont: &Mont) -> Big { let mut tmp: [u32; 16] = [0; 16]; - for i in 0..8 { tmp[i] = mont.big.limbs[i] } - Mont::redc( BigInt { limbs: tmp } ) + for i in 0..8 { tmp[i] = BigInt::unwrap(mont.0)[i] } + Mont::redc( BigInt::make(tmp) ) } // take a small number, interpret it as modulo P, @@ -151,7 +191,7 @@ impl Mont { impl fmt::Debug for Mont { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let _ = f.write_str("["); - let res = f.write_fmt(format_args!("{}",self.big)); + let res = f.write_fmt(format_args!("{}",self.0)); let _ = f.write_str("]"); res } @@ -169,7 +209,7 @@ impl fmt::Display for Mont { impl Mont { pub fn print_internal(s: &str, A: &Mont) { - println!("{} = [{}]", s, A.big); + println!("{} = [{}]", s, A.0); } pub fn print_standard(s: &str, A: &Mont) { diff --git a/src/bn254/platform.rs b/src/bn254/platform.rs index 85d9761..5c018be 100644 --- a/src/bn254/platform.rs +++ b/src/bn254/platform.rs @@ -5,38 +5,47 @@ //------------------------------------------------------------------------------ // unstable version +#[inline(always)] pub fn addCarry32_(x: u32, y: u32) -> (u32,bool) { u32::overflowing_add(x,y) } +#[inline(always)] pub fn subBorrow32_(x: u32, y: u32) -> (u32,bool) { u32::overflowing_sub(x,y) } +#[inline(always)] pub fn addCarry32(x :u32, y: u32, cin: bool) -> (u32,bool) { u32::carrying_add(x,y,cin) } +#[inline(always)] pub fn subBorrow32(x: u32, y: u32, cin: bool) -> (u32,bool) { u32::borrowing_sub(x,y,cin) } +#[inline(always)] pub fn mulTrunc32(x: u32, y: u32) -> u32 { u32::wrapping_mul(x,y) } +#[inline(always)] pub fn mulExt32(x: u32, y: u32) -> (u32,u32) { u32::widening_mul(x,y) } +#[inline(always)] pub fn mulAdd32(x: u32, y: u32, a: u32) -> (u32,u32) { u32::carrying_mul(x,y,a) } +#[inline(always)] pub fn mulAddAdd32(x: u32, y: u32, a: u32, b: u32) -> (u32,u32) { u32::carrying_mul_add(x,y,a,b) } +#[inline(always)] pub fn takeApart64(x: u64) -> (u32,u32) { let lo: u32 = (x & 0x_FFFF_FFFF) as u32; let hi: u32 = (x >> 32 ) as u32;