moved from incorrect location at nomos-specs

This commit is contained in:
holisticode 2024-07-19 13:19:35 -05:00
parent e6402007f0
commit faf399eaaf
50 changed files with 2844 additions and 0 deletions

0
da/__init__.py Normal file
View File

0
da/api/__init__.py Normal file
View File

58
da/api/common.py Normal file
View File

@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, List, Sequence
from da.common import Certificate
from da.verifier import DABlob
@dataclass
class Metadata:
# app identifier
app_id: bytes
# index of VID certificate blob
index: int
@dataclass
class VID:
# da certificate id
cert_id: bytes
# application + index information
metadata: Metadata
class BlobStore(ABC):
@abstractmethod
def add(self, certificate: Certificate, metadata: Metadata):
"""
Raises: ValueError if there is already a registered certificate fot the given metadata
"""
pass
@abstractmethod
def get_multiple(self, app_id: bytes, indexes: Sequence[int]) -> List[Optional[DABlob]]:
pass
class DAApi:
def __init__(self, bs: BlobStore):
self.store = bs
def write(self, certificate: Certificate, metadata: Metadata):
"""
Write method should be used by a service that is able to retrieve verified certificates
from the latest Block. Once a certificate is retrieved, api creates a relation between
the blob of an original data, certificate and index for the app_id of the certificate.
Raises: ValueError if there is already a registered certificate for a given metadata
"""
self.store.add(certificate, metadata)
def read(self, app_id, indexes) -> List[Optional[DABlob]]:
"""
Read method should accept only `app_id` and a list of indexes. The returned list of
blobs should be ordered in the same sequence as `indexes` in a request.
If node does not have the blob for some indexes, then it should add None object as an
item.
"""
return self.store.get_multiple(app_id, indexes)

97
da/api/test_flow.py Normal file
View File

@ -0,0 +1,97 @@
from unittest import TestCase
from collections import defaultdict
from da.api.common import *
@dataclass
class MockCertificate:
cert_id: int
class MockStore(BlobStore):
def __init__(self):
self.blob_store = {}
self.app_id_store = defaultdict(dict)
def populate(self, blob, cert_id: bytes):
self.blob_store[cert_id] = blob
# Implements `add` method from BlobStore abstract class.
def add(self, cert_id: bytes, metadata: Metadata):
if metadata.index in self.app_id_store[metadata.app_id]:
raise ValueError("index already written")
self.app_id_store[metadata.app_id][metadata.index] = cert_id
# Implements `get_multiple` method from BlobStore abstract class.
def get_multiple(self, app_id, indexes) -> List[Optional[DABlob]]:
return [
self.blob_store.get(self.app_id_store[app_id].get(i), None) if self.app_id_store[app_id].get(i) else None for i in indexes
]
class TestFlow(TestCase):
def test_api_write_read(self):
expected_blob = "hello"
cert_id = b"11"*32
app_id = 1
idx = 1
mock_meta = Metadata(1, 1)
mock_store = MockStore()
mock_store.populate(expected_blob, cert_id)
api = DAApi(mock_store)
api.write(cert_id, mock_meta)
blobs = api.read(app_id, [idx])
self.assertEqual([expected_blob], blobs)
def test_same_index(self):
expected_blob = "hello"
cert_id = b"11"*32
app_id = 1
idx = 1
mock_meta = Metadata(1, 1)
mock_store = MockStore()
mock_store.populate(expected_blob, cert_id)
api = DAApi(mock_store)
api.write(cert_id, mock_meta)
with self.assertRaises(ValueError):
api.write(cert_id, mock_meta)
blobs = api.read(app_id, [idx])
self.assertEqual([expected_blob], blobs)
def test_multiple_indexes_same_data(self):
expected_blob = "hello"
cert_id = b"11"*32
app_id = 1
idx1 = 1
idx2 = 2
mock_meta1 = Metadata(app_id, idx1)
mock_meta2 = Metadata(app_id, idx2)
mock_store = MockStore()
mock_store.populate(expected_blob, cert_id)
api = DAApi(mock_store)
api.write(cert_id, mock_meta1)
mock_store.populate(expected_blob, cert_id)
api.write(cert_id, mock_meta2)
blobs_idx1 = api.read(app_id, [idx1])
blobs_idx2 = api.read(app_id, [idx2])
self.assertEqual([expected_blob], blobs_idx1)
self.assertEqual([expected_blob], blobs_idx2)
self.assertEqual(mock_store.app_id_store[app_id][idx1], mock_store.app_id_store[app_id][idx2])

82
da/common.py Normal file
View File

@ -0,0 +1,82 @@
from dataclasses import dataclass
from hashlib import sha3_256
from itertools import chain, zip_longest, compress
from typing import List, Generator, Self, Sequence
from eth2spec.eip7594.mainnet import Bytes32, KZGCommitment as Commitment
from py_ecc.bls import G2ProofOfPossession
class NodeId(Bytes32):
pass
class Chunk(bytes):
pass
class Column(List[Bytes32]):
def as_bytes(self) -> bytes:
return bytes(chain.from_iterable(self))
class Row(List[Bytes32]):
def as_bytes(self) -> bytes:
return bytes(chain.from_iterable(self))
class ChunksMatrix(List[Row | Column]):
@property
def columns(self) -> Generator[List[Chunk], None, None]:
yield from map(Column, zip_longest(*self, fillvalue=b""))
def transposed(self) -> Self:
return ChunksMatrix(self.columns)
BLSPublicKey = bytes
BLSPrivateKey = int
BLSSignature = bytes
class Bitfield(List[bool]):
pass
@dataclass
class Attestation:
signature: BLSSignature
@dataclass
class Certificate:
aggregated_signatures: BLSSignature
signers: Bitfield
aggregated_column_commitment: Commitment
row_commitments: List[Commitment]
def id(self) -> bytes:
return build_attestation_message(self.aggregated_column_commitment, self.row_commitments)
def verify(self, nodes_public_keys: List[BLSPublicKey]) -> bool:
"""
List of nodes public keys should be a trusted list of verified proof of possession keys.
Otherwise, we could fall under the Rogue Key Attack
`assert all(bls_pop.PopVerify(pk, proof) for pk, proof in zip(node_public_keys, pops))`
"""
# we sort them as the signers bitfield is sorted by the public keys as well
signers_keys = list(compress(sorted(nodes_public_keys), self.signers))
message = build_attestation_message(self.aggregated_column_commitment, self.row_commitments)
return NomosDaG2ProofOfPossession.AggregateVerify(signers_keys, [message]*len(signers_keys), self.aggregated_signatures)
def build_attestation_message(aggregated_column_commitment: Commitment, row_commitments: Sequence[Commitment]) -> bytes:
hasher = sha3_256()
hasher.update(bytes(aggregated_column_commitment))
for c in row_commitments:
hasher.update(bytes(c))
return hasher.digest()
class NomosDaG2ProofOfPossession(G2ProofOfPossession):
# Domain specific tag for Nomos DA protocol
DST = b"NOMOS_DA_AVAIL"

90
da/dispersal.py Normal file
View File

@ -0,0 +1,90 @@
from dataclasses import dataclass
from hashlib import sha3_256
from typing import List, Optional, Generator, Sequence
from da.common import Certificate, NodeId, BLSPublicKey, Bitfield, build_attestation_message, NomosDaG2ProofOfPossession as bls_pop
from da.encoder import EncodedData
from da.verifier import DABlob, Attestation
@dataclass
class DispersalSettings:
nodes_ids: List[NodeId]
nodes_pubkey: List[BLSPublicKey]
threshold: int
class Dispersal:
def __init__(self, settings: DispersalSettings):
self.settings = settings
# sort over public keys
self.settings.nodes_ids, self.settings.nodes_pubkey = zip(
*sorted(zip(self.settings.nodes_ids, self.settings.nodes_pubkey), key=lambda x: x[1])
)
def _prepare_data(self, encoded_data: EncodedData) -> Generator[DABlob, None, None]:
assert len(encoded_data.column_commitments) == len(self.settings.nodes_ids)
assert len(encoded_data.aggregated_column_proofs) == len(self.settings.nodes_ids)
columns = encoded_data.extended_matrix.columns
column_commitments = encoded_data.column_commitments
row_commitments = encoded_data.row_commitments
rows_proofs = encoded_data.row_proofs
aggregated_column_commitment = encoded_data.aggregated_column_commitment
aggregated_column_proofs = encoded_data.aggregated_column_proofs
blobs_data = zip(columns, column_commitments, zip(*rows_proofs), aggregated_column_proofs)
for (column, column_commitment, row_proofs, column_proof) in blobs_data:
blob = DABlob(
column,
column_commitment,
aggregated_column_commitment,
column_proof,
row_commitments,
row_proofs
)
yield blob
def _send_and_await_response(self, node: NodeId, blob: DABlob) -> Optional[Attestation]:
pass
def _build_certificate(
self,
encoded_data: EncodedData,
attestations: Sequence[Attestation],
signers: Bitfield
) -> Certificate:
assert len(attestations) >= self.settings.threshold
assert len(attestations) == signers.count(True)
aggregated = bls_pop.Aggregate([attestation.signature for attestation in attestations])
return Certificate(
aggregated_signatures=aggregated,
signers=signers,
aggregated_column_commitment=encoded_data.aggregated_column_commitment,
row_commitments=encoded_data.row_commitments
)
@staticmethod
def _verify_attestation(public_key: BLSPublicKey, attested_message: bytes, attestation: Attestation) -> bool:
return bls_pop.Verify(public_key, attested_message, attestation.signature)
@staticmethod
def _build_attestation_message(encoded_data: EncodedData) -> bytes:
return build_attestation_message(encoded_data.aggregated_column_commitment, encoded_data.row_commitments)
def disperse(self, encoded_data: EncodedData) -> Optional[Certificate]:
attestations = []
attested_message = self._build_attestation_message(encoded_data)
signed = Bitfield(False for _ in range(len(self.settings.nodes_ids)))
blob_data = zip(
range(len(self.settings.nodes_ids)),
self.settings.nodes_ids,
self.settings.nodes_pubkey,
self._prepare_data(encoded_data)
)
for i, node, pk, blob in blob_data:
if attestation := self._send_and_await_response(node, blob):
if self._verify_attestation(pk, attested_message, attestation):
# mark as received
signed[i] = True
attestations.append(attestation)
if len(attestations) >= self.settings.threshold:
return self._build_certificate(encoded_data, attestations, signed)

136
da/encoder.py Normal file
View File

