From 848b3f400c7e3ac725d73fa8ac4547005a3eb9de Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:55:51 +0900 Subject: [PATCH] add comments --- mixnet/config.py | 17 +++++++++++++++++ mixnet/connection.py | 22 ++++++++++++++++------ mixnet/gossip.py | 8 ++++++++ mixnet/node.py | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 6 deletions(-) diff --git a/mixnet/config.py b/mixnet/config.py index 4525041..60e0583 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -13,6 +13,10 @@ from pysphinx.sphinx import Node as SphinxNode @dataclass class GlobalConfig: + """ + Global parameters used across all nodes in the network + """ + membership: MixMembership transmission_rate_per_sec: int # Global Transmission Rate # TODO: use this to make the size of Sphinx packet constant @@ -21,6 +25,10 @@ class GlobalConfig: @dataclass class NodeConfig: + """ + Node-specific parameters + """ + private_key: X25519PrivateKey mix_path_length: int gossip: GossipConfig @@ -34,6 +42,11 @@ class GossipConfig: @dataclass class MixMembership: + """ + A list of public information of nodes in the network. + We assume that this list is eventually known to all nodes in the network (e.g. via p2p advertising). + """ + nodes: List[NodeInfo] def generate_route(self, num_hops: int, last_mix: NodeInfo) -> list[NodeInfo]: @@ -53,6 +66,10 @@ class MixMembership: @dataclass class NodeInfo: + """ + Public information of a node to be shared to all nodes in the network + """ + public_key: X25519PublicKey def sphinx_node(self) -> SphinxNode: diff --git a/mixnet/connection.py b/mixnet/connection.py index 0653dc8..9b6acf6 100644 --- a/mixnet/connection.py +++ b/mixnet/connection.py @@ -7,6 +7,11 @@ SimplexConnection = NetworkPacketQueue class DuplexConnection: + """ + A duplex connection in which data can be transmitted and received simultaneously in both directions. + This is to mimic duplex communication in a real network (such as TCP or QUIC). + """ + inbound: SimplexConnection outbound: MixSimplexConnection @@ -22,6 +27,10 @@ class DuplexConnection: class MixSimplexConnection: + """ + Wraps a SimplexConnection to add a transmission rate and noise to the connection. + """ + queue: NetworkPacketQueue conn: SimplexConnection transmission_rate_per_sec: int @@ -39,12 +48,13 @@ class MixSimplexConnection: async def __run(self): while True: await asyncio.sleep(1 / self.transmission_rate_per_sec) - # TODO: time mixing + # TODO: temporal mixing if self.queue.empty(): - elem = self.noise_msg + # To guarantee GTR, send noise if there is no message to send + msg = self.noise_msg else: - elem = self.queue.get_nowait() - await self.conn.put(elem) + msg = self.queue.get_nowait() + await self.conn.put(msg) - async def send(self, elem: bytes): - await self.queue.put(elem) + async def send(self, msg: bytes): + await self.queue.put(msg) diff --git a/mixnet/gossip.py b/mixnet/gossip.py index e855c97..0c622a9 100644 --- a/mixnet/gossip.py +++ b/mixnet/gossip.py @@ -7,9 +7,17 @@ from mixnet.connection import DuplexConnection class GossipChannel: + """ + A gossip channel that broadcasts messages to all connected peers. + Peers are connected via DuplexConnection. + This class simplifies and simulates the libp2p gossipsub. + """ + config: GossipConfig conns: list[DuplexConnection] + # A handler to process inbound messages. handler: Callable[[bytes], Awaitable[bytes | None]] + # A set of message hashes to prevent processing the same message twice. msg_cache: set[bytes] def __init__( diff --git a/mixnet/node.py b/mixnet/node.py index ed26d35..07a0abc 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -20,6 +20,13 @@ BroadcastChannel: TypeAlias = asyncio.Queue[bytes] class Node: + """ + This represents any node in the network, which: + - generates/gossips mix messages (Sphinx packets) + - performs cryptographic mix (unwrapping Sphinx packets) + - generates noise + """ + config: NodeConfig global_config: GlobalConfig mixgossip_channel: GossipChannel @@ -40,6 +47,9 @@ class Node: self.packet_size = len(sample_packet.bytes()) async def __process_msg(self, msg: bytes) -> bytes | None: + """ + A handler to process messages received via gossip channel + """ flag, msg = Node.__parse_msg(msg) match flag: case MsgType.NOISE: @@ -56,6 +66,9 @@ class Node: async def __process_sphinx_packet( self, packet: SphinxPacket ) -> SphinxPacket | None: + """ + Unwrap the Sphinx packet and process the next Sphinx packet or the payload. + """ try: processed = packet.process(self.config.private_key) match processed: @@ -68,6 +81,9 @@ class Node: return packet async def __process_sphinx_payload(self, payload: Payload): + """ + Process the Sphinx payload and broadcast it if it is a real message. + """ msg_with_flag = self.reconstructor.add( Fragment.from_bytes(payload.recover_plain_playload()) ) @@ -77,8 +93,13 @@ class Node: await self.broadcast_channel.put(msg) def connect(self, peer: Node): + """ + Establish a duplex connection with a peer node. + """ noise_msg = Node.__build_msg(MsgType.NOISE, bytes(self.packet_size)) inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() + + # Register a duplex connection for its own use self.mixgossip_channel.add_conn( DuplexConnection( inbound_conn, @@ -89,6 +110,7 @@ class Node: ), ) ) + # Register the same duplex connection for the peer peer.mixgossip_channel.add_conn( DuplexConnection( outbound_conn, @@ -101,6 +123,11 @@ class Node: ) async def send_message(self, msg: bytes): + """ + Build a Sphinx packet and gossip it to all connected peers. + """ + # Here, we handle the case in which a msg is split into multiple Sphinx packets. + # But, in practice, we expect a message to be small enough to fit in a single Sphinx packet. for packet, _ in PacketBuilder.build_real_packets( msg, self.global_config.membership, @@ -112,10 +139,16 @@ class Node: @staticmethod def __build_msg(flag: MsgType, data: bytes) -> bytes: + """ + Prepend a flag to the message, right before sending it via network channel. + """ return flag.value + data @staticmethod def __parse_msg(data: bytes) -> tuple[MsgType, bytes]: + """ + Parse the message and extract the flag. + """ if len(data) < 1: raise ValueError("Invalid message format") return (MsgType(data[:1]), data[1:])