Randomize the test_recover_polynomial()

This commit is contained in:
George Kadianakis 2024-01-09 15:41:36 +02:00
parent 09c2519938
commit a58c86832a
2 changed files with 36 additions and 8 deletions

View File

@ -486,7 +486,7 @@ def recover_polynomial(cell_ids: Sequence[CellID], cells: Sequence[Cell]) -> Pol
extended_evaluation_rbo[start:end] = cell
extended_evaluation = bit_reversal_permutation(extended_evaluation_rbo)
extended_evaluation_times_zero = [BLSFieldElement(a * b % BLS_MODULUS)
extended_evaluation_times_zero = [BLSFieldElement(int(a) * int(b) % BLS_MODULUS)
for a, b in zip(zero_poly_eval, extended_evaluation)]
extended_evaluations_fft = fft_field(extended_evaluation_times_zero, ROOTS_OF_UNITY_EXTENDED, inv=True)

View File

@ -1,3 +1,4 @@
import random
from eth2spec.test.context import (
spec_test,
single_phase,
@ -17,6 +18,7 @@ def test_fft(spec):
result = spec.fft_field(vals, roots_of_unity)
assert len(result) == len(vals)
# TODO: add more assertions?
# One possible test would be to use polynomial_eval_to_coeff()
@with_peerdas_and_later
@ -53,11 +55,37 @@ def test_verify_cell_proof_batch(spec):
@spec_test
@single_phase
def test_recover_polynomial(spec):
blob = get_sample_blob(spec)
original_polynomial = spec.blob_to_polynomial(blob)
cells = spec.compute_cells(blob)
cell_ids = list(range(spec.CELLS_PER_BLOB // 2))
known_cells = [cells[cell_id] for cell_id in cell_ids]
result = spec.recover_polynomial(cell_ids, known_cells)
rng = random.Random(5566)
assert original_polynomial == result[0:len(result) // 2]
# Number of samples we will be recovering from
N_SAMPLES = spec.CELLS_PER_BLOB // 2
# Get the data we will be working with
blob = get_sample_blob(spec)
# Get the data in evaluation form
original_polynomial = spec.blob_to_polynomial(blob)
# Extend data with Reed-Solomon and split the extended data in cells
cells = spec.compute_cells(blob)
# Compute the cells we will be recovering from
cell_ids = []
known_cells = []
# First figure out just the indices of the cells
for i in range(N_SAMPLES):
j = rng.randint(0, spec.CELLS_PER_BLOB)
while j in cell_ids:
j = rng.randint(0, spec.CELLS_PER_BLOB)
cell_ids.append(j)
# Now the cells themselves
known_cells = [cells[cell_id] for cell_id in cell_ids]
# Recover the data
recovered_data = spec.recover_polynomial(cell_ids, known_cells)
# Check that the original data match the non-extended portion of the recovered data
assert original_polynomial == recovered_data[:len(recovered_data) // 2]
# Now flatten the cells and check that they match the entirety of the recovered data
flattened_cells = [x for xs in cells for x in xs]
assert flattened_cells == recovered_data