@ -0,0 +1,136 @@
from dataclasses import dataclass
from itertools import batched, chain
from typing import List, Sequence, Tuple
from hashlib import blake2b
from eth2spec.eip7594.mainnet import KZGCommitment as Commitment, KZGProof as Proof, BLSFieldElement
from da.common import ChunksMatrix, Chunk, Row, Column
from da.kzg_rs import kzg, rs
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS, BYTES_PER_FIELD_ELEMENT
from da.kzg_rs.poly import Polynomial
@dataclass
class DAEncoderParams:
column_count: int
bytes_per_chunk: int
@dataclass
class EncodedData:
data: bytes
chunked_data: ChunksMatrix
extended_matrix: ChunksMatrix
row_commitments: List[Commitment]
row_proofs: List[List[Proof]]
column_commitments: List[Commitment]
aggregated_column_commitment: Commitment
aggregated_column_proofs: List[Proof]
class DAEncoder:
def __init__(self, params: DAEncoderParams):
# we can only encode up to 31 bytes per element which fits without problem in a 32 byte element
assert params.bytes_per_chunk < BYTES_PER_FIELD_ELEMENT
self.params = params
def _chunkify_data(self, data: bytes) -> ChunksMatrix:
size: int = self.params.column_count * self.params.bytes_per_chunk
return ChunksMatrix(
Row(Chunk(int.from_bytes(chunk, byteorder="big").to_bytes(length=BYTES_PER_FIELD_ELEMENT))
for chunk in batched(b, self.params.bytes_per_chunk)
)
for b in batched(data, size)
)
def _compute_row_kzg_commitments(self, matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
return [
kzg.bytes_to_commitment(
row.as_bytes(),
GLOBAL_PARAMETERS,
)
for row in matrix
]
def _rs_encode_rows(self, chunks_matrix: ChunksMatrix) -> ChunksMatrix:
def __rs_encode_row(row: Row) -> Row:
polynomial = kzg.bytes_to_polynomial(row.as_bytes())
return Row(
Chunk(BLSFieldElement.to_bytes(
x,
# fixed to 32 bytes as bls_field_elements are 32bytes (256bits) encoded
length=32, byteorder="big"
)) for x in rs.encode(polynomial, 2, ROOTS_OF_UNITY)
)
return ChunksMatrix(__rs_encode_row(row) for row in chunks_matrix)
@staticmethod
def _compute_rows_proofs(
chunks_matrix: ChunksMatrix,
polynomials: Sequence[Polynomial],
row_commitments: Sequence[Commitment]
) -> List[List[Proof]]:
proofs = []
for row, poly, commitment in zip(chunks_matrix, polynomials, row_commitments):
proofs.append(
[
kzg.generate_element_proof(i, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
for i in range(len(row))
]
)
return proofs
def _compute_column_kzg_commitments(self, chunks_matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
return self._compute_row_kzg_commitments(chunks_matrix.transposed())
@staticmethod
def _compute_aggregated_column_commitment(
chunks_matrix: ChunksMatrix, column_commitments: Sequence[Commitment]
) -> Tuple[Polynomial, Commitment]:
data = bytes(chain.from_iterable(
DAEncoder.hash_column_and_commitment(column, commitment)
for column, commitment in zip(chunks_matrix.columns, column_commitments)
))
return kzg.bytes_to_commitment(data, GLOBAL_PARAMETERS)
@staticmethod
def _compute_aggregated_column_proofs(
polynomial: Polynomial,
column_commitments: Sequence[Commitment],
) -> List[Proof]:
return [
kzg.generate_element_proof(i, polynomial, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
for i in range(len(column_commitments))
]
def encode(self, data: bytes) -> EncodedData:
chunks_matrix = self._chunkify_data(data)
row_polynomials, row_commitments = zip(*self._compute_row_kzg_commitments(chunks_matrix))
extended_matrix = self._rs_encode_rows(chunks_matrix)
row_proofs = self._compute_rows_proofs(extended_matrix, row_polynomials, row_commitments)
column_polynomials, column_commitments = zip(*self._compute_column_kzg_commitments(extended_matrix))
aggregated_column_polynomial, aggregated_column_commitment = (
self._compute_aggregated_column_commitment(extended_matrix, column_commitments)
)
aggregated_column_proofs = self._compute_aggregated_column_proofs(
aggregated_column_polynomial, column_commitments
)
result = EncodedData(
data,
chunks_matrix,
extended_matrix,
row_commitments,
row_proofs,
column_commitments,
aggregated_column_commitment,
aggregated_column_proofs
)
return result
@staticmethod
def hash_column_and_commitment(column: Column, commitment: Commitment) -> bytes:
return (
# digest size must be 31 bytes as we cannot encode 32 without risking overflowing the BLS_MODULUS
int.from_bytes(blake2b(column.as_bytes() + bytes(commitment), digest_size=31).digest())
).to_bytes(32, byteorder="big")

0
da/kzg_rs/__init__.py Normal file
View File

22
da/kzg_rs/common.py Normal file
View File

@ -0,0 +1,22 @@
from typing import List, Tuple
import eth2spec.eip7594.mainnet
from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed
from da.kzg_rs.roots import compute_roots_of_unity
from da.kzg_rs.trusted_setup import generate_setup
G1 = G1Uncompressed
G2 = G2Uncompressed
BYTES_PER_FIELD_ELEMENT = 32
BLS_MODULUS = eth2spec.eip7594.mainnet.BLS_MODULUS
PRIMITIVE_ROOT: int = 7
GLOBAL_PARAMETERS: List[G1]
GLOBAL_PARAMETERS_G2: List[G2]
# secret is fixed but this should come from a different synchronization protocol
GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(4096, 8, 1987))
ROOTS_OF_UNITY: Tuple[int] = compute_roots_of_unity(
PRIMITIVE_ROOT, 4096, BLS_MODULUS
)

65
da/kzg_rs/fft.py Normal file
View File

@ -0,0 +1,65 @@
from typing import Sequence, List
from eth2spec.deneb.mainnet import BLSFieldElement
from eth2spec.utils import bls
from da.kzg_rs.common import G1
def fft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]:
if len(vals) == 1:
return vals
L = fft_g1(vals[::2], roots_of_unity[::2], modulus)
R = fft_g1(vals[1::2], roots_of_unity[::2], modulus)
o = [bls.Z1() for _ in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = bls.multiply(y, roots_of_unity[i])
o[i] = (x + y_times_root)
o[i + len(L)] = x + -y_times_root
return o
def ifft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]:
assert len(vals) == len(roots_of_unity)
# modular inverse
invlen = pow(len(vals), modulus-2, modulus)
return [
bls.multiply(x, invlen)
for x in fft_g1(
vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus
)
]
def _fft(
vals: Sequence[BLSFieldElement],
roots_of_unity: Sequence[BLSFieldElement],
modulus: int,
) -> Sequence[BLSFieldElement]:
if len(vals) == 1:
return vals
L = _fft(vals[::2], roots_of_unity[::2], modulus)
R = _fft(vals[1::2], roots_of_unity[::2], modulus)
o = [BLSFieldElement(0) for _ in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = BLSFieldElement((int(y) * int(roots_of_unity[i])) % modulus)
o[i] = BLSFieldElement((int(x) + y_times_root) % modulus)
o[i + len(L)] = BLSFieldElement((int(x) - int(y_times_root) + modulus) % modulus)
return o
def fft(vals, root_of_unity, modulus):
assert len(vals) == len(root_of_unity)
return _fft(vals, root_of_unity, modulus)
def ifft(vals, roots_of_unity, modulus):
assert len(vals) == len(roots_of_unity)
# modular inverse
invlen = pow(len(vals), modulus-2, modulus)
return [
BLSFieldElement((int(x) * invlen) % modulus)
for x in _fft(
vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus
)
]

75
da/kzg_rs/fk20.py Normal file
View File

@ -0,0 +1,75 @@
from typing import List, Sequence
from eth2spec.deneb.mainnet import KZGProof as Proof, BLSFieldElement
from eth2spec.utils import bls
from da.kzg_rs.common import G1, BLS_MODULUS, PRIMITIVE_ROOT
from da.kzg_rs.fft import fft, fft_g1, ifft_g1
from da.kzg_rs.poly import Polynomial
from da.kzg_rs.roots import compute_roots_of_unity
from da.kzg_rs.utils import is_power_of_two
def __toeplitz1(global_parameters: List[G1], polynomial_degree: int) -> List[G1]:
"""
This part can be precomputed for different global_parameters lengths depending on polynomial degree of powers of two.
:param global_parameters:
:param roots_of_unity:
:param polynomial_degree:
:return:
"""
assert len(global_parameters) == polynomial_degree
# algorithm only works on powers of 2 for dft computations
assert is_power_of_two(len(global_parameters))
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, polynomial_degree*2, BLS_MODULUS)
vector_x_extended = global_parameters + [bls.Z1() for _ in range(polynomial_degree)]
vector_x_extended_fft = fft_g1(vector_x_extended, roots_of_unity, BLS_MODULUS)
return vector_x_extended_fft
def __toeplitz2(coefficients: List[BLSFieldElement], extended_vector: Sequence[G1]) -> List[G1]:
assert is_power_of_two(len(coefficients))
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, len(coefficients), BLS_MODULUS)
toeplitz_coefficients_fft = fft(coefficients, roots_of_unity, BLS_MODULUS)
return [bls.multiply(v, c) for v, c in zip(extended_vector, toeplitz_coefficients_fft)]
def __toeplitz3(h_extended_fft: Sequence[G1], polynomial_degree: int) -> List[G1]:
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, len(h_extended_fft), BLS_MODULUS)
return ifft_g1(h_extended_fft, roots_of_unity, BLS_MODULUS)[:polynomial_degree]
def fk20_generate_proofs(
polynomial: Polynomial, global_parameters: List[G1]
) -> List[Proof]:
"""
Generate all proofs for the polynomial points in batch.
This method uses the fk20 algorthm from https://eprint.iacr.org/2023/033.pdf
Disclaimer: It only works for polynomial degree of powers of two.
:param polynomial: polynomial to generate proof for
:param global_parameters: setup generated parameters
:return: list of proof for each point in the polynomial
"""
polynomial_degree = len(polynomial)
assert len(global_parameters) >= polynomial_degree
assert is_power_of_two(len(polynomial))
# 1 - Build toeplitz matrix for h values
# 1.1 y = dft([s^d-1, s^d-2, ..., s, 1, *[0 for _ in len(polynomial)]])
# 1.2 z = dft([*[0 for _ in len(polynomial)], f1, f2, ..., fd])
# 1.3 u = y * v * roots_of_unity(len(polynomial)*2)
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, polynomial_degree, BLS_MODULUS)
global_parameters = list(reversed(global_parameters[:polynomial_degree]))
extended_vector = __toeplitz1(global_parameters, polynomial_degree)
# 2 - Build circulant matrix with the polynomial coefficients (reversed N..n, and padded)
toeplitz_coefficients = [
*(BLSFieldElement(0) for _ in range(polynomial_degree)),
*polynomial.coefficients
]
h_extended_vector = __toeplitz2(toeplitz_coefficients, extended_vector)
# 3 - Perform fft and nub the tail half as it is padding
h_vector = __toeplitz3(h_extended_vector, polynomial_degree)
# 4 - proof are the dft of the h vector
proofs = fft_g1(h_vector, roots_of_unity, BLS_MODULUS)
proofs = [Proof(bls.G1_to_bytes48(proof)) for proof in proofs]
return proofs

74
da/kzg_rs/kzg.py Normal file
View File

@ -0,0 +1,74 @@
from functools import reduce
from itertools import batched
from typing import Sequence, Tuple
from eth2spec.deneb.mainnet import bytes_to_bls_field, BLSFieldElement, KZGCommitment as Commitment, KZGProof as Proof
from eth2spec.utils import bls
from .common import BYTES_PER_FIELD_ELEMENT, G1, BLS_MODULUS, GLOBAL_PARAMETERS_G2
from .poly import Polynomial
def bytes_to_polynomial(b: bytes, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT) -> Polynomial:
"""
Convert bytes to list of BLS field scalars.
"""
assert len(b) % bytes_per_field_element == 0
eval_form = [int(bytes_to_bls_field(b)) for b in batched(b, int(bytes_per_field_element))]
return Polynomial.from_evaluations(eval_form, BLS_MODULUS)
def g1_linear_combination(polynomial: Polynomial[BLSFieldElement], global_parameters: Sequence[G1]) -> Commitment:
"""
BLS multiscalar multiplication.
"""
# we assert to have more points available than elements,
# this is dependent on the available kzg setup size
assert len(polynomial) <= len(global_parameters)
point = reduce(
bls.add,
(bls.multiply(g, p) for g, p in zip(global_parameters, polynomial)),
bls.Z1()
)
return Commitment(bls.G1_to_bytes48(point))
def bytes_to_commitment(b: bytes, global_parameters: Sequence[G1]) -> Tuple[Polynomial, Commitment]:
poly = bytes_to_polynomial(b, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT)
return poly, g1_linear_combination(poly, global_parameters)
def generate_element_proof(
element_index: int,
polynomial: Polynomial,
global_parameters: Sequence[G1],
roots_of_unity: Sequence[BLSFieldElement],
) -> Proof:
# compute a witness polynomial in that satisfies `witness(x) = (f(x)-v)/(x-u)`
u = int(roots_of_unity[element_index])
v = polynomial.eval(u)
f_x_v = polynomial - Polynomial([v], BLS_MODULUS)
x_u = Polynomial([-u, 1], BLS_MODULUS)
witness, _ = f_x_v / x_u
return g1_linear_combination(witness, global_parameters)
def verify_element_proof(
chunk: BLSFieldElement,
commitment: Commitment,
proof: Proof,
element_index: int,
roots_of_unity: Sequence[BLSFieldElement],
) -> bool:
u = int(roots_of_unity[element_index])
v = chunk
commitment_check_G1 = bls.bytes48_to_G1(commitment) - bls.multiply(bls.G1(), v)
proof_check_g2 = bls.add(
GLOBAL_PARAMETERS_G2[1],
bls.neg(bls.multiply(bls.G2(), u))
)
return bls.pairing_check([
# G2 here needs to be negated due to library requirements as pairing_check([[G1, -G2], [G1, G2]])
[commitment_check_G1, bls.neg(bls.G2())],
[bls.bytes48_to_G1(proof), proof_check_g2],
])

