diff --git a/mixnet/connection.py b/mixnet/connection.py index bf39cbc..8230e31 100644 --- a/mixnet/connection.py +++ b/mixnet/connection.py @@ -1,5 +1,6 @@ import abc -import asyncio + +from mixnet.framework.framework import Framework class SimplexConnection(abc.ABC): @@ -13,8 +14,8 @@ class SimplexConnection(abc.ABC): class LocalSimplexConnection(SimplexConnection): - def __init__(self): - self.queue = asyncio.Queue() + def __init__(self, framework: Framework): + self.queue = framework.queue() async def send(self, data: bytes) -> None: await self.queue.put(data) diff --git a/mixnet/framework/__init__.py b/mixnet/framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/framework/asyncio.py b/mixnet/framework/asyncio.py new file mode 100644 index 0000000..78ff3ea --- /dev/null +++ b/mixnet/framework/asyncio.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Awaitable, Coroutine + +from mixnet.framework import framework + + +class Framework(framework.Framework): + def __init__(self): + super().__init__() + + def queue(self) -> framework.Queue: + return Queue() + + async def sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) + + def now(self) -> float: + return time.time() + + def spawn( + self, coroutine: Coroutine[Any, Any, framework.RT] + ) -> Awaitable[framework.RT]: + return asyncio.create_task(coroutine) + + +class Queue(framework.Queue): + _queue: asyncio.Queue[bytes] + + def __init__(self): + super().__init__() + self._queue = asyncio.Queue() + + async def put(self, data: bytes) -> None: + await self._queue.put(data) + + async def get(self) -> bytes: + return await self._queue.get() + + def empty(self) -> bool: + return self._queue.empty() diff --git a/mixnet/framework/framework.py b/mixnet/framework/framework.py new file mode 100644 index 0000000..5bbfaa1 --- /dev/null +++ b/mixnet/framework/framework.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import abc +from typing import Any, Awaitable, Coroutine, TypeVar + +RT = TypeVar("RT") + + +class Framework(abc.ABC): + @abc.abstractmethod + def queue(self) -> Queue: + pass + + @abc.abstractmethod + async def sleep(self, seconds: float) -> None: + pass + + @abc.abstractmethod + def now(self) -> float: + pass + + @abc.abstractmethod + def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]: + pass + + +class Queue(abc.ABC): + @abc.abstractmethod + async def put(self, data: bytes) -> None: + pass + + @abc.abstractmethod + async def get(self) -> bytes: + pass + + @abc.abstractmethod + def empty(self) -> bool: + pass diff --git a/mixnet/framework/usim.py b/mixnet/framework/usim.py new file mode 100644 index 0000000..49e443d --- /dev/null +++ b/mixnet/framework/usim.py @@ -0,0 +1,44 @@ +from typing import Any, Awaitable, Coroutine + +import usim + +from mixnet.framework import framework + + +class Framework(framework.Framework): + _scope: usim.Scope + + def __init__(self, scope: usim.Scope) -> None: + super().__init__() + self._scope = scope + + def queue(self) -> framework.Queue: + return Queue() + + async def sleep(self, seconds: float) -> None: + await (usim.time + seconds) + + def now(self) -> float: + return usim.time.now + + def spawn( + self, coroutine: Coroutine[Any, Any, framework.RT] + ) -> Awaitable[framework.RT]: + return self._scope.do(coroutine) + + +class Queue(framework.Queue): + _queue: usim.Queue[bytes] + + def __init__(self): + super().__init__() + self._queue = usim.Queue() + + async def put(self, data: bytes) -> None: + await self._queue.put(data) + + async def get(self) -> bytes: + return await self._queue + + def empty(self) -> bool: + return len(self._queue._buffer) == 0 diff --git a/mixnet/node.py b/mixnet/node.py index b4949f0..a2e110c 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import hashlib from enum import Enum from typing import Awaitable, Callable, TypeAlias @@ -14,28 +13,33 @@ from pysphinx.sphinx import ( ) from mixnet.config import GlobalConfig, NodeConfig -from mixnet.connection import LocalSimplexConnection, SimplexConnection +from mixnet.connection import SimplexConnection +from mixnet.framework.framework import Framework, Queue from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder -NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes] -BroadcastChannel: TypeAlias = asyncio.Queue[bytes] +NetworkPacketQueue: TypeAlias = Queue +BroadcastChannel: TypeAlias = Queue class Node: + framework: Framework config: NodeConfig global_config: GlobalConfig mixgossip_channel: MixGossipChannel reconstructor: MessageReconstructor broadcast_channel: BroadcastChannel - def __init__(self, config: NodeConfig, global_config: GlobalConfig): + def __init__( + self, framework: Framework, config: NodeConfig, global_config: GlobalConfig + ): + self.framework = framework self.config = config self.global_config = global_config self.mixgossip_channel = MixGossipChannel( - config.peering_degree, self.__process_sphinx_packet + framework, config.peering_degree, self.__process_sphinx_packet ) self.reconstructor = MessageReconstructor() - self.broadcast_channel = asyncio.Queue() + self.broadcast_channel = framework.queue() async def __process_sphinx_packet( self, packet: SphinxPacket @@ -64,14 +68,16 @@ class Node: def connect( self, peer: Node, - inbound_conn: SimplexConnection = LocalSimplexConnection(), - outbound_conn: SimplexConnection = LocalSimplexConnection(), + inbound_conn: SimplexConnection, + outbound_conn: SimplexConnection, ): self.mixgossip_channel.add_conn( DuplexConnection( inbound_conn, MixSimplexConnection( - outbound_conn, self.global_config.transmission_rate_per_sec + self.framework, + outbound_conn, + self.global_config.transmission_rate_per_sec, ), ) ) @@ -79,7 +85,9 @@ class Node: DuplexConnection( outbound_conn, MixSimplexConnection( - inbound_conn, self.global_config.transmission_rate_per_sec + self.framework, + inbound_conn, + self.global_config.transmission_rate_per_sec, ), ) ) @@ -93,6 +101,7 @@ class Node: class MixGossipChannel: + framework: Framework peering_degree: int conns: list[DuplexConnection] handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]] @@ -100,9 +109,11 @@ class MixGossipChannel: def __init__( self, + framework: Framework, peering_degree: int, handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]], ): + self.framework = framework self.peering_degree = peering_degree self.conns = [] self.handler = handler @@ -117,10 +128,8 @@ class MixGossipChannel: raise ValueError("The peering degree is reached.") self.conns.append(conn) - task = asyncio.create_task(self.__process_inbound_conn(conn)) + task = self.framework.spawn(self.__process_inbound_conn(conn)) self.tasks.add(task) - # To discard the task from the set automatically when it is done. - task.add_done_callback(self.tasks.discard) async def __process_inbound_conn(self, conn: DuplexConnection): while True: @@ -166,24 +175,31 @@ class DuplexConnection: class MixSimplexConnection: + framework: Framework queue: NetworkPacketQueue conn: SimplexConnection transmission_rate_per_sec: float - def __init__(self, conn: SimplexConnection, transmission_rate_per_sec: float): - self.queue = asyncio.Queue() + def __init__( + self, + framework: Framework, + conn: SimplexConnection, + transmission_rate_per_sec: float, + ): + self.framework = framework + self.queue = framework.queue() self.conn = conn self.transmission_rate_per_sec = transmission_rate_per_sec - self.task = asyncio.create_task(self.__run()) + self.task = framework.spawn(self.__run()) async def __run(self): while True: - await asyncio.sleep(1 / self.transmission_rate_per_sec) + await self.framework.sleep(1 / self.transmission_rate_per_sec) # TODO: time mixing if self.queue.empty(): elem = build_noise_packet() else: - elem = self.queue.get_nowait() + elem = await self.queue.get() await self.conn.send(elem) async def send(self, elem: bytes): diff --git a/mixnet/sim/config.py b/mixnet/sim/config.py index 0de0ee0..bdffb7a 100644 --- a/mixnet/sim/config.py +++ b/mixnet/sim/config.py @@ -31,27 +31,13 @@ class Config: @dataclass class SimulationConfig: - time_scale: float duration_sec: int net_latency_sec: float def validate(self): - assert self.time_scale > 0 assert self.duration_sec > 0 assert self.net_latency_sec > 0 - def scale_time(self, time: float) -> float: - return time * self.time_scale - - def scale_rate(self, rate: int) -> float: - return float(rate / self.time_scale) - - def scaled_duration(self) -> float: - return self.scale_time(self.duration_sec) - - def scaled_net_latency(self) -> float: - return self.scale_time(self.net_latency_sec) - @dataclass class LogicConfig: diff --git a/mixnet/sim/config.yaml b/mixnet/sim/config.yaml index cd92751..f605b42 100644 --- a/mixnet/sim/config.yaml +++ b/mixnet/sim/config.yaml @@ -1,6 +1,5 @@ simulation: - time_scale: 0.001 - duration_sec: 10000 + duration_sec: 1000 net_latency_sec: 0.01 logic: diff --git a/mixnet/sim/connection.py b/mixnet/sim/connection.py index dfaf1e9..36db151 100644 --- a/mixnet/sim/connection.py +++ b/mixnet/sim/connection.py @@ -1,32 +1,34 @@ -import asyncio import math -import time +from typing import Awaitable import pandas from mixnet.connection import SimplexConnection +from mixnet.framework.framework import Framework, Queue from mixnet.sim.config import SimulationConfig class MeteredRemoteSimplexConnection(SimplexConnection): + framework: Framework config: SimulationConfig - outputs: asyncio.Queue - conn: asyncio.Queue - inputs: asyncio.Queue - output_task: asyncio.Task + outputs: Queue + conn: Queue + inputs: Queue + output_task: Awaitable output_meters: list[int] - input_task: asyncio.Task + input_task: Awaitable input_meters: list[int] - def __init__(self, config: SimulationConfig): + def __init__(self, config: SimulationConfig, framework: Framework): + self.framework = framework self.config = config - self.outputs = asyncio.Queue() - self.conn = asyncio.Queue() - self.inputs = asyncio.Queue() + self.outputs = framework.queue() + self.conn = framework.queue() + self.inputs = framework.queue() self.output_meters = [] - self.output_task = asyncio.create_task(self.__run_output_task()) + self.output_task = framework.spawn(self.__run_output_task()) self.input_meters = [] - self.input_task = asyncio.create_task(self.__run_input_task()) + self.input_task = framework.spawn(self.__run_input_task()) async def send(self, data: bytes) -> None: await self.outputs.put(data) @@ -35,22 +37,24 @@ class MeteredRemoteSimplexConnection(SimplexConnection): return await self.inputs.get() async def __run_output_task(self): - start_time = time.time() + start_time = self.framework.now() while True: data = await self.outputs.get() self.__update_meter(self.output_meters, len(data), start_time) await self.conn.put(data) async def __run_input_task(self): - start_time = time.time() + start_time = self.framework.now() while True: - await asyncio.sleep(self.config.scaled_net_latency()) data = await self.conn.get() + if data is None: + break + await self.framework.sleep(self.config.net_latency_sec) self.__update_meter(self.input_meters, len(data), start_time) await self.inputs.put(data) def __update_meter(self, meters: list[int], size: int, start_time: float): - slot = math.floor((time.time() - start_time) / self.config.time_scale) + slot = math.floor(self.framework.now() - start_time) assert slot >= len(meters) - 1 meters.extend([0] * (slot - len(meters) + 1)) meters[-1] += size diff --git a/mixnet/sim/main.py b/mixnet/sim/main.py index 11389d5..9a9b0cd 100644 --- a/mixnet/sim/main.py +++ b/mixnet/sim/main.py @@ -1,5 +1,6 @@ import argparse -import asyncio + +import usim from mixnet.sim.config import Config from mixnet.sim.simulation import Simulation @@ -16,6 +17,6 @@ if __name__ == "__main__": config = Config.load(args.config) sim = Simulation(config) - asyncio.run(sim.run()) + usim.run(sim.run()) print("Simulation complete!") diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py index a7d504a..38fde6b 100644 --- a/mixnet/sim/simulation.py +++ b/mixnet/sim/simulation.py @@ -1,8 +1,10 @@ -import asyncio import random -import time +import usim + +import mixnet.framework.usim as usimfw from mixnet.config import GlobalConfig, MixMembership, NodeInfo +from mixnet.framework.framework import Framework from mixnet.node import Node from mixnet.sim.config import Config from mixnet.sim.connection import MeteredRemoteSimplexConnection @@ -10,20 +12,25 @@ from mixnet.sim.stats import ConnectionStats class Simulation: + config: Config + framework: Framework + def __init__(self, config: Config): random.seed() self.config = config async def run(self): - nodes, conn_measurement = self.init_nodes() + conn_stats = await self._run() + conn_stats.bandwidths() - deadline = time.time() + self.config.simulation.scaled_duration() - tasks: list[asyncio.Task] = [] - for node in nodes: - tasks.append(asyncio.create_task(self.run_logic(node, deadline))) - await asyncio.gather(*tasks) - - conn_measurement.bandwidths() + async def _run(self) -> ConnectionStats: + async with usim.until(usim.time + self.config.simulation.duration_sec) as scope: + self.framework = usimfw.Framework(scope) + nodes, conn_stats = self.init_nodes() + for node in nodes: + self.framework.spawn(self.run_logic(node)) + return conn_stats + assert False # unreachable def init_nodes(self) -> tuple[list[Node], ConnectionStats]: node_configs = self.config.mixnet.node_configs() @@ -34,16 +41,20 @@ class Simulation: for node_config in node_configs ] ), - self.config.simulation.scale_rate( - self.config.mixnet.transmission_rate_per_sec - ), + self.config.mixnet.transmission_rate_per_sec, self.config.mixnet.max_mix_path_length, ) - nodes = [Node(node_config, global_config) for node_config in node_configs] + nodes = [ + Node(self.framework, node_config, global_config) + for node_config in node_configs + ] conn_stats = ConnectionStats() for i, node in enumerate(nodes): - inbound_conn, outbound_conn = self.create_conn(), self.create_conn() + inbound_conn, outbound_conn = ( + self.create_conn(), + self.create_conn(), + ) peer = nodes[(i + 1) % len(nodes)] node.connect(peer, inbound_conn, outbound_conn) conn_stats.register(node, inbound_conn, outbound_conn) @@ -52,15 +63,10 @@ class Simulation: return nodes, conn_stats def create_conn(self) -> MeteredRemoteSimplexConnection: - return MeteredRemoteSimplexConnection(self.config.simulation) - - async def run_logic(self, node: Node, deadline: float): - while time.time() < deadline: - await asyncio.sleep( - self.config.simulation.scale_time( - self.config.logic.lottery_interval_sec - ) - ) + return MeteredRemoteSimplexConnection(self.config.simulation, self.framework) + async def run_logic(self, node: Node): + while True: + await (usim.time + self.config.logic.lottery_interval_sec) if random.random() < self.config.logic.sender_prob: await node.send_message(b"selected block") diff --git a/mixnet/test_node.py b/mixnet/test_node.py index f4ba644..bd62c2e 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -1,6 +1,8 @@ import asyncio from unittest import IsolatedAsyncioTestCase +import mixnet.framework.asyncio as asynciofw +from mixnet.connection import LocalSimplexConnection from mixnet.node import Node from mixnet.test_utils import ( init_mixnet_config, @@ -9,13 +11,17 @@ from mixnet.test_utils import ( class TestNode(IsolatedAsyncioTestCase): async def test_node(self): + framework = asynciofw.Framework() global_config, node_configs, _ = init_mixnet_config(10) - nodes = [Node(node_config, global_config) for node_config in node_configs] + nodes = [ + Node(framework, node_config, global_config) for node_config in node_configs + ] for i, node in enumerate(nodes): - try: - node.connect(nodes[(i + 1) % len(nodes)]) - except ValueError as e: - print(e) + node.connect( + nodes[(i + 1) % len(nodes)], + LocalSimplexConnection(framework), + LocalSimplexConnection(framework), + ) await nodes[0].send_message(b"block selection") @@ -24,7 +30,7 @@ class TestNode(IsolatedAsyncioTestCase): broadcasted_msgs = [] for node in nodes: if not node.broadcast_channel.empty(): - broadcasted_msgs.append(node.broadcast_channel.get_nowait()) + broadcasted_msgs.append(await node.broadcast_channel.get()) if len(broadcasted_msgs) == 0: await asyncio.sleep(1) diff --git a/requirements.txt b/requirements.txt index d89ad06..f6462eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ pycparser==2.21 pysphinx==0.0.3 scipy==1.11.4 black==23.12.1 -sympy==1.12 +usim==0.4.4