mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-03 14:23:07 +00:00
Faster modular inverse (#292)
* Working "faster" inverse algo, using u128s. * Faster inverse_2exp for large exp. * More inverse tests. * Make f, g u64. * Comments. * Unroll first two iterations. * Fix bug and re-unroll first two iterations. * Simplify loop. * Refactoring and documentation. * Clean up testing. * Move inverse code to inversion.rs; use in GoldilocksField. * Bench quartic Goldilocks extension too. * cargo fmt * Add more documentation. * Address Jakub's comments.
This commit is contained in:
parent
dc600d5abf
commit
8f59381c87
@ -93,6 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
bench_field::<CrandallField>(c);
|
||||
bench_field::<GoldilocksField>(c);
|
||||
bench_field::<QuarticExtension<CrandallField>>(c);
|
||||
bench_field::<QuarticExtension<GoldilocksField>>(c);
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
|
||||
@ -79,8 +79,9 @@ impl Field for CrandallField {
|
||||
Self::ORDER.into()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn try_inverse(&self) -> Option<Self> {
|
||||
try_inverse_u64(self.0, Self::ORDER).map(|inv| Self(inv))
|
||||
try_inverse_u64(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@ -81,20 +81,6 @@ macro_rules! test_field_arithmetic {
|
||||
assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow));
|
||||
assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inverse_2exp() {
|
||||
// Just check consistency with try_inverse()
|
||||
type F = $field;
|
||||
|
||||
let v = <F as Field>::PrimeField::TWO_ADICITY;
|
||||
|
||||
for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] {
|
||||
let x = F::TWO.exp_u64(e as u64).inverse();
|
||||
let y = F::inverse_2exp(e);
|
||||
assert_eq!(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@ -127,25 +127,34 @@ pub trait Field:
|
||||
/// Compute the inverse of 2^exp in this field.
|
||||
#[inline]
|
||||
fn inverse_2exp(exp: usize) -> Self {
|
||||
// The inverse of 2^exp is p-(p-1)/2^exp when char(F) = p and
|
||||
// exp is at most the t=TWO_ADICITY of the prime field. When
|
||||
// exp exceeds t, we repeatedly multiply by 2^-t and reduce
|
||||
// exp until it's in the right range.
|
||||
|
||||
let p = Self::CHARACTERISTIC;
|
||||
|
||||
if exp <= Self::PrimeField::TWO_ADICITY {
|
||||
// The inverse of 2^exp is p-(p-1)/2^exp when char(F) = p and exp is
|
||||
// at most the TWO_ADICITY of the prime field.
|
||||
//
|
||||
// NB: PrimeFields fit in 64 bits => TWO_ADICITY < 64 =>
|
||||
// exp < 64 => this shift amount is legal.
|
||||
Self::from_canonical_u64(p - ((p - 1) >> exp))
|
||||
} else {
|
||||
// In the general case we compute 1/2 = (p+1)/2 and then exponentiate
|
||||
// by exp to get 1/2^exp. Costs about log_2(exp) operations.
|
||||
let half = Self::from_canonical_u64((p + 1) >> 1);
|
||||
half.exp_u64(exp as u64)
|
||||
// NB: The only reason this is split into two cases is to save
|
||||
// the multiplication (and possible calculation of
|
||||
// inverse_2_pow_adicity) in the usual case that exp <=
|
||||
// TWO_ADICITY. Can remove the branch and simplify if that
|
||||
// saving isn't worth it.
|
||||
|
||||
// TODO: Faster to combine several high powers of 1/2 using multiple
|
||||
// applications of the trick above. E.g. if the 2-adicity is v, then
|
||||
// compute 1/2^(v^2 + v + 13) with 1/2^((v + 1) * v + 13), etc.
|
||||
// (using the v-adic expansion of m). Costs about log_v(exp) operations.
|
||||
if exp > Self::PrimeField::TWO_ADICITY {
|
||||
// NB: This should be a compile-time constant
|
||||
let inverse_2_pow_adicity: Self =
|
||||
Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY));
|
||||
|
||||
let mut res = inverse_2_pow_adicity;
|
||||
let mut e = exp - Self::PrimeField::TWO_ADICITY;
|
||||
|
||||
while e > Self::PrimeField::TWO_ADICITY {
|
||||
res *= inverse_2_pow_adicity;
|
||||
e -= Self::PrimeField::TWO_ADICITY;
|
||||
}
|
||||
res * Self::from_canonical_u64(p - ((p - 1) >> e))
|
||||
} else {
|
||||
Self::from_canonical_u64(p - ((p - 1) >> exp))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -84,8 +84,9 @@ impl Field for GoldilocksField {
|
||||
Self::ORDER.into()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn try_inverse(&self) -> Option<Self> {
|
||||
try_inverse_u64(self.0, Self::ORDER).map(|inv| Self(inv))
|
||||
try_inverse_u64(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@ -1,61 +1,136 @@
|
||||
use num::{Integer, Zero};
|
||||
use crate::field::field_types::PrimeField;
|
||||
|
||||
/// Try to invert an element in a prime field with the given modulus.
|
||||
#[allow(clippy::many_single_char_names)] // The names are from the paper.
|
||||
pub(crate) fn try_inverse_u64(x: u64, p: u64) -> Option<u64> {
|
||||
if x.is_zero() {
|
||||
/// This is a 'safe' iteration for the modular inversion algorithm. It
|
||||
/// is safe in the sense that it will produce the right answer even
|
||||
/// when f + g >= 2^64.
|
||||
#[inline(always)]
|
||||
fn safe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, k: &mut u32) {
|
||||
if f < g {
|
||||
std::mem::swap(f, g);
|
||||
std::mem::swap(c, d);
|
||||
}
|
||||
if *f & 3 == *g & 3 {
|
||||
// f - g = 0 (mod 4)
|
||||
*f -= *g;
|
||||
*c -= *d;
|
||||
|
||||
// kk >= 2 because f is now 0 (mod 4).
|
||||
let kk = f.trailing_zeros();
|
||||
*f >>= kk;
|
||||
*d <<= kk;
|
||||
*k += kk;
|
||||
} else {
|
||||
// f + g = 0 (mod 4)
|
||||
*f = (*f >> 2) + (*g >> 2) + 1u64;
|
||||
*c += *d;
|
||||
let kk = f.trailing_zeros();
|
||||
*f >>= kk;
|
||||
*d <<= kk + 2;
|
||||
*k += kk + 2;
|
||||
}
|
||||
}
|
||||
|
||||
/// This is an 'unsafe' iteration for the modular inversion
|
||||
/// algorithm. It is unsafe in the sense that it might produce the
|
||||
/// wrong answer if f + g >= 2^64.
|
||||
#[inline(always)]
|
||||
unsafe fn unsafe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, k: &mut u32) {
|
||||
if *f < *g {
|
||||
std::mem::swap(f, g);
|
||||
std::mem::swap(c, d);
|
||||
}
|
||||
if *f & 3 == *g & 3 {
|
||||
// f - g = 0 (mod 4)
|
||||
*f -= *g;
|
||||
*c -= *d;
|
||||
} else {
|
||||
// f + g = 0 (mod 4)
|
||||
*f += *g;
|
||||
*c += *d;
|
||||
}
|
||||
|
||||
// kk >= 2 because f is now 0 (mod 4).
|
||||
let kk = f.trailing_zeros();
|
||||
*f >>= kk;
|
||||
*d <<= kk;
|
||||
*k += kk;
|
||||
}
|
||||
|
||||
/// Try to invert an element in a prime field.
|
||||
///
|
||||
/// The algorithm below is the "plus-minus-inversion" method
|
||||
/// with an "almost Montgomery inverse" flair. See Handbook of
|
||||
/// Elliptic and Hyperelliptic Cryptography, Algorithms 11.6
|
||||
/// and 11.12.
|
||||
#[allow(clippy::many_single_char_names)]
|
||||
pub(crate) fn try_inverse_u64<F: PrimeField>(x: &F) -> Option<F> {
|
||||
let mut f = x.to_noncanonical_u64();
|
||||
let mut g = F::ORDER;
|
||||
// NB: These two are very rarely such that their absolute
|
||||
// value exceeds (p-1)/2; we are paying the price of i128 for
|
||||
// the whole calculation, just for the times they do
|
||||
// though. Measurements suggest a further 10% time saving if c
|
||||
// and d could be replaced with i64's.
|
||||
let mut c = 1i128;
|
||||
let mut d = 0i128;
|
||||
|
||||
if f == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Based on Algorithm 16 of "Efficient Software-Implementation of Finite Fields with
|
||||
// Applications to Cryptography".
|
||||
// f and g must always be odd.
|
||||
let mut k = f.trailing_zeros();
|
||||
f >>= k;
|
||||
if f == 1 {
|
||||
return Some(F::inverse_2exp(k as usize));
|
||||
}
|
||||
|
||||
let mut u = x;
|
||||
let mut v = p;
|
||||
let mut b = 1u64;
|
||||
let mut c = 0u64;
|
||||
// The first two iterations are unrolled. This is to handle
|
||||
// the case where f and g are both large and f+g can
|
||||
// overflow. log2(max{f,g}) goes down by at least one each
|
||||
// iteration though, so after two iterations we can be sure
|
||||
// that f+g won't overflow.
|
||||
|
||||
while u != 1 && v != 1 {
|
||||
let u_tz = u.trailing_zeros();
|
||||
u >>= u_tz;
|
||||
for _ in 0..u_tz {
|
||||
if b.is_even() {
|
||||
b /= 2;
|
||||
} else {
|
||||
// b = (b + p)/2, avoiding overflow
|
||||
b = (b / 2) + (p / 2) + 1;
|
||||
}
|
||||
// Iteration 1:
|
||||
safe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k);
|
||||
|
||||
if f == 1 {
|
||||
// c must be -1 or 1 here.
|
||||
if c == -1 {
|
||||
return Some(-F::inverse_2exp(k as usize));
|
||||
}
|
||||
debug_assert!(c == 1, "bug in try_inverse_u64");
|
||||
return Some(F::inverse_2exp(k as usize));
|
||||
}
|
||||
|
||||
let v_tz = v.trailing_zeros();
|
||||
v >>= v_tz;
|
||||
for _ in 0..v_tz {
|
||||
if c.is_even() {
|
||||
c /= 2;
|
||||
} else {
|
||||
// c = (c + p)/2, avoiding overflow
|
||||
c = (c / 2) + (p / 2) + 1;
|
||||
}
|
||||
}
|
||||
// Iteration 2:
|
||||
safe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k);
|
||||
|
||||
if u >= v {
|
||||
u -= v;
|
||||
// b -= c
|
||||
let (mut diff, under) = b.overflowing_sub(c);
|
||||
if under {
|
||||
diff = diff.wrapping_add(p);
|
||||
}
|
||||
b = diff;
|
||||
} else {
|
||||
v -= u;
|
||||
// c -= b
|
||||
let (mut diff, under) = c.overflowing_sub(b);
|
||||
if under {
|
||||
diff = diff.wrapping_add(p);
|
||||
}
|
||||
c = diff;
|
||||
// Remaining iterations:
|
||||
while f != 1 {
|
||||
unsafe {
|
||||
unsafe_iteration(&mut f, &mut g, &mut c, &mut d, &mut k);
|
||||
}
|
||||
}
|
||||
|
||||
Some(if u == 1 { b } else { c })
|
||||
// The following two loops adjust c so it's in the canonical range
|
||||
// [0, F::ORDER).
|
||||
|
||||
// The maximum number of iterations observed here is 2; should
|
||||
// prove this.
|
||||
while c < 0 {
|
||||
c += F::ORDER as i128;
|
||||
}
|
||||
|
||||
// The maximum number of iterations observed here is 1; should
|
||||
// prove this.
|
||||
while c >= F::ORDER as i128 {
|
||||
c -= F::ORDER as i128;
|
||||
}
|
||||
|
||||
// Precomputing the binary inverses rather than using inverse_2exp
|
||||
// saves ~5ns on my machine.
|
||||
let res = F::from_canonical_u64(c as u64) * F::inverse_2exp(k as usize);
|
||||
debug_assert!(*x * res == F::ONE, "bug in try_inverse_u64");
|
||||
Some(res)
|
||||
}
|
||||
|
||||
@ -125,14 +125,31 @@ macro_rules! test_prime_field_arithmetic {
|
||||
fn inversion() {
|
||||
let zero = <$field>::ZERO;
|
||||
let one = <$field>::ONE;
|
||||
let order = <$field>::ORDER;
|
||||
let modulus = <$field>::ORDER;
|
||||
|
||||
assert_eq!(zero.try_inverse(), None);
|
||||
|
||||
for x in [1, 2, 3, order - 3, order - 2, order - 1] {
|
||||
let x = <$field>::from_canonical_u64(x);
|
||||
let inv = x.inverse();
|
||||
assert_eq!(x * inv, one);
|
||||
let inputs = crate::field::prime_field_testing::test_inputs(modulus);
|
||||
|
||||
for x in inputs {
|
||||
if x != 0 {
|
||||
let x = <$field>::from_canonical_u64(x);
|
||||
let inv = x.inverse();
|
||||
assert_eq!(x * inv, one);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inverse_2exp() {
|
||||
type F = $field;
|
||||
|
||||
let v = <F as Field>::PrimeField::TWO_ADICITY;
|
||||
|
||||
for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] {
|
||||
let x = F::TWO.exp_u64(e as u64);
|
||||
let y = F::inverse_2exp(e);
|
||||
assert_eq!(x * y, F::ONE);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user