Merge branch 'mixnet-v2' into mixnet-v2-sim-after-france

This commit is contained in:
Youngjoon Lee 2024-06-27 14:05:50 +09:00
commit 55fd61bc65
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
31 changed files with 433 additions and 692 deletions

View File

@ -4,7 +4,7 @@ from itertools import chain, zip_longest, compress
from typing import List, Generator, Self, Sequence
from eth2spec.eip7594.mainnet import Bytes32, KZGCommitment as Commitment
from py_ecc.bls import G2ProofOfPossession as bls_pop
from py_ecc.bls import G2ProofOfPossession
class NodeId(Bytes32):
@ -67,7 +67,7 @@ class Certificate:
# we sort them as the signers bitfield is sorted by the public keys as well
signers_keys = list(compress(sorted(nodes_public_keys), self.signers))
message = build_attestation_message(self.aggregated_column_commitment, self.row_commitments)
return bls_pop.AggregateVerify(signers_keys, [message]*len(signers_keys), self.aggregated_signatures)
return NomosDaG2ProofOfPossession.AggregateVerify(signers_keys, [message]*len(signers_keys), self.aggregated_signatures)
def build_attestation_message(aggregated_column_commitment: Commitment, row_commitments: Sequence[Commitment]) -> bytes:
@ -76,3 +76,7 @@ def build_attestation_message(aggregated_column_commitment: Commitment, row_comm
for c in row_commitments:
hasher.update(bytes(c))
return hasher.digest()
class NomosDaG2ProofOfPossession(G2ProofOfPossession):
# Domain specific tag for Nomos DA protocol
DST = b"NOMOS_DA_AVAIL"

View File

@ -2,9 +2,7 @@ from dataclasses import dataclass
from hashlib import sha3_256
from typing import List, Optional, Generator, Sequence
from py_ecc.bls import G2ProofOfPossession as bls_pop
from da.common import Certificate, NodeId, BLSPublicKey, Bitfield, build_attestation_message
from da.common import Certificate, NodeId, BLSPublicKey, Bitfield, build_attestation_message, NomosDaG2ProofOfPossession as bls_pop
from da.encoder import EncodedData
from da.verifier import DABlob, Attestation

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple
import eth2spec.eip7594.mainnet
from py_ecc.bls.typing import G1Uncompressed, G2Uncompressed
@ -12,8 +12,11 @@ G2 = G2Uncompressed
BYTES_PER_FIELD_ELEMENT = 32
BLS_MODULUS = eth2spec.eip7594.mainnet.BLS_MODULUS
PRIMITIVE_ROOT: int = 7
GLOBAL_PARAMETERS: List[G1]
GLOBAL_PARAMETERS_G2: List[G2]
# secret is fixed but this should come from a different synchronization protocol
GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(1024, 8, 1987))
ROOTS_OF_UNITY: List[int] = compute_roots_of_unity(2, BLS_MODULUS, 4096)
GLOBAL_PARAMETERS, GLOBAL_PARAMETERS_G2 = map(list, generate_setup(4096, 8, 1987))
ROOTS_OF_UNITY: Tuple[int] = compute_roots_of_unity(
PRIMITIVE_ROOT, 4096, BLS_MODULUS
)

65
da/kzg_rs/fft.py Normal file
View File

