From 8d56ab7eb6a3bdb6c97e4da288f34f812c11b4c1 Mon Sep 17 00:00:00 2001 From: danielsanchezq Date: Tue, 11 Jun 2024 18:08:19 +0200 Subject: [PATCH] Implement fft for g1 values --- da/kzg_rs/fft.py | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/da/kzg_rs/fft.py b/da/kzg_rs/fft.py index 0f0ed93..db74b6d 100644 --- a/da/kzg_rs/fft.py +++ b/da/kzg_rs/fft.py @@ -1,7 +1,34 @@ -from typing import Sequence +from typing import Sequence, List from eth2spec.deneb.mainnet import BLSFieldElement +from eth2spec.utils import bls +from da.kzg_rs.common import G1 + + +def fft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]: + if len(vals) == 1: + return vals + L = fft_g1(vals[::2], roots_of_unity[::2], modulus) + R = fft_g1(vals[1::2], roots_of_unity[::2], modulus) + o = [bls.Z1() for _ in vals] + for i, (x, y) in enumerate(zip(L, R)): + y_times_root = bls.multiply(y, roots_of_unity[i]) + o[i] = (x + y_times_root) + o[i + len(L)] = x + -y_times_root + return o + + +def ifft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]: + assert len(vals) == len(roots_of_unity) + # modular inverse + invlen = pow(len(vals), modulus-2, modulus) + return [ + bls.multiply(x, invlen) + for x in fft_g1( + vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus + ) + ] def _fft( vals: Sequence[BLSFieldElement], @@ -14,24 +41,23 @@ def _fft( R = _fft(vals[1::2], roots_of_unity[::2], modulus) o = [BLSFieldElement(0) for _ in vals] for i, (x, y) in enumerate(zip(L, R)): - y_times_root = (int(y) * int(roots_of_unity[i])) % modulus + y_times_root = BLSFieldElement((int(y) * int(roots_of_unity[i])) % modulus) o[i] = BLSFieldElement((int(x) + y_times_root) % modulus) - o[i + len(L)] = BLSFieldElement((int(x) - y_times_root + modulus) % modulus) + o[i + len(L)] = BLSFieldElement((int(x) - int(y_times_root) + modulus) % modulus) return o - def fft(vals, root_of_unity, modulus): assert len(vals) == len(root_of_unity) return _fft(vals, root_of_unity, modulus) -def ifft(vals, root_of_unity, modulus): - assert len(vals) == len(root_of_unity) +def ifft(vals, roots_of_unity, modulus): + assert len(vals) == len(roots_of_unity) # modular inverse invlen = pow(len(vals), modulus-2, modulus) return [ BLSFieldElement((int(x) * invlen) % modulus) for x in _fft( - vals, [root_of_unity[0], *root_of_unity[:0:-1]], modulus + vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus ) ]