add simulation

This commit is contained in:
Youngjoon Lee 2024-07-03 23:29:26 +09:00
parent d3e8b0223e
commit 515bc2c50a
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
10 changed files with 347 additions and 15 deletions

View File

@ -14,7 +14,7 @@ from pysphinx.sphinx import Node as SphinxNode
@dataclass
class GlobalConfig:
membership: MixMembership
transmission_rate_per_sec: int # Global Transmission Rate
transmission_rate_per_sec: float # Global Transmission Rate
# TODO: use this to make the size of Sphinx packet constant
max_mix_path_length: int

23
mixnet/connection.py Normal file
View File

@ -0,0 +1,23 @@
import abc
import asyncio
class SimplexConnection(abc.ABC):
@abc.abstractmethod
async def send(self, data: bytes) -> None:
pass
@abc.abstractmethod
async def recv(self) -> bytes:
pass
class LocalSimplexConnection(SimplexConnection):
def __init__(self):
self.queue = asyncio.Queue()
async def send(self, data: bytes) -> None:
await self.queue.put(data)
async def recv(self) -> bytes:
return await self.queue.get()

View File

@ -14,10 +14,10 @@ from pysphinx.sphinx import (
)
from mixnet.config import GlobalConfig, NodeConfig
from mixnet.connection import LocalSimplexConnection, SimplexConnection
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
Connection: TypeAlias = NetworkPacketQueue
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
@ -58,14 +58,19 @@ class Node:
if msg_with_flag is not None:
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
if flag == MessageFlag.MESSAGE_FLAG_REAL:
print(f"Broadcasting message finally: {msg}")
await self.broadcast_channel.put(msg)
def connect(self, peer: Node):
inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue()
def connect(
self,
peer: Node,
inbound_conn: SimplexConnection = LocalSimplexConnection(),
outbound_conn: SimplexConnection = LocalSimplexConnection(),
):
self.mixgossip_channel.add_conn(
DuplexConnection(
inbound_conn,
MixOutboundConnection(
MixSimplexConnection(
outbound_conn, self.global_config.transmission_rate_per_sec
),
)
@ -73,13 +78,14 @@ class Node:
peer.mixgossip_channel.add_conn(
DuplexConnection(
outbound_conn,
MixOutboundConnection(
MixSimplexConnection(
inbound_conn, self.global_config.transmission_rate_per_sec
),
)
)
async def send_message(self, msg: bytes):
print(f"Sending message: {msg}")
for packet, _ in PacketBuilder.build_real_packets(
msg, self.global_config.membership
):
@ -145,26 +151,26 @@ class MixGossipChannel:
class DuplexConnection:
inbound: Connection
outbound: MixOutboundConnection
inbound: SimplexConnection
outbound: MixSimplexConnection
def __init__(self, inbound: Connection, outbound: MixOutboundConnection):
def __init__(self, inbound: SimplexConnection, outbound: MixSimplexConnection):
self.inbound = inbound
self.outbound = outbound
async def recv(self) -> bytes:
return await self.inbound.get()
return await self.inbound.recv()
async def send(self, packet: bytes):
await self.outbound.send(packet)
class MixOutboundConnection:
class MixSimplexConnection:
queue: NetworkPacketQueue
conn: Connection
transmission_rate_per_sec: int
conn: SimplexConnection
transmission_rate_per_sec: float
def __init__(self, conn: Connection, transmission_rate_per_sec: int):
def __init__(self, conn: SimplexConnection, transmission_rate_per_sec: float):
self.queue = asyncio.Queue()
self.conn = conn
self.transmission_rate_per_sec = transmission_rate_per_sec
@ -178,7 +184,7 @@ class MixOutboundConnection:
elem = build_noise_packet()
else:
elem = self.queue.get_nowait()
await self.conn.put(elem)
await self.conn.send(elem)
async def send(self, elem: bytes):
await self.queue.put(elem)

0
mixnet/sim/__init__.py Normal file
View File

77
mixnet/sim/config.py Normal file
View File

@ -0,0 +1,77 @@
from __future__ import annotations
from dataclasses import dataclass
import dacite
import yaml
from pysphinx.sphinx import X25519PrivateKey
from mixnet.config import NodeConfig
@dataclass
class Config:
simulation: SimulationConfig
logic: LogicConfig
mixnet: MixnetConfig
@classmethod
def load(cls, yaml_path: str) -> Config:
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
config = dacite.from_dict(data_class=Config, data=data)
# Validations
config.simulation.validate()
config.logic.validate()
config.mixnet.validate()
return config
@dataclass
class SimulationConfig:
time_scale: float
duration_sec: int
net_latency_sec: float
meter_interval_sec: float
def validate(self):
assert self.time_scale > 0
assert self.duration_sec > 0
assert self.net_latency_sec > 0
assert self.meter_interval_sec > 0
@dataclass
class LogicConfig:
lottery_interval_sec: float
sender_prob: float
def validate(self):
assert self.lottery_interval_sec > 0
assert self.sender_prob > 0
@dataclass
class MixnetConfig:
num_nodes: int
transmission_rate_per_sec: int
peering_degree: int
max_mix_path_length: int
def validate(self):
assert self.num_nodes > 0
assert self.transmission_rate_per_sec > 0
assert self.peering_degree > 0
assert self.max_mix_path_length > 0
def node_configs(self) -> list[NodeConfig]:
return [
NodeConfig(
X25519PrivateKey.generate(),
self.peering_degree,
self.transmission_rate_per_sec,
)
for _ in range(self.num_nodes)
]

16
mixnet/sim/config.yaml Normal file
View File

@ -0,0 +1,16 @@
simulation:
time_scale: 0.001
duration_sec: 10000
net_latency_sec: 0.01
meter_interval_sec: 1
logic:
lottery_interval_sec: 1
sender_prob: 0.01
mixnet:
num_nodes: 5
transmission_rate_per_sec: 10
peering_degree: 6
max_mix_path_length: 3

68
mixnet/sim/connection.py Normal file
View File

@ -0,0 +1,68 @@
import asyncio
import math
import time
import pandas
from mixnet.connection import SimplexConnection
class MeteredRemoteSimplexConnection(SimplexConnection):
latency: float
meter_interval: float
outputs: asyncio.Queue
conn: asyncio.Queue
inputs: asyncio.Queue
output_task: asyncio.Task
output_meters: list[int]
input_task: asyncio.Task
input_meters: list[int]
def __init__(self, latency: float, meter_interval: float):
self.latency = latency
self.meter_interval = meter_interval
self.outputs = asyncio.Queue()
self.conn = asyncio.Queue()
self.inputs = asyncio.Queue()
self.output_meters = []
self.output_task = asyncio.create_task(self.__run_output_task())
self.input_meters = []
self.input_task = asyncio.create_task(self.__run_input_task())
async def send(self, data: bytes) -> None:
await self.outputs.put(data)
async def recv(self) -> bytes:
return await self.inputs.get()
async def __run_output_task(self):
start_time = time.time()
while True:
data = await self.outputs.get()
self.__update_meter(self.output_meters, len(data), start_time)
await self.conn.put(data)
async def __run_input_task(self):
start_time = time.time()
while True:
await asyncio.sleep(self.latency)
data = await self.conn.get()
self.__update_meter(self.input_meters, len(data), start_time)
await self.inputs.put(data)
def __update_meter(self, meters: list[int], size: int, start_time: float):
slot = math.floor((time.time() - start_time) / self.meter_interval)
assert slot >= len(meters) - 1
meters.extend([0] * (slot - len(meters) + 1))
meters[-1] += size
def output_bandwidths(self) -> pandas.Series:
return self.__bandwidths(self.output_meters)
def input_bandwidths(self) -> pandas.Series:
return self.__bandwidths(self.input_meters)
def __bandwidths(self, meters: list[int]) -> pandas.Series:
return pandas.Series(meters, name="bandwidth").map(
lambda x: x / self.meter_interval
)

21
mixnet/sim/main.py Normal file
View File

@ -0,0 +1,21 @@
import argparse
import asyncio
from mixnet.sim.config import Config
from mixnet.sim.simulation import Simulation
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run mixnet simulation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config", type=str, required=True, help="Configuration file path"
)
args = parser.parse_args()
config = Config.load(args.config)
sim = Simulation(config)
asyncio.run(sim.run())
print("Simulation complete!")

69
mixnet/sim/simulation.py Normal file
View File

@ -0,0 +1,69 @@
import asyncio
import random
import time
from mixnet.config import GlobalConfig, MixMembership, NodeInfo
from mixnet.node import Node
from mixnet.sim.config import Config
from mixnet.sim.connection import MeteredRemoteSimplexConnection
from mixnet.sim.stats import ConnectionStats
class Simulation:
def __init__(self, config: Config):
random.seed()
self.config = config
async def run(self):
nodes, conn_measurement = self.init_nodes()
deadline = time.time() + self.scaled_time(self.config.simulation.duration_sec)
tasks: list[asyncio.Task] = []
for node in nodes:
tasks.append(asyncio.create_task(self.run_logic(node, deadline)))
await asyncio.gather(*tasks)
conn_measurement.bandwidths()
def init_nodes(self) -> tuple[list[Node], ConnectionStats]:
node_configs = self.config.mixnet.node_configs()
global_config = GlobalConfig(
MixMembership(
[
NodeInfo(node_config.private_key.public_key())
for node_config in node_configs
]
),
self.scaled_rate(self.config.mixnet.transmission_rate_per_sec),
self.config.mixnet.max_mix_path_length,
)
nodes = [Node(node_config, global_config) for node_config in node_configs]
conn_stats = ConnectionStats()
for i, node in enumerate(nodes):
inbound_conn, outbound_conn = self.create_conn(), self.create_conn()
node.connect(nodes[(i + 1) % len(nodes)], inbound_conn, outbound_conn)
conn_stats.register(node, inbound_conn, outbound_conn)
return nodes, conn_stats
def create_conn(self) -> MeteredRemoteSimplexConnection:
return MeteredRemoteSimplexConnection(
latency=self.scaled_time(self.config.simulation.net_latency_sec),
meter_interval=self.scaled_time(self.config.simulation.meter_interval_sec),
)
async def run_logic(self, node: Node, deadline: float):
while time.time() < deadline:
await asyncio.sleep(
self.scaled_time(self.config.logic.lottery_interval_sec)
)
if random.random() < self.config.logic.sender_prob:
await node.send_message(b"selected block")
def scaled_time(self, time: float) -> float:
return time * self.config.simulation.time_scale
def scaled_rate(self, rate: int) -> float:
return float(rate / self.config.simulation.time_scale)

52
mixnet/sim/stats.py Normal file
View File

@ -0,0 +1,52 @@
import pandas
from mixnet.node import Node
from mixnet.sim.connection import MeteredRemoteSimplexConnection
NodeConnectionsMap = dict[
Node,
tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]],
]
class ConnectionStats:
conns_per_node: NodeConnectionsMap
def __init__(self):
self.conns_per_node = dict()
def register(
self,
node: Node,
inbound_conn: MeteredRemoteSimplexConnection,
outbound_conn: MeteredRemoteSimplexConnection,
):
if node not in self.conns_per_node:
self.conns_per_node[node] = ([], [])
self.conns_per_node[node][0].append(inbound_conn)
self.conns_per_node[node][1].append(outbound_conn)
def bandwidths(self):
for i, (_, (inbound_conns, outbound_conns)) in enumerate(
self.conns_per_node.items()
):
inbound_bandwidths = (
pandas.concat(
[conn.input_bandwidths() for conn in inbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024 / 1024)
)
outbound_bandwidths = (
pandas.concat(
[conn.output_bandwidths() for conn in outbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024 / 1024)
)
print(f"=== [Node:{i}] ===")
print("--- Inbound bandwidths ---")
print(inbound_bandwidths.describe())
print("--- Outbound bandwidths ---")
print(outbound_bandwidths.describe())