2024-07-08 16:02:42 +09:00

102 lines
3.3 KiB
Python

import matplotlib.pyplot as plt
import pandas
from mixnet.node import Node
from mixnet.sim.connection import MeteredRemoteSimplexConnection
NodeConnectionsMap = dict[
Node,
tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]],
]
class ConnectionStats:
conns_per_node: NodeConnectionsMap
def __init__(self):
self.conns_per_node = dict()
def register(
self,
node: Node,
inbound_conn: MeteredRemoteSimplexConnection,
outbound_conn: MeteredRemoteSimplexConnection,
):
if node not in self.conns_per_node:
self.conns_per_node[node] = ([], [])
self.conns_per_node[node][0].append(inbound_conn)
self.conns_per_node[node][1].append(outbound_conn)
def bandwidths(self):
self._bandwidths_per_conn()
self._bandwidths_per_node()
def _bandwidths_per_conn(self):
_, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
for _, (inbound_conns, outbound_conns) in self.conns_per_node.items():
for conn in inbound_conns:
inbound_bandwidths = conn.input_bandwidths().map(lambda x: x / 1024)
axs[0].plot(inbound_bandwidths.index, inbound_bandwidths)
for conn in outbound_conns:
outbound_bandwidths = conn.output_bandwidths().map(lambda x: x / 1024)
axs[1].plot(outbound_bandwidths.index, outbound_bandwidths)
axs[0].set_title("Inbound Bandwidths per Connection")
axs[0].set_xlabel("Time (s)")
axs[0].set_ylabel("Bandwidth (KB/s)")
axs[0].set_ylim(bottom=0)
axs[0].grid(True)
axs[1].set_title("Outbound Bandwidths per Connection")
axs[1].set_xlabel("Time (s)")
axs[1].set_ylabel("Bandwidth (KB/s)")
axs[1].set_ylim(bottom=0)
axs[1].grid(True)
plt.tight_layout()
plt.show()
def _bandwidths_per_node(self):
_, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
for i, (_, (inbound_conns, outbound_conns)) in enumerate(
self.conns_per_node.items()
):
inbound_bandwidths = (
pandas.concat(
[conn.input_bandwidths() for conn in inbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024)
)
outbound_bandwidths = (
pandas.concat(
[conn.output_bandwidths() for conn in outbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024)
)
axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}")
axs[1].plot(
outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}"
)
axs[0].set_title("Inbound Bandwidths per Node")
axs[0].set_xlabel("Time (s)")
axs[0].set_ylabel("Bandwidth (KB/s)")
axs[0].legend()
axs[0].set_ylim(bottom=0)
axs[0].grid(True)
axs[1].set_title("Outbound Bandwidths per Node")
axs[1].set_xlabel("Time (s)")
axs[1].set_ylabel("Bandwidth (KB/s)")
axs[1].legend()
axs[1].set_ylim(bottom=0)
axs[1].grid(True)
plt.tight_layout()
plt.show()