2026-01-25 20:49:34 +01:00

526 lines
12 KiB
Rust

//
// big integers, represented as little-endian arrays of u32-s
//
#![allow(dead_code)]
#![allow(non_snake_case)]
#![allow(unused_parens)]
#![allow(unused_imports)]
use std::fmt;
use std::cmp::{Ordering,min};
use unroll::unroll_for_loops;
use crate::bn254::platform::*;
use crate::bn254::constant::{PRIME_ARRAY};
//------------------------------------------------------------------------------
#[derive(Copy, Clone)]
pub struct BigInt<const N: usize>([u32; N]);
#[inline(always)]
pub fn mkBigInt<const N: usize>(ls: [u32; N]) -> BigInt<N> {
BigInt(ls)
}
pub type BigInt256 = BigInt<8>;
pub type BigInt512 = BigInt<16>;
//------------------------------------------------------------------------------
impl<const N: usize> fmt::Display for BigInt<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let _ = f.write_str("0x");
for i in 0..N {
let _ = f.write_fmt(format_args!("{:08x}",self.0[N-1-i]));
}
Ok(())
}
}
impl<const N: usize> BigInt<N> {
pub fn print(s: &str, A: &BigInt<N>) {
println!("{} = {}", s, A);
}
}
//------------------------------------------------------------------------------
impl<const N: usize> BigInt<N> {
#[inline(always)]
pub fn unwrap(big: BigInt<N>) -> [u32; N] {
big.0
}
#[inline(always)]
pub const fn make(ls: [u32; N]) -> BigInt<N> {
BigInt(ls)
}
//------------------------------------
// conversion to/from bytes
pub fn to_le_bytes(big: &BigInt<N>) -> [u8; 4*N] {
let mut buf : [u8; 4*N] = [0; 4*N];
for i in 0..N {
let xs: [u8; 4] = big.0[i].to_le_bytes();
let k = 4*i;
for j in 0..4 {
buf[k + j] = xs[j];
}
}
buf
}
pub fn from_le_bytes(buf : &[u8; 4*N]) -> BigInt<N> {
let mut ws: [u32; N] = [0; N];
for i in 0..N {
let k = 4*i;
let mut xs: [u8; 4] = [0; 4];
for j in 0..4 { xs[j] = buf[k+j]; } // stupid rust...
let w: u32 = u32::from_le_bytes(xs);
ws[i] = w;
}
BigInt(ws)
}
pub fn to_be_bytes(big: &BigInt<N>) -> [u8; 4*N] {
let mut buf : [u8; 4*N] = [0; 4*N];
for i in 0..N {
let xs: [u8; 4] = big.0[N-1-i].to_be_bytes();
let k = 4*i;
for j in 0..4 {
buf[k + j] = xs[j];
}
}
buf
}
pub fn from_be_bytes(buf: &[u8; 4*N]) -> BigInt<N> {
let mut ws: [u32; N] = [0; N];
for i in 0..N {
let k = 4*i;
let mut xs: [u8; 4] = [0; 4];
for j in 0..4 { xs[j] = buf[k+j]; } // stupid rust...
let w: u32 = u32::from_be_bytes(xs);
ws[N-1-i] = w;
}
BigInt(ws)
}
//------------------------------------
// decimal printing
pub fn divmod_small(big: &BigInt<N>, modulus: u32) -> (BigInt<N> , u32) {
let u64_modulus: u64 = modulus as u64;
let mut carry: u32 = 0;
let mut qs: [u32; N] = [0; N];
for i in 0..N {
let x: u64 = ((carry as u64) << 32) + (big.0[N-1-i] as u64);
qs[N-1-i] = (x / u64_modulus) as u32;
carry = (x % u64_modulus) as u32;
}
(BigInt(qs), carry)
}
pub fn to_decimal_string(input: &BigInt<N>) -> String {
let mut digits: Vec<u8> = Vec::new();
let mut big: BigInt<N> = input.clone();
while( !BigInt::is_zero(&big) ) {
let (q,r) = BigInt::divmod_small(&big, 10);
digits.push( 48 + (r as u8) );
big = q;
}
if digits.len() == 0 {
digits.push( 48 );
}
digits.reverse();
str::from_utf8(&digits).unwrap().to_string()
}
//------------------------------------
pub fn truncate1(big: &BigInt<{N+1}>) -> BigInt<N> {
// let small: [u32; N] = &big.limbs[0..N];
let mut small: [u32; N] = [0; N];
for i in 0..N { small[i] = big.0[i]; }
BigInt(small)
}
pub fn zero() -> BigInt<N> {
BigInt([0; N])
}
pub fn from_u32(x: u32) -> BigInt<N> {
let mut xs = [0; N];
xs[0] = x;
BigInt(xs)
}
//------------------------------------
// comparison
pub fn is_zero(big: &BigInt<N>) -> bool {
let mut ok : bool = true;
for i in 0..N {
if big.0[i] != 0 {
ok = false;
break;
}
}
ok
}
pub fn is_equal(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
let mut ok : bool = true;
for i in 0..N {
if big1.0[i] != big2.0[i] {
ok = false;
break;
}
}
ok
}
pub fn cmp(big1: &BigInt<N>, big2: &BigInt<N>) -> Ordering {
let mut res : Ordering = Ordering::Equal;
for i in (0..N).rev() {
if big1.0[i] < big2.0[i] {
res = Ordering::Less;
break;
}
if big1.0[i] > big2.0[i] {
res = Ordering::Greater;
break;
}
}
res
}
pub fn is_lt(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
BigInt::cmp(&big1, &big2) == Ordering::Less
}
pub fn is_gt(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
BigInt::cmp(&big1, &big2) == Ordering::Greater
}
pub fn is_le(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
!BigInt::is_gt(&big1, &big2)
}
pub fn is_ge(big1: &BigInt<N>, big2: &BigInt<N>) -> bool {
!BigInt::is_lt(&big1, &big2)
}
//------------------------------------
// addition and subtraction
#[inline(always)]
#[unroll_for_loops]
pub fn addCarry(big1: &BigInt<N>, big2: &BigInt<N>) -> (BigInt<N>, bool) {
let mut c : bool = false;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (z,cout) = addCarry32( big1.0[i] , big2.0[i] , c);
zs[i] = z;
c = cout;
}
let big: BigInt<N> = BigInt(zs);
(big, c)
}
#[inline(always)]
#[unroll_for_loops]
pub fn subBorrow(big1: &BigInt<N>, big2: &BigInt<N>) -> (BigInt<N>, bool) {
let mut c : bool = false;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (z,cout) = subBorrow32( big1.0[i] , big2.0[i] , c );
zs[i] = z;
c = cout;
}
let big: BigInt<N> = BigInt(zs);
(big, c)
}
pub fn add(big1: &BigInt<N>, big2: &BigInt<N>) -> BigInt<N> {
let (out,_) = BigInt::addCarry(big1,big2);
out
}
pub fn sub(big1: &BigInt<N>, big2: &BigInt<N>) -> BigInt<N> {
let (out,_) = BigInt::subBorrow(big1,big2);
out
}
//------------------------------------
// specialize to the prime number
#[inline(always)]
#[unroll_for_loops]
pub fn is_lt_prime(big: &BigInt<N>) -> bool {
let mut less: bool = false;
for i in (0..N).rev() {
if big.0[i] < PRIME_ARRAY[i] {
less = true;
break;
}
if big.0[i] > PRIME_ARRAY[i] {
break;
}
}
less
}
#[inline(always)]
pub fn is_ge_prime(big: &BigInt<N>) -> bool {
!BigInt::is_lt_prime(big)
}
#[inline(always)]
#[unroll_for_loops]
pub fn add_prime(big: &BigInt<N>) -> (BigInt<N>, bool) {
let mut c : bool = false;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (z,cout) = addCarry32( big.0[i] , PRIME_ARRAY[i] , c );
zs[i] = z;
c = cout;
}
let big: BigInt<N> = BigInt(zs);
(big, c)
}
#[inline(always)]
#[unroll_for_loops]
pub fn subtract_prime(big: &BigInt<N>) -> (BigInt<N>, bool) {
let mut c : bool = false;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (z,cout) = subBorrow32( big.0[i] , PRIME_ARRAY[i] , c );
zs[i] = z;
c = cout;
}
let big: BigInt<N> = BigInt(zs);
(big, c)
}
#[inline(always)]
pub fn subtract_prime_if_necessary(big: &BigInt<N>) -> BigInt<N> {
if BigInt::is_lt_prime(big) {
*big
}
else {
let (corrected, _) = BigInt::subtract_prime(big);
corrected
}
}
//------------------------------------
// multiplication
pub fn scale(scalar: u32, big2: &BigInt<N>) -> (BigInt<N>, u32) {
let mut c : u32 = 0;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (lo,hi) = mulAdd32(scalar, big2.0[i], c);
zs[i] = lo;
c = hi;
}
let big: BigInt<N> = BigInt(zs);
(big, c)
}
#[inline(always)]
#[unroll_for_loops]
pub fn scaleAdd(scalar: u32, vector: &BigInt<N>, add: &BigInt<N>) -> (BigInt<N>, u32) {
let mut c : u32 = 0;
let mut zs : [u32; N] = [0; N];
for i in 0..N {
let (lo,hi) = mulAddAdd32(scalar, vector.0[i], c, add.0[i]);
zs[i] = lo;
c = hi;
}
let big: BigInt<N> = BigInt(zs);
(big, c)
}
#[inline(always)]
#[unroll_for_loops]
pub fn multiply<const M: usize>(big1: &BigInt<N>, big2: &BigInt<M>) -> BigInt<{N+M}> {
let mut product : [u32; N+M] = [0; N+M];
let mut state : [u32; N] = [0; N];
for j in 0..M {
let (scaled,carry) = BigInt::scaleAdd( big2.0[j], &big1, &BigInt(state) );
product[j] = scaled.0[0];
for i in 1..N { state[i-1] = scaled.0[i] }
state[N-1] = carry;
}
for i in 0..N {
product[i+M] = state[i]
}
BigInt(product)
}
#[inline(always)]
pub fn mul(big1: &BigInt<N>, big2: &BigInt<N>) -> BigInt<{N+N}> {
BigInt::multiply(big1,big2)
}
// x*y + z
#[inline(always)]
pub fn mulAdd(big1: &BigInt<N>, big2: &BigInt<N>, big3: &BigInt<N>) -> BigInt<{N+N}> {
// first compute the product
let mut product : [u32; N+N] = [0; N+N];
let mut state : [u32; N] = [0; N];
for j in 0..N {
let (scaled,carry) = BigInt::scaleAdd( big2.0[j], &big1, &BigInt(state) );
product[j] = scaled.0[0];
for i in 1..N { state[i-1] = scaled.0[i] }
state[N-1] = carry;
}
for i in 0..N {
product[i+N] = state[i]
}
// then add the third number
let mut carry: bool = false;
for i in 0..N {
let (z,c) = addCarry32( product[i] , big3.0[i] , carry );
carry = c;
product[i] = z;
}
// continue carrying
for i in N..(N+N) {
let (z,c) = addCarry32( product[i] , 0 , carry );
carry = c;
product[i] = z;
}
BigInt(product)
}
// x*y + (z << 256)
#[inline(always)]
pub fn mulAddShifted(big1: &BigInt<N>, big2: &BigInt<N>, big3: &BigInt<N>) -> BigInt<{N+N}> {
// first compute the product
let mut product : [u32; N+N] = [0; N+N];
let mut state : [u32; N] = [0; N];
for j in 0..N {
let (scaled,carry) = BigInt::scaleAdd( big2.0[j], &big1, &BigInt(state) );
product[j] = scaled.0[0];
for i in 1..N { state[i-1] = scaled.0[i] }
state[N-1] = carry;
}
for i in 0..N {
product[i+N] = state[i]
}
// then add the third number, shifted
let mut carry: bool = false;
for i in 0..N {
let (z,c) = addCarry32( product[i+N] , big3.0[i] , carry );
carry = c;
product[i+N] = z;
}
BigInt(product)
}
// TODO: optimize this?!
pub fn sqr_naive(big: &BigInt<N>) -> BigInt<{N+N}> {
BigInt::multiply(big,big)
}
#[inline(always)]
pub fn sqr(big: &BigInt<N>) -> BigInt<{N+N}> {
BigInt::multiply(big,big)
}
// -----------------------------------------------------------------------------
// half-assed optimization attempts...
/*
pub fn sqr_also_slower(big: &BigInt<N>) -> BigInt<{N+N}> {
let mut mul_mtx : [[(u32,u32); N]; N] = [[(0,0); N]; N];
for i in 0..N {
for j in i..N {
// i <= j
let lo_hi = mulExt32( big.0[i] , big.0[j] );
mul_mtx[i][j] = lo_hi;
mul_mtx[j][i] = lo_hi;
}
}
let mut product : [u32; N+N] = [0; N+N];
let mut state : [u32; N] = [0; N];
for j in 0..N {
// let (scaled,carry) = BigInt::scaleAdd( big2.0[j], &big1, &BigInt(state) );
let mut scaled : [u32; N] = [0; N];
let mut carry : u32 = 0;
for k in 0..N {
// scalar = big2.0[j]
// vector = big1
let (lo,hi) = u64AddAdd32( mul_mtx[j][k], carry, state[k] );
scaled[k] = lo;
carry = hi;
}
product[j] = scaled[0];
for i in 1..N { state[i-1] = scaled[i] }
state[N-1] = carry;
}
for i in 0..N {
product[i+N] = state[i]
}
BigInt(product)
}
// -----------------------------------
pub fn sqr_is_actually_slower(big: &BigInt<N>) -> BigInt<{N+N}> {
let mut product : [u32; N+N] = [0; N+N];
let mut carry : u64 = 0;
for k in 0..(N+N-1) {
let mut sum_lo: u64 = carry;
let mut sum_hi: u64 = 0;
for i in 0..min(N,k+1) {
let j = k - i;
if j < N && i <= j {
let (lo,hi) = mulExt32( big.limbs[i], big.limbs[j] );
sum_lo += (lo as u64);
sum_hi += (hi as u64);
if i < j {
sum_lo += (lo as u64);
sum_hi += (hi as u64);
}
}
}
let (u,v) = takeApart64(sum_lo);
product[k] = u;
carry = sum_hi + (v as u64);
}
product[N+N-1] = (carry as u32);
BigInt { limbs: product }
}
*/
}
//------------------------------------------------------------------------------