From f6b7d5bd3e4c705a7c6678a77f67bb16adfccdc5 Mon Sep 17 00:00:00 2001 From: Daniel Sanchez Quiros Date: Mon, 4 Mar 2024 16:55:07 +0100 Subject: [PATCH] Use lagrange for interpolation --- da/kzg_rs/common.py | 5 +- da/kzg_rs/fft.py | 134 ++++++++++++++++++++++++++++++++++++------ da/kzg_rs/rs.py | 33 ++++++++--- da/kzg_rs/test_fft.py | 9 ++- 4 files changed, 150 insertions(+), 31 deletions(-) diff --git a/da/kzg_rs/common.py b/da/kzg_rs/common.py index 9eee3c2..bd28c8e 100644 --- a/da/kzg_rs/common.py +++ b/da/kzg_rs/common.py @@ -1,9 +1,11 @@ from typing import List import eth2spec.eip7594.mainnet +from eth2spec.eip7594.mainnet import BLSFieldElement from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed +from remerkleable.basic import uint64 -from da.kzg_rs.roots import compute_roots_of_unity +from da.kzg_rs.fft import compute_roots_of_unity, compute_inverse_roots_of_unity from da.kzg_rs.trusted_setup import generate_setup G1 = G1Uncompressed @@ -17,3 +19,4 @@ GLOBAL_PARAMETERS_G2: List[G2] # secret is fixed but this should come from a different synchronization protocol GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(1024, 8, 1987)) ROOTS_OF_UNITY: List[int] = compute_roots_of_unity(2, BLS_MODULUS, 4096) +INVERSE_ROOTS_OF_UNITY: List[int] = compute_inverse_roots_of_unity(2, BLS_MODULUS, 4096) diff --git a/da/kzg_rs/fft.py b/da/kzg_rs/fft.py index 2760279..3ca29f8 100644 --- a/da/kzg_rs/fft.py +++ b/da/kzg_rs/fft.py @@ -1,24 +1,122 @@ -from da.kzg_rs.common import BLS_MODULUS +# def __fft(vals, modulus, roots_of_unity): +# if len(vals) == 1: +# return vals +# left = __fft(vals[::2], modulus, roots_of_unity[::2]) +# right = __fft(vals[1::2], modulus, roots_of_unity[::2]) +# o = [0 for _ in vals] +# for i, (x, y) in enumerate(zip(left, right)): +# y_times_root = y*int(roots_of_unity[i]) % modulus +# o[i] = (x+y_times_root) % modulus +# o[i+len(left)] = (x+modulus-y_times_root) % modulus +# return o +# +# +# def fft(vals, modulus, roots_of_unity): +# return __fft(vals, modulus, roots_of_unity) +# +# +# def ifft(vals, modulus, factor, roots_of_unity): +# # Inverse FFT +# invlen = pow(len(vals), modulus - factor, modulus) +# return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])] + +import math -def __fft(vals, modulus, roots_of_unity): - if len(vals) == 1: - return vals - left = __fft(vals[::2], modulus, roots_of_unity[::2]) - right = __fft(vals[1::2], modulus, roots_of_unity[::2]) - o = [0 for _ in vals] - for i, (x, y) in enumerate(zip(left, right)): - y_times_root = y*int(roots_of_unity[i]) % modulus - o[i] = (x+y_times_root) % modulus - o[i+len(left)] = (x+modulus-y_times_root) % modulus - return o +def fft(x, p, roots_of_unity): + """ + Compute the FFT of a sequence x modulo p using precomputed roots of unity. + + Parameters: + x (list): Sequence of integers. + p (int): Modulus. + roots_of_unity (list): List of precomputed roots of unity modulo p. + + Returns: + list: FFT of the sequence x. + """ + N = len(x) + if N == 1: + return x + even = fft(x[0::2], p, roots_of_unity) + odd = fft(x[1::2], p, roots_of_unity) + factor = 1 + result = [0] * N + for i in range(N // 2): + result[i] = (even[i] + factor * odd[i]) % p + result[i + N // 2] = (even[i] - factor * odd[i]) % p + factor = (factor * roots_of_unity[i]) % p + return result -def fft(vals, modulus, roots_of_unity): - return __fft(vals, modulus, roots_of_unity) +def ifft(y, p, inverse_roots_of_unity): + """ + Compute the inverse FFT of a sequence y modulo p using precomputed inverse roots of unity. + + Parameters: + y (list): Sequence of integers. + p (int): Modulus. + inverse_roots_of_unity (list): List of precomputed inverse roots of unity modulo p. + + Returns: + list: Inverse FFT of the sequence y. + """ + N = len(y) + if N == 1: + return y + even = ifft(y[0::2], p, inverse_roots_of_unity) + odd = ifft(y[1::2], p, inverse_roots_of_unity) + factor = 1 + result = [0] * N + for i in range(N // 2): + result[i] = (even[i] + factor * odd[i]) % p + result[i + N // 2] = (even[i] - factor * odd[i]) % p + factor = (factor * inverse_roots_of_unity[i]) % p + return result -def ifft(vals, modulus, factor, roots_of_unity): - # Inverse FFT - invlen = pow(len(vals), modulus - factor, modulus) - return [(x * invlen) % modulus for x in __fft(vals, modulus, roots_of_unity[:0:-1])] +def find_inverse_primitive_root(primitive_root, p): + """ + Find the inverse primitive root modulo p. + + Parameters: + primitive_root (int): Primitive root modulo p. + p (int): Modulus. + + Returns: + int: Inverse primitive root modulo p. + """ + return pow(primitive_root, p - 2, p) + + +def compute_roots_of_unity(primitive_root, p, n): + """ + Compute the roots of unity modulo p. + + Parameters: + primitive_root (int): Primitive root modulo p. + p (int): Modulus. + n (int): Number of roots of unity to compute. + + Returns: + list: List of roots of unity modulo p. + """ + roots_of_unity = [pow(primitive_root, i, p) for i in range(n)] + return roots_of_unity + + +def compute_inverse_roots_of_unity(primitive_root, p, n): + """ + Compute the inverse roots of unity modulo p. + + Parameters: + primitive_root (int): Primitive root modulo p. + p (int): Modulus. + n (int): Number of roots of unity to compute. + + Returns: + list: List of inverse roots of unity modulo p. + """ + inverse_primitive_root = find_inverse_primitive_root(primitive_root, p) + inverse_roots_of_unity = [pow(inverse_primitive_root, i, p) for i in range(n)] + return inverse_roots_of_unity diff --git a/da/kzg_rs/rs.py b/da/kzg_rs/rs.py index 8ff686a..6cbe615 100644 --- a/da/kzg_rs/rs.py +++ b/da/kzg_rs/rs.py @@ -1,21 +1,36 @@ -from typing import Sequence +from typing import Sequence, List +import scipy.interpolate from eth2spec.deneb.mainnet import BLSFieldElement -from .common import G1 +from .common import G1, BLS_MODULUS from .poly import Polynomial from .fft import fft, ifft +ExtendedData = Sequence[BLSFieldElement] -def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial: + +def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> ExtendedData: assert factor >= 2 assert len(polynomial)*factor <= len(roots_of_unity) - extended_polynomial_coefficients = polynomial.coefficients + [0]*(len(polynomial)*factor-1) - extended_polynomial_coefficients = fft(extended_polynomial_coefficients, polynomial.modulus, roots_of_unity) - return Polynomial(extended_polynomial_coefficients, modulus=polynomial.modulus) + extended_polynomial_evaluations = polynomial.coefficients + [0]*(len(polynomial)*factor-1) + extended_polynomial_evaluations = [ + BLSFieldElement(e % polynomial.modulus) + for e in fft(extended_polynomial_evaluations, polynomial.modulus, roots_of_unity) + ] + return extended_polynomial_evaluations -def decode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[BLSFieldElement]) -> Polynomial: - coefficients = ifft(polynomial.coefficients, polynomial.modulus, factor, roots_of_unity) - return Polynomial(coefficients=coefficients, modulus=polynomial.modulus) +def __interpolate(evaluations: List[int], roots_of_unity: List[int], modulus=BLS_MODULUS) -> List[int]: + """ + Lagrange interpolation + """ + assert len(evaluations) <= len(roots_of_unity) + coefs = scipy.interpolate.lagrange(roots_of_unity[:len(evaluations)], evaluations).coef + return [coef % modulus for coef in coefs] + + +def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], original_len: int) -> Polynomial: + coefs = __interpolate(list(map(int, encoded)), list(map(int, roots_of_unity))) + return Polynomial([int(c) for c in coefs], BLS_MODULUS) diff --git a/da/kzg_rs/test_fft.py b/da/kzg_rs/test_fft.py index 7f85d23..5691fec 100644 --- a/da/kzg_rs/test_fft.py +++ b/da/kzg_rs/test_fft.py @@ -1,14 +1,17 @@ from unittest import TestCase -from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY +from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY, INVERSE_ROOTS_OF_UNITY +from da.kzg_rs.fft import fft, ifft from da.kzg_rs.poly import Polynomial from da.kzg_rs.rs import encode, decode class TestFFT(TestCase): + def test_encode_decode(self): poly = Polynomial(list(range(10)), modulus=BLS_MODULUS) encoded = encode(poly, 2, ROOTS_OF_UNITY) - decoded = decode(encoded, 2, ROOTS_OF_UNITY) + decoded = decode(encoded, ROOTS_OF_UNITY, len(poly)) + self.assertEqual(poly, decoded) for i in range(len(poly)): - self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(int(ROOTS_OF_UNITY[i]))) + self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))