From 515bc2c50a808bd126a6be826854a0469f29b7e0 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Wed, 3 Jul 2024 23:29:26 +0900 Subject: [PATCH] add simulation --- mixnet/config.py | 2 +- mixnet/connection.py | 23 ++++++++++++ mixnet/node.py | 34 ++++++++++-------- mixnet/sim/__init__.py | 0 mixnet/sim/config.py | 77 ++++++++++++++++++++++++++++++++++++++++ mixnet/sim/config.yaml | 16 +++++++++ mixnet/sim/connection.py | 68 +++++++++++++++++++++++++++++++++++ mixnet/sim/main.py | 21 +++++++++++ mixnet/sim/simulation.py | 69 +++++++++++++++++++++++++++++++++++ mixnet/sim/stats.py | 52 +++++++++++++++++++++++++++ 10 files changed, 347 insertions(+), 15 deletions(-) create mode 100644 mixnet/connection.py create mode 100644 mixnet/sim/__init__.py create mode 100644 mixnet/sim/config.py create mode 100644 mixnet/sim/config.yaml create mode 100644 mixnet/sim/connection.py create mode 100644 mixnet/sim/main.py create mode 100644 mixnet/sim/simulation.py create mode 100644 mixnet/sim/stats.py diff --git a/mixnet/config.py b/mixnet/config.py index 696a4ad..140a913 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -14,7 +14,7 @@ from pysphinx.sphinx import Node as SphinxNode @dataclass class GlobalConfig: membership: MixMembership - transmission_rate_per_sec: int # Global Transmission Rate + transmission_rate_per_sec: float # Global Transmission Rate # TODO: use this to make the size of Sphinx packet constant max_mix_path_length: int diff --git a/mixnet/connection.py b/mixnet/connection.py new file mode 100644 index 0000000..bf39cbc --- /dev/null +++ b/mixnet/connection.py @@ -0,0 +1,23 @@ +import abc +import asyncio + + +class SimplexConnection(abc.ABC): + @abc.abstractmethod + async def send(self, data: bytes) -> None: + pass + + @abc.abstractmethod + async def recv(self) -> bytes: + pass + + +class LocalSimplexConnection(SimplexConnection): + def __init__(self): + self.queue = asyncio.Queue() + + async def send(self, data: bytes) -> None: + await self.queue.put(data) + + async def recv(self) -> bytes: + return await self.queue.get() diff --git a/mixnet/node.py b/mixnet/node.py index 994a41c..b4949f0 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -14,10 +14,10 @@ from pysphinx.sphinx import ( ) from mixnet.config import GlobalConfig, NodeConfig +from mixnet.connection import LocalSimplexConnection, SimplexConnection from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes] -Connection: TypeAlias = NetworkPacketQueue BroadcastChannel: TypeAlias = asyncio.Queue[bytes] @@ -58,14 +58,19 @@ class Node: if msg_with_flag is not None: flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag) if flag == MessageFlag.MESSAGE_FLAG_REAL: + print(f"Broadcasting message finally: {msg}") await self.broadcast_channel.put(msg) - def connect(self, peer: Node): - inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() + def connect( + self, + peer: Node, + inbound_conn: SimplexConnection = LocalSimplexConnection(), + outbound_conn: SimplexConnection = LocalSimplexConnection(), + ): self.mixgossip_channel.add_conn( DuplexConnection( inbound_conn, - MixOutboundConnection( + MixSimplexConnection( outbound_conn, self.global_config.transmission_rate_per_sec ), ) @@ -73,13 +78,14 @@ class Node: peer.mixgossip_channel.add_conn( DuplexConnection( outbound_conn, - MixOutboundConnection( + MixSimplexConnection( inbound_conn, self.global_config.transmission_rate_per_sec ), ) ) async def send_message(self, msg: bytes): + print(f"Sending message: {msg}") for packet, _ in PacketBuilder.build_real_packets( msg, self.global_config.membership ): @@ -145,26 +151,26 @@ class MixGossipChannel: class DuplexConnection: - inbound: Connection - outbound: MixOutboundConnection + inbound: SimplexConnection + outbound: MixSimplexConnection - def __init__(self, inbound: Connection, outbound: MixOutboundConnection): + def __init__(self, inbound: SimplexConnection, outbound: MixSimplexConnection): self.inbound = inbound self.outbound = outbound async def recv(self) -> bytes: - return await self.inbound.get() + return await self.inbound.recv() async def send(self, packet: bytes): await self.outbound.send(packet) -class MixOutboundConnection: +class MixSimplexConnection: queue: NetworkPacketQueue - conn: Connection - transmission_rate_per_sec: int + conn: SimplexConnection + transmission_rate_per_sec: float - def __init__(self, conn: Connection, transmission_rate_per_sec: int): + def __init__(self, conn: SimplexConnection, transmission_rate_per_sec: float): self.queue = asyncio.Queue() self.conn = conn self.transmission_rate_per_sec = transmission_rate_per_sec @@ -178,7 +184,7 @@ class MixOutboundConnection: elem = build_noise_packet() else: elem = self.queue.get_nowait() - await self.conn.put(elem) + await self.conn.send(elem) async def send(self, elem: bytes): await self.queue.put(elem) diff --git a/mixnet/sim/__init__.py b/mixnet/sim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/sim/config.py b/mixnet/sim/config.py new file mode 100644 index 0000000..e38f382 --- /dev/null +++ b/mixnet/sim/config.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import dacite +import yaml +from pysphinx.sphinx import X25519PrivateKey + +from mixnet.config import NodeConfig + + +@dataclass +class Config: + simulation: SimulationConfig + logic: LogicConfig + mixnet: MixnetConfig + + @classmethod + def load(cls, yaml_path: str) -> Config: + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + config = dacite.from_dict(data_class=Config, data=data) + + # Validations + config.simulation.validate() + config.logic.validate() + config.mixnet.validate() + + return config + + +@dataclass +class SimulationConfig: + time_scale: float + duration_sec: int + net_latency_sec: float + meter_interval_sec: float + + def validate(self): + assert self.time_scale > 0 + assert self.duration_sec > 0 + assert self.net_latency_sec > 0 + assert self.meter_interval_sec > 0 + + +@dataclass +class LogicConfig: + lottery_interval_sec: float + sender_prob: float + + def validate(self): + assert self.lottery_interval_sec > 0 + assert self.sender_prob > 0 + + +@dataclass +class MixnetConfig: + num_nodes: int + transmission_rate_per_sec: int + peering_degree: int + max_mix_path_length: int + + def validate(self): + assert self.num_nodes > 0 + assert self.transmission_rate_per_sec > 0 + assert self.peering_degree > 0 + assert self.max_mix_path_length > 0 + + def node_configs(self) -> list[NodeConfig]: + return [ + NodeConfig( + X25519PrivateKey.generate(), + self.peering_degree, + self.transmission_rate_per_sec, + ) + for _ in range(self.num_nodes) + ] diff --git a/mixnet/sim/config.yaml b/mixnet/sim/config.yaml new file mode 100644 index 0000000..79e0977 --- /dev/null +++ b/mixnet/sim/config.yaml @@ -0,0 +1,16 @@ +simulation: + time_scale: 0.001 + duration_sec: 10000 + net_latency_sec: 0.01 + meter_interval_sec: 1 + + +logic: + lottery_interval_sec: 1 + sender_prob: 0.01 + +mixnet: + num_nodes: 5 + transmission_rate_per_sec: 10 + peering_degree: 6 + max_mix_path_length: 3 diff --git a/mixnet/sim/connection.py b/mixnet/sim/connection.py new file mode 100644 index 0000000..e2bcbab --- /dev/null +++ b/mixnet/sim/connection.py @@ -0,0 +1,68 @@ +import asyncio +import math +import time + +import pandas + +from mixnet.connection import SimplexConnection + + +class MeteredRemoteSimplexConnection(SimplexConnection): + latency: float + meter_interval: float + outputs: asyncio.Queue + conn: asyncio.Queue + inputs: asyncio.Queue + output_task: asyncio.Task + output_meters: list[int] + input_task: asyncio.Task + input_meters: list[int] + + def __init__(self, latency: float, meter_interval: float): + self.latency = latency + self.meter_interval = meter_interval + self.outputs = asyncio.Queue() + self.conn = asyncio.Queue() + self.inputs = asyncio.Queue() + self.output_meters = [] + self.output_task = asyncio.create_task(self.__run_output_task()) + self.input_meters = [] + self.input_task = asyncio.create_task(self.__run_input_task()) + + async def send(self, data: bytes) -> None: + await self.outputs.put(data) + + async def recv(self) -> bytes: + return await self.inputs.get() + + async def __run_output_task(self): + start_time = time.time() + while True: + data = await self.outputs.get() + self.__update_meter(self.output_meters, len(data), start_time) + await self.conn.put(data) + + async def __run_input_task(self): + start_time = time.time() + while True: + await asyncio.sleep(self.latency) + data = await self.conn.get() + self.__update_meter(self.input_meters, len(data), start_time) + await self.inputs.put(data) + + def __update_meter(self, meters: list[int], size: int, start_time: float): + slot = math.floor((time.time() - start_time) / self.meter_interval) + assert slot >= len(meters) - 1 + meters.extend([0] * (slot - len(meters) + 1)) + meters[-1] += size + + def output_bandwidths(self) -> pandas.Series: + return self.__bandwidths(self.output_meters) + + def input_bandwidths(self) -> pandas.Series: + return self.__bandwidths(self.input_meters) + + def __bandwidths(self, meters: list[int]) -> pandas.Series: + return pandas.Series(meters, name="bandwidth").map( + lambda x: x / self.meter_interval + ) diff --git a/mixnet/sim/main.py b/mixnet/sim/main.py new file mode 100644 index 0000000..11389d5 --- /dev/null +++ b/mixnet/sim/main.py @@ -0,0 +1,21 @@ +import argparse +import asyncio + +from mixnet.sim.config import Config +from mixnet.sim.simulation import Simulation + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run mixnet simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, required=True, help="Configuration file path" + ) + args = parser.parse_args() + + config = Config.load(args.config) + sim = Simulation(config) + asyncio.run(sim.run()) + + print("Simulation complete!") diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py new file mode 100644 index 0000000..928b674 --- /dev/null +++ b/mixnet/sim/simulation.py @@ -0,0 +1,69 @@ +import asyncio +import random +import time + +from mixnet.config import GlobalConfig, MixMembership, NodeInfo +from mixnet.node import Node +from mixnet.sim.config import Config +from mixnet.sim.connection import MeteredRemoteSimplexConnection +from mixnet.sim.stats import ConnectionStats + + +class Simulation: + def __init__(self, config: Config): + random.seed() + self.config = config + + async def run(self): + nodes, conn_measurement = self.init_nodes() + + deadline = time.time() + self.scaled_time(self.config.simulation.duration_sec) + tasks: list[asyncio.Task] = [] + for node in nodes: + tasks.append(asyncio.create_task(self.run_logic(node, deadline))) + await asyncio.gather(*tasks) + + conn_measurement.bandwidths() + + def init_nodes(self) -> tuple[list[Node], ConnectionStats]: + node_configs = self.config.mixnet.node_configs() + global_config = GlobalConfig( + MixMembership( + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ] + ), + self.scaled_rate(self.config.mixnet.transmission_rate_per_sec), + self.config.mixnet.max_mix_path_length, + ) + nodes = [Node(node_config, global_config) for node_config in node_configs] + + conn_stats = ConnectionStats() + for i, node in enumerate(nodes): + inbound_conn, outbound_conn = self.create_conn(), self.create_conn() + node.connect(nodes[(i + 1) % len(nodes)], inbound_conn, outbound_conn) + conn_stats.register(node, inbound_conn, outbound_conn) + + return nodes, conn_stats + + def create_conn(self) -> MeteredRemoteSimplexConnection: + return MeteredRemoteSimplexConnection( + latency=self.scaled_time(self.config.simulation.net_latency_sec), + meter_interval=self.scaled_time(self.config.simulation.meter_interval_sec), + ) + + async def run_logic(self, node: Node, deadline: float): + while time.time() < deadline: + await asyncio.sleep( + self.scaled_time(self.config.logic.lottery_interval_sec) + ) + + if random.random() < self.config.logic.sender_prob: + await node.send_message(b"selected block") + + def scaled_time(self, time: float) -> float: + return time * self.config.simulation.time_scale + + def scaled_rate(self, rate: int) -> float: + return float(rate / self.config.simulation.time_scale) diff --git a/mixnet/sim/stats.py b/mixnet/sim/stats.py new file mode 100644 index 0000000..2e30ff3 --- /dev/null +++ b/mixnet/sim/stats.py @@ -0,0 +1,52 @@ +import pandas + +from mixnet.node import Node +from mixnet.sim.connection import MeteredRemoteSimplexConnection + +NodeConnectionsMap = dict[ + Node, + tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]], +] + + +class ConnectionStats: + conns_per_node: NodeConnectionsMap + + def __init__(self): + self.conns_per_node = dict() + + def register( + self, + node: Node, + inbound_conn: MeteredRemoteSimplexConnection, + outbound_conn: MeteredRemoteSimplexConnection, + ): + if node not in self.conns_per_node: + self.conns_per_node[node] = ([], []) + self.conns_per_node[node][0].append(inbound_conn) + self.conns_per_node[node][1].append(outbound_conn) + + def bandwidths(self): + for i, (_, (inbound_conns, outbound_conns)) in enumerate( + self.conns_per_node.items() + ): + inbound_bandwidths = ( + pandas.concat( + [conn.input_bandwidths() for conn in inbound_conns], axis=1 + ) + .sum(axis=1) + .map(lambda x: x / 1024 / 1024) + ) + outbound_bandwidths = ( + pandas.concat( + [conn.output_bandwidths() for conn in outbound_conns], axis=1 + ) + .sum(axis=1) + .map(lambda x: x / 1024 / 1024) + ) + + print(f"=== [Node:{i}] ===") + print("--- Inbound bandwidths ---") + print(inbound_bandwidths.describe()) + print("--- Outbound bandwidths ---") + print(outbound_bandwidths.describe())