mirror of
https://github.com/logos-co/nomos-specs.git
synced 2025-02-02 02:25:02 +00:00
add simulation
This commit is contained in:
parent
d3e8b0223e
commit
515bc2c50a
@ -14,7 +14,7 @@ from pysphinx.sphinx import Node as SphinxNode
|
||||
@dataclass
|
||||
class GlobalConfig:
|
||||
membership: MixMembership
|
||||
transmission_rate_per_sec: int # Global Transmission Rate
|
||||
transmission_rate_per_sec: float # Global Transmission Rate
|
||||
# TODO: use this to make the size of Sphinx packet constant
|
||||
max_mix_path_length: int
|
||||
|
||||
|
23
mixnet/connection.py
Normal file
23
mixnet/connection.py
Normal file
@ -0,0 +1,23 @@
|
||||
import abc
|
||||
import asyncio
|
||||
|
||||
|
||||
class SimplexConnection(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def send(self, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def recv(self) -> bytes:
|
||||
pass
|
||||
|
||||
|
||||
class LocalSimplexConnection(SimplexConnection):
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
await self.queue.put(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.queue.get()
|
@ -14,10 +14,10 @@ from pysphinx.sphinx import (
|
||||
)
|
||||
|
||||
from mixnet.config import GlobalConfig, NodeConfig
|
||||
from mixnet.connection import LocalSimplexConnection, SimplexConnection
|
||||
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
|
||||
|
||||
NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
|
||||
Connection: TypeAlias = NetworkPacketQueue
|
||||
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
|
||||
|
||||
|
||||
@ -58,14 +58,19 @@ class Node:
|
||||
if msg_with_flag is not None:
|
||||
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
|
||||
if flag == MessageFlag.MESSAGE_FLAG_REAL:
|
||||
print(f"Broadcasting message finally: {msg}")
|
||||
await self.broadcast_channel.put(msg)
|
||||
|
||||
def connect(self, peer: Node):
|
||||
inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue()
|
||||
def connect(
|
||||
self,
|
||||
peer: Node,
|
||||
inbound_conn: SimplexConnection = LocalSimplexConnection(),
|
||||
outbound_conn: SimplexConnection = LocalSimplexConnection(),
|
||||
):
|
||||
self.mixgossip_channel.add_conn(
|
||||
DuplexConnection(
|
||||
inbound_conn,
|
||||
MixOutboundConnection(
|
||||
MixSimplexConnection(
|
||||
outbound_conn, self.global_config.transmission_rate_per_sec
|
||||
),
|
||||
)
|
||||
@ -73,13 +78,14 @@ class Node:
|
||||
peer.mixgossip_channel.add_conn(
|
||||
DuplexConnection(
|
||||
outbound_conn,
|
||||
MixOutboundConnection(
|
||||
MixSimplexConnection(
|
||||
inbound_conn, self.global_config.transmission_rate_per_sec
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def send_message(self, msg: bytes):
|
||||
print(f"Sending message: {msg}")
|
||||
for packet, _ in PacketBuilder.build_real_packets(
|
||||
msg, self.global_config.membership
|
||||
):
|
||||
@ -145,26 +151,26 @@ class MixGossipChannel:
|
||||
|
||||
|
||||
class DuplexConnection:
|
||||
inbound: Connection
|
||||
outbound: MixOutboundConnection
|
||||
inbound: SimplexConnection
|
||||
outbound: MixSimplexConnection
|
||||
|
||||
def __init__(self, inbound: Connection, outbound: MixOutboundConnection):
|
||||
def __init__(self, inbound: SimplexConnection, outbound: MixSimplexConnection):
|
||||
self.inbound = inbound
|
||||
self.outbound = outbound
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.inbound.get()
|
||||
return await self.inbound.recv()
|
||||
|
||||
async def send(self, packet: bytes):
|
||||
await self.outbound.send(packet)
|
||||
|
||||
|
||||
class MixOutboundConnection:
|
||||
class MixSimplexConnection:
|
||||
queue: NetworkPacketQueue
|
||||
conn: Connection
|
||||
transmission_rate_per_sec: int
|
||||
conn: SimplexConnection
|
||||
transmission_rate_per_sec: float
|
||||
|
||||
def __init__(self, conn: Connection, transmission_rate_per_sec: int):
|
||||
def __init__(self, conn: SimplexConnection, transmission_rate_per_sec: float):
|
||||
self.queue = asyncio.Queue()
|
||||
self.conn = conn
|
||||
self.transmission_rate_per_sec = transmission_rate_per_sec
|
||||
@ -178,7 +184,7 @@ class MixOutboundConnection:
|
||||
elem = build_noise_packet()
|
||||
else:
|
||||
elem = self.queue.get_nowait()
|
||||
await self.conn.put(elem)
|
||||
await self.conn.send(elem)
|
||||
|
||||
async def send(self, elem: bytes):
|
||||
await self.queue.put(elem)
|
||||
|
0
mixnet/sim/__init__.py
Normal file
0
mixnet/sim/__init__.py
Normal file
77
mixnet/sim/config.py
Normal file
77
mixnet/sim/config.py
Normal file
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import dacite
|
||||
import yaml
|
||||
from pysphinx.sphinx import X25519PrivateKey
|
||||
|
||||
from mixnet.config import NodeConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
simulation: SimulationConfig
|
||||
logic: LogicConfig
|
||||
mixnet: MixnetConfig
|
||||
|
||||
@classmethod
|
||||
def load(cls, yaml_path: str) -> Config:
|
||||
with open(yaml_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
config = dacite.from_dict(data_class=Config, data=data)
|
||||
|
||||
# Validations
|
||||
config.simulation.validate()
|
||||
config.logic.validate()
|
||||
config.mixnet.validate()
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationConfig:
|
||||
time_scale: float
|
||||
duration_sec: int
|
||||
net_latency_sec: float
|
||||
meter_interval_sec: float
|
||||
|
||||
def validate(self):
|
||||
assert self.time_scale > 0
|
||||
assert self.duration_sec > 0
|
||||
assert self.net_latency_sec > 0
|
||||
assert self.meter_interval_sec > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogicConfig:
|
||||
lottery_interval_sec: float
|
||||
sender_prob: float
|
||||
|
||||
def validate(self):
|
||||
assert self.lottery_interval_sec > 0
|
||||
assert self.sender_prob > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixnetConfig:
|
||||
num_nodes: int
|
||||
transmission_rate_per_sec: int
|
||||
peering_degree: int
|
||||
max_mix_path_length: int
|
||||
|
||||
def validate(self):
|
||||
assert self.num_nodes > 0
|
||||
assert self.transmission_rate_per_sec > 0
|
||||
assert self.peering_degree > 0
|
||||
assert self.max_mix_path_length > 0
|
||||
|
||||
def node_configs(self) -> list[NodeConfig]:
|
||||
return [
|
||||
NodeConfig(
|
||||
X25519PrivateKey.generate(),
|
||||
self.peering_degree,
|
||||
self.transmission_rate_per_sec,
|
||||
)
|
||||
for _ in range(self.num_nodes)
|
||||
]
|
16
mixnet/sim/config.yaml
Normal file
16
mixnet/sim/config.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
simulation:
|
||||
time_scale: 0.001
|
||||
duration_sec: 10000
|
||||
net_latency_sec: 0.01
|
||||
meter_interval_sec: 1
|
||||
|
||||
|
||||
logic:
|
||||
lottery_interval_sec: 1
|
||||
sender_prob: 0.01
|
||||
|
||||
mixnet:
|
||||
num_nodes: 5
|
||||
transmission_rate_per_sec: 10
|
||||
peering_degree: 6
|
||||
max_mix_path_length: 3
|
68
mixnet/sim/connection.py
Normal file
68
mixnet/sim/connection.py
Normal file
@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import math
|
||||
import time
|
||||
|
||||
import pandas
|
||||
|
||||
from mixnet.connection import SimplexConnection
|
||||
|
||||
|
||||
class MeteredRemoteSimplexConnection(SimplexConnection):
|
||||
latency: float
|
||||
meter_interval: float
|
||||
outputs: asyncio.Queue
|
||||
conn: asyncio.Queue
|
||||
inputs: asyncio.Queue
|
||||
output_task: asyncio.Task
|
||||
output_meters: list[int]
|
||||
input_task: asyncio.Task
|
||||
input_meters: list[int]
|
||||
|
||||
def __init__(self, latency: float, meter_interval: float):
|
||||
self.latency = latency
|
||||
self.meter_interval = meter_interval
|
||||
self.outputs = asyncio.Queue()
|
||||
self.conn = asyncio.Queue()
|
||||
self.inputs = asyncio.Queue()
|
||||
self.output_meters = []
|
||||
self.output_task = asyncio.create_task(self.__run_output_task())
|
||||
self.input_meters = []
|
||||
self.input_task = asyncio.create_task(self.__run_input_task())
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
await self.outputs.put(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.inputs.get()
|
||||
|
||||
async def __run_output_task(self):
|
||||
start_time = time.time()
|
||||
while True:
|
||||
data = await self.outputs.get()
|
||||
self.__update_meter(self.output_meters, len(data), start_time)
|
||||
await self.conn.put(data)
|
||||
|
||||
async def __run_input_task(self):
|
||||
start_time = time.time()
|
||||
while True:
|
||||
await asyncio.sleep(self.latency)
|
||||
data = await self.conn.get()
|
||||
self.__update_meter(self.input_meters, len(data), start_time)
|
||||
await self.inputs.put(data)
|
||||
|
||||
def __update_meter(self, meters: list[int], size: int, start_time: float):
|
||||
slot = math.floor((time.time() - start_time) / self.meter_interval)
|
||||
assert slot >= len(meters) - 1
|
||||
meters.extend([0] * (slot - len(meters) + 1))
|
||||
meters[-1] += size
|
||||
|
||||
def output_bandwidths(self) -> pandas.Series:
|
||||
return self.__bandwidths(self.output_meters)
|
||||
|
||||
def input_bandwidths(self) -> pandas.Series:
|
||||
return self.__bandwidths(self.input_meters)
|
||||
|
||||
def __bandwidths(self, meters: list[int]) -> pandas.Series:
|
||||
return pandas.Series(meters, name="bandwidth").map(
|
||||
lambda x: x / self.meter_interval
|
||||
)
|
21
mixnet/sim/main.py
Normal file
21
mixnet/sim/main.py
Normal file
@ -0,0 +1,21 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from mixnet.sim.config import Config
|
||||
from mixnet.sim.simulation import Simulation
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run mixnet simulation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Configuration file path"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = Config.load(args.config)
|
||||
sim = Simulation(config)
|
||||
asyncio.run(sim.run())
|
||||
|
||||
print("Simulation complete!")
|
69
mixnet/sim/simulation.py
Normal file
69
mixnet/sim/simulation.py
Normal file
@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
|
||||
from mixnet.config import GlobalConfig, MixMembership, NodeInfo
|
||||
from mixnet.node import Node
|
||||
from mixnet.sim.config import Config
|
||||
from mixnet.sim.connection import MeteredRemoteSimplexConnection
|
||||
from mixnet.sim.stats import ConnectionStats
|
||||
|
||||
|
||||
class Simulation:
|
||||
def __init__(self, config: Config):
|
||||
random.seed()
|
||||
self.config = config
|
||||
|
||||
async def run(self):
|
||||
nodes, conn_measurement = self.init_nodes()
|
||||
|
||||
deadline = time.time() + self.scaled_time(self.config.simulation.duration_sec)
|
||||
tasks: list[asyncio.Task] = []
|
||||
for node in nodes:
|
||||
tasks.append(asyncio.create_task(self.run_logic(node, deadline)))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
conn_measurement.bandwidths()
|
||||
|
||||
def init_nodes(self) -> tuple[list[Node], ConnectionStats]:
|
||||
node_configs = self.config.mixnet.node_configs()
|
||||
global_config = GlobalConfig(
|
||||
MixMembership(
|
||||
[
|
||||
NodeInfo(node_config.private_key.public_key())
|
||||
for node_config in node_configs
|
||||
]
|
||||
),
|
||||
self.scaled_rate(self.config.mixnet.transmission_rate_per_sec),
|
||||
self.config.mixnet.max_mix_path_length,
|
||||
)
|
||||
nodes = [Node(node_config, global_config) for node_config in node_configs]
|
||||
|
||||
conn_stats = ConnectionStats()
|
||||
for i, node in enumerate(nodes):
|
||||
inbound_conn, outbound_conn = self.create_conn(), self.create_conn()
|
||||
node.connect(nodes[(i + 1) % len(nodes)], inbound_conn, outbound_conn)
|
||||
conn_stats.register(node, inbound_conn, outbound_conn)
|
||||
|
||||
return nodes, conn_stats
|
||||
|
||||
def create_conn(self) -> MeteredRemoteSimplexConnection:
|
||||
return MeteredRemoteSimplexConnection(
|
||||
latency=self.scaled_time(self.config.simulation.net_latency_sec),
|
||||
meter_interval=self.scaled_time(self.config.simulation.meter_interval_sec),
|
||||
)
|
||||
|
||||
async def run_logic(self, node: Node, deadline: float):
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(
|
||||
self.scaled_time(self.config.logic.lottery_interval_sec)
|
||||
)
|
||||
|
||||
if random.random() < self.config.logic.sender_prob:
|
||||
await node.send_message(b"selected block")
|
||||
|
||||
def scaled_time(self, time: float) -> float:
|
||||
return time * self.config.simulation.time_scale
|
||||
|
||||
def scaled_rate(self, rate: int) -> float:
|
||||
return float(rate / self.config.simulation.time_scale)
|
52
mixnet/sim/stats.py
Normal file
52
mixnet/sim/stats.py
Normal file
@ -0,0 +1,52 @@
|
||||
import pandas
|
||||
|
||||
from mixnet.node import Node
|
||||
from mixnet.sim.connection import MeteredRemoteSimplexConnection
|
||||
|
||||
NodeConnectionsMap = dict[
|
||||
Node,
|
||||
tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]],
|
||||
]
|
||||
|
||||
|
||||
class ConnectionStats:
|
||||
conns_per_node: NodeConnectionsMap
|
||||
|
||||
def __init__(self):
|
||||
self.conns_per_node = dict()
|
||||
|
||||
def register(
|
||||
self,
|
||||
node: Node,
|
||||
inbound_conn: MeteredRemoteSimplexConnection,
|
||||
outbound_conn: MeteredRemoteSimplexConnection,
|
||||
):
|
||||
if node not in self.conns_per_node:
|
||||
self.conns_per_node[node] = ([], [])
|
||||
self.conns_per_node[node][0].append(inbound_conn)
|
||||
self.conns_per_node[node][1].append(outbound_conn)
|
||||
|
||||
def bandwidths(self):
|
||||
for i, (_, (inbound_conns, outbound_conns)) in enumerate(
|
||||
self.conns_per_node.items()
|
||||
):
|
||||
inbound_bandwidths = (
|
||||
pandas.concat(
|
||||
[conn.input_bandwidths() for conn in inbound_conns], axis=1
|
||||
)
|
||||
.sum(axis=1)
|
||||
.map(lambda x: x / 1024 / 1024)
|
||||
)
|
||||
outbound_bandwidths = (
|
||||
pandas.concat(
|
||||
[conn.output_bandwidths() for conn in outbound_conns], axis=1
|
||||
)
|
||||
.sum(axis=1)
|
||||
.map(lambda x: x / 1024 / 1024)
|
||||
)
|
||||
|
||||
print(f"=== [Node:{i}] ===")
|
||||
print("--- Inbound bandwidths ---")
|
||||
print(inbound_bandwidths.describe())
|
||||
print("--- Outbound bandwidths ---")
|
||||
print(outbound_bandwidths.describe())
|
Loading…
x
Reference in New Issue
Block a user