Mixnet: Simulation

This commit is contained in:
Youngjoon Lee 2024-07-15 18:07:26 +09:00
parent 693880ad35
commit f7f931f73e
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
23 changed files with 985 additions and 32 deletions

View File

@ -19,10 +19,12 @@ jobs:
uses: actions/setup-python@v5
with:
# Semantic version range syntax or exact version of a Python version
python-version: '3.x'
python-version: "3.x"
- name: Install dependencies
run: pip install -r requirements.txt
- name: Build and install eth-specs
run: ./install-eth-specs.sh
- name: Run tests
run: python -m unittest
- name: Run a short mixnet simulation
run: python -m mixnet.sim.main --config mixnet/sim/config.ci.yaml

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
.venv
__pycache__
*.csv

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import random
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List
from cryptography.hazmat.primitives.asymmetric.x25519 import (
@ -19,7 +19,6 @@ class GlobalConfig:
membership: MixMembership
transmission_rate_per_sec: int # Global Transmission Rate
# TODO: use these two to make the size of Sphinx packet constant
max_message_size: int
max_mix_path_length: int
@ -49,12 +48,13 @@ class MixMembership:
"""
nodes: List[NodeInfo]
rng: random.Random = field(default_factory=random.Random)
def generate_route(self, length: int) -> list[NodeInfo]:
"""
Choose `length` nodes with replacement as a mix route.
"""
return random.choices(self.nodes, k=length)
return self.rng.choices(self.nodes, k=length)
@dataclass

View File

@ -1,9 +1,40 @@
from __future__ import annotations
import asyncio
import abc
NetworkPacketQueue = asyncio.Queue[bytes]
SimplexConnection = NetworkPacketQueue
from mixnet.framework import Framework, Queue
NetworkPacketQueue = Queue
class SimplexConnection(abc.ABC):
"""
An abstract class for a simplex connection that can send and receive data in one direction
"""
@abc.abstractmethod
async def send(self, data: bytes) -> None:
pass
@abc.abstractmethod
async def recv(self) -> bytes:
pass
class LocalSimplexConnection(SimplexConnection):
"""
A simplex connection that doesn't have any network latency.
Data sent through this connection can be immediately received from the other end.
"""
def __init__(self, framework: Framework):
self.queue = framework.queue()
async def send(self, data: bytes) -> None:
await self.queue.put(data)
async def recv(self) -> bytes:
return await self.queue.get()
class DuplexConnection:
@ -17,7 +48,7 @@ class DuplexConnection:
self.outbound = outbound
async def recv(self) -> bytes:
return await self.inbound.get()
return await self.inbound.recv()
async def send(self, packet: bytes):
await self.outbound.send(packet)
@ -29,24 +60,29 @@ class MixSimplexConnection:
"""
def __init__(
self, conn: SimplexConnection, transmission_rate_per_sec: int, noise_msg: bytes
self,
framework: Framework,
conn: SimplexConnection,
transmission_rate_per_sec: int,
noise_msg: bytes,
):
self.queue = asyncio.Queue()
self.framework = framework
self.queue = framework.queue()
self.conn = conn
self.transmission_rate_per_sec = transmission_rate_per_sec
self.noise_msg = noise_msg
self.task = asyncio.create_task(self.__run())
self.task = framework.spawn(self.__run())
async def __run(self):
while True:
await asyncio.sleep(1 / self.transmission_rate_per_sec)
await self.framework.sleep(1 / self.transmission_rate_per_sec)
# TODO: temporal mixing
if self.queue.empty():
# To guarantee GTR, send noise if there is no message to send
msg = self.noise_msg
else:
msg = self.queue.get_nowait()
await self.conn.put(msg)
msg = await self.queue.get()
await self.conn.send(msg)
async def send(self, msg: bytes):
await self.queue.put(msg)

2
mixnet/error.py Normal file
View File

@ -0,0 +1,2 @@
class PeeringDegreeReached(Exception):
pass

View File

@ -0,0 +1 @@
from .framework import *

View File

@ -0,0 +1,49 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Awaitable, Coroutine
from mixnet import framework
class Framework(framework.Framework):
"""
An asyncio implementation of the Framework
"""
def __init__(self):
super().__init__()
def queue(self) -> framework.Queue:
return Queue()
async def sleep(self, seconds: float) -> None:
await asyncio.sleep(seconds)
def now(self) -> float:
return time.time()
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]
) -> Awaitable[framework.RT]:
return asyncio.create_task(coroutine)
class Queue(framework.Queue):
"""
An asyncio implementation of the Queue
"""
def __init__(self):
super().__init__()
self._queue = asyncio.Queue()
async def put(self, data: bytes) -> None:
await self._queue.put(data)
async def get(self) -> bytes:
return await self._queue.get()
def empty(self) -> bool:
return self._queue.empty()

