diff --git a/mixnet/node.py b/mixnet/node.py index 3fc5c7a..a91723c 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -104,6 +104,7 @@ class MixGossipChannel: if isinstance(elem, bytes): assert elem == build_noise_packet() # Drop packet + print("dropping noise") continue elif isinstance(elem, SphinxPacket): net_packet = await self.handler(elem) @@ -132,6 +133,7 @@ class MixOutboundConnection: # TODO: time mixing if self.queue.empty(): elem = build_noise_packet() + print("generating noise") else: elem = self.queue.get_nowait() await self.conn.put(elem) diff --git a/mixnet/sim/__init__.py b/mixnet/sim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/sim/config.py b/mixnet/sim/config.py new file mode 100644 index 0000000..a751b52 --- /dev/null +++ b/mixnet/sim/config.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import dacite +import yaml +from pysphinx.sphinx import X25519PrivateKey +from simpy.core import SimTime + +from mixnet.config import NodeConfig + + +@dataclass +class Config: + simulation: SimulationConfig + mixnet: MixnetConfig + + @classmethod + def load(cls, yaml_path: str) -> Config: + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + config = dacite.from_dict(data_class=Config, data=data) + + # Validations + config.simulation.validate() + config.mixnet.validate() + + return config + + def description(self): + return f"{self.simulation.description()}\n" f"{self.mixnet.description()}" + + +@dataclass +class SimulationConfig: + running_time: SimTime + + def validate(self): + # SimTime supports float but better to use int for time accuracy + assert isinstance(self.running_time, int) and self.running_time > 0 + + def description(self): + return f"running_time: {self.running_time}" + + +@dataclass +class MixnetConfig: + num_nodes: int + transmission_rate_per_sec: int + + def validate(self): + assert self.num_nodes > 0 + assert self.transmission_rate_per_sec > 0 + + def description(self): + return ( + f"num_nodes: {self.num_nodes}\n" + f"transmission_rate_per_sec: {self.transmission_rate_per_sec}" + ) + + def node_configs(self) -> list[NodeConfig]: + return [ + NodeConfig(X25519PrivateKey.generate(), self.transmission_rate_per_sec) + for _ in range(self.num_nodes) + ] diff --git a/mixnet/sim/config.yaml b/mixnet/sim/config.yaml new file mode 100644 index 0000000..87b52e6 --- /dev/null +++ b/mixnet/sim/config.yaml @@ -0,0 +1,7 @@ +simulation: + # The simulation uses a virtual time (integer). Please see README for more details. + running_time: 300 + +mixnet: + num_nodes: 100 + transmission_rate_per_sec: 10 diff --git a/mixnet/sim/main.py b/mixnet/sim/main.py new file mode 100644 index 0000000..a9384d7 --- /dev/null +++ b/mixnet/sim/main.py @@ -0,0 +1,20 @@ +import argparse + +from mixnet.sim.config import Config +from mixnet.sim.simulation import Simulation + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run mixnet simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, required=True, help="Configuration file path" + ) + args = parser.parse_args() + + config = Config.load(args.config) + sim = Simulation(config) + sim.run() + + print("Simulation complete!") diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py new file mode 100644 index 0000000..ccdca7c --- /dev/null +++ b/mixnet/sim/simulation.py @@ -0,0 +1,26 @@ +import random + +import simpy + +from mixnet.config import MixMembership, NodeInfo +from mixnet.node import Node +from mixnet.sim.config import Config + + +class Simulation: + def __init__(self, config: Config): + random.seed() + self.config = config + self.env = simpy.Environment() + + # Initialize mixnet nodes and establish connections + node_configs = self.config.mixnet.node_configs() + membership = MixMembership( + [NodeInfo(node_config.private_key) for node_config in node_configs] + ) + nodes = [Node(node_config, membership) for node_config in node_configs] + for i, node in enumerate(nodes): + node.connect(nodes[(i + 1) % len(nodes)]) + + def run(self): + self.env.run(until=self.config.simulation.running_time)