Mixnet: Refactor with asyncio (#53)

This commit is contained in:
Youngjoon Lee 2024-01-25 18:04:55 +09:00 committed by GitHub
parent 30d52791c3
commit ebc069b112
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 165 additions and 186 deletions

View File

@ -1,9 +1,6 @@
from __future__ import annotations from __future__ import annotations
import queue import asyncio
import time
from datetime import datetime, timedelta
from threading import Thread
from mixnet.mixnet import Mixnet, MixnetTopology from mixnet.mixnet import Mixnet, MixnetTopology
from mixnet.node import PacketQueue from mixnet.node import PacketQueue
@ -11,7 +8,14 @@ from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_interval_sec from mixnet.poisson import poisson_interval_sec
class MixClientRunner(Thread): async def mixclient_emitter(
mixnet: Mixnet,
topology: MixnetTopology,
emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec
redundancy: int, # b in the spec
real_packet_queue: PacketQueue,
outbound_socket: PacketQueue,
):
""" """
Emit packets at the Poisson emission_rate_per_min. Emit packets at the Poisson emission_rate_per_min.
@ -21,55 +25,55 @@ class MixClientRunner(Thread):
If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min. If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min.
""" """
def __init__( redundant_real_packet_queue: PacketQueue = asyncio.Queue()
self,
mixnet: Mixnet,
topology: MixnetTopology,
emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec
redundancy: int, # b in the spec
real_packet_queue: PacketQueue,
outbound_socket: PacketQueue,
):
super().__init__()
self.mixnet = mixnet
self.topology = topology
self.emission_rate_per_min = emission_rate_per_min
self.redundancy = redundancy
self.real_packet_queue = real_packet_queue
self.redundant_real_packet_queue: PacketQueue = queue.Queue()
self.outbound_socket = outbound_socket
def run(self) -> None: emission_notifier_queue = asyncio.Queue()
# Here in Python, this thread is implemented in synchronous manner. _ = asyncio.create_task(
# In the real implementation, consider implementing this in asynchronous if possible. emission_notifier(emission_rate_per_min, emission_notifier_queue)
)
next_emission_ts = datetime.now() + timedelta( while True:
seconds=poisson_interval_sec(self.emission_rate_per_min) # Wait until the next emission time
) _ = await emission_notifier_queue.get()
try:
while True: await emit(
time.sleep(1 / 1000) mixnet,
topology,
if datetime.now() < next_emission_ts: redundancy,
continue real_packet_queue,
redundant_real_packet_queue,
next_emission_ts += timedelta( outbound_socket,
seconds=poisson_interval_sec(self.emission_rate_per_min)
) )
finally:
# Python convention: indicate that the previously enqueued task has been processed
emission_notifier_queue.task_done()
if not self.redundant_real_packet_queue.empty():
addr, packet = self.redundant_real_packet_queue.get()
self.outbound_socket.put((addr, packet))
continue
if not self.real_packet_queue.empty(): async def emit(
addr, packet = self.real_packet_queue.get() mixnet: Mixnet,
# Schedule redundant real packets topology: MixnetTopology,
for _ in range(self.redundancy - 1): redundancy: int, # b in the spec
self.redundant_real_packet_queue.put((addr, packet)) real_packet_queue: PacketQueue,
self.outbound_socket.put((addr, packet)) redundant_real_packet_queue: PacketQueue,
outbound_socket: PacketQueue,
):
if not redundant_real_packet_queue.empty():
addr, packet = redundant_real_packet_queue.get_nowait()
await outbound_socket.put((addr, packet))
return
packet, route = PacketBuilder.drop_cover( if not real_packet_queue.empty():
b"drop cover", self.mixnet, self.topology addr, packet = real_packet_queue.get_nowait()
).next() # Schedule redundant real packets
self.outbound_socket.put((route[0].addr, packet)) for _ in range(redundancy - 1):
redundant_real_packet_queue.put_nowait((addr, packet))
await outbound_socket.put((addr, packet))
packet, route = PacketBuilder.drop_cover(b"drop cover", mixnet, topology).next()
await outbound_socket.put((route[0].addr, packet))
async def emission_notifier(emission_rate_per_min: int, queue: asyncio.Queue):
while True:
await asyncio.sleep(poisson_interval_sec(emission_rate_per_min))
queue.put_nowait(None)

View File

