Implement encode/decode+test using fft. Non-working
This commit is contained in:
parent
8ea2fb1fa3
commit
d75eb2949f
|
@ -0,0 +1,24 @@
|
|||
from da.kzg_rs.common import BLS_MODULUS
|
||||
|
||||
|
||||
def __fft(vals, modulus, roots_of_unity):
|
||||
if len(vals) == 1:
|
||||
return vals
|
||||
left = __fft(vals[::2], modulus, roots_of_unity[::2])
|
||||
right = __fft(vals[1::2], modulus, roots_of_unity[::2])
|
||||
o = [0 for _ in vals]
|
||||
for i, (x, y) in enumerate(zip(left, right)):
|
||||
y_times_root = y*int(roots_of_unity[i]) % modulus
|
||||
o[i] = (x+y_times_root) % modulus
|
||||
o[i+len(left)] = (x+modulus-y_times_root) % modulus
|
||||
return o
|
||||
|
||||
|
||||
def fft(vals, modulus, roots_of_unity):
|
||||
return __fft(vals, modulus, roots_of_unity)
|
||||
|
||||
|
||||
def ifft(vals, modulus, factor, roots_of_unity):
|
||||
# Inverse FFT
|
||||
invlen = pow(len(vals), modulus - factor, modulus)
|
||||
return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])]
|
|
@ -1,30 +1,21 @@
|
|||
from eth2spec.utils import bls
|
||||
from typing import Sequence
|
||||
|
||||
from .common import BLS_MODULUS
|
||||
from eth2spec.deneb.mainnet import BLSFieldElement
|
||||
|
||||
from .common import G1
|
||||
from .poly import Polynomial
|
||||
from functools import reduce
|
||||
from .fft import fft, ifft
|
||||
|
||||
|
||||
def generator_polynomial(n, k, gen=bls.G1()) -> Polynomial:
|
||||
"""
|
||||
Generate the generator polynomial for RS codes
|
||||
g(x) = (x-α^1)(x-α^2)...(x-α^(n-k))
|
||||
"""
|
||||
g = Polynomial([bls.Z1()], modulus=BLS_MODULUS)
|
||||
return reduce(
|
||||
Polynomial.__mul__,
|
||||
(Polynomial([bls.Z1(), bls.multiply(gen, alpha)], modulus=BLS_MODULUS) for alpha in range(1, n-k+1)),
|
||||
initial=g
|
||||
)
|
||||
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
|
||||
assert factor >= 2
|
||||
assert len(polynomial)*factor <= len(roots_of_unity)
|
||||
extended_polynomial_coefficients = polynomial.coefficients + [0]*(len(polynomial)*factor-1)
|
||||
extended_polynomial_coefficients = fft(extended_polynomial_coefficients, polynomial.modulus, roots_of_unity)
|
||||
return Polynomial(extended_polynomial_coefficients, modulus=polynomial.modulus)
|
||||
|
||||
|
||||
def encode(m: Polynomial, g: Polynomial, n: int, k: int) -> Polynomial:
|
||||
# mprime = q*g + b for some q
|
||||
xshift = Polynomial([bls.Z1(), *[0 for _ in range(n-k)]], modulus=m.modulus)
|
||||
mprime = m * xshift
|
||||
_, b = m / g
|
||||
# subtract out b, so now c = q*g
|
||||
c = mprime - b
|
||||
# Since c is a multiple of g, it has (at least) n-k roots: α^1 through
|
||||
# α^(n-k)
|
||||
return c
|
||||
def decode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
|
||||
coefficients = ifft(polynomial.coefficients, polynomial.modulus, factor, roots_of_unity)
|
||||
return Polynomial(coefficients=coefficients, modulus=polynomial.modulus)
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
|
||||
from da.kzg_rs.poly import Polynomial
|
||||
from da.kzg_rs.rs import encode, decode
|
||||
|
||||
|
||||
class TestFFT(TestCase):
|
||||
def test_encode_decode(self):
|
||||
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
|
||||
encoded = encode(poly, 2, ROOTS_OF_UNITY)
|
||||
decoded = decode(encoded, 2, ROOTS_OF_UNITY)
|
||||
for i in range(len(poly)):
|
||||
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(int(ROOTS_OF_UNITY[i])))
|
Loading…
Reference in New Issue