From bb029db2a7100720ae320ff1962405de0004001f Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Fri, 3 Dec 2021 13:12:19 -0800 Subject: [PATCH] Type tweaks for packed types (#387) * PackedField tweaks * AVX2 changes * FFT fixes * tests * test fixes * Lints * Rename things for clarity * Minor interleave fixes * Minor interleave fixes the sequel * Rebase fixes * Docs * Daniel PR comments --- src/field/fft.rs | 21 ++- src/field/packable.rs | 6 +- ...ked_prime_field.rs => avx2_prime_field.rs} | 166 +++++++++++------- src/field/packed_avx2/common.rs | 2 +- src/field/packed_avx2/goldilocks.rs | 4 +- src/field/packed_avx2/mod.rs | 125 +++++++------ src/field/packed_field.rs | 115 ++++++------ 7 files changed, 255 insertions(+), 184 deletions(-) rename src/field/packed_avx2/{packed_prime_field.rs => avx2_prime_field.rs} (75%) diff --git a/src/field/fft.rs b/src/field/fft.rs index 09672278..76e0fd42 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -98,12 +98,12 @@ pub fn ifft_with_options( /// Generic FFT implementation that works with both scalar and packed inputs. #[unroll_for_loops] fn fft_classic_simd( - values: &mut [P::FieldType], + values: &mut [P::Scalar], r: usize, lg_n: usize, - root_table: &FftRootTable, + root_table: &FftRootTable, ) { - let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar. + let lg_packed_width = log2_strict(P::WIDTH); // 0 when P is a scalar. let packed_values = P::pack_slice_mut(values); let packed_n = packed_values.len(); debug_assert!(packed_n == 1 << (lg_n - lg_packed_width)); @@ -121,19 +121,18 @@ fn fft_classic_simd( let half_m = 1 << lg_half_m; // Set omega to root_table[lg_half_m][0..half_m] but repeated. - let mut omega_vec = P::zero().to_vec(); - for (j, omega) in omega_vec.iter_mut().enumerate() { - *omega = root_table[lg_half_m][j % half_m]; + let mut omega = P::ZERO; + for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() { + *omega_j = root_table[lg_half_m][j % half_m]; } - let omega = P::from_slice(&omega_vec[..]); for k in (0..packed_n).step_by(2) { // We have two vectors and want to do math on pairs of adjacent elements (or for // lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the // appropriate shuffling and is its own inverse. - let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m); + let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m); let t = omega * v; - (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m); + (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m); } } } @@ -197,13 +196,13 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT } } - let lg_packed_width = ::PackedType::LOG2_WIDTH; + let lg_packed_width = log2_strict(::Packing::WIDTH); if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. fft_classic_simd::(&mut values[..], r, lg_n, root_table); } else { - fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, root_table); + fft_classic_simd::<::Packing>(&mut values[..], r, lg_n, root_table); } values } diff --git a/src/field/packable.rs b/src/field/packable.rs index e5fc2ac5..a3f96197 100644 --- a/src/field/packable.rs +++ b/src/field/packable.rs @@ -5,14 +5,14 @@ use crate::field::packed_field::PackedField; /// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the /// recommended one. The recommended packing varies by target_arch and target_feature. pub trait Packable: Field { - type PackedType: PackedField; + type Packing: PackedField; } impl Packable for F { - default type PackedType = Self; + default type Packing = Self; } #[cfg(target_feature = "avx2")] impl Packable for crate::field::goldilocks_field::GoldilocksField { - type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2; + type Packing = crate::field::packed_avx2::PackedGoldilocksAvx2; } diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/avx2_prime_field.rs similarity index 75% rename from src/field/packed_avx2/packed_prime_field.rs rename to src/field/packed_avx2/avx2_prime_field.rs index 5800d0bd..b42814c2 100644 --- a/src/field/packed_avx2/packed_prime_field.rs +++ b/src/field/packed_avx2/avx2_prime_field.rs @@ -2,20 +2,20 @@ use core::arch::x86_64::*; use std::fmt; use std::fmt::{Debug, Formatter}; use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::field::field_types::PrimeField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAvx2, }; use crate::field::packed_field::PackedField; -// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that +// Avx2PrimeField wraps an array of four u64s, with the new and get methods to convert that // array to and from __m256i, which is the type we actually operate on. This indirection is a -// terrible trick to change PackedPrimeField's alignment. -// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust +// terrible trick to change Avx2PrimeField's alignment. +// We'd like to be able to cast slices of PrimeField to slices of Avx2PrimeField. Rust // aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to -// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is +// Avx2PrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is // important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. // There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and // friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and @@ -23,12 +23,12 @@ use crate::field::packed_field::PackedField; // were faster, and although this is no longer the case, compilers prefer the aligned versions if // they know that the address is aligned. Using aligned instructions on unaligned addresses leads to // bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and -// therefore PackedPrimeField wraps [F; 4] and not __m256i. +// therefore Avx2PrimeField wraps [F; 4] and not __m256i. #[derive(Copy, Clone)] #[repr(transparent)] -pub struct PackedPrimeField(pub [F; 4]); +pub struct Avx2PrimeField(pub [F; 4]); -impl PackedPrimeField { +impl Avx2PrimeField { #[inline] fn new(x: __m256i) -> Self { let mut obj = Self([F::ZERO; 4]); @@ -45,75 +45,109 @@ impl PackedPrimeField { } } -impl Add for PackedPrimeField { +impl Add for Avx2PrimeField { type Output = Self; #[inline] fn add(self, rhs: Self) -> Self { Self::new(unsafe { add::(self.get(), rhs.get()) }) } } -impl Add for PackedPrimeField { +impl Add for Avx2PrimeField { type Output = Self; #[inline] fn add(self, rhs: F) -> Self { - self + Self::broadcast(rhs) + self + Self::from(rhs) } } -impl AddAssign for PackedPrimeField { +impl Add> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn add(self, rhs: Self::Output) -> Self::Output { + Self::Output::from(self) + rhs + } +} +impl AddAssign for Avx2PrimeField { #[inline] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } -impl AddAssign for PackedPrimeField { +impl AddAssign for Avx2PrimeField { #[inline] fn add_assign(&mut self, rhs: F) { *self = *self + rhs; } } -impl Debug for PackedPrimeField { +impl Debug for Avx2PrimeField { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "({:?})", self.get()) } } -impl Default for PackedPrimeField { +impl Default for Avx2PrimeField { #[inline] fn default() -> Self { - Self::zero() + Self::ZERO } } -impl Mul for PackedPrimeField { +impl Div for Avx2PrimeField { + type Output = Self; + #[inline] + fn div(self, rhs: F) -> Self { + self * rhs.inverse() + } +} +impl DivAssign for Avx2PrimeField { + #[inline] + fn div_assign(&mut self, rhs: F) { + *self *= rhs.inverse(); + } +} + +impl From for Avx2PrimeField { + fn from(x: F) -> Self { + Self([x; 4]) + } +} + +impl Mul for Avx2PrimeField { type Output = Self; #[inline] fn mul(self, rhs: Self) -> Self { Self::new(unsafe { mul::(self.get(), rhs.get()) }) } } -impl Mul for PackedPrimeField { +impl Mul for Avx2PrimeField { type Output = Self; #[inline] fn mul(self, rhs: F) -> Self { - self * Self::broadcast(rhs) + self * Self::from(rhs) } } -impl MulAssign for PackedPrimeField { +impl Mul> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn mul(self, rhs: Avx2PrimeField) -> Self::Output { + Self::Output::from(self) * rhs + } +} +impl MulAssign for Avx2PrimeField { #[inline] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } -impl MulAssign for PackedPrimeField { +impl MulAssign for Avx2PrimeField { #[inline] fn mul_assign(&mut self, rhs: F) { *self = *self * rhs; } } -impl Neg for PackedPrimeField { +impl Neg for Avx2PrimeField { type Output = Self; #[inline] fn neg(self) -> Self { @@ -121,52 +155,59 @@ impl Neg for PackedPrimeField { } } -impl Product for PackedPrimeField { +impl Product for Avx2PrimeField { #[inline] fn product>(iter: I) -> Self { - iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + iter.reduce(|x, y| x * y).unwrap_or(Self::ONE) } } -impl PackedField for PackedPrimeField { - const LOG2_WIDTH: usize = 2; +unsafe impl PackedField for Avx2PrimeField { + const WIDTH: usize = 4; - type FieldType = F; + type Scalar = F; + type PackedPrimeField = Avx2PrimeField; + + const ZERO: Self = Self([F::ZERO; 4]); + const ONE: Self = Self([F::ONE; 4]); #[inline] - fn broadcast(x: F) -> Self { - Self([x; 4]) - } - - #[inline] - fn from_arr(arr: [F; Self::WIDTH]) -> Self { + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { Self(arr) } #[inline] - fn to_arr(&self) -> [F; Self::WIDTH] { + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { self.0 } #[inline] - fn from_slice(slice: &[F]) -> Self { - assert!(slice.len() == 4); - Self([slice[0], slice[1], slice[2], slice[3]]) + fn from_slice(slice: &[Self::Scalar]) -> &Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &*slice.as_ptr().cast() } + } + #[inline] + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &mut *slice.as_mut_ptr().cast() } + } + #[inline] + fn as_slice(&self) -> &[Self::Scalar] { + &self.0[..] + } + #[inline] + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + &mut self.0[..] } #[inline] - fn to_vec(&self) -> Vec { - self.0.into() - } - - #[inline] - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { let (v0, v1) = (self.get(), other.get()); - let (res0, res1) = match r { - 0 => unsafe { interleave0(v0, v1) }, + let (res0, res1) = match block_len { 1 => unsafe { interleave1(v0, v1) }, - 2 => (v0, v1), - _ => panic!("r cannot be more than LOG2_WIDTH"), + 2 => unsafe { interleave2(v0, v1) }, + 4 => (v0, v1), + _ => panic!("unsupported block_len"), }; (Self::new(res0), Self::new(res1)) } @@ -177,37 +218,44 @@ impl PackedField for PackedPrimeField { } } -impl Sub for PackedPrimeField { +impl Sub for Avx2PrimeField { type Output = Self; #[inline] fn sub(self, rhs: Self) -> Self { Self::new(unsafe { sub::(self.get(), rhs.get()) }) } } -impl Sub for PackedPrimeField { +impl Sub for Avx2PrimeField { type Output = Self; #[inline] fn sub(self, rhs: F) -> Self { - self - Self::broadcast(rhs) + self - Self::from(rhs) } } -impl SubAssign for PackedPrimeField { +impl Sub> for as PackedField>::Scalar { + type Output = Avx2PrimeField; + #[inline] + fn sub(self, rhs: Avx2PrimeField) -> Self::Output { + Self::Output::from(self) - rhs + } +} +impl SubAssign for Avx2PrimeField { #[inline] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } -impl SubAssign for PackedPrimeField { +impl SubAssign for Avx2PrimeField { #[inline] fn sub_assign(&mut self, rhs: F) { *self = *self - rhs; } } -impl Sum for PackedPrimeField { +impl Sum for Avx2PrimeField { #[inline] fn sum>(iter: I) -> Self { - iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO) } } @@ -367,25 +415,25 @@ unsafe fn square64(x: __m256i) -> (__m256i, __m256i) { /// Multiply two integers modulo FIELD_ORDER. #[inline] -unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { +unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { F::reduce128(mul64_64(x, y)) } /// Square an integer modulo FIELD_ORDER. #[inline] -unsafe fn square(x: __m256i) -> __m256i { +unsafe fn square(x: __m256i) -> __m256i { F::reduce128(square64(x)) } #[inline] -unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { +unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let a = _mm256_unpacklo_epi64(x, y); let b = _mm256_unpackhi_epi64(x, y); (a, b) } #[inline] -unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { +unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. // 1 places y_lo in the high half of x; 0 would place it in the lower half. diff --git a/src/field/packed_avx2/common.rs b/src/field/packed_avx2/common.rs index c100e6dc..48f9524d 100644 --- a/src/field/packed_avx2/common.rs +++ b/src/field/packed_avx2/common.rs @@ -2,7 +2,7 @@ use core::arch::x86_64::*; use crate::field::field_types::PrimeField; -pub trait ReducibleAVX2: PrimeField { +pub trait ReducibleAvx2: PrimeField { unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i; } diff --git a/src/field/packed_avx2/goldilocks.rs b/src/field/packed_avx2/goldilocks.rs index 186c8e0c..954516b8 100644 --- a/src/field/packed_avx2/goldilocks.rs +++ b/src/field/packed_avx2/goldilocks.rs @@ -2,12 +2,12 @@ use core::arch::x86_64::*; use crate::field::goldilocks_field::GoldilocksField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAvx2, }; /// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is /// similarly shifted. -impl ReducibleAVX2 for GoldilocksField { +impl ReducibleAvx2 for GoldilocksField { #[inline] unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i { let (hi0, lo0) = x; diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs index 20eecba7..5f6294a4 100644 --- a/src/field/packed_avx2/mod.rs +++ b/src/field/packed_avx2/mod.rs @@ -1,21 +1,21 @@ +mod avx2_prime_field; mod common; mod goldilocks; -mod packed_prime_field; -use packed_prime_field::PackedPrimeField; +use avx2_prime_field::Avx2PrimeField; use crate::field::goldilocks_field::GoldilocksField; -pub type PackedGoldilocksAVX2 = PackedPrimeField; +pub type PackedGoldilocksAvx2 = Avx2PrimeField; #[cfg(test)] mod tests { use crate::field::goldilocks_field::GoldilocksField; - use crate::field::packed_avx2::common::ReducibleAVX2; - use crate::field::packed_avx2::packed_prime_field::PackedPrimeField; + use crate::field::packed_avx2::avx2_prime_field::Avx2PrimeField; + use crate::field::packed_avx2::common::ReducibleAvx2; use crate::field::packed_field::PackedField; - fn test_vals_a() -> [F; 4] { + fn test_vals_a() -> [F; 4] { [ F::from_noncanonical_u64(14479013849828404771), F::from_noncanonical_u64(9087029921428221768), @@ -23,7 +23,7 @@ mod tests { F::from_noncanonical_u64(5646033492608483824), ] } - fn test_vals_b() -> [F; 4] { + fn test_vals_b() -> [F; 4] { [ F::from_noncanonical_u64(17891926589593242302), F::from_noncanonical_u64(11009798273260028228), @@ -32,17 +32,17 @@ mod tests { ] } - fn test_add() + fn test_add() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a + packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b); for (exp, res) in expected.zip(arr_res) { @@ -50,17 +50,17 @@ mod tests { } } - fn test_mul() + fn test_mul() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a * packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b); for (exp, res) in expected.zip(arr_res) { @@ -68,15 +68,15 @@ mod tests { } } - fn test_square() + fn test_square() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); let packed_res = packed_a.square(); - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().map(|&a| a.square()); for (exp, res) in expected.zip(arr_res) { @@ -84,15 +84,15 @@ mod tests { } } - fn test_neg() + fn test_neg() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); let packed_res = -packed_a; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().map(|&a| -a); for (exp, res) in expected.zip(arr_res) { @@ -100,17 +100,17 @@ mod tests { } } - fn test_sub() + fn test_sub() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); let packed_res = packed_a - packed_b; - let arr_res = packed_res.to_arr(); + let arr_res = packed_res.as_arr(); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b); for (exp, res) in expected.zip(arr_res) { @@ -118,33 +118,39 @@ mod tests { } } - fn test_interleave_is_involution() + fn test_interleave_is_involution() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let a_arr = test_vals_a::(); let b_arr = test_vals_b::(); - let packed_a = PackedPrimeField::::from_arr(a_arr); - let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_a = Avx2PrimeField::::from_arr(a_arr); + let packed_b = Avx2PrimeField::::from_arr(b_arr); { // Interleave, then deinterleave. - let (x, y) = packed_a.interleave(packed_b, 0); - let (res_a, res_b) = x.interleave(y, 0); - assert_eq!(res_a.to_arr(), a_arr); - assert_eq!(res_b.to_arr(), b_arr); - } - { let (x, y) = packed_a.interleave(packed_b, 1); let (res_a, res_b) = x.interleave(y, 1); - assert_eq!(res_a.to_arr(), a_arr); - assert_eq!(res_b.to_arr(), b_arr); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 2); + let (res_a, res_b) = x.interleave(y, 2); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 4); + let (res_a, res_b) = x.interleave(y, 4); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); } } - fn test_interleave() + fn test_interleave() where - [(); PackedPrimeField::::WIDTH]:, + [(); Avx2PrimeField::::WIDTH]:, { let in_a: [F; 4] = [ F::from_noncanonical_u64(00), @@ -158,42 +164,47 @@ mod tests { F::from_noncanonical_u64(12), F::from_noncanonical_u64(13), ]; - let int0_a: [F; 4] = [ + let int1_a: [F; 4] = [ F::from_noncanonical_u64(00), F::from_noncanonical_u64(10), F::from_noncanonical_u64(02), F::from_noncanonical_u64(12), ]; - let int0_b: [F; 4] = [ + let int1_b: [F; 4] = [ F::from_noncanonical_u64(01), F::from_noncanonical_u64(11), F::from_noncanonical_u64(03), F::from_noncanonical_u64(13), ]; - let int1_a: [F; 4] = [ + let int2_a: [F; 4] = [ F::from_noncanonical_u64(00), F::from_noncanonical_u64(01), F::from_noncanonical_u64(10), F::from_noncanonical_u64(11), ]; - let int1_b: [F; 4] = [ + let int2_b: [F; 4] = [ F::from_noncanonical_u64(02), F::from_noncanonical_u64(03), F::from_noncanonical_u64(12), F::from_noncanonical_u64(13), ]; - let packed_a = PackedPrimeField::::from_arr(in_a); - let packed_b = PackedPrimeField::::from_arr(in_b); - { - let (x0, y0) = packed_a.interleave(packed_b, 0); - assert_eq!(x0.to_arr(), int0_a); - assert_eq!(y0.to_arr(), int0_b); - } + let packed_a = Avx2PrimeField::::from_arr(in_a); + let packed_b = Avx2PrimeField::::from_arr(in_b); { let (x1, y1) = packed_a.interleave(packed_b, 1); - assert_eq!(x1.to_arr(), int1_a); - assert_eq!(y1.to_arr(), int1_b); + assert_eq!(x1.as_arr(), int1_a); + assert_eq!(y1.as_arr(), int1_b); + } + { + let (x2, y2) = packed_a.interleave(packed_b, 2); + assert_eq!(x2.as_arr(), int2_a); + assert_eq!(y2.as_arr(), int2_b); + } + { + let (x4, y4) = packed_a.interleave(packed_b, 4); + assert_eq!(x4.as_arr(), in_a); + assert_eq!(y4.as_arr(), in_b); } } diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index 69733bca..f2b0c83e 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -1,76 +1,82 @@ use std::fmt::Debug; use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::slice; use crate::field::field_types::Field; -pub trait PackedField: +/// # Safety +/// - WIDTH is assumed to be a power of 2. +/// - If P implements PackedField then P must be castable to/from [P::Scalar; P::WIDTH] without UB. +pub unsafe trait PackedField: 'static + Add - + Add + + Add + AddAssign - + AddAssign + + AddAssign + Copy + Debug + Default - // TODO: Implementing Div sounds like a pain so it's a worry for later. + + From + // TODO: Implement packed / packed division + + Div + Mul - + Mul + + Mul + MulAssign - + MulAssign + + MulAssign + Neg + Product + Send + Sub - + Sub + + Sub + SubAssign - + SubAssign + + SubAssign + Sum + Sync +where + Self::Scalar: Add, + Self::Scalar: Mul, + Self::Scalar: Sub, { - type FieldType: Field; + type Scalar: Field; + type PackedPrimeField: PackedField::PrimeField>; - const LOG2_WIDTH: usize; - const WIDTH: usize = 1 << Self::LOG2_WIDTH; + const WIDTH: usize; + const ZERO: Self; + const ONE: Self; fn square(&self) -> Self { *self * *self } - fn zero() -> Self { - Self::broadcast(Self::FieldType::ZERO) - } - fn one() -> Self { - Self::broadcast(Self::FieldType::ONE) - } + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self; + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH]; - fn broadcast(x: Self::FieldType) -> Self; + fn from_slice(slice: &[Self::Scalar]) -> &Self; + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self; + fn as_slice(&self) -> &[Self::Scalar]; + fn as_slice_mut(&mut self) -> &mut [Self::Scalar]; - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self; - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH]; - - fn from_slice(slice: &[Self::FieldType]) -> Self; - fn to_vec(&self) -> Vec; - - /// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those + /// Take interpret two vectors as chunks of block_len elements. Unpack and interleave those /// chunks. This is best seen with an example. If we have: /// A = [x0, y0, x1, y1], /// B = [x2, y2, x3, y3], /// then - /// interleave(A, B, 0) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). + /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). /// Pairs that were adjacent in the input are at corresponding positions in the output. - /// r lets us set the size of chunks we're interleaving. If we set r = 1, then for + /// r lets us set the size of chunks we're interleaving. If we set block_len = 2, then for /// A = [x0, x1, y0, y1], /// B = [x2, x3, y2, y3], /// we obtain - /// interleave(A, B, r) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). + /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and /// transposing those matrices. - /// When r = LOG2_WIDTH, this operation is a no-op. Values of r > LOG2_WIDTH are not - /// permitted. - fn interleave(&self, other: Self, r: usize) -> (Self, Self); + /// When block_len = WIDTH, this operation is a no-op. block_len must divide WIDTH. Since + /// WIDTH is specified to be a power of 2, block_len must also be a power of 2. It cannot be 0 + /// and it cannot be > WIDTH. + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self); - fn pack_slice(buf: &[Self::FieldType]) -> &[Self] { + fn pack_slice(buf: &[Self::Scalar]) -> &[Self] { assert!( buf.len() % Self::WIDTH == 0, "Slice length (got {}) must be a multiple of packed field width ({}).", @@ -81,7 +87,7 @@ pub trait PackedField: let n = buf.len() / Self::WIDTH; unsafe { std::slice::from_raw_parts(buf_ptr, n) } } - fn pack_slice_mut(buf: &mut [Self::FieldType]) -> &mut [Self] { + fn pack_slice_mut(buf: &mut [Self::Scalar]) -> &mut [Self] { assert!( buf.len() % Self::WIDTH == 0, "Slice length (got {}) must be a multiple of packed field width ({}).", @@ -94,35 +100,42 @@ pub trait PackedField: } } -impl PackedField for F { - type FieldType = Self; +unsafe impl PackedField for F { + type Scalar = Self; + type PackedPrimeField = F::PrimeField; - const LOG2_WIDTH: usize = 0; + const WIDTH: usize = 1; + const ZERO: Self = ::ZERO; + const ONE: Self = ::ONE; - fn broadcast(x: Self::FieldType) -> Self { - x + fn square(&self) -> Self { + ::square(self) } - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self { + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { arr[0] } - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] { + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { [*self] } - fn from_slice(slice: &[Self::FieldType]) -> Self { - assert_eq!(slice.len(), 1); - slice[0] + fn from_slice(slice: &[Self::Scalar]) -> &Self { + &slice[0] } - fn to_vec(&self) -> Vec { - vec![*self] + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + &mut slice[0] + } + fn as_slice(&self) -> &[Self::Scalar] { + slice::from_ref(self) + } + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + slice::from_mut(self) } - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { - if r == 0 { - (*self, other) - } else { - panic!("r > LOG2_WIDTH"); + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + match block_len { + 1 => (*self, other), + _ => panic!("unsupported block length"), } } }