From dbf1b781344e0968f3e85a07094dcad9520fb60f Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:40:06 +0900 Subject: [PATCH] optimize: use generic for messages to reduce the size of msg cache in gossip and discard serde cost --- mixnet/protocol/connection.py | 37 +++++++------- mixnet/protocol/gossip.py | 48 ++++++++++-------- mixnet/protocol/node.py | 90 +++++++++++++++++++-------------- mixnet/protocol/nomssip.py | 95 ++++++++++++++++++----------------- mixnet/protocol/sphinx.py | 12 +++++ mixnet/protocol/test_node.py | 51 +++++++++++++++---- mixnet/queuesim/message.py | 30 +++++++++++ mixnet/queuesim/node.py | 20 +++++--- mixnet/queuesim/queuesim.py | 3 +- mixnet/queuesim/simulation.py | 54 ++++++-------------- mixnet/sim/connection.py | 40 +++++++++------ mixnet/sim/message.py | 32 +++++++++--- mixnet/sim/simulation.py | 53 ++++++++++--------- mixnet/sim/stats.py | 11 ++-- mixnet/sim/test_message.py | 13 +++-- 15 files changed, 352 insertions(+), 237 deletions(-) create mode 100644 mixnet/queuesim/message.py diff --git a/mixnet/protocol/connection.py b/mixnet/protocol/connection.py index a7d1a6c..e5e92fa 100644 --- a/mixnet/protocol/connection.py +++ b/mixnet/protocol/connection.py @@ -1,59 +1,62 @@ from __future__ import annotations import abc +from typing import Generic, TypeVar from framework import Framework, Queue from protocol.temporalmix import TemporalMix, TemporalMixConfig +T = TypeVar("T") -class SimplexConnection(abc.ABC): + +class SimplexConnection(abc.ABC, Generic[T]): """ An abstract class for a simplex connection that can send and receive data in one direction """ @abc.abstractmethod - async def send(self, data: bytes) -> None: + async def send(self, data: T) -> None: pass @abc.abstractmethod - async def recv(self) -> bytes: + async def recv(self) -> T: pass -class LocalSimplexConnection(SimplexConnection): +class LocalSimplexConnection(SimplexConnection[T]): """ A simplex connection that doesn't have any network latency. Data sent through this connection can be immediately received from the other end. """ def __init__(self, framework: Framework): - self.queue: Queue[bytes] = framework.queue() + self.queue: Queue[T] = framework.queue() - async def send(self, data: bytes) -> None: + async def send(self, data: T) -> None: await self.queue.put(data) - async def recv(self) -> bytes: + async def recv(self) -> T: return await self.queue.get() -class DuplexConnection: +class DuplexConnection(Generic[T]): """ A duplex connection in which data can be transmitted and received simultaneously in both directions. This is to mimic duplex communication in a real network (such as TCP or QUIC). """ - def __init__(self, inbound: SimplexConnection, outbound: SimplexConnection): + def __init__(self, inbound: SimplexConnection[T], outbound: SimplexConnection[T]): self.inbound = inbound self.outbound = outbound - async def recv(self) -> bytes: + async def recv(self) -> T: return await self.inbound.recv() - async def send(self, packet: bytes): + async def send(self, packet: T): await self.outbound.send(packet) -class MixSimplexConnection(SimplexConnection): +class MixSimplexConnection(SimplexConnection[T]): """ Wraps a SimplexConnection to add a transmission rate and noise to the connection. """ @@ -61,16 +64,16 @@ class MixSimplexConnection(SimplexConnection): def __init__( self, framework: Framework, - conn: SimplexConnection, + conn: SimplexConnection[T], transmission_rate_per_sec: int, - noise_msg: bytes, + noise_msg: T, temporal_mix_config: TemporalMixConfig, # OPTIMIZATION ONLY FOR EXPERIMENTS WITHOUT BANDWIDTH MEASUREMENT # If True, skip sending a noise even if it's time to send one. skip_sending_noise: bool, ): self.framework = framework - self.queue: Queue[bytes] = TemporalMix.queue( + self.queue: Queue[T] = TemporalMix.queue( temporal_mix_config, framework, noise_msg ) self.conn = conn @@ -87,8 +90,8 @@ class MixSimplexConnection(SimplexConnection): continue await self.conn.send(msg) - async def send(self, data: bytes) -> None: + async def send(self, data: T) -> None: await self.queue.put(data) - async def recv(self) -> bytes: + async def recv(self) -> T: return await self.conn.recv() diff --git a/mixnet/protocol/gossip.py b/mixnet/protocol/gossip.py index 01d0e49..ce8e25c 100644 --- a/mixnet/protocol/gossip.py +++ b/mixnet/protocol/gossip.py @@ -1,8 +1,7 @@ from __future__ import annotations -import hashlib from dataclasses import dataclass -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Generic, Protocol, TypeVar from framework import Framework from protocol.connection import ( @@ -18,7 +17,14 @@ class GossipConfig: peering_degree: int -class Gossip: +class HasId(Protocol): + def id(self) -> int: ... + + +T = TypeVar("T", bound=HasId) + + +class Gossip(Generic[T]): """ A gossip channel that broadcasts messages to all connected peers. Peers are connected via DuplexConnection. @@ -28,15 +34,15 @@ class Gossip: self, framework: Framework, config: GossipConfig, - handler: Callable[[bytes], Awaitable[None]], + handler: Callable[[T], Awaitable[None]], ): self.framework = framework self.config = config - self.conns: list[DuplexConnection] = [] + self.conns: list[DuplexConnection[T]] = [] # A handler to process inbound messages. self.handler = handler - # msg -> received_cnt - self.packet_cache: dict[bytes, int] = dict() + # msg_id -> received_cnt + self.packet_cache: dict[int, int] = dict() # A set just for gathering a reference of tasks to prevent them from being garbage collected. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task self.tasks: set[Awaitable] = set() @@ -44,12 +50,12 @@ class Gossip: def can_accept_conn(self) -> bool: return len(self.conns) < self.config.peering_degree - def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): + def add_conn(self, inbound: SimplexConnection[T], outbound: SimplexConnection[T]): if not self.can_accept_conn(): # For simplicity of the spec, reject the connection if the peering degree is reached. raise PeeringDegreeReached() - conn = DuplexConnection( + conn = DuplexConnection[T]( inbound, outbound, ) @@ -57,18 +63,18 @@ class Gossip: task = self.framework.spawn(self.__process_inbound_conn(conn)) self.tasks.add(task) - async def __process_inbound_conn(self, conn: DuplexConnection): + async def __process_inbound_conn(self, conn: DuplexConnection[T]): while True: msg = await conn.recv() if self._check_update_cache(msg): continue await self._process_inbound_msg(msg, conn) - async def _process_inbound_msg(self, msg: bytes, received_from: DuplexConnection): + async def _process_inbound_msg(self, msg: T, received_from: DuplexConnection[T]): await self._gossip(msg, [received_from]) await self.handler(msg) - async def publish(self, msg: bytes): + async def publish(self, msg: T): """ Publish a message to all nodes in the network. """ @@ -83,7 +89,7 @@ class Gossip: # which means that we consider that this publisher node received the message. await self.handler(msg) - async def _gossip(self, msg: bytes, excludes: list[DuplexConnection] = []): + async def _gossip(self, msg: T, excludes: list[DuplexConnection] = []): """ Gossip a message to all peers connected to this node. """ @@ -91,26 +97,26 @@ class Gossip: if conn not in excludes: await conn.send(msg) - def _check_update_cache(self, packet: bytes, publishing: bool = False) -> bool: + def _check_update_cache(self, msg: T, publishing: bool = False) -> bool: """ Add a message to the cache, and return True if the message was already in the cache. """ - hash = hashlib.sha256(packet).digest() - seen = hash in self.packet_cache + id = msg.id() + seen = id in self.packet_cache if publishing: if not seen: # Put 0 when publishing, so that the publisher node doesn't gossip the message again # even when it first receive the message from one of its peers later. - self.packet_cache[hash] = 0 + self.packet_cache[id] = 0 else: if not seen: - self.packet_cache[hash] = 1 + self.packet_cache[id] = 1 else: - self.packet_cache[hash] += 1 + self.packet_cache[id] += 1 # Remove the message from the cache if it's received from all adjacent peers in the end # to reduce the size of cache. - if self.packet_cache[hash] >= self.config.peering_degree: - del self.packet_cache[hash] + if self.packet_cache[id] >= self.config.peering_degree: + del self.packet_cache[id] return seen diff --git a/mixnet/protocol/node.py b/mixnet/protocol/node.py index 4b6d77f..490698e 100644 --- a/mixnet/protocol/node.py +++ b/mixnet/protocol/node.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Generic, Protocol, Self, Type, TypeVar from pysphinx.sphinx import ( ProcessedFinalHopPacket, @@ -13,11 +13,22 @@ from protocol.config import GlobalConfig, NodeConfig from protocol.connection import SimplexConnection from protocol.error import PeeringDegreeReached from protocol.gossip import Gossip -from protocol.nomssip import Nomssip, NomssipConfig +from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage from protocol.sphinx import SphinxPacketBuilder -class Node: +class HasIdAndLenAndBytes(Protocol): + def id(self) -> int: ... + def __len__(self) -> int: ... + def __bytes__(self) -> bytes: ... + @classmethod + def from_bytes(cls, data: bytes) -> Self: ... + + +T = TypeVar("T", bound=HasIdAndLenAndBytes) + + +class Node(Generic[T]): """ This represents any node in the network, which: - generates/gossips mix messages (Sphinx packets) @@ -31,57 +42,53 @@ class Node: config: NodeConfig, global_config: GlobalConfig, # A handler called when a node receives a broadcasted message originated from the last mix. - broadcasted_msg_handler: Callable[[bytes], Awaitable[None]], - # An optional handler only for the simulation, - # which is called when a message is fully recovered by the last mix + broadcasted_msg_handler: Callable[[T], Awaitable[None]], + # A handler called when a message is fully recovered by the last mix # and returns a new message to be broadcasted. - recovered_msg_handler: Callable[[bytes], Awaitable[bytes]] | None = None, + recovered_msg_handler: Callable[[bytes], Awaitable[T]], + noise_msg: T, ): self.framework = framework self.config = config self.global_config = global_config + nomssip_config = NomssipConfig( + config.gossip.peering_degree, + global_config.transmission_rate_per_sec, + SphinxPacketBuilder.size(global_config), + config.temporal_mix, + ) self.nomssip = Nomssip( framework, - NomssipConfig( - config.gossip.peering_degree, - global_config.transmission_rate_per_sec, - self.__calculate_message_size(global_config), - config.temporal_mix, - ), + nomssip_config, self.__process_msg, + noise_msg=NomssipMessage[T](NomssipMessage.Flag.NOISE, noise_msg), ) - self.broadcast = Gossip(framework, config.gossip, broadcasted_msg_handler) + self.broadcast = Gossip[T](framework, config.gossip, broadcasted_msg_handler) self.recovered_msg_handler = recovered_msg_handler - @staticmethod - def __calculate_message_size(global_config: GlobalConfig) -> int: - """ - Calculate the actual message size to be gossiped, which depends on the maximum length of mix path. - """ - sample_sphinx_packet, _ = SphinxPacketBuilder.build( - bytes(global_config.max_message_size), - global_config, - global_config.max_mix_path_length, - ) - return len(sample_sphinx_packet.bytes()) - - async def __process_msg(self, msg: bytes) -> None: + async def __process_msg(self, msg: NomssipMessage[T]) -> None: """ A handler to process messages received via Nomssip channel """ + assert msg.flag == NomssipMessage.Flag.REAL + sphinx_packet = SphinxPacket.from_bytes( - msg, self.global_config.max_mix_path_length + bytes(msg.message), self.global_config.max_mix_path_length ) result = await self.__process_sphinx_packet(sphinx_packet) match result: case SphinxPacket(): # Gossip the next Sphinx packet - await self.nomssip.publish(result.bytes()) + t: Type[T] = type(msg.message) + await self.nomssip.publish( + NomssipMessage[T]( + NomssipMessage.Flag.REAL, + t.from_bytes(result.bytes()), + ) + ) case bytes(): - if self.recovered_msg_handler is not None: - result = await self.recovered_msg_handler(result) # Broadcast the message fully recovered from Sphinx packets - await self.broadcast.publish(result) + await self.broadcast.publish(await self.recovered_msg_handler(result)) case None: return @@ -105,31 +112,36 @@ class Node: def connect_mix( self, peer: Node, - inbound_conn: SimplexConnection, - outbound_conn: SimplexConnection, + inbound_conn: SimplexConnection[NomssipMessage[T]], + outbound_conn: SimplexConnection[NomssipMessage[T]], ): connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn) def connect_broadcast( self, peer: Node, - inbound_conn: SimplexConnection, - outbound_conn: SimplexConnection, + inbound_conn: SimplexConnection[T], + outbound_conn: SimplexConnection[T], ): connect_nodes(self.broadcast, peer.broadcast, inbound_conn, outbound_conn) - async def send_message(self, msg: bytes): + async def send_message(self, msg: T): """ Build a Sphinx packet and gossip it to all connected peers. """ # Here, we handle the case in which a msg is split into multiple Sphinx packets. # But, in practice, we expect a message to be small enough to fit in a single Sphinx packet. sphinx_packet, _ = SphinxPacketBuilder.build( - msg, + bytes(msg), self.global_config, self.config.mix_path_length, ) - await self.nomssip.publish(sphinx_packet.bytes()) + t: Type[T] = type(msg) + await self.nomssip.publish( + NomssipMessage( + NomssipMessage.Flag.REAL, t.from_bytes(sphinx_packet.bytes()) + ) + ) def connect_nodes( diff --git a/mixnet/protocol/nomssip.py b/mixnet/protocol/nomssip.py index 2cd8c9b..e3288f5 100644 --- a/mixnet/protocol/nomssip.py +++ b/mixnet/protocol/nomssip.py @@ -1,8 +1,6 @@ -from __future__ import annotations - from dataclasses import dataclass from enum import Enum -from typing import Awaitable, Callable, Self, override +from typing import Awaitable, Callable, Generic, Protocol, TypeVar, override from framework import Framework from protocol.connection import ( @@ -24,7 +22,31 @@ class NomssipConfig(GossipConfig): skip_sending_noise: bool = False -class Nomssip(Gossip): +class HasIdAndLen(Protocol): + def id(self) -> int: ... + def __len__(self) -> int: ... + + +T = TypeVar("T", bound=HasIdAndLen) + + +class NomssipMessage(Generic[T]): + class Flag(Enum): + REAL = b"\x00" + NOISE = b"\x01" + + def __init__(self, flag: Flag, message: T): + self.flag = flag + self.message = message + + def id(self) -> int: + return self.message.id() + + def __len__(self) -> int: + return len(self.flag.value) + len(self.message) + + +class Nomssip(Gossip[NomssipMessage[T]]): """ A NomMix gossip channel that extends the Gossip channel by adding global transmission rate and noise generation. @@ -34,72 +56,53 @@ class Nomssip(Gossip): self, framework: Framework, config: NomssipConfig, - handler: Callable[[bytes], Awaitable[None]], + handler: Callable[[NomssipMessage[T]], Awaitable[None]], + noise_msg: NomssipMessage[T], ): super().__init__(framework, config, handler) self.config = config + self.noise_msg = noise_msg @override - def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): - noise_packet = FlaggedPacket( - FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size) - ).bytes() + def add_conn( + self, + inbound: SimplexConnection[NomssipMessage[T]], + outbound: SimplexConnection[NomssipMessage[T]], + ): super().add_conn( inbound, - MixSimplexConnection( + MixSimplexConnection[NomssipMessage[T]]( self.framework, outbound, self.config.transmission_rate_per_sec, - noise_packet, + self.noise_msg, self.config.temporal_mix, self.config.skip_sending_noise, ), ) @override - async def _process_inbound_msg(self, msg: bytes, received_from: DuplexConnection): - packet = FlaggedPacket.from_bytes(msg) - match packet.flag: - case FlaggedPacket.Flag.NOISE: + async def _process_inbound_msg( + self, msg: NomssipMessage[T], received_from: DuplexConnection + ): + match msg.flag: + case NomssipMessage.Flag.NOISE: # Drop noise packet return - case FlaggedPacket.Flag.REAL: - self.assert_message_size(packet.message) + case NomssipMessage.Flag.REAL: + self.assert_message_size(msg.message) await super()._gossip(msg, [received_from]) - await self.handler(packet.message) + await self.handler(msg) @override - async def publish(self, msg: bytes): - self.assert_message_size(msg) + async def publish(self, msg: NomssipMessage[T]): + self.assert_message_size(msg.message) - packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg).bytes() # Please see comments in super().publish() for the reason of the following line. - if not self._check_update_cache(packet, publishing=True): - await self._gossip(packet) + if not self._check_update_cache(msg, publishing=True): + await self._gossip(msg) await self.handler(msg) - def assert_message_size(self, msg: bytes): + def assert_message_size(self, msg: T): # The message size must be fixed. assert len(msg) == self.config.msg_size, f"{len(msg)} != {self.config.msg_size}" - - -class FlaggedPacket: - class Flag(Enum): - REAL = b"\x00" - NOISE = b"\x01" - - def __init__(self, flag: Flag, message: bytes): - self.flag = flag - self.message = message - - def bytes(self) -> bytes: - return self.flag.value + self.message - - @classmethod - def from_bytes(cls, packet: bytes) -> Self: - """ - Parse a flagged packet from bytes - """ - if len(packet) < 1: - raise ValueError("Invalid message format") - return cls(cls.Flag(packet[:1]), packet[1:]) diff --git a/mixnet/protocol/sphinx.py b/mixnet/protocol/sphinx.py index a219e77..8ca431f 100644 --- a/mixnet/protocol/sphinx.py +++ b/mixnet/protocol/sphinx.py @@ -31,3 +31,15 @@ class SphinxPacketBuilder: max_plain_payload_size=global_config.max_message_size, ) return (packet, route) + + @staticmethod + def size(global_config: GlobalConfig) -> int: + """ + Calculate the size of Sphinx packet, which depends on the maximum length of mix path. + """ + sample_sphinx_packet, _ = SphinxPacketBuilder.build( + bytes(global_config.max_message_size), + global_config, + global_config.max_mix_path_length, + ) + return len(sample_sphinx_packet.bytes()) diff --git a/mixnet/protocol/test_node.py b/mixnet/protocol/test_node.py index 00f2e0b..6190f7b 100644 --- a/mixnet/protocol/test_node.py +++ b/mixnet/protocol/test_node.py @@ -1,9 +1,13 @@ +import hashlib +from dataclasses import dataclass +from typing import Self from unittest import IsolatedAsyncioTestCase import framework.asyncio as asynciofw from framework.framework import Queue from protocol.connection import LocalSimplexConnection from protocol.node import Node +from protocol.nomssip import NomssipMessage from protocol.test_utils import ( init_mixnet_config, ) @@ -14,31 +18,42 @@ class TestNode(IsolatedAsyncioTestCase): framework = asynciofw.Framework() global_config, node_configs, _ = init_mixnet_config(10) - queue: Queue[bytes] = framework.queue() + queue: Queue[Message] = framework.queue() - async def broadcasted_msg_handler(msg: bytes) -> None: + async def broadcasted_msg_handler(msg: Message) -> None: await queue.put(msg) + async def recovered_msg_handler(msg: bytes) -> Message: + return Message(msg) + nodes = [ - Node(framework, node_config, global_config, broadcasted_msg_handler) + Node[Message]( + framework, + node_config, + global_config, + broadcasted_msg_handler, + recovered_msg_handler, + noise_msg=Message(b""), + ) for node_config in node_configs ] for i, node in enumerate(nodes): try: node.connect_mix( nodes[(i + 1) % len(nodes)], - LocalSimplexConnection(framework), - LocalSimplexConnection(framework), + LocalSimplexConnection[NomssipMessage[Message]](framework), + LocalSimplexConnection[NomssipMessage[Message]](framework), ) node.connect_broadcast( nodes[(i + 1) % len(nodes)], - LocalSimplexConnection(framework), - LocalSimplexConnection(framework), + LocalSimplexConnection[Message](framework), + LocalSimplexConnection[Message](framework), ) except ValueError as e: print(e) - await nodes[0].send_message(b"block selection") + msg = Message(b"block selection") + await nodes[0].send_message(msg) # Wait for all nodes to receive the broadcast num_nodes_received_broadcast = 0 @@ -47,7 +62,7 @@ class TestNode(IsolatedAsyncioTestCase): await framework.sleep(1) while not queue.empty(): - self.assertEqual(b"block selection", await queue.get()) + self.assertEqual(msg, await queue.get()) num_nodes_received_broadcast += 1 if num_nodes_received_broadcast == len(nodes): @@ -56,3 +71,21 @@ class TestNode(IsolatedAsyncioTestCase): self.assertEqual(len(nodes), num_nodes_received_broadcast) # TODO: check noise + + +@dataclass +class Message: + data: bytes + + def id(self) -> int: + return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big") + + def __len__(self) -> int: + return len(self.data) + + def __bytes__(self) -> bytes: + return self.data + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + return cls(data) diff --git a/mixnet/queuesim/message.py b/mixnet/queuesim/message.py new file mode 100644 index 0000000..5241977 --- /dev/null +++ b/mixnet/queuesim/message.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +from framework.framework import Framework + +MESSAGE_SIZE = 1 + + +@dataclass +class Message: + _id: int + sent_time: float + + def id(self) -> int: + return self._id + + def __len__(self) -> int: + # Return any number here, since we don't use Sphinx encoding for queuesim and byte serialization. + # This must be matched with NomssipConfig.msg_size. + return MESSAGE_SIZE + + +class MessageBuilder: + def __init__(self, framework: Framework): + self.framework = framework + self.next_id = 0 + + def next(self) -> Message: + msg = Message(self.next_id, self.framework.now()) + self.next_id += 1 + return msg diff --git a/mixnet/queuesim/node.py b/mixnet/queuesim/node.py index f9bbdf6..b3470cc 100644 --- a/mixnet/queuesim/node.py +++ b/mixnet/queuesim/node.py @@ -5,7 +5,8 @@ from typing import Awaitable, Callable from framework.framework import Framework from protocol.connection import SimplexConnection from protocol.node import connect_nodes -from protocol.nomssip import Nomssip, NomssipConfig +from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage +from queuesim.message import Message class Node: @@ -13,20 +14,25 @@ class Node: self, framework: Framework, nomssip_config: NomssipConfig, - msg_handler: Callable[[bytes], Awaitable[None]], + msg_handler: Callable[[NomssipMessage[Message]], Awaitable[None]], ): - self.nomssip = Nomssip(framework, nomssip_config, msg_handler) + self.nomssip = Nomssip( + framework, + nomssip_config, + msg_handler, + noise_msg=NomssipMessage(NomssipMessage.Flag.NOISE, Message(-1, 0)), + ) def connect( self, peer: Node, - inbound_conn: SimplexConnection, - outbound_conn: SimplexConnection, + inbound_conn: SimplexConnection[NomssipMessage[Message]], + outbound_conn: SimplexConnection[NomssipMessage[Message]], ): connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn) - async def send_message(self, msg: bytes): + async def send_message(self, msg: Message): """ Send the message via Nomos Gossip to all connected peers. """ - await self.nomssip.publish(msg) + await self.nomssip.publish(NomssipMessage(NomssipMessage.Flag.REAL, msg)) diff --git a/mixnet/queuesim/queuesim.py b/mixnet/queuesim/queuesim.py index be01aa3..b93e861 100644 --- a/mixnet/queuesim/queuesim.py +++ b/mixnet/queuesim/queuesim.py @@ -16,6 +16,7 @@ import usim from protocol.nomssip import NomssipConfig from protocol.temporalmix import TemporalMixConfig, TemporalMixType from queuesim.config import Config +from queuesim.message import MESSAGE_SIZE from queuesim.paramset import ( EXPERIMENT_TITLES, ExperimentID, @@ -32,7 +33,7 @@ DEFAULT_CONFIG = Config( nomssip=NomssipConfig( peering_degree=3, transmission_rate_per_sec=10, - msg_size=8, + msg_size=MESSAGE_SIZE, temporal_mix=TemporalMixConfig( mix_type=TemporalMixType.NONE, min_queue_size=10, diff --git a/mixnet/queuesim/simulation.py b/mixnet/queuesim/simulation.py index 9fdf6c9..2d799bf 100644 --- a/mixnet/queuesim/simulation.py +++ b/mixnet/queuesim/simulation.py @@ -1,7 +1,5 @@ import csv -import struct -from dataclasses import dataclass -from typing import Counter, Self +from typing import Counter import pandas as pd import usim @@ -9,7 +7,9 @@ import usim from framework.framework import Queue from framework.usim import Framework from protocol.connection import LocalSimplexConnection, SimplexConnection +from protocol.nomssip import NomssipMessage from queuesim.config import Config +from queuesim.message import Message, MessageBuilder from queuesim.node import Node from sim.connection import RemoteSimplexConnection from sim.topology import build_full_random_topology @@ -31,7 +31,7 @@ class Simulation: self.framework.stop_tasks() async def __run(self, out_csv_path: str, topology_path: str): - self.received_msg_queue: Queue[tuple[float, bytes]] = self.framework.queue() + self.received_msg_queue: Queue[tuple[float, Message]] = self.framework.queue() # Run and connect nodes nodes = self.__run_nodes() @@ -48,7 +48,7 @@ class Simulation: writer = csv.writer(f) writer.writerow(["dissemination_time", "sent_time", "all_received_time"]) # To count how many nodes have received each message - received_msg_counters: Counter[bytes] = Counter() + received_msg_counters: Counter[int] = Counter() # To count how many results (dissemination time) have been collected so far result_cnt = 0 # Wait until all messages are disseminated to the entire network. @@ -56,13 +56,16 @@ class Simulation: # Wait until a node notifies that it has received a new message. received_time, msg = await self.received_msg_queue.get() # If the message has been received by all nodes, calculate the dissemination time. - received_msg_counters.update([msg]) - if received_msg_counters[msg] == len(nodes): - sent_time = Message.from_bytes(msg).sent_time - dissemination_time = received_time - sent_time + received_msg_counters.update([msg.id()]) + if received_msg_counters[msg.id()] == len(nodes): + dissemination_time = received_time - msg.sent_time # Use repr to convert a float to a string with as much precision as Python can provide writer.writerow( - [repr(dissemination_time), repr(sent_time), repr(received_time)] + [ + repr(dissemination_time), + repr(msg.sent_time), + repr(received_time), + ] ) result_cnt += 1 @@ -76,13 +79,13 @@ class Simulation: for _ in range(self.config.num_nodes) ] - async def __process_msg(self, msg: bytes) -> None: + async def __process_msg(self, msg: NomssipMessage[Message]) -> None: """ A handler to process messages received via Nomos Gossip channel """ # Notify that a new message has been received by the node. # The received time is also included in the notification. - await self.received_msg_queue.put((self.framework.now(), msg)) + await self.received_msg_queue.put((self.framework.now(), msg.message)) def __connect_nodes(self, nodes: list[Node], topology_path: str): topology = build_full_random_topology( @@ -127,30 +130,5 @@ class Simulation: for i in range(self.config.num_sent_msgs): if i > 0: await self.framework.sleep(self.config.msg_interval_sec) - msg = bytes(self.message_builder.next()) + msg = self.message_builder.next() await sender.send_message(msg) - - -@dataclass -class Message: - id: int - sent_time: float - - def __bytes__(self) -> bytes: - return struct.pack("if", self.id, self.sent_time) - - @classmethod - def from_bytes(cls, data: bytes) -> Self: - id, sent_from = struct.unpack("if", data) - return cls(id, sent_from) - - -class MessageBuilder: - def __init__(self, framework: Framework): - self.framework = framework - self.next_id = 0 - - def next(self) -> Message: - msg = Message(self.next_id, self.framework.now()) - self.next_id += 1 - return msg diff --git a/mixnet/sim/connection.py b/mixnet/sim/connection.py index 1296c63..0420459 100644 --- a/mixnet/sim/connection.py +++ b/mixnet/sim/connection.py @@ -1,18 +1,19 @@ import math -from abc import abstractmethod from collections import Counter -from typing import Awaitable +from typing import Protocol, TypeVar import pandas from typing_extensions import override from framework import Framework, Queue from protocol.connection import SimplexConnection -from sim.config import LatencyConfig, NetworkConfig +from sim.config import LatencyConfig from sim.state import NodeState +T = TypeVar("T") -class RemoteSimplexConnection(SimplexConnection): + +class RemoteSimplexConnection(SimplexConnection[T]): """ A simplex connection implementation that simulates network latency. """ @@ -22,18 +23,18 @@ class RemoteSimplexConnection(SimplexConnection): # A connection has a random constant latency self.latency = config.random_latency() # A queue of tuple(timestamp, msg) where a sender puts messages to be sent - self.send_queue: Queue[tuple[float, bytes]] = framework.queue() + self.send_queue: Queue[tuple[float, T]] = framework.queue() # A task that reads messages from send_queue, and puts them to recv_queue. # Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message. self.relayer = framework.spawn(self.__run_relayer()) # A queue where a receiver gets messages - self.recv_queue: Queue[bytes] = framework.queue() + self.recv_queue: Queue[T] = framework.queue() - async def send(self, data: bytes) -> None: + async def send(self, data: T) -> None: await self.send_queue.put((self.framework.now(), data)) self.on_sending(data) - async def recv(self) -> bytes: + async def recv(self) -> T: return await self.recv_queue.get() async def __run_relayer(self): @@ -54,16 +55,23 @@ class RemoteSimplexConnection(SimplexConnection): self.on_receiving(data) await self.recv_queue.put(data) - def on_sending(self, data: bytes) -> None: + def on_sending(self, data: T) -> None: # Should be overridden by subclass pass - def on_receiving(self, data: bytes) -> None: + def on_receiving(self, data: T) -> None: # Should be overridden by subclass pass -class MeteredRemoteSimplexConnection(RemoteSimplexConnection): +class HasLen(Protocol): + def __len__(self) -> int: ... + + +TL = TypeVar("TL", bound=HasLen) + + +class MeteredRemoteSimplexConnection(RemoteSimplexConnection[TL]): """ An extension of RemoteSimplexConnection that measures bandwidth usages. """ @@ -81,14 +89,14 @@ class MeteredRemoteSimplexConnection(RemoteSimplexConnection): self.recv_meters: list[int] = [] @override - def on_sending(self, data: bytes) -> None: + def on_sending(self, data: TL) -> None: """ Update statistics when sending a message """ self.__update_meter(self.send_meters, len(data)) @override - def on_receiving(self, data: bytes) -> None: + def on_receiving(self, data: TL) -> None: """ Update statistics when receiving a message """ @@ -120,7 +128,7 @@ class MeteredRemoteSimplexConnection(RemoteSimplexConnection): return pandas.Series(meters, name="bandwidth") -class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection): +class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection[TL]): """ An extension of MeteredRemoteSimplexConnection that is observed by passive observer. The observer monitors the node states of the sender and receiver and message sizes. @@ -143,13 +151,13 @@ class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection): self.msg_sizes: Counter[int] = Counter() @override - def on_sending(self, data: bytes) -> None: + def on_sending(self, data: TL) -> None: super().on_sending(data) self.__update_node_state(self.send_node_states, NodeState.SENDING) self.msg_sizes.update([len(data)]) @override - def on_receiving(self, data: bytes) -> None: + def on_receiving(self, data: TL) -> None: super().on_receiving(data) self.__update_node_state(self.recv_node_states, NodeState.RECEIVING) diff --git a/mixnet/sim/message.py b/mixnet/sim/message.py index cfc939a..df4b42f 100644 --- a/mixnet/sim/message.py +++ b/mixnet/sim/message.py @@ -1,3 +1,4 @@ +import hashlib import pickle from dataclasses import dataclass from typing import Self @@ -8,7 +9,29 @@ class Message: """ A message structure for simulation, which will be sent through mix nodes and eventually broadcasted to all nodes in the network. + """ + # The bytes of Sphinx packet + data: bytes + + def id(self) -> int: + return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big") + + def __len__(self) -> int: + return len(self.data) + + def __bytes__(self) -> bytes: + return self.data + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + return cls(data) + + +@dataclass +class InnerMessage: + """ + The inner message that is wrapped by Sphinx packet. The `id` must ensure the uniqueness of the message. """ @@ -23,11 +46,8 @@ class Message: def from_bytes(cls, data: bytes) -> Self: return pickle.loads(data) - def __hash__(self) -> int: - return self.id - -class UniqueMessageBuilder: +class UniqueInnerMessageBuilder: """ Builds a unique message with an incremental ID, assuming that the simulation is run in a single thread. @@ -36,7 +56,7 @@ class UniqueMessageBuilder: def __init__(self): self.next_id = 0 - def next(self, created_at: float, body: bytes) -> Message: - msg = Message(created_at, self.next_id, body) + def next(self, created_at: float, body: bytes) -> InnerMessage: + msg = InnerMessage(created_at, self.next_id, body) self.next_id += 1 return msg diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py index 84453c4..7aebd02 100644 --- a/mixnet/sim/simulation.py +++ b/mixnet/sim/simulation.py @@ -1,20 +1,19 @@ -from dataclasses import asdict, dataclass from pprint import pprint -from typing import Self import usim from matplotlib import pyplot import framework.usim as usimfw -from framework import Framework from protocol.config import GlobalConfig, MixMembership, NodeInfo -from protocol.node import Node, PeeringDegreeReached +from protocol.node import Node +from protocol.nomssip import NomssipMessage +from protocol.sphinx import SphinxPacketBuilder from sim.config import Config from sim.connection import ( MeteredRemoteSimplexConnection, ObservedMeteredRemoteSimplexConnection, ) -from sim.message import Message, UniqueMessageBuilder +from sim.message import InnerMessage, Message, UniqueInnerMessageBuilder from sim.state import NodeState, NodeStateTable from sim.stats import ConnectionStats, DisseminationTime from sim.topology import build_full_random_topology @@ -27,7 +26,7 @@ class Simulation: def __init__(self, config: Config): self.config = config - self.msg_builder = UniqueMessageBuilder() + self.inner_msg_builder = UniqueInnerMessageBuilder() self.dissemination_time = DisseminationTime(self.config.network.num_nodes) async def run(self): @@ -61,7 +60,7 @@ class Simulation: # Return analysis tools once the μSim scope is done return conn_stats, node_state_table - def __init_nodes(self) -> list[Node]: + def __init_nodes(self) -> list[Node[Message]]: # Initialize node/global configurations node_configs = self.config.node_configs() global_config = GlobalConfig( @@ -78,20 +77,22 @@ class Simulation: ) # Initialize/return Node instances + noise_msg = Message(bytes(SphinxPacketBuilder.size(global_config))) return [ - Node( + Node[Message]( self.framework, node_config, global_config, self.__process_broadcasted_msg, self.__process_recovered_msg, + noise_msg, ) for node_config in node_configs ] def __connect_nodes( self, - nodes: list[Node], + nodes: list[Node[Message]], node_state_table: NodeStateTable, conn_stats: ConnectionStats, ): @@ -144,8 +145,8 @@ class Simulation: meter_start_time: float, sender_states: list[NodeState], receiver_states: list[NodeState], - ) -> ObservedMeteredRemoteSimplexConnection: - return ObservedMeteredRemoteSimplexConnection( + ) -> ObservedMeteredRemoteSimplexConnection[NomssipMessage[Message]]: + return ObservedMeteredRemoteSimplexConnection[NomssipMessage[Message]]( self.config.network.latency, self.framework, meter_start_time, @@ -156,14 +157,14 @@ class Simulation: def __create_conn( self, meter_start_time: float, - ) -> MeteredRemoteSimplexConnection: - return MeteredRemoteSimplexConnection( + ) -> MeteredRemoteSimplexConnection[Message]: + return MeteredRemoteSimplexConnection[Message]( self.config.network.latency, self.framework, meter_start_time, ) - async def __run_node_logic(self, node: Node): + async def __run_node_logic(self, node: Node[Message]): """ Runs the lottery periodically to check if the node is selected to send a block. If the node is selected, creates a block and sends it through mix nodes. @@ -172,27 +173,29 @@ class Simulation: while True: await self.framework.sleep(lottery_config.interval_sec) if lottery_config.seed.random() < lottery_config.probability: - msg = self.msg_builder.next(self.framework.now(), b"selected block") - await node.send_message(bytes(msg)) + inner_msg = self.inner_msg_builder.next( + self.framework.now(), b"selected block" + ) + await node.send_message(Message(bytes(inner_msg))) - async def __process_broadcasted_msg(self, msg: bytes): + async def __process_broadcasted_msg(self, msg: Message): """ Process a broadcasted message originated from the last mix. """ - message = Message.from_bytes(msg) - elapsed = self.framework.now() - message.created_at - self.dissemination_time.add_broadcasted_msg(message, elapsed) + inner_msg = InnerMessage.from_bytes(msg.data) + elapsed = self.framework.now() - inner_msg.created_at + self.dissemination_time.add_broadcasted_msg(msg, elapsed) - async def __process_recovered_msg(self, msg: bytes) -> bytes: + async def __process_recovered_msg(self, msg: bytes) -> Message: """ Process a message fully recovered by the last mix and returns a new message to be broadcasted. """ - message = Message.from_bytes(msg) - elapsed = self.framework.now() - message.created_at + inner_msg = InnerMessage.from_bytes(Message.from_bytes(msg).data) + elapsed = self.framework.now() - inner_msg.created_at self.dissemination_time.add_mix_propagation_time(elapsed) # Update the timestamp and return the message to be broadcasted, # so that the broadcast dissemination time can be calculated from now. - message.created_at = self.framework.now() - return bytes(message) + inner_msg.created_at = self.framework.now() + return Message(bytes(inner_msg)) diff --git a/mixnet/sim/stats.py b/mixnet/sim/stats.py index d8a6efe..562a308 100644 --- a/mixnet/sim/stats.py +++ b/mixnet/sim/stats.py @@ -3,7 +3,6 @@ from collections import Counter, defaultdict import matplotlib.pyplot as plt import numpy import pandas -from matplotlib.axes import Axes from protocol.node import Node from sim.connection import ObservedMeteredRemoteSimplexConnection @@ -126,16 +125,18 @@ class DisseminationTime: # A collection of time taken for a message to be broadcasted from the last mix to all nodes in the network self.broadcast_dissemination_times: list[float] = [] # Data structures to check if a message has been broadcasted to all nodes - self.broadcast_status: Counter[Message] = Counter() + # msg_id (int) is a key. + self.broadcast_status: Counter[int] = Counter() self.num_nodes: int = num_nodes def add_mix_propagation_time(self, elapsed: float): self.mix_propagation_times.append(elapsed) def add_broadcasted_msg(self, msg: Message, elapsed: float): - assert self.broadcast_status[msg] < self.num_nodes - self.broadcast_status.update([msg]) - if self.broadcast_status[msg] == self.num_nodes: + id = msg.id() + assert self.broadcast_status[id] < self.num_nodes + self.broadcast_status.update([id]) + if self.broadcast_status[id] == self.num_nodes: self.broadcast_dissemination_times.append(elapsed) def analyze(self): diff --git a/mixnet/sim/test_message.py b/mixnet/sim/test_message.py index 3eddae8..b2b46c9 100644 --- a/mixnet/sim/test_message.py +++ b/mixnet/sim/test_message.py @@ -1,23 +1,22 @@ import time from unittest import TestCase -from sim.message import Message, UniqueMessageBuilder +from sim.message import InnerMessage, UniqueInnerMessageBuilder class TestMessage(TestCase): - def test_message_serde(self): - msg = Message(time.time(), 10, b"hello") + def test_inner_message_serde(self): + msg = InnerMessage(time.time(), 10, b"hello") serialized = bytes(msg) - deserialized = Message.from_bytes(serialized) + deserialized = InnerMessage.from_bytes(serialized) self.assertEqual(msg, deserialized) -class TestUniqueMessageBuilder(TestCase): +class TestUniqueInnerMessageBuilder(TestCase): def test_uniqueness(self): - builder = UniqueMessageBuilder() + builder = UniqueInnerMessageBuilder() msg1 = builder.next(time.time(), b"hello") msg2 = builder.next(time.time(), b"hello") self.assertEqual(0, msg1.id) self.assertEqual(1, msg2.id) self.assertNotEqual(msg1, msg2) - self.assertNotEqual(hash(msg1), hash(msg2))