111
da/kzg_rs/poly.py Normal file
View File

@ -0,0 +1,111 @@
from itertools import zip_longest
from typing import List, Sequence, Self
from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff
from da.kzg_rs.common import ROOTS_OF_UNITY
class Polynomial[T]:
def __init__(self, coefficients, modulus):
self.coefficients = coefficients
self.modulus = modulus
@staticmethod
def interpolate(evaluations: List[int], roots_of_unity: List[int]) -> List[int]:
"""
Lagrange interpolation
Parameters:
evaluations: List of evaluations
roots_of_unity: Powers of 2 sequence
Returns:
list: Coefficients of the interpolated polynomial
"""
return list(map(int, interpolate_polynomialcoeff(roots_of_unity[:len(evaluations)], evaluations)))
@classmethod
def from_evaluations(cls, evaluations: Sequence[T], modulus, roots_of_unity: Sequence[int]=ROOTS_OF_UNITY) -> Self:
coefficients = [
x % modulus
for x in map(int, Polynomial.interpolate(evaluations, roots_of_unity))
]
return cls(coefficients, modulus)
def __repr__(self):
return "Polynomial({}, modulus={})".format(self.coefficients, self.modulus)
def __add__(self, other):
return Polynomial(
[(a + b) % self.modulus for a, b in zip_longest(self.coefficients, other.coefficients, fillvalue=0)],
self.modulus
)
def __sub__(self, other):
return Polynomial(
[(a - b) % self.modulus for a, b in zip_longest(self.coefficients, other.coefficients, fillvalue=0)],
self.modulus
)
def __mul__(self, other):
result = [0] * (len(self.coefficients) + len(other.coefficients) - 1)
for i in range(len(self.coefficients)):
for j in range(len(other.coefficients)):
result[i + j] = (result[i + j] + self.coefficients[i] * other.coefficients[j]) % self.modulus
return Polynomial(result, self.modulus)
def divide(self, other):
if not isinstance(other, Polynomial):
raise ValueError("Unsupported type for division.")
dividend = list(self.coefficients)
divisor = list(other.coefficients)
quotient = []
remainder = dividend
while len(remainder) >= len(divisor):
factor = remainder[-1] * pow(divisor[-1], -1, self.modulus) % self.modulus
quotient.insert(0, factor)
# Subtract divisor * factor from remainder
for i in range(len(divisor)):
remainder[len(remainder) - len(divisor) + i] -= divisor[i] * factor
remainder[len(remainder) - len(divisor) + i] %= self.modulus
# Remove leading zeros from remainder
while remainder and remainder[-1] == 0:
remainder.pop()
return Polynomial(quotient, self.modulus), Polynomial(remainder, self.modulus)
def __truediv__(self, other):
return self.divide(other)
def __neg__(self):
return Polynomial([-1 * c for c in self.coefficients], self.modulus)
def __len__(self):
return len(self.coefficients)
def __iter__(self):
return iter(self.coefficients)
def __getitem__(self, item):
return self.coefficients[item]
def __eq__(self, other):
return (
self.coefficients == other.coefficients and
self.modulus == other.modulus
)
def eval(self, x):
return (self.coefficients[0] + sum(
(pow(x, i, mod=self.modulus)*coefficient)
for i, coefficient in enumerate(self.coefficients[1:], start=1)
)) % self.modulus
def evaluation_form(self) -> List[T]:
return [self.eval(ROOTS_OF_UNITY[i]) for i in range(len(self))]

25
da/kzg_rs/roots.py Normal file
View File

@ -0,0 +1,25 @@
from typing import Tuple
def compute_root_of_unity(primitive_root: int, order: int, modulus: int) -> int:
"""
Generate a w such that ``w**length = 1``.
"""
assert (modulus - 1) % order == 0
return pow(primitive_root, (modulus - 1) // order, modulus)
def compute_roots_of_unity(primitive_root: int, order: int, modulus: int) -> Tuple[int]:
"""
Compute a list of roots of unity for a given order.
The order must divide the BLS multiplicative group order, i.e. BLS_MODULUS - 1
"""
assert (modulus - 1) % order == 0
root_of_unity = compute_root_of_unity(primitive_root, order, modulus)
roots = []
current_root_of_unity = 1
for _ in range(order):
roots.append(current_root_of_unity)
current_root_of_unity = current_root_of_unity * root_of_unity % modulus
return tuple(roots)

40
da/kzg_rs/rs.py Normal file
View File

@ -0,0 +1,40 @@
from typing import Sequence, Optional
from eth2spec.deneb.mainnet import BLSFieldElement
from .common import BLS_MODULUS
from .poly import Polynomial
ExtendedData = Sequence[Optional[BLSFieldElement]]
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[int]) -> ExtendedData:
"""
Encode a polynomial extending to the given factor
Parameters:
polynomial: Polynomial to be encoded
factor: Encoding factor
roots_of_unity: Powers of 2 sequence
Returns:
list: Extended data set
"""
assert factor >= 2
assert len(polynomial)*factor <= len(roots_of_unity)
return [polynomial.eval(e) for e in roots_of_unity[:len(polynomial)*factor]]
def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], original_len: int) -> Polynomial:
"""
Decode a polynomial from an extended data-set and the roots of unity, cap to original length
Parameters:
encoded: Extended data set
roots_of_unity: Powers of 2 sequence
original_len: Original length of the encoded polynomial
Returns:
Polynomial: original polynomial
"""
encoded, roots_of_unity = zip(*((point, root) for point, root in zip(encoded, roots_of_unity) if point is not None))
coefs = Polynomial.interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))[:original_len]
return Polynomial([int(c) for c in coefs], BLS_MODULUS)

14
da/kzg_rs/test_fft.py Normal file
View File

@ -0,0 +1,14 @@
from unittest import TestCase
from .roots import compute_roots_of_unity
from .common import BLS_MODULUS
from .fft import fft, ifft
class TestFFT(TestCase):
def test_fft_ifft(self):
for size in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
roots_of_unity = compute_roots_of_unity(2, size, BLS_MODULUS)
vals = list(x for x in range(size))
vals_fft = fft(vals, roots_of_unity, BLS_MODULUS)
self.assertEqual(vals, ifft(vals_fft, roots_of_unity, BLS_MODULUS))

28
da/kzg_rs/test_fk20.py Normal file
View File

@ -0,0 +1,28 @@
from itertools import chain
from unittest import TestCase
import random
from .fk20 import fk20_generate_proofs
from .kzg import generate_element_proof, bytes_to_polynomial
from .common import BLS_MODULUS, BYTES_PER_FIELD_ELEMENT, GLOBAL_PARAMETERS, PRIMITIVE_ROOT
from .roots import compute_roots_of_unity
class TestFK20(TestCase):
@staticmethod
def rand_bytes(n_chunks=1024):
return bytes(
chain.from_iterable(
int.to_bytes(random.randrange(BLS_MODULUS), length=BYTES_PER_FIELD_ELEMENT)
for _ in range(n_chunks)
)
)
def test_fk20(self):
for size in [16, 32, 64, 128, 256]:
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, size, BLS_MODULUS)
rand_bytes = self.rand_bytes(size)
polynomial = bytes_to_polynomial(rand_bytes)
proofs = [generate_element_proof(i, polynomial, GLOBAL_PARAMETERS, roots_of_unity) for i in range(size)]
fk20_proofs = fk20_generate_proofs(polynomial, GLOBAL_PARAMETERS)
self.assertEqual(len(proofs), len(fk20_proofs))
self.assertEqual(proofs, fk20_proofs)

67
da/kzg_rs/test_kzg.py Normal file
View File

