From 549ce0d8e9f0a26e7ce810ba0620079a421c3177 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Tue, 23 Nov 2021 21:36:12 -0800 Subject: [PATCH] Interleaved batched multiplicative inverse (#371) * Interleaved batched multiplicative inverse * Minor: typo --- benches/field_arithmetic.rs | 60 ++++++++++++++++++++++++ src/field/field_testing.rs | 15 +++--- src/field/field_types.rs | 93 ++++++++++++++++++++++++++++++------- 3 files changed, 144 insertions(+), 24 deletions(-) diff --git a/benches/field_arithmetic.rs b/benches/field_arithmetic.rs index ebecb871..2fb4a24b 100644 --- a/benches/field_arithmetic.rs +++ b/benches/field_arithmetic.rs @@ -112,6 +112,66 @@ pub(crate) fn bench_field(c: &mut Criterion) { c.bench_function(&format!("try_inverse<{}>", type_name::()), |b| { b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput) }); + + c.bench_function( + &format!("batch_multiplicative_inverse-tiny<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..2).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-small<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..4).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-medium<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..16).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-large<{}>", type_name::()), + |b| { + b.iter_batched( + || (0..256).into_iter().map(|_| F::rand()).collect::>(), + |x| F::batch_multiplicative_inverse(&x), + BatchSize::LargeInput, + ) + }, + ); + + c.bench_function( + &format!("batch_multiplicative_inverse-huge<{}>", type_name::()), + |b| { + b.iter_batched( + || { + (0..65536) + .into_iter() + .map(|_| F::rand()) + .collect::>() + }, + |x| F::batch_multiplicative_inverse(&x), + BatchSize::LargeInput, + ) + }, + ); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index f422d810..d15b712a 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -13,12 +13,15 @@ macro_rules! test_field_arithmetic { #[test] fn batch_inversion() { - let xs = (1..=3) - .map(|i| <$field>::from_canonical_u64(i)) - .collect::>(); - let invs = <$field>::batch_multiplicative_inverse(&xs); - for (x, inv) in xs.into_iter().zip(invs) { - assert_eq!(x * inv, <$field>::ONE); + for n in 0..20 { + let xs = (1..=n as u64) + .map(|i| <$field>::from_canonical_u64(i)) + .collect::>(); + let invs = <$field>::batch_multiplicative_inverse(&xs); + assert_eq!(invs.len(), n); + for (x, inv) in xs.into_iter().zip(invs) { + assert_eq!(x * inv, <$field>::ONE); + } } } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index bbb1604e..45839bd9 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -102,34 +102,91 @@ pub trait Field: // This is Montgomery's trick. At a high level, we invert the product of the given field // elements, then derive the individual inverses from that via multiplication. + // The usual Montgomery trick involves calculating an array of cumulative products, + // resulting in a long dependency chain. To increase instruction-level parallelism, we + // compute WIDTH separate cumulative product arrays that only meet at the end. + + // Higher WIDTH increases instruction-level parallelism, but too high a value will cause us + // to run out of registers. + const WIDTH: usize = 4; + // JN note: WIDTH is 4. The code is specialized to this value and will need + // modification if it is changed. I tried to make it more generic, but Rust's const + // generics are not yet good enough. + + // Handle special cases. Paradoxically, below is repetitive but concise. + // The branches should be very predictable. let n = x.len(); if n == 0 { return Vec::new(); - } - if n == 1 { + } else if n == 1 { return vec![x[0].inverse()]; + } else if n == 2 { + let x01 = x[0] * x[1]; + let x01inv = x01.inverse(); + return vec![x01inv * x[1], x01inv * x[0]]; + } else if n == 3 { + let x01 = x[0] * x[1]; + let x012 = x01 * x[2]; + let x012inv = x012.inverse(); + let x01inv = x012inv * x[2]; + return vec![x01inv * x[1], x01inv * x[0], x012inv * x01]; } + debug_assert!(n >= WIDTH); - // Fill buf with cumulative product of x. - let mut buf = Vec::with_capacity(n); - let mut cumul_prod = x[0]; - buf.push(cumul_prod); - for i in 1..n { - cumul_prod *= x[i]; - buf.push(cumul_prod); + // Buf is reused for a few things to save allocations. + // Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will + // be [ + // x[0], x[1], x[2], x[3], + // x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7], + // x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11], + // ... + // ]. + // If n is not a multiple of WIDTH, the result is truncated from the end. For example, + // for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]]. + let mut buf: Vec = Vec::with_capacity(n); + // cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we + // convince LLVM to keep the values in the registers. + let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap(); + buf.extend(cumul_prod); + for (i, &xi) in x[WIDTH..].iter().enumerate() { + cumul_prod[i % WIDTH] *= xi; + buf.push(cumul_prod[i % WIDTH]); } + debug_assert_eq!(buf.len(), n); - // At this stage buf contains the the cumulative product of x. We reuse the buffer for - // efficiency. At the end of the loop, it is filled with inverses of x. - let mut a_inv = cumul_prod.inverse(); - buf[n - 1] = buf[n - 2] * a_inv; - for i in (1..n - 1).rev() { - a_inv = x[i + 1] * a_inv; - // buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1]. - buf[i] = buf[i - 1] * a_inv; + let mut a_inv = { + // This is where the four dependency chains meet. + // Take the last four elements of buf and invert them all. + let c01 = cumul_prod[0] * cumul_prod[1]; + let c23 = cumul_prod[2] * cumul_prod[3]; + let c0123 = c01 * c23; + let c0123inv = c0123.inverse(); + let c01inv = c0123inv * c23; + let c23inv = c0123inv * c01; + [ + c01inv * cumul_prod[1], + c01inv * cumul_prod[0], + c23inv * cumul_prod[3], + c23inv * cumul_prod[2], + ] + }; + + for i in (WIDTH..n).rev() { + // buf[i - WIDTH] has not been written to by this loop, so it equals + // x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH]. + buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH]; // buf[i] now holds the inverse of x[i]. + a_inv[i % WIDTH] *= x[i]; } - buf[0] = x[1] * a_inv; + for i in (0..WIDTH).rev() { + buf[i] = a_inv[i]; + } + + for (&bi, &xi) in buf.iter().zip(x) { + // Sanity check only. + debug_assert_eq!(bi * xi, Self::ONE); + } + buf }