View File

@ -0,0 +1,47 @@
from __future__ import annotations
import abc
from typing import Any, Awaitable, Coroutine, TypeVar
RT = TypeVar("RT")
class Framework(abc.ABC):
"""
An abstract class that provides essential asynchronous functions.
This class can be implemented using any asynchronous framework (e.g., asyncio, usim)).
"""
@abc.abstractmethod
def queue(self) -> Queue:
pass
@abc.abstractmethod
async def sleep(self, seconds: float) -> None:
pass
@abc.abstractmethod
def now(self) -> float:
pass
@abc.abstractmethod
def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]:
pass
class Queue(abc.ABC):
"""
An abstract class that provides asynchronous queue operations.
"""
@abc.abstractmethod
async def put(self, data: bytes) -> None:
pass
@abc.abstractmethod
async def get(self) -> bytes:
pass
@abc.abstractmethod
def empty(self) -> bool:
pass

55
mixnet/framework/usim.py Normal file
View File

@ -0,0 +1,55 @@
from typing import Any, Awaitable, Coroutine
import usim
from mixnet import framework
class Framework(framework.Framework):
"""
A usim implementation of the Framework for discrete-time simulation
"""
def __init__(self, scope: usim.Scope) -> None:
super().__init__()
# Scope is used to spawn concurrent simulation activities (coroutines).
# μSim waits until all activities spawned in the scope are done
# or until the timeout specified in the scope is reached.
# Because of the way μSim works, the scope must be created using `async with` syntax
# and be passed to this constructor.
self._scope = scope
def queue(self) -> framework.Queue:
return Queue()
async def sleep(self, seconds: float) -> None:
await (usim.time + seconds)
def now(self) -> float:
# Round to milliseconds to make analysis not too heavy
return int(usim.time.now * 1000) / 1000
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]
) -> Awaitable[framework.RT]:
return self._scope.do(coroutine)
class Queue(framework.Queue):
"""
A usim implementation of the Queue for discrete-time simulation
"""
def __init__(self):
super().__init__()
self._queue = usim.Queue()
async def put(self, data: bytes) -> None:
await self._queue.put(data)
async def get(self) -> bytes:
return await self._queue
def empty(self) -> bool:
return len(self._queue._buffer) == 0

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from typing import TypeAlias
from pysphinx.sphinx import (
@ -10,10 +9,13 @@ from pysphinx.sphinx import (
)
from mixnet.config import GlobalConfig, NodeConfig
from mixnet.connection import SimplexConnection
from mixnet.error import PeeringDegreeReached
from mixnet.framework import Framework, Queue
from mixnet.nomssip import Nomssip
from mixnet.sphinx import SphinxPacketBuilder
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
BroadcastChannel = Queue
class Node:
@ -24,10 +26,14 @@ class Node:
- generates noise
"""
def __init__(self, config: NodeConfig, global_config: GlobalConfig):
def __init__(
self, framework: Framework, config: NodeConfig, global_config: GlobalConfig
):
self.framework = framework
self.config = config
self.global_config = global_config
self.nomssip = Nomssip(
framework,
Nomssip.Config(
global_config.transmission_rate_per_sec,
config.nomssip.peering_degree,
@ -35,7 +41,7 @@ class Node:
),
self.__process_msg,
)
self.broadcast_channel = asyncio.Queue()
self.broadcast_channel = framework.queue()
@staticmethod
def __calculate_message_size(global_config: GlobalConfig) -> int:
@ -84,11 +90,18 @@ class Node:
# Return nothing, if it cannot be unwrapped by the private key of this node.
return None
def connect(self, peer: Node):
def connect(
self,
peer: Node,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
):
"""
Establish a duplex connection with a peer node.
"""
inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue()
if not self.nomssip.can_accept_conn() or not peer.nomssip.can_accept_conn():
raise PeeringDegreeReached()
# Register a duplex connection for its own use
self.nomssip.add_conn(inbound_conn, outbound_conn)
# Register a duplex connection for the peer

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import asyncio
import hashlib
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Self
from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection
from mixnet.error import PeeringDegreeReached
from mixnet.framework import Framework
class Nomssip:
@ -23,9 +24,11 @@ class Nomssip:
def __init__(
self,
framework: Framework,
config: Config,
handler: Callable[[bytes], Awaitable[None]],
):
self.framework = framework
self.config = config
self.conns: list[DuplexConnection] = []
# A handler to process inbound messages.
@ -33,12 +36,15 @@ class Nomssip:
self.packet_cache: set[bytes] = set()
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks: set[asyncio.Task] = set()
self.tasks: set[Awaitable] = set()
def can_accept_conn(self) -> bool:
return len(self.conns) < self.config.peering_degree
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
if len(self.conns) >= self.config.peering_degree:
if not self.can_accept_conn():
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise ValueError("The peering degree is reached.")
raise PeeringDegreeReached()
noise_packet = FlaggedPacket(
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
@ -46,6 +52,7 @@ class Nomssip:
conn = DuplexConnection(
inbound,
MixSimplexConnection(
self.framework,
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
@ -53,10 +60,8 @@ class Nomssip:
)
self.conns.append(conn)
task = asyncio.create_task(self.__process_inbound_conn(conn))
task = self.framework.spawn(self.__process_inbound_conn(conn))
self.tasks.add(task)
# To discard the task from the set automatically when it is done.
task.add_done_callback(self.tasks.discard)
async def __process_inbound_conn(self, conn: DuplexConnection):
while True:

81
mixnet/sim/README.md Normal file
View File

@ -0,0 +1,81 @@
# NomMix Simulation
## Installation
Clone the repository and install the dependencies:
```bash
git clone https://github.com/logos-co/nomos-specs.git
cd nomos-specs
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```
## Getting started
Copy the [`mixnet/sim/config.ci.yaml`](./config.ci.yaml) file and adjust the parameters to your needs.
Each parameter is explained in the config file.
For more details, please refer to the [documentation](https://www.notion.so/NomMix-Sim-Getting-Started-ee0e2191f4e7437e93976aff2627d7ce?pvs=4).
Run the simulation with the following command:
```bash
python -m mixnet.sim.main --config {config_path}
```
All results are printed in the console as below.
And, all plots are shown once all analysis is done.
```
==========================================
Message Size Distribution
==========================================
msg_size count
0 1405 99990
==========================================
Node States of All Nodes over Time
==========================================
Node-0 Node-1 Node-2 Node-3 Node-4
0 0 0 0 0 0
1 0 0 0 0 0
2 0 0 0 0 0
3 0 0 0 0 0
4 0 0 0 0 0
... ... ... ... ... ...
999995 0 0 0 0 0
999996 0 0 0 0 0
999997 0 0 0 0 0
999998 0 0 0 0 0
999999 0 0 0 0 0
[1000000 rows x 5 columns]
Saved DataFrame to all_node_states_2024-07-15T18:20:23.csv
State Counts per Node:
Node-0 Node-1 Node-2 Node-3 Node-4
0 970003 970003 970003 970003 970003
1 19998 19998 19998 19998 19998
-1 9999 9999 9999 9999 9999
Simulation complete!
```
Please note that the result of node state analysis is saved as a CSV file, as printed in the console.
```
Saved DataFrame to all_node_states_2024-07-15T18:20:23.csv
```
If you run the simulation again with the different parameters and want to
compare the results of two simulations,
you can calculate the hamming distance between them:
```bash
python -m mixnet.sim.hamming \
all_node_states_2024-07-15T18:20:23.csv \
all_node_states_2024-07-15T19:32:45.csv
```
The output is a floating point number between 0 and 1.
If the output is 0, the results of two simulations are identical.
The closer the result is to 1, the more the two results differ from each other.
```
Hamming distance: 0.29997
```

0
mixnet/sim/__init__.py Normal file
View File

39
mixnet/sim/config.ci.yaml Normal file
View File

@ -0,0 +1,39 @@
simulation:
# Desired duration of the simulation in seconds
# Since the simulation uses discrete time steps, the actual duration may be longer or shorter.
duration_sec: 1000
# Show all plots that have been drawn during the simulation
show_plots: false
network:
# Total number of nodes in the entire network.
num_nodes: 5
latency:
# Maximum network latency between nodes in seconds.
# A constant latency will be chosen randomly for each connection within the range [0, max_latency_sec].
max_latency_sec: 0.1
# Seed for the random number generator used to determine the network latencies.
seed: 0
nomssip:
# Target number of peers each node can connect to (both inbound and outbound).
peering_degree: 6
mix:
# Global constant transmission rate of each connection in messages per second.
transmission_rate_per_sec: 10
# Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet.
max_message_size: 1007
mix_path:
# Maximum number of mix nodes to be chosen for a Sphinx packet.
max_length: 5
# Seed for the random number generator used to determine the mix path.
seed: 3
logic:
sender_lottery:
# Interval between lottery draws in seconds.
interval_sec: 1
# Probability of a node being selected as a sender in each lottery draw.
probability: 0.001
# Seed for the random number generator used to determine the lottery winners.
seed: 10

138
mixnet/sim/config.py Normal file
View File

@ -0,0 +1,138 @@
from __future__ import annotations
import hashlib
import random
from dataclasses import dataclass
import dacite
import yaml
from pysphinx.sphinx import X25519PrivateKey
from mixnet.config import NodeConfig, NomssipConfig
@dataclass
class Config:
simulation: SimulationConfig
network: NetworkConfig
logic: LogicConfig
mix: MixConfig
@classmethod
def load(cls, yaml_path: str) -> Config:
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
return dacite.from_dict(
data_class=Config,
data=data,
config=dacite.Config(
type_hooks={random.Random: seed_to_random}, strict=True
),
)
def node_configs(self) -> list[NodeConfig]:
return [
NodeConfig(
self._gen_private_key(i),
self.mix.mix_path.random_length(),
self.network.nomssip,
)
for i in range(self.network.num_nodes)
]
def _gen_private_key(self, node_idx: int) -> X25519PrivateKey:
return X25519PrivateKey.from_private_bytes(
hashlib.sha256(node_idx.to_bytes(4, "big")).digest()[:32]
)
@dataclass
class SimulationConfig:
# Desired duration of the simulation in seconds
# Since the simulation uses discrete time steps, the actual duration may be longer or shorter.
duration_sec: int
# Show all plots that have been drawn during the simulation
show_plots: bool
def __post_init__(self):
assert self.duration_sec > 0
@dataclass
class NetworkConfig:
# Total number of nodes in the entire network.
num_nodes: int
latency: LatencyConfig
nomssip: NomssipConfig
def __post_init__(self):
assert self.num_nodes > 0
@dataclass
class LatencyConfig:
# Maximum network latency between nodes in seconds.
# A constant latency will be chosen randomly for each connection within the range [0, max_latency_sec].
max_latency_sec: float
# Seed for the random number generator used to determine the network latencies.
seed: random.Random
def __post_init__(self):
assert self.max_latency_sec > 0
assert self.seed is not None
def random_latency(self) -> float:
# round to milliseconds to make analysis not too heavy
return int(self.seed.random() * self.max_latency_sec * 1000) / 1000
@dataclass
class MixConfig:
# Global constant transmission rate of each connection in messages per second.
transmission_rate_per_sec: int
# Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet.
max_message_size: int
mix_path: MixPathConfig
def __post_init__(self):
assert self.transmission_rate_per_sec > 0
assert self.max_message_size > 0
@dataclass
class MixPathConfig:
# Maximum number of mix nodes to be chosen for a Sphinx packet.
max_length: int
# Seed for the random number generator used to determine the mix path.
seed: random.Random
def __post_init__(self):
assert self.max_length > 0
assert self.seed is not None
def random_length(self) -> int:
return self.seed.randint(1, self.max_length)
@dataclass
class LogicConfig:
sender_lottery: LotteryConfig
@dataclass
class LotteryConfig:
# Interval between lottery draws in seconds.
interval_sec: float
# Probability of a node being selected as a sender in each lottery draw.
probability: float
# Seed for the random number generator used to determine the lottery winners.
seed: random.Random
def __post_init__(self):
assert self.interval_sec > 0
assert self.probability >= 0
assert self.seed is not None
def seed_to_random(seed: int) -> random.Random:
return random.Random(seed)

100
mixnet/sim/connection.py Normal file
View File

@ -0,0 +1,100 @@
import math
from collections import Counter
from typing import Awaitable
import pandas
from mixnet.connection import SimplexConnection
from mixnet.framework import Framework, Queue
from mixnet.sim.config import NetworkConfig
from mixnet.sim.state import NodeState
class MeteredRemoteSimplexConnection(SimplexConnection):
"""
A simplex connection implementation that simulates network latency and measures bandwidth usages.
"""
def __init__(
self,
config: NetworkConfig,
framework: Framework,
send_node_states: list[NodeState],
recv_node_states: list[NodeState],
):
self.framework = framework
# A connection has a random constant latency
self.latency = config.latency.random_latency()
# A queue where a sender puts messages to be sent
self.send_queue = framework.queue()
# A queue that connects send_queue and recv_queue (to measure bandwidths and simulate latency)
self.mid_queue = framework.queue()
# A queue where a receiver gets messages
self.recv_queue = framework.queue()
# A task that reads messages from send_queue, updates bandwidth stats, and puts them to mid_queue
self.send_meters: list[int] = []
self.send_task = framework.spawn(self.__run_send_task())
# A task that reads messages from mid_queue, simulates network latency, updates bandwidth stats, and puts them to recv_queue
self.recv_meters: list[int] = []
self.recv_task = framework.spawn(self.__run_recv_task())
# To measure node states over time
self.send_node_states = send_node_states
self.recv_node_states = recv_node_states
# To measure the size of messages sent via this connection
self.msg_sizes: Counter[int] = Counter()
async def send(self, data: bytes) -> None:
await self.send_queue.put(data)
self.msg_sizes.update([len(data)])
# The time unit of node states is milliseconds
ms = math.floor(self.framework.now() * 1000)
self.send_node_states[ms] = NodeState.SENDING
async def recv(self) -> bytes:
data = await self.recv_queue.get()
# The time unit of node states is milliseconds
ms = math.floor(self.framework.now() * 1000)
self.send_node_states[ms] = NodeState.RECEIVING
return data
async def __run_send_task(self):
"""
A task that reads messages from send_queue, updates bandwidth stats, and puts them to mid_queue
"""
start_time = self.framework.now()
while True:
data = await self.send_queue.get()
self.__update_meter(self.send_meters, len(data), start_time)
await self.mid_queue.put(data)
async def __run_recv_task(self):
"""
A task that reads messages from mid_queue, simulates network latency, updates bandwidth stats, and puts them to recv_queue
"""
start_time = self.framework.now()
while True:
data = await self.mid_queue.get()
if data is None:
break
await self.framework.sleep(self.latency)
self.__update_meter(self.recv_meters, len(data), start_time)
await self.recv_queue.put(data)
def __update_meter(self, meters: list[int], size: int, start_time: float):
"""
Accumulates the bandwidth usage in the current time slot (seconds).
"""
slot = math.floor(self.framework.now() - start_time)
assert slot >= len(meters) - 1
# Fill zeros for the empty time slots
meters.extend([0] * (slot - len(meters) + 1))
meters[-1] += size
def sending_bandwidths(self) -> pandas.Series:
return self.__bandwidths(self.send_meters)
def receiving_bandwidths(self) -> pandas.Series:
return self.__bandwidths(self.recv_meters)
def __bandwidths(self, meters: list[int]) -> pandas.Series:
return pandas.Series(meters, name="bandwidth")

42
mixnet/sim/hamming.py Normal file
View File

@ -0,0 +1,42 @@
import sys
import pandas as pd
def calculate_hamming_distance(df1, df2):
"""
Caculate the hamming distance between two DataFrames
to quantify the difference between them.
"""
if df1.shape != df2.shape:
raise ValueError(
"DataFrames must have the same shape to calculate Hamming distance."
)
# Compare element-wise and count differences
differences = (df1 != df2).sum().sum()
return differences / df1.size # normalize the distance
def main():
if len(sys.argv) != 3:
print("Usage: python hamming.py <csv_path1> <csv_path2>")
sys.exit(1)
csv_path1 = sys.argv[1]
csv_path2 = sys.argv[2]
# Load the CSV files into DataFrames
df1 = pd.read_csv(csv_path1)
df2 = pd.read_csv(csv_path2)
# Calculate the Hamming distance
try:
hamming_distance = calculate_hamming_distance(df1, df2)
print(f"Hamming distance: {hamming_distance}")
except ValueError as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()

25
mixnet/sim/main.py Normal file
View File

@ -0,0 +1,25 @@
import argparse
import usim
from mixnet.sim.config import Config
from mixnet.sim.simulation import Simulation
if __name__ == "__main__":
"""
Read a config file and run a simulation
"""
parser = argparse.ArgumentParser(
description="Run mixnet simulation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config", type=str, required=True, help="Configuration file path"
)
args = parser.parse_args()
config = Config.load(args.config)
sim = Simulation(config)
usim.run(sim.run())
print("Simulation complete!")

115
mixnet/sim/simulation.py Normal file
View File

@ -0,0 +1,115 @@
import usim
from matplotlib import pyplot
import mixnet.framework.usim as usimfw
from mixnet.config import GlobalConfig, MixMembership, NodeInfo
from mixnet.framework import Framework
from mixnet.node import Node, PeeringDegreeReached
from mixnet.sim.config import Config
from mixnet.sim.connection import MeteredRemoteSimplexConnection
from mixnet.sim.state import NodeState, NodeStateTable
from mixnet.sim.stats import ConnectionStats
class Simulation:
"""
Manages the entire cycle of simulation: initialization, running, and analysis.
"""
def __init__(self, config: Config):
self.config = config
async def run(self):
# Run the simulation
conn_stats, node_state_table = await self.__run()
# Analyze the simulation results
conn_stats.analyze()
node_state_table.analyze()
# Show plots
if self.config.simulation.show_plots:
pyplot.show()
async def __run(self) -> tuple[ConnectionStats, NodeStateTable]:
# Initialize analysis tools
node_state_table = NodeStateTable(
self.config.network.num_nodes, self.config.simulation.duration_sec
)
conn_stats = ConnectionStats()
# Create a μSim scope and run the simulation
async with usim.until(usim.time + self.config.simulation.duration_sec) as scope:
self.framework = usimfw.Framework(scope)
nodes, conn_stats, node_state_table = self.__init_nodes(
node_state_table, conn_stats
)
for node in nodes:
self.framework.spawn(self.__run_node_logic(node))
# Return analysis tools once the μSim scope is done
return conn_stats, node_state_table
def __init_nodes(
self, node_state_table: NodeStateTable, conn_stats: ConnectionStats
) -> tuple[list[Node], ConnectionStats, NodeStateTable]:
# Initialize node/global configurations
node_configs = self.config.node_configs()
global_config = GlobalConfig(
MixMembership(
[
NodeInfo(node_config.private_key.public_key())
for node_config in node_configs
],
self.config.mix.mix_path.seed,
),
self.config.mix.transmission_rate_per_sec,
self.config.mix.max_message_size,
self.config.mix.mix_path.max_length,
)
# Initialize Node instances
nodes = [
Node(self.framework, node_config, global_config)
for node_config in node_configs
]
# Connect nodes to each other
for i, node in enumerate(nodes):
# For now, we only consider a simple ring topology for simplicity.
peer_idx = (i + 1) % len(nodes)
peer = nodes[peer_idx]
node_states = node_state_table[i]
peer_states = node_state_table[peer_idx]
# Create simplex inbound/outbound connections
# and use them to connect node and peer.
inbound_conn, outbound_conn = (
self.__create_conn(peer_states, node_states),
self.__create_conn(node_states, peer_states),
)
node.connect(peer, inbound_conn, outbound_conn)
# Register the connections to the connection statistics
conn_stats.register(node, inbound_conn, outbound_conn)
conn_stats.register(peer, outbound_conn, inbound_conn)
return nodes, conn_stats, node_state_table
def __create_conn(
self, sender_states: list[NodeState], receiver_states: list[NodeState]
) -> MeteredRemoteSimplexConnection:
return MeteredRemoteSimplexConnection(
self.config.network,
self.framework,
sender_states,
receiver_states,
)
async def __run_node_logic(self, node: Node):
"""
Runs the lottery periodically to check if the node is selected to send a block.
If the node is selected, creates a block and sends it through mix nodes.
"""
lottery_config = self.config.logic.sender_lottery
while True:
await (usim.time + lottery_config.interval_sec)
if lottery_config.seed.random() < lottery_config.probability:
await node.send_message(b"selected block")

76
mixnet/sim/state.py Normal file
View File

@ -0,0 +1,76 @@
from datetime import datetime
from enum import Enum
import matplotlib.pyplot as plt
import pandas
class NodeState(Enum):
"""
A state of node at a certain time.
For now, we assume that the node cannot send and receive messages at the same time for simplicity.
"""
SENDING = -1
IDLE = 0
RECEIVING = 1
class NodeStateTable:
def __init__(self, num_nodes: int, duration_sec: int):
# Create a table to store the state of each node at each millisecond
self.__table = [
[NodeState.IDLE] * (duration_sec * 1000) for _ in range(num_nodes)
]
def __getitem__(self, idx: int) -> list[NodeState]:
return self.__table[idx]
def analyze(self):
df = pandas.DataFrame(self.__table).transpose()
df.columns = [f"Node-{i}" for i in range(len(self.__table))]
# Convert NodeState enum to their integer values
df = df.map(lambda state: state.value)
print("==========================================")
print(" Node States of All Nodes over Time")
print("==========================================")
print(f"{df}\n")
csv_path = f"all_node_states_{datetime.now().isoformat(timespec="seconds")}.csv"
df.to_csv(csv_path)
print(f"Saved DataFrame to {csv_path}\n")
# Count/print the number of each state for each node
# because the df is usually too big to print
state_counts = df.apply(pandas.Series.value_counts).fillna(0)
print("State Counts per Node:")
print(f"{state_counts}\n")
# Draw a dot plot
plt.figure(figsize=(15, 8))
for node in df.columns:
times = df.index
states = df[node]
sending_times = times[states == NodeState.SENDING.value]
receiving_times = times[states == NodeState.RECEIVING.value]
plt.scatter(
sending_times,
[node] * len(sending_times),
color="red",
marker="o",
s=10,
label="SENDING" if node == df.columns[0] else "",
)
plt.scatter(
receiving_times,
[node] * len(receiving_times),
color="blue",
marker="x",
s=10,
label="RECEIVING" if node == df.columns[0] else "",
)
plt.xlabel("Time")
plt.ylabel("Node")
plt.title("Node States Over Time")
plt.legend(loc="upper right")
plt.draw()

112
mixnet/sim/stats.py Normal file
View File

@ -0,0 +1,112 @@
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import pandas
from mixnet.node import Node
from mixnet.sim.connection import MeteredRemoteSimplexConnection
# A map of nodes to their inbound/outbound connections
NodeConnectionsMap = dict[
Node,
tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]],
]
class ConnectionStats:
def __init__(self):
self.conns_per_node: NodeConnectionsMap = defaultdict(lambda: ([], []))
def register(
self,
node: Node,
inbound_conn: MeteredRemoteSimplexConnection,
outbound_conn: MeteredRemoteSimplexConnection,
):
self.conns_per_node[node][0].append(inbound_conn)
self.conns_per_node[node][1].append(outbound_conn)
def analyze(self):
self.__message_sizes()
self.__bandwidths_per_conn()
self.__bandwidths_per_node()
def __message_sizes(self):
"""
Analyzes all message sizes sent across all connections of all nodes.
"""
sizes: Counter[int] = Counter()
for _, (_, outbound_conns) in self.conns_per_node.items():
for conn in outbound_conns:
sizes.update(conn.msg_sizes)
df = pandas.DataFrame.from_dict(sizes, orient="index").reset_index()
df.columns = ["msg_size", "count"]
print("==========================================")
print(" Message Size Distribution")
print("==========================================")
print(f"{df}\n")
def __bandwidths_per_conn(self):
"""
Analyzes the bandwidth consumed by each simplex connection.
"""
plt.plot(figsize=(12, 6))
for _, (_, outbound_conns) in self.conns_per_node.items():
for conn in outbound_conns:
sending_bandwidths = conn.sending_bandwidths().map(lambda x: x / 1024)
plt.plot(sending_bandwidths.index, sending_bandwidths)
plt.title("Unidirectional Bandwidths per Connection")
plt.xlabel("Time (s)")
plt.ylabel("Bandwidth (KiB/s)")
plt.ylim(bottom=0)
plt.grid(True)
plt.tight_layout()
plt.draw()
def __bandwidths_per_node(self):
"""
Analyzes the inbound/outbound bandwidths consumed by each node (sum of all its connections).
"""
_, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
for i, (_, (inbound_conns, outbound_conns)) in enumerate(
self.conns_per_node.items()
):
inbound_bandwidths = (
pandas.concat(
[conn.receiving_bandwidths() for conn in inbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024)
)
outbound_bandwidths = (
pandas.concat(
[conn.sending_bandwidths() for conn in outbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024)
)
axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}")
axs[1].plot(
outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}"
)
axs[0].set_title("Inbound Bandwidths per Node")
axs[0].set_xlabel("Time (s)")
axs[0].set_ylabel("Bandwidth (KiB/s)")
axs[0].legend()
axs[0].set_ylim(bottom=0)
axs[0].grid(True)
axs[1].set_title("Outbound Bandwidths per Node")
axs[1].set_xlabel("Time (s)")
axs[1].set_ylabel("Bandwidth (KiB/s)")
axs[1].legend()
axs[1].set_ylim(bottom=0)
axs[1].grid(True)
plt.tight_layout()
plt.draw()

View File

@ -1,6 +1,7 @@
import asyncio
from unittest import IsolatedAsyncioTestCase
import mixnet.framework.asyncio as asynciofw
from mixnet.connection import LocalSimplexConnection
from mixnet.node import Node
from mixnet.test_utils import (
init_mixnet_config,
@ -9,11 +10,18 @@ from mixnet.test_utils import (
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
framework = asynciofw.Framework()
global_config, node_configs, _ = init_mixnet_config(10)
nodes = [Node(node_config, global_config) for node_config in node_configs]
nodes = [
Node(framework, node_config, global_config) for node_config in node_configs
]
for i, node in enumerate(nodes):
try:
node.connect(nodes[(i + 1) % len(nodes)])
node.connect(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection(framework),
LocalSimplexConnection(framework),
)
except ValueError as e:
print(e)
@ -24,10 +32,10 @@ class TestNode(IsolatedAsyncioTestCase):
broadcasted_msgs = []
for node in nodes:
if not node.broadcast_channel.empty():
broadcasted_msgs.append(node.broadcast_channel.get_nowait())
broadcasted_msgs.append(await node.broadcast_channel.get())
if len(broadcasted_msgs) == 0:
await asyncio.sleep(1)
await framework.sleep(1)
else:
# We expect only one node to broadcast the message.
assert len(broadcasted_msgs) == 1

View File

@ -13,3 +13,8 @@ portalocker==2.8.2 # portable file locking
keum==0.2.0 # for CL's use of more obscure curves
poseidon-hash==0.1.4 # used as the algebraic hash in CL
hypothesis==6.103.0
dacite==1.8.1
pandas==2.2.2
matplotlib==3.9.1
PyYAML==6.0.1
usim==0.4.4