diff --git a/src/fft_fr.c b/src/fft_fr.c index 439049b..942d6ac 100644 --- a/src/fft_fr.c +++ b/src/fft_fr.c @@ -17,7 +17,7 @@ #include "fft_fr.h" // Slow Fourier Transform (simple, good for small sizes) -void slow_ft(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) { +void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) { blst_fr v, last, tmp; for (uint64_t i = 0; i < l; i++) { blst_fr jv = in[offset]; @@ -36,10 +36,10 @@ void slow_ft(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_f } // Fast Fourier Transform -void fast_ft(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) { +void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) { uint64_t half = l / 2; - fft_helper(out, in, offset, stride * 2, roots, roots_stride * 2, l / 2); - fft_helper(out + half, in, offset + stride, stride * 2, roots, roots_stride * 2, l / 2); + fft_fr_helper(out, in, offset, stride * 2, roots, roots_stride * 2, l / 2); + fft_fr_helper(out + half, in, offset + stride, stride * 2, roots, roots_stride * 2, l / 2); for (uint64_t i = 0; i < half; i++) { blst_fr y_times_root; blst_fr x = out[i]; @@ -49,17 +49,17 @@ void fast_ft(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_f } } -void fft_helper(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) { +void fft_fr_helper(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) { // TODO: Tunable parameter if (l <= 4) { - slow_ft(out, in, offset, stride, roots, roots_stride, l); + fft_fr_slow(out, in, offset, stride, roots, roots_stride, l); } else { - fast_ft(out, in, offset, stride, roots, roots_stride, l); + fft_fr_fast(out, in, offset, stride, roots, roots_stride, l); } } // The main entry point for forward and reverse FFTs -void fft (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) { +void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) { uint64_t stride = fs->max_width / n; assert(n <= fs->max_width); assert(is_power_of_two(n)); @@ -67,11 +67,11 @@ void fft (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) { blst_fr inv_len; fr_from_uint64(&inv_len, n); blst_fr_eucl_inverse(&inv_len, &inv_len); - fft_helper(out, in, 0, 1, fs->reverse_roots_of_unity, stride, fs->max_width); + fft_fr_helper(out, in, 0, 1, fs->reverse_roots_of_unity, stride, fs->max_width); for (uint64_t i = 0; i < fs->max_width; i++) { blst_fr_mul(&out[i], &out[i], &inv_len); } } else { - fft_helper(out, in, 0, 1, fs->expanded_roots_of_unity, stride, fs->max_width); + fft_fr_helper(out, in, 0, 1, fs->expanded_roots_of_unity, stride, fs->max_width); } } diff --git a/src/fft_fr.h b/src/fft_fr.h index 1006bca..e09a490 100644 --- a/src/fft_fr.h +++ b/src/fft_fr.h @@ -17,7 +17,7 @@ #include "c-kzg.h" #include "fft_util.h" -void slow_ft(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); -void fast_ft(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); -void fft_helper(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); -void fft (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n); +void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); +void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); +void fft_fr_helper(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); +void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n); diff --git a/src/fft_fr_test.c b/src/fft_fr_test.c index 8f30782..d8f2945 100644 --- a/src/fft_fr_test.c +++ b/src/fft_fr_test.c @@ -48,8 +48,8 @@ void compare_sft_fft(void) { } // Do both fast and slow transforms - slow_ft(out0, data, 0, 1, fs.expanded_roots_of_unity, 1, fs.max_width); - fast_ft(out1, data, 0, 1, fs.expanded_roots_of_unity, 1, fs.max_width); + fft_fr_slow(out0, data, 0, 1, fs.expanded_roots_of_unity, 1, fs.max_width); + fft_fr_fast(out1, data, 0, 1, fs.expanded_roots_of_unity, 1, fs.max_width); // Verify the results are identical for (int i = 0; i < fs.max_width; i++) { @@ -69,8 +69,8 @@ void roundtrip_fft(void) { } // Forward and reverse FFT - fft(coeffs, data, &fs, false, fs.max_width); - fft(data, coeffs, &fs, true, fs.max_width); + fft_fr(coeffs, data, &fs, false, fs.max_width); + fft_fr(data, coeffs, &fs, true, fs.max_width); // Verify that the result is still ascending values of i for (int i = 0; i < fs.max_width; i++) { @@ -91,7 +91,7 @@ void inverse_fft(void) { } // Inverst FFT - fft(out, data, &fs, true, fs.max_width); + fft_fr(out, data, &fs, true, fs.max_width); // Verify against the known result, `inv_fft_expected` int n = sizeof(inv_fft_expected) / sizeof(inv_fft_expected[0]);