Fix polynomial from/to evaluations

This commit is contained in:
Daniel Sanchez Quiros 2024-03-06 16:02:36 +01:00
parent 7dbd5ef351
commit 0f27a16468
1 changed files with 18 additions and 9 deletions

View File

@ -1,7 +1,9 @@
from itertools import zip_longest
from typing import List, Sequence, Self
from sympy import ntt, intt
from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff
from da.kzg_rs.common import ROOTS_OF_UNITY
class Polynomial[T]:
@ -10,8 +12,11 @@ class Polynomial[T]:
self.modulus = modulus
@classmethod
def from_evaluations(cls, evaluations: Sequence[T], modulus) -> Self:
coefficients = intt(evaluations, prime=modulus)
def from_evaluations(cls, evaluations: Sequence[T], modulus, roots_of_unity: Sequence[int]=ROOTS_OF_UNITY) -> Self:
coefficients = [
x % modulus
for x in map(int, interpolate_polynomialcoeff(ROOTS_OF_UNITY[:len(evaluations)], evaluations))
]
return cls(coefficients, modulus)
def __repr__(self):
@ -77,12 +82,16 @@ class Polynomial[T]:
return self.coefficients[item]
def __eq__(self, other):
return self.coefficients == other.coefficients and self.modulus == other.modulus
return (
self.coefficients == other.coefficients and
self.modulus == other.modulus
)
def eval(self, element):
return sum(
(pow(element, i)*x) % self.modulus for i, x in enumerate(self.coefficients)
) % self.modulus
def eval(self, x):
return (self.coefficients[0] + sum(
(pow(x, i, mod=self.modulus)*coefficient)
for i, coefficient in enumerate(self.coefficients[1:], start=1)
)) % self.modulus
def evaluation_form(self) -> List[T]:
return ntt(self.coefficients, prime=self.modulus)
return [self.eval(ROOTS_OF_UNITY[i]) for i in range(len(self))]