diff --git a/field/src/goldilocks_field.rs b/field/src/goldilocks_field.rs index 4432d9c4..6e4361fc 100644 --- a/field/src/goldilocks_field.rs +++ b/field/src/goldilocks_field.rs @@ -7,7 +7,7 @@ use num::{BigUint, Integer}; use plonky2_util::{assume, branch_hint}; use serde::{Deserialize, Serialize}; -use crate::inversion::try_inverse_u64; +use crate::ops::Square; use crate::types::{Field, Field64, PrimeField, PrimeField64, Sample}; const EPSILON: u64 = (1 << 32) - 1; @@ -95,9 +95,55 @@ impl Field for GoldilocksField { Self::order() } - #[inline(always)] + /// Returns the inverse of the field element, using Fermat's little theorem. + /// The inverse of `a` is computed as `a^(p-2)`, where `p` is the prime order of the field. + /// + /// Mathematically, this is equivalent to: + /// $a^(p-1) = 1 (mod p)$ + /// $a^(p-2) * a = 1 (mod p)$ + /// Therefore $a^(p-2) = a^-1 (mod p)$ + /// + /// The following code has been adapted from winterfell/math/src/field/f64/mod.rs + /// located at https://github.com/facebook/winterfell. fn try_inverse(&self) -> Option { - try_inverse_u64(self) + if self.is_zero() { + return None; + } + + // compute base^(P - 2) using 72 multiplications + // The exponent P - 2 is represented in binary as: + // 0b1111111111111111111111111111111011111111111111111111111111111111 + + // compute base^11 + let t2 = self.square() * *self; + + // compute base^111 + let t3 = t2.square() * *self; + + // compute base^111111 (6 ones) + // repeatedly square t3 3 times and multiply by t3 + let t6 = exp_acc::<3>(t3, t3); + + // compute base^111111111111 (12 ones) + // repeatedly square t6 6 times and multiply by t6 + let t12 = exp_acc::<6>(t6, t6); + + // compute base^111111111111111111111111 (24 ones) + // repeatedly square t12 12 times and multiply by t12 + let t24 = exp_acc::<12>(t12, t12); + + // compute base^1111111111111111111111111111111 (31 ones) + // repeatedly square t24 6 times and multiply by t6 first. then square t30 and + // multiply by base + let t30 = exp_acc::<6>(t24, t6); + let t31 = t30.square() * *self; + + // compute base^111111111111111111111111111111101111111111111111111111111111111 + // repeatedly square t31 32 times and multiply by t31 + let t63 = exp_acc::<32>(t31, t31); + + // compute base^1111111111111111111111111111111011111111111111111111111111111111 + Some(t63.square() * *self) } fn from_noncanonical_biguint(n: BigUint) -> Self { @@ -402,6 +448,12 @@ pub(crate) unsafe fn reduce160(x_lo: u128, x_hi: u32) -> GoldilocksField { GoldilocksField(t2) } +/// Squares the base N number of times and multiplies the result by the tail value. +#[inline(always)] +fn exp_acc(base: GoldilocksField, tail: GoldilocksField) -> GoldilocksField { + base.exp_power_of_2(N) * tail +} + #[cfg(test)] mod tests { use crate::{test_field_arithmetic, test_prime_field_arithmetic}; diff --git a/field/src/inversion.rs b/field/src/inversion.rs deleted file mode 100644 index 45d17ab5..00000000 --- a/field/src/inversion.rs +++ /dev/null @@ -1,136 +0,0 @@ -use crate::types::PrimeField64; - -/// This is a 'safe' iteration for the modular inversion algorithm. It -/// is safe in the sense that it will produce the right answer even -/// when f + g >= 2^64. -#[inline(always)] -fn safe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, k: &mut u32) { - if f < g { - core::mem::swap(f, g); - core::mem::swap(c, d); - } - if *f & 3 == *g & 3 { - // f - g = 0 (mod 4) - *f -= *g; - *c -= *d; - - // kk >= 2 because f is now 0 (mod 4). - let kk = f.trailing_zeros(); - *f >>= kk; - *d <<= kk; - *k += kk; - } else { - // f + g = 0 (mod 4) - *f = (*f >> 2) + (*g >> 2) + 1u64; - *c += *d; - let kk = f.trailing_zeros(); - *f >>= kk; - *d <<= kk + 2; - *k += kk + 2; - } -} - -/// This is an 'unsafe' iteration for the modular inversion -/// algorithm. It is unsafe in the sense that it might produce the -/// wrong answer if f + g >= 2^64. -#[inline(always)] -unsafe fn unsafe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, k: &mut u32) { - if *f < *g { - core::mem::swap(f, g); - core::mem::swap(c, d); - } - if *f & 3 == *g & 3 { - // f - g = 0 (mod 4) - *f -= *g; - *c -= *d; - } else { - // f + g = 0 (mod 4) - *f += *g; - *c += *d; - } - - // kk >= 2 because f is now 0 (mod 4). - let kk = f.trailing_zeros(); - *f >>= kk; - *d <<= kk; - *k += kk; -} - -/// Try to invert an element in a prime field. -/// -/// The algorithm below is the "plus-minus-inversion" method -/// with an "almost Montgomery inverse" flair. See Handbook of -/// Elliptic and Hyperelliptic Cryptography, Algorithms 11.6 -/// and 11.12. -#[allow(clippy::many_single_char_names)] -pub(crate) fn try_inverse_u64(x: &F) -> Option { - let mut f = x.to_noncanonical_u64(); - let mut g = F::ORDER; - // NB: These two are very rarely such that their absolute - // value exceeds (p-1)/2; we are paying the price of i128 for - // the whole calculation, just for the times they do - // though. Measurements suggest a further 10% time saving if c - // and d could be replaced with i64's. - let mut c = 1i128; - let mut d = 0i128; - - if f == 0 { - return None; - } - - // f and g must always be odd. - let mut k = f.trailing_zeros(); - f >>= k; - if f == 1 { - return Some(F::inverse_2exp(k as usize)); - } - - // The first two iterations are unrolled. This is to handle - // the case where f and g are both large and f+g can - // overflow. log2(max{f,g}) goes down by at least one each - // iteration though, so after two iterations we can be sure - // that f+g won't overflow. - - // Iteration 1: - safe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k); - - if f == 1 { - // c must be -1 or 1 here. - if c == -1 { - return Some(-F::inverse_2exp(k as usize)); - } - debug_assert!(c == 1, "bug in try_inverse_u64"); - return Some(F::inverse_2exp(k as usize)); - } - - // Iteration 2: - safe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k); - - // Remaining iterations: - while f != 1 { - unsafe { - unsafe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k); - } - } - - // The following two loops adjust c so it's in the canonical range - // [0, F::ORDER). - - // The maximum number of iterations observed here is 2; should - // prove this. - while c < 0 { - c += F::ORDER as i128; - } - - // The maximum number of iterations observed here is 1; should - // prove this. - while c >= F::ORDER as i128 { - c -= F::ORDER as i128; - } - - // Precomputing the binary inverses rather than using inverse_2exp - // saves ~5ns on my machine. - let res = F::from_canonical_u64(c as u64) * F::inverse_2exp(k as usize); - debug_assert!(*x * res == F::ONE, "bug in try_inverse_u64"); - Some(res) -} diff --git a/field/src/lib.rs b/field/src/lib.rs index 461b60c1..d0806bc8 100644 --- a/field/src/lib.rs +++ b/field/src/lib.rs @@ -9,8 +9,6 @@ extern crate alloc; -mod inversion; - pub(crate) mod arch; pub mod batch_util;