nomos-specs/mixnet/node.py

175 lines
5.2 KiB
Python
Raw Normal View History

2024-01-23 01:29:14 +00:00
from __future__ import annotations
import queue
import threading
import time
from dataclasses import dataclass
from threading import Thread
from typing import Tuple, TypeAlias
from cryptography.hazmat.primitives.asymmetric.x25519 import (
X25519PrivateKey,
X25519PublicKey,
)
from pysphinx.node import Node
from pysphinx.sphinx import (
Payload,
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
SphinxPacket,
UnknownHeaderTypeError,
)
from mixnet.bls import BlsPrivateKey, BlsPublicKey
from mixnet.poisson import poisson_interval_sec
NodeId: TypeAlias = BlsPublicKey
# 32-byte that represents an IP address and a port of a mix node.
NodeAddress: TypeAlias = bytes
PacketQueue: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket]]"
PacketPayloadQueue: TypeAlias = (
"queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]"
)
@dataclass
class MixNode:
identity_private_key: BlsPrivateKey
encryption_private_key: X25519PrivateKey
addr: NodeAddress
def identity_public_key(self) -> BlsPublicKey:
return self.identity_private_key.get_g1()
def encryption_public_key(self) -> X25519PublicKey:
return self.encryption_private_key.public_key()
def sphinx_node(self) -> Node:
return Node(self.encryption_private_key, self.addr)
def start(
self,
delay_rate_per_min: int,
inbound_socket: PacketQueue,
outbound_socket: PacketPayloadQueue,
) -> MixNodeRunner:
thread = MixNodeRunner(
self.encryption_private_key,
delay_rate_per_min,
inbound_socket,
outbound_socket,
)
thread.daemon = True
thread.start()
return thread
class MixNodeRunner(Thread):
"""
Read SphinxPackets from inbound socket and spawn a thread for each packet to process it.
This thread approximates a M/M/inf queue.
"""
def __init__(
self,
encryption_private_key: X25519PrivateKey,
delay_rate_per_min: int, # Poisson rate parameter: mu
inbound_socket: PacketQueue,
outbound_socket: PacketPayloadQueue,
):
super().__init__()
self.encryption_private_key = encryption_private_key
self.delay_rate_per_min = delay_rate_per_min
self.inbound_socket = inbound_socket
self.outbound_socket = outbound_socket
self.num_processing = AtomicInt(0)
def run(self) -> None:
# Here in Python, this thread is implemented in synchronous manner.
# In the real implementation, consider implementing this in asynchronous if possible,
# to approximate a M/M/inf queue
while True:
_, packet = self.inbound_socket.get()
thread = MixNodePacketProcessor(
packet,
self.encryption_private_key,
self.delay_rate_per_min,
self.outbound_socket,
self.num_processing,
)
thread.daemon = True
self.num_processing.add(1)
thread.start()
def num_jobs(self) -> int:
"""
Return the number of packets that are being processed or still in the inbound socket.
If this thread works as a M/M/inf queue completely,
the number of packets that are still in the inbound socket must be always 0.
"""
return self.num_processing.get() + self.inbound_socket.qsize()
class MixNodePacketProcessor(Thread):
"""
Process a single packet with a delay that follows exponential distribution,
and forward it to the next mix node or the mix destination
This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates.
"""
def __init__(
self,
packet: SphinxPacket,
encryption_private_key: X25519PrivateKey,
delay_rate_per_min: int, # Poisson rate parameter: mu
outbound_socket: PacketPayloadQueue,
num_processing: AtomicInt,
):
super().__init__()
self.packet = packet
self.encryption_private_key = encryption_private_key
self.delay_rate_per_min = delay_rate_per_min
self.outbound_socket = outbound_socket
self.num_processing = num_processing
def run(self) -> None:
delay_sec = poisson_interval_sec(self.delay_rate_per_min)
time.sleep(delay_sec)
processed = self.packet.process(self.encryption_private_key)
match processed:
case ProcessedForwardHopPacket():
self.outbound_socket.put(
(processed.next_node_address, processed.next_packet)
)
case ProcessedFinalHopPacket():
self.outbound_socket.put(
(processed.destination_node_address, processed.payload)
)
case _:
raise UnknownHeaderTypeError
self.num_processing.sub(1)
class AtomicInt:
def __init__(self, initial: int) -> None:
self.lock = threading.Lock()
self.value = initial
def add(self, v: int):
with self.lock:
self.value += v
def sub(self, v: int):
with self.lock:
self.value -= v
def get(self) -> int:
with self.lock:
return self.value