mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-02 13:53:07 +00:00
Bit-order reversal optimizations (#442)
* Bit-order in-place reversal optimizations * optimization/simplification * Done modulo documentation and testing on x86 * Minor type fixes on non-ARM * Minor x86 * Transpose docs * Docs * Make rustfmt happy * Bug fixes + tests * Minor docs + lints
This commit is contained in:
parent
86dc4c933a
commit
5f0eee1a9b
@ -99,7 +99,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reverse_index_bits_in_place() {
|
||||
fn test_reverse_index_bits_in_place_trivial() {
|
||||
let mut arr1: Vec<u64> = vec![10];
|
||||
reverse_index_bits_in_place(&mut arr1);
|
||||
assert_eq!(arr1, vec![10]);
|
||||
|
||||
let mut arr2: Vec<u64> = vec![10, 20];
|
||||
reverse_index_bits_in_place(&mut arr2);
|
||||
assert_eq!(arr2, vec![10, 20]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reverse_index_bits_in_place_small() {
|
||||
let mut arr4: Vec<u64> = vec![10, 20, 30, 40];
|
||||
reverse_index_bits_in_place(&mut arr4);
|
||||
assert_eq!(arr4, vec![10, 30, 20, 40]);
|
||||
@ -127,4 +138,26 @@ mod tests {
|
||||
reverse_index_bits_in_place(&mut arr256);
|
||||
assert_eq!(arr256, output256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reverse_index_bits_in_place_big_even() {
|
||||
let mut arr: Vec<u64> = (0..1 << 16).collect();
|
||||
let target = reverse_index_bits(&arr);
|
||||
reverse_index_bits_in_place(&mut arr);
|
||||
assert_eq!(arr, target);
|
||||
reverse_index_bits_in_place(&mut arr);
|
||||
let range: Vec<u64> = (0..1 << 16).collect();
|
||||
assert_eq!(arr, range);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reverse_index_bits_in_place_big_odd() {
|
||||
let mut arr: Vec<u64> = (0..1 << 17).collect();
|
||||
let target = reverse_index_bits(&arr);
|
||||
reverse_index_bits_in_place(&mut arr);
|
||||
assert_eq!(arr, target);
|
||||
reverse_index_bits_in_place(&mut arr);
|
||||
let range: Vec<u64> = (0..1 << 17).collect();
|
||||
assert_eq!(arr, range);
|
||||
}
|
||||
}
|
||||
|
||||
172
util/src/lib.rs
172
util/src/lib.rs
@ -7,6 +7,11 @@
|
||||
|
||||
use std::arch::asm;
|
||||
use std::hint::unreachable_unchecked;
|
||||
use std::mem::size_of;
|
||||
use std::ptr::{swap, swap_nonoverlapping};
|
||||
|
||||
mod transpose_util;
|
||||
use crate::transpose_util::transpose_in_place_square;
|
||||
|
||||
pub fn bits_u64(n: u64) -> usize {
|
||||
(64 - n.leading_zeros()) as usize
|
||||
@ -26,6 +31,9 @@ pub fn log2_ceil(n: usize) -> usize {
|
||||
pub fn log2_strict(n: usize) -> usize {
|
||||
let res = n.trailing_zeros();
|
||||
assert!(n.wrapping_shr(res) == 1, "Not a power of two: {}", n);
|
||||
// Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
|
||||
// `1 << res` and vice versa.
|
||||
assume(n == 1 << res);
|
||||
res as usize
|
||||
}
|
||||
|
||||
@ -80,57 +88,129 @@ fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
|
||||
result
|
||||
}
|
||||
|
||||
/// Bit-reverse the order of elements in `arr`.
|
||||
/// SAFETY: ensure that `arr.len() == 1 << lb_n`.
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
unsafe fn reverse_index_bits_in_place_small<T>(arr: &mut [T], lb_n: usize) {
|
||||
if lb_n <= 6 {
|
||||
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
|
||||
let dst_shr_amt = 6 - lb_n;
|
||||
for src in 0..arr.len() {
|
||||
let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt;
|
||||
if src < dst {
|
||||
swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// LLVM does not know that it does not need to reverse src at each iteration (which is
|
||||
// expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
|
||||
// bits of dst are dependent only on the low bits of src.
|
||||
let dst_lo_shr_amt = 64 - (lb_n - 6);
|
||||
let dst_hi_shl_amt = lb_n - 6;
|
||||
for src_chunk in 0..(arr.len() >> 6) {
|
||||
let src_hi = src_chunk << 6;
|
||||
let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt;
|
||||
for src_lo in 0..(1 << 6) {
|
||||
let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
|
||||
let src = src_hi + src_lo;
|
||||
let dst = dst_hi + dst_lo;
|
||||
if src < dst {
|
||||
swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Bit-reverse the order of elements in `arr`.
|
||||
/// SAFETY: ensure that `arr.len() == 1 << lb_n`.
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe fn reverse_index_bits_in_place_small<T>(arr: &mut [T], lb_n: usize) {
|
||||
// Aarch64 can reverse bits in one instruction, so the trivial version works best.
|
||||
for src in 0..arr.len() {
|
||||
// `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so
|
||||
// `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the
|
||||
// correct result.
|
||||
let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
|
||||
if src < dst {
|
||||
swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
|
||||
/// chunks, each of length `1 << lb_chunk_size`.
|
||||
/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
|
||||
unsafe fn reverse_index_bits_in_place_chunks<T>(
|
||||
arr: &mut [T],
|
||||
lb_num_chunks: usize,
|
||||
lb_chunk_size: usize,
|
||||
) {
|
||||
for i in 0..1usize << lb_num_chunks {
|
||||
// `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
|
||||
let j = i
|
||||
.reverse_bits()
|
||||
.wrapping_shr(usize::BITS - lb_num_chunks as u32);
|
||||
if i < j {
|
||||
swap_nonoverlapping(
|
||||
arr.get_unchecked_mut(i << lb_chunk_size),
|
||||
arr.get_unchecked_mut(j << lb_chunk_size),
|
||||
1 << lb_chunk_size,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
|
||||
const BIG_T_SIZE: usize = 1 << 14;
|
||||
const SMALL_ARR_SIZE: usize = 1 << 16;
|
||||
pub fn reverse_index_bits_in_place<T>(arr: &mut [T]) {
|
||||
let n = arr.len();
|
||||
let n_power = log2_strict(n);
|
||||
|
||||
if n_power <= 6 {
|
||||
reverse_index_bits_in_place_small(arr, n_power);
|
||||
let lb_n = log2_strict(n);
|
||||
// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
|
||||
// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the
|
||||
// array.
|
||||
if size_of::<T>() << lb_n <= SMALL_ARR_SIZE || size_of::<T>() >= BIG_T_SIZE {
|
||||
unsafe {
|
||||
reverse_index_bits_in_place_small(arr, lb_n);
|
||||
}
|
||||
} else {
|
||||
reverse_index_bits_in_place_large(arr, n_power);
|
||||
}
|
||||
}
|
||||
debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
|
||||
|
||||
/* Both functions below are semantically equivalent to:
|
||||
for src in 0..n {
|
||||
let dst = reverse_bits(src, n_power);
|
||||
if src < dst {
|
||||
arr.swap(src, dst);
|
||||
}
|
||||
}
|
||||
where reverse_bits(src, n_power) computes the n_power-bit reverse.
|
||||
*/
|
||||
|
||||
fn reverse_index_bits_in_place_small<T>(arr: &mut [T], n_power: usize) {
|
||||
let n = arr.len();
|
||||
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
|
||||
let dst_shr_amt = 6 - n_power;
|
||||
for src in 0..n {
|
||||
let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt;
|
||||
if src < dst {
|
||||
arr.swap(src, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reverse_index_bits_in_place_large<T>(arr: &mut [T], n_power: usize) {
|
||||
let n = arr.len();
|
||||
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
|
||||
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high
|
||||
// bits of dst are dependent only on the low bits of src.
|
||||
let dst_lo_shr_amt = 64 - (n_power - 6);
|
||||
let dst_hi_shl_amt = n_power - 6;
|
||||
for src_chunk in 0..(n >> 6) {
|
||||
let src_hi = src_chunk << 6;
|
||||
let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt;
|
||||
for src_lo in 0..(1 << 6) {
|
||||
let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
|
||||
|
||||
let src = src_hi + src_lo;
|
||||
let dst = dst_hi + dst_lo;
|
||||
if src < dst {
|
||||
arr.swap(src, dst);
|
||||
// Algorithm:
|
||||
//
|
||||
// Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
|
||||
// even, i.e., `n` is a square number.) To perform bit-order reversal we:
|
||||
// 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
|
||||
// basically a series of large `memcpy`s.)
|
||||
// 2. Transpose the matrix.
|
||||
// 3. Bit-reverse the order of the rows.
|
||||
// This is equivalent to, for every index `0 <= i < n`:
|
||||
// 1. bit-reversing `i[lb_n / 2..lb_n]`,
|
||||
// 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
|
||||
// 3. bit-reversing `i[lb_n / 2..lb_n]`.
|
||||
//
|
||||
// If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
|
||||
// slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
|
||||
// index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
|
||||
// perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
|
||||
// index is `0` and another, where the middle bit is `1`; we transpose each individually.
|
||||
|
||||
let lb_num_chunks = lb_n >> 1;
|
||||
let lb_chunk_size = lb_n - lb_num_chunks;
|
||||
unsafe {
|
||||
reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size);
|
||||
transpose_in_place_square(arr, lb_chunk_size, lb_num_chunks, 0);
|
||||
if lb_num_chunks != lb_chunk_size {
|
||||
// `arr` cannot be interpreted as a square matrix. We instead interpret it as a
|
||||
// `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
|
||||
// The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
|
||||
// `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
|
||||
// arr by `1 << lb_num_chunks` effectively, adding that to every index.
|
||||
let arr_with_offset = &mut arr[1 << lb_num_chunks..];
|
||||
transpose_in_place_square(arr_with_offset, lb_chunk_size, lb_num_chunks, 0);
|
||||
}
|
||||
reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
112
util/src/transpose_util.rs
Normal file
112
util/src/transpose_util.rs
Normal file
@ -0,0 +1,112 @@
|
||||
use std::ptr::swap;
|
||||
|
||||
const LB_BLOCK_SIZE: usize = 3;
|
||||
|
||||
/// Transpose square matrix in-place
|
||||
/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies
|
||||
/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition
|
||||
/// swaps `M[i, j]` and `M[j, i]`.
|
||||
///
|
||||
/// SAFETY:
|
||||
/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all
|
||||
/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap.
|
||||
unsafe fn transpose_in_place_square_small<T>(
|
||||
arr: &mut [T],
|
||||
lb_stride: usize,
|
||||
lb_size: usize,
|
||||
x: usize,
|
||||
) {
|
||||
for i in x..x + (1 << lb_size) {
|
||||
for j in x..i {
|
||||
swap(
|
||||
arr.get_unchecked_mut(i + (j << lb_stride)),
|
||||
arr.get_unchecked_mut((i << lb_stride) + j),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose square matrices and swap
|
||||
/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy
|
||||
/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]`
|
||||
/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`.
|
||||
///
|
||||
/// SAFETY:
|
||||
/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid
|
||||
/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to
|
||||
/// prevent overlap.
|
||||
unsafe fn transpose_swap_square_small<T>(
|
||||
arr: &mut [T],
|
||||
lb_stride: usize,
|
||||
lb_size: usize,
|
||||
x: usize,
|
||||
y: usize,
|
||||
) {
|
||||
for i in x..x + (1 << lb_size) {
|
||||
for j in y..y + (1 << lb_size) {
|
||||
swap(
|
||||
arr.get_unchecked_mut(i + (j << lb_stride)),
|
||||
arr.get_unchecked_mut((i << lb_stride) + j),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose square matrices and swap
|
||||
/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy
|
||||
/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]`
|
||||
/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`.
|
||||
///
|
||||
/// SAFETY:
|
||||
/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid
|
||||
/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to
|
||||
/// prevent overlap.
|
||||
unsafe fn transpose_swap_square<T>(
|
||||
arr: &mut [T],
|
||||
lb_stride: usize,
|
||||
lb_size: usize,
|
||||
x: usize,
|
||||
y: usize,
|
||||
) {
|
||||
if lb_size <= LB_BLOCK_SIZE {
|
||||
transpose_swap_square_small(arr, lb_stride, lb_size, x, y);
|
||||
} else {
|
||||
let lb_block_size = lb_size - 1;
|
||||
let block_size = 1 << lb_block_size;
|
||||
transpose_swap_square(arr, lb_stride, lb_block_size, x, y);
|
||||
transpose_swap_square(arr, lb_stride, lb_block_size, x + block_size, y);
|
||||
transpose_swap_square(arr, lb_stride, lb_block_size, x, y + block_size);
|
||||
transpose_swap_square(
|
||||
arr,
|
||||
lb_stride,
|
||||
lb_block_size,
|
||||
x + block_size,
|
||||
y + block_size,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose square matrix in-place
|
||||
/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies
|
||||
/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition
|
||||
/// swaps `M[i, j]` and `M[j, i]`.
|
||||
///
|
||||
/// SAFETY:
|
||||
/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all
|
||||
/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap.
|
||||
pub(crate) unsafe fn transpose_in_place_square<T>(
|
||||
arr: &mut [T],
|
||||
lb_stride: usize,
|
||||
lb_size: usize,
|
||||
x: usize,
|
||||
) {
|
||||
if lb_size <= LB_BLOCK_SIZE {
|
||||
transpose_in_place_square_small(arr, lb_stride, lb_size, x);
|
||||
} else {
|
||||
let lb_block_size = lb_size - 1;
|
||||
let block_size = 1 << lb_block_size;
|
||||
transpose_in_place_square(arr, lb_stride, lb_block_size, x);
|
||||
transpose_swap_square(arr, lb_stride, lb_block_size, x, x + block_size);
|
||||
transpose_in_place_square(arr, lb_stride, lb_block_size, x + block_size);
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user