From 69598e836cdf9c24ed8c17575cc4db0c28669336 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:48:03 +0900 Subject: [PATCH] fix timing attack --- mixnet/v2/sim/adversary.py | 8 ++++-- mixnet/v2/sim/analysis.py | 54 +++++++++++++----------------------- mixnet/v2/sim/config.yaml | 12 ++++---- mixnet/v2/sim/measurement.py | 2 +- mixnet/v2/sim/node.py | 8 ++++-- mixnet/v2/sim/p2p.py | 17 ++++++------ mixnet/v2/sim/simulation.py | 5 ++-- 7 files changed, 49 insertions(+), 57 deletions(-) diff --git a/mixnet/v2/sim/adversary.py b/mixnet/v2/sim/adversary.py index b8f38ec..14927a4 100644 --- a/mixnet/v2/sim/adversary.py +++ b/mixnet/v2/sim/adversary.py @@ -1,6 +1,5 @@ from __future__ import annotations -import math from collections import defaultdict, deque, Counter from enum import Enum from typing import TYPE_CHECKING @@ -23,6 +22,7 @@ class Adversary: self.senders_around_interval = Counter() self.io_windows = [] # dict[receiver, (deque[time_received], set[sender]))] self.io_windows.append(defaultdict(lambda: (deque(), set()))) + self.final_msgs_received = defaultdict(dict) # dict[receiver, dict[window, sender]] # self.node_states = defaultdict(dict) self.env.process(self.update_observation_window()) @@ -30,16 +30,19 @@ class Adversary: def inspect_message_size(self, msg: SphinxPacket | bytes): self.message_sizes.append(len(msg)) - def observe_receiving_node(self, sender: "Node", receiver: "Node"): + def observe_receiving_node(self, sender: "Node", receiver: "Node", msg: SphinxPacket | bytes): msg_queue, senders = self.io_windows[-1][receiver] msg_queue.append(self.env.now) senders.add(sender) + if receiver.operated_by_adversary and not isinstance(msg, SphinxPacket): + self.final_msgs_received[receiver][len(self.io_windows) - 1] = sender # if node not in self.node_states[self.env.now]: # self.node_states[self.env.now][node] = NodeState.RECEIVING def observe_sending_node(self, sender: "Node"): msg_queue, _ = self.io_windows[-1][sender] if len(msg_queue) > 0: + # Adversary doesn't know which message in the pool is being emitted. So, pop the oldest one from the pool. msg_queue.popleft() if self.is_around_message_interval(self.env.now): self.senders_around_interval.update({sender}) @@ -54,6 +57,7 @@ class Adversary: new_window = defaultdict(lambda: (deque(), set())) for receiver, (msg_queue, _) in self.io_windows[-1].items(): for time_received in msg_queue: + # If the message is likely to be still pending and be emitted soon, pass it on to the next window. if self.env.now - time_received < self.config.mixnet.max_mix_delay: new_window[receiver][0].append(time_received) self.io_windows.append(new_window) diff --git a/mixnet/v2/sim/analysis.py b/mixnet/v2/sim/analysis.py index 5256f11..438255e 100644 --- a/mixnet/v2/sim/analysis.py +++ b/mixnet/v2/sim/analysis.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING import numpy as np import pandas as pd +import scipy.stats as stats import seaborn from matplotlib import pyplot as plt -import scipy.stats as stats from adversary import NodeState from config import Config @@ -37,8 +37,8 @@ class Analysis: self.messages_emitted_around_interval() self.messages_in_node_over_time() # self.node_states() - self.message_hops() - # self.timing_attack(median_hops) + median_hops = self.message_hops() + self.timing_attack(median_hops) def bandwidth(self, message_size_df: pd.DataFrame): dataframes = [] @@ -96,8 +96,9 @@ class Analysis: def messages_emitted_around_interval(self): # A ground truth that shows how many times each node sent a real message - truth_df = pd.DataFrame([(node.id, count) for node, count in self.sim.p2p.measurement.original_senders.items()], - columns=[COL_NODE_ID, COL_MSG_CNT]) + truth_df = pd.DataFrame( + [(node.id, count) for node, count in self.sim.p2p.measurement.original_senders.items()], + columns=[COL_NODE_ID, COL_MSG_CNT]) # A result of observing nodes who have sent messages around the promised message interval suspected_df = pd.DataFrame( [(node.id, self.sim.p2p.adversary.senders_around_interval[node]) for node in @@ -194,22 +195,12 @@ class Analysis: def timing_attack(self, hops_between_layers: int): hops_to_observe = hops_between_layers * (self.config.mixnet.num_mix_layers + 1) - all_results = Counter() - window = len(self.sim.p2p.adversary.io_windows) - 1 - while window >= 0: - items = self.sim.p2p.adversary.io_windows[window].items() - actual_receivers = [receiver for receiver, (_, senders) in items if len(senders) > 0] - if len(actual_receivers) == 0: - window -= 1 - continue + suspected_senders = Counter() + for receiver, windows_and_senders in self.sim.p2p.adversary.final_msgs_received.items(): + for window, sender in windows_and_senders.items(): + suspected_senders.update(self.timing_attack_with(receiver, window, hops_to_observe, sender)) - for receiver in actual_receivers: - suspected_senders = self.timing_attack_with(receiver, window, hops_to_observe) - # self.print_nodes_per_hop(suspected_senders, window) - all_results.update(suspected_senders) - window -= 1 - - suspected_senders = ({node.id: count for node, count in all_results.items()}) + suspected_senders = ({node.id: count for node, count in suspected_senders.items()}) print(f"suspected nodes count: {len(suspected_senders)}") # Create the bar plot for original sender counts @@ -262,40 +253,33 @@ class Analysis: plt.legend() plt.show() - def timing_attack_with(self, receiver: "Node", window: int, remaining_hops: int) -> Counter: + def timing_attack_with(self, receiver: "Node", window: int, remaining_hops: int, sender: "Node" = None) -> Counter: assert remaining_hops >= 1 # Start inspecting senders who sent messages that were arrived in the receiver at the given window - _, senders = self.sim.p2p.adversary.io_windows[window][receiver] + if sender is not None: + senders = {sender} + else: + _, senders = self.sim.p2p.adversary.io_windows[window][receiver] # If the remaining_hops is 1, return the senders as suspected senders if remaining_hops == 1: return Counter(senders) # A result to be returned after inspecting all senders who sent messages to the receiver - all_suspected_senders = Counter() + suspected_senders = Counter() # Inspect each sender who sent messages to the receiver for sender in senders: - # A sub-result to be filled when tracking back further from the sender - suspected_senders = Counter() - # Track back to each window where that sender might have received any messages. - time_range = self.config.mixnet.max_mix_delay + self.config.p2p.max_network_latency - self.config.adversary.io_window_size + time_range = self.config.mixnet.max_mix_delay + self.config.p2p.max_network_latency window_range = int(time_range / self.config.adversary.io_window_size) for prev_window in range(window - 1, window - 1 - window_range, -1): if prev_window < 0: break suspected_senders.update(self.timing_attack_with(sender, prev_window, remaining_hops - 1)) - # If there is no suspected sender gathered, we can assume that the sender is the original sender - # because it means that nobody has sent messages to the sender within the reasonable time window - if len(suspected_senders) == 0: - all_suspected_senders.update({sender}) - else: - all_suspected_senders.update(suspected_senders) - - return all_suspected_senders + return suspected_senders @staticmethod def print_nodes_per_hop(nodes_per_hop, starting_window: int): diff --git a/mixnet/v2/sim/config.yaml b/mixnet/v2/sim/config.yaml index 747df0b..154b1fa 100644 --- a/mixnet/v2/sim/config.yaml +++ b/mixnet/v2/sim/config.yaml @@ -6,7 +6,7 @@ mixnet: num_nodes: 100 # A number of mix nodes selected by a message sender through which the Sphinx message goes through # If 0, the message is broadcast directly to all nodes without being Sphinx-encoded. - num_mix_layers: 4 + num_mix_layers: 1 # A size of a message payload in bytes (e.g. the size of a block proposal) payload_size: 320 # An interval of sending a new real/cover message @@ -17,9 +17,9 @@ mixnet: # A weight of real message emission probability of some nodes # Each weight is multiplied to the real_message_prob of the node being at the same position in the node list. # The length of the list should be <= p2p.num_nodes. i.e. some nodes won't have a weight. - real_message_prob_weights: [] + real_message_prob_weights: [ ] # A probability of sending a cover message within a cycle if not sending a real message - cover_message_prob: 0.05 + cover_message_prob: 0.00 # A maximum preparation time (computation time) for a message sender before sending the message max_message_prep_time: 0 # A maximum delay of messages mixed in a mix node @@ -28,7 +28,7 @@ mixnet: p2p: # Broadcasting type: 1-to-all | gossip - type: "gossip" + type: "1-to-all" # A connection density, only if the type is gossip connection_density: 6 # A maximum network latency between nodes directly connected with each other @@ -36,8 +36,8 @@ p2p: max_network_latency: 0.20 measurement: - # How many times in simulation represent 1 second in real time - sim_time_per_second: 1 + # How many times in simulation represent 1 second in real time + sim_time_per_second: 1 adversary: # A time window for the adversary to observe inputs and outputs of each node diff --git a/mixnet/v2/sim/measurement.py b/mixnet/v2/sim/measurement.py index bb462be..963ef7a 100644 --- a/mixnet/v2/sim/measurement.py +++ b/mixnet/v2/sim/measurement.py @@ -27,7 +27,7 @@ class Measurement: self.original_senders[node] = 0 def count_original_sender(self, sender: "Node"): - self.original_senders[sender] += 1 + self.original_senders.update({sender}) def measure_egress(self, node: "Node", msg: SphinxPacket | bytes): self.egress_bandwidth_per_time[-1][node] += len(msg) diff --git a/mixnet/v2/sim/node.py b/mixnet/v2/sim/node.py index e4065c6..9833d75 100644 --- a/mixnet/v2/sim/node.py +++ b/mixnet/v2/sim/node.py @@ -10,15 +10,16 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X from config import Config from measurement import Measurement -from sphinx import SphinxPacket, Attachment from p2p import P2P +from sphinx import SphinxPacket, Attachment class Node: INCENTIVE_TX_SIZE = 512 PADDING_SEPARATOR = b'\x01' - def __init__(self, id: int, env: simpy.Environment, p2p: P2P, config: Config, measurement: Measurement): + def __init__(self, id: int, env: simpy.Environment, p2p: P2P, config: Config, measurement: Measurement, + operated_by_adversary: bool = False): self.id = id self.env = env self.p2p = p2p @@ -27,6 +28,7 @@ class Node: self.config = config self.payload_id = 0 self.measurement = measurement + self.operated_by_adversary = operated_by_adversary self.action = self.env.process(self.send_message()) def send_message(self): @@ -135,4 +137,4 @@ class Node: class MessageType(Enum): REAL = 0 - COVER = 1 \ No newline at end of file + COVER = 1 diff --git a/mixnet/v2/sim/p2p.py b/mixnet/v2/sim/p2p.py index 62d66b9..1b43613 100644 --- a/mixnet/v2/sim/p2p.py +++ b/mixnet/v2/sim/p2p.py @@ -44,16 +44,17 @@ class P2P(ABC): def send(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node", is_first_of_broadcasting: bool): - if is_first_of_broadcasting: - self.adversary.inspect_message_size(msg) - self.adversary.observe_sending_node(sender) - self.measurement.measure_egress(sender, msg) + if sender != receiver: + if is_first_of_broadcasting: + self.adversary.inspect_message_size(msg) + self.adversary.observe_sending_node(sender) + self.measurement.measure_egress(sender, msg) - # simulate network latency - yield self.env.timeout(self.config.p2p.random_network_latency()) + # simulate network latency + yield self.env.timeout(self.config.p2p.random_network_latency()) - self.measurement.measure_ingress(receiver, msg) - self.adversary.observe_receiving_node(sender, receiver) + self.measurement.measure_ingress(receiver, msg) + self.adversary.observe_receiving_node(sender, receiver, msg) self.receive(msg, hops_traveled + 1, sender, receiver) @abstractmethod diff --git a/mixnet/v2/sim/simulation.py b/mixnet/v2/sim/simulation.py index 4945401..cef5caa 100644 --- a/mixnet/v2/sim/simulation.py +++ b/mixnet/v2/sim/simulation.py @@ -13,7 +13,8 @@ class Simulation: self.config = config self.env = simpy.Environment() self.p2p = Simulation.init_p2p(self.env, config) - nodes = [Node(i, self.env, self.p2p, config, self.p2p.measurement) for i in range(config.mixnet.num_nodes)] + nodes = [Node(i, self.env, self.p2p, config, self.p2p.measurement, i == 0) for i in + range(config.mixnet.num_nodes)] self.p2p.set_nodes(nodes) def run(self): @@ -27,4 +28,4 @@ class Simulation: case P2PConfig.TYPE_GOSSIP: return GossipP2P(env, config) case _: - raise ValueError("Unknown P2P type") \ No newline at end of file + raise ValueError("Unknown P2P type")