diff --git a/da/kzg_rs/fft.py b/da/kzg_rs/fft.py new file mode 100644 index 0000000..2760279 --- /dev/null +++ b/da/kzg_rs/fft.py @@ -0,0 +1,24 @@ +from da.kzg_rs.common import BLS_MODULUS + + +def __fft(vals, modulus, roots_of_unity): + if len(vals) == 1: + return vals + left = __fft(vals[::2], modulus, roots_of_unity[::2]) + right = __fft(vals[1::2], modulus, roots_of_unity[::2]) + o = [0 for _ in vals] + for i, (x, y) in enumerate(zip(left, right)): + y_times_root = y*int(roots_of_unity[i]) % modulus + o[i] = (x+y_times_root) % modulus + o[i+len(left)] = (x+modulus-y_times_root) % modulus + return o + + +def fft(vals, modulus, roots_of_unity): + return __fft(vals, modulus, roots_of_unity) + + +def ifft(vals, modulus, factor, roots_of_unity): + # Inverse FFT + invlen = pow(len(vals), modulus - factor, modulus) + return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])] diff --git a/da/kzg_rs/rs.py b/da/kzg_rs/rs.py index 18e5537..8ff686a 100644 --- a/da/kzg_rs/rs.py +++ b/da/kzg_rs/rs.py @@ -1,30 +1,21 @@ -from eth2spec.utils import bls +from typing import Sequence -from .common import BLS_MODULUS +from eth2spec.deneb.mainnet import BLSFieldElement + +from .common import G1 from .poly import Polynomial -from functools import reduce +from .fft import fft, ifft -def generator_polynomial(n, k, gen=bls.G1()) -> Polynomial: - """ - Generate the generator polynomial for RS codes - g(x) = (x-α^1)(x-α^2)...(x-α^(n-k)) - """ - g = Polynomial([bls.Z1()], modulus=BLS_MODULUS) - return reduce( - Polynomial.__mul__, - (Polynomial([bls.Z1(), bls.multiply(gen, alpha)], modulus=BLS_MODULUS) for alpha in range(1, n-k+1)), - initial=g - ) +def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial: + assert factor >= 2 + assert len(polynomial)*factor <= len(roots_of_unity) + extended_polynomial_coefficients = polynomial.coefficients + [0]*(len(polynomial)*factor-1) + extended_polynomial_coefficients = fft(extended_polynomial_coefficients, polynomial.modulus, roots_of_unity) + return Polynomial(extended_polynomial_coefficients, modulus=polynomial.modulus) -def encode(m: Polynomial, g: Polynomial, n: int, k: int) -> Polynomial: - # mprime = q*g + b for some q - xshift = Polynomial([bls.Z1(), *[0 for _ in range(n-k)]], modulus=m.modulus) - mprime = m * xshift - _, b = m / g - # subtract out b, so now c = q*g - c = mprime - b - # Since c is a multiple of g, it has (at least) n-k roots: α^1 through - # α^(n-k) - return c +def decode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial: + coefficients = ifft(polynomial.coefficients, polynomial.modulus, factor, roots_of_unity) + return Polynomial(coefficients=coefficients, modulus=polynomial.modulus) + diff --git a/da/kzg_rs/test_fft.py b/da/kzg_rs/test_fft.py new file mode 100644 index 0000000..7f85d23 --- /dev/null +++ b/da/kzg_rs/test_fft.py @@ -0,0 +1,14 @@ +from unittest import TestCase + +from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY +from da.kzg_rs.poly import Polynomial +from da.kzg_rs.rs import encode, decode + + +class TestFFT(TestCase): + def test_encode_decode(self): + poly = Polynomial(list(range(10)), modulus=BLS_MODULUS) + encoded = encode(poly, 2, ROOTS_OF_UNITY) + decoded = decode(encoded, 2, ROOTS_OF_UNITY) + for i in range(len(poly)): + self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(int(ROOTS_OF_UNITY[i])))