Use usim for simulation for reliable performance

This commit is contained in:
Youngjoon Lee 2024-07-04 16:07:12 +09:00
parent a0da12a93a
commit a260047cef
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
13 changed files with 232 additions and 88 deletions

View File

@ -1,5 +1,6 @@
import abc
import asyncio
from mixnet.framework.framework import Framework
class SimplexConnection(abc.ABC):
@ -13,8 +14,8 @@ class SimplexConnection(abc.ABC):
class LocalSimplexConnection(SimplexConnection):
def __init__(self):
self.queue = asyncio.Queue()
def __init__(self, framework: Framework):
self.queue = framework.queue()
async def send(self, data: bytes) -> None:
await self.queue.put(data)

View File

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Awaitable, Coroutine
from mixnet.framework import framework
class Framework(framework.Framework):
def __init__(self):
super().__init__()
def queue(self) -> framework.Queue:
return Queue()
async def sleep(self, seconds: float) -> None:
await asyncio.sleep(seconds)
def now(self) -> float:
return time.time()
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]
) -> Awaitable[framework.RT]:
return asyncio.create_task(coroutine)
class Queue(framework.Queue):
_queue: asyncio.Queue[bytes]
def __init__(self):
super().__init__()
self._queue = asyncio.Queue()
async def put(self, data: bytes) -> None:
await self._queue.put(data)
async def get(self) -> bytes:
return await self._queue.get()
def empty(self) -> bool:
return self._queue.empty()

View File

@ -0,0 +1,38 @@
from __future__ import annotations
import abc
from typing import Any, Awaitable, Coroutine, TypeVar
RT = TypeVar("RT")
class Framework(abc.ABC):
@abc.abstractmethod
def queue(self) -> Queue:
pass
@abc.abstractmethod
async def sleep(self, seconds: float) -> None:
pass
@abc.abstractmethod
def now(self) -> float:
pass
@abc.abstractmethod
def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]:
pass
class Queue(abc.ABC):
@abc.abstractmethod
async def put(self, data: bytes) -> None:
pass
@abc.abstractmethod
async def get(self) -> bytes:
pass
@abc.abstractmethod
def empty(self) -> bool:
pass

44
mixnet/framework/usim.py Normal file
View File

@ -0,0 +1,44 @@
from typing import Any, Awaitable, Coroutine
import usim
from mixnet.framework import framework
class Framework(framework.Framework):
_scope: usim.Scope
def __init__(self, scope: usim.Scope) -> None:
super().__init__()
self._scope = scope
def queue(self) -> framework.Queue:
return Queue()
async def sleep(self, seconds: float) -> None:
await (usim.time + seconds)
def now(self) -> float:
return usim.time.now
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]
) -> Awaitable[framework.RT]:
return self._scope.do(coroutine)
class Queue(framework.Queue):
_queue: usim.Queue[bytes]
def __init__(self):
super().__init__()
self._queue = usim.Queue()
async def put(self, data: bytes) -> None:
await self._queue.put(data)
async def get(self) -> bytes:
return await self._queue
def empty(self) -> bool:
return len(self._queue._buffer) == 0

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import hashlib
from enum import Enum
from typing import Awaitable, Callable, TypeAlias
@ -14,28 +13,33 @@ from pysphinx.sphinx import (
)
from mixnet.config import GlobalConfig, NodeConfig
from mixnet.connection import LocalSimplexConnection, SimplexConnection
from mixnet.connection import SimplexConnection
from mixnet.framework.framework import Framework, Queue
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
NetworkPacketQueue: TypeAlias = Queue
BroadcastChannel: TypeAlias = Queue
class Node:
framework: Framework
config: NodeConfig
global_config: GlobalConfig
mixgossip_channel: MixGossipChannel
reconstructor: MessageReconstructor
broadcast_channel: BroadcastChannel
def __init__(self, config: NodeConfig, global_config: GlobalConfig):
def __init__(
self, framework: Framework, config: NodeConfig, global_config: GlobalConfig
):
self.framework = framework
self.config = config
self.global_config = global_config
self.mixgossip_channel = MixGossipChannel(
config.peering_degree, self.__process_sphinx_packet
framework, config.peering_degree, self.__process_sphinx_packet
)
self.reconstructor = MessageReconstructor()
self.broadcast_channel = asyncio.Queue()
self.broadcast_channel = framework.queue()
async def __process_sphinx_packet(
self, packet: SphinxPacket
@ -64,14 +68,16 @@ class Node:
def connect(
self,
peer: Node,
inbound_conn: SimplexConnection = LocalSimplexConnection(),
outbound_conn: SimplexConnection = LocalSimplexConnection(),
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
):
self.mixgossip_channel.add_conn(
DuplexConnection(
inbound_conn,
MixSimplexConnection(
outbound_conn, self.global_config.transmission_rate_per_sec
self.framework,
outbound_conn,
self.global_config.transmission_rate_per_sec,
),
)
)
@ -79,7 +85,9 @@ class Node:
DuplexConnection(
outbound_conn,
MixSimplexConnection(
inbound_conn, self.global_config.transmission_rate_per_sec
self.framework,
inbound_conn,
self.global_config.transmission_rate_per_sec,
),
)
)
@ -93,6 +101,7 @@ class Node:
class MixGossipChannel:
framework: Framework
peering_degree: int
conns: list[DuplexConnection]
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]]
@ -100,9 +109,11 @@ class MixGossipChannel:
def __init__(
self,
framework: Framework,
peering_degree: int,
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]],
):
self.framework = framework
self.peering_degree = peering_degree
self.conns = []
self.handler = handler
@ -117,10 +128,8 @@ class MixGossipChannel:
raise ValueError("The peering degree is reached.")
self.conns.append(conn)
task = asyncio.create_task(self.__process_inbound_conn(conn))
task = self.framework.spawn(self.__process_inbound_conn(conn))
self.tasks.add(task)
# To discard the task from the set automatically when it is done.
task.add_done_callback(self.tasks.discard)
async def __process_inbound_conn(self, conn: DuplexConnection):
while True:
@ -166,24 +175,31 @@ class DuplexConnection:
class MixSimplexConnection:
framework: Framework
queue: NetworkPacketQueue
conn: SimplexConnection
transmission_rate_per_sec: float
def __init__(self, conn: SimplexConnection, transmission_rate_per_sec: float):
self.queue = asyncio.Queue()
def __init__(
self,
framework: Framework,
conn: SimplexConnection,
transmission_rate_per_sec: float,
):
self.framework = framework
self.queue = framework.queue()
self.conn = conn
self.transmission_rate_per_sec = transmission_rate_per_sec
self.task = asyncio.create_task(self.__run())
self.task = framework.spawn(self.__run())
async def __run(self):
while True:
await asyncio.sleep(1 / self.transmission_rate_per_sec)
await self.framework.sleep(1 / self.transmission_rate_per_sec)
# TODO: time mixing
if self.queue.empty():
elem = build_noise_packet()
else:
elem = self.queue.get_nowait()
elem = await self.queue.get()
await self.conn.send(elem)
async def send(self, elem: bytes):

