Mixnet: Initial simulation (#6)
This commit is contained in:
parent
537f86f53f
commit
39eabe1537
|
@ -0,0 +1,30 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
push:
|
||||
branches: [master]
|
||||
|
||||
jobs:
|
||||
mixnet:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
- name: Set up Python 3.x
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- name: Install dependencies for mixnet
|
||||
working-directory: mixnet
|
||||
run: pip install -r requirements.txt
|
||||
- name: Run unit tests
|
||||
working-directory: mixnet
|
||||
run: python -m unittest -v
|
||||
- name: Run a short mixnet simulation
|
||||
working-directory: mixnet
|
||||
run: python -m cmd.main --config config.ci.yaml
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
.venv/
|
||||
*.csv
|
|
@ -0,0 +1,140 @@
|
|||
# NomMix Simulation
|
||||
|
||||
* [Project Structure](#project-structure)
|
||||
* [Features](#features)
|
||||
* [Future Plans](#future-plans)
|
||||
* [Installation](#installation)
|
||||
* [Getting Started](#getting-started)
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `cmd`: CLIs to run the simulation and analyze the results.
|
||||
- `sim`: Simulation that runs the NomMix defined in the `protocol` package.
|
||||
- `protocol`: Core NomMix protocol implementation, which is going to be moved to the [nomos-repos](https://github.com/logos-co/nomos-specs) repository once verified by simulations.
|
||||
- `framework`: Asynchronous framework that provides essential async functions for simulations and tests, implemented with various async libraries ([asyncio](https://docs.python.org/3/library/asyncio.html), [μSim](https://usim.readthedocs.io/en/latest/), etc.)
|
||||
|
||||
## Features
|
||||
|
||||
- NomMix protocol simulation
|
||||
- Performance measurements
|
||||
- Bandwidth usages
|
||||
- Message dissemination time
|
||||
- Privacy property analysis
|
||||
- Message sizes
|
||||
- Node states and hamming distances
|
||||
|
||||
## Future Plans
|
||||
|
||||
- More NomMix features
|
||||
- Temporal mixing
|
||||
- Level-1 noise
|
||||
- Adversary simulation to measure the robustness of NomMix
|
||||
|
||||
## Installation
|
||||
|
||||
Clone the repository and install the dependencies:
|
||||
```bash
|
||||
git clone https://github.com/logos-co/nomos-simulations.git
|
||||
cd nomos-simulations/mixnet
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
Copy the [`config.ci.yaml`](./config.ci.yaml) file and adjust the parameters to your needs.
|
||||
Each parameter is explained in the config file.
|
||||
For more details, please refer to the [documentation](https://www.notion.so/NomMix-Sim-Getting-Started-ee0e2191f4e7437e93976aff2627d7ce?pvs=4).
|
||||
|
||||
Run the simulation with the following command:
|
||||
```bash
|
||||
python -m cmd.main --config {config_path}
|
||||
```
|
||||
|
||||
All results are printed in the console as below.
|
||||
And, all plots are shown once all analysis is done.
|
||||
```
|
||||
Spawning node-0 with 3 conns
|
||||
Spawning node-1 with 3 conns
|
||||
Spawning node-2 with 3 conns
|
||||
Spawning node-3 with 3 conns
|
||||
Spawning node-4 with 3 conns
|
||||
Spawning node-5 with 3 conns
|
||||
==========================================
|
||||
Message Dissemination Time
|
||||
==========================================
|
||||
[Mix Propagation Times]
|
||||
count 7.000000
|
||||
mean 1.122000
|
||||
std 0.106276
|
||||
min 1.009000
|
||||
25% 1.024500
|
||||
50% 1.157000
|
||||
75% 1.174500
|
||||
max 1.290000
|
||||
dtype: float64
|
||||
|
||||
[Broadcast Dissemination Times]
|
||||
count 7.000000
|
||||
mean 0.118429
|
||||
std 0.004353
|
||||
min 0.111000
|
||||
25% 0.116000
|
||||
50% 0.120000
|
||||
75% 0.121500
|
||||
max 0.123000
|
||||
dtype: float64
|
||||
|
||||
==========================================
|
||||
Message Size Distribution
|
||||
==========================================
|
||||
msg_size count
|
||||
0 1405 179982
|
||||
|
||||
==========================================
|
||||
Node States of All Nodes over Time
|
||||
SENDING:-1, IDLE:0, RECEIVING:1
|
||||
==========================================
|
||||
Node-0 Node-1 Node-2 Node-3 Node-4 Node-5
|
||||
0 0 0 0 0 0 0
|
||||
1 0 0 0 0 0 0
|
||||
2 0 0 0 0 0 0
|
||||
3 0 0 0 0 0 0
|
||||
4 0 0 0 0 0 0
|
||||
... ... ... ... ... ... ...
|
||||
999995 0 0 0 0 0 0
|
||||
999996 0 0 0 0 0 1
|
||||
999997 0 0 0 0 0 0
|
||||
999998 0 0 0 1 0 0
|
||||
999999 0 0 0 0 0 0
|
||||
|
||||
[1000000 rows x 6 columns]
|
||||
|
||||
Saved DataFrame to all_node_states_2024-07-23T09:10:59.csv
|
||||
|
||||
State Counts per Node:
|
||||
Node-0 Node-1 Node-2 Node-3 Node-4 Node-5
|
||||
0 960004 960004 960004 960004 960004 960004
|
||||
1 29997 29997 29997 29997 29997 29997
|
||||
-1 9999 9999 9999 9999 9999 9999
|
||||
|
||||
Simulation complete!
|
||||
```
|
||||
|
||||
Please note that the result of node state analysis is saved as a CSV file, as printed in the console.
|
||||
|
||||
If you run the simulation again with the different parameters and want to
|
||||
compare the results of two simulations,
|
||||
you can calculate the hamming distance between them:
|
||||
```bash
|
||||
python -m cmd.hamming \
|
||||
all_node_states_2024-07-15T18:20:23.csv \
|
||||
all_node_states_2024-07-15T19:32:45.csv
|
||||
```
|
||||
The output is a floating point number between 0 and 1.
|
||||
If the output is 0, the results of two simulations are identical.
|
||||
The closer the result is to 1, the more the two results differ from each other.
|
||||
```
|
||||
Hamming distance: 0.29997
|
||||
```
|
|
@ -0,0 +1,42 @@
|
|||
import sys
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def calculate_hamming_distance(df1, df2):
|
||||
"""
|
||||
Caculate the hamming distance between two DataFrames
|
||||
to quantify the difference between them.
|
||||
"""
|
||||
if df1.shape != df2.shape:
|
||||
raise ValueError(
|
||||
"DataFrames must have the same shape to calculate Hamming distance."
|
||||
)
|
||||
|
||||
# Compare element-wise and count differences
|
||||
differences = (df1 != df2).sum().sum()
|
||||
return differences / df1.size # normalize the distance
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage: python hamming.py <csv_path1> <csv_path2>")
|
||||
sys.exit(1)
|
||||
|
||||
csv_path1 = sys.argv[1]
|
||||
csv_path2 = sys.argv[2]
|
||||
|
||||
# Load the CSV files into DataFrames
|
||||
df1 = pd.read_csv(csv_path1)
|
||||
df2 = pd.read_csv(csv_path2)
|
||||
|
||||
# Calculate the Hamming distance
|
||||
try:
|
||||
hamming_distance = calculate_hamming_distance(df1, df2)
|
||||
print(f"Hamming distance: {hamming_distance}")
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,25 @@
|
|||
import argparse
|
||||
|
||||
import usim
|
||||
|
||||
from sim.config import Config
|
||||
from sim.simulation import Simulation
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Read a config file and run a simulation
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run mixnet simulation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Configuration file path"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = Config.load(args.config)
|
||||
sim = Simulation(config)
|
||||
usim.run(sim.run())
|
||||
|
||||
print("Simulation complete!")
|
|
@ -0,0 +1,53 @@
|
|||
simulation:
|
||||
# Desired duration of the simulation in seconds
|
||||
# Since the simulation uses discrete time steps, the actual duration may be longer or shorter.
|
||||
duration_sec: 1000
|
||||
# Show all plots that have been drawn during the simulation
|
||||
show_plots: false
|
||||
|
||||
network:
|
||||
# Total number of nodes in the entire network.
|
||||
num_nodes: 6
|
||||
latency:
|
||||
# Minimum/maximum network latency between nodes in seconds.
|
||||
# A constant latency will be chosen randomly for each connection within the range [min_latency_sec, max_latency_sec].
|
||||
min_latency_sec: 0
|
||||
max_latency_sec: 0.1
|
||||
# Seed for the random number generator used to determine the network latencies.
|
||||
seed: 0
|
||||
gossip:
|
||||
# Expected number of peers each node must connect to if there are enough peers available in the network.
|
||||
peering_degree: 3
|
||||
topology:
|
||||
# Seed for the random number generator used to determine the network topology.
|
||||
seed: 1
|
||||
|
||||
mix:
|
||||
# Global constant transmission rate of each connection in messages per second.
|
||||
transmission_rate_per_sec: 10
|
||||
# Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet.
|
||||
max_message_size: 1007
|
||||
mix_path:
|
||||
# Minimum number of mix nodes to be chosen for a Sphinx packet.
|
||||
min_length: 5
|
||||
# Maximum number of mix nodes to be chosen for a Sphinx packet.
|
||||
max_length: 5
|
||||
# Seed for the random number generator used to determine the mix path.
|
||||
seed: 3
|
||||
temporal_mix:
|
||||
# none | pure-coin-flipping | pure-random-sampling | permuted-coin-flipping
|
||||
mix_type: "pure-coin-flipping"
|
||||
# The minimum size of queue to be mixed.
|
||||
# If the queue size is less than this value, noise messages are added.
|
||||
min_queue_size: 5
|
||||
# Generate the seeds used to create the RNG for each queue that will be created.
|
||||
seed_generator: 100
|
||||
|
||||
logic:
|
||||
sender_lottery:
|
||||
# Interval between lottery draws in seconds.
|
||||
interval_sec: 1
|
||||
# Probability of a node being selected as a sender in each lottery draw.
|
||||
probability: 0.001
|
||||
# Seed for the random number generator used to determine the lottery winners.
|
||||
seed: 10
|
|
@ -0,0 +1 @@
|
|||
from .framework import *
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Awaitable, Coroutine, Generic, TypeVar
|
||||
|
||||
from framework import framework
|
||||
|
||||
|
||||
class Framework(framework.Framework):
|
||||
"""
|
||||
An asyncio implementation of the Framework
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def queue(self) -> framework.Queue:
|
||||
return Queue()
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
def now(self) -> float:
|
||||
return time.time()
|
||||
|
||||
def spawn(
|
||||
self, coroutine: Coroutine[Any, Any, framework.RT]
|
||||
) -> Awaitable[framework.RT]:
|
||||
return asyncio.create_task(coroutine)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Queue(framework.Queue[T]):
|
||||
"""
|
||||
An asyncio implementation of the Queue
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def put(self, data: T) -> None:
|
||||
await self._queue.put(data)
|
||||
|
||||
async def get(self) -> T:
|
||||
return await self._queue.get()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self._queue.empty()
|
|
@ -0,0 +1,50 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any, Awaitable, Coroutine, Generic, TypeVar
|
||||
|
||||
RT = TypeVar("RT")
|
||||
|
||||
|
||||
class Framework(abc.ABC):
|
||||
"""
|
||||
An abstract class that provides essential asynchronous functions.
|
||||
This class can be implemented using any asynchronous framework (e.g., asyncio, usim).
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def queue(self) -> Queue:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def now(self) -> float:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def spawn(self, coroutine: Coroutine[Any, Any, RT]) -> Awaitable[RT]:
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Queue(abc.ABC, Generic[T]):
|
||||
"""
|
||||
An abstract class that provides asynchronous queue operations.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put(self, data: T) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self) -> T:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def empty(self) -> bool:
|
||||
pass
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Any, Awaitable, Coroutine, TypeVar
|
||||
|
||||
import usim
|
||||
|
||||
from framework import framework
|
||||
|
||||
|
||||
class Framework(framework.Framework):
|
||||
"""
|
||||
A usim implementation of the Framework for discrete-time simulation
|
||||
"""
|
||||
|
||||
def __init__(self, scope: usim.Scope) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Scope is used to spawn concurrent simulation activities (coroutines).
|
||||
# μSim waits until all activities spawned in the scope are done
|
||||
# or until the timeout specified in the scope is reached.
|
||||
# Because of the way μSim works, the scope must be created using `async with` syntax
|
||||
# and be passed to this constructor.
|
||||
self._scope = scope
|
||||
|
||||
def queue(self) -> framework.Queue:
|
||||
return Queue()
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await (usim.time + seconds)
|
||||
|
||||
def now(self) -> float:
|
||||
# Round to milliseconds to make analysis not too heavy
|
||||
return int(usim.time.now * 1000) / 1000
|
||||
|
||||
def spawn(
|
||||
self, coroutine: Coroutine[Any, Any, framework.RT]
|
||||
) -> Awaitable[framework.RT]:
|
||||
return self._scope.do(coroutine)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Queue(framework.Queue[T]):
|
||||
"""
|
||||
A usim implementation of the Queue for discrete-time simulation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._queue = usim.Queue()
|
||||
|
||||
async def put(self, data: T) -> None:
|
||||
await self._queue.put(data)
|
||||
|
||||
async def get(self) -> T:
|
||||
return await self._queue
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self._queue._buffer) == 0
|
|
@ -0,0 +1,66 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from pysphinx.node import X25519PublicKey
|
||||
from pysphinx.sphinx import Node as SphinxNode
|
||||
from pysphinx.sphinx import X25519PrivateKey
|
||||
|
||||
from protocol.gossip import GossipConfig
|
||||
from protocol.temporalmix import TemporalMixConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlobalConfig:
|
||||
"""
|
||||
Global parameters used across all nodes in the network
|
||||
"""
|
||||
|
||||
membership: MixMembership
|
||||
transmission_rate_per_sec: int # Global Transmission Rate
|
||||
max_message_size: int
|
||||
max_mix_path_length: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeConfig:
|
||||
"""
|
||||
Node-specific parameters
|
||||
"""
|
||||
|
||||
private_key: X25519PrivateKey
|
||||
mix_path_length: int
|
||||
gossip: GossipConfig
|
||||
temporal_mix: TemporalMixConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixMembership:
|
||||
"""
|
||||
A list of public information of nodes in the network.
|
||||
We assume that this list is known to all nodes in the network.
|
||||
"""
|
||||
|
||||
nodes: List[NodeInfo]
|
||||
rng: random.Random = field(default_factory=random.Random)
|
||||
|
||||
def generate_route(self, length: int) -> list[NodeInfo]:
|
||||
"""
|
||||
Choose `length` nodes with replacement as a mix route.
|
||||
"""
|
||||
return self.rng.choices(self.nodes, k=length)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeInfo:
|
||||
"""
|
||||
Public information of a node to be shared to all nodes in the network
|
||||
"""
|
||||
|
||||
public_key: X25519PublicKey
|
||||
|
||||
def sphinx_node(self) -> SphinxNode:
|
||||
dummy_node_addr = bytes(32)
|
||||
return SphinxNode(self.public_key, dummy_node_addr)
|
|
@ -0,0 +1,88 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import random
|
||||
|
||||
from framework import Framework, Queue
|
||||
from protocol.temporalmix import PureCoinFlipppingQueue, TemporalMix, TemporalMixConfig
|
||||
|
||||
|
||||
class SimplexConnection(abc.ABC):
|
||||
"""
|
||||
An abstract class for a simplex connection that can send and receive data in one direction
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def recv(self) -> bytes:
|
||||
pass
|
||||
|
||||
|
||||
class LocalSimplexConnection(SimplexConnection):
|
||||
"""
|
||||
A simplex connection that doesn't have any network latency.
|
||||
Data sent through this connection can be immediately received from the other end.
|
||||
"""
|
||||
|
||||
def __init__(self, framework: Framework):
|
||||
self.queue: Queue[bytes] = framework.queue()
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
await self.queue.put(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.queue.get()
|
||||
|
||||
|
||||
class DuplexConnection:
|
||||
"""
|
||||
A duplex connection in which data can be transmitted and received simultaneously in both directions.
|
||||
This is to mimic duplex communication in a real network (such as TCP or QUIC).
|
||||
"""
|
||||
|
||||
def __init__(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
||||
self.inbound = inbound
|
||||
self.outbound = outbound
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.inbound.recv()
|
||||
|
||||
async def send(self, packet: bytes):
|
||||
await self.outbound.send(packet)
|
||||
|
||||
|
||||
class MixSimplexConnection(SimplexConnection):
|
||||
"""
|
||||
Wraps a SimplexConnection to add a transmission rate and noise to the connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
framework: Framework,
|
||||
conn: SimplexConnection,
|
||||
transmission_rate_per_sec: int,
|
||||
noise_msg: bytes,
|
||||
temporal_mix_config: TemporalMixConfig,
|
||||
):
|
||||
self.framework = framework
|
||||
self.queue: Queue[bytes] = TemporalMix.queue(
|
||||
temporal_mix_config, framework, noise_msg
|
||||
)
|
||||
self.conn = conn
|
||||
self.transmission_rate_per_sec = transmission_rate_per_sec
|
||||
self.task = framework.spawn(self.__run())
|
||||
|
||||
async def __run(self):
|
||||
while True:
|
||||
await self.framework.sleep(1 / self.transmission_rate_per_sec)
|
||||
msg = await self.queue.get()
|
||||
await self.conn.send(msg)
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
await self.queue.put(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.conn.recv()
|
|
@ -0,0 +1,2 @@
|
|||
class PeeringDegreeReached(Exception):
|
||||
pass
|
|
@ -0,0 +1,87 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Self
|
||||
|
||||
from framework import Framework
|
||||
from protocol.connection import (
|
||||
DuplexConnection,
|
||||
MixSimplexConnection,
|
||||
SimplexConnection,
|
||||
)
|
||||
from protocol.error import PeeringDegreeReached
|
||||
|
||||
|
||||
@dataclass
|
||||
class GossipConfig:
|
||||
# Expected number of peers each node must connect to if there are enough peers available in the network.
|
||||
peering_degree: int
|
||||
|
||||
|
||||
class Gossip:
|
||||
"""
|
||||
A gossip channel that broadcasts messages to all connected peers.
|
||||
Peers are connected via DuplexConnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
framework: Framework,
|
||||
config: GossipConfig,
|
||||
handler: Callable[[bytes], Awaitable[None]],
|
||||
):
|
||||
self.framework = framework
|
||||
self.config = config
|
||||
self.conns: list[DuplexConnection] = []
|
||||
# A handler to process inbound messages.
|
||||
self.handler = handler
|
||||
self.packet_cache: set[bytes] = set()
|
||||
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
|
||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||
self.tasks: set[Awaitable] = set()
|
||||
|
||||
def can_accept_conn(self) -> bool:
|
||||
return len(self.conns) < self.config.peering_degree
|
||||
|
||||
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
||||
if not self.can_accept_conn():
|
||||
# For simplicity of the spec, reject the connection if the peering degree is reached.
|
||||
raise PeeringDegreeReached()
|
||||
|
||||
conn = DuplexConnection(
|
||||
inbound,
|
||||
outbound,
|
||||
)
|
||||
self.conns.append(conn)
|
||||
task = self.framework.spawn(self.__process_inbound_conn(conn))
|
||||
self.tasks.add(task)
|
||||
|
||||
async def __process_inbound_conn(self, conn: DuplexConnection):
|
||||
while True:
|
||||
msg = await conn.recv()
|
||||
if self.__check_update_cache(msg):
|
||||
continue
|
||||
await self.process_inbound_msg(msg)
|
||||
|
||||
async def process_inbound_msg(self, msg: bytes):
|
||||
await self.gossip(msg)
|
||||
await self.handler(msg)
|
||||
|
||||
async def gossip(self, msg: bytes):
|
||||
"""
|
||||
Gossip a message to all connected peers.
|
||||
"""
|
||||
for conn in self.conns:
|
||||
await conn.send(msg)
|
||||
|
||||
def __check_update_cache(self, packet: bytes) -> bool:
|
||||
"""
|
||||
Add a message to the cache, and return True if the message was already in the cache.
|
||||
"""
|
||||
hash = hashlib.sha256(packet).digest()
|
||||
if hash in self.packet_cache:
|
||||
return True
|
||||
self.packet_cache.add(hash)
|
||||
return False
|
|
@ -0,0 +1,150 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from pysphinx.sphinx import (
|
||||
ProcessedFinalHopPacket,
|
||||
ProcessedForwardHopPacket,
|
||||
SphinxPacket,
|
||||
)
|
||||
|
||||
from framework import Framework, Queue
|
||||
from protocol.config import GlobalConfig, NodeConfig
|
||||
from protocol.connection import SimplexConnection
|
||||
from protocol.error import PeeringDegreeReached
|
||||
from protocol.gossip import Gossip
|
||||
from protocol.nomssip import Nomssip, NomssipConfig
|
||||
from protocol.sphinx import SphinxPacketBuilder
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
This represents any node in the network, which:
|
||||
- generates/gossips mix messages (Sphinx packets)
|
||||
- performs cryptographic mix (unwrapping Sphinx packets)
|
||||
- generates noise
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
framework: Framework,
|
||||
config: NodeConfig,
|
||||
global_config: GlobalConfig,
|
||||
# A handler called when a node receives a broadcasted message originated from the last mix.
|
||||
broadcasted_msg_handler: Callable[[bytes], Awaitable[None]],
|
||||
# An optional handler only for the simulation,
|
||||
# which is called when a message is fully recovered by the last mix
|
||||
# and returns a new message to be broadcasted.
|
||||
recovered_msg_handler: Callable[[bytes], Awaitable[bytes]] | None = None,
|
||||
):
|
||||
self.framework = framework
|
||||
self.config = config
|
||||
self.global_config = global_config
|
||||
self.nomssip = Nomssip(
|
||||
framework,
|
||||
NomssipConfig(
|
||||
config.gossip.peering_degree,
|
||||
global_config.transmission_rate_per_sec,
|
||||
self.__calculate_message_size(global_config),
|
||||
config.temporal_mix,
|
||||
),
|
||||
self.__process_msg,
|
||||
)
|
||||
self.broadcast = Gossip(framework, config.gossip, broadcasted_msg_handler)
|
||||
self.recovered_msg_handler = recovered_msg_handler
|
||||
|
||||
@staticmethod
|
||||
def __calculate_message_size(global_config: GlobalConfig) -> int:
|
||||
"""
|
||||
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
|
||||
"""
|
||||
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
bytes(global_config.max_message_size),
|
||||
global_config,
|
||||
global_config.max_mix_path_length,
|
||||
)
|
||||
return len(sample_sphinx_packet.bytes())
|
||||
|
||||
async def __process_msg(self, msg: bytes) -> None:
|
||||
"""
|
||||
A handler to process messages received via Nomssip channel
|
||||
"""
|
||||
sphinx_packet = SphinxPacket.from_bytes(
|
||||
msg, self.global_config.max_mix_path_length
|
||||
)
|
||||
result = await self.__process_sphinx_packet(sphinx_packet)
|
||||
match result:
|
||||
case SphinxPacket():
|
||||
# Gossip the next Sphinx packet
|
||||
await self.nomssip.gossip(result.bytes())
|
||||
case bytes():
|
||||
if self.recovered_msg_handler is not None:
|
||||
result = await self.recovered_msg_handler(result)
|
||||
# Broadcast the message fully recovered from Sphinx packets
|
||||
await self.broadcast.gossip(result)
|
||||
case None:
|
||||
return
|
||||
|
||||
async def __process_sphinx_packet(
|
||||
self, packet: SphinxPacket
|
||||
) -> SphinxPacket | bytes | None:
|
||||
"""
|
||||
Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible
|
||||
"""
|
||||
try:
|
||||
processed = packet.process(self.config.private_key)
|
||||
match processed:
|
||||
case ProcessedForwardHopPacket():
|
||||
return processed.next_packet
|
||||
case ProcessedFinalHopPacket():
|
||||
return processed.payload.recover_plain_playload()
|
||||
except ValueError:
|
||||
# Return nothing, if it cannot be unwrapped by the private key of this node.
|
||||
return None
|
||||
|
||||
def connect_mix(
|
||||
self,
|
||||
peer: Node,
|
||||
inbound_conn: SimplexConnection,
|
||||
outbound_conn: SimplexConnection,
|
||||
):
|
||||
Node.__connect(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
|
||||
|
||||
def connect_broadcast(
|
||||
self,
|
||||
peer: Node,
|
||||
inbound_conn: SimplexConnection,
|
||||
outbound_conn: SimplexConnection,
|
||||
):
|
||||
Node.__connect(self.broadcast, peer.broadcast, inbound_conn, outbound_conn)
|
||||
|
||||
@staticmethod
|
||||
def __connect(
|
||||
self_channel: Gossip,
|
||||
peer_channel: Gossip,
|
||||
inbound_conn: SimplexConnection,
|
||||
outbound_conn: SimplexConnection,
|
||||
):
|
||||
"""
|
||||
Establish a duplex connection with a peer node.
|
||||
"""
|
||||
if not self_channel.can_accept_conn() or not peer_channel.can_accept_conn():
|
||||
raise PeeringDegreeReached()
|
||||
|
||||
# Register a duplex connection for its own use
|
||||
self_channel.add_conn(inbound_conn, outbound_conn)
|
||||
# Register a duplex connection for the peer
|
||||
peer_channel.add_conn(outbound_conn, inbound_conn)
|
||||
|
||||
async def send_message(self, msg: bytes):
|
||||
"""
|
||||
Build a Sphinx packet and gossip it to all connected peers.
|
||||
"""
|
||||
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
|
||||
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
|
||||
sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
msg,
|
||||
self.global_config,
|
||||
self.config.mix_path_length,
|
||||
)
|
||||
await self.nomssip.gossip(sphinx_packet.bytes())
|
|
@ -0,0 +1,106 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Self, override
|
||||
|
||||
from framework import Framework
|
||||
from protocol.connection import (
|
||||
DuplexConnection,
|
||||
MixSimplexConnection,
|
||||
SimplexConnection,
|
||||
)
|
||||
from protocol.error import PeeringDegreeReached
|
||||
from protocol.gossip import Gossip, GossipConfig
|
||||
from protocol.temporalmix import TemporalMixConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class NomssipConfig(GossipConfig):
|
||||
transmission_rate_per_sec: int
|
||||
msg_size: int
|
||||
temporal_mix: TemporalMixConfig
|
||||
|
||||
|
||||
class Nomssip(Gossip):
|
||||
"""
|
||||
A NomMix gossip channel that extends the Gossip channel
|
||||
by adding global transmission rate and noise generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
framework: Framework,
|
||||
config: NomssipConfig,
|
||||
handler: Callable[[bytes], Awaitable[None]],
|
||||
):
|
||||
super().__init__(framework, config, handler)
|
||||
self.config = config
|
||||
|
||||
@override
|
||||
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
||||
noise_packet = FlaggedPacket(
|
||||
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
|
||||
).bytes()
|
||||
super().add_conn(
|
||||
inbound,
|
||||
MixSimplexConnection(
|
||||
self.framework,
|
||||
outbound,
|
||||
self.config.transmission_rate_per_sec,
|
||||
noise_packet,
|
||||
self.config.temporal_mix,
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def process_inbound_msg(self, msg: bytes):
|
||||
packet = FlaggedPacket.from_bytes(msg)
|
||||
match packet.flag:
|
||||
case FlaggedPacket.Flag.NOISE:
|
||||
# Drop noise packet
|
||||
return
|
||||
case FlaggedPacket.Flag.REAL:
|
||||
await self.__gossip_flagged_packet(packet)
|
||||
await self.handler(packet.message)
|
||||
|
||||
@override
|
||||
async def gossip(self, msg: bytes):
|
||||
"""
|
||||
Gossip a message to all connected peers with prepending a message flag
|
||||
"""
|
||||
# The message size must be fixed.
|
||||
assert len(msg) == self.config.msg_size
|
||||
|
||||
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg)
|
||||
await self.__gossip_flagged_packet(packet)
|
||||
|
||||
async def __gossip_flagged_packet(self, packet: FlaggedPacket):
|
||||
"""
|
||||
An internal method to send a flagged packet to all connected peers
|
||||
"""
|
||||
await super().gossip(packet.bytes())
|
||||
|
||||
|
||||
class FlaggedPacket:
|
||||
class Flag(Enum):
|
||||
REAL = b"\x00"
|
||||
NOISE = b"\x01"
|
||||
|
||||
def __init__(self, flag: Flag, message: bytes):
|
||||
self.flag = flag
|
||||
self.message = message
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return self.flag.value + self.message
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, packet: bytes) -> Self:
|
||||
"""
|
||||
Parse a flagged packet from bytes
|
||||
"""
|
||||
if len(packet) < 1:
|
||||
raise ValueError("Invalid message format")
|
||||
return cls(cls.Flag(packet[:1]), packet[1:])
|
|
@ -0,0 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from pysphinx.sphinx import SphinxPacket
|
||||
|
||||
from protocol.config import GlobalConfig, NodeInfo
|
||||
|
||||
|
||||
class SphinxPacketBuilder:
|
||||
@staticmethod
|
||||
def build(
|
||||
message: bytes, global_config: GlobalConfig, path_len: int
|
||||
) -> Tuple[SphinxPacket, List[NodeInfo]]:
|
||||
if path_len <= 0:
|
||||
raise ValueError("path_len must be greater than 0")
|
||||
if len(message) > global_config.max_message_size:
|
||||
raise ValueError("message is too long")
|
||||
|
||||
route = global_config.membership.generate_route(path_len)
|
||||
# We don't need the destination (defined in the Loopix Sphinx spec)
|
||||
# because the last mix will broadcast the fully unwrapped message.
|
||||
# Later, we will optimize the Sphinx according to our requirements.
|
||||
dummy_destination = route[-1]
|
||||
|
||||
packet = SphinxPacket.build(
|
||||
message,
|
||||
route=[mixnode.sphinx_node() for mixnode in route],
|
||||
destination=dummy_destination.sphinx_node(),
|
||||
max_route_length=global_config.max_mix_path_length,
|
||||
max_plain_payload_size=global_config.max_message_size,
|
||||
)
|
||||
return (packet, route)
|
|
@ -0,0 +1,177 @@
|
|||
import random
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TypeVar
|
||||
|
||||
from framework.framework import Framework, Queue
|
||||
|
||||
|
||||
class TemporalMixType(Enum):
|
||||
NONE = "none"
|
||||
PURE_COIN_FLIPPING = "pure-coin-flipping"
|
||||
PURE_RANDOM_SAMPLING = "pure-random-sampling"
|
||||
PERMUTED_COIN_FLIPPING = "permuted-coin-flipping"
|
||||
NOISY_COIN_FLIPPING = "noisy-coin-flipping"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalMixConfig:
|
||||
mix_type: TemporalMixType
|
||||
# The minimum size of queue to be mixed.
|
||||
# If the queue size is less than this value, noise messages are added.
|
||||
min_queue_size: int
|
||||
# Generate the seeds used to create the RNG for each queue that will be created.
|
||||
seed_generator: random.Random
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.seed_generator is not None
|
||||
assert self.min_queue_size > 0
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TemporalMix:
|
||||
@staticmethod
|
||||
def queue(
|
||||
config: TemporalMixConfig, framework: Framework, noise_msg: T
|
||||
) -> Queue[T]:
|
||||
match config.mix_type:
|
||||
case TemporalMixType.NONE:
|
||||
return NonMixQueue(framework, noise_msg)
|
||||
case TemporalMixType.PURE_COIN_FLIPPING:
|
||||
return PureCoinFlipppingQueue(
|
||||
config.min_queue_size,
|
||||
random.Random(config.seed_generator.random()),
|
||||
noise_msg,
|
||||
)
|
||||
case TemporalMixType.PURE_RANDOM_SAMPLING:
|
||||
return PureRandomSamplingQueue(
|
||||
config.min_queue_size,
|
||||
random.Random(config.seed_generator.random()),
|
||||
noise_msg,
|
||||
)
|
||||
case TemporalMixType.PERMUTED_COIN_FLIPPING:
|
||||
return PermutedCoinFlipppingQueue(
|
||||
config.min_queue_size,
|
||||
random.Random(config.seed_generator.random()),
|
||||
noise_msg,
|
||||
)
|
||||
case TemporalMixType.NOISY_COIN_FLIPPING:
|
||||
return NoisyCoinFlippingQueue(
|
||||
random.Random(config.seed_generator.random()),
|
||||
noise_msg,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown mix type: {config.mix_type}")
|
||||
|
||||
|
||||
class NonMixQueue(Queue[T]):
|
||||
"""
|
||||
Queue without temporal mixing. Only have the noise generation when the queue is empty.
|
||||
"""
|
||||
|
||||
def __init__(self, framework: Framework, noise_msg: T):
|
||||
self.__queue = framework.queue()
|
||||
self.__noise_msg = noise_msg
|
||||
|
||||
async def put(self, data: T) -> None:
|
||||
await self.__queue.put(data)
|
||||
|
||||
async def get(self) -> T:
|
||||
if self.__queue.empty():
|
||||
return self.__noise_msg
|
||||
else:
|
||||
return await self.__queue.get()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.__queue.empty()
|
||||
|
||||
|
||||
class MixQueue(Queue[T]):
|
||||
def __init__(self, rng: random.Random, noise_msg: T):
|
||||
super().__init__()
|
||||
# Assuming that simulations run in a single thread
|
||||
self._queue: list[T] = []
|
||||
self._rng = rng
|
||||
self._noise_msg = noise_msg
|
||||
|
||||
async def put(self, data: T) -> None:
|
||||
self._queue.append(data)
|
||||
|
||||
@abstractmethod
|
||||
async def get(self) -> T:
|
||||
pass
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self._queue) == 0
|
||||
|
||||
|
||||
class MinSizeMixQueue(MixQueue[T]):
|
||||
def __init__(self, min_pool_size: int, rng: random.Random, noise_msg: T):
|
||||
super().__init__(rng, noise_msg)
|
||||
self._mix_pool_size = min_pool_size
|
||||
|
||||
@abstractmethod
|
||||
async def get(self) -> T:
|
||||
while len(self._queue) < self._mix_pool_size:
|
||||
self._queue.append(self._noise_msg)
|
||||
|
||||
# Subclass must implement this method
|
||||
pass
|
||||
|
||||
|
||||
class PureCoinFlipppingQueue(MinSizeMixQueue[T]):
|
||||
async def get(self) -> T:
|
||||
await super().get()
|
||||
|
||||
while True:
|
||||
for i in range(len(self._queue)):
|
||||
# coin-flipping
|
||||
if self._rng.randint(0, 1) == 1:
|
||||
# After removing a message from the position `i`, we don't fill up the position.
|
||||
# Instead, the queue is always filled from the back.
|
||||
return self._queue.pop(i)
|
||||
|
||||
|
||||
class PureRandomSamplingQueue(MinSizeMixQueue[T]):
|
||||
async def get(self) -> T:
|
||||
await super().get()
|
||||
|
||||
i = self._rng.randint(0, len(self._queue) - 1)
|
||||
# After removing a message from the position `i`, we don't fill up the position.
|
||||
# Instead, the queue is always filled from the back.
|
||||
return self._queue.pop(i)
|
||||
|
||||
|
||||
class PermutedCoinFlipppingQueue(MinSizeMixQueue[T]):
|
||||
async def get(self) -> T:
|
||||
await super().get()
|
||||
|
||||
self._rng.shuffle(self._queue)
|
||||
|
||||
while True:
|
||||
for i in range(len(self._queue)):
|
||||
# coin-flipping
|
||||
if self._rng.randint(0, 1) == 1:
|
||||
# After removing a message from the position `i`, we don't fill up the position.
|
||||
# Instead, the queue is always filled from the back.
|
||||
return self._queue.pop(i)
|
||||
|
||||
|
||||
class NoisyCoinFlippingQueue(MixQueue[T]):
|
||||
async def get(self) -> T:
|
||||
if len(self._queue) == 0:
|
||||
return self._noise_msg
|
||||
|
||||
while True:
|
||||
for i in range(len(self._queue)):
|
||||
# coin-flipping
|
||||
if self._rng.randint(0, 1) == 1:
|
||||
# After removing a message from the position `i`, we don't fill up the position.
|
||||
# Instead, the queue is always filled from the back.
|
||||
return self._queue.pop(i)
|
||||
else:
|
||||
if i == 0:
|
||||
return self._noise_msg
|
|
@ -0,0 +1,58 @@
|
|||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
import framework.asyncio as asynciofw
|
||||
from framework.framework import Queue
|
||||
from protocol.connection import LocalSimplexConnection
|
||||
from protocol.node import Node
|
||||
from protocol.test_utils import (
|
||||
init_mixnet_config,
|
||||
)
|
||||
|
||||
|
||||
class TestNode(IsolatedAsyncioTestCase):
|
||||
async def test_node(self):
|
||||
framework = asynciofw.Framework()
|
||||
global_config, node_configs, _ = init_mixnet_config(10)
|
||||
|
||||
queue: Queue[bytes] = framework.queue()
|
||||
|
||||
async def broadcasted_msg_handler(msg: bytes) -> None:
|
||||
await queue.put(msg)
|
||||
|
||||
nodes = [
|
||||
Node(framework, node_config, global_config, broadcasted_msg_handler)
|
||||
for node_config in node_configs
|
||||
]
|
||||
for i, node in enumerate(nodes):
|
||||
try:
|
||||
node.connect_mix(
|
||||
nodes[(i + 1) % len(nodes)],
|
||||
LocalSimplexConnection(framework),
|
||||
LocalSimplexConnection(framework),
|
||||
)
|
||||
node.connect_broadcast(
|
||||
nodes[(i + 1) % len(nodes)],
|
||||
LocalSimplexConnection(framework),
|
||||
LocalSimplexConnection(framework),
|
||||
)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
|
||||
await nodes[0].send_message(b"block selection")
|
||||
|
||||
# Wait for all nodes to receive the broadcast
|
||||
num_nodes_received_broadcast = 0
|
||||
timeout = 15
|
||||
for _ in range(timeout):
|
||||
await framework.sleep(1)
|
||||
|
||||
while not queue.empty():
|
||||
self.assertEqual(b"block selection", await queue.get())
|
||||
num_nodes_received_broadcast += 1
|
||||
|
||||
if num_nodes_received_broadcast == len(nodes):
|
||||
break
|
||||
|
||||
self.assertEqual(len(nodes), num_nodes_received_broadcast)
|
||||
|
||||
# TODO: check noise
|
|
@ -0,0 +1,66 @@
|
|||
from random import randint
|
||||
from typing import cast
|
||||
from unittest import TestCase
|
||||
|
||||
from pysphinx.sphinx import (
|
||||
ProcessedFinalHopPacket,
|
||||
ProcessedForwardHopPacket,
|
||||
)
|
||||
|
||||
from protocol.sphinx import SphinxPacketBuilder
|
||||
from protocol.test_utils import init_mixnet_config
|
||||
|
||||
|
||||
class TestSphinxPacketBuilder(TestCase):
|
||||
def test_builder(self):
|
||||
global_config, _, key_map = init_mixnet_config(10)
|
||||
msg = self.random_bytes(500)
|
||||
packet, route = SphinxPacketBuilder.build(msg, global_config, 3)
|
||||
self.assertEqual(3, len(route))
|
||||
|
||||
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
|
||||
self.assertIsInstance(processed, ProcessedForwardHopPacket)
|
||||
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
|
||||
key_map[route[1].public_key.public_bytes_raw()]
|
||||
)
|
||||
self.assertIsInstance(processed, ProcessedForwardHopPacket)
|
||||
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
|
||||
key_map[route[2].public_key.public_bytes_raw()]
|
||||
)
|
||||
self.assertIsInstance(processed, ProcessedFinalHopPacket)
|
||||
recovered = cast(
|
||||
ProcessedFinalHopPacket, processed
|
||||
).payload.recover_plain_playload()
|
||||
self.assertEqual(msg, recovered)
|
||||
|
||||
def test_max_message_size(self):
|
||||
global_config, _, _ = init_mixnet_config(10, max_message_size=2000)
|
||||
mix_path_length = global_config.max_mix_path_length
|
||||
|
||||
packet1, _ = SphinxPacketBuilder.build(
|
||||
self.random_bytes(1500), global_config, mix_path_length
|
||||
)
|
||||
packet2, _ = SphinxPacketBuilder.build(
|
||||
self.random_bytes(2000), global_config, mix_path_length
|
||||
)
|
||||
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
|
||||
|
||||
msg = self.random_bytes(2001)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = SphinxPacketBuilder.build(msg, global_config, mix_path_length)
|
||||
|
||||
def test_max_mix_path_length(self):
|
||||
global_config, _, _ = init_mixnet_config(10, max_mix_path_length=2)
|
||||
msg = self.random_bytes(global_config.max_message_size)
|
||||
|
||||
packet1, _ = SphinxPacketBuilder.build(msg, global_config, 1)
|
||||
packet2, _ = SphinxPacketBuilder.build(msg, global_config, 2)
|
||||
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = SphinxPacketBuilder.build(msg, global_config, 3)
|
||||
|
||||
@staticmethod
|
||||
def random_bytes(size: int) -> bytes:
|
||||
assert size >= 0
|
||||
return bytes([randint(0, 255) for _ in range(size)])
|
|
@ -0,0 +1,103 @@
|
|||
import random
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
import framework.asyncio as asynciofw
|
||||
from framework.framework import Queue
|
||||
from protocol.temporalmix import (
|
||||
NoisyCoinFlippingQueue,
|
||||
NonMixQueue,
|
||||
PermutedCoinFlipppingQueue,
|
||||
PureCoinFlipppingQueue,
|
||||
PureRandomSamplingQueue,
|
||||
TemporalMix,
|
||||
TemporalMixConfig,
|
||||
TemporalMixType,
|
||||
)
|
||||
|
||||
|
||||
class TestTemporalMix(IsolatedAsyncioTestCase):
|
||||
async def test_queue_builder(self):
|
||||
# Check if the queue builder generates the correct queue type
|
||||
for mix_type in TemporalMixType:
|
||||
await self.__test_queue_builder(mix_type)
|
||||
|
||||
async def __test_queue_builder(self, mix_type: TemporalMixType):
|
||||
queue: Queue[int] = TemporalMix.queue(
|
||||
TemporalMixConfig(mix_type, 4, random.Random(0)),
|
||||
asynciofw.Framework(),
|
||||
-1,
|
||||
)
|
||||
match mix_type:
|
||||
case TemporalMixType.NONE:
|
||||
self.assertIsInstance(queue, NonMixQueue)
|
||||
case TemporalMixType.PURE_COIN_FLIPPING:
|
||||
self.assertIsInstance(queue, PureCoinFlipppingQueue)
|
||||
case TemporalMixType.PURE_RANDOM_SAMPLING:
|
||||
self.assertIsInstance(queue, PureRandomSamplingQueue)
|
||||
case TemporalMixType.PERMUTED_COIN_FLIPPING:
|
||||
self.assertIsInstance(queue, PermutedCoinFlipppingQueue)
|
||||
case TemporalMixType.NOISY_COIN_FLIPPING:
|
||||
self.assertIsInstance(queue, NoisyCoinFlippingQueue)
|
||||
case _:
|
||||
self.fail(f"Unknown mix type: {mix_type}")
|
||||
|
||||
async def test_non_mix_queue(self):
|
||||
queue: Queue[int] = TemporalMix.queue(
|
||||
TemporalMixConfig(TemporalMixType.NONE, 4, random.Random(0)),
|
||||
asynciofw.Framework(),
|
||||
-1,
|
||||
)
|
||||
|
||||
# Check if queue is FIFO
|
||||
await queue.put(0)
|
||||
await queue.put(1)
|
||||
self.assertEqual(0, await queue.get())
|
||||
self.assertEqual(1, await queue.get())
|
||||
|
||||
# Check if noise is generated when queue is empty
|
||||
self.assertEqual(-1, await queue.get())
|
||||
|
||||
# FIFO again
|
||||
await queue.put(2)
|
||||
self.assertEqual(2, await queue.get())
|
||||
await queue.put(3)
|
||||
self.assertEqual(3, await queue.get())
|
||||
|
||||
async def test_pure_coin_flipping_queue(self):
|
||||
await self.__test_mix_queue(TemporalMixType.PURE_COIN_FLIPPING)
|
||||
|
||||
async def test_pure_random_sampling(self):
|
||||
await self.__test_mix_queue(TemporalMixType.PURE_RANDOM_SAMPLING)
|
||||
|
||||
async def test_permuted_coin_flipping_queue(self):
|
||||
await self.__test_mix_queue(TemporalMixType.PERMUTED_COIN_FLIPPING)
|
||||
|
||||
async def test_noisy_coin_flipping_queue(self):
|
||||
await self.__test_mix_queue(TemporalMixType.NOISY_COIN_FLIPPING)
|
||||
|
||||
async def __test_mix_queue(self, mix_type: TemporalMixType):
|
||||
queue: Queue[int] = TemporalMix.queue(
|
||||
TemporalMixConfig(mix_type, 4, random.Random(0)),
|
||||
asynciofw.Framework(),
|
||||
-1,
|
||||
)
|
||||
|
||||
# Check if noise is generated when queue is empty
|
||||
self.assertEqual(-1, await queue.get())
|
||||
|
||||
# Put only 2 elements even though the min queue size is 4
|
||||
await queue.put(0)
|
||||
await queue.put(1)
|
||||
|
||||
# Wait until 2 elements are returned from the queue
|
||||
waiting = {0, 1}
|
||||
while len(waiting) > 0:
|
||||
e = await queue.get()
|
||||
if e in waiting:
|
||||
waiting.remove(e)
|
||||
else:
|
||||
# Check if it's the noise
|
||||
self.assertEqual(-1, e)
|
||||
|
||||
# Check if noise is generated when there is no real message inserted
|
||||
self.assertEqual(-1, await queue.get())
|
|
@ -0,0 +1,46 @@
|
|||
import random
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
|
||||
from protocol.config import (
|
||||
GlobalConfig,
|
||||
MixMembership,
|
||||
NodeConfig,
|
||||
NodeInfo,
|
||||
)
|
||||
from protocol.gossip import GossipConfig
|
||||
from protocol.nomssip import TemporalMixConfig
|
||||
from protocol.temporalmix import TemporalMixType
|
||||
|
||||
|
||||
def init_mixnet_config(
|
||||
num_nodes: int,
|
||||
max_message_size: int = 512,
|
||||
max_mix_path_length: int = 3,
|
||||
) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]:
|
||||
gossip_config = GossipConfig(peering_degree=6)
|
||||
node_configs = [
|
||||
NodeConfig(
|
||||
X25519PrivateKey.generate(),
|
||||
max_mix_path_length,
|
||||
gossip_config,
|
||||
TemporalMixConfig(TemporalMixType.PURE_COIN_FLIPPING, 3, random.Random()),
|
||||
)
|
||||
for _ in range(num_nodes)
|
||||
]
|
||||
global_config = GlobalConfig(
|
||||
MixMembership(
|
||||
[
|
||||
NodeInfo(node_config.private_key.public_key())
|
||||
for node_config in node_configs
|
||||
]
|
||||
),
|
||||
transmission_rate_per_sec=3,
|
||||
max_message_size=max_message_size,
|
||||
max_mix_path_length=max_mix_path_length,
|
||||
)
|
||||
key_map = {
|
||||
node_config.private_key.public_key().public_bytes_raw(): node_config.private_key
|
||||
for node_config in node_configs
|
||||
}
|
||||
return (global_config, node_configs, key_map)
|
|
@ -0,0 +1,6 @@
|
|||
usim==0.4.4
|
||||
pysphinx==0.0.5
|
||||
dacite==1.8.1
|
||||
pandas==2.2.2
|
||||
matplotlib==3.9.1
|
||||
PyYAML==6.0.1
|
|
@ -0,0 +1,163 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
import dacite
|
||||
import yaml
|
||||
from pysphinx.sphinx import X25519PrivateKey
|
||||
|
||||
from protocol.config import NodeConfig
|
||||
from protocol.gossip import GossipConfig
|
||||
from protocol.temporalmix import TemporalMixConfig, TemporalMixType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
simulation: SimulationConfig
|
||||
network: NetworkConfig
|
||||
logic: LogicConfig
|
||||
mix: MixConfig
|
||||
|
||||
@classmethod
|
||||
def load(cls, yaml_path: str) -> Config:
|
||||
with open(yaml_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return dacite.from_dict(
|
||||
data_class=Config,
|
||||
data=data,
|
||||
config=dacite.Config(
|
||||
type_hooks={
|
||||
random.Random: seed_to_random,
|
||||
TemporalMixType: str_to_temporal_mix_type,
|
||||
},
|
||||
strict=True,
|
||||
),
|
||||
)
|
||||
|
||||
def node_configs(self) -> list[NodeConfig]:
|
||||
return [
|
||||
NodeConfig(
|
||||
self.__gen_private_key(i),
|
||||
self.mix.mix_path.random_length(),
|
||||
self.network.gossip,
|
||||
self.mix.temporal_mix,
|
||||
)
|
||||
for i in range(self.network.num_nodes)
|
||||
]
|
||||
|
||||
def __gen_private_key(self, node_idx: int) -> X25519PrivateKey:
|
||||
return X25519PrivateKey.from_private_bytes(
|
||||
hashlib.sha256(node_idx.to_bytes(4, "big")).digest()[:32]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationConfig:
|
||||
# Desired duration of the simulation in seconds
|
||||
# Since the simulation uses discrete time steps, the actual duration may be longer or shorter.
|
||||
duration_sec: int
|
||||
# Show all plots that have been drawn during the simulation
|
||||
show_plots: bool
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.duration_sec > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkConfig:
|
||||
# Total number of nodes in the entire network.
|
||||
num_nodes: int
|
||||
latency: LatencyConfig
|
||||
gossip: GossipConfig
|
||||
topology: TopologyConfig
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.num_nodes > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatencyConfig:
|
||||
# Minimum/maximum network latency between nodes in seconds.
|
||||
# A constant latency will be chosen randomly for each connection within the range [min_latency_sec, max_latency_sec].
|
||||
min_latency_sec: float
|
||||
max_latency_sec: float
|
||||
# Seed for the random number generator used to determine the network latencies.
|
||||
seed: random.Random
|
||||
|
||||
def __post_init__(self):
|
||||
assert 0 <= self.min_latency_sec <= self.max_latency_sec
|
||||
assert self.seed is not None
|
||||
|
||||
def random_latency(self) -> float:
|
||||
# round to milliseconds to make analysis not too heavy
|
||||
return round(self.seed.uniform(self.min_latency_sec, self.max_latency_sec), 3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopologyConfig:
|
||||
# Seed for the random number generator used to determine the network topology.
|
||||
seed: random.Random
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.seed is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixConfig:
|
||||
# Global constant transmission rate of each connection in messages per second.
|
||||
transmission_rate_per_sec: int
|
||||
# Maximum size of a message in bytes that can be encapsulated in a single Sphinx packet.
|
||||
max_message_size: int
|
||||
mix_path: MixPathConfig
|
||||
temporal_mix: TemporalMixConfig
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.transmission_rate_per_sec > 0
|
||||
assert self.max_message_size > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixPathConfig:
|
||||
# Minimum number of mix nodes to be chosen for a Sphinx packet.
|
||||
min_length: int
|
||||
# Maximum number of mix nodes to be chosen for a Sphinx packet.
|
||||
max_length: int
|
||||
# Seed for the random number generator used to determine the mix path.
|
||||
seed: random.Random
|
||||
|
||||
def __post_init__(self):
|
||||
assert 0 < self.min_length <= self.max_length
|
||||
assert self.seed is not None
|
||||
|
||||
def random_length(self) -> int:
|
||||
return self.seed.randint(self.min_length, self.max_length)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogicConfig:
|
||||
sender_lottery: LotteryConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class LotteryConfig:
|
||||
# Interval between lottery draws in seconds.
|
||||
interval_sec: float
|
||||
# Probability of a node being selected as a sender in each lottery draw.
|
||||
probability: float
|
||||
# Seed for the random number generator used to determine the lottery winners.
|
||||
seed: random.Random
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.interval_sec > 0
|
||||
assert self.probability >= 0
|
||||
assert self.seed is not None
|
||||
|
||||
|
||||
def seed_to_random(seed: int) -> random.Random:
|
||||
return random.Random(seed)
|
||||
|
||||
|
||||
def str_to_temporal_mix_type(val: str) -> TemporalMixType:
|
||||
return TemporalMixType(val)
|
|
@ -0,0 +1,139 @@
|
|||
import math
|
||||
from collections import Counter
|
||||
from typing import Awaitable
|
||||
|
||||
import pandas
|
||||
from typing_extensions import override
|
||||
|
||||
from framework import Framework, Queue
|
||||
from protocol.connection import SimplexConnection
|
||||
from sim.config import LatencyConfig, NetworkConfig
|
||||
from sim.state import NodeState
|
||||
|
||||
|
||||
class MeteredRemoteSimplexConnection(SimplexConnection):
|
||||
"""
|
||||
A simplex connection implementation that simulates network latency and measures bandwidth usages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LatencyConfig,
|
||||
framework: Framework,
|
||||
meter_start_time: float,
|
||||
):
|
||||
self.framework = framework
|
||||
# A connection has a random constant latency
|
||||
self.latency = config.random_latency()
|
||||
# A queue of tuple(timestamp, msg) where a sender puts messages to be sent
|
||||
self.send_queue: Queue[tuple[float, bytes]] = framework.queue()
|
||||
# A task that reads messages from send_queue, and puts them to recv_queue.
|
||||
# Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message.
|
||||
self.relayer = framework.spawn(self.__run_relayer())
|
||||
# A queue where a receiver gets messages
|
||||
self.recv_queue: Queue[bytes] = framework.queue()
|
||||
# To measure bandwidth usages
|
||||
self.meter_start_time = meter_start_time
|
||||
self.send_meters: list[int] = []
|
||||
self.recv_meters: list[int] = []
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
await self.send_queue.put((self.framework.now(), data))
|
||||
self.on_sending(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
return await self.recv_queue.get()
|
||||
|
||||
async def __run_relayer(self):
|
||||
"""
|
||||
A task that reads messages from send_queue, and puts them to recv_queue.
|
||||
Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message.
|
||||
"""
|
||||
while True:
|
||||
sent_time, data = await self.send_queue.get()
|
||||
# Simulate network latency
|
||||
delay = self.latency - (self.framework.now() - sent_time)
|
||||
if delay > 0:
|
||||
await self.framework.sleep(delay)
|
||||
|
||||
# Relay msg to the recv_queue.
|
||||
# Update related statistics before msg is read from recv_queue by the receiver
|
||||
# because the time at which enters the node is important when viewed from the outside.
|
||||
self.on_receiving(data)
|
||||
await self.recv_queue.put(data)
|
||||
|
||||
def on_sending(self, data: bytes) -> None:
|
||||
"""
|
||||
Update statistics when sending a message
|
||||
"""
|
||||
self.__update_meter(self.send_meters, len(data))
|
||||
|
||||
def on_receiving(self, data: bytes) -> None:
|
||||
"""
|
||||
Update statistics when receiving a message
|
||||
"""
|
||||
self.__update_meter(self.recv_meters, len(data))
|
||||
|
||||
def __update_meter(self, meters: list[int], size: int):
|
||||
"""
|
||||
Accumulates the bandwidth usage in the current time slot (seconds).
|
||||
"""
|
||||
slot = math.floor(self.framework.now() - self.meter_start_time)
|
||||
assert slot >= len(meters) - 1
|
||||
# Fill zeros for the empty time slots
|
||||
meters.extend([0] * (slot - len(meters) + 1))
|
||||
meters[-1] += size
|
||||
|
||||
def sending_bandwidths(self) -> pandas.Series:
|
||||
"""
|
||||
Returns the accumulated sending bandwidth usage over time
|
||||
"""
|
||||
return self.__bandwidths(self.send_meters)
|
||||
|
||||
def receiving_bandwidths(self) -> pandas.Series:
|
||||
"""
|
||||
Returns the accumulated receiving bandwidth usage over time
|
||||
"""
|
||||
return self.__bandwidths(self.recv_meters)
|
||||
|
||||
def __bandwidths(self, meters: list[int]) -> pandas.Series:
|
||||
return pandas.Series(meters, name="bandwidth")
|
||||
|
||||
|
||||
class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection):
|
||||
"""
|
||||
An extension of MeteredRemoteSimplexConnection that is observed by passive observer.
|
||||
The observer monitors the node states of the sender and receiver and message sizes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LatencyConfig,
|
||||
framework: Framework,
|
||||
meter_start_time: float,
|
||||
send_node_states: list[NodeState],
|
||||
recv_node_states: list[NodeState],
|
||||
):
|
||||
super().__init__(config, framework, meter_start_time)
|
||||
|
||||
# To measure node states over time
|
||||
self.send_node_states = send_node_states
|
||||
self.recv_node_states = recv_node_states
|
||||
# To measure the size of messages sent via this connection
|
||||
self.msg_sizes: Counter[int] = Counter()
|
||||
|
||||
@override
|
||||
def on_sending(self, data: bytes) -> None:
|
||||
super().on_sending(data)
|
||||
self.__update_node_state(self.send_node_states, NodeState.SENDING)
|
||||
self.msg_sizes.update([len(data)])
|
||||
|
||||
@override
|
||||
def on_receiving(self, data: bytes) -> None:
|
||||
super().on_receiving(data)
|
||||
self.__update_node_state(self.recv_node_states, NodeState.RECEIVING)
|
||||
|
||||
def __update_node_state(self, node_states: list[NodeState], state: NodeState):
|
||||
# The time unit of node states is milliseconds
|
||||
ms = math.floor(self.framework.now() * 1000)
|
||||
node_states[ms] = state
|
|
@ -0,0 +1,42 @@
|
|||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""
|
||||
A message structure for simulation, which will be sent through mix nodes
|
||||
and eventually broadcasted to all nodes in the network.
|
||||
|
||||
The `id` must ensure the uniqueness of the message.
|
||||
"""
|
||||
|
||||
created_at: float
|
||||
id: int
|
||||
body: bytes
|
||||
|
||||
def __bytes__(self):
|
||||
return pickle.dumps(self)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return pickle.loads(data)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.id
|
||||
|
||||
|
||||
class UniqueMessageBuilder:
|
||||
"""
|
||||
Builds a unique message with an incremental ID,
|
||||
assuming that the simulation is run in a single thread.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.next_id = 0
|
||||
|
||||
def next(self, created_at: float, body: bytes) -> Message:
|
||||
msg = Message(created_at, self.next_id, body)
|
||||
self.next_id += 1
|
||||
return msg
|
|
@ -0,0 +1,198 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from pprint import pprint
|
||||
from typing import Self
|
||||
|
||||
import usim
|
||||
from matplotlib import pyplot
|
||||
|
||||
import framework.usim as usimfw
|
||||
from framework import Framework
|
||||
from protocol.config import GlobalConfig, MixMembership, NodeInfo
|
||||
from protocol.node import Node, PeeringDegreeReached
|
||||
from sim.config import Config
|
||||
from sim.connection import (
|
||||
MeteredRemoteSimplexConnection,
|
||||
ObservedMeteredRemoteSimplexConnection,
|
||||
)
|
||||
from sim.message import Message, UniqueMessageBuilder
|
||||
from sim.state import NodeState, NodeStateTable
|
||||
from sim.stats import ConnectionStats, DisseminationTime
|
||||
from sim.topology import build_full_random_topology
|
||||
|
||||
|
||||
class Simulation:
|
||||
"""
|
||||
Manages the entire cycle of simulation: initialization, running, and analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.msg_builder = UniqueMessageBuilder()
|
||||
self.dissemination_time = DisseminationTime(self.config.network.num_nodes)
|
||||
|
||||
async def run(self):
|
||||
# Run the simulation
|
||||
conn_stats, node_state_table = await self.__run()
|
||||
# Analyze the dissemination times
|
||||
self.dissemination_time.analyze()
|
||||
# Analyze the simulation results
|
||||
conn_stats.analyze()
|
||||
node_state_table.analyze()
|
||||
# Show plots
|
||||
if self.config.simulation.show_plots:
|
||||
pyplot.show()
|
||||
|
||||
async def __run(self) -> tuple[ConnectionStats, NodeStateTable]:
|
||||
# Initialize analysis tools
|
||||
node_state_table = NodeStateTable(
|
||||
self.config.network.num_nodes, self.config.simulation.duration_sec
|
||||
)
|
||||
conn_stats = ConnectionStats()
|
||||
|
||||
# Create a μSim scope and run the simulation
|
||||
async with usim.until(usim.time + self.config.simulation.duration_sec) as scope:
|
||||
self.framework = usimfw.Framework(scope)
|
||||
nodes = self.__init_nodes()
|
||||
self.__connect_nodes(nodes, node_state_table, conn_stats)
|
||||
for i, node in enumerate(nodes):
|
||||
print(f"Spawning node-{i} with {len(node.nomssip.conns)} conns")
|
||||
self.framework.spawn(self.__run_node_logic(node))
|
||||
|
||||
# Return analysis tools once the μSim scope is done
|
||||
return conn_stats, node_state_table
|
||||
|
||||
def __init_nodes(self) -> list[Node]:
|
||||
# Initialize node/global configurations
|
||||
node_configs = self.config.node_configs()
|
||||
global_config = GlobalConfig(
|
||||
MixMembership(
|
||||
[
|
||||
NodeInfo(node_config.private_key.public_key())
|
||||
for node_config in node_configs
|
||||
],
|
||||
self.config.mix.mix_path.seed,
|
||||
),
|
||||
self.config.mix.transmission_rate_per_sec,
|
||||
self.config.mix.max_message_size,
|
||||
self.config.mix.mix_path.max_length,
|
||||
)
|
||||
|
||||
# Initialize/return Node instances
|
||||
return [
|
||||
Node(
|
||||
self.framework,
|
||||
node_config,
|
||||
global_config,
|
||||
self.__process_broadcasted_msg,
|
||||
self.__process_recovered_msg,
|
||||
)
|
||||
for node_config in node_configs
|
||||
]
|
||||
|
||||
def __connect_nodes(
|
||||
self,
|
||||
nodes: list[Node],
|
||||
node_state_table: NodeStateTable,
|
||||
conn_stats: ConnectionStats,
|
||||
):
|
||||
topology = build_full_random_topology(
|
||||
self.config.network.topology.seed,
|
||||
len(nodes),
|
||||
self.config.network.gossip.peering_degree,
|
||||
)
|
||||
print("Topology:")
|
||||
pprint(topology)
|
||||
|
||||
meter_start_time = self.framework.now()
|
||||
# Sort the topology by node index for the connection RULE defined below.
|
||||
for node_idx, peer_indices in sorted(topology.items()):
|
||||
for peer_idx in peer_indices:
|
||||
# Since the topology is undirected, we only need to connect the two nodes once.
|
||||
# RULE: the node with the smaller index establishes the connection.
|
||||
assert node_idx != peer_idx
|
||||
if node_idx > peer_idx:
|
||||
continue
|
||||
|
||||
node = nodes[node_idx]
|
||||
peer = nodes[peer_idx]
|
||||
node_states = node_state_table[node_idx]
|
||||
peer_states = node_state_table[peer_idx]
|
||||
|
||||
# Connect the node and peer for Nomos Gossip
|
||||
inbound_conn, outbound_conn = (
|
||||
self.__create_observed_conn(
|
||||
meter_start_time, peer_states, node_states
|
||||
),
|
||||
self.__create_observed_conn(
|
||||
meter_start_time, node_states, peer_states
|
||||
),
|
||||
)
|
||||
node.connect_mix(peer, inbound_conn, outbound_conn)
|
||||
# Register the connections to the connection statistics
|
||||
conn_stats.register(node, inbound_conn, outbound_conn)
|
||||
conn_stats.register(peer, outbound_conn, inbound_conn)
|
||||
|
||||
# Connect the node and peer for broadcasting.
|
||||
node.connect_broadcast(
|
||||
peer,
|
||||
self.__create_conn(meter_start_time),
|
||||
self.__create_conn(meter_start_time),
|
||||
)
|
||||
|
||||
def __create_observed_conn(
|
||||
self,
|
||||
meter_start_time: float,
|
||||
sender_states: list[NodeState],
|
||||
receiver_states: list[NodeState],
|
||||
) -> ObservedMeteredRemoteSimplexConnection:
|
||||
return ObservedMeteredRemoteSimplexConnection(
|
||||
self.config.network.latency,
|
||||
self.framework,
|
||||
meter_start_time,
|
||||
sender_states,
|
||||
receiver_states,
|
||||
)
|
||||
|
||||
def __create_conn(
|
||||
self,
|
||||
meter_start_time: float,
|
||||
) -> MeteredRemoteSimplexConnection:
|
||||
return MeteredRemoteSimplexConnection(
|
||||
self.config.network.latency,
|
||||
self.framework,
|
||||
meter_start_time,
|
||||
)
|
||||
|
||||
async def __run_node_logic(self, node: Node):
|
||||
"""
|
||||
Runs the lottery periodically to check if the node is selected to send a block.
|
||||
If the node is selected, creates a block and sends it through mix nodes.
|
||||
"""
|
||||
lottery_config = self.config.logic.sender_lottery
|
||||
while True:
|
||||
await self.framework.sleep(lottery_config.interval_sec)
|
||||
if lottery_config.seed.random() < lottery_config.probability:
|
||||
msg = self.msg_builder.next(self.framework.now(), b"selected block")
|
||||
await node.send_message(bytes(msg))
|
||||
|
||||
async def __process_broadcasted_msg(self, msg: bytes):
|
||||
"""
|
||||
Process a broadcasted message originated from the last mix.
|
||||
"""
|
||||
message = Message.from_bytes(msg)
|
||||
elapsed = self.framework.now() - message.created_at
|
||||
self.dissemination_time.add_broadcasted_msg(message, elapsed)
|
||||
|
||||
async def __process_recovered_msg(self, msg: bytes) -> bytes:
|
||||
"""
|
||||
Process a message fully recovered by the last mix
|
||||
and returns a new message to be broadcasted.
|
||||
"""
|
||||
message = Message.from_bytes(msg)
|
||||
elapsed = self.framework.now() - message.created_at
|
||||
self.dissemination_time.add_mix_propagation_time(elapsed)
|
||||
|
||||
# Update the timestamp and return the message to be broadcasted,
|
||||
# so that the broadcast dissemination time can be calculated from now.
|
||||
message.created_at = self.framework.now()
|
||||
return bytes(message)
|
|
@ -0,0 +1,77 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
"""
|
||||
A state of node at a certain time.
|
||||
For now, we assume that the node cannot send and receive messages at the same time for simplicity.
|
||||
"""
|
||||
|
||||
SENDING = -1
|
||||
IDLE = 0
|
||||
RECEIVING = 1
|
||||
|
||||
|
||||
class NodeStateTable:
|
||||
def __init__(self, num_nodes: int, duration_sec: int):
|
||||
# Create a table to store the state of each node at each millisecond
|
||||
self.__table = [
|
||||
[NodeState.IDLE] * (duration_sec * 1000) for _ in range(num_nodes)
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int) -> list[NodeState]:
|
||||
return self.__table[idx]
|
||||
|
||||
def analyze(self):
|
||||
df = pandas.DataFrame(self.__table).transpose()
|
||||
df.columns = [f"Node-{i}" for i in range(len(self.__table))]
|
||||
# Convert NodeState enum to their integer values
|
||||
df = df.map(lambda state: state.value)
|
||||
print("==========================================")
|
||||
print("Node States of All Nodes over Time")
|
||||
print(", ".join(f"{state.name}:{state.value}" for state in NodeState))
|
||||
print("==========================================")
|
||||
print(f"{df}\n")
|
||||
|
||||
csv_path = f"all_node_states_{datetime.now().isoformat(timespec="seconds")}.csv"
|
||||
df.to_csv(csv_path)
|
||||
print(f"Saved DataFrame to {csv_path}\n")
|
||||
|
||||
# Count/print the number of each state for each node
|
||||
# because the df is usually too big to print
|
||||
state_counts = df.apply(pandas.Series.value_counts).fillna(0)
|
||||
print("State Counts per Node:")
|
||||
print(f"{state_counts}\n")
|
||||
|
||||
# Draw a dot plot
|
||||
plt.figure(figsize=(15, 8))
|
||||
for node in df.columns:
|
||||
times = df.index
|
||||
states = df[node]
|
||||
sending_times = times[states == NodeState.SENDING.value]
|
||||
receiving_times = times[states == NodeState.RECEIVING.value]
|
||||
plt.scatter(
|
||||
sending_times,
|
||||
[node] * len(sending_times),
|
||||
color="red",
|
||||
marker="o",
|
||||
s=10,
|
||||
label="SENDING" if node == df.columns[0] else "",
|
||||
)
|
||||
plt.scatter(
|
||||
receiving_times,
|
||||
[node] * len(receiving_times),
|
||||
color="blue",
|
||||
marker="x",
|
||||
s=10,
|
||||
label="RECEIVING" if node == df.columns[0] else "",
|
||||
)
|
||||
plt.xlabel("Time")
|
||||
plt.ylabel("Node")
|
||||
plt.title("Node States Over Time")
|
||||
plt.legend(loc="upper right")
|
||||
plt.draw()
|
|
@ -0,0 +1,152 @@
|
|||
from collections import Counter, defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy
|
||||
import pandas
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
from protocol.node import Node
|
||||
from sim.connection import ObservedMeteredRemoteSimplexConnection
|
||||
from sim.message import Message
|
||||
|
||||
# A map of nodes to their inbound/outbound connections
|
||||
NodeConnectionsMap = dict[
|
||||
Node,
|
||||
tuple[
|
||||
list[ObservedMeteredRemoteSimplexConnection],
|
||||
list[ObservedMeteredRemoteSimplexConnection],
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class ConnectionStats:
|
||||
def __init__(self):
|
||||
self.conns_per_node: NodeConnectionsMap = defaultdict(lambda: ([], []))
|
||||
|
||||
def register(
|
||||
self,
|
||||
node: Node,
|
||||
inbound_conn: ObservedMeteredRemoteSimplexConnection,
|
||||
outbound_conn: ObservedMeteredRemoteSimplexConnection,
|
||||
):
|
||||
self.conns_per_node[node][0].append(inbound_conn)
|
||||
self.conns_per_node[node][1].append(outbound_conn)
|
||||
|
||||
def analyze(self):
|
||||
self.__message_sizes()
|
||||
self.__bandwidths_per_conn()
|
||||
self.__bandwidths_per_node()
|
||||
|
||||
def __message_sizes(self):
|
||||
"""
|
||||
Analyzes all message sizes sent across all connections of all nodes.
|
||||
"""
|
||||
sizes: Counter[int] = Counter()
|
||||
for _, (_, outbound_conns) in self.conns_per_node.items():
|
||||
for conn in outbound_conns:
|
||||
sizes.update(conn.msg_sizes)
|
||||
|
||||
df = pandas.DataFrame.from_dict(sizes, orient="index").reset_index()
|
||||
df.columns = ["msg_size", "count"]
|
||||
print("==========================================")
|
||||
print("Message Size Distribution")
|
||||
print("==========================================")
|
||||
print(f"{df}\n")
|
||||
|
||||
def __bandwidths_per_conn(self):
|
||||
"""
|
||||
Analyzes the bandwidth consumed by each simplex connection.
|
||||
"""
|
||||
plt.plot(figsize=(12, 6))
|
||||
|
||||
for _, (_, outbound_conns) in self.conns_per_node.items():
|
||||
for conn in outbound_conns:
|
||||
sending_bandwidths = conn.sending_bandwidths().map(lambda x: x / 1024)
|
||||
plt.plot(sending_bandwidths.index, sending_bandwidths)
|
||||
|
||||
plt.title("Unidirectional Bandwidths per Connection")
|
||||
plt.xlabel("Time (s)")
|
||||
plt.ylabel("Bandwidth (KiB/s)")
|
||||
plt.ylim(bottom=0)
|
||||
plt.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.draw()
|
||||
|
||||
def __bandwidths_per_node(self):
|
||||
"""
|
||||
Analyzes the inbound/outbound bandwidths consumed by each node (sum of all its connections).
|
||||
"""
|
||||
_, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))
|
||||
assert isinstance(axs, numpy.ndarray)
|
||||
|
||||
for i, (_, (inbound_conns, outbound_conns)) in enumerate(
|
||||
self.conns_per_node.items()
|
||||
):
|
||||
inbound_bandwidths = (
|
||||
pandas.concat(
|
||||
[conn.receiving_bandwidths() for conn in inbound_conns], axis=1
|
||||
)
|
||||
.sum(axis=1)
|
||||
.map(lambda x: x / 1024)
|
||||
)
|
||||
outbound_bandwidths = (
|
||||
pandas.concat(
|
||||
[conn.sending_bandwidths() for conn in outbound_conns], axis=1
|
||||
)
|
||||
.sum(axis=1)
|
||||
.map(lambda x: x / 1024)
|
||||
)
|
||||
axs[0].plot(inbound_bandwidths.index, inbound_bandwidths, label=f"Node-{i}")
|
||||
axs[1].plot(
|
||||
outbound_bandwidths.index, outbound_bandwidths, label=f"Node-{i}"
|
||||
)
|
||||
|
||||
axs[0].set_title("Inbound Bandwidths per Node")
|
||||
axs[0].set_xlabel("Time (s)")
|
||||
axs[0].set_ylabel("Bandwidth (KiB/s)")
|
||||
axs[0].legend()
|
||||
axs[0].set_ylim(bottom=0)
|
||||
axs[0].grid(True)
|
||||
|
||||
axs[1].set_title("Outbound Bandwidths per Node")
|
||||
axs[1].set_xlabel("Time (s)")
|
||||
axs[1].set_ylabel("Bandwidth (KiB/s)")
|
||||
axs[1].legend()
|
||||
axs[1].set_ylim(bottom=0)
|
||||
axs[1].grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.draw()
|
||||
|
||||
|
||||
class DisseminationTime:
|
||||
def __init__(self, num_nodes: int):
|
||||
# A collection of time taken for a message to propagate through all mix nodes in its mix route
|
||||
self.mix_propagation_times: list[float] = []
|
||||
# A collection of time taken for a message to be broadcasted from the last mix to all nodes in the network
|
||||
self.broadcast_dissemination_times: list[float] = []
|
||||
# Data structures to check if a message has been broadcasted to all nodes
|
||||
self.broadcast_status: Counter[Message] = Counter()
|
||||
self.num_nodes: int = num_nodes
|
||||
|
||||
def add_mix_propagation_time(self, elapsed: float):
|
||||
self.mix_propagation_times.append(elapsed)
|
||||
|
||||
def add_broadcasted_msg(self, msg: Message, elapsed: float):
|
||||
assert self.broadcast_status[msg] < self.num_nodes
|
||||
self.broadcast_status.update([msg])
|
||||
if self.broadcast_status[msg] == self.num_nodes:
|
||||
self.broadcast_dissemination_times.append(elapsed)
|
||||
|
||||
def analyze(self):
|
||||
print("==========================================")
|
||||
print("Message Dissemination Time")
|
||||
print("==========================================")
|
||||
print("[Mix Propagation Times]")
|
||||
mix_propagation_times = pandas.Series(self.mix_propagation_times)
|
||||
print(mix_propagation_times.describe())
|
||||
print("")
|
||||
print("[Broadcast Dissemination Times]")
|
||||
broadcast_travel_times = pandas.Series(self.broadcast_dissemination_times)
|
||||
print(broadcast_travel_times.describe())
|
||||
print("")
|
|
@ -0,0 +1,104 @@
|
|||
import math
|
||||
import random
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
import usim
|
||||
|
||||
import framework.usim as usimfw
|
||||
from protocol.connection import LocalSimplexConnection
|
||||
from protocol.node import Node
|
||||
from protocol.test_utils import (
|
||||
init_mixnet_config,
|
||||
)
|
||||
from sim.config import LatencyConfig, NetworkConfig
|
||||
from sim.connection import (
|
||||
MeteredRemoteSimplexConnection,
|
||||
ObservedMeteredRemoteSimplexConnection,
|
||||
)
|
||||
from sim.state import NodeState, NodeStateTable
|
||||
|
||||
|
||||
class TestMeteredRemoteSimplexConnection(IsolatedAsyncioTestCase):
|
||||
async def test_latency(self):
|
||||
usim.run(self.__test_latency())
|
||||
|
||||
async def __test_latency(self):
|
||||
async with usim.Scope() as scope:
|
||||
framework = usimfw.Framework(scope)
|
||||
node_state_table = NodeStateTable(num_nodes=2, duration_sec=3)
|
||||
conn = MeteredRemoteSimplexConnection(
|
||||
LatencyConfig(
|
||||
min_latency_sec=0,
|
||||
max_latency_sec=1,
|
||||
seed=random.Random(),
|
||||
),
|
||||
framework,
|
||||
framework.now(),
|
||||
)
|
||||
|
||||
# Send two messages without delay
|
||||
sent_time = framework.now()
|
||||
await conn.send(b"hello")
|
||||
await conn.send(b"world")
|
||||
|
||||
# Receive two messages and check if the network latency was simulated well.
|
||||
# There should be no delay between the two messages because they were sent without delay.
|
||||
self.assertEqual(b"hello", await conn.recv())
|
||||
self.assertEqual(conn.latency, framework.now() - sent_time)
|
||||
self.assertEqual(b"world", await conn.recv())
|
||||
self.assertEqual(conn.latency, framework.now() - sent_time)
|
||||
|
||||
|
||||
class TestObservedMeteredRemoteSimplexConnection(IsolatedAsyncioTestCase):
|
||||
async def test_node_state(self):
|
||||
usim.run(self.__test_node_state())
|
||||
|
||||
async def __test_node_state(self):
|
||||
async with usim.Scope() as scope:
|
||||
framework = usimfw.Framework(scope)
|
||||
node_state_table = NodeStateTable(num_nodes=2, duration_sec=3)
|
||||
meter_start_time = framework.now()
|
||||
conn = ObservedMeteredRemoteSimplexConnection(
|
||||
LatencyConfig(
|
||||
min_latency_sec=0,
|
||||
max_latency_sec=1,
|
||||
seed=random.Random(),
|
||||
),
|
||||
framework,
|
||||
meter_start_time,
|
||||
node_state_table[0],
|
||||
node_state_table[1],
|
||||
)
|
||||
|
||||
# Sleep and send a message
|
||||
await framework.sleep(1)
|
||||
sent_time = framework.now()
|
||||
await conn.send(b"hello")
|
||||
|
||||
# Receive the message. It should be received after the latency.
|
||||
self.assertEqual(b"hello", await conn.recv())
|
||||
recv_time = framework.now()
|
||||
|
||||
# Check if the sender node state is SENDING at the sent time
|
||||
timeslot = math.floor((sent_time - meter_start_time) * 1000)
|
||||
self.assertEqual(
|
||||
NodeState.SENDING,
|
||||
node_state_table[0][timeslot],
|
||||
)
|
||||
# Ensure that the sender node states in other time slots are IDLE
|
||||
states = set()
|
||||
states.update(node_state_table[0][:timeslot])
|
||||
states.update(node_state_table[0][timeslot + 1 :])
|
||||
self.assertEqual(set([NodeState.IDLE]), states)
|
||||
|
||||
# Check if the receiver node state is RECEIVING at the received time
|
||||
timeslot = math.floor((recv_time - meter_start_time) * 1000)
|
||||
self.assertEqual(
|
||||
NodeState.RECEIVING,
|
||||
node_state_table[1][timeslot],
|
||||
)
|
||||
# Ensure that the receiver node states in other time slots are IDLE
|
||||
states = set()
|
||||
states.update(node_state_table[1][:timeslot])
|
||||
states.update(node_state_table[1][timeslot + 1 :])
|
||||
self.assertEqual(set([NodeState.IDLE]), states)
|
|
@ -0,0 +1,23 @@
|
|||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
from sim.message import Message, UniqueMessageBuilder
|
||||
|
||||
|
||||
class TestMessage(TestCase):
|
||||
def test_message_serde(self):
|
||||
msg = Message(time.time(), 10, b"hello")
|
||||
serialized = bytes(msg)
|
||||
deserialized = Message.from_bytes(serialized)
|
||||
self.assertEqual(msg, deserialized)
|
||||
|
||||
|
||||
class TestUniqueMessageBuilder(TestCase):
|
||||
def test_uniqueness(self):
|
||||
builder = UniqueMessageBuilder()
|
||||
msg1 = builder.next(time.time(), b"hello")
|
||||
msg2 = builder.next(time.time(), b"hello")
|
||||
self.assertEqual(0, msg1.id)
|
||||
self.assertEqual(1, msg2.id)
|
||||
self.assertNotEqual(msg1, msg2)
|
||||
self.assertNotEqual(hash(msg1), hash(msg2))
|
|
@ -0,0 +1,20 @@
|
|||
import random
|
||||
from unittest import TestCase
|
||||
|
||||
from sim.topology import are_all_nodes_connected, build_full_random_topology
|
||||
|
||||
|
||||
class TestTopology(TestCase):
|
||||
def test_full_random(self):
|
||||
num_nodes = 100
|
||||
peering_degree = 6
|
||||
topology = build_full_random_topology(
|
||||
random.Random(0), num_nodes, peering_degree
|
||||
)
|
||||
self.assertEqual(num_nodes, len(topology))
|
||||
self.assertTrue(are_all_nodes_connected(topology))
|
||||
for node, peers in topology.items():
|
||||
self.assertTrue(0 < len(peers) <= peering_degree)
|
||||
# Check if nodes are interconnected
|
||||
for peer in peers:
|
||||
self.assertIn(node, topology[peer])
|
|
@ -0,0 +1,58 @@
|
|||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
from protocol.node import Node
|
||||
|
||||
Topology = dict[int, set[int]]
|
||||
|
||||
|
||||
def build_full_random_topology(
|
||||
rng: random.Random, num_nodes: int, peering_degree: int
|
||||
) -> Topology:
|
||||
"""
|
||||
Generate a random undirected topology until all nodes are connected.
|
||||
We don't implement any artificial tool to ensure the connectivity of the topology.
|
||||
Instead, we regenerate a topology in a fully randomized way until all nodes are connected.
|
||||
"""
|
||||
while True:
|
||||
topology: Topology = defaultdict(set[int])
|
||||
nodes = list(range(num_nodes))
|
||||
for node in nodes:
|
||||
# Filter nodes that can be connected to the current node.
|
||||
others = []
|
||||
for other in nodes[:node] + nodes[node + 1 :]:
|
||||
# Check if the other node is not already connected to the current node
|
||||
# and the other node has not reached the peering degree.
|
||||
if (
|
||||
other not in topology[node]
|
||||
and len(topology[other]) < peering_degree
|
||||
):
|
||||
others.append(other)
|
||||
# How many more connections the current node needs
|
||||
n_needs = peering_degree - len(topology[node])
|
||||
# Sample peers as many as possible
|
||||
peers = rng.sample(others, k=min(n_needs, len(others)))
|
||||
# Connect the current node to the peers
|
||||
topology[node].update(peers)
|
||||
# Connect the peers to the current node, since the topology is undirected
|
||||
for peer in peers:
|
||||
topology[peer].update([node])
|
||||
|
||||
if are_all_nodes_connected(topology):
|
||||
return topology
|
||||
|
||||
|
||||
def are_all_nodes_connected(topology: Topology) -> bool:
|
||||
visited = set()
|
||||
|
||||
def dfs(topology: Topology, node: int) -> None:
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for peer in topology[node]:
|
||||
dfs(topology, peer)
|
||||
|
||||
# Start DFS from the first node
|
||||
dfs(topology, next(iter(topology)))
|
||||
|
||||
return len(visited) == len(topology)
|
Loading…
Reference in New Issue