@ -0,0 +1,67 @@
from itertools import chain, batched
from random import randrange
from unittest import TestCase
from eth2spec.deneb.mainnet import BLS_MODULUS, bytes_to_bls_field, BLSFieldElement
from da.kzg_rs import kzg
from da.kzg_rs.common import BYTES_PER_FIELD_ELEMENT, GLOBAL_PARAMETERS, ROOTS_OF_UNITY, GLOBAL_PARAMETERS_G2
from da.kzg_rs.trusted_setup import verify_setup
class TestKZG(TestCase):
@staticmethod
def rand_bytes(n_chunks=1024):
return bytes(
chain.from_iterable(
int.to_bytes(randrange(BLS_MODULUS), length=BYTES_PER_FIELD_ELEMENT)
for _ in range(n_chunks)
)
)
def test_kzg_setup(self):
self.assertTrue(verify_setup((GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2)))
def test_poly_forms(self):
n_chunks = 16
rand_bytes = self.rand_bytes(n_chunks)
eval_form = [int(bytes_to_bls_field(b)) for b in batched(rand_bytes, int(BYTES_PER_FIELD_ELEMENT))]
poly = kzg.bytes_to_polynomial(rand_bytes)
self.assertEqual(poly.evaluation_form(), eval_form)
for i, chunk in enumerate(eval_form):
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), chunk)
for i in range(n_chunks):
self.assertEqual(poly.evaluation_form()[i], poly.eval(int(ROOTS_OF_UNITY[i])))
def test_commitment(self):
rand_bytes = self.rand_bytes(32)
_, commit = kzg.bytes_to_commitment(rand_bytes, GLOBAL_PARAMETERS)
self.assertEqual(len(commit), 48)
def test_proof(self):
rand_bytes = self.rand_bytes(2)
poly = kzg.bytes_to_polynomial(rand_bytes)
proof = kzg.generate_element_proof(0, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
self.assertEqual(len(proof), 48)
def test_verify(self):
n_chunks = 32
rand_bytes = self.rand_bytes(n_chunks)
_, commit = kzg.bytes_to_commitment(rand_bytes, GLOBAL_PARAMETERS)
poly = kzg.bytes_to_polynomial(rand_bytes)
for i, chunk in enumerate(batched(rand_bytes, BYTES_PER_FIELD_ELEMENT)):
chunk = bytes(chunk)
proof = kzg.generate_element_proof(i, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
self.assertEqual(len(proof), 48)
self.assertEqual(poly.eval(int(ROOTS_OF_UNITY[i])), bytes_to_bls_field(chunk))
self.assertTrue(kzg.verify_element_proof(
bytes_to_bls_field(chunk), commit, proof, i, ROOTS_OF_UNITY
)
)
proof = kzg.generate_element_proof(0, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
for n in range(1, n_chunks):
self.assertFalse(kzg.verify_element_proof(
BLSFieldElement(0), commit, proof, n, ROOTS_OF_UNITY
)
)

18
da/kzg_rs/test_rs.py Normal file
View File

@ -0,0 +1,18 @@
from unittest import TestCase
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
from da.kzg_rs.poly import Polynomial
from da.kzg_rs.rs import encode, decode
class TestFFT(TestCase):
def test_encode_decode(self):
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
encoded = encode(poly, 2, ROOTS_OF_UNITY)
# remove a few points, but enough so we can reconstruct
for i in [1, 3, 7]:
encoded[i] = None
decoded = decode(encoded, ROOTS_OF_UNITY, len(poly))
# self.assertEqual(poly, decoded)
for i in range(len(poly)):
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))

View File

@ -0,0 +1,48 @@
import random
from typing import Tuple, Sequence, Generator
from eth2spec.utils import bls
from itertools import accumulate, repeat
def __linear_combination(points, coeffs, zero=bls.Z1()):
o = zero
for point, coeff in zip(points, coeffs):
o = bls.add(o, bls.multiply(point, coeff))
return o
# Verifies the integrity of a setup
def verify_setup(setup) -> bool:
g1_setup, g2_setup = setup
g1_random_coefficients = [random.randrange(2**40) for _ in range(len(g1_setup) - 1)]
g1_lower = __linear_combination(g1_setup[:-1], g1_random_coefficients, bls.Z1())
g1_upper = __linear_combination(g1_setup[1:], g1_random_coefficients, bls.Z1())
g2_random_coefficients = [random.randrange(2**40) for _ in range(len(g2_setup) - 1)]
g2_lower = __linear_combination(g2_setup[:-1], g2_random_coefficients, bls.Z2())
g2_upper = __linear_combination(g2_setup[1:], g2_random_coefficients, bls.Z2())
return (
g1_setup[0] == bls.G1() and
g2_setup[0] == bls.G2() and
bls.pairing_check([[g1_upper, bls.neg(g2_lower)], [g1_lower, g2_upper]])
)
def generate_one_sided_setup(length, secret, generator=bls.G1()):
def __take(gen):
return (next(gen) for _ in range(length))
secrets = repeat(secret)
return __take(accumulate(secrets, bls.multiply, initial=generator))
# Generate a trusted setup with the given secret
def generate_setup(
g1_length,
g2_length,
secret
) -> Tuple[Generator[bls.G1, None, None], Generator[bls.G2, None, None]]:
return (
generate_one_sided_setup(g1_length, secret, bls.G1()),
generate_one_sided_setup(g2_length, secret, bls.G2()),
)

5
da/kzg_rs/utils.py Normal file
View File

@ -0,0 +1,5 @@
POWERS_OF_2 = {2**i for i in range(1, 32)}
def is_power_of_two(n) -> bool:
return n in POWERS_OF_2

0
da/network/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

21
da/network/constants.py Normal file
View File

@ -0,0 +1,21 @@
from libp2p.typing import TProtocol
"""
Some constants for use throught the poc
"""
PROTOCOL_ID = TProtocol("/nomosda/1.0.0")
MAX_READ_LEN = 2**32 - 1
HASH_LENGTH = 256
NODE_PORT_BASE = 7560
EXECUTOR_PORT = 8766
# These can be overridden with cli params
DEFAULT_DATA_SIZE = 1024
DEFAULT_SUBNETS = 256
DEFAULT_NODES = 32
DEFAULT_SAMPLE_THRESHOLD = 12
# how many nodes per subnet minimum
DEFAULT_REPLICATION_FACTOR = 4
DEBUG = False

View File

@ -0,0 +1,76 @@
# Zone Executor to Nomos DA Communication
Protocol for communication between the Zone Executor and Nomos DA using Protocol Buffers (protobuf).
## Overview
The protocol defines messages used to request and respond to data dispersal, sampling operations, and session control within the Nomos DA system. The communication involves the exchange of blobs (binary large objects) and error handling for various operations.
## Messages
### Blob
- **Blob**: Represents the binary data to be dispersed.
- `bytes blob_id`: Unique identifier for the blob.
- `bytes data`: The binary data of the blob.
### Error Handling
- **DispersalErr**: Represents errors related to dispersal operations.
- `bytes blob_id`: Unique identifier of the blob related to the error.
- `enum DispersalErrType`: Enumeration of dispersal error types.
- `CHUNK_SIZE`: Error due to incorrect chunk size.
- `VERIFICATION`: Error due to verification failure.
- `string err_description`: Description of the error.
- **SampleErr**: Represents errors related to sample operations.
- `bytes blob_id`: Unique identifier of the blob related to the error.
- `enum SampleErrType`: Enumeration of sample error types.
- `NOT_FOUND`: Error when a blob is not found.
- `string err_description`: Description of the error.
### Dispersal
- **DispersalReq**: Request message for dispersing a blob.
- `Blob blob`: The blob to be dispersed.
- **DispersalRes**: Response message for a dispersal request.
- `oneof message_type`: Contains either a success response or an error.
- `bytes blob_id`: Unique identifier of the dispersed blob.
- `DispersalErr err`: Error occurred during dispersal.
### Sample
- **SampleReq**: Request message for sampling a blob.
- `bytes blob_id`: Unique identifier of the blob to be sampled.
- **SampleRes**: Response message for a sample request.
- `oneof message_type`: Contains either a success response or an error.
- `Blob blob`: The sampled blob.
- `SampleErr err`: Error occurred during sampling.
### Session Control
- **CloseMsg**: Message to close a session with a reason.
- `enum CloseReason`: Enumeration of close reasons.
- `GRACEFUL_SHUTDOWN`: Graceful shutdown of the session.
- `SUBNET_CHANGE`: Change in the subnet.
- `SUBNET_SAMPLE_FAIL`: Subnet sample failure.
- `CloseReason reason`: Reason for closing the session.
- **SessionReq**: Request message for session control.
- `oneof message_type`: Contains one of the following message types.
- `CloseMsg close_msg`: Message to close the session.
### DispersalMessage
- **DispersalMessage**: Wrapper message for different types of dispersal and sampling messages.
- `oneof message_type`: Contains one of the following message types.
- `DispersalReq dispersal_req`: Dispersal request.
- `DispersalRes dispersal_res`: Dispersal response.
- `SampleReq sample_req`: Sample request.
- `SampleRes sample_res`: Sample response.
## Protobuf
To generate the updated protobuf serializer from `dispersal.proto`, run the following command:
```bash
protoc --python_out=. dispersal.proto
```
This will generate the necessary Python code to serialize and deserialize the messages defined in the `dispersal.proto` file.

View File

Binary file not shown.

View File

@ -0,0 +1,87 @@
syntax = "proto3";
package nomos.da.dispersal.v1;
message Blob {
bytes blob_id = 1;
bytes data = 2;
}
// DISPERSAL
message DispersalErr {
bytes blob_id = 1;
enum DispersalErrType {
CHUNK_SIZE = 0;
VERIFICATION = 1;
}
DispersalErrType err_type = 2;
string err_description = 3;
}
message DispersalReq {
Blob blob = 1;
}
message DispersalRes {
oneof message_type {
bytes blob_id = 1;
DispersalErr err = 2;
}
}
// SAMPLING
message SampleErr {
bytes blob_id = 1;
enum SampleErrType {
NOT_FOUND = 0;
}
SampleErrType err_type = 2;
string err_description = 3;
}
message SampleReq {
bytes blob_id = 1;
}
message SampleRes {
oneof message_type {
Blob blob = 1;
SampleErr err = 2;
}
}
// SESSION CONTROL
message CloseMsg {
enum CloseReason {
GRACEFUL_SHUTDOWN = 0;
SUBNET_CHANGE = 1;
SUBNET_SAMPLE_FAIL = 2;
}
CloseReason reason = 1;
}
message SessionReq {
oneof message_type {
CloseMsg close_msg = 1;
}
}
// WRAPPER MESSAGE
message DispersalMessage {
oneof message_type {
DispersalReq dispersal_req = 1;
DispersalRes dispersal_res = 2;
SampleReq sample_req = 3;
SampleRes sample_res = 4;
SessionReq session_req = 5;
}
}

View File

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: dispersal.proto
# Protobuf Python Version: 5.27.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
27,
1,
'',
'dispersal.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ispersal.proto\x12\x15nomos.da.dispersal.v1\"%\n\x04\x42lob\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\xb6\x01\n\x0c\x44ispersalErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x46\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32\x34.nomos.da.dispersal.v1.DispersalErr.DispersalErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"4\n\x10\x44ispersalErrType\x12\x0e\n\nCHUNK_SIZE\x10\x00\x12\x10\n\x0cVERIFICATION\x10\x01\"9\n\x0c\x44ispersalReq\x12)\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x1b.nomos.da.dispersal.v1.Blob\"e\n\x0c\x44ispersalRes\x12\x11\n\x07\x62lob_id\x18\x01 \x01(\x0cH\x00\x12\x32\n\x03\x65rr\x18\x02 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalErrH\x00\x42\x0e\n\x0cmessage_type\"\x97\x01\n\tSampleErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12@\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32..nomos.da.dispersal.v1.SampleErr.SampleErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"\x1e\n\rSampleErrType\x12\r\n\tNOT_FOUND\x10\x00\"\x1c\n\tSampleReq\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\"y\n\tSampleRes\x12+\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x1b.nomos.da.dispersal.v1.BlobH\x00\x12/\n\x03\x65rr\x18\x02 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleErrH\x00\x42\x0e\n\x0cmessage_type\"\x98\x01\n\x08\x43loseMsg\x12;\n\x06reason\x18\x01 \x01(\x0e\x32+.nomos.da.dispersal.v1.CloseMsg.CloseReason\"O\n\x0b\x43loseReason\x12\x15\n\x11GRACEFUL_SHUTDOWN\x10\x00\x12\x11\n\rSUBNET_CHANGE\x10\x01\x12\x16\n\x12SUBNET_SAMPLE_FAIL\x10\x02\"R\n\nSessionReq\x12\x34\n\tclose_msg\x18\x01 \x01(\x0b\x32\x1f.nomos.da.dispersal.v1.CloseMsgH\x00\x42\x0e\n\x0cmessage_type\"\xc8\x02\n\x10\x44ispersalMessage\x12<\n\rdispersal_req\x18\x01 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalReqH\x00\x12<\n\rdispersal_res\x18\x02 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalResH\x00\x12\x36\n\nsample_req\x18\x03 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleReqH\x00\x12\x36\n\nsample_res\x18\x04 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleResH\x00\x12\x38\n\x0bsession_req\x18\x05 \x01(\x0b\x32!.nomos.da.dispersal.v1.SessionReqH\x00\x42\x0e\n\x0cmessage_typeb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dispersal_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_BLOB']._serialized_start=42
_globals['_BLOB']._serialized_end=79
_globals['_DISPERSALERR']._serialized_start=82
_globals['_DISPERSALERR']._serialized_end=264
_globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_start=212
_globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_end=264
_globals['_DISPERSALREQ']._serialized_start=266
_globals['_DISPERSALREQ']._serialized_end=323
_globals['_DISPERSALRES']._serialized_start=325
_globals['_DISPERSALRES']._serialized_end=426
_globals['_SAMPLEERR']._serialized_start=429
_globals['_SAMPLEERR']._serialized_end=580
_globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_start=550
_globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_end=580
_globals['_SAMPLEREQ']._serialized_start=582
_globals['_SAMPLEREQ']._serialized_end=610
_globals['_SAMPLERES']._serialized_start=612
_globals['_SAMPLERES']._serialized_end=733
_globals['_CLOSEMSG']._serialized_start=736
_globals['_CLOSEMSG']._serialized_end=888
_globals['_CLOSEMSG_CLOSEREASON']._serialized_start=809
_globals['_CLOSEMSG_CLOSEREASON']._serialized_end=888
_globals['_SESSIONREQ']._serialized_start=890
_globals['_SESSIONREQ']._serialized_end=972
_globals['_DISPERSALMESSAGE']._serialized_start=975
_globals['_DISPERSALMESSAGE']._serialized_end=1303
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,123 @@
import asyncio
import argparse
import proto
from itertools import count
conn_id_counter = count(start=1)
class MockTransport:
def __init__(self, conn_id, reader, writer, handler):
self.conn_id = conn_id
self.reader = reader
self.writer = writer
self.handler = handler
async def read_and_process(self):
try:
while True:
message = await proto.unpack_from_reader(self.reader)
await self.handler(self.conn_id, self.writer, message)
except Exception as e:
print(f"MockTransport: An error occurred: {e}")
finally:
self.writer.close()
await self.writer.wait_closed()
async def write(self, message):
self.writer.write(message)
await self.writer.drain()
class MockNode:
def __init__(self, addr, port, handler=None):
self.addr = addr
self.port = port
self.handler = handler if handler else self._handle
async def _on_conn(self, reader, writer):
conn_id = next(conn_id_counter)
transport = MockTransport(conn_id, reader, writer, self.handler)
await transport.read_and_process()
async def _handle(self, conn_id, writer, message):
if message.HasField('dispersal_req'):
blob_id = message.dispersal_req.blob.blob_id
data = message.dispersal_req.blob.data
print(f"MockNode: Received DispersalRes: blob_id={blob_id}; data={data}")
# Imitate succesful verification.
writer.write(proto.new_dispersal_res_success_msg(blob_id))
elif message.HasField('sample_req'):
print(f"MockNode: Received SampleRes: blob_id={message.sample_req.blob_id}")
else:
print(f"MockNode: Received unknown message: {message} ")
async def run(self):
server = await asyncio.start_server(
self._on_conn, self.addr, self.port
)
print(f"MockNode: Server started at {self.addr}:{self.port}")
async with server:
await server.serve_forever()
class MockExecutor:
def __init__(self, addr, port, col_num, executor=None, handler=None):
self.addr = addr
self.port = port
self.col_num = col_num
self.connections = []
self.interval = 10
self.executor = executor if executor else self._execute
self.handler = handler if handler else self._handle
async def _execute(self):
message = proto.new_dispersal_req_msg(b"dummy_blob_id", b"dummy_data")
while True:
try:
await asyncio.gather(*[t.write(message) for t in self.connections])
await asyncio.sleep(self.interval)
except asyncio.CancelledError:
break
except Exception as e:
print(f"MockExecutor: Error during message sending: {e}")
async def _handle(self, conn_id, writer, message):
if message.HasField('dispersal_res'):
print(f"MockExecutor: Received DispersalRes: blob_id={message.dispersal_res.blob_id}")
elif message.HasField('sample_res'):
print(f"MockExecutor: Received SampleRes: blob_id={message.sample_res.blob_id}")
else:
print(f"MockExecutor: Received unknown message: {message}")
async def _connect(self):
try:
reader, writer = await asyncio.open_connection(self.addr, self.port)
conn_id = len(self.connections)
transport = MockTransport(conn_id, reader, writer, self.handler)
self.connections.append(transport)
print(f"MockExecutor: Connected to {self.addr}:{self.port}, ID: {conn_id}")
asyncio.create_task(transport.read_and_process())
except Exception as e:
print(f"MockExecutor: Failed to connect or lost connection: {e}")
async def run(self):
await asyncio.gather(*(self._connect() for _ in range(self.col_num)))
await self.executor()
class MockSystem:
def __init__(self, addr='localhost'):
self.addr = addr
async def run_node_with_executor(self, col_number):
node = MockNode(self.addr, 8888)
executor = MockExecutor(self.addr, 8888, col_number)
await asyncio.gather(node.run(), executor.run())
def main():
app = MockSystem()
asyncio.run(app.run_node_with_executor(1))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,126 @@
from itertools import count
import dispersal.dispersal_pb2 as dispersal_pb2
MAX_MSG_LEN_BYTES = 2
def pack_message(message):
# SerializeToString method returns an instance of bytes.
data = message.SerializeToString()
length_prefix = len(data).to_bytes(MAX_MSG_LEN_BYTES, byteorder="big")
return length_prefix + data
async def unpack_from_reader(reader):
length_prefix = await reader.readexactly(MAX_MSG_LEN_BYTES)
data_length = int.from_bytes(length_prefix, byteorder="big")
data = await reader.readexactly(data_length)
return parse(data)
def unpack_from_bytes(data):
length_prefix = data[:MAX_MSG_LEN_BYTES]
data_length = int.from_bytes(length_prefix, byteorder="big")
return parse(data[MAX_MSG_LEN_BYTES : MAX_MSG_LEN_BYTES + data_length])
def parse(data):
message = dispersal_pb2.DispersalMessage()
message.ParseFromString(data)
return message
# DISPERSAL
def new_dispersal_req_msg(blob_id, data):
blob = dispersal_pb2.Blob(blob_id=blob_id, data=data)
dispersal_req = dispersal_pb2.DispersalReq(blob=blob)
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_req=dispersal_req)
return pack_message(dispersal_message)
def new_dispersal_res_success_msg(blob_id):
dispersal_res = dispersal_pb2.DispersalRes(blob_id=blob_id)
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_res=dispersal_res)
return pack_message(dispersal_message)
def new_dispersal_res_chunk_size_error_msg(blob_id, description):
dispersal_err = dispersal_pb2.DispersalErr(
blob_id=blob_id,
err_type=dispersal_pb2.DispersalErr.CHUNK_SIZE,
err_description=description,
)
dispersal_res = dispersal_pb2.DispersalRes(err=dispersal_err)
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_res=dispersal_res)
return pack_message(dispersal_message)
def new_dispersal_res_verification_error_msg(blob_id, description):
dispersal_err = dispersal_pb2.DispersalErr(
blob_id=blob_id,
err_type=dispersal_pb2.DispersalErr.VERIFICATION,
err_description=description,
)
dispersal_res = dispersal_pb2.DispersalRes(err=dispersal_err)
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_res=dispersal_res)
return pack_message(dispersal_message)
# SAMPLING
def new_sample_req_msg(blob_id):
sample_req = dispersal_pb2.SampleReq(blob_id=blob_id)
dispersal_message = dispersal_pb2.DispersalMessage(sample_req=sample_req)
return pack_message(dispersal_message)
def new_sample_res_success_msg(blob_id, data):
blob = dispersal_pb2.Blob(blob_id=blob_id, data=data)
sample_res = dispersal_pb2.SampleRes(blob=blob)
dispersal_message = dispersal_pb2.DispersalMessage(sample_res=sample_res)
return pack_message(dispersal_message)
def new_sample_res_not_found_error_msg(blob_id, description):
sample_err = dispersal_pb2.SampleErr(
blob_id=blob_id,
err_type=dispersal_pb2.SampleErr.NOT_FOUND,
err_description=description,
)
sample_res = dispersal_pb2.SampleRes(err=sample_err)
dispersal_message = dispersal_pb2.DispersalMessage(sample_res=sample_res)
return pack_message(dispersal_message)
# SESSION CONTROL
def new_close_msg(reason):
close_msg = dispersal_pb2.CloseMsg(reason=reason)
return close_msg
def new_session_req_close_msg(reason):
close_msg = new_close_msg(reason)
session_req = dispersal_pb2.SessionReq(close_msg=close_msg)
dispersal_message = dispersal_pb2.DispersalMessage(session_req=session_req)
return dispersal_message
def new_session_req_graceful_shutdown_msg():
message = new_session_req_close_msg(dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN)
return pack_message(message)
def new_session_req_subnet_change_msg():
message = new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_CHANGE)
return pack_message(message)
def new_session_req_subnet_sample_fail_msg():
message = new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_SAMPLE_FAIL)
return pack_message(message)

