diff --git a/mixnet/sim/connection.py b/mixnet/sim/connection.py index af555e2..0cba337 100644 --- a/mixnet/sim/connection.py +++ b/mixnet/sim/connection.py @@ -1,4 +1,5 @@ import math +from collections import Counter from typing import Awaitable import pandas @@ -21,6 +22,7 @@ class MeteredRemoteSimplexConnection(SimplexConnection): recv_meters: list[int] send_node_states: list[NodeState] recv_node_states: list[NodeState] + msg_sizes: Counter[int] def __init__( self, @@ -40,9 +42,11 @@ class MeteredRemoteSimplexConnection(SimplexConnection): self.recv_task = framework.spawn(self.__run_recv_task()) self.send_node_states = send_node_states self.recv_node_states = recv_node_states + self.msg_sizes = Counter() async def send(self, data: bytes) -> None: await self.send_queue.put(data) + self.msg_sizes.update([len(data)]) ms = math.floor(self.framework.now() * 1000) self.send_node_states[ms] = NodeState.SENDING diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py index 98cf187..e1775c2 100644 --- a/mixnet/sim/simulation.py +++ b/mixnet/sim/simulation.py @@ -19,7 +19,7 @@ class Simulation: async def run(self): conn_stats, all_node_states = await self._run() - conn_stats.bandwidths() + conn_stats.analyze() all_node_states.analyze() async def _run(self) -> tuple[ConnectionStats, AllNodeStates]: diff --git a/mixnet/sim/stats.py b/mixnet/sim/stats.py index 1a88918..0e4f745 100644 --- a/mixnet/sim/stats.py +++ b/mixnet/sim/stats.py @@ -1,3 +1,5 @@ +from collections import Counter + import matplotlib.pyplot as plt import pandas @@ -27,10 +29,21 @@ class ConnectionStats: self.conns_per_node[node][0].append(inbound_conn) self.conns_per_node[node][1].append(outbound_conn) - def bandwidths(self): + def analyze(self): + self._message_sizes() self._bandwidths_per_conn() self._bandwidths_per_node() + def _message_sizes(self): + sizes = Counter() + for _, (_, outbound_conns) in self.conns_per_node.items(): + for conn in outbound_conns: + sizes.update(conn.msg_sizes) + + df = pandas.DataFrame.from_dict(sizes, orient="index").reset_index() + df.columns = ["msg_size", "count"] + print(df) + def _bandwidths_per_conn(self): plt.plot(figsize=(12, 6))