Use lagrange for interpolation

This commit is contained in:
Daniel Sanchez Quiros 2024-03-04 16:55:07 +01:00
parent d75eb2949f
commit f6b7d5bd3e
4 changed files with 150 additions and 31 deletions

View File

@ -1,9 +1,11 @@
from typing import List
import eth2spec.eip7594.mainnet
from eth2spec.eip7594.mainnet import BLSFieldElement
from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed
from remerkleable.basic import uint64
from da.kzg_rs.roots import compute_roots_of_unity
from da.kzg_rs.fft import compute_roots_of_unity, compute_inverse_roots_of_unity
from da.kzg_rs.trusted_setup import generate_setup
G1 = G1Uncompressed
@ -17,3 +19,4 @@ GLOBAL_PARAMETERS_G2: List[G2]
# secret is fixed but this should come from a different synchronization protocol
GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(1024, 8, 1987))
ROOTS_OF_UNITY: List[int] = compute_roots_of_unity(2, BLS_MODULUS, 4096)
INVERSE_ROOTS_OF_UNITY: List[int] = compute_inverse_roots_of_unity(2, BLS_MODULUS, 4096)

View File

@ -1,24 +1,122 @@
from da.kzg_rs.common import BLS_MODULUS
# def __fft(vals, modulus, roots_of_unity):
# if len(vals) == 1:
# return vals
# left = __fft(vals[::2], modulus, roots_of_unity[::2])
# right = __fft(vals[1::2], modulus, roots_of_unity[::2])
# o = [0 for _ in vals]
# for i, (x, y) in enumerate(zip(left, right)):
# y_times_root = y*int(roots_of_unity[i]) % modulus
# o[i] = (x+y_times_root) % modulus
# o[i+len(left)] = (x+modulus-y_times_root) % modulus
# return o
#
#
# def fft(vals, modulus, roots_of_unity):
# return __fft(vals, modulus, roots_of_unity)
#
#
# def ifft(vals, modulus, factor, roots_of_unity):
# # Inverse FFT
# invlen = pow(len(vals), modulus - factor, modulus)
# return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])]
import math
def __fft(vals, modulus, roots_of_unity):
if len(vals) == 1:
return vals
left = __fft(vals[::2], modulus, roots_of_unity[::2])
right = __fft(vals[1::2], modulus, roots_of_unity[::2])
o = [0 for _ in vals]
for i, (x, y) in enumerate(zip(left, right)):
y_times_root = y*int(roots_of_unity[i]) % modulus
o[i] = (x+y_times_root) % modulus
o[i+len(left)] = (x+modulus-y_times_root) % modulus
return o
def fft(x, p, roots_of_unity):
"""
Compute the FFT of a sequence x modulo p using precomputed roots of unity.
Parameters:
x (list): Sequence of integers.
p (int): Modulus.
roots_of_unity (list): List of precomputed roots of unity modulo p.
Returns:
list: FFT of the sequence x.
"""
N = len(x)
if N == 1:
return x
even = fft(x[0::2], p, roots_of_unity)
odd = fft(x[1::2], p, roots_of_unity)
factor = 1
result = [0] * N
for i in range(N // 2):
result[i] = (even[i] + factor * odd[i]) % p
result[i + N // 2] = (even[i] - factor * odd[i]) % p
factor = (factor * roots_of_unity[i]) % p
return result
def fft(vals, modulus, roots_of_unity):
return __fft(vals, modulus, roots_of_unity)
def ifft(y, p, inverse_roots_of_unity):
"""
Compute the inverse FFT of a sequence y modulo p using precomputed inverse roots of unity.
Parameters:
y (list): Sequence of integers.
p (int): Modulus.
inverse_roots_of_unity (list): List of precomputed inverse roots of unity modulo p.
Returns:
list: Inverse FFT of the sequence y.
"""
N = len(y)
if N == 1:
return y
even = ifft(y[0::2], p, inverse_roots_of_unity)
odd = ifft(y[1::2], p, inverse_roots_of_unity)
factor = 1
result = [0] * N
for i in range(N // 2):
result[i] = (even[i] + factor * odd[i]) % p
result[i + N // 2] = (even[i] - factor * odd[i]) % p
factor = (factor * inverse_roots_of_unity[i]) % p
return result
def ifft(vals, modulus, factor, roots_of_unity):
# Inverse FFT
invlen = pow(len(vals), modulus - factor, modulus)
return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])]
def find_inverse_primitive_root(primitive_root, p):
"""
Find the inverse primitive root modulo p.
Parameters:
primitive_root (int): Primitive root modulo p.
p (int): Modulus.
Returns:
int: Inverse primitive root modulo p.
"""
return pow(primitive_root, p - 2, p)
def compute_roots_of_unity(primitive_root, p, n):
"""
Compute the roots of unity modulo p.
Parameters:
primitive_root (int): Primitive root modulo p.
p (int): Modulus.
n (int): Number of roots of unity to compute.
Returns:
list: List of roots of unity modulo p.
"""
roots_of_unity = [pow(primitive_root, i, p) for i in range(n)]
return roots_of_unity
def compute_inverse_roots_of_unity(primitive_root, p, n):
"""
Compute the inverse roots of unity modulo p.
Parameters:
primitive_root (int): Primitive root modulo p.
p (int): Modulus.
n (int): Number of roots of unity to compute.
Returns:
list: List of inverse roots of unity modulo p.
"""
inverse_primitive_root = find_inverse_primitive_root(primitive_root, p)
inverse_roots_of_unity = [pow(inverse_primitive_root, i, p) for i in range(n)]
return inverse_roots_of_unity

View File

@ -1,21 +1,36 @@
from typing import Sequence
from typing import Sequence, List
import scipy.interpolate
from eth2spec.deneb.mainnet import BLSFieldElement
from .common import G1
from .common import G1, BLS_MODULUS
from .poly import Polynomial
from .fft import fft, ifft
ExtendedData = Sequence[BLSFieldElement]
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> ExtendedData:
assert factor >= 2
assert len(polynomial)*factor <= len(roots_of_unity)
extended_polynomial_coefficients = polynomial.coefficients + [0]*(len(polynomial)*factor-1)
extended_polynomial_coefficients = fft(extended_polynomial_coefficients, polynomial.modulus, roots_of_unity)
return Polynomial(extended_polynomial_coefficients, modulus=polynomial.modulus)
extended_polynomial_evaluations = polynomial.coefficients + [0]*(len(polynomial)*factor-1)
extended_polynomial_evaluations = [
BLSFieldElement(e % polynomial.modulus)
for e in fft(extended_polynomial_evaluations, polynomial.modulus, roots_of_unity)
]
return extended_polynomial_evaluations
def decode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
coefficients = ifft(polynomial.coefficients, polynomial.modulus, factor, roots_of_unity)
return Polynomial(coefficients=coefficients, modulus=polynomial.modulus)
def __interpolate(evaluations: List[int], roots_of_unity: List[int], modulus=BLS_MODULUS) -> List[int]:
"""
Lagrange interpolation
"""
assert len(evaluations) <= len(roots_of_unity)
coefs = scipy.interpolate.lagrange(roots_of_unity[:len(evaluations)], evaluations).coef
return [coef % modulus for coef in coefs]
def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], original_len: int) -> Polynomial:
coefs = __interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))
return Polynomial([int(c) for c in coefs], BLS_MODULUS)

View File

@ -1,14 +1,17 @@
from unittest import TestCase
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY, INVERSE_ROOTS_OF_UNITY
from da.kzg_rs.fft import fft, ifft
from da.kzg_rs.poly import Polynomial
from da.kzg_rs.rs import encode, decode
class TestFFT(TestCase):
def test_encode_decode(self):
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
encoded = encode(poly, 2, ROOTS_OF_UNITY)
decoded = decode(encoded, 2, ROOTS_OF_UNITY)
decoded = decode(encoded, ROOTS_OF_UNITY, len(poly))
self.assertEqual(poly, decoded)
for i in range(len(poly)):
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(int(ROOTS_OF_UNITY[i])))
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))