From 0fcb195c38581090318608272bceba6fd4fb592b Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:37:57 +0900 Subject: [PATCH] fix timing attack by deprecating window --- mixnet/v2/sim/adversary.py | 51 +++++++++++----------- mixnet/v2/sim/analysis.py | 85 +++++++++++++++++------------------- mixnet/v2/sim/config.py | 37 ++++++---------- mixnet/v2/sim/config.yaml | 21 ++++----- mixnet/v2/sim/environment.py | 22 ++++++++++ mixnet/v2/sim/measurement.py | 20 ++++----- mixnet/v2/sim/node.py | 12 ++--- mixnet/v2/sim/p2p.py | 34 ++++++++------- mixnet/v2/sim/simulation.py | 3 +- 9 files changed, 146 insertions(+), 139 deletions(-) create mode 100644 mixnet/v2/sim/environment.py diff --git a/mixnet/v2/sim/adversary.py b/mixnet/v2/sim/adversary.py index 9e2591f..6ce63ef 100644 --- a/mixnet/v2/sim/adversary.py +++ b/mixnet/v2/sim/adversary.py @@ -4,10 +4,8 @@ from collections import defaultdict, deque, Counter from enum import Enum from typing import TYPE_CHECKING -import simpy -from simpy.core import SimTime - from config import Config +from environment import Environment, Time from sphinx import SphinxPacket if TYPE_CHECKING: @@ -15,61 +13,62 @@ if TYPE_CHECKING: class Adversary: - def __init__(self, env: simpy.Environment, config: Config): + def __init__(self, env: Environment, config: Config): self.env = env self.config = config self.message_sizes = [] self.senders_around_interval = Counter() - self.msg_pools_per_window = [] # list[dict[receiver, deque[time_received])]] - self.msg_pools_per_window.append(defaultdict(lambda: deque())) - self.msgs_received_per_window = [] # list[dict[receiver, set[sender])]] - self.msgs_received_per_window.append(defaultdict(set)) - # dict[receiver, dict[window, list[(sender, origin_id)]]] + self.msg_pools_per_time = [] # list[dict[receiver, deque[time_received])]] + self.msg_pools_per_time.append(defaultdict(lambda: deque())) + self.msgs_received_per_time = [] # list[dict[receiver, dict[sender, list[time_sent]]]] + self.msgs_received_per_time.append(defaultdict(lambda: defaultdict(list))) + # dict[receiver, dict[time, list[(sender, time_sent, origin_id)]]] self.final_msgs_received = defaultdict(lambda: defaultdict(list)) # self.node_states = defaultdict(dict) - self.env.process(self.update_observation_window()) + self.env.process(self.update_observation_time()) def inspect_message_size(self, msg: SphinxPacket | bytes): self.message_sizes.append(len(msg)) - def observe_receiving_node(self, sender: "Node", receiver: "Node"): - self.msg_pools_per_window[-1][receiver].append(self.env.now) - self.msgs_received_per_window[-1][receiver].add(sender) + def observe_receiving_node(self, sender: "Node", receiver: "Node", time_sent: Time): + self.msg_pools_per_time[-1][receiver].append(self.env.now()) + self.msgs_received_per_time[-1][receiver][sender].append(time_sent) # if node not in self.node_states[self.env.now]: # self.node_states[self.env.now][node] = NodeState.RECEIVING def observe_sending_node(self, sender: "Node"): - msg_pool = self.msg_pools_per_window[-1][sender] + msg_pool = self.msg_pools_per_time[-1][sender] if len(msg_pool) > 0: # Adversary doesn't know which message in the pool is being emitted. So, pop the oldest one from the pool. msg_pool.popleft() - if self.is_around_message_interval(self.env.now): + if self.is_around_message_interval(self.env.now()): self.senders_around_interval.update({sender}) # self.node_states[self.env.now][node] = NodeState.SENDING - def observe_if_final_msg(self, sender: "Node", receiver: "Node", msg: SphinxPacket | bytes): + def observe_if_final_msg(self, sender: "Node", receiver: "Node", time_sent: Time, msg: SphinxPacket | bytes): origin_id = receiver.inspect_message(msg) if origin_id is not None: - cur_window = len(self.msgs_received_per_window) - 1 - self.final_msgs_received[receiver][cur_window].append((sender, origin_id)) + cur_time = len(self.msgs_received_per_time) - 1 + self.final_msgs_received[receiver][cur_time].append((sender, time_sent, origin_id)) - def is_around_message_interval(self, time: SimTime): + def is_around_message_interval(self, time: Time) -> bool: return time % self.config.mixnet.message_interval <= self.config.mixnet.max_message_prep_time - def update_observation_window(self): + def update_observation_time(self): while True: - yield self.env.timeout(self.config.adversary.window_size) + yield self.env.timeout(1) - self.msgs_received_per_window.append(defaultdict(set)) + self.msgs_received_per_time.append(defaultdict(lambda: defaultdict(list))) new_msg_pool = defaultdict(lambda: deque()) - for receiver, msg_queue in self.msg_pools_per_window[-1].items(): + for receiver, msg_queue in self.msg_pools_per_time[-1].items(): for time_received in msg_queue: - # If the message is likely to be still pending and be emitted soon, pass it on to the next window. - if self.env.now - time_received < self.config.mixnet.max_mix_delay: + # If the message is likely to be still pending and be emitted soon, + # pass it on to the next time slot. + if self.env.now() - time_received < self.config.mixnet.max_mix_delay: new_msg_pool[receiver][0].append(time_received) - self.msg_pools_per_window.append(new_msg_pool) + self.msg_pools_per_time.append(new_msg_pool) class NodeState(Enum): diff --git a/mixnet/v2/sim/analysis.py b/mixnet/v2/sim/analysis.py index c42d155..c44dc22 100644 --- a/mixnet/v2/sim/analysis.py +++ b/mixnet/v2/sim/analysis.py @@ -1,3 +1,4 @@ +import sys from collections import Counter from typing import TYPE_CHECKING @@ -7,6 +8,7 @@ from matplotlib import pyplot as plt from adversary import NodeState from config import Config +from environment import Time from simulation import Simulation if TYPE_CHECKING: @@ -43,8 +45,8 @@ class Analysis: dataframes = [] nonzero_egresses = [] nonzero_ingresses = [] - for egress_bandwidths, ingress_bandwidths in zip(self.sim.p2p.measurement.egress_bandwidth_per_time, - self.sim.p2p.measurement.ingress_bandwidth_per_time): + for egress_bandwidths, ingress_bandwidths in zip(self.sim.p2p.measurement.egress_bandwidth_per_sec, + self.sim.p2p.measurement.ingress_bandwidth_per_sec): rows = [] for node in self.sim.p2p.nodes: egress = egress_bandwidths[node] / 1024.0 @@ -129,11 +131,10 @@ class Analysis: def messages_in_node_over_time(self): dataframes = [] - for window, msg_pools in enumerate(self.sim.p2p.adversary.msg_pools_per_window): - time = window * self.config.adversary.window_size + for time, msg_pools in enumerate(self.sim.p2p.adversary.msg_pools_per_time): data = [] for receiver, msg_pool in msg_pools.items(): - senders = self.sim.p2p.adversary.msgs_received_per_window[window][receiver] + senders = self.sim.p2p.adversary.msgs_received_per_time[time][receiver].keys() data.append((time, receiver.id, len(msg_pool), len(senders))) df = pd.DataFrame(data, columns=[COL_TIME, COL_NODE_ID, COL_MSG_CNT, COL_SENDER_CNT]) if not df.empty: @@ -205,20 +206,18 @@ class Analysis: def timing_attack(self, hops_between_layers: int): hops_to_observe = hops_between_layers * (self.config.mixnet.num_mix_layers + 1) success_rates = [] - for receiver, windows_and_msgs in self.sim.p2p.adversary.final_msgs_received.items(): - for window, senders_and_origins in windows_and_msgs.items(): - for sender, origin_id in senders_and_origins: - print(f"START: receiver:{receiver.id}, window:{window}, sender:{sender.id}, origin:{origin_id}") + for receiver, times_and_msgs in self.sim.p2p.adversary.final_msgs_received.items(): + for time_received, msgs in times_and_msgs.items(): + for sender, time_sent, origin_id in msgs: suspected_origins = Counter() - self.timing_attack_with(receiver, window, hops_to_observe, 0, suspected_origins, sender) + self.timing_attack_with( + receiver, time_received, hops_to_observe, 0, suspected_origins, {sender: [time_sent]} + ) suspected_origin_ids = {node.id for node in suspected_origins} if origin_id in suspected_origin_ids: success_rate = 1 / len(suspected_origin_ids) * 100.0 else: success_rate = 0.0 - print( - f"END: origin:{origin_id}, suspected_origins:{suspected_origin_ids}, success_rate:{success_rate:.2f}%" - ) success_rates.append(success_rate) df = pd.DataFrame(success_rates, columns=[COL_SUCCESS_RATE]) @@ -237,43 +236,39 @@ class Analysis: plt.grid(True) plt.show() - def timing_attack_with(self, receiver: "Node", window: int, remaining_hops: int, observed_hops: int, - suspected_origins: Counter, - sender: "Node" = None): - assert remaining_hops >= 1 + def timing_attack_with(self, receiver: "Node", time_received: Time, + remaining_hops: int, observed_hops: int, suspected_origins: Counter, + senders: dict["Node", list[Time]] = None): + if remaining_hops <= 0: + return + # If all nodes are already suspected, no need to inspect further. if len(suspected_origins) == len(self.sim.p2p.nodes): return - # Start inspecting senders who sent messages that were arrived in the receiver at the given window. + # Start inspecting senders who sent messages that were arrived in the receiver at the given time. # If the specific sender is given, inspect only that sender to maximize the success rate. - if sender is not None: - senders = {sender} - else: - senders = self.sim.p2p.adversary.msgs_received_per_window[window][receiver] - - # Suspect the receiver as the origin, if the receiver has not received any messages at the given window, - # and if the minimum number of hops has been observed. - if len(senders) == 0 and observed_hops > self.sim.config.mixnet.num_mix_layers: - suspected_origins.update({receiver}) - return - - # If the remaining_hops is 1, return the senders as suspected senders - if remaining_hops == 1: - suspected_origins.update(senders) - return + if senders is None: + senders = self.sim.p2p.adversary.msgs_received_per_time[time_received][receiver] # Inspect each sender who sent messages to the receiver - for sender in senders: - # Track back to each window where that sender might have received any messages. - time_range = self.config.mixnet.max_mix_delay + self.config.p2p.max_network_latency - window_range = int(time_range / self.config.adversary.window_size) - for prev_window in range(window - 1, window - 1 - window_range, -1): - if prev_window < 0: - break - self.timing_attack_with(sender, prev_window, remaining_hops - 1, observed_hops + 1, suspected_origins) + for sender, times_sent in senders.items(): + # Calculate the time range where the sender might have received any messages + # related to the message being traced. + min_time, max_time = sys.maxsize, 0 + for time_sent in times_sent: + min_time = min(min_time, time_sent - self.config.mixnet.max_mix_delay) + max_time = max(max_time, time_sent - self.config.mixnet.min_mix_delay) + # If the sender is sent the message around the message interval, suspect the sender as the origin. + if (self.sim.p2p.adversary.is_around_message_interval(time_sent) + and observed_hops >= self.sim.config.mixnet.num_mix_layers): + suspected_origins.update({sender}) - @staticmethod - def print_nodes_per_hop(nodes_per_hop, starting_window: int): - for hop, nodes in enumerate(nodes_per_hop): - print(f"hop-{hop} from w-{starting_window}: {len(nodes)} nodes: {sorted([node.id for node in nodes])}") + # Track back to each time when that sender might have received any messages. + for time_sender_received in range(max_time, min_time - 1, -1): + if time_sender_received < 0: + break + self.timing_attack_with( + sender, time_sender_received, + remaining_hops - 1, observed_hops + 1, suspected_origins + ) diff --git a/mixnet/v2/sim/config.py b/mixnet/v2/sim/config.py index bf34b4c..e9548bc 100644 --- a/mixnet/v2/sim/config.py +++ b/mixnet/v2/sim/config.py @@ -7,6 +7,8 @@ from typing import Self import dacite import yaml +from environment import Time + @dataclass class Config: @@ -14,7 +16,6 @@ class Config: mixnet: MixnetConfig p2p: P2PConfig measurement: MeasurementConfig - adversary: AdversaryConfig @classmethod def load(cls, yaml_path: str) -> Self: @@ -27,7 +28,6 @@ class Config: config.mixnet.validate() config.p2p.validate() config.measurement.validate() - config.adversary.validate() return config @@ -40,7 +40,7 @@ class Config: @dataclass class SimulationConfig: - running_time: int + running_time: Time def validate(self): assert self.running_time > 0 @@ -54,7 +54,7 @@ class MixnetConfig: payload_size: int # An interval of sending a new real/cover message # A probability of actually sending a message depends on the following parameters. - message_interval: int + message_interval: Time # A probability of sending a real message within one cycle real_message_prob: float # A weight of real message emission probability of some nodes @@ -64,10 +64,10 @@ class MixnetConfig: # A probability of sending a cover message within one cycle if not sending a real message cover_message_prob: float # A maximum preparation time (computation time) for a message sender before sending the message - max_message_prep_time: float + max_message_prep_time: Time # A maximum delay of messages mixed in a mix node - min_mix_delay: float - max_mix_delay: float + min_mix_delay: Time + max_mix_delay: Time def validate(self): assert self.num_nodes > 0 @@ -97,8 +97,8 @@ class MixnetConfig: def is_mixing_on(self) -> bool: return self.num_mix_layers > 0 - def random_mix_delay(self) -> float: - return random.uniform(self.min_mix_delay, self.max_mix_delay) + def random_mix_delay(self) -> Time: + return random.randint(self.min_mix_delay, self.max_mix_delay) @dataclass @@ -108,8 +108,8 @@ class P2PConfig: # A connection density, only if the type is gossip connection_density: int # A maximum network latency between nodes directly connected with each other - min_network_latency: float - max_network_latency: float + min_network_latency: Time + max_network_latency: Time TYPE_ONE_TO_ALL = "1-to-all" TYPE_GOSSIP = "gossip" @@ -128,23 +128,14 @@ class P2PConfig: f"max_net_latency: {self.max_network_latency:.2f}" ) - def random_network_latency(self) -> float: - return random.uniform(self.min_network_latency, self.max_network_latency) + def random_network_latency(self) -> Time: + return random.randint(self.min_network_latency, self.max_network_latency) @dataclass class MeasurementConfig: # How many times in simulation represent 1 second in real time - sim_time_per_second: float + sim_time_per_second: Time def validate(self): assert self.sim_time_per_second > 0 - - -@dataclass -class AdversaryConfig: - # A time window for the adversary to observe inputs and outputs of each node - window_size: float - - def validate(self): - assert self.window_size > 0 diff --git a/mixnet/v2/sim/config.yaml b/mixnet/v2/sim/config.yaml index 22558ce..0d7ea0b 100644 --- a/mixnet/v2/sim/config.yaml +++ b/mixnet/v2/sim/config.yaml @@ -1,17 +1,17 @@ simulation: # The simulation uses a virtual time. Please see README for more details. - running_time: 30 + running_time: 300 mixnet: num_nodes: 100 # A number of mix nodes selected by a message sender through which the Sphinx message goes through # If 0, the message is broadcast directly to all nodes without being Sphinx-encoded. - num_mix_layers: 1 + num_mix_layers: 2 # A size of a message payload in bytes (e.g. the size of a block proposal) payload_size: 320 # An interval of sending a new real/cover message # A probability of actually sending a message depends on the following parameters. - message_interval: 1 + message_interval: 10 # A probability of sending a real message within a cycle real_message_prob: 0.01 # A weight of real message emission probability of some nodes @@ -23,8 +23,8 @@ mixnet: # A maximum preparation time (computation time) for a message sender before sending the message max_message_prep_time: 0 # A maximum delay of messages mixed in a mix node - min_mix_delay: 0.0 - max_mix_delay: 0.0 + min_mix_delay: 0 + max_mix_delay: 0 p2p: # Broadcasting type: 1-to-all | gossip @@ -32,14 +32,9 @@ p2p: # A connection density, only if the type is gossip connection_density: 6 # A maximum network latency between nodes directly connected with each other - min_network_latency: 0.10 - max_network_latency: 0.20 + min_network_latency: 1 + max_network_latency: 1 measurement: # How many times in simulation represent 1 second in real time - sim_time_per_second: 1 - -adversary: - # A time window for the adversary to observe inputs and outputs of each node - # Recommendation: Same as `p2p.min_network_latency` - window_size: 0.10 \ No newline at end of file + sim_time_per_second: 10 \ No newline at end of file diff --git a/mixnet/v2/sim/environment.py b/mixnet/v2/sim/environment.py new file mode 100644 index 0000000..e110370 --- /dev/null +++ b/mixnet/v2/sim/environment.py @@ -0,0 +1,22 @@ +from typing import Optional, Any + +import simpy + +Time = int + + +class Environment: + def __init__(self): + self.env = simpy.Environment() + + def now(self) -> Time: + return Time(self.env.now) + + def run(self, until: Time) -> Optional[Any]: + return self.env.run(until=until) + + def timeout(self, timeout: Time) -> simpy.Timeout: + return self.env.timeout(timeout) + + def process(self, generator: simpy.events.ProcessGenerator) -> simpy.Process: + return self.env.process(generator) diff --git a/mixnet/v2/sim/measurement.py b/mixnet/v2/sim/measurement.py index 963ef7a..2b8bd3c 100644 --- a/mixnet/v2/sim/measurement.py +++ b/mixnet/v2/sim/measurement.py @@ -2,9 +2,9 @@ from collections import defaultdict, Counter from typing import TYPE_CHECKING import pandas as pd -import simpy from config import Config +from environment import Environment from sphinx import SphinxPacket if TYPE_CHECKING: @@ -12,12 +12,12 @@ if TYPE_CHECKING: class Measurement: - def __init__(self, env: simpy.Environment, config: Config): + def __init__(self, env: Environment, config: Config): self.env = env self.config = config self.original_senders = Counter() - self.egress_bandwidth_per_time = [] - self.ingress_bandwidth_per_time = [] + self.egress_bandwidth_per_sec = [] + self.ingress_bandwidth_per_sec = [] self.message_hops = defaultdict(int) # dict[msg_hash, hops] self.env.process(self._update_bandwidth_window()) @@ -30,24 +30,24 @@ class Measurement: self.original_senders.update({sender}) def measure_egress(self, node: "Node", msg: SphinxPacket | bytes): - self.egress_bandwidth_per_time[-1][node] += len(msg) + self.egress_bandwidth_per_sec[-1][node] += len(msg) def measure_ingress(self, node: "Node", msg: SphinxPacket | bytes): - self.ingress_bandwidth_per_time[-1][node] += len(msg) + self.ingress_bandwidth_per_sec[-1][node] += len(msg) def update_message_hops(self, msg_hash: bytes, hops: int): self.message_hops[msg_hash] = max(hops, self.message_hops[msg_hash]) def _update_bandwidth_window(self): while True: - self.ingress_bandwidth_per_time.append(defaultdict(int)) - self.egress_bandwidth_per_time.append(defaultdict(int)) + self.ingress_bandwidth_per_sec.append(defaultdict(int)) + self.egress_bandwidth_per_sec.append(defaultdict(int)) yield self.env.timeout(self.config.measurement.sim_time_per_second) def bandwidth(self) -> (pd.Series, pd.Series): nonzero_egresses, nonzero_ingresses = [], [] - for egress_bandwidths, ingress_bandwidths in zip(self.egress_bandwidth_per_time, - self.ingress_bandwidth_per_time): + for egress_bandwidths, ingress_bandwidths in zip(self.egress_bandwidth_per_sec, + self.ingress_bandwidth_per_sec): for bandwidth in egress_bandwidths.values(): if bandwidth > 0: nonzero_egresses.append(bandwidth / 1024.0) diff --git a/mixnet/v2/sim/node.py b/mixnet/v2/sim/node.py index 3bd9379..efe652f 100644 --- a/mixnet/v2/sim/node.py +++ b/mixnet/v2/sim/node.py @@ -4,11 +4,11 @@ import os import random from enum import Enum -import simpy from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from config import Config +from environment import Environment from measurement import Measurement from p2p import P2P from sphinx import SphinxPacket, Attachment @@ -18,7 +18,7 @@ class Node: INCENTIVE_TX_SIZE = 512 PADDING_SEPARATOR = b'\x01' - def __init__(self, id: int, env: simpy.Environment, p2p: P2P, config: Config, measurement: Measurement, + def __init__(self, id: int, env: Environment, p2p: P2P, config: Config, measurement: Measurement, operated_by_adversary: bool = False): self.id = id self.env = env @@ -45,7 +45,7 @@ class Node: self.measurement.count_original_sender(self) msg = self.create_message(message_type) - prep_time = random.uniform(0, self.config.mixnet.max_message_prep_time) + prep_time = random.randint(0, self.config.mixnet.max_message_prep_time) yield self.env.timeout(prep_time) self.log("Sending a message to the mixnet") @@ -73,7 +73,7 @@ class Node: if not self.config.mixnet.is_mixing_on(): return self.build_payload() - mixes = self.p2p.get_nodes(self.config.mixnet.num_mix_layers) + mixes = self.p2p.get_nodes(self.config.mixnet.num_mix_layers, self.id) public_keys = [mix.public_key for mix in mixes] # TODO: replace with realistic tx incentive_txs = [Node.create_incentive_tx(mix.public_key) for mix in mixes] @@ -98,7 +98,7 @@ class Node: self.env.process(self.p2p.broadcast(self, final_padded_msg)) else: # TODO: use Poisson delay or something else, if necessary - yield self.env.timeout(random.uniform(0, self.config.mixnet.max_mix_delay)) + yield self.env.timeout(random.randint(0, self.config.mixnet.max_mix_delay)) self.env.process(self.p2p.broadcast(self, msg)) # else: # self.log("Receiving SphinxPacket, but not mine") @@ -149,7 +149,7 @@ class Node: return tx == Node.create_incentive_tx(self.public_key) def log(self, msg): - print(f"t={self.env.now:.3f}: Node:{self.id}: {msg}") + print(f"t={self.env.now():.3f}: Node:{self.id}: {msg}") class MessageType(Enum): diff --git a/mixnet/v2/sim/p2p.py b/mixnet/v2/sim/p2p.py index 48d175e..b5d13ae 100644 --- a/mixnet/v2/sim/p2p.py +++ b/mixnet/v2/sim/p2p.py @@ -6,10 +6,9 @@ from abc import ABC, abstractmethod from collections import defaultdict from typing import TYPE_CHECKING -import simpy - from adversary import Adversary from config import Config +from environment import Environment, Time from measurement import Measurement from sphinx import SphinxPacket @@ -18,7 +17,7 @@ if TYPE_CHECKING: class P2P(ABC): - def __init__(self, env: simpy.Environment, config: Config): + def __init__(self, env: Environment, config: Config): self.env = env self.config = config self.nodes = [] @@ -29,8 +28,9 @@ class P2P(ABC): self.nodes = nodes self.measurement.set_nodes(nodes) - def get_nodes(self, n: int) -> list["Node"]: - return random.sample(self.nodes, n) + def get_nodes(self, n: int, exclude_node_id: int) -> list["Node"]: + candidates = self.nodes[:exclude_node_id] + self.nodes[exclude_node_id + 1:] + return random.sample(candidates, n) # This should accept only bytes in practice, # but we accept SphinxPacket as well because we don't implement Sphinx deserialization. @@ -42,6 +42,7 @@ class P2P(ABC): def send(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", is_first_of_broadcasting: bool): + time_sent = self.env.now() if sender != receiver: if is_first_of_broadcasting: self.adversary.inspect_message_size(msg) @@ -52,19 +53,20 @@ class P2P(ABC): yield self.env.timeout(self.config.p2p.random_network_latency()) self.measurement.measure_ingress(receiver, msg) - self.adversary.observe_receiving_node(sender, receiver) - self.receive(msg, hops_traveled + 1, sender, receiver) + self.adversary.observe_receiving_node(sender, receiver, time_sent) + self.receive(msg, hops_traveled + 1, sender, receiver, time_sent) @abstractmethod - def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node"): + def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + time_sent: Time): pass def log(self, msg): - print(f"t={self.env.now:.3f}: P2P: {msg}") + print(f"t={self.env.now():.3f}: P2P: {msg}") class NaiveBroadcastP2P(P2P): - def __init__(self, env: simpy.Environment, config: Config): + def __init__(self, env: Environment, config: Config): super().__init__(env, config) self.nodes = [] @@ -77,15 +79,16 @@ class NaiveBroadcastP2P(P2P): for i, receiver in enumerate(self.nodes): self.env.process(self.send(msg, 0, sender, receiver, i == 0)) - def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node"): + def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + time_sent: Time): msg_hash = hashlib.sha256(bytes(msg)).digest() self.measurement.update_message_hops(msg_hash, hops_traveled) - self.adversary.observe_if_final_msg(sender, receiver, msg) + self.adversary.observe_if_final_msg(sender, receiver, time_sent, msg) self.env.process(receiver.receive_message(msg)) class GossipP2P(P2P): - def __init__(self, env: simpy.Environment, config: Config): + def __init__(self, env: Environment, config: Config): super().__init__(env, config) self.topology = defaultdict(set) self.message_cache = defaultdict(dict) # dict[receiver, dict[msg_hash, sender]] @@ -128,14 +131,15 @@ class GossipP2P(P2P): self.env.process(self.send(msg, hops_traveled, sender, receiver, cnt == 0)) cnt += 1 - def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node"): + def receive(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", + time_sent: Time): # Receive/gossip the msg only if it hasn't been received before. If not, just ignore the msg. # i.e. each message is received/gossiped at most once by each node. msg_hash = hashlib.sha256(bytes(msg)).digest() if msg_hash not in self.message_cache[receiver]: self.message_cache[receiver][msg_hash] = sender self.measurement.update_message_hops(msg_hash, hops_traveled) - self.adversary.observe_if_final_msg(sender, receiver, msg) + self.adversary.observe_if_final_msg(sender, receiver, time_sent, msg) # Receive and gossip self.env.process(receiver.receive_message(msg)) diff --git a/mixnet/v2/sim/simulation.py b/mixnet/v2/sim/simulation.py index cef5caa..c174ea5 100644 --- a/mixnet/v2/sim/simulation.py +++ b/mixnet/v2/sim/simulation.py @@ -3,6 +3,7 @@ import random import simpy from config import Config, P2PConfig +from environment import Environment from node import Node from p2p import NaiveBroadcastP2P, GossipP2P @@ -11,7 +12,7 @@ class Simulation: def __init__(self, config: Config): random.seed() self.config = config - self.env = simpy.Environment() + self.env = Environment() self.p2p = Simulation.init_p2p(self.env, config) nodes = [Node(i, self.env, self.p2p, config, self.p2p.measurement, i == 0) for i in range(config.mixnet.num_nodes)]