diff --git a/da/kzg_rs/poly.py b/da/kzg_rs/poly.py index 7dfd8d2..524dacc 100644 --- a/da/kzg_rs/poly.py +++ b/da/kzg_rs/poly.py @@ -1,7 +1,9 @@ from itertools import zip_longest from typing import List, Sequence, Self -from sympy import ntt, intt +from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff + +from da.kzg_rs.common import ROOTS_OF_UNITY class Polynomial[T]: @@ -10,8 +12,11 @@ class Polynomial[T]: self.modulus = modulus @classmethod - def from_evaluations(cls, evaluations: Sequence[T], modulus) -> Self: - coefficients = intt(evaluations, prime=modulus) + def from_evaluations(cls, evaluations: Sequence[T], modulus, roots_of_unity: Sequence[int]=ROOTS_OF_UNITY) -> Self: + coefficients = [ + x % modulus + for x in map(int, interpolate_polynomialcoeff(ROOTS_OF_UNITY[:len(evaluations)], evaluations)) + ] return cls(coefficients, modulus) def __repr__(self): @@ -77,12 +82,16 @@ class Polynomial[T]: return self.coefficients[item] def __eq__(self, other): - return self.coefficients == other.coefficients and self.modulus == other.modulus + return ( + self.coefficients == other.coefficients and + self.modulus == other.modulus + ) - def eval(self, element): - return sum( - (pow(element, i)*x) % self.modulus for i, x in enumerate(self.coefficients) - ) % self.modulus + def eval(self, x): + return (self.coefficients[0] + sum( + (pow(x, i, mod=self.modulus)*coefficient) + for i, coefficient in enumerate(self.coefficients[1:], start=1) + )) % self.modulus def evaluation_form(self) -> List[T]: - return ntt(self.coefficients, prime=self.modulus) + return [self.eval(ROOTS_OF_UNITY[i]) for i in range(len(self))] \ No newline at end of file