@ -0,0 +1,65 @@
from typing import Sequence, List
from eth2spec.deneb.mainnet import BLSFieldElement
from eth2spec.utils import bls
from da.kzg_rs.common import G1
def fft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]:
if len(vals) == 1:
return vals
L = fft_g1(vals[::2], roots_of_unity[::2], modulus)
R = fft_g1(vals[1::2], roots_of_unity[::2], modulus)
o = [bls.Z1() for _ in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = bls.multiply(y, roots_of_unity[i])
o[i] = (x + y_times_root)
o[i + len(L)] = x + -y_times_root
return o
def ifft_g1(vals: Sequence[G1], roots_of_unity: Sequence[BLSFieldElement], modulus: int) -> List[G1]:
assert len(vals) == len(roots_of_unity)
# modular inverse
invlen = pow(len(vals), modulus-2, modulus)
return [
bls.multiply(x, invlen)
for x in fft_g1(
vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus
)
]
def _fft(
vals: Sequence[BLSFieldElement],
roots_of_unity: Sequence[BLSFieldElement],
modulus: int,
) -> Sequence[BLSFieldElement]:
if len(vals) == 1:
return vals
L = _fft(vals[::2], roots_of_unity[::2], modulus)
R = _fft(vals[1::2], roots_of_unity[::2], modulus)
o = [BLSFieldElement(0) for _ in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = BLSFieldElement((int(y) * int(roots_of_unity[i])) % modulus)
o[i] = BLSFieldElement((int(x) + y_times_root) % modulus)
o[i + len(L)] = BLSFieldElement((int(x) - int(y_times_root) + modulus) % modulus)
return o
def fft(vals, root_of_unity, modulus):
assert len(vals) == len(root_of_unity)
return _fft(vals, root_of_unity, modulus)
def ifft(vals, roots_of_unity, modulus):
assert len(vals) == len(roots_of_unity)
# modular inverse
invlen = pow(len(vals), modulus-2, modulus)
return [
BLSFieldElement((int(x) * invlen) % modulus)
for x in _fft(
vals, [roots_of_unity[0], *roots_of_unity[:0:-1]], modulus
)
]

78
da/kzg_rs/fk20.py Normal file
View File

@ -0,0 +1,78 @@
from typing import List, Sequence
from eth2spec.deneb.mainnet import KZGProof as Proof, BLSFieldElement
from eth2spec.utils import bls
from da.kzg_rs.common import G1, BLS_MODULUS, PRIMITIVE_ROOT
from da.kzg_rs.fft import fft, fft_g1, ifft_g1
from da.kzg_rs.poly import Polynomial
from da.kzg_rs.roots import compute_roots_of_unity
from da.kzg_rs.utils import is_power_of_two
def __toeplitz1(global_parameters: List[G1], polynomial_degree: int) -> List[G1]:
"""
This part can be precomputed for different global_parameters lengths depending on polynomial degree of powers of two.
:param global_parameters:
:param roots_of_unity:
:param polynomial_degree:
:return:
"""
assert len(global_parameters) >= polynomial_degree
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, polynomial_degree*2, BLS_MODULUS)
global_parameters = global_parameters[:polynomial_degree]
# algorithm only works on powers of 2 for dft computations
assert is_power_of_two(len(global_parameters))
roots_of_unity = roots_of_unity[:2*polynomial_degree]
vector_x_extended = global_parameters + [bls.multiply(bls.Z1(), 0) for _ in range(len(global_parameters))]
vector_x_extended_fft = fft_g1(vector_x_extended, roots_of_unity, BLS_MODULUS)
return vector_x_extended_fft
def __toeplitz2(coefficients: List[G1], extended_vector: Sequence[G1]) -> List[G1]:
assert is_power_of_two(len(coefficients))
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, len(coefficients), BLS_MODULUS)
toeplitz_coefficients_fft = fft(coefficients, roots_of_unity, BLS_MODULUS)
return [bls.multiply(v, c) for v, c in zip(extended_vector, toeplitz_coefficients_fft)]
def __toeplitz3(h_extended_fft: Sequence[G1], polynomial_degree: int) -> List[G1]:
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, len(h_extended_fft), BLS_MODULUS)
return ifft_g1(h_extended_fft, roots_of_unity, BLS_MODULUS)[:polynomial_degree]
def fk20_generate_proofs(
polynomial: Polynomial, global_parameters: List[G1]
) -> List[Proof]:
"""
Generate all proofs for the polynomial points in batch.
This method uses the fk20 algorthm from https://eprint.iacr.org/2023/033.pdf
Disclaimer: It only works for polynomial degree of powers of two.
:param polynomial: polynomial to generate proof for
:param global_parameters: setup generated parameters
:return: list of proof for each point in the polynomial
"""
polynomial_degree = len(polynomial)
assert len(global_parameters) >= polynomial_degree
assert is_power_of_two(len(polynomial))
# 1 - Build toeplitz matrix for h values
# 1.1 y = dft([s^d-1, s^d-2, ..., s, 1, *[0 for _ in len(polynomial)]])
# 1.2 z = dft([*[0 for _ in len(polynomial)], f1, f2, ..., fd])
# 1.3 u = y * v * roots_of_unity(len(polynomial)*2)
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, polynomial_degree, BLS_MODULUS)
global_parameters = [*global_parameters[polynomial_degree-2::-1], bls.multiply(bls.Z1(), 0)]
extended_vector = __toeplitz1(global_parameters, polynomial_degree)
# 2 - Build circulant matrix with the polynomial coefficients (reversed N..n, and padded)
toeplitz_coefficients = [
polynomial.coefficients[-1],
*(BLSFieldElement(0) for _ in range(polynomial_degree+1)),
*polynomial.coefficients[1:-1]
]
h_extended_vector = __toeplitz2(toeplitz_coefficients, extended_vector)
# 3 - Perform fft and nub the tail half as it is padding
h_vector = __toeplitz3(h_extended_vector, polynomial_degree)
# 4 - proof are the dft of the h vector
proofs = fft_g1(h_vector, roots_of_unity, BLS_MODULUS)
proofs = [Proof(bls.G1_to_bytes48(proof)) for proof in proofs]
return proofs

View File

