diff --git a/da/kzg_rs/fft.py b/da/kzg_rs/fft.py index bf93412..0f0ed93 100644 --- a/da/kzg_rs/fft.py +++ b/da/kzg_rs/fft.py @@ -1,37 +1,37 @@ +from typing import Sequence -def _simple_ft(vals, modulus, roots_of_unity): - L = len(roots_of_unity) - o = [] - for i in range(L): - last = 0 - for j in range(L): - last += vals[j] * roots_of_unity[(i*j)%L] - o.append(last % modulus) - return o +from eth2spec.deneb.mainnet import BLSFieldElement -def _fft(vals, modulus, roots_of_unity): - if len(vals) == 4: - return _simple_ft(vals, modulus, roots_of_unity) +def _fft( + vals: Sequence[BLSFieldElement], + roots_of_unity: Sequence[BLSFieldElement], + modulus: int, +) -> Sequence[BLSFieldElement]: if len(vals) == 1: return vals - L = _fft(vals[::2], modulus, roots_of_unity[::2]) - R = _fft(vals[1::2], modulus, roots_of_unity[::2]) - o = [0 for _ in vals] + L = _fft(vals[::2], roots_of_unity[::2], modulus) + R = _fft(vals[1::2], roots_of_unity[::2], modulus) + o = [BLSFieldElement(0) for _ in vals] for i, (x, y) in enumerate(zip(L, R)): - y_times_root = (y*roots_of_unity[i]) % modulus - o[i] = (x+y_times_root) % modulus - o[i+len(L)] = (x-y_times_root+modulus) % modulus + y_times_root = (int(y) * int(roots_of_unity[i])) % modulus + o[i] = BLSFieldElement((int(x) + y_times_root) % modulus) + o[i + len(L)] = BLSFieldElement((int(x) - y_times_root + modulus) % modulus) return o -def fft(vals, modulus, root_of_unity): +def fft(vals, root_of_unity, modulus): assert len(vals) == len(root_of_unity) - return _fft(vals, modulus, root_of_unity) + return _fft(vals, root_of_unity, modulus) -def ifft(vals, modulus, root_of_unity): +def ifft(vals, root_of_unity, modulus): assert len(vals) == len(root_of_unity) # modular inverse - invlen = pow(len(vals), -1, modulus) - return [(x * invlen) % modulus for x in _fft(vals, modulus, list(reversed(root_of_unity)))] + invlen = pow(len(vals), modulus-2, modulus) + return [ + BLSFieldElement((int(x) * invlen) % modulus) + for x in _fft( + vals, [root_of_unity[0], *root_of_unity[:0:-1]], modulus + ) + ] diff --git a/da/kzg_rs/test_fft.py b/da/kzg_rs/test_fft.py index bb7b263..15a5347 100644 --- a/da/kzg_rs/test_fft.py +++ b/da/kzg_rs/test_fft.py @@ -2,12 +2,12 @@ from unittest import TestCase from da.kzg_rs.common import BLS_MODULUS from fft import fft, ifft -from eth2spec.eip7594.mainnet import fft_field, BLSFieldElement class TestFFT(TestCase): def test_fft_ifft(self): - roots_of_unity = [pow(2, i, BLS_MODULUS) for i in range(8)] - vals = list(BLSFieldElement(x) for x in range(8)) - vals_fft = fft_field(vals, roots_of_unity) - self.assertEqual(vals, fft_field(vals_fft, roots_of_unity, inv=True)) + roots_of_unity = [pow(23674694431658770659612952115660802947967373701506253797663184111817857449850, i, BLS_MODULUS) for i in range(1024)] + vals = list(x for x in range(1024)) + vals_fft = fft(vals, roots_of_unity, BLS_MODULUS) + self.assertEqual(vals, ifft(vals_fft, roots_of_unity, BLS_MODULUS)) +