diff --git a/da/common.py b/da/common.py index a591808..cb1e3c9 100644 --- a/da/common.py +++ b/da/common.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from itertools import chain +from itertools import chain, zip_longest from typing import List, Generator, Self from eth2spec.eip7594.mainnet import Bytes32 @@ -25,7 +25,7 @@ class Row(List[Bytes32]): class ChunksMatrix(List[Row | Column]): @property def columns(self) -> Generator[List[Chunk], None, None]: - yield from map(Column, zip(*self)) + yield from map(Column, zip_longest(*self, fillvalue=b"")) def transposed(self) -> Self: return ChunksMatrix(self.columns) diff --git a/da/encoder.py b/da/encoder.py index 87e5f09..c4b3344 100644 --- a/da/encoder.py +++ b/da/encoder.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from itertools import batched, chain from typing import List, Sequence, Tuple +from hashlib import sha256 + from eth2spec.eip7594.mainnet import KZGCommitment as Commitment, KZGProof as Proof, BLSFieldElement -from da.common import ChunksMatrix, Chunk, Row -from da.kzg_rs import kzg, rs, poly -from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY +from da.common import ChunksMatrix, Chunk, Row, Column +from da.kzg_rs import kzg, rs +from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS from da.kzg_rs.poly import Polynomial @@ -52,11 +54,12 @@ class DAEncoder: ) return ChunksMatrix(__rs_encode_row(row) for row in chunks_matrix) + @staticmethod def _compute_rows_proofs( - self, - chunks_matrix: ChunksMatrix, - polynomials: Sequence[Polynomial], - row_commitments: Sequence[Commitment]) -> List[List[Proof]]: + chunks_matrix: ChunksMatrix, + polynomials: Sequence[Polynomial], + row_commitments: Sequence[Commitment] + ) -> List[List[Proof]]: proofs = [] for row, poly, commitment in zip(chunks_matrix, polynomials, row_commitments): proofs.append( @@ -70,26 +73,38 @@ class DAEncoder: def _compute_column_kzg_commitments(self, chunks_matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]: return self._compute_row_kzg_commitments(chunks_matrix.transposed()) - def _compute_aggregated_column_commitments( - self, chunks_matrix: ChunksMatrix, column_commitments: List[Commitment] - ) -> Commitment: - ... + @staticmethod + def _compute_aggregated_column_commitment( + chunks_matrix: ChunksMatrix, column_commitments: Sequence[Commitment] + ) -> Tuple[Polynomial, Commitment]: + data = bytes(chain.from_iterable( + DAEncoder._hash_column_and_commitment(column, commitment) + for column, commitment in zip(chunks_matrix.columns, column_commitments) + )) + return kzg.bytes_to_commitment(data, GLOBAL_PARAMETERS) + @staticmethod def _compute_aggregated_column_proofs( - self, - chunks_matrix: ChunksMatrix, - aggregated_column_commitment: Commitment + polynomial: Polynomial, + column_commitments: Sequence[Commitment], ) -> List[Proof]: - ... + return [ + kzg.generate_element_proof(i, polynomial, GLOBAL_PARAMETERS, ROOTS_OF_UNITY) + for i in range(len(column_commitments)) + ] def encode(self, data: bytes) -> EncodedData: chunks_matrix = self._chunkify_data(data) - row_commitments = self._compute_row_kzg_commitments(chunks_matrix) + row_polynomials, row_commitments = zip(*self._compute_row_kzg_commitments(chunks_matrix)) extended_matrix = self._rs_encode_rows(chunks_matrix) - row_proofs = self._compute_rows_proofs(extended_matrix, row_commitments) - column_commitments = self._compute_column_kzg_commitments(extended_matrix) - aggregated_column_commitment = self._compute_aggregated_column_commitments(extended_matrix, column_commitments) - aggregated_column_proofs = self._compute_aggregated_column_proofs(extended_matrix, aggregated_column_commitment) + row_proofs = self._compute_rows_proofs(extended_matrix, row_polynomials, row_commitments) + column_polynomials, column_commitments = zip(*self._compute_column_kzg_commitments(extended_matrix)) + aggregated_column_polynomial, aggregated_column_commitment = ( + self._compute_aggregated_column_commitment(extended_matrix, column_commitments) + ) + aggregated_column_proofs = self._compute_aggregated_column_proofs( + aggregated_column_polynomial, column_commitments + ) result = EncodedData( data, extended_matrix, @@ -100,3 +115,10 @@ class DAEncoder: aggregated_column_proofs ) return result + + @staticmethod + def _hash_column_and_commitment(column: Column, commitment: Commitment) -> bytes: + # TODO: Check correctness of bytes to blsfieldelement using modulus over the hash + return ( + int.from_bytes(sha256(column.as_bytes() + bytes(commitment)).digest()) % BLS_MODULUS + ).to_bytes(32, byteorder="big") diff --git a/da/test_encoder.py b/da/test_encoder.py index e431d2d..125ab3d 100644 --- a/da/test_encoder.py +++ b/da/test_encoder.py @@ -2,6 +2,8 @@ from itertools import chain, batched from random import randrange from unittest import TestCase +from eth2spec.deneb.mainnet import bytes_to_bls_field + from da import encoder from da.encoder import DAEncoderParams, Commitment, DAEncoder from eth2spec.eip7594.mainnet import BYTES_PER_FIELD_ELEMENT, BLSFieldElement @@ -22,13 +24,38 @@ class TestEncoder(TestCase): ) ) - def assert_encoding(self, encoder_params: DAEncoderParams, data: bytearray): + def assert_encoding(self, encoder_params: DAEncoderParams, data: bytes): encoded_data = encoder.DAEncoder(encoder_params).encode(data) self.assertEqual(encoded_data.data, data) - self.assertEqual(len(encoded_data.extended_matrix), encoder_params.column_count) + extended_factor = 2 + column_count = encoder_params.column_count*extended_factor + columns_len = len(list(encoded_data.extended_matrix.columns)) + self.assertEqual(columns_len, column_count) chunks_size = (len(data) // encoder_params.bytes_per_field_element) // encoder_params.column_count self.assertEqual(len(encoded_data.row_commitments), chunks_size) self.assertEqual(len(encoded_data.row_proofs), chunks_size) + self.assertEqual(len(encoded_data.row_proofs[0]), column_count) + self.assertIsNotNone(encoded_data.aggregated_column_commitment) + self.assertEqual(len(encoded_data.aggregated_column_proofs), columns_len) + + # verify rows + for row, proofs, commitment in zip(encoded_data.extended_matrix, encoded_data.row_proofs, encoded_data.row_commitments): + for i, (chunk, proof) in enumerate(zip(row, proofs)): + self.assertTrue( + kzg.verify_element_proof(bytes_to_bls_field(chunk), commitment, proof, i, ROOTS_OF_UNITY) + ) + + # verify column aggregation + for i, (column, proof) in enumerate(zip(encoded_data.extended_matrix.columns, encoded_data.aggregated_column_proofs)): + data = DAEncoder._hash_column_and_commitment(column, commitment) + kzg.verify_element_proof( + bytes_to_bls_field(data), + encoded_data.aggregated_column_commitment, + proof, + i, + ROOTS_OF_UNITY + ) + def test_chunkify(self): encoder_settings = DAEncoderParams(column_count=2, bytes_per_field_element=32) @@ -81,17 +108,34 @@ class TestEncoder(TestCase): self.assertEqual(len(polynomials), len(chunks_matrix[0])) def test_generate_aggregated_column_commitments(self): - pass + chunks_matrix = self.encoder._chunkify_data(self.data) + _, column_commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix)) + poly, commitment = self.encoder._compute_aggregated_column_commitment(chunks_matrix, column_commitments) + self.assertIsNotNone(poly) + self.assertIsNotNone(commitment) + + def test_generate_aggregated_column_proofs(self): + chunks_matrix = self.encoder._chunkify_data(self.data) + _, column_commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix)) + poly, _ = self.encoder._compute_aggregated_column_commitment(chunks_matrix, column_commitments) + proofs = self.encoder._compute_aggregated_column_proofs(poly, column_commitments) + self.assertEqual(len(proofs), len(column_commitments)) def test_encode(self): - # TODO: remove return, for now we make it work for now so we do not disturb other modules - return from random import randbytes - sizes = [pow(2, exp) for exp in range(0, 8, 2)] + sizes = [pow(2, exp) for exp in range(4, 8, 2)] encoder_params = DAEncoderParams( - column_count=10, + column_count=8, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT ) for size in sizes: - data = bytearray(randbytes(size*1024)) + data = bytes( + chain.from_iterable( + # TODO: For now we make data fit with modulus, we need to research if this is correct + (int.from_bytes(b) % BLS_MODULUS).to_bytes(length=32) + for b in batched( + randbytes(size*self.encoder.params.column_count), self.encoder.params.bytes_per_field_element + ) + ) + ) self.assert_encoding(encoder_params, data)