Use lagrange for interpolation
This commit is contained in:
parent
d75eb2949f
commit
f6b7d5bd3e
|
@ -1,9 +1,11 @@
|
|||
from typing import List
|
||||
|
||||
import eth2spec.eip7594.mainnet
|
||||
from eth2spec.eip7594.mainnet import BLSFieldElement
|
||||
from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed
|
||||
from remerkleable.basic import uint64
|
||||
|
||||
from da.kzg_rs.roots import compute_roots_of_unity
|
||||
from da.kzg_rs.fft import compute_roots_of_unity, compute_inverse_roots_of_unity
|
||||
from da.kzg_rs.trusted_setup import generate_setup
|
||||
|
||||
G1 = G1Uncompressed
|
||||
|
@ -17,3 +19,4 @@ GLOBAL_PARAMETERS_G2: List[G2]
|
|||
# secret is fixed but this should come from a different synchronization protocol
|
||||
GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(1024, 8, 1987))
|
||||
ROOTS_OF_UNITY: List[int] = compute_roots_of_unity(2, BLS_MODULUS, 4096)
|
||||
INVERSE_ROOTS_OF_UNITY: List[int] = compute_inverse_roots_of_unity(2, BLS_MODULUS, 4096)
|
||||
|
|
134
da/kzg_rs/fft.py
134
da/kzg_rs/fft.py
|
@ -1,24 +1,122 @@
|
|||
from da.kzg_rs.common import BLS_MODULUS
|
||||
# def __fft(vals, modulus, roots_of_unity):
|
||||
# if len(vals) == 1:
|
||||
# return vals
|
||||
# left = __fft(vals[::2], modulus, roots_of_unity[::2])
|
||||
# right = __fft(vals[1::2], modulus, roots_of_unity[::2])
|
||||
# o = [0 for _ in vals]
|
||||
# for i, (x, y) in enumerate(zip(left, right)):
|
||||
# y_times_root = y*int(roots_of_unity[i]) % modulus
|
||||
# o[i] = (x+y_times_root) % modulus
|
||||
# o[i+len(left)] = (x+modulus-y_times_root) % modulus
|
||||
# return o
|
||||
#
|
||||
#
|
||||
# def fft(vals, modulus, roots_of_unity):
|
||||
# return __fft(vals, modulus, roots_of_unity)
|
||||
#
|
||||
#
|
||||
# def ifft(vals, modulus, factor, roots_of_unity):
|
||||
# # Inverse FFT
|
||||
# invlen = pow(len(vals), modulus - factor, modulus)
|
||||
# return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])]
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def __fft(vals, modulus, roots_of_unity):
|
||||
if len(vals) == 1:
|
||||
return vals
|
||||
left = __fft(vals[::2], modulus, roots_of_unity[::2])
|
||||
right = __fft(vals[1::2], modulus, roots_of_unity[::2])
|
||||
o = [0 for _ in vals]
|
||||
for i, (x, y) in enumerate(zip(left, right)):
|
||||
y_times_root = y*int(roots_of_unity[i]) % modulus
|
||||
o[i] = (x+y_times_root) % modulus
|
||||
o[i+len(left)] = (x+modulus-y_times_root) % modulus
|
||||
return o
|
||||
def fft(x, p, roots_of_unity):
|
||||
"""
|
||||
Compute the FFT of a sequence x modulo p using precomputed roots of unity.
|
||||
|
||||
Parameters:
|
||||
x (list): Sequence of integers.
|
||||
p (int): Modulus.
|
||||
roots_of_unity (list): List of precomputed roots of unity modulo p.
|
||||
|
||||
Returns:
|
||||
list: FFT of the sequence x.
|
||||
"""
|
||||
N = len(x)
|
||||
if N == 1:
|
||||
return x
|
||||
even = fft(x[0::2], p, roots_of_unity)
|
||||
odd = fft(x[1::2], p, roots_of_unity)
|
||||
factor = 1
|
||||
result = [0] * N
|
||||
for i in range(N // 2):
|
||||
result[i] = (even[i] + factor * odd[i]) % p
|
||||
result[i + N // 2] = (even[i] - factor * odd[i]) % p
|
||||
factor = (factor * roots_of_unity[i]) % p
|
||||
return result
|
||||
|
||||
|
||||
def fft(vals, modulus, roots_of_unity):
|
||||
return __fft(vals, modulus, roots_of_unity)
|
||||
def ifft(y, p, inverse_roots_of_unity):
|
||||
"""
|
||||
Compute the inverse FFT of a sequence y modulo p using precomputed inverse roots of unity.
|
||||
|
||||
Parameters:
|
||||
y (list): Sequence of integers.
|
||||
p (int): Modulus.
|
||||
inverse_roots_of_unity (list): List of precomputed inverse roots of unity modulo p.
|
||||
|
||||
Returns:
|
||||
list: Inverse FFT of the sequence y.
|
||||
"""
|
||||
N = len(y)
|
||||
if N == 1:
|
||||
return y
|
||||
even = ifft(y[0::2], p, inverse_roots_of_unity)
|
||||
odd = ifft(y[1::2], p, inverse_roots_of_unity)
|
||||
factor = 1
|
||||
result = [0] * N
|
||||
for i in range(N // 2):
|
||||
result[i] = (even[i] + factor * odd[i]) % p
|
||||
result[i + N // 2] = (even[i] - factor * odd[i]) % p
|
||||
factor = (factor * inverse_roots_of_unity[i]) % p
|
||||
return result
|
||||
|
||||
|
||||
def ifft(vals, modulus, factor, roots_of_unity):
|
||||
# Inverse FFT
|
||||
invlen = pow(len(vals), modulus - factor, modulus)
|
||||
return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])]
|
||||
def find_inverse_primitive_root(primitive_root, p):
|
||||
"""
|
||||
Find the inverse primitive root modulo p.
|
||||
|
||||
Parameters:
|
||||
primitive_root (int): Primitive root modulo p.
|
||||
p (int): Modulus.
|
||||
|
||||
Returns:
|
||||
int: Inverse primitive root modulo p.
|
||||
"""
|
||||
return pow(primitive_root, p - 2, p)
|
||||
|
||||
|
||||
def compute_roots_of_unity(primitive_root, p, n):
|
||||
"""
|
||||
Compute the roots of unity modulo p.
|
||||
|
||||
Parameters:
|
||||
primitive_root (int): Primitive root modulo p.
|
||||
p (int): Modulus.
|
||||
n (int): Number of roots of unity to compute.
|
||||
|
||||
Returns:
|
||||
list: List of roots of unity modulo p.
|
||||
"""
|
||||
roots_of_unity = [pow(primitive_root, i, p) for i in range(n)]
|
||||
return roots_of_unity
|
||||
|
||||
|
||||
def compute_inverse_roots_of_unity(primitive_root, p, n):
|
||||
"""
|
||||
Compute the inverse roots of unity modulo p.
|
||||
|
||||
Parameters:
|
||||
primitive_root (int): Primitive root modulo p.
|
||||
p (int): Modulus.
|
||||
n (int): Number of roots of unity to compute.
|
||||
|
||||
Returns:
|
||||
list: List of inverse roots of unity modulo p.
|
||||
"""
|
||||
inverse_primitive_root = find_inverse_primitive_root(primitive_root, p)
|
||||
inverse_roots_of_unity = [pow(inverse_primitive_root, i, p) for i in range(n)]
|
||||
return inverse_roots_of_unity
|
||||
|
|
|
@ -1,21 +1,36 @@
|
|||
from typing import Sequence
|
||||
from typing import Sequence, List
|
||||
|
||||
import scipy.interpolate
|
||||
from eth2spec.deneb.mainnet import BLSFieldElement
|
||||
|
||||
from .common import G1
|
||||
from .common import G1, BLS_MODULUS
|
||||
from .poly import Polynomial
|
||||
from .fft import fft, ifft
|
||||
|
||||
ExtendedData = Sequence[BLSFieldElement]
|
||||
|
||||
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
|
||||
|
||||
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> ExtendedData:
|
||||
assert factor >= 2
|
||||
assert len(polynomial)*factor <= len(roots_of_unity)
|
||||
extended_polynomial_coefficients = polynomial.coefficients + [0]*(len(polynomial)*factor-1)
|
||||
extended_polynomial_coefficients = fft(extended_polynomial_coefficients, polynomial.modulus, roots_of_unity)
|
||||
return Polynomial(extended_polynomial_coefficients, modulus=polynomial.modulus)
|
||||
extended_polynomial_evaluations = polynomial.coefficients + [0]*(len(polynomial)*factor-1)
|
||||
extended_polynomial_evaluations = [
|
||||
BLSFieldElement(e % polynomial.modulus)
|
||||
for e in fft(extended_polynomial_evaluations, polynomial.modulus, roots_of_unity)
|
||||
]
|
||||
return extended_polynomial_evaluations
|
||||
|
||||
|
||||
def decode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial:
|
||||
coefficients = ifft(polynomial.coefficients, polynomial.modulus, factor, roots_of_unity)
|
||||
return Polynomial(coefficients=coefficients, modulus=polynomial.modulus)
|
||||
def __interpolate(evaluations: List[int], roots_of_unity: List[int], modulus=BLS_MODULUS) -> List[int]:
|
||||
"""
|
||||
Lagrange interpolation
|
||||
"""
|
||||
assert len(evaluations) <= len(roots_of_unity)
|
||||
coefs = scipy.interpolate.lagrange(roots_of_unity[:len(evaluations)], evaluations).coef
|
||||
return [coef % modulus for coef in coefs]
|
||||
|
||||
|
||||
def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], original_len: int) -> Polynomial:
|
||||
coefs = __interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))
|
||||
return Polynomial([int(c) for c in coefs], BLS_MODULUS)
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
|
||||
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY, INVERSE_ROOTS_OF_UNITY
|
||||
from da.kzg_rs.fft import fft, ifft
|
||||
from da.kzg_rs.poly import Polynomial
|
||||
from da.kzg_rs.rs import encode, decode
|
||||
|
||||
|
||||
class TestFFT(TestCase):
|
||||
|
||||
def test_encode_decode(self):
|
||||
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
|
||||
encoded = encode(poly, 2, ROOTS_OF_UNITY)
|
||||
decoded = decode(encoded, 2, ROOTS_OF_UNITY)
|
||||
decoded = decode(encoded, ROOTS_OF_UNITY, len(poly))
|
||||
self.assertEqual(poly, decoded)
|
||||
for i in range(len(poly)):
|
||||
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(int(ROOTS_OF_UNITY[i])))
|
||||
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))
|
||||
|
|
Loading…
Reference in New Issue