Mixnet: Initial simulation (#6)

This commit is contained in:
Youngjoon Lee 2024-08-01 11:07:52 +09:00 committed by GitHub
parent 537f86f53f
commit 39eabe1537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 2417 additions and 0 deletions

30
.github/workflows/ci.yaml vendored Normal file
View File

@ -0,0 +1,30 @@
name: CI
on:
pull_request:
branches:
- "*"
push:
branches: [master]
jobs:
mixnet:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Set up Python 3.x
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Install dependencies for mixnet
working-directory: mixnet
run: pip install -r requirements.txt
- name: Run unit tests
working-directory: mixnet
run: python -m unittest -v
- name: Run a short mixnet simulation
working-directory: mixnet
run: python -m cmd.main --config config.ci.yaml

2
mixnet/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
.venv/
*.csv

140
mixnet/README.md Normal file
View File

@ -0,0 +1,140 @@
# NomMix Simulation
* [Project Structure](#project-structure)
* [Features](#features)
* [Future Plans](#future-plans)
* [Installation](#installation)
* [Getting Started](#getting-started)
## Project Structure
- `cmd`: CLIs to run the simulation and analyze the results.
- `sim`: Simulation that runs the NomMix defined in the `protocol` package.
- `protocol`: Core NomMix protocol implementation, which is going to be moved to the [nomos-repos](https://github.com/logos-co/nomos-specs) repository once verified by simulations.
- `framework`: Asynchronous framework that provides essential async functions for simulations and tests, implemented with various async libraries ([asyncio](https://docs.python.org/3/library/asyncio.html), [μSim](https://usim.readthedocs.io/en/latest/), etc.)
## Features
- NomMix protocol simulation
- Performance measurements
- Bandwidth usages
- Message dissemination time
- Privacy property analysis
- Message sizes
- Node states and hamming distances
## Future Plans
- More NomMix features
- Temporal mixing
- Level-1 noise
- Adversary simulation to measure the robustness of NomMix
## Installation
Clone the repository and install the dependencies:
```bash
git clone https://github.com/logos-co/nomos-simulations.git
cd nomos-simulations/mixnet
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```
## Getting Started
Copy the [`config.ci.yaml`](./config.ci.yaml) file and adjust the parameters to your needs.
Each parameter is explained in the config file.
For more details, please refer to the [documentation](https://www.notion.so/NomMix-Sim-Getting-Started-ee0e2191f4e7437e93976aff2627d7ce?pvs=4).
Run the simulation with the following command:
```bash
python -m cmd.main --config {config_path}
```
All results are printed in the console as below.
And, all plots are shown once all analysis is done.
```
Spawning node-0 with 3 conns
Spawning node-1 with 3 conns
Spawning node-2 with 3 conns
Spawning node-3 with 3 conns
Spawning node-4 with 3 conns
Spawning node-5 with 3 conns
==========================================
Message Dissemination Time
==========================================
[Mix Propagation Times]
count 7.000000
mean 1.122000
std 0.106276
min 1.009000
25% 1.024500
50% 1.157000
75% 1.174500
max 1.290000
dtype: float64
[Broadcast Dissemination Times]
count 7.000000
mean 0.118429
std 0.004353
min 0.111000
25% 0.116000
50% 0.120000
75% 0.121500
max 0.123000
dtype: float64
==========================================
Message Size Distribution
==========================================
msg_size count
0 1405 179982
==========================================
Node States of All Nodes over Time
SENDING:-1, IDLE:0, RECEIVING:1
==========================================
Node-0 Node-1 Node-2 Node-3 Node-4 Node-5
0 0 0 0 0 0 0
1 0 0 0 0 0 0
2 0 0 0 0 0 0
3 0 0 0 0 0 0
4 0 0 0 0 0 0
... ... ... ... ... ... ...
999995 0 0 0 0 0 0
999996 0 0 0 0 0 1
999997 0 0 0 0 0 0
999998 0 0 0 1 0 0
999999 0 0 0 0 0 0
[1000000 rows x 6 columns]
Saved DataFrame to all_node_states_2024-07-23T09:10:59.csv
State Counts per Node:
Node-0 Node-1 Node-2 Node-3 Node-4 Node-5
0 960004 960004 960004 960004 960004 960004
1 29997 29997 29997 29997 29997 29997
-1 9999 9999 9999 9999 9999 9999
Simulation complete!
```
Please note that the result of node state analysis is saved as a CSV file, as printed in the console.
If you run the simulation again with the different parameters and want to
compare the results of two simulations,
you can calculate the hamming distance between them:
```bash
python -m cmd.hamming \
all_node_states_2024-07-15T18:20:23.csv \
all_node_states_2024-07-15T19:32:45.csv
```
The output is a floating point number between 0 and 1.
If the output is 0, the results of two simulations are identical.
The closer the result is to 1, the more the two results differ from each other.
```
Hamming distance: 0.29997
```

0
mixnet/cmd/__init__.py Normal file
View File

42
mixnet/cmd/hamming.py Normal file
View File

@ -0,0 +1,42 @@
import sys
import pandas as pd
def calculate_hamming_distance(df1, df2):
"""
Caculate the hamming distance between two DataFrames
to quantify the difference between them.
"""
if df1.shape != df2.shape:
raise ValueError(
"DataFrames must have the same shape to calculate Hamming distance."
)
# Compare element-wise and count differences
differences = (df1 != df2).sum().sum()
return differences / df1.size # normalize the distance
def main():
if len(sys.argv) != 3:
print("Usage: python hamming.py <csv_path1> <csv_path2>")
sys.exit(1)
csv_path1 = sys.argv[1]
csv_path2 = sys.argv[2]
# Load the CSV files into DataFrames
df1 = pd.read_csv(csv_path1)
df2 = pd.read_csv(csv_path2)
# Calculate the Hamming distance
try:
hamming_distance = calculate_hamming_distance(df1, df2)
print(f"Hamming distance: {hamming_distance}")
except ValueError as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()

25
mixnet/cmd/main.py Normal file
View File

@ -0,0 +1,25 @@
import argparse
import usim
from sim.config import Config
from sim.simulation import Simulation
if __name__ == "__main__":
"""
Read a config file and run a simulation
"""
parser = argparse.ArgumentParser(
description="Run mixnet simulation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config", type=str, required=True, help="Configuration file path"
)
args = parser.parse_args()
config = Config.load(args.config)
sim = Simulation(config)
usim.run(sim.run())
print("Simulation complete!")

53
mixnet/config.ci.yaml Normal file
View File

@ -0,0 +1,53 @@
simulation:
# Desired duration of the simulation in seconds
# Since the simulation uses discrete time steps, the actual duration may be longer or shorter.
duration_sec: 1000
# Show all plots that have been drawn during the simulation
show_plots: false
network:
# Total number of nodes in the entire network.
num_nodes: 6
latency:
# Minimum/maximum network latency between nodes in seconds.
# A constant latency will be chosen randomly for each connection within the range [min_latency_sec, max_latency_sec].
min_latency_sec: 0
max_latency_sec: 0.1
# Seed for the random number generator used to determine the network latencies.
seed: 0
gossip:
# Expected number of peers each node must connect to if there are enough peers available in the network.
peering_degree: 3
topology:
# Seed for the random number generator used to determine the network topology.
seed: 1
mix:
# Global constant transmission rate of each connection in messages per second.
transmission_rate_per_sec: 10
# Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet.
max_message_size: 1007
mix_path:
# Minimum number of mix nodes to be chosen for a Sphinx packet.
min_length: 5
# Maximum number of mix nodes to be chosen for a Sphinx packet.
max_length: 5
# Seed for the random number generator used to determine the mix path.
seed: 3
temporal_mix:
# none | pure-coin-flipping | pure-random-sampling | permuted-coin-flipping
mix_type: "pure-coin-flipping"
# The minimum size of queue to be mixed.
# If the queue size is less than this value, noise messages are added.
min_queue_size: 5
# Generate the seeds used to create the RNG for each queue that will be created.
seed_generator: 100
logic:
sender_lottery:
# Interval between lottery draws in seconds.
interval_sec: 1
# Probability of a node being selected as a sender in each lottery draw.
probability: 0.001
# Seed for the random number generator used to determine the lottery winners.
seed: 10

View File

@ -0,0 +1 @@
from .framework import *

View File

@ -0,0 +1,52 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, Awaitable, Coroutine, Generic, TypeVar
from framework import framework
class Framework(framework.Framework):
"""
An asyncio implementation of the Framework
"""
def __init__(self):
super().__init__()
def queue(self) -> framework.Queue:
return Queue()
async def sleep(self, seconds: float) -> None:
await asyncio.sleep(seconds)
def now(self) -> float:
return time.time()
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]
) -> Awaitable[framework.RT]:
return asyncio.create_task(coroutine)
T = TypeVar("T")
class Queue(framework.Queue[T]):
"""
An asyncio implementation of the Queue
"""
def __init__(self):
super().__init__()
self._queue = asyncio.Queue()
async def put(self, data: T) -> None:
await self._queue.put(data)
async def get(self) -> T:
return await self._queue.get()
def empty(self) -> bool:
return self._queue.empty()

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import abc
from typing import Any, Awaitable, Coroutine, Generic, TypeVar
RT = TypeVar("RT")
class Framework(abc.ABC):
"""
An abstract class that provides essential asynchronous functions.
This class can be implemented using any asynchronous framework (e.g., asyncio, usim).
"""
@abc.abstractmethod
def queue(self) -> Queue:
pass
@abc.abstractmethod
async def sleep(self, seconds: float) -> None:
pass
@abc.abstractmethod
def now(self) -> float:
pass
@abc.abstractmethod
def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]:
pass
T = TypeVar("T")
class Queue(abc.ABC, Generic[T]):
"""
An abstract class that provides asynchronous queue operations.
"""
@abc.abstractmethod
async def put(self, data: T) -> None:
pass
@abc.abstractmethod
async def get(self) -> T:
pass
@abc.abstractmethod
def empty(self) -> bool:
pass

58
mixnet/framework/usim.py Normal file
View File

@ -0,0 +1,58 @@
from typing import Any, Awaitable, Coroutine, TypeVar
import usim
from framework import framework
class Framework(framework.Framework):
"""
A usim implementation of the Framework for discrete-time simulation
"""
def __init__(self, scope: usim.Scope) -> None:
super().__init__()
# Scope is used to spawn concurrent simulation activities (coroutines).
# μSim waits until all activities spawned in the scope are done
# or until the timeout specified in the scope is reached.
# Because of the way μSim works, the scope must be created using `async with` syntax
# and be passed to this constructor.
self._scope = scope
def queue(self) -> framework.Queue:
return Queue()
async def sleep(self, seconds: float) -> None:
await (usim.time + seconds)
def now(self) -> float:
# Round to milliseconds to make analysis not too heavy
return int(usim.time.now * 1000) / 1000
def spawn(
self, coroutine: Coroutine[Any, Any, framework.RT]
) -> Awaitable[framework.RT]:
return self._scope.do(coroutine)
T = TypeVar("T")
class Queue(framework.Queue[T]):
"""
A usim implementation of the Queue for discrete-time simulation
"""
def __init__(self):
super().__init__()
self._queue = usim.Queue()
async def put(self, data: T) -> None:
await self._queue.put(data)
async def get(self) -> T:
return await self._queue
def empty(self) -> bool:
return len(self._queue._buffer) == 0

View File

66
mixnet/protocol/config.py Normal file
View File

@ -0,0 +1,66 @@
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import List
from pysphinx.node import X25519PublicKey
from pysphinx.sphinx import Node as SphinxNode
from pysphinx.sphinx import X25519PrivateKey
from protocol.gossip import GossipConfig
from protocol.temporalmix import TemporalMixConfig
@dataclass
class GlobalConfig:
"""
Global parameters used across all nodes in the network
"""
membership: MixMembership
transmission_rate_per_sec: int # Global Transmission Rate
max_message_size: int
max_mix_path_length: int
@dataclass
class NodeConfig:
"""
Node-specific parameters
"""
private_key: X25519PrivateKey
mix_path_length: int
gossip: GossipConfig
temporal_mix: TemporalMixConfig
@dataclass
class MixMembership:
"""
A list of public information of nodes in the network.
We assume that this list is known to all nodes in the network.
"""
nodes: List[NodeInfo]
rng: random.Random = field(default_factory=random.Random)
def generate_route(self, length: int) -> list[NodeInfo]:
"""
Choose `length` nodes with replacement as a mix route.
"""
return self.rng.choices(self.nodes, k=length)
@dataclass
class NodeInfo:
"""
Public information of a node to be shared to all nodes in the network
"""
public_key: X25519PublicKey
def sphinx_node(self) -> SphinxNode:
dummy_node_addr = bytes(32)
return SphinxNode(self.public_key, dummy_node_addr)

View File

@ -0,0 +1,88 @@
from __future__ import annotations
import abc
import random
from framework import Framework, Queue
from protocol.temporalmix import PureCoinFlipppingQueue, TemporalMix, TemporalMixConfig
class SimplexConnection(abc.ABC):
"""
An abstract class for a simplex connection that can send and receive data in one direction
"""
@abc.abstractmethod
async def send(self, data: bytes) -> None:
pass
@abc.abstractmethod
async def recv(self) -> bytes:
pass
class LocalSimplexConnection(SimplexConnection):
"""
A simplex connection that doesn't have any network latency.
Data sent through this connection can be immediately received from the other end.
"""
def __init__(self, framework: Framework):
self.queue: Queue[bytes] = framework.queue()
async def send(self, data: bytes) -> None:
await self.queue.put(data)
async def recv(self) -> bytes:
return await self.queue.get()
class DuplexConnection:
"""
A duplex connection in which data can be transmitted and received simultaneously in both directions.
This is to mimic duplex communication in a real network (such as TCP or QUIC).
"""
def __init__(self, inbound: SimplexConnection, outbound: SimplexConnection):
self.inbound = inbound
self.outbound = outbound
async def recv(self) -> bytes:
return await self.inbound.recv()
async def send(self, packet: bytes):
await self.outbound.send(packet)
class MixSimplexConnection(SimplexConnection):
"""
Wraps a SimplexConnection to add a transmission rate and noise to the connection.
"""
def __init__(
self,
framework: Framework,
conn: SimplexConnection,
transmission_rate_per_sec: int,
noise_msg: bytes,
temporal_mix_config: TemporalMixConfig,
):
self.framework = framework
self.queue: Queue[bytes] = TemporalMix.queue(
temporal_mix_config, framework, noise_msg
)
self.conn = conn
self.transmission_rate_per_sec = transmission_rate_per_sec
self.task = framework.spawn(self.__run())
async def __run(self):
while True:
await self.framework.sleep(1 / self.transmission_rate_per_sec)
msg = await self.queue.get()
await self.conn.send(msg)
async def send(self, data: bytes) -> None:
await self.queue.put(data)
async def recv(self) -> bytes:
return await self.conn.recv()

2
mixnet/protocol/error.py Normal file
View File

@ -0,0 +1,2 @@
class PeeringDegreeReached(Exception):
pass

87
mixnet/protocol/gossip.py Normal file
View File

@ -0,0 +1,87 @@
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Self
from framework import Framework
from protocol.connection import (
DuplexConnection,
MixSimplexConnection,
SimplexConnection,
)
from protocol.error import PeeringDegreeReached
@dataclass
class GossipConfig:
# Expected number of peers each node must connect to if there are enough peers available in the network.
peering_degree: int
class Gossip:
"""
A gossip channel that broadcasts messages to all connected peers.
Peers are connected via DuplexConnection.
"""
def __init__(
self,
framework: Framework,
config: GossipConfig,
handler: Callable[[bytes], Awaitable[None]],
):
self.framework = framework
self.config = config
self.conns: list[DuplexConnection] = []
# A handler to process inbound messages.
self.handler = handler
self.packet_cache: set[bytes] = set()
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks: set[Awaitable] = set()
def can_accept_conn(self) -> bool:
return len(self.conns) < self.config.peering_degree
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
if not self.can_accept_conn():
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise PeeringDegreeReached()
conn = DuplexConnection(
inbound,
outbound,
)
self.conns.append(conn)
task = self.framework.spawn(self.__process_inbound_conn(conn))
self.tasks.add(task)
async def __process_inbound_conn(self, conn: DuplexConnection):
while True:
msg = await conn.recv()
if self.__check_update_cache(msg):
continue
await self.process_inbound_msg(msg)
async def process_inbound_msg(self, msg: bytes):
await self.gossip(msg)
await self.handler(msg)
async def gossip(self, msg: bytes):
"""
Gossip a message to all connected peers.
"""
for conn in self.conns:
await conn.send(msg)
def __check_update_cache(self, packet: bytes) -> bool:
"""
Add a message to the cache, and return True if the message was already in the cache.
"""
hash = hashlib.sha256(packet).digest()
if hash in self.packet_cache:
return True
self.packet_cache.add(hash)
return False

150
mixnet/protocol/node.py Normal file
View File

@ -0,0 +1,150 @@
from __future__ import annotations
from typing import Awaitable, Callable
from pysphinx.sphinx import (
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
SphinxPacket,
)
from framework import Framework, Queue
from protocol.config import GlobalConfig, NodeConfig
from protocol.connection import SimplexConnection
from protocol.error import PeeringDegreeReached
from protocol.gossip import Gossip
from protocol.nomssip import Nomssip, NomssipConfig
from protocol.sphinx import SphinxPacketBuilder
class Node:
"""
This represents any node in the network, which:
- generates/gossips mix messages (Sphinx packets)
- performs cryptographic mix (unwrapping Sphinx packets)
- generates noise
"""
def __init__(
self,
framework: Framework,
config: NodeConfig,
global_config: GlobalConfig,
# A handler called when a node receives a broadcasted message originated from the last mix.
broadcasted_msg_handler: Callable[[bytes], Awaitable[None]],
# An optional handler only for the simulation,
# which is called when a message is fully recovered by the last mix
# and returns a new message to be broadcasted.
recovered_msg_handler: Callable[[bytes], Awaitable[bytes]] | None = None,
):
self.framework = framework
self.config = config
self.global_config = global_config
self.nomssip = Nomssip(
framework,
NomssipConfig(
config.gossip.peering_degree,
global_config.transmission_rate_per_sec,
self.__calculate_message_size(global_config),
config.temporal_mix,
),
self.__process_msg,
)
self.broadcast = Gossip(framework, config.gossip, broadcasted_msg_handler)
self.recovered_msg_handler = recovered_msg_handler
@staticmethod
def __calculate_message_size(global_config: GlobalConfig) -> int:
"""
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
"""
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
bytes(global_config.max_message_size),
global_config,
global_config.max_mix_path_length,
)
return len(sample_sphinx_packet.bytes())
async def __process_msg(self, msg: bytes) -> None:
"""
A handler to process messages received via Nomssip channel
"""
sphinx_packet = SphinxPacket.from_bytes(
msg, self.global_config.max_mix_path_length
)
result = await self.__process_sphinx_packet(sphinx_packet)
match result:
case SphinxPacket():
# Gossip the next Sphinx packet
await self.nomssip.gossip(result.bytes())
case bytes():
if self.recovered_msg_handler is not None:
result = await self.recovered_msg_handler(result)
# Broadcast the message fully recovered from Sphinx packets
await self.broadcast.gossip(result)
case None:
return
async def __process_sphinx_packet(
self, packet: SphinxPacket
) -> SphinxPacket | bytes | None:
"""
Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible
"""
try:
processed = packet.process(self.config.private_key)
match processed:
case ProcessedForwardHopPacket():
return processed.next_packet
case ProcessedFinalHopPacket():
return processed.payload.recover_plain_playload()
except ValueError:
# Return nothing, if it cannot be unwrapped by the private key of this node.
return None
def connect_mix(
self,
peer: Node,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
):
Node.__connect(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
def connect_broadcast(
self,
peer: Node,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
):
Node.__connect(self.broadcast, peer.broadcast, inbound_conn, outbound_conn)
@staticmethod
def __connect(
self_channel: Gossip,
peer_channel: Gossip,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
):
"""
Establish a duplex connection with a peer node.
"""
if not self_channel.can_accept_conn() or not peer_channel.can_accept_conn():
raise PeeringDegreeReached()
# Register a duplex connection for its own use
self_channel.add_conn(inbound_conn, outbound_conn)
# Register a duplex connection for the peer
peer_channel.add_conn(outbound_conn, inbound_conn)
async def send_message(self, msg: bytes):
"""
Build a Sphinx packet and gossip it to all connected peers.
"""
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
sphinx_packet, _ = SphinxPacketBuilder.build(
msg,
self.global_config,
self.config.mix_path_length,
)
await self.nomssip.gossip(sphinx_packet.bytes())

106
mixnet/protocol/nomssip.py Normal file
View File

@ -0,0 +1,106 @@
from __future__ import annotations
import hashlib
import random
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Self, override
from framework import Framework
from protocol.connection import (
DuplexConnection,
MixSimplexConnection,
SimplexConnection,
)
from protocol.error import PeeringDegreeReached
from protocol.gossip import Gossip, GossipConfig
from protocol.temporalmix import TemporalMixConfig
@dataclass
class NomssipConfig(GossipConfig):
transmission_rate_per_sec: int
msg_size: int
temporal_mix: TemporalMixConfig
class Nomssip(Gossip):
"""
A NomMix gossip channel that extends the Gossip channel
by adding global transmission rate and noise generation.
"""
def __init__(
self,
framework: Framework,
config: NomssipConfig,
handler: Callable[[bytes], Awaitable[None]],
):
super().__init__(framework, config, handler)
self.config = config
@override
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
noise_packet = FlaggedPacket(
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
).bytes()
super().add_conn(
inbound,
MixSimplexConnection(
self.framework,
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
self.config.temporal_mix,
),
)
@override
async def process_inbound_msg(self, msg: bytes):
packet = FlaggedPacket.from_bytes(msg)
match packet.flag:
case FlaggedPacket.Flag.NOISE:
# Drop noise packet
return
case FlaggedPacket.Flag.REAL:
await self.__gossip_flagged_packet(packet)
await self.handler(packet.message)
@override
async def gossip(self, msg: bytes):
"""
Gossip a message to all connected peers with prepending a message flag
"""
# The message size must be fixed.
assert len(msg) == self.config.msg_size
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg)
await self.__gossip_flagged_packet(packet)
async def __gossip_flagged_packet(self, packet: FlaggedPacket):
"""
An internal method to send a flagged packet to all connected peers
"""
await super().gossip(packet.bytes())
class FlaggedPacket:
class Flag(Enum):
REAL = b"\x00"
NOISE = b"\x01"
def __init__(self, flag: Flag, message: bytes):
self.flag = flag
self.message = message
def bytes(self) -> bytes:
return self.flag.value + self.message
@classmethod
def from_bytes(cls, packet: bytes) -> Self:
"""
Parse a flagged packet from bytes
"""
if len(packet) < 1:
raise ValueError("Invalid message format")
return cls(cls.Flag(packet[:1]), packet[1:])

33
mixnet/protocol/sphinx.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import List, Tuple
from pysphinx.sphinx import SphinxPacket
from protocol.config import GlobalConfig, NodeInfo
class SphinxPacketBuilder:
@staticmethod
def build(
message: bytes, global_config: GlobalConfig, path_len: int
) -> Tuple[SphinxPacket, List[NodeInfo]]:
if path_len <= 0:
raise ValueError("path_len must be greater than 0")
if len(message) > global_config.max_message_size:
raise ValueError("message is too long")
route = global_config.membership.generate_route(path_len)
# We don't need the destination (defined in the Loopix Sphinx spec)
# because the last mix will broadcast the fully unwrapped message.
# Later, we will optimize the Sphinx according to our requirements.
dummy_destination = route[-1]
packet = SphinxPacket.build(
message,
route=[mixnode.sphinx_node() for mixnode in route],
destination=dummy_destination.sphinx_node(),
max_route_length=global_config.max_mix_path_length,
max_plain_payload_size=global_config.max_message_size,
)
return (packet, route)

View File

@ -0,0 +1,177 @@
import random
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TypeVar
from framework.framework import Framework, Queue
class TemporalMixType(Enum):
NONE = "none"
PURE_COIN_FLIPPING = "pure-coin-flipping"
PURE_RANDOM_SAMPLING = "pure-random-sampling"
PERMUTED_COIN_FLIPPING = "permuted-coin-flipping"
NOISY_COIN_FLIPPING = "noisy-coin-flipping"
@dataclass
class TemporalMixConfig:
mix_type: TemporalMixType
# The minimum size of queue to be mixed.
# If the queue size is less than this value, noise messages are added.
min_queue_size: int
# Generate the seeds used to create the RNG for each queue that will be created.
seed_generator: random.Random
def __post_init__(self):
assert self.seed_generator is not None
assert self.min_queue_size > 0
T = TypeVar("T")
class TemporalMix:
@staticmethod
def queue(
config: TemporalMixConfig, framework: Framework, noise_msg: T
) -> Queue[T]:
match config.mix_type:
case TemporalMixType.NONE:
return NonMixQueue(framework, noise_msg)
case TemporalMixType.PURE_COIN_FLIPPING:
return PureCoinFlipppingQueue(
config.min_queue_size,
random.Random(config.seed_generator.random()),
noise_msg,
)
case TemporalMixType.PURE_RANDOM_SAMPLING:
return PureRandomSamplingQueue(
config.min_queue_size,
random.Random(config.seed_generator.random()),
noise_msg,
)
case TemporalMixType.PERMUTED_COIN_FLIPPING:
return PermutedCoinFlipppingQueue(
config.min_queue_size,
random.Random(config.seed_generator.random()),
noise_msg,
)
case TemporalMixType.NOISY_COIN_FLIPPING:
return NoisyCoinFlippingQueue(
random.Random(config.seed_generator.random()),
noise_msg,
)
case _:
raise ValueError(f"Unknown mix type: {config.mix_type}")
class NonMixQueue(Queue[T]):
"""
Queue without temporal mixing. Only have the noise generation when the queue is empty.
"""
def __init__(self, framework: Framework, noise_msg: T):
self.__queue = framework.queue()
self.__noise_msg = noise_msg
async def put(self, data: T) -> None:
await self.__queue.put(data)
async def get(self) -> T:
if self.__queue.empty():
return self.__noise_msg
else:
return await self.__queue.get()
def empty(self) -> bool:
return self.__queue.empty()
class MixQueue(Queue[T]):
def __init__(self, rng: random.Random, noise_msg: T):
super().__init__()
# Assuming that simulations run in a single thread
self._queue: list[T] = []
self._rng = rng
self._noise_msg = noise_msg
async def put(self, data: T) -> None:
self._queue.append(data)
@abstractmethod
async def get(self) -> T:
pass
def empty(self) -> bool:
return len(self._queue) == 0
class MinSizeMixQueue(MixQueue[T]):
def __init__(self, min_pool_size: int, rng: random.Random, noise_msg: T):
super().__init__(rng, noise_msg)
self._mix_pool_size = min_pool_size
@abstractmethod
async def get(self) -> T:
while len(self._queue) < self._mix_pool_size:
self._queue.append(self._noise_msg)
# Subclass must implement this method
pass
class PureCoinFlipppingQueue(MinSizeMixQueue[T]):
async def get(self) -> T:
await super().get()
while True:
for i in range(len(self._queue)):
# coin-flipping
if self._rng.randint(0, 1) == 1:
# After removing a message from the position `i`, we don't fill up the position.
# Instead, the queue is always filled from the back.
return self._queue.pop(i)
class PureRandomSamplingQueue(MinSizeMixQueue[T]):
async def get(self) -> T:
await super().get()
i = self._rng.randint(0, len(self._queue) - 1)
# After removing a message from the position `i`, we don't fill up the position.
# Instead, the queue is always filled from the back.
return self._queue.pop(i)
class PermutedCoinFlipppingQueue(MinSizeMixQueue[T]):
async def get(self) -> T:
await super().get()
self._rng.shuffle(self._queue)
while True:
for i in range(len(self._queue)):
# coin-flipping
if self._rng.randint(0, 1) == 1:
# After removing a message from the position `i`, we don't fill up the position.
# Instead, the queue is always filled from the back.
return self._queue.pop(i)
class NoisyCoinFlippingQueue(MixQueue[T]):
async def get(self) -> T:
if len(self._queue) == 0:
return self._noise_msg
while True:
for i in range(len(self._queue)):
# coin-flipping
if self._rng.randint(0, 1) == 1:
# After removing a message from the position `i`, we don't fill up the position.
# Instead, the queue is always filled from the back.
return self._queue.pop(i)
else:
if i == 0:
return self._noise_msg

View File

@ -0,0 +1,58 @@
from unittest import IsolatedAsyncioTestCase
import framework.asyncio as asynciofw
from framework.framework import Queue
from protocol.connection import LocalSimplexConnection
from protocol.node import Node
from protocol.test_utils import (
init_mixnet_config,
)
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
framework = asynciofw.Framework()
global_config, node_configs, _ = init_mixnet_config(10)
queue: Queue[bytes] = framework.queue()
async def broadcasted_msg_handler(msg: bytes) -> None:
await queue.put(msg)
nodes = [
Node(framework, node_config, global_config, broadcasted_msg_handler)
for node_config in node_configs
]
for i, node in enumerate(nodes):
try:
node.connect_mix(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection(framework),
LocalSimplexConnection(framework),
)
node.connect_broadcast(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection(framework),
LocalSimplexConnection(framework),
)
except ValueError as e:
print(e)
await nodes[0].send_message(b"block selection")
# Wait for all nodes to receive the broadcast
num_nodes_received_broadcast = 0
timeout = 15
for _ in range(timeout):
await framework.sleep(1)
while not queue.empty():
self.assertEqual(b"block selection", await queue.get())
num_nodes_received_broadcast += 1
if num_nodes_received_broadcast == len(nodes):
break
self.assertEqual(len(nodes), num_nodes_received_broadcast)
# TODO: check noise

View File

@ -0,0 +1,66 @@
from random import randint
from typing import cast
from unittest import TestCase
from pysphinx.sphinx import (
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
)
from protocol.sphinx import SphinxPacketBuilder
from protocol.test_utils import init_mixnet_config
class TestSphinxPacketBuilder(TestCase):
def test_builder(self):
global_config, _, key_map = init_mixnet_config(10)
msg = self.random_bytes(500)
packet, route = SphinxPacketBuilder.build(msg, global_config, 3)
self.assertEqual(3, len(route))
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
self.assertIsInstance(processed, ProcessedForwardHopPacket)
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
key_map[route[1].public_key.public_bytes_raw()]
)
self.assertIsInstance(processed, ProcessedForwardHopPacket)
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
key_map[route[2].public_key.public_bytes_raw()]
)
self.assertIsInstance(processed, ProcessedFinalHopPacket)
recovered = cast(
ProcessedFinalHopPacket, processed
).payload.recover_plain_playload()
self.assertEqual(msg, recovered)
def test_max_message_size(self):
global_config, _, _ = init_mixnet_config(10, max_message_size=2000)
mix_path_length = global_config.max_mix_path_length
packet1, _ = SphinxPacketBuilder.build(
self.random_bytes(1500), global_config, mix_path_length
)
packet2, _ = SphinxPacketBuilder.build(
self.random_bytes(2000), global_config, mix_path_length
)
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
msg = self.random_bytes(2001)
with self.assertRaises(ValueError):
_ = SphinxPacketBuilder.build(msg, global_config, mix_path_length)
def test_max_mix_path_length(self):
global_config, _, _ = init_mixnet_config(10, max_mix_path_length=2)
msg = self.random_bytes(global_config.max_message_size)
packet1, _ = SphinxPacketBuilder.build(msg, global_config, 1)
packet2, _ = SphinxPacketBuilder.build(msg, global_config, 2)
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
with self.assertRaises(ValueError):
_ = SphinxPacketBuilder.build(msg, global_config, 3)
@staticmethod
def random_bytes(size: int) -> bytes:
assert size >= 0
return bytes([randint(0, 255) for _ in range(size)])

View File

@ -0,0 +1,103 @@
import random
from unittest import IsolatedAsyncioTestCase
import framework.asyncio as asynciofw
from framework.framework import Queue
from protocol.temporalmix import (
NoisyCoinFlippingQueue,
NonMixQueue,
PermutedCoinFlipppingQueue,
PureCoinFlipppingQueue,
PureRandomSamplingQueue,
TemporalMix,
TemporalMixConfig,
TemporalMixType,
)
class TestTemporalMix(IsolatedAsyncioTestCase):
async def test_queue_builder(self):
# Check if the queue builder generates the correct queue type
for mix_type in TemporalMixType:
await self.__test_queue_builder(mix_type)
async def __test_queue_builder(self, mix_type: TemporalMixType):
queue: Queue[int] = TemporalMix.queue(
TemporalMixConfig(mix_type, 4, random.Random(0)),
asynciofw.Framework(),
-1,
)
match mix_type:
case TemporalMixType.NONE:
self.assertIsInstance(queue, NonMixQueue)
case TemporalMixType.PURE_COIN_FLIPPING:
self.assertIsInstance(queue, PureCoinFlipppingQueue)
case TemporalMixType.PURE_RANDOM_SAMPLING:
self.assertIsInstance(queue, PureRandomSamplingQueue)
case TemporalMixType.PERMUTED_COIN_FLIPPING:
self.assertIsInstance(queue, PermutedCoinFlipppingQueue)
case TemporalMixType.NOISY_COIN_FLIPPING:
self.assertIsInstance(queue, NoisyCoinFlippingQueue)
case _:
self.fail(f"Unknown mix type: {mix_type}")
async def test_non_mix_queue(self):
queue: Queue[int] = TemporalMix.queue(
TemporalMixConfig(TemporalMixType.NONE, 4, random.Random(0)),
asynciofw.Framework(),
-1,
)
# Check if queue is FIFO
await queue.put(0)
await queue.put(1)
self.assertEqual(0, await queue.get())
self.assertEqual(1, await queue.get())
# Check if noise is generated when queue is empty
self.assertEqual(-1, await queue.get())
# FIFO again
await queue.put(2)
self.assertEqual(2, await queue.get())
await queue.put(3)
self.assertEqual(3, await queue.get())
async def test_pure_coin_flipping_queue(self):
await self.__test_mix_queue(TemporalMixType.PURE_COIN_FLIPPING)
async def test_pure_random_sampling(self):
await self.__test_mix_queue(TemporalMixType.PURE_RANDOM_SAMPLING)
async def test_permuted_coin_flipping_queue(self):
await self.__test_mix_queue(TemporalMixType.PERMUTED_COIN_FLIPPING)
async def test_noisy_coin_flipping_queue(self):
await self.__test_mix_queue(TemporalMixType.NOISY_COIN_FLIPPING)
async def __test_mix_queue(self, mix_type: TemporalMixType):
queue: Queue[int] = TemporalMix.queue(
TemporalMixConfig(mix_type, 4, random.Random(0)),
asynciofw.Framework(),
-1,
)
# Check if noise is generated when queue is empty
self.assertEqual(-1, await queue.get())
# Put only 2 elements even though the min queue size is 4
await queue.put(0)
await queue.put(1)
# Wait until 2 elements are returned from the queue
waiting = {0, 1}
while len(waiting) > 0:
e = await queue.get()
if e in waiting:
waiting.remove(e)
else:
# Check if it's the noise
self.assertEqual(-1, e)
# Check if noise is generated when there is no real message inserted
self.assertEqual(-1, await queue.get())

View File

@ -0,0 +1,46 @@
import random
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from protocol.config import (
GlobalConfig,
MixMembership,
NodeConfig,
NodeInfo,
)
from protocol.gossip import GossipConfig
from protocol.nomssip import TemporalMixConfig
from protocol.temporalmix import TemporalMixType
def init_mixnet_config(
num_nodes: int,
max_message_size: int = 512,
max_mix_path_length: int = 3,
) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]:
gossip_config = GossipConfig(peering_degree=6)
node_configs = [
NodeConfig(
X25519PrivateKey.generate(),
max_mix_path_length,
gossip_config,
TemporalMixConfig(TemporalMixType.PURE_COIN_FLIPPING, 3, random.Random()),
)
for _ in range(num_nodes)
]
global_config = GlobalConfig(
MixMembership(
[
NodeInfo(node_config.private_key.public_key())
for node_config in node_configs
]
),
transmission_rate_per_sec=3,
max_message_size=max_message_size,
max_mix_path_length=max_mix_path_length,
)
key_map = {
node_config.private_key.public_key().public_bytes_raw(): node_config.private_key
for node_config in node_configs
}
return (global_config, node_configs, key_map)

6
mixnet/requirements.txt Normal file
View File

@ -0,0 +1,6 @@
usim==0.4.4
pysphinx==0.0.5
dacite==1.8.1
pandas==2.2.2
matplotlib==3.9.1
PyYAML==6.0.1

0
mixnet/sim/__init__.py Normal file
View File

163
mixnet/sim/config.py Normal file
View File

@ -0,0 +1,163 @@
from __future__ import annotations
import hashlib
import random
from dataclasses import dataclass
import dacite
import yaml
from pysphinx.sphinx import X25519PrivateKey
from protocol.config import NodeConfig
from protocol.gossip import GossipConfig
from protocol.temporalmix import TemporalMixConfig, TemporalMixType
@dataclass
class Config:
simulation: SimulationConfig
network: NetworkConfig
logic: LogicConfig
mix: MixConfig
@classmethod
def load(cls, yaml_path: str) -> Config:
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
return dacite.from_dict(
data_class=Config,
data=data,
config=dacite.Config(
type_hooks={
random.Random: seed_to_random,
TemporalMixType: str_to_temporal_mix_type,
},
strict=True,
),
)
def node_configs(self) -> list[NodeConfig]:
return [
NodeConfig(
self.__gen_private_key(i),
self.mix.mix_path.random_length(),
self.network.gossip,
self.mix.temporal_mix,
)
for i in range(self.network.num_nodes)
]
def __gen_private_key(self, node_idx: int) -> X25519PrivateKey:
return X25519PrivateKey.from_private_bytes(
hashlib.sha256(node_idx.to_bytes(4, "big")).digest()[:32]
)
@dataclass
class SimulationConfig:
# Desired duration of the simulation in seconds
# Since the simulation uses discrete time steps, the actual duration may be longer or shorter.
duration_sec: int
# Show all plots that have been drawn during the simulation
show_plots: bool
def __post_init__(self):
assert self.duration_sec > 0
@dataclass
class NetworkConfig:
# Total number of nodes in the entire network.
num_nodes: int
latency: LatencyConfig
gossip: GossipConfig
topology: TopologyConfig
def __post_init__(self):
assert self.num_nodes > 0
@dataclass
class LatencyConfig:
# Minimum/maximum network latency between nodes in seconds.
# A constant latency will be chosen randomly for each connection within the range [min_latency_sec, max_latency_sec].
min_latency_sec: float
max_latency_sec: float
# Seed for the random number generator used to determine the network latencies.
seed: random.Random
def __post_init__(self):
assert 0 <= self.min_latency_sec <= self.max_latency_sec
assert self.seed is not None
def random_latency(self) -> float:
# round to milliseconds to make analysis not too heavy
return round(self.seed.uniform(self.min_latency_sec, self.max_latency_sec), 3)
@dataclass
class TopologyConfig:
# Seed for the random number generator used to determine the network topology.
seed: random.Random
def __post_init__(self):
assert self.seed is not None
@dataclass
class MixConfig:
# Global constant transmission rate of each connection in messages per second.
transmission_rate_per_sec: int
# Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet.
max_message_size: int
mix_path: MixPathConfig
temporal_mix: TemporalMixConfig
def __post_init__(self):
assert self.transmission_rate_per_sec > 0
assert self.max_message_size > 0
@dataclass
class MixPathConfig:
# Minimum number of mix nodes to be chosen for a Sphinx packet.
min_length: int
# Maximum number of mix nodes to be chosen for a Sphinx packet.
max_length: int
# Seed for the random number generator used to determine the mix path.
seed: random.Random
def __post_init__(self):
assert 0 < self.min_length <= self.max_length
assert self.seed is not None
def random_length(self) -> int:
return self.seed.randint(self.min_length, self.max_length)
@dataclass
class LogicConfig:
sender_lottery: LotteryConfig
@dataclass
class LotteryConfig:
# Interval between lottery draws in seconds.
interval_sec: float
# Probability of a node being selected as a sender in each lottery draw.
probability: float
# Seed for the random number generator used to determine the lottery winners.
seed: random.Random
def __post_init__(self):
assert self.interval_sec > 0
assert self.probability >= 0
assert self.seed is not None
def seed_to_random(seed: int) -> random.Random:
return random.Random(seed)
def str_to_temporal_mix_type(val: str) -> TemporalMixType:
return TemporalMixType(val)

139
mixnet/sim/connection.py Normal file
View File

@ -0,0 +1,139 @@
import math
from collections import Counter
from typing import Awaitable
import pandas
from typing_extensions import override
from framework import Framework, Queue
from protocol.connection import SimplexConnection
from sim.config import LatencyConfig, NetworkConfig
from sim.state import NodeState
class MeteredRemoteSimplexConnection(SimplexConnection):
"""
A simplex connection implementation that simulates network latency and measures bandwidth usages.
"""
def __init__(
self,
config: LatencyConfig,
framework: Framework,
meter_start_time: float,
):
self.framework = framework
# A connection has a random constant latency
self.latency = config.random_latency()
# A queue of tuple(timestamp, msg) where a sender puts messages to be sent
self.send_queue: Queue[tuple[float, bytes]] = framework.queue()
# A task that reads messages from send_queue, and puts them to recv_queue.
# Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message.
self.relayer = framework.spawn(self.__run_relayer())
# A queue where a receiver gets messages
self.recv_queue: Queue[bytes] = framework.queue()
# To measure bandwidth usages
self.meter_start_time = meter_start_time
self.send_meters: list[int] = []
self.recv_meters: list[int] = []
async def send(self, data: bytes) -> None:
await self.send_queue.put((self.framework.now(), data))
self.on_sending(data)
async def recv(self) -> bytes:
return await self.recv_queue.get()
async def __run_relayer(self):
"""
A task that reads messages from send_queue, and puts them to recv_queue.
Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message.
"""
while True:
sent_time, data = await self.send_queue.get()
# Simulate network latency
delay = self.latency - (self.framework.now() - sent_time)
if delay > 0:
await self.framework.sleep(delay)
# Relay msg to the recv_queue.
# Update related statistics before msg is read from recv_queue by the receiver
# because the time at which enters the node is important when viewed from the outside.
self.on_receiving(data)
await self.recv_queue.put(data)
def on_sending(self, data: bytes) -> None:
"""
Update statistics when sending a message
"""
self.__update_meter(self.send_meters, len(data))
def on_receiving(self, data: bytes) -> None:
"""
Update statistics when receiving a message
"""
self.__update_meter(self.recv_meters, len(data))
def __update_meter(self, meters: list[int], size: int):
"""
Accumulates the bandwidth usage in the current time slot (seconds).
"""
slot = math.floor(self.framework.now() - self.meter_start_time)
assert slot >= len(meters) - 1
# Fill zeros for the empty time slots
meters.extend([0] * (slot - len(meters) + 1))
meters[-1] += size
def sending_bandwidths(self) -> pandas.Series:
"""
Returns the accumulated sending bandwidth usage over time
"""
return self.__bandwidths(self.send_meters)
def receiving_bandwidths(self) -> pandas.Series:
"""
Returns the accumulated receiving bandwidth usage over time
"""
return self.__bandwidths(self.recv_meters)
def __bandwidths(self, meters: list[int]) -> pandas.Series:
return pandas.Series(meters, name="bandwidth")
class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection):
"""
An extension of MeteredRemoteSimplexConnection that is observed by passive observer.
The observer monitors the node states of the sender and receiver and message sizes.
"""
def __init__(
self,
config: LatencyConfig,
framework: Framework,
meter_start_time: float,
send_node_states: list[NodeState],
recv_node_states: list[NodeState],
):
super().__init__(config, framework, meter_start_time)
# To measure node states over time
self.send_node_states = send_node_states
self.recv_node_states = recv_node_states
# To measure the size of messages sent via this connection
self.msg_sizes: Counter[int] = Counter()
@override
def on_sending(self, data: bytes) -> None:
super().on_sending(data)
self.__update_node_state(self.send_node_states, NodeState.SENDING)
self.msg_sizes.update([len(data)])
@override
def on_receiving(self, data: bytes) -> None:
super().on_receiving(data)
self.__update_node_state(self.recv_node_states, NodeState.RECEIVING)
def __update_node_state(self, node_states: list[NodeState], state: NodeState):
# The time unit of node states is milliseconds
ms = math.floor(self.framework.now() * 1000)
node_states[ms] = state

42
mixnet/sim/message.py Normal file
View File

@ -0,0 +1,42 @@
import pickle
from dataclasses import dataclass
from typing import Self
@dataclass
class Message:
"""
A message structure for simulation, which will be sent through mix nodes
and eventually broadcasted to all nodes in the network.
The `id` must ensure the uniqueness of the message.
"""
created_at: float
id: int
body: bytes
def __bytes__(self):
return pickle.dumps(self)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
return pickle.loads(data)
def __hash__(self) -> int:
return self.id
class UniqueMessageBuilder:
"""
Builds a unique message with an incremental ID,
assuming that the simulation is run in a single thread.
"""
def __init__(self):
self.next_id = 0
def next(self, created_at: float, body: bytes) -> Message:
msg = Message(created_at, self.next_id, body)
self.next_id += 1
return msg

198
mixnet/sim/simulation.py Normal file
View File

@ -0,0 +1,198 @@
from dataclasses import asdict, dataclass
from pprint import pprint
from typing import Self
import usim
from matplotlib import pyplot
import framework.usim as usimfw
from framework import Framework
from protocol.config import GlobalConfig, MixMembership, NodeInfo
from protocol.node import Node, PeeringDegreeReached
from sim.config import Config
from sim.connection import (
MeteredRemoteSimplexConnection,
ObservedMeteredRemoteSimplexConnection,
)
from sim.message import Message, UniqueMessageBuilder
from sim.state import NodeState, NodeStateTable
from sim.stats import ConnectionStats, DisseminationTime
from sim.topology import build_full_random_topology
class Simulation:
"""
Manages the entire cycle of simulation: initialization, running, and analysis.
"""
def __init__(self, config: Config):
self.config = config
self.msg_builder = UniqueMessageBuilder()
self.dissemination_time = DisseminationTime(self.config.network.num_nodes)
async def run(self):
# Run the simulation
conn_stats, node_state_table = await self.__run()
# Analyze the dissemination times
self.dissemination_time.analyze()
# Analyze the simulation results
conn_stats.analyze()
node_state_table.analyze()
# Show plots
if self.config.simulation.show_plots:
pyplot.show()
async def __run(self) -> tuple[ConnectionStats, NodeStateTable]:
# Initialize analysis tools
node_state_table = NodeStateTable(
self.config.network.num_nodes, self.config.simulation.duration_sec
)
conn_stats = ConnectionStats()
# Create a μSim scope and run the simulation
async with usim.until(usim.time + self.config.simulation.duration_sec) as scope:
self.framework = usimfw.Framework(scope)
nodes = self.__init_nodes()
self.__connect_nodes(nodes, node_state_table, conn_stats)
for i, node in enumerate(nodes):
print(f"Spawning node-{i} with {len(node.nomssip.conns)} conns")
self.framework.spawn(self.__run_node_logic(node))
# Return analysis tools once the μSim scope is done
return conn_stats, node_state_table
def __init_nodes(self) -> list[Node]:
# Initialize node/global configurations
node_configs = self.config.node_configs()
global_config = GlobalConfig(
MixMembership(
[
NodeInfo(node_config.private_key.public_key())
for node_config in node_configs
],
self.config.mix.mix_path.seed,
),
self.config.mix.transmission_rate_per_sec,
self.config.mix.max_message_size,
self.config.mix.mix_path.max_length,
)
# Initialize/return Node instances
return [
Node(
self.framework,
node_config,
global_config,
self.__process_broadcasted_msg,
self.__process_recovered_msg,
)
for node_config in node_configs
]
def __connect_nodes(
self,
nodes: list[Node],
node_state_table: NodeStateTable,
conn_stats: ConnectionStats,
):
topology = build_full_random_topology(
self.config.network.topology.seed,
len(nodes),
self.config.network.gossip.peering_degree,
)
print("Topology:")
pprint(topology)
meter_start_time = self.framework.now()
# Sort the topology by node index for the connection RULE defined below.
for node_idx, peer_indices in sorted(topology.items()):
for peer_idx in peer_indices:
# Since the topology is undirected, we only need to connect the two nodes once.
# RULE: the node with the smaller index establishes the connection.
assert node_idx != peer_idx
if node_idx > peer_idx:
continue
node = nodes[node_idx]
peer = nodes[peer_idx]
node_states = node_state_table[node_idx]
peer_states = node_state_table[peer_idx]
# Connect the node and peer for Nomos Gossip
inbound_conn, outbound_conn = (
self.__create_observed_conn(
meter_start_time, peer_states, node_states
),
self.__create_observed_conn(
meter_start_time, node_states, peer_states
),
)
node.connect_mix(peer, inbound_conn, outbound_conn)
# Register the connections to the connection statistics
conn_stats.register(node, inbound_conn, outbound_conn)
conn_stats.register(peer, outbound_conn, inbound_conn)
# Connect the node and peer for broadcasting.
node.connect_broadcast(
peer,
self.__create_conn(meter_start_time),
self.__create_conn(meter_start_time),
)
def __create_observed_conn(
self,
meter_start_time: float,
sender_states: list[NodeState],
receiver_states: list[NodeState],
) -> ObservedMeteredRemoteSimplexConnection:
return ObservedMeteredRemoteSimplexConnection(
self.config.network.latency,
self.framework,
meter_start_time,
sender_states,
receiver_states,
)
def __create_conn(
self,
meter_start_time: float,
) -> MeteredRemoteSimplexConnection:
return MeteredRemoteSimplexConnection(
self.config.network.latency,
self.framework,
meter_start_time,
)
async def __run_node_logic(self, node: Node):
"""
Runs the lottery periodically to check if the node is selected to send a block.
If the node is selected, creates a block and sends it through mix nodes.
"""
lottery_config = self.config.logic.sender_lottery
while True:
await self.framework.sleep(lottery_config.interval_sec)
if lottery_config.seed.random() < lottery_config.probability:
msg = self.msg_builder.next(self.framework.now(), b"selected block")
await node.send_message(bytes(msg))
async def __process_broadcasted_msg(self, msg: bytes):
"""
Process a broadcasted message originated from the last mix.
"""
message = Message.from_bytes(msg)
elapsed = self.framework.now() - message.created_at
self.dissemination_time.add_broadcasted_msg(message, elapsed)
async def __process_recovered_msg(self, msg: bytes) -> bytes:
"""
Process a message fully recovered by the last mix
and returns a new message to be broadcasted.
"""
message = Message.from_bytes(msg)
elapsed = self.framework.now() - message.created_at
self.dissemination_time.add_mix_propagation_time(elapsed)
# Update the timestamp and return the message to be broadcasted,
# so that the broadcast dissemination time can be calculated from now.
message.created_at = self.framework.now()
return bytes(message)

77
mixnet/sim/state.py Normal file
View File

@ -0,0 +1,77 @@
from datetime import datetime
from enum import Enum
import matplotlib.pyplot as plt
import pandas
class NodeState(Enum):
"""
A state of node at a certain time.
For now, we assume that the node cannot send and receive messages at the same time for simplicity.
"""
SENDING = -1
IDLE = 0
RECEIVING = 1
class NodeStateTable:
def __init__(self, num_nodes: int, duration_sec: int):
# Create a table to store the state of each node at each millisecond
self.__table = [
[NodeState.IDLE] * (duration_sec * 1000) for _ in range(num_nodes)
]
def __getitem__(self, idx: int) -> list[NodeState]:
return self.__table[idx]
def analyze(self):
df = pandas.DataFrame(self.__table).transpose()
df.columns = [f"Node-{i}" for i in range(len(self.__table))]
# Convert NodeState enum to their integer values
df = df.map(lambda state: state.value)
print("==========================================")
print("Node States of All Nodes over Time")
print(", ".join(f"{state.name}:{state.value}" for state in NodeState))
print("==========================================")
print(f"{df}\n")
csv_path = f"all_node_states_{datetime.now().isoformat(timespec="seconds")}.csv"
df.to_csv(csv_path)
print(f"Saved DataFrame to {csv_path}\n")
# Count/print the number of each state for each node
# because the df is usually too big to print
state_counts = df.apply(pandas.Series.value_counts).fillna(0)
print("State Counts per Node:")
print(f"{state_counts}\n")
# Draw a dot plot
plt.figure(figsize=(15, 8))
for node in df.columns:
times = df.index
states = df[node]
sending_times = times[states == NodeState.SENDING.value]
receiving_times = times[states == NodeState.RECEIVING.value]
plt.scatter(
sending_times,
[node] * len(sending_times),
color="red",
marker="o",
s=10,
label="SENDING" if node == df.columns[0] else "",
)
plt.scatter(
receiving_times,
[node] * len(receiving_times),
color="blue",
marker="x",
s=10,
label="RECEIVING" if node == df.columns[0] else "",
)
plt.xlabel("Time")
plt.ylabel("Node")
plt.title("Node States Over Time")
plt.legend(loc="upper right")
plt.draw()

152
mixnet/sim/stats.py Normal file
View File

@ -0,0 +1,152 @@
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import numpy
import pandas
from matplotlib.axes import Axes
from protocol.node import Node
from sim.connection import ObservedMeteredRemoteSimplexConnection
from sim.message import Message
# A map of nodes to their inbound/outbound connections
NodeConnectionsMap = dict[
Node,
tuple[
list[ObservedMeteredRemoteSimplexConnection],
list[ObservedMeteredRemoteSimplexConnection],
],
]
class ConnectionStats:
def __init__(self):
self.conns_per_node: NodeConnectionsMap = defaultdict(lambda: ([], []))
def register(
self,
node: Node,
inbound_conn: ObservedMeteredRemoteSimplexConnection,
outbound_conn: ObservedMeteredRemoteSimplexConnection,
):
self.conns_per_node[node][0].append(inbound_conn)
self.conns_per_node[node][1].append(outbound_conn)
def analyze(self):
self.__message_sizes()
self.__bandwidths_per_conn()
self.__bandwidths_per_node()
def __message_sizes(self):
"""
Analyzes all message sizes sent across all connections of all nodes.
"""
sizes: Counter[int] = Counter()
for _, (_, outbound_conns) in self.conns_per_node.items():
for conn in outbound_conns:
sizes.update(conn.msg_sizes)
df = pandas.DataFrame.from_dict(sizes, orient="index").reset_index()
df.columns = ["msg_size", "count"]
print("==========================================")
print("Message Size Distribution")
print("==========================================")
print(f"{df}\n")
def __bandwidths_per_conn(self):
"""
Analyzes the bandwidth consumed by each simplex connection.
"""
plt.plot(figsize=(12, 6))
for _, (_, outbound_conns) in self.conns_per_node.items():
for conn in outbound_conns:
sending_bandwidths = conn.sending_bandwidths().map(lambda x: x / 1024)
plt.plot(sending_bandwidths.index, sending_bandwidths)
plt.title("Unidirectional Bandwidths per Connection")
plt.xlabel("Time (s)")
plt.ylabel("Bandwidth (KiB/s)")
plt.ylim(bottom=0)
plt.grid(True)
plt.tight_layout()
plt.draw()
def __bandwidths_per_node(self):
"""
Analyzes the inbound/outbound bandwidths consumed by each node (sum of all its connections).
"""
_, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
assert isinstance(axs, numpy.ndarray)
for i, (_, (inbound_conns, outbound_conns)) in enumerate(
self.conns_per_node.items()
):
inbound_bandwidths = (
pandas.concat(
[conn.receiving_bandwidths() for conn in inbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024)
)
outbound_bandwidths = (
pandas.concat(
[conn.sending_bandwidths() for conn in outbound_conns], axis=1
)
.sum(axis=1)
.map(lambda x: x / 1024)
)
axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}")
axs[1].plot(
outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}"
)
axs[0].set_title("Inbound Bandwidths per Node")
axs[0].set_xlabel("Time (s)")
axs[0].set_ylabel("Bandwidth (KiB/s)")
axs[0].legend()
axs[0].set_ylim(bottom=0)
axs[0].grid(True)
axs[1].set_title("Outbound Bandwidths per Node")
axs[1].set_xlabel("Time (s)")
axs[1].set_ylabel("Bandwidth (KiB/s)")
axs[1].legend()
axs[1].set_ylim(bottom=0)
axs[1].grid(True)
plt.tight_layout()
plt.draw()
class DisseminationTime:
def __init__(self, num_nodes: int):
# A collection of time taken for a message to propagate through all mix nodes in its mix route
self.mix_propagation_times: list[float] = []
# A collection of time taken for a message to be broadcasted from the last mix to all nodes in the network
self.broadcast_dissemination_times: list[float] = []
# Data structures to check if a message has been broadcasted to all nodes
self.broadcast_status: Counter[Message] = Counter()
self.num_nodes: int = num_nodes
def add_mix_propagation_time(self, elapsed: float):
self.mix_propagation_times.append(elapsed)
def add_broadcasted_msg(self, msg: Message, elapsed: float):
assert self.broadcast_status[msg] < self.num_nodes
self.broadcast_status.update([msg])
if self.broadcast_status[msg] == self.num_nodes:
self.broadcast_dissemination_times.append(elapsed)
def analyze(self):
print("==========================================")
print("Message Dissemination Time")
print("==========================================")
print("[Mix Propagation Times]")
mix_propagation_times = pandas.Series(self.mix_propagation_times)
print(mix_propagation_times.describe())
print("")
print("[Broadcast Dissemination Times]")
broadcast_travel_times = pandas.Series(self.broadcast_dissemination_times)
print(broadcast_travel_times.describe())
print("")

View File

@ -0,0 +1,104 @@
import math
import random
from unittest import IsolatedAsyncioTestCase
import usim
import framework.usim as usimfw
from protocol.connection import LocalSimplexConnection
from protocol.node import Node
from protocol.test_utils import (
init_mixnet_config,
)
from sim.config import LatencyConfig, NetworkConfig
from sim.connection import (
MeteredRemoteSimplexConnection,
ObservedMeteredRemoteSimplexConnection,
)
from sim.state import NodeState, NodeStateTable
class TestMeteredRemoteSimplexConnection(IsolatedAsyncioTestCase):
async def test_latency(self):
usim.run(self.__test_latency())
async def __test_latency(self):
async with usim.Scope() as scope:
framework = usimfw.Framework(scope)
node_state_table = NodeStateTable(num_nodes=2, duration_sec=3)
conn = MeteredRemoteSimplexConnection(
LatencyConfig(
min_latency_sec=0,
max_latency_sec=1,
seed=random.Random(),
),
framework,
framework.now(),
)
# Send two messages without delay
sent_time = framework.now()
await conn.send(b"hello")
await conn.send(b"world")
# Receive two messages and check if the network latency was simulated well.
# There should be no delay between the two messages because they were sent without delay.
self.assertEqual(b"hello", await conn.recv())
self.assertEqual(conn.latency, framework.now() - sent_time)
self.assertEqual(b"world", await conn.recv())
self.assertEqual(conn.latency, framework.now() - sent_time)
class TestObservedMeteredRemoteSimplexConnection(IsolatedAsyncioTestCase):
async def test_node_state(self):
usim.run(self.__test_node_state())
async def __test_node_state(self):
async with usim.Scope() as scope:
framework = usimfw.Framework(scope)
node_state_table = NodeStateTable(num_nodes=2, duration_sec=3)
meter_start_time = framework.now()
conn = ObservedMeteredRemoteSimplexConnection(
LatencyConfig(
min_latency_sec=0,
max_latency_sec=1,
seed=random.Random(),
),
framework,
meter_start_time,
node_state_table[0],
node_state_table[1],
)
# Sleep and send a message
await framework.sleep(1)
sent_time = framework.now()
await conn.send(b"hello")
# Receive the message. It should be received after the latency.
self.assertEqual(b"hello", await conn.recv())
recv_time = framework.now()
# Check if the sender node state is SENDING at the sent time
timeslot = math.floor((sent_time - meter_start_time) * 1000)
self.assertEqual(
NodeState.SENDING,
node_state_table[0][timeslot],
)
# Ensure that the sender node states in other time slots are IDLE
states = set()
states.update(node_state_table[0][:timeslot])
states.update(node_state_table[0][timeslot + 1 :])
self.assertEqual(set([NodeState.IDLE]), states)
# Check if the receiver node state is RECEIVING at the received time
timeslot = math.floor((recv_time - meter_start_time) * 1000)
self.assertEqual(
NodeState.RECEIVING,
node_state_table[1][timeslot],
)
# Ensure that the receiver node states in other time slots are IDLE
states = set()
states.update(node_state_table[1][:timeslot])
states.update(node_state_table[1][timeslot + 1 :])
self.assertEqual(set([NodeState.IDLE]), states)

View File

@ -0,0 +1,23 @@
import time
from unittest import TestCase
from sim.message import Message, UniqueMessageBuilder
class TestMessage(TestCase):
def test_message_serde(self):
msg = Message(time.time(), 10, b"hello")
serialized = bytes(msg)
deserialized = Message.from_bytes(serialized)
self.assertEqual(msg, deserialized)
class TestUniqueMessageBuilder(TestCase):
def test_uniqueness(self):
builder = UniqueMessageBuilder()
msg1 = builder.next(time.time(), b"hello")
msg2 = builder.next(time.time(), b"hello")
self.assertEqual(0, msg1.id)
self.assertEqual(1, msg2.id)
self.assertNotEqual(msg1, msg2)
self.assertNotEqual(hash(msg1), hash(msg2))

View File

@ -0,0 +1,20 @@
import random
from unittest import TestCase
from sim.topology import are_all_nodes_connected, build_full_random_topology
class TestTopology(TestCase):
def test_full_random(self):
num_nodes = 100
peering_degree = 6
topology = build_full_random_topology(
random.Random(0), num_nodes, peering_degree
)
self.assertEqual(num_nodes, len(topology))
self.assertTrue(are_all_nodes_connected(topology))
for node, peers in topology.items():
self.assertTrue(0 < len(peers) <= peering_degree)
# Check if nodes are interconnected
for peer in peers:
self.assertIn(node, topology[peer])

58
mixnet/sim/topology.py Normal file
View File

@ -0,0 +1,58 @@
import random
from collections import defaultdict
from protocol.node import Node
Topology = dict[int, set[int]]
def build_full_random_topology(
rng: random.Random, num_nodes: int, peering_degree: int
) -> Topology:
"""
Generate a random undirected topology until all nodes are connected.
We don't implement any artificial tool to ensure the connectivity of the topology.
Instead, we regenerate a topology in a fully randomized way until all nodes are connected.
"""
while True:
topology: Topology = defaultdict(set[int])
nodes = list(range(num_nodes))
for node in nodes:
# Filter nodes that can be connected to the current node.
others = []
for other in nodes[:node] + nodes[node + 1 :]:
# Check if the other node is not already connected to the current node
# and the other node has not reached the peering degree.
if (
other not in topology[node]
and len(topology[other]) < peering_degree
):
others.append(other)
# How many more connections the current node needs
n_needs = peering_degree - len(topology[node])
# Sample peers as many as possible
peers = rng.sample(others, k=min(n_needs, len(others)))
# Connect the current node to the peers
topology[node].update(peers)
# Connect the peers to the current node, since the topology is undirected
for peer in peers:
topology[peer].update([node])
if are_all_nodes_connected(topology):
return topology
def are_all_nodes_connected(topology: Topology) -> bool:
visited = set()
def dfs(topology: Topology, node: int) -> None:
if node in visited:
return
visited.add(node)
for peer in topology[node]:
dfs(topology, peer)
# Start DFS from the first node
dfs(topology, next(iter(topology)))
return len(visited) == len(topology)