From 39eabe1537e407846cc7c3d04b14255990c78253 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Thu, 1 Aug 2024 11:07:52 +0900 Subject: [PATCH] Mixnet: Initial simulation (#6) --- .github/workflows/ci.yaml | 30 +++++ mixnet/.gitignore | 2 + mixnet/README.md | 140 ++++++++++++++++++++ mixnet/cmd/__init__.py | 0 mixnet/cmd/hamming.py | 42 ++++++ mixnet/cmd/main.py | 25 ++++ mixnet/config.ci.yaml | 53 ++++++++ mixnet/framework/__init__.py | 1 + mixnet/framework/asyncio.py | 52 ++++++++ mixnet/framework/framework.py | 50 +++++++ mixnet/framework/usim.py | 58 ++++++++ mixnet/protocol/__init__.py | 0 mixnet/protocol/config.py | 66 ++++++++++ mixnet/protocol/connection.py | 88 +++++++++++++ mixnet/protocol/error.py | 2 + mixnet/protocol/gossip.py | 87 ++++++++++++ mixnet/protocol/node.py | 150 +++++++++++++++++++++ mixnet/protocol/nomssip.py | 106 +++++++++++++++ mixnet/protocol/sphinx.py | 33 +++++ mixnet/protocol/temporalmix.py | 177 +++++++++++++++++++++++++ mixnet/protocol/test_node.py | 58 ++++++++ mixnet/protocol/test_sphinx.py | 66 ++++++++++ mixnet/protocol/test_temporalmix.py | 103 +++++++++++++++ mixnet/protocol/test_utils.py | 46 +++++++ mixnet/requirements.txt | 6 + mixnet/sim/__init__.py | 0 mixnet/sim/config.py | 163 +++++++++++++++++++++++ mixnet/sim/connection.py | 139 +++++++++++++++++++ mixnet/sim/message.py | 42 ++++++ mixnet/sim/simulation.py | 198 ++++++++++++++++++++++++++++ mixnet/sim/state.py | 77 +++++++++++ mixnet/sim/stats.py | 152 +++++++++++++++++++++ mixnet/sim/test_connection.py | 104 +++++++++++++++ mixnet/sim/test_message.py | 23 ++++ mixnet/sim/test_topology.py | 20 +++ mixnet/sim/topology.py | 58 ++++++++ 36 files changed, 2417 insertions(+) create mode 100644 .github/workflows/ci.yaml create mode 100644 mixnet/.gitignore create mode 100644 mixnet/README.md create mode 100644 mixnet/cmd/__init__.py create mode 100644 mixnet/cmd/hamming.py create mode 100644 mixnet/cmd/main.py create mode 100644 mixnet/config.ci.yaml create mode 100644 mixnet/framework/__init__.py create mode 100644 mixnet/framework/asyncio.py create mode 100644 mixnet/framework/framework.py create mode 100644 mixnet/framework/usim.py create mode 100644 mixnet/protocol/__init__.py create mode 100644 mixnet/protocol/config.py create mode 100644 mixnet/protocol/connection.py create mode 100644 mixnet/protocol/error.py create mode 100644 mixnet/protocol/gossip.py create mode 100644 mixnet/protocol/node.py create mode 100644 mixnet/protocol/nomssip.py create mode 100644 mixnet/protocol/sphinx.py create mode 100644 mixnet/protocol/temporalmix.py create mode 100644 mixnet/protocol/test_node.py create mode 100644 mixnet/protocol/test_sphinx.py create mode 100644 mixnet/protocol/test_temporalmix.py create mode 100644 mixnet/protocol/test_utils.py create mode 100644 mixnet/requirements.txt create mode 100644 mixnet/sim/__init__.py create mode 100644 mixnet/sim/config.py create mode 100644 mixnet/sim/connection.py create mode 100644 mixnet/sim/message.py create mode 100644 mixnet/sim/simulation.py create mode 100644 mixnet/sim/state.py create mode 100644 mixnet/sim/stats.py create mode 100644 mixnet/sim/test_connection.py create mode 100644 mixnet/sim/test_message.py create mode 100644 mixnet/sim/test_topology.py create mode 100644 mixnet/sim/topology.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..c2628e9 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,30 @@ +name: CI + +on: + pull_request: + branches: + - "*" + push: + branches: [master] + +jobs: + mixnet: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Set up Python 3.x + uses: actions/setup-python@v5 + with: + python-version: "3.x" + - name: Install dependencies for mixnet + working-directory: mixnet + run: pip install -r requirements.txt + - name: Run unit tests + working-directory: mixnet + run: python -m unittest -v + - name: Run a short mixnet simulation + working-directory: mixnet + run: python -m cmd.main --config config.ci.yaml + diff --git a/mixnet/.gitignore b/mixnet/.gitignore new file mode 100644 index 0000000..d05906f --- /dev/null +++ b/mixnet/.gitignore @@ -0,0 +1,2 @@ +.venv/ +*.csv diff --git a/mixnet/README.md b/mixnet/README.md new file mode 100644 index 0000000..bed1aa6 --- /dev/null +++ b/mixnet/README.md @@ -0,0 +1,140 @@ +# NomMix Simulation + +* [Project Structure](#project-structure) +* [Features](#features) +* [Future Plans](#future-plans) +* [Installation](#installation) +* [Getting Started](#getting-started) + +## Project Structure + +- `cmd`: CLIs to run the simulation and analyze the results. +- `sim`: Simulation that runs the NomMix defined in the `protocol` package. +- `protocol`: Core NomMix protocol implementation, which is going to be moved to the [nomos-repos](https://github.com/logos-co/nomos-specs) repository once verified by simulations. +- `framework`: Asynchronous framework that provides essential async functions for simulations and tests, implemented with various async libraries ([asyncio](https://docs.python.org/3/library/asyncio.html), [μSim](https://usim.readthedocs.io/en/latest/), etc.) + +## Features + +- NomMix protocol simulation +- Performance measurements + - Bandwidth usages + - Message dissemination time +- Privacy property analysis + - Message sizes + - Node states and hamming distances + +## Future Plans + +- More NomMix features + - Temporal mixing + - Level-1 noise +- Adversary simulation to measure the robustness of NomMix + +## Installation + +Clone the repository and install the dependencies: +```bash +git clone https://github.com/logos-co/nomos-simulations.git +cd nomos-simulations/mixnet +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +``` + +## Getting Started + +Copy the [`config.ci.yaml`](./config.ci.yaml) file and adjust the parameters to your needs. +Each parameter is explained in the config file. +For more details, please refer to the [documentation](https://www.notion.so/NomMix-Sim-Getting-Started-ee0e2191f4e7437e93976aff2627d7ce?pvs=4). + +Run the simulation with the following command: +```bash +python -m cmd.main --config {config_path} +``` + +All results are printed in the console as below. +And, all plots are shown once all analysis is done. +``` +Spawning node-0 with 3 conns +Spawning node-1 with 3 conns +Spawning node-2 with 3 conns +Spawning node-3 with 3 conns +Spawning node-4 with 3 conns +Spawning node-5 with 3 conns +========================================== +Message Dissemination Time +========================================== +[Mix Propagation Times] +count 7.000000 +mean 1.122000 +std 0.106276 +min 1.009000 +25% 1.024500 +50% 1.157000 +75% 1.174500 +max 1.290000 +dtype: float64 + +[Broadcast Dissemination Times] +count 7.000000 +mean 0.118429 +std 0.004353 +min 0.111000 +25% 0.116000 +50% 0.120000 +75% 0.121500 +max 0.123000 +dtype: float64 + +========================================== +Message Size Distribution +========================================== + msg_size count +0 1405 179982 + +========================================== +Node States of All Nodes over Time +SENDING:-1, IDLE:0, RECEIVING:1 +========================================== + Node-0 Node-1 Node-2 Node-3 Node-4 Node-5 +0 0 0 0 0 0 0 +1 0 0 0 0 0 0 +2 0 0 0 0 0 0 +3 0 0 0 0 0 0 +4 0 0 0 0 0 0 +... ... ... ... ... ... ... +999995 0 0 0 0 0 0 +999996 0 0 0 0 0 1 +999997 0 0 0 0 0 0 +999998 0 0 0 1 0 0 +999999 0 0 0 0 0 0 + +[1000000 rows x 6 columns] + +Saved DataFrame to all_node_states_2024-07-23T09:10:59.csv + +State Counts per Node: + Node-0 Node-1 Node-2 Node-3 Node-4 Node-5 + 0 960004 960004 960004 960004 960004 960004 + 1 29997 29997 29997 29997 29997 29997 +-1 9999 9999 9999 9999 9999 9999 + +Simulation complete! +``` + +Please note that the result of node state analysis is saved as a CSV file, as printed in the console. + +If you run the simulation again with the different parameters and want to +compare the results of two simulations, +you can calculate the hamming distance between them: +```bash +python -m cmd.hamming \ + all_node_states_2024-07-15T18:20:23.csv \ + all_node_states_2024-07-15T19:32:45.csv +``` +The output is a floating point number between 0 and 1. +If the output is 0, the results of two simulations are identical. +The closer the result is to 1, the more the two results differ from each other. +``` +Hamming distance: 0.29997 +``` diff --git a/mixnet/cmd/__init__.py b/mixnet/cmd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/cmd/hamming.py b/mixnet/cmd/hamming.py new file mode 100644 index 0000000..974a517 --- /dev/null +++ b/mixnet/cmd/hamming.py @@ -0,0 +1,42 @@ +import sys + +import pandas as pd + + +def calculate_hamming_distance(df1, df2): + """ + Caculate the hamming distance between two DataFrames + to quantify the difference between them. + """ + if df1.shape != df2.shape: + raise ValueError( + "DataFrames must have the same shape to calculate Hamming distance." + ) + + # Compare element-wise and count differences + differences = (df1 != df2).sum().sum() + return differences / df1.size # normalize the distance + + +def main(): + if len(sys.argv) != 3: + print("Usage: python hamming.py ") + sys.exit(1) + + csv_path1 = sys.argv[1] + csv_path2 = sys.argv[2] + + # Load the CSV files into DataFrames + df1 = pd.read_csv(csv_path1) + df2 = pd.read_csv(csv_path2) + + # Calculate the Hamming distance + try: + hamming_distance = calculate_hamming_distance(df1, df2) + print(f"Hamming distance: {hamming_distance}") + except ValueError as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/mixnet/cmd/main.py b/mixnet/cmd/main.py new file mode 100644 index 0000000..05b2807 --- /dev/null +++ b/mixnet/cmd/main.py @@ -0,0 +1,25 @@ +import argparse + +import usim + +from sim.config import Config +from sim.simulation import Simulation + +if __name__ == "__main__": + """ + Read a config file and run a simulation + """ + parser = argparse.ArgumentParser( + description="Run mixnet simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, required=True, help="Configuration file path" + ) + args = parser.parse_args() + + config = Config.load(args.config) + sim = Simulation(config) + usim.run(sim.run()) + + print("Simulation complete!") diff --git a/mixnet/config.ci.yaml b/mixnet/config.ci.yaml new file mode 100644 index 0000000..8cdcf63 --- /dev/null +++ b/mixnet/config.ci.yaml @@ -0,0 +1,53 @@ +simulation: + # Desired duration of the simulation in seconds + # Since the simulation uses discrete time steps, the actual duration may be longer or shorter. + duration_sec: 1000 + # Show all plots that have been drawn during the simulation + show_plots: false + +network: + # Total number of nodes in the entire network. + num_nodes: 6 + latency: + # Minimum/maximum network latency between nodes in seconds. + # A constant latency will be chosen randomly for each connection within the range [min_latency_sec, max_latency_sec]. + min_latency_sec: 0 + max_latency_sec: 0.1 + # Seed for the random number generator used to determine the network latencies. + seed: 0 + gossip: + # Expected number of peers each node must connect to if there are enough peers available in the network. + peering_degree: 3 + topology: + # Seed for the random number generator used to determine the network topology. + seed: 1 + +mix: + # Global constant transmission rate of each connection in messages per second. + transmission_rate_per_sec: 10 + # Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet. + max_message_size: 1007 + mix_path: + # Minimum number of mix nodes to be chosen for a Sphinx packet. + min_length: 5 + # Maximum number of mix nodes to be chosen for a Sphinx packet. + max_length: 5 + # Seed for the random number generator used to determine the mix path. + seed: 3 + temporal_mix: + # none | pure-coin-flipping | pure-random-sampling | permuted-coin-flipping + mix_type: "pure-coin-flipping" + # The minimum size of queue to be mixed. + # If the queue size is less than this value, noise messages are added. + min_queue_size: 5 + # Generate the seeds used to create the RNG for each queue that will be created. + seed_generator: 100 + +logic: + sender_lottery: + # Interval between lottery draws in seconds. + interval_sec: 1 + # Probability of a node being selected as a sender in each lottery draw. + probability: 0.001 + # Seed for the random number generator used to determine the lottery winners. + seed: 10 diff --git a/mixnet/framework/__init__.py b/mixnet/framework/__init__.py new file mode 100644 index 0000000..13aa683 --- /dev/null +++ b/mixnet/framework/__init__.py @@ -0,0 +1 @@ +from .framework import * diff --git a/mixnet/framework/asyncio.py b/mixnet/framework/asyncio.py new file mode 100644 index 0000000..fe82627 --- /dev/null +++ b/mixnet/framework/asyncio.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Awaitable, Coroutine, Generic, TypeVar + +from framework import framework + + +class Framework(framework.Framework): + """ + An asyncio implementation of the Framework + """ + + def __init__(self): + super().__init__() + + def queue(self) -> framework.Queue: + return Queue() + + async def sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) + + def now(self) -> float: + return time.time() + + def spawn( + self, coroutine: Coroutine[Any, Any, framework.RT] + ) -> Awaitable[framework.RT]: + return asyncio.create_task(coroutine) + + +T = TypeVar("T") + + +class Queue(framework.Queue[T]): + """ + An asyncio implementation of the Queue + """ + + def __init__(self): + super().__init__() + self._queue = asyncio.Queue() + + async def put(self, data: T) -> None: + await self._queue.put(data) + + async def get(self) -> T: + return await self._queue.get() + + def empty(self) -> bool: + return self._queue.empty() diff --git a/mixnet/framework/framework.py b/mixnet/framework/framework.py new file mode 100644 index 0000000..035fed2 --- /dev/null +++ b/mixnet/framework/framework.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import abc +from typing import Any, Awaitable, Coroutine, Generic, TypeVar + +RT = TypeVar("RT") + + +class Framework(abc.ABC): + """ + An abstract class that provides essential asynchronous functions. + This class can be implemented using any asynchronous framework (e.g., asyncio, usim). + """ + + @abc.abstractmethod + def queue(self) -> Queue: + pass + + @abc.abstractmethod + async def sleep(self, seconds: float) -> None: + pass + + @abc.abstractmethod + def now(self) -> float: + pass + + @abc.abstractmethod + def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]: + pass + + +T = TypeVar("T") + + +class Queue(abc.ABC, Generic[T]): + """ + An abstract class that provides asynchronous queue operations. + """ + + @abc.abstractmethod + async def put(self, data: T) -> None: + pass + + @abc.abstractmethod + async def get(self) -> T: + pass + + @abc.abstractmethod + def empty(self) -> bool: + pass diff --git a/mixnet/framework/usim.py b/mixnet/framework/usim.py new file mode 100644 index 0000000..67ededa --- /dev/null +++ b/mixnet/framework/usim.py @@ -0,0 +1,58 @@ +from typing import Any, Awaitable, Coroutine, TypeVar + +import usim + +from framework import framework + + +class Framework(framework.Framework): + """ + A usim implementation of the Framework for discrete-time simulation + """ + + def __init__(self, scope: usim.Scope) -> None: + super().__init__() + + # Scope is used to spawn concurrent simulation activities (coroutines). + # μSim waits until all activities spawned in the scope are done + # or until the timeout specified in the scope is reached. + # Because of the way μSim works, the scope must be created using `async with` syntax + # and be passed to this constructor. + self._scope = scope + + def queue(self) -> framework.Queue: + return Queue() + + async def sleep(self, seconds: float) -> None: + await (usim.time + seconds) + + def now(self) -> float: + # Round to milliseconds to make analysis not too heavy + return int(usim.time.now * 1000) / 1000 + + def spawn( + self, coroutine: Coroutine[Any, Any, framework.RT] + ) -> Awaitable[framework.RT]: + return self._scope.do(coroutine) + + +T = TypeVar("T") + + +class Queue(framework.Queue[T]): + """ + A usim implementation of the Queue for discrete-time simulation + """ + + def __init__(self): + super().__init__() + self._queue = usim.Queue() + + async def put(self, data: T) -> None: + await self._queue.put(data) + + async def get(self) -> T: + return await self._queue + + def empty(self) -> bool: + return len(self._queue._buffer) == 0 diff --git a/mixnet/protocol/__init__.py b/mixnet/protocol/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/protocol/config.py b/mixnet/protocol/config.py new file mode 100644 index 0000000..48a594d --- /dev/null +++ b/mixnet/protocol/config.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import random +from dataclasses import dataclass, field +from typing import List + +from pysphinx.node import X25519PublicKey +from pysphinx.sphinx import Node as SphinxNode +from pysphinx.sphinx import X25519PrivateKey + +from protocol.gossip import GossipConfig +from protocol.temporalmix import TemporalMixConfig + + +@dataclass +class GlobalConfig: + """ + Global parameters used across all nodes in the network + """ + + membership: MixMembership + transmission_rate_per_sec: int # Global Transmission Rate + max_message_size: int + max_mix_path_length: int + + +@dataclass +class NodeConfig: + """ + Node-specific parameters + """ + + private_key: X25519PrivateKey + mix_path_length: int + gossip: GossipConfig + temporal_mix: TemporalMixConfig + + +@dataclass +class MixMembership: + """ + A list of public information of nodes in the network. + We assume that this list is known to all nodes in the network. + """ + + nodes: List[NodeInfo] + rng: random.Random = field(default_factory=random.Random) + + def generate_route(self, length: int) -> list[NodeInfo]: + """ + Choose `length` nodes with replacement as a mix route. + """ + return self.rng.choices(self.nodes, k=length) + + +@dataclass +class NodeInfo: + """ + Public information of a node to be shared to all nodes in the network + """ + + public_key: X25519PublicKey + + def sphinx_node(self) -> SphinxNode: + dummy_node_addr = bytes(32) + return SphinxNode(self.public_key, dummy_node_addr) diff --git a/mixnet/protocol/connection.py b/mixnet/protocol/connection.py new file mode 100644 index 0000000..11a7c1f --- /dev/null +++ b/mixnet/protocol/connection.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import abc +import random + +from framework import Framework, Queue +from protocol.temporalmix import PureCoinFlipppingQueue, TemporalMix, TemporalMixConfig + + +class SimplexConnection(abc.ABC): + """ + An abstract class for a simplex connection that can send and receive data in one direction + """ + + @abc.abstractmethod + async def send(self, data: bytes) -> None: + pass + + @abc.abstractmethod + async def recv(self) -> bytes: + pass + + +class LocalSimplexConnection(SimplexConnection): + """ + A simplex connection that doesn't have any network latency. + Data sent through this connection can be immediately received from the other end. + """ + + def __init__(self, framework: Framework): + self.queue: Queue[bytes] = framework.queue() + + async def send(self, data: bytes) -> None: + await self.queue.put(data) + + async def recv(self) -> bytes: + return await self.queue.get() + + +class DuplexConnection: + """ + A duplex connection in which data can be transmitted and received simultaneously in both directions. + This is to mimic duplex communication in a real network (such as TCP or QUIC). + """ + + def __init__(self, inbound: SimplexConnection, outbound: SimplexConnection): + self.inbound = inbound + self.outbound = outbound + + async def recv(self) -> bytes: + return await self.inbound.recv() + + async def send(self, packet: bytes): + await self.outbound.send(packet) + + +class MixSimplexConnection(SimplexConnection): + """ + Wraps a SimplexConnection to add a transmission rate and noise to the connection. + """ + + def __init__( + self, + framework: Framework, + conn: SimplexConnection, + transmission_rate_per_sec: int, + noise_msg: bytes, + temporal_mix_config: TemporalMixConfig, + ): + self.framework = framework + self.queue: Queue[bytes] = TemporalMix.queue( + temporal_mix_config, framework, noise_msg + ) + self.conn = conn + self.transmission_rate_per_sec = transmission_rate_per_sec + self.task = framework.spawn(self.__run()) + + async def __run(self): + while True: + await self.framework.sleep(1 / self.transmission_rate_per_sec) + msg = await self.queue.get() + await self.conn.send(msg) + + async def send(self, data: bytes) -> None: + await self.queue.put(data) + + async def recv(self) -> bytes: + return await self.conn.recv() diff --git a/mixnet/protocol/error.py b/mixnet/protocol/error.py new file mode 100644 index 0000000..96d0c41 --- /dev/null +++ b/mixnet/protocol/error.py @@ -0,0 +1,2 @@ +class PeeringDegreeReached(Exception): + pass diff --git a/mixnet/protocol/gossip.py b/mixnet/protocol/gossip.py new file mode 100644 index 0000000..ff88e2f --- /dev/null +++ b/mixnet/protocol/gossip.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from enum import Enum +from typing import Awaitable, Callable, Self + +from framework import Framework +from protocol.connection import ( + DuplexConnection, + MixSimplexConnection, + SimplexConnection, +) +from protocol.error import PeeringDegreeReached + + +@dataclass +class GossipConfig: + # Expected number of peers each node must connect to if there are enough peers available in the network. + peering_degree: int + + +class Gossip: + """ + A gossip channel that broadcasts messages to all connected peers. + Peers are connected via DuplexConnection. + """ + + def __init__( + self, + framework: Framework, + config: GossipConfig, + handler: Callable[[bytes], Awaitable[None]], + ): + self.framework = framework + self.config = config + self.conns: list[DuplexConnection] = [] + # A handler to process inbound messages. + self.handler = handler + self.packet_cache: set[bytes] = set() + # A set just for gathering a reference of tasks to prevent them from being garbage collected. + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + self.tasks: set[Awaitable] = set() + + def can_accept_conn(self) -> bool: + return len(self.conns) < self.config.peering_degree + + def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): + if not self.can_accept_conn(): + # For simplicity of the spec, reject the connection if the peering degree is reached. + raise PeeringDegreeReached() + + conn = DuplexConnection( + inbound, + outbound, + ) + self.conns.append(conn) + task = self.framework.spawn(self.__process_inbound_conn(conn)) + self.tasks.add(task) + + async def __process_inbound_conn(self, conn: DuplexConnection): + while True: + msg = await conn.recv() + if self.__check_update_cache(msg): + continue + await self.process_inbound_msg(msg) + + async def process_inbound_msg(self, msg: bytes): + await self.gossip(msg) + await self.handler(msg) + + async def gossip(self, msg: bytes): + """ + Gossip a message to all connected peers. + """ + for conn in self.conns: + await conn.send(msg) + + def __check_update_cache(self, packet: bytes) -> bool: + """ + Add a message to the cache, and return True if the message was already in the cache. + """ + hash = hashlib.sha256(packet).digest() + if hash in self.packet_cache: + return True + self.packet_cache.add(hash) + return False diff --git a/mixnet/protocol/node.py b/mixnet/protocol/node.py new file mode 100644 index 0000000..5818694 --- /dev/null +++ b/mixnet/protocol/node.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import Awaitable, Callable + +from pysphinx.sphinx import ( + ProcessedFinalHopPacket, + ProcessedForwardHopPacket, + SphinxPacket, +) + +from framework import Framework, Queue +from protocol.config import GlobalConfig, NodeConfig +from protocol.connection import SimplexConnection +from protocol.error import PeeringDegreeReached +from protocol.gossip import Gossip +from protocol.nomssip import Nomssip, NomssipConfig +from protocol.sphinx import SphinxPacketBuilder + + +class Node: + """ + This represents any node in the network, which: + - generates/gossips mix messages (Sphinx packets) + - performs cryptographic mix (unwrapping Sphinx packets) + - generates noise + """ + + def __init__( + self, + framework: Framework, + config: NodeConfig, + global_config: GlobalConfig, + # A handler called when a node receives a broadcasted message originated from the last mix. + broadcasted_msg_handler: Callable[[bytes], Awaitable[None]], + # An optional handler only for the simulation, + # which is called when a message is fully recovered by the last mix + # and returns a new message to be broadcasted. + recovered_msg_handler: Callable[[bytes], Awaitable[bytes]] | None = None, + ): + self.framework = framework + self.config = config + self.global_config = global_config + self.nomssip = Nomssip( + framework, + NomssipConfig( + config.gossip.peering_degree, + global_config.transmission_rate_per_sec, + self.__calculate_message_size(global_config), + config.temporal_mix, + ), + self.__process_msg, + ) + self.broadcast = Gossip(framework, config.gossip, broadcasted_msg_handler) + self.recovered_msg_handler = recovered_msg_handler + + @staticmethod + def __calculate_message_size(global_config: GlobalConfig) -> int: + """ + Calculate the actual message size to be gossiped, which depends on the maximum length of mix path. + """ + sample_sphinx_packet, _ = SphinxPacketBuilder.build( + bytes(global_config.max_message_size), + global_config, + global_config.max_mix_path_length, + ) + return len(sample_sphinx_packet.bytes()) + + async def __process_msg(self, msg: bytes) -> None: + """ + A handler to process messages received via Nomssip channel + """ + sphinx_packet = SphinxPacket.from_bytes( + msg, self.global_config.max_mix_path_length + ) + result = await self.__process_sphinx_packet(sphinx_packet) + match result: + case SphinxPacket(): + # Gossip the next Sphinx packet + await self.nomssip.gossip(result.bytes()) + case bytes(): + if self.recovered_msg_handler is not None: + result = await self.recovered_msg_handler(result) + # Broadcast the message fully recovered from Sphinx packets + await self.broadcast.gossip(result) + case None: + return + + async def __process_sphinx_packet( + self, packet: SphinxPacket + ) -> SphinxPacket | bytes | None: + """ + Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible + """ + try: + processed = packet.process(self.config.private_key) + match processed: + case ProcessedForwardHopPacket(): + return processed.next_packet + case ProcessedFinalHopPacket(): + return processed.payload.recover_plain_playload() + except ValueError: + # Return nothing, if it cannot be unwrapped by the private key of this node. + return None + + def connect_mix( + self, + peer: Node, + inbound_conn: SimplexConnection, + outbound_conn: SimplexConnection, + ): + Node.__connect(self.nomssip, peer.nomssip, inbound_conn, outbound_conn) + + def connect_broadcast( + self, + peer: Node, + inbound_conn: SimplexConnection, + outbound_conn: SimplexConnection, + ): + Node.__connect(self.broadcast, peer.broadcast, inbound_conn, outbound_conn) + + @staticmethod + def __connect( + self_channel: Gossip, + peer_channel: Gossip, + inbound_conn: SimplexConnection, + outbound_conn: SimplexConnection, + ): + """ + Establish a duplex connection with a peer node. + """ + if not self_channel.can_accept_conn() or not peer_channel.can_accept_conn(): + raise PeeringDegreeReached() + + # Register a duplex connection for its own use + self_channel.add_conn(inbound_conn, outbound_conn) + # Register a duplex connection for the peer + peer_channel.add_conn(outbound_conn, inbound_conn) + + async def send_message(self, msg: bytes): + """ + Build a Sphinx packet and gossip it to all connected peers. + """ + # Here, we handle the case in which a msg is split into multiple Sphinx packets. + # But, in practice, we expect a message to be small enough to fit in a single Sphinx packet. + sphinx_packet, _ = SphinxPacketBuilder.build( + msg, + self.global_config, + self.config.mix_path_length, + ) + await self.nomssip.gossip(sphinx_packet.bytes()) diff --git a/mixnet/protocol/nomssip.py b/mixnet/protocol/nomssip.py new file mode 100644 index 0000000..b5e08f6 --- /dev/null +++ b/mixnet/protocol/nomssip.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import hashlib +import random +from dataclasses import dataclass +from enum import Enum +from typing import Awaitable, Callable, Self, override + +from framework import Framework +from protocol.connection import ( + DuplexConnection, + MixSimplexConnection, + SimplexConnection, +) +from protocol.error import PeeringDegreeReached +from protocol.gossip import Gossip, GossipConfig +from protocol.temporalmix import TemporalMixConfig + + +@dataclass +class NomssipConfig(GossipConfig): + transmission_rate_per_sec: int + msg_size: int + temporal_mix: TemporalMixConfig + + +class Nomssip(Gossip): + """ + A NomMix gossip channel that extends the Gossip channel + by adding global transmission rate and noise generation. + """ + + def __init__( + self, + framework: Framework, + config: NomssipConfig, + handler: Callable[[bytes], Awaitable[None]], + ): + super().__init__(framework, config, handler) + self.config = config + + @override + def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): + noise_packet = FlaggedPacket( + FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size) + ).bytes() + super().add_conn( + inbound, + MixSimplexConnection( + self.framework, + outbound, + self.config.transmission_rate_per_sec, + noise_packet, + self.config.temporal_mix, + ), + ) + + @override + async def process_inbound_msg(self, msg: bytes): + packet = FlaggedPacket.from_bytes(msg) + match packet.flag: + case FlaggedPacket.Flag.NOISE: + # Drop noise packet + return + case FlaggedPacket.Flag.REAL: + await self.__gossip_flagged_packet(packet) + await self.handler(packet.message) + + @override + async def gossip(self, msg: bytes): + """ + Gossip a message to all connected peers with prepending a message flag + """ + # The message size must be fixed. + assert len(msg) == self.config.msg_size + + packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg) + await self.__gossip_flagged_packet(packet) + + async def __gossip_flagged_packet(self, packet: FlaggedPacket): + """ + An internal method to send a flagged packet to all connected peers + """ + await super().gossip(packet.bytes()) + + +class FlaggedPacket: + class Flag(Enum): + REAL = b"\x00" + NOISE = b"\x01" + + def __init__(self, flag: Flag, message: bytes): + self.flag = flag + self.message = message + + def bytes(self) -> bytes: + return self.flag.value + self.message + + @classmethod + def from_bytes(cls, packet: bytes) -> Self: + """ + Parse a flagged packet from bytes + """ + if len(packet) < 1: + raise ValueError("Invalid message format") + return cls(cls.Flag(packet[:1]), packet[1:]) diff --git a/mixnet/protocol/sphinx.py b/mixnet/protocol/sphinx.py new file mode 100644 index 0000000..a219e77 --- /dev/null +++ b/mixnet/protocol/sphinx.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import List, Tuple + +from pysphinx.sphinx import SphinxPacket + +from protocol.config import GlobalConfig, NodeInfo + + +class SphinxPacketBuilder: + @staticmethod + def build( + message: bytes, global_config: GlobalConfig, path_len: int + ) -> Tuple[SphinxPacket, List[NodeInfo]]: + if path_len <= 0: + raise ValueError("path_len must be greater than 0") + if len(message) > global_config.max_message_size: + raise ValueError("message is too long") + + route = global_config.membership.generate_route(path_len) + # We don't need the destination (defined in the Loopix Sphinx spec) + # because the last mix will broadcast the fully unwrapped message. + # Later, we will optimize the Sphinx according to our requirements. + dummy_destination = route[-1] + + packet = SphinxPacket.build( + message, + route=[mixnode.sphinx_node() for mixnode in route], + destination=dummy_destination.sphinx_node(), + max_route_length=global_config.max_mix_path_length, + max_plain_payload_size=global_config.max_message_size, + ) + return (packet, route) diff --git a/mixnet/protocol/temporalmix.py b/mixnet/protocol/temporalmix.py new file mode 100644 index 0000000..6775528 --- /dev/null +++ b/mixnet/protocol/temporalmix.py @@ -0,0 +1,177 @@ +import random +from abc import abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import TypeVar + +from framework.framework import Framework, Queue + + +class TemporalMixType(Enum): + NONE = "none" + PURE_COIN_FLIPPING = "pure-coin-flipping" + PURE_RANDOM_SAMPLING = "pure-random-sampling" + PERMUTED_COIN_FLIPPING = "permuted-coin-flipping" + NOISY_COIN_FLIPPING = "noisy-coin-flipping" + + +@dataclass +class TemporalMixConfig: + mix_type: TemporalMixType + # The minimum size of queue to be mixed. + # If the queue size is less than this value, noise messages are added. + min_queue_size: int + # Generate the seeds used to create the RNG for each queue that will be created. + seed_generator: random.Random + + def __post_init__(self): + assert self.seed_generator is not None + assert self.min_queue_size > 0 + + +T = TypeVar("T") + + +class TemporalMix: + @staticmethod + def queue( + config: TemporalMixConfig, framework: Framework, noise_msg: T + ) -> Queue[T]: + match config.mix_type: + case TemporalMixType.NONE: + return NonMixQueue(framework, noise_msg) + case TemporalMixType.PURE_COIN_FLIPPING: + return PureCoinFlipppingQueue( + config.min_queue_size, + random.Random(config.seed_generator.random()), + noise_msg, + ) + case TemporalMixType.PURE_RANDOM_SAMPLING: + return PureRandomSamplingQueue( + config.min_queue_size, + random.Random(config.seed_generator.random()), + noise_msg, + ) + case TemporalMixType.PERMUTED_COIN_FLIPPING: + return PermutedCoinFlipppingQueue( + config.min_queue_size, + random.Random(config.seed_generator.random()), + noise_msg, + ) + case TemporalMixType.NOISY_COIN_FLIPPING: + return NoisyCoinFlippingQueue( + random.Random(config.seed_generator.random()), + noise_msg, + ) + case _: + raise ValueError(f"Unknown mix type: {config.mix_type}") + + +class NonMixQueue(Queue[T]): + """ + Queue without temporal mixing. Only have the noise generation when the queue is empty. + """ + + def __init__(self, framework: Framework, noise_msg: T): + self.__queue = framework.queue() + self.__noise_msg = noise_msg + + async def put(self, data: T) -> None: + await self.__queue.put(data) + + async def get(self) -> T: + if self.__queue.empty(): + return self.__noise_msg + else: + return await self.__queue.get() + + def empty(self) -> bool: + return self.__queue.empty() + + +class MixQueue(Queue[T]): + def __init__(self, rng: random.Random, noise_msg: T): + super().__init__() + # Assuming that simulations run in a single thread + self._queue: list[T] = [] + self._rng = rng + self._noise_msg = noise_msg + + async def put(self, data: T) -> None: + self._queue.append(data) + + @abstractmethod + async def get(self) -> T: + pass + + def empty(self) -> bool: + return len(self._queue) == 0 + + +class MinSizeMixQueue(MixQueue[T]): + def __init__(self, min_pool_size: int, rng: random.Random, noise_msg: T): + super().__init__(rng, noise_msg) + self._mix_pool_size = min_pool_size + + @abstractmethod + async def get(self) -> T: + while len(self._queue) < self._mix_pool_size: + self._queue.append(self._noise_msg) + + # Subclass must implement this method + pass + + +class PureCoinFlipppingQueue(MinSizeMixQueue[T]): + async def get(self) -> T: + await super().get() + + while True: + for i in range(len(self._queue)): + # coin-flipping + if self._rng.randint(0, 1) == 1: + # After removing a message from the position `i`, we don't fill up the position. + # Instead, the queue is always filled from the back. + return self._queue.pop(i) + + +class PureRandomSamplingQueue(MinSizeMixQueue[T]): + async def get(self) -> T: + await super().get() + + i = self._rng.randint(0, len(self._queue) - 1) + # After removing a message from the position `i`, we don't fill up the position. + # Instead, the queue is always filled from the back. + return self._queue.pop(i) + + +class PermutedCoinFlipppingQueue(MinSizeMixQueue[T]): + async def get(self) -> T: + await super().get() + + self._rng.shuffle(self._queue) + + while True: + for i in range(len(self._queue)): + # coin-flipping + if self._rng.randint(0, 1) == 1: + # After removing a message from the position `i`, we don't fill up the position. + # Instead, the queue is always filled from the back. + return self._queue.pop(i) + + +class NoisyCoinFlippingQueue(MixQueue[T]): + async def get(self) -> T: + if len(self._queue) == 0: + return self._noise_msg + + while True: + for i in range(len(self._queue)): + # coin-flipping + if self._rng.randint(0, 1) == 1: + # After removing a message from the position `i`, we don't fill up the position. + # Instead, the queue is always filled from the back. + return self._queue.pop(i) + else: + if i == 0: + return self._noise_msg diff --git a/mixnet/protocol/test_node.py b/mixnet/protocol/test_node.py new file mode 100644 index 0000000..00f2e0b --- /dev/null +++ b/mixnet/protocol/test_node.py @@ -0,0 +1,58 @@ +from unittest import IsolatedAsyncioTestCase + +import framework.asyncio as asynciofw +from framework.framework import Queue +from protocol.connection import LocalSimplexConnection +from protocol.node import Node +from protocol.test_utils import ( + init_mixnet_config, +) + + +class TestNode(IsolatedAsyncioTestCase): + async def test_node(self): + framework = asynciofw.Framework() + global_config, node_configs, _ = init_mixnet_config(10) + + queue: Queue[bytes] = framework.queue() + + async def broadcasted_msg_handler(msg: bytes) -> None: + await queue.put(msg) + + nodes = [ + Node(framework, node_config, global_config, broadcasted_msg_handler) + for node_config in node_configs + ] + for i, node in enumerate(nodes): + try: + node.connect_mix( + nodes[(i + 1) % len(nodes)], + LocalSimplexConnection(framework), + LocalSimplexConnection(framework), + ) + node.connect_broadcast( + nodes[(i + 1) % len(nodes)], + LocalSimplexConnection(framework), + LocalSimplexConnection(framework), + ) + except ValueError as e: + print(e) + + await nodes[0].send_message(b"block selection") + + # Wait for all nodes to receive the broadcast + num_nodes_received_broadcast = 0 + timeout = 15 + for _ in range(timeout): + await framework.sleep(1) + + while not queue.empty(): + self.assertEqual(b"block selection", await queue.get()) + num_nodes_received_broadcast += 1 + + if num_nodes_received_broadcast == len(nodes): + break + + self.assertEqual(len(nodes), num_nodes_received_broadcast) + + # TODO: check noise diff --git a/mixnet/protocol/test_sphinx.py b/mixnet/protocol/test_sphinx.py new file mode 100644 index 0000000..39dd233 --- /dev/null +++ b/mixnet/protocol/test_sphinx.py @@ -0,0 +1,66 @@ +from random import randint +from typing import cast +from unittest import TestCase + +from pysphinx.sphinx import ( + ProcessedFinalHopPacket, + ProcessedForwardHopPacket, +) + +from protocol.sphinx import SphinxPacketBuilder +from protocol.test_utils import init_mixnet_config + + +class TestSphinxPacketBuilder(TestCase): + def test_builder(self): + global_config, _, key_map = init_mixnet_config(10) + msg = self.random_bytes(500) + packet, route = SphinxPacketBuilder.build(msg, global_config, 3) + self.assertEqual(3, len(route)) + + processed = packet.process(key_map[route[0].public_key.public_bytes_raw()]) + self.assertIsInstance(processed, ProcessedForwardHopPacket) + processed = cast(ProcessedForwardHopPacket, processed).next_packet.process( + key_map[route[1].public_key.public_bytes_raw()] + ) + self.assertIsInstance(processed, ProcessedForwardHopPacket) + processed = cast(ProcessedForwardHopPacket, processed).next_packet.process( + key_map[route[2].public_key.public_bytes_raw()] + ) + self.assertIsInstance(processed, ProcessedFinalHopPacket) + recovered = cast( + ProcessedFinalHopPacket, processed + ).payload.recover_plain_playload() + self.assertEqual(msg, recovered) + + def test_max_message_size(self): + global_config, _, _ = init_mixnet_config(10, max_message_size=2000) + mix_path_length = global_config.max_mix_path_length + + packet1, _ = SphinxPacketBuilder.build( + self.random_bytes(1500), global_config, mix_path_length + ) + packet2, _ = SphinxPacketBuilder.build( + self.random_bytes(2000), global_config, mix_path_length + ) + self.assertEqual(len(packet1.bytes()), len(packet2.bytes())) + + msg = self.random_bytes(2001) + with self.assertRaises(ValueError): + _ = SphinxPacketBuilder.build(msg, global_config, mix_path_length) + + def test_max_mix_path_length(self): + global_config, _, _ = init_mixnet_config(10, max_mix_path_length=2) + msg = self.random_bytes(global_config.max_message_size) + + packet1, _ = SphinxPacketBuilder.build(msg, global_config, 1) + packet2, _ = SphinxPacketBuilder.build(msg, global_config, 2) + self.assertEqual(len(packet1.bytes()), len(packet2.bytes())) + + with self.assertRaises(ValueError): + _ = SphinxPacketBuilder.build(msg, global_config, 3) + + @staticmethod + def random_bytes(size: int) -> bytes: + assert size >= 0 + return bytes([randint(0, 255) for _ in range(size)]) diff --git a/mixnet/protocol/test_temporalmix.py b/mixnet/protocol/test_temporalmix.py new file mode 100644 index 0000000..c856ff8 --- /dev/null +++ b/mixnet/protocol/test_temporalmix.py @@ -0,0 +1,103 @@ +import random +from unittest import IsolatedAsyncioTestCase + +import framework.asyncio as asynciofw +from framework.framework import Queue +from protocol.temporalmix import ( + NoisyCoinFlippingQueue, + NonMixQueue, + PermutedCoinFlipppingQueue, + PureCoinFlipppingQueue, + PureRandomSamplingQueue, + TemporalMix, + TemporalMixConfig, + TemporalMixType, +) + + +class TestTemporalMix(IsolatedAsyncioTestCase): + async def test_queue_builder(self): + # Check if the queue builder generates the correct queue type + for mix_type in TemporalMixType: + await self.__test_queue_builder(mix_type) + + async def __test_queue_builder(self, mix_type: TemporalMixType): + queue: Queue[int] = TemporalMix.queue( + TemporalMixConfig(mix_type, 4, random.Random(0)), + asynciofw.Framework(), + -1, + ) + match mix_type: + case TemporalMixType.NONE: + self.assertIsInstance(queue, NonMixQueue) + case TemporalMixType.PURE_COIN_FLIPPING: + self.assertIsInstance(queue, PureCoinFlipppingQueue) + case TemporalMixType.PURE_RANDOM_SAMPLING: + self.assertIsInstance(queue, PureRandomSamplingQueue) + case TemporalMixType.PERMUTED_COIN_FLIPPING: + self.assertIsInstance(queue, PermutedCoinFlipppingQueue) + case TemporalMixType.NOISY_COIN_FLIPPING: + self.assertIsInstance(queue, NoisyCoinFlippingQueue) + case _: + self.fail(f"Unknown mix type: {mix_type}") + + async def test_non_mix_queue(self): + queue: Queue[int] = TemporalMix.queue( + TemporalMixConfig(TemporalMixType.NONE, 4, random.Random(0)), + asynciofw.Framework(), + -1, + ) + + # Check if queue is FIFO + await queue.put(0) + await queue.put(1) + self.assertEqual(0, await queue.get()) + self.assertEqual(1, await queue.get()) + + # Check if noise is generated when queue is empty + self.assertEqual(-1, await queue.get()) + + # FIFO again + await queue.put(2) + self.assertEqual(2, await queue.get()) + await queue.put(3) + self.assertEqual(3, await queue.get()) + + async def test_pure_coin_flipping_queue(self): + await self.__test_mix_queue(TemporalMixType.PURE_COIN_FLIPPING) + + async def test_pure_random_sampling(self): + await self.__test_mix_queue(TemporalMixType.PURE_RANDOM_SAMPLING) + + async def test_permuted_coin_flipping_queue(self): + await self.__test_mix_queue(TemporalMixType.PERMUTED_COIN_FLIPPING) + + async def test_noisy_coin_flipping_queue(self): + await self.__test_mix_queue(TemporalMixType.NOISY_COIN_FLIPPING) + + async def __test_mix_queue(self, mix_type: TemporalMixType): + queue: Queue[int] = TemporalMix.queue( + TemporalMixConfig(mix_type, 4, random.Random(0)), + asynciofw.Framework(), + -1, + ) + + # Check if noise is generated when queue is empty + self.assertEqual(-1, await queue.get()) + + # Put only 2 elements even though the min queue size is 4 + await queue.put(0) + await queue.put(1) + + # Wait until 2 elements are returned from the queue + waiting = {0, 1} + while len(waiting) > 0: + e = await queue.get() + if e in waiting: + waiting.remove(e) + else: + # Check if it's the noise + self.assertEqual(-1, e) + + # Check if noise is generated when there is no real message inserted + self.assertEqual(-1, await queue.get()) diff --git a/mixnet/protocol/test_utils.py b/mixnet/protocol/test_utils.py new file mode 100644 index 0000000..5bc3cd9 --- /dev/null +++ b/mixnet/protocol/test_utils.py @@ -0,0 +1,46 @@ +import random + +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey + +from protocol.config import ( + GlobalConfig, + MixMembership, + NodeConfig, + NodeInfo, +) +from protocol.gossip import GossipConfig +from protocol.nomssip import TemporalMixConfig +from protocol.temporalmix import TemporalMixType + + +def init_mixnet_config( + num_nodes: int, + max_message_size: int = 512, + max_mix_path_length: int = 3, +) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]: + gossip_config = GossipConfig(peering_degree=6) + node_configs = [ + NodeConfig( + X25519PrivateKey.generate(), + max_mix_path_length, + gossip_config, + TemporalMixConfig(TemporalMixType.PURE_COIN_FLIPPING, 3, random.Random()), + ) + for _ in range(num_nodes) + ] + global_config = GlobalConfig( + MixMembership( + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ] + ), + transmission_rate_per_sec=3, + max_message_size=max_message_size, + max_mix_path_length=max_mix_path_length, + ) + key_map = { + node_config.private_key.public_key().public_bytes_raw(): node_config.private_key + for node_config in node_configs + } + return (global_config, node_configs, key_map) diff --git a/mixnet/requirements.txt b/mixnet/requirements.txt new file mode 100644 index 0000000..25753ba --- /dev/null +++ b/mixnet/requirements.txt @@ -0,0 +1,6 @@ +usim==0.4.4 +pysphinx==0.0.5 +dacite==1.8.1 +pandas==2.2.2 +matplotlib==3.9.1 +PyYAML==6.0.1 diff --git a/mixnet/sim/__init__.py b/mixnet/sim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/sim/config.py b/mixnet/sim/config.py new file mode 100644 index 0000000..b078d5c --- /dev/null +++ b/mixnet/sim/config.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import hashlib +import random +from dataclasses import dataclass + +import dacite +import yaml +from pysphinx.sphinx import X25519PrivateKey + +from protocol.config import NodeConfig +from protocol.gossip import GossipConfig +from protocol.temporalmix import TemporalMixConfig, TemporalMixType + + +@dataclass +class Config: + simulation: SimulationConfig + network: NetworkConfig + logic: LogicConfig + mix: MixConfig + + @classmethod + def load(cls, yaml_path: str) -> Config: + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + return dacite.from_dict( + data_class=Config, + data=data, + config=dacite.Config( + type_hooks={ + random.Random: seed_to_random, + TemporalMixType: str_to_temporal_mix_type, + }, + strict=True, + ), + ) + + def node_configs(self) -> list[NodeConfig]: + return [ + NodeConfig( + self.__gen_private_key(i), + self.mix.mix_path.random_length(), + self.network.gossip, + self.mix.temporal_mix, + ) + for i in range(self.network.num_nodes) + ] + + def __gen_private_key(self, node_idx: int) -> X25519PrivateKey: + return X25519PrivateKey.from_private_bytes( + hashlib.sha256(node_idx.to_bytes(4, "big")).digest()[:32] + ) + + +@dataclass +class SimulationConfig: + # Desired duration of the simulation in seconds + # Since the simulation uses discrete time steps, the actual duration may be longer or shorter. + duration_sec: int + # Show all plots that have been drawn during the simulation + show_plots: bool + + def __post_init__(self): + assert self.duration_sec > 0 + + +@dataclass +class NetworkConfig: + # Total number of nodes in the entire network. + num_nodes: int + latency: LatencyConfig + gossip: GossipConfig + topology: TopologyConfig + + def __post_init__(self): + assert self.num_nodes > 0 + + +@dataclass +class LatencyConfig: + # Minimum/maximum network latency between nodes in seconds. + # A constant latency will be chosen randomly for each connection within the range [min_latency_sec, max_latency_sec]. + min_latency_sec: float + max_latency_sec: float + # Seed for the random number generator used to determine the network latencies. + seed: random.Random + + def __post_init__(self): + assert 0 <= self.min_latency_sec <= self.max_latency_sec + assert self.seed is not None + + def random_latency(self) -> float: + # round to milliseconds to make analysis not too heavy + return round(self.seed.uniform(self.min_latency_sec, self.max_latency_sec), 3) + + +@dataclass +class TopologyConfig: + # Seed for the random number generator used to determine the network topology. + seed: random.Random + + def __post_init__(self): + assert self.seed is not None + + +@dataclass +class MixConfig: + # Global constant transmission rate of each connection in messages per second. + transmission_rate_per_sec: int + # Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet. + max_message_size: int + mix_path: MixPathConfig + temporal_mix: TemporalMixConfig + + def __post_init__(self): + assert self.transmission_rate_per_sec > 0 + assert self.max_message_size > 0 + + +@dataclass +class MixPathConfig: + # Minimum number of mix nodes to be chosen for a Sphinx packet. + min_length: int + # Maximum number of mix nodes to be chosen for a Sphinx packet. + max_length: int + # Seed for the random number generator used to determine the mix path. + seed: random.Random + + def __post_init__(self): + assert 0 < self.min_length <= self.max_length + assert self.seed is not None + + def random_length(self) -> int: + return self.seed.randint(self.min_length, self.max_length) + + +@dataclass +class LogicConfig: + sender_lottery: LotteryConfig + + +@dataclass +class LotteryConfig: + # Interval between lottery draws in seconds. + interval_sec: float + # Probability of a node being selected as a sender in each lottery draw. + probability: float + # Seed for the random number generator used to determine the lottery winners. + seed: random.Random + + def __post_init__(self): + assert self.interval_sec > 0 + assert self.probability >= 0 + assert self.seed is not None + + +def seed_to_random(seed: int) -> random.Random: + return random.Random(seed) + + +def str_to_temporal_mix_type(val: str) -> TemporalMixType: + return TemporalMixType(val) diff --git a/mixnet/sim/connection.py b/mixnet/sim/connection.py new file mode 100644 index 0000000..6a8911c --- /dev/null +++ b/mixnet/sim/connection.py @@ -0,0 +1,139 @@ +import math +from collections import Counter +from typing import Awaitable + +import pandas +from typing_extensions import override + +from framework import Framework, Queue +from protocol.connection import SimplexConnection +from sim.config import LatencyConfig, NetworkConfig +from sim.state import NodeState + + +class MeteredRemoteSimplexConnection(SimplexConnection): + """ + A simplex connection implementation that simulates network latency and measures bandwidth usages. + """ + + def __init__( + self, + config: LatencyConfig, + framework: Framework, + meter_start_time: float, + ): + self.framework = framework + # A connection has a random constant latency + self.latency = config.random_latency() + # A queue of tuple(timestamp, msg) where a sender puts messages to be sent + self.send_queue: Queue[tuple[float, bytes]] = framework.queue() + # A task that reads messages from send_queue, and puts them to recv_queue. + # Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message. + self.relayer = framework.spawn(self.__run_relayer()) + # A queue where a receiver gets messages + self.recv_queue: Queue[bytes] = framework.queue() + # To measure bandwidth usages + self.meter_start_time = meter_start_time + self.send_meters: list[int] = [] + self.recv_meters: list[int] = [] + + async def send(self, data: bytes) -> None: + await self.send_queue.put((self.framework.now(), data)) + self.on_sending(data) + + async def recv(self) -> bytes: + return await self.recv_queue.get() + + async def __run_relayer(self): + """ + A task that reads messages from send_queue, and puts them to recv_queue. + Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message. + """ + while True: + sent_time, data = await self.send_queue.get() + # Simulate network latency + delay = self.latency - (self.framework.now() - sent_time) + if delay > 0: + await self.framework.sleep(delay) + + # Relay msg to the recv_queue. + # Update related statistics before msg is read from recv_queue by the receiver + # because the time at which enters the node is important when viewed from the outside. + self.on_receiving(data) + await self.recv_queue.put(data) + + def on_sending(self, data: bytes) -> None: + """ + Update statistics when sending a message + """ + self.__update_meter(self.send_meters, len(data)) + + def on_receiving(self, data: bytes) -> None: + """ + Update statistics when receiving a message + """ + self.__update_meter(self.recv_meters, len(data)) + + def __update_meter(self, meters: list[int], size: int): + """ + Accumulates the bandwidth usage in the current time slot (seconds). + """ + slot = math.floor(self.framework.now() - self.meter_start_time) + assert slot >= len(meters) - 1 + # Fill zeros for the empty time slots + meters.extend([0] * (slot - len(meters) + 1)) + meters[-1] += size + + def sending_bandwidths(self) -> pandas.Series: + """ + Returns the accumulated sending bandwidth usage over time + """ + return self.__bandwidths(self.send_meters) + + def receiving_bandwidths(self) -> pandas.Series: + """ + Returns the accumulated receiving bandwidth usage over time + """ + return self.__bandwidths(self.recv_meters) + + def __bandwidths(self, meters: list[int]) -> pandas.Series: + return pandas.Series(meters, name="bandwidth") + + +class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection): + """ + An extension of MeteredRemoteSimplexConnection that is observed by passive observer. + The observer monitors the node states of the sender and receiver and message sizes. + """ + + def __init__( + self, + config: LatencyConfig, + framework: Framework, + meter_start_time: float, + send_node_states: list[NodeState], + recv_node_states: list[NodeState], + ): + super().__init__(config, framework, meter_start_time) + + # To measure node states over time + self.send_node_states = send_node_states + self.recv_node_states = recv_node_states + # To measure the size of messages sent via this connection + self.msg_sizes: Counter[int] = Counter() + + @override + def on_sending(self, data: bytes) -> None: + super().on_sending(data) + self.__update_node_state(self.send_node_states, NodeState.SENDING) + self.msg_sizes.update([len(data)]) + + @override + def on_receiving(self, data: bytes) -> None: + super().on_receiving(data) + self.__update_node_state(self.recv_node_states, NodeState.RECEIVING) + + def __update_node_state(self, node_states: list[NodeState], state: NodeState): + # The time unit of node states is milliseconds + ms = math.floor(self.framework.now() * 1000) + node_states[ms] = state diff --git a/mixnet/sim/message.py b/mixnet/sim/message.py new file mode 100644 index 0000000..cfc939a --- /dev/null +++ b/mixnet/sim/message.py @@ -0,0 +1,42 @@ +import pickle +from dataclasses import dataclass +from typing import Self + + +@dataclass +class Message: + """ + A message structure for simulation, which will be sent through mix nodes + and eventually broadcasted to all nodes in the network. + + The `id` must ensure the uniqueness of the message. + """ + + created_at: float + id: int + body: bytes + + def __bytes__(self): + return pickle.dumps(self) + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + return pickle.loads(data) + + def __hash__(self) -> int: + return self.id + + +class UniqueMessageBuilder: + """ + Builds a unique message with an incremental ID, + assuming that the simulation is run in a single thread. + """ + + def __init__(self): + self.next_id = 0 + + def next(self, created_at: float, body: bytes) -> Message: + msg = Message(created_at, self.next_id, body) + self.next_id += 1 + return msg diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py new file mode 100644 index 0000000..84453c4 --- /dev/null +++ b/mixnet/sim/simulation.py @@ -0,0 +1,198 @@ +from dataclasses import asdict, dataclass +from pprint import pprint +from typing import Self + +import usim +from matplotlib import pyplot + +import framework.usim as usimfw +from framework import Framework +from protocol.config import GlobalConfig, MixMembership, NodeInfo +from protocol.node import Node, PeeringDegreeReached +from sim.config import Config +from sim.connection import ( + MeteredRemoteSimplexConnection, + ObservedMeteredRemoteSimplexConnection, +) +from sim.message import Message, UniqueMessageBuilder +from sim.state import NodeState, NodeStateTable +from sim.stats import ConnectionStats, DisseminationTime +from sim.topology import build_full_random_topology + + +class Simulation: + """ + Manages the entire cycle of simulation: initialization, running, and analysis. + """ + + def __init__(self, config: Config): + self.config = config + self.msg_builder = UniqueMessageBuilder() + self.dissemination_time = DisseminationTime(self.config.network.num_nodes) + + async def run(self): + # Run the simulation + conn_stats, node_state_table = await self.__run() + # Analyze the dissemination times + self.dissemination_time.analyze() + # Analyze the simulation results + conn_stats.analyze() + node_state_table.analyze() + # Show plots + if self.config.simulation.show_plots: + pyplot.show() + + async def __run(self) -> tuple[ConnectionStats, NodeStateTable]: + # Initialize analysis tools + node_state_table = NodeStateTable( + self.config.network.num_nodes, self.config.simulation.duration_sec + ) + conn_stats = ConnectionStats() + + # Create a μSim scope and run the simulation + async with usim.until(usim.time + self.config.simulation.duration_sec) as scope: + self.framework = usimfw.Framework(scope) + nodes = self.__init_nodes() + self.__connect_nodes(nodes, node_state_table, conn_stats) + for i, node in enumerate(nodes): + print(f"Spawning node-{i} with {len(node.nomssip.conns)} conns") + self.framework.spawn(self.__run_node_logic(node)) + + # Return analysis tools once the μSim scope is done + return conn_stats, node_state_table + + def __init_nodes(self) -> list[Node]: + # Initialize node/global configurations + node_configs = self.config.node_configs() + global_config = GlobalConfig( + MixMembership( + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ], + self.config.mix.mix_path.seed, + ), + self.config.mix.transmission_rate_per_sec, + self.config.mix.max_message_size, + self.config.mix.mix_path.max_length, + ) + + # Initialize/return Node instances + return [ + Node( + self.framework, + node_config, + global_config, + self.__process_broadcasted_msg, + self.__process_recovered_msg, + ) + for node_config in node_configs + ] + + def __connect_nodes( + self, + nodes: list[Node], + node_state_table: NodeStateTable, + conn_stats: ConnectionStats, + ): + topology = build_full_random_topology( + self.config.network.topology.seed, + len(nodes), + self.config.network.gossip.peering_degree, + ) + print("Topology:") + pprint(topology) + + meter_start_time = self.framework.now() + # Sort the topology by node index for the connection RULE defined below. + for node_idx, peer_indices in sorted(topology.items()): + for peer_idx in peer_indices: + # Since the topology is undirected, we only need to connect the two nodes once. + # RULE: the node with the smaller index establishes the connection. + assert node_idx != peer_idx + if node_idx > peer_idx: + continue + + node = nodes[node_idx] + peer = nodes[peer_idx] + node_states = node_state_table[node_idx] + peer_states = node_state_table[peer_idx] + + # Connect the node and peer for Nomos Gossip + inbound_conn, outbound_conn = ( + self.__create_observed_conn( + meter_start_time, peer_states, node_states + ), + self.__create_observed_conn( + meter_start_time, node_states, peer_states + ), + ) + node.connect_mix(peer, inbound_conn, outbound_conn) + # Register the connections to the connection statistics + conn_stats.register(node, inbound_conn, outbound_conn) + conn_stats.register(peer, outbound_conn, inbound_conn) + + # Connect the node and peer for broadcasting. + node.connect_broadcast( + peer, + self.__create_conn(meter_start_time), + self.__create_conn(meter_start_time), + ) + + def __create_observed_conn( + self, + meter_start_time: float, + sender_states: list[NodeState], + receiver_states: list[NodeState], + ) -> ObservedMeteredRemoteSimplexConnection: + return ObservedMeteredRemoteSimplexConnection( + self.config.network.latency, + self.framework, + meter_start_time, + sender_states, + receiver_states, + ) + + def __create_conn( + self, + meter_start_time: float, + ) -> MeteredRemoteSimplexConnection: + return MeteredRemoteSimplexConnection( + self.config.network.latency, + self.framework, + meter_start_time, + ) + + async def __run_node_logic(self, node: Node): + """ + Runs the lottery periodically to check if the node is selected to send a block. + If the node is selected, creates a block and sends it through mix nodes. + """ + lottery_config = self.config.logic.sender_lottery + while True: + await self.framework.sleep(lottery_config.interval_sec) + if lottery_config.seed.random() < lottery_config.probability: + msg = self.msg_builder.next(self.framework.now(), b"selected block") + await node.send_message(bytes(msg)) + + async def __process_broadcasted_msg(self, msg: bytes): + """ + Process a broadcasted message originated from the last mix. + """ + message = Message.from_bytes(msg) + elapsed = self.framework.now() - message.created_at + self.dissemination_time.add_broadcasted_msg(message, elapsed) + + async def __process_recovered_msg(self, msg: bytes) -> bytes: + """ + Process a message fully recovered by the last mix + and returns a new message to be broadcasted. + """ + message = Message.from_bytes(msg) + elapsed = self.framework.now() - message.created_at + self.dissemination_time.add_mix_propagation_time(elapsed) + + # Update the timestamp and return the message to be broadcasted, + # so that the broadcast dissemination time can be calculated from now. + message.created_at = self.framework.now() + return bytes(message) diff --git a/mixnet/sim/state.py b/mixnet/sim/state.py new file mode 100644 index 0000000..fbbd033 --- /dev/null +++ b/mixnet/sim/state.py @@ -0,0 +1,77 @@ +from datetime import datetime +from enum import Enum + +import matplotlib.pyplot as plt +import pandas + + +class NodeState(Enum): + """ + A state of node at a certain time. + For now, we assume that the node cannot send and receive messages at the same time for simplicity. + """ + + SENDING = -1 + IDLE = 0 + RECEIVING = 1 + + +class NodeStateTable: + def __init__(self, num_nodes: int, duration_sec: int): + # Create a table to store the state of each node at each millisecond + self.__table = [ + [NodeState.IDLE] * (duration_sec * 1000) for _ in range(num_nodes) + ] + + def __getitem__(self, idx: int) -> list[NodeState]: + return self.__table[idx] + + def analyze(self): + df = pandas.DataFrame(self.__table).transpose() + df.columns = [f"Node-{i}" for i in range(len(self.__table))] + # Convert NodeState enum to their integer values + df = df.map(lambda state: state.value) + print("==========================================") + print("Node States of All Nodes over Time") + print(", ".join(f"{state.name}:{state.value}" for state in NodeState)) + print("==========================================") + print(f"{df}\n") + + csv_path = f"all_node_states_{datetime.now().isoformat(timespec="seconds")}.csv" + df.to_csv(csv_path) + print(f"Saved DataFrame to {csv_path}\n") + + # Count/print the number of each state for each node + # because the df is usually too big to print + state_counts = df.apply(pandas.Series.value_counts).fillna(0) + print("State Counts per Node:") + print(f"{state_counts}\n") + + # Draw a dot plot + plt.figure(figsize=(15, 8)) + for node in df.columns: + times = df.index + states = df[node] + sending_times = times[states == NodeState.SENDING.value] + receiving_times = times[states == NodeState.RECEIVING.value] + plt.scatter( + sending_times, + [node] * len(sending_times), + color="red", + marker="o", + s=10, + label="SENDING" if node == df.columns[0] else "", + ) + plt.scatter( + receiving_times, + [node] * len(receiving_times), + color="blue", + marker="x", + s=10, + label="RECEIVING" if node == df.columns[0] else "", + ) + plt.xlabel("Time") + plt.ylabel("Node") + plt.title("Node States Over Time") + plt.legend(loc="upper right") + plt.draw() diff --git a/mixnet/sim/stats.py b/mixnet/sim/stats.py new file mode 100644 index 0000000..d8a6efe --- /dev/null +++ b/mixnet/sim/stats.py @@ -0,0 +1,152 @@ +from collections import Counter, defaultdict + +import matplotlib.pyplot as plt +import numpy +import pandas +from matplotlib.axes import Axes + +from protocol.node import Node +from sim.connection import ObservedMeteredRemoteSimplexConnection +from sim.message import Message + +# A map of nodes to their inbound/outbound connections +NodeConnectionsMap = dict[ + Node, + tuple[ + list[ObservedMeteredRemoteSimplexConnection], + list[ObservedMeteredRemoteSimplexConnection], + ], +] + + +class ConnectionStats: + def __init__(self): + self.conns_per_node: NodeConnectionsMap = defaultdict(lambda: ([], [])) + + def register( + self, + node: Node, + inbound_conn: ObservedMeteredRemoteSimplexConnection, + outbound_conn: ObservedMeteredRemoteSimplexConnection, + ): + self.conns_per_node[node][0].append(inbound_conn) + self.conns_per_node[node][1].append(outbound_conn) + + def analyze(self): + self.__message_sizes() + self.__bandwidths_per_conn() + self.__bandwidths_per_node() + + def __message_sizes(self): + """ + Analyzes all message sizes sent across all connections of all nodes. + """ + sizes: Counter[int] = Counter() + for _, (_, outbound_conns) in self.conns_per_node.items(): + for conn in outbound_conns: + sizes.update(conn.msg_sizes) + + df = pandas.DataFrame.from_dict(sizes, orient="index").reset_index() + df.columns = ["msg_size", "count"] + print("==========================================") + print("Message Size Distribution") + print("==========================================") + print(f"{df}\n") + + def __bandwidths_per_conn(self): + """ + Analyzes the bandwidth consumed by each simplex connection. + """ + plt.plot(figsize=(12, 6)) + + for _, (_, outbound_conns) in self.conns_per_node.items(): + for conn in outbound_conns: + sending_bandwidths = conn.sending_bandwidths().map(lambda x: x / 1024) + plt.plot(sending_bandwidths.index, sending_bandwidths) + + plt.title("Unidirectional Bandwidths per Connection") + plt.xlabel("Time (s)") + plt.ylabel("Bandwidth (KiB/s)") + plt.ylim(bottom=0) + plt.grid(True) + plt.tight_layout() + plt.draw() + + def __bandwidths_per_node(self): + """ + Analyzes the inbound/outbound bandwidths consumed by each node (sum of all its connections). + """ + _, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6)) + assert isinstance(axs, numpy.ndarray) + + for i, (_, (inbound_conns, outbound_conns)) in enumerate( + self.conns_per_node.items() + ): + inbound_bandwidths = ( + pandas.concat( + [conn.receiving_bandwidths() for conn in inbound_conns], axis=1 + ) + .sum(axis=1) + .map(lambda x: x / 1024) + ) + outbound_bandwidths = ( + pandas.concat( + [conn.sending_bandwidths() for conn in outbound_conns], axis=1 + ) + .sum(axis=1) + .map(lambda x: x / 1024) + ) + axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}") + axs[1].plot( + outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}" + ) + + axs[0].set_title("Inbound Bandwidths per Node") + axs[0].set_xlabel("Time (s)") + axs[0].set_ylabel("Bandwidth (KiB/s)") + axs[0].legend() + axs[0].set_ylim(bottom=0) + axs[0].grid(True) + + axs[1].set_title("Outbound Bandwidths per Node") + axs[1].set_xlabel("Time (s)") + axs[1].set_ylabel("Bandwidth (KiB/s)") + axs[1].legend() + axs[1].set_ylim(bottom=0) + axs[1].grid(True) + + plt.tight_layout() + plt.draw() + + +class DisseminationTime: + def __init__(self, num_nodes: int): + # A collection of time taken for a message to propagate through all mix nodes in its mix route + self.mix_propagation_times: list[float] = [] + # A collection of time taken for a message to be broadcasted from the last mix to all nodes in the network + self.broadcast_dissemination_times: list[float] = [] + # Data structures to check if a message has been broadcasted to all nodes + self.broadcast_status: Counter[Message] = Counter() + self.num_nodes: int = num_nodes + + def add_mix_propagation_time(self, elapsed: float): + self.mix_propagation_times.append(elapsed) + + def add_broadcasted_msg(self, msg: Message, elapsed: float): + assert self.broadcast_status[msg] < self.num_nodes + self.broadcast_status.update([msg]) + if self.broadcast_status[msg] == self.num_nodes: + self.broadcast_dissemination_times.append(elapsed) + + def analyze(self): + print("==========================================") + print("Message Dissemination Time") + print("==========================================") + print("[Mix Propagation Times]") + mix_propagation_times = pandas.Series(self.mix_propagation_times) + print(mix_propagation_times.describe()) + print("") + print("[Broadcast Dissemination Times]") + broadcast_travel_times = pandas.Series(self.broadcast_dissemination_times) + print(broadcast_travel_times.describe()) + print("") diff --git a/mixnet/sim/test_connection.py b/mixnet/sim/test_connection.py new file mode 100644 index 0000000..2b59807 --- /dev/null +++ b/mixnet/sim/test_connection.py @@ -0,0 +1,104 @@ +import math +import random +from unittest import IsolatedAsyncioTestCase + +import usim + +import framework.usim as usimfw +from protocol.connection import LocalSimplexConnection +from protocol.node import Node +from protocol.test_utils import ( + init_mixnet_config, +) +from sim.config import LatencyConfig, NetworkConfig +from sim.connection import ( + MeteredRemoteSimplexConnection, + ObservedMeteredRemoteSimplexConnection, +) +from sim.state import NodeState, NodeStateTable + + +class TestMeteredRemoteSimplexConnection(IsolatedAsyncioTestCase): + async def test_latency(self): + usim.run(self.__test_latency()) + + async def __test_latency(self): + async with usim.Scope() as scope: + framework = usimfw.Framework(scope) + node_state_table = NodeStateTable(num_nodes=2, duration_sec=3) + conn = MeteredRemoteSimplexConnection( + LatencyConfig( + min_latency_sec=0, + max_latency_sec=1, + seed=random.Random(), + ), + framework, + framework.now(), + ) + + # Send two messages without delay + sent_time = framework.now() + await conn.send(b"hello") + await conn.send(b"world") + + # Receive two messages and check if the network latency was simulated well. + # There should be no delay between the two messages because they were sent without delay. + self.assertEqual(b"hello", await conn.recv()) + self.assertEqual(conn.latency, framework.now() - sent_time) + self.assertEqual(b"world", await conn.recv()) + self.assertEqual(conn.latency, framework.now() - sent_time) + + +class TestObservedMeteredRemoteSimplexConnection(IsolatedAsyncioTestCase): + async def test_node_state(self): + usim.run(self.__test_node_state()) + + async def __test_node_state(self): + async with usim.Scope() as scope: + framework = usimfw.Framework(scope) + node_state_table = NodeStateTable(num_nodes=2, duration_sec=3) + meter_start_time = framework.now() + conn = ObservedMeteredRemoteSimplexConnection( + LatencyConfig( + min_latency_sec=0, + max_latency_sec=1, + seed=random.Random(), + ), + framework, + meter_start_time, + node_state_table[0], + node_state_table[1], + ) + + # Sleep and send a message + await framework.sleep(1) + sent_time = framework.now() + await conn.send(b"hello") + + # Receive the message. It should be received after the latency. + self.assertEqual(b"hello", await conn.recv()) + recv_time = framework.now() + + # Check if the sender node state is SENDING at the sent time + timeslot = math.floor((sent_time - meter_start_time) * 1000) + self.assertEqual( + NodeState.SENDING, + node_state_table[0][timeslot], + ) + # Ensure that the sender node states in other time slots are IDLE + states = set() + states.update(node_state_table[0][:timeslot]) + states.update(node_state_table[0][timeslot + 1 :]) + self.assertEqual(set([NodeState.IDLE]), states) + + # Check if the receiver node state is RECEIVING at the received time + timeslot = math.floor((recv_time - meter_start_time) * 1000) + self.assertEqual( + NodeState.RECEIVING, + node_state_table[1][timeslot], + ) + # Ensure that the receiver node states in other time slots are IDLE + states = set() + states.update(node_state_table[1][:timeslot]) + states.update(node_state_table[1][timeslot + 1 :]) + self.assertEqual(set([NodeState.IDLE]), states) diff --git a/mixnet/sim/test_message.py b/mixnet/sim/test_message.py new file mode 100644 index 0000000..3eddae8 --- /dev/null +++ b/mixnet/sim/test_message.py @@ -0,0 +1,23 @@ +import time +from unittest import TestCase + +from sim.message import Message, UniqueMessageBuilder + + +class TestMessage(TestCase): + def test_message_serde(self): + msg = Message(time.time(), 10, b"hello") + serialized = bytes(msg) + deserialized = Message.from_bytes(serialized) + self.assertEqual(msg, deserialized) + + +class TestUniqueMessageBuilder(TestCase): + def test_uniqueness(self): + builder = UniqueMessageBuilder() + msg1 = builder.next(time.time(), b"hello") + msg2 = builder.next(time.time(), b"hello") + self.assertEqual(0, msg1.id) + self.assertEqual(1, msg2.id) + self.assertNotEqual(msg1, msg2) + self.assertNotEqual(hash(msg1), hash(msg2)) diff --git a/mixnet/sim/test_topology.py b/mixnet/sim/test_topology.py new file mode 100644 index 0000000..0a4e3cd --- /dev/null +++ b/mixnet/sim/test_topology.py @@ -0,0 +1,20 @@ +import random +from unittest import TestCase + +from sim.topology import are_all_nodes_connected, build_full_random_topology + + +class TestTopology(TestCase): + def test_full_random(self): + num_nodes = 100 + peering_degree = 6 + topology = build_full_random_topology( + random.Random(0), num_nodes, peering_degree + ) + self.assertEqual(num_nodes, len(topology)) + self.assertTrue(are_all_nodes_connected(topology)) + for node, peers in topology.items(): + self.assertTrue(0 < len(peers) <= peering_degree) + # Check if nodes are interconnected + for peer in peers: + self.assertIn(node, topology[peer]) diff --git a/mixnet/sim/topology.py b/mixnet/sim/topology.py new file mode 100644 index 0000000..a1cbc13 --- /dev/null +++ b/mixnet/sim/topology.py @@ -0,0 +1,58 @@ +import random +from collections import defaultdict + +from protocol.node import Node + +Topology = dict[int, set[int]] + + +def build_full_random_topology( + rng: random.Random, num_nodes: int, peering_degree: int +) -> Topology: + """ + Generate a random undirected topology until all nodes are connected. + We don't implement any artificial tool to ensure the connectivity of the topology. + Instead, we regenerate a topology in a fully randomized way until all nodes are connected. + """ + while True: + topology: Topology = defaultdict(set[int]) + nodes = list(range(num_nodes)) + for node in nodes: + # Filter nodes that can be connected to the current node. + others = [] + for other in nodes[:node] + nodes[node + 1 :]: + # Check if the other node is not already connected to the current node + # and the other node has not reached the peering degree. + if ( + other not in topology[node] + and len(topology[other]) < peering_degree + ): + others.append(other) + # How many more connections the current node needs + n_needs = peering_degree - len(topology[node]) + # Sample peers as many as possible + peers = rng.sample(others, k=min(n_needs, len(others))) + # Connect the current node to the peers + topology[node].update(peers) + # Connect the peers to the current node, since the topology is undirected + for peer in peers: + topology[peer].update([node]) + + if are_all_nodes_connected(topology): + return topology + + +def are_all_nodes_connected(topology: Topology) -> bool: + visited = set() + + def dfs(topology: Topology, node: int) -> None: + if node in visited: + return + visited.add(node) + for peer in topology[node]: + dfs(topology, peer) + + # Start DFS from the first node + dfs(topology, next(iter(topology))) + + return len(visited) == len(topology)