Chunkify up to 31byte elements (#87)

This commit is contained in:
Daniel Sanchez 2024-03-20 11:03:39 +01:00 committed by GitHub
parent b1e13f79c5
commit a0175e16f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 32 deletions

View File

@ -11,7 +11,7 @@ class NodeId(Bytes32):
pass
class Chunk(Bytes32):
class Chunk(bytes):
pass

View File

@ -7,14 +7,14 @@ from eth2spec.eip7594.mainnet import KZGCommitment as Commitment, KZGProof as Pr
from da.common import ChunksMatrix, Chunk, Row, Column
from da.kzg_rs import kzg, rs
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS, BYTES_PER_FIELD_ELEMENT
from da.kzg_rs.poly import Polynomial
@dataclass
class DAEncoderParams:
column_count: int
bytes_per_field_element: int
bytes_per_chunk: int
@dataclass
@ -29,21 +29,29 @@ class EncodedData:
aggregated_column_proofs: List[Proof]
class DAEncoder:
def __init__(self, params: DAEncoderParams):
# we can only encode up to 31 bytes per element which fits without problem in a 32 byte element
assert params.bytes_per_chunk < BYTES_PER_FIELD_ELEMENT
self.params = params
def _chunkify_data(self, data: bytes) -> ChunksMatrix:
size: int = self.params.column_count * self.params.bytes_per_field_element
size: int = self.params.column_count * self.params.bytes_per_chunk
return ChunksMatrix(
Row(Chunk(bytes(chunk)) for chunk in batched(b, self.params.bytes_per_field_element))
Row(Chunk(int.from_bytes(chunk, byteorder="big").to_bytes(length=BYTES_PER_FIELD_ELEMENT))
for chunk in batched(b, self.params.bytes_per_chunk)
)
for b in batched(data, size)
)
@staticmethod
def _compute_row_kzg_commitments(matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
return [kzg.bytes_to_commitment(row.as_bytes(), GLOBAL_PARAMETERS) for row in matrix]
def _compute_row_kzg_commitments(self, matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
return [
kzg.bytes_to_commitment(
row.as_bytes(),
GLOBAL_PARAMETERS,
)
for row in matrix
]
def _rs_encode_rows(self, chunks_matrix: ChunksMatrix) -> ChunksMatrix:
def __rs_encode_row(row: Row) -> Row:
@ -51,7 +59,8 @@ class DAEncoder:
return Row(
Chunk(BLSFieldElement.to_bytes(
x,
length=self.params.bytes_per_field_element, byteorder="big"
# fixed to 32 bytes as bls_field_elements are 32bytes (256bits) encoded
length=32, byteorder="big"
)) for x in rs.encode(polynomial, 2, ROOTS_OF_UNITY)
)
return ChunksMatrix(__rs_encode_row(row) for row in chunks_matrix)

View File

@ -9,12 +9,12 @@ from .common import BYTES_PER_FIELD_ELEMENT, G1, BLS_MODULUS, GLOBAL_PARAMETERS_
from .poly import Polynomial
def bytes_to_polynomial(b: bytes) -> Polynomial:
def bytes_to_polynomial(b: bytes, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT) -> Polynomial:
"""
Convert bytes to list of BLS field scalars.
"""
assert len(b) % BYTES_PER_FIELD_ELEMENT == 0
eval_form = [int(bytes_to_bls_field(b)) for b in batched(b, int(BYTES_PER_FIELD_ELEMENT))]
assert len(b) % bytes_per_field_element == 0
eval_form = [int(bytes_to_bls_field(b)) for b in batched(b, int(bytes_per_field_element))]
return Polynomial.from_evaluations(eval_form, BLS_MODULUS)
@ -34,7 +34,7 @@ def g1_linear_combination(polynomial: Polynomial[BLSFieldElement], global_parame
def bytes_to_commitment(b: bytes, global_parameters: Sequence[G1]) -> Tuple[Polynomial, Commitment]:
poly = bytes_to_polynomial(b)
poly = bytes_to_polynomial(b, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT)
return poly, g1_linear_combination(poly, global_parameters)

View File

@ -1,9 +1,7 @@
from typing import Sequence, List
from typing import Sequence
import scipy.interpolate
from eth2spec.deneb.mainnet import BLSFieldElement
from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff
from .common import G1, BLS_MODULUS
from .common import BLS_MODULUS
from .poly import Polynomial
ExtendedData = Sequence[BLSFieldElement]

View File

@ -50,7 +50,7 @@ class TestDispersal(TestCase):
def test_disperse(self):
data = self.encoder_test.data
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_field_element=32)
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
encoded_data = DAEncoder(encoding_params).encode(data)
# mock send and await method with local verifiers

View File

@ -1,5 +1,5 @@
from itertools import chain, batched
from random import randrange
from random import randrange, randbytes
from unittest import TestCase
from eth2spec.deneb.mainnet import bytes_to_bls_field
@ -14,12 +14,12 @@ from da.kzg_rs import kzg, rs
class TestEncoder(TestCase):
def setUp(self):
self.params: DAEncoderParams = DAEncoderParams(column_count=16, bytes_per_field_element=32)
self.params: DAEncoderParams = DAEncoderParams(column_count=16, bytes_per_chunk=31)
self.encoder: DAEncoder = DAEncoder(self.params)
self.elements = 32
self.data = bytearray(
chain.from_iterable(
randrange(BLS_MODULUS).to_bytes(length=self.params.bytes_per_field_element, byteorder='big')
randbytes(self.params.bytes_per_chunk)
for _ in range(self.elements)
)
)
@ -31,7 +31,7 @@ class TestEncoder(TestCase):
column_count = encoder_params.column_count*extended_factor
columns_len = len(list(encoded_data.extended_matrix.columns))
self.assertEqual(columns_len, column_count)
chunks_size = (len(data) // encoder_params.bytes_per_field_element) // encoder_params.column_count
chunks_size = (len(data) // encoder_params.bytes_per_chunk) // encoder_params.column_count
self.assertEqual(len(encoded_data.row_commitments), chunks_size)
self.assertEqual(len(encoded_data.row_proofs), chunks_size)
self.assertEqual(len(encoded_data.row_proofs[0]), column_count)
@ -57,15 +57,15 @@ class TestEncoder(TestCase):
)
def test_chunkify(self):
encoder_settings = DAEncoderParams(column_count=2, bytes_per_field_element=32)
encoder_settings = DAEncoderParams(column_count=2, bytes_per_chunk=31)
elements = 10
data = bytearray(chain.from_iterable(int.to_bytes(0, length=32, byteorder='big') for _ in range(elements)))
data = bytes(chain.from_iterable(int.to_bytes(0, length=31, byteorder='big') for _ in range(elements)))
_encoder = encoder.DAEncoder(encoder_settings)
chunks_matrix = _encoder._chunkify_data(data)
self.assertEqual(len(chunks_matrix), elements//encoder_settings.column_count)
for row in chunks_matrix:
self.assertEqual(len(row), encoder_settings.column_count)
self.assertEqual(len(row[0]), encoder_settings.bytes_per_field_element)
self.assertEqual(len(row[0]), 32)
def test_compute_row_kzg_commitments(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
@ -125,16 +125,13 @@ class TestEncoder(TestCase):
sizes = [pow(2, exp) for exp in range(4, 8, 2)]
encoder_params = DAEncoderParams(
column_count=8,
bytes_per_field_element=BYTES_PER_FIELD_ELEMENT
bytes_per_chunk=31
)
for size in sizes:
data = bytes(
chain.from_iterable(
# TODO: For now we make data fit with modulus, we need to research if this is correct
(int.from_bytes(b) % BLS_MODULUS).to_bytes(length=32)
for b in batched(
randbytes(size*self.encoder.params.column_count), self.encoder.params.bytes_per_field_element
)
randbytes(encoder_params.bytes_per_chunk)
for _ in range(size*encoder_params.column_count)
)
)
self.assert_encoding(encoder_params, data)