diff --git a/da/kzg_rs/common.py b/da/kzg_rs/common.py index c7ca72a..9eee3c2 100644 --- a/da/kzg_rs/common.py +++ b/da/kzg_rs/common.py @@ -1,9 +1,9 @@ from typing import List import eth2spec.eip7594.mainnet -from eth2spec.eip7594.mainnet import BLSFieldElement, compute_roots_of_unity from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed -from remerkleable.basic import uint64 + +from da.kzg_rs.roots import compute_roots_of_unity from da.kzg_rs.trusted_setup import generate_setup G1 = G1Uncompressed @@ -11,9 +11,9 @@ G2 = G2Uncompressed BYTES_PER_FIELD_ELEMENT = 32 +BLS_MODULUS = eth2spec.eip7594.mainnet.BLS_MODULUS GLOBAL_PARAMETERS: List[G1] GLOBAL_PARAMETERS_G2: List[G2] # secret is fixed but this should come from a different synchronization protocol GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(1024, 8, 1987)) -ROOTS_OF_UNITY: List[BLSFieldElement] = list(compute_roots_of_unity(uint64(4096))) -BLS_MODULUS = eth2spec.eip7594.mainnet.BLS_MODULUS +ROOTS_OF_UNITY: List[int] = compute_roots_of_unity(2, BLS_MODULUS, 4096) diff --git a/da/kzg_rs/poly.py b/da/kzg_rs/poly.py index 82a6292..2844106 100644 --- a/da/kzg_rs/poly.py +++ b/da/kzg_rs/poly.py @@ -76,6 +76,9 @@ class Polynomial[T]: def __getitem__(self, item): return self.coefficients[item] + def __eq__(self, other): + return self.coefficients == other.coefficients and self.modulus == other.modulus + def eval(self, element): return sum( (pow(element, i)*x) % self.modulus for i, x in enumerate(self.coefficients) diff --git a/da/kzg_rs/roots.py b/da/kzg_rs/roots.py new file mode 100644 index 0000000..ec5988b --- /dev/null +++ b/da/kzg_rs/roots.py @@ -0,0 +1,14 @@ +def compute_roots_of_unity(primitive_root, p, n): + """ + Compute the roots of unity modulo p. + + Parameters: + primitive_root (int): Primitive root modulo p. + p (int): Modulus. + n (int): Number of roots of unity to compute. + + Returns: + list: List of roots of unity modulo p. + """ + roots_of_unity = [pow(primitive_root, i, p) for i in range(n)] + return roots_of_unity diff --git a/da/kzg_rs/rs.py b/da/kzg_rs/rs.py new file mode 100644 index 0000000..ac600ff --- /dev/null +++ b/da/kzg_rs/rs.py @@ -0,0 +1,55 @@ +from typing import Sequence, List + +import scipy.interpolate +from eth2spec.deneb.mainnet import BLSFieldElement +from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff +from .common import G1, BLS_MODULUS +from .poly import Polynomial + +ExtendedData = Sequence[BLSFieldElement] + + +def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> ExtendedData: + """ + Encode a polynomial extending to the given factor + Parameters: + polynomial: Polynomial to be encoded + factor: Encoding factor + roots_of_unity: Powers of 2 sequence + + Returns: + list: Extended data set + """ + assert factor >= 2 + assert len(polynomial)*factor <= len(roots_of_unity) + return [polynomial.eval(e) for e in roots_of_unity[:len(polynomial)*factor]] + + +def __interpolate(evaluations: List[int], roots_of_unity: List[int]) -> List[int]: + """ + Lagrange interpolation + + Parameters: + evaluations: List of evaluations + roots_of_unity: Powers of 2 sequence + + Returns: + list: Coefficients of the interpolated polynomial + """ + return list(map(int, interpolate_polynomialcoeff(roots_of_unity[:len(evaluations)], evaluations))) + + +def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], original_len: int) -> Polynomial: + """ + Decode a polynomial from an extended data-set and the roots of unity, cap to original length + + Parameters: + encoded: Extended data set + roots_of_unity: Powers of 2 sequence + original_len: Original length of the encoded polynomial + + Returns: + Polynomial: original polynomial + """ + coefs = __interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))[:original_len] + return Polynomial([int(c) for c in coefs], BLS_MODULUS) diff --git a/da/test_kzg.py b/da/kzg_rs/test_kzg.py similarity index 100% rename from da/test_kzg.py rename to da/kzg_rs/test_kzg.py diff --git a/da/kzg_rs/test_rs.py b/da/kzg_rs/test_rs.py new file mode 100644 index 0000000..2e41550 --- /dev/null +++ b/da/kzg_rs/test_rs.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY +from da.kzg_rs.poly import Polynomial +from da.kzg_rs.rs import encode, decode + + +class TestFFT(TestCase): + def test_encode_decode(self): + poly = Polynomial(list(range(10)), modulus=BLS_MODULUS) + encoded = encode(poly, 2, ROOTS_OF_UNITY) + decoded = decode(encoded, ROOTS_OF_UNITY, len(poly)) + self.assertEqual(poly, decoded) + for i in range(len(poly)): + self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))