@ -1,14 +1,25 @@
def compute_roots_of_unity(primitive_root, p, n):
"""
Compute the roots of unity modulo p.
from typing import Tuple
Parameters:
primitive_root (int): Primitive root modulo p.
p (int): Modulus.
n (int): Number of roots of unity to compute.
Returns:
list: List of roots of unity modulo p.
def compute_root_of_unity(primitive_root: int, order: int, modulus: int) -> int:
"""
roots_of_unity = [pow(primitive_root, i, p) for i in range(n)]
return roots_of_unity
Generate a w such that ``w**length = 1``.
"""
assert (modulus - 1) % order == 0
return pow(primitive_root, (modulus - 1) // order, modulus)
def compute_roots_of_unity(primitive_root: int, order: int, modulus: int) -> Tuple[int]:
"""
Compute a list of roots of unity for a given order.
The order must divide the BLS multiplicative group order, i.e. BLS_MODULUS - 1
"""
assert (modulus - 1) % order == 0
root_of_unity = compute_root_of_unity(primitive_root, order, modulus)
roots = []
current_root_of_unity = 1
for _ in range(order):
roots.append(current_root_of_unity)
current_root_of_unity = current_root_of_unity * root_of_unity % modulus
return tuple(roots)

14
da/kzg_rs/test_fft.py Normal file
View File

@ -0,0 +1,14 @@
from unittest import TestCase
from .roots import compute_roots_of_unity
from .common import BLS_MODULUS
from .fft import fft, ifft
class TestFFT(TestCase):
def test_fft_ifft(self):
for size in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
roots_of_unity = compute_roots_of_unity(2, size, BLS_MODULUS)
vals = list(x for x in range(size))
vals_fft = fft(vals, roots_of_unity, BLS_MODULUS)
self.assertEqual(vals, ifft(vals_fft, roots_of_unity, BLS_MODULUS))

28
da/kzg_rs/test_fk20.py Normal file
View File

@ -0,0 +1,28 @@
from itertools import chain
from unittest import TestCase
import random
from .fk20 import fk20_generate_proofs
from .kzg import generate_element_proof, bytes_to_polynomial
from .common import BLS_MODULUS, BYTES_PER_FIELD_ELEMENT, GLOBAL_PARAMETERS, PRIMITIVE_ROOT
from .roots import compute_roots_of_unity
class TestFK20(TestCase):
@staticmethod
def rand_bytes(n_chunks=1024):
return bytes(
chain.from_iterable(
int.to_bytes(random.randrange(BLS_MODULUS), length=BYTES_PER_FIELD_ELEMENT)
for _ in range(n_chunks)
)
)
def test_fk20(self):
for size in [16, 32, 64, 128, 256]:
roots_of_unity = compute_roots_of_unity(PRIMITIVE_ROOT, size, BLS_MODULUS)
rand_bytes = self.rand_bytes(size)
polynomial = bytes_to_polynomial(rand_bytes)
proofs = [generate_element_proof(i, polynomial, GLOBAL_PARAMETERS, roots_of_unity) for i in range(size)]
fk20_proofs = fk20_generate_proofs(polynomial, GLOBAL_PARAMETERS)
self.assertEqual(len(proofs), len(fk20_proofs))
self.assertEqual(proofs, fk20_proofs)

5
da/kzg_rs/utils.py Normal file
View File

@ -0,0 +1,5 @@
POWERS_OF_2 = {2**i for i in range(1, 32)}
def is_power_of_two(n) -> bool:
return n in POWERS_OF_2

View File

@ -4,11 +4,9 @@ from unittest import TestCase
from da.encoder import DAEncoderParams, DAEncoder
from da.test_encoder import TestEncoder
from da.verifier import DAVerifier, DABlob
from da.common import NodeId, Attestation, Bitfield
from da.common import NodeId, Attestation, Bitfield, NomosDaG2ProofOfPossession as bls_pop
from da.dispersal import Dispersal, EncodedData, DispersalSettings
from py_ecc.bls import G2ProofOfPossession as bls_pop
class TestDispersal(TestCase):
def setUp(self):

View File

@ -2,9 +2,7 @@ from itertools import chain
from unittest import TestCase
from typing import List, Optional
from py_ecc.bls import G2ProofOfPossession as bls_pop
from da.common import NodeId, build_attestation_message, BLSPublicKey
from da.common import NodeId, build_attestation_message, BLSPublicKey, NomosDaG2ProofOfPossession as bls_pop
from da.api.common import DAApi, VID, Metadata
from da.verifier import DAVerifier, DABlob
from da.api.test_flow import MockStore

View File

@ -1,8 +1,6 @@
from unittest import TestCase
from py_ecc.bls import G2ProofOfPossession as bls_pop
from da.common import Column
from da.common import Column, NomosDaG2ProofOfPossession as bls_pop
from da.encoder import DAEncoder
from da.kzg_rs import kzg
from da.kzg_rs.common import GLOBAL_PARAMETERS, ROOTS_OF_UNITY

View File

@ -7,10 +7,9 @@ from eth2spec.eip7594.mainnet import (
KZGCommitment as Commitment,
KZGProof as Proof,
)
from py_ecc.bls import G2ProofOfPossession as bls_pop
import da.common
from da.common import Column, Chunk, Attestation, BLSPrivateKey, BLSPublicKey
from da.common import Column, Chunk, Attestation, BLSPrivateKey, BLSPublicKey, NomosDaG2ProofOfPossession as bls_pop
from da.encoder import DAEncoder
from da.kzg_rs import kzg
from da.kzg_rs.common import ROOTS_OF_UNITY, GLOBAL_PARAMETERS, BLS_MODULUS

View File

@ -1,21 +0,0 @@
# Mixnet Specification
This is the executable specification of Mixnet, which can be used as a networking layer of the Nomos network.
![](structure.png)
## Public Components
- [`mixnet.py`](mixnet.py): A public interface of the Mixnet layer, which can be used by upper layers
- [`robustness.py`](robustness.py): A public interface of the Robustness layer, which can be on top of the Mixnet layer and used by upper layers
## Private Components
There are two primary components in the Mixnet layer.
- [`client.py`](client.py): A mix client interface, which splits a message into Sphinx packets, sends packets to mix nodes, and receives messages via gossip. Also, this emits cover packets periodically.
- [`node.py`](node.py): A mix node interface, which receives Sphinx packets from other mix nodes, processes packets, and forwards packets to other mix nodes. This works only when selected by the topology construction.
Each component receives a new topology from the Robustness layer.
There is no interaction between mix client and mix node components.

View File

@ -1,13 +0,0 @@
from typing import TypeAlias
import blspy
from mixnet.utils import random_bytes
BlsPrivateKey: TypeAlias = blspy.PrivateKey
BlsPublicKey: TypeAlias = blspy.G1Element
def generate_bls() -> BlsPrivateKey:
seed = random_bytes(32)
return blspy.BasicSchemeMPL.key_gen(seed)

View File

@ -1,118 +0,0 @@
from __future__ import annotations
import asyncio
from contextlib import suppress
from typing import Self
from mixnet.config import MixClientConfig, MixnetTopology
from mixnet.node import PacketQueue
from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_interval_sec
class MixClient:
config: MixClientConfig
real_packet_queue: PacketQueue
outbound_socket: PacketQueue
task: asyncio.Task # A reference just to prevent task from being garbage collected
@classmethod
async def new(
cls,
config: MixClientConfig,
) -> Self:
self = cls()
self.config = config
self.real_packet_queue = asyncio.Queue()
self.outbound_socket = asyncio.Queue()
self.task = asyncio.create_task(self.__run())
return self
def set_topology(self, topology: MixnetTopology) -> None:
"""
Replace the old topology with the new topology received
In real implementations, this method may be integrated in a long-running task.
Here in the spec, this method has been simplified as a setter, assuming the single-thread test environment.
"""
self.config.topology = topology
# Only for testing
def get_topology(self) -> MixnetTopology:
return self.config.topology
async def send_message(self, msg: bytes) -> None:
packets_and_routes = PacketBuilder.build_real_packets(msg, self.config.topology)
for packet, route in packets_and_routes:
await self.real_packet_queue.put((route[0].addr, packet))
def subscribe_messages(self) -> "asyncio.Queue[bytes]":
"""
Subscribe messages, which went through mix nodes and were broadcasted via gossip
"""
return asyncio.Queue()
async def __run(self):
"""
Emit packets at the Poisson emission_rate_per_min.
If a real packet is scheduled to be sent, this thread sends the real packet to the mixnet,
and schedules redundant real packets to be emitted in the next turns.
If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min.
"""
redundant_real_packet_queue: PacketQueue = asyncio.Queue()
emission_notifier_queue = asyncio.Queue()
_ = asyncio.create_task(
self.__emission_notifier(
self.config.emission_rate_per_min, emission_notifier_queue
)
)
while True:
# Wait until the next emission time
_ = await emission_notifier_queue.get()
try:
await self.__emit(self.config.redundancy, redundant_real_packet_queue)
finally:
# Python convention: indicate that the previously enqueued task has been processed
emission_notifier_queue.task_done()
async def __emit(
self,
redundancy: int, # b in the spec
redundant_real_packet_queue: PacketQueue,
):
if not redundant_real_packet_queue.empty():
addr, packet = redundant_real_packet_queue.get_nowait()
await self.outbound_socket.put((addr, packet))
return
if not self.real_packet_queue.empty():
addr, packet = self.real_packet_queue.get_nowait()
# Schedule redundant real packets
for _ in range(redundancy - 1):
redundant_real_packet_queue.put_nowait((addr, packet))
await self.outbound_socket.put((addr, packet))
packets_and_routes = PacketBuilder.build_drop_cover_packets(
b"drop cover", self.config.topology
)
# We have a for loop here, but we expect that the total num of packets is 1
# because the dummy message is short.
for packet, route in packets_and_routes:
await self.outbound_socket.put((route[0].addr, packet))
async def __emission_notifier(
self, emission_rate_per_min: int, queue: asyncio.Queue
):
while True:
await asyncio.sleep(poisson_interval_sec(emission_rate_per_min))
queue.put_nowait(None)
async def cancel(self) -> None:
self.task.cancel()
with suppress(asyncio.CancelledError):
await self.task

View File

@ -2,110 +2,56 @@ from __future__ import annotations
import random
from dataclasses import dataclass
from typing import List, TypeAlias
from typing import List
from cryptography.hazmat.primitives.asymmetric.x25519 import (
X25519PrivateKey,
X25519PublicKey,
)
from pysphinx.node import Node
from mixnet.bls import BlsPrivateKey, BlsPublicKey
from mixnet.fisheryates import FisherYates
from pysphinx.sphinx import Node as SphinxNode
@dataclass
class MixnetConfig:
topology_config: MixnetTopologyConfig
mixclient_config: MixClientConfig
mixnode_config: MixNodeConfig
node_configs: List[NodeConfig]
membership: MixMembership
@dataclass
class MixnetTopologyConfig:
mixnode_candidates: List[MixNodeInfo]
size: MixnetTopologySize
entropy: bytes
class NodeConfig:
private_key: X25519PrivateKey
transmission_rate_per_sec: int # Global Transmission Rate
@dataclass
class MixClientConfig:
emission_rate_per_min: int # Poisson rate parameter: lambda
redundancy: int
topology: MixnetTopology
class MixMembership:
nodes: List[NodeInfo]
@dataclass
class MixNodeConfig:
encryption_private_key: X25519PrivateKey
delay_rate_per_min: int # Poisson rate parameter: mu
@dataclass
class MixnetTopology:
# In production, this can be a 1-D array, which is accessible by indexes.
# Here, we use a 2-D array for readability.
layers: List[List[MixNodeInfo]]
def __init__(
self,
config: MixnetTopologyConfig,
) -> None:
"""
Build a new topology deterministically using an entropy and a given set of candidates.
"""
shuffled = FisherYates.shuffle(config.mixnode_candidates, config.entropy)
sampled = shuffled[: config.size.num_total_mixnodes()]
layers = []
for layer_id in range(config.size.num_layers):
start = layer_id * config.size.num_mixnodes_per_layer
layer = sampled[start : start + config.size.num_mixnodes_per_layer]
layers.append(layer)
self.layers = layers
def generate_route(self, mix_destination: MixNodeInfo) -> list[MixNodeInfo]:
def generate_route(self, num_hops: int, last_mix: NodeInfo) -> list[NodeInfo]:
"""
Generate a mix route for a Sphinx packet.
The pre-selected mix_destination is used as a last mix node in the route,
so that associated packets can be merged together into a original message.
"""
route = [random.choice(layer) for layer in self.layers[:-1]]
route.append(mix_destination)
route = [self.choose() for _ in range(num_hops - 1)]
route.append(last_mix)
return route
def choose_mix_destination(self) -> MixNodeInfo:
def choose(self) -> NodeInfo:
"""
Choose a mix node from the last mix layer as a mix destination
that will reconstruct a message from Sphinx packets.
Choose a mix node as a mix destination that will reconstruct a message from Sphinx packets.
"""
return random.choice(self.layers[-1])
return random.choice(self.nodes)
@dataclass
class MixnetTopologySize:
num_layers: int
num_mixnodes_per_layer: int
class NodeInfo:
private_key: X25519PrivateKey
def num_total_mixnodes(self) -> int:
return self.num_layers * self.num_mixnodes_per_layer
def public_key(self) -> X25519PublicKey:
return self.private_key.public_key()
# 32-byte that represents an IP address and a port of a mix node.
NodeAddress: TypeAlias = bytes
@dataclass
class MixNodeInfo:
identity_private_key: BlsPrivateKey
encryption_private_key: X25519PrivateKey
addr: NodeAddress
def identity_public_key(self) -> BlsPublicKey:
return self.identity_private_key.get_g1()
def encryption_public_key(self) -> X25519PublicKey:
return self.encryption_private_key.public_key()
def sphinx_node(self) -> Node:
return Node(self.encryption_private_key, self.addr)
def sphinx_node(self) -> SphinxNode:
# TODO: Use a pre-signed incentive tx, instead of NodeAddress
dummy_node_addr = bytes(32)
return SphinxNode(self.private_key, dummy_node_addr)

View File

@ -1,21 +0,0 @@
import random
from typing import List
class FisherYates:
@staticmethod
def shuffle(elements: List, entropy: bytes) -> List:
"""
Fisher-Yates shuffling algorithm.
In Python, random.shuffle implements the Fisher-Yates shuffling.
https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
https://softwareengineering.stackexchange.com/a/215780
:param elements: elements to be shuffled
:param entropy: a seed for deterministic sampling
"""
out = elements.copy()
random.seed(a=entropy, version=2)
random.shuffle(out)
# reset seed
random.seed()
return out

View File

@ -1,62 +0,0 @@
from __future__ import annotations
import asyncio
from contextlib import suppress
from typing import Self, TypeAlias
from mixnet.client import MixClient
from mixnet.config import MixnetConfig, MixnetTopology, MixnetTopologyConfig
from mixnet.node import MixNode
EntropyQueue: TypeAlias = "asyncio.Queue[bytes]"
class Mixnet:
topology_config: MixnetTopologyConfig
mixclient: MixClient
mixnode: MixNode
entropy_queue: EntropyQueue
task: asyncio.Task # A reference just to prevent task from being garbage collected
@classmethod
async def new(
cls,
config: MixnetConfig,
entropy_queue: EntropyQueue,
) -> Self:
self = cls()
self.topology_config = config.topology_config
self.mixclient = await MixClient.new(config.mixclient_config)
self.mixnode = await MixNode.new(config.mixnode_config)
self.entropy_queue = entropy_queue
self.task = asyncio.create_task(self.__consume_entropy())
return self
async def publish_message(self, msg: bytes) -> None:
await self.mixclient.send_message(msg)
def subscribe_messages(self) -> "asyncio.Queue[bytes]":
return self.mixclient.subscribe_messages()
async def __consume_entropy(
self,
) -> None:
while True:
entropy = await self.entropy_queue.get()
self.topology_config.entropy = entropy
topology = MixnetTopology(self.topology_config)
self.mixclient.set_topology(topology)
async def cancel(self) -> None:
self.task.cancel()
with suppress(asyncio.CancelledError):
await self.task
await self.mixclient.cancel()
await self.mixnode.cancel()
# Only for testing
def get_topology(self) -> MixnetTopology:
return self.mixclient.get_topology()

View File

@ -1,107 +1,144 @@
from __future__ import annotations
import asyncio
from contextlib import suppress
from typing import Self, Tuple, TypeAlias
from typing import Awaitable, Callable, TypeAlias
from cryptography.hazmat.primitives.asymmetric.x25519 import (
X25519PrivateKey,
)
from pysphinx.payload import DEFAULT_PAYLOAD_SIZE
from pysphinx.sphinx import (
Payload,
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
SphinxPacket,
UnknownHeaderTypeError,
)
from mixnet.config import MixNodeConfig, NodeAddress
from mixnet.poisson import poisson_interval_sec
from mixnet.config import MixMembership, NodeConfig
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
PacketQueue: TypeAlias = "asyncio.Queue[Tuple[NodeAddress, SphinxPacket]]"
PacketPayloadQueue: TypeAlias = (
"asyncio.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
)
NetworkPacket: TypeAlias = "SphinxPacket | bytes"
NetworkPacketQueue: TypeAlias = "asyncio.Queue[NetworkPacket]"
Connection: TypeAlias = NetworkPacketQueue
BroadcastChannel: TypeAlias = "asyncio.Queue[bytes]"
class MixNode:
"""
A class handling incoming packets with delays
class Node:
config: NodeConfig
membership: MixMembership
mixgossip_channel: MixGossipChannel
reconstructor: MessageReconstructor
broadcast_channel: BroadcastChannel
This class is defined separated with the MixNode class,
in order to define the MixNode as a simple dataclass for clarity.
"""
config: MixNodeConfig
inbound_socket: PacketQueue
outbound_socket: PacketPayloadQueue
task: asyncio.Task # A reference just to prevent task from being garbage collected
@classmethod
async def new(
cls,
config: MixNodeConfig,
) -> Self:
self = cls()
def __init__(self, config: NodeConfig, membership: MixMembership):
self.config = config
self.inbound_socket = asyncio.Queue()
self.outbound_socket = asyncio.Queue()
self.task = asyncio.create_task(self.__run())
return self
self.membership = membership
self.mixgossip_channel = MixGossipChannel(self.__process_sphinx_packet)
self.reconstructor = MessageReconstructor()
self.broadcast_channel = asyncio.Queue()
async def __run(self):
"""
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
async def __process_sphinx_packet(
self, packet: SphinxPacket
) -> NetworkPacket | None:
try:
processed = packet.process(self.config.private_key)
match processed:
case ProcessedForwardHopPacket():
return processed.next_packet
case ProcessedFinalHopPacket():
await self.__process_sphinx_payload(processed.payload)
except Exception:
# Return SphinxPacket as it is, if this node cannot unwrap it.
return packet
This thread approximates a M/M/inf queue.
"""
async def __process_sphinx_payload(self, payload: Payload):
msg_with_flag = self.reconstructor.add(
Fragment.from_bytes(payload.recover_plain_playload())
)
if msg_with_flag is not None:
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
if flag == MessageFlag.MESSAGE_FLAG_REAL:
await self.broadcast_channel.put(msg)
def connect(self, peer: Node):
conn = asyncio.Queue()
peer.mixgossip_channel.add_inbound(conn)
self.mixgossip_channel.add_outbound(
MixOutboundConnection(conn, self.config.transmission_rate_per_sec)
)
async def send_message(self, msg: bytes):
for packet, _ in PacketBuilder.build_real_packets(msg, self.membership):
await self.mixgossip_channel.gossip(packet)
class MixGossipChannel:
inbound_conns: list[Connection]
outbound_conns: list[MixOutboundConnection]
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]]
def __init__(
self,
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]],
):
self.inbound_conns = []
self.outbound_conns = []
self.handler = handler
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks = set()
def add_inbound(self, conn: Connection):
self.inbound_conns.append(conn)
task = asyncio.create_task(self.__process_inbound_conn(conn))
self.tasks.add(task)
# To discard the task from the set automatically when it is done.
task.add_done_callback(self.tasks.discard)
def add_outbound(self, conn: MixOutboundConnection):
self.outbound_conns.append(conn)
async def __process_inbound_conn(self, conn: Connection):
while True:
_, packet = await self.inbound_socket.get()
task = asyncio.create_task(
self.__process_packet(
packet,
self.config.encryption_private_key,
self.config.delay_rate_per_min,
)
)
self.tasks.add(task)
# To discard the task from the set automatically when it is done.
task.add_done_callback(self.tasks.discard)
elem = await conn.get()
# In practice, data transmitted through connections is going to be always 'bytes'.
# But here, we use the SphinxPacket type explicitly for simplicity
# without implementing serde for SphinxPacket.
if isinstance(elem, bytes):
assert elem == build_noise_packet()
# Drop packet
continue
elif isinstance(elem, SphinxPacket):
net_packet = await self.handler(elem)
if net_packet is not None:
await self.gossip(net_packet)
async def __process_packet(
self,
packet: SphinxPacket,
encryption_private_key: X25519PrivateKey,
delay_rate_per_min: int, # Poisson rate parameter: mu
):
"""
Process a single packet with a delay that follows exponential distribution,
and forward it to the next mix node or the mix destination
async def gossip(self, packet: NetworkPacket):
for conn in self.outbound_conns:
await conn.send(packet)
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
"""
delay_sec = poisson_interval_sec(delay_rate_per_min)
await asyncio.sleep(delay_sec)
processed = packet.process(encryption_private_key)
match processed:
case ProcessedForwardHopPacket():
await self.outbound_socket.put(
(processed.next_node_address, processed.next_packet)
)
case ProcessedFinalHopPacket():
await self.outbound_socket.put(
(processed.destination_node_address, processed.payload)
)
case _:
raise UnknownHeaderTypeError
class MixOutboundConnection:
queue: NetworkPacketQueue
conn: Connection
transmission_rate_per_sec: int
async def cancel(self) -> None:
self.task.cancel()
with suppress(asyncio.CancelledError):
await self.task
def __init__(self, conn: Connection, transmission_rate_per_sec: int):
self.queue = asyncio.Queue()
self.conn = conn
self.transmission_rate_per_sec = transmission_rate_per_sec
self.task = asyncio.create_task(self.__run())
async def __run(self):
while True:
await asyncio.sleep(1 / self.transmission_rate_per_sec)
# TODO: time mixing
if self.queue.empty():
elem = build_noise_packet()
else:
elem = self.queue.get_nowait()
await self.conn.put(elem)
async def send(self, elem: NetworkPacket):
await self.queue.put(elem)
def build_noise_packet() -> bytes:
return bytes(DEFAULT_PAYLOAD_SIZE)

View File

@ -9,7 +9,7 @@ from typing import Dict, List, Self, Tuple, TypeAlias
from pysphinx.payload import Payload
from pysphinx.sphinx import SphinxPacket
from mixnet.config import MixnetTopology, MixNodeInfo
from mixnet.config import MixMembership, NodeInfo
class MessageFlag(Enum):
@ -23,25 +23,25 @@ class MessageFlag(Enum):
class PacketBuilder:
@staticmethod
def build_real_packets(
message: bytes, topology: MixnetTopology
) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]:
message: bytes, membership: MixMembership
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
return PacketBuilder.__build_packets(
MessageFlag.MESSAGE_FLAG_REAL, message, topology
MessageFlag.MESSAGE_FLAG_REAL, message, membership
)
@staticmethod
def build_drop_cover_packets(
message: bytes, topology: MixnetTopology
) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]:
message: bytes, membership: MixMembership
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
return PacketBuilder.__build_packets(
MessageFlag.MESSAGE_FLAG_DROP_COVER, message, topology
MessageFlag.MESSAGE_FLAG_DROP_COVER, message, membership
)
@staticmethod
def __build_packets(
flag: MessageFlag, message: bytes, topology: MixnetTopology
) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]:
destination = topology.choose_mix_destination()
flag: MessageFlag, message: bytes, membership: MixMembership
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
last_mix = membership.choose()
msg_with_flag = flag.bytes() + message
# NOTE: We don't encrypt msg_with_flag for destination.
@ -50,11 +50,11 @@ class PacketBuilder:
out = []
for fragment in fragment_set.fragments:
route = topology.generate_route(destination)
route = membership.generate_route(3, last_mix)
packet = SphinxPacket.build(
fragment.bytes(),
[mixnode.sphinx_node() for mixnode in route],
destination.sphinx_node(),
last_mix.sphinx_node(),
)
out.append((packet, route))

View File

@ -1,13 +0,0 @@
import numpy
def poisson_interval_sec(rate_per_min: int) -> float:
# If events occur in a Poisson distribution with rate_per_min,
# the interval between events follows the exponential distribution
# with the rate_per_min (i.e. with the scale 1/rate_per_min).
interval_min = numpy.random.exponential(scale=1 / rate_per_min, size=1)[0]
return interval_min * 60
def poisson_mean_interval_sec(rate_per_min: int) -> float:
return 1 / rate_per_min * 60

Binary file not shown.

Before

Width:  |  Height:  |  Size: 42 KiB

View File

@ -1,45 +0,0 @@
from datetime import datetime
from unittest import IsolatedAsyncioTestCase
import numpy
from mixnet.client import MixClient
from mixnet.poisson import poisson_mean_interval_sec
from mixnet.test_utils import (
init_mixnet_config,
with_test_timeout,
)
from mixnet.utils import random_bytes
class TestMixClient(IsolatedAsyncioTestCase):
@with_test_timeout(100)
async def test_mixclient(self):
config = init_mixnet_config().mixclient_config
config.emission_rate_per_min = 30
config.redundancy = 3
mixclient = await MixClient.new(config)
try:
# Send a 3500-byte msg, expecting that it is split into at least two packets
await mixclient.send_message(random_bytes(3500))
# Calculate intervals between packet emissions from the mix client
intervals = []
ts = datetime.now()
for _ in range(30):
_ = await mixclient.outbound_socket.get()
now = datetime.now()
intervals.append((now - ts).total_seconds())
ts = now
# Check if packets were emitted at the Poisson emission_rate
# If emissions follow the Poisson distribution with a rate `lambda`,
# a mean interval between emissions must be `1/lambda`.
self.assertAlmostEqual(
float(numpy.mean(intervals)),
poisson_mean_interval_sec(config.emission_rate_per_min),
delta=1.0,
)
finally:
await mixclient.cancel()

View File

@ -1,21 +0,0 @@
from unittest import TestCase
from mixnet.fisheryates import FisherYates
class TestFisherYates(TestCase):
def test_shuffle(self):
entropy = b"hello"
elems = [1, 2, 3, 4, 5]
shuffled1 = FisherYates.shuffle(elems, entropy)
self.assertEqual(sorted(elems), sorted(shuffled1))
# shuffle again with the same entropy
shuffled2 = FisherYates.shuffle(elems, entropy)
self.assertEqual(shuffled1, shuffled2)
# shuffle with a different entropy
shuffled3 = FisherYates.shuffle(elems, b"world")
self.assertNotEqual(shuffled1, shuffled3)
self.assertEqual(sorted(elems), sorted(shuffled3))

View File

@ -1,20 +0,0 @@
import asyncio
from unittest import IsolatedAsyncioTestCase
from mixnet.mixnet import Mixnet
from mixnet.test_utils import init_mixnet_config
class TestMixnet(IsolatedAsyncioTestCase):
async def test_topology_from_robustness(self):
config = init_mixnet_config()
entropy_queue = asyncio.Queue()
mixnet = await Mixnet.new(config, entropy_queue)
try:
old_topology = config.mixclient_config.topology
await entropy_queue.put(b"new entropy")
await asyncio.sleep(1)
self.assertNotEqual(old_topology, mixnet.get_topology())
finally:
await mixnet.cancel()

View File

@ -1,117 +1,37 @@
import asyncio
from datetime import datetime
from unittest import IsolatedAsyncioTestCase
import numpy
from pysphinx.sphinx import SphinxPacket
from mixnet.node import MixNode, NodeAddress, PacketQueue
from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec
from mixnet.node import Node
from mixnet.test_utils import (
init_mixnet_config,
with_test_timeout,
)
class TestMixNodeRunner(IsolatedAsyncioTestCase):
@with_test_timeout(180)
async def test_mixnode_emission_rate(self):
"""
Test if MixNodeRunner works as a M/M/inf queue.
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
config = init_mixnet_config(10)
nodes = [
Node(node_config, config.membership) for node_config in config.node_configs
]
for i, node in enumerate(nodes):
node.connect(nodes[(i + 1) % len(nodes)])
If inputs are arrived at Poisson rate `lambda`,
and if processing is delayed according to an exponential distribution with a rate `mu`,
the rate of outputs should be `lambda`.
"""
config = init_mixnet_config()
config.mixclient_config.emission_rate_per_min = 120 # lambda (= 2msg/sec)
config.mixnode_config.delay_rate_per_min = 30 # mu (= 2s delay on average)
await nodes[0].send_message(b"block selection")
packet, route = PacketBuilder.build_real_packets(
b"msg", config.mixclient_config.topology
)[0]
timeout = 15
for _ in range(timeout):
broadcasted_msgs = []
for node in nodes:
if not node.broadcast_channel.empty():
broadcasted_msgs.append(node.broadcast_channel.get_nowait())
# Start only the first mix node for testing
config.mixnode_config.encryption_private_key = route[0].encryption_private_key
mixnode = await MixNode.new(config.mixnode_config)
try:
# Send packets to the first mix node in a Poisson distribution
packet_count = 100
# This queue is just for counting how many packets have been sent so far.
sent_packet_queue: PacketQueue = asyncio.Queue()
sender_task = asyncio.create_task(
self.send_packets(
mixnode.inbound_socket,
packet,
route[0].addr,
packet_count,
config.mixclient_config.emission_rate_per_min,
sent_packet_queue,
)
)
try:
# Calculate intervals between outputs and gather num_jobs in the first mix node.
intervals = []
num_jobs = []
ts = datetime.now()
for _ in range(packet_count):
_ = await mixnode.outbound_socket.get()
now = datetime.now()
intervals.append((now - ts).total_seconds())
if len(broadcasted_msgs) == 0:
await asyncio.sleep(1)
else:
# We expect only one node to broadcast the message.
assert len(broadcasted_msgs) == 1
self.assertEqual(b"block selection", broadcasted_msgs[0])
return
self.fail("timeout")
# Calculate the current # of jobs staying in the mix node
num_packets_emitted_from_mixnode = len(intervals)
num_packets_sent_to_mixnode = sent_packet_queue.qsize()
num_jobs.append(
num_packets_sent_to_mixnode - num_packets_emitted_from_mixnode
)
ts = now
# Remove the first interval that would be much larger than other intervals,
# because of the delay in mix node.
intervals = intervals[1:]
num_jobs = num_jobs[1:]
# Check if the emission rate of the first mix node is the same as
# the emission rate of the message sender, but with a delay.
# If outputs follow the Poisson distribution with a rate `lambda`,
# a mean interval between outputs must be `1/lambda`.
self.assertAlmostEqual(
float(numpy.mean(intervals)),
poisson_mean_interval_sec(
config.mixclient_config.emission_rate_per_min
),
delta=1.0,
)
# If runner is a M/M/inf queue,
# a mean number of jobs being processed/scheduled in the runner must be `lambda/mu`.
self.assertAlmostEqual(
float(numpy.mean(num_jobs)),
round(
config.mixclient_config.emission_rate_per_min
/ config.mixnode_config.delay_rate_per_min
),
delta=1.5,
)
finally:
await sender_task
finally:
await mixnode.cancel()
@staticmethod
async def send_packets(
inbound_socket: PacketQueue,
packet: SphinxPacket,
node_addr: NodeAddress,
cnt: int,
rate_per_min: int,
# For testing purpose, to inform the caller how many packets have been sent to the inbound_socket
sent_packet_queue: PacketQueue,
):
for _ in range(cnt):
# Since the task is not heavy, just sleep for seconds instead of using emission_notifier
await asyncio.sleep(poisson_interval_sec(rate_per_min))
await inbound_socket.put((node_addr, packet))
await sent_packet_queue.put((node_addr, packet))
# TODO: check noise

View File

@ -1,9 +1,10 @@
from random import randint
from typing import List
from unittest import TestCase
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket
from mixnet.config import MixNodeInfo
from mixnet.config import NodeInfo
from mixnet.packet import (
Fragment,
MessageFlag,
@ -11,14 +12,13 @@ from mixnet.packet import (
PacketBuilder,
)
from mixnet.test_utils import init_mixnet_config
from mixnet.utils import random_bytes
class TestPacket(TestCase):
def test_real_packet(self):
topology = init_mixnet_config().mixclient_config.topology
msg = random_bytes(3500)
packets_and_routes = PacketBuilder.build_real_packets(msg, topology)
membership = init_mixnet_config(10).membership
msg = self.random_bytes(3500)
packets_and_routes = PacketBuilder.build_real_packets(msg, membership)
self.assertEqual(4, len(packets_and_routes))
reconstructor = MessageReconstructor()
@ -47,9 +47,9 @@ class TestPacket(TestCase):
)
def test_cover_packet(self):
topology = init_mixnet_config().mixclient_config.topology
membership = init_mixnet_config(10).membership
msg = b"cover"
packets_and_routes = PacketBuilder.build_drop_cover_packets(msg, topology)
packets_and_routes = PacketBuilder.build_drop_cover_packets(msg, membership)
self.assertEqual(1, len(packets_and_routes))
reconstructor = MessageReconstructor()
@ -63,16 +63,21 @@ class TestPacket(TestCase):
)
@staticmethod
def process_packet(packet: SphinxPacket, route: List[MixNodeInfo]) -> Fragment:
processed = packet.process(route[0].encryption_private_key)
def process_packet(packet: SphinxPacket, route: List[NodeInfo]) -> Fragment:
processed = packet.process(route[0].private_key)
if isinstance(processed, ProcessedFinalHopPacket):
return Fragment.from_bytes(processed.payload.recover_plain_playload())
else:
processed = processed
for node in route[1:]:
p = processed.next_packet.process(node.encryption_private_key)
p = processed.next_packet.process(node.private_key)
if isinstance(p, ProcessedFinalHopPacket):
return Fragment.from_bytes(p.payload.recover_plain_playload())
else:
processed = p
assert False
@staticmethod
def random_bytes(size: int) -> bytes:
assert size >= 0
return bytes([randint(0, 255) for _ in range(size)])

View File

@ -1,46 +1,20 @@
import asyncio
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from mixnet.bls import generate_bls
from mixnet.config import (
MixClientConfig,
MixNodeConfig,
MixMembership,
MixnetConfig,
MixNodeInfo,
MixnetTopology,
MixnetTopologyConfig,
MixnetTopologySize,
NodeConfig,
NodeInfo,
)
from mixnet.utils import random_bytes
def with_test_timeout(t):
def wrapper(coroutine):
async def run(*args, **kwargs):
async with asyncio.timeout(t):
return await coroutine(*args, **kwargs)
return run
return wrapper
def init_mixnet_config() -> MixnetConfig:
topology_config = MixnetTopologyConfig(
[
MixNodeInfo(
generate_bls(),
X25519PrivateKey.generate(),
random_bytes(32),
)
for _ in range(12)
],
MixnetTopologySize(3, 3),
b"entropy",
def init_mixnet_config(num_nodes: int) -> MixnetConfig:
transmission_rate_per_sec = 3
node_configs = [
NodeConfig(X25519PrivateKey.generate(), transmission_rate_per_sec)
for _ in range(num_nodes)
]
membership = MixMembership(
[NodeInfo(node_config.private_key) for node_config in node_configs]
)
mixclient_config = MixClientConfig(30, 3, MixnetTopology(topology_config))
mixnode_config = MixNodeConfig(
topology_config.mixnode_candidates[0].encryption_private_key, 30
)
return MixnetConfig(topology_config, mixclient_config, mixnode_config)
return MixnetConfig(node_configs, membership)

View File

@ -1,6 +0,0 @@
from random import randint
def random_bytes(size: int) -> bytes:
assert size >= 0
return bytes([randint(0, 255) for _ in range(size)])

View File

@ -1,4 +1,4 @@
blspy==2.0.2
blspy==2.0.3
cffi==1.16.0
cryptography==41.0.7
numpy==1.26.3