View File

@ -31,27 +31,13 @@ class Config:
@dataclass
class SimulationConfig:
time_scale: float
duration_sec: int
net_latency_sec: float
def validate(self):
assert self.time_scale > 0
assert self.duration_sec > 0
assert self.net_latency_sec > 0
def scale_time(self, time: float) -> float:
return time * self.time_scale
def scale_rate(self, rate: int) -> float:
return float(rate / self.time_scale)
def scaled_duration(self) -> float:
return self.scale_time(self.duration_sec)
def scaled_net_latency(self) -> float:
return self.scale_time(self.net_latency_sec)
@dataclass
class LogicConfig:

View File

@ -1,6 +1,5 @@
simulation:
time_scale: 0.001
duration_sec: 10000
duration_sec: 1000
net_latency_sec: 0.01
logic:

View File

@ -1,32 +1,34 @@
import asyncio
import math
import time
from typing import Awaitable
import pandas
from mixnet.connection import SimplexConnection
from mixnet.framework.framework import Framework, Queue
from mixnet.sim.config import SimulationConfig
class MeteredRemoteSimplexConnection(SimplexConnection):
framework: Framework
config: SimulationConfig
outputs: asyncio.Queue
conn: asyncio.Queue
inputs: asyncio.Queue
output_task: asyncio.Task
outputs: Queue
conn: Queue
inputs: Queue
output_task: Awaitable
output_meters: list[int]
input_task: asyncio.Task
input_task: Awaitable
input_meters: list[int]
def __init__(self, config: SimulationConfig):
def __init__(self, config: SimulationConfig, framework: Framework):
self.framework = framework
self.config = config
self.outputs = asyncio.Queue()
self.conn = asyncio.Queue()
self.inputs = asyncio.Queue()
self.outputs = framework.queue()
self.conn = framework.queue()
self.inputs = framework.queue()
self.output_meters = []
self.output_task = asyncio.create_task(self.__run_output_task())
self.output_task = framework.spawn(self.__run_output_task())
self.input_meters = []
self.input_task = asyncio.create_task(self.__run_input_task())
self.input_task = framework.spawn(self.__run_input_task())
async def send(self, data: bytes) -> None:
await self.outputs.put(data)
@ -35,22 +37,24 @@ class MeteredRemoteSimplexConnection(SimplexConnection):
return await self.inputs.get()
async def __run_output_task(self):
start_time = time.time()
start_time = self.framework.now()
while True:
data = await self.outputs.get()
self.__update_meter(self.output_meters, len(data), start_time)
await self.conn.put(data)
async def __run_input_task(self):
start_time = time.time()
start_time = self.framework.now()
while True:
await asyncio.sleep(self.config.scaled_net_latency())
data = await self.conn.get()
if data is None:
break
await self.framework.sleep(self.config.net_latency_sec)
self.__update_meter(self.input_meters, len(data), start_time)
await self.inputs.put(data)
def __update_meter(self, meters: list[int], size: int, start_time: float):
slot = math.floor((time.time() - start_time) / self.config.time_scale)
slot = math.floor(self.framework.now() - start_time)
assert slot >= len(meters) - 1
meters.extend([0] * (slot - len(meters) + 1))
meters[-1] += size

