From 7d9e81362d8bc849698668f4d42d182d786a6d25 Mon Sep 17 00:00:00 2001 From: Jacqueline Nabaglo Date: Thu, 15 Sep 2022 14:59:16 -0700 Subject: [PATCH] Python prototype of cache-oblivious FFT (#722) --- projects/cache-friendly-fft/__init__.py | 229 +++++++++++++++++++++++ projects/cache-friendly-fft/transpose.py | 61 ++++++ projects/cache-friendly-fft/util.py | 6 + 3 files changed, 296 insertions(+) create mode 100644 projects/cache-friendly-fft/__init__.py create mode 100644 projects/cache-friendly-fft/transpose.py create mode 100644 projects/cache-friendly-fft/util.py diff --git a/projects/cache-friendly-fft/__init__.py b/projects/cache-friendly-fft/__init__.py new file mode 100644 index 00000000..08f1acac --- /dev/null +++ b/projects/cache-friendly-fft/__init__.py @@ -0,0 +1,229 @@ +import numpy as np + +from transpose import transpose_square +from util import lb_exact + + +def _interleave(x, scratch): + """Interleave the elements in an array in-place. + + For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its + contents will be rearranged to `array([1, 5, 2, 6, 3, 7, 4, 8])`. + + `scratch` is an externally-allocated buffer, whose `dtype` matches + `x` and whose length is at least half the length of `x`. + """ + assert len(x.shape) == len(scratch.shape) == 1 + + n, = x.shape + assert n % 2 == 0 + + half_n = n // 2 + assert scratch.shape[0] >= half_n + + assert x.dtype == scratch.dtype + scratch = scratch[:half_n] + + scratch[:] = x[:half_n] # Save the first half of `x`. + for i in range(half_n): + x[2 * i] = scratch[i] + x[2 * i + 1] = x[half_n + i] + + +def _deinterleave(x, scratch): + """Deinterleave the elements in an array in-place. + + For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its + contents will be rearranged to `array([1, 3, 5, 7, 2, 4, 6, 8])`. + + `scratch` is an externally-allocated buffer, whose `dtype` matches + `x` and whose length is at least half the length of `x`. + """ + assert len(x.shape) == len(scratch.shape) == 1 + + n, = x.shape + assert n % 2 == 0 + + half_n = n // 2 + assert scratch.shape[0] >= half_n + + assert x.dtype == scratch.dtype + scratch = scratch[:half_n] + + for i in range(half_n): + x[i] = x[2 * i] + scratch[i] = x[2 * i + 1] + x[half_n:] = scratch + + +def _fft_inplace_evenpow(x, scratch): + """In-place FFT of length 2^even""" + # Reshape `x` to a square matrix in row-major order. + vec_len = x.shape[0] + n = 1 << (lb_exact(vec_len) >> 1) # Matrix dimension + x.shape = n, n, 1 + + # We want to recursively apply FFT to every column. Because `x` is + # in row-major order, we transpose it to make the columns contiguous + # in memory, then recurse, and finally transpose it back. While the + # row is in cache, we also multiply by the twiddle factors. + transpose_square(x) + for i, row in enumerate(x[..., 0]): + _fft_inplace(row, scratch) + # Multiply by the twiddle factors + for j in range(n): + row[j] *= np.exp(-2j * np.pi * (i * j) / vec_len) + transpose_square(x) + + # Now recursively apply FFT to the rows. + for row in x[..., 0]: + _fft_inplace(row, scratch) + + # Transpose again before returning. + transpose_square(x) + + +def _fft_inplace_oddpow(x, scratch): + """In-place FFT of length 2^odd""" + # This code is based on `_fft_inplace_evenpow`, but it has to + # account for some additional complications. + + vec_len = x.shape[0] + # `vec_len` is an odd power of 2, so we cannot reshape `x` to a + # matrix square. Instead, we'll (conceptually) reshape it to a + # matrix that's twice as wide as it is high. E.g., `[1 ... 8]` + # becomes `[1 2 3 4]` + # `[5 6 7 8]`. + col_len = 1 << (lb_exact(vec_len) >> 1) + row_len = col_len << 1 + + # We can only perform efficient, in-place transposes on square + # matrices, so we will actually treat this as a square matrix of + # 2-tuples, e.g. `[(1 2) (3 4)]` + # `[(5 6) (7 8)]`. + # Note that we can currently `.reshape` it to our intended wide + # matrix (although this is broken by transposition). + x.shape = col_len, col_len, 2 + + # We want to apply FFT to each column. We transpose our + # matrix-of-tuples and get something like `[(1 2) (5 6)]` + # `[(3 4) (7 8)]`. + # Note that each row of the transposed matrix represents two columns + # of the original matrix. We can deinterleave the values to recover + # the original columns. + transpose_square(x) + + for i, row_pair in enumerate(x): + # `row_pair` represents two columns of the original matrix. + # Their values must be deinterleaved to recover the columns. + row_pair.shape = row_len, + _deinterleave(row_pair, scratch) + # The below are rows of the transposed matrix(/cols of the + # original matrix. + row0 = row_pair[:col_len] + row1 = row_pair[col_len:] + + # Apply FFT and twiddle factors to each. + _fft_inplace(row0, scratch) + for j in range(col_len): + row0[j] *= np.exp(-2j * np.pi * ((2 * i) * j) / vec_len) + _fft_inplace(row1, scratch) + for j in range(col_len): + row1[j] *= np.exp(-2j * np.pi * ((2 * i + 1) * j) / vec_len) + + # Re-interleave them and transpose back. + _interleave(row_pair, scratch) + + transpose_square(x) + + # Recursively apply FFT to each row of the matrix. + for row in x: + # Turn vec of 2-tuples into vec of single elements. + row.shape = row_len, + _fft_inplace(row, scratch) + + # Transpose again before returning. This again involves + # deinterleaving. + transpose_square(x) + for row_pair in x: + row_pair.shape = row_len, + _deinterleave(row_pair, scratch) + + +def _fft_inplace(x, scratch): + """In-place FFT.""" + # Avoid modifying the shape of the original. + # This does not copy the buffer. + x = x.view() + assert x.flags['C_CONTIGUOUS'] + + n, = x.shape + if n == 1: + return + if n == 2: + x0, x1 = x + x[0] = x0 + x1 + x[1] = x0 - x1 + return + + lb_n = lb_exact(n) + is_odd = lb_n & 1 != 0 + if is_odd: + _fft_inplace_oddpow(x, scratch) + else: + _fft_inplace_evenpow(x, scratch) + + +def _scrach_length(lb_n): + """Find the amount of scratch space required to run the FFT. + + Layers where the input's length is an even power of two do not + require scratch space, but the layers where that power is odd do. + """ + if lb_n == 0: + # Length-1 input. + return 0 + # Repeatedly halve lb_n as long as it's even. This is the same as + # `n = sqrt(n)`, where the `sqrt` is exact. + while lb_n & 1 == 0: + lb_n >>= 1 + # `lb_n` is now odd, so `n` is not an even power of 2. + lb_res = (lb_n - 1) >> 1 + if lb_res == 0: + # Special case (n == 2 or n == 4): no scratch needed. + return 0 + return 1 << lb_res + + +def fft(x): + """Returns the FFT of `x`. + + This is a wrapper around an in-place routine, provided for user + convenience. + """ + n, = x.shape + lb_n = lb_exact(n) # Raises if not a power of 2. + # We have one scratch buffer for the whole algorithm. If we were to + # parallelize it, we'd need one thread-local buffer for each worker + # thread. + scratch_len = _scrach_length(lb_n) + if scratch_len == 0: + scratch = None + else: + scratch = np.empty_like(x, shape=scratch_len, order='C', subok=False) + + res = x.copy(order='C') + _fft_inplace(res, scratch) + + return res + + +if __name__ == "__main__": + LENGTH = 1 << 10 + v = np.random.normal(size=LENGTH).astype(complex) + print(v) + numpy_fft = np.fft.fft(v) + print(numpy_fft) + our_fft = fft(v) + print(our_fft) + print(np.isclose(numpy_fft, our_fft).all()) diff --git a/projects/cache-friendly-fft/transpose.py b/projects/cache-friendly-fft/transpose.py new file mode 100644 index 00000000..ea20bf6b --- /dev/null +++ b/projects/cache-friendly-fft/transpose.py @@ -0,0 +1,61 @@ +from util import lb_exact + + +def _swap_transpose_square(a, b): + """Transpose two square matrices in-place and swap them. + + The matrices must be a of shape `(n, n, m)`, where the `m` dimension + may be of arbitrary length and is not moved. + """ + assert len(a.shape) == len(b.shape) == 3 + n = a.shape[0] + m = a.shape[2] + assert n == a.shape[1] == b.shape[0] == b.shape[1] + assert m == b.shape[2] + + if n == 0: + return + if n == 1: + # Swap the two matrices (transposition is a no-op). + a = a[0, 0] + b = b[0, 0] + # Recall that each element of the matrix is an `m`-vector. Swap + # all `m` elements. + for i in range(m): + a[i], b[i] = b[i], a[i] + return + + half_n = n >> 1 + # Transpose and swap top-left of `a` with top-left of `b`. + _swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n]) + # ...top-right of `a` with bottom-left of `b`. + _swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n]) + # ...bottom-left of `a` with top-right of `b`. + _swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:]) + # ...bottom-right of `a` with bottom-right of `b`. + _swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:]) + + +def transpose_square(a): + """In-place transpose of a square matrix. + + The matrix must be a of shape `(n, n, m)`, where the `m` dimension + may be of arbitrary length and is not moved. + """ + if len(a.shape) != 3: + raise ValueError("a must be a matrix of batches") + n, n_, _ = a.shape + if n != n_: + raise ValueError("a must be square") + lb_exact(n) + + if n <= 1: + return # Base case: no-op + + half_n = n >> 1 + # Transpose top-left quarter in-place. + transpose_square(a[:half_n, :half_n]) + # Transpose top-right and bottom-left quarters and swap them. + _swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n]) + # Transpose bottom-right quarter in-place. + transpose_square(a[half_n:, half_n:]) diff --git a/projects/cache-friendly-fft/util.py b/projects/cache-friendly-fft/util.py new file mode 100644 index 00000000..50118827 --- /dev/null +++ b/projects/cache-friendly-fft/util.py @@ -0,0 +1,6 @@ +def lb_exact(n): + """Returns `log2(n)`, raising if `n` is not a power of 2.""" + lb = n.bit_length() - 1 + if lb < 0 or n != 1 << lb: + raise ValueError(f"{n} is not a power of 2") + return lb