Merge branch 'dev' into kw/optimize-compute-kzg-proof-multi

This commit is contained in:
Kevaundray Wedderburn 2024-04-19 12:04:15 +01:00
commit c2b7c0b414
4 changed files with 41 additions and 29 deletions

View File

@ -147,7 +147,7 @@ def recover_matrix(cells_dict: Dict[Tuple[BlobIndex, CellID], Cell], blob_count:
full_polynomial = recover_polynomial(cell_ids, cells_bytes) full_polynomial = recover_polynomial(cell_ids, cells_bytes)
cells_from_full_polynomial = [ cells_from_full_polynomial = [
full_polynomial[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL] full_polynomial[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL]
for i in range(CELLS_PER_BLOB) for i in range(CELLS_PER_EXT_BLOB)
] ]
extended_matrix.extend(cells_from_full_polynomial) extended_matrix.extend(cells_from_full_polynomial)
return ExtendedMatrix(extended_matrix) return ExtendedMatrix(extended_matrix)

View File

@ -84,7 +84,7 @@ Cells are the smallest unit of blob data that can come with their own KZG proofs
| `FIELD_ELEMENTS_PER_EXT_BLOB` | `2 * FIELD_ELEMENTS_PER_BLOB` | Number of field elements in a Reed-Solomon extended blob | | `FIELD_ELEMENTS_PER_EXT_BLOB` | `2 * FIELD_ELEMENTS_PER_BLOB` | Number of field elements in a Reed-Solomon extended blob |
| `FIELD_ELEMENTS_PER_CELL` | `uint64(64)` | Number of field elements in a cell | | `FIELD_ELEMENTS_PER_CELL` | `uint64(64)` | Number of field elements in a cell |
| `BYTES_PER_CELL` | `FIELD_ELEMENTS_PER_CELL * BYTES_PER_FIELD_ELEMENT` | The number of bytes in a cell | | `BYTES_PER_CELL` | `FIELD_ELEMENTS_PER_CELL * BYTES_PER_FIELD_ELEMENT` | The number of bytes in a cell |
| `CELLS_PER_BLOB` | `FIELD_ELEMENTS_PER_EXT_BLOB // FIELD_ELEMENTS_PER_CELL` | The number of cells in a blob | | `CELLS_PER_EXT_BLOB` | `FIELD_ELEMENTS_PER_EXT_BLOB // FIELD_ELEMENTS_PER_CELL` | The number of cells in an extended blob |
| `RANDOM_CHALLENGE_KZG_CELL_BATCH_DOMAIN` | `b'RCKZGCBATCH__V1_'` | | `RANDOM_CHALLENGE_KZG_CELL_BATCH_DOMAIN` | `b'RCKZGCBATCH__V1_'` |
## Helper functions ## Helper functions
@ -106,7 +106,7 @@ def bytes_to_cell(cell_bytes: Vector[Bytes32, FIELD_ELEMENTS_PER_CELL]) -> Cell:
#### `g2_lincomb` #### `g2_lincomb`
```python ```python
def g2_lincomb(points: Sequence[KZGCommitment], scalars: Sequence[BLSFieldElement]) -> Bytes96: def g2_lincomb(points: Sequence[G2Point], scalars: Sequence[BLSFieldElement]) -> Bytes96:
""" """
BLS multiscalar multiplication in G2. This function can be optimized using Pippenger's algorithm and variants. BLS multiscalar multiplication in G2. This function can be optimized using Pippenger's algorithm and variants.
""" """
@ -308,16 +308,27 @@ def compute_kzg_proof_multi_impl(
polynomial_coeff: PolynomialCoeff, polynomial_coeff: PolynomialCoeff,
zs: Sequence[BLSFieldElement]) -> Tuple[KZGProof, Sequence[BLSFieldElement]]: zs: Sequence[BLSFieldElement]) -> Tuple[KZGProof, Sequence[BLSFieldElement]]:
""" """
Helper function that computes multi-evaluation KZG proofs. Compute a KZG multi-evaluation proof for a set of `k` points.
This is done by committing to the following quotient polynomial:
Q(X) = f(X) - r(X) / Z(X)
Where:
- r(X) is the degree `k-1` polynomial that agrees with f(x) at all `k` points
- Z(X) is the degree `k` polynomial that evaluates to zero on all `k` points
We further note that since the degree of r(X) is less than the degree of Z(X),
the computation can be simplified in monomial form to Q(X) = f(X) / Z(X)
""" """
# For all x_i, compute p(x_i) - p(z) # For all points, compute the evaluation of those points
ys = [evaluate_polynomialcoeff(polynomial_coeff, z) for z in zs] ys = [evaluate_polynomialcoeff(polynomial_coeff, z) for z in zs]
# Compute r(X)
interpolation_polynomial = interpolate_polynomialcoeff(zs, ys)
# For all x_i, compute (x_i - z) # Compute Z(X)
denominator_poly = vanishing_polynomialcoeff(zs) denominator_poly = vanishing_polynomialcoeff(zs)
# Compute the quotient polynomial directly in evaluation form # Compute the quotient polynomial directly in monomial form
quotient_polynomial = divide_polynomialcoeff(polynomial_coeff, denominator_poly) quotient_polynomial = divide_polynomialcoeff(polynomial_coeff, denominator_poly)
return KZGProof(g1_lincomb(KZG_SETUP_G1_MONOMIAL[:len(quotient_polynomial)], quotient_polynomial)), ys return KZGProof(g1_lincomb(KZG_SETUP_G1_MONOMIAL[:len(quotient_polynomial)], quotient_polynomial)), ys
@ -357,7 +368,7 @@ def coset_for_cell(cell_id: CellID) -> Cell:
""" """
Get the coset for a given ``cell_id`` Get the coset for a given ``cell_id``
""" """
assert cell_id < CELLS_PER_BLOB assert cell_id < CELLS_PER_EXT_BLOB
roots_of_unity_brp = bit_reversal_permutation( roots_of_unity_brp = bit_reversal_permutation(
compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB) compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB)
) )
@ -372,10 +383,10 @@ def coset_for_cell(cell_id: CellID) -> Cell:
```python ```python
def compute_cells_and_proofs(blob: Blob) -> Tuple[ def compute_cells_and_proofs(blob: Blob) -> Tuple[
Vector[Cell, CELLS_PER_BLOB], Vector[Cell, CELLS_PER_EXT_BLOB],
Vector[KZGProof, CELLS_PER_BLOB]]: Vector[KZGProof, CELLS_PER_EXT_BLOB]]:
""" """
Compute all the cell proofs for one blob. This is an inefficient O(n^2) algorithm, Compute all the cell proofs for an extended blob. This is an inefficient O(n^2) algorithm,
for performant implementation the FK20 algorithm that runs in O(n log n) should be for performant implementation the FK20 algorithm that runs in O(n log n) should be
used instead. used instead.
@ -387,7 +398,7 @@ def compute_cells_and_proofs(blob: Blob) -> Tuple[
cells = [] cells = []
proofs = [] proofs = []
for i in range(CELLS_PER_BLOB): for i in range(CELLS_PER_EXT_BLOB):
coset = coset_for_cell(i) coset = coset_for_cell(i)
proof, ys = compute_kzg_proof_multi_impl(polynomial_coeff, coset) proof, ys = compute_kzg_proof_multi_impl(polynomial_coeff, coset)
cells.append(ys) cells.append(ys)
@ -399,9 +410,9 @@ def compute_cells_and_proofs(blob: Blob) -> Tuple[
#### `compute_cells` #### `compute_cells`
```python ```python
def compute_cells(blob: Blob) -> Vector[Cell, CELLS_PER_BLOB]: def compute_cells(blob: Blob) -> Vector[Cell, CELLS_PER_EXT_BLOB]:
""" """
Compute the cell data for a blob (without computing the proofs). Compute the cell data for an extended blob (without computing the proofs).
Public method. Public method.
""" """
@ -412,7 +423,7 @@ def compute_cells(blob: Blob) -> Vector[Cell, CELLS_PER_BLOB]:
compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB)) compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB))
extended_data_rbo = bit_reversal_permutation(extended_data) extended_data_rbo = bit_reversal_permutation(extended_data)
return [extended_data_rbo[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL] return [extended_data_rbo[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL]
for i in range(CELLS_PER_BLOB)] for i in range(CELLS_PER_EXT_BLOB)]
``` ```
### Cell verification ### Cell verification
@ -489,11 +500,11 @@ def construct_vanishing_polynomial(missing_cell_ids: Sequence[CellID]) -> Tuple[
corresponds to a missing field element. corresponds to a missing field element.
""" """
# Get the small domain # Get the small domain
roots_of_unity_reduced = compute_roots_of_unity(CELLS_PER_BLOB) roots_of_unity_reduced = compute_roots_of_unity(CELLS_PER_EXT_BLOB)
# Compute polynomial that vanishes at all the missing cells (over the small domain) # Compute polynomial that vanishes at all the missing cells (over the small domain)
short_zero_poly = vanishing_polynomialcoeff([ short_zero_poly = vanishing_polynomialcoeff([
roots_of_unity_reduced[reverse_bits(missing_cell_id, CELLS_PER_BLOB)] roots_of_unity_reduced[reverse_bits(missing_cell_id, CELLS_PER_EXT_BLOB)]
for missing_cell_id in missing_cell_ids for missing_cell_id in missing_cell_ids
]) ])
@ -508,7 +519,7 @@ def construct_vanishing_polynomial(missing_cell_ids: Sequence[CellID]) -> Tuple[
zero_poly_eval_brp = bit_reversal_permutation(zero_poly_eval) zero_poly_eval_brp = bit_reversal_permutation(zero_poly_eval)
# Sanity check # Sanity check
for cell_id in range(CELLS_PER_BLOB): for cell_id in range(CELLS_PER_EXT_BLOB):
start = cell_id * FIELD_ELEMENTS_PER_CELL start = cell_id * FIELD_ELEMENTS_PER_CELL
end = (cell_id + 1) * FIELD_ELEMENTS_PER_CELL end = (cell_id + 1) * FIELD_ELEMENTS_PER_CELL
if cell_id in missing_cell_ids: if cell_id in missing_cell_ids:
@ -603,7 +614,7 @@ def recover_polynomial(cell_ids: Sequence[CellID],
""" """
assert len(cell_ids) == len(cells_bytes) assert len(cell_ids) == len(cells_bytes)
# Check we have enough cells to be able to perform the reconstruction # Check we have enough cells to be able to perform the reconstruction
assert CELLS_PER_BLOB / 2 <= len(cell_ids) <= CELLS_PER_BLOB assert CELLS_PER_EXT_BLOB / 2 <= len(cell_ids) <= CELLS_PER_EXT_BLOB
# Check for duplicates # Check for duplicates
assert len(cell_ids) == len(set(cell_ids)) assert len(cell_ids) == len(set(cell_ids))
@ -613,7 +624,7 @@ def recover_polynomial(cell_ids: Sequence[CellID],
# Convert from bytes to cells # Convert from bytes to cells
cells = [bytes_to_cell(cell_bytes) for cell_bytes in cells_bytes] cells = [bytes_to_cell(cell_bytes) for cell_bytes in cells_bytes]
missing_cell_ids = [cell_id for cell_id in range(CELLS_PER_BLOB) if cell_id not in cell_ids] missing_cell_ids = [cell_id for cell_id in range(CELLS_PER_EXT_BLOB) if cell_id not in cell_ids]
zero_poly_coeff, zero_poly_eval, zero_poly_eval_brp = construct_vanishing_polynomial(missing_cell_ids) zero_poly_coeff, zero_poly_eval, zero_poly_eval_brp = construct_vanishing_polynomial(missing_cell_ids)
eval_shifted_extended_evaluation, eval_shifted_zero_poly, shift_inv = recover_shifted_data( eval_shifted_extended_evaluation, eval_shifted_zero_poly, shift_inv = recover_shifted_data(

View File

@ -18,11 +18,12 @@ def test_compute_extended_matrix(spec):
blob_count = 2 blob_count = 2
input_blobs = [get_sample_blob(spec, rng=rng) for _ in range(blob_count)] input_blobs = [get_sample_blob(spec, rng=rng) for _ in range(blob_count)]
extended_matrix = spec.compute_extended_matrix(input_blobs) extended_matrix = spec.compute_extended_matrix(input_blobs)
assert len(extended_matrix) == spec.CELLS_PER_BLOB * blob_count assert len(extended_matrix) == spec.CELLS_PER_EXT_BLOB * blob_count
rows = [extended_matrix[i:(i + spec.CELLS_PER_BLOB)] for i in range(0, len(extended_matrix), spec.CELLS_PER_BLOB)] rows = [extended_matrix[i:(i + spec.CELLS_PER_EXT_BLOB)]
for i in range(0, len(extended_matrix), spec.CELLS_PER_EXT_BLOB)]
assert len(rows) == blob_count assert len(rows) == blob_count
assert len(rows[0]) == spec.CELLS_PER_BLOB assert len(rows[0]) == spec.CELLS_PER_EXT_BLOB
for blob_index, row in enumerate(rows): for blob_index, row in enumerate(rows):
extended_blob = [] extended_blob = []
@ -40,7 +41,7 @@ def test_recover_matrix(spec):
rng = random.Random(5566) rng = random.Random(5566)
# Number of samples we will be recovering from # Number of samples we will be recovering from
N_SAMPLES = spec.CELLS_PER_BLOB // 2 N_SAMPLES = spec.CELLS_PER_EXT_BLOB // 2
blob_count = 2 blob_count = 2
cells_dict = {} cells_dict = {}
@ -54,9 +55,9 @@ def test_recover_matrix(spec):
cell_ids = [] cell_ids = []
# First figure out just the indices of the cells # First figure out just the indices of the cells
for _ in range(N_SAMPLES): for _ in range(N_SAMPLES):
cell_id = rng.randint(0, spec.CELLS_PER_BLOB - 1) cell_id = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
while cell_id in cell_ids: while cell_id in cell_ids:
cell_id = rng.randint(0, spec.CELLS_PER_BLOB - 1) cell_id = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
cell_ids.append(cell_id) cell_ids.append(cell_id)
cell = cells[cell_id] cell = cells[cell_id]
cells_dict[(blob_index, cell_id)] = cell cells_dict[(blob_index, cell_id)] = cell

View File

@ -71,7 +71,7 @@ def test_recover_polynomial(spec):
rng = random.Random(5566) rng = random.Random(5566)
# Number of samples we will be recovering from # Number of samples we will be recovering from
N_SAMPLES = spec.CELLS_PER_BLOB // 2 N_SAMPLES = spec.CELLS_PER_EXT_BLOB // 2
# Get the data we will be working with # Get the data we will be working with
blob = get_sample_blob(spec) blob = get_sample_blob(spec)
@ -86,9 +86,9 @@ def test_recover_polynomial(spec):
cell_ids = [] cell_ids = []
# First figure out just the indices of the cells # First figure out just the indices of the cells
for i in range(N_SAMPLES): for i in range(N_SAMPLES):
j = rng.randint(0, spec.CELLS_PER_BLOB - 1) j = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
while j in cell_ids: while j in cell_ids:
j = rng.randint(0, spec.CELLS_PER_BLOB - 1) j = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
cell_ids.append(j) cell_ids.append(j)
# Now the cells themselves # Now the cells themselves
known_cells_bytes = [cells_bytes[cell_id] for cell_id in cell_ids] known_cells_bytes = [cells_bytes[cell_id] for cell_id in cell_ids]