mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 06:43:07 +00:00
Python prototype of cache-oblivious FFT (#722)
This commit is contained in:
parent
b933e43cf1
commit
7d9e81362d
229
projects/cache-friendly-fft/__init__.py
Normal file
229
projects/cache-friendly-fft/__init__.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transpose import transpose_square
|
||||||
|
from util import lb_exact
|
||||||
|
|
||||||
|
|
||||||
|
def _interleave(x, scratch):
|
||||||
|
"""Interleave the elements in an array in-place.
|
||||||
|
|
||||||
|
For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its
|
||||||
|
contents will be rearranged to `array([1, 5, 2, 6, 3, 7, 4, 8])`.
|
||||||
|
|
||||||
|
`scratch` is an externally-allocated buffer, whose `dtype` matches
|
||||||
|
`x` and whose length is at least half the length of `x`.
|
||||||
|
"""
|
||||||
|
assert len(x.shape) == len(scratch.shape) == 1
|
||||||
|
|
||||||
|
n, = x.shape
|
||||||
|
assert n % 2 == 0
|
||||||
|
|
||||||
|
half_n = n // 2
|
||||||
|
assert scratch.shape[0] >= half_n
|
||||||
|
|
||||||
|
assert x.dtype == scratch.dtype
|
||||||
|
scratch = scratch[:half_n]
|
||||||
|
|
||||||
|
scratch[:] = x[:half_n] # Save the first half of `x`.
|
||||||
|
for i in range(half_n):
|
||||||
|
x[2 * i] = scratch[i]
|
||||||
|
x[2 * i + 1] = x[half_n + i]
|
||||||
|
|
||||||
|
|
||||||
|
def _deinterleave(x, scratch):
|
||||||
|
"""Deinterleave the elements in an array in-place.
|
||||||
|
|
||||||
|
For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its
|
||||||
|
contents will be rearranged to `array([1, 3, 5, 7, 2, 4, 6, 8])`.
|
||||||
|
|
||||||
|
`scratch` is an externally-allocated buffer, whose `dtype` matches
|
||||||
|
`x` and whose length is at least half the length of `x`.
|
||||||
|
"""
|
||||||
|
assert len(x.shape) == len(scratch.shape) == 1
|
||||||
|
|
||||||
|
n, = x.shape
|
||||||
|
assert n % 2 == 0
|
||||||
|
|
||||||
|
half_n = n // 2
|
||||||
|
assert scratch.shape[0] >= half_n
|
||||||
|
|
||||||
|
assert x.dtype == scratch.dtype
|
||||||
|
scratch = scratch[:half_n]
|
||||||
|
|
||||||
|
for i in range(half_n):
|
||||||
|
x[i] = x[2 * i]
|
||||||
|
scratch[i] = x[2 * i + 1]
|
||||||
|
x[half_n:] = scratch
|
||||||
|
|
||||||
|
|
||||||
|
def _fft_inplace_evenpow(x, scratch):
|
||||||
|
"""In-place FFT of length 2^even"""
|
||||||
|
# Reshape `x` to a square matrix in row-major order.
|
||||||
|
vec_len = x.shape[0]
|
||||||
|
n = 1 << (lb_exact(vec_len) >> 1) # Matrix dimension
|
||||||
|
x.shape = n, n, 1
|
||||||
|
|
||||||
|
# We want to recursively apply FFT to every column. Because `x` is
|
||||||
|
# in row-major order, we transpose it to make the columns contiguous
|
||||||
|
# in memory, then recurse, and finally transpose it back. While the
|
||||||
|
# row is in cache, we also multiply by the twiddle factors.
|
||||||
|
transpose_square(x)
|
||||||
|
for i, row in enumerate(x[..., 0]):
|
||||||
|
_fft_inplace(row, scratch)
|
||||||
|
# Multiply by the twiddle factors
|
||||||
|
for j in range(n):
|
||||||
|
row[j] *= np.exp(-2j * np.pi * (i * j) / vec_len)
|
||||||
|
transpose_square(x)
|
||||||
|
|
||||||
|
# Now recursively apply FFT to the rows.
|
||||||
|
for row in x[..., 0]:
|
||||||
|
_fft_inplace(row, scratch)
|
||||||
|
|
||||||
|
# Transpose again before returning.
|
||||||
|
transpose_square(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _fft_inplace_oddpow(x, scratch):
|
||||||
|
"""In-place FFT of length 2^odd"""
|
||||||
|
# This code is based on `_fft_inplace_evenpow`, but it has to
|
||||||
|
# account for some additional complications.
|
||||||
|
|
||||||
|
vec_len = x.shape[0]
|
||||||
|
# `vec_len` is an odd power of 2, so we cannot reshape `x` to a
|
||||||
|
# matrix square. Instead, we'll (conceptually) reshape it to a
|
||||||
|
# matrix that's twice as wide as it is high. E.g., `[1 ... 8]`
|
||||||
|
# becomes `[1 2 3 4]`
|
||||||
|
# `[5 6 7 8]`.
|
||||||
|
col_len = 1 << (lb_exact(vec_len) >> 1)
|
||||||
|
row_len = col_len << 1
|
||||||
|
|
||||||
|
# We can only perform efficient, in-place transposes on square
|
||||||
|
# matrices, so we will actually treat this as a square matrix of
|
||||||
|
# 2-tuples, e.g. `[(1 2) (3 4)]`
|
||||||
|
# `[(5 6) (7 8)]`.
|
||||||
|
# Note that we can currently `.reshape` it to our intended wide
|
||||||
|
# matrix (although this is broken by transposition).
|
||||||
|
x.shape = col_len, col_len, 2
|
||||||
|
|
||||||
|
# We want to apply FFT to each column. We transpose our
|
||||||
|
# matrix-of-tuples and get something like `[(1 2) (5 6)]`
|
||||||
|
# `[(3 4) (7 8)]`.
|
||||||
|
# Note that each row of the transposed matrix represents two columns
|
||||||
|
# of the original matrix. We can deinterleave the values to recover
|
||||||
|
# the original columns.
|
||||||
|
transpose_square(x)
|
||||||
|
|
||||||
|
for i, row_pair in enumerate(x):
|
||||||
|
# `row_pair` represents two columns of the original matrix.
|
||||||
|
# Their values must be deinterleaved to recover the columns.
|
||||||
|
row_pair.shape = row_len,
|
||||||
|
_deinterleave(row_pair, scratch)
|
||||||
|
# The below are rows of the transposed matrix(/cols of the
|
||||||
|
# original matrix.
|
||||||
|
row0 = row_pair[:col_len]
|
||||||
|
row1 = row_pair[col_len:]
|
||||||
|
|
||||||
|
# Apply FFT and twiddle factors to each.
|
||||||
|
_fft_inplace(row0, scratch)
|
||||||
|
for j in range(col_len):
|
||||||
|
row0[j] *= np.exp(-2j * np.pi * ((2 * i) * j) / vec_len)
|
||||||
|
_fft_inplace(row1, scratch)
|
||||||
|
for j in range(col_len):
|
||||||
|
row1[j] *= np.exp(-2j * np.pi * ((2 * i + 1) * j) / vec_len)
|
||||||
|
|
||||||
|
# Re-interleave them and transpose back.
|
||||||
|
_interleave(row_pair, scratch)
|
||||||
|
|
||||||
|
transpose_square(x)
|
||||||
|
|
||||||
|
# Recursively apply FFT to each row of the matrix.
|
||||||
|
for row in x:
|
||||||
|
# Turn vec of 2-tuples into vec of single elements.
|
||||||
|
row.shape = row_len,
|
||||||
|
_fft_inplace(row, scratch)
|
||||||
|
|
||||||
|
# Transpose again before returning. This again involves
|
||||||
|
# deinterleaving.
|
||||||
|
transpose_square(x)
|
||||||
|
for row_pair in x:
|
||||||
|
row_pair.shape = row_len,
|
||||||
|
_deinterleave(row_pair, scratch)
|
||||||
|
|
||||||
|
|
||||||
|
def _fft_inplace(x, scratch):
|
||||||
|
"""In-place FFT."""
|
||||||
|
# Avoid modifying the shape of the original.
|
||||||
|
# This does not copy the buffer.
|
||||||
|
x = x.view()
|
||||||
|
assert x.flags['C_CONTIGUOUS']
|
||||||
|
|
||||||
|
n, = x.shape
|
||||||
|
if n == 1:
|
||||||
|
return
|
||||||
|
if n == 2:
|
||||||
|
x0, x1 = x
|
||||||
|
x[0] = x0 + x1
|
||||||
|
x[1] = x0 - x1
|
||||||
|
return
|
||||||
|
|
||||||
|
lb_n = lb_exact(n)
|
||||||
|
is_odd = lb_n & 1 != 0
|
||||||
|
if is_odd:
|
||||||
|
_fft_inplace_oddpow(x, scratch)
|
||||||
|
else:
|
||||||
|
_fft_inplace_evenpow(x, scratch)
|
||||||
|
|
||||||
|
|
||||||
|
def _scrach_length(lb_n):
|
||||||
|
"""Find the amount of scratch space required to run the FFT.
|
||||||
|
|
||||||
|
Layers where the input's length is an even power of two do not
|
||||||
|
require scratch space, but the layers where that power is odd do.
|
||||||
|
"""
|
||||||
|
if lb_n == 0:
|
||||||
|
# Length-1 input.
|
||||||
|
return 0
|
||||||
|
# Repeatedly halve lb_n as long as it's even. This is the same as
|
||||||
|
# `n = sqrt(n)`, where the `sqrt` is exact.
|
||||||
|
while lb_n & 1 == 0:
|
||||||
|
lb_n >>= 1
|
||||||
|
# `lb_n` is now odd, so `n` is not an even power of 2.
|
||||||
|
lb_res = (lb_n - 1) >> 1
|
||||||
|
if lb_res == 0:
|
||||||
|
# Special case (n == 2 or n == 4): no scratch needed.
|
||||||
|
return 0
|
||||||
|
return 1 << lb_res
|
||||||
|
|
||||||
|
|
||||||
|
def fft(x):
|
||||||
|
"""Returns the FFT of `x`.
|
||||||
|
|
||||||
|
This is a wrapper around an in-place routine, provided for user
|
||||||
|
convenience.
|
||||||
|
"""
|
||||||
|
n, = x.shape
|
||||||
|
lb_n = lb_exact(n) # Raises if not a power of 2.
|
||||||
|
# We have one scratch buffer for the whole algorithm. If we were to
|
||||||
|
# parallelize it, we'd need one thread-local buffer for each worker
|
||||||
|
# thread.
|
||||||
|
scratch_len = _scrach_length(lb_n)
|
||||||
|
if scratch_len == 0:
|
||||||
|
scratch = None
|
||||||
|
else:
|
||||||
|
scratch = np.empty_like(x, shape=scratch_len, order='C', subok=False)
|
||||||
|
|
||||||
|
res = x.copy(order='C')
|
||||||
|
_fft_inplace(res, scratch)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
LENGTH = 1 << 10
|
||||||
|
v = np.random.normal(size=LENGTH).astype(complex)
|
||||||
|
print(v)
|
||||||
|
numpy_fft = np.fft.fft(v)
|
||||||
|
print(numpy_fft)
|
||||||
|
our_fft = fft(v)
|
||||||
|
print(our_fft)
|
||||||
|
print(np.isclose(numpy_fft, our_fft).all())
|
||||||
61
projects/cache-friendly-fft/transpose.py
Normal file
61
projects/cache-friendly-fft/transpose.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from util import lb_exact
|
||||||
|
|
||||||
|
|
||||||
|
def _swap_transpose_square(a, b):
|
||||||
|
"""Transpose two square matrices in-place and swap them.
|
||||||
|
|
||||||
|
The matrices must be a of shape `(n, n, m)`, where the `m` dimension
|
||||||
|
may be of arbitrary length and is not moved.
|
||||||
|
"""
|
||||||
|
assert len(a.shape) == len(b.shape) == 3
|
||||||
|
n = a.shape[0]
|
||||||
|
m = a.shape[2]
|
||||||
|
assert n == a.shape[1] == b.shape[0] == b.shape[1]
|
||||||
|
assert m == b.shape[2]
|
||||||
|
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
if n == 1:
|
||||||
|
# Swap the two matrices (transposition is a no-op).
|
||||||
|
a = a[0, 0]
|
||||||
|
b = b[0, 0]
|
||||||
|
# Recall that each element of the matrix is an `m`-vector. Swap
|
||||||
|
# all `m` elements.
|
||||||
|
for i in range(m):
|
||||||
|
a[i], b[i] = b[i], a[i]
|
||||||
|
return
|
||||||
|
|
||||||
|
half_n = n >> 1
|
||||||
|
# Transpose and swap top-left of `a` with top-left of `b`.
|
||||||
|
_swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n])
|
||||||
|
# ...top-right of `a` with bottom-left of `b`.
|
||||||
|
_swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n])
|
||||||
|
# ...bottom-left of `a` with top-right of `b`.
|
||||||
|
_swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:])
|
||||||
|
# ...bottom-right of `a` with bottom-right of `b`.
|
||||||
|
_swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:])
|
||||||
|
|
||||||
|
|
||||||
|
def transpose_square(a):
|
||||||
|
"""In-place transpose of a square matrix.
|
||||||
|
|
||||||
|
The matrix must be a of shape `(n, n, m)`, where the `m` dimension
|
||||||
|
may be of arbitrary length and is not moved.
|
||||||
|
"""
|
||||||
|
if len(a.shape) != 3:
|
||||||
|
raise ValueError("a must be a matrix of batches")
|
||||||
|
n, n_, _ = a.shape
|
||||||
|
if n != n_:
|
||||||
|
raise ValueError("a must be square")
|
||||||
|
lb_exact(n)
|
||||||
|
|
||||||
|
if n <= 1:
|
||||||
|
return # Base case: no-op
|
||||||
|
|
||||||
|
half_n = n >> 1
|
||||||
|
# Transpose top-left quarter in-place.
|
||||||
|
transpose_square(a[:half_n, :half_n])
|
||||||
|
# Transpose top-right and bottom-left quarters and swap them.
|
||||||
|
_swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n])
|
||||||
|
# Transpose bottom-right quarter in-place.
|
||||||
|
transpose_square(a[half_n:, half_n:])
|
||||||
6
projects/cache-friendly-fft/util.py
Normal file
6
projects/cache-friendly-fft/util.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
def lb_exact(n):
|
||||||
|
"""Returns `log2(n)`, raising if `n` is not a power of 2."""
|
||||||
|
lb = n.bit_length() - 1
|
||||||
|
if lb < 0 or n != 1 << lb:
|
||||||
|
raise ValueError(f"{n} is not a power of 2")
|
||||||
|
return lb
|
||||||
Loading…
x
Reference in New Issue
Block a user