Fix rs encoding to fit with missing points

This commit is contained in:
Daniel Sanchez Quiros 2024-04-02 18:38:07 +02:00
parent f29869c029
commit e48bde1e56
2 changed files with 7 additions and 3 deletions

View File

@ -1,10 +1,10 @@
from typing import Sequence
from typing import Sequence, Optional
from eth2spec.deneb.mainnet import BLSFieldElement
from .common import BLS_MODULUS
from .poly import Polynomial
ExtendedData = Sequence[BLSFieldElement]
ExtendedData = Sequence[Optional[BLSFieldElement]]
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[int]) -> ExtendedData:
@ -35,5 +35,6 @@ def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], ori
Returns:
Polynomial: original polynomial
"""
encoded, roots_of_unity = zip(*((point, root) for point, root in zip(encoded, roots_of_unity) if point is not None))
coefs = Polynomial.interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))[:original_len]
return Polynomial([int(c) for c in coefs], BLS_MODULUS)

View File

@ -9,7 +9,10 @@ class TestFFT(TestCase):
def test_encode_decode(self):
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
encoded = encode(poly, 2, ROOTS_OF_UNITY)
# remove a few points, but enough so we can reconstruct
for i in [1, 3, 7]:
encoded[i] = None
decoded = decode(encoded, ROOTS_OF_UNITY, len(poly))
self.assertEqual(poly, decoded)
# self.assertEqual(poly, decoded)
for i in range(len(poly)):
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))