From 8f59381c877ecde10eacb9607ded9437d6e02895 Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Sun, 10 Oct 2021 10:39:02 +1100 Subject: [PATCH] Faster modular inverse (#292) * Working "faster" inverse algo, using u128s. * Faster inverse_2exp for large exp. * More inverse tests. * Make f, g u64. * Comments. * Unroll first two iterations. * Fix bug and re-unroll first two iterations. * Simplify loop. * Refactoring and documentation. * Clean up testing. * Move inverse code to inversion.rs; use in GoldilocksField. * Bench quartic Goldilocks extension too. * cargo fmt * Add more documentation. * Address Jakub's comments. --- benches/field_arithmetic.rs | 1 + src/field/crandall_field.rs | 3 +- src/field/field_testing.rs | 14 --- src/field/field_types.rs | 41 +++++--- src/field/goldilocks_field.rs | 3 +- src/field/inversion.rs | 171 ++++++++++++++++++++++--------- src/field/prime_field_testing.rs | 27 ++++- 7 files changed, 175 insertions(+), 85 deletions(-) diff --git a/benches/field_arithmetic.rs b/benches/field_arithmetic.rs index f685465a..7990973c 100644 --- a/benches/field_arithmetic.rs +++ b/benches/field_arithmetic.rs @@ -93,6 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) { bench_field::(c); bench_field::(c); bench_field::>(c); + bench_field::>(c); } criterion_group!(benches, criterion_benchmark); diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 97abf5c8..26a436a5 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -79,8 +79,9 @@ impl Field for CrandallField { Self::ORDER.into() } + #[inline(always)] fn try_inverse(&self) -> Option { - try_inverse_u64(self.0, Self::ORDER).map(|inv| Self(inv)) + try_inverse_u64(self) } #[inline] diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index a1efa5f5..f422d810 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -81,20 +81,6 @@ macro_rules! test_field_arithmetic { assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow)); assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong)); } - - #[test] - fn inverse_2exp() { - // Just check consistency with try_inverse() - type F = $field; - - let v = ::PrimeField::TWO_ADICITY; - - for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] { - let x = F::TWO.exp_u64(e as u64).inverse(); - let y = F::inverse_2exp(e); - assert_eq!(x, y); - } - } } }; } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 82f27d60..f23a95f3 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -127,25 +127,34 @@ pub trait Field: /// Compute the inverse of 2^exp in this field. #[inline] fn inverse_2exp(exp: usize) -> Self { + // The inverse of 2^exp is p-(p-1)/2^exp when char(F) = p and + // exp is at most the t=TWO_ADICITY of the prime field. When + // exp exceeds t, we repeatedly multiply by 2^-t and reduce + // exp until it's in the right range. + let p = Self::CHARACTERISTIC; - if exp <= Self::PrimeField::TWO_ADICITY { - // The inverse of 2^exp is p-(p-1)/2^exp when char(F) = p and exp is - // at most the TWO_ADICITY of the prime field. - // - // NB: PrimeFields fit in 64 bits => TWO_ADICITY < 64 => - // exp < 64 => this shift amount is legal. - Self::from_canonical_u64(p - ((p - 1) >> exp)) - } else { - // In the general case we compute 1/2 = (p+1)/2 and then exponentiate - // by exp to get 1/2^exp. Costs about log_2(exp) operations. - let half = Self::from_canonical_u64((p + 1) >> 1); - half.exp_u64(exp as u64) + // NB: The only reason this is split into two cases is to save + // the multiplication (and possible calculation of + // inverse_2_pow_adicity) in the usual case that exp <= + // TWO_ADICITY. Can remove the branch and simplify if that + // saving isn't worth it. - // TODO: Faster to combine several high powers of 1/2 using multiple - // applications of the trick above. E.g. if the 2-adicity is v, then - // compute 1/2^(v^2 + v + 13) with 1/2^((v + 1) * v + 13), etc. - // (using the v-adic expansion of m). Costs about log_v(exp) operations. + if exp > Self::PrimeField::TWO_ADICITY { + // NB: This should be a compile-time constant + let inverse_2_pow_adicity: Self = + Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY)); + + let mut res = inverse_2_pow_adicity; + let mut e = exp - Self::PrimeField::TWO_ADICITY; + + while e > Self::PrimeField::TWO_ADICITY { + res *= inverse_2_pow_adicity; + e -= Self::PrimeField::TWO_ADICITY; + } + res * Self::from_canonical_u64(p - ((p - 1) >> e)) + } else { + Self::from_canonical_u64(p - ((p - 1) >> exp)) } } diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 0c970fcd..ac7b7816 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -84,8 +84,9 @@ impl Field for GoldilocksField { Self::ORDER.into() } + #[inline(always)] fn try_inverse(&self) -> Option { - try_inverse_u64(self.0, Self::ORDER).map(|inv| Self(inv)) + try_inverse_u64(self) } #[inline] diff --git a/src/field/inversion.rs b/src/field/inversion.rs index 8f66fbed..e3cca682 100644 --- a/src/field/inversion.rs +++ b/src/field/inversion.rs @@ -1,61 +1,136 @@ -use num::{Integer, Zero}; +use crate::field::field_types::PrimeField; -/// Try to invert an element in a prime field with the given modulus. -#[allow(clippy::many_single_char_names)] // The names are from the paper. -pub(crate) fn try_inverse_u64(x: u64, p: u64) -> Option { - if x.is_zero() { +/// This is a 'safe' iteration for the modular inversion algorithm. It +/// is safe in the sense that it will produce the right answer even +/// when f + g >= 2^64. +#[inline(always)] +fn safe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, k: &mut u32) { + if f < g { + std::mem::swap(f, g); + std::mem::swap(c, d); + } + if *f & 3 == *g & 3 { + // f - g = 0 (mod 4) + *f -= *g; + *c -= *d; + + // kk >= 2 because f is now 0 (mod 4). + let kk = f.trailing_zeros(); + *f >>= kk; + *d <<= kk; + *k += kk; + } else { + // f + g = 0 (mod 4) + *f = (*f >> 2) + (*g >> 2) + 1u64; + *c += *d; + let kk = f.trailing_zeros(); + *f >>= kk; + *d <<= kk + 2; + *k += kk + 2; + } +} + +/// This is an 'unsafe' iteration for the modular inversion +/// algorithm. It is unsafe in the sense that it might produce the +/// wrong answer if f + g >= 2^64. +#[inline(always)] +unsafe fn unsafe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, k: &mut u32) { + if *f < *g { + std::mem::swap(f, g); + std::mem::swap(c, d); + } + if *f & 3 == *g & 3 { + // f - g = 0 (mod 4) + *f -= *g; + *c -= *d; + } else { + // f + g = 0 (mod 4) + *f += *g; + *c += *d; + } + + // kk >= 2 because f is now 0 (mod 4). + let kk = f.trailing_zeros(); + *f >>= kk; + *d <<= kk; + *k += kk; +} + +/// Try to invert an element in a prime field. +/// +/// The algorithm below is the "plus-minus-inversion" method +/// with an "almost Montgomery inverse" flair. See Handbook of +/// Elliptic and Hyperelliptic Cryptography, Algorithms 11.6 +/// and 11.12. +#[allow(clippy::many_single_char_names)] +pub(crate) fn try_inverse_u64(x: &F) -> Option { + let mut f = x.to_noncanonical_u64(); + let mut g = F::ORDER; + // NB: These two are very rarely such that their absolute + // value exceeds (p-1)/2; we are paying the price of i128 for + // the whole calculation, just for the times they do + // though. Measurements suggest a further 10% time saving if c + // and d could be replaced with i64's. + let mut c = 1i128; + let mut d = 0i128; + + if f == 0 { return None; } - // Based on Algorithm 16 of "Efficient Software-Implementation of Finite Fields with - // Applications to Cryptography". + // f and g must always be odd. + let mut k = f.trailing_zeros(); + f >>= k; + if f == 1 { + return Some(F::inverse_2exp(k as usize)); + } - let mut u = x; - let mut v = p; - let mut b = 1u64; - let mut c = 0u64; + // The first two iterations are unrolled. This is to handle + // the case where f and g are both large and f+g can + // overflow. log2(max{f,g}) goes down by at least one each + // iteration though, so after two iterations we can be sure + // that f+g won't overflow. - while u != 1 && v != 1 { - let u_tz = u.trailing_zeros(); - u >>= u_tz; - for _ in 0..u_tz { - if b.is_even() { - b /= 2; - } else { - // b = (b + p)/2, avoiding overflow - b = (b / 2) + (p / 2) + 1; - } + // Iteration 1: + safe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k); + + if f == 1 { + // c must be -1 or 1 here. + if c == -1 { + return Some(-F::inverse_2exp(k as usize)); } + debug_assert!(c == 1, "bug in try_inverse_u64"); + return Some(F::inverse_2exp(k as usize)); + } - let v_tz = v.trailing_zeros(); - v >>= v_tz; - for _ in 0..v_tz { - if c.is_even() { - c /= 2; - } else { - // c = (c + p)/2, avoiding overflow - c = (c / 2) + (p / 2) + 1; - } - } + // Iteration 2: + safe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k); - if u >= v { - u -= v; - // b -= c - let (mut diff, under) = b.overflowing_sub(c); - if under { - diff = diff.wrapping_add(p); - } - b = diff; - } else { - v -= u; - // c -= b - let (mut diff, under) = c.overflowing_sub(b); - if under { - diff = diff.wrapping_add(p); - } - c = diff; + // Remaining iterations: + while f != 1 { + unsafe { + unsafe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k); } } - Some(if u == 1 { b } else { c }) + // The following two loops adjust c so it's in the canonical range + // [0, F::ORDER). + + // The maximum number of iterations observed here is 2; should + // prove this. + while c < 0 { + c += F::ORDER as i128; + } + + // The maximum number of iterations observed here is 1; should + // prove this. + while c >= F::ORDER as i128 { + c -= F::ORDER as i128; + } + + // Precomputing the binary inverses rather than using inverse_2exp + // saves ~5ns on my machine. + let res = F::from_canonical_u64(c as u64) * F::inverse_2exp(k as usize); + debug_assert!(*x * res == F::ONE, "bug in try_inverse_u64"); + Some(res) } diff --git a/src/field/prime_field_testing.rs b/src/field/prime_field_testing.rs index f6b88d0c..4febc3a8 100644 --- a/src/field/prime_field_testing.rs +++ b/src/field/prime_field_testing.rs @@ -125,14 +125,31 @@ macro_rules! test_prime_field_arithmetic { fn inversion() { let zero = <$field>::ZERO; let one = <$field>::ONE; - let order = <$field>::ORDER; + let modulus = <$field>::ORDER; assert_eq!(zero.try_inverse(), None); - for x in [1, 2, 3, order - 3, order - 2, order - 1] { - let x = <$field>::from_canonical_u64(x); - let inv = x.inverse(); - assert_eq!(x * inv, one); + let inputs = crate::field::prime_field_testing::test_inputs(modulus); + + for x in inputs { + if x != 0 { + let x = <$field>::from_canonical_u64(x); + let inv = x.inverse(); + assert_eq!(x * inv, one); + } + } + } + + #[test] + fn inverse_2exp() { + type F = $field; + + let v = ::PrimeField::TWO_ADICITY; + + for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] { + let x = F::TWO.exp_u64(e as u64); + let y = F::inverse_2exp(e); + assert_eq!(x * y, F::ONE); } }