Working unrolled xor for all modes

This commit is contained in:
Christopher Taylor 2017-06-04 04:07:07 -07:00
parent 94a4c5731b
commit 5dc9f49298
4 changed files with 226 additions and 27 deletions

View File

@ -137,10 +137,10 @@ void InitializeCPUArch()
#endif // LEO_TRY_AVX2 #endif // LEO_TRY_AVX2
#ifndef LEO_USE_SSSE3_OPT #ifndef LEO_USE_SSSE3_OPT
CpuHasAVX2 = false; CpuHasSSSE3 = false;
#endif // LEO_USE_SSSE3_OPT #endif // LEO_USE_SSSE3_OPT
#ifndef LEO_USE_AVX2_OPT #ifndef LEO_USE_AVX2_OPT
CpuHasSSSE3 = false; CpuHasAVX2 = false;
#endif // LEO_USE_AVX2_OPT #endif // LEO_USE_AVX2_OPT
#endif // LEO_TARGET_MOBILE #endif // LEO_TARGET_MOBILE

View File

@ -165,7 +165,7 @@
// Enable 8-bit or 16-bit fields // Enable 8-bit or 16-bit fields
#define LEO_HAS_FF8 #define LEO_HAS_FF8
//#define LEO_HAS_FF16 #define LEO_HAS_FF16
// Enable using SIMD instructions // Enable using SIMD instructions
#define LEO_USE_SSSE3_OPT #define LEO_USE_SSSE3_OPT

View File

