Mixnet: Refactor with asyncio (#53)

This commit is contained in:
Youngjoon Lee 2024-01-25 18:04:55 +09:00 committed by GitHub
parent 30d52791c3
commit ebc069b112
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 165 additions and 186 deletions

View File

@ -1,9 +1,6 @@
from __future__ import annotations
import queue
import time
from datetime import datetime, timedelta
from threading import Thread
import asyncio
from mixnet.mixnet import Mixnet, MixnetTopology
from mixnet.node import PacketQueue
@ -11,7 +8,14 @@ from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_interval_sec
class MixClientRunner(Thread):
async def mixclient_emitter(
mixnet: Mixnet,
topology: MixnetTopology,
emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec
redundancy: int, # b in the spec
real_packet_queue: PacketQueue,
outbound_socket: PacketQueue,
):
"""
Emit packets at the Poisson emission_rate_per_min.
@ -21,55 +25,55 @@ class MixClientRunner(Thread):
If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min.
"""
def __init__(
self,
mixnet: Mixnet,
topology: MixnetTopology,
emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec
redundancy: int, # b in the spec
real_packet_queue: PacketQueue,
outbound_socket: PacketQueue,
):
super().__init__()
self.mixnet = mixnet
self.topology = topology
self.emission_rate_per_min = emission_rate_per_min
self.redundancy = redundancy
self.real_packet_queue = real_packet_queue
self.redundant_real_packet_queue: PacketQueue = queue.Queue()
self.outbound_socket = outbound_socket
redundant_real_packet_queue: PacketQueue = asyncio.Queue()
def run(self) -> None:
# Here in Python, this thread is implemented in synchronous manner.
# In the real implementation, consider implementing this in asynchronous if possible.
emission_notifier_queue = asyncio.Queue()
_ = asyncio.create_task(
emission_notifier(emission_rate_per_min, emission_notifier_queue)
)
next_emission_ts = datetime.now() + timedelta(
seconds=poisson_interval_sec(self.emission_rate_per_min)
)
while True:
time.sleep(1 / 1000)
if datetime.now() < next_emission_ts:
continue
next_emission_ts += timedelta(
seconds=poisson_interval_sec(self.emission_rate_per_min)
while True:
# Wait until the next emission time
_ = await emission_notifier_queue.get()
try:
await emit(
mixnet,
topology,
redundancy,
real_packet_queue,
redundant_real_packet_queue,
outbound_socket,
)
finally:
# Python convention: indicate that the previously enqueued task has been processed
emission_notifier_queue.task_done()
if not self.redundant_real_packet_queue.empty():
addr, packet = self.redundant_real_packet_queue.get()
self.outbound_socket.put((addr, packet))
continue
if not self.real_packet_queue.empty():
addr, packet = self.real_packet_queue.get()
# Schedule redundant real packets
for _ in range(self.redundancy - 1):
self.redundant_real_packet_queue.put((addr, packet))
self.outbound_socket.put((addr, packet))
async def emit(
mixnet: Mixnet,
topology: MixnetTopology,
redundancy: int, # b in the spec
real_packet_queue: PacketQueue,
redundant_real_packet_queue: PacketQueue,
outbound_socket: PacketQueue,
):
if not redundant_real_packet_queue.empty():
addr, packet = redundant_real_packet_queue.get_nowait()
await outbound_socket.put((addr, packet))
return
packet, route = PacketBuilder.drop_cover(
b"drop cover", self.mixnet, self.topology
).next()
self.outbound_socket.put((route[0].addr, packet))
if not real_packet_queue.empty():
addr, packet = real_packet_queue.get_nowait()
# Schedule redundant real packets
for _ in range(redundancy - 1):
redundant_real_packet_queue.put_nowait((addr, packet))
await outbound_socket.put((addr, packet))
packet, route = PacketBuilder.drop_cover(b"drop cover", mixnet, topology).next()
await outbound_socket.put((route[0].addr, packet))
async def emission_notifier(emission_rate_per_min: int, queue: asyncio.Queue):
while True:
await asyncio.sleep(poisson_interval_sec(emission_rate_per_min))
queue.put_nowait(None)

View File

@ -1,10 +1,7 @@
from __future__ import annotations
import queue
import threading
import time
import asyncio
from dataclasses import dataclass
from threading import Thread
from typing import Tuple, TypeAlias
from cryptography.hazmat.primitives.asymmetric.x25519 import (
@ -27,9 +24,9 @@ NodeId: TypeAlias = BlsPublicKey
# 32-byte that represents an IP address and a port of a mix node.
NodeAddress: TypeAlias = bytes
PacketQueue: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket]]"
PacketQueue: TypeAlias = "asyncio.Queue[Tuple[NodeAddress, SphinxPacket]]"
PacketPayloadQueue: TypeAlias = (
"queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
"asyncio.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
)
@ -50,26 +47,26 @@ class MixNode:
def start(
self,
delay_rate_per_min: int,
delay_rate_per_min: int, # Poisson rate parameter: mu
inbound_socket: PacketQueue,
outbound_socket: PacketPayloadQueue,
) -> MixNodeRunner:
thread = MixNodeRunner(
self.encryption_private_key,
delay_rate_per_min,
inbound_socket,
outbound_socket,
) -> asyncio.Task:
return asyncio.create_task(
MixNodeRunner(
self.encryption_private_key,
delay_rate_per_min,
inbound_socket,
outbound_socket,
).run()
)
thread.daemon = True
thread.start()
return thread
class MixNodeRunner(Thread):
class MixNodeRunner:
"""
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
A class handling incoming packets with delays
This thread approximates a M/M/inf queue.
This class is defined separated with the MixNode class,
in order to define the MixNode as a simple dataclass for clarity.
"""
def __init__(
@ -79,96 +76,55 @@ class MixNodeRunner(Thread):
inbound_socket: PacketQueue,
outbound_socket: PacketPayloadQueue,
):
super().__init__()
self.encryption_private_key = encryption_private_key
self.delay_rate_per_min = delay_rate_per_min
self.inbound_socket = inbound_socket
self.outbound_socket = outbound_socket
self.num_processing = AtomicInt(0)
def run(self) -> None:
# Here in Python, this thread is implemented in synchronous manner.
# In the real implementation, consider implementing this in asynchronous if possible,
# to approximate a M/M/inf queue
async def run(self):
"""
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
This thread approximates a M/M/inf queue.
"""
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks = set()
while True:
_, packet = self.inbound_socket.get()
thread = MixNodePacketProcessor(
packet,
self.encryption_private_key,
self.delay_rate_per_min,
self.outbound_socket,
self.num_processing,
_, packet = await self.inbound_socket.get()
task = asyncio.create_task(
self.process_packet(
packet,
)
)
thread.daemon = True
self.num_processing.add(1)
thread.start()
self.tasks.add(task)
# To discard the task from the set automatically when it is done.
task.add_done_callback(self.tasks.discard)
def num_jobs(self) -> int:
"""
Return the number of packets that are being processed or still in the inbound socket.
If this thread works as a M/M/inf queue completely,
the number of packets that are still in the inbound socket must be always 0.
"""
return self.num_processing.get() + self.inbound_socket.qsize()
class MixNodePacketProcessor(Thread):
"""
Process a single packet with a delay that follows exponential distribution,
and forward it to the next mix node or the mix destination
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
"""
def __init__(
async def process_packet(
self,
packet: SphinxPacket,
encryption_private_key: X25519PrivateKey,
delay_rate_per_min: int, # Poisson rate parameter: mu
outbound_socket: PacketPayloadQueue,
num_processing: AtomicInt,
):
super().__init__()
self.packet = packet
self.encryption_private_key = encryption_private_key
self.delay_rate_per_min = delay_rate_per_min
self.outbound_socket = outbound_socket
self.num_processing = num_processing
"""
Process a single packet with a delay that follows exponential distribution,
and forward it to the next mix node or the mix destination
def run(self) -> None:
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
"""
delay_sec = poisson_interval_sec(self.delay_rate_per_min)
time.sleep(delay_sec)
await asyncio.sleep(delay_sec)
processed = self.packet.process(self.encryption_private_key)
processed = packet.process(self.encryption_private_key)
match processed:
case ProcessedForwardHopPacket():
self.outbound_socket.put(
await self.outbound_socket.put(
(processed.next_node_address, processed.next_packet)
)
case ProcessedFinalHopPacket():
self.outbound_socket.put(
await self.outbound_socket.put(
(processed.destination_node_address, processed.payload)
)
case _:
raise UnknownHeaderTypeError
self.num_processing.sub(1)
class AtomicInt:
def __init__(self, initial: int) -> None:
self.lock = threading.Lock()
self.value = initial
def add(self, v: int):
with self.lock:
self.value += v
def sub(self, v: int):
with self.lock:
self.value -= v
def get(self) -> int:
with self.lock:
return self.value

View File

@ -1,54 +1,54 @@
import queue
import asyncio
from datetime import datetime
from typing import Tuple
from unittest import TestCase
from unittest import IsolatedAsyncioTestCase
import numpy
import timeout_decorator
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from mixnet.bls import generate_bls
from mixnet.client import MixClientRunner
from mixnet.client import mixclient_emitter
from mixnet.mixnet import Mixnet, MixnetTopology
from mixnet.node import MixNode, PacketQueue
from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_mean_interval_sec
from mixnet.utils import random_bytes
from mixnet.test_utils import with_test_timeout
class TestMixClientRunner(TestCase):
@timeout_decorator.timeout(180)
def test_mixclient_runner_emission_rate(self):
class TestMixClient(IsolatedAsyncioTestCase):
@with_test_timeout(100)
async def test_mixclient_emitter(self):
mixnet, topology = self.init()
real_packet_queue: PacketQueue = queue.Queue()
outbound_socket: PacketQueue = queue.Queue()
real_packet_queue: PacketQueue = asyncio.Queue()
outbound_socket: PacketQueue = asyncio.Queue()
emission_rate_per_min = 30
redundancy = 3
client = MixClientRunner(
mixnet,
topology,
emission_rate_per_min,
redundancy,
real_packet_queue,
outbound_socket,
_ = asyncio.create_task(
mixclient_emitter(
mixnet,
topology,
emission_rate_per_min,
redundancy,
real_packet_queue,
outbound_socket,
)
)
client.daemon = True
client.start()
# Create packets. At least two packets are expected to be generated from a 3500-byte msg
builder = PacketBuilder.real(random_bytes(3500), mixnet, topology)
# Schedule two packets to the mix client without any interval
packet, route = builder.next()
real_packet_queue.put((route[0].addr, packet))
await real_packet_queue.put((route[0].addr, packet))
packet, route = builder.next()
real_packet_queue.put((route[0].addr, packet))
await real_packet_queue.put((route[0].addr, packet))
# Calculate intervals between packet emissions from the mix client
intervals = []
ts = datetime.now()
for _ in range(30):
_ = outbound_socket.get()
_ = await outbound_socket.get()
now = datetime.now()
intervals.append((now - ts).total_seconds())
ts = now

View File

@ -1,12 +1,9 @@
import queue
import threading
import time
import asyncio
from datetime import datetime
from typing import Tuple
from unittest import TestCase
from unittest import IsolatedAsyncioTestCase
import numpy
import timeout_decorator
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from pysphinx.sphinx import SphinxPacket
@ -15,12 +12,13 @@ from mixnet.mixnet import Mixnet, MixnetTopology
from mixnet.node import MixNode, NodeAddress, PacketPayloadQueue, PacketQueue
from mixnet.packet import PacketBuilder
from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec
from mixnet.test_utils import with_test_timeout
from mixnet.utils import random_bytes
class TestMixNodeRunner(TestCase):
@timeout_decorator.timeout(180)
def test_mixnode_runner_emission_rate(self):
class TestMixNodeRunner(IsolatedAsyncioTestCase):
@with_test_timeout(180)
async def test_mixnode_runner_emission_rate(self):
"""
Test if MixNodeRunner works as a M/M/inf queue.
@ -29,41 +27,49 @@ class TestMixNodeRunner(TestCase):
the rate of outputs should be `lambda`.
"""
mixnet, topology = self.init()
inbound_socket: PacketQueue = queue.Queue()
outbound_socket: PacketPayloadQueue = queue.Queue()
inbound_socket: PacketQueue = asyncio.Queue()
outbound_socket: PacketPayloadQueue = asyncio.Queue()
packet, route = PacketBuilder.real(b"msg", mixnet, topology).next()
delay_rate_per_min = 30 # mu (= 2s delay on average)
# Start only the first mix node for testing
runner = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket)
_ = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket)
# Send packets to the first mix node in a Poisson distribution
packet_count = 100
emission_rate_per_min = 120 # lambda (= 2msg/sec)
sender = threading.Thread(
target=self.send_packets,
args=(
# This queue is just for counting how many packets have been sent so far.
sent_packet_queue: PacketQueue = asyncio.Queue()
_ = asyncio.create_task(
self.send_packets(
inbound_socket,
packet,
route[0].addr,
packet_count,
emission_rate_per_min,
),
sent_packet_queue,
)
)
sender.daemon = True
sender.start()
# Calculate intervals between outputs and gather num_jobs in the first mix node.
intervals = []
num_jobs = []
ts = datetime.now()
for _ in range(packet_count):
_ = outbound_socket.get()
_ = await outbound_socket.get()
now = datetime.now()
intervals.append((now - ts).total_seconds())
num_jobs.append(runner.num_jobs())
# Calculate the current # of jobs staying in the mix node
num_packets_emitted_from_mixnode = len(intervals)
num_packets_sent_to_mixnode = sent_packet_queue.qsize()
num_jobs.append(
num_packets_sent_to_mixnode - num_packets_emitted_from_mixnode
)
ts = now
# Remove the first interval that would be much larger than other intervals,
# because of the delay in mix node.
intervals = intervals[1:]
@ -87,16 +93,20 @@ class TestMixNodeRunner(TestCase):
)
@staticmethod
def send_packets(
async def send_packets(
inbound_socket: PacketQueue,
packet: SphinxPacket,
node_addr: NodeAddress,
cnt: int,
rate_per_min: int,
# For testing purpose, to inform the caller how many packets have been sent to the inbound_socket
sent_packet_queue: PacketQueue,
):
for _ in range(cnt):
time.sleep(poisson_interval_sec(rate_per_min))
inbound_socket.put((node_addr, packet))
# Since the task is not heavy, just sleep for seconds instead of using emission_notifier
await asyncio.sleep(poisson_interval_sec(rate_per_min))
await inbound_socket.put((node_addr, packet))
await sent_packet_queue.put((node_addr, packet))
@staticmethod
def init() -> Tuple[Mixnet, MixnetTopology]:

12
mixnet/test_utils.py Normal file
View File

@ -0,0 +1,12 @@
import asyncio
def with_test_timeout(t):
def wrapper(coroutine):
async def run(*args, **kwargs):
async with asyncio.timeout(t):
return await coroutine(*args, **kwargs)
return run
return wrapper

View File

@ -5,6 +5,3 @@ numpy==1.26.3
pycparser==2.21
pysphinx==0.0.1
scipy==1.11.4
setuptools==69.0.3
timeout-decorator==0.5.0
wheel==0.42.0