Fix polynomial from/to evaluations

This commit is contained in:
Daniel Sanchez Quiros 2024-03-06 16:02:36 +01:00
parent 7dbd5ef351
commit 0f27a16468
1 changed files with 18 additions and 9 deletions

View File

@ -1,7 +1,9 @@
from itertools import zip_longest from itertools import zip_longest
from typing import List, Sequence, Self from typing import List, Sequence, Self
from sympy import ntt, intt from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff
from da.kzg_rs.common import ROOTS_OF_UNITY
class Polynomial[T]: class Polynomial[T]:
@ -10,8 +12,11 @@ class Polynomial[T]:
self.modulus = modulus self.modulus = modulus
@classmethod @classmethod
def from_evaluations(cls, evaluations: Sequence[T], modulus) -> Self: def from_evaluations(cls, evaluations: Sequence[T], modulus, roots_of_unity: Sequence[int]=ROOTS_OF_UNITY) -> Self:
coefficients = intt(evaluations, prime=modulus) coefficients = [
x % modulus
for x in map(int, interpolate_polynomialcoeff(ROOTS_OF_UNITY[:len(evaluations)], evaluations))
]
return cls(coefficients, modulus) return cls(coefficients, modulus)
def __repr__(self): def __repr__(self):
@ -77,12 +82,16 @@ class Polynomial[T]:
return self.coefficients[item] return self.coefficients[item]
def __eq__(self, other): def __eq__(self, other):
return self.coefficients == other.coefficients and self.modulus == other.modulus return (
self.coefficients == other.coefficients and
self.modulus == other.modulus
)
def eval(self, element): def eval(self, x):
return sum( return (self.coefficients[0] + sum(
(pow(element, i)*x) % self.modulus for i, x in enumerate(self.coefficients) (pow(x, i, mod=self.modulus)*coefficient)
) % self.modulus for i, coefficient in enumerate(self.coefficients[1:], start=1)
)) % self.modulus
def evaluation_form(self) -> List[T]: def evaluation_form(self) -> List[T]:
return ntt(self.coefficients, prime=self.modulus) return [self.eval(ROOTS_OF_UNITY[i]) for i in range(len(self))]