plonky2/src/util/mod.rs
Nicholas Ward 013c8bb612 progress
2021-09-03 18:07:20 -07:00

267 lines
11 KiB
Rust

use crate::field::field_types::Field;
use crate::polynomial::polynomial::PolynomialValues;
pub(crate) mod bimap;
pub(crate) mod context_tree;
pub(crate) mod marking;
pub(crate) mod partial_products;
pub mod reducing;
pub(crate) mod timing;
pub(crate) fn bits_u64(n: u64) -> usize {
(64 - n.leading_zeros()) as usize
}
pub(crate) const fn ceil_div_usize(a: usize, b: usize) -> usize {
(a + b - 1) / b
}
/// Computes `ceil(log_2(n))`.
pub(crate) fn log2_ceil(n: usize) -> usize {
n.next_power_of_two().trailing_zeros() as usize
}
/// Computes `log_2(n)`, panicking if `n` is not a power of two.
pub(crate) fn log2_strict(n: usize) -> usize {
assert!(n.is_power_of_two(), "Not a power of two: {}", n);
log2_ceil(n)
}
pub(crate) fn transpose_poly_values<F: Field>(polys: Vec<PolynomialValues<F>>) -> Vec<Vec<F>> {
let poly_values = polys.into_iter().map(|p| p.values).collect::<Vec<_>>();
transpose(&poly_values)
}
pub fn transpose<F: Field>(matrix: &[Vec<F>]) -> Vec<Vec<F>> {
let l = matrix.len();
let w = matrix[0].len();
let mut transposed = vec![vec![]; w];
for i in 0..w {
transposed[i].reserve_exact(l);
unsafe {
// After .reserve_exact(l), transposed[i] will have capacity at least l. Hence, set_len
// will not cause the buffer to overrun.
transposed[i].set_len(l);
}
}
// Optimization: ensure the larger loop is outside.
if w >= l {
for i in 0..w {
for j in 0..l {
transposed[i][j] = matrix[j][i];
}
}
} else {
for j in 0..l {
for i in 0..w {
transposed[i][j] = matrix[j][i];
}
}
}
transposed
}
/// Permutes `arr` such that each index is mapped to its reverse in binary.
pub(crate) fn reverse_index_bits<T: Copy>(arr: &[T]) -> Vec<T> {
let n = arr.len();
let n_power = log2_strict(n);
if n_power <= 6 {
reverse_index_bits_small(arr, n_power)
} else {
reverse_index_bits_large(arr, n_power)
}
}
/* Both functions below are semantically equivalent to:
for i in 0..n {
result.push(arr[reverse_bits(i, n_power)]);
}
where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
to guide the compiler to generate optimal assembly.
*/
fn reverse_index_bits_small<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
let n = arr.len();
let mut result = Vec::with_capacity(n);
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
let dst_shr_amt = 6 - n_power;
for i in 0..n {
let src = (BIT_REVERSE_6BIT[i] as usize) >> dst_shr_amt;
result.push(arr[src]);
}
result
}
fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
let n = arr.len();
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high
// bits of dst are dependent only on the low bits of src.
let src_lo_shr_amt = 64 - (n_power - 6);
let src_hi_shl_amt = n_power - 6;
let mut result = Vec::with_capacity(n);
for i_chunk in 0..(n >> 6) {
let src_lo = i_chunk.reverse_bits() >> src_lo_shr_amt;
for i_lo in 0..(1 << 6) {
let src_hi = (BIT_REVERSE_6BIT[i_lo] as usize) << src_hi_shl_amt;
let src = src_hi + src_lo;
result.push(arr[src]);
}
}
result
}
pub(crate) fn reverse_index_bits_in_place<T>(arr: &mut Vec<T>) {
let n = arr.len();
let n_power = log2_strict(n);
if n_power <= 6 {
reverse_index_bits_in_place_small(arr, n_power);
} else {
reverse_index_bits_in_place_large(arr, n_power);
}
}
/* Both functions below are semantically equivalent to:
for src in 0..n {
let dst = reverse_bits(src, n_power);
if src < dst {
arr.swap(src, dst);
}
}
where reverse_bits(src, n_power) computes the n_power-bit reverse.
*/
fn reverse_index_bits_in_place_small<T>(arr: &mut Vec<T>, n_power: usize) {
let n = arr.len();
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
let dst_shr_amt = 6 - n_power;
for src in 0..n {
let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt;
if src < dst {
arr.swap(src, dst);
}
}
}
fn reverse_index_bits_in_place_large<T>(arr: &mut Vec<T>, n_power: usize) {
let n = arr.len();
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high
// bits of dst are dependent only on the low bits of src.
let dst_lo_shr_amt = 64 - (n_power - 6);
let dst_hi_shl_amt = n_power - 6;
for src_chunk in 0..(n >> 6) {
let src_hi = src_chunk << 6;
let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt;
for src_lo in 0..(1 << 6) {
let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
let src = src_hi + src_lo;
let dst = dst_hi + dst_lo;
if src < dst {
arr.swap(src, dst);
}
}
}
}
// Lookup table of 6-bit reverses.
// NB: 2^6=64 bytes is a cacheline. A smaller table wastes cache space.
#[rustfmt::skip]
const BIT_REVERSE_6BIT: &[u8] = &[
0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
];
pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize {
// NB: The only reason we need overflowing_shr() here as opposed
// to plain '>>' is to accommodate the case n == num_bits == 0,
// which would become `0 >> 64`. Rust thinks that any shift of 64
// bits causes overflow, even when the argument is zero.
n.reverse_bits()
.overflowing_shr(usize::BITS - num_bits as u32)
.0
}
#[cfg(test)]
mod tests {
use crate::util::{reverse_bits, reverse_index_bits, reverse_index_bits_in_place};
#[test]
fn test_reverse_bits() {
assert_eq!(reverse_bits(0b0000000000, 10), 0b0000000000);
assert_eq!(reverse_bits(0b0000000001, 10), 0b1000000000);
assert_eq!(reverse_bits(0b1000000000, 10), 0b0000000001);
assert_eq!(reverse_bits(0b00000, 5), 0b00000);
assert_eq!(reverse_bits(0b01011, 5), 0b11010);
}
#[test]
fn test_reverse_index_bits() {
assert_eq!(reverse_index_bits(&[10, 20, 30, 40]), vec![10, 30, 20, 40]);
let input256: Vec<u64> = (0..256).collect();
#[rustfmt::skip]
let output256: Vec<u64> = vec![
0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
];
assert_eq!(reverse_index_bits(&input256[..]), output256);
}
#[test]
fn test_reverse_index_bits_in_place() {
let mut arr4: Vec<u64> = vec![10, 20, 30, 40];
reverse_index_bits_in_place(&mut arr4);
assert_eq!(arr4, vec![10, 30, 20, 40]);
let mut arr256: Vec<u64> = (0..256).collect();
#[rustfmt::skip]
let output256: Vec<u64> = vec![
0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
];
reverse_index_bits_in_place(&mut arr256);
assert_eq!(arr256, output256);
}
}