diff --git a/mixnet/sim/__init__.py b/mixnet/sim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mixnet/sim/config.py b/mixnet/sim/config.py new file mode 100644 index 0000000..564df9a --- /dev/null +++ b/mixnet/sim/config.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import dacite +import yaml +from pysphinx.sphinx import X25519PrivateKey + +from mixnet.config import NodeConfig + + +@dataclass +class Config: + simulation: SimulationConfig + mixnet: MixnetConfig + + @classmethod + def load(cls, yaml_path: str) -> Config: + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + config = dacite.from_dict(data_class=Config, data=data) + + # Validations + config.simulation.validate() + config.mixnet.validate() + + return config + + def description(self): + return f"{self.simulation.description()}\n" f"{self.mixnet.description()}" + + +@dataclass +class SimulationConfig: + duration_sec: int + + def validate(self): + assert self.duration_sec > 0 + + def description(self): + return f"running_secs: {self.duration_sec}" + + +@dataclass +class MixnetConfig: + num_nodes: int + transmission_rate_per_sec: int + peering_degree: int + max_mix_path_length: int + + def validate(self): + assert self.num_nodes > 0 + assert self.transmission_rate_per_sec > 0 + assert self.peering_degree > 0 + assert self.max_mix_path_length > 0 + + def description(self): + return ( + f"num_nodes: {self.num_nodes}\n" + f"transmission_rate_per_sec: {self.transmission_rate_per_sec}\n" + f"peering_degree: {self.peering_degree}\n" + f"max_mix_path_length: {self.max_mix_path_length}\n" + ) + + def node_configs(self) -> list[NodeConfig]: + return [ + NodeConfig( + X25519PrivateKey.generate(), + self.peering_degree, + self.transmission_rate_per_sec, + ) + for _ in range(self.num_nodes) + ] diff --git a/mixnet/sim/config.yaml b/mixnet/sim/config.yaml new file mode 100644 index 0000000..d3973cd --- /dev/null +++ b/mixnet/sim/config.yaml @@ -0,0 +1,8 @@ +simulation: + duration_sec: 10 + +mixnet: + num_nodes: 5 + transmission_rate_per_sec: 1 + peering_degree: 6 + max_mix_path_length: 3 diff --git a/mixnet/sim/main.py b/mixnet/sim/main.py new file mode 100644 index 0000000..11389d5 --- /dev/null +++ b/mixnet/sim/main.py @@ -0,0 +1,21 @@ +import argparse +import asyncio + +from mixnet.sim.config import Config +from mixnet.sim.simulation import Simulation + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run mixnet simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, required=True, help="Configuration file path" + ) + args = parser.parse_args() + + config = Config.load(args.config) + sim = Simulation(config) + asyncio.run(sim.run()) + + print("Simulation complete!") diff --git a/mixnet/sim/simulation.py b/mixnet/sim/simulation.py new file mode 100644 index 0000000..fc5dcf2 --- /dev/null +++ b/mixnet/sim/simulation.py @@ -0,0 +1,31 @@ +import asyncio +import random + +from mixnet.config import GlobalConfig, MixMembership, NodeInfo +from mixnet.node import Node +from mixnet.sim.config import Config + + +class Simulation: + def __init__(self, config: Config): + random.seed() + self.config = config + + async def run(self): + # Initialize mixnet nodes and establish connections + node_configs = self.config.mixnet.node_configs() + global_config = GlobalConfig( + MixMembership( + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ] + ), + self.config.mixnet.transmission_rate_per_sec, + self.config.mixnet.max_mix_path_length, + ) + nodes = [Node(node_config, global_config) for node_config in node_configs] + for i, node in enumerate(nodes): + node.connect(nodes[(i + 1) % len(nodes)]) + + await asyncio.sleep(self.config.simulation.duration_sec) diff --git a/mixnet/v2/sim/2024-05-25T22:29:13.csv b/mixnet/v2/sim/2024-05-25T22:29:13.csv new file mode 100644 index 0000000..edb6242 --- /dev/null +++ b/mixnet/v2/sim/2024-05-25T22:29:13.csv @@ -0,0 +1,37 @@ +num_nodes,config,ingress_mean,ingress_max,egress_mean,egress_max +10,1-to-all: 0: 0.0,0.390625,0.625,3.125,3.125 +10,1-to-all: 0: 0.02,0.3125,0.3125,3.125,3.125 +10,1-to-all: 2: 0.0,4.125,4.125,16.5,27.5 +10,1-to-all: 2: 0.02,2.75,4.125,13.75,13.75 +10,1-to-all: 4: 0.0,12.03125,12.03125,25.32894736842105,48.125 +10,1-to-all: 4: 0.02,12.03125,12.03125,25.32894736842105,48.125 +200,1-to-all: 0: 0.0,0.6,1.5625,62.5,62.5 +200,1-to-all: 0: 0.02,1.4732142857142858,2.8125,62.5,62.5 +200,1-to-all: 2: 0.0,10.3125,20.625,276.41752577319585,550.0 +200,1-to-all: 2: 0.02,10.525862068965518,20.625,278.7671232876712,550.0 +200,1-to-all: 4: 0.0,24.525240384615383,48.125,496.2305447470817,962.5 +200,1-to-all: 4: 0.02,26.717672413793103,72.1875,493.5111464968153,962.5 +400,1-to-all: 0: 0.0,1.3038793103448276,2.8125,125.0,125.0 +400,1-to-all: 0: 0.02,2.726293103448276,4.375,125.0,125.0 +400,1-to-all: 2: 0.0,16.78448275862069,41.25,562.7167630057803,1100.0 +400,1-to-all: 2: 0.02,23.232758620689655,49.5,562.6304801670146,1100.0 +400,1-to-all: 4: 0.0,49.369612068965516,132.34375,994.2491319444445,2887.5 +400,1-to-all: 4: 0.02,59.741379310344826,129.9375,1005.8055152394775,2887.5 +600,1-to-all: 0: 0.0,1.918103448275862,3.75,187.5,187.5 +600,1-to-all: 0: 0.02,3.9331896551724137,6.875,187.5,187.5 +600,1-to-all: 2: 0.0,25.03448275862069,49.5,831.2977099236641,1650.0 +600,1-to-all: 2: 0.02,33.23706896551724,68.75,839.3686502177068,1650.0 +600,1-to-all: 4: 0.0,71.77262931034483,192.5,1488.4907628128724,4331.25 +600,1-to-all: 4: 0.02,79.82112068965517,180.46875,1482.2705442902882,2887.5 +800,1-to-all: 0: 0.0,2.456896551724138,4.375,250.0,250.0 +800,1-to-all: 0: 0.02,5.226293103448276,7.5,250.0,250.0 +800,1-to-all: 2: 0.0,32.28879310344828,49.5,1114.732142857143,2200.0 +800,1-to-all: 2: 0.02,41.25,72.875,1127.208480565371,3300.0 +800,1-to-all: 4: 0.0,95.00538793103448,156.40625,1971.489266547406,3850.0 +800,1-to-all: 4: 0.02,122.13793103448276,192.5,1980.1537386443047,5775.0 +1000,1-to-all: 0: 0.0,2.877155172413793,5.9375,312.5,312.5 +1000,1-to-all: 0: 0.02,6.228448275862069,9.375,312.5,312.5 +1000,1-to-all: 2: 0.0,45.23275862068966,74.25,1388.095238095238,2750.0 +1000,1-to-all: 2: 0.02,52.866379310344826,74.25,1392.4841053587647,4125.0 +1000,1-to-all: 4: 0.0,120.72737068965517,192.5,2458.6332514044943,4812.5 +1000,1-to-all: 4: 0.02,136.57543103448276,223.78125,2461.58328154133,7218.75 diff --git a/mixnet/v2/sim/2024-05-25T23:16:39.csv b/mixnet/v2/sim/2024-05-25T23:16:39.csv new file mode 100644 index 0000000..1a7d683 --- /dev/null +++ b/mixnet/v2/sim/2024-05-25T23:16:39.csv @@ -0,0 +1,37 @@ +num_nodes,config,ingress_mean,ingress_max,egress_mean,egress_max +10,gossip: 0: 0.0,1.65625,2.5,1.65625,1.875 +10,gossip: 0: 0.02,1.6875,2.5,1.6875,1.875 +10,gossip: 2: 0.0,21.93125,33.0,21.93125,23.375 +10,gossip: 2: 0.02,7.5625,12.375,7.5625,8.25 +10,gossip: 4: 0.0,72.66875,105.875,72.66875,84.21875 +10,gossip: 4: 0.02,46.248125,93.84375,46.248125,72.1875 +200,gossip: 0: 0.0,4.8561875,24.0625,4.8561875,11.25 +200,gossip: 0: 0.02,8.74229525862069,37.5,8.74229525862069,18.75 +200,gossip: 2: 0.0,60.64959202175884,288.75,60.815,123.75 +200,gossip: 2: 0.02,67.73155172413793,330.0,67.73155172413793,165.0 +200,gossip: 4: 0.0,169.6930894308943,620.8125,174.26546489563566,360.9375 +200,gossip: 4: 0.02,183.93375,690.59375,184.40130926724137,404.25 +400,gossip: 0: 0.0,6.451481681034482,32.8125,6.451481681034482,13.125 +400,gossip: 0: 0.02,15.15870150862069,64.0625,15.15870150862069,28.125 +400,gossip: 2: 0.0,103.01286637931034,474.375,103.08695043103448,222.75 +400,gossip: 2: 0.02,133.36148706896552,496.375,133.36207974137932,255.75 +400,gossip: 4: 0.0,238.6203448275862,1217.5625,239.8322528668736,664.125 +400,gossip: 4: 0.02,348.4793480603448,1241.625,349.7957327586207,678.5625 +600,gossip: 0: 0.0,10.845707614942528,48.125,10.845707614942528,20.625 +600,gossip: 0: 0.02,21.99876077586207,85.0,21.99876077586207,31.875 +600,gossip: 2: 0.0,134.6136063218391,833.25,134.6136063218391,297.0 +600,gossip: 2: 0.02,186.97431752873564,603.625,186.98158764367815,288.75 +600,gossip: 4: 0.0,468.0078807471264,1722.875,469.76112428160917,851.8125 +600,gossip: 4: 0.02,596.8570366379311,2057.34375,598.305626795977,1025.0625 +800,gossip: 0: 0.0,13.49982489224138,70.0,13.49982489224138,30.0 +800,gossip: 0: 0.02,31.12440732758621,96.25,31.12440732758621,41.25 +800,gossip: 2: 0.0,204.4971713362069,820.875,204.60337823275862,363.0 +800,gossip: 2: 0.02,263.0346551724138,1427.25,263.1884536637931,486.75 +800,gossip: 4: 0.0,510.1774811422414,1944.25,511.705864762931,895.125 +800,gossip: 4: 0.02,677.2293130387931,2057.34375,678.0591581357759,981.75 +1000,gossip: 0: 0.0,17.376422413793104,89.0625,17.376422413793104,35.625 +1000,gossip: 0: 0.02,35.141433189655174,138.125,35.141433189655174,63.75 +1000,gossip: 2: 0.0,248.81137068965518,979.0,249.05687931034484,404.25 +1000,gossip: 2: 0.02,327.9825431034483,1344.75,328.2077586206897,552.75 +1000,gossip: 4: 0.0,710.1483480603448,2632.4375,711.7312456896552,1212.75 +1000,gossip: 4: 0.02,836.671213362069,2622.8125,839.1504806034483,1284.9375 diff --git a/mixnet/v2/sim/2024-05-27T14:14:58.csv b/mixnet/v2/sim/2024-05-27T14:14:58.csv new file mode 100644 index 0000000..65baa35 --- /dev/null +++ b/mixnet/v2/sim/2024-05-27T14:14:58.csv @@ -0,0 +1,6 @@ +num_nodes,num_mix_layers,p2p_type,real_message_prob,cover_message_prob,egress_mean,egress_max,ingress_mean,ingress_max +1000,4,gossip,0.03,0.06,2541.7524924568966,3335.0625,2537.05615625,7637.4375 +1000,4,gossip,0.03,0.12,3481.2558782327587,4287.9375,3475.6617618534483,9107.65625 +1000,4,gossip,0.03,0.18,4284.96495150862,4908.75,4280.065577586207,11877.25 +1000,4,gossip,0.03,0.24,5208.071651939656,5962.6875,5202.543334051724,12478.8125 +1000,4,gossip,0.03,0.30,6017.549585129311,7132.125,6013.211033405172,15058.3125 \ No newline at end of file diff --git a/mixnet/v2/sim/2024-05-27T15:04:51.csv b/mixnet/v2/sim/2024-05-27T15:04:51.csv new file mode 100644 index 0000000..8c6fc0d --- /dev/null +++ b/mixnet/v2/sim/2024-05-27T15:04:51.csv @@ -0,0 +1,6 @@ +num_nodes,num_mix_layers,p2p_type,real_message_prob,cover_message_prob,egress_mean,egress_max,ingress_mean,ingress_max +10,4,gossip,0.01,0.02,36.4375,79.40625,36.4375,98.65625 +10,4,gossip,0.01,0.04,14.718229166666667,28.875,14.718229166666667,33.6875 +10,4,gossip,0.01,0.06,27.848333333333333,98.65625,27.848333333333333,113.09375 +10,4,gossip,0.01,0.08,19.75955882352941,79.40625,19.75955882352941,91.4375 +10,4,gossip,0.01,0.1,27.62121710526316,98.65625,27.62121710526316,134.75 diff --git a/mixnet/v2/sim/2024-05-27T15:55:35.csv b/mixnet/v2/sim/2024-05-27T15:55:35.csv new file mode 100644 index 0000000..4253b04 --- /dev/null +++ b/mixnet/v2/sim/2024-05-27T15:55:35.csv @@ -0,0 +1,6 @@ +num_nodes,num_mix_layers,p2p_type,real_message_prob,cover_message_prob,egress_mean,egress_max,ingress_mean,ingress_max +1000,4,gossip,0.01,0.02,789.9003512931034,1270.5,786.5745818965518,2697.40625 +1000,4,gossip,0.01,0.04,1040.332811422414,1660.3125,1038.5553394396552,4716.25 +1000,4,gossip,0.01,0.06,1408.3879989224138,1992.375,1406.3984450431035,5681.15625 +1000,4,gossip,0.01,0.08,1721.6532887931035,2252.25,1720.1033318965517,6102.25 +1000,4,gossip,0.01,0.1,1984.492207974138,2483.25,1981.9849784482758,5625.8125 diff --git a/mixnet/v2/sim/README.md b/mixnet/v2/sim/README.md new file mode 100644 index 0000000..64577ab --- /dev/null +++ b/mixnet/v2/sim/README.md @@ -0,0 +1,60 @@ +# Mixnet v2 Simulation + +* [How to Run](#how-to-run) + + [Time in simulation](#time-in-simulation) +* [Mixnet Functionalities](#mixnet-functionalities) +* [Adversary Models](#adversary-models) + +## How to Run + +First, make sure that all dependencies specified in the `requirements.txt` in the project root. +Then, configure parameters in the [`config.yaml`](./config.yaml), and run the following command: +```bash +python main.py --config ./config.yaml +``` +The simulation runs during a specified duration, prints the results to the console, and show some plots. + +### Time in simulation + +The simulation is implemented based on [SimPy](https://simpy.readthedocs.io/en/latest/) which is discrete-event simulation framework. +All events are processed sequentially by a single thread. +However, multiple parallel events, which should be processed at the same time, can be also simulated by scheduling them at the same "time". + +The simulation has the virtual time concept, which doesn't have the same scale as the real time. +If the event A is scheduled to happen at time 10, and the event B is scheduled to happen at time 11, +the simulation guarantees that the event B happens only after the event A happens. +But, it doesn't mean that the event B happens exactly 1 second after the event A. It will be much shorter. + +If two events are scheduled at the same time (e.g. 10), the simulation processes the one that is scheduled first and the other one next (FIFO). +But, it is guarantees that those two event are processed before the events scheduled at time 11+. + +Using this virtual time, complex distributed systems can be simulated in a simple way without worrying about the real-time synchronization. +For more details, please see the [Time and Scheduling](https://simpy.readthedocs.io/en/latest/topical_guides/time_and_scheduling.html#what-is-time) section in the SimPy documentation. + +## Progresses + +### Mixnet Functionalities +- Modified Sphinx + - [x] Without encryption + - [ ] With encryption +- P2P Broadcasting + - [x] Naive 1-to-all + - [x] More realistic broadcasting (e.g. gossipsub) +- [x] Sending a real message to the mixnet at the promised interval + - Each node has its own probability of sending a real message at each interval. +- [x] Cover traffic + - All nodes have the same probability of sending a cover message at each interval. +- [x] Forwarding messages through mixes, and then broadcasting messages to all nodes if the message is real. +- Mix delays + - [x] Naive random delays + - [ ] More sophisticated delays (e.g. Poisson) if necessary + +### Performance Measurements + +- [x] Bandwidth Usage + +### [Adversary Models](https://www.notion.so/Mixnet-v2-Proof-of-Concept-102d0563e75345a3a6f1c11791fbd746?pvs=4#c5ffa49486ce47ed81d25028bc0d9d40) +- [x] Inspecting message sizes to analyze how far each message has traveled since emitted by the original sender. +- [x] Identifying nodes emitting messages around the promised interval. +- [ ] Correlating senders-receivers based on timing +- [ ] Active attacks \ No newline at end of file diff --git a/mixnet/v2/sim/adversary.py b/mixnet/v2/sim/adversary.py new file mode 100644 index 0000000..6ce63ef --- /dev/null +++ b/mixnet/v2/sim/adversary.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from collections import defaultdict, deque, Counter +from enum import Enum +from typing import TYPE_CHECKING + +from config import Config +from environment import Environment, Time +from sphinx import SphinxPacket + +if TYPE_CHECKING: + from node import Node + + +class Adversary: + def __init__(self, env: Environment, config: Config): + self.env = env + self.config = config + self.message_sizes = [] + self.senders_around_interval = Counter() + self.msg_pools_per_time = [] # list[dict[receiver, deque[time_received])]] + self.msg_pools_per_time.append(defaultdict(lambda: deque())) + self.msgs_received_per_time = [] # list[dict[receiver, dict[sender, list[time_sent]]]] + self.msgs_received_per_time.append(defaultdict(lambda: defaultdict(list))) + # dict[receiver, dict[time, list[(sender, time_sent, origin_id)]]] + self.final_msgs_received = defaultdict(lambda: defaultdict(list)) + # self.node_states = defaultdict(dict) + + self.env.process(self.update_observation_time()) + + def inspect_message_size(self, msg: SphinxPacket | bytes): + self.message_sizes.append(len(msg)) + + def observe_receiving_node(self, sender: "Node", receiver: "Node", time_sent: Time): + self.msg_pools_per_time[-1][receiver].append(self.env.now()) + self.msgs_received_per_time[-1][receiver][sender].append(time_sent) + # if node not in self.node_states[self.env.now]: + # self.node_states[self.env.now][node] = NodeState.RECEIVING + + def observe_sending_node(self, sender: "Node"): + msg_pool = self.msg_pools_per_time[-1][sender] + if len(msg_pool) > 0: + # Adversary doesn't know which message in the pool is being emitted. So, pop the oldest one from the pool. + msg_pool.popleft() + if self.is_around_message_interval(self.env.now()): + self.senders_around_interval.update({sender}) + # self.node_states[self.env.now][node] = NodeState.SENDING + + def observe_if_final_msg(self, sender: "Node", receiver: "Node", time_sent: Time, msg: SphinxPacket | bytes): + origin_id = receiver.inspect_message(msg) + if origin_id is not None: + cur_time = len(self.msgs_received_per_time) - 1 + self.final_msgs_received[receiver][cur_time].append((sender, time_sent, origin_id)) + + def is_around_message_interval(self, time: Time) -> bool: + return time % self.config.mixnet.message_interval <= self.config.mixnet.max_message_prep_time + + def update_observation_time(self): + while True: + yield self.env.timeout(1) + + self.msgs_received_per_time.append(defaultdict(lambda: defaultdict(list))) + + new_msg_pool = defaultdict(lambda: deque()) + for receiver, msg_queue in self.msg_pools_per_time[-1].items(): + for time_received in msg_queue: + # If the message is likely to be still pending and be emitted soon, + # pass it on to the next time slot. + if self.env.now() - time_received < self.config.mixnet.max_mix_delay: + new_msg_pool[receiver][0].append(time_received) + self.msg_pools_per_time.append(new_msg_pool) + + +class NodeState(Enum): + SENDING = 0 + RECEIVING = 1 diff --git a/mixnet/v2/sim/analysis.py b/mixnet/v2/sim/analysis.py new file mode 100644 index 0000000..055f1f2 --- /dev/null +++ b/mixnet/v2/sim/analysis.py @@ -0,0 +1,352 @@ +import itertools +import multiprocessing +import sys +import threading +from collections import Counter +from typing import TYPE_CHECKING + +import pandas as pd +import seaborn +from matplotlib import pyplot as plt + +from adversary import NodeState +from config import Config +from environment import Time +from simulation import Simulation + +if TYPE_CHECKING: + from node import Node + +COL_TIME = "Time" +COL_NODE_ID = "Node ID" +COL_MSG_CNT = "Message Count" +COL_SENDER_CNT = "Sender Count" +COL_NODE_STATE = "Node State" +COL_HOPS = "Hops" +COL_EXPECTED = "Expected" +COL_MSG_SIZE = "Message Size" +COL_EGRESS = "Egress" +COL_INGRESS = "Ingress" +COL_SUCCESS_RATE = "Success Rate (%)" + + +class Analysis: + def __init__(self, sim: Simulation, config: Config, show_plots: bool = True): + self.sim = sim + self.config = config + self.show_plots = show_plots + + def run(self): + message_size_df = self.message_size_distribution() + self.bandwidth(message_size_df) + self.messages_emitted_around_interval() + self.messages_in_node_over_time() + # self.node_states() + median_hops = self.message_hops() + self.timing_attack(median_hops) + + def bandwidth(self, message_size_df: pd.DataFrame): + if not self.show_plots: + return + + dataframes = [] + nonzero_egresses = [] + nonzero_ingresses = [] + for egress_bandwidths, ingress_bandwidths in zip(self.sim.p2p.measurement.egress_bandwidth_per_sec, + self.sim.p2p.measurement.ingress_bandwidth_per_sec): + rows = [] + for node in self.sim.p2p.nodes: + egress = egress_bandwidths[node] / 1024.0 + ingress = ingress_bandwidths[node] / 1024.0 + rows.append((node.id, egress, ingress)) + if egress > 0: + nonzero_egresses.append(egress) + if ingress > 0: + nonzero_ingresses.append(ingress) + df = pd.DataFrame(rows, columns=[COL_NODE_ID, COL_EGRESS, COL_INGRESS]) + dataframes.append(df) + + times = range(len(dataframes)) + df = pd.concat([df.assign(Time=time) for df, time in zip(dataframes, times)], ignore_index=True) + df = df.pivot(index=COL_TIME, columns=COL_NODE_ID, values=[COL_EGRESS, COL_INGRESS]) + plt.figure(figsize=(12, 6)) + for column in df.columns: + marker = "x" if column[0] == COL_INGRESS else "o" + plt.plot(df.index, df[column], marker=marker, label=column[0]) + plt.title("Egress/ingress bandwidth of each node over time") + plt.xlabel(COL_TIME) + plt.ylabel("Bandwidth (KiB/s)") + plt.ylim(bottom=0) + # Customize the legend to show only "egress" and "ingress" regardless of node_id + handles, labels = plt.gca().get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + plt.legend(by_label.values(), by_label.keys()) + plt.grid(True) + + # Adding descriptions on the right size of the plot + egress_series = pd.Series(nonzero_egresses) + ingress_series = pd.Series(nonzero_ingresses) + desc = ( + f"message: {message_size_df[COL_MSG_SIZE].mean():.0f} bytes\n" + f"{self.config.description()}\n\n" + f"[egress(>0)]\nmean: {egress_series.mean():.2f} KiB/s\nmax: {egress_series.max():.2f} KiB/s\n\n" + f"[ingress(>0)]\nmean: {ingress_series.mean():.2f} KiB/s\nmax: {ingress_series.max():.2f} KiB/s" + ) + plt.text(1.02, 0.5, desc, transform=plt.gca().transAxes, verticalalignment="center", fontsize=12) + plt.subplots_adjust(right=0.8) # Adjust layout to make room for the text + + plt.show() + + def message_size_distribution(self) -> pd.DataFrame: + df = pd.DataFrame(self.sim.p2p.adversary.message_sizes, columns=[COL_MSG_SIZE]) + print(df.describe()) + return df + + def messages_emitted_around_interval(self) -> (float, float, float): + # A ground truth that shows how many times each node sent a real message + truth_df = pd.DataFrame( + [(node.id, count) for node, count in self.sim.p2p.measurement.original_senders.items()], + columns=[COL_NODE_ID, COL_MSG_CNT]) + # A result of observing nodes who have sent messages around the promised message interval + suspected_df = pd.DataFrame( + [(node.id, self.sim.p2p.adversary.senders_around_interval[node]) for node in + self.sim.p2p.measurement.original_senders.keys()], + columns=[COL_NODE_ID, COL_MSG_CNT] + ) + + if self.show_plots: + width = 0.4 + fig, ax = plt.subplots(figsize=(12, 8)) + ax.bar(truth_df[COL_NODE_ID] - width / 2, truth_df[COL_MSG_CNT], width, label="Ground Truth", color="b") + ax.bar(truth_df[COL_NODE_ID] + width / 2, suspected_df[COL_MSG_CNT], width, label="Adversary's Inference", + color="r") + ax.set_title("Nodes who generated real messages") + ax.set_xlabel(COL_NODE_ID) + ax.set_ylabel(COL_MSG_CNT) + ax.set_xlim(-1, len(truth_df[COL_NODE_ID])) + ax.legend() + plt.tight_layout() + plt.show() + + # Calculate precision, recall, and F1 score + truth = set(truth_df[truth_df[COL_MSG_CNT] > 0][COL_NODE_ID]) + suspected = set(suspected_df[suspected_df[COL_MSG_CNT] > 0][COL_NODE_ID]) + true_positives = truth.intersection(suspected) + precision = len(true_positives) / len(suspected) * 100.0 if len(suspected) > 0 else 0.0 + recall = len(true_positives) / len(truth) * 100.0 if len(truth) > 0 else 0.0 + f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0 + print(f"Precision: {precision:.2f}%, Recall: {recall:.2f}%, F1 Score: {f1_score:.2f}%") + return precision, recall, f1_score + + def messages_in_node_over_time(self): + if not self.show_plots: + return + + dataframes = [] + for time, msg_pools in enumerate(self.sim.p2p.adversary.msg_pools_per_time): + data = [] + for receiver, msg_pool in msg_pools.items(): + senders = self.sim.p2p.adversary.msgs_received_per_time[time][receiver].keys() + data.append((time, receiver.id, len(msg_pool), len(senders))) + df = pd.DataFrame(data, columns=[COL_TIME, COL_NODE_ID, COL_MSG_CNT, COL_SENDER_CNT]) + if not df.empty: + dataframes.append(df) + df = pd.concat(dataframes, ignore_index=True) + + msg_cnt_df = df.pivot(index=COL_TIME, columns=COL_NODE_ID, values=COL_MSG_CNT) + plt.figure(figsize=(12, 6)) + for column in msg_cnt_df.columns: + plt.plot(msg_cnt_df.index, msg_cnt_df[column], marker=None, label=column) + plt.title("Messages within each node over time") + plt.xlabel(COL_TIME) + plt.ylabel(COL_MSG_CNT) + plt.ylim(bottom=0) + plt.grid(True) + plt.tight_layout() + plt.show() + + sender_cnt_df = df.pivot(index=COL_TIME, columns=COL_NODE_ID, values=COL_SENDER_CNT) + plt.figure(figsize=(12, 6)) + for column in sender_cnt_df.columns: + plt.plot(sender_cnt_df.index, sender_cnt_df[column], marker=None, label=column) + plt.title("Diversity of senders of messages received by each node over time") + plt.xlabel(COL_TIME) + plt.ylabel("# of senders of messages received by each node") + plt.ylim(bottom=0) + plt.grid(True) + plt.tight_layout() + plt.show() + + plt.figure(figsize=(12, 6)) + df.boxplot(column=COL_SENDER_CNT, by=COL_TIME, medianprops={"color": "red", "linewidth": 2.5}) + plt.title("Diversity of senders of messages received by each node over time") + plt.suptitle("") + plt.xticks([]) + plt.xlabel(COL_TIME) + plt.ylabel("# of senders of messages received by each node") + plt.ylim(bottom=0) + plt.grid(axis="x") + plt.tight_layout() + plt.show() + + def node_states(self): + if not self.show_plots: + return + + rows = [] + for time, node_states in self.sim.p2p.adversary.node_states.items(): + for node, state in node_states.items(): + rows.append((time, node.id, state)) + df = pd.DataFrame(rows, columns=[COL_TIME, COL_NODE_ID, COL_NODE_STATE]) + + plt.figure(figsize=(10, 6)) + seaborn.scatterplot(data=df, x=COL_TIME, y=COL_NODE_ID, hue=COL_NODE_STATE, + palette={NodeState.SENDING: "red", NodeState.RECEIVING: "blue"}) + plt.title("Node states over time") + plt.xlabel(COL_TIME) + plt.ylabel(COL_NODE_ID) + plt.legend(title=COL_NODE_STATE) + plt.show() + + def message_hops(self) -> int: + df = pd.DataFrame(self.sim.p2p.measurement.message_hops.values(), columns=[COL_HOPS]) + print(df.describe()) + if self.show_plots: + plt.figure(figsize=(6, 6)) + seaborn.boxplot(data=df, y=COL_HOPS, medianprops={"color": "red", "linewidth": 2.5}) + plt.ylim(bottom=0) + plt.title("The distribution of max hops of single broadcasting") + plt.show() + return int(df.median().iloc[0]) + + def timing_attack(self, hops_between_layers: int) -> pd.DataFrame: + success_rates = self.spawn_timing_attacks(hops_between_layers) + df = pd.DataFrame(success_rates, columns=[COL_SUCCESS_RATE]) + print(df.describe()) + + if self.show_plots: + plt.figure(figsize=(6, 6)) + plt.boxplot(df[COL_SUCCESS_RATE], vert=True, patch_artist=True, boxprops=dict(facecolor="lightblue"), + medianprops=dict(color="orange")) + mean = df[COL_SUCCESS_RATE].mean() + median = df[COL_SUCCESS_RATE].median() + plt.axhline(mean, color="red", linestyle="--", linewidth=1, label=f"Mean: {mean:.2f}%") + plt.axhline(median, color="orange", linestyle="-", linewidth=1, label=f"Median: {median:.2f}%") + plt.ylabel(COL_SUCCESS_RATE) + plt.ylim(-5, 105) + plt.title("Timing attack success rate distribution") + plt.legend() + plt.grid(True) + plt.show() + + return df + + def spawn_timing_attacks(self, hops_between_layers: int) -> list[float]: + tasks = self.prepare_timing_attack_tasks(hops_between_layers) + print(f"{len(tasks)} TASKS") + + # Spawn process for each task + processes = [] + accuracy_results = multiprocessing.Manager().list() + for task in tasks: + process = multiprocessing.Process(target=self.spawn_timing_attack, args=(task, accuracy_results)) + process.start() + processes.append(process) + + # Join processes using threading to apply a timeout to all processes almost simultaneously. + threads = [] + for process in processes: + thread = threading.Thread(target=Analysis.join_process, + args=(process, self.config.adversary.timing_attack_timeout)) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + + return list(accuracy_results) + + def spawn_timing_attack(self, task, accuracy_results): + origin_id, receiver, time_received, remaining_hops, observed_hops, senders = task + result = self.run_and_evaluate_timing_attack( + origin_id, receiver, time_received, remaining_hops, observed_hops, senders + ) + accuracy_results.append(result) + print(f"{len(accuracy_results)} PROCESSES DONE") + + @staticmethod + def join_process(process, timeout): + process.join(timeout) + if process.is_alive(): + process.terminate() + process.join() + print(f"PROCESS TIMED OUT") + + def prepare_timing_attack_tasks(self, hops_between_layers: int) -> list: + hops_to_observe = hops_between_layers * (self.config.mixnet.num_mix_layers + 1) + tasks = [] + + # Prepare a task for each real message received by the adversary + for receiver, times_and_msgs in self.sim.p2p.adversary.final_msgs_received.items(): + for time_received, msgs in times_and_msgs.items(): + for sender, time_sent, origin_id in msgs: + tasks.append(( + origin_id, receiver, time_received, hops_to_observe, 0, {sender: [time_sent]} + )) + if len(tasks) >= self.config.adversary.timing_attack_max_targets: + return tasks + + return tasks + + def run_and_evaluate_timing_attack(self, origin_id: int, receiver: "Node", time_received: Time, + remaining_hops: int, observed_hops: int, + senders: dict["Node", list[Time]] = None) -> float: + suspected_origins = self.timing_attack_from_receiver( + receiver, time_received, remaining_hops, observed_hops, Counter(), senders + ) + if origin_id in suspected_origins: + return 1 / len(suspected_origins) * 100.0 + else: + return 0.0 + + def timing_attack_from_receiver(self, receiver: "Node", time_received: Time, + remaining_hops: int, observed_hops: int, suspected_origins: Counter, + senders: dict["Node", list[Time]] = None) -> Counter: + if remaining_hops <= 0: + return suspected_origins + + # If all nodes are already suspected, no need to inspect further. + if len(suspected_origins) == self.config.mixnet.num_nodes: + return suspected_origins + + # Start inspecting senders who sent messages that were arrived in the receiver at the given time. + # If the specific sender is given, inspect only that sender to maximize the success rate. + if senders is None: + senders = self.sim.p2p.adversary.msgs_received_per_time[time_received][receiver] + + senders = dict(itertools.islice(senders.items(), self.config.adversary.timing_attack_max_pool_size)) + + # Inspect each sender who sent messages to the receiver + for sender, times_sent in senders.items(): + # Calculate the time range where the sender might have received any messages + # related to the message being traced. + min_time, max_time = sys.maxsize, 0 + for time_sent in times_sent: + min_time = min(min_time, time_sent - self.config.mixnet.max_mix_delay) + max_time = max(max_time, time_sent - self.config.mixnet.min_mix_delay) + # If the sender is sent the message around the message interval, suspect the sender as the origin. + if (self.sim.p2p.adversary.is_around_message_interval(time_sent) + and observed_hops + 1 >= self.min_hops_to_observe_for_timing_attack()): + suspected_origins.update({sender.id}) + + # Track back to each time when that sender might have received any messages. + for time_sender_received in range(max_time, min_time - 1, -1): + if time_sender_received < 0: + break + self.timing_attack_from_receiver( + sender, time_sender_received, remaining_hops - 1, observed_hops + 1, suspected_origins + ) + + return suspected_origins + + def min_hops_to_observe_for_timing_attack(self) -> int: + return self.config.mixnet.num_mix_layers + 1 diff --git a/mixnet/v2/sim/bulk_attack.py b/mixnet/v2/sim/bulk_attack.py new file mode 100644 index 0000000..c17b01c --- /dev/null +++ b/mixnet/v2/sim/bulk_attack.py @@ -0,0 +1,152 @@ +import argparse +from datetime import datetime + +import pandas as pd +from matplotlib import pyplot as plt + +from analysis import Analysis +from config import Config, P2PConfig +from simulation import Simulation + +COL_P2P_TYPE = "p2p_type" +COL_NUM_MIX_LAYERS = "num_mix_layers" +COL_COVER_MESSAGE_PROB = "cover_message_prob" +COL_MIX_DELAY = "mix_delay" +COL_GLOBAL_PRECISION = "global_precision" +COL_GLOBAL_RECALL = "global_recall" +COL_GLOBAL_F1_SCORE = "global_f1_score" +COL_TARGET_ACCURACY_MEDIAN = "target_accuracy_median" +COL_TARGET_ACCURACY_STD = "target_accuracy_std" +COL_TARGET_ACCURACY_MIN = "target_accuracy_min" +COL_TARGET_ACCURACY_25p = "target_accuracy_25p" +COL_TARGET_ACCURACY_MEAN = "target_accuracy_mean" +COL_TARGET_ACCURACY_75p = "target_accuracy_75p" +COL_TARGET_ACCURACY_MAX = "target_accuracy_max" + + +def bulk_attack(): + parser = argparse.ArgumentParser(description="Run multiple passive adversary attack simulations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--config", type=str, required=True, help="Configuration file path") + args = parser.parse_args() + config = Config.load(args.config) + + config.simulation.running_time = 200 + config.mixnet.num_nodes = 100 + config.mixnet.payload_size = 320 + config.mixnet.message_interval = 10 + config.mixnet.real_message_prob = 0.01 + config.mixnet.real_message_prob_weights = [] + config.mixnet.max_message_prep_time = 0 + config.p2p.connection_density = 6 + config.p2p.min_network_latency = 1 + config.p2p.max_network_latency = 1 + config.measurement.sim_time_per_second = 10 + + results = [] + + for p2p_type in [P2PConfig.TYPE_ONE_TO_ALL, P2PConfig.TYPE_GOSSIP]: + config.p2p.type = p2p_type + + for num_mix_layers in [0, 1, 2, 3]: + config.mixnet.num_mix_layers = num_mix_layers + + for cover_message_prob in [0.0, 0.1, 0.2, 0.3]: + config.mixnet.cover_message_prob = cover_message_prob + + for mix_delay in [0]: + config.mixnet.min_mix_delay = mix_delay + config.mixnet.max_mix_delay = mix_delay + + sim = Simulation(config) + sim.run() + + analysis = Analysis(sim, config, show_plots=False) + precision, recall, f1_score = analysis.messages_emitted_around_interval() + print( + f"STARTING TIMING ATTACK: p2p_type:{p2p_type}, {num_mix_layers} layers, {cover_message_prob} cover, {mix_delay} delay") + timing_attack_df = analysis.timing_attack(analysis.message_hops()) + + results.append({ + COL_P2P_TYPE: p2p_type, + COL_NUM_MIX_LAYERS: num_mix_layers, + COL_COVER_MESSAGE_PROB: cover_message_prob, + COL_MIX_DELAY: mix_delay, + COL_GLOBAL_PRECISION: precision, + COL_GLOBAL_RECALL: recall, + COL_GLOBAL_F1_SCORE: f1_score, + COL_TARGET_ACCURACY_MEDIAN: float(timing_attack_df.median().iloc[0]), + COL_TARGET_ACCURACY_STD: float(timing_attack_df.std().iloc[0]), + COL_TARGET_ACCURACY_MIN: float(timing_attack_df.min().iloc[0]), + COL_TARGET_ACCURACY_25p: float(timing_attack_df.quantile(0.25).iloc[0]), + COL_TARGET_ACCURACY_MEAN: float(timing_attack_df.mean().iloc[0]), + COL_TARGET_ACCURACY_75p: float(timing_attack_df.quantile(0.75).iloc[0]), + COL_TARGET_ACCURACY_MAX: float(timing_attack_df.max().iloc[0]), + }) + + df = pd.DataFrame(results) + df.to_csv(f"bulk-attack-{datetime.now().replace(microsecond=0).isoformat()}.csv", index=False) + plot_global_metrics(df) + plot_target_accuracy(df) + + +def plot_global_metrics(df: pd.DataFrame): + for p2p_type in df[COL_P2P_TYPE].unique(): + # Plotting global precision, recall, and f1 score against different parameters + fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 15)) + + # Precision plot + for cover_message_prob in df[COL_COVER_MESSAGE_PROB].unique(): + subset = df[(df[COL_COVER_MESSAGE_PROB] == cover_message_prob) & (df[COL_P2P_TYPE] == p2p_type)] + axes[0].plot(subset[COL_NUM_MIX_LAYERS], subset[COL_GLOBAL_PRECISION], + label=f"{cover_message_prob} cover rate") + axes[0].set_title(f"Global Precision ({p2p_type})") + axes[0].set_xlabel("# of Mix Layers") + axes[0].set_ylabel("Global Precision (%)") + axes[0].set_ylim(0, 105) + axes[0].legend() + + # Recall plot + for cover_message_prob in df[COL_COVER_MESSAGE_PROB].unique(): + subset = df[(df[COL_COVER_MESSAGE_PROB] == cover_message_prob) & (df[COL_P2P_TYPE] == p2p_type)] + axes[1].plot(subset[COL_NUM_MIX_LAYERS], subset[COL_GLOBAL_RECALL], + label=f"{cover_message_prob} cover rate") + axes[1].set_title(f"Global Recall ({p2p_type})") + axes[1].set_xlabel("# of Mix Layers") + axes[1].set_ylabel("Global Recall (%)") + axes[1].set_ylim(0, 105) + axes[1].legend() + + # F1 Score plot + for cover_message_prob in df[COL_COVER_MESSAGE_PROB].unique(): + subset = df[(df[COL_COVER_MESSAGE_PROB] == cover_message_prob) & (df[COL_P2P_TYPE] == p2p_type)] + axes[2].plot(subset[COL_NUM_MIX_LAYERS], subset[COL_GLOBAL_F1_SCORE], + label=f"{cover_message_prob} cover rate") + axes[2].set_title(f"Global F1 Score ({p2p_type})") + axes[2].set_xlabel("# of Mix Layers") + axes[2].set_ylabel("Global F1 Score (%)") + axes[2].set_ylim(0, 105) + axes[2].legend() + + plt.tight_layout() + plt.show() + + +def plot_target_accuracy(df: pd.DataFrame): + for p2p_type in df[COL_P2P_TYPE].unique(): + plt.figure(figsize=(12, 6)) + for cover_message_prob in df[COL_COVER_MESSAGE_PROB].unique(): + subset = df[(df[COL_COVER_MESSAGE_PROB] == cover_message_prob) & (df[COL_P2P_TYPE] == p2p_type)] + plt.plot(subset[COL_NUM_MIX_LAYERS], subset[COL_TARGET_ACCURACY_MEDIAN], + label=f"{cover_message_prob} cover rate") + plt.title(f"Timing Attack Accuracy ({p2p_type})") + plt.xlabel("# of Mix Layers") + plt.ylabel("Median of Accuracy (%)") + plt.ylim(0, 105) + plt.legend() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + bulk_attack() diff --git a/mixnet/v2/sim/bulk_run_cover.py b/mixnet/v2/sim/bulk_run_cover.py new file mode 100644 index 0000000..65a5fed --- /dev/null +++ b/mixnet/v2/sim/bulk_run_cover.py @@ -0,0 +1,82 @@ +import argparse +from datetime import datetime + +import pandas as pd +from matplotlib import pyplot as plt + +from config import P2PConfig, Config +from simulation import Simulation + + +def bulk_run_cover(): + parser = argparse.ArgumentParser(description="Run simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--config", type=str, required=True, help="Configuration file path") + args = parser.parse_args() + config = Config.load(args.config) + + results = [] + + config.simulation.running_time = 30 + config.mixnet.num_nodes = 1000 + config.mixnet.num_mix_layers = 4 + config.mixnet.payload_size = 320 + config.mixnet.message_interval = 1 + config.mixnet.real_message_prob = 0.01 + config.mixnet.real_message_prob_weights = [] + config.mixnet.max_message_prep_time = 0 + config.mixnet.max_mix_delay = 0 + config.p2p.type = P2PConfig.TYPE_GOSSIP + config.p2p.connection_density = 6 + config.p2p.max_network_latency = 0.20 + config.measurement.sim_time_per_second = 1 + + base = config.mixnet.real_message_prob * 2 + for cover_message_prob in [base, base * 2, base * 3, base * 4, base * 5]: + config.mixnet.cover_message_prob = cover_message_prob + + sim = Simulation(config) + sim.run() + + egress, ingress = sim.p2p.measurement.bandwidth() + results.append({ + "num_nodes": config.mixnet.num_nodes, + "num_mix_layers": config.mixnet.num_mix_layers, + "p2p_type": config.p2p.type, + "real_message_prob": config.mixnet.real_message_prob, + "cover_message_prob": cover_message_prob, + "egress_mean": egress.mean(), + "egress_max": egress.max(), + "ingress_mean": ingress.mean(), + "ingress_max": ingress.max(), + }) + + df = pd.DataFrame(results) + df.to_csv(f"{datetime.now().replace(microsecond=0).isoformat()}.csv", index=False) + draw_plot(df) + + +def load_and_plot(): + # with skipping the header + df = pd.read_csv("2024-05-27T14:14:58.csv") + print(df) + draw_plot(df) + + +def draw_plot(df: pd.DataFrame): + plt.plot(df["cover_message_prob"], df["egress_mean"], label="Egress Mean", marker="o") + plt.plot(df["cover_message_prob"], df["egress_max"], label="Egress Max", marker="x") + plt.plot(df["cover_message_prob"], df["ingress_mean"], label="Ingress Mean", marker="v") + plt.plot(df["cover_message_prob"], df["ingress_max"], label="Ingress Max", marker="^") + + plt.xlabel("Cover Emission Rate") + plt.ylabel("Bandwidth (KiB/s)") + plt.title("Bandwidth vs Cover Emission Rate") + plt.legend() + plt.grid(True) + plt.show() + + +if __name__ == "__main__": + bulk_run_cover() + # load_and_plot() diff --git a/mixnet/v2/sim/bulk_run_num_nodes.py b/mixnet/v2/sim/bulk_run_num_nodes.py new file mode 100644 index 0000000..5f8a1f1 --- /dev/null +++ b/mixnet/v2/sim/bulk_run_num_nodes.py @@ -0,0 +1,103 @@ +import argparse +from datetime import datetime + +import pandas as pd +from matplotlib import pyplot as plt + +from config import P2PConfig, Config +from simulation import Simulation + +# https://matplotlib.org/stable/api/markers_api.html +MARKERS = ['o', 'x', 'v', '^', '<', '>'] +NUM_NODES_SET = [10, 200, 400, 600, 800, 1000] +NUM_MIX_LAYERS_SET = [0, 2, 4] + + +def bulk_run(): + parser = argparse.ArgumentParser(description="Run simulation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--config", type=str, required=True, help="Configuration file path") + args = parser.parse_args() + config = Config.load(args.config) + + results = [] + + config.simulation.running_time = 30 + config.mixnet.payload_size = 320 + config.mixnet.real_message_prob = 0.01 + config.mixnet.real_message_prob_weights = [] + config.mixnet.max_message_prep_time = 0 + config.mixnet.max_mix_delay = 0 + + for num_nodes in NUM_NODES_SET: + config.mixnet.num_nodes = num_nodes + + for p2p_type in [P2PConfig.TYPE_GOSSIP]: + config.p2p.type = p2p_type + + for num_mix_layers in NUM_MIX_LAYERS_SET: + config.mixnet.num_mix_layers = num_mix_layers + + for cover_message_prob in [0.0, config.mixnet.real_message_prob * 2]: + config.mixnet.cover_message_prob = cover_message_prob + + sim = Simulation(config) + sim.run() + + egress, ingress = sim.p2p.measurement.bandwidth() + results.append({ + "num_nodes": num_nodes, + "config": f"{p2p_type}: {num_mix_layers}: {cover_message_prob}", + "egress_mean": egress.mean(), + "egress_max": egress.max(), + "ingress_mean": ingress.mean(), + "ingress_max": ingress.max(), + }) + + df = pd.DataFrame(results) + df.to_csv(f"{datetime.now().replace(microsecond=0).isoformat()}.csv", index=False) + draw_plots(df) + + +def load_and_plot(): + # with skipping the header + df = pd.read_csv("2024-05-25T23:16:39.csv") + print(df) + draw_plots(df) + + +def draw_plots(df: pd.DataFrame): + max_ylim = draw_plot(df, "num_nodes", "config", "egress_max", "Egress Bandwidth (Max)", + "Number of Nodes", "Max Bandwidth (KiB/s)") + draw_plot(df, "num_nodes", "config", "egress_mean", "Egress Bandwidth (Mean)", + "Number of Nodes", "Mean Bandwidth (KiB/s)", max_ylim) + + max_ylim = draw_plot(df, "num_nodes", "config", "ingress_max", "Ingress Bandwidth (Max)", + "Number of Nodes", "Max Bandwidth (KiB/s)") + draw_plot(df, "num_nodes", "config", "ingress_mean", "Ingress Bandwidth (Mean)", + "Number of Nodes", "Mean Bandwidth (KiB/s)", max_ylim) + + +def draw_plot(df: pd.DataFrame, index: str, column: str, value: str, title: str, xlabel: str, ylabel: str, + ylim: float = None) -> float: + df_pivot = df.pivot(index=index, columns=column, values=value) + plt.figure(figsize=(12, 6)) + fig, ax = plt.subplots() + for i, config in enumerate(df_pivot.columns): + marker = MARKERS[NUM_MIX_LAYERS_SET.index(int(config.split(":")[1].strip()))] + ax.plot(df_pivot.index, df_pivot[config], label=config, marker=marker) + plt.title(title) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.legend(title="mode: layers: cover", loc="upper left") + plt.tight_layout() + plt.grid(True) + if ylim is not None: + ax.set_ylim(ylim) + plt.show() + return ax.get_ylim() + + +if __name__ == "__main__": + # bulk_run() + load_and_plot() diff --git a/mixnet/v2/sim/config.py b/mixnet/v2/sim/config.py new file mode 100644 index 0000000..55b7d25 --- /dev/null +++ b/mixnet/v2/sim/config.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import random +from dataclasses import dataclass +from typing import Self + +import dacite +import yaml + +from environment import Time + + +@dataclass +class Config: + simulation: SimulationConfig + mixnet: MixnetConfig + p2p: P2PConfig + measurement: MeasurementConfig + adversary: AdversaryConfig + + @classmethod + def load(cls, yaml_path: str) -> Self: + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + config = dacite.from_dict(data_class=Config, data=data) + + # Validations + config.simulation.validate() + config.mixnet.validate() + config.p2p.validate() + config.measurement.validate() + config.adversary.validate() + + return config + + def description(self): + return ( + f"{self.mixnet.description()}\n" + f"{self.p2p.description()}" + ) + + +@dataclass +class SimulationConfig: + running_time: Time + + def validate(self): + assert self.running_time > 0 + + +@dataclass +class MixnetConfig: + num_nodes: int + num_mix_layers: int + # A size of a message payload in bytes (e.g. the size of a block proposal) + payload_size: int + # An interval of sending a new real/cover message + # A probability of actually sending a message depends on the following parameters. + message_interval: Time + # A probability of sending a real message within one cycle + real_message_prob: float + # A weight of real message emission probability of some nodes + # Each weight is multiplied to the real_message_prob of the node being at the same position in the node list. + # The length of the list should be <= num_nodes. i.e. some nodes won't have a weight. + real_message_prob_weights: list[float] + # A probability of sending a cover message within one cycle if not sending a real message + cover_message_prob: float + # A maximum preparation time (computation time) for a message sender before sending the message + max_message_prep_time: Time + # A maximum delay of messages mixed in a mix node + min_mix_delay: Time + max_mix_delay: Time + + def validate(self): + assert self.num_nodes > 0 + assert 0 <= self.num_mix_layers <= self.num_nodes + assert self.payload_size > 0 + assert self.message_interval > 0 + assert self.real_message_prob > 0 + assert len(self.real_message_prob_weights) <= self.num_nodes + for weight in self.real_message_prob_weights: + assert weight >= 1 + assert self.cover_message_prob >= 0 + assert self.max_message_prep_time >= 0 + assert 0 <= self.min_mix_delay <= self.max_mix_delay + + def description(self): + return ( + f"payload: {self.payload_size} bytes\n" + f"num_nodes: {self.num_nodes}\n" + f"num_mix_layers: {self.num_mix_layers}\n" + f"min_mix_delay: {self.min_mix_delay}\n" + f"max_mix_delay: {self.max_mix_delay}\n" + f"msg_interval: {self.message_interval}\n" + f"real_msg_prob: {self.real_message_prob:.2f}\n" + f"cover_msg_prob: {self.cover_message_prob:.2f}" + ) + + def is_mixing_on(self) -> bool: + return self.num_mix_layers > 0 + + def random_mix_delay(self) -> Time: + return random.randint(self.min_mix_delay, self.max_mix_delay) + + +@dataclass +class P2PConfig: + # Broadcasting type: 1-to-all | gossip + type: str + # A connection density, only if the type is gossip + connection_density: int + # A maximum network latency between nodes directly connected with each other + min_network_latency: Time + max_network_latency: Time + + TYPE_ONE_TO_ALL = "1-to-all" + TYPE_GOSSIP = "gossip" + + def validate(self): + assert self.type in [self.TYPE_ONE_TO_ALL, self.TYPE_GOSSIP] + if self.type == self.TYPE_GOSSIP: + assert self.connection_density > 0 + assert 0 < self.min_network_latency <= self.max_network_latency + + def description(self): + return ( + f"p2p_type: {self.type}\n" + f"conn_density: {self.connection_density}\n" + f"min_net_latency: {self.min_network_latency:.2f}\n" + f"max_net_latency: {self.max_network_latency:.2f}" + ) + + def random_network_latency(self) -> Time: + return random.randint(self.min_network_latency, self.max_network_latency) + + +@dataclass +class MeasurementConfig: + # How many times in simulation represent 1 second in real time + sim_time_per_second: Time + + def validate(self): + assert self.sim_time_per_second > 0 + + +@dataclass +class AdversaryConfig: + timing_attack_timeout: int + timing_attack_max_targets: int + timing_attack_max_pool_size: int + + def validate(self): + assert self.timing_attack_timeout > 0 + assert self.timing_attack_max_targets > 0 + assert self.timing_attack_max_pool_size > 0 diff --git a/mixnet/v2/sim/config.yaml b/mixnet/v2/sim/config.yaml new file mode 100644 index 0000000..8c653dd --- /dev/null +++ b/mixnet/v2/sim/config.yaml @@ -0,0 +1,45 @@ +simulation: + # The simulation uses a virtual time. Please see README for more details. + running_time: 300 + +mixnet: + num_nodes: 100 + # A number of mix nodes selected by a message sender through which the Sphinx message goes through + # If 0, the message is broadcast directly to all nodes without being Sphinx-encoded. + num_mix_layers: 2 + # A size of a message payload in bytes (e.g. the size of a block proposal) + payload_size: 320 + # An interval of sending a new real/cover message + # A probability of actually sending a message depends on the following parameters. + message_interval: 10 + # A probability of sending a real message within a cycle + real_message_prob: 0.01 + # A weight of real message emission probability of some nodes + # Each weight is multiplied to the real_message_prob of the node being at the same position in the node list. + # The length of the list should be <= p2p.num_nodes. i.e. some nodes won't have a weight. + real_message_prob_weights: [ ] + # A probability of sending a cover message within a cycle if not sending a real message + cover_message_prob: 0.00 + # A maximum preparation time (computation time) for a message sender before sending the message + max_message_prep_time: 0 + # A maximum delay of messages mixed in a mix node + min_mix_delay: 0 + max_mix_delay: 0 + +p2p: + # Broadcasting type: 1-to-all | gossip + type: "1-to-all" + # A connection density, only if the type is gossip + connection_density: 6 + # A maximum network latency between nodes directly connected with each other + min_network_latency: 1 + max_network_latency: 1 + +measurement: + # How many times in simulation represent 1 second in real time + sim_time_per_second: 10 + +adversary: + timing_attack_timeout: 300 + timing_attack_max_targets: 10000000000 + timing_attack_max_pool_size: 100 \ No newline at end of file diff --git a/mixnet/v2/sim/environment.py b/mixnet/v2/sim/environment.py new file mode 100644 index 0000000..e110370 --- /dev/null +++ b/mixnet/v2/sim/environment.py @@ -0,0 +1,22 @@ +from typing import Optional, Any + +import simpy + +Time = int + + +class Environment: + def __init__(self): + self.env = simpy.Environment() + + def now(self) -> Time: + return Time(self.env.now) + + def run(self, until: Time) -> Optional[Any]: + return self.env.run(until=until) + + def timeout(self, timeout: Time) -> simpy.Timeout: + return self.env.timeout(timeout) + + def process(self, generator: simpy.events.ProcessGenerator) -> simpy.Process: + return self.env.process(generator) diff --git a/mixnet/v2/sim/main.py b/mixnet/v2/sim/main.py new file mode 100644 index 0000000..e22bece --- /dev/null +++ b/mixnet/v2/sim/main.py @@ -0,0 +1,18 @@ +import argparse + +from config import Config +from analysis import Analysis +from simulation import Simulation + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run simulation", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--config", type=str, required=True, help="Configuration file path") + args = parser.parse_args() + + config = Config.load(args.config) + sim = Simulation(config) + sim.run() + + Analysis(sim, config).run() + + print("Simulation complete!") \ No newline at end of file diff --git a/mixnet/v2/sim/measurement.py b/mixnet/v2/sim/measurement.py new file mode 100644 index 0000000..2b8bd3c --- /dev/null +++ b/mixnet/v2/sim/measurement.py @@ -0,0 +1,57 @@ +from collections import defaultdict, Counter +from typing import TYPE_CHECKING + +import pandas as pd + +from config import Config +from environment import Environment +from sphinx import SphinxPacket + +if TYPE_CHECKING: + from node import Node + + +class Measurement: + def __init__(self, env: Environment, config: Config): + self.env = env + self.config = config + self.original_senders = Counter() + self.egress_bandwidth_per_sec = [] + self.ingress_bandwidth_per_sec = [] + self.message_hops = defaultdict(int) # dict[msg_hash, hops] + + self.env.process(self._update_bandwidth_window()) + + def set_nodes(self, nodes: list["Node"]): + for node in nodes: + self.original_senders[node] = 0 + + def count_original_sender(self, sender: "Node"): + self.original_senders.update({sender}) + + def measure_egress(self, node: "Node", msg: SphinxPacket | bytes): + self.egress_bandwidth_per_sec[-1][node] += len(msg) + + def measure_ingress(self, node: "Node", msg: SphinxPacket | bytes): + self.ingress_bandwidth_per_sec[-1][node] += len(msg) + + def update_message_hops(self, msg_hash: bytes, hops: int): + self.message_hops[msg_hash] = max(hops, self.message_hops[msg_hash]) + + def _update_bandwidth_window(self): + while True: + self.ingress_bandwidth_per_sec.append(defaultdict(int)) + self.egress_bandwidth_per_sec.append(defaultdict(int)) + yield self.env.timeout(self.config.measurement.sim_time_per_second) + + def bandwidth(self) -> (pd.Series, pd.Series): + nonzero_egresses, nonzero_ingresses = [], [] + for egress_bandwidths, ingress_bandwidths in zip(self.egress_bandwidth_per_sec, + self.ingress_bandwidth_per_sec): + for bandwidth in egress_bandwidths.values(): + if bandwidth > 0: + nonzero_egresses.append(bandwidth / 1024.0) + for bandwidth in ingress_bandwidths.values(): + if bandwidth > 0: + nonzero_ingresses.append(bandwidth / 1024.0) + return pd.Series(nonzero_egresses), pd.Series(nonzero_ingresses) diff --git a/mixnet/v2/sim/node.py b/mixnet/v2/sim/node.py new file mode 100644 index 0000000..efe652f --- /dev/null +++ b/mixnet/v2/sim/node.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import os +import random +from enum import Enum + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey + +from config import Config +from environment import Environment +from measurement import Measurement +from p2p import P2P +from sphinx import SphinxPacket, Attachment + + +class Node: + INCENTIVE_TX_SIZE = 512 + PADDING_SEPARATOR = b'\x01' + + def __init__(self, id: int, env: Environment, p2p: P2P, config: Config, measurement: Measurement, + operated_by_adversary: bool = False): + self.id = id + self.env = env + self.p2p = p2p + self.private_key = X25519PrivateKey.generate() + self.public_key = self.private_key.public_key() + self.config = config + self.payload_id = 0 + self.measurement = measurement + self.operated_by_adversary = operated_by_adversary + self.action = self.env.process(self.send_message()) + + def send_message(self): + """ + Creates/encapsulate a message and send it to the network through the mixnet + """ + while True: + yield self.env.timeout(self.config.mixnet.message_interval) + + message_type = self.message_type_to_send() + if message_type is None: # nothing to send in this turn + continue + elif message_type == MessageType.REAL: + self.measurement.count_original_sender(self) + + msg = self.create_message(message_type) + prep_time = random.randint(0, self.config.mixnet.max_message_prep_time) + yield self.env.timeout(prep_time) + + self.log("Sending a message to the mixnet") + self.env.process(self.p2p.broadcast(self, msg)) + + def message_type_to_send(self) -> MessageType | None: + rnd = random.random() + if rnd < self.real_message_prob(): + return MessageType.REAL + elif rnd < self.config.mixnet.cover_message_prob: + return MessageType.COVER + else: + return None + + def real_message_prob(self): + weight = self.config.mixnet.real_message_prob_weights[self.id] \ + if self.id < len(self.config.mixnet.real_message_prob_weights) else 1 + return self.config.mixnet.real_message_prob * weight + + def create_message(self, message_type: MessageType) -> SphinxPacket | bytes: + """ + Creates a real or cover message + @return: + """ + if not self.config.mixnet.is_mixing_on(): + return self.build_payload() + + mixes = self.p2p.get_nodes(self.config.mixnet.num_mix_layers, self.id) + public_keys = [mix.public_key for mix in mixes] + # TODO: replace with realistic tx + incentive_txs = [Node.create_incentive_tx(mix.public_key) for mix in mixes] + if message_type == MessageType.COVER: + # Set invalid txs for a cover message, + # so that nobody will recognize that as a real message to be forwarded to the next mix. + incentive_txs = [Attachment(os.urandom(len(bytes(tx)))) for tx in incentive_txs] + return SphinxPacket(public_keys, incentive_txs, self.build_payload()) + + def receive_message(self, msg: SphinxPacket | bytes): + """ + Receives a message from the network, processes it, + and forwards it to the next mix or the entire network if necessary. + @param msg: the message to be processed + """ + if isinstance(msg, SphinxPacket): + msg, incentive_tx = msg.unwrap(self.private_key) + if self.is_my_incentive_tx(incentive_tx): + # self.log("Receiving SphinxPacket. It's mine!") + if msg.is_all_unwrapped(): + final_padded_msg = self.pad_payload(msg.payload, len(msg)) + self.env.process(self.p2p.broadcast(self, final_padded_msg)) + else: + # TODO: use Poisson delay or something else, if necessary + yield self.env.timeout(random.randint(0, self.config.mixnet.max_mix_delay)) + self.env.process(self.p2p.broadcast(self, msg)) + # else: + # self.log("Receiving SphinxPacket, but not mine") + else: + final_msg = msg[:msg.rfind(self.PADDING_SEPARATOR)] + self.log("Received final message: %s" % final_msg) + + def inspect_message(self, msg: SphinxPacket | bytes) -> int | None: + """ + Inspects the message if the node is operated by adversary. + @param msg: SphinxPacket or final unwrapped message + @return: Origin Node ID, or None if the node is not operated by adversary + """ + if self.operated_by_adversary and isinstance(msg, bytes): + origin_id, _ = Node.parse_payload(msg) + return origin_id + return None + + def build_payload(self) -> bytes: + payload = bytes(f"{self.id}-{self.payload_id}-", "utf-8") + self.payload_id += 1 + return payload + bytes(self.config.mixnet.payload_size - len(payload)) + + @staticmethod + def parse_payload(payload: bytes) -> (int, int): + parts = payload.split(b"-") + node_id, payload_id = int(parts[0]), int(parts[1]) + return node_id, payload_id + + def pad_payload(self, payload: bytes, target_size: int) -> bytes: + """ + Pad the final msg to the target size (e.g. the same size as a SphinxPacket), + assuming that the final msg is going to be sent via secure channels (TLS, Noise, etc.) + """ + return (payload + + self.PADDING_SEPARATOR + + bytes(target_size - len(payload) - len(self.PADDING_SEPARATOR))) + + # TODO: This is a dummy logic + @classmethod + def create_incentive_tx(cls, mix_public_key: X25519PublicKey) -> Attachment: + public_key = mix_public_key.public_bytes(encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw) + public_key += bytes(cls.INCENTIVE_TX_SIZE - len(public_key)) + return Attachment(public_key) + + def is_my_incentive_tx(self, tx: Attachment) -> bool: + return tx == Node.create_incentive_tx(self.public_key) + + def log(self, msg): + print(f"t={self.env.now():.3f}: Node:{self.id}: {msg}") + + +class MessageType(Enum): + REAL = 0 + COVER = 1 diff --git a/mixnet/v2/sim/p2p.py b/mixnet/v2/sim/p2p.py new file mode 100644 index 0000000..b5d13ae --- /dev/null +++ b/mixnet/v2/sim/p2p.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import hashlib +import random +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import TYPE_CHECKING + +from adversary import Adversary +from config import Config +from environment import Environment, Time +from measurement import Measurement +from sphinx import SphinxPacket + +if TYPE_CHECKING: + from node import Node + + +class P2P(ABC): + def __init__(self, env: Environment, config: Config): + self.env = env + self.config = config + self.nodes = [] + self.measurement = Measurement(env, config) + self.adversary = Adversary(env, config) + + def set_nodes(self, nodes: list["Node"]): + self.nodes = nodes + self.measurement.set_nodes(nodes) + + def get_nodes(self, n: int, exclude_node_id: int) -> list["Node"]: + candidates = self.nodes[:exclude_node_id] + self.nodes[exclude_node_id + 1:] + return random.sample(candidates, n) + + # This should accept only bytes in practice, + # but we accept SphinxPacket as well because we don't implement Sphinx deserialization. + @abstractmethod + def broadcast(self, sender: "Node", msg: SphinxPacket | bytes, hops_traveled: int = 0): + # Yield 0 to ensure that the broadcast is done in the same time step. + # Without any yield, SimPy complains that the broadcast func is not a generator. + yield self.env.timeout(0) + + def send(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + is_first_of_broadcasting: bool): + time_sent = self.env.now() + if sender != receiver: + if is_first_of_broadcasting: + self.adversary.inspect_message_size(msg) + self.adversary.observe_sending_node(sender) + self.measurement.measure_egress(sender, msg) + + # simulate network latency + yield self.env.timeout(self.config.p2p.random_network_latency()) + + self.measurement.measure_ingress(receiver, msg) + self.adversary.observe_receiving_node(sender, receiver, time_sent) + self.receive(msg, hops_traveled + 1, sender, receiver, time_sent) + + @abstractmethod + def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + time_sent: Time): + pass + + def log(self, msg): + print(f"t={self.env.now():.3f}: P2P: {msg}") + + +class NaiveBroadcastP2P(P2P): + def __init__(self, env: Environment, config: Config): + super().__init__(env, config) + self.nodes = [] + + # This should accept only bytes in practice, + # but we accept SphinxPacket as well because we don't implement Sphinx deserialization. + def broadcast(self, sender: "Node", msg: SphinxPacket | bytes, hops_traveled: int = 0): + yield from super().broadcast(sender, msg) + + self.log(f"Node:{sender.id}: Broadcasting a msg: {len(msg)} bytes") + for i, receiver in enumerate(self.nodes): + self.env.process(self.send(msg, 0, sender, receiver, i == 0)) + + def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + time_sent: Time): + msg_hash = hashlib.sha256(bytes(msg)).digest() + self.measurement.update_message_hops(msg_hash, hops_traveled) + self.adversary.observe_if_final_msg(sender, receiver, time_sent, msg) + self.env.process(receiver.receive_message(msg)) + + +class GossipP2P(P2P): + def __init__(self, env: Environment, config: Config): + super().__init__(env, config) + self.topology = defaultdict(set) + self.message_cache = defaultdict(dict) # dict[receiver, dict[msg_hash, sender]] + + def set_nodes(self, nodes: list["Node"]): + super().set_nodes(nodes) + for i, node in enumerate(nodes): + # Each node is chained with the right neighbor, so that no node is not orphaned. + # And then, each node is connected to a random subset of other nodes. + front, back = nodes[:i], nodes[i + 1:] + if len(back) > 0: + neighbor = back[0] + back = back[1:] + elif len(front) > 0: + neighbor = front[0] + front = front[1:] + else: + continue + + others = front + back + n = min(self.config.p2p.connection_density - 1, len(others)) + conns = set(random.sample(others, n)) + conns.add(neighbor) + self.topology[node] = conns + + def broadcast(self, sender: "Node", msg: SphinxPacket | bytes, hops_traveled: int = 0): + yield from super().broadcast(sender, msg) + self.log(f"Node:{sender.id}: Gossiping a msg: {len(msg)} bytes") + + # if the msg is created originally by the sender (not forwarded from others), cache it with the sender itself. + msg_hash = hashlib.sha256(bytes(msg)).digest() + if msg_hash not in self.message_cache[sender]: + self.message_cache[sender][msg_hash] = sender + + cnt = 0 + for receiver in self.topology[sender]: + # Don't gossip the message if it was received from the node who is going to be the receiver, + # which means that the node already knows the message. + if receiver != self.message_cache[sender][msg_hash]: + self.env.process(self.send(msg, hops_traveled, sender, receiver, cnt == 0)) + cnt += 1 + + def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + time_sent: Time): + # Receive/gossip the msg only if it hasn't been received before. If not, just ignore the msg. + # i.e. each message is received/gossiped at most once by each node. + msg_hash = hashlib.sha256(bytes(msg)).digest() + if msg_hash not in self.message_cache[receiver]: + self.message_cache[receiver][msg_hash] = sender + self.measurement.update_message_hops(msg_hash, hops_traveled) + self.adversary.observe_if_final_msg(sender, receiver, time_sent, msg) + + # Receive and gossip + self.env.process(receiver.receive_message(msg)) + self.env.process(self.broadcast(receiver, msg, hops_traveled)) diff --git a/mixnet/v2/sim/simulation.py b/mixnet/v2/sim/simulation.py new file mode 100644 index 0000000..c174ea5 --- /dev/null +++ b/mixnet/v2/sim/simulation.py @@ -0,0 +1,32 @@ +import random + +import simpy + +from config import Config, P2PConfig +from environment import Environment +from node import Node +from p2p import NaiveBroadcastP2P, GossipP2P + + +class Simulation: + def __init__(self, config: Config): + random.seed() + self.config = config + self.env = Environment() + self.p2p = Simulation.init_p2p(self.env, config) + nodes = [Node(i, self.env, self.p2p, config, self.p2p.measurement, i == 0) for i in + range(config.mixnet.num_nodes)] + self.p2p.set_nodes(nodes) + + def run(self): + self.env.run(until=self.config.simulation.running_time) + + @classmethod + def init_p2p(cls, env: simpy.Environment, config: Config): + match config.p2p.type: + case P2PConfig.TYPE_ONE_TO_ALL: + return NaiveBroadcastP2P(env, config) + case P2PConfig.TYPE_GOSSIP: + return GossipP2P(env, config) + case _: + raise ValueError("Unknown P2P type") diff --git a/mixnet/v2/sim/sphinx.py b/mixnet/v2/sim/sphinx.py new file mode 100644 index 0000000..c0ea75d --- /dev/null +++ b/mixnet/v2/sim/sphinx.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from copy import deepcopy + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey, X25519PrivateKey + +class SphinxPacket: + # TODO: define max path length + + def __init__(self, public_keys: list[X25519PublicKey], attachments: list[Attachment], payload: bytes): + assert len(public_keys) == len(attachments) + + ephemeral_private_key = X25519PrivateKey.generate() + ephemeral_public_key = ephemeral_private_key.public_key() + shared_keys = [SharedSecret(ephemeral_private_key, pk) for pk in public_keys] + self.header = SphinxHeader(ephemeral_public_key, shared_keys, attachments) + self.payload = payload # TODO: encrypt payload + + def __bytes__(self): + return bytes(self.header) + self.payload + + def __len__(self): + return len(bytes(self)) + + def unwrap(self, private_key: X25519PrivateKey) -> tuple[SphinxPacket, Attachment]: + packet = deepcopy(self) + attachment = packet.header.unwrap_inplace(private_key) + # TODO: decrypt packet._payload + return packet, attachment + + def is_all_unwrapped(self) -> bool: + return self.header.is_all_unwrapped() + + +class SphinxHeader: + DUMMY_MAC = b'\xFF' * 16 + + def __init__(self, ephemeral_public_key: X25519PublicKey, shared_keys: list[SharedSecret], + attachments: list[Attachment]): + assert len(shared_keys) == len(attachments) + self.ephemeral_public_key = ephemeral_public_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw) + self.attachments = attachments # TODO: encapsulation using node_keys + + def __bytes__(self): + return b"".join([self.ephemeral_public_key] + [bytes(att) + self.DUMMY_MAC for att in self.attachments]) + + def unwrap_inplace(self, private_key: X25519PrivateKey) -> Attachment: + # TODO: shared_secret = SharedSecret(private_key, header.ephemeral_public_key) + attachment = self.attachments.pop(0) + self.attachments.append(Attachment(bytes(len(bytes(attachment))))) # append a dummy attachment + return attachment + + def is_all_unwrapped(self) -> bool: + # true if the first attachment is a dummy + return self.attachments[0] == Attachment(bytes(len(bytes(self.attachments[0])))) + + +class SharedSecret: + def __init__(self, private_key: X25519PrivateKey, public_key: X25519PublicKey): + self.key = private_key.exchange(public_key) # 32 bytes + + def __bytes__(self): + return self.key + + +class Attachment: + def __init__(self, data: bytes): + self.data = data + + def __bytes__(self): + return self.data + + def __eq__(self, other): + return bytes(self) == bytes(other) \ No newline at end of file