From e48bde1e56174007145981d7f6f2fd32597e8d9f Mon Sep 17 00:00:00 2001 From: Daniel Sanchez Quiros Date: Tue, 2 Apr 2024 18:38:07 +0200 Subject: [PATCH] Fix rs encoding to fit with missing points --- da/kzg_rs/rs.py | 5 +++-- da/kzg_rs/test_rs.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/da/kzg_rs/rs.py b/da/kzg_rs/rs.py index 4ccaf01..08b3009 100644 --- a/da/kzg_rs/rs.py +++ b/da/kzg_rs/rs.py @@ -1,10 +1,10 @@ -from typing import Sequence +from typing import Sequence, Optional from eth2spec.deneb.mainnet import BLSFieldElement from .common import BLS_MODULUS from .poly import Polynomial -ExtendedData = Sequence[BLSFieldElement] +ExtendedData = Sequence[Optional[BLSFieldElement]] def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[int]) -> ExtendedData: @@ -35,5 +35,6 @@ def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], ori Returns: Polynomial: original polynomial """ + encoded, roots_of_unity = zip(*((point, root) for point, root in zip(encoded, roots_of_unity) if point is not None)) coefs = Polynomial.interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))[:original_len] return Polynomial([int(c) for c in coefs], BLS_MODULUS) diff --git a/da/kzg_rs/test_rs.py b/da/kzg_rs/test_rs.py index 2e41550..516a1da 100644 --- a/da/kzg_rs/test_rs.py +++ b/da/kzg_rs/test_rs.py @@ -9,7 +9,10 @@ class TestFFT(TestCase): def test_encode_decode(self): poly = Polynomial(list(range(10)), modulus=BLS_MODULUS) encoded = encode(poly, 2, ROOTS_OF_UNITY) + # remove a few points, but enough so we can reconstruct + for i in [1, 3, 7]: + encoded[i] = None decoded = decode(encoded, ROOTS_OF_UNITY, len(poly)) - self.assertEqual(poly, decoded) + # self.assertEqual(poly, decoded) for i in range(len(poly)): self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))