add node state collection

This commit is contained in:
Youngjoon Lee 2024-07-05 22:57:46 +09:00
parent 2e58207cf0
commit 1a0c47cfbf
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
7 changed files with 92 additions and 16 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
.venv
__pycache__
*.csv

View File

@ -19,7 +19,8 @@ class Framework(framework.Framework):
await (usim.time + seconds)
def now(self) -> float:
return usim.time.now
# round to milliseconds to make analysis not too heavy
return int(usim.time.now * 1000) / 1000
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]

View File

@ -54,6 +54,10 @@ class NetworkConfig:
assert self.max_latency_sec > 0
assert self.seed is not None
def random_latency(self) -> float:
# round to milliseconds to make analysis not too heavy
return int(self.seed.random() * self.max_latency_sec * 1000) / 1000
@dataclass
class LogicConfig:

View File

@ -1,7 +1,7 @@
simulation:
duration_sec: 1000
network:
max_latency_sec: 0.01
max_latency_sec: 0.1
seed: 0
logic:

View File

@ -6,6 +6,7 @@ import pandas
from mixnet.connection import SimplexConnection
from mixnet.framework.framework import Framework, Queue
from mixnet.sim.config import NetworkConfig
from mixnet.sim.state import NodeState
class MeteredRemoteSimplexConnection(SimplexConnection):
@ -18,10 +19,18 @@ class MeteredRemoteSimplexConnection(SimplexConnection):
output_meters: list[int]
input_task: Awaitable
input_meters: list[int]
output_node_states: list[NodeState]
input_node_states: list[NodeState]
def __init__(self, config: NetworkConfig, framework: Framework):
def __init__(
self,
config: NetworkConfig,
framework: Framework,
output_node_states: list[NodeState],
input_node_states: list[NodeState],
):
self.framework = framework
self.latency = config.seed.random() * config.max_latency_sec
self.latency = config.random_latency()
self.outputs = framework.queue()
self.conn = framework.queue()
self.inputs = framework.queue()
@ -29,12 +38,19 @@ class MeteredRemoteSimplexConnection(SimplexConnection):
self.output_task = framework.spawn(self.__run_output_task())
self.input_meters = []
self.input_task = framework.spawn(self.__run_input_task())
self.output_node_states = output_node_states
self.input_node_states = input_node_states
async def send(self, data: bytes) -> None:
await self.outputs.put(data)
ms = math.floor(self.framework.now() * 1000)
self.output_node_states[ms] = NodeState.SENDING
async def recv(self) -> bytes:
return await self.inputs.get()
data = await self.inputs.get()
ms = math.floor(self.framework.now() * 1000)
self.output_node_states[ms] = NodeState.RECEIVING
return data
async def __run_output_task(self):
start_time = self.framework.now()

View File

@ -6,6 +6,7 @@ from mixnet.framework.framework import Framework
from mixnet.node import Node, PeeringDegreeReached
from mixnet.sim.config import Config
from mixnet.sim.connection import MeteredRemoteSimplexConnection
from mixnet.sim.state import AllNodeStates, NodeState
from mixnet.sim.stats import ConnectionStats
@ -17,19 +18,20 @@ class Simulation:
self.config = config
async def run(self):
conn_stats = await self._run()
conn_stats, all_node_states = await self._run()
conn_stats.bandwidths()
all_node_states.analyze()
async def _run(self) -> ConnectionStats:
async def _run(self) -> tuple[ConnectionStats, AllNodeStates]:
async with usim.until(usim.time + self.config.simulation.duration_sec) as scope:
self.framework = usimfw.Framework(scope)
nodes, conn_stats = self.init_nodes()
nodes, conn_stats, all_node_states = self.init_nodes()
for node in nodes:
self.framework.spawn(self.run_logic(node))
return conn_stats
return conn_stats, all_node_states
assert False # unreachable
def init_nodes(self) -> tuple[list[Node], ConnectionStats]:
def init_nodes(self) -> tuple[list[Node], ConnectionStats, AllNodeStates]:
node_configs = self.config.mixnet.node_configs()
global_config = GlobalConfig(
MixMembership(
@ -47,13 +49,18 @@ class Simulation:
for node_config in node_configs
]
all_node_states = AllNodeStates(len(nodes), self.config.simulation.duration_sec)
conn_stats = ConnectionStats()
for i, node in enumerate(nodes):
peer_idx = (i + 1) % len(nodes)
peer = nodes[peer_idx]
node_states = all_node_states[i]
peer_states = all_node_states[peer_idx]
inbound_conn, outbound_conn = (
self.create_conn(),
self.create_conn(),
self.create_conn(peer_states, node_states),
self.create_conn(node_states, peer_states),
)
peer = nodes[(i + 1) % len(nodes)]
try:
node.connect(peer, inbound_conn, outbound_conn)
except PeeringDegreeReached:
@ -61,11 +68,16 @@ class Simulation:
conn_stats.register(node, inbound_conn, outbound_conn)
conn_stats.register(peer, outbound_conn, inbound_conn)
return nodes, conn_stats
return nodes, conn_stats, all_node_states
def create_conn(self) -> MeteredRemoteSimplexConnection:
def create_conn(
self, sender_states: list[NodeState], receiver_states: list[NodeState]
) -> MeteredRemoteSimplexConnection:
return MeteredRemoteSimplexConnection(
self.config.simulation.network, self.framework
self.config.simulation.network,
self.framework,
sender_states,
receiver_states,
)
async def run_logic(self, node: Node):

41
mixnet/sim/state.py Normal file
View File

@ -0,0 +1,41 @@
from datetime import datetime
from enum import Enum
import pandas
class NodeState(Enum):
SENDING = -1
IDLE = 0
RECEIVING = 1
class AllNodeStates:
_table: list[list[NodeState]]
def __init__(self, num_nodes: int, duration_sec: int):
self._table = [
[NodeState.IDLE] * (duration_sec * 1000) for _ in range(num_nodes)
]
def __getitem__(self, idx: int) -> list[NodeState]:
return self._table[idx]
def analyze(self):
df = pandas.DataFrame(self._table).transpose()
df.columns = [f"Node-{i}" for i in range(len(self._table))]
print(df)
csv_path = f"all_node_states_{datetime.now().isoformat(timespec="seconds")}.csv"
df.to_csv(csv_path)
print(f"\nSaved DataFrame to {csv_path}\n")
# 1. Count the number of each state for each node
state_counts = df.apply(pandas.Series.value_counts).fillna(0)
# 2. Calculate the percentage of each state for each node
state_percentages = state_counts.div(state_counts.sum(axis=0), axis=1) * 100
print("State Counts per Node:")
print(state_counts)
print("\nState Percentages per Node:")
print(state_percentages)