From 0772bf1563e9ecd58472ae0b13434e9aa2f05aae Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:23:33 +0900 Subject: [PATCH] refactor connection and gossip --- mixnet/connection.py | 50 ++++++++++++++ mixnet/gossip.py | 54 +++++++++++++++ mixnet/node.py | 157 ++++++++++--------------------------------- 3 files changed, 138 insertions(+), 123 deletions(-) create mode 100644 mixnet/connection.py create mode 100644 mixnet/gossip.py diff --git a/mixnet/connection.py b/mixnet/connection.py new file mode 100644 index 0000000..0653dc8 --- /dev/null +++ b/mixnet/connection.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import asyncio + +NetworkPacketQueue = asyncio.Queue[bytes] +SimplexConnection = NetworkPacketQueue + + +class DuplexConnection: + inbound: SimplexConnection + outbound: MixSimplexConnection + + def __init__(self, inbound: SimplexConnection, outbound: MixSimplexConnection): + self.inbound = inbound + self.outbound = outbound + + async def recv(self) -> bytes: + return await self.inbound.get() + + async def send(self, packet: bytes): + await self.outbound.send(packet) + + +class MixSimplexConnection: + queue: NetworkPacketQueue + conn: SimplexConnection + transmission_rate_per_sec: int + noise_msg: bytes + + def __init__( + self, conn: SimplexConnection, transmission_rate_per_sec: int, noise_msg: bytes + ): + self.queue = asyncio.Queue() + self.conn = conn + self.transmission_rate_per_sec = transmission_rate_per_sec + self.noise_msg = noise_msg + self.task = asyncio.create_task(self.__run()) + + async def __run(self): + while True: + await asyncio.sleep(1 / self.transmission_rate_per_sec) + # TODO: time mixing + if self.queue.empty(): + elem = self.noise_msg + else: + elem = self.queue.get_nowait() + await self.conn.put(elem) + + async def send(self, elem: bytes): + await self.queue.put(elem) diff --git a/mixnet/gossip.py b/mixnet/gossip.py new file mode 100644 index 0000000..e855c97 --- /dev/null +++ b/mixnet/gossip.py @@ -0,0 +1,54 @@ +import asyncio +import hashlib +from typing import Awaitable, Callable + +from mixnet.config import GossipConfig +from mixnet.connection import DuplexConnection + + +class GossipChannel: + config: GossipConfig + conns: list[DuplexConnection] + handler: Callable[[bytes], Awaitable[bytes | None]] + msg_cache: set[bytes] + + def __init__( + self, + config: GossipConfig, + handler: Callable[[bytes], Awaitable[bytes | None]], + ): + self.config = config + self.conns = [] + self.handler = handler + self.msg_cache = set() + # A set just for gathering a reference of tasks to prevent them from being garbage collected. + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + self.tasks = set() + + def add_conn(self, conn: DuplexConnection): + if len(self.conns) >= self.config.peering_degree: + # For simplicity of the spec, reject the connection if the peering degree is reached. + raise ValueError("The peering degree is reached.") + + self.conns.append(conn) + task = asyncio.create_task(self.__process_inbound_conn(conn)) + self.tasks.add(task) + # To discard the task from the set automatically when it is done. + task.add_done_callback(self.tasks.discard) + + async def __process_inbound_conn(self, conn: DuplexConnection): + while True: + msg = await conn.recv() + # Don't process the same message twice. + msg_hash = hashlib.sha256(msg).digest() + if msg_hash in self.msg_cache: + continue + self.msg_cache.add(msg_hash) + + new_msg = await self.handler(msg) + if new_msg is not None: + await self.gossip(new_msg) + + async def gossip(self, packet: bytes): + for conn in self.conns: + await conn.send(packet) diff --git a/mixnet/node.py b/mixnet/node.py index d1786dd..717483f 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -1,9 +1,8 @@ from __future__ import annotations import asyncio -import hashlib from enum import Enum -from typing import Awaitable, Callable, TypeAlias +from typing import TypeAlias from pysphinx.sphinx import ( Payload, @@ -12,18 +11,18 @@ from pysphinx.sphinx import ( SphinxPacket, ) -from mixnet.config import GlobalConfig, GossipConfig, NodeConfig +from mixnet.config import GlobalConfig, NodeConfig +from mixnet.connection import DuplexConnection, MixSimplexConnection +from mixnet.gossip import GossipChannel from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder -NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes] -Connection: TypeAlias = NetworkPacketQueue BroadcastChannel: TypeAlias = asyncio.Queue[bytes] class Node: config: NodeConfig global_config: GlobalConfig - mixgossip_channel: MixGossipChannel + mixgossip_channel: GossipChannel reconstructor: MessageReconstructor broadcast_channel: BroadcastChannel packet_size: int @@ -31,9 +30,7 @@ class Node: def __init__(self, config: NodeConfig, global_config: GlobalConfig): self.config = config self.global_config = global_config - self.mixgossip_channel = MixGossipChannel( - config.gossip, self.__process_sphinx_packet - ) + self.mixgossip_channel = GossipChannel(config.gossip, self.__process_msg) self.reconstructor = MessageReconstructor() self.broadcast_channel = asyncio.Queue() @@ -42,6 +39,20 @@ class Node: )[0] self.packet_size = len(sample_packet.bytes()) + async def __process_msg(self, msg: bytes) -> bytes | None: + flag, msg = Node.__parse_msg(msg) + match flag: + case MsgType.NOISE: + # Drop noise packet + return None + case MsgType.REAL: + # Handle the packet and gossip the result if needed. + sphinx_packet = SphinxPacket.from_bytes(msg) + new_sphinx_packet = await self.__process_sphinx_packet(sphinx_packet) + if new_sphinx_packet is None: + return None + return Node.__build_msg(MsgType.REAL, new_sphinx_packet.bytes()) + async def __process_sphinx_packet( self, packet: SphinxPacket ) -> SphinxPacket | None: @@ -66,12 +77,12 @@ class Node: await self.broadcast_channel.put(msg) def connect(self, peer: Node): - noise_msg = build_msg(MsgType.NOISE, bytes(self.packet_size)) + noise_msg = Node.__build_msg(MsgType.NOISE, bytes(self.packet_size)) inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() self.mixgossip_channel.add_conn( DuplexConnection( inbound_conn, - MixOutboundConnection( + MixSimplexConnection( outbound_conn, self.global_config.transmission_rate_per_sec, noise_msg, @@ -81,7 +92,7 @@ class Node: peer.mixgossip_channel.add_conn( DuplexConnection( outbound_conn, - MixOutboundConnection( + MixSimplexConnection( inbound_conn, self.global_config.transmission_rate_per_sec, noise_msg, @@ -93,121 +104,21 @@ class Node: for packet, _ in PacketBuilder.build_real_packets( msg, self.global_config.membership ): - await self.mixgossip_channel.gossip(build_msg(MsgType.REAL, packet.bytes())) + await self.mixgossip_channel.gossip( + Node.__build_msg(MsgType.REAL, packet.bytes()) + ) + @staticmethod + def __build_msg(flag: MsgType, data: bytes) -> bytes: + return flag.value + data -class MixGossipChannel: - config: GossipConfig - conns: list[DuplexConnection] - handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]] - msg_cache: set[bytes] - - def __init__( - self, - config: GossipConfig, - handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]], - ): - self.config = config - self.conns = [] - self.handler = handler - self.msg_cache = set() - # A set just for gathering a reference of tasks to prevent them from being garbage collected. - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - self.tasks = set() - - def add_conn(self, conn: DuplexConnection): - if len(self.conns) >= self.config.peering_degree: - # For simplicity of the spec, reject the connection if the peering degree is reached. - raise ValueError("The peering degree is reached.") - - self.conns.append(conn) - task = asyncio.create_task(self.__process_inbound_conn(conn)) - self.tasks.add(task) - # To discard the task from the set automatically when it is done. - task.add_done_callback(self.tasks.discard) - - async def __process_inbound_conn(self, conn: DuplexConnection): - while True: - msg = await conn.recv() - # Don't process the same message twice. - msg_hash = hashlib.sha256(msg).digest() - if msg_hash in self.msg_cache: - continue - self.msg_cache.add(msg_hash) - - flag, msg = parse_msg(msg) - match flag: - case MsgType.NOISE: - # Drop noise packet - continue - case MsgType.REAL: - # Handle the packet and gossip the result if needed. - sphinx_packet = SphinxPacket.from_bytes(msg) - new_sphinx_packet = await self.handler(sphinx_packet) - if new_sphinx_packet is not None: - await self.gossip( - build_msg(MsgType.REAL, new_sphinx_packet.bytes()) - ) - - async def gossip(self, packet: bytes): - for conn in self.conns: - await conn.send(packet) - - -class DuplexConnection: - inbound: Connection - outbound: MixOutboundConnection - - def __init__(self, inbound: Connection, outbound: MixOutboundConnection): - self.inbound = inbound - self.outbound = outbound - - async def recv(self) -> bytes: - return await self.inbound.get() - - async def send(self, packet: bytes): - await self.outbound.send(packet) - - -class MixOutboundConnection: - queue: NetworkPacketQueue - conn: Connection - transmission_rate_per_sec: int - noise_msg: bytes - - def __init__( - self, conn: Connection, transmission_rate_per_sec: int, noise_msg: bytes - ): - self.queue = asyncio.Queue() - self.conn = conn - self.transmission_rate_per_sec = transmission_rate_per_sec - self.noise_msg = noise_msg - self.task = asyncio.create_task(self.__run()) - - async def __run(self): - while True: - await asyncio.sleep(1 / self.transmission_rate_per_sec) - # TODO: time mixing - if self.queue.empty(): - elem = self.noise_msg - else: - elem = self.queue.get_nowait() - await self.conn.put(elem) - - async def send(self, elem: bytes): - await self.queue.put(elem) + @staticmethod + def __parse_msg(data: bytes) -> tuple[MsgType, bytes]: + if len(data) < 1: + raise ValueError("Invalid message format") + return (MsgType(data[:1]), data[1:]) class MsgType(Enum): REAL = b"\x00" NOISE = b"\x01" - - -def build_msg(flag: MsgType, data: bytes) -> bytes: - return flag.value + data - - -def parse_msg(data: bytes) -> tuple[MsgType, bytes]: - if len(data) < 1: - raise ValueError("Invalid message format") - return (MsgType(data[:1]), data[1:])