diff --git a/src/bin/testmain.rs b/src/bin/testmain.rs index feb8d84..86ce109 100644 --- a/src/bin/testmain.rs +++ b/src/bin/testmain.rs @@ -147,6 +147,20 @@ fn main() { let elapsed = now.elapsed(); println!("Elapsed: {:.3?}", elapsed); + println!(""); + println!("sanity checking comparison with the prime"); + let one : Big = BigInt::from_u32(1); + let a: Big = BigInt::sub(&FIELD_PRIME, &one); + let b: Big = FIELD_PRIME ; + let c: Big = BigInt::add(&FIELD_PRIME, &one); + println!("a = {}", a ); + println!("b = {}", b ); + println!("c = {}", c ); + println!("{} , {} , {}" , + BigInt::is_lt_prime(&a) , + BigInt::is_lt_prime(&b) , + BigInt::is_lt_prime(&c) ); + //---------------------------------------------------------------------------- } diff --git a/src/bn254/bigint.rs b/src/bn254/bigint.rs index 6b81712..3124b34 100644 --- a/src/bn254/bigint.rs +++ b/src/bn254/bigint.rs @@ -14,6 +14,7 @@ use std::cmp::{Ordering,min}; use unroll::unroll_for_loops; use crate::bn254::platform::*; +use crate::bn254::constant::{PRIME_ARRAY}; //------------------------------------------------------------------------------ @@ -114,7 +115,6 @@ impl BigInt { res } - #[inline(always)] pub fn is_lt(big1: &BigInt, big2: &BigInt) -> bool { BigInt::cmp(&big1, &big2) == Ordering::Less } @@ -131,6 +131,8 @@ impl BigInt { !BigInt::is_lt(&big1, &big2) } + //------------------------------------ + #[inline(always)] #[unroll_for_loops] pub fn addCarry(big1: &BigInt, big2: &BigInt) -> (BigInt, bool) { @@ -169,6 +171,71 @@ impl BigInt { out } + //------------------------------------ + // specialize to the prime number + + #[inline(always)] + #[unroll_for_loops] + pub fn is_lt_prime(big: &BigInt) -> bool { + let mut less: bool = false; + for i in (0..N).rev() { + if big.0[i] < PRIME_ARRAY[i] { + less = true; + break; + } + if big.0[i] > PRIME_ARRAY[i] { + break; + } + } + less + } + + #[inline(always)] + pub fn is_ge_prime(big: &BigInt) -> bool { + !BigInt::is_lt_prime(big) + } + + #[inline(always)] + #[unroll_for_loops] + pub fn add_prime(big: &BigInt) -> (BigInt, bool) { + let mut c : bool = false; + let mut zs : [u32; N] = [0; N]; + for i in 0..N { + let (z,cout) = addCarry32( big.0[i] , PRIME_ARRAY[i] , c ); + zs[i] = z; + c = cout; + } + let big: BigInt = BigInt(zs); + (big, c) + } + + #[inline(always)] + #[unroll_for_loops] + pub fn subtract_prime(big: &BigInt) -> (BigInt, bool) { + let mut c : bool = false; + let mut zs : [u32; N] = [0; N]; + for i in 0..N { + let (z,cout) = subBorrow32( big.0[i] , PRIME_ARRAY[i] , c ); + zs[i] = z; + c = cout; + } + let big: BigInt = BigInt(zs); + (big, c) + } + + #[inline(always)] + pub fn subtract_prime_if_necessary(big: &BigInt) -> BigInt { + if BigInt::is_lt_prime(big) { + *big + } + else { + let (corrected, _) = BigInt::subtract_prime(big); + corrected + } + } + + //------------------------------------ + pub fn scale(scalar: u32, big2: &BigInt) -> (BigInt, u32) { let mut c : u32 = 0; let mut zs : [u32; N] = [0; N]; diff --git a/src/bn254/constant.rs b/src/bn254/constant.rs index cdfe0c4..6969473 100644 --- a/src/bn254/constant.rs +++ b/src/bn254/constant.rs @@ -6,6 +6,7 @@ use crate::bn254::bigint::*; type Big = BigInt<8>; +pub const PRIME_ARRAY : [u32; 8] = [ 0xf0000001 , 0x43e1f593 , 0x79b97091 , 0x2833e848 , 0x8181585d , 0xb85045b6 , 0xe131a029 , 0x30644e72 ]; pub const PRIME_EXT : BigInt<9> = BigInt::make( [ 0xf0000001 , 0x43e1f593 , 0x79b97091 , 0x2833e848 , 0x8181585d , 0xb85045b6 , 0xe131a029 , 0x30644e72 , 0x00000000 ] ); pub const FIELD_PRIME : Big = BigInt::make( [ 0xf0000001 , 0x43e1f593 , 0x79b97091 , 0x2833e848 , 0x8181585d , 0xb85045b6 , 0xe131a029 , 0x30644e72 ] ); pub const PRIME_PLUS_1 : Big = BigInt::make( [ 0xf0000002 , 0x43e1f593 , 0x79b97091 , 0x2833e848 , 0x8181585d , 0xb85045b6 , 0xe131a029 , 0x30644e72 ] ); diff --git a/src/bn254/field.rs b/src/bn254/field.rs index 47aa164..d8797ec 100644 --- a/src/bn254/field.rs +++ b/src/bn254/field.rs @@ -51,7 +51,7 @@ impl Felt { pub fn checked_make( xs: [u32; 8] ) -> Felt { let big: Big = BigInt::make(xs); - if BigInt::is_lt(&big, &FIELD_PRIME) { + if BigInt::is_lt_prime(&big) { Felt(big) } else { @@ -91,19 +91,16 @@ impl Felt { } pub fn add(fld1: &Felt, fld2: &Felt) -> Felt { - let (big, carry) = BigInt::addCarry(&fld1.0, &fld2.0); - if carry || BigInt::is_ge(&big, &FIELD_PRIME) { - Felt(BigInt::sub(&big, &FIELD_PRIME)) - } - else { - Felt(big) - } + let (A, _) = BigInt::addCarry(&fld1.0, &fld2.0); + let B = BigInt::subtract_prime_if_necessary(&A); + Felt(B) } pub fn sub(fld1: &Felt, fld2: &Felt) -> Felt { let (big, carry) = BigInt::subBorrow(&fld1.0, &fld2.0); if carry { - Felt(BigInt::add(&big, &FIELD_PRIME)) + let (corrected, _) = BigInt::add_prime(&big); + Felt(corrected) } else { Felt(big) diff --git a/src/bn254/montgomery.rs b/src/bn254/montgomery.rs index e2e2aa6..413424c 100644 --- a/src/bn254/montgomery.rs +++ b/src/bn254/montgomery.rs @@ -59,20 +59,17 @@ impl Mont { #[inline(always)] pub fn add(mont1: &Mont, mont2: &Mont) -> Mont { - let (big, carry) = BigInt::addCarry(&mont1.0, &mont2.0); - if carry || BigInt::is_ge(&big, &FIELD_PRIME) { - Mont(BigInt::sub(&big, &FIELD_PRIME)) - } - else { - Mont(big) - } + let (A, _) = BigInt::addCarry(&mont1.0, &mont2.0); + let B = BigInt::subtract_prime_if_necessary(&A); + Mont(B) } #[inline(always)] pub fn sub(mont1: &Mont, mont2: &Mont) -> Mont { let (big, carry) = BigInt::subBorrow(&mont1.0, &mont2.0); if carry { - Mont(BigInt::add(&big, &FIELD_PRIME)) + let (corrected, _) = BigInt::add_prime(&big); + Mont(corrected) } else { Mont(big) @@ -134,7 +131,7 @@ impl Mont { let mut carry: u32 = 0; let m: u32 = mulTrunc32( T[i] , MONT_Q ); for j in 0..8 { - let (lo,hi) = mulAddAdd32( m, BigInt::unwrap(FIELD_PRIME)[j], carry, T[i+j] ); + let (lo,hi) = mulAddAdd32( m, PRIME_ARRAY[j], carry, T[i+j] ); T[i+j] = lo; carry = hi; } @@ -148,17 +145,9 @@ impl Mont { let mut S : [u32; 8] = [0; 8]; for i in 0..8 { S[i] = T[8+i]; } - let A : Big = BigInt::make(S); - let (B,c) : (Big,bool) = BigInt::subBorrow( &A , &FIELD_PRIME ); - - if c { - // `A - prime < 0` is equivalent to `A < prime` - A - } - else { - // `A - prime >= 0` is equivalent to `A >= prime` - B - } + let A : Big = BigInt::make(S); + let B : Big = BigInt::subtract_prime_if_necessary(&A); + B } pub fn sqr(mont: &Mont) -> Mont { @@ -174,7 +163,7 @@ impl Mont { // this does conversion from the standard representation // we assume the input is in the range `[0..p-1]` pub fn unsafe_convert_from_big(input: &Big) -> Mont { - let mont: Mont = Mont(input.clone()); + let mont: Mont = Mont(*input); Mont::mul( &mont , &MONT_R2 ) }