View File

@ -0,0 +1,75 @@
import dispersal_pb2
import proto
from unittest import TestCase
class TestMessageSerialization(TestCase):
def test_dispersal_req_msg(self):
blob_id = b"dummy_blob_id"
data = b"dummy_data"
packed_message = proto.new_dispersal_req_msg(blob_id, data)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('dispersal_req'))
self.assertEqual(message.dispersal_req.blob.blob_id, blob_id)
self.assertEqual(message.dispersal_req.blob.data, data)
def test_dispersal_res_success_msg(self):
blob_id = b"dummy_blob_id"
packed_message = proto.new_dispersal_res_success_msg(blob_id)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('dispersal_res'))
self.assertEqual(message.dispersal_res.blob_id, blob_id)
def test_dispersal_res_chunk_size_error_msg(self):
blob_id = b"dummy_blob_id"
description = "Chunk size error"
packed_message = proto.new_dispersal_res_chunk_size_error_msg(blob_id, description)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('dispersal_res'))
self.assertEqual(message.dispersal_res.err.blob_id, blob_id)
self.assertEqual(message.dispersal_res.err.err_type, dispersal_pb2.DispersalErr.CHUNK_SIZE)
self.assertEqual(message.dispersal_res.err.err_description, description)
def test_dispersal_res_verification_error_msg(self):
blob_id = b"dummy_blob_id"
description = "Verification error"
packed_message = proto.new_dispersal_res_verification_error_msg(blob_id, description)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('dispersal_res'))
self.assertEqual(message.dispersal_res.err.blob_id, blob_id)
self.assertEqual(message.dispersal_res.err.err_type, dispersal_pb2.DispersalErr.VERIFICATION)
self.assertEqual(message.dispersal_res.err.err_description, description)
def test_sample_req_msg(self):
blob_id = b"dummy_blob_id"
packed_message = proto.new_sample_req_msg(blob_id)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('sample_req'))
self.assertEqual(message.sample_req.blob_id, blob_id)
def test_sample_res_success_msg(self):
blob_id = b"dummy_blob_id"
data = b"dummy_data"
packed_message = proto.new_sample_res_success_msg(blob_id, data)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('sample_res'))
self.assertEqual(message.sample_res.blob.blob_id, blob_id)
self.assertEqual(message.sample_res.blob.data, data)
def test_sample_res_not_found_error_msg(self):
blob_id = b"dummy_blob_id"
description = "Blob not found"
packed_message = proto.new_sample_res_not_found_error_msg(blob_id, description)
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('sample_res'))
self.assertEqual(message.sample_res.err.blob_id, blob_id)
self.assertEqual(message.sample_res.err.err_type, dispersal_pb2.SampleErr.NOT_FOUND)
self.assertEqual(message.sample_res.err.err_description, description)
def test_session_req_close_msg(self):
reason = dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN
packed_message = proto.new_session_req_graceful_shutdown_msg()
message = proto.unpack_from_bytes(packed_message)
self.assertTrue(message.HasField('session_req'))
self.assertTrue(message.session_req.HasField('close_msg'))
self.assertEqual(message.session_req.close_msg.reason, reason)

106
da/network/executor.py Normal file
View File

@ -0,0 +1,106 @@
from hashlib import sha256
from random import randbytes
from typing import Self
import dispersal.proto as proto
import multiaddr
import trio
from constants import HASH_LENGTH, PROTOCOL_ID
from libp2p import host, new_host
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr
class Executor:
"""
A class for simulating a simple executor.
Runs on hardcoded port.
Creates random data and disperses it.
One packet represents a subnet, and each packet is sent
to one DANode.
"""
listen_addr: multiaddr.Multiaddr
host: host
port: int
num_subnets: int
node_list: {}
# size of a packet
data_size: int
# holds random data for dispersal
data: []
# stores hashes of the data for later verification
data_hashes: []
blob_id: int
@classmethod
def new(cls, port, node_list, num_subnets, data_size) -> Self:
self = cls()
self.listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
self.host = new_host()
self.port = port
self.num_subnets = num_subnets
self.data_size = data_size
# one packet per subnet
self.data = [[] * data_size] * num_subnets
# one hash per packet. **assumes 256 hash length**
self.data_hashes = [[] * HASH_LENGTH] * num_subnets
self.node_list = node_list
# create random simulated data right from the beginning
self.__create_data()
return self
def get_id(self):
return self.host.get_id()
def net_iface(self):
return self.host
def get_port(self):
return self.port
def get_hash(self, index: int):
return self.data_hashes[index]
def __create_data(self):
"""
Create random data for dispersal
One packet of self.data_size length per subnet
"""
id = sha256()
for i in range(self.num_subnets):
self.data[i] = randbytes(self.data_size)
self.data_hashes[i] = sha256(self.data[i]).hexdigest()
id.update(self.data[i])
self.blob_id = id.digest()
async def disperse(self, nursery):
"""
Disperse the data to the DA network.
Sends one packet of data per network node
"""
async with self.host.run(listen_addrs=[self.listen_addr]):
for subnet, nodes in self.node_list.items():
# get first node of each subnet
n = nodes[0]
# connect to it...
await self.host.connect(n)
# ...and send (async)
stream = await self.host.new_stream(n.peer_id, [PROTOCOL_ID])
nursery.start_soon(self.write_data, stream, subnet)
async def write_data(self, stream: INetStream, index: int) -> None:
"""
Send data to peer (async)
The index is the subnet number
"""
blob_id = self.blob_id
blob_data = self.data[index]
message = proto.new_dispersal_req_msg(blob_id, blob_data)
await stream.write(message)

35
da/network/network.py Normal file
View File

@ -0,0 +1,35 @@
import trio
from constants import DEBUG, NODE_PORT_BASE
from node import DANode
class DANetwork:
"""
Lightweight wrapper around a network of DA nodes.
Really just creates the network for now
"""
num_nodes: int
nodes: []
def __init__(self, nodes):
self.num_nodes = nodes
self.nodes = []
async def build(self, nursery, shutdown, disperse_send):
port_idx = NODE_PORT_BASE
for _ in range(self.num_nodes):
port_idx += 1
nursery.start_soon(
DANode.new,
port_idx,
self.nodes,
nursery,
shutdown,
disperse_send.clone(),
)
if DEBUG:
print("net built")
def get_nodes(self):
return self.nodes

129
da/network/node.py Normal file
View File

