mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-06 07:43:10 +00:00
Interleaved batched multiplicative inverse (#371)
* Interleaved batched multiplicative inverse * Minor: typo
This commit is contained in:
parent
1fed718a70
commit
549ce0d8e9
@ -112,6 +112,66 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
|
||||
c.bench_function(&format!("try_inverse<{}>", type_name::<F>()), |b| {
|
||||
b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput)
|
||||
});
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-tiny<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..2).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-small<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..4).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-medium<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..16).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-large<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..256).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-huge<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
(0..65536)
|
||||
.into_iter()
|
||||
.map(|_| F::rand())
|
||||
.collect::<Vec<_>>()
|
||||
},
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
|
||||
@ -13,14 +13,17 @@ macro_rules! test_field_arithmetic {
|
||||
|
||||
#[test]
|
||||
fn batch_inversion() {
|
||||
let xs = (1..=3)
|
||||
for n in 0..20 {
|
||||
let xs = (1..=n as u64)
|
||||
.map(|i| <$field>::from_canonical_u64(i))
|
||||
.collect::<Vec<_>>();
|
||||
let invs = <$field>::batch_multiplicative_inverse(&xs);
|
||||
assert_eq!(invs.len(), n);
|
||||
for (x, inv) in xs.into_iter().zip(invs) {
|
||||
assert_eq!(x * inv, <$field>::ONE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn primitive_root_order() {
|
||||
|
||||
@ -102,34 +102,91 @@ pub trait Field:
|
||||
// This is Montgomery's trick. At a high level, we invert the product of the given field
|
||||
// elements, then derive the individual inverses from that via multiplication.
|
||||
|
||||
// The usual Montgomery trick involves calculating an array of cumulative products,
|
||||
// resulting in a long dependency chain. To increase instruction-level parallelism, we
|
||||
// compute WIDTH separate cumulative product arrays that only meet at the end.
|
||||
|
||||
// Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
|
||||
// to run out of registers.
|
||||
const WIDTH: usize = 4;
|
||||
// JN note: WIDTH is 4. The code is specialized to this value and will need
|
||||
// modification if it is changed. I tried to make it more generic, but Rust's const
|
||||
// generics are not yet good enough.
|
||||
|
||||
// Handle special cases. Paradoxically, below is repetitive but concise.
|
||||
// The branches should be very predictable.
|
||||
let n = x.len();
|
||||
if n == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
if n == 1 {
|
||||
} else if n == 1 {
|
||||
return vec![x[0].inverse()];
|
||||
} else if n == 2 {
|
||||
let x01 = x[0] * x[1];
|
||||
let x01inv = x01.inverse();
|
||||
return vec![x01inv * x[1], x01inv * x[0]];
|
||||
} else if n == 3 {
|
||||
let x01 = x[0] * x[1];
|
||||
let x012 = x01 * x[2];
|
||||
let x012inv = x012.inverse();
|
||||
let x01inv = x012inv * x[2];
|
||||
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
|
||||
}
|
||||
debug_assert!(n >= WIDTH);
|
||||
|
||||
// Fill buf with cumulative product of x.
|
||||
let mut buf = Vec::with_capacity(n);
|
||||
let mut cumul_prod = x[0];
|
||||
buf.push(cumul_prod);
|
||||
for i in 1..n {
|
||||
cumul_prod *= x[i];
|
||||
buf.push(cumul_prod);
|
||||
// Buf is reused for a few things to save allocations.
|
||||
// Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will
|
||||
// be [
|
||||
// x[0], x[1], x[2], x[3],
|
||||
// x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7],
|
||||
// x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11],
|
||||
// ...
|
||||
// ].
|
||||
// If n is not a multiple of WIDTH, the result is truncated from the end. For example,
|
||||
// for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]].
|
||||
let mut buf: Vec<Self> = Vec::with_capacity(n);
|
||||
// cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we
|
||||
// convince LLVM to keep the values in the registers.
|
||||
let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap();
|
||||
buf.extend(cumul_prod);
|
||||
for (i, &xi) in x[WIDTH..].iter().enumerate() {
|
||||
cumul_prod[i % WIDTH] *= xi;
|
||||
buf.push(cumul_prod[i % WIDTH]);
|
||||
}
|
||||
debug_assert_eq!(buf.len(), n);
|
||||
|
||||
// At this stage buf contains the the cumulative product of x. We reuse the buffer for
|
||||
// efficiency. At the end of the loop, it is filled with inverses of x.
|
||||
let mut a_inv = cumul_prod.inverse();
|
||||
buf[n - 1] = buf[n - 2] * a_inv;
|
||||
for i in (1..n - 1).rev() {
|
||||
a_inv = x[i + 1] * a_inv;
|
||||
// buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1].
|
||||
buf[i] = buf[i - 1] * a_inv;
|
||||
let mut a_inv = {
|
||||
// This is where the four dependency chains meet.
|
||||
// Take the last four elements of buf and invert them all.
|
||||
let c01 = cumul_prod[0] * cumul_prod[1];
|
||||
let c23 = cumul_prod[2] * cumul_prod[3];
|
||||
let c0123 = c01 * c23;
|
||||
let c0123inv = c0123.inverse();
|
||||
let c01inv = c0123inv * c23;
|
||||
let c23inv = c0123inv * c01;
|
||||
[
|
||||
c01inv * cumul_prod[1],
|
||||
c01inv * cumul_prod[0],
|
||||
c23inv * cumul_prod[3],
|
||||
c23inv * cumul_prod[2],
|
||||
]
|
||||
};
|
||||
|
||||
for i in (WIDTH..n).rev() {
|
||||
// buf[i - WIDTH] has not been written to by this loop, so it equals
|
||||
// x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH].
|
||||
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
|
||||
// buf[i] now holds the inverse of x[i].
|
||||
a_inv[i % WIDTH] *= x[i];
|
||||
}
|
||||
buf[0] = x[1] * a_inv;
|
||||
for i in (0..WIDTH).rev() {
|
||||
buf[i] = a_inv[i];
|
||||
}
|
||||
|
||||
for (&bi, &xi) in buf.iter().zip(x) {
|
||||
// Sanity check only.
|
||||
debug_assert_eq!(bi * xi, Self::ONE);
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user