moved from incorrect location at nomos-specs
This commit is contained in:
parent
e6402007f0
commit
faf399eaaf
|
@ -0,0 +1,58 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Sequence
|
||||
|
||||
from da.common import Certificate
|
||||
from da.verifier import DABlob
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
# app identifier
|
||||
app_id: bytes
|
||||
# index of VID certificate blob
|
||||
index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class VID:
|
||||
# da certificate id
|
||||
cert_id: bytes
|
||||
# application + index information
|
||||
metadata: Metadata
|
||||
|
||||
|
||||
class BlobStore(ABC):
|
||||
@abstractmethod
|
||||
def add(self, certificate: Certificate, metadata: Metadata):
|
||||
"""
|
||||
Raises: ValueError if there is already a registered certificate fot the given metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_multiple(self, app_id: bytes, indexes: Sequence[int]) -> List[Optional[DABlob]]:
|
||||
pass
|
||||
|
||||
|
||||
class DAApi:
|
||||
def __init__(self, bs: BlobStore):
|
||||
self.store = bs
|
||||
|
||||
def write(self, certificate: Certificate, metadata: Metadata):
|
||||
"""
|
||||
Write method should be used by a service that is able to retrieve verified certificates
|
||||
from the latest Block. Once a certificate is retrieved, api creates a relation between
|
||||
the blob of an original data, certificate and index for the app_id of the certificate.
|
||||
Raises: ValueError if there is already a registered certificate for a given metadata
|
||||
"""
|
||||
self.store.add(certificate, metadata)
|
||||
|
||||
def read(self, app_id, indexes) -> List[Optional[DABlob]]:
|
||||
"""
|
||||
Read method should accept only `app_id` and a list of indexes. The returned list of
|
||||
blobs should be ordered in the same sequence as `indexes` in a request.
|
||||
If node does not have the blob for some indexes, then it should add None object as an
|
||||
item.
|
||||
"""
|
||||
return self.store.get_multiple(app_id, indexes)
|
|
@ -0,0 +1,97 @@
|
|||
from unittest import TestCase
|
||||
from collections import defaultdict
|
||||
|
||||
from da.api.common import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockCertificate:
|
||||
cert_id: int
|
||||
|
||||
|
||||
class MockStore(BlobStore):
|
||||
def __init__(self):
|
||||
self.blob_store = {}
|
||||
self.app_id_store = defaultdict(dict)
|
||||
|
||||
def populate(self, blob, cert_id: bytes):
|
||||
self.blob_store[cert_id] = blob
|
||||
|
||||
# Implements `add` method from BlobStore abstract class.
|
||||
def add(self, cert_id: bytes, metadata: Metadata):
|
||||
if metadata.index in self.app_id_store[metadata.app_id]:
|
||||
raise ValueError("index already written")
|
||||
|
||||
self.app_id_store[metadata.app_id][metadata.index] = cert_id
|
||||
|
||||
# Implements `get_multiple` method from BlobStore abstract class.
|
||||
def get_multiple(self, app_id, indexes) -> List[Optional[DABlob]]:
|
||||
return [
|
||||
self.blob_store.get(self.app_id_store[app_id].get(i), None) if self.app_id_store[app_id].get(i) else None for i in indexes
|
||||
]
|
||||
|
||||
|
||||
|
||||
class TestFlow(TestCase):
|
||||
def test_api_write_read(self):
|
||||
expected_blob = "hello"
|
||||
cert_id = b"11"*32
|
||||
app_id = 1
|
||||
idx = 1
|
||||
mock_meta = Metadata(1, 1)
|
||||
|
||||
mock_store = MockStore()
|
||||
mock_store.populate(expected_blob, cert_id)
|
||||
|
||||
api = DAApi(mock_store)
|
||||
|
||||
api.write(cert_id, mock_meta)
|
||||
blobs = api.read(app_id, [idx])
|
||||
|
||||
self.assertEqual([expected_blob], blobs)
|
||||
|
||||
def test_same_index(self):
|
||||
expected_blob = "hello"
|
||||
cert_id = b"11"*32
|
||||
app_id = 1
|
||||
idx = 1
|
||||
mock_meta = Metadata(1, 1)
|
||||
|
||||
mock_store = MockStore()
|
||||
mock_store.populate(expected_blob, cert_id)
|
||||
|
||||
api = DAApi(mock_store)
|
||||
|
||||
api.write(cert_id, mock_meta)
|
||||
with self.assertRaises(ValueError):
|
||||
api.write(cert_id, mock_meta)
|
||||
|
||||
blobs = api.read(app_id, [idx])
|
||||
|
||||
self.assertEqual([expected_blob], blobs)
|
||||
|
||||
def test_multiple_indexes_same_data(self):
|
||||
expected_blob = "hello"
|
||||
cert_id = b"11"*32
|
||||
app_id = 1
|
||||
idx1 = 1
|
||||
idx2 = 2
|
||||
mock_meta1 = Metadata(app_id, idx1)
|
||||
mock_meta2 = Metadata(app_id, idx2)
|
||||
|
||||
mock_store = MockStore()
|
||||
mock_store.populate(expected_blob, cert_id)
|
||||
|
||||
api = DAApi(mock_store)
|
||||
|
||||
api.write(cert_id, mock_meta1)
|
||||
mock_store.populate(expected_blob, cert_id)
|
||||
api.write(cert_id, mock_meta2)
|
||||
|
||||
blobs_idx1 = api.read(app_id, [idx1])
|
||||
blobs_idx2 = api.read(app_id, [idx2])
|
||||
|
||||
self.assertEqual([expected_blob], blobs_idx1)
|
||||
self.assertEqual([expected_blob], blobs_idx2)
|
||||
self.assertEqual(mock_store.app_id_store[app_id][idx1], mock_store.app_id_store[app_id][idx2])
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
from dataclasses import dataclass
|
||||
from hashlib import sha3_256
|
||||
from itertools import chain, zip_longest, compress
|
||||
from typing import List, Generator, Self, Sequence
|
||||
|
||||
from eth2spec.eip7594.mainnet import Bytes32, KZGCommitment as Commitment
|
||||
from py_ecc.bls import G2ProofOfPossession
|
||||
|
||||
|
||||
class NodeId(Bytes32):
|
||||
pass
|
||||
|
||||
|
||||
class Chunk(bytes):
|
||||
pass
|
||||
|
||||
|
||||
class Column(List[Bytes32]):
|
||||
def as_bytes(self) -> bytes:
|
||||
return bytes(chain.from_iterable(self))
|
||||
|
||||
|
||||
class Row(List[Bytes32]):
|
||||
def as_bytes(self) -> bytes:
|
||||
return bytes(chain.from_iterable(self))
|
||||
|
||||
|
||||
class ChunksMatrix(List[Row | Column]):
|
||||
@property
|
||||
def columns(self) -> Generator[List[Chunk], None, None]:
|
||||
yield from map(Column, zip_longest(*self, fillvalue=b""))
|
||||
|
||||
def transposed(self) -> Self:
|
||||
return ChunksMatrix(self.columns)
|
||||
|
||||
|
||||
BLSPublicKey = bytes
|
||||
BLSPrivateKey = int
|
||||
BLSSignature = bytes
|
||||
|
||||
|
||||
class Bitfield(List[bool]):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attestation:
|
||||
signature: BLSSignature
|
||||
|
||||
|
||||
@dataclass
|
||||
class Certificate:
|
||||
aggregated_signatures: BLSSignature
|
||||
signers: Bitfield
|
||||
aggregated_column_commitment: Commitment
|
||||
row_commitments: List[Commitment]
|
||||
|
||||
def id(self) -> bytes:
|
||||
return build_attestation_message(self.aggregated_column_commitment, self.row_commitments)
|
||||
|
||||
def verify(self, nodes_public_keys: List[BLSPublicKey]) -> bool:
|
||||
"""
|
||||
List of nodes public keys should be a trusted list of verified proof of possession keys.
|
||||
Otherwise, we could fall under the Rogue Key Attack
|
||||
`assert all(bls_pop.PopVerify(pk, proof) for pk, proof in zip(node_public_keys, pops))`
|
||||
"""
|
||||
# we sort them as the signers bitfield is sorted by the public keys as well
|
||||
signers_keys = list(compress(sorted(nodes_public_keys), self.signers))
|
||||
message = build_attestation_message(self.aggregated_column_commitment, self.row_commitments)
|
||||
return NomosDaG2ProofOfPossession.AggregateVerify(signers_keys, [message]*len(signers_keys), self.aggregated_signatures)
|
||||
|
||||
|
||||
def build_attestation_message(aggregated_column_commitment: Commitment, row_commitments: Sequence[Commitment]) -> bytes:
|
||||
hasher = sha3_256()
|
||||
hasher.update(bytes(aggregated_column_commitment))
|
||||
for c in row_commitments:
|
||||
hasher.update(bytes(c))
|
||||
return hasher.digest()
|
||||
|
||||
class NomosDaG2ProofOfPossession(G2ProofOfPossession):
|
||||
# Domain specific tag for Nomos DA protocol
|
||||
DST = b"NOMOS_DA_AVAIL"
|
|
@ -0,0 +1,90 @@
|
|||
from dataclasses import dataclass
|
||||
from hashlib import sha3_256
|
||||
from typing import List, Optional, Generator, Sequence
|
||||
|
||||
from da.common import Certificate, NodeId, BLSPublicKey, Bitfield, build_attestation_message, NomosDaG2ProofOfPossession as bls_pop
|
||||
from da.encoder import EncodedData
|
||||
from da.verifier import DABlob, Attestation
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispersalSettings:
|
||||
nodes_ids: List[NodeId]
|
||||
nodes_pubkey: List[BLSPublicKey]
|
||||
threshold: int
|
||||
|
||||
|
||||
class Dispersal:
|
||||
def __init__(self, settings: DispersalSettings):
|
||||
self.settings = settings
|
||||
# sort over public keys
|
||||
self.settings.nodes_ids, self.settings.nodes_pubkey = zip(
|
||||
*sorted(zip(self.settings.nodes_ids, self.settings.nodes_pubkey), key=lambda x: x[1])
|
||||
)
|
||||
|
||||
def _prepare_data(self, encoded_data: EncodedData) -> Generator[DABlob, None, None]:
|
||||
assert len(encoded_data.column_commitments) == len(self.settings.nodes_ids)
|
||||
assert len(encoded_data.aggregated_column_proofs) == len(self.settings.nodes_ids)
|
||||
columns = encoded_data.extended_matrix.columns
|
||||
column_commitments = encoded_data.column_commitments
|
||||
row_commitments = encoded_data.row_commitments
|
||||
rows_proofs = encoded_data.row_proofs
|
||||
aggregated_column_commitment = encoded_data.aggregated_column_commitment
|
||||
aggregated_column_proofs = encoded_data.aggregated_column_proofs
|
||||
blobs_data = zip(columns, column_commitments, zip(*rows_proofs), aggregated_column_proofs)
|
||||
for (column, column_commitment, row_proofs, column_proof) in blobs_data:
|
||||
blob = DABlob(
|
||||
column,
|
||||
column_commitment,
|
||||
aggregated_column_commitment,
|
||||
column_proof,
|
||||
row_commitments,
|
||||
row_proofs
|
||||
)
|
||||
yield blob
|
||||
|
||||
def _send_and_await_response(self, node: NodeId, blob: DABlob) -> Optional[Attestation]:
|
||||
pass
|
||||
|
||||
def _build_certificate(
|
||||
self,
|
||||
encoded_data: EncodedData,
|
||||
attestations: Sequence[Attestation],
|
||||
signers: Bitfield
|
||||
) -> Certificate:
|
||||
assert len(attestations) >= self.settings.threshold
|
||||
assert len(attestations) == signers.count(True)
|
||||
aggregated = bls_pop.Aggregate([attestation.signature for attestation in attestations])
|
||||
return Certificate(
|
||||
aggregated_signatures=aggregated,
|
||||
signers=signers,
|
||||
aggregated_column_commitment=encoded_data.aggregated_column_commitment,
|
||||
row_commitments=encoded_data.row_commitments
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_attestation(public_key: BLSPublicKey, attested_message: bytes, attestation: Attestation) -> bool:
|
||||
return bls_pop.Verify(public_key, attested_message, attestation.signature)
|
||||
|
||||
@staticmethod
|
||||
def _build_attestation_message(encoded_data: EncodedData) -> bytes:
|
||||
return build_attestation_message(encoded_data.aggregated_column_commitment, encoded_data.row_commitments)
|
||||
|
||||
def disperse(self, encoded_data: EncodedData) -> Optional[Certificate]:
|
||||
attestations = []
|
||||
attested_message = self._build_attestation_message(encoded_data)
|
||||
signed = Bitfield(False for _ in range(len(self.settings.nodes_ids)))
|
||||
blob_data = zip(
|
||||
range(len(self.settings.nodes_ids)),
|
||||
self.settings.nodes_ids,
|
||||
self.settings.nodes_pubkey,
|
||||
self._prepare_data(encoded_data)
|
||||
)
|
||||
for i, node, pk, blob in blob_data:
|
||||
if attestation := self._send_and_await_response(node, blob):
|
||||
if self._verify_attestation(pk, attested_message, attestation):
|
||||
# mark as received
|
||||
signed[i] = True
|
||||
attestations.append(attestation)
|
||||
if len(attestations) >= self.settings.threshold:
|
||||
return self._build_certificate(encoded_data, attestations, signed)
|
|
@ -0,0 +1,136 @@
|
|||
from dataclasses import dataclass
|
||||
from itertools import batched, chain
|
||||
from typing import List, Sequence, Tuple
|
||||
from hashlib import blake2b
|
||||
|
||||
from eth2spec.eip7594.mainnet import KZGCommitment as Commitment, KZGProof as Proof, BLSFieldElement
|
||||
|
||||
from da.common import ChunksMatrix, Chunk, Row, Column
|
||||
from da.kzg_rs import kzg, rs
|
||||
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY, BLS_MODULUS, BYTES_PER_FIELD_ELEMENT
|
||||
from da.kzg_rs.poly import Polynomial
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAEncoderParams:
|
||||
column_count: int
|
||||
bytes_per_chunk: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncodedData:
|
||||
data: bytes
|
||||
chunked_data: ChunksMatrix
|
||||
extended_matrix: ChunksMatrix
|
||||
row_commitments: List[Commitment]
|
||||
row_proofs: List[List[Proof]]
|
||||
column_commitments: List[Commitment]
|
||||
aggregated_column_commitment: Commitment
|
||||
aggregated_column_proofs: List[Proof]
|
||||
|
||||
|
||||
class DAEncoder:
|
||||
def __init__(self, params: DAEncoderParams):
|
||||
# we can only encode up to 31 bytes per element which fits without problem in a 32 byte element
|
||||
assert params.bytes_per_chunk < BYTES_PER_FIELD_ELEMENT
|
||||
self.params = params
|
||||
|
||||
def _chunkify_data(self, data: bytes) -> ChunksMatrix:
|
||||
size: int = self.params.column_count * self.params.bytes_per_chunk
|
||||
return ChunksMatrix(
|
||||
Row(Chunk(int.from_bytes(chunk, byteorder="big").to_bytes(length=BYTES_PER_FIELD_ELEMENT))
|
||||
for chunk in batched(b, self.params.bytes_per_chunk)
|
||||
)
|
||||
for b in batched(data, size)
|
||||
)
|
||||
|
||||
def _compute_row_kzg_commitments(self, matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
|
||||
return [
|
||||
kzg.bytes_to_commitment(
|
||||
row.as_bytes(),
|
||||
GLOBAL_PARAMETERS,
|
||||
)
|
||||
for row in matrix
|
||||
]
|
||||
|
||||
def _rs_encode_rows(self, chunks_matrix: ChunksMatrix) -> ChunksMatrix:
|
||||
def __rs_encode_row(row: Row) -> Row:
|
||||
polynomial = kzg.bytes_to_polynomial(row.as_bytes())
|
||||
return Row(
|
||||
Chunk(BLSFieldElement.to_bytes(
|
||||
x,
|
||||
# fixed to 32 bytes as bls_field_elements are 32bytes (256bits) encoded
|
||||
length=32, byteorder="big"
|
||||
)) for x in rs.encode(polynomial, 2, ROOTS_OF_UNITY)
|
||||
)
|
||||
return ChunksMatrix(__rs_encode_row(row) for row in chunks_matrix)
|
||||
|
||||
@staticmethod
|
||||
def _compute_rows_proofs(
|
||||
chunks_matrix: ChunksMatrix,
|
||||
polynomials: Sequence[Polynomial],
|
||||
row_commitments: Sequence[Commitment]
|
||||
) -> List[List[Proof]]:
|
||||
proofs = []
|
||||
for row, poly, commitment in zip(chunks_matrix, polynomials, row_commitments):
|
||||
proofs.append(
|
||||
[
|
||||
kzg.generate_element_proof(i, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
|
||||
for i in range(len(row))
|
||||
]
|
||||
)
|
||||
return proofs
|
||||
|
||||
def _compute_column_kzg_commitments(self, chunks_matrix: ChunksMatrix) -> List[Tuple[Polynomial, Commitment]]:
|
||||
return self._compute_row_kzg_commitments(chunks_matrix.transposed())
|
||||
|
||||
@staticmethod
|
||||
def _compute_aggregated_column_commitment(
|
||||
chunks_matrix: ChunksMatrix, column_commitments: Sequence[Commitment]
|
||||
) -> Tuple[Polynomial, Commitment]:
|
||||
data = bytes(chain.from_iterable(
|
||||
DAEncoder.hash_column_and_commitment(column, commitment)
|
||||
for column, commitment in zip(chunks_matrix.columns, column_commitments)
|
||||
))
|
||||
return kzg.bytes_to_commitment(data, GLOBAL_PARAMETERS)
|
||||
|
||||
@staticmethod
|
||||
def _compute_aggregated_column_proofs(
|
||||
polynomial: Polynomial,
|
||||
column_commitments: Sequence[Commitment],
|
||||
) -> List[Proof]:
|
||||
return [
|
||||
kzg.generate_element_proof(i, polynomial, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
|
||||
for i in range(len(column_commitments))
|
||||
]
|
||||
|
||||
def encode(self, data: bytes) -> EncodedData:
|
||||
chunks_matrix = self._chunkify_data(data)
|
||||
row_polynomials, row_commitments = zip(*self._compute_row_kzg_commitments(chunks_matrix))
|
||||
extended_matrix = self._rs_encode_rows(chunks_matrix)
|
||||
row_proofs = self._compute_rows_proofs(extended_matrix, row_polynomials, row_commitments)
|
||||
column_polynomials, column_commitments = zip(*self._compute_column_kzg_commitments(extended_matrix))
|
||||
aggregated_column_polynomial, aggregated_column_commitment = (
|
||||
self._compute_aggregated_column_commitment(extended_matrix, column_commitments)
|
||||
)
|
||||
aggregated_column_proofs = self._compute_aggregated_column_proofs(
|
||||
aggregated_column_polynomial, column_commitments
|
||||
)
|
||||
result = EncodedData(
|
||||
data,
|
||||
chunks_matrix,
|
||||
extended_matrix,
|
||||
row_commitments,
|
||||
row_proofs,
|
||||
column_commitments,
|
||||
aggregated_column_commitment,
|
||||
aggregated_column_proofs
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def hash_column_and_commitment(column: Column, commitment: Commitment) -> bytes:
|
||||
return (
|
||||
# digest size must be 31 bytes as we cannot encode 32 without risking overflowing the BLS_MODULUS
|
||||
int.from_bytes(blake2b(column.as_bytes() + bytes(commitment), digest_size=31).digest())
|
||||
).to_bytes(32, byteorder="big")
|
|
@ -0,0 +1,22 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import eth2spec.eip7594.mainnet
|
||||
from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed
|
||||
|
||||
from da.kzg_rs.roots import compute_roots_of_unity
|
||||
from da.kzg_rs.trusted_setup import generate_setup
|
||||
|
||||
G1 = G1Uncompressed
|
||||
G2 = G2Uncompressed
|
||||
|
||||
|
||||
BYTES_PER_FIELD_ELEMENT = 32
|
||||
BLS_MODULUS = eth2spec.eip7594.mainnet.BLS_MODULUS
|
||||
PRIMITIVE_ROOT: int = 7
|
||||
GLOBAL_PARAMETERS: List[G1]
|
||||
GLOBAL_PARAMETERS_G2: List[G2]
|
||||
# secret is fixed but this should come from a different synchronization protocol
|
||||
GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(4096, 8, 1987))
|
||||
ROOTS_OF_UNITY: Tuple[int] = compute_roots_of_unity(
|
||||
PRIMITIVE_ROOT, 4096, BLS_MODULUS
|
||||
)
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Sequence, List
|
||||
|
||||
from eth2spec.deneb.mainnet import BLSFieldElement
|
||||
from eth2spec.utils import bls
|
||||
|
||||
from da.kzg_rs.common import G1
|
||||
|
||||
|
||||
def fft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]:
|
||||
if len(vals) == 1:
|
||||
return vals
|
||||
L = fft_g1(vals[::2], roots_of_unity[::2], modulus)
|
||||
R = fft_g1(vals[1::2], roots_of_unity[::2], modulus)
|
||||
o = [bls.Z1() for _ in vals]
|
||||
for i, (x, y) in enumerate(zip(L, R)):
|
||||
y_times_root = bls.multiply(y, roots_of_unity[i])
|
||||
o[i] = (x + y_times_root)
|
||||
o[i + len(L)] = x + -y_times_root
|
||||
return o
|
||||
|
||||
|
||||
def ifft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]:
|
||||
assert len(vals) == len(roots_of_unity)
|
||||
# modular inverse
|
||||
invlen = pow(len(vals), modulus-2, modulus)
|
||||
return [
|
||||
bls.multiply(x, invlen)
|
||||
for x in fft_g1(
|
||||
vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _fft(
|
||||
vals: Sequence[BLSFieldElement],
|
||||
roots_of_unity: Sequence[BLSFieldElement],
|
||||
modulus: int,
|
||||
) -> Sequence[BLSFieldElement]:
|
||||
if len(vals) == 1:
|
||||
return vals
|
||||
L = _fft(vals[::2], roots_of_unity[::2], modulus)
|
||||
R = _fft(vals[1::2], roots_of_unity[::2], modulus)
|
||||
o = [BLSFieldElement(0) for _ in vals]
|
||||
for i, (x, y) in enumerate(zip(L, R)):
|
||||
y_times_root = BLSFieldElement((int(y) * int(roots_of_unity[i])) % modulus)
|
||||
o[i] = BLSFieldElement((int(x) + y_times_root) % modulus)
|
||||
o[i + len(L)] = BLSFieldElement((int(x) - int(y_times_root) + modulus) % modulus)
|
||||
return o
|
||||
|
||||
|
||||
def fft(vals, root_of_unity, modulus):
|
||||
assert len(vals) == len(root_of_unity)
|
||||
return _fft(vals, root_of_unity, modulus)
|
||||
|
||||
|
||||
def ifft(vals, roots_of_unity, modulus):
|
||||
assert len(vals) == len(roots_of_unity)
|
||||
# modular inverse
|
||||
invlen = pow(len(vals), modulus-2, modulus)
|
||||
return [
|
||||
BLSFieldElement((int(x) * invlen) % modulus)
|
||||
for x in _fft(
|
||||
vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus
|
||||
)
|
||||
]
|
|
@ -0,0 +1,75 @@
|
|||
from typing import List, Sequence
|
||||
|
||||
from eth2spec.deneb.mainnet import KZGProof as Proof, BLSFieldElement
|
||||
from eth2spec.utils import bls
|
||||
|
||||
from da.kzg_rs.common import G1, BLS_MODULUS, PRIMITIVE_ROOT
|
||||
from da.kzg_rs.fft import fft, fft_g1, ifft_g1
|
||||
from da.kzg_rs.poly import Polynomial
|
||||
from da.kzg_rs.roots import compute_roots_of_unity
|
||||
from da.kzg_rs.utils import is_power_of_two
|
||||
|
||||
|
||||
def __toeplitz1(global_parameters: List[G1], polynomial_degree: int) -> List[G1]:
|
||||
"""
|
||||
This part can be precomputed for different global_parameters lengths depending on polynomial degree of powers of two.
|
||||
:param global_parameters:
|
||||
:param roots_of_unity:
|
||||
:param polynomial_degree:
|
||||
:return:
|
||||
"""
|
||||
assert len(global_parameters) == polynomial_degree
|
||||
# algorithm only works on powers of 2 for dft computations
|
||||
assert is_power_of_two(len(global_parameters))
|
||||
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, polynomial_degree*2, BLS_MODULUS)
|
||||
vector_x_extended = global_parameters + [bls.Z1() for _ in range(polynomial_degree)]
|
||||
vector_x_extended_fft = fft_g1(vector_x_extended, roots_of_unity, BLS_MODULUS)
|
||||
return vector_x_extended_fft
|
||||
|
||||
|
||||
def __toeplitz2(coefficients: List[BLSFieldElement], extended_vector: Sequence[G1]) -> List[G1]:
|
||||
assert is_power_of_two(len(coefficients))
|
||||
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, len(coefficients), BLS_MODULUS)
|
||||
toeplitz_coefficients_fft = fft(coefficients, roots_of_unity, BLS_MODULUS)
|
||||
return [bls.multiply(v, c) for v, c in zip(extended_vector, toeplitz_coefficients_fft)]
|
||||
|
||||
|
||||
def __toeplitz3(h_extended_fft: Sequence[G1], polynomial_degree: int) -> List[G1]:
|
||||
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, len(h_extended_fft), BLS_MODULUS)
|
||||
return ifft_g1(h_extended_fft, roots_of_unity, BLS_MODULUS)[:polynomial_degree]
|
||||
|
||||
|
||||
def fk20_generate_proofs(
|
||||
polynomial: Polynomial, global_parameters: List[G1]
|
||||
) -> List[Proof]:
|
||||
"""
|
||||
Generate all proofs for the polynomial points in batch.
|
||||
This method uses the fk20 algorthm from https://eprint.iacr.org/2023/033.pdf
|
||||
Disclaimer: It only works for polynomial degree of powers of two.
|
||||
:param polynomial: polynomial to generate proof for
|
||||
:param global_parameters: setup generated parameters
|
||||
:return: list of proof for each point in the polynomial
|
||||
"""
|
||||
polynomial_degree = len(polynomial)
|
||||
assert len(global_parameters) >= polynomial_degree
|
||||
assert is_power_of_two(len(polynomial))
|
||||
|
||||
# 1 - Build toeplitz matrix for h values
|
||||
# 1.1 y = dft([s^d-1, s^d-2, ..., s, 1, *[0 for _ in len(polynomial)]])
|
||||
# 1.2 z = dft([*[0 for _ in len(polynomial)], f1, f2, ..., fd])
|
||||
# 1.3 u = y * v * roots_of_unity(len(polynomial)*2)
|
||||
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, polynomial_degree, BLS_MODULUS)
|
||||
global_parameters = list(reversed(global_parameters[:polynomial_degree]))
|
||||
extended_vector = __toeplitz1(global_parameters, polynomial_degree)
|
||||
# 2 - Build circulant matrix with the polynomial coefficients (reversed N..n, and padded)
|
||||
toeplitz_coefficients = [
|
||||
*(BLSFieldElement(0) for _ in range(polynomial_degree)),
|
||||
*polynomial.coefficients
|
||||
]
|
||||
h_extended_vector = __toeplitz2(toeplitz_coefficients, extended_vector)
|
||||
# 3 - Perform fft and nub the tail half as it is padding
|
||||
h_vector = __toeplitz3(h_extended_vector, polynomial_degree)
|
||||
# 4 - proof are the dft of the h vector
|
||||
proofs = fft_g1(h_vector, roots_of_unity, BLS_MODULUS)
|
||||
proofs = [Proof(bls.G1_to_bytes48(proof)) for proof in proofs]
|
||||
return proofs
|
|
@ -0,0 +1,74 @@
|
|||
from functools import reduce
|
||||
from itertools import batched
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from eth2spec.deneb.mainnet import bytes_to_bls_field, BLSFieldElement, KZGCommitment as Commitment, KZGProof as Proof
|
||||
from eth2spec.utils import bls
|
||||
|
||||
from .common import BYTES_PER_FIELD_ELEMENT, G1, BLS_MODULUS, GLOBAL_PARAMETERS_G2
|
||||
from .poly import Polynomial
|
||||
|
||||
|
||||
def bytes_to_polynomial(b: bytes, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT) -> Polynomial:
|
||||
"""
|
||||
Convert bytes to list of BLS field scalars.
|
||||
"""
|
||||
assert len(b) % bytes_per_field_element == 0
|
||||
eval_form = [int(bytes_to_bls_field(b)) for b in batched(b, int(bytes_per_field_element))]
|
||||
return Polynomial.from_evaluations(eval_form, BLS_MODULUS)
|
||||
|
||||
|
||||
def g1_linear_combination(polynomial: Polynomial[BLSFieldElement], global_parameters: Sequence[G1]) -> Commitment:
|
||||
"""
|
||||
BLS multiscalar multiplication.
|
||||
"""
|
||||
# we assert to have more points available than elements,
|
||||
# this is dependent on the available kzg setup size
|
||||
assert len(polynomial) <= len(global_parameters)
|
||||
point = reduce(
|
||||
bls.add,
|
||||
(bls.multiply(g, p) for g, p in zip(global_parameters, polynomial)),
|
||||
bls.Z1()
|
||||
)
|
||||
return Commitment(bls.G1_to_bytes48(point))
|
||||
|
||||
|
||||
def bytes_to_commitment(b: bytes, global_parameters: Sequence[G1]) -> Tuple[Polynomial, Commitment]:
|
||||
poly = bytes_to_polynomial(b, bytes_per_field_element=BYTES_PER_FIELD_ELEMENT)
|
||||
return poly, g1_linear_combination(poly, global_parameters)
|
||||
|
||||
|
||||
def generate_element_proof(
|
||||
element_index: int,
|
||||
polynomial: Polynomial,
|
||||
global_parameters: Sequence[G1],
|
||||
roots_of_unity: Sequence[BLSFieldElement],
|
||||
) -> Proof:
|
||||
# compute a witness polynomial in that satisfies `witness(x) = (f(x)-v)/(x-u)`
|
||||
u = int(roots_of_unity[element_index])
|
||||
v = polynomial.eval(u)
|
||||
f_x_v = polynomial - Polynomial([v], BLS_MODULUS)
|
||||
x_u = Polynomial([-u, 1], BLS_MODULUS)
|
||||
witness, _ = f_x_v / x_u
|
||||
return g1_linear_combination(witness, global_parameters)
|
||||
|
||||
|
||||
def verify_element_proof(
|
||||
chunk: BLSFieldElement,
|
||||
commitment: Commitment,
|
||||
proof: Proof,
|
||||
element_index: int,
|
||||
roots_of_unity: Sequence[BLSFieldElement],
|
||||
) -> bool:
|
||||
u = int(roots_of_unity[element_index])
|
||||
v = chunk
|
||||
commitment_check_G1 = bls.bytes48_to_G1(commitment) - bls.multiply(bls.G1(), v)
|
||||
proof_check_g2 = bls.add(
|
||||
GLOBAL_PARAMETERS_G2[1],
|
||||
bls.neg(bls.multiply(bls.G2(), u))
|
||||
)
|
||||
return bls.pairing_check([
|
||||
# G2 here needs to be negated due to library requirements as pairing_check([[G1, -G2], [G1, G2]])
|
||||
[commitment_check_G1, bls.neg(bls.G2())],
|
||||
[bls.bytes48_to_G1(proof), proof_check_g2],
|
||||
])
|
|
@ -0,0 +1,111 @@
|
|||
from itertools import zip_longest
|
||||
from typing import List, Sequence, Self
|
||||
|
||||
from eth2spec.eip7594.mainnet import interpolate_polynomialcoeff
|
||||
|
||||
from da.kzg_rs.common import ROOTS_OF_UNITY
|
||||
|
||||
|
||||
class Polynomial[T]:
|
||||
def __init__(self, coefficients, modulus):
|
||||
self.coefficients = coefficients
|
||||
self.modulus = modulus
|
||||
|
||||
@staticmethod
|
||||
def interpolate(evaluations: List[int], roots_of_unity: List[int]) -> List[int]:
|
||||
"""
|
||||
Lagrange interpolation
|
||||
|
||||
Parameters:
|
||||
evaluations: List of evaluations
|
||||
roots_of_unity: Powers of 2 sequence
|
||||
|
||||
Returns:
|
||||
list: Coefficients of the interpolated polynomial
|
||||
"""
|
||||
return list(map(int, interpolate_polynomialcoeff(roots_of_unity[:len(evaluations)], evaluations)))
|
||||
|
||||
@classmethod
|
||||
def from_evaluations(cls, evaluations: Sequence[T], modulus, roots_of_unity: Sequence[int]=ROOTS_OF_UNITY) -> Self:
|
||||
coefficients = [
|
||||
x % modulus
|
||||
for x in map(int, Polynomial.interpolate(evaluations, roots_of_unity))
|
||||
]
|
||||
return cls(coefficients, modulus)
|
||||
|
||||
def __repr__(self):
|
||||
return "Polynomial({}, modulus={})".format(self.coefficients, self.modulus)
|
||||
|
||||
def __add__(self, other):
|
||||
return Polynomial(
|
||||
[(a + b) % self.modulus for a, b in zip_longest(self.coefficients, other.coefficients, fillvalue=0)],
|
||||
self.modulus
|
||||
)
|
||||
|
||||
def __sub__(self, other):
|
||||
return Polynomial(
|
||||
[(a - b) % self.modulus for a, b in zip_longest(self.coefficients, other.coefficients, fillvalue=0)],
|
||||
self.modulus
|
||||
)
|
||||
|
||||
def __mul__(self, other):
|
||||
result = [0] * (len(self.coefficients) + len(other.coefficients) - 1)
|
||||
for i in range(len(self.coefficients)):
|
||||
for j in range(len(other.coefficients)):
|
||||
result[i + j] = (result[i + j] + self.coefficients[i] * other.coefficients[j]) % self.modulus
|
||||
return Polynomial(result, self.modulus)
|
||||
|
||||
def divide(self, other):
|
||||
if not isinstance(other, Polynomial):
|
||||
raise ValueError("Unsupported type for division.")
|
||||
|
||||
dividend = list(self.coefficients)
|
||||
divisor = list(other.coefficients)
|
||||
|
||||
quotient = []
|
||||
remainder = dividend
|
||||
|
||||
while len(remainder) >= len(divisor):
|
||||
factor = remainder[-1] * pow(divisor[-1], -1, self.modulus) % self.modulus
|
||||
quotient.insert(0, factor)
|
||||
|
||||
# Subtract divisor * factor from remainder
|
||||
for i in range(len(divisor)):
|
||||
remainder[len(remainder) - len(divisor) + i] -= divisor[i] * factor
|
||||
remainder[len(remainder) - len(divisor) + i] %= self.modulus
|
||||
|
||||
# Remove leading zeros from remainder
|
||||
while remainder and remainder[-1] == 0:
|
||||
remainder.pop()
|
||||
|
||||
return Polynomial(quotient, self.modulus), Polynomial(remainder, self.modulus)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.divide(other)
|
||||
|
||||
def __neg__(self):
|
||||
return Polynomial([-1 * c for c in self.coefficients], self.modulus)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.coefficients)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.coefficients)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.coefficients[item]
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.coefficients == other.coefficients and
|
||||
self.modulus == other.modulus
|
||||
)
|
||||
|
||||
def eval(self, x):
|
||||
return (self.coefficients[0] + sum(
|
||||
(pow(x, i, mod=self.modulus)*coefficient)
|
||||
for i, coefficient in enumerate(self.coefficients[1:], start=1)
|
||||
)) % self.modulus
|
||||
|
||||
def evaluation_form(self) -> List[T]:
|
||||
return [self.eval(ROOTS_OF_UNITY[i]) for i in range(len(self))]
|
|
@ -0,0 +1,25 @@
|
|||
from typing import Tuple
|
||||
|
||||
|
||||
def compute_root_of_unity(primitive_root: int, order: int, modulus: int) -> int:
|
||||
"""
|
||||
Generate a w such that ``w**length = 1``.
|
||||
"""
|
||||
assert (modulus - 1) % order == 0
|
||||
return pow(primitive_root, (modulus - 1) // order, modulus)
|
||||
|
||||
|
||||
def compute_roots_of_unity(primitive_root: int, order: int, modulus: int) -> Tuple[int]:
|
||||
"""
|
||||
Compute a list of roots of unity for a given order.
|
||||
The order must divide the BLS multiplicative group order, i.e. BLS_MODULUS - 1
|
||||
"""
|
||||
assert (modulus - 1) % order == 0
|
||||
root_of_unity = compute_root_of_unity(primitive_root, order, modulus)
|
||||
|
||||
roots = []
|
||||
current_root_of_unity = 1
|
||||
for _ in range(order):
|
||||
roots.append(current_root_of_unity)
|
||||
current_root_of_unity = current_root_of_unity * root_of_unity % modulus
|
||||
return tuple(roots)
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Sequence, Optional
|
||||
|
||||
from eth2spec.deneb.mainnet import BLSFieldElement
|
||||
from .common import BLS_MODULUS
|
||||
from .poly import Polynomial
|
||||
|
||||
ExtendedData = Sequence[Optional[BLSFieldElement]]
|
||||
|
||||
|
||||
def encode(polynomial: Polynomial, factor: int, roots_of_unity: Sequence[int]) -> ExtendedData:
|
||||
"""
|
||||
Encode a polynomial extending to the given factor
|
||||
Parameters:
|
||||
polynomial: Polynomial to be encoded
|
||||
factor: Encoding factor
|
||||
roots_of_unity: Powers of 2 sequence
|
||||
|
||||
Returns:
|
||||
list: Extended data set
|
||||
"""
|
||||
assert factor >= 2
|
||||
assert len(polynomial)*factor <= len(roots_of_unity)
|
||||
return [polynomial.eval(e) for e in roots_of_unity[:len(polynomial)*factor]]
|
||||
|
||||
|
||||
def decode(encoded: ExtendedData, roots_of_unity: Sequence[BLSFieldElement], original_len: int) -> Polynomial:
|
||||
"""
|
||||
Decode a polynomial from an extended data-set and the roots of unity, cap to original length
|
||||
|
||||
Parameters:
|
||||
encoded: Extended data set
|
||||
roots_of_unity: Powers of 2 sequence
|
||||
original_len: Original length of the encoded polynomial
|
||||
|
||||
Returns:
|
||||
Polynomial: original polynomial
|
||||
"""
|
||||
encoded, roots_of_unity = zip(*((point, root) for point, root in zip(encoded, roots_of_unity) if point is not None))
|
||||
coefs = Polynomial.interpolate(list(map(int, encoded)), list(map(int, roots_of_unity)))[:original_len]
|
||||
return Polynomial([int(c) for c in coefs], BLS_MODULUS)
|
|
@ -0,0 +1,14 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from .roots import compute_roots_of_unity
|
||||
from .common import BLS_MODULUS
|
||||
from .fft import fft, ifft
|
||||
|
||||
|
||||
class TestFFT(TestCase):
|
||||
def test_fft_ifft(self):
|
||||
for size in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
roots_of_unity = compute_roots_of_unity(2, size, BLS_MODULUS)
|
||||
vals = list(x for x in range(size))
|
||||
vals_fft = fft(vals, roots_of_unity, BLS_MODULUS)
|
||||
self.assertEqual(vals, ifft(vals_fft, roots_of_unity, BLS_MODULUS))
|
|
@ -0,0 +1,28 @@
|
|||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
import random
|
||||
from .fk20 import fk20_generate_proofs
|
||||
from .kzg import generate_element_proof, bytes_to_polynomial
|
||||
from .common import BLS_MODULUS, BYTES_PER_FIELD_ELEMENT, GLOBAL_PARAMETERS, PRIMITIVE_ROOT
|
||||
from .roots import compute_roots_of_unity
|
||||
|
||||
|
||||
class TestFK20(TestCase):
|
||||
@staticmethod
|
||||
def rand_bytes(n_chunks=1024):
|
||||
return bytes(
|
||||
chain.from_iterable(
|
||||
int.to_bytes(random.randrange(BLS_MODULUS), length=BYTES_PER_FIELD_ELEMENT)
|
||||
for _ in range(n_chunks)
|
||||
)
|
||||
)
|
||||
|
||||
def test_fk20(self):
|
||||
for size in [16, 32, 64, 128, 256]:
|
||||
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, size, BLS_MODULUS)
|
||||
rand_bytes = self.rand_bytes(size)
|
||||
polynomial = bytes_to_polynomial(rand_bytes)
|
||||
proofs = [generate_element_proof(i, polynomial, GLOBAL_PARAMETERS, roots_of_unity) for i in range(size)]
|
||||
fk20_proofs = fk20_generate_proofs(polynomial, GLOBAL_PARAMETERS)
|
||||
self.assertEqual(len(proofs), len(fk20_proofs))
|
||||
self.assertEqual(proofs, fk20_proofs)
|
|
@ -0,0 +1,67 @@
|
|||
from itertools import chain, batched
|
||||
from random import randrange
|
||||
from unittest import TestCase
|
||||
|
||||
from eth2spec.deneb.mainnet import BLS_MODULUS, bytes_to_bls_field, BLSFieldElement
|
||||
|
||||
from da.kzg_rs import kzg
|
||||
from da.kzg_rs.common import BYTES_PER_FIELD_ELEMENT, GLOBAL_PARAMETERS, ROOTS_OF_UNITY, GLOBAL_PARAMETERS_G2
|
||||
from da.kzg_rs.trusted_setup import verify_setup
|
||||
|
||||
|
||||
class TestKZG(TestCase):
|
||||
|
||||
@staticmethod
|
||||
def rand_bytes(n_chunks=1024):
|
||||
return bytes(
|
||||
chain.from_iterable(
|
||||
int.to_bytes(randrange(BLS_MODULUS), length=BYTES_PER_FIELD_ELEMENT)
|
||||
for _ in range(n_chunks)
|
||||
)
|
||||
)
|
||||
|
||||
def test_kzg_setup(self):
|
||||
self.assertTrue(verify_setup((GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2)))
|
||||
|
||||
def test_poly_forms(self):
|
||||
n_chunks = 16
|
||||
rand_bytes = self.rand_bytes(n_chunks)
|
||||
eval_form = [int(bytes_to_bls_field(b)) for b in batched(rand_bytes, int(BYTES_PER_FIELD_ELEMENT))]
|
||||
poly = kzg.bytes_to_polynomial(rand_bytes)
|
||||
self.assertEqual(poly.evaluation_form(), eval_form)
|
||||
for i, chunk in enumerate(eval_form):
|
||||
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), chunk)
|
||||
for i in range(n_chunks):
|
||||
self.assertEqual(poly.evaluation_form()[i], poly.eval(int(ROOTS_OF_UNITY[i])))
|
||||
|
||||
def test_commitment(self):
|
||||
rand_bytes = self.rand_bytes(32)
|
||||
_, commit = kzg.bytes_to_commitment(rand_bytes, GLOBAL_PARAMETERS)
|
||||
self.assertEqual(len(commit), 48)
|
||||
|
||||
def test_proof(self):
|
||||
rand_bytes = self.rand_bytes(2)
|
||||
poly = kzg.bytes_to_polynomial(rand_bytes)
|
||||
proof = kzg.generate_element_proof(0, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
|
||||
self.assertEqual(len(proof), 48)
|
||||
|
||||
def test_verify(self):
|
||||
n_chunks = 32
|
||||
rand_bytes = self.rand_bytes(n_chunks)
|
||||
_, commit = kzg.bytes_to_commitment(rand_bytes, GLOBAL_PARAMETERS)
|
||||
poly = kzg.bytes_to_polynomial(rand_bytes)
|
||||
for i, chunk in enumerate(batched(rand_bytes, BYTES_PER_FIELD_ELEMENT)):
|
||||
chunk = bytes(chunk)
|
||||
proof = kzg.generate_element_proof(i, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
|
||||
self.assertEqual(len(proof), 48)
|
||||
self.assertEqual(poly.eval(int(ROOTS_OF_UNITY[i])), bytes_to_bls_field(chunk))
|
||||
self.assertTrue(kzg.verify_element_proof(
|
||||
bytes_to_bls_field(chunk), commit, proof, i, ROOTS_OF_UNITY
|
||||
)
|
||||
)
|
||||
proof = kzg.generate_element_proof(0, poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
|
||||
for n in range(1, n_chunks):
|
||||
self.assertFalse(kzg.verify_element_proof(
|
||||
BLSFieldElement(0), commit, proof, n, ROOTS_OF_UNITY
|
||||
)
|
||||
)
|
|
@ -0,0 +1,18 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
|
||||
from da.kzg_rs.poly import Polynomial
|
||||
from da.kzg_rs.rs import encode, decode
|
||||
|
||||
|
||||
class TestFFT(TestCase):
|
||||
def test_encode_decode(self):
|
||||
poly = Polynomial(list(range(10)), modulus=BLS_MODULUS)
|
||||
encoded = encode(poly, 2, ROOTS_OF_UNITY)
|
||||
# remove a few points, but enough so we can reconstruct
|
||||
for i in [1, 3, 7]:
|
||||
encoded[i] = None
|
||||
decoded = decode(encoded, ROOTS_OF_UNITY, len(poly))
|
||||
# self.assertEqual(poly, decoded)
|
||||
for i in range(len(poly)):
|
||||
self.assertEqual(poly.eval(ROOTS_OF_UNITY[i]), decoded.eval(ROOTS_OF_UNITY[i]))
|
|
@ -0,0 +1,48 @@
|
|||
import random
|
||||
from typing import Tuple, Sequence, Generator
|
||||
from eth2spec.utils import bls
|
||||
from itertools import accumulate, repeat
|
||||
|
||||
|
||||
def __linear_combination(points, coeffs, zero=bls.Z1()):
|
||||
o = zero
|
||||
for point, coeff in zip(points, coeffs):
|
||||
o = bls.add(o, bls.multiply(point, coeff))
|
||||
return o
|
||||
|
||||
|
||||
# Verifies the integrity of a setup
|
||||
def verify_setup(setup) -> bool:
|
||||
g1_setup, g2_setup = setup
|
||||
g1_random_coefficients = [random.randrange(2**40) for _ in range(len(g1_setup) - 1)]
|
||||
g1_lower = __linear_combination(g1_setup[:-1], g1_random_coefficients, bls.Z1())
|
||||
g1_upper = __linear_combination(g1_setup[1:], g1_random_coefficients, bls.Z1())
|
||||
g2_random_coefficients = [random.randrange(2**40) for _ in range(len(g2_setup) - 1)]
|
||||
g2_lower = __linear_combination(g2_setup[:-1], g2_random_coefficients, bls.Z2())
|
||||
g2_upper = __linear_combination(g2_setup[1:], g2_random_coefficients, bls.Z2())
|
||||
return (
|
||||
g1_setup[0] == bls.G1() and
|
||||
g2_setup[0] == bls.G2() and
|
||||
bls.pairing_check([[g1_upper, bls.neg(g2_lower)], [g1_lower, g2_upper]])
|
||||
)
|
||||
|
||||
|
||||
def generate_one_sided_setup(length, secret, generator=bls.G1()):
|
||||
def __take(gen):
|
||||
return (next(gen) for _ in range(length))
|
||||
|
||||
secrets = repeat(secret)
|
||||
|
||||
return __take(accumulate(secrets, bls.multiply, initial=generator))
|
||||
|
||||
|
||||
# Generate a trusted setup with the given secret
|
||||
def generate_setup(
|
||||
g1_length,
|
||||
g2_length,
|
||||
secret
|
||||
) -> Tuple[Generator[bls.G1, None, None], Generator[bls.G2, None, None]]:
|
||||
return (
|
||||
generate_one_sided_setup(g1_length, secret, bls.G1()),
|
||||
generate_one_sided_setup(g2_length, secret, bls.G2()),
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
POWERS_OF_2 = {2**i for i in range(1, 32)}
|
||||
|
||||
|
||||
def is_power_of_two(n) -> bool:
|
||||
return n in POWERS_OF_2
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,21 @@
|
|||
from libp2p.typing import TProtocol
|
||||
|
||||
"""
|
||||
Some constants for use throught the poc
|
||||
"""
|
||||
|
||||
PROTOCOL_ID = TProtocol("/nomosda/1.0.0")
|
||||
MAX_READ_LEN = 2**32 - 1
|
||||
HASH_LENGTH = 256
|
||||
NODE_PORT_BASE = 7560
|
||||
EXECUTOR_PORT = 8766
|
||||
|
||||
# These can be overridden with cli params
|
||||
DEFAULT_DATA_SIZE = 1024
|
||||
DEFAULT_SUBNETS = 256
|
||||
DEFAULT_NODES = 32
|
||||
DEFAULT_SAMPLE_THRESHOLD = 12
|
||||
# how many nodes per subnet minimum
|
||||
DEFAULT_REPLICATION_FACTOR = 4
|
||||
|
||||
DEBUG = False
|
|
@ -0,0 +1,76 @@
|
|||
# Zone Executor to Nomos DA Communication
|
||||
|
||||
Protocol for communication between the Zone Executor and Nomos DA using Protocol Buffers (protobuf).
|
||||
|
||||
## Overview
|
||||
|
||||
The protocol defines messages used to request and respond to data dispersal, sampling operations, and session control within the Nomos DA system. The communication involves the exchange of blobs (binary large objects) and error handling for various operations.
|
||||
|
||||
## Messages
|
||||
|
||||
### Blob
|
||||
- **Blob**: Represents the binary data to be dispersed.
|
||||
- `bytes blob_id`: Unique identifier for the blob.
|
||||
- `bytes data`: The binary data of the blob.
|
||||
|
||||
### Error Handling
|
||||
- **DispersalErr**: Represents errors related to dispersal operations.
|
||||
- `bytes blob_id`: Unique identifier of the blob related to the error.
|
||||
- `enum DispersalErrType`: Enumeration of dispersal error types.
|
||||
- `CHUNK_SIZE`: Error due to incorrect chunk size.
|
||||
- `VERIFICATION`: Error due to verification failure.
|
||||
- `string err_description`: Description of the error.
|
||||
|
||||
- **SampleErr**: Represents errors related to sample operations.
|
||||
- `bytes blob_id`: Unique identifier of the blob related to the error.
|
||||
- `enum SampleErrType`: Enumeration of sample error types.
|
||||
- `NOT_FOUND`: Error when a blob is not found.
|
||||
- `string err_description`: Description of the error.
|
||||
|
||||
### Dispersal
|
||||
- **DispersalReq**: Request message for dispersing a blob.
|
||||
- `Blob blob`: The blob to be dispersed.
|
||||
|
||||
- **DispersalRes**: Response message for a dispersal request.
|
||||
- `oneof message_type`: Contains either a success response or an error.
|
||||
- `bytes blob_id`: Unique identifier of the dispersed blob.
|
||||
- `DispersalErr err`: Error occurred during dispersal.
|
||||
|
||||
### Sample
|
||||
- **SampleReq**: Request message for sampling a blob.
|
||||
- `bytes blob_id`: Unique identifier of the blob to be sampled.
|
||||
|
||||
- **SampleRes**: Response message for a sample request.
|
||||
- `oneof message_type`: Contains either a success response or an error.
|
||||
- `Blob blob`: The sampled blob.
|
||||
- `SampleErr err`: Error occurred during sampling.
|
||||
|
||||
### Session Control
|
||||
- **CloseMsg**: Message to close a session with a reason.
|
||||
- `enum CloseReason`: Enumeration of close reasons.
|
||||
- `GRACEFUL_SHUTDOWN`: Graceful shutdown of the session.
|
||||
- `SUBNET_CHANGE`: Change in the subnet.
|
||||
- `SUBNET_SAMPLE_FAIL`: Subnet sample failure.
|
||||
- `CloseReason reason`: Reason for closing the session.
|
||||
|
||||
- **SessionReq**: Request message for session control.
|
||||
- `oneof message_type`: Contains one of the following message types.
|
||||
- `CloseMsg close_msg`: Message to close the session.
|
||||
|
||||
### DispersalMessage
|
||||
- **DispersalMessage**: Wrapper message for different types of dispersal and sampling messages.
|
||||
- `oneof message_type`: Contains one of the following message types.
|
||||
- `DispersalReq dispersal_req`: Dispersal request.
|
||||
- `DispersalRes dispersal_res`: Dispersal response.
|
||||
- `SampleReq sample_req`: Sample request.
|
||||
- `SampleRes sample_res`: Sample response.
|
||||
|
||||
## Protobuf
|
||||
|
||||
To generate the updated protobuf serializer from `dispersal.proto`, run the following command:
|
||||
|
||||
```bash
|
||||
protoc --python_out=. dispersal.proto
|
||||
```
|
||||
|
||||
This will generate the necessary Python code to serialize and deserialize the messages defined in the `dispersal.proto` file.
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,87 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package nomos.da.dispersal.v1;
|
||||
|
||||
message Blob {
|
||||
bytes blob_id = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
// DISPERSAL
|
||||
|
||||
message DispersalErr {
|
||||
bytes blob_id = 1;
|
||||
|
||||
enum DispersalErrType {
|
||||
CHUNK_SIZE = 0;
|
||||
VERIFICATION = 1;
|
||||
}
|
||||
|
||||
DispersalErrType err_type = 2;
|
||||
string err_description = 3;
|
||||
}
|
||||
|
||||
message DispersalReq {
|
||||
Blob blob = 1;
|
||||
}
|
||||
|
||||
message DispersalRes {
|
||||
oneof message_type {
|
||||
bytes blob_id = 1;
|
||||
DispersalErr err = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// SAMPLING
|
||||
|
||||
message SampleErr {
|
||||
bytes blob_id = 1;
|
||||
|
||||
enum SampleErrType {
|
||||
NOT_FOUND = 0;
|
||||
}
|
||||
|
||||
SampleErrType err_type = 2;
|
||||
string err_description = 3;
|
||||
}
|
||||
|
||||
message SampleReq {
|
||||
bytes blob_id = 1;
|
||||
}
|
||||
|
||||
message SampleRes {
|
||||
oneof message_type {
|
||||
Blob blob = 1;
|
||||
SampleErr err = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// SESSION CONTROL
|
||||
|
||||
message CloseMsg {
|
||||
enum CloseReason {
|
||||
GRACEFUL_SHUTDOWN = 0;
|
||||
SUBNET_CHANGE = 1;
|
||||
SUBNET_SAMPLE_FAIL = 2;
|
||||
}
|
||||
|
||||
CloseReason reason = 1;
|
||||
}
|
||||
|
||||
message SessionReq {
|
||||
oneof message_type {
|
||||
CloseMsg close_msg = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// WRAPPER MESSAGE
|
||||
|
||||
message DispersalMessage {
|
||||
oneof message_type {
|
||||
DispersalReq dispersal_req = 1;
|
||||
DispersalRes dispersal_res = 2;
|
||||
SampleReq sample_req = 3;
|
||||
SampleRes sample_res = 4;
|
||||
SessionReq session_req = 5;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: dispersal.proto
|
||||
# Protobuf Python Version: 5.27.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
27,
|
||||
1,
|
||||
'',
|
||||
'dispersal.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ispersal.proto\x12\x15nomos.da.dispersal.v1\"%\n\x04\x42lob\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\xb6\x01\n\x0c\x44ispersalErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x46\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32\x34.nomos.da.dispersal.v1.DispersalErr.DispersalErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"4\n\x10\x44ispersalErrType\x12\x0e\n\nCHUNK_SIZE\x10\x00\x12\x10\n\x0cVERIFICATION\x10\x01\"9\n\x0c\x44ispersalReq\x12)\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x1b.nomos.da.dispersal.v1.Blob\"e\n\x0c\x44ispersalRes\x12\x11\n\x07\x62lob_id\x18\x01 \x01(\x0cH\x00\x12\x32\n\x03\x65rr\x18\x02 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalErrH\x00\x42\x0e\n\x0cmessage_type\"\x97\x01\n\tSampleErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12@\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32..nomos.da.dispersal.v1.SampleErr.SampleErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"\x1e\n\rSampleErrType\x12\r\n\tNOT_FOUND\x10\x00\"\x1c\n\tSampleReq\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\"y\n\tSampleRes\x12+\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x1b.nomos.da.dispersal.v1.BlobH\x00\x12/\n\x03\x65rr\x18\x02 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleErrH\x00\x42\x0e\n\x0cmessage_type\"\x98\x01\n\x08\x43loseMsg\x12;\n\x06reason\x18\x01 \x01(\x0e\x32+.nomos.da.dispersal.v1.CloseMsg.CloseReason\"O\n\x0b\x43loseReason\x12\x15\n\x11GRACEFUL_SHUTDOWN\x10\x00\x12\x11\n\rSUBNET_CHANGE\x10\x01\x12\x16\n\x12SUBNET_SAMPLE_FAIL\x10\x02\"R\n\nSessionReq\x12\x34\n\tclose_msg\x18\x01 \x01(\x0b\x32\x1f.nomos.da.dispersal.v1.CloseMsgH\x00\x42\x0e\n\x0cmessage_type\"\xc8\x02\n\x10\x44ispersalMessage\x12<\n\rdispersal_req\x18\x01 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalReqH\x00\x12<\n\rdispersal_res\x18\x02 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalResH\x00\x12\x36\n\nsample_req\x18\x03 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleReqH\x00\x12\x36\n\nsample_res\x18\x04 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleResH\x00\x12\x38\n\x0bsession_req\x18\x05 \x01(\x0b\x32!.nomos.da.dispersal.v1.SessionReqH\x00\x42\x0e\n\x0cmessage_typeb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dispersal_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_BLOB']._serialized_start=42
|
||||
_globals['_BLOB']._serialized_end=79
|
||||
_globals['_DISPERSALERR']._serialized_start=82
|
||||
_globals['_DISPERSALERR']._serialized_end=264
|
||||
_globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_start=212
|
||||
_globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_end=264
|
||||
_globals['_DISPERSALREQ']._serialized_start=266
|
||||
_globals['_DISPERSALREQ']._serialized_end=323
|
||||
_globals['_DISPERSALRES']._serialized_start=325
|
||||
_globals['_DISPERSALRES']._serialized_end=426
|
||||
_globals['_SAMPLEERR']._serialized_start=429
|
||||
_globals['_SAMPLEERR']._serialized_end=580
|
||||
_globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_start=550
|
||||
_globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_end=580
|
||||
_globals['_SAMPLEREQ']._serialized_start=582
|
||||
_globals['_SAMPLEREQ']._serialized_end=610
|
||||
_globals['_SAMPLERES']._serialized_start=612
|
||||
_globals['_SAMPLERES']._serialized_end=733
|
||||
_globals['_CLOSEMSG']._serialized_start=736
|
||||
_globals['_CLOSEMSG']._serialized_end=888
|
||||
_globals['_CLOSEMSG_CLOSEREASON']._serialized_start=809
|
||||
_globals['_CLOSEMSG_CLOSEREASON']._serialized_end=888
|
||||
_globals['_SESSIONREQ']._serialized_start=890
|
||||
_globals['_SESSIONREQ']._serialized_end=972
|
||||
_globals['_DISPERSALMESSAGE']._serialized_start=975
|
||||
_globals['_DISPERSALMESSAGE']._serialized_end=1303
|
||||
# @@protoc_insertion_point(module_scope)
|
|
@ -0,0 +1,123 @@
|
|||
import asyncio
|
||||
import argparse
|
||||
import proto
|
||||
from itertools import count
|
||||
|
||||
conn_id_counter = count(start=1)
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, conn_id, reader, writer, handler):
|
||||
self.conn_id = conn_id
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.handler = handler
|
||||
|
||||
async def read_and_process(self):
|
||||
try:
|
||||
while True:
|
||||
message = await proto.unpack_from_reader(self.reader)
|
||||
await self.handler(self.conn_id, self.writer, message)
|
||||
except Exception as e:
|
||||
print(f"MockTransport: An error occurred: {e}")
|
||||
finally:
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
|
||||
async def write(self, message):
|
||||
self.writer.write(message)
|
||||
await self.writer.drain()
|
||||
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, addr, port, handler=None):
|
||||
self.addr = addr
|
||||
self.port = port
|
||||
self.handler = handler if handler else self._handle
|
||||
|
||||
async def _on_conn(self, reader, writer):
|
||||
conn_id = next(conn_id_counter)
|
||||
transport = MockTransport(conn_id, reader, writer, self.handler)
|
||||
await transport.read_and_process()
|
||||
|
||||
async def _handle(self, conn_id, writer, message):
|
||||
if message.HasField('dispersal_req'):
|
||||
blob_id = message.dispersal_req.blob.blob_id
|
||||
data = message.dispersal_req.blob.data
|
||||
print(f"MockNode: Received DispersalRes: blob_id={blob_id}; data={data}")
|
||||
# Imitate succesful verification.
|
||||
writer.write(proto.new_dispersal_res_success_msg(blob_id))
|
||||
elif message.HasField('sample_req'):
|
||||
print(f"MockNode: Received SampleRes: blob_id={message.sample_req.blob_id}")
|
||||
else:
|
||||
print(f"MockNode: Received unknown message: {message} ")
|
||||
|
||||
async def run(self):
|
||||
server = await asyncio.start_server(
|
||||
self._on_conn, self.addr, self.port
|
||||
)
|
||||
print(f"MockNode: Server started at {self.addr}:{self.port}")
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
class MockExecutor:
|
||||
def __init__(self, addr, port, col_num, executor=None, handler=None):
|
||||
self.addr = addr
|
||||
self.port = port
|
||||
self.col_num = col_num
|
||||
self.connections = []
|
||||
self.interval = 10
|
||||
self.executor = executor if executor else self._execute
|
||||
self.handler = handler if handler else self._handle
|
||||
|
||||
async def _execute(self):
|
||||
message = proto.new_dispersal_req_msg(b"dummy_blob_id", b"dummy_data")
|
||||
while True:
|
||||
try:
|
||||
await asyncio.gather(*[t.write(message) for t in self.connections])
|
||||
await asyncio.sleep(self.interval)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"MockExecutor: Error during message sending: {e}")
|
||||
|
||||
async def _handle(self, conn_id, writer, message):
|
||||
if message.HasField('dispersal_res'):
|
||||
print(f"MockExecutor: Received DispersalRes: blob_id={message.dispersal_res.blob_id}")
|
||||
elif message.HasField('sample_res'):
|
||||
print(f"MockExecutor: Received SampleRes: blob_id={message.sample_res.blob_id}")
|
||||
else:
|
||||
print(f"MockExecutor: Received unknown message: {message}")
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(self.addr, self.port)
|
||||
conn_id = len(self.connections)
|
||||
transport = MockTransport(conn_id, reader, writer, self.handler)
|
||||
self.connections.append(transport)
|
||||
print(f"MockExecutor: Connected to {self.addr}:{self.port}, ID: {conn_id}")
|
||||
asyncio.create_task(transport.read_and_process())
|
||||
except Exception as e:
|
||||
print(f"MockExecutor: Failed to connect or lost connection: {e}")
|
||||
|
||||
async def run(self):
|
||||
await asyncio.gather(*(self._connect() for _ in range(self.col_num)))
|
||||
await self.executor()
|
||||
|
||||
|
||||
class MockSystem:
|
||||
def __init__(self, addr='localhost'):
|
||||
self.addr = addr
|
||||
|
||||
async def run_node_with_executor(self, col_number):
|
||||
node = MockNode(self.addr, 8888)
|
||||
executor = MockExecutor(self.addr, 8888, col_number)
|
||||
await asyncio.gather(node.run(), executor.run())
|
||||
|
||||
|
||||
def main():
|
||||
app = MockSystem()
|
||||
asyncio.run(app.run_node_with_executor(1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,126 @@
|
|||
from itertools import count
|
||||
|
||||
import dispersal.dispersal_pb2 as dispersal_pb2
|
||||
|
||||
MAX_MSG_LEN_BYTES = 2
|
||||
|
||||
|
||||
def pack_message(message):
|
||||
# SerializeToString method returns an instance of bytes.
|
||||
data = message.SerializeToString()
|
||||
length_prefix = len(data).to_bytes(MAX_MSG_LEN_BYTES, byteorder="big")
|
||||
return length_prefix + data
|
||||
|
||||
|
||||
async def unpack_from_reader(reader):
|
||||
length_prefix = await reader.readexactly(MAX_MSG_LEN_BYTES)
|
||||
data_length = int.from_bytes(length_prefix, byteorder="big")
|
||||
data = await reader.readexactly(data_length)
|
||||
return parse(data)
|
||||
|
||||
|
||||
def unpack_from_bytes(data):
|
||||
length_prefix = data[:MAX_MSG_LEN_BYTES]
|
||||
data_length = int.from_bytes(length_prefix, byteorder="big")
|
||||
return parse(data[MAX_MSG_LEN_BYTES : MAX_MSG_LEN_BYTES + data_length])
|
||||
|
||||
|
||||
def parse(data):
|
||||
message = dispersal_pb2.DispersalMessage()
|
||||
message.ParseFromString(data)
|
||||
return message
|
||||
|
||||
|
||||
# DISPERSAL
|
||||
|
||||
|
||||
def new_dispersal_req_msg(blob_id, data):
|
||||
blob = dispersal_pb2.Blob(blob_id=blob_id, data=data)
|
||||
dispersal_req = dispersal_pb2.DispersalReq(blob=blob)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_req=dispersal_req)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
def new_dispersal_res_success_msg(blob_id):
|
||||
dispersal_res = dispersal_pb2.DispersalRes(blob_id=blob_id)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_res=dispersal_res)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
def new_dispersal_res_chunk_size_error_msg(blob_id, description):
|
||||
dispersal_err = dispersal_pb2.DispersalErr(
|
||||
blob_id=blob_id,
|
||||
err_type=dispersal_pb2.DispersalErr.CHUNK_SIZE,
|
||||
err_description=description,
|
||||
)
|
||||
dispersal_res = dispersal_pb2.DispersalRes(err=dispersal_err)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_res=dispersal_res)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
def new_dispersal_res_verification_error_msg(blob_id, description):
|
||||
dispersal_err = dispersal_pb2.DispersalErr(
|
||||
blob_id=blob_id,
|
||||
err_type=dispersal_pb2.DispersalErr.VERIFICATION,
|
||||
err_description=description,
|
||||
)
|
||||
dispersal_res = dispersal_pb2.DispersalRes(err=dispersal_err)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(dispersal_res=dispersal_res)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
# SAMPLING
|
||||
|
||||
|
||||
def new_sample_req_msg(blob_id):
|
||||
sample_req = dispersal_pb2.SampleReq(blob_id=blob_id)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(sample_req=sample_req)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
def new_sample_res_success_msg(blob_id, data):
|
||||
blob = dispersal_pb2.Blob(blob_id=blob_id, data=data)
|
||||
sample_res = dispersal_pb2.SampleRes(blob=blob)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(sample_res=sample_res)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
def new_sample_res_not_found_error_msg(blob_id, description):
|
||||
sample_err = dispersal_pb2.SampleErr(
|
||||
blob_id=blob_id,
|
||||
err_type=dispersal_pb2.SampleErr.NOT_FOUND,
|
||||
err_description=description,
|
||||
)
|
||||
sample_res = dispersal_pb2.SampleRes(err=sample_err)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(sample_res=sample_res)
|
||||
return pack_message(dispersal_message)
|
||||
|
||||
|
||||
# SESSION CONTROL
|
||||
|
||||
|
||||
def new_close_msg(reason):
|
||||
close_msg = dispersal_pb2.CloseMsg(reason=reason)
|
||||
return close_msg
|
||||
|
||||
|
||||
def new_session_req_close_msg(reason):
|
||||
close_msg = new_close_msg(reason)
|
||||
session_req = dispersal_pb2.SessionReq(close_msg=close_msg)
|
||||
dispersal_message = dispersal_pb2.DispersalMessage(session_req=session_req)
|
||||
return dispersal_message
|
||||
|
||||
|
||||
def new_session_req_graceful_shutdown_msg():
|
||||
message = new_session_req_close_msg(dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN)
|
||||
return pack_message(message)
|
||||
|
||||
|
||||
def new_session_req_subnet_change_msg():
|
||||
message = new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_CHANGE)
|
||||
return pack_message(message)
|
||||
|
||||
|
||||
def new_session_req_subnet_sample_fail_msg():
|
||||
message = new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_SAMPLE_FAIL)
|
||||
return pack_message(message)
|
|
@ -0,0 +1,75 @@
|
|||
import dispersal_pb2
|
||||
import proto
|
||||
from unittest import TestCase
|
||||
|
||||
class TestMessageSerialization(TestCase):
|
||||
|
||||
def test_dispersal_req_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
data = b"dummy_data"
|
||||
packed_message = proto.new_dispersal_req_msg(blob_id, data)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('dispersal_req'))
|
||||
self.assertEqual(message.dispersal_req.blob.blob_id, blob_id)
|
||||
self.assertEqual(message.dispersal_req.blob.data, data)
|
||||
|
||||
def test_dispersal_res_success_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
packed_message = proto.new_dispersal_res_success_msg(blob_id)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('dispersal_res'))
|
||||
self.assertEqual(message.dispersal_res.blob_id, blob_id)
|
||||
|
||||
def test_dispersal_res_chunk_size_error_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
description = "Chunk size error"
|
||||
packed_message = proto.new_dispersal_res_chunk_size_error_msg(blob_id, description)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('dispersal_res'))
|
||||
self.assertEqual(message.dispersal_res.err.blob_id, blob_id)
|
||||
self.assertEqual(message.dispersal_res.err.err_type, dispersal_pb2.DispersalErr.CHUNK_SIZE)
|
||||
self.assertEqual(message.dispersal_res.err.err_description, description)
|
||||
|
||||
def test_dispersal_res_verification_error_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
description = "Verification error"
|
||||
packed_message = proto.new_dispersal_res_verification_error_msg(blob_id, description)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('dispersal_res'))
|
||||
self.assertEqual(message.dispersal_res.err.blob_id, blob_id)
|
||||
self.assertEqual(message.dispersal_res.err.err_type, dispersal_pb2.DispersalErr.VERIFICATION)
|
||||
self.assertEqual(message.dispersal_res.err.err_description, description)
|
||||
|
||||
def test_sample_req_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
packed_message = proto.new_sample_req_msg(blob_id)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('sample_req'))
|
||||
self.assertEqual(message.sample_req.blob_id, blob_id)
|
||||
|
||||
def test_sample_res_success_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
data = b"dummy_data"
|
||||
packed_message = proto.new_sample_res_success_msg(blob_id, data)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('sample_res'))
|
||||
self.assertEqual(message.sample_res.blob.blob_id, blob_id)
|
||||
self.assertEqual(message.sample_res.blob.data, data)
|
||||
|
||||
def test_sample_res_not_found_error_msg(self):
|
||||
blob_id = b"dummy_blob_id"
|
||||
description = "Blob not found"
|
||||
packed_message = proto.new_sample_res_not_found_error_msg(blob_id, description)
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('sample_res'))
|
||||
self.assertEqual(message.sample_res.err.blob_id, blob_id)
|
||||
self.assertEqual(message.sample_res.err.err_type, dispersal_pb2.SampleErr.NOT_FOUND)
|
||||
self.assertEqual(message.sample_res.err.err_description, description)
|
||||
|
||||
def test_session_req_close_msg(self):
|
||||
reason = dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN
|
||||
packed_message = proto.new_session_req_graceful_shutdown_msg()
|
||||
message = proto.unpack_from_bytes(packed_message)
|
||||
self.assertTrue(message.HasField('session_req'))
|
||||
self.assertTrue(message.session_req.HasField('close_msg'))
|
||||
self.assertEqual(message.session_req.close_msg.reason, reason)
|
|
@ -0,0 +1,106 @@
|
|||
from hashlib import sha256
|
||||
from random import randbytes
|
||||
from typing import Self
|
||||
|
||||
import dispersal.proto as proto
|
||||
import multiaddr
|
||||
import trio
|
||||
from constants import HASH_LENGTH, PROTOCOL_ID
|
||||
from libp2p import host, new_host
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
A class for simulating a simple executor.
|
||||
|
||||
Runs on hardcoded port.
|
||||
Creates random data and disperses it.
|
||||
One packet represents a subnet, and each packet is sent
|
||||
to one DANode.
|
||||
|
||||
"""
|
||||
|
||||
listen_addr: multiaddr.Multiaddr
|
||||
host: host
|
||||
port: int
|
||||
num_subnets: int
|
||||
node_list: {}
|
||||
# size of a packet
|
||||
data_size: int
|
||||
# holds random data for dispersal
|
||||
data: []
|
||||
# stores hashes of the data for later verification
|
||||
data_hashes: []
|
||||
blob_id: int
|
||||
|
||||
@classmethod
|
||||
def new(cls, port, node_list, num_subnets, data_size) -> Self:
|
||||
self = cls()
|
||||
self.listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
self.host = new_host()
|
||||
self.port = port
|
||||
self.num_subnets = num_subnets
|
||||
self.data_size = data_size
|
||||
# one packet per subnet
|
||||
self.data = [[] * data_size] * num_subnets
|
||||
# one hash per packet. **assumes 256 hash length**
|
||||
self.data_hashes = [[] * HASH_LENGTH] * num_subnets
|
||||
self.node_list = node_list
|
||||
# create random simulated data right from the beginning
|
||||
self.__create_data()
|
||||
return self
|
||||
|
||||
def get_id(self):
|
||||
return self.host.get_id()
|
||||
|
||||
def net_iface(self):
|
||||
return self.host
|
||||
|
||||
def get_port(self):
|
||||
return self.port
|
||||
|
||||
def get_hash(self, index: int):
|
||||
return self.data_hashes[index]
|
||||
|
||||
def __create_data(self):
|
||||
"""
|
||||
Create random data for dispersal
|
||||
One packet of self.data_size length per subnet
|
||||
"""
|
||||
id = sha256()
|
||||
for i in range(self.num_subnets):
|
||||
self.data[i] = randbytes(self.data_size)
|
||||
self.data_hashes[i] = sha256(self.data[i]).hexdigest()
|
||||
id.update(self.data[i])
|
||||
self.blob_id = id.digest()
|
||||
|
||||
async def disperse(self, nursery):
|
||||
"""
|
||||
Disperse the data to the DA network.
|
||||
Sends one packet of data per network node
|
||||
"""
|
||||
|
||||
async with self.host.run(listen_addrs=[self.listen_addr]):
|
||||
for subnet, nodes in self.node_list.items():
|
||||
# get first node of each subnet
|
||||
n = nodes[0]
|
||||
# connect to it...
|
||||
await self.host.connect(n)
|
||||
|
||||
# ...and send (async)
|
||||
stream = await self.host.new_stream(n.peer_id, [PROTOCOL_ID])
|
||||
nursery.start_soon(self.write_data, stream, subnet)
|
||||
|
||||
async def write_data(self, stream: INetStream, index: int) -> None:
|
||||
"""
|
||||
Send data to peer (async)
|
||||
The index is the subnet number
|
||||
"""
|
||||
|
||||
blob_id = self.blob_id
|
||||
blob_data = self.data[index]
|
||||
|
||||
message = proto.new_dispersal_req_msg(blob_id, blob_data)
|
||||
await stream.write(message)
|
|
@ -0,0 +1,35 @@
|
|||
import trio
|
||||
from constants import DEBUG, NODE_PORT_BASE
|
||||
from node import DANode
|
||||
|
||||
|
||||
class DANetwork:
|
||||
"""
|
||||
Lightweight wrapper around a network of DA nodes.
|
||||
Really just creates the network for now
|
||||
"""
|
||||
|
||||
num_nodes: int
|
||||
nodes: []
|
||||
|
||||
def __init__(self, nodes):
|
||||
self.num_nodes = nodes
|
||||
self.nodes = []
|
||||
|
||||
async def build(self, nursery, shutdown, disperse_send):
|
||||
port_idx = NODE_PORT_BASE
|
||||
for _ in range(self.num_nodes):
|
||||
port_idx += 1
|
||||
nursery.start_soon(
|
||||
DANode.new,
|
||||
port_idx,
|
||||
self.nodes,
|
||||
nursery,
|
||||
shutdown,
|
||||
disperse_send.clone(),
|
||||
)
|
||||
if DEBUG:
|
||||
print("net built")
|
||||
|
||||
def get_nodes(self):
|
||||
return self.nodes
|
|
@ -0,0 +1,129 @@
|
|||
import sys
|
||||
from hashlib import sha256
|
||||
from random import randint
|
||||
|
||||
import dispersal.proto as proto
|
||||
import multiaddr
|
||||
import trio
|
||||
from blspy import BasicSchemeMPL, G1Element, PrivateKey
|
||||
from constants import *
|
||||
from libp2p import host, new_host
|
||||
from libp2p.network.stream.exceptions import StreamReset
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
|
||||
class DANode:
|
||||
"""
|
||||
A class handling Data Availability (DA).
|
||||
|
||||
Runs on a hardcoded port.
|
||||
Starts a libp2p node.
|
||||
Listens on a handler for receiving data.
|
||||
Resends all data it receives to all peers it is connected to
|
||||
(therefore assumes connection logic is established elsewhere)
|
||||
|
||||
"""
|
||||
|
||||
listen_addr: multiaddr.Multiaddr
|
||||
libp2phost: host
|
||||
port: int
|
||||
node_list: []
|
||||
# list of packet hashes it "stores"
|
||||
hashes: set()
|
||||
|
||||
@classmethod
|
||||
async def new(cls, port, node_list, nursery, shutdown, disperse_send):
|
||||
self = cls()
|
||||
self.listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
self.libp2phost = new_host()
|
||||
self.port = port
|
||||
self.node_list = node_list
|
||||
self.hashes = set()
|
||||
nursery.start_soon(self.__run, nursery, shutdown, disperse_send)
|
||||
if DEBUG:
|
||||
print("DA node at port {} initialized".format(port))
|
||||
|
||||
def get_id(self):
|
||||
return self.libp2phost.get_id()
|
||||
|
||||
def net_iface(self):
|
||||
return self.libp2phost
|
||||
|
||||
def get_port(self):
|
||||
return self.port
|
||||
|
||||
async def __run(self, nursery, shutdown, disperse_send):
|
||||
"""
|
||||
Run the node. Starts libp2p host, and listener for data
|
||||
"""
|
||||
async with self.libp2phost.run(listen_addrs=[self.listen_addr]):
|
||||
print("started node at {}...".format(self.listen_addr))
|
||||
|
||||
# handler to run when data is received
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
nursery.start_soon(
|
||||
self.read_data, stream, nursery, shutdown, disperse_send
|
||||
)
|
||||
|
||||
# set the above handler
|
||||
self.libp2phost.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
# at this point the node is "initialized" - signal it's "ready"
|
||||
self.node_list.append(self)
|
||||
# run until we shutdown
|
||||
await shutdown.wait()
|
||||
|
||||
async def read_data(
|
||||
self, stream: INetStream, nursery, shutdown, disperse_send
|
||||
) -> None:
|
||||
"""
|
||||
We need to wait for incoming data, but also we want to shutdown
|
||||
when the test is finished.
|
||||
The following code makes sure that both events are listened to
|
||||
and the first which occurs is handled.
|
||||
"""
|
||||
|
||||
first_event = None
|
||||
|
||||
async def select_event(async_fn, cancel_scope):
|
||||
nonlocal first_event
|
||||
first_event = await async_fn()
|
||||
cancel_scope.cancel()
|
||||
disperse_send.close()
|
||||
|
||||
async def read_stream():
|
||||
while True:
|
||||
read_bytes = await stream.read(MAX_READ_LEN)
|
||||
if read_bytes is not None:
|
||||
message = proto.unpack_from_bytes(read_bytes)
|
||||
hashstr = sha256(message.dispersal_req.blob.data).hexdigest()
|
||||
if hashstr not in self.hashes:
|
||||
# "store" the received packet
|
||||
self.hashes.add(hashstr)
|
||||
# now disperse this hash to all peers
|
||||
nursery.start_soon(self.disperse, read_bytes, disperse_send)
|
||||
if DEBUG:
|
||||
print(
|
||||
"{} stored {}".format(
|
||||
self.libp2phost.get_id().pretty(), hashstr
|
||||
)
|
||||
)
|
||||
await disperse_send.send(-1)
|
||||
else:
|
||||
print("read_bytes is None, unexpected!")
|
||||
|
||||
nursery.start_soon(select_event, read_stream, nursery.cancel_scope)
|
||||
nursery.start_soon(select_event, shutdown.wait, nursery.cancel_scope)
|
||||
|
||||
async def disperse(self, packet, disperse_send) -> None:
|
||||
# disperse the given packet to all peers
|
||||
for p_id in self.libp2phost.get_peerstore().peer_ids():
|
||||
if p_id == self.libp2phost.get_id():
|
||||
continue
|
||||
await disperse_send.send(1)
|
||||
stream = await self.libp2phost.new_stream(p_id, [PROTOCOL_ID])
|
||||
|
||||
await stream.write(packet)
|
||||
|
||||
async def has_hash(self, hashstr: str):
|
||||
return hashstr in self.hashes
|
|
@ -0,0 +1,271 @@
|
|||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from random import randint
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
from constants import *
|
||||
from executor import Executor
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from network import DANetwork
|
||||
from subnet import calculate_subnets
|
||||
|
||||
"""
|
||||
Entry point for the poc.
|
||||
Handles cli arguments, initiates the network
|
||||
and waits for it to complete.
|
||||
|
||||
Also does some simple completion check.
|
||||
"""
|
||||
|
||||
|
||||
async def run_network(params):
|
||||
"""
|
||||
Create the network.
|
||||
Run the run_subnets
|
||||
"""
|
||||
|
||||
num_nodes = int(params.nodes)
|
||||
net = DANetwork(num_nodes)
|
||||
shutdown = trio.Event()
|
||||
disperse_send, disperse_recv = trio.open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(net.build, nursery, shutdown, disperse_send)
|
||||
nursery.start_soon(
|
||||
run_subnets, net, params, nursery, shutdown, disperse_send, disperse_recv
|
||||
)
|
||||
|
||||
|
||||
async def run_subnets(net, params, nursery, shutdown, disperse_send, disperse_recv):
|
||||
"""
|
||||
Run the actual PoC logic.
|
||||
Calculate the subnets.
|
||||
-> Establish connections based on the subnets <-
|
||||
Runs the executor.
|
||||
Runs simulated sampling.
|
||||
Runs simple completion check
|
||||
"""
|
||||
|
||||
num_nodes = int(params.nodes)
|
||||
num_subnets = int(params.subnets)
|
||||
data_size = int(params.data_size)
|
||||
sample_threshold = int(params.sample_threshold)
|
||||
fault_rate = int(params.fault_rate)
|
||||
replication_factor = int(params.replication_factor)
|
||||
|
||||
while len(net.get_nodes()) != num_nodes:
|
||||
print("nodes not ready yet")
|
||||
await trio.sleep(0.1)
|
||||
|
||||
print("Nodes ready")
|
||||
nodes = net.get_nodes()
|
||||
subnets = calculate_subnets(nodes, num_subnets, replication_factor)
|
||||
await print_subnet_info(subnets)
|
||||
|
||||
print("Establishing connections...")
|
||||
node_list = {}
|
||||
all_node_instances = set()
|
||||
await establish_connections(subnets, node_list, all_node_instances, fault_rate)
|
||||
|
||||
print("Starting executor...")
|
||||
exe = Executor.new(EXECUTOR_PORT, node_list, num_subnets, data_size)
|
||||
|
||||
print("Start dispersal and wait to complete...")
|
||||
print("depending on network and subnet size this may take a while...")
|
||||
global TIMESTAMP
|
||||
TIMESTAMP = time.time()
|
||||
async with trio.open_nursery() as subnursery:
|
||||
subnursery.start_soon(wait_disperse_finished, disperse_recv, num_subnets)
|
||||
subnursery.start_soon(exe.disperse, nursery)
|
||||
subnursery.start_soon(disperse_watcher, disperse_send.clone())
|
||||
|
||||
print()
|
||||
print()
|
||||
|
||||
print("OK. Start sampling...")
|
||||
checked = []
|
||||
for _ in range(sample_threshold):
|
||||
nursery.start_soon(sample_node, exe, subnets, checked)
|
||||
|
||||
print("Waiting for sampling to finish...")
|
||||
await check_complete(checked, sample_threshold)
|
||||
|
||||
print_connections(all_node_instances)
|
||||
|
||||
print("Test completed")
|
||||
shutdown.set()
|
||||
|
||||
|
||||
TIMESTAMP = time.time()
|
||||
|
||||
|
||||
def print_connections(node_list):
|
||||
for n in node_list:
|
||||
for p in n.net_iface().get_peerstore().peer_ids():
|
||||
if p == n.net_iface().get_id():
|
||||
continue
|
||||
print("node {} is connected to {}".format(n.get_id(), p))
|
||||
print()
|
||||
|
||||
|
||||
async def disperse_watcher(disperse_send):
|
||||
while time.time() - TIMESTAMP < 5:
|
||||
await trio.sleep(1)
|
||||
|
||||
await disperse_send.send(9999)
|
||||
print("canceled")
|
||||
|
||||
|
||||
async def wait_disperse_finished(disperse_recv, num_subnets):
|
||||
# run until there are no changes detected
|
||||
async for value in disperse_recv:
|
||||
if value == 9999:
|
||||
print("dispersal finished")
|
||||
return
|
||||
|
||||
print(".", end="")
|
||||
|
||||
global TIMESTAMP
|
||||
TIMESTAMP = time.time()
|
||||
|
||||
|
||||
async def print_subnet_info(subnets):
|
||||
"""
|
||||
Print which node is in what subnet
|
||||
"""
|
||||
|
||||
print()
|
||||
print("By subnets: ")
|
||||
for subnet in subnets:
|
||||
print("subnet: {} - ".format(subnet), end="")
|
||||
for n in subnets[subnet]:
|
||||
print(n.get_id().pretty()[:16], end=", ")
|
||||
print()
|
||||
|
||||
print()
|
||||
print()
|
||||
|
||||
|
||||
async def establish_connections(subnets, node_list, all_node_instances, fault_rate=0):
|
||||
"""
|
||||
Each node in a subnet connects to the other ones in that subnet.
|
||||
"""
|
||||
for subnet in subnets:
|
||||
# n is a DANode
|
||||
for n in subnets[subnet]:
|
||||
# while nodes connect to each other, they are **mutually** added
|
||||
# to their peer lists. Hence, we don't need to establish connections
|
||||
# again to peers we are already connected.
|
||||
# So in each iteration we get the peer list for the current node
|
||||
# to later check if we are already connected with the next peer
|
||||
this_nodes_peers = n.net_iface().get_peerstore().peer_ids()
|
||||
all_node_instances.add(n)
|
||||
faults = []
|
||||
for i in range(fault_rate):
|
||||
faults.append(randint(0, len(subnets[subnet])))
|
||||
for i, nn in enumerate(subnets[subnet]):
|
||||
# don't connect to self
|
||||
if nn.get_id() == n.get_id():
|
||||
continue
|
||||
if i in faults:
|
||||
continue
|
||||
remote_id = nn.get_id().pretty()
|
||||
remote_port = nn.get_port()
|
||||
# this script only works on localhost!
|
||||
addr = "/ip4/127.0.0.1/tcp/{}/p2p/{}/".format(remote_port, remote_id)
|
||||
remote_addr = multiaddr.Multiaddr(addr)
|
||||
remote = info_from_p2p_addr(remote_addr)
|
||||
if subnet not in node_list:
|
||||
node_list[subnet] = []
|
||||
node_list[subnet].append(remote)
|
||||
# check if we are already connected with this peer. If yes, skip connecting
|
||||
if nn.get_id() in this_nodes_peers:
|
||||
continue
|
||||
if DEBUG:
|
||||
print("{} connecting to {}...".format(n.get_id(), addr))
|
||||
await n.net_iface().connect(remote)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
async def check_complete(checked, sample_threshold):
|
||||
"""
|
||||
Simple completion check:
|
||||
Check how many nodes have already been "sampled"
|
||||
"""
|
||||
|
||||
while len(checked) < sample_threshold:
|
||||
await trio.sleep(0.5)
|
||||
print("check_complete exiting")
|
||||
return
|
||||
|
||||
|
||||
async def sample_node(exe, subnets, checked):
|
||||
"""
|
||||
Pick a random subnet.
|
||||
Pick a random node in that subnet.
|
||||
As the executor has a list of hashes per subnet,
|
||||
we can ask that node if it has that hash.
|
||||
"""
|
||||
|
||||
# s: subnet
|
||||
s = randint(0, len(subnets) - 1)
|
||||
# n: node (index)
|
||||
n = randint(0, len(subnets[s]) - 1)
|
||||
# actual node
|
||||
node = subnets[s][n]
|
||||
# pick the hash to check
|
||||
hashstr = exe.get_hash(s)
|
||||
# run the "sampling"
|
||||
has = await node.has_hash(hashstr)
|
||||
if has:
|
||||
print("node {} has hash {}".format(node.get_id().pretty(), hashstr))
|
||||
else:
|
||||
print("node {} does NOT HAVE hash {}".format(node.get_id().pretty(), hashstr))
|
||||
print("TEST FAILED")
|
||||
# signal we "sampled" another node
|
||||
checked.append(1)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--subnets", help="Number of subnets [default: 256]")
|
||||
parser.add_argument("-n", "--nodes", help="Number of nodes [default: 32]")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--sample-threshold",
|
||||
help="Threshold for sampling request attempts [default: 12]",
|
||||
)
|
||||
parser.add_argument("-d", "--data-size", help="Size of packages [default: 1024]")
|
||||
parser.add_argument("-f", "--fault_rate", help="Fault rate [default: 0]")
|
||||
parser.add_argument(
|
||||
"-r", "--replication_factor", help="Replication factor [default: 4]"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.subnets:
|
||||
args.subnets = DEFAULT_SUBNETS
|
||||
if not args.nodes:
|
||||
args.nodes = DEFAULT_NODES
|
||||
if not args.sample_threshold:
|
||||
args.sample_threshold = DEFAULT_SAMPLE_THRESHOLD
|
||||
if not args.data_size:
|
||||
args.data_size = DEFAULT_DATA_SIZE
|
||||
if not args.replication_factor:
|
||||
args.replication_factor = DEFAULT_REPLICATION_FACTOR
|
||||
if not args.fault_rate:
|
||||
args.fault_rate = 0
|
||||
|
||||
print("Number of subnets will be: {}".format(args.subnets))
|
||||
print("Number of nodes will be: {}".format(args.nodes))
|
||||
print("Size of data package will be: {}".format(args.data_size))
|
||||
print("Threshold for sampling attempts will be: {}".format(args.sample_threshold))
|
||||
print("Fault rate will be: {}".format(args.fault_rate))
|
||||
|
||||
print()
|
||||
print("*******************")
|
||||
print("Starting network...")
|
||||
|
||||
trio.run(run_network, args)
|
|
@ -0,0 +1,65 @@
|
|||
# Data Availability Subnets Proof-Of-Concept
|
||||
|
||||
## Contents
|
||||
This folder contains code as implementation for a Proof-Of-Concept (PoC) for the subnets designed
|
||||
to address dispersal and sampling in Data Availability (DA) in Nomos.
|
||||
|
||||
Refer to the [Specification](https://www.notion.so/Runnable-DA-PoC-Specification-50f204f2ff0a41d09de4926962bbb4ef?d=9e9677e5536a46d49fe95f366b7c3320#308624c50f1a42769b6c142976999483)
|
||||
for the details of the design of this PoC.
|
||||
|
||||
|
||||
Being a PoC, this code has no pretentions in terms of quality, and is certainly not meant to reach anywhere near production status.
|
||||
|
||||
## How to run
|
||||
|
||||
The entry point is `poc.py` , which can be run with a python3 binary.
|
||||
|
||||
It can be parametrized with the following options:
|
||||
|
||||
`python poc.py -s 512 -n 64 -t 12 -d 2048`
|
||||
|
||||
To understand what these parameter mean, just look at the help output:
|
||||
|
||||
```sh
|
||||
> python poc.py -h
|
||||
usage: poc.py [-h] [-s SUBNETS] [-n NODES] [-t SAMPLE_THRESHOLD] [-d DATA_SIZE]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
-s SUBNETS, --subnets SUBNETS
|
||||
Number of subnets [default: 256]
|
||||
-n NODES, --nodes NODES
|
||||
Number of nodes [default: 32]
|
||||
-t SAMPLE_THRESHOLD, --sample-threshold SAMPLE_THRESHOLD
|
||||
Threshold for sampling request attempts [default: 12]
|
||||
-d DATA_SIZE, --data-size DATA_SIZE
|
||||
Size of packages [default: 1024]
|
||||
```
|
||||
|
||||
|
||||
## What it does
|
||||
The PoC first creates an instance of a light-weight `DANetwork`, which in turn
|
||||
starts the configured number of nodes.
|
||||
|
||||
[!NOTE]
|
||||
Currently ports are hardcoded. Nodes start at 7561 and are instantiated sequentially from there.
|
||||
The Executor simulator runs on 8766.
|
||||
|
||||
After nodes are up, the subnets are calculated. Subnets calculation is explicitly **not part of the PoC**.
|
||||
Therefore, the PoC uses a simple strategy of filling all subnets sequentially, and if not enough nodes are requested,
|
||||
just fills up nodes up to a `REPLICATION_FACTOR` per subnet (thus, each subnet has at least `REPLICATION_FACTOR` nodes).
|
||||
|
||||
After nodes are assigned to subnets, the network connections (via direct libp2p links) are established.
|
||||
Each node in a subnet connects with every other node in that subnet.
|
||||
|
||||
Next, the executor is started. It is just a simulator. It creates random data for each subnet of `DATA_SIZE` length,
|
||||
simulating the columns generated by the NomosDA protocol.
|
||||
|
||||
It then establishes one connection per subnet and sends one packet of `DATA_SIZE` length on each of these connections.
|
||||
The executor also stores a hash of each packet per subnet.
|
||||
|
||||
Receiving nodes then forward this package to each of their peers in the subnet.
|
||||
They also store the respective hash (only).
|
||||
|
||||
Finally a simulated check samples up to `SAMPLE_THRESHOLD` nodes.
|
||||
For each subnet it simply picks a node randomly and asks if it has the hash.
|
|
@ -0,0 +1,65 @@
|
|||
from random import randint
|
||||
|
||||
from constants import *
|
||||
|
||||
|
||||
def calculate_subnets(node_list, num_subnets, replication_factor):
|
||||
"""
|
||||
Calculate in which subnet(s) to place each node.
|
||||
This PoC does NOT require this to be analyzed,
|
||||
nor to find the best solution.
|
||||
|
||||
Hence, we just use a simple model here:
|
||||
|
||||
1. Iterate all nodes and place each node in the subsequent subnet
|
||||
2. If the subnet list can not be filled, start again from the top of the list
|
||||
3. If each subnet does NOT have at least up to REPLICATION_FACTOR nodes, then
|
||||
fill up the list with nodes up to the factor.
|
||||
|
||||
NOTE: This might be incomplete and/or buggy, but should be sufficient for
|
||||
the purpose of the PoC.
|
||||
|
||||
If however, you find a bug, please report.
|
||||
|
||||
"""
|
||||
# key of dict is the subnet number
|
||||
subnets = {}
|
||||
for i, n in enumerate(node_list):
|
||||
idx = i % num_subnets
|
||||
|
||||
# each key has an array, so multiple nodes can be filter
|
||||
# into a subnet
|
||||
if idx not in subnets:
|
||||
subnets[idx] = []
|
||||
subnets[idx].append(n)
|
||||
|
||||
listlen = len(node_list)
|
||||
i = listlen
|
||||
# if there are less nodes than subnets
|
||||
while i < num_subnets:
|
||||
subnets[i] = []
|
||||
subnets[i].append(node_list[i % listlen])
|
||||
i += 1
|
||||
|
||||
# if not each subnet has at least factor number of nodes, fill up
|
||||
if listlen < replication_factor * num_subnets:
|
||||
for subnet in subnets:
|
||||
last = subnets[subnet][len(subnets[subnet]) - 1].get_id()
|
||||
idx = -1
|
||||
# what is the last filled index of a subnet row
|
||||
for j, n in enumerate(node_list):
|
||||
if n.get_id() == last:
|
||||
idx = j + 1
|
||||
# fill up until factor
|
||||
while len(subnets[subnet]) < replication_factor:
|
||||
# wrap index if at end
|
||||
if idx > len(node_list) - 1:
|
||||
idx = 0
|
||||
# don't add same node multiple times
|
||||
if node_list[idx] in subnets[subnet]:
|
||||
idx += 1
|
||||
continue
|
||||
subnets[subnet].append(node_list[idx])
|
||||
idx += 1
|
||||
|
||||
return subnets
|
|
@ -0,0 +1,20 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from da.common import ChunksMatrix
|
||||
|
||||
|
||||
class TestCommon(TestCase):
|
||||
|
||||
def test_chunks_matrix_columns(self):
|
||||
matrix = ChunksMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
expected = [[1, 4, 7], [2, 5, 8], [3, 6, 9]]
|
||||
for c1, c2 in zip(expected, matrix.columns):
|
||||
self.assertEqual(c1, c2)
|
||||
|
||||
def test_chunks_matrix_transposed(self):
|
||||
matrix = ChunksMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
expected = ChunksMatrix([[1, 4, 7], [2, 5, 8], [3, 6, 9]])
|
||||
self.assertEqual(matrix.transposed(), expected)
|
||||
matrix = ChunksMatrix([[1, 2, 3], [4, 5, 6]])
|
||||
expected = ChunksMatrix([[1, 4], [2, 5], [3, 6]])
|
||||
self.assertEqual(matrix.transposed(), expected)
|
|
@ -0,0 +1,74 @@
|
|||
from hashlib import sha3_256
|
||||
from unittest import TestCase
|
||||
|
||||
from da.encoder import DAEncoderParams, DAEncoder
|
||||
from da.test_encoder import TestEncoder
|
||||
from da.verifier import DAVerifier, DABlob
|
||||
from da.common import NodeId, Attestation, Bitfield, NomosDaG2ProofOfPossession as bls_pop
|
||||
from da.dispersal import Dispersal, EncodedData, DispersalSettings
|
||||
|
||||
|
||||
class TestDispersal(TestCase):
|
||||
def setUp(self):
|
||||
self.n_nodes = 16
|
||||
self.nodes_ids = [NodeId(x.to_bytes(length=32, byteorder='big')) for x in range(self.n_nodes)]
|
||||
self.secret_keys = list(range(1, self.n_nodes+1))
|
||||
self.public_keys = [bls_pop.SkToPk(sk) for sk in self.secret_keys]
|
||||
# sort by pk as we do in dispersal
|
||||
self.secret_keys, self.public_keys = zip(
|
||||
*sorted(zip(self.secret_keys, self.public_keys), key=lambda x: x[1])
|
||||
)
|
||||
dispersal_settings = DispersalSettings(
|
||||
self.nodes_ids,
|
||||
self.public_keys,
|
||||
self.n_nodes // 2 + 1
|
||||
)
|
||||
self.dispersal = Dispersal(dispersal_settings)
|
||||
self.encoder_test = TestEncoder()
|
||||
self.encoder_test.setUp()
|
||||
|
||||
def test_build_certificate_insufficient_attestations(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.dispersal._build_certificate(None, [], [])
|
||||
|
||||
def test_build_certificate_enough_attestations(self):
|
||||
mock_encoded_data = EncodedData(
|
||||
None, None, None, [], [], [], bytes(b"f"*48), []
|
||||
)
|
||||
mock_message = sha3_256(mock_encoded_data.aggregated_column_commitment).digest()
|
||||
mock_attestations = [Attestation(bls_pop.Sign(sk, mock_message)) for sk in self.secret_keys]
|
||||
certificate = self.dispersal._build_certificate(
|
||||
mock_encoded_data,
|
||||
mock_attestations,
|
||||
Bitfield([True for _ in range(len(self.secret_keys))])
|
||||
)
|
||||
self.assertIsNotNone(certificate)
|
||||
self.assertEqual(certificate.aggregated_column_commitment, mock_encoded_data.aggregated_column_commitment)
|
||||
self.assertEqual(certificate.row_commitments, [])
|
||||
self.assertIsNotNone(certificate.aggregated_signatures)
|
||||
self.assertTrue(
|
||||
certificate.verify(self.public_keys)
|
||||
)
|
||||
|
||||
def test_disperse(self):
|
||||
data = self.encoder_test.data
|
||||
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
|
||||
encoded_data = DAEncoder(encoding_params).encode(data)
|
||||
|
||||
# mock send and await method with local verifiers
|
||||
def __send_and_await_response(node: NodeId, blob: DABlob):
|
||||
sk = self.secret_keys[int.from_bytes(node)]
|
||||
verifier = DAVerifier(sk, self.public_keys)
|
||||
return verifier.verify(blob)
|
||||
# inject mock send and await method
|
||||
self.dispersal._send_and_await_response = __send_and_await_response
|
||||
|
||||
certificate = self.dispersal.disperse(encoded_data)
|
||||
self.assertIsNotNone(certificate)
|
||||
self.assertTrue(certificate.verify(self.public_keys)
|
||||
)
|
||||
self.assertEqual(
|
||||
certificate.signers,
|
||||
[True if i < self.dispersal.settings.threshold else False for i in range(self.n_nodes)]
|
||||
)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
from itertools import chain, batched
|
||||
from random import randrange, randbytes
|
||||
from unittest import TestCase
|
||||
|
||||
from eth2spec.deneb.mainnet import bytes_to_bls_field
|
||||
|
||||
from da import encoder
|
||||
from da.encoder import DAEncoderParams, DAEncoder
|
||||
from eth2spec.eip7594.mainnet import BYTES_PER_FIELD_ELEMENT, BLSFieldElement
|
||||
|
||||
from da.kzg_rs.common import BLS_MODULUS, ROOTS_OF_UNITY
|
||||
from da.kzg_rs import kzg, rs
|
||||
|
||||
|
||||
class TestEncoder(TestCase):
|
||||
def setUp(self):
|
||||
self.params: DAEncoderParams = DAEncoderParams(column_count=16, bytes_per_chunk=31)
|
||||
self.encoder: DAEncoder = DAEncoder(self.params)
|
||||
self.elements = 32
|
||||
self.data = bytearray(
|
||||
chain.from_iterable(
|
||||
randbytes(self.params.bytes_per_chunk)
|
||||
for _ in range(self.elements)
|
||||
)
|
||||
)
|
||||
|
||||
def assert_encoding(self, encoder_params: DAEncoderParams, data: bytes):
|
||||
encoded_data = encoder.DAEncoder(encoder_params).encode(data)
|
||||
self.assertEqual(encoded_data.data, data)
|
||||
extended_factor = 2
|
||||
column_count = encoder_params.column_count*extended_factor
|
||||
columns_len = len(list(encoded_data.extended_matrix.columns))
|
||||
self.assertEqual(columns_len, column_count)
|
||||
chunks_size = (len(data) // encoder_params.bytes_per_chunk) // encoder_params.column_count
|
||||
self.assertEqual(len(encoded_data.row_commitments), chunks_size)
|
||||
self.assertEqual(len(encoded_data.row_proofs), chunks_size)
|
||||
self.assertEqual(len(encoded_data.row_proofs[0]), column_count)
|
||||
self.assertIsNotNone(encoded_data.aggregated_column_commitment)
|
||||
self.assertEqual(len(encoded_data.aggregated_column_proofs), columns_len)
|
||||
|
||||
# verify rows
|
||||
for row, proofs, commitment in zip(encoded_data.extended_matrix, encoded_data.row_proofs, encoded_data.row_commitments):
|
||||
for i, (chunk, proof) in enumerate(zip(row, proofs)):
|
||||
self.assertTrue(
|
||||
kzg.verify_element_proof(bytes_to_bls_field(chunk), commitment, proof, i, ROOTS_OF_UNITY)
|
||||
)
|
||||
|
||||
# verify column aggregation
|
||||
for i, (column, proof) in enumerate(zip(encoded_data.extended_matrix.columns, encoded_data.aggregated_column_proofs)):
|
||||
data = DAEncoder.hash_column_and_commitment(column, commitment)
|
||||
kzg.verify_element_proof(
|
||||
bytes_to_bls_field(data),
|
||||
encoded_data.aggregated_column_commitment,
|
||||
proof,
|
||||
i,
|
||||
ROOTS_OF_UNITY
|
||||
)
|
||||
|
||||
def test_chunkify(self):
|
||||
encoder_settings = DAEncoderParams(column_count=2, bytes_per_chunk=31)
|
||||
elements = 10
|
||||
data = bytes(chain.from_iterable(int.to_bytes(0, length=31, byteorder='big') for _ in range(elements)))
|
||||
_encoder = encoder.DAEncoder(encoder_settings)
|
||||
chunks_matrix = _encoder._chunkify_data(data)
|
||||
self.assertEqual(len(chunks_matrix), elements//encoder_settings.column_count)
|
||||
for row in chunks_matrix:
|
||||
self.assertEqual(len(row), encoder_settings.column_count)
|
||||
self.assertEqual(len(row[0]), 32)
|
||||
|
||||
def test_compute_row_kzg_commitments(self):
|
||||
chunks_matrix = self.encoder._chunkify_data(self.data)
|
||||
polynomials, commitments = zip(*self.encoder._compute_row_kzg_commitments(chunks_matrix))
|
||||
self.assertEqual(len(commitments), len(chunks_matrix))
|
||||
self.assertEqual(len(polynomials), len(chunks_matrix))
|
||||
|
||||
def test_rs_encode_rows(self):
|
||||
chunks_matrix = self.encoder._chunkify_data(self.data)
|
||||
extended_chunks_matrix = self.encoder._rs_encode_rows(chunks_matrix)
|
||||
for r1, r2 in zip(chunks_matrix, extended_chunks_matrix):
|
||||
self.assertEqual(len(r1), len(r2)//2)
|
||||
r2 = [BLSFieldElement.from_bytes(x) for x in r2]
|
||||
poly_1 = kzg.bytes_to_polynomial(r1.as_bytes())
|
||||
# we check against decoding so we now the encoding was properly done
|
||||
poly_2 = rs.decode(r2, ROOTS_OF_UNITY, len(poly_1))
|
||||
self.assertEqual(poly_1, poly_2)
|
||||
|
||||
def test_compute_rows_proofs(self):
|
||||
chunks_matrix = self.encoder._chunkify_data(self.data)
|
||||
polynomials, commitments = zip(*self.encoder._compute_row_kzg_commitments(chunks_matrix))
|
||||
extended_chunks_matrix = self.encoder._rs_encode_rows(chunks_matrix)
|
||||
original_proofs = self.encoder._compute_rows_proofs(chunks_matrix, polynomials, commitments)
|
||||
extended_proofs = self.encoder._compute_rows_proofs(extended_chunks_matrix, polynomials, commitments)
|
||||
# check original sized matrix
|
||||
for row, poly, commitment, proofs in zip(chunks_matrix, polynomials, commitments, original_proofs):
|
||||
self.assertEqual(len(proofs), len(row))
|
||||
for i, chunk in enumerate(row):
|
||||
self.assertTrue(kzg.verify_element_proof(BLSFieldElement.from_bytes(chunk), commitment, proofs[i], i, ROOTS_OF_UNITY))
|
||||
# check extended matrix
|
||||
for row, poly, commitment, proofs in zip(extended_chunks_matrix, polynomials, commitments, extended_proofs):
|
||||
for i, chunk in enumerate(row):
|
||||
self.assertTrue(kzg.verify_element_proof(BLSFieldElement.from_bytes(chunk), commitment, proofs[i], i, ROOTS_OF_UNITY))
|
||||
|
||||
def test_compute_column_kzg_commitments(self):
|
||||
chunks_matrix = self.encoder._chunkify_data(self.data)
|
||||
polynomials, commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix))
|
||||
self.assertEqual(len(commitments), len(chunks_matrix[0]))
|
||||
self.assertEqual(len(polynomials), len(chunks_matrix[0]))
|
||||
|
||||
def test_generate_aggregated_column_commitments(self):
|
||||
chunks_matrix = self.encoder._chunkify_data(self.data)
|
||||
_, column_commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix))
|
||||
poly, commitment = self.encoder._compute_aggregated_column_commitment(chunks_matrix, column_commitments)
|
||||
self.assertIsNotNone(poly)
|
||||
self.assertIsNotNone(commitment)
|
||||
|
||||
def test_generate_aggregated_column_proofs(self):
|
||||
chunks_matrix = self.encoder._chunkify_data(self.data)
|
||||
_, column_commitments = zip(*self.encoder._compute_column_kzg_commitments(chunks_matrix))
|
||||
poly, _ = self.encoder._compute_aggregated_column_commitment(chunks_matrix, column_commitments)
|
||||
proofs = self.encoder._compute_aggregated_column_proofs(poly, column_commitments)
|
||||
self.assertEqual(len(proofs), len(column_commitments))
|
||||
|
||||
def test_encode(self):
|
||||
from random import randbytes
|
||||
sizes = [pow(2, exp) for exp in range(4, 8, 2)]
|
||||
encoder_params = DAEncoderParams(
|
||||
column_count=8,
|
||||
bytes_per_chunk=31
|
||||
)
|
||||
for size in sizes:
|
||||
data = bytes(
|
||||
chain.from_iterable(
|
||||
randbytes(encoder_params.bytes_per_chunk)
|
||||
for _ in range(size*encoder_params.column_count)
|
||||
)
|
||||
)
|
||||
self.assert_encoding(encoder_params, data)
|
|
@ -0,0 +1,134 @@
|
|||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
from typing import List, Optional
|
||||
|
||||
from da.common import NodeId, build_attestation_message, BLSPublicKey, NomosDaG2ProofOfPossession as bls_pop
|
||||
from da.api.common import DAApi, VID, Metadata
|
||||
from da.verifier import DAVerifier, DABlob
|
||||
from da.api.test_flow import MockStore
|
||||
from da.dispersal import Dispersal, DispersalSettings
|
||||
from da.test_encoder import TestEncoder
|
||||
from da.encoder import DAEncoderParams, DAEncoder
|
||||
|
||||
|
||||
class DAVerifierWApi:
|
||||
def __init__(self, sk: int, public_keys: List[BLSPublicKey]):
|
||||
self.store = MockStore()
|
||||
self.api = DAApi(self.store)
|
||||
self.verifier = DAVerifier(sk, public_keys)
|
||||
|
||||
def receive_blob(self, blob: DABlob):
|
||||
if attestation := self.verifier.verify(blob):
|
||||
# Warning: If aggregated col commitment and row commitment are the same,
|
||||
# the build_attestation_message method will produce the same output.
|
||||
cert_id = build_attestation_message(blob.aggregated_column_commitment, blob.rows_commitments)
|
||||
self.store.populate(blob, cert_id)
|
||||
return attestation
|
||||
|
||||
def receive_cert(self, vid: VID):
|
||||
# Usually the certificate would be verifier here,
|
||||
# but we are assuming that this it is already coming from the verified block,
|
||||
# in which case all certificates had been already verified by the DA Node.
|
||||
self.api.write(vid.cert_id, vid.metadata)
|
||||
|
||||
def read(self, app_id, indexes) -> List[Optional[DABlob]]:
|
||||
return self.api.read(app_id, indexes)
|
||||
|
||||
|
||||
class TestFullFlow(TestCase):
|
||||
def setUp(self):
|
||||
self.n_nodes = 16
|
||||
self.nodes_ids = [NodeId(x.to_bytes(length=32, byteorder='big')) for x in range(self.n_nodes)]
|
||||
self.secret_keys = list(range(1, self.n_nodes+1))
|
||||
self.public_keys = [bls_pop.SkToPk(sk) for sk in self.secret_keys]
|
||||
# sort by pk as we do in dispersal
|
||||
self.secret_keys, self.public_keys = zip(
|
||||
*sorted(zip(self.secret_keys, self.public_keys), key=lambda x: x[1])
|
||||
)
|
||||
dispersal_settings = DispersalSettings(
|
||||
self.nodes_ids,
|
||||
self.public_keys,
|
||||
self.n_nodes
|
||||
)
|
||||
self.dispersal = Dispersal(dispersal_settings)
|
||||
self.encoder_test = TestEncoder()
|
||||
self.encoder_test.setUp()
|
||||
|
||||
self.api_nodes = [DAVerifierWApi(k, self.public_keys) for k in self.secret_keys]
|
||||
|
||||
def test_full_flow(self):
|
||||
app_id = int.to_bytes(1)
|
||||
index = 1
|
||||
|
||||
# encoder
|
||||
data = self.encoder_test.data
|
||||
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
|
||||
encoded_data = DAEncoder(encoding_params).encode(data)
|
||||
|
||||
# mock send and await method with local verifiers
|
||||
def __send_and_await_response(node: int, blob: DABlob):
|
||||
node = self.api_nodes[int.from_bytes(node)]
|
||||
return node.receive_blob(blob)
|
||||
|
||||
# inject mock send and await method
|
||||
self.dispersal._send_and_await_response = __send_and_await_response
|
||||
certificate = self.dispersal.disperse(encoded_data)
|
||||
|
||||
vid = VID(
|
||||
certificate.id(),
|
||||
Metadata(app_id, index)
|
||||
)
|
||||
|
||||
# verifier
|
||||
for node in self.api_nodes:
|
||||
node.receive_cert(vid)
|
||||
|
||||
# read from api and confirm its working
|
||||
# notice that we need to sort the api_nodes by their public key to have the blobs sorted in the same fashion
|
||||
# we do actually do dispersal.
|
||||
blobs = list(chain.from_iterable(
|
||||
node.read(app_id, [index])
|
||||
for node in sorted(self.api_nodes, key=lambda n: bls_pop.SkToPk(n.verifier.sk))
|
||||
))
|
||||
original_blobs = list(self.dispersal._prepare_data(encoded_data))
|
||||
self.assertEqual(blobs, original_blobs)
|
||||
|
||||
def test_same_blob_multiple_indexes(self):
|
||||
app_id = int.to_bytes(1)
|
||||
indexes = [1, 2, 3] # Different indexes to test with the same blob
|
||||
|
||||
# encoder
|
||||
data = self.encoder_test.data
|
||||
encoding_params = DAEncoderParams(column_count=self.n_nodes // 2, bytes_per_chunk=31)
|
||||
encoded_data = DAEncoder(encoding_params).encode(data)
|
||||
|
||||
# mock send and await method with local verifiers
|
||||
def __send_and_await_response(node: int, blob: DABlob):
|
||||
node = self.api_nodes[int.from_bytes(node)]
|
||||
return node.receive_blob(blob)
|
||||
|
||||
# inject mock send and await method
|
||||
self.dispersal._send_and_await_response = __send_and_await_response
|
||||
certificate = self.dispersal.disperse(encoded_data)
|
||||
|
||||
# Loop through each index and simulate dispersal with the same cert_id but different metadata
|
||||
for index in indexes:
|
||||
vid = VID(
|
||||
certificate.id(),
|
||||
Metadata(app_id, index)
|
||||
)
|
||||
|
||||
# verifier
|
||||
for node in self.api_nodes:
|
||||
node.receive_cert(vid)
|
||||
|
||||
# Verify retrieval for each index
|
||||
for index in indexes:
|
||||
# Notice that we need to sort the api_nodes by their public key to have the blobs sorted in the same fashion
|
||||
# as we do actually do dispersal.
|
||||
blobs = list(chain.from_iterable(
|
||||
node.read(app_id, [index])
|
||||
for node in sorted(self.api_nodes, key=lambda n: bls_pop.SkToPk(n.verifier.sk))
|
||||
))
|
||||
original_blobs = list(self.dispersal._prepare_data(encoded_data))
|
||||
self.assertEqual(blobs, original_blobs, f"Failed at index {index}")
|
|
@ -0,0 +1,71 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from da.common import Column, NomosDaG2ProofOfPossession as bls_pop
|
||||
from da.encoder import DAEncoder
|
||||
from da.kzg_rs import kzg
|
||||
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY
|
||||
from da.test_encoder import TestEncoder
|
||||
from da.verifier import Attestation, DAVerifier, DABlob
|
||||
|
||||
|
||||
class TestVerifier(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.verifier = DAVerifier(1987, [bls_pop.SkToPk(1987)])
|
||||
|
||||
def test_verify_column(self):
|
||||
column = Column(int.to_bytes(i, length=32) for i in range(8))
|
||||
_, column_commitment = kzg.bytes_to_commitment(column.as_bytes(), GLOBAL_PARAMETERS)
|
||||
aggregated_poly, aggregated_column_commitment = kzg.bytes_to_commitment(
|
||||
DAEncoder.hash_column_and_commitment(column, column_commitment), GLOBAL_PARAMETERS
|
||||
)
|
||||
aggregated_proof = kzg.generate_element_proof(0, aggregated_poly, GLOBAL_PARAMETERS, ROOTS_OF_UNITY)
|
||||
self.assertTrue(
|
||||
self.verifier._verify_column(
|
||||
column, column_commitment, aggregated_column_commitment, aggregated_proof, 0
|
||||
)
|
||||
)
|
||||
|
||||
def test_verify(self):
|
||||
_ = TestEncoder()
|
||||
_.setUp()
|
||||
encoded_data = _.encoder.encode(_.data)
|
||||
verifiers_sk = [i for i in range(1000, 1000+len(encoded_data.chunked_data[0]))]
|
||||
vefiers_pk = [bls_pop.SkToPk(k) for k in verifiers_sk]
|
||||
for i, column in enumerate(encoded_data.chunked_data.columns):
|
||||
verifier = DAVerifier(verifiers_sk[i], vefiers_pk)
|
||||
da_blob = DABlob(
|
||||
Column(column),
|
||||
encoded_data.column_commitments[i],
|
||||
encoded_data.aggregated_column_commitment,
|
||||
encoded_data.aggregated_column_proofs[i],
|
||||
encoded_data.row_commitments,
|
||||
[row[i] for row in encoded_data.row_proofs],
|
||||
)
|
||||
self.assertIsNotNone(verifier.verify(da_blob))
|
||||
|
||||
def test_verify_duplicated_blob(self):
|
||||
_ = TestEncoder()
|
||||
_.setUp()
|
||||
encoded_data = _.encoder.encode(_.data)
|
||||
columns = enumerate(encoded_data.chunked_data.columns)
|
||||
i, column = next(columns)
|
||||
da_blob = DABlob(
|
||||
Column(column),
|
||||
encoded_data.column_commitments[i],
|
||||
encoded_data.aggregated_column_commitment,
|
||||
encoded_data.aggregated_column_proofs[i],
|
||||
encoded_data.row_commitments,
|
||||
[row[i] for row in encoded_data.row_proofs],
|
||||
)
|
||||
self.assertIsNotNone(self.verifier.verify(da_blob))
|
||||
for i, column in columns:
|
||||
da_blob = DABlob(
|
||||
Column(column),
|
||||
encoded_data.column_commitments[i],
|
||||
encoded_data.aggregated_column_commitment,
|
||||
encoded_data.aggregated_column_proofs[i],
|
||||
encoded_data.row_commitments,
|
||||
[row[i] for row in encoded_data.row_proofs],
|
||||
)
|
||||
self.assertIsNone(self.verifier.verify(da_blob))
|
|
@ -0,0 +1,114 @@
|
|||
from dataclasses import dataclass
|
||||
from hashlib import sha3_256
|
||||
from typing import List, Optional, Sequence, Set, Dict
|
||||
|
||||
from eth2spec.deneb.mainnet import BLSFieldElement
|
||||
from eth2spec.eip7594.mainnet import (
|
||||
KZGCommitment as Commitment,
|
||||
KZGProof as Proof,
|
||||
)
|
||||
|
||||
import da.common
|
||||
from da.common import Column, Chunk, Attestation, BLSPrivateKey, BLSPublicKey, NomosDaG2ProofOfPossession as bls_pop
|
||||
from da.encoder import DAEncoder
|
||||
from da.kzg_rs import kzg
|
||||
from da.kzg_rs.common import ROOTS_OF_UNITY, GLOBAL_PARAMETERS, BLS_MODULUS
|
||||
|
||||
|
||||
@dataclass
|
||||
class DABlob:
|
||||
column: Column
|
||||
column_commitment: Commitment
|
||||
aggregated_column_commitment: Commitment
|
||||
aggregated_column_proof: Proof
|
||||
rows_commitments: List[Commitment]
|
||||
rows_proofs: List[Proof]
|
||||
|
||||
def id(self) -> bytes:
|
||||
return da.common.build_attestation_message(self.aggregated_column_commitment, self.rows_commitments)
|
||||
|
||||
def column_id(self) -> bytes:
|
||||
return sha3_256(self.column.as_bytes()).digest()
|
||||
|
||||
|
||||
class DAVerifier:
|
||||
def __init__(self, sk: BLSPrivateKey, nodes_pks: List[BLSPublicKey]):
|
||||
self.attested_blobs: Dict[bytes, (bytes, Attestation)] = dict()
|
||||
self.sk = sk
|
||||
self.index = nodes_pks.index(bls_pop.SkToPk(self.sk))
|
||||
|
||||
@staticmethod
|
||||
def _verify_column(
|
||||
column: Column,
|
||||
column_commitment: Commitment,
|
||||
aggregated_column_commitment: Commitment,
|
||||
aggregated_column_proof: Proof,
|
||||
index: int
|
||||
) -> bool:
|
||||
# 1. compute commitment for column
|
||||
_, computed_column_commitment = kzg.bytes_to_commitment(column.as_bytes(), GLOBAL_PARAMETERS)
|
||||
# 2. If computed column commitment != column commitment, fail
|
||||
if column_commitment != computed_column_commitment:
|
||||
return False
|
||||
# 3. compute column hash
|
||||
column_hash = DAEncoder.hash_column_and_commitment(column, column_commitment)
|
||||
# 4. Check proof with commitment and proof over the aggregated column commitment
|
||||
chunk = BLSFieldElement.from_bytes(column_hash)
|
||||
return kzg.verify_element_proof(
|
||||
chunk, aggregated_column_commitment, aggregated_column_proof, index, ROOTS_OF_UNITY
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_chunk(chunk: Chunk, commitment: Commitment, proof: Proof, index: int) -> bool:
|
||||
chunk = BLSFieldElement(int.from_bytes(bytes(chunk)) % BLS_MODULUS)
|
||||
return kzg.verify_element_proof(chunk, commitment, proof, index, ROOTS_OF_UNITY)
|
||||
|
||||
@staticmethod
|
||||
def _verify_chunks(
|
||||
chunks: Sequence[Chunk],
|
||||
commitments: Sequence[Commitment],
|
||||
proofs: Sequence[Proof],
|
||||
index: int
|
||||
) -> bool:
|
||||
if not (len(chunks) == len(commitments) == len(proofs)):
|
||||
return False
|
||||
for chunk, commitment, proof in zip(chunks, commitments, proofs):
|
||||
if not DAVerifier._verify_chunk(chunk, commitment, proof, index):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _build_attestation(self, blob: DABlob) -> Attestation:
|
||||
hasher = sha3_256()
|
||||
hasher.update(bytes(blob.aggregated_column_commitment))
|
||||
for c in blob.rows_commitments:
|
||||
hasher.update(bytes(c))
|
||||
message = hasher.digest()
|
||||
return Attestation(signature=bls_pop.Sign(self.sk, message))
|
||||
|
||||
def verify(self, blob: DABlob) -> Optional[Attestation]:
|
||||
blob_id = blob.id()
|
||||
if previous_attestation := self.attested_blobs.get(blob_id):
|
||||
column_id, attestation = previous_attestation
|
||||
# we already attested, is cached so we return it
|
||||
if column_id == blob.column_id():
|
||||
return attestation
|
||||
# we already attested and they are asking us to attest the same data different column
|
||||
# skip
|
||||
return None
|
||||
is_column_verified = DAVerifier._verify_column(
|
||||
blob.column,
|
||||
blob.column_commitment,
|
||||
blob.aggregated_column_commitment,
|
||||
blob.aggregated_column_proof,
|
||||
self.index
|
||||
)
|
||||
if not is_column_verified:
|
||||
return
|
||||
are_chunks_verified = DAVerifier._verify_chunks(
|
||||
blob.column, blob.rows_commitments, blob.rows_proofs, self.index
|
||||
)
|
||||
if not are_chunks_verified:
|
||||
return
|
||||
attestation = self._build_attestation(blob)
|
||||
self.attested_blobs[blob_id] = (blob.column_id(), attestation)
|
||||
return attestation
|
Loading…
Reference in New Issue