diff --git a/plonky2/src/util/mod.rs b/plonky2/src/util/mod.rs index 13a72f78..4cf7119a 100644 --- a/plonky2/src/util/mod.rs +++ b/plonky2/src/util/mod.rs @@ -99,7 +99,18 @@ mod tests { } #[test] - fn test_reverse_index_bits_in_place() { + fn test_reverse_index_bits_in_place_trivial() { + let mut arr1: Vec = vec![10]; + reverse_index_bits_in_place(&mut arr1); + assert_eq!(arr1, vec![10]); + + let mut arr2: Vec = vec![10, 20]; + reverse_index_bits_in_place(&mut arr2); + assert_eq!(arr2, vec![10, 20]); + } + + #[test] + fn test_reverse_index_bits_in_place_small() { let mut arr4: Vec = vec![10, 20, 30, 40]; reverse_index_bits_in_place(&mut arr4); assert_eq!(arr4, vec![10, 30, 20, 40]); @@ -127,4 +138,26 @@ mod tests { reverse_index_bits_in_place(&mut arr256); assert_eq!(arr256, output256); } + + #[test] + fn test_reverse_index_bits_in_place_big_even() { + let mut arr: Vec = (0..1 << 16).collect(); + let target = reverse_index_bits(&arr); + reverse_index_bits_in_place(&mut arr); + assert_eq!(arr, target); + reverse_index_bits_in_place(&mut arr); + let range: Vec = (0..1 << 16).collect(); + assert_eq!(arr, range); + } + + #[test] + fn test_reverse_index_bits_in_place_big_odd() { + let mut arr: Vec = (0..1 << 17).collect(); + let target = reverse_index_bits(&arr); + reverse_index_bits_in_place(&mut arr); + assert_eq!(arr, target); + reverse_index_bits_in_place(&mut arr); + let range: Vec = (0..1 << 17).collect(); + assert_eq!(arr, range); + } } diff --git a/util/src/lib.rs b/util/src/lib.rs index 5c683a50..f760cfba 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -7,6 +7,11 @@ use std::arch::asm; use std::hint::unreachable_unchecked; +use std::mem::size_of; +use std::ptr::{swap, swap_nonoverlapping}; + +mod transpose_util; +use crate::transpose_util::transpose_in_place_square; pub fn bits_u64(n: u64) -> usize { (64 - n.leading_zeros()) as usize @@ -26,6 +31,9 @@ pub fn log2_ceil(n: usize) -> usize { pub fn log2_strict(n: usize) -> usize { let res = n.trailing_zeros(); assert!(n.wrapping_shr(res) == 1, "Not a power of two: {}", n); + // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with + // `1 << res` and vice versa. + assume(n == 1 << res); res as usize } @@ -80,57 +88,129 @@ fn reverse_index_bits_large(arr: &[T], n_power: usize) -> Vec { result } +/// Bit-reverse the order of elements in `arr`. +/// SAFETY: ensure that `arr.len() == 1 << lb_n`. +#[cfg(not(target_arch = "aarch64"))] +unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { + if lb_n <= 6 { + // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses. + let dst_shr_amt = 6 - lb_n; + for src in 0..arr.len() { + let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt; + if src < dst { + swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); + } + } + } else { + // LLVM does not know that it does not need to reverse src at each iteration (which is + // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high + // bits of dst are dependent only on the low bits of src. + let dst_lo_shr_amt = 64 - (lb_n - 6); + let dst_hi_shl_amt = lb_n - 6; + for src_chunk in 0..(arr.len() >> 6) { + let src_hi = src_chunk << 6; + let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt; + for src_lo in 0..(1 << 6) { + let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt; + let src = src_hi + src_lo; + let dst = dst_hi + dst_lo; + if src < dst { + swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); + } + } + } + } +} + +/// Bit-reverse the order of elements in `arr`. +/// SAFETY: ensure that `arr.len() == 1 << lb_n`. +#[cfg(target_arch = "aarch64")] +unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { + // Aarch64 can reverse bits in one instruction, so the trivial version works best. + for src in 0..arr.len() { + // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so + // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the + // correct result. + let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32); + if src < dst { + swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); + } + } +} + +/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks` +/// chunks, each of length `1 << lb_chunk_size`. +/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`. +unsafe fn reverse_index_bits_in_place_chunks( + arr: &mut [T], + lb_num_chunks: usize, + lb_chunk_size: usize, +) { + for i in 0..1usize << lb_num_chunks { + // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`. + let j = i + .reverse_bits() + .wrapping_shr(usize::BITS - lb_num_chunks as u32); + if i < j { + swap_nonoverlapping( + arr.get_unchecked_mut(i << lb_chunk_size), + arr.get_unchecked_mut(j << lb_chunk_size), + 1 << lb_chunk_size, + ); + } + } +} + +// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE. +const BIG_T_SIZE: usize = 1 << 14; +const SMALL_ARR_SIZE: usize = 1 << 16; pub fn reverse_index_bits_in_place(arr: &mut [T]) { let n = arr.len(); - let n_power = log2_strict(n); - - if n_power <= 6 { - reverse_index_bits_in_place_small(arr, n_power); + let lb_n = log2_strict(n); + // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if + // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the + // array. + if size_of::() << lb_n <= SMALL_ARR_SIZE || size_of::() >= BIG_T_SIZE { + unsafe { + reverse_index_bits_in_place_small(arr, lb_n); + } } else { - reverse_index_bits_in_place_large(arr, n_power); - } -} + debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`. -/* Both functions below are semantically equivalent to: - for src in 0..n { - let dst = reverse_bits(src, n_power); - if src < dst { - arr.swap(src, dst); - } - } - where reverse_bits(src, n_power) computes the n_power-bit reverse. -*/ - -fn reverse_index_bits_in_place_small(arr: &mut [T], n_power: usize) { - let n = arr.len(); - // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses. - let dst_shr_amt = 6 - n_power; - for src in 0..n { - let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt; - if src < dst { - arr.swap(src, dst); - } - } -} - -fn reverse_index_bits_in_place_large(arr: &mut [T], n_power: usize) { - let n = arr.len(); - // LLVM does not know that it does not need to reverse src at each iteration (which is expensive - // on x86). We take advantage of the fact that the low bits of dst change rarely and the high - // bits of dst are dependent only on the low bits of src. - let dst_lo_shr_amt = 64 - (n_power - 6); - let dst_hi_shl_amt = n_power - 6; - for src_chunk in 0..(n >> 6) { - let src_hi = src_chunk << 6; - let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt; - for src_lo in 0..(1 << 6) { - let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt; - - let src = src_hi + src_lo; - let dst = dst_hi + dst_lo; - if src < dst { - arr.swap(src, dst); + // Algorithm: + // + // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is + // even, i.e., `n` is a square number.) To perform bit-order reversal we: + // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is + // basically a series of large `memcpy`s.) + // 2. Transpose the matrix. + // 3. Bit-reverse the order of the rows. + // This is equivalent to, for every index `0 <= i < n`: + // 1. bit-reversing `i[lb_n / 2..lb_n]`, + // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`, + // 3. bit-reversing `i[lb_n / 2..lb_n]`. + // + // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires + // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the + // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we + // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the + // index is `0` and another, where the middle bit is `1`; we transpose each individually. + + let lb_num_chunks = lb_n >> 1; + let lb_chunk_size = lb_n - lb_num_chunks; + unsafe { + reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size); + transpose_in_place_square(arr, lb_chunk_size, lb_num_chunks, 0); + if lb_num_chunks != lb_chunk_size { + // `arr` cannot be interpreted as a square matrix. We instead interpret it as a + // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order. + // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit + // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance + // arr by `1 << lb_num_chunks` effectively, adding that to every index. + let arr_with_offset = &mut arr[1 << lb_num_chunks..]; + transpose_in_place_square(arr_with_offset, lb_chunk_size, lb_num_chunks, 0); } + reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size); } } } diff --git a/util/src/transpose_util.rs b/util/src/transpose_util.rs new file mode 100644 index 00000000..1c8280a8 --- /dev/null +++ b/util/src/transpose_util.rs @@ -0,0 +1,112 @@ +use std::ptr::swap; + +const LB_BLOCK_SIZE: usize = 3; + +/// Transpose square matrix in-place +/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies +/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition +/// swaps `M[i, j]` and `M[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all +/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap. +unsafe fn transpose_in_place_square_small( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, +) { + for i in x..x + (1 << lb_size) { + for j in x..i { + swap( + arr.get_unchecked_mut(i + (j << lb_stride)), + arr.get_unchecked_mut((i << lb_stride) + j), + ); + } + } +} + +/// Transpose square matrices and swap +/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy +/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]` +/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid +/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to +/// prevent overlap. +unsafe fn transpose_swap_square_small( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, + y: usize, +) { + for i in x..x + (1 << lb_size) { + for j in y..y + (1 << lb_size) { + swap( + arr.get_unchecked_mut(i + (j << lb_stride)), + arr.get_unchecked_mut((i << lb_stride) + j), + ); + } + } +} + +/// Transpose square matrices and swap +/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy +/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]` +/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid +/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to +/// prevent overlap. +unsafe fn transpose_swap_square( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, + y: usize, +) { + if lb_size <= LB_BLOCK_SIZE { + transpose_swap_square_small(arr, lb_stride, lb_size, x, y); + } else { + let lb_block_size = lb_size - 1; + let block_size = 1 << lb_block_size; + transpose_swap_square(arr, lb_stride, lb_block_size, x, y); + transpose_swap_square(arr, lb_stride, lb_block_size, x + block_size, y); + transpose_swap_square(arr, lb_stride, lb_block_size, x, y + block_size); + transpose_swap_square( + arr, + lb_stride, + lb_block_size, + x + block_size, + y + block_size, + ); + } +} + +/// Transpose square matrix in-place +/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies +/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition +/// swaps `M[i, j]` and `M[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all +/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap. +pub(crate) unsafe fn transpose_in_place_square( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, +) { + if lb_size <= LB_BLOCK_SIZE { + transpose_in_place_square_small(arr, lb_stride, lb_size, x); + } else { + let lb_block_size = lb_size - 1; + let block_size = 1 << lb_block_size; + transpose_in_place_square(arr, lb_stride, lb_block_size, x); + transpose_swap_square(arr, lb_stride, lb_block_size, x, x + block_size); + transpose_in_place_square(arr, lb_stride, lb_block_size, x + block_size); + } +}