diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py index 98cf187..a575bbb 100644 --- a/mixnet/sim/simulation.py +++ b/mixnet/sim/simulation.py @@ -1,3 +1,4 @@ +import matplotlib.pyplot as plt import usim import mixnet.framework.usim as usimfw @@ -21,6 +22,8 @@ class Simulation: conn_stats, all_node_states = await self._run() conn_stats.bandwidths() all_node_states.analyze() + # Show all plots that have been created so far + plt.show() async def _run(self) -> tuple[ConnectionStats, AllNodeStates]: async with usim.until(usim.time + self.config.simulation.duration_sec) as scope: diff --git a/mixnet/sim/stats.py b/mixnet/sim/stats.py index 2e30ff3..7a51dab 100644 --- a/mixnet/sim/stats.py +++ b/mixnet/sim/stats.py @@ -1,3 +1,4 @@ +import matplotlib.pyplot as plt import pandas from mixnet.node import Node @@ -27,26 +28,44 @@ class ConnectionStats: self.conns_per_node[node][1].append(outbound_conn) def bandwidths(self): - for i, (_, (inbound_conns, outbound_conns)) in enumerate( - self.conns_per_node.items() - ): + plt.figure(figsize=(12, 6)) + + plt.subplot(2, 1, 1) + + for i, (_, (inbound_conns, _)) in enumerate(self.conns_per_node.items()): inbound_bandwidths = ( pandas.concat( [conn.input_bandwidths() for conn in inbound_conns], axis=1 ) .sum(axis=1) - .map(lambda x: x / 1024 / 1024) + .map(lambda x: x / 1024) ) + plt.plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}") + + plt.xlabel("Time (s)") + plt.ylabel("Bandwidth (KB/s)") + plt.title("Inbound Bandwidths per Node") + plt.legend() + plt.ylim(bottom=0) + plt.grid(True) + + plt.subplot(2, 1, 2) + + for i, (_, (_, outbound_conns)) in enumerate(self.conns_per_node.items()): outbound_bandwidths = ( pandas.concat( [conn.output_bandwidths() for conn in outbound_conns], axis=1 ) .sum(axis=1) - .map(lambda x: x / 1024 / 1024) + .map(lambda x: x / 1024) ) + plt.plot(outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}") - print(f"=== [Node:{i}] ===") - print("--- Inbound bandwidths ---") - print(inbound_bandwidths.describe()) - print("--- Outbound bandwidths ---") - print(outbound_bandwidths.describe()) + plt.xlabel("Time (s)") + plt.ylabel("Bandwidth (KB/s)") + plt.title("Outbound Bandwidths per Node") + plt.legend() + plt.ylim(bottom=0) + plt.grid(True) + + plt.tight_layout()