@ -0,0 +1,129 @@
import sys
from hashlib import sha256
from random import randint
import dispersal.proto as proto
import multiaddr
import trio
from blspy import BasicSchemeMPL, G1Element, PrivateKey
from constants import *
from libp2p import host, new_host
from libp2p.network.stream.exceptions import StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr
class DANode:
"""
A class handling Data Availability (DA).
Runs on a hardcoded port.
Starts a libp2p node.
Listens on a handler for receiving data.
Resends all data it receives to all peers it is connected to
(therefore assumes connection logic is established elsewhere)
"""
listen_addr: multiaddr.Multiaddr
libp2phost: host
port: int
node_list: []
# list of packet hashes it "stores"
hashes: set()
@classmethod
async def new(cls, port, node_list, nursery, shutdown, disperse_send):
self = cls()
self.listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
self.libp2phost = new_host()
self.port = port
self.node_list = node_list
self.hashes = set()
nursery.start_soon(self.__run, nursery, shutdown, disperse_send)
if DEBUG:
print("DA node at port {} initialized".format(port))
def get_id(self):
return self.libp2phost.get_id()
def net_iface(self):
return self.libp2phost
def get_port(self):
return self.port
async def __run(self, nursery, shutdown, disperse_send):
"""
Run the node. Starts libp2p host, and listener for data
"""
async with self.libp2phost.run(listen_addrs=[self.listen_addr]):
print("started node at {}...".format(self.listen_addr))
# handler to run when data is received
async def stream_handler(stream: INetStream) -> None:
nursery.start_soon(
self.read_data, stream, nursery, shutdown, disperse_send
)
# set the above handler
self.libp2phost.set_stream_handler(PROTOCOL_ID, stream_handler)
# at this point the node is "initialized" - signal it's "ready"
self.node_list.append(self)
# run until we shutdown
await shutdown.wait()
async def read_data(
self, stream: INetStream, nursery, shutdown, disperse_send
) -> None:
"""
We need to wait for incoming data, but also we want to shutdown
when the test is finished.
The following code makes sure that both events are listened to
and the first which occurs is handled.
"""
first_event = None
async def select_event(async_fn, cancel_scope):
nonlocal first_event
first_event = await async_fn()
cancel_scope.cancel()
disperse_send.close()
async def read_stream():
while True:
read_bytes = await stream.read(MAX_READ_LEN)
if read_bytes is not None:
message = proto.unpack_from_bytes(read_bytes)
hashstr = sha256(message.dispersal_req.blob.data).hexdigest()
if hashstr not in self.hashes:
# "store" the received packet
self.hashes.add(hashstr)
# now disperse this hash to all peers
nursery.start_soon(self.disperse, read_bytes, disperse_send)
if DEBUG:
print(
"{} stored {}".format(
self.libp2phost.get_id().pretty(), hashstr
)
)
await disperse_send.send(-1)
else:
print("read_bytes is None, unexpected!")
nursery.start_soon(select_event, read_stream, nursery.cancel_scope)
nursery.start_soon(select_event, shutdown.wait, nursery.cancel_scope)
async def disperse(self, packet, disperse_send) -> None:
# disperse the given packet to all peers
for p_id in self.libp2phost.get_peerstore().peer_ids():
if p_id == self.libp2phost.get_id():
continue
await disperse_send.send(1)
stream = await self.libp2phost.new_stream(p_id, [PROTOCOL_ID])
await stream.write(packet)
async def has_hash(self, hashstr: str):
return hashstr in self.hashes

271
da/network/poc.py Normal file
View File

@ -0,0 +1,271 @@
import argparse
import sys
import time
from random import randint
import multiaddr
import trio
from constants import *
from executor import Executor
from libp2p.peer.peerinfo import info_from_p2p_addr
from network import DANetwork
from subnet import calculate_subnets
"""
Entry point for the poc.
Handles cli arguments, initiates the network
and waits for it to complete.
Also does some simple completion check.
"""
async def run_network(params):
"""
Create the network.
Run the run_subnets
"""
num_nodes = int(params.nodes)
net = DANetwork(num_nodes)
shutdown = trio.Event()
disperse_send, disperse_recv = trio.open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(net.build, nursery, shutdown, disperse_send)
nursery.start_soon(
run_subnets, net, params, nursery, shutdown, disperse_send, disperse_recv
)
async def run_subnets(net, params, nursery, shutdown, disperse_send, disperse_recv):
"""
Run the actual PoC logic.
Calculate the subnets.
-> Establish connections based on the subnets <-
Runs the executor.
Runs simulated sampling.
Runs simple completion check
"""
num_nodes = int(params.nodes)
num_subnets = int(params.subnets)
data_size = int(params.data_size)
sample_threshold = int(params.sample_threshold)
fault_rate = int(params.fault_rate)
replication_factor = int(params.replication_factor)
while len(net.get_nodes()) != num_nodes:
print("nodes not ready yet")
await trio.sleep(0.1)
print("Nodes ready")
nodes = net.get_nodes()
subnets = calculate_subnets(nodes, num_subnets, replication_factor)
await print_subnet_info(subnets)
print("Establishing connections...")
node_list = {}
all_node_instances = set()
await establish_connections(subnets, node_list, all_node_instances, fault_rate)
print("Starting executor...")
exe = Executor.new(EXECUTOR_PORT, node_list, num_subnets, data_size)
print("Start dispersal and wait to complete...")
print("depending on network and subnet size this may take a while...")
global TIMESTAMP
TIMESTAMP = time.time()
async with trio.open_nursery() as subnursery:
subnursery.start_soon(wait_disperse_finished, disperse_recv, num_subnets)
subnursery.start_soon(exe.disperse, nursery)
subnursery.start_soon(disperse_watcher, disperse_send.clone())
print()
print()
print("OK. Start sampling...")
checked = []
for _ in range(sample_threshold):
nursery.start_soon(sample_node, exe, subnets, checked)
print("Waiting for sampling to finish...")
await check_complete(checked, sample_threshold)
print_connections(all_node_instances)
print("Test completed")
shutdown.set()
TIMESTAMP = time.time()
def print_connections(node_list):
for n in node_list:
for p in n.net_iface().get_peerstore().peer_ids():
if p == n.net_iface().get_id():
continue
print("node {} is connected to {}".format(n.get_id(), p))
print()
async def disperse_watcher(disperse_send):
while time.time() - TIMESTAMP < 5:
await trio.sleep(1)
await disperse_send.send(9999)
print("canceled")
async def wait_disperse_finished(disperse_recv, num_subnets):
# run until there are no changes detected
async for value in disperse_recv:
if value == 9999:
print("dispersal finished")
return
print(".", end="")
global TIMESTAMP
TIMESTAMP = time.time()
async def print_subnet_info(subnets):
"""
Print which node is in what subnet
"""
print()
print("By subnets: ")
for subnet in subnets:
print("subnet: {} - ".format(subnet), end="")
for n in subnets[subnet]:
print(n.get_id().pretty()[:16], end=", ")
print()
print()
print()
async def establish_connections(subnets, node_list, all_node_instances, fault_rate=0):
"""
Each node in a subnet connects to the other ones in that subnet.
"""
for subnet in subnets:
# n is a DANode
for n in subnets[subnet]:
# while nodes connect to each other, they are **mutually** added
# to their peer lists. Hence, we don't need to establish connections
# again to peers we are already connected.
# So in each iteration we get the peer list for the current node
# to later check if we are already connected with the next peer
this_nodes_peers = n.net_iface().get_peerstore().peer_ids()
all_node_instances.add(n)
faults = []
for i in range(fault_rate):
faults.append(randint(0, len(subnets[subnet])))
for i, nn in enumerate(subnets[subnet]):
# don't connect to self
if nn.get_id() == n.get_id():
continue
if i in faults:
continue
remote_id = nn.get_id().pretty()
remote_port = nn.get_port()
# this script only works on localhost!
addr = "/ip4/127.0.0.1/tcp/{}/p2p/{}/".format(remote_port, remote_id)
remote_addr = multiaddr.Multiaddr(addr)
remote = info_from_p2p_addr(remote_addr)
if subnet not in node_list:
node_list[subnet] = []
node_list[subnet].append(remote)
# check if we are already connected with this peer. If yes, skip connecting
if nn.get_id() in this_nodes_peers:
continue
if DEBUG:
print("{} connecting to {}...".format(n.get_id(), addr))
await n.net_iface().connect(remote)
print()
async def check_complete(checked, sample_threshold):
"""
Simple completion check:
Check how many nodes have already been "sampled"
"""
while len(checked) < sample_threshold:
await trio.sleep(0.5)
print("check_complete exiting")
return
async def sample_node(exe, subnets, checked):
"""
Pick a random subnet.
Pick a random node in that subnet.
As the executor has a list of hashes per subnet,
we can ask that node if it has that hash.
"""
# s: subnet
s = randint(0, len(subnets) - 1)
# n: node (index)
n = randint(0, len(subnets[s]) - 1)
# actual node
node = subnets[s][n]
# pick the hash to check
hashstr = exe.get_hash(s)
# run the "sampling"
has = await node.has_hash(hashstr)
if has:
print("node {} has hash {}".format(node.get_id().pretty(), hashstr))
else:
print("node {} does NOT HAVE hash {}".format(node.get_id().pretty(), hashstr))
print("TEST FAILED")
# signal we "sampled" another node
checked.append(1)
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--subnets", help="Number of subnets [default: 256]")
parser.add_argument("-n", "--nodes", help="Number of nodes [default: 32]")
parser.add_argument(
"-t",
"--sample-threshold",
help="Threshold for sampling request attempts [default: 12]",
)
parser.add_argument("-d", "--data-size", help="Size of packages [default: 1024]")
parser.add_argument("-f", "--fault_rate", help="Fault rate [default: 0]")
parser.add_argument(
"-r", "--replication_factor", help="Replication factor [default: 4]"
)
args = parser.parse_args()
if not args.subnets:
args.subnets = DEFAULT_SUBNETS
if not args.nodes:
args.nodes = DEFAULT_NODES
if not args.sample_threshold:
args.sample_threshold = DEFAULT_SAMPLE_THRESHOLD
if not args.data_size:
args.data_size = DEFAULT_DATA_SIZE
if not args.replication_factor:
args.replication_factor = DEFAULT_REPLICATION_FACTOR
if not args.fault_rate:
args.fault_rate = 0
print("Number of subnets will be: {}".format(args.subnets))
print("Number of nodes will be: {}".format(args.nodes))
print("Size of data package will be: {}".format(args.data_size))
print("Threshold for sampling attempts will be: {}".format(args.sample_threshold))
print("Fault rate will be: {}".format(args.fault_rate))
print()
print("*******************")
print("Starting network...")
trio.run(run_network, args)

65
da/network/readme.md Normal file
View File

