Interleaved batched multiplicative inverse (#371)

* Interleaved batched multiplicative inverse

* Minor: typo
This commit is contained in:
Jakub Nabaglo 2021-11-23 21:36:12 -08:00 committed by GitHub
parent 1fed718a70
commit 549ce0d8e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 24 deletions

View File

@ -112,6 +112,66 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
c.bench_function(&format!("try_inverse<{}>", type_name::<F>()), |b| {
b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput)
});
c.bench_function(
&format!("batch_multiplicative_inverse-tiny<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..2).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-small<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..4).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-medium<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..16).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-large<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..256).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::LargeInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-huge<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| {
(0..65536)
.into_iter()
.map(|_| F::rand())
.collect::<Vec<_>>()
},
|x| F::batch_multiplicative_inverse(&x),
BatchSize::LargeInput,
)
},
);
}
fn criterion_benchmark(c: &mut Criterion) {

View File

@ -13,12 +13,15 @@ macro_rules! test_field_arithmetic {
#[test]
fn batch_inversion() {
let xs = (1..=3)
.map(|i| <$field>::from_canonical_u64(i))
.collect::<Vec<_>>();
let invs = <$field>::batch_multiplicative_inverse(&xs);
for (x, inv) in xs.into_iter().zip(invs) {
assert_eq!(x * inv, <$field>::ONE);
for n in 0..20 {
let xs = (1..=n as u64)
.map(|i| <$field>::from_canonical_u64(i))
.collect::<Vec<_>>();
let invs = <$field>::batch_multiplicative_inverse(&xs);
assert_eq!(invs.len(), n);
for (x, inv) in xs.into_iter().zip(invs) {
assert_eq!(x * inv, <$field>::ONE);
}
}
}

View File

@ -102,34 +102,91 @@ pub trait Field:
// This is Montgomery's trick. At a high level, we invert the product of the given field
// elements, then derive the individual inverses from that via multiplication.
// The usual Montgomery trick involves calculating an array of cumulative products,
// resulting in a long dependency chain. To increase instruction-level parallelism, we
// compute WIDTH separate cumulative product arrays that only meet at the end.
// Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
// to run out of registers.
const WIDTH: usize = 4;
// JN note: WIDTH is 4. The code is specialized to this value and will need
// modification if it is changed. I tried to make it more generic, but Rust's const
// generics are not yet good enough.
// Handle special cases. Paradoxically, below is repetitive but concise.
// The branches should be very predictable.
let n = x.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
} else if n == 1 {
return vec![x[0].inverse()];
} else if n == 2 {
let x01 = x[0] * x[1];
let x01inv = x01.inverse();
return vec![x01inv * x[1], x01inv * x[0]];
} else if n == 3 {
let x01 = x[0] * x[1];
let x012 = x01 * x[2];
let x012inv = x012.inverse();
let x01inv = x012inv * x[2];
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
}
debug_assert!(n >= WIDTH);
// Fill buf with cumulative product of x.
let mut buf = Vec::with_capacity(n);
let mut cumul_prod = x[0];
buf.push(cumul_prod);
for i in 1..n {
cumul_prod *= x[i];
buf.push(cumul_prod);
// Buf is reused for a few things to save allocations.
// Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will
// be [
// x[0], x[1], x[2], x[3],
// x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7],
// x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11],
// ...
// ].
// If n is not a multiple of WIDTH, the result is truncated from the end. For example,
// for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]].
let mut buf: Vec<Self> = Vec::with_capacity(n);
// cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we
// convince LLVM to keep the values in the registers.
let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap();
buf.extend(cumul_prod);
for (i, &xi) in x[WIDTH..].iter().enumerate() {
cumul_prod[i % WIDTH] *= xi;
buf.push(cumul_prod[i % WIDTH]);
}
debug_assert_eq!(buf.len(), n);
// At this stage buf contains the the cumulative product of x. We reuse the buffer for
// efficiency. At the end of the loop, it is filled with inverses of x.
let mut a_inv = cumul_prod.inverse();
buf[n - 1] = buf[n - 2] * a_inv;
for i in (1..n - 1).rev() {
a_inv = x[i + 1] * a_inv;
// buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1].
buf[i] = buf[i - 1] * a_inv;
let mut a_inv = {
// This is where the four dependency chains meet.
// Take the last four elements of buf and invert them all.
let c01 = cumul_prod[0] * cumul_prod[1];
let c23 = cumul_prod[2] * cumul_prod[3];
let c0123 = c01 * c23;
let c0123inv = c0123.inverse();
let c01inv = c0123inv * c23;
let c23inv = c0123inv * c01;
[
c01inv * cumul_prod[1],
c01inv * cumul_prod[0],
c23inv * cumul_prod[3],
c23inv * cumul_prod[2],
]
};
for i in (WIDTH..n).rev() {
// buf[i - WIDTH] has not been written to by this loop, so it equals
// x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH].
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
// buf[i] now holds the inverse of x[i].
a_inv[i % WIDTH] *= x[i];
}
buf[0] = x[1] * a_inv;
for i in (0..WIDTH).rev() {
buf[i] = a_inv[i];
}
for (&bi, &xi) in buf.iter().zip(x) {
// Sanity check only.
debug_assert_eq!(bi * xi, Self::ONE);
}
buf
}