diff --git a/da/kzg_rs/kzg.py b/da/kzg_rs/kzg.py index c84c5cf..b476a4e 100644 --- a/da/kzg_rs/kzg.py +++ b/da/kzg_rs/kzg.py @@ -15,7 +15,7 @@ def bytes_to_polynomial(b: bytearray) -> Polynomial: Convert bytes to list of BLS field scalars. """ assert len(b) % BYTES_PER_FIELD_ELEMENT == 0 - return Polynomial([int(bytes_to_bls_field(b)) for b in batched(b, int(BYTES_PER_FIELD_ELEMENT))]) + return Polynomial([int(bytes_to_bls_field(b)) for b in batched(b, int(BYTES_PER_FIELD_ELEMENT))], BLS_MODULUS) def g1_linear_combination(polynomial: Polynomial[BLSFieldElement], global_parameters: Sequence[G1]) -> Commitment: @@ -27,7 +27,7 @@ def g1_linear_combination(polynomial: Polynomial[BLSFieldElement], global_parame assert len(polynomial) <= len(global_parameters) point = reduce( bls.add, - (bls.multiply(g, p) for g, p in zip(global_parameters, polynomial.coef)), + (bls.multiply(g, p) for g, p in zip(global_parameters, polynomial)), bls.Z1() ) return Commitment(bls.G1_to_bytes48(point)) @@ -44,8 +44,7 @@ def generate_element_proof( global_parameters: Sequence[G1] ) -> Proof: # compute a witness polynomial in that satisfies `witness(x) = (f(x)-v)/(x-u)` - f_x_v = polynomial - Polynomial([polynomial.eval(int(element)) % BLS_MODULUS]) - x_u = Polynomial([-element, BLSFieldElement(1)]) - witness = f_x_v // x_u - witness = Polynomial(list(BLSFieldElement(int(x) % BLS_MODULUS) for x in reversed(witness) if x != inf)) + f_x_v = polynomial - Polynomial([polynomial.eval(int(element)) % BLS_MODULUS], BLS_MODULUS) + x_u = Polynomial([-element, BLSFieldElement(1)], BLS_MODULUS) + witness, _ = f_x_v / x_u return g1_linear_combination(witness, global_parameters) diff --git a/da/kzg_rs/poly.py b/da/kzg_rs/poly.py index 30b6aaa..c67c17c 100644 --- a/da/kzg_rs/poly.py +++ b/da/kzg_rs/poly.py @@ -1,22 +1,69 @@ -from typing import Self, List -from eth2spec.eip7594.mainnet import BLS_MODULUS -import numpy as np -from sympy import ntt, intt +class Polynomial[T]: + def __init__(self, coefficients, modulus): + self.coefficients = coefficients + self.modulus = modulus + def __repr__(self): + return "Polynomial({}, modulus={})".format(self.coefficients, self.modulus) -class Polynomial[T](np.polynomial.Polynomial): - def __init__(self, coef, domain=None, window=None, symbol="x"): - self.coef = coef - super().__init__(coef, domain, window, symbol) + def __add__(self, other): + return Polynomial( + [(a + b) % self.modulus for a, b in zip(self.coefficients, other.coefficients)], + self.modulus + ) - def eval(self, x: T) -> T: - return np.polyval(self, x) + def __sub__(self, other): + return Polynomial( + [(a - b) % self.modulus for a, b in zip(self.coefficients, other.coefficients)], + self.modulus + ) - def evaluation_form(self, modulus=BLS_MODULUS) -> Self: - return Polynomial(intt(reversed(self), prime=modulus)) + def __mul__(self, other): + result = [0] * (len(self.coefficients) + len(other.coefficients) - 1) + for i in range(len(self.coefficients)): + for j in range(len(other.coefficients)): + result[i + j] = (result[i + j] + self.coefficients[i] * other.coefficients[j]) % self.modulus + return Polynomial(result, self.modulus) - # def __truediv__(self, other): - # return Polynomial(list(reversed(np.polydiv(list(reversed(self.coef)), list(reversed(other.coef)))))) + def divide(self, other): + if not isinstance(other, Polynomial): + raise ValueError("Unsupported type for division.") + + dividend = list(self.coefficients) + divisor = list(other.coefficients) + + quotient = [] + remainder = dividend + + while len(remainder) >= len(divisor): + factor = remainder[-1] * pow(divisor[-1], -1, self.modulus) % self.modulus + quotient.insert(0, factor) + + # Subtract divisor * factor from remainder + for i in range(len(divisor)): + remainder[len(remainder) - len(divisor) + i] -= divisor[i] * factor + remainder[len(remainder) - len(divisor) + i] %= self.modulus + + # Remove leading zeros from remainder + while remainder and remainder[-1] == 0: + remainder.pop() + + return Polynomial(quotient, self.modulus), Polynomial(remainder, self.modulus) + + def __truediv__(self, other): + return self.divide(other) + + def __neg__(self): + return Polynomial([-1 * c for c in self.coefficients], self.modulus) + + def __len__(self): + return len(self.coefficients) + + def __iter__(self): + return iter(self.coefficients) def __getitem__(self, item): - return self.coef[item] + return self.coefficients[item] + + def eval(self, element): + return sum(pow(x, i)*element for i, x in enumerate(self.coefficients)) diff --git a/da/test_kzg.py b/da/test_kzg.py index 7a0be8b..acd4554 100644 --- a/da/test_kzg.py +++ b/da/test_kzg.py @@ -28,4 +28,5 @@ class TestKZG(TestCase): rand_bytes = self.rand_bytes(32) commit = kzg.bytes_to_commitment(rand_bytes, GLOBAL_PARAMETERS) poly = kzg.bytes_to_polynomial(rand_bytes) - proof = kzg.generate_element_proof(poly[0], poly, GLOBAL_PARAMETERS) \ No newline at end of file + proof = kzg.generate_element_proof(poly[0], poly, GLOBAL_PARAMETERS) + self.assertEqual(len(proof), 48) \ No newline at end of file