mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-03 14:23:07 +00:00
Optimize bit reverse transpose (#198)
* Bit reverse improvements * Formatting * Tests * Daniel PR comment
This commit is contained in:
parent
d4ee2a6c18
commit
8c4961222f
158
src/util/mod.rs
158
src/util/mod.rs
@ -68,9 +68,48 @@ pub(crate) fn reverse_index_bits<T: Copy>(arr: &[T]) -> Vec<T> {
|
||||
let n = arr.len();
|
||||
let n_power = log2_strict(n);
|
||||
|
||||
if n_power <= 6 {
|
||||
reverse_index_bits_small(arr, n_power)
|
||||
} else {
|
||||
reverse_index_bits_large(arr, n_power)
|
||||
}
|
||||
}
|
||||
|
||||
/* Both functions below are semantically equivalent to:
|
||||
for i in 0..n {
|
||||
result.push(arr[reverse_bits(i, n_power)]);
|
||||
}
|
||||
where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
|
||||
to guide the compiler to generate optimal assembly.
|
||||
*/
|
||||
|
||||
fn reverse_index_bits_small<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
|
||||
let n = arr.len();
|
||||
let mut result = Vec::with_capacity(n);
|
||||
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
|
||||
let dst_shr_amt = 6 - n_power;
|
||||
for i in 0..n {
|
||||
result.push(arr[reverse_bits(i, n_power)]);
|
||||
let src = (BIT_REVERSE_6BIT[i] as usize) >> dst_shr_amt;
|
||||
result.push(arr[src]);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
|
||||
let n = arr.len();
|
||||
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
|
||||
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high
|
||||
// bits of dst are dependent only on the low bits of src.
|
||||
let src_lo_shr_amt = 64 - (n_power - 6);
|
||||
let src_hi_shl_amt = n_power - 6;
|
||||
let mut result = Vec::with_capacity(n);
|
||||
for i_chunk in 0..(n >> 6) {
|
||||
let src_lo = i_chunk.reverse_bits() >> src_lo_shr_amt;
|
||||
for i_lo in 0..(1 << 6) {
|
||||
let src_hi = (BIT_REVERSE_6BIT[i_lo] as usize) << src_hi_shl_amt;
|
||||
let src = src_hi + src_lo;
|
||||
result.push(arr[src]);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
@ -79,14 +118,71 @@ pub(crate) fn reverse_index_bits_in_place<T>(arr: &mut Vec<T>) {
|
||||
let n = arr.len();
|
||||
let n_power = log2_strict(n);
|
||||
|
||||
if n_power <= 6 {
|
||||
reverse_index_bits_in_place_small(arr, n_power);
|
||||
} else {
|
||||
reverse_index_bits_in_place_large(arr, n_power);
|
||||
}
|
||||
}
|
||||
|
||||
/* Both functions below are semantically equivalent to:
|
||||
for src in 0..n {
|
||||
let dst = reverse_bits(src, n_power);
|
||||
if src < dst {
|
||||
arr.swap(src, dst);
|
||||
}
|
||||
}
|
||||
where reverse_bits(src, n_power) computes the n_power-bit reverse.
|
||||
*/
|
||||
|
||||
fn reverse_index_bits_in_place_small<T>(arr: &mut Vec<T>, n_power: usize) {
|
||||
let n = arr.len();
|
||||
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
|
||||
let dst_shr_amt = 6 - n_power;
|
||||
for src in 0..n {
|
||||
let dst = reverse_bits(src, n_power);
|
||||
let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt;
|
||||
if src < dst {
|
||||
arr.swap(src, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reverse_index_bits_in_place_large<T>(arr: &mut Vec<T>, n_power: usize) {
|
||||
let n = arr.len();
|
||||
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
|
||||
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high
|
||||
// bits of dst are dependent only on the low bits of src.
|
||||
let dst_lo_shr_amt = 64 - (n_power - 6);
|
||||
let dst_hi_shl_amt = n_power - 6;
|
||||
for src_chunk in 0..(n >> 6) {
|
||||
let src_hi = src_chunk << 6;
|
||||
let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt;
|
||||
for src_lo in 0..(1 << 6) {
|
||||
let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
|
||||
|
||||
let src = src_hi + src_lo;
|
||||
let dst = dst_hi + dst_lo;
|
||||
if src < dst {
|
||||
arr.swap(src, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup table of 6-bit reverses.
|
||||
// NB: 2^6=64 bytes is a cacheline. A smaller table wastes cache space.
|
||||
#[rustfmt::skip]
|
||||
const BIT_REVERSE_6BIT: &[u8] = &[
|
||||
0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
|
||||
0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
|
||||
0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
|
||||
0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
|
||||
0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
|
||||
0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
|
||||
0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
|
||||
0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
|
||||
];
|
||||
|
||||
pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize {
|
||||
// NB: The only reason we need overflowing_shr() here as opposed
|
||||
// to plain '>>' is to accommodate the case n == num_bits == 0,
|
||||
@ -99,7 +195,7 @@ pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::util::{reverse_bits, reverse_index_bits};
|
||||
use crate::util::{reverse_bits, reverse_index_bits, reverse_index_bits_in_place};
|
||||
|
||||
#[test]
|
||||
fn test_reverse_bits() {
|
||||
@ -113,9 +209,57 @@ mod tests {
|
||||
#[test]
|
||||
fn test_reverse_index_bits() {
|
||||
assert_eq!(reverse_index_bits(&[10, 20, 30, 40]), vec![10, 30, 20, 40]);
|
||||
assert_eq!(
|
||||
reverse_index_bits(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
|
||||
vec![0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]
|
||||
);
|
||||
|
||||
let input256: Vec<u64> = (0..256).collect();
|
||||
#[rustfmt::skip]
|
||||
let output256: Vec<u64> = vec![
|
||||
0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
|
||||
0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
|
||||
0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
|
||||
0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
|
||||
0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
|
||||
0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
|
||||
0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
|
||||
0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
|
||||
0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
|
||||
0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
|
||||
0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
|
||||
0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
|
||||
0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
|
||||
0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
|
||||
0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
|
||||
0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
|
||||
];
|
||||
assert_eq!(reverse_index_bits(&input256[..]), output256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reverse_index_bits_in_place() {
|
||||
let mut arr4: Vec<u64> = vec![10, 20, 30, 40];
|
||||
reverse_index_bits_in_place(&mut arr4);
|
||||
assert_eq!(arr4, vec![10, 30, 20, 40]);
|
||||
|
||||
let mut arr256: Vec<u64> = (0..256).collect();
|
||||
#[rustfmt::skip]
|
||||
let output256: Vec<u64> = vec![
|
||||
0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
|
||||
0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
|
||||
0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
|
||||
0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
|
||||
0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
|
||||
0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
|
||||
0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
|
||||
0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
|
||||
0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
|
||||
0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
|
||||
0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
|
||||
0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
|
||||
0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
|
||||
0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
|
||||
0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
|
||||
0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
|
||||
];
|
||||
reverse_index_bits_in_place(&mut arr256);
|
||||
assert_eq!(arr256, output256);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user