Merge branch 'dev' into kw/optimize-compute-kzg-proof-multi
This commit is contained in:
commit
c2b7c0b414
|
@ -147,7 +147,7 @@ def recover_matrix(cells_dict: Dict[Tuple[BlobIndex, CellID], Cell], blob_count:
|
||||||
full_polynomial = recover_polynomial(cell_ids, cells_bytes)
|
full_polynomial = recover_polynomial(cell_ids, cells_bytes)
|
||||||
cells_from_full_polynomial = [
|
cells_from_full_polynomial = [
|
||||||
full_polynomial[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL]
|
full_polynomial[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL]
|
||||||
for i in range(CELLS_PER_BLOB)
|
for i in range(CELLS_PER_EXT_BLOB)
|
||||||
]
|
]
|
||||||
extended_matrix.extend(cells_from_full_polynomial)
|
extended_matrix.extend(cells_from_full_polynomial)
|
||||||
return ExtendedMatrix(extended_matrix)
|
return ExtendedMatrix(extended_matrix)
|
||||||
|
|
|
@ -84,7 +84,7 @@ Cells are the smallest unit of blob data that can come with their own KZG proofs
|
||||||
| `FIELD_ELEMENTS_PER_EXT_BLOB` | `2 * FIELD_ELEMENTS_PER_BLOB` | Number of field elements in a Reed-Solomon extended blob |
|
| `FIELD_ELEMENTS_PER_EXT_BLOB` | `2 * FIELD_ELEMENTS_PER_BLOB` | Number of field elements in a Reed-Solomon extended blob |
|
||||||
| `FIELD_ELEMENTS_PER_CELL` | `uint64(64)` | Number of field elements in a cell |
|
| `FIELD_ELEMENTS_PER_CELL` | `uint64(64)` | Number of field elements in a cell |
|
||||||
| `BYTES_PER_CELL` | `FIELD_ELEMENTS_PER_CELL * BYTES_PER_FIELD_ELEMENT` | The number of bytes in a cell |
|
| `BYTES_PER_CELL` | `FIELD_ELEMENTS_PER_CELL * BYTES_PER_FIELD_ELEMENT` | The number of bytes in a cell |
|
||||||
| `CELLS_PER_BLOB` | `FIELD_ELEMENTS_PER_EXT_BLOB // FIELD_ELEMENTS_PER_CELL` | The number of cells in a blob |
|
| `CELLS_PER_EXT_BLOB` | `FIELD_ELEMENTS_PER_EXT_BLOB // FIELD_ELEMENTS_PER_CELL` | The number of cells in an extended blob |
|
||||||
| `RANDOM_CHALLENGE_KZG_CELL_BATCH_DOMAIN` | `b'RCKZGCBATCH__V1_'` |
|
| `RANDOM_CHALLENGE_KZG_CELL_BATCH_DOMAIN` | `b'RCKZGCBATCH__V1_'` |
|
||||||
|
|
||||||
## Helper functions
|
## Helper functions
|
||||||
|
@ -106,7 +106,7 @@ def bytes_to_cell(cell_bytes: Vector[Bytes32, FIELD_ELEMENTS_PER_CELL]) -> Cell:
|
||||||
#### `g2_lincomb`
|
#### `g2_lincomb`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def g2_lincomb(points: Sequence[KZGCommitment], scalars: Sequence[BLSFieldElement]) -> Bytes96:
|
def g2_lincomb(points: Sequence[G2Point], scalars: Sequence[BLSFieldElement]) -> Bytes96:
|
||||||
"""
|
"""
|
||||||
BLS multiscalar multiplication in G2. This function can be optimized using Pippenger's algorithm and variants.
|
BLS multiscalar multiplication in G2. This function can be optimized using Pippenger's algorithm and variants.
|
||||||
"""
|
"""
|
||||||
|
@ -308,16 +308,27 @@ def compute_kzg_proof_multi_impl(
|
||||||
polynomial_coeff: PolynomialCoeff,
|
polynomial_coeff: PolynomialCoeff,
|
||||||
zs: Sequence[BLSFieldElement]) -> Tuple[KZGProof, Sequence[BLSFieldElement]]:
|
zs: Sequence[BLSFieldElement]) -> Tuple[KZGProof, Sequence[BLSFieldElement]]:
|
||||||
"""
|
"""
|
||||||
Helper function that computes multi-evaluation KZG proofs.
|
Compute a KZG multi-evaluation proof for a set of `k` points.
|
||||||
|
|
||||||
|
This is done by committing to the following quotient polynomial:
|
||||||
|
Q(X) = f(X) - r(X) / Z(X)
|
||||||
|
Where:
|
||||||
|
- r(X) is the degree `k-1` polynomial that agrees with f(x) at all `k` points
|
||||||
|
- Z(X) is the degree `k` polynomial that evaluates to zero on all `k` points
|
||||||
|
|
||||||
|
We further note that since the degree of r(X) is less than the degree of Z(X),
|
||||||
|
the computation can be simplified in monomial form to Q(X) = f(X) / Z(X)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# For all x_i, compute p(x_i) - p(z)
|
# For all points, compute the evaluation of those points
|
||||||
ys = [evaluate_polynomialcoeff(polynomial_coeff, z) for z in zs]
|
ys = [evaluate_polynomialcoeff(polynomial_coeff, z) for z in zs]
|
||||||
|
# Compute r(X)
|
||||||
|
interpolation_polynomial = interpolate_polynomialcoeff(zs, ys)
|
||||||
|
|
||||||
# For all x_i, compute (x_i - z)
|
# Compute Z(X)
|
||||||
denominator_poly = vanishing_polynomialcoeff(zs)
|
denominator_poly = vanishing_polynomialcoeff(zs)
|
||||||
|
|
||||||
# Compute the quotient polynomial directly in evaluation form
|
# Compute the quotient polynomial directly in monomial form
|
||||||
quotient_polynomial = divide_polynomialcoeff(polynomial_coeff, denominator_poly)
|
quotient_polynomial = divide_polynomialcoeff(polynomial_coeff, denominator_poly)
|
||||||
|
|
||||||
return KZGProof(g1_lincomb(KZG_SETUP_G1_MONOMIAL[:len(quotient_polynomial)], quotient_polynomial)), ys
|
return KZGProof(g1_lincomb(KZG_SETUP_G1_MONOMIAL[:len(quotient_polynomial)], quotient_polynomial)), ys
|
||||||
|
@ -357,7 +368,7 @@ def coset_for_cell(cell_id: CellID) -> Cell:
|
||||||
"""
|
"""
|
||||||
Get the coset for a given ``cell_id``
|
Get the coset for a given ``cell_id``
|
||||||
"""
|
"""
|
||||||
assert cell_id < CELLS_PER_BLOB
|
assert cell_id < CELLS_PER_EXT_BLOB
|
||||||
roots_of_unity_brp = bit_reversal_permutation(
|
roots_of_unity_brp = bit_reversal_permutation(
|
||||||
compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB)
|
compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB)
|
||||||
)
|
)
|
||||||
|
@ -372,10 +383,10 @@ def coset_for_cell(cell_id: CellID) -> Cell:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def compute_cells_and_proofs(blob: Blob) -> Tuple[
|
def compute_cells_and_proofs(blob: Blob) -> Tuple[
|
||||||
Vector[Cell, CELLS_PER_BLOB],
|
Vector[Cell, CELLS_PER_EXT_BLOB],
|
||||||
Vector[KZGProof, CELLS_PER_BLOB]]:
|
Vector[KZGProof, CELLS_PER_EXT_BLOB]]:
|
||||||
"""
|
"""
|
||||||
Compute all the cell proofs for one blob. This is an inefficient O(n^2) algorithm,
|
Compute all the cell proofs for an extended blob. This is an inefficient O(n^2) algorithm,
|
||||||
for performant implementation the FK20 algorithm that runs in O(n log n) should be
|
for performant implementation the FK20 algorithm that runs in O(n log n) should be
|
||||||
used instead.
|
used instead.
|
||||||
|
|
||||||
|
@ -387,7 +398,7 @@ def compute_cells_and_proofs(blob: Blob) -> Tuple[
|
||||||
cells = []
|
cells = []
|
||||||
proofs = []
|
proofs = []
|
||||||
|
|
||||||
for i in range(CELLS_PER_BLOB):
|
for i in range(CELLS_PER_EXT_BLOB):
|
||||||
coset = coset_for_cell(i)
|
coset = coset_for_cell(i)
|
||||||
proof, ys = compute_kzg_proof_multi_impl(polynomial_coeff, coset)
|
proof, ys = compute_kzg_proof_multi_impl(polynomial_coeff, coset)
|
||||||
cells.append(ys)
|
cells.append(ys)
|
||||||
|
@ -399,9 +410,9 @@ def compute_cells_and_proofs(blob: Blob) -> Tuple[
|
||||||
#### `compute_cells`
|
#### `compute_cells`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def compute_cells(blob: Blob) -> Vector[Cell, CELLS_PER_BLOB]:
|
def compute_cells(blob: Blob) -> Vector[Cell, CELLS_PER_EXT_BLOB]:
|
||||||
"""
|
"""
|
||||||
Compute the cell data for a blob (without computing the proofs).
|
Compute the cell data for an extended blob (without computing the proofs).
|
||||||
|
|
||||||
Public method.
|
Public method.
|
||||||
"""
|
"""
|
||||||
|
@ -412,7 +423,7 @@ def compute_cells(blob: Blob) -> Vector[Cell, CELLS_PER_BLOB]:
|
||||||
compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB))
|
compute_roots_of_unity(FIELD_ELEMENTS_PER_EXT_BLOB))
|
||||||
extended_data_rbo = bit_reversal_permutation(extended_data)
|
extended_data_rbo = bit_reversal_permutation(extended_data)
|
||||||
return [extended_data_rbo[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL]
|
return [extended_data_rbo[i * FIELD_ELEMENTS_PER_CELL:(i + 1) * FIELD_ELEMENTS_PER_CELL]
|
||||||
for i in range(CELLS_PER_BLOB)]
|
for i in range(CELLS_PER_EXT_BLOB)]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Cell verification
|
### Cell verification
|
||||||
|
@ -489,11 +500,11 @@ def construct_vanishing_polynomial(missing_cell_ids: Sequence[CellID]) -> Tuple[
|
||||||
corresponds to a missing field element.
|
corresponds to a missing field element.
|
||||||
"""
|
"""
|
||||||
# Get the small domain
|
# Get the small domain
|
||||||
roots_of_unity_reduced = compute_roots_of_unity(CELLS_PER_BLOB)
|
roots_of_unity_reduced = compute_roots_of_unity(CELLS_PER_EXT_BLOB)
|
||||||
|
|
||||||
# Compute polynomial that vanishes at all the missing cells (over the small domain)
|
# Compute polynomial that vanishes at all the missing cells (over the small domain)
|
||||||
short_zero_poly = vanishing_polynomialcoeff([
|
short_zero_poly = vanishing_polynomialcoeff([
|
||||||
roots_of_unity_reduced[reverse_bits(missing_cell_id, CELLS_PER_BLOB)]
|
roots_of_unity_reduced[reverse_bits(missing_cell_id, CELLS_PER_EXT_BLOB)]
|
||||||
for missing_cell_id in missing_cell_ids
|
for missing_cell_id in missing_cell_ids
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -508,7 +519,7 @@ def construct_vanishing_polynomial(missing_cell_ids: Sequence[CellID]) -> Tuple[
|
||||||
zero_poly_eval_brp = bit_reversal_permutation(zero_poly_eval)
|
zero_poly_eval_brp = bit_reversal_permutation(zero_poly_eval)
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
for cell_id in range(CELLS_PER_BLOB):
|
for cell_id in range(CELLS_PER_EXT_BLOB):
|
||||||
start = cell_id * FIELD_ELEMENTS_PER_CELL
|
start = cell_id * FIELD_ELEMENTS_PER_CELL
|
||||||
end = (cell_id + 1) * FIELD_ELEMENTS_PER_CELL
|
end = (cell_id + 1) * FIELD_ELEMENTS_PER_CELL
|
||||||
if cell_id in missing_cell_ids:
|
if cell_id in missing_cell_ids:
|
||||||
|
@ -603,7 +614,7 @@ def recover_polynomial(cell_ids: Sequence[CellID],
|
||||||
"""
|
"""
|
||||||
assert len(cell_ids) == len(cells_bytes)
|
assert len(cell_ids) == len(cells_bytes)
|
||||||
# Check we have enough cells to be able to perform the reconstruction
|
# Check we have enough cells to be able to perform the reconstruction
|
||||||
assert CELLS_PER_BLOB / 2 <= len(cell_ids) <= CELLS_PER_BLOB
|
assert CELLS_PER_EXT_BLOB / 2 <= len(cell_ids) <= CELLS_PER_EXT_BLOB
|
||||||
# Check for duplicates
|
# Check for duplicates
|
||||||
assert len(cell_ids) == len(set(cell_ids))
|
assert len(cell_ids) == len(set(cell_ids))
|
||||||
|
|
||||||
|
@ -613,7 +624,7 @@ def recover_polynomial(cell_ids: Sequence[CellID],
|
||||||
# Convert from bytes to cells
|
# Convert from bytes to cells
|
||||||
cells = [bytes_to_cell(cell_bytes) for cell_bytes in cells_bytes]
|
cells = [bytes_to_cell(cell_bytes) for cell_bytes in cells_bytes]
|
||||||
|
|
||||||
missing_cell_ids = [cell_id for cell_id in range(CELLS_PER_BLOB) if cell_id not in cell_ids]
|
missing_cell_ids = [cell_id for cell_id in range(CELLS_PER_EXT_BLOB) if cell_id not in cell_ids]
|
||||||
zero_poly_coeff, zero_poly_eval, zero_poly_eval_brp = construct_vanishing_polynomial(missing_cell_ids)
|
zero_poly_coeff, zero_poly_eval, zero_poly_eval_brp = construct_vanishing_polynomial(missing_cell_ids)
|
||||||
|
|
||||||
eval_shifted_extended_evaluation, eval_shifted_zero_poly, shift_inv = recover_shifted_data(
|
eval_shifted_extended_evaluation, eval_shifted_zero_poly, shift_inv = recover_shifted_data(
|
||||||
|
|
|
@ -18,11 +18,12 @@ def test_compute_extended_matrix(spec):
|
||||||
blob_count = 2
|
blob_count = 2
|
||||||
input_blobs = [get_sample_blob(spec, rng=rng) for _ in range(blob_count)]
|
input_blobs = [get_sample_blob(spec, rng=rng) for _ in range(blob_count)]
|
||||||
extended_matrix = spec.compute_extended_matrix(input_blobs)
|
extended_matrix = spec.compute_extended_matrix(input_blobs)
|
||||||
assert len(extended_matrix) == spec.CELLS_PER_BLOB * blob_count
|
assert len(extended_matrix) == spec.CELLS_PER_EXT_BLOB * blob_count
|
||||||
|
|
||||||
rows = [extended_matrix[i:(i + spec.CELLS_PER_BLOB)] for i in range(0, len(extended_matrix), spec.CELLS_PER_BLOB)]
|
rows = [extended_matrix[i:(i + spec.CELLS_PER_EXT_BLOB)]
|
||||||
|
for i in range(0, len(extended_matrix), spec.CELLS_PER_EXT_BLOB)]
|
||||||
assert len(rows) == blob_count
|
assert len(rows) == blob_count
|
||||||
assert len(rows[0]) == spec.CELLS_PER_BLOB
|
assert len(rows[0]) == spec.CELLS_PER_EXT_BLOB
|
||||||
|
|
||||||
for blob_index, row in enumerate(rows):
|
for blob_index, row in enumerate(rows):
|
||||||
extended_blob = []
|
extended_blob = []
|
||||||
|
@ -40,7 +41,7 @@ def test_recover_matrix(spec):
|
||||||
rng = random.Random(5566)
|
rng = random.Random(5566)
|
||||||
|
|
||||||
# Number of samples we will be recovering from
|
# Number of samples we will be recovering from
|
||||||
N_SAMPLES = spec.CELLS_PER_BLOB // 2
|
N_SAMPLES = spec.CELLS_PER_EXT_BLOB // 2
|
||||||
|
|
||||||
blob_count = 2
|
blob_count = 2
|
||||||
cells_dict = {}
|
cells_dict = {}
|
||||||
|
@ -54,9 +55,9 @@ def test_recover_matrix(spec):
|
||||||
cell_ids = []
|
cell_ids = []
|
||||||
# First figure out just the indices of the cells
|
# First figure out just the indices of the cells
|
||||||
for _ in range(N_SAMPLES):
|
for _ in range(N_SAMPLES):
|
||||||
cell_id = rng.randint(0, spec.CELLS_PER_BLOB - 1)
|
cell_id = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
|
||||||
while cell_id in cell_ids:
|
while cell_id in cell_ids:
|
||||||
cell_id = rng.randint(0, spec.CELLS_PER_BLOB - 1)
|
cell_id = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
|
||||||
cell_ids.append(cell_id)
|
cell_ids.append(cell_id)
|
||||||
cell = cells[cell_id]
|
cell = cells[cell_id]
|
||||||
cells_dict[(blob_index, cell_id)] = cell
|
cells_dict[(blob_index, cell_id)] = cell
|
||||||
|
|
|
@ -71,7 +71,7 @@ def test_recover_polynomial(spec):
|
||||||
rng = random.Random(5566)
|
rng = random.Random(5566)
|
||||||
|
|
||||||
# Number of samples we will be recovering from
|
# Number of samples we will be recovering from
|
||||||
N_SAMPLES = spec.CELLS_PER_BLOB // 2
|
N_SAMPLES = spec.CELLS_PER_EXT_BLOB // 2
|
||||||
|
|
||||||
# Get the data we will be working with
|
# Get the data we will be working with
|
||||||
blob = get_sample_blob(spec)
|
blob = get_sample_blob(spec)
|
||||||
|
@ -86,9 +86,9 @@ def test_recover_polynomial(spec):
|
||||||
cell_ids = []
|
cell_ids = []
|
||||||
# First figure out just the indices of the cells
|
# First figure out just the indices of the cells
|
||||||
for i in range(N_SAMPLES):
|
for i in range(N_SAMPLES):
|
||||||
j = rng.randint(0, spec.CELLS_PER_BLOB - 1)
|
j = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
|
||||||
while j in cell_ids:
|
while j in cell_ids:
|
||||||
j = rng.randint(0, spec.CELLS_PER_BLOB - 1)
|
j = rng.randint(0, spec.CELLS_PER_EXT_BLOB - 1)
|
||||||
cell_ids.append(j)
|
cell_ids.append(j)
|
||||||
# Now the cells themselves
|
# Now the cells themselves
|
||||||
known_cells_bytes = [cells_bytes[cell_id] for cell_id in cell_ids]
|
known_cells_bytes = [cells_bytes[cell_id] for cell_id in cell_ids]
|
||||||
|
|
Loading…
Reference in New Issue