Implement i/fft from ethspecs
This commit is contained in:
parent
d7bafffbdc
commit
0488748d5d
|
@ -1,37 +1,37 @@
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
def _simple_ft(vals, modulus, roots_of_unity):
|
from eth2spec.deneb.mainnet import BLSFieldElement
|
||||||
L = len(roots_of_unity)
|
|
||||||
o = []
|
|
||||||
for i in range(L):
|
|
||||||
last = 0
|
|
||||||
for j in range(L):
|
|
||||||
last += vals[j] * roots_of_unity[(i*j)%L]
|
|
||||||
o.append(last % modulus)
|
|
||||||
return o
|
|
||||||
|
|
||||||
|
|
||||||
def _fft(vals, modulus, roots_of_unity):
|
def _fft(
|
||||||
if len(vals) == 4:
|
vals: Sequence[BLSFieldElement],
|
||||||
return _simple_ft(vals, modulus, roots_of_unity)
|
roots_of_unity: Sequence[BLSFieldElement],
|
||||||
|
modulus: int,
|
||||||
|
) -> Sequence[BLSFieldElement]:
|
||||||
if len(vals) == 1:
|
if len(vals) == 1:
|
||||||
return vals
|
return vals
|
||||||
L = _fft(vals[::2], modulus, roots_of_unity[::2])
|
L = _fft(vals[::2], roots_of_unity[::2], modulus)
|
||||||
R = _fft(vals[1::2], modulus, roots_of_unity[::2])
|
R = _fft(vals[1::2], roots_of_unity[::2], modulus)
|
||||||
o = [0 for _ in vals]
|
o = [BLSFieldElement(0) for _ in vals]
|
||||||
for i, (x, y) in enumerate(zip(L, R)):
|
for i, (x, y) in enumerate(zip(L, R)):
|
||||||
y_times_root = (y*roots_of_unity[i]) % modulus
|
y_times_root = (int(y) * int(roots_of_unity[i])) % modulus
|
||||||
o[i] = (x+y_times_root) % modulus
|
o[i] = BLSFieldElement((int(x) + y_times_root) % modulus)
|
||||||
o[i+len(L)] = (x-y_times_root+modulus) % modulus
|
o[i + len(L)] = BLSFieldElement((int(x) - y_times_root + modulus) % modulus)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def fft(vals, modulus, root_of_unity):
|
def fft(vals, root_of_unity, modulus):
|
||||||
assert len(vals) == len(root_of_unity)
|
assert len(vals) == len(root_of_unity)
|
||||||
return _fft(vals, modulus, root_of_unity)
|
return _fft(vals, root_of_unity, modulus)
|
||||||
|
|
||||||
|
|
||||||
def ifft(vals, modulus, root_of_unity):
|
def ifft(vals, root_of_unity, modulus):
|
||||||
assert len(vals) == len(root_of_unity)
|
assert len(vals) == len(root_of_unity)
|
||||||
# modular inverse
|
# modular inverse
|
||||||
invlen = pow(len(vals), -1, modulus)
|
invlen = pow(len(vals), modulus-2, modulus)
|
||||||
return [(x * invlen) % modulus for x in _fft(vals, modulus, list(reversed(root_of_unity)))]
|
return [
|
||||||
|
BLSFieldElement((int(x) * invlen) % modulus)
|
||||||
|
for x in _fft(
|
||||||
|
vals, [root_of_unity[0], *root_of_unity[:0:-1]], modulus
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
|
@ -2,12 +2,12 @@ from unittest import TestCase
|
||||||
|
|
||||||
from da.kzg_rs.common import BLS_MODULUS
|
from da.kzg_rs.common import BLS_MODULUS
|
||||||
from fft import fft, ifft
|
from fft import fft, ifft
|
||||||
from eth2spec.eip7594.mainnet import fft_field, BLSFieldElement
|
|
||||||
|
|
||||||
|
|
||||||
class TestFFT(TestCase):
|
class TestFFT(TestCase):
|
||||||
def test_fft_ifft(self):
|
def test_fft_ifft(self):
|
||||||
roots_of_unity = [pow(2, i, BLS_MODULUS) for i in range(8)]
|
roots_of_unity = [pow(23674694431658770659612952115660802947967373701506253797663184111817857449850, i, BLS_MODULUS) for i in range(1024)]
|
||||||
vals = list(BLSFieldElement(x) for x in range(8))
|
vals = list(x for x in range(1024))
|
||||||
vals_fft = fft_field(vals, roots_of_unity)
|
vals_fft = fft(vals, roots_of_unity, BLS_MODULUS)
|
||||||
self.assertEqual(vals, fft_field(vals_fft, roots_of_unity, inv=True))
|
self.assertEqual(vals, ifft(vals_fft, roots_of_unity, BLS_MODULUS))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue