Python prototype of cache-oblivious FFT (#722)

This commit is contained in:
Jacqueline Nabaglo 2022-09-15 14:59:16 -07:00 committed by GitHub
parent b933e43cf1
commit 7d9e81362d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 296 additions and 0 deletions

View File

@ -0,0 +1,229 @@
import numpy as np
from transpose import transpose_square
from util import lb_exact
def _interleave(x, scratch):
"""Interleave the elements in an array in-place.
For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its
contents will be rearranged to `array([1, 5, 2, 6, 3, 7, 4, 8])`.
`scratch` is an externally-allocated buffer, whose `dtype` matches
`x` and whose length is at least half the length of `x`.
"""
assert len(x.shape) == len(scratch.shape) == 1
n, = x.shape
assert n % 2 == 0
half_n = n // 2
assert scratch.shape[0] >= half_n
assert x.dtype == scratch.dtype
scratch = scratch[:half_n]
scratch[:] = x[:half_n] # Save the first half of `x`.
for i in range(half_n):
x[2 * i] = scratch[i]
x[2 * i + 1] = x[half_n + i]
def _deinterleave(x, scratch):
"""Deinterleave the elements in an array in-place.
For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its
contents will be rearranged to `array([1, 3, 5, 7, 2, 4, 6, 8])`.
`scratch` is an externally-allocated buffer, whose `dtype` matches
`x` and whose length is at least half the length of `x`.
"""
assert len(x.shape) == len(scratch.shape) == 1
n, = x.shape
assert n % 2 == 0
half_n = n // 2
assert scratch.shape[0] >= half_n
assert x.dtype == scratch.dtype
scratch = scratch[:half_n]
for i in range(half_n):
x[i] = x[2 * i]
scratch[i] = x[2 * i + 1]
x[half_n:] = scratch
def _fft_inplace_evenpow(x, scratch):
"""In-place FFT of length 2^even"""
# Reshape `x` to a square matrix in row-major order.
vec_len = x.shape[0]
n = 1 << (lb_exact(vec_len) >> 1) # Matrix dimension
x.shape = n, n, 1
# We want to recursively apply FFT to every column. Because `x` is
# in row-major order, we transpose it to make the columns contiguous
# in memory, then recurse, and finally transpose it back. While the
# row is in cache, we also multiply by the twiddle factors.
transpose_square(x)
for i, row in enumerate(x[..., 0]):
_fft_inplace(row, scratch)
# Multiply by the twiddle factors
for j in range(n):
row[j] *= np.exp(-2j * np.pi * (i * j) / vec_len)
transpose_square(x)
# Now recursively apply FFT to the rows.
for row in x[..., 0]:
_fft_inplace(row, scratch)
# Transpose again before returning.
transpose_square(x)
def _fft_inplace_oddpow(x, scratch):
"""In-place FFT of length 2^odd"""
# This code is based on `_fft_inplace_evenpow`, but it has to
# account for some additional complications.
vec_len = x.shape[0]
# `vec_len` is an odd power of 2, so we cannot reshape `x` to a
# matrix square. Instead, we'll (conceptually) reshape it to a
# matrix that's twice as wide as it is high. E.g., `[1 ... 8]`
# becomes `[1 2 3 4]`
# `[5 6 7 8]`.
col_len = 1 << (lb_exact(vec_len) >> 1)
row_len = col_len << 1
# We can only perform efficient, in-place transposes on square
# matrices, so we will actually treat this as a square matrix of
# 2-tuples, e.g. `[(1 2) (3 4)]`
# `[(5 6) (7 8)]`.
# Note that we can currently `.reshape` it to our intended wide
# matrix (although this is broken by transposition).
x.shape = col_len, col_len, 2
# We want to apply FFT to each column. We transpose our
# matrix-of-tuples and get something like `[(1 2) (5 6)]`
# `[(3 4) (7 8)]`.
# Note that each row of the transposed matrix represents two columns
# of the original matrix. We can deinterleave the values to recover
# the original columns.
transpose_square(x)
for i, row_pair in enumerate(x):
# `row_pair` represents two columns of the original matrix.
# Their values must be deinterleaved to recover the columns.
row_pair.shape = row_len,
_deinterleave(row_pair, scratch)
# The below are rows of the transposed matrix(/cols of the
# original matrix.
row0 = row_pair[:col_len]
row1 = row_pair[col_len:]
# Apply FFT and twiddle factors to each.
_fft_inplace(row0, scratch)
for j in range(col_len):
row0[j] *= np.exp(-2j * np.pi * ((2 * i) * j) / vec_len)
_fft_inplace(row1, scratch)
for j in range(col_len):
row1[j] *= np.exp(-2j * np.pi * ((2 * i + 1) * j) / vec_len)
# Re-interleave them and transpose back.
_interleave(row_pair, scratch)
transpose_square(x)
# Recursively apply FFT to each row of the matrix.
for row in x:
# Turn vec of 2-tuples into vec of single elements.
row.shape = row_len,
_fft_inplace(row, scratch)
# Transpose again before returning. This again involves
# deinterleaving.
transpose_square(x)
for row_pair in x:
row_pair.shape = row_len,
_deinterleave(row_pair, scratch)
def _fft_inplace(x, scratch):
"""In-place FFT."""
# Avoid modifying the shape of the original.
# This does not copy the buffer.
x = x.view()
assert x.flags['C_CONTIGUOUS']
n, = x.shape
if n == 1:
return
if n == 2:
x0, x1 = x
x[0] = x0 + x1
x[1] = x0 - x1
return
lb_n = lb_exact(n)
is_odd = lb_n & 1 != 0
if is_odd:
_fft_inplace_oddpow(x, scratch)
else:
_fft_inplace_evenpow(x, scratch)
def _scrach_length(lb_n):
"""Find the amount of scratch space required to run the FFT.
Layers where the input's length is an even power of two do not
require scratch space, but the layers where that power is odd do.
"""
if lb_n == 0:
# Length-1 input.
return 0
# Repeatedly halve lb_n as long as it's even. This is the same as
# `n = sqrt(n)`, where the `sqrt` is exact.
while lb_n & 1 == 0:
lb_n >>= 1
# `lb_n` is now odd, so `n` is not an even power of 2.
lb_res = (lb_n - 1) >> 1
if lb_res == 0:
# Special case (n == 2 or n == 4): no scratch needed.
return 0
return 1 << lb_res
def fft(x):
"""Returns the FFT of `x`.
This is a wrapper around an in-place routine, provided for user
convenience.
"""
n, = x.shape
lb_n = lb_exact(n) # Raises if not a power of 2.
# We have one scratch buffer for the whole algorithm. If we were to
# parallelize it, we'd need one thread-local buffer for each worker
# thread.
scratch_len = _scrach_length(lb_n)
if scratch_len == 0:
scratch = None
else:
scratch = np.empty_like(x, shape=scratch_len, order='C', subok=False)
res = x.copy(order='C')
_fft_inplace(res, scratch)
return res
if __name__ == "__main__":
LENGTH = 1 << 10
v = np.random.normal(size=LENGTH).astype(complex)
print(v)
numpy_fft = np.fft.fft(v)
print(numpy_fft)
our_fft = fft(v)
print(our_fft)
print(np.isclose(numpy_fft, our_fft).all())

View File

@ -0,0 +1,61 @@
from util import lb_exact
def _swap_transpose_square(a, b):
"""Transpose two square matrices in-place and swap them.
The matrices must be a of shape `(n, n, m)`, where the `m` dimension
may be of arbitrary length and is not moved.
"""
assert len(a.shape) == len(b.shape) == 3
n = a.shape[0]
m = a.shape[2]
assert n == a.shape[1] == b.shape[0] == b.shape[1]
assert m == b.shape[2]
if n == 0:
return
if n == 1:
# Swap the two matrices (transposition is a no-op).
a = a[0, 0]
b = b[0, 0]
# Recall that each element of the matrix is an `m`-vector. Swap
# all `m` elements.
for i in range(m):
a[i], b[i] = b[i], a[i]
return
half_n = n >> 1
# Transpose and swap top-left of `a` with top-left of `b`.
_swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n])
# ...top-right of `a` with bottom-left of `b`.
_swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n])
# ...bottom-left of `a` with top-right of `b`.
_swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:])
# ...bottom-right of `a` with bottom-right of `b`.
_swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:])
def transpose_square(a):
"""In-place transpose of a square matrix.
The matrix must be a of shape `(n, n, m)`, where the `m` dimension
may be of arbitrary length and is not moved.
"""
if len(a.shape) != 3:
raise ValueError("a must be a matrix of batches")
n, n_, _ = a.shape
if n != n_:
raise ValueError("a must be square")
lb_exact(n)
if n <= 1:
return # Base case: no-op
half_n = n >> 1
# Transpose top-left quarter in-place.
transpose_square(a[:half_n, :half_n])
# Transpose top-right and bottom-left quarters and swap them.
_swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n])
# Transpose bottom-right quarter in-place.
transpose_square(a[half_n:, half_n:])

View File

@ -0,0 +1,6 @@
def lb_exact(n):
"""Returns `log2(n)`, raising if `n` is not a power of 2."""
lb = n.bit_length() - 1
if lb < 0 or n != 1 << lb:
raise ValueError(f"{n} is not a power of 2")
return lb