Implement encode/decode+test using fft. Non-working

This commit is contained in:
Daniel Sanchez Quiros 2024-02-28 14:05:08 +01:00
parent 8ea2fb1fa3
commit d75eb2949f
3 changed files with 53 additions and 24 deletions

24
da/kzg_rs/fft.py Normal file
View File

@ -0,0 +1,24 @@
from da.kzg_rs.common import BLS_MODULUS
def __fft(vals, modulus, roots_of_unity):
if len(vals) == 1:
return vals
left = __fft(vals[::2], modulus, roots_of_unity[::2])
right = __fft(vals[1::2], modulus, roots_of_unity[::2])
o = [0 for _ in vals]
for i, (x, y) in enumerate(zip(left, right)):
y_times_root = y*int(roots_of_unity[i]) % modulus
o[i] = (x+y_times_root) % modulus
o[i+len(left)] = (x+modulus-y_times_root) % modulus
return o
def fft(vals, modulus, roots_of_unity):
return __fft(vals, modulus, roots_of_unity)
def ifft(vals, modulus, factor, roots_of_unity):
# Inverse FFT
invlen = pow(len(vals), modulus - factor, modulus)
return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])]

View File

@ -1,30 +1,21 @@
from eth2spec.utils import bls
from typing import Sequence
from .common import BLS_MODULUS
from eth2spec.deneb.mainnet import BLSFieldElement
from .common import G1
from .poly import Polynomial
from functools import reduce
from .fft import fft, ifft
def generator_polynomial(n, k, gen=bls.G1()) -> Polynomial:
"""
Generate the generator polynomial for RS codes
g(x) = (x-α^1)(x-α^2)...(x-α^(n-k))
"""
g = Polynomial([bls.Z1()], modulus=BLS_MODULUS)
return reduce(
Polynomial.__mul__,
(Polynomial([bls.Z1(), bls.multiply(gen, alpha)], modulus=BLS_MODULUS) for alpha in range(1, n-k+1)),
initial=g
)
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
assert factor >= 2
assert len(polynomial)*factor <= len(roots_of_unity)
extended_polynomial_coefficients = polynomial.coefficients + [0]*(len(polynomial)*factor-1)
extended_polynomial_coefficients = fft(extended_polynomial_coefficients, polynomial.modulus, roots_of_unity)
return Polynomial(extended_polynomial_coefficients, modulus=polynomial.modulus)
def encode(m: Polynomial, g: Polynomial, n: int, k: int) -> Polynomial:
# mprime = q*g + b for some q
xshift = Polynomial([bls.Z1(), *[0 for _ in range(n-k)]], modulus=m.modulus)
mprime = m * xshift
_, b = m / g
# subtract out b, so now c = q*g
c = mprime - b
# Since c is a multiple of g, it has (at least) n-k roots: α^1 through
# α^(n-k)
return c
def decode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
coefficients = ifft(polynomial.coefficients, polynomial.modulus, factor, roots_of_unity)
return Polynomial(coefficients=coefficients, modulus=polynomial.modulus)

14
da/kzg_rs/test_fft.py Normal file
View File

@ -0,0 +1,14 @@
from unittest import TestCase
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
from da.kzg_rs.poly import Polynomial
from da.kzg_rs.rs import encode, decode
class TestFFT(TestCase):
def test_encode_decode(self):
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
encoded = encode(poly, 2, ROOTS_OF_UNITY)
decoded = decode(encoded, 2, ROOTS_OF_UNITY)
for i in range(len(poly)):
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(int(ROOTS_OF_UNITY[i])))