Chunkify up to 31byte elements (#87)

This commit is contained in:
Daniel Sanchez 2024-03-20 11:03:39 +01:00 committed by GitHub
parent b1e13f79c5
commit a0175e16f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 32 deletions

View File

@ -11,7 +11,7 @@ class NodeId(Bytes32):
pass pass
class Chunk(Bytes32): class Chunk(bytes):
pass pass

View File

@ -7,14 +7,14 @@ from eth2spec.eip7594.mainnet import KZGCommitment as Commitment, KZGProof as Pr
from da.common import ChunksMatrix, Chunk, Row, Column from da.common import ChunksMatrix, Chunk, Row, Column
from da.kzg_rs import kzg, rs from da.kzg_rs import kzg, rs
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS, BYTES_PER_FIELD_ELEMENT
from da.kzg_rs.poly import Polynomial from da.kzg_rs.poly import Polynomial
@dataclass @dataclass
class DAEncoderParams: class DAEncoderParams:
column_count: int column_count: int
bytes_per_field_element: int bytes_per_chunk: int
@dataclass @dataclass
@ -29,21 +29,29 @@ class EncodedData:
aggregated_column_proofs: List[Proof] aggregated_column_proofs: List[Proof]
class DAEncoder: class DAEncoder:
def __init__(self, params: DAEncoderParams): def __init__(self, params: DAEncoderParams):
# we can only encode up to 31 bytes per element which fits without problem in a 32 byte element
assert params.bytes_per_chunk < BYTES_PER_FIELD_ELEMENT
self.params = params self.params = params
def _chunkify_data(self, data: bytes) -> ChunksMatrix: def _chunkify_data(self, data: bytes) -> ChunksMatrix:
size: int = self.params.column_count * self.params.bytes_per_field_element size: int = self.params.column_count * self.params.bytes_per_chunk
return ChunksMatrix( return ChunksMatrix(
Row(Chunk(bytes(chunk)) for chunk in batched(b, self.params.bytes_per_field_element)) Row(Chunk(int.from_bytes(chunk, byteorder="big").to_bytes(length=BYTES_PER_FIELD_ELEMENT))
for chunk in batched(b, self.params.bytes_per_chunk)
)
for b in batched(data, size) for b in batched(data, size)
) )
@staticmethod def _compute_row_kzg_commitments(self, matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
def _compute_row_kzg_commitments(matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]: return [
return [kzg.bytes_to_commitment(row.as_bytes(), GLOBAL_PARAMETERS) for row in matrix] kzg.bytes_to_commitment(
row.as_bytes(),
GLOBAL_PARAMETERS,
)
for row in matrix
]
def _rs_encode_rows(self, chunks_matrix: ChunksMatrix) -> ChunksMatrix: def _rs_encode_rows(self, chunks_matrix: ChunksMatrix) -> ChunksMatrix:
def __rs_encode_row(row: Row) -> Row: def __rs_encode_row(row: Row) -> Row:
@ -51,7 +59,8 @@ class DAEncoder:
return Row( return Row(
Chunk(BLSFieldElement.to_bytes( Chunk(BLSFieldElement.to_bytes(
x, x,
length=self.params.bytes_per_field_element, byteorder="big" # fixed to 32 bytes as bls_field_elements are 32bytes (256bits) encoded
length=32, byteorder="big"
)) for x in rs.encode(polynomial, 2, ROOTS_OF_UNITY) )) for x in rs.encode(polynomial, 2, ROOTS_OF_UNITY)
) )
return ChunksMatrix(__rs_encode_row(row) for row in chunks_matrix) return ChunksMatrix(__rs_encode_row(row) for row in chunks_matrix)

View File

@ -9,12 +9,12 @@ from .common import BYTES_PER_FIELD_ELEMENT, G1, BLS_MODULUS, GLOBAL_PARAMETERS_
from .poly import Polynomial from .poly import Polynomial
def bytes_to_polynomial(b: bytes) -> Polynomial: def bytes_to_polynomial(b: bytes, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT) -> Polynomial:
""" """
Convert bytes to list of BLS field scalars. Convert bytes to list of BLS field scalars.
""" """
assert len(b) % BYTES_PER_FIELD_ELEMENT == 0 assert len(b) % bytes_per_field_element == 0
eval_form = [int(bytes_to_bls_field(b)) for b in batched(b, int(BYTES_PER_FIELD_ELEMENT))] eval_form = [int(bytes_to_bls_field(b)) for b in batched(b, int(bytes_per_field_element))]
return Polynomial.from_evaluations(eval_form, BLS_MODULUS) return Polynomial.from_evaluations(eval_form, BLS_MODULUS)
@ -34,7 +34,7 @@ def g1_linear_combination(polynomial: Polynomial[BLSFieldElement], global_parame
def bytes_to_commitment(b: bytes, global_parameters: Sequence[G1]) -> Tuple[Polynomial, Commitment]: def bytes_to_commitment(b: bytes, global_parameters: Sequence[G1]) -> Tuple[Polynomial, Commitment]:
poly = bytes_to_polynomial(b) poly = bytes_to_polynomial(b, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT)
return poly, g1_linear_combination(poly, global_parameters) return poly, g1_linear_combination(poly, global_parameters)

View File

@ -1,9 +1,7 @@
from typing import Sequence, List from typing import Sequence
import scipy.interpolate
from eth2spec.deneb.mainnet import BLSFieldElement from eth2spec.deneb.mainnet import BLSFieldElement
from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff from .common import BLS_MODULUS
from .common import G1, BLS_MODULUS
from .poly import Polynomial from .poly import Polynomial
ExtendedData = Sequence[BLSFieldElement] ExtendedData = Sequence[BLSFieldElement]

View File

@ -50,7 +50,7 @@ class TestDispersal(TestCase):
def test_disperse(self): def test_disperse(self):
data = self.encoder_test.data data = self.encoder_test.data
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_field_element=32) encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
encoded_data = DAEncoder(encoding_params).encode(data) encoded_data = DAEncoder(encoding_params).encode(data)
# mock send and await method with local verifiers # mock send and await method with local verifiers

View File

@ -1,5 +1,5 @@
from itertools import chain, batched from itertools import chain, batched
from random import randrange from random import randrange, randbytes
from unittest import TestCase from unittest import TestCase
from eth2spec.deneb.mainnet import bytes_to_bls_field from eth2spec.deneb.mainnet import bytes_to_bls_field
@ -14,12 +14,12 @@ from da.kzg_rs import kzg, rs
class TestEncoder(TestCase): class TestEncoder(TestCase):
def setUp(self): def setUp(self):
self.params: DAEncoderParams = DAEncoderParams(column_count=16, bytes_per_field_element=32) self.params: DAEncoderParams = DAEncoderParams(column_count=16, bytes_per_chunk=31)
self.encoder: DAEncoder = DAEncoder(self.params) self.encoder: DAEncoder = DAEncoder(self.params)
self.elements = 32 self.elements = 32
self.data = bytearray( self.data = bytearray(
chain.from_iterable( chain.from_iterable(
randrange(BLS_MODULUS).to_bytes(length=self.params.bytes_per_field_element, byteorder='big') randbytes(self.params.bytes_per_chunk)
for _ in range(self.elements) for _ in range(self.elements)
) )
) )
@ -31,7 +31,7 @@ class TestEncoder(TestCase):
column_count = encoder_params.column_count*extended_factor column_count = encoder_params.column_count*extended_factor
columns_len = len(list(encoded_data.extended_matrix.columns)) columns_len = len(list(encoded_data.extended_matrix.columns))
self.assertEqual(columns_len, column_count) self.assertEqual(columns_len, column_count)
chunks_size = (len(data) // encoder_params.bytes_per_field_element) // encoder_params.column_count chunks_size = (len(data) // encoder_params.bytes_per_chunk) // encoder_params.column_count
self.assertEqual(len(encoded_data.row_commitments), chunks_size) self.assertEqual(len(encoded_data.row_commitments), chunks_size)
self.assertEqual(len(encoded_data.row_proofs), chunks_size) self.assertEqual(len(encoded_data.row_proofs), chunks_size)
self.assertEqual(len(encoded_data.row_proofs[0]), column_count) self.assertEqual(len(encoded_data.row_proofs[0]), column_count)
@ -57,15 +57,15 @@ class TestEncoder(TestCase):
) )
def test_chunkify(self): def test_chunkify(self):
encoder_settings = DAEncoderParams(column_count=2, bytes_per_field_element=32) encoder_settings = DAEncoderParams(column_count=2, bytes_per_chunk=31)
elements = 10 elements = 10
data = bytearray(chain.from_iterable(int.to_bytes(0, length=32, byteorder='big') for _ in range(elements))) data = bytes(chain.from_iterable(int.to_bytes(0, length=31, byteorder='big') for _ in range(elements)))
_encoder = encoder.DAEncoder(encoder_settings) _encoder = encoder.DAEncoder(encoder_settings)
chunks_matrix = _encoder._chunkify_data(data) chunks_matrix = _encoder._chunkify_data(data)
self.assertEqual(len(chunks_matrix), elements//encoder_settings.column_count) self.assertEqual(len(chunks_matrix), elements//encoder_settings.column_count)
for row in chunks_matrix: for row in chunks_matrix:
self.assertEqual(len(row), encoder_settings.column_count) self.assertEqual(len(row), encoder_settings.column_count)
self.assertEqual(len(row[0]), encoder_settings.bytes_per_field_element) self.assertEqual(len(row[0]), 32)
def test_compute_row_kzg_commitments(self): def test_compute_row_kzg_commitments(self):
chunks_matrix = self.encoder._chunkify_data(self.data) chunks_matrix = self.encoder._chunkify_data(self.data)
@ -125,16 +125,13 @@ class TestEncoder(TestCase):
sizes = [pow(2, exp) for exp in range(4, 8, 2)] sizes = [pow(2, exp) for exp in range(4, 8, 2)]
encoder_params = DAEncoderParams( encoder_params = DAEncoderParams(
column_count=8, column_count=8,
bytes_per_field_element=BYTES_PER_FIELD_ELEMENT bytes_per_chunk=31
) )
for size in sizes: for size in sizes:
data = bytes( data = bytes(
chain.from_iterable( chain.from_iterable(
# TODO: For now we make data fit with modulus, we need to research if this is correct randbytes(encoder_params.bytes_per_chunk)
(int.from_bytes(b) % BLS_MODULUS).to_bytes(length=32) for _ in range(size*encoder_params.column_count)
for b in batched(
randbytes(size*self.encoder.params.column_count), self.encoder.params.bytes_per_field_element
)
) )
) )
self.assert_encoding(encoder_params, data) self.assert_encoding(encoder_params, data)