View File

@ -1,5 +1,6 @@
import argparse
import asyncio
import usim
from mixnet.sim.config import Config
from mixnet.sim.simulation import Simulation
@ -16,6 +17,6 @@ if __name__ == "__main__":
config = Config.load(args.config)
sim = Simulation(config)
asyncio.run(sim.run())
usim.run(sim.run())
print("Simulation complete!")

View File

@ -1,8 +1,10 @@
import asyncio
import random
import time
import usim
import mixnet.framework.usim as usimfw
from mixnet.config import GlobalConfig, MixMembership, NodeInfo
from mixnet.framework.framework import Framework
from mixnet.node import Node
from mixnet.sim.config import Config
from mixnet.sim.connection import MeteredRemoteSimplexConnection
@ -10,20 +12,25 @@ from mixnet.sim.stats import ConnectionStats
class Simulation:
config: Config
framework: Framework
def __init__(self, config: Config):
random.seed()
self.config = config
async def run(self):
nodes, conn_measurement = self.init_nodes()
conn_stats = await self._run()
conn_stats.bandwidths()
deadline = time.time() + self.config.simulation.scaled_duration()
tasks: list[asyncio.Task] = []
for node in nodes:
tasks.append(asyncio.create_task(self.run_logic(node, deadline)))
await asyncio.gather(*tasks)
conn_measurement.bandwidths()
async def _run(self) -> ConnectionStats:
async with usim.until(usim.time + self.config.simulation.duration_sec) as scope:
self.framework = usimfw.Framework(scope)
nodes, conn_stats = self.init_nodes()
for node in nodes:
self.framework.spawn(self.run_logic(node))
return conn_stats
assert False # unreachable
def init_nodes(self) -> tuple[list[Node], ConnectionStats]:
node_configs = self.config.mixnet.node_configs()
@ -34,16 +41,20 @@ class Simulation:
for node_config in node_configs
]
),
self.config.simulation.scale_rate(
self.config.mixnet.transmission_rate_per_sec
),
self.config.mixnet.transmission_rate_per_sec,
self.config.mixnet.max_mix_path_length,
)
nodes = [Node(node_config, global_config) for node_config in node_configs]
nodes = [
Node(self.framework, node_config, global_config)
for node_config in node_configs
]
conn_stats = ConnectionStats()
for i, node in enumerate(nodes):
inbound_conn, outbound_conn = self.create_conn(), self.create_conn()
inbound_conn, outbound_conn = (
self.create_conn(),
self.create_conn(),
)
peer = nodes[(i + 1) % len(nodes)]
node.connect(peer, inbound_conn, outbound_conn)
conn_stats.register(node, inbound_conn, outbound_conn)
@ -52,15 +63,10 @@ class Simulation:
return nodes, conn_stats
def create_conn(self) -> MeteredRemoteSimplexConnection:
return MeteredRemoteSimplexConnection(self.config.simulation)
async def run_logic(self, node: Node, deadline: float):
while time.time() < deadline:
await asyncio.sleep(
self.config.simulation.scale_time(
self.config.logic.lottery_interval_sec
)
)
return MeteredRemoteSimplexConnection(self.config.simulation, self.framework)
async def run_logic(self, node: Node):
while True:
await (usim.time + self.config.logic.lottery_interval_sec)
if random.random() < self.config.logic.sender_prob:
await node.send_message(b"selected block")

View File

@ -1,6 +1,8 @@
import asyncio
from unittest import IsolatedAsyncioTestCase
import mixnet.framework.asyncio as asynciofw
from mixnet.connection import LocalSimplexConnection
from mixnet.node import Node
from mixnet.test_utils import (
init_mixnet_config,
@ -9,13 +11,17 @@ from mixnet.test_utils import (
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
framework = asynciofw.Framework()
global_config, node_configs, _ = init_mixnet_config(10)
nodes = [Node(node_config, global_config) for node_config in node_configs]
nodes = [
Node(framework, node_config, global_config) for node_config in node_configs
]
for i, node in enumerate(nodes):
try:
node.connect(nodes[(i + 1) % len(nodes)])
except ValueError as e:
print(e)
node.connect(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection(framework),
LocalSimplexConnection(framework),
)
await nodes[0].send_message(b"block selection")
@ -24,7 +30,7 @@ class TestNode(IsolatedAsyncioTestCase):
broadcasted_msgs = []
for node in nodes:
if not node.broadcast_channel.empty():
broadcasted_msgs.append(node.broadcast_channel.get_nowait())
broadcasted_msgs.append(await node.broadcast_channel.get())
if len(broadcasted_msgs) == 0:
await asyncio.sleep(1)

View File

@ -6,4 +6,4 @@ pycparser==2.21
pysphinx==0.0.3
scipy==1.11.4
black==23.12.1
sympy==1.12
usim==0.4.4