From f7f931f73eb1835b7b9e6a56fc2f450f146cc262 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Mon, 15 Jul 2024 18:07:26 +0900 Subject: [PATCH] Mixnet: Simulation --- .github/workflows/ci.yml | 4 +- .gitignore | 2 + mixnet/config.py | 6 +- mixnet/connection.py | 56 +++++++++++--- mixnet/error.py | 2 + mixnet/framework/__init__.py | 1 + mixnet/framework/asyncio.py | 49 ++++++++++++ mixnet/framework/framework.py | 47 ++++++++++++ mixnet/framework/usim.py | 55 ++++++++++++++ mixnet/node.py | 25 ++++-- mixnet/nomssip.py | 19 +++-- mixnet/sim/README.md | 81 ++++++++++++++++++++ mixnet/sim/__init__.py | 0 mixnet/sim/config.ci.yaml | 39 ++++++++++ mixnet/sim/config.py | 138 ++++++++++++++++++++++++++++++++++ mixnet/sim/connection.py | 100 ++++++++++++++++++++++++ mixnet/sim/hamming.py | 42 +++++++++++ mixnet/sim/main.py | 25 ++++++ mixnet/sim/simulation.py | 115 ++++++++++++++++++++++++++++ mixnet/sim/state.py | 76 +++++++++++++++++++ mixnet/sim/stats.py | 112 +++++++++++++++++++++++++++ mixnet/test_node.py | 18 +++-- requirements.txt | 5 ++ 23 files changed, 985 insertions(+), 32 deletions(-) create mode 100644 mixnet/error.py create mode 100644 mixnet/framework/__init__.py create mode 100644 mixnet/framework/asyncio.py create mode 100644 mixnet/framework/framework.py create mode 100644 mixnet/framework/usim.py create mode 100644 mixnet/sim/README.md create mode 100644 mixnet/sim/__init__.py create mode 100644 mixnet/sim/config.ci.yaml create mode 100644 mixnet/sim/config.py create mode 100644 mixnet/sim/connection.py create mode 100644 mixnet/sim/hamming.py create mode 100644 mixnet/sim/main.py create mode 100644 mixnet/sim/simulation.py create mode 100644 mixnet/sim/state.py create mode 100644 mixnet/sim/stats.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9dcacf3..dae8701 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,10 +19,12 @@ jobs: uses: actions/setup-python@v5 with: # Semantic version range syntax or exact version of a Python version - python-version: '3.x' + python-version: "3.x" - name: Install dependencies run: pip install -r requirements.txt - name: Build and install eth-specs run: ./install-eth-specs.sh - name: Run tests run: python -m unittest + - name: Run a short mixnet simulation + run: python -m mixnet.sim.main --config mixnet/sim/config.ci.yaml diff --git a/.gitignore b/.gitignore index d4b47f4..e3608ca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .venv __pycache__ + +*.csv diff --git a/mixnet/config.py b/mixnet/config.py index 749038a..45974f4 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -1,7 +1,7 @@ from __future__ import annotations import random -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List from cryptography.hazmat.primitives.asymmetric.x25519 import ( @@ -19,7 +19,6 @@ class GlobalConfig: membership: MixMembership transmission_rate_per_sec: int # Global Transmission Rate - # TODO: use these two to make the size of Sphinx packet constant max_message_size: int max_mix_path_length: int @@ -49,12 +48,13 @@ class MixMembership: """ nodes: List[NodeInfo] + rng: random.Random = field(default_factory=random.Random) def generate_route(self, length: int) -> list[NodeInfo]: """ Choose `length` nodes with replacement as a mix route. """ - return random.choices(self.nodes, k=length) + return self.rng.choices(self.nodes, k=length) @dataclass diff --git a/mixnet/connection.py b/mixnet/connection.py index 5fc319d..7de41cd 100644 --- a/mixnet/connection.py +++ b/mixnet/connection.py @@ -1,9 +1,40 @@ from __future__ import annotations -import asyncio +import abc -NetworkPacketQueue = asyncio.Queue[bytes] -SimplexConnection = NetworkPacketQueue +from mixnet.framework import Framework, Queue + +NetworkPacketQueue = Queue + + +class SimplexConnection(abc.ABC): + """ + An abstract class for a simplex connection that can send and receive data in one direction + """ + + @abc.abstractmethod + async def send(self, data: bytes) -> None: + pass + + @abc.abstractmethod + async def recv(self) -> bytes: + pass + + +class LocalSimplexConnection(SimplexConnection): + """ + A simplex connection that doesn't have any network latency. + Data sent through this connection can be immediately received from the other end. + """ + + def __init__(self, framework: Framework): + self.queue = framework.queue() + + async def send(self, data: bytes) -> None: + await self.queue.put(data) + + async def recv(self) -> bytes: + return await self.queue.get() class DuplexConnection: @@ -17,7 +48,7 @@ class DuplexConnection: self.outbound = outbound async def recv(self) -> bytes: - return await self.inbound.get() + return await self.inbound.recv() async def send(self, packet: bytes): await self.outbound.send(packet) @@ -29,24 +60,29 @@ class MixSimplexConnection: """ def __init__( - self, conn: SimplexConnection, transmission_rate_per_sec: int, noise_msg: bytes + self, + framework: Framework, + conn: SimplexConnection, + transmission_rate_per_sec: int, + noise_msg: bytes, ): - self.queue = asyncio.Queue() + self.framework = framework + self.queue = framework.queue() self.conn = conn self.transmission_rate_per_sec = transmission_rate_per_sec self.noise_msg = noise_msg - self.task = asyncio.create_task(self.__run()) + self.task = framework.spawn(self.__run()) async def __run(self): while True: - await asyncio.sleep(1 / self.transmission_rate_per_sec) + await self.framework.sleep(1 / self.transmission_rate_per_sec) # TODO: temporal mixing if self.queue.empty(): # To guarantee GTR, send noise if there is no message to send msg = self.noise_msg else: - msg = self.queue.get_nowait() - await self.conn.put(msg) + msg = await self.queue.get() + await self.conn.send(msg) async def send(self, msg: bytes): await self.queue.put(msg) diff --git a/mixnet/error.py b/mixnet/error.py new file mode 100644 index 0000000..96d0c41 --- /dev/null +++ b/mixnet/error.py @@ -0,0 +1,2 @@ +class PeeringDegreeReached(Exception): + pass diff --git a/mixnet/framework/__init__.py b/mixnet/framework/__init__.py new file mode 100644 index 0000000..13aa683 --- /dev/null +++ b/mixnet/framework/__init__.py @@ -0,0 +1 @@ +from .framework import * diff --git a/mixnet/framework/asyncio.py b/mixnet/framework/asyncio.py new file mode 100644 index 0000000..53c19b6 --- /dev/null +++ b/mixnet/framework/asyncio.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Awaitable, Coroutine + +from mixnet import framework + + +class Framework(framework.Framework): + """ + An asyncio implementation of the Framework + """ + + def __init__(self): + super().__init__() + + def queue(self) -> framework.Queue: + return Queue() + + async def sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) + + def now(self) -> float: + return time.time() + + def spawn( + self, coroutine: Coroutine[Any, Any, framework.RT] + ) -> Awaitable[framework.RT]: + return asyncio.create_task(coroutine) + + +class Queue(framework.Queue): + """ + An asyncio implementation of the Queue + """ + + def __init__(self): + super().__init__() + self._queue = asyncio.Queue() + + async def put(self, data: bytes) -> None: + await self._queue.put(data) + + async def get(self) -> bytes: + return await self._queue.get() + + def empty(self) -> bool: + return self._queue.empty() diff --git a/mixnet/framework/framework.py b/mixnet/framework/framework.py new file mode 100644 index 0000000..3844120 --- /dev/null +++ b/mixnet/framework/framework.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import abc +from typing import Any, Awaitable, Coroutine, TypeVar + +RT = TypeVar("RT") + + +class Framework(abc.ABC): + """ + An abstract class that provides essential asynchronous functions. + This class can be implemented using any asynchronous framework (e.g., asyncio, usim)). + """ + + @abc.abstractmethod + def queue(self) -> Queue: + pass + + @abc.abstractmethod + async def sleep(self, seconds: float) -> None: + pass + + @abc.abstractmethod + def now(self) -> float: + pass + + @abc.abstractmethod + def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]: + pass + + +class Queue(abc.ABC): + """ + An abstract class that provides asynchronous queue operations. + """ + + @abc.abstractmethod + async def put(self, data: bytes) -> None: + pass + + @abc.abstractmethod + async def get(self) -> bytes: + pass + + @abc.abstractmethod + def empty(self) -> bool: + pass diff --git a/mixnet/framework/usim.py b/mixnet/framework/usim.py new file mode 100644 index 0000000..5181802 --- /dev/null +++ b/mixnet/framework/usim.py @@ -0,0 +1,55 @@ +from typing import Any, Awaitable, Coroutine + +import usim + +from mixnet import framework + + +class Framework(framework.Framework): + """ + A usim implementation of the Framework for discrete-time simulation + """ + + def __init__(self, scope: usim.Scope) -> None: + super().__init__() + + # Scope is used to spawn concurrent simulation activities (coroutines). + # μSim waits until all activities spawned in the scope are done + # or until the timeout specified in the scope is reached. + # Because of the way μSim works, the scope must be created using `async with` syntax + # and be passed to this constructor. + self._scope = scope + + def queue(self) -> framework.Queue: + return Queue() + + async def sleep(self, seconds: float) -> None: + await (usim.time + seconds) + + def now(self) -> float: + # Round to milliseconds to make analysis not too heavy + return int(usim.time.now * 1000) / 1000 + + def spawn( + self, coroutine: Coroutine[Any, Any, framework.RT] + ) -> Awaitable[framework.RT]: + return self._scope.do(coroutine) + + +class Queue(framework.Queue): + """ + A usim implementation of the Queue for discrete-time simulation + """ + + def __init__(self): + super().__init__() + self._queue = usim.Queue() + + async def put(self, data: bytes) -> None: + await self._queue.put(data) + + async def get(self) -> bytes: + return await self._queue + + def empty(self) -> bool: + return len(self._queue._buffer) == 0 diff --git a/mixnet/node.py b/mixnet/node.py index d48e997..fa98ca3 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from typing import TypeAlias from pysphinx.sphinx import ( @@ -10,10 +9,13 @@ from pysphinx.sphinx import ( ) from mixnet.config import GlobalConfig, NodeConfig +from mixnet.connection import SimplexConnection +from mixnet.error import PeeringDegreeReached +from mixnet.framework import Framework, Queue from mixnet.nomssip import Nomssip from mixnet.sphinx import SphinxPacketBuilder -BroadcastChannel: TypeAlias = asyncio.Queue[bytes] +BroadcastChannel = Queue class Node: @@ -24,10 +26,14 @@ class Node: - generates noise """ - def __init__(self, config: NodeConfig, global_config: GlobalConfig): + def __init__( + self, framework: Framework, config: NodeConfig, global_config: GlobalConfig + ): + self.framework = framework self.config = config self.global_config = global_config self.nomssip = Nomssip( + framework, Nomssip.Config( global_config.transmission_rate_per_sec, config.nomssip.peering_degree, @@ -35,7 +41,7 @@ class Node: ), self.__process_msg, ) - self.broadcast_channel = asyncio.Queue() + self.broadcast_channel = framework.queue() @staticmethod def __calculate_message_size(global_config: GlobalConfig) -> int: @@ -84,11 +90,18 @@ class Node: # Return nothing, if it cannot be unwrapped by the private key of this node. return None - def connect(self, peer: Node): + def connect( + self, + peer: Node, + inbound_conn: SimplexConnection, + outbound_conn: SimplexConnection, + ): """ Establish a duplex connection with a peer node. """ - inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() + if not self.nomssip.can_accept_conn() or not peer.nomssip.can_accept_conn(): + raise PeeringDegreeReached() + # Register a duplex connection for its own use self.nomssip.add_conn(inbound_conn, outbound_conn) # Register a duplex connection for the peer diff --git a/mixnet/nomssip.py b/mixnet/nomssip.py index f0fc4e6..35f81e6 100644 --- a/mixnet/nomssip.py +++ b/mixnet/nomssip.py @@ -1,12 +1,13 @@ from __future__ import annotations -import asyncio import hashlib from dataclasses import dataclass from enum import Enum from typing import Awaitable, Callable, Self from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection +from mixnet.error import PeeringDegreeReached +from mixnet.framework import Framework class Nomssip: @@ -23,9 +24,11 @@ class Nomssip: def __init__( self, + framework: Framework, config: Config, handler: Callable[[bytes], Awaitable[None]], ): + self.framework = framework self.config = config self.conns: list[DuplexConnection] = [] # A handler to process inbound messages. @@ -33,12 +36,15 @@ class Nomssip: self.packet_cache: set[bytes] = set() # A set just for gathering a reference of tasks to prevent them from being garbage collected. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - self.tasks: set[asyncio.Task] = set() + self.tasks: set[Awaitable] = set() + + def can_accept_conn(self) -> bool: + return len(self.conns) < self.config.peering_degree def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): - if len(self.conns) >= self.config.peering_degree: + if not self.can_accept_conn(): # For simplicity of the spec, reject the connection if the peering degree is reached. - raise ValueError("The peering degree is reached.") + raise PeeringDegreeReached() noise_packet = FlaggedPacket( FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size) @@ -46,6 +52,7 @@ class Nomssip: conn = DuplexConnection( inbound, MixSimplexConnection( + self.framework, outbound, self.config.transmission_rate_per_sec, noise_packet, @@ -53,10 +60,8 @@ class Nomssip: ) self.conns.append(conn) - task = asyncio.create_task(self.__process_inbound_conn(conn)) + task = self.framework.spawn(self.__process_inbound_conn(conn)) self.tasks.add(task) - # To discard the task from the set automatically when it is done. - task.add_done_callback(self.tasks.discard) async def __process_inbound_conn(self, conn: DuplexConnection): while True: diff --git a/mixnet/sim/README.md b/mixnet/sim/README.md new file mode 100644 index 0000000..5fa2a2a --- /dev/null +++ b/mixnet/sim/README.md @@ -0,0 +1,81 @@ +# NomMix Simulation + +## Installation + +Clone the repository and install the dependencies: +```bash +git clone https://github.com/logos-co/nomos-specs.git +cd nomos-specs +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +``` + +## Getting started + +Copy the [`mixnet/sim/config.ci.yaml`](./config.ci.yaml) file and adjust the parameters to your needs. +Each parameter is explained in the config file. +For more details, please refer to the [documentation](https://www.notion.so/NomMix-Sim-Getting-Started-ee0e2191f4e7437e93976aff2627d7ce?pvs=4). + +Run the simulation with the following command: +```bash +python -m mixnet.sim.main --config {config_path} +``` + +All results are printed in the console as below. +And, all plots are shown once all analysis is done. +``` +========================================== + Message Size Distribution +========================================== + msg_size count +0 1405 99990 + +========================================== + Node States of All Nodes over Time +========================================== + Node-0 Node-1 Node-2 Node-3 Node-4 +0 0 0 0 0 0 +1 0 0 0 0 0 +2 0 0 0 0 0 +3 0 0 0 0 0 +4 0 0 0 0 0 +... ... ... ... ... ... +999995 0 0 0 0 0 +999996 0 0 0 0 0 +999997 0 0 0 0 0 +999998 0 0 0 0 0 +999999 0 0 0 0 0 + +[1000000 rows x 5 columns] + +Saved DataFrame to all_node_states_2024-07-15T18:20:23.csv + +State Counts per Node: + Node-0 Node-1 Node-2 Node-3 Node-4 + 0 970003 970003 970003 970003 970003 + 1 19998 19998 19998 19998 19998 +-1 9999 9999 9999 9999 9999 + +Simulation complete! +``` + +Please note that the result of node state analysis is saved as a CSV file, as printed in the console. +``` +Saved DataFrame to all_node_states_2024-07-15T18:20:23.csv +``` + +If you run the simulation again with the different parameters and want to +compare the results of two simulations, +you can calculate the hamming distance between them: +```bash +python -m mixnet.sim.hamming \ + all_node_states_2024-07-15T18:20:23.csv \ + all_node_states_2024-07-15T19:32:45.csv +``` +The output is a floating point number between 0 and 1. +If the output is 0, the results of two simulations are identical. +The closer the result is to 1, the more the two results differ from each other. +``` +Hamming distance: 0.29997 +``` diff --git a/mixnet/sim/__init__.py b/mixnet/sim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/sim/config.ci.yaml b/mixnet/sim/config.ci.yaml new file mode 100644 index 0000000..8c4d60a --- /dev/null +++ b/mixnet/sim/config.ci.yaml @@ -0,0 +1,39 @@ +simulation: + # Desired duration of the simulation in seconds + # Since the simulation uses discrete time steps, the actual duration may be longer or shorter. + duration_sec: 1000 + # Show all plots that have been drawn during the simulation + show_plots: false + +network: + # Total number of nodes in the entire network. + num_nodes: 5 + latency: + # Maximum network latency between nodes in seconds. + # A constant latency will be chosen randomly for each connection within the range [0, max_latency_sec]. + max_latency_sec: 0.1 + # Seed for the random number generator used to determine the network latencies. + seed: 0 + nomssip: + # Target number of peers each node can connect to (both inbound and outbound). + peering_degree: 6 + +mix: + # Global constant transmission rate of each connection in messages per second. + transmission_rate_per_sec: 10 + # Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet. + max_message_size: 1007 + mix_path: + # Maximum number of mix nodes to be chosen for a Sphinx packet. + max_length: 5 + # Seed for the random number generator used to determine the mix path. + seed: 3 + +logic: + sender_lottery: + # Interval between lottery draws in seconds. + interval_sec: 1 + # Probability of a node being selected as a sender in each lottery draw. + probability: 0.001 + # Seed for the random number generator used to determine the lottery winners. + seed: 10 diff --git a/mixnet/sim/config.py b/mixnet/sim/config.py new file mode 100644 index 0000000..5d2e7b1 --- /dev/null +++ b/mixnet/sim/config.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import hashlib +import random +from dataclasses import dataclass + +import dacite +import yaml +from pysphinx.sphinx import X25519PrivateKey + +from mixnet.config import NodeConfig, NomssipConfig + + +@dataclass +class Config: + simulation: SimulationConfig + network: NetworkConfig + logic: LogicConfig + mix: MixConfig + + @classmethod + def load(cls, yaml_path: str) -> Config: + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + return dacite.from_dict( + data_class=Config, + data=data, + config=dacite.Config( + type_hooks={random.Random: seed_to_random}, strict=True + ), + ) + + def node_configs(self) -> list[NodeConfig]: + return [ + NodeConfig( + self._gen_private_key(i), + self.mix.mix_path.random_length(), + self.network.nomssip, + ) + for i in range(self.network.num_nodes) + ] + + def _gen_private_key(self, node_idx: int) -> X25519PrivateKey: + return X25519PrivateKey.from_private_bytes( + hashlib.sha256(node_idx.to_bytes(4, "big")).digest()[:32] + ) + + +@dataclass +class SimulationConfig: + # Desired duration of the simulation in seconds + # Since the simulation uses discrete time steps, the actual duration may be longer or shorter. + duration_sec: int + # Show all plots that have been drawn during the simulation + show_plots: bool + + def __post_init__(self): + assert self.duration_sec > 0 + + +@dataclass +class NetworkConfig: + # Total number of nodes in the entire network. + num_nodes: int + latency: LatencyConfig + nomssip: NomssipConfig + + def __post_init__(self): + assert self.num_nodes > 0 + + +@dataclass +class LatencyConfig: + # Maximum network latency between nodes in seconds. + # A constant latency will be chosen randomly for each connection within the range [0, max_latency_sec]. + max_latency_sec: float + # Seed for the random number generator used to determine the network latencies. + seed: random.Random + + def __post_init__(self): + assert self.max_latency_sec > 0 + assert self.seed is not None + + def random_latency(self) -> float: + # round to milliseconds to make analysis not too heavy + return int(self.seed.random() * self.max_latency_sec * 1000) / 1000 + + +@dataclass +class MixConfig: + # Global constant transmission rate of each connection in messages per second. + transmission_rate_per_sec: int + # Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet. + max_message_size: int + mix_path: MixPathConfig + + def __post_init__(self): + assert self.transmission_rate_per_sec > 0 + assert self.max_message_size > 0 + + +@dataclass +class MixPathConfig: + # Maximum number of mix nodes to be chosen for a Sphinx packet. + max_length: int + # Seed for the random number generator used to determine the mix path. + seed: random.Random + + def __post_init__(self): + assert self.max_length > 0 + assert self.seed is not None + + def random_length(self) -> int: + return self.seed.randint(1, self.max_length) + + +@dataclass +class LogicConfig: + sender_lottery: LotteryConfig + + +@dataclass +class LotteryConfig: + # Interval between lottery draws in seconds. + interval_sec: float + # Probability of a node being selected as a sender in each lottery draw. + probability: float + # Seed for the random number generator used to determine the lottery winners. + seed: random.Random + + def __post_init__(self): + assert self.interval_sec > 0 + assert self.probability >= 0 + assert self.seed is not None + + +def seed_to_random(seed: int) -> random.Random: + return random.Random(seed) diff --git a/mixnet/sim/connection.py b/mixnet/sim/connection.py new file mode 100644 index 0000000..695aecd --- /dev/null +++ b/mixnet/sim/connection.py @@ -0,0 +1,100 @@ +import math +from collections import Counter +from typing import Awaitable + +import pandas + +from mixnet.connection import SimplexConnection +from mixnet.framework import Framework, Queue +from mixnet.sim.config import NetworkConfig +from mixnet.sim.state import NodeState + + +class MeteredRemoteSimplexConnection(SimplexConnection): + """ + A simplex connection implementation that simulates network latency and measures bandwidth usages. + """ + + def __init__( + self, + config: NetworkConfig, + framework: Framework, + send_node_states: list[NodeState], + recv_node_states: list[NodeState], + ): + self.framework = framework + # A connection has a random constant latency + self.latency = config.latency.random_latency() + # A queue where a sender puts messages to be sent + self.send_queue = framework.queue() + # A queue that connects send_queue and recv_queue (to measure bandwidths and simulate latency) + self.mid_queue = framework.queue() + # A queue where a receiver gets messages + self.recv_queue = framework.queue() + # A task that reads messages from send_queue, updates bandwidth stats, and puts them to mid_queue + self.send_meters: list[int] = [] + self.send_task = framework.spawn(self.__run_send_task()) + # A task that reads messages from mid_queue, simulates network latency, updates bandwidth stats, and puts them to recv_queue + self.recv_meters: list[int] = [] + self.recv_task = framework.spawn(self.__run_recv_task()) + # To measure node states over time + self.send_node_states = send_node_states + self.recv_node_states = recv_node_states + # To measure the size of messages sent via this connection + self.msg_sizes: Counter[int] = Counter() + + async def send(self, data: bytes) -> None: + await self.send_queue.put(data) + self.msg_sizes.update([len(data)]) + # The time unit of node states is milliseconds + ms = math.floor(self.framework.now() * 1000) + self.send_node_states[ms] = NodeState.SENDING + + async def recv(self) -> bytes: + data = await self.recv_queue.get() + # The time unit of node states is milliseconds + ms = math.floor(self.framework.now() * 1000) + self.send_node_states[ms] = NodeState.RECEIVING + return data + + async def __run_send_task(self): + """ + A task that reads messages from send_queue, updates bandwidth stats, and puts them to mid_queue + """ + start_time = self.framework.now() + while True: + data = await self.send_queue.get() + self.__update_meter(self.send_meters, len(data), start_time) + await self.mid_queue.put(data) + + async def __run_recv_task(self): + """ + A task that reads messages from mid_queue, simulates network latency, updates bandwidth stats, and puts them to recv_queue + """ + start_time = self.framework.now() + while True: + data = await self.mid_queue.get() + if data is None: + break + await self.framework.sleep(self.latency) + self.__update_meter(self.recv_meters, len(data), start_time) + await self.recv_queue.put(data) + + def __update_meter(self, meters: list[int], size: int, start_time: float): + """ + Accumulates the bandwidth usage in the current time slot (seconds). + """ + slot = math.floor(self.framework.now() - start_time) + assert slot >= len(meters) - 1 + # Fill zeros for the empty time slots + meters.extend([0] * (slot - len(meters) + 1)) + meters[-1] += size + + def sending_bandwidths(self) -> pandas.Series: + return self.__bandwidths(self.send_meters) + + def receiving_bandwidths(self) -> pandas.Series: + return self.__bandwidths(self.recv_meters) + + def __bandwidths(self, meters: list[int]) -> pandas.Series: + return pandas.Series(meters, name="bandwidth") diff --git a/mixnet/sim/hamming.py b/mixnet/sim/hamming.py new file mode 100644 index 0000000..974a517 --- /dev/null +++ b/mixnet/sim/hamming.py @@ -0,0 +1,42 @@ +import sys + +import pandas as pd + + +def calculate_hamming_distance(df1, df2): + """ + Caculate the hamming distance between two DataFrames + to quantify the difference between them. + """ + if df1.shape != df2.shape: + raise ValueError( + "DataFrames must have the same shape to calculate Hamming distance." + ) + + # Compare element-wise and count differences + differences = (df1 != df2).sum().sum() + return differences / df1.size # normalize the distance + + +def main(): + if len(sys.argv) != 3: + print("Usage: python hamming.py ") + sys.exit(1) + + csv_path1 = sys.argv[1] + csv_path2 = sys.argv[2] + + # Load the CSV files into DataFrames + df1 = pd.read_csv(csv_path1) + df2 = pd.read_csv(csv_path2) + + # Calculate the Hamming distance + try: + hamming_distance = calculate_hamming_distance(df1, df2) + print(f"Hamming distance: {hamming_distance}") + except ValueError as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/mixnet/sim/main.py b/mixnet/sim/main.py new file mode 100644 index 0000000..4016de8 --- /dev/null +++ b/mixnet/sim/main.py @@ -0,0 +1,25 @@ +import argparse + +import usim + +from mixnet.sim.config import Config +from mixnet.sim.simulation import Simulation + +if __name__ == "__main__": + """ + Read a config file and run a simulation + """ + parser = argparse.ArgumentParser( + description="Run mixnet simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, required=True, help="Configuration file path" + ) + args = parser.parse_args() + + config = Config.load(args.config) + sim = Simulation(config) + usim.run(sim.run()) + + print("Simulation complete!") diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py new file mode 100644 index 0000000..283a2e2 --- /dev/null +++ b/mixnet/sim/simulation.py @@ -0,0 +1,115 @@ +import usim +from matplotlib import pyplot + +import mixnet.framework.usim as usimfw +from mixnet.config import GlobalConfig, MixMembership, NodeInfo +from mixnet.framework import Framework +from mixnet.node import Node, PeeringDegreeReached +from mixnet.sim.config import Config +from mixnet.sim.connection import MeteredRemoteSimplexConnection +from mixnet.sim.state import NodeState, NodeStateTable +from mixnet.sim.stats import ConnectionStats + + +class Simulation: + """ + Manages the entire cycle of simulation: initialization, running, and analysis. + """ + + def __init__(self, config: Config): + self.config = config + + async def run(self): + # Run the simulation + conn_stats, node_state_table = await self.__run() + # Analyze the simulation results + conn_stats.analyze() + node_state_table.analyze() + # Show plots + if self.config.simulation.show_plots: + pyplot.show() + + async def __run(self) -> tuple[ConnectionStats, NodeStateTable]: + # Initialize analysis tools + node_state_table = NodeStateTable( + self.config.network.num_nodes, self.config.simulation.duration_sec + ) + conn_stats = ConnectionStats() + + # Create a μSim scope and run the simulation + async with usim.until(usim.time + self.config.simulation.duration_sec) as scope: + self.framework = usimfw.Framework(scope) + nodes, conn_stats, node_state_table = self.__init_nodes( + node_state_table, conn_stats + ) + for node in nodes: + self.framework.spawn(self.__run_node_logic(node)) + + # Return analysis tools once the μSim scope is done + return conn_stats, node_state_table + + def __init_nodes( + self, node_state_table: NodeStateTable, conn_stats: ConnectionStats + ) -> tuple[list[Node], ConnectionStats, NodeStateTable]: + # Initialize node/global configurations + node_configs = self.config.node_configs() + global_config = GlobalConfig( + MixMembership( + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ], + self.config.mix.mix_path.seed, + ), + self.config.mix.transmission_rate_per_sec, + self.config.mix.max_message_size, + self.config.mix.mix_path.max_length, + ) + + # Initialize Node instances + nodes = [ + Node(self.framework, node_config, global_config) + for node_config in node_configs + ] + + # Connect nodes to each other + for i, node in enumerate(nodes): + # For now, we only consider a simple ring topology for simplicity. + peer_idx = (i + 1) % len(nodes) + peer = nodes[peer_idx] + node_states = node_state_table[i] + peer_states = node_state_table[peer_idx] + + # Create simplex inbound/outbound connections + # and use them to connect node and peer. + inbound_conn, outbound_conn = ( + self.__create_conn(peer_states, node_states), + self.__create_conn(node_states, peer_states), + ) + node.connect(peer, inbound_conn, outbound_conn) + # Register the connections to the connection statistics + conn_stats.register(node, inbound_conn, outbound_conn) + conn_stats.register(peer, outbound_conn, inbound_conn) + + return nodes, conn_stats, node_state_table + + def __create_conn( + self, sender_states: list[NodeState], receiver_states: list[NodeState] + ) -> MeteredRemoteSimplexConnection: + return MeteredRemoteSimplexConnection( + self.config.network, + self.framework, + sender_states, + receiver_states, + ) + + async def __run_node_logic(self, node: Node): + """ + Runs the lottery periodically to check if the node is selected to send a block. + If the node is selected, creates a block and sends it through mix nodes. + """ + lottery_config = self.config.logic.sender_lottery + while True: + await (usim.time + lottery_config.interval_sec) + if lottery_config.seed.random() < lottery_config.probability: + await node.send_message(b"selected block") diff --git a/mixnet/sim/state.py b/mixnet/sim/state.py new file mode 100644 index 0000000..1d203f0 --- /dev/null +++ b/mixnet/sim/state.py @@ -0,0 +1,76 @@ +from datetime import datetime +from enum import Enum + +import matplotlib.pyplot as plt +import pandas + + +class NodeState(Enum): + """ + A state of node at a certain time. + For now, we assume that the node cannot send and receive messages at the same time for simplicity. + """ + + SENDING = -1 + IDLE = 0 + RECEIVING = 1 + + +class NodeStateTable: + def __init__(self, num_nodes: int, duration_sec: int): + # Create a table to store the state of each node at each millisecond + self.__table = [ + [NodeState.IDLE] * (duration_sec * 1000) for _ in range(num_nodes) + ] + + def __getitem__(self, idx: int) -> list[NodeState]: + return self.__table[idx] + + def analyze(self): + df = pandas.DataFrame(self.__table).transpose() + df.columns = [f"Node-{i}" for i in range(len(self.__table))] + # Convert NodeState enum to their integer values + df = df.map(lambda state: state.value) + print("==========================================") + print(" Node States of All Nodes over Time") + print("==========================================") + print(f"{df}\n") + + csv_path = f"all_node_states_{datetime.now().isoformat(timespec="seconds")}.csv" + df.to_csv(csv_path) + print(f"Saved DataFrame to {csv_path}\n") + + # Count/print the number of each state for each node + # because the df is usually too big to print + state_counts = df.apply(pandas.Series.value_counts).fillna(0) + print("State Counts per Node:") + print(f"{state_counts}\n") + + # Draw a dot plot + plt.figure(figsize=(15, 8)) + for node in df.columns: + times = df.index + states = df[node] + sending_times = times[states == NodeState.SENDING.value] + receiving_times = times[states == NodeState.RECEIVING.value] + plt.scatter( + sending_times, + [node] * len(sending_times), + color="red", + marker="o", + s=10, + label="SENDING" if node == df.columns[0] else "", + ) + plt.scatter( + receiving_times, + [node] * len(receiving_times), + color="blue", + marker="x", + s=10, + label="RECEIVING" if node == df.columns[0] else "", + ) + plt.xlabel("Time") + plt.ylabel("Node") + plt.title("Node States Over Time") + plt.legend(loc="upper right") + plt.draw() diff --git a/mixnet/sim/stats.py b/mixnet/sim/stats.py new file mode 100644 index 0000000..0efbae2 --- /dev/null +++ b/mixnet/sim/stats.py @@ -0,0 +1,112 @@ +from collections import Counter, defaultdict + +import matplotlib.pyplot as plt +import pandas + +from mixnet.node import Node +from mixnet.sim.connection import MeteredRemoteSimplexConnection + +# A map of nodes to their inbound/outbound connections +NodeConnectionsMap = dict[ + Node, + tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]], +] + + +class ConnectionStats: + def __init__(self): + self.conns_per_node: NodeConnectionsMap = defaultdict(lambda: ([], [])) + + def register( + self, + node: Node, + inbound_conn: MeteredRemoteSimplexConnection, + outbound_conn: MeteredRemoteSimplexConnection, + ): + self.conns_per_node[node][0].append(inbound_conn) + self.conns_per_node[node][1].append(outbound_conn) + + def analyze(self): + self.__message_sizes() + self.__bandwidths_per_conn() + self.__bandwidths_per_node() + + def __message_sizes(self): + """ + Analyzes all message sizes sent across all connections of all nodes. + """ + sizes: Counter[int] = Counter() + for _, (_, outbound_conns) in self.conns_per_node.items(): + for conn in outbound_conns: + sizes.update(conn.msg_sizes) + + df = pandas.DataFrame.from_dict(sizes, orient="index").reset_index() + df.columns = ["msg_size", "count"] + print("==========================================") + print(" Message Size Distribution") + print("==========================================") + print(f"{df}\n") + + def __bandwidths_per_conn(self): + """ + Analyzes the bandwidth consumed by each simplex connection. + """ + plt.plot(figsize=(12, 6)) + + for _, (_, outbound_conns) in self.conns_per_node.items(): + for conn in outbound_conns: + sending_bandwidths = conn.sending_bandwidths().map(lambda x: x / 1024) + plt.plot(sending_bandwidths.index, sending_bandwidths) + + plt.title("Unidirectional Bandwidths per Connection") + plt.xlabel("Time (s)") + plt.ylabel("Bandwidth (KiB/s)") + plt.ylim(bottom=0) + plt.grid(True) + plt.tight_layout() + plt.draw() + + def __bandwidths_per_node(self): + """ + Analyzes the inbound/outbound bandwidths consumed by each node (sum of all its connections). + """ + _, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6)) + + for i, (_, (inbound_conns, outbound_conns)) in enumerate( + self.conns_per_node.items() + ): + inbound_bandwidths = ( + pandas.concat( + [conn.receiving_bandwidths() for conn in inbound_conns], axis=1 + ) + .sum(axis=1) + .map(lambda x: x / 1024) + ) + outbound_bandwidths = ( + pandas.concat( + [conn.sending_bandwidths() for conn in outbound_conns], axis=1 + ) + .sum(axis=1) + .map(lambda x: x / 1024) + ) + axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}") + axs[1].plot( + outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}" + ) + + axs[0].set_title("Inbound Bandwidths per Node") + axs[0].set_xlabel("Time (s)") + axs[0].set_ylabel("Bandwidth (KiB/s)") + axs[0].legend() + axs[0].set_ylim(bottom=0) + axs[0].grid(True) + + axs[1].set_title("Outbound Bandwidths per Node") + axs[1].set_xlabel("Time (s)") + axs[1].set_ylabel("Bandwidth (KiB/s)") + axs[1].legend() + axs[1].set_ylim(bottom=0) + axs[1].grid(True) + + plt.tight_layout() + plt.draw() diff --git a/mixnet/test_node.py b/mixnet/test_node.py index f4ba644..2048f14 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -1,6 +1,7 @@ -import asyncio from unittest import IsolatedAsyncioTestCase +import mixnet.framework.asyncio as asynciofw +from mixnet.connection import LocalSimplexConnection from mixnet.node import Node from mixnet.test_utils import ( init_mixnet_config, @@ -9,11 +10,18 @@ from mixnet.test_utils import ( class TestNode(IsolatedAsyncioTestCase): async def test_node(self): + framework = asynciofw.Framework() global_config, node_configs, _ = init_mixnet_config(10) - nodes = [Node(node_config, global_config) for node_config in node_configs] + nodes = [ + Node(framework, node_config, global_config) for node_config in node_configs + ] for i, node in enumerate(nodes): try: - node.connect(nodes[(i + 1) % len(nodes)]) + node.connect( + nodes[(i + 1) % len(nodes)], + LocalSimplexConnection(framework), + LocalSimplexConnection(framework), + ) except ValueError as e: print(e) @@ -24,10 +32,10 @@ class TestNode(IsolatedAsyncioTestCase): broadcasted_msgs = [] for node in nodes: if not node.broadcast_channel.empty(): - broadcasted_msgs.append(node.broadcast_channel.get_nowait()) + broadcasted_msgs.append(await node.broadcast_channel.get()) if len(broadcasted_msgs) == 0: - await asyncio.sleep(1) + await framework.sleep(1) else: # We expect only one node to broadcast the message. assert len(broadcasted_msgs) == 1 diff --git a/requirements.txt b/requirements.txt index fb7551f..e90b905 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,8 @@ portalocker==2.8.2 # portable file locking keum==0.2.0 # for CL's use of more obscure curves poseidon-hash==0.1.4 # used as the algebraic hash in CL hypothesis==6.103.0 +dacite==1.8.1 +pandas==2.2.2 +matplotlib==3.9.1 +PyYAML==6.0.1 +usim==0.4.4