Mixnet: Refactor with asyncio (#53)
This commit is contained in:
parent
30d52791c3
commit
ebc069b112
106
mixnet/client.py
106
mixnet/client.py
|
@ -1,9 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
import asyncio
|
||||
|
||||
from mixnet.mixnet import Mixnet, MixnetTopology
|
||||
from mixnet.node import PacketQueue
|
||||
|
@ -11,7 +8,14 @@ from mixnet.packet import PacketBuilder
|
|||
from mixnet.poisson import poisson_interval_sec
|
||||
|
||||
|
||||
class MixClientRunner(Thread):
|
||||
async def mixclient_emitter(
|
||||
mixnet: Mixnet,
|
||||
topology: MixnetTopology,
|
||||
emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec
|
||||
redundancy: int, # b in the spec
|
||||
real_packet_queue: PacketQueue,
|
||||
outbound_socket: PacketQueue,
|
||||
):
|
||||
"""
|
||||
Emit packets at the Poisson emission_rate_per_min.
|
||||
|
||||
|
@ -21,55 +25,55 @@ class MixClientRunner(Thread):
|
|||
If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mixnet: Mixnet,
|
||||
topology: MixnetTopology,
|
||||
emission_rate_per_min: int, # Poisson rate parameter: lambda in the spec
|
||||
redundancy: int, # b in the spec
|
||||
real_packet_queue: PacketQueue,
|
||||
outbound_socket: PacketQueue,
|
||||
):
|
||||
super().__init__()
|
||||
self.mixnet = mixnet
|
||||
self.topology = topology
|
||||
self.emission_rate_per_min = emission_rate_per_min
|
||||
self.redundancy = redundancy
|
||||
self.real_packet_queue = real_packet_queue
|
||||
self.redundant_real_packet_queue: PacketQueue = queue.Queue()
|
||||
self.outbound_socket = outbound_socket
|
||||
redundant_real_packet_queue: PacketQueue = asyncio.Queue()
|
||||
|
||||
def run(self) -> None:
|
||||
# Here in Python, this thread is implemented in synchronous manner.
|
||||
# In the real implementation, consider implementing this in asynchronous if possible.
|
||||
emission_notifier_queue = asyncio.Queue()
|
||||
_ = asyncio.create_task(
|
||||
emission_notifier(emission_rate_per_min, emission_notifier_queue)
|
||||
)
|
||||
|
||||
next_emission_ts = datetime.now() + timedelta(
|
||||
seconds=poisson_interval_sec(self.emission_rate_per_min)
|
||||
)
|
||||
|
||||
while True:
|
||||
time.sleep(1 / 1000)
|
||||
|
||||
if datetime.now() < next_emission_ts:
|
||||
continue
|
||||
|
||||
next_emission_ts += timedelta(
|
||||
seconds=poisson_interval_sec(self.emission_rate_per_min)
|
||||
while True:
|
||||
# Wait until the next emission time
|
||||
_ = await emission_notifier_queue.get()
|
||||
try:
|
||||
await emit(
|
||||
mixnet,
|
||||
topology,
|
||||
redundancy,
|
||||
real_packet_queue,
|
||||
redundant_real_packet_queue,
|
||||
outbound_socket,
|
||||
)
|
||||
finally:
|
||||
# Python convention: indicate that the previously enqueued task has been processed
|
||||
emission_notifier_queue.task_done()
|
||||
|
||||
if not self.redundant_real_packet_queue.empty():
|
||||
addr, packet = self.redundant_real_packet_queue.get()
|
||||
self.outbound_socket.put((addr, packet))
|
||||
continue
|
||||
|
||||
if not self.real_packet_queue.empty():
|
||||
addr, packet = self.real_packet_queue.get()
|
||||
# Schedule redundant real packets
|
||||
for _ in range(self.redundancy - 1):
|
||||
self.redundant_real_packet_queue.put((addr, packet))
|
||||
self.outbound_socket.put((addr, packet))
|
||||
async def emit(
|
||||
mixnet: Mixnet,
|
||||
topology: MixnetTopology,
|
||||
redundancy: int, # b in the spec
|
||||
real_packet_queue: PacketQueue,
|
||||
redundant_real_packet_queue: PacketQueue,
|
||||
outbound_socket: PacketQueue,
|
||||
):
|
||||
if not redundant_real_packet_queue.empty():
|
||||
addr, packet = redundant_real_packet_queue.get_nowait()
|
||||
await outbound_socket.put((addr, packet))
|
||||
return
|
||||
|
||||
packet, route = PacketBuilder.drop_cover(
|
||||
b"drop cover", self.mixnet, self.topology
|
||||
).next()
|
||||
self.outbound_socket.put((route[0].addr, packet))
|
||||
if not real_packet_queue.empty():
|
||||
addr, packet = real_packet_queue.get_nowait()
|
||||
# Schedule redundant real packets
|
||||
for _ in range(redundancy - 1):
|
||||
redundant_real_packet_queue.put_nowait((addr, packet))
|
||||
await outbound_socket.put((addr, packet))
|
||||
|
||||
packet, route = PacketBuilder.drop_cover(b"drop cover", mixnet, topology).next()
|
||||
await outbound_socket.put((route[0].addr, packet))
|
||||
|
||||
|
||||
async def emission_notifier(emission_rate_per_min: int, queue: asyncio.Queue):
|
||||
while True:
|
||||
await asyncio.sleep(poisson_interval_sec(emission_rate_per_min))
|
||||
queue.put_nowait(None)
|
||||
|
|
134
mixnet/node.py
134
mixnet/node.py
|
@ -1,10 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Tuple, TypeAlias
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import (
|
||||
|
@ -27,9 +24,9 @@ NodeId: TypeAlias = BlsPublicKey
|
|||
# 32-byte that represents an IP address and a port of a mix node.
|
||||
NodeAddress: TypeAlias = bytes
|
||||
|
||||
PacketQueue: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket]]"
|
||||
PacketQueue: TypeAlias = "asyncio.Queue[Tuple[NodeAddress, SphinxPacket]]"
|
||||
PacketPayloadQueue: TypeAlias = (
|
||||
"queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
|
||||
"asyncio.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
|
||||
)
|
||||
|
||||
|
||||
|
@ -50,26 +47,26 @@ class MixNode:
|
|||
|
||||
def start(
|
||||
self,
|
||||
delay_rate_per_min: int,
|
||||
delay_rate_per_min: int, # Poisson rate parameter: mu
|
||||
inbound_socket: PacketQueue,
|
||||
outbound_socket: PacketPayloadQueue,
|
||||
) -> MixNodeRunner:
|
||||
thread = MixNodeRunner(
|
||||
self.encryption_private_key,
|
||||
delay_rate_per_min,
|
||||
inbound_socket,
|
||||
outbound_socket,
|
||||
) -> asyncio.Task:
|
||||
return asyncio.create_task(
|
||||
MixNodeRunner(
|
||||
self.encryption_private_key,
|
||||
delay_rate_per_min,
|
||||
inbound_socket,
|
||||
outbound_socket,
|
||||
).run()
|
||||
)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
class MixNodeRunner(Thread):
|
||||
class MixNodeRunner:
|
||||
"""
|
||||
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
|
||||
A class handling incoming packets with delays
|
||||
|
||||
This thread approximates a M/M/inf queue.
|
||||
This class is defined separated with the MixNode class,
|
||||
in order to define the MixNode as a simple dataclass for clarity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -79,96 +76,55 @@ class MixNodeRunner(Thread):
|
|||
inbound_socket: PacketQueue,
|
||||
outbound_socket: PacketPayloadQueue,
|
||||
):
|
||||
super().__init__()
|
||||
self.encryption_private_key = encryption_private_key
|
||||
self.delay_rate_per_min = delay_rate_per_min
|
||||
self.inbound_socket = inbound_socket
|
||||
self.outbound_socket = outbound_socket
|
||||
self.num_processing = AtomicInt(0)
|
||||
|
||||
def run(self) -> None:
|
||||
# Here in Python, this thread is implemented in synchronous manner.
|
||||
# In the real implementation, consider implementing this in asynchronous if possible,
|
||||
# to approximate a M/M/inf queue
|
||||
async def run(self):
|
||||
"""
|
||||
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
|
||||
|
||||
This thread approximates a M/M/inf queue.
|
||||
"""
|
||||
|
||||
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
|
||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||
self.tasks = set()
|
||||
|
||||
while True:
|
||||
_, packet = self.inbound_socket.get()
|
||||
thread = MixNodePacketProcessor(
|
||||
packet,
|
||||
self.encryption_private_key,
|
||||
self.delay_rate_per_min,
|
||||
self.outbound_socket,
|
||||
self.num_processing,
|
||||
_, packet = await self.inbound_socket.get()
|
||||
task = asyncio.create_task(
|
||||
self.process_packet(
|
||||
packet,
|
||||
)
|
||||
)
|
||||
thread.daemon = True
|
||||
self.num_processing.add(1)
|
||||
thread.start()
|
||||
self.tasks.add(task)
|
||||
# To discard the task from the set automatically when it is done.
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def num_jobs(self) -> int:
|
||||
"""
|
||||
Return the number of packets that are being processed or still in the inbound socket.
|
||||
|
||||
If this thread works as a M/M/inf queue completely,
|
||||
the number of packets that are still in the inbound socket must be always 0.
|
||||
"""
|
||||
return self.num_processing.get() + self.inbound_socket.qsize()
|
||||
|
||||
|
||||
class MixNodePacketProcessor(Thread):
|
||||
"""
|
||||
Process a single packet with a delay that follows exponential distribution,
|
||||
and forward it to the next mix node or the mix destination
|
||||
|
||||
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
async def process_packet(
|
||||
self,
|
||||
packet: SphinxPacket,
|
||||
encryption_private_key: X25519PrivateKey,
|
||||
delay_rate_per_min: int, # Poisson rate parameter: mu
|
||||
outbound_socket: PacketPayloadQueue,
|
||||
num_processing: AtomicInt,
|
||||
):
|
||||
super().__init__()
|
||||
self.packet = packet
|
||||
self.encryption_private_key = encryption_private_key
|
||||
self.delay_rate_per_min = delay_rate_per_min
|
||||
self.outbound_socket = outbound_socket
|
||||
self.num_processing = num_processing
|
||||
"""
|
||||
Process a single packet with a delay that follows exponential distribution,
|
||||
and forward it to the next mix node or the mix destination
|
||||
|
||||
def run(self) -> None:
|
||||
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
|
||||
"""
|
||||
delay_sec = poisson_interval_sec(self.delay_rate_per_min)
|
||||
time.sleep(delay_sec)
|
||||
await asyncio.sleep(delay_sec)
|
||||
|
||||
processed = self.packet.process(self.encryption_private_key)
|
||||
processed = packet.process(self.encryption_private_key)
|
||||
match processed:
|
||||
case ProcessedForwardHopPacket():
|
||||
self.outbound_socket.put(
|
||||
await self.outbound_socket.put(
|
||||
(processed.next_node_address, processed.next_packet)
|
||||
)
|
||||
case ProcessedFinalHopPacket():
|
||||
self.outbound_socket.put(
|
||||
await self.outbound_socket.put(
|
||||
(processed.destination_node_address, processed.payload)
|
||||
)
|
||||
case _:
|
||||
raise UnknownHeaderTypeError
|
||||
|
||||
self.num_processing.sub(1)
|
||||
|
||||
|
||||
class AtomicInt:
|
||||
def __init__(self, initial: int) -> None:
|
||||
self.lock = threading.Lock()
|
||||
self.value = initial
|
||||
|
||||
def add(self, v: int):
|
||||
with self.lock:
|
||||
self.value += v
|
||||
|
||||
def sub(self, v: int):
|
||||
with self.lock:
|
||||
self.value -= v
|
||||
|
||||
def get(self) -> int:
|
||||
with self.lock:
|
||||
return self.value
|
||||
|
|
|
@ -1,54 +1,54 @@
|
|||
import queue
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from unittest import TestCase
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
import numpy
|
||||
import timeout_decorator
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
|
||||
from mixnet.bls import generate_bls
|
||||
from mixnet.client import MixClientRunner
|
||||
from mixnet.client import mixclient_emitter
|
||||
from mixnet.mixnet import Mixnet, MixnetTopology
|
||||
from mixnet.node import MixNode, PacketQueue
|
||||
from mixnet.packet import PacketBuilder
|
||||
from mixnet.poisson import poisson_mean_interval_sec
|
||||
from mixnet.utils import random_bytes
|
||||
from mixnet.test_utils import with_test_timeout
|
||||
|
||||
|
||||
class TestMixClientRunner(TestCase):
|
||||
@timeout_decorator.timeout(180)
|
||||
def test_mixclient_runner_emission_rate(self):
|
||||
class TestMixClient(IsolatedAsyncioTestCase):
|
||||
@with_test_timeout(100)
|
||||
async def test_mixclient_emitter(self):
|
||||
mixnet, topology = self.init()
|
||||
real_packet_queue: PacketQueue = queue.Queue()
|
||||
outbound_socket: PacketQueue = queue.Queue()
|
||||
real_packet_queue: PacketQueue = asyncio.Queue()
|
||||
outbound_socket: PacketQueue = asyncio.Queue()
|
||||
|
||||
emission_rate_per_min = 30
|
||||
redundancy = 3
|
||||
client = MixClientRunner(
|
||||
mixnet,
|
||||
topology,
|
||||
emission_rate_per_min,
|
||||
redundancy,
|
||||
real_packet_queue,
|
||||
outbound_socket,
|
||||
_ = asyncio.create_task(
|
||||
mixclient_emitter(
|
||||
mixnet,
|
||||
topology,
|
||||
emission_rate_per_min,
|
||||
redundancy,
|
||||
real_packet_queue,
|
||||
outbound_socket,
|
||||
)
|
||||
)
|
||||
client.daemon = True
|
||||
client.start()
|
||||
|
||||
# Create packets. At least two packets are expected to be generated from a 3500-byte msg
|
||||
builder = PacketBuilder.real(random_bytes(3500), mixnet, topology)
|
||||
# Schedule two packets to the mix client without any interval
|
||||
packet, route = builder.next()
|
||||
real_packet_queue.put((route[0].addr, packet))
|
||||
await real_packet_queue.put((route[0].addr, packet))
|
||||
packet, route = builder.next()
|
||||
real_packet_queue.put((route[0].addr, packet))
|
||||
await real_packet_queue.put((route[0].addr, packet))
|
||||
|
||||
# Calculate intervals between packet emissions from the mix client
|
||||
intervals = []
|
||||
ts = datetime.now()
|
||||
for _ in range(30):
|
||||
_ = outbound_socket.get()
|
||||
_ = await outbound_socket.get()
|
||||
now = datetime.now()
|
||||
intervals.append((now - ts).total_seconds())
|
||||
ts = now
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from unittest import TestCase
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
import numpy
|
||||
import timeout_decorator
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
from pysphinx.sphinx import SphinxPacket
|
||||
|
||||
|
@ -15,12 +12,13 @@ from mixnet.mixnet import Mixnet, MixnetTopology
|
|||
from mixnet.node import MixNode, NodeAddress, PacketPayloadQueue, PacketQueue
|
||||
from mixnet.packet import PacketBuilder
|
||||
from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec
|
||||
from mixnet.test_utils import with_test_timeout
|
||||
from mixnet.utils import random_bytes
|
||||
|
||||
|
||||
class TestMixNodeRunner(TestCase):
|
||||
@timeout_decorator.timeout(180)
|
||||
def test_mixnode_runner_emission_rate(self):
|
||||
class TestMixNodeRunner(IsolatedAsyncioTestCase):
|
||||
@with_test_timeout(180)
|
||||
async def test_mixnode_runner_emission_rate(self):
|
||||
"""
|
||||
Test if MixNodeRunner works as a M/M/inf queue.
|
||||
|
||||
|
@ -29,41 +27,49 @@ class TestMixNodeRunner(TestCase):
|
|||
the rate of outputs should be `lambda`.
|
||||
"""
|
||||
mixnet, topology = self.init()
|
||||
inbound_socket: PacketQueue = queue.Queue()
|
||||
outbound_socket: PacketPayloadQueue = queue.Queue()
|
||||
inbound_socket: PacketQueue = asyncio.Queue()
|
||||
outbound_socket: PacketPayloadQueue = asyncio.Queue()
|
||||
|
||||
packet, route = PacketBuilder.real(b"msg", mixnet, topology).next()
|
||||
|
||||
delay_rate_per_min = 30 # mu (= 2s delay on average)
|
||||
# Start only the first mix node for testing
|
||||
runner = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket)
|
||||
_ = route[0].start(delay_rate_per_min, inbound_socket, outbound_socket)
|
||||
|
||||
# Send packets to the first mix node in a Poisson distribution
|
||||
packet_count = 100
|
||||
emission_rate_per_min = 120 # lambda (= 2msg/sec)
|
||||
sender = threading.Thread(
|
||||
target=self.send_packets,
|
||||
args=(
|
||||
# This queue is just for counting how many packets have been sent so far.
|
||||
sent_packet_queue: PacketQueue = asyncio.Queue()
|
||||
_ = asyncio.create_task(
|
||||
self.send_packets(
|
||||
inbound_socket,
|
||||
packet,
|
||||
route[0].addr,
|
||||
packet_count,
|
||||
emission_rate_per_min,
|
||||
),
|
||||
sent_packet_queue,
|
||||
)
|
||||
)
|
||||
sender.daemon = True
|
||||
sender.start()
|
||||
|
||||
# Calculate intervals between outputs and gather num_jobs in the first mix node.
|
||||
intervals = []
|
||||
num_jobs = []
|
||||
ts = datetime.now()
|
||||
for _ in range(packet_count):
|
||||
_ = outbound_socket.get()
|
||||
_ = await outbound_socket.get()
|
||||
now = datetime.now()
|
||||
intervals.append((now - ts).total_seconds())
|
||||
num_jobs.append(runner.num_jobs())
|
||||
|
||||
# Calculate the current # of jobs staying in the mix node
|
||||
num_packets_emitted_from_mixnode = len(intervals)
|
||||
num_packets_sent_to_mixnode = sent_packet_queue.qsize()
|
||||
num_jobs.append(
|
||||
num_packets_sent_to_mixnode - num_packets_emitted_from_mixnode
|
||||
)
|
||||
|
||||
ts = now
|
||||
|
||||
# Remove the first interval that would be much larger than other intervals,
|
||||
# because of the delay in mix node.
|
||||
intervals = intervals[1:]
|
||||
|
@ -87,16 +93,20 @@ class TestMixNodeRunner(TestCase):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def send_packets(
|
||||
async def send_packets(
|
||||
inbound_socket: PacketQueue,
|
||||
packet: SphinxPacket,
|
||||
node_addr: NodeAddress,
|
||||
cnt: int,
|
||||
rate_per_min: int,
|
||||
# For testing purpose, to inform the caller how many packets have been sent to the inbound_socket
|
||||
sent_packet_queue: PacketQueue,
|
||||
):
|
||||
for _ in range(cnt):
|
||||
time.sleep(poisson_interval_sec(rate_per_min))
|
||||
inbound_socket.put((node_addr, packet))
|
||||
# Since the task is not heavy, just sleep for seconds instead of using emission_notifier
|
||||
await asyncio.sleep(poisson_interval_sec(rate_per_min))
|
||||
await inbound_socket.put((node_addr, packet))
|
||||
await sent_packet_queue.put((node_addr, packet))
|
||||
|
||||
@staticmethod
|
||||
def init() -> Tuple[Mixnet, MixnetTopology]:
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
import asyncio
|
||||
|
||||
|
||||
def with_test_timeout(t):
|
||||
def wrapper(coroutine):
|
||||
async def run(*args, **kwargs):
|
||||
async with asyncio.timeout(t):
|
||||
return await coroutine(*args, **kwargs)
|
||||
|
||||
return run
|
||||
|
||||
return wrapper
|
|
@ -5,6 +5,3 @@ numpy==1.26.3
|
|||
pycparser==2.21
|
||||
pysphinx==0.0.1
|
||||
scipy==1.11.4
|
||||
setuptools==69.0.3
|
||||
timeout-decorator==0.5.0
|
||||
wheel==0.42.0
|
||||
|
|
Loading…
Reference in New Issue