import matplotlib.pyplot as plt import pandas from mixnet.node import Node from mixnet.sim.connection import MeteredRemoteSimplexConnection NodeConnectionsMap = dict[ Node, tuple[list[MeteredRemoteSimplexConnection], list[MeteredRemoteSimplexConnection]], ] class ConnectionStats: conns_per_node: NodeConnectionsMap def __init__(self): self.conns_per_node = dict() def register( self, node: Node, inbound_conn: MeteredRemoteSimplexConnection, outbound_conn: MeteredRemoteSimplexConnection, ): if node not in self.conns_per_node: self.conns_per_node[node] = ([], []) self.conns_per_node[node][0].append(inbound_conn) self.conns_per_node[node][1].append(outbound_conn) def bandwidths(self): self._bandwidths_per_conn() self._bandwidths_per_node() def _bandwidths_per_conn(self): _, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6)) for _, (inbound_conns, outbound_conns) in self.conns_per_node.items(): for conn in inbound_conns: inbound_bandwidths = conn.input_bandwidths().map(lambda x: x / 1024) axs[0].plot(inbound_bandwidths.index, inbound_bandwidths) for conn in outbound_conns: outbound_bandwidths = conn.output_bandwidths().map(lambda x: x / 1024) axs[1].plot(outbound_bandwidths.index, outbound_bandwidths) axs[0].set_title("Inbound Bandwidths per Connection") axs[0].set_xlabel("Time (s)") axs[0].set_ylabel("Bandwidth (KB/s)") axs[0].set_ylim(bottom=0) axs[0].grid(True) axs[1].set_title("Outbound Bandwidths per Connection") axs[1].set_xlabel("Time (s)") axs[1].set_ylabel("Bandwidth (KB/s)") axs[1].set_ylim(bottom=0) axs[1].grid(True) plt.tight_layout() plt.show() def _bandwidths_per_node(self): _, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6)) for i, (_, (inbound_conns, outbound_conns)) in enumerate( self.conns_per_node.items() ): inbound_bandwidths = ( pandas.concat( [conn.input_bandwidths() for conn in inbound_conns], axis=1 ) .sum(axis=1) .map(lambda x: x / 1024) ) outbound_bandwidths = ( pandas.concat( [conn.output_bandwidths() for conn in outbound_conns], axis=1 ) .sum(axis=1) .map(lambda x: x / 1024) ) axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}") axs[1].plot( outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}" ) axs[0].set_title("Inbound Bandwidths per Node") axs[0].set_xlabel("Time (s)") axs[0].set_ylabel("Bandwidth (KB/s)") axs[0].legend() axs[0].set_ylim(bottom=0) axs[0].grid(True) axs[1].set_title("Outbound Bandwidths per Node") axs[1].set_xlabel("Time (s)") axs[1].set_ylabel("Bandwidth (KB/s)") axs[1].legend() axs[1].set_ylim(bottom=0) axs[1].grid(True) plt.tight_layout() plt.show()