@ -704,9 +704,11 @@ static void IFFT_DIT2_xor(
xor_mem(y_in, x_in, bytes); xor_mem(y_in, x_in, bytes);
unsigned count = bytes;
ffe_t * LEO_RESTRICT y1 = reinterpret_cast<ffe_t *>(y_in);
#ifdef LEO_TARGET_MOBILE #ifdef LEO_TARGET_MOBILE
ffe_t * LEO_RESTRICT x1 = reinterpret_cast<ffe_t *>(x_in); ffe_t * LEO_RESTRICT x1 = reinterpret_cast<ffe_t *>(x_in);
ffe_t * LEO_RESTRICT y1 = reinterpret_cast<ffe_t *>(y_in);
do do
{ {
@ -714,11 +716,10 @@ static void IFFT_DIT2_xor(
x1[j] ^= lut[y1[j]]; x1[j] ^= lut[y1[j]];
x1 += 64, y1 += 64; x1 += 64, y1 += 64;
bytes -= 64; count -= 64;
} while (bytes > 0); } while (count > 0);
#else #else
uint64_t * LEO_RESTRICT x8 = reinterpret_cast<uint64_t *>(x_in); uint64_t * LEO_RESTRICT x8 = reinterpret_cast<uint64_t *>(x_in);
ffe_t * LEO_RESTRICT y1 = reinterpret_cast<ffe_t *>(y_in);
do do
{ {
@ -738,8 +739,8 @@ static void IFFT_DIT2_xor(
} }
x8 += 8; x8 += 8;
bytes -= 64; count -= 64;
} while (bytes > 0); } while (count > 0);
#endif #endif
xor_mem(y_out, y_in, bytes); xor_mem(y_out, y_in, bytes);
@ -777,10 +778,10 @@ static void IFFT_DIT4(
do do
{ {
// First layer:
LEO_M256 work0_reg = _mm256_loadu_si256(work0); LEO_M256 work0_reg = _mm256_loadu_si256(work0);
LEO_M256 work1_reg = _mm256_loadu_si256(work1); LEO_M256 work1_reg = _mm256_loadu_si256(work1);
// First layer:
work1_reg = _mm256_xor_si256(work0_reg, work1_reg); work1_reg = _mm256_xor_si256(work0_reg, work1_reg);
if (log_m01 != kModulus) if (log_m01 != kModulus)
LEO_MULADD_256(work0_reg, work1_reg, t01_lo, t01_hi); LEO_MULADD_256(work0_reg, work1_reg, t01_lo, t01_hi);
@ -788,7 +789,6 @@ static void IFFT_DIT4(
LEO_M256 work2_reg = _mm256_loadu_si256(work2); LEO_M256 work2_reg = _mm256_loadu_si256(work2);
LEO_M256 work3_reg = _mm256_loadu_si256(work3); LEO_M256 work3_reg = _mm256_loadu_si256(work3);
// First layer:
work3_reg = _mm256_xor_si256(work2_reg, work3_reg); work3_reg = _mm256_xor_si256(work2_reg, work3_reg);
if (log_m23 != kModulus) if (log_m23 != kModulus)
LEO_MULADD_256(work2_reg, work3_reg, t23_lo, t23_hi); LEO_MULADD_256(work2_reg, work3_reg, t23_lo, t23_hi);
@ -834,10 +834,10 @@ static void IFFT_DIT4(
do do
{ {
// First layer:
LEO_M128 work0_reg = _mm_loadu_si128(work0); LEO_M128 work0_reg = _mm_loadu_si128(work0);
LEO_M128 work1_reg = _mm_loadu_si128(work1); LEO_M128 work1_reg = _mm_loadu_si128(work1);
// First layer:
work1_reg = _mm_xor_si128(work0_reg, work1_reg); work1_reg = _mm_xor_si128(work0_reg, work1_reg);
if (log_m01 != kModulus) if (log_m01 != kModulus)
LEO_MULADD_128(work0_reg, work1_reg, t01_lo, t01_hi); LEO_MULADD_128(work0_reg, work1_reg, t01_lo, t01_hi);
@ -845,7 +845,6 @@ static void IFFT_DIT4(
LEO_M128 work2_reg = _mm_loadu_si128(work2); LEO_M128 work2_reg = _mm_loadu_si128(work2);
LEO_M128 work3_reg = _mm_loadu_si128(work3); LEO_M128 work3_reg = _mm_loadu_si128(work3);
// First layer:
work3_reg = _mm_xor_si128(work2_reg, work3_reg); work3_reg = _mm_xor_si128(work2_reg, work3_reg);
if (log_m23 != kModulus) if (log_m23 != kModulus)
LEO_MULADD_128(work2_reg, work3_reg, t23_lo, t23_hi); LEO_MULADD_128(work2_reg, work3_reg, t23_lo, t23_hi);
@ -897,6 +896,183 @@ static void IFFT_DIT4(
} }
} }
// xor_result ^= IFFT_DIT4(work)
static void IFFT_DIT4_xor(
uint64_t bytes,
void** work_in,
void** xor_out,
unsigned dist,
const ffe_t log_m01,
const ffe_t log_m23,
const ffe_t log_m02)
{
#ifdef LEO_INTERLEAVE_BUTTERFLY4_OPT
#if defined(LEO_TRY_AVX2)
if (CpuHasAVX2)
{
const LEO_M256 t01_lo = _mm256_loadu_si256(&Multiply256LUT[log_m01].Value[0]);
const LEO_M256 t01_hi = _mm256_loadu_si256(&Multiply256LUT[log_m01].Value[1]);
const LEO_M256 t23_lo = _mm256_loadu_si256(&Multiply256LUT[log_m23].Value[0]);
const LEO_M256 t23_hi = _mm256_loadu_si256(&Multiply256LUT[log_m23].Value[1]);
const LEO_M256 t02_lo = _mm256_loadu_si256(&Multiply256LUT[log_m02].Value[0]);
const LEO_M256 t02_hi = _mm256_loadu_si256(&Multiply256LUT[log_m02].Value[1]);
const LEO_M256 clr_mask = _mm256_set1_epi8(0x0f);
const LEO_M256 * LEO_RESTRICT work0 = reinterpret_cast<const LEO_M256 *>(work_in[0]);
const LEO_M256 * LEO_RESTRICT work1 = reinterpret_cast<const LEO_M256 *>(work_in[dist]);
const LEO_M256 * LEO_RESTRICT work2 = reinterpret_cast<const LEO_M256 *>(work_in[dist * 2]);
const LEO_M256 * LEO_RESTRICT work3 = reinterpret_cast<const LEO_M256 *>(work_in[dist * 3]);
LEO_M256 * LEO_RESTRICT xor0 = reinterpret_cast<LEO_M256 *>(xor_out[0]);
LEO_M256 * LEO_RESTRICT xor1 = reinterpret_cast<LEO_M256 *>(xor_out[dist]);
LEO_M256 * LEO_RESTRICT xor2 = reinterpret_cast<LEO_M256 *>(xor_out[dist * 2]);
LEO_M256 * LEO_RESTRICT xor3 = reinterpret_cast<LEO_M256 *>(xor_out[dist * 3]);
do
{
// First layer:
LEO_M256 work0_reg = _mm256_loadu_si256(work0);
LEO_M256 work1_reg = _mm256_loadu_si256(work1);
work0++, work1++;
work1_reg = _mm256_xor_si256(work0_reg, work1_reg);
if (log_m01 != kModulus)
LEO_MULADD_256(work0_reg, work1_reg, t01_lo, t01_hi);
LEO_M256 work2_reg = _mm256_loadu_si256(work2);
LEO_M256 work3_reg = _mm256_loadu_si256(work3);
work2++, work3++;
work3_reg = _mm256_xor_si256(work2_reg, work3_reg);
if (log_m23 != kModulus)
LEO_MULADD_256(work2_reg, work3_reg, t23_lo, t23_hi);
// Second layer:
work2_reg = _mm256_xor_si256(work0_reg, work2_reg);
work3_reg = _mm256_xor_si256(work1_reg, work3_reg);
if (log_m02 != kModulus)
{
LEO_MULADD_256(work0_reg, work2_reg, t02_lo, t02_hi);
LEO_MULADD_256(work1_reg, work3_reg, t02_lo, t02_hi);
}
work0_reg = _mm256_xor_si256(work0_reg, _mm256_loadu_si256(xor0));
work1_reg = _mm256_xor_si256(work1_reg, _mm256_loadu_si256(xor1));
work2_reg = _mm256_xor_si256(work2_reg, _mm256_loadu_si256(xor2));
work3_reg = _mm256_xor_si256(work3_reg, _mm256_loadu_si256(xor3));
_mm256_storeu_si256(xor0, work0_reg);
_mm256_storeu_si256(xor1, work1_reg);
_mm256_storeu_si256(xor2, work2_reg);
_mm256_storeu_si256(xor3, work3_reg);
xor0++, xor1++, xor2++, xor3++;
bytes -= 32;
} while (bytes > 0);
return;
}
#endif // LEO_TRY_AVX2
if (CpuHasSSSE3)
{
const LEO_M128 t01_lo = _mm_loadu_si128(&Multiply128LUT[log_m01].Value[0]);
const LEO_M128 t01_hi = _mm_loadu_si128(&Multiply128LUT[log_m01].Value[1]);
const LEO_M128 t23_lo = _mm_loadu_si128(&Multiply128LUT[log_m23].Value[0]);
const LEO_M128 t23_hi = _mm_loadu_si128(&Multiply128LUT[log_m23].Value[1]);
const LEO_M128 t02_lo = _mm_loadu_si128(&Multiply128LUT[log_m02].Value[0]);
const LEO_M128 t02_hi = _mm_loadu_si128(&Multiply128LUT[log_m02].Value[1]);
const LEO_M128 clr_mask = _mm_set1_epi8(0x0f);
const LEO_M128 * LEO_RESTRICT work0 = reinterpret_cast<const LEO_M128 *>(work_in[0]);
const LEO_M128 * LEO_RESTRICT work1 = reinterpret_cast<const LEO_M128 *>(work_in[dist]);
const LEO_M128 * LEO_RESTRICT work2 = reinterpret_cast<const LEO_M128 *>(work_in[dist * 2]);
const LEO_M128 * LEO_RESTRICT work3 = reinterpret_cast<const LEO_M128 *>(work_in[dist * 3]);
LEO_M128 * LEO_RESTRICT xor0 = reinterpret_cast<LEO_M128 *>(xor_out[0]);
LEO_M128 * LEO_RESTRICT xor1 = reinterpret_cast<LEO_M128 *>(xor_out[dist]);
LEO_M128 * LEO_RESTRICT xor2 = reinterpret_cast<LEO_M128 *>(xor_out[dist * 2]);
LEO_M128 * LEO_RESTRICT xor3 = reinterpret_cast<LEO_M128 *>(xor_out[dist * 3]);
do
{
// First layer:
LEO_M128 work0_reg = _mm_loadu_si128(work0);
LEO_M128 work1_reg = _mm_loadu_si128(work1);
work0++, work1++;
work1_reg = _mm_xor_si128(work0_reg, work1_reg);
if (log_m01 != kModulus)
LEO_MULADD_128(work0_reg, work1_reg, t01_lo, t01_hi);
LEO_M128 work2_reg = _mm_loadu_si128(work2);
LEO_M128 work3_reg = _mm_loadu_si128(work3);
work2++, work3++;
work3_reg = _mm_xor_si128(work2_reg, work3_reg);
if (log_m23 != kModulus)
LEO_MULADD_128(work2_reg, work3_reg, t23_lo, t23_hi);
// Second layer:
work2_reg = _mm_xor_si128(work0_reg, work2_reg);
work3_reg = _mm_xor_si128(work1_reg, work3_reg);
if (log_m02 != kModulus)
{
LEO_MULADD_128(work0_reg, work2_reg, t02_lo, t02_hi);
LEO_MULADD_128(work1_reg, work3_reg, t02_lo, t02_hi);
}
work0_reg = _mm_xor_si128(work0_reg, _mm_loadu_si128(xor0));
work1_reg = _mm_xor_si128(work1_reg, _mm_loadu_si128(xor1));
work2_reg = _mm_xor_si128(work2_reg, _mm_loadu_si128(xor2));
work3_reg = _mm_xor_si128(work3_reg, _mm_loadu_si128(xor3));
_mm_storeu_si128(xor0, work0_reg);
_mm_storeu_si128(xor1, work1_reg);
_mm_storeu_si128(xor2, work2_reg);
_mm_storeu_si128(xor3, work3_reg);
xor0++, xor1++, xor2++, xor3++;
bytes -= 16;
} while (bytes > 0);
return;
}
#endif // LEO_INTERLEAVE_BUTTERFLY4_OPT
// First layer:
if (log_m01 == kModulus)
xor_mem(work_in[dist], work_in[0], bytes);
else
IFFT_DIT2(work_in[0], work_in[dist], log_m01, bytes);
if (log_m23 == kModulus)
xor_mem(work_in[dist * 3], work_in[dist * 2], bytes);
else
IFFT_DIT2(work_in[dist * 2], work_in[dist * 3], log_m23, bytes);
// Second layer:
if (log_m02 == kModulus)
{
xor_mem(work_in[dist * 2], work_in[0], bytes);
xor_mem(work_in[dist * 3], work_in[dist], bytes);
}
else
{
IFFT_DIT2(work_in[0], work_in[dist * 2], log_m02, bytes);
IFFT_DIT2(work_in[dist], work_in[dist * 3], log_m02, bytes);
}
xor_mem(xor_out[0], work_in[0], bytes);
xor_mem(xor_out[dist], work_in[dist], bytes);
xor_mem(xor_out[dist * 2], work_in[dist * 2], bytes);
xor_mem(xor_out[dist * 3], work_in[dist * 3], bytes);
}
static void IFFT_DIT( static void IFFT_DIT(
const uint64_t bytes, const uint64_t bytes,
const void* const* data, const void* const* data,
@ -930,8 +1106,26 @@ static void IFFT_DIT(
const ffe_t log_m23 = skewLUT[r + dist * 3]; const ffe_t log_m23 = skewLUT[r + dist * 3];
const ffe_t log_m02 = skewLUT[r + dist * 2]; const ffe_t log_m02 = skewLUT[r + dist * 2];
// For each set of dist elements:
const unsigned i_end = r + dist; const unsigned i_end = r + dist;
if (dist4 == m && xor_result)
{
// For each set of dist elements:
for (unsigned i = r; i < i_end; ++i)
{
IFFT_DIT4_xor(
bytes,
work + i,
xor_result + i,
dist,
log_m01,
log_m23,
log_m02);
}
}
else
{
// For each set of dist elements:
for (unsigned i = r; i < i_end; ++i) for (unsigned i = r; i < i_end; ++i)
{ {
IFFT_DIT4( IFFT_DIT4(
@ -943,6 +1137,7 @@ static void IFFT_DIT(
log_m02); log_m02);
} }
} }
}
// I tried alternating sweeps left->right and right->left to reduce cache misses. // I tried alternating sweeps left->right and right->left to reduce cache misses.
// It provides about 1% performance boost when done for both FFT and IFFT, so it // It provides about 1% performance boost when done for both FFT and IFFT, so it

View File

@ -45,8 +45,8 @@ struct TestParameters
unsigned original_count = 100; // under 65536 unsigned original_count = 100; // under 65536
unsigned recovery_count = 20; // under 65536 - original_count unsigned recovery_count = 20; // under 65536 - original_count
#else #else
unsigned original_count = 128; // under 65536 unsigned original_count = 100; // under 65536
unsigned recovery_count = 128; // under 65536 - original_count unsigned recovery_count = 20; // under 65536 - original_count
#endif #endif
unsigned buffer_bytes = 64000; // multiple of 64 bytes unsigned buffer_bytes = 64000; // multiple of 64 bytes
unsigned loss_count = 32768; // some fraction of original_count unsigned loss_count = 32768; // some fraction of original_count
@ -54,6 +54,9 @@ struct TestParameters
bool multithreaded = true; bool multithreaded = true;
}; };
static const unsigned kLargeTrialCount = 1;
static const unsigned kSmallTrialCount = 100;
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Windows // Windows
@ -370,7 +373,7 @@ static void ShuffleDeck16(PCGRandom &prng, uint16_t * LEO_RESTRICT deck, uint32_
static bool Benchmark(const TestParameters& params) static bool Benchmark(const TestParameters& params)
{ {
const unsigned kTrials = params.original_count > 8000 ? 1 : 100; const unsigned kTrials = params.original_count > 4000 ? kLargeTrialCount : kSmallTrialCount;
std::vector<uint8_t*> original_data(params.original_count); std::vector<uint8_t*> original_data(params.original_count);
@ -565,7 +568,7 @@ int main(int argc, char **argv)
if (!Benchmark(params)) if (!Benchmark(params))
goto Failed; goto Failed;
#if 1 #if 0
static const unsigned kMaxRandomData = 32768; static const unsigned kMaxRandomData = 32768;
prng.Seed(params.seed, 8); prng.Seed(params.seed, 8);
@ -582,7 +585,7 @@ int main(int argc, char **argv)
} }
#endif #endif
#if 0 #if 1
for (unsigned original_count = 1; original_count <= 256; ++original_count) for (unsigned original_count = 1; original_count <= 256; ++original_count)
{ {
for (unsigned recovery_count = 1; recovery_count <= original_count; ++recovery_count) for (unsigned recovery_count = 1; recovery_count <= original_count; ++recovery_count)
@ -600,6 +603,7 @@ int main(int argc, char **argv)
#endif #endif
Failed: Failed:
cout << "Tests completed." << endl;
getchar(); getchar();
return 0; return 0;