Fix fft helper and add basic test

This commit is contained in:
Hsiao-Wei Wang 2024-01-05 20:31:09 +08:00
parent 7f8858b2ac
commit 03583b1b47
No known key found for this signature in database
GPG Key ID: AE3D6B174F971DE4
7 changed files with 55 additions and 11 deletions

View File

@ -110,28 +110,31 @@ def g2_lincomb(points: Sequence[KZGCommitment], scalars: Sequence[BLSFieldElemen
#### `_fft_field`
```python
def _fft_field(vals, roots_of_unity):
if len(vals) == 0:
def _fft_field(vals: Sequence[BLSFieldElement],
roots_of_unity: Sequence[BLSFieldElement]) -> Sequence[BLSFieldElement]:
if len(vals) == 1:
return vals
L = _fft_field(vals[::2], roots_of_unity[::2])
R = _fft_field(vals[1::2], roots_of_unity[::2])
o = [0 for i in vals]
o = [BLSFieldElement(0) for _ in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = int(y) * int(roots_of_unity[i]) % BLS_MODULUS
o[i] = (x + y_times_root) % BLS_MODULUS
o[i + len(L)] = (x - y_times_root + BLS_MODULUS) % BLS_MODULUS
y_times_root = (int(y) * int(roots_of_unity[i])) % BLS_MODULUS
o[i] = BLSFieldElement((int(x) + y_times_root) % BLS_MODULUS)
o[i + len(L)] = BLSFieldElement((int(x) - y_times_root + BLS_MODULUS) % BLS_MODULUS)
return o
```
#### `fft_field`
```python
def fft_field(vals, roots_of_unity, inv=False):
def fft_field(vals: Sequence[BLSFieldElement],
roots_of_unity: Sequence[BLSFieldElement],
inv: bool=False) -> Sequence[BLSFieldElement]:
if inv:
# Inverse FFT
invlen = pow(len(vals), BLS_MODULUS - 2, BLS_MODULUS)
return [(x * invlen) % BLS_MODULUS for x in
_fft_field(vals, roots_of_unity[0:1] + roots_of_unity[:0:-1])]
return [BLSFieldElement((int(x) * invlen) % BLS_MODULUS)
for x in _fft_field(vals, roots_of_unity[0:1] + roots_of_unity[:0:-1])]
else:
# Regular FFT
return _fft_field(vals, roots_of_unity)
@ -272,7 +275,7 @@ def evaluate_polynomialcoeff(polynomial_coeff: PolynomialCoeff, z: BLSFieldEleme
"""
y = 0
for coef in polynomial_coeff[::-1]:
y = (int(y) * int(z) + coef) % BLS_MODULUS
y = (int(y) * int(z) + int(coef)) % BLS_MODULUS
return BLSFieldElement(y % BLS_MODULUS)
```

View File

@ -8,7 +8,7 @@ from eth2spec.utils import bls
from .exceptions import SkippedTest
from .helpers.constants import (
PHASE0, ALTAIR, BELLATRIX, CAPELLA, DENEB,
EIP6110, EIP7002,
EIP6110, EIP7002, PEERDAS,
WHISK,
MINIMAL,
ALL_PHASES,
@ -510,6 +510,7 @@ with_deneb_and_later = with_all_phases_from(DENEB)
with_eip6110_and_later = with_all_phases_from(EIP6110)
with_eip7002_and_later = with_all_phases_from(EIP7002)
with_whisk_and_later = with_all_phases_from(WHISK, all_phases=ALLOWED_TEST_RUNNER_FORKS)
with_peerdas_and_later = with_all_phases_from(PEERDAS, all_phases=ALLOWED_TEST_RUNNER_FORKS)
class quoted_str(str):

View File

@ -0,0 +1,30 @@
from eth2spec.test.context import (
spec_test,
single_phase,
with_peerdas_and_later,
)
from eth2spec.test.helpers.sharding import (
get_sample_blob,
)
@with_peerdas_and_later
@spec_test
@single_phase
def test_fft(spec):
vals = [int.from_bytes(x, spec.KZG_ENDIANNESS) for x in spec.KZG_SETUP_G1_MONOMIAL]
roots_of_unity = spec.ROOTS_OF_UNITY
result = spec.fft_field(vals, roots_of_unity)
assert len(result) == len(vals)
# TODO: add more assertions?
@with_peerdas_and_later
@spec_test
@single_phase
def test_compute_and_verify_cells_and_proofs(spec):
blob = get_sample_blob(spec)
commitment = spec.blob_to_kzg_commitment(blob)
cells, proofs = spec.compute_cells_and_proofs(blob)
cell_id = 0
assert spec.verify_cell_proof(commitment, cell_id, cells[cell_id], proofs[cell_id])

View File

@ -4,6 +4,7 @@ from py_ecc.optimized_bls12_381 import ( # noqa: F401
G1 as py_ecc_G1,
G2 as py_ecc_G2,
Z1 as py_ecc_Z1,
Z2 as py_ecc_Z2,
add as py_ecc_add,
multiply as py_ecc_mul,
neg as py_ecc_neg,
@ -243,6 +244,15 @@ def Z1():
return py_ecc_Z1
def Z2():
"""
Returns the identity point in G2
"""
if bls == arkworks_bls or bls == fastest_bls:
return arkworks_G2.identity()
return py_ecc_Z2
def G1():
"""
Returns the chosen generator point in G1