@ -1,10 +1,7 @@
from __future__ import annotations from __future__ import annotations
import queue import asyncio
import threading
import time
from dataclasses import dataclass from dataclasses import dataclass
from threading import Thread
from typing import Tuple, TypeAlias from typing import Tuple, TypeAlias
from cryptography.hazmat.primitives.asymmetric.x25519 import ( from cryptography.hazmat.primitives.asymmetric.x25519 import (
@ -27,9 +24,9 @@ NodeId: TypeAlias = BlsPublicKey
# 32-byte that represents an IP address and a port of a mix node. # 32-byte that represents an IP address and a port of a mix node.
NodeAddress: TypeAlias = bytes NodeAddress: TypeAlias = bytes
PacketQueue: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket]]" PacketQueue: TypeAlias = "asyncio.Queue[Tuple[NodeAddress, SphinxPacket]]"
PacketPayloadQueue: TypeAlias = ( PacketPayloadQueue: TypeAlias = (
"queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]" "asyncio.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
) )
@ -50,26 +47,26 @@ class MixNode:
def start( def start(
self, self,
delay_rate_per_min: int, delay_rate_per_min: int, # Poisson rate parameter: mu
inbound_socket: PacketQueue, inbound_socket: PacketQueue,
outbound_socket: PacketPayloadQueue, outbound_socket: PacketPayloadQueue,
) -> MixNodeRunner: ) -> asyncio.Task:
thread = MixNodeRunner( return asyncio.create_task(
self.encryption_private_key, MixNodeRunner(
delay_rate_per_min, self.encryption_private_key,
inbound_socket, delay_rate_per_min,
outbound_socket, inbound_socket,
outbound_socket,
).run()
) )
thread.daemon = True
thread.start()
return thread
class MixNodeRunner(Thread): class MixNodeRunner:
""" """
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it. A class handling incoming packets with delays
This thread approximates a M/M/inf queue. This class is defined separated with the MixNode class,
in order to define the MixNode as a simple dataclass for clarity.
""" """
def __init__( def __init__(
@ -79,96 +76,55 @@ class MixNodeRunner(Thread):
inbound_socket: PacketQueue, inbound_socket: PacketQueue,
outbound_socket: PacketPayloadQueue, outbound_socket: PacketPayloadQueue,
): ):
super().__init__()
self.encryption_private_key = encryption_private_key self.encryption_private_key = encryption_private_key
self.delay_rate_per_min = delay_rate_per_min self.delay_rate_per_min = delay_rate_per_min
self.inbound_socket = inbound_socket self.inbound_socket = inbound_socket
self.outbound_socket = outbound_socket self.outbound_socket = outbound_socket
self.num_processing = AtomicInt(0)
def run(self) -> None: async def run(self):
# Here in Python, this thread is implemented in synchronous manner. """
# In the real implementation, consider implementing this in asynchronous if possible, Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
# to approximate a M/M/inf queue
This thread approximates a M/M/inf queue.
"""
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks = set()
while True: while True:
_, packet = self.inbound_socket.get() _, packet = await self.inbound_socket.get()
thread = MixNodePacketProcessor( task = asyncio.create_task(
packet, self.process_packet(
self.encryption_private_key, packet,
self.delay_rate_per_min, )
self.outbound_socket,
self.num_processing,
) )
thread.daemon = True self.tasks.add(task)
self.num_processing.add(1) # To discard the task from the set automatically when it is done.
thread.start() task.add_done_callback(self.tasks.discard)
def num_jobs(self) -> int: async def process_packet(
"""
Return the number of packets that are being processed or still in the inbound socket.
If this thread works as a M/M/inf queue completely,
the number of packets that are still in the inbound socket must be always 0.
"""
return self.num_processing.get() + self.inbound_socket.qsize()
class MixNodePacketProcessor(Thread):
"""
Process a single packet with a delay that follows exponential distribution,
and forward it to the next mix node or the mix destination
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
"""
def __init__(
self, self,
packet: SphinxPacket, packet: SphinxPacket,
encryption_private_key: X25519PrivateKey,
delay_rate_per_min: int, # Poisson rate parameter: mu
outbound_socket: PacketPayloadQueue,
num_processing: AtomicInt,
): ):
super().__init__() """
self.packet = packet Process a single packet with a delay that follows exponential distribution,
self.encryption_private_key = encryption_private_key and forward it to the next mix node or the mix destination
self.delay_rate_per_min = delay_rate_per_min
self.outbound_socket = outbound_socket
self.num_processing = num_processing
def run(self) -> None: This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
"""
delay_sec = poisson_interval_sec(self.delay_rate_per_min) delay_sec = poisson_interval_sec(self.delay_rate_per_min)
time.sleep(delay_sec) await asyncio.sleep(delay_sec)
processed = self.packet.process(self.encryption_private_key) processed = packet.process(self.encryption_private_key)
match processed: match processed:
case ProcessedForwardHopPacket(): case ProcessedForwardHopPacket():
self.outbound_socket.put( await self.outbound_socket.put(
(processed.next_node_address, processed.next_packet) (processed.next_node_address, processed.next_packet)
) )
case ProcessedFinalHopPacket(): case ProcessedFinalHopPacket():
self.outbound_socket.put( await self.outbound_socket.put(
(processed.destination_node_address, processed.payload) (processed.destination_node_address, processed.payload)
) )
case _: case _:
raise UnknownHeaderTypeError raise UnknownHeaderTypeError
self.num_processing.sub(1)
class AtomicInt:
def __init__(self, initial: int) -> None:
self.lock = threading.Lock()
self.value = initial
def add(self, v: int):
with self.lock:
self.value += v
def sub(self, v: int):
with self.lock:
self.value -= v
def get(self) -> int:
with self.lock:
return self.value

View File

@ -1,54 +1,54 @@
import queue import asyncio
from datetime import datetime from datetime import datetime
from typing import Tuple from typing import Tuple
from unittest import TestCase from unittest import IsolatedAsyncioTestCase
import numpy import numpy
import timeout_decorator
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from mixnet.bls import generate_bls from mixnet.bls import generate_bls
from mixnet.client import MixClientRunner from mixnet.client import mixclient_emitter
from mixnet.mixnet import Mixnet, MixnetTopology from mixnet.mixnet import Mixnet, MixnetTopology
from mixnet.node import MixNode, PacketQueue from mixnet.node import MixNode, PacketQueue
from mixnet.packet import PacketBuilder from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_mean_interval_sec from mixnet.poisson import poisson_mean_interval_sec
from mixnet.utils import random_bytes from mixnet.utils import random_bytes
from mixnet.test_utils import with_test_timeout
class TestMixClientRunner(TestCase): class TestMixClient(IsolatedAsyncioTestCase):
@timeout_decorator.timeout(180) @with_test_timeout(100)
def test_mixclient_runner_emission_rate(self): async def test_mixclient_emitter(self):
mixnet, topology = self.init() mixnet, topology = self.init()
real_packet_queue: PacketQueue = queue.Queue() real_packet_queue: PacketQueue = asyncio.Queue()
outbound_socket: PacketQueue = queue.Queue() outbound_socket: PacketQueue = asyncio.Queue()
emission_rate_per_min = 30 emission_rate_per_min = 30
redundancy = 3 redundancy = 3
client = MixClientRunner( _ = asyncio.create_task(
mixnet, mixclient_emitter(
topology, mixnet,
emission_rate_per_min, topology,
redundancy, emission_rate_per_min,
real_packet_queue, redundancy,
outbound_socket, real_packet_queue,
outbound_socket,
)
) )
client.daemon = True
client.start()
# Create packets. At least two packets are expected to be generated from a 3500-byte msg # Create packets. At least two packets are expected to be generated from a 3500-byte msg
builder = PacketBuilder.real(random_bytes(3500), mixnet, topology) builder = PacketBuilder.real(random_bytes(3500), mixnet, topology)
# Schedule two packets to the mix client without any interval # Schedule two packets to the mix client without any interval
packet, route = builder.next() packet, route = builder.next()
real_packet_queue.put((route[0].addr, packet)) await real_packet_queue.put((route[0].addr, packet))
packet, route = builder.next() packet, route = builder.next()
real_packet_queue.put((route[0].addr, packet)) await real_packet_queue.put((route[0].addr, packet))
# Calculate intervals between packet emissions from the mix client # Calculate intervals between packet emissions from the mix client
intervals = [] intervals = []
ts = datetime.now() ts = datetime.now()
for _ in range(30): for _ in range(30):
_ = outbound_socket.get() _ = await outbound_socket.get()
now = datetime.now() now = datetime.now()
intervals.append((now - ts).total_seconds()) intervals.append((now - ts).total_seconds())
ts = now ts = now

View File

@ -1,12 +1,9 @@
import queue import asyncio
import threading
import time
from datetime import datetime from datetime import datetime
from typing import Tuple from typing import Tuple
from unittest import TestCase from unittest import IsolatedAsyncioTestCase
import numpy import numpy
import timeout_decorator
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from pysphinx.sphinx import SphinxPacket from pysphinx.sphinx import SphinxPacket
@ -15,12 +12,13 @@ from mixnet.mixnet import Mixnet, MixnetTopology
from mixnet.node import MixNode, NodeAddress, PacketPayloadQueue, PacketQueue from mixnet.node import MixNode, NodeAddress, PacketPayloadQueue, PacketQueue
from mixnet.packet import PacketBuilder from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec
from mixnet.test_utils import with_test_timeout
from mixnet.utils import random_bytes from mixnet.utils import random_bytes
class TestMixNodeRunner(TestCase): class TestMixNodeRunner(IsolatedAsyncioTestCase):
@timeout_decorator.timeout(180) @with_test_timeout(180)
def test_mixnode_runner_emission_rate(self): async def test_mixnode_runner_emission_rate(self):
""" """
Test if MixNodeRunner works as a M/M/inf queue. Test if MixNodeRunner works as a M/M/inf queue.
@ -29,41 +27,49 @@ class TestMixNodeRunner(TestCase):
the rate of outputs should be `lambda`. the rate of outputs should be `lambda`.
""" """
mixnet, topology = self.init() mixnet, topology = self.init()
inbound_socket: PacketQueue = queue.Queue() inbound_socket: PacketQueue = asyncio.Queue()
outbound_socket: PacketPayloadQueue = queue.Queue() outbound_socket: PacketPayloadQueue = asyncio.Queue()
packet, route = PacketBuilder.real(b"msg", mixnet, topology).next() packet, route = PacketBuilder.real(b"msg", mixnet, topology).next()
delay_rate_per_min = 30 # mu (= 2s delay on average) delay_rate_per_min = 30 # mu (= 2s delay on average)
# Start only the first mix node for testing # Start only the first mix node for testing
runner = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket) _ = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket)
# Send packets to the first mix node in a Poisson distribution # Send packets to the first mix node in a Poisson distribution
packet_count = 100 packet_count = 100
emission_rate_per_min = 120 # lambda (= 2msg/sec) emission_rate_per_min = 120 # lambda (= 2msg/sec)
sender = threading.Thread( # This queue is just for counting how many packets have been sent so far.
target=self.send_packets, sent_packet_queue: PacketQueue = asyncio.Queue()
args=( _ = asyncio.create_task(
self.send_packets(
inbound_socket, inbound_socket,
packet, packet,
route[0].addr, route[0].addr,
packet_count, packet_count,
emission_rate_per_min, emission_rate_per_min,
), sent_packet_queue,
)
) )
sender.daemon = True
sender.start()
# Calculate intervals between outputs and gather num_jobs in the first mix node. # Calculate intervals between outputs and gather num_jobs in the first mix node.
intervals = [] intervals = []
num_jobs = [] num_jobs = []
ts = datetime.now() ts = datetime.now()
for _ in range(packet_count): for _ in range(packet_count):
_ = outbound_socket.get() _ = await outbound_socket.get()
now = datetime.now() now = datetime.now()
intervals.append((now - ts).total_seconds()) intervals.append((now - ts).total_seconds())
num_jobs.append(runner.num_jobs())
# Calculate the current # of jobs staying in the mix node
num_packets_emitted_from_mixnode = len(intervals)
num_packets_sent_to_mixnode = sent_packet_queue.qsize()
num_jobs.append(
num_packets_sent_to_mixnode - num_packets_emitted_from_mixnode
)
ts = now ts = now
# Remove the first interval that would be much larger than other intervals, # Remove the first interval that would be much larger than other intervals,
# because of the delay in mix node. # because of the delay in mix node.
intervals = intervals[1:] intervals = intervals[1:]
@ -87,16 +93,20 @@ class TestMixNodeRunner(TestCase):
) )
@staticmethod @staticmethod
def send_packets( async def send_packets(
inbound_socket: PacketQueue, inbound_socket: PacketQueue,
packet: SphinxPacket, packet: SphinxPacket,
node_addr: NodeAddress, node_addr: NodeAddress,
cnt: int, cnt: int,
rate_per_min: int, rate_per_min: int,
# For testing purpose, to inform the caller how many packets have been sent to the inbound_socket
sent_packet_queue: PacketQueue,
): ):
for _ in range(cnt): for _ in range(cnt):
time.sleep(poisson_interval_sec(rate_per_min)) # Since the task is not heavy, just sleep for seconds instead of using emission_notifier
inbound_socket.put((node_addr, packet)) await asyncio.sleep(poisson_interval_sec(rate_per_min))
await inbound_socket.put((node_addr, packet))
await sent_packet_queue.put((node_addr, packet))
@staticmethod @staticmethod
def init() -> Tuple[Mixnet, MixnetTopology]: def init() -> Tuple[Mixnet, MixnetTopology]:

12
mixnet/test_utils.py Normal file
View File

@ -0,0 +1,12 @@
import asyncio
def with_test_timeout(t):
def wrapper(coroutine):
async def run(*args, **kwargs):
async with asyncio.timeout(t):
return await coroutine(*args, **kwargs)
return run
return wrapper

View File

@ -5,6 +5,3 @@ numpy==1.26.3
pycparser==2.21 pycparser==2.21
pysphinx==0.0.1 pysphinx==0.0.1
scipy==1.11.4 scipy==1.11.4
setuptools==69.0.3
timeout-decorator==0.5.0
wheel==0.42.0