mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-09 09:13:09 +00:00
Style (incl. Daniel PR comments)
This commit is contained in:
parent
7ee7d8bf8a
commit
87f5201e6f
@ -13,6 +13,6 @@ impl<F: Field> Packable for F {
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
impl Packable for CrandallField {
|
||||
impl Packable for crate::field::crandall_field::CrandallField {
|
||||
type PackedType = crate::field::packed_crandall_avx2::PackedCrandallAVX2;
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
// PackedCrandallAVX2 wraps an array of four u64s, with the new and get methods to convert that
|
||||
@ -147,10 +146,10 @@ impl PackedField for PackedCrandallAVX2 {
|
||||
}
|
||||
#[inline]
|
||||
fn to_vec(&self) -> Vec<Self::FieldType> {
|
||||
let a = unsafe { _mm256_extract_epi64(self.get(), 0) } as u64;
|
||||
let b = unsafe { _mm256_extract_epi64(self.get(), 1) } as u64;
|
||||
let c = unsafe { _mm256_extract_epi64(self.get(), 2) } as u64;
|
||||
let d = unsafe { _mm256_extract_epi64(self.get(), 3) } as u64;
|
||||
let a = unsafe { _mm256_extract_epi64::<0>(self.get()) } as u64;
|
||||
let b = unsafe { _mm256_extract_epi64::<1>(self.get()) } as u64;
|
||||
let c = unsafe { _mm256_extract_epi64::<2>(self.get()) } as u64;
|
||||
let d = unsafe { _mm256_extract_epi64::<3>(self.get()) } as u64;
|
||||
vec![
|
||||
CrandallField(a),
|
||||
CrandallField(b),
|
||||
@ -293,20 +292,26 @@ unsafe fn canonicalize_s(x_s: __m256i) -> __m256i {
|
||||
_mm256_add_epi64(x_s, wrapback_amt)
|
||||
}
|
||||
|
||||
/// Addition u64 + u64 -> u64. Assumes that x + y < 2^64 + FIELD_ORDER. The second argument is
|
||||
/// pre-shifted by 1 << 63. The result is similarly shifted.
|
||||
#[inline]
|
||||
unsafe fn add_no_canonicalize_64_64s_s(x: __m256i, y_s: __m256i) -> __m256i {
|
||||
let res_wrapped_s = _mm256_add_epi64(x, y_s);
|
||||
let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); // -1 if overflowed else 0.
|
||||
let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if overflowed else 0.
|
||||
let res_s = _mm256_add_epi64(res_wrapped_s, wrapback_amt);
|
||||
res_s
|
||||
}
|
||||
|
||||
// Theoretical throughput (Skylake)
|
||||
// Scalar version (compiled): 1.75 cycles/(op * word)
|
||||
// Scalar version (optimized asm): 1 cycle/(op * word)
|
||||
// Below (256-bit vectors): .75 cycles/(op * word)
|
||||
#[inline]
|
||||
unsafe fn add(x: __m256i, y: __m256i) -> __m256i {
|
||||
let mut y_s = shift(y);
|
||||
y_s = canonicalize_s(y_s);
|
||||
let res_wrapped_s = _mm256_add_epi64(x, y_s);
|
||||
let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); // 1 if overflowed else 0.
|
||||
let res_wrapped = shift(res_wrapped_s);
|
||||
let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if overflowed else 0.
|
||||
let res = _mm256_add_epi64(res_wrapped, wrapback_amt);
|
||||
res
|
||||
let y_s = shift(y);
|
||||
let res_s = add_no_canonicalize_64_64s_s(x, canonicalize_s(y_s));
|
||||
shift(res_s)
|
||||
}
|
||||
|
||||
// Theoretical throughput (Skylake)
|
||||
@ -318,7 +323,7 @@ unsafe fn sub(x: __m256i, y: __m256i) -> __m256i {
|
||||
let mut y_s = shift(y);
|
||||
y_s = canonicalize_s(y_s);
|
||||
let x_s = shift(x);
|
||||
let mask = _mm256_cmpgt_epi64(y_s, x_s); // 1 if sub will underflow (y > y) else 0.
|
||||
let mask = _mm256_cmpgt_epi64(y_s, x_s); // -1 if sub will underflow (y > x) else 0.
|
||||
let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if underflow else 0.
|
||||
let res_wrapped = _mm256_sub_epi64(x_s, y_s);
|
||||
let res = _mm256_sub_epi64(res_wrapped, wrapback_amt);
|
||||
@ -333,7 +338,8 @@ unsafe fn sub(x: __m256i, y: __m256i) -> __m256i {
|
||||
unsafe fn neg(y: __m256i) -> __m256i {
|
||||
let y_s = shift(y);
|
||||
let field_order_s = shift(field_order());
|
||||
let mask = _mm256_cmpgt_epi64(y_s, field_order_s); // 1 if sub will underflow (y > y) else 0.
|
||||
// mask is -1 if sub will underflow (y > field_order) else 0.
|
||||
let mask = _mm256_cmpgt_epi64(y_s, field_order_s);
|
||||
let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if underflow else 0.
|
||||
let res_wrapped = _mm256_sub_epi64(field_order_s, y_s);
|
||||
let res = _mm256_sub_epi64(res_wrapped, wrapback_amt);
|
||||
@ -352,39 +358,35 @@ unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let mul_hh = _mm256_mul_epu32(x_hi, y_hi);
|
||||
|
||||
let res_lo0_s = shift(mul_ll);
|
||||
let res_hi0 = mul_hh;
|
||||
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32));
|
||||
let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32));
|
||||
|
||||
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
|
||||
// overflow and must be subtracted, not added.
|
||||
let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
|
||||
let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s);
|
||||
|
||||
let res_hi0 = mul_hh;
|
||||
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32));
|
||||
let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32));
|
||||
let res_hi3 = _mm256_sub_epi64(res_hi2, carry0);
|
||||
let res_hi4 = _mm256_sub_epi64(res_hi3, carry1);
|
||||
|
||||
let res_lo3_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32));
|
||||
let res_hi3 = _mm256_sub_epi64(res_hi2, _mm256_cmpgt_epi64(res_lo0_s, res_lo3_s)); // Carry.
|
||||
|
||||
let res_lo4_s = _mm256_add_epi32(res_lo3_s, _mm256_slli_epi64(mul_hl, 32));
|
||||
let res_hi4 = _mm256_sub_epi64(res_hi3, _mm256_cmpgt_epi64(res_lo3_s, res_lo4_s)); // Carry.
|
||||
|
||||
(res_hi4, res_lo4_s)
|
||||
(res_hi4, res_lo2_s)
|
||||
}
|
||||
|
||||
/// u128 + u64 addition with carry. The second argument is pre-shifted by 2^63. The result is also
|
||||
/// shifted.
|
||||
/// (u64 << 64) + u64 + u64 -> u128 addition with carry. The third argument is pre-shifted by 2^63.
|
||||
/// The result is also shifted.
|
||||
#[inline]
|
||||
unsafe fn add_with_carry128_64s_s(x: (__m256i, __m256i), y_s: __m256i) -> (__m256i, __m256i) {
|
||||
let (x_hi, x_lo) = x;
|
||||
let res_lo_s = _mm256_add_epi64(x_lo, y_s);
|
||||
let carry = _mm256_cmpgt_epi64(y_s, res_lo_s);
|
||||
let res_hi = _mm256_sub_epi64(x_hi, carry);
|
||||
(res_hi, res_lo_s)
|
||||
}
|
||||
|
||||
/// u128 + u64 addition with carry. The first argument is pre-shifted by 2^63. The result is also
|
||||
/// shifted.
|
||||
#[inline]
|
||||
unsafe fn add_with_carry128s_64_s(x_s: (__m256i, __m256i), y: __m256i) -> (__m256i, __m256i) {
|
||||
let (x_hi, x_lo_s) = x_s;
|
||||
let res_lo_s = _mm256_add_epi64(x_lo_s, y);
|
||||
let carry = _mm256_cmpgt_epi64(x_lo_s, res_lo_s);
|
||||
let res_hi = _mm256_sub_epi64(x_hi, carry);
|
||||
unsafe fn add_with_carry_hi_lo_los_s(
|
||||
hi: __m256i,
|
||||
lo0: __m256i,
|
||||
lo1_s: __m256i,
|
||||
) -> (__m256i, __m256i) {
|
||||
let res_lo_s = _mm256_add_epi64(lo0, lo1_s);
|
||||
// carry is -1 if overflow (res_lo < lo1) because cmpgt returns -1 on true and 0 on false.
|
||||
let carry = _mm256_cmpgt_epi64(lo1_s, res_lo_s);
|
||||
let res_hi = _mm256_sub_epi64(hi, carry);
|
||||
(res_hi, res_lo_s)
|
||||
}
|
||||
|
||||
@ -395,8 +397,8 @@ unsafe fn fmadd_64_32_64s_s(x: __m256i, y: __m256i, z_s: __m256i) -> (__m256i, _
|
||||
let x_hi = _mm256_srli_epi64(x, 32);
|
||||
let mul_lo = _mm256_mul_epu32(x, y);
|
||||
let mul_hi = _mm256_mul_epu32(x_hi, y);
|
||||
let tmp_s = add_with_carry128_64s_s((_mm256_srli_epi64(mul_hi, 32), mul_lo), z_s);
|
||||
add_with_carry128s_64_s(tmp_s, _mm256_slli_epi64(mul_hi, 32))
|
||||
let (tmp_hi, tmp_lo_s) = add_with_carry_hi_lo_los_s(_mm256_srli_epi64(mul_hi, 32), mul_lo, z_s);
|
||||
add_with_carry_hi_lo_los_s(tmp_hi, _mm256_slli_epi64(mul_hi, 32), tmp_lo_s)
|
||||
}
|
||||
|
||||
/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is
|
||||
@ -406,10 +408,7 @@ unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i {
|
||||
let (hi0, lo0_s) = x_s;
|
||||
let (hi1, lo1_s) = fmadd_64_32_64s_s(hi0, epsilon(), lo0_s);
|
||||
let lo2 = _mm256_mul_epu32(hi1, epsilon());
|
||||
let res_wrapped_s = _mm256_add_epi64(lo1_s, lo2);
|
||||
let carry_mask = _mm256_cmpgt_epi64(lo1_s, res_wrapped_s); // all 1 if overflow
|
||||
let res_s = _mm256_add_epi64(res_wrapped_s, _mm256_and_si256(carry_mask, epsilon()));
|
||||
res_s
|
||||
add_no_canonicalize_64_64s_s(lo2, lo1_s)
|
||||
}
|
||||
|
||||
/// Multiply two integers modulo FIELD_ORDER.
|
||||
@ -430,7 +429,7 @@ unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let y_lo = _mm256_castsi256_si128(y); // This has 0 cost.
|
||||
|
||||
// 1 places y_lo in the high half of x; 0 would place it in the lower half.
|
||||
let a = _mm256_inserti128_si256(x, y_lo, 1);
|
||||
let a = _mm256_inserti128_si256::<1>(x, y_lo);
|
||||
// NB: _mm256_permute2x128_si256 could be used here as well but _mm256_inserti128_si256 has
|
||||
// lower latency on Zen 3 processors.
|
||||
|
||||
@ -440,7 +439,7 @@ unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
// 2 => src2[low 128 bits]
|
||||
// 3 => src2[high 128 bits]
|
||||
// The low (resp. high) nibble chooses the low (resp. high) 128 bits of the result.
|
||||
let b = _mm256_permute2x128_si256(x, y, 0x31);
|
||||
let b = _mm256_permute2x128_si256::<0x31>(x, y);
|
||||
|
||||
(a, b)
|
||||
}
|
||||
@ -514,17 +513,14 @@ mod tests {
|
||||
let packed_res = packed_a - packed_b;
|
||||
let arr_res = packed_res.to_vec();
|
||||
|
||||
let expected = TEST_VALS_A
|
||||
.iter()
|
||||
.zip(TEST_VALS_B.iter())
|
||||
.map(|(&a, &b)| a - b);
|
||||
let expected = TEST_VALS_A.iter().zip(TEST_VALS_B).map(|(&a, &b)| a - b);
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
assert_eq!(res, exp);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interleave_is_bijection() {
|
||||
fn test_interleave_is_involution() {
|
||||
let packed_a = PackedCrandallAVX2::new_from_slice(TEST_VALS_A);
|
||||
let packed_b = PackedCrandallAVX2::new_from_slice(TEST_VALS_B);
|
||||
{
|
||||
@ -544,54 +540,54 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_interleave() {
|
||||
let arr_a: [CrandallField; 4] = [
|
||||
let in_a: [CrandallField; 4] = [
|
||||
CrandallField(00),
|
||||
CrandallField(01),
|
||||
CrandallField(02),
|
||||
CrandallField(03),
|
||||
];
|
||||
let arr_b: [CrandallField; 4] = [
|
||||
let in_b: [CrandallField; 4] = [
|
||||
CrandallField(10),
|
||||
CrandallField(11),
|
||||
CrandallField(12),
|
||||
CrandallField(13),
|
||||
];
|
||||
let arr_x0: [CrandallField; 4] = [
|
||||
let int0_a: [CrandallField; 4] = [
|
||||
CrandallField(00),
|
||||
CrandallField(10),
|
||||
CrandallField(02),
|
||||
CrandallField(12),
|
||||
];
|
||||
let arr_y0: [CrandallField; 4] = [
|
||||
let int0_b: [CrandallField; 4] = [
|
||||
CrandallField(01),
|
||||
CrandallField(11),
|
||||
CrandallField(03),
|
||||
CrandallField(13),
|
||||
];
|
||||
let arr_x1: [CrandallField; 4] = [
|
||||
let int1_a: [CrandallField; 4] = [
|
||||
CrandallField(00),
|
||||
CrandallField(01),
|
||||
CrandallField(10),
|
||||
CrandallField(11),
|
||||
];
|
||||
let arr_y1: [CrandallField; 4] = [
|
||||
let int1_b: [CrandallField; 4] = [
|
||||
CrandallField(02),
|
||||
CrandallField(03),
|
||||
CrandallField(12),
|
||||
CrandallField(13),
|
||||
];
|
||||
|
||||
let packed_a = PackedCrandallAVX2::new_from_slice(&arr_a);
|
||||
let packed_b = PackedCrandallAVX2::new_from_slice(&arr_b);
|
||||
let packed_a = PackedCrandallAVX2::new_from_slice(&in_a);
|
||||
let packed_b = PackedCrandallAVX2::new_from_slice(&in_b);
|
||||
{
|
||||
let (x0, y0) = packed_a.interleave(packed_b, 0);
|
||||
assert_eq!(x0.to_vec()[..], arr_x0);
|
||||
assert_eq!(y0.to_vec()[..], arr_y0);
|
||||
assert_eq!(x0.to_vec()[..], int0_a);
|
||||
assert_eq!(y0.to_vec()[..], int0_b);
|
||||
}
|
||||
{
|
||||
let (x1, y1) = packed_a.interleave(packed_b, 1);
|
||||
assert_eq!(x1.to_vec()[..], arr_x1);
|
||||
assert_eq!(y1.to_vec()[..], arr_y1);
|
||||
assert_eq!(x1.to_vec()[..], int1_a);
|
||||
assert_eq!(y1.to_vec()[..], int1_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user