From ebc069b112f99c8ee0bf9642f5610002a1a312c1 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Thu, 25 Jan 2024 18:04:55 +0900 Subject: [PATCH] Mixnet: Refactor with asyncio (#53) --- mixnet/client.py | 106 +++++++++++++++++---------------- mixnet/node.py | 134 ++++++++++++++---------------------------- mixnet/test_client.py | 42 ++++++------- mixnet/test_node.py | 54 ++++++++++------- mixnet/test_utils.py | 12 ++++ requirements.txt | 3 - 6 files changed, 165 insertions(+), 186 deletions(-) create mode 100644 mixnet/test_utils.py diff --git a/mixnet/client.py b/mixnet/client.py index cab2b55..295b7c6 100644 --- a/mixnet/client.py +++ b/mixnet/client.py @@ -1,9 +1,6 @@ from __future__ import annotations -import queue -import time -from datetime import datetime, timedelta -from threading import Thread +import asyncio from mixnet.mixnet import Mixnet, MixnetTopology from mixnet.node import PacketQueue @@ -11,7 +8,14 @@ from mixnet.packet import PacketBuilder from mixnet.poisson import poisson_interval_sec -class MixClientRunner(Thread): +async def mixclient_emitter( + mixnet: Mixnet, + topology: MixnetTopology, + emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec + redundancy: int, # b in the spec + real_packet_queue: PacketQueue, + outbound_socket: PacketQueue, +): """ Emit packets at the Poisson emission_rate_per_min. @@ -21,55 +25,55 @@ class MixClientRunner(Thread): If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min. """ - def __init__( - self, - mixnet: Mixnet, - topology: MixnetTopology, - emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec - redundancy: int, # b in the spec - real_packet_queue: PacketQueue, - outbound_socket: PacketQueue, - ): - super().__init__() - self.mixnet = mixnet - self.topology = topology - self.emission_rate_per_min = emission_rate_per_min - self.redundancy = redundancy - self.real_packet_queue = real_packet_queue - self.redundant_real_packet_queue: PacketQueue = queue.Queue() - self.outbound_socket = outbound_socket + redundant_real_packet_queue: PacketQueue = asyncio.Queue() - def run(self) -> None: - # Here in Python, this thread is implemented in synchronous manner. - # In the real implementation, consider implementing this in asynchronous if possible. + emission_notifier_queue = asyncio.Queue() + _ = asyncio.create_task( + emission_notifier(emission_rate_per_min, emission_notifier_queue) + ) - next_emission_ts = datetime.now() + timedelta( - seconds=poisson_interval_sec(self.emission_rate_per_min) - ) - - while True: - time.sleep(1 / 1000) - - if datetime.now() < next_emission_ts: - continue - - next_emission_ts += timedelta( - seconds=poisson_interval_sec(self.emission_rate_per_min) + while True: + # Wait until the next emission time + _ = await emission_notifier_queue.get() + try: + await emit( + mixnet, + topology, + redundancy, + real_packet_queue, + redundant_real_packet_queue, + outbound_socket, ) + finally: + # Python convention: indicate that the previously enqueued task has been processed + emission_notifier_queue.task_done() - if not self.redundant_real_packet_queue.empty(): - addr, packet = self.redundant_real_packet_queue.get() - self.outbound_socket.put((addr, packet)) - continue - if not self.real_packet_queue.empty(): - addr, packet = self.real_packet_queue.get() - # Schedule redundant real packets - for _ in range(self.redundancy - 1): - self.redundant_real_packet_queue.put((addr, packet)) - self.outbound_socket.put((addr, packet)) +async def emit( + mixnet: Mixnet, + topology: MixnetTopology, + redundancy: int, # b in the spec + real_packet_queue: PacketQueue, + redundant_real_packet_queue: PacketQueue, + outbound_socket: PacketQueue, +): + if not redundant_real_packet_queue.empty(): + addr, packet = redundant_real_packet_queue.get_nowait() + await outbound_socket.put((addr, packet)) + return - packet, route = PacketBuilder.drop_cover( - b"drop cover", self.mixnet, self.topology - ).next() - self.outbound_socket.put((route[0].addr, packet)) + if not real_packet_queue.empty(): + addr, packet = real_packet_queue.get_nowait() + # Schedule redundant real packets + for _ in range(redundancy - 1): + redundant_real_packet_queue.put_nowait((addr, packet)) + await outbound_socket.put((addr, packet)) + + packet, route = PacketBuilder.drop_cover(b"drop cover", mixnet, topology).next() + await outbound_socket.put((route[0].addr, packet)) + + +async def emission_notifier(emission_rate_per_min: int, queue: asyncio.Queue): + while True: + await asyncio.sleep(poisson_interval_sec(emission_rate_per_min)) + queue.put_nowait(None) diff --git a/mixnet/node.py b/mixnet/node.py index defd5b3..0521cdd 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -1,10 +1,7 @@ from __future__ import annotations -import queue -import threading -import time +import asyncio from dataclasses import dataclass -from threading import Thread from typing import Tuple, TypeAlias from cryptography.hazmat.primitives.asymmetric.x25519 import ( @@ -27,9 +24,9 @@ NodeId: TypeAlias = BlsPublicKey # 32-byte that represents an IP address and a port of a mix node. NodeAddress: TypeAlias = bytes -PacketQueue: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket]]" +PacketQueue: TypeAlias = "asyncio.Queue[Tuple[NodeAddress, SphinxPacket]]" PacketPayloadQueue: TypeAlias = ( - "queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]" + "asyncio.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]" ) @@ -50,26 +47,26 @@ class MixNode: def start( self, - delay_rate_per_min: int, + delay_rate_per_min: int, # Poisson rate parameter: mu inbound_socket: PacketQueue, outbound_socket: PacketPayloadQueue, - ) -> MixNodeRunner: - thread = MixNodeRunner( - self.encryption_private_key, - delay_rate_per_min, - inbound_socket, - outbound_socket, + ) -> asyncio.Task: + return asyncio.create_task( + MixNodeRunner( + self.encryption_private_key, + delay_rate_per_min, + inbound_socket, + outbound_socket, + ).run() ) - thread.daemon = True - thread.start() - return thread -class MixNodeRunner(Thread): +class MixNodeRunner: """ - Read SphinxPackets from inbound socket and spawn a thread for each packet to process it. + A class handling incoming packets with delays - This thread approximates a M/M/inf queue. + This class is defined separated with the MixNode class, + in order to define the MixNode as a simple dataclass for clarity. """ def __init__( @@ -79,96 +76,55 @@ class MixNodeRunner(Thread): inbound_socket: PacketQueue, outbound_socket: PacketPayloadQueue, ): - super().__init__() self.encryption_private_key = encryption_private_key self.delay_rate_per_min = delay_rate_per_min self.inbound_socket = inbound_socket self.outbound_socket = outbound_socket - self.num_processing = AtomicInt(0) - def run(self) -> None: - # Here in Python, this thread is implemented in synchronous manner. - # In the real implementation, consider implementing this in asynchronous if possible, - # to approximate a M/M/inf queue + async def run(self): + """ + Read SphinxPackets from inbound socket and spawn a thread for each packet to process it. + + This thread approximates a M/M/inf queue. + """ + + # A set just for gathering a reference of tasks to prevent them from being garbage collected. + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + self.tasks = set() + while True: - _, packet = self.inbound_socket.get() - thread = MixNodePacketProcessor( - packet, - self.encryption_private_key, - self.delay_rate_per_min, - self.outbound_socket, - self.num_processing, + _, packet = await self.inbound_socket.get() + task = asyncio.create_task( + self.process_packet( + packet, + ) ) - thread.daemon = True - self.num_processing.add(1) - thread.start() + self.tasks.add(task) + # To discard the task from the set automatically when it is done. + task.add_done_callback(self.tasks.discard) - def num_jobs(self) -> int: - """ - Return the number of packets that are being processed or still in the inbound socket. - - If this thread works as a M/M/inf queue completely, - the number of packets that are still in the inbound socket must be always 0. - """ - return self.num_processing.get() + self.inbound_socket.qsize() - - -class MixNodePacketProcessor(Thread): - """ - Process a single packet with a delay that follows exponential distribution, - and forward it to the next mix node or the mix destination - - This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates. - """ - - def __init__( + async def process_packet( self, packet: SphinxPacket, - encryption_private_key: X25519PrivateKey, - delay_rate_per_min: int, # Poisson rate parameter: mu - outbound_socket: PacketPayloadQueue, - num_processing: AtomicInt, ): - super().__init__() - self.packet = packet - self.encryption_private_key = encryption_private_key - self.delay_rate_per_min = delay_rate_per_min - self.outbound_socket = outbound_socket - self.num_processing = num_processing + """ + Process a single packet with a delay that follows exponential distribution, + and forward it to the next mix node or the mix destination - def run(self) -> None: + This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates. + """ delay_sec = poisson_interval_sec(self.delay_rate_per_min) - time.sleep(delay_sec) + await asyncio.sleep(delay_sec) - processed = self.packet.process(self.encryption_private_key) + processed = packet.process(self.encryption_private_key) match processed: case ProcessedForwardHopPacket(): - self.outbound_socket.put( + await self.outbound_socket.put( (processed.next_node_address, processed.next_packet) ) case ProcessedFinalHopPacket(): - self.outbound_socket.put( + await self.outbound_socket.put( (processed.destination_node_address, processed.payload) ) case _: raise UnknownHeaderTypeError - - self.num_processing.sub(1) - - -class AtomicInt: - def __init__(self, initial: int) -> None: - self.lock = threading.Lock() - self.value = initial - - def add(self, v: int): - with self.lock: - self.value += v - - def sub(self, v: int): - with self.lock: - self.value -= v - - def get(self) -> int: - with self.lock: - return self.value diff --git a/mixnet/test_client.py b/mixnet/test_client.py index 1f9f31b..2aea83f 100644 --- a/mixnet/test_client.py +++ b/mixnet/test_client.py @@ -1,54 +1,54 @@ -import queue +import asyncio from datetime import datetime from typing import Tuple -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase import numpy -import timeout_decorator from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from mixnet.bls import generate_bls -from mixnet.client import MixClientRunner +from mixnet.client import mixclient_emitter from mixnet.mixnet import Mixnet, MixnetTopology from mixnet.node import MixNode, PacketQueue from mixnet.packet import PacketBuilder from mixnet.poisson import poisson_mean_interval_sec from mixnet.utils import random_bytes +from mixnet.test_utils import with_test_timeout -class TestMixClientRunner(TestCase): - @timeout_decorator.timeout(180) - def test_mixclient_runner_emission_rate(self): +class TestMixClient(IsolatedAsyncioTestCase): + @with_test_timeout(100) + async def test_mixclient_emitter(self): mixnet, topology = self.init() - real_packet_queue: PacketQueue = queue.Queue() - outbound_socket: PacketQueue = queue.Queue() + real_packet_queue: PacketQueue = asyncio.Queue() + outbound_socket: PacketQueue = asyncio.Queue() emission_rate_per_min = 30 redundancy = 3 - client = MixClientRunner( - mixnet, - topology, - emission_rate_per_min, - redundancy, - real_packet_queue, - outbound_socket, + _ = asyncio.create_task( + mixclient_emitter( + mixnet, + topology, + emission_rate_per_min, + redundancy, + real_packet_queue, + outbound_socket, + ) ) - client.daemon = True - client.start() # Create packets. At least two packets are expected to be generated from a 3500-byte msg builder = PacketBuilder.real(random_bytes(3500), mixnet, topology) # Schedule two packets to the mix client without any interval packet, route = builder.next() - real_packet_queue.put((route[0].addr, packet)) + await real_packet_queue.put((route[0].addr, packet)) packet, route = builder.next() - real_packet_queue.put((route[0].addr, packet)) + await real_packet_queue.put((route[0].addr, packet)) # Calculate intervals between packet emissions from the mix client intervals = [] ts = datetime.now() for _ in range(30): - _ = outbound_socket.get() + _ = await outbound_socket.get() now = datetime.now() intervals.append((now - ts).total_seconds()) ts = now diff --git a/mixnet/test_node.py b/mixnet/test_node.py index 1c5c90c..b7a33d5 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -1,12 +1,9 @@ -import queue -import threading -import time +import asyncio from datetime import datetime from typing import Tuple -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase import numpy -import timeout_decorator from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from pysphinx.sphinx import SphinxPacket @@ -15,12 +12,13 @@ from mixnet.mixnet import Mixnet, MixnetTopology from mixnet.node import MixNode, NodeAddress, PacketPayloadQueue, PacketQueue from mixnet.packet import PacketBuilder from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec +from mixnet.test_utils import with_test_timeout from mixnet.utils import random_bytes -class TestMixNodeRunner(TestCase): - @timeout_decorator.timeout(180) - def test_mixnode_runner_emission_rate(self): +class TestMixNodeRunner(IsolatedAsyncioTestCase): + @with_test_timeout(180) + async def test_mixnode_runner_emission_rate(self): """ Test if MixNodeRunner works as a M/M/inf queue. @@ -29,41 +27,49 @@ class TestMixNodeRunner(TestCase): the rate of outputs should be `lambda`. """ mixnet, topology = self.init() - inbound_socket: PacketQueue = queue.Queue() - outbound_socket: PacketPayloadQueue = queue.Queue() + inbound_socket: PacketQueue = asyncio.Queue() + outbound_socket: PacketPayloadQueue = asyncio.Queue() packet, route = PacketBuilder.real(b"msg", mixnet, topology).next() delay_rate_per_min = 30 # mu (= 2s delay on average) # Start only the first mix node for testing - runner = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket) + _ = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket) # Send packets to the first mix node in a Poisson distribution packet_count = 100 emission_rate_per_min = 120 # lambda (= 2msg/sec) - sender = threading.Thread( - target=self.send_packets, - args=( + # This queue is just for counting how many packets have been sent so far. + sent_packet_queue: PacketQueue = asyncio.Queue() + _ = asyncio.create_task( + self.send_packets( inbound_socket, packet, route[0].addr, packet_count, emission_rate_per_min, - ), + sent_packet_queue, + ) ) - sender.daemon = True - sender.start() # Calculate intervals between outputs and gather num_jobs in the first mix node. intervals = [] num_jobs = [] ts = datetime.now() for _ in range(packet_count): - _ = outbound_socket.get() + _ = await outbound_socket.get() now = datetime.now() intervals.append((now - ts).total_seconds()) - num_jobs.append(runner.num_jobs()) + + # Calculate the current # of jobs staying in the mix node + num_packets_emitted_from_mixnode = len(intervals) + num_packets_sent_to_mixnode = sent_packet_queue.qsize() + num_jobs.append( + num_packets_sent_to_mixnode - num_packets_emitted_from_mixnode + ) + ts = now + # Remove the first interval that would be much larger than other intervals, # because of the delay in mix node. intervals = intervals[1:] @@ -87,16 +93,20 @@ class TestMixNodeRunner(TestCase): ) @staticmethod - def send_packets( + async def send_packets( inbound_socket: PacketQueue, packet: SphinxPacket, node_addr: NodeAddress, cnt: int, rate_per_min: int, + # For testing purpose, to inform the caller how many packets have been sent to the inbound_socket + sent_packet_queue: PacketQueue, ): for _ in range(cnt): - time.sleep(poisson_interval_sec(rate_per_min)) - inbound_socket.put((node_addr, packet)) + # Since the task is not heavy, just sleep for seconds instead of using emission_notifier + await asyncio.sleep(poisson_interval_sec(rate_per_min)) + await inbound_socket.put((node_addr, packet)) + await sent_packet_queue.put((node_addr, packet)) @staticmethod def init() -> Tuple[Mixnet, MixnetTopology]: diff --git a/mixnet/test_utils.py b/mixnet/test_utils.py new file mode 100644 index 0000000..de89bd9 --- /dev/null +++ b/mixnet/test_utils.py @@ -0,0 +1,12 @@ +import asyncio + + +def with_test_timeout(t): + def wrapper(coroutine): + async def run(*args, **kwargs): + async with asyncio.timeout(t): + return await coroutine(*args, **kwargs) + + return run + + return wrapper diff --git a/requirements.txt b/requirements.txt index 81d737e..6bbee51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,3 @@ numpy==1.26.3 pycparser==2.21 pysphinx==0.0.1 scipy==1.11.4 -setuptools==69.0.3 -timeout-decorator==0.5.0 -wheel==0.42.0