mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-02 22:03:07 +00:00
62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
from util import lb_exact
|
|
|
|
|
|
def _swap_transpose_square(a, b):
|
|
"""Transpose two square matrices in-place and swap them.
|
|
|
|
The matrices must be a of shape `(n, n, m)`, where the `m` dimension
|
|
may be of arbitrary length and is not moved.
|
|
"""
|
|
assert len(a.shape) == len(b.shape) == 3
|
|
n = a.shape[0]
|
|
m = a.shape[2]
|
|
assert n == a.shape[1] == b.shape[0] == b.shape[1]
|
|
assert m == b.shape[2]
|
|
|
|
if n == 0:
|
|
return
|
|
if n == 1:
|
|
# Swap the two matrices (transposition is a no-op).
|
|
a = a[0, 0]
|
|
b = b[0, 0]
|
|
# Recall that each element of the matrix is an `m`-vector. Swap
|
|
# all `m` elements.
|
|
for i in range(m):
|
|
a[i], b[i] = b[i], a[i]
|
|
return
|
|
|
|
half_n = n >> 1
|
|
# Transpose and swap top-left of `a` with top-left of `b`.
|
|
_swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n])
|
|
# ...top-right of `a` with bottom-left of `b`.
|
|
_swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n])
|
|
# ...bottom-left of `a` with top-right of `b`.
|
|
_swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:])
|
|
# ...bottom-right of `a` with bottom-right of `b`.
|
|
_swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:])
|
|
|
|
|
|
def transpose_square(a):
|
|
"""In-place transpose of a square matrix.
|
|
|
|
The matrix must be a of shape `(n, n, m)`, where the `m` dimension
|
|
may be of arbitrary length and is not moved.
|
|
"""
|
|
if len(a.shape) != 3:
|
|
raise ValueError("a must be a matrix of batches")
|
|
n, n_, _ = a.shape
|
|
if n != n_:
|
|
raise ValueError("a must be square")
|
|
lb_exact(n)
|
|
|
|
if n <= 1:
|
|
return # Base case: no-op
|
|
|
|
half_n = n >> 1
|
|
# Transpose top-left quarter in-place.
|
|
transpose_square(a[:half_n, :half_n])
|
|
# Transpose top-right and bottom-left quarters and swap them.
|
|
_swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n])
|
|
# Transpose bottom-right quarter in-place.
|
|
transpose_square(a[half_n:, half_n:])
|