2022-09-15 14:59:16 -07:00

62 lines
2.0 KiB
Python

from util import lb_exact
def _swap_transpose_square(a, b):
"""Transpose two square matrices in-place and swap them.
The matrices must be a of shape `(n, n, m)`, where the `m` dimension
may be of arbitrary length and is not moved.
"""
assert len(a.shape) == len(b.shape) == 3
n = a.shape[0]
m = a.shape[2]
assert n == a.shape[1] == b.shape[0] == b.shape[1]
assert m == b.shape[2]
if n == 0:
return
if n == 1:
# Swap the two matrices (transposition is a no-op).
a = a[0, 0]
b = b[0, 0]
# Recall that each element of the matrix is an `m`-vector. Swap
# all `m` elements.
for i in range(m):
a[i], b[i] = b[i], a[i]
return
half_n = n >> 1
# Transpose and swap top-left of `a` with top-left of `b`.
_swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n])
# ...top-right of `a` with bottom-left of `b`.
_swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n])
# ...bottom-left of `a` with top-right of `b`.
_swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:])
# ...bottom-right of `a` with bottom-right of `b`.
_swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:])
def transpose_square(a):
"""In-place transpose of a square matrix.
The matrix must be a of shape `(n, n, m)`, where the `m` dimension
may be of arbitrary length and is not moved.
"""
if len(a.shape) != 3:
raise ValueError("a must be a matrix of batches")
n, n_, _ = a.shape
if n != n_:
raise ValueError("a must be square")
lb_exact(n)
if n <= 1:
return # Base case: no-op
half_n = n >> 1
# Transpose top-left quarter in-place.
transpose_square(a[:half_n, :half_n])
# Transpose top-right and bottom-left quarters and swap them.
_swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n])
# Transpose bottom-right quarter in-place.
transpose_square(a[half_n:, half_n:])