@ -0,0 +1,65 @@
# Data Availability Subnets Proof-Of-Concept
## Contents
This folder contains code as implementation for a Proof-Of-Concept (PoC) for the subnets designed
to address dispersal and sampling in Data Availability (DA) in Nomos.
Refer to the [Specification](https://www.notion.so/Runnable-DA-PoC-Specification-50f204f2ff0a41d09de4926962bbb4ef?d=9e9677e5536a46d49fe95f366b7c3320#308624c50f1a42769b6c142976999483)
for the details of the design of this PoC.
Being a PoC, this code has no pretentions in terms of quality, and is certainly not meant to reach anywhere near production status.
## How to run
The entry point is `poc.py` , which can be run with a python3 binary.
It can be parametrized with the following options:
`python poc.py -s 512 -n 64 -t 12 -d 2048`
To understand what these parameter mean, just look at the help output:
```sh
> python poc.py -h
usage: poc.py [-h] [-s SUBNETS] [-n NODES] [-t SAMPLE_THRESHOLD] [-d DATA_SIZE]
options:
-h, --help show this help message and exit
-s SUBNETS, --subnets SUBNETS
Number of subnets [default: 256]
-n NODES, --nodes NODES
Number of nodes [default: 32]
-t SAMPLE_THRESHOLD, --sample-threshold SAMPLE_THRESHOLD
Threshold for sampling request attempts [default: 12]
-d DATA_SIZE, --data-size DATA_SIZE
Size of packages [default: 1024]
```
## What it does
The PoC first creates an instance of a light-weight `DANetwork`, which in turn
starts the configured number of nodes.
[!NOTE]
Currently ports are hardcoded. Nodes start at 7561 and are instantiated sequentially from there.
The Executor simulator runs on 8766.
After nodes are up, the subnets are calculated. Subnets calculation is explicitly **not part of the PoC**.
Therefore, the PoC uses a simple strategy of filling all subnets sequentially, and if not enough nodes are requested,
just fills up nodes up to a `REPLICATION_FACTOR` per subnet (thus, each subnet has at least `REPLICATION_FACTOR` nodes).
After nodes are assigned to subnets, the network connections (via direct libp2p links) are established.
Each node in a subnet connects with every other node in that subnet.
Next, the executor is started. It is just a simulator. It creates random data for each subnet of `DATA_SIZE` length,
simulating the columns generated by the NomosDA protocol.
It then establishes one connection per subnet and sends one packet of `DATA_SIZE` length on each of these connections.
The executor also stores a hash of each packet per subnet.
Receiving nodes then forward this package to each of their peers in the subnet.
They also store the respective hash (only).
Finally a simulated check samples up to `SAMPLE_THRESHOLD` nodes.
For each subnet it simply picks a node randomly and asks if it has the hash.

65
da/network/subnet.py Normal file
View File

@ -0,0 +1,65 @@
from random import randint
from constants import *
def calculate_subnets(node_list, num_subnets, replication_factor):
"""
Calculate in which subnet(s) to place each node.
This PoC does NOT require this to be analyzed,
nor to find the best solution.
Hence, we just use a simple model here:
1. Iterate all nodes and place each node in the subsequent subnet
2. If the subnet list can not be filled, start again from the top of the list
3. If each subnet does NOT have at least up to REPLICATION_FACTOR nodes, then
fill up the list with nodes up to the factor.
NOTE: This might be incomplete and/or buggy, but should be sufficient for
the purpose of the PoC.
If however, you find a bug, please report.
"""
# key of dict is the subnet number
subnets = {}
for i, n in enumerate(node_list):
idx = i % num_subnets
# each key has an array, so multiple nodes can be filter
# into a subnet
if idx not in subnets:
subnets[idx] = []
subnets[idx].append(n)
listlen = len(node_list)
i = listlen
# if there are less nodes than subnets
while i < num_subnets:
subnets[i] = []
subnets[i].append(node_list[i % listlen])
i += 1
# if not each subnet has at least factor number of nodes, fill up
if listlen < replication_factor * num_subnets:
for subnet in subnets:
last = subnets[subnet][len(subnets[subnet]) - 1].get_id()
idx = -1
# what is the last filled index of a subnet row
for j, n in enumerate(node_list):
if n.get_id() == last:
idx = j + 1
# fill up until factor
while len(subnets[subnet]) < replication_factor:
# wrap index if at end
if idx > len(node_list) - 1:
idx = 0
# don't add same node multiple times
if node_list[idx] in subnets[subnet]:
idx += 1
continue
subnets[subnet].append(node_list[idx])
idx += 1
return subnets

20
da/test_common.py Normal file
View File

@ -0,0 +1,20 @@
from unittest import TestCase
from da.common import ChunksMatrix
class TestCommon(TestCase):
def test_chunks_matrix_columns(self):
matrix = ChunksMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
expected = [[1, 4, 7], [2, 5, 8], [3, 6, 9]]
for c1, c2 in zip(expected, matrix.columns):
self.assertEqual(c1, c2)
def test_chunks_matrix_transposed(self):
matrix = ChunksMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
expected = ChunksMatrix([[1, 4, 7], [2, 5, 8], [3, 6, 9]])
self.assertEqual(matrix.transposed(), expected)
matrix = ChunksMatrix([[1, 2, 3], [4, 5, 6]])
expected = ChunksMatrix([[1, 4], [2, 5], [3, 6]])
self.assertEqual(matrix.transposed(), expected)

74
da/test_dispersal.py Normal file
View File

@ -0,0 +1,74 @@
from hashlib import sha3_256
from unittest import TestCase
from da.encoder import DAEncoderParams, DAEncoder
from da.test_encoder import TestEncoder
from da.verifier import DAVerifier, DABlob
from da.common import NodeId, Attestation, Bitfield, NomosDaG2ProofOfPossession as bls_pop
from da.dispersal import Dispersal, EncodedData, DispersalSettings
class TestDispersal(TestCase):
def setUp(self):
self.n_nodes = 16
self.nodes_ids = [NodeId(x.to_bytes(length=32, byteorder='big')) for x in range(self.n_nodes)]
self.secret_keys = list(range(1, self.n_nodes+1))
self.public_keys = [bls_pop.SkToPk(sk) for sk in self.secret_keys]
# sort by pk as we do in dispersal
self.secret_keys, self.public_keys = zip(
*sorted(zip(self.secret_keys, self.public_keys), key=lambda x: x[1])
)
dispersal_settings = DispersalSettings(
self.nodes_ids,
self.public_keys,
self.n_nodes // 2 + 1
)
self.dispersal = Dispersal(dispersal_settings)
self.encoder_test = TestEncoder()
self.encoder_test.setUp()
def test_build_certificate_insufficient_attestations(self):
with self.assertRaises(AssertionError):
self.dispersal._build_certificate(None, [], [])
def test_build_certificate_enough_attestations(self):
mock_encoded_data = EncodedData(
None, None, None, [], [], [], bytes(b"f"*48), []
)
mock_message = sha3_256(mock_encoded_data.aggregated_column_commitment).digest()
mock_attestations = [Attestation(bls_pop.Sign(sk, mock_message)) for sk in self.secret_keys]
certificate = self.dispersal._build_certificate(
mock_encoded_data,
mock_attestations,
Bitfield([True for _ in range(len(self.secret_keys))])
)
self.assertIsNotNone(certificate)
self.assertEqual(certificate.aggregated_column_commitment, mock_encoded_data.aggregated_column_commitment)
self.assertEqual(certificate.row_commitments, [])
self.assertIsNotNone(certificate.aggregated_signatures)
self.assertTrue(
certificate.verify(self.public_keys)
)
def test_disperse(self):
data = self.encoder_test.data
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
encoded_data = DAEncoder(encoding_params).encode(data)
# mock send and await method with local verifiers
def __send_and_await_response(node: NodeId, blob: DABlob):
sk = self.secret_keys[int.from_bytes(node)]
verifier = DAVerifier(sk, self.public_keys)
return verifier.verify(blob)
# inject mock send and await method
self.dispersal._send_and_await_response = __send_and_await_response
certificate = self.dispersal.disperse(encoded_data)
self.assertIsNotNone(certificate)
self.assertTrue(certificate.verify(self.public_keys)
)
self.assertEqual(
certificate.signers,
[True if i < self.dispersal.settings.threshold else False for i in range(self.n_nodes)]
)

137
da/test_encoder.py Normal file
View File

@ -0,0 +1,137 @@
from itertools import chain, batched
from random import randrange, randbytes
from unittest import TestCase
from eth2spec.deneb.mainnet import bytes_to_bls_field
from da import encoder
from da.encoder import DAEncoderParams, DAEncoder
from eth2spec.eip7594.mainnet import BYTES_PER_FIELD_ELEMENT, BLSFieldElement
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
from da.kzg_rs import kzg, rs
class TestEncoder(TestCase):
def setUp(self):
self.params: DAEncoderParams = DAEncoderParams(column_count=16, bytes_per_chunk=31)
self.encoder: DAEncoder = DAEncoder(self.params)
self.elements = 32
self.data = bytearray(
chain.from_iterable(
randbytes(self.params.bytes_per_chunk)
for _ in range(self.elements)
)
)
def assert_encoding(self, encoder_params: DAEncoderParams, data: bytes):
encoded_data = encoder.DAEncoder(encoder_params).encode(data)
self.assertEqual(encoded_data.data, data)
extended_factor = 2
column_count = encoder_params.column_count*extended_factor
columns_len = len(list(encoded_data.extended_matrix.columns))
self.assertEqual(columns_len, column_count)
chunks_size = (len(data) // encoder_params.bytes_per_chunk) // encoder_params.column_count
self.assertEqual(len(encoded_data.row_commitments), chunks_size)
self.assertEqual(len(encoded_data.row_proofs), chunks_size)
self.assertEqual(len(encoded_data.row_proofs[0]), column_count)
self.assertIsNotNone(encoded_data.aggregated_column_commitment)
self.assertEqual(len(encoded_data.aggregated_column_proofs), columns_len)
# verify rows
for row, proofs, commitment in zip(encoded_data.extended_matrix, encoded_data.row_proofs, encoded_data.row_commitments):
for i, (chunk, proof) in enumerate(zip(row, proofs)):
self.assertTrue(
kzg.verify_element_proof(bytes_to_bls_field(chunk), commitment, proof, i, ROOTS_OF_UNITY)
)
# verify column aggregation
for i, (column, proof) in enumerate(zip(encoded_data.extended_matrix.columns, encoded_data.aggregated_column_proofs)):
data = DAEncoder.hash_column_and_commitment(column, commitment)
kzg.verify_element_proof(
bytes_to_bls_field(data),
encoded_data.aggregated_column_commitment,
proof,
i,
ROOTS_OF_UNITY
)
def test_chunkify(self):
encoder_settings = DAEncoderParams(column_count=2, bytes_per_chunk=31)
elements = 10
data = bytes(chain.from_iterable(int.to_bytes(0, length=31, byteorder='big') for _ in range(elements)))
_encoder = encoder.DAEncoder(encoder_settings)
chunks_matrix = _encoder._chunkify_data(data)
self.assertEqual(len(chunks_matrix), elements//encoder_settings.column_count)
for row in chunks_matrix:
self.assertEqual(len(row), encoder_settings.column_count)
self.assertEqual(len(row[0]), 32)
def test_compute_row_kzg_commitments(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
polynomials, commitments = zip(*self.encoder._compute_row_kzg_commitments(chunks_matrix))
self.assertEqual(len(commitments), len(chunks_matrix))
self.assertEqual(len(polynomials), len(chunks_matrix))
def test_rs_encode_rows(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
extended_chunks_matrix = self.encoder._rs_encode_rows(chunks_matrix)
for r1, r2 in zip(chunks_matrix, extended_chunks_matrix):
self.assertEqual(len(r1), len(r2)//2)
r2 = [BLSFieldElement.from_bytes(x) for x in r2]
poly_1 = kzg.bytes_to_polynomial(r1.as_bytes())
# we check against decoding so we now the encoding was properly done
poly_2 = rs.decode(r2, ROOTS_OF_UNITY, len(poly_1))
self.assertEqual(poly_1, poly_2)
def test_compute_rows_proofs(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
polynomials, commitments = zip(*self.encoder._compute_row_kzg_commitments(chunks_matrix))
extended_chunks_matrix = self.encoder._rs_encode_rows(chunks_matrix)
original_proofs = self.encoder._compute_rows_proofs(chunks_matrix, polynomials, commitments)
extended_proofs = self.encoder._compute_rows_proofs(extended_chunks_matrix, polynomials, commitments)
# check original sized matrix
for row, poly, commitment, proofs in zip(chunks_matrix, polynomials, commitments, original_proofs):
self.assertEqual(len(proofs), len(row))
for i, chunk in enumerate(row):
self.assertTrue(kzg.verify_element_proof(BLSFieldElement.from_bytes(chunk), commitment, proofs[i], i, ROOTS_OF_UNITY))
# check extended matrix
for row, poly, commitment, proofs in zip(extended_chunks_matrix, polynomials, commitments, extended_proofs):
for i, chunk in enumerate(row):
self.assertTrue(kzg.verify_element_proof(BLSFieldElement.from_bytes(chunk), commitment, proofs[i], i, ROOTS_OF_UNITY))
def test_compute_column_kzg_commitments(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
polynomials, commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix))
self.assertEqual(len(commitments), len(chunks_matrix[0]))
self.assertEqual(len(polynomials), len(chunks_matrix[0]))
def test_generate_aggregated_column_commitments(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
_, column_commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix))
poly, commitment = self.encoder._compute_aggregated_column_commitment(chunks_matrix, column_commitments)
self.assertIsNotNone(poly)
self.assertIsNotNone(commitment)
def test_generate_aggregated_column_proofs(self):
chunks_matrix = self.encoder._chunkify_data(self.data)
_, column_commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix))
poly, _ = self.encoder._compute_aggregated_column_commitment(chunks_matrix, column_commitments)
proofs = self.encoder._compute_aggregated_column_proofs(poly, column_commitments)
self.assertEqual(len(proofs), len(column_commitments))
def test_encode(self):
from random import randbytes
sizes = [pow(2, exp) for exp in range(4, 8, 2)]
encoder_params = DAEncoderParams(
column_count=8,
bytes_per_chunk=31
)
for size in sizes:
data = bytes(
chain.from_iterable(
randbytes(encoder_params.bytes_per_chunk)
for _ in range(size*encoder_params.column_count)
)
)
self.assert_encoding(encoder_params, data)

134
da/test_full_flow.py Normal file
View File

@ -0,0 +1,134 @@
from itertools import chain
from unittest import TestCase
from typing import List, Optional
from da.common import NodeId, build_attestation_message, BLSPublicKey, NomosDaG2ProofOfPossession as bls_pop
from da.api.common import DAApi, VID, Metadata
from da.verifier import DAVerifier, DABlob
from da.api.test_flow import MockStore
from da.dispersal import Dispersal, DispersalSettings
from da.test_encoder import TestEncoder
from da.encoder import DAEncoderParams, DAEncoder
class DAVerifierWApi:
def __init__(self, sk: int, public_keys: List[BLSPublicKey]):
self.store = MockStore()
self.api = DAApi(self.store)
self.verifier = DAVerifier(sk, public_keys)
def receive_blob(self, blob: DABlob):
if attestation := self.verifier.verify(blob):
# Warning: If aggregated col commitment and row commitment are the same,
# the build_attestation_message method will produce the same output.
cert_id = build_attestation_message(blob.aggregated_column_commitment, blob.rows_commitments)
self.store.populate(blob, cert_id)
return attestation
def receive_cert(self, vid: VID):
# Usually the certificate would be verifier here,
# but we are assuming that this it is already coming from the verified block,
# in which case all certificates had been already verified by the DA Node.
self.api.write(vid.cert_id, vid.metadata)
def read(self, app_id, indexes) -> List[Optional[DABlob]]:
return self.api.read(app_id, indexes)
class TestFullFlow(TestCase):
def setUp(self):
self.n_nodes = 16
self.nodes_ids = [NodeId(x.to_bytes(length=32, byteorder='big')) for x in range(self.n_nodes)]
self.secret_keys = list(range(1, self.n_nodes+1))
self.public_keys = [bls_pop.SkToPk(sk) for sk in self.secret_keys]
# sort by pk as we do in dispersal
self.secret_keys, self.public_keys = zip(
*sorted(zip(self.secret_keys, self.public_keys), key=lambda x: x[1])
)
dispersal_settings = DispersalSettings(
self.nodes_ids,
self.public_keys,
self.n_nodes
)
self.dispersal = Dispersal(dispersal_settings)
self.encoder_test = TestEncoder()
self.encoder_test.setUp()
self.api_nodes = [DAVerifierWApi(k, self.public_keys) for k in self.secret_keys]
def test_full_flow(self):
app_id = int.to_bytes(1)
index = 1
# encoder
data = self.encoder_test.data
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
encoded_data = DAEncoder(encoding_params).encode(data)
# mock send and await method with local verifiers
def __send_and_await_response(node: int, blob: DABlob):
node = self.api_nodes[int.from_bytes(node)]
return node.receive_blob(blob)
# inject mock send and await method
self.dispersal._send_and_await_response = __send_and_await_response
certificate = self.dispersal.disperse(encoded_data)
vid = VID(
certificate.id(),
Metadata(app_id, index)
)
# verifier
for node in self.api_nodes:
node.receive_cert(vid)
# read from api and confirm its working
# notice that we need to sort the api_nodes by their public key to have the blobs sorted in the same fashion
# we do actually do dispersal.
blobs = list(chain.from_iterable(
node.read(app_id, [index])
for node in sorted(self.api_nodes, key=lambda n: bls_pop.SkToPk(n.verifier.sk))
))
original_blobs = list(self.dispersal._prepare_data(encoded_data))
self.assertEqual(blobs, original_blobs)
def test_same_blob_multiple_indexes(self):
app_id = int.to_bytes(1)
indexes = [1, 2, 3] # Different indexes to test with the same blob
# encoder
data = self.encoder_test.data
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
encoded_data = DAEncoder(encoding_params).encode(data)
# mock send and await method with local verifiers
def __send_and_await_response(node: int, blob: DABlob):
node = self.api_nodes[int.from_bytes(node)]
return node.receive_blob(blob)
# inject mock send and await method
self.dispersal._send_and_await_response = __send_and_await_response
certificate = self.dispersal.disperse(encoded_data)
# Loop through each index and simulate dispersal with the same cert_id but different metadata
for index in indexes:
vid = VID(
certificate.id(),
Metadata(app_id, index)
)
# verifier
for node in self.api_nodes:
node.receive_cert(vid)
# Verify retrieval for each index
for index in indexes:
# Notice that we need to sort the api_nodes by their public key to have the blobs sorted in the same fashion
# as we do actually do dispersal.
blobs = list(chain.from_iterable(
node.read(app_id, [index])
for node in sorted(self.api_nodes, key=lambda n: bls_pop.SkToPk(n.verifier.sk))
))
original_blobs = list(self.dispersal._prepare_data(encoded_data))
self.assertEqual(blobs, original_blobs, f"Failed at index {index}")

71
da/test_verifier.py Normal file
View File

@ -0,0 +1,71 @@
from unittest import TestCase
from da.common import Column, NomosDaG2ProofOfPossession as bls_pop
from da.encoder import DAEncoder
from da.kzg_rs import kzg
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY
from da.test_encoder import TestEncoder
from da.verifier import Attestation, DAVerifier, DABlob
class TestVerifier(TestCase):
def setUp(self):
self.verifier = DAVerifier(1987, [bls_pop.SkToPk(1987)])
def test_verify_column(self):
column = Column(int.to_bytes(i, length=32) for i in range(8))
_, column_commitment = kzg.bytes_to_commitment(column.as_bytes(), GLOBAL_PARAMETERS)
aggregated_poly, aggregated_column_commitment = kzg.bytes_to_commitment(
DAEncoder.hash_column_and_commitment(column, column_commitment), GLOBAL_PARAMETERS
)
aggregated_proof = kzg.generate_element_proof(0, aggregated_poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
self.assertTrue(
self.verifier._verify_column(
column, column_commitment, aggregated_column_commitment, aggregated_proof, 0
)
)
def test_verify(self):
_ = TestEncoder()
_.setUp()
encoded_data = _.encoder.encode(_.data)
verifiers_sk = [i for i in range(1000, 1000+len(encoded_data.chunked_data[0]))]
vefiers_pk = [bls_pop.SkToPk(k) for k in verifiers_sk]
for i, column in enumerate(encoded_data.chunked_data.columns):
verifier = DAVerifier(verifiers_sk[i], vefiers_pk)
da_blob = DABlob(
Column(column),
encoded_data.column_commitments[i],
encoded_data.aggregated_column_commitment,
encoded_data.aggregated_column_proofs[i],
encoded_data.row_commitments,
[row[i] for row in encoded_data.row_proofs],
)
self.assertIsNotNone(verifier.verify(da_blob))
def test_verify_duplicated_blob(self):
_ = TestEncoder()
_.setUp()
encoded_data = _.encoder.encode(_.data)
columns = enumerate(encoded_data.chunked_data.columns)
i, column = next(columns)
da_blob = DABlob(
Column(column),
encoded_data.column_commitments[i],
encoded_data.aggregated_column_commitment,
encoded_data.aggregated_column_proofs[i],
encoded_data.row_commitments,
[row[i] for row in encoded_data.row_proofs],
)
self.assertIsNotNone(self.verifier.verify(da_blob))
for i, column in columns:
da_blob = DABlob(
Column(column),
encoded_data.column_commitments[i],
encoded_data.aggregated_column_commitment,
encoded_data.aggregated_column_proofs[i],
encoded_data.row_commitments,
[row[i] for row in encoded_data.row_proofs],
)
self.assertIsNone(self.verifier.verify(da_blob))

114
da/verifier.py Normal file
View File

@ -0,0 +1,114 @@
from dataclasses import dataclass
from hashlib import sha3_256
from typing import List, Optional, Sequence, Set, Dict
from eth2spec.deneb.mainnet import BLSFieldElement
from eth2spec.eip7594.mainnet import (
KZGCommitment as Commitment,
KZGProof as Proof,
)
import da.common
from da.common import Column, Chunk, Attestation, BLSPrivateKey, BLSPublicKey, NomosDaG2ProofOfPossession as bls_pop
from da.encoder import DAEncoder
from da.kzg_rs import kzg
from da.kzg_rs.common import ROOTS_OF_UNITY, GLOBAL_PARAMETERS, BLS_MODULUS
@dataclass
class DABlob:
column: Column
column_commitment: Commitment
aggregated_column_commitment: Commitment
aggregated_column_proof: Proof
rows_commitments: List[Commitment]
rows_proofs: List[Proof]
def id(self) -> bytes:
return da.common.build_attestation_message(self.aggregated_column_commitment, self.rows_commitments)
def column_id(self) -> bytes:
return sha3_256(self.column.as_bytes()).digest()
class DAVerifier:
def __init__(self, sk: BLSPrivateKey, nodes_pks: List[BLSPublicKey]):
self.attested_blobs: Dict[bytes, (bytes, Attestation)] = dict()
self.sk = sk
self.index = nodes_pks.index(bls_pop.SkToPk(self.sk))
@staticmethod
def _verify_column(
column: Column,
column_commitment: Commitment,
aggregated_column_commitment: Commitment,
aggregated_column_proof: Proof,
index: int
) -> bool:
# 1. compute commitment for column
_, computed_column_commitment = kzg.bytes_to_commitment(column.as_bytes(), GLOBAL_PARAMETERS)
# 2. If computed column commitment != column commitment, fail
if column_commitment != computed_column_commitment:
return False
# 3. compute column hash
column_hash = DAEncoder.hash_column_and_commitment(column, column_commitment)
# 4. Check proof with commitment and proof over the aggregated column commitment
chunk = BLSFieldElement.from_bytes(column_hash)
return kzg.verify_element_proof(
chunk, aggregated_column_commitment, aggregated_column_proof, index, ROOTS_OF_UNITY
)
@staticmethod
def _verify_chunk(chunk: Chunk, commitment: Commitment, proof: Proof, index: int) -> bool:
chunk = BLSFieldElement(int.from_bytes(bytes(chunk)) % BLS_MODULUS)
return kzg.verify_element_proof(chunk, commitment, proof, index, ROOTS_OF_UNITY)
@staticmethod
def _verify_chunks(
chunks: Sequence[Chunk],
commitments: Sequence[Commitment],
proofs: Sequence[Proof],
index: int
) -> bool:
if not (len(chunks) == len(commitments) == len(proofs)):
return False
for chunk, commitment, proof in zip(chunks, commitments, proofs):
if not DAVerifier._verify_chunk(chunk, commitment, proof, index):
return False
return True
def _build_attestation(self, blob: DABlob) -> Attestation:
hasher = sha3_256()
hasher.update(bytes(blob.aggregated_column_commitment))
for c in blob.rows_commitments:
hasher.update(bytes(c))
message = hasher.digest()
return Attestation(signature=bls_pop.Sign(self.sk, message))
def verify(self, blob: DABlob) -> Optional[Attestation]:
blob_id = blob.id()
if previous_attestation := self.attested_blobs.get(blob_id):
column_id, attestation = previous_attestation
# we already attested, is cached so we return it
if column_id == blob.column_id():
return attestation
# we already attested and they are asking us to attest the same data different column
# skip
return None
is_column_verified = DAVerifier._verify_column(
blob.column,
blob.column_commitment,
blob.aggregated_column_commitment,
blob.aggregated_column_proof,
self.index
)
if not is_column_verified:
return
are_chunks_verified = DAVerifier._verify_chunks(
blob.column, blob.rows_commitments, blob.rows_proofs, self.index
)
if not are_chunks_verified:
return
attestation = self._build_attestation(blob)
self.attested_blobs[blob_id] = (blob.column_id(), attestation)
return attestation