Implement DIT FFT for Encoder

This commit is contained in:
Christopher Taylor 2017-05-31 01:11:20 -07:00
parent 9090619d73
commit ac68c62d28
5 changed files with 424 additions and 266 deletions

View File

@ -34,7 +34,6 @@
+ Look into 12-bit fields as a performance optimization
+ Unroll first/final butterflies to avoid extra copies/xors in encoder
+ Skip a lot of the initial FWHT() layers that are just operating on zeroes
+ Skip a lot of the final FWHT() layers that are not needed for calculation
+ For the actual FFT(), I should be unrolling the bottom two layers
and performing them in a specialized function that does 2 <=> 2 and
then 1<=>1, 1<=>1 operations in local registers/cache
@ -160,7 +159,7 @@
// Define this to enable the optimized version of FWHT()
#define LEO_FWHT_OPT
// Avoid scheduling reduced FFT operations that are unneeded
// Avoid scheduling FFT operations that are unused
#define LEO_SCHEDULE_OPT
// Avoid calculating final FFT values in decoder using bitfield

View File

@ -364,7 +364,7 @@ static void InitializeLogarithmTables()
/*
The multiplication algorithm used follows the approach outlined in {4}.
Specifically section 7 outlines the algorithm used here for 16-bit fields.
I use the ALTMAP memory layout since I do not need to convert in/out of it.
The ALTMAP memory layout is used since there is no need to convert in/out.
*/
struct {
@ -903,6 +903,7 @@ static ffe_t FFTSkew[kModulus];
// Factors used in the evaluation of the error locator polynomial
static ffe_t LogWalsh[kOrder];
static void FFTInitialize()
{
ffe_t temp[kBits - 1];
@ -1328,12 +1329,12 @@ void ReedSolomonDecode(
// Evaluate error locator polynomial
FWHT(ErrorLocations, kBits);
FWHT(ErrorLocations);
for (unsigned i = 0; i < kOrder; ++i)
ErrorLocations[i] = ((unsigned)ErrorLocations[i] * (unsigned)LogWalsh[i]) % kModulus;
FWHT(ErrorLocations, kBits);
FWHT(ErrorLocations);
// work <- recovery data

View File

@ -121,38 +121,6 @@ static LEO_FORCE_INLINE void FWHT_4(ffe_t* data, unsigned s)
data[y] = t3;
}
static void FWHT_8(ffe_t* data)
{
ffe_t t0 = data[0];
ffe_t t1 = data[1];
ffe_t t2 = data[2];
ffe_t t3 = data[3];
ffe_t t4 = data[4];
ffe_t t5 = data[5];
ffe_t t6 = data[6];
ffe_t t7 = data[7];
FWHT_2(t0, t1);
FWHT_2(t2, t3);
FWHT_2(t4, t5);
FWHT_2(t6, t7);
FWHT_2(t0, t2);
FWHT_2(t1, t3);
FWHT_2(t4, t6);
FWHT_2(t5, t7);
FWHT_2(t0, t4);
FWHT_2(t1, t5);
FWHT_2(t2, t6);
FWHT_2(t3, t7);
data[0] = t0;
data[1] = t1;
data[2] = t2;
data[3] = t3;
data[4] = t4;
data[5] = t5;
data[6] = t6;
data[7] = t7;
}
// Decimation in time (DIT) version
static void FWHT(ffe_t* data, const unsigned bits)
{
@ -174,16 +142,8 @@ static void FWHT(ffe_t* data, const unsigned bits)
FWHT_4(data + j + r, m4);
}
if (bits & 1)
{
for (unsigned i0 = 0; i0 < n; i0 += 8)
FWHT_8(data + i0);
}
else
{
for (unsigned i0 = 0; i0 < n; i0 += 4)
FWHT_4(data + i0);
}
for (unsigned i0 = 0; i0 < n; i0 += 4)
FWHT_4(data + i0);
}
#else // LEO_FWHT_OPT
@ -790,230 +750,393 @@ static void FFTInitialize()
FWHT(LogWalsh, kBits);
}
void VectorFFTButterfly(
const uint64_t bytes,
unsigned count,
void** x,
void** y,
const ffe_t log_m)
{
if (log_m == kModulus)
{
VectorXOR(bytes, count, y, x);
return;
}
#ifdef LEO_USE_VECTOR4_OPT
while (count >= 4)
{
fft_butterfly4(
x[0], y[0],
x[1], y[1],
x[2], y[2],
x[3], y[3],
log_m, bytes);
x += 4, y += 4;
count -= 4;
}
#endif // LEO_USE_VECTOR4_OPT
for (unsigned i = 0; i < count; ++i)
fft_butterfly(x[i], y[i], log_m, bytes);
}
void VectorIFFTButterfly(
const uint64_t bytes,
unsigned count,
void** x,
void** y,
const ffe_t log_m)
{
if (log_m == kModulus)
{
VectorXOR(bytes, count, y, x);
return;
}
#ifdef LEO_USE_VECTOR4_OPT
while (count >= 4)
{
ifft_butterfly4(
x[0], y[0],
x[1], y[1],
x[2], y[2],
x[3], y[3],
log_m, bytes);
x += 4, y += 4;
count -= 4;
}
#endif // LEO_USE_VECTOR4_OPT
for (unsigned i = 0; i < count; ++i)
ifft_butterfly(x[i], y[i], log_m, bytes);
}
//------------------------------------------------------------------------------
// Reed-Solomon Encode
/*
Decimation in time IFFT:
The decimation in time IFFT algorithm allows us to unroll 2 layers at a time,
performing calculations on local registers and faster cache memory.
Each ^___^ below indicates a butterfly between the associated indices.
The ifft_butterfly(x, y) operation:
if (log_m != kModulus)
x[] ^= exp(log(y[]) + log_m)
y[] ^= x[]
Layer 0:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^_^ ^_^ ^_^ ^_^ ^_^ ^_^ ^_^ ^_^
Layer 1:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^___^ ^___^ ^___^ ^___^
^___^ ^___^ ^___^ ^___^
Layer 2:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^_______^ ^_______^
^_______^ ^_______^
^_______^ ^_______^
^_______^ ^_______^
Layer 3:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
DIT layer 0-1 operations, grouped 4 at a time:
{0-1, 2-3, 0-2, 1-3},
{4-5, 6-7, 4-6, 5-7},
DIT layer 1-2 operations, grouped 4 at a time:
{0-2, 4-6, 0-4, 2-6},
{1-3, 5-7, 1-5, 3-7},
DIT layer 2-3 operations, grouped 4 at a time:
{0-4, 0'-4', 0-0', 4-4'},
{1-5, 1'-5', 1-1', 5-5'},
*/
// 4-way butterfly
static void IFFT_DIT4(
const uint64_t bytes,
void** work,
unsigned dist,
const ffe_t log_m01,
const ffe_t log_m23,
const ffe_t log_m02)
{
// FIXME: Interleave these
// First layer:
if (log_m01 == kModulus)
xor_mem(work[dist], work[0], bytes);
else
ifft_butterfly(work[0], work[dist], log_m01, bytes);
if (log_m23 == kModulus)
xor_mem(work[dist * 3], work[dist * 2], bytes);
else
ifft_butterfly(work[dist * 2], work[dist * 3], log_m23, bytes);
// Second layer:
if (log_m02 == kModulus)
{
xor_mem(work[dist * 2], work[0], bytes);
xor_mem(work[dist * 3], work[dist], bytes);
}
else
{
ifft_butterfly(work[0], work[dist * 2], log_m02, bytes);
ifft_butterfly(work[dist], work[dist * 3], log_m02, bytes);
}
}
void IFFT_DIT(
const uint64_t bytes,
void* const* data,
const unsigned m_truncated,
void** work,
void** xor_result,
const unsigned m,
const ffe_t* skewLUT)
{
// FIXME: Roll into first layer
if (data)
{
for (unsigned i = 0; i < m_truncated; ++i)
memcpy(work[i], data[i], bytes);
for (unsigned i = m_truncated; i < m; ++i)
memset(work[i], 0, bytes);
}
// Decimation in time: Unroll 2 layers at a time
unsigned dist = 1, dist4 = 4;
for (; dist4 <= m; dist = dist4, dist4 <<= 2)
{
// FIXME: Walk this in reverse order every other pair of layers for better cache locality
// FIXME: m_truncated
// For each set of dist*4 elements:
for (unsigned r = 0; r < m_truncated; r += dist4)
{
const ffe_t log_m01 = skewLUT[r + dist];
const ffe_t log_m23 = skewLUT[r + dist * 3];
const ffe_t log_m02 = skewLUT[r + dist * 2];
// For each set of dist elements:
for (unsigned i = r; i < r + dist; ++i)
{
IFFT_DIT4(
bytes,
work + i,
dist,
log_m01,
log_m23,
log_m02);
}
}
data = nullptr;
}
// If there is one layer left:
if (dist < m)
{
const ffe_t log_m = skewLUT[dist];
if (log_m == kModulus)
{
for (unsigned i = 0; i < dist; ++i)
VectorXOR(bytes, dist, work + dist, work);
}
else
{
for (unsigned i = 0; i < dist; ++i)
{
ifft_butterfly(
work[i],
work[i + dist],
log_m,
bytes);
}
}
}
// FIXME: Roll into last layer
if (xor_result)
for (unsigned i = 0; i < m; ++i)
xor_mem(xor_result[i], work[i], bytes);
}
/*
Decimation in time FFT:
The decimation in time FFT algorithm allows us to unroll 2 layers at a time,
performing calculations on local registers and faster cache memory.
Each ^___^ below indicates a butterfly between the associated indices.
The fft_butterfly(x, y) operation:
y[] ^= x[]
if (log_m != kModulus)
x[] ^= exp(log(y[]) + log_m)
Layer 0:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
^_______________^
Layer 1:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^_______^ ^_______^
^_______^ ^_______^
^_______^ ^_______^
^_______^ ^_______^
Layer 2:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^___^ ^___^ ^___^ ^___^
^___^ ^___^ ^___^ ^___^
Layer 3:
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
^_^ ^_^ ^_^ ^_^ ^_^ ^_^ ^_^ ^_^
DIT layer 0-1 operations, grouped 4 at a time:
{0-0', 4-4', 0-4, 0'-4'},
{1-1', 5-5', 1-5, 1'-5'},
DIT layer 1-2 operations, grouped 4 at a time:
{0-4, 2-6, 0-2, 4-6},
{1-5, 3-7, 1-3, 5-7},
DIT layer 2-3 operations, grouped 4 at a time:
{0-2, 1-3, 0-1, 2-3},
{4-6, 5-7, 4-5, 6-7},
*/
static void FFT_DIT4(
const uint64_t bytes,
void** work,
const unsigned dist,
const ffe_t log_m01,
const ffe_t log_m23,
const ffe_t log_m02)
{
// FIXME: Interleave
// First layer:
if (log_m02 == kModulus)
{
xor_mem(work[dist * 2], work[0], bytes);
xor_mem(work[dist * 3], work[dist], bytes);
}
else
{
fft_butterfly(work[0], work[dist * 2], log_m02, bytes);
fft_butterfly(work[dist], work[dist * 3], log_m02, bytes);
}
// Second layer:
if (log_m01 == kModulus)
xor_mem(work[dist], work[0], bytes);
else
fft_butterfly(work[0], work[dist], log_m01, bytes);
if (log_m23 == kModulus)
xor_mem(work[dist * 3], work[dist * 2], bytes);
else
fft_butterfly(work[dist * 2], work[dist * 3], log_m23, bytes);
}
void FFT_DIT(
const uint64_t bytes,
void** work,
const unsigned m_truncated,
const unsigned m,
const ffe_t* skewLUT)
{
// Decimation in time: Unroll 2 layers at a time
unsigned dist4 = m, dist = m >> 2;
for (; dist != 0; dist4 = dist, dist >>= 2)
{
// FIXME: Walk this in reverse order every other pair of layers for better cache locality
// FIXME: m_truncated
// For each set of dist*4 elements:
for (unsigned r = 0; r < m_truncated; r += dist4)
{
const ffe_t log_m01 = skewLUT[r + dist];
const ffe_t log_m23 = skewLUT[r + dist * 3];
const ffe_t log_m02 = skewLUT[r + dist * 2];
// For each set of dist elements:
for (unsigned i = r; i < r + dist; ++i)
{
FFT_DIT4(
bytes,
work + i,
dist,
log_m01,
log_m23,
log_m02);
}
}
}
// If there is one layer left:
if (dist4 == 2)
{
for (unsigned r = 0; r < m_truncated; r += 2)
{
const ffe_t log_m = skewLUT[r + 1];
if (log_m == kModulus)
xor_mem(work[r + 1], work[r], bytes);
else
{
fft_butterfly(
work[r],
work[r + 1],
log_m,
bytes);
}
}
}
}
void ReedSolomonEncode(
uint64_t buffer_bytes,
unsigned original_count,
unsigned recovery_count,
unsigned m,
void* const * data,
void* const* data,
void** work)
{
// work <- data
// TBD: Unroll first loop to eliminate this
unsigned first_end = m;
if (original_count < m)
{
first_end = original_count;
for (unsigned i = original_count; i < m; ++i)
memset(work[i], 0, buffer_bytes);
}
for (unsigned i = 0; i < first_end; ++i)
memcpy(work[i], data[i], buffer_bytes);
// work <- IFFT(data, m, m)
for (unsigned width = 1; width < m; width <<= 1)
{
const unsigned range = width << 1;
const ffe_t* skewLUT = FFTSkew + width + m - 1;
const ffe_t* skewLUT = FFTSkew + m - 1;
#ifdef LEO_SCHEDULE_OPT
for (unsigned j = 0; j < first_end; j += range)
#else
for (unsigned j = 0; j < m; j += range)
#endif
{
VectorIFFTButterfly(
buffer_bytes,
width,
work + j,
work + j + width,
skewLUT[j]);
}
}
IFFT_DIT(
buffer_bytes,
data,
original_count < m ? original_count : m,
work,
nullptr, // No xor output
m,
skewLUT);
if (m >= original_count)
goto skip_body;
// For sets of m data pieces:
for (unsigned i = m; i + m <= original_count; i += m)
{
// temp <- data + i
data += m;
void** temp = work + m;
skewLUT += m;
// TBD: Unroll first loop to eliminate this
for (unsigned j = 0; j < m; ++j)
memcpy(temp[j], data[j], buffer_bytes);
// work <- work xor IFFT(data + i, m, m + i)
// temp <- IFFT(temp, m, m + i)
const ffe_t* skewLUT = FFTSkew + m + i - 1;
for (unsigned width = 1; width < m; width <<= 1)
{
const unsigned range = width << 1;
for (unsigned j = width; j < m; j += range)
{
VectorIFFTButterfly(
buffer_bytes,
width,
temp + j - width,
temp + j,
skewLUT[j]);
}
}
// work <- work XOR temp
// TBD: Unroll last loop to eliminate this
VectorXOR(
IFFT_DIT(
buffer_bytes,
data, // data source
m,
work,
temp);
work + m, // temporary workspace
work, // xor destination
m,
skewLUT);
}
// Handle final partial set of m pieces:
const unsigned last_count = original_count % m;
if (last_count != 0)
{
const unsigned i = original_count - last_count;
// temp <- data + i
data += m;
void** temp = work + m;
skewLUT += m;
for (unsigned j = 0; j < last_count; ++j)
memcpy(temp[j], data[j], buffer_bytes);
for (unsigned j = last_count; j < m; ++j)
memset(temp[j], 0, buffer_bytes);
// work <- work xor IFFT(data + i, m, m + i)
// temp <- IFFT(temp, m, m + i)
for (unsigned width = 1, shift = 1; width < m; width <<= 1, ++shift)
{
const unsigned range = width << 1;
const ffe_t* skewLUT = FFTSkew + width + m + i - 1;
#ifdef LEO_SCHEDULE_OPT
// Calculate stop considering that the right is all zeroes
const unsigned stop = ((last_count + range - 1) >> shift) << shift;
for (unsigned j = 0; j < stop; j += range)
#else
for (unsigned j = 0; j < m; j += range)
#endif
{
VectorIFFTButterfly(
buffer_bytes,
width,
temp + j,
temp + j + width,
skewLUT[j]);
}
}
// work <- work XOR temp
// TBD: Unroll last loop to eliminate this
VectorXOR(
IFFT_DIT(
buffer_bytes,
data, // data source
last_count,
work + m, // temporary workspace
work, // xor destination
m,
work,
temp);
skewLUT);
}
skip_body:
// work <- FFT(work, m, 0)
for (unsigned width = (m >> 1); width > 0; width >>= 1)
{
const ffe_t* skewLUT = FFTSkew + width - 1;
const unsigned range = width << 1;
#ifdef LEO_SCHEDULE_OPT
for (unsigned j = 0; j < recovery_count; j += range)
#else
for (unsigned j = 0; j < m; j += range)
#endif
{
VectorFFTButterfly(
buffer_bytes,
width,
work + j,
work + j + width,
skewLUT[j]);
}
}
FFT_DIT(
buffer_bytes,
work,
recovery_count,
m,
FFTSkew - 1);
}
@ -1088,6 +1211,68 @@ void ErrorBitfield::Prepare()
//------------------------------------------------------------------------------
// Reed-Solomon Decode
void VectorFFTButterfly(
const uint64_t bytes,
unsigned count,
void** x,
void** y,
const ffe_t log_m)
{
if (log_m == kModulus)
{
VectorXOR(bytes, count, y, x);
return;
}
#ifdef LEO_USE_VECTOR4_OPT
while (count >= 4)
{
fft_butterfly4(
x[0], y[0],
x[1], y[1],
x[2], y[2],
x[3], y[3],
log_m, bytes);
x += 4, y += 4;
count -= 4;
}
#endif // LEO_USE_VECTOR4_OPT
for (unsigned i = 0; i < count; ++i)
fft_butterfly(x[i], y[i], log_m, bytes);
}
void VectorIFFTButterfly(
const uint64_t bytes,
unsigned count,
void** x,
void** y,
const ffe_t log_m)
{
if (log_m == kModulus)
{
VectorXOR(bytes, count, y, x);
return;
}
#ifdef LEO_USE_VECTOR4_OPT
while (count >= 4)
{
ifft_butterfly4(
x[0], y[0],
x[1], y[1],
x[2], y[2],
x[3], y[3],
log_m, bytes);
x += 4, y += 4;
count -= 4;
}
#endif // LEO_USE_VECTOR4_OPT
for (unsigned i = 0; i < count; ++i)
ifft_butterfly(x[i], y[i], log_m, bytes);
}
void ReedSolomonDecode(
uint64_t buffer_bytes,
unsigned original_count,
@ -1127,12 +1312,12 @@ void ReedSolomonDecode(
// Evaluate error locator polynomial
FWHT(ErrorLocations, kBits);
FWHT(ErrorLocations);
for (unsigned i = 0; i < kOrder; ++i)
ErrorLocations[i] = ((unsigned)ErrorLocations[i] * (unsigned)LogWalsh[i]) % kModulus;
FWHT(ErrorLocations, kBits);
FWHT(ErrorLocations);
// work <- recovery data

View File

@ -134,34 +134,6 @@ void ifft_butterfly4(
#endif // LEO_USE_VECTOR4_OPT
//------------------------------------------------------------------------------
// FFT
/*
if (log_m != kModulus)
x[] ^= exp(log(y[]) + log_m)
y[] ^= x[]
*/
void VectorFFTButterfly(
const uint64_t bytes,
unsigned count,
void** x,
void** y,
const ffe_t log_m);
/*
y[] ^= x[]
if (log_m != kModulus)
x[] ^= exp(log(y[]) + log_m)
*/
void VectorIFFTButterfly(
const uint64_t bytes,
unsigned count,
void** x,
void** y,
const ffe_t log_m);
//------------------------------------------------------------------------------
// Reed-Solomon Encode

View File

@ -42,14 +42,14 @@ using namespace std;
struct TestParameters
{
#ifdef LEO_HAS_FF16
unsigned original_count = 1000; // under 65536
unsigned recovery_count = 200; // under 65536 - original_count
unsigned original_count = 100; // under 65536
unsigned recovery_count = 20; // under 65536 - original_count
#else
unsigned original_count = 128; // under 65536
unsigned recovery_count = 128; // under 65536 - original_count
#endif
unsigned buffer_bytes = 2560; // multiple of 64 bytes
unsigned loss_count = 500; // some fraction of original_count
unsigned buffer_bytes = 64; // multiple of 64 bytes
unsigned loss_count = 32768; // some fraction of original_count
unsigned seed = 2;
bool multithreaded = true;
};
@ -399,7 +399,7 @@ static LEO_FORCE_INLINE void SIMDSafeFree(void* ptr)
static bool BasicTest(const TestParameters& params)
{
const unsigned kTrials = params.original_count > 8000 ? 1 : 10;
const unsigned kTrials = params.original_count > 8000 ? 1 : 100000;
std::vector<uint8_t*> original_data(params.original_count);
@ -807,8 +807,8 @@ int main(int argc, char **argv)
if (!BasicTest(params))
goto Failed;
static const unsigned kMaxRandomData = 32768;
#if 0
static const unsigned kMaxRandomData = 128;
prng.Seed(params.seed, 8);
for (;; ++params.seed)
@ -822,8 +822,9 @@ int main(int argc, char **argv)
if (!BasicTest(params))
goto Failed;
}
#endif
#ifdef LEO_BENCH_ALL_256_PARAMS
#if 1
for (unsigned original_count = 1; original_count <= 256; ++original_count)
{
for (unsigned recovery_count = 1; recovery_count <= original_count; ++recovery_count)