fix timing attack

This commit is contained in:
Youngjoon Lee 2024-06-10 15:48:03 +09:00
parent 3fa8af8850
commit 69598e836c
No known key found for this signature in database
GPG Key ID: 09B750B5BD6F08A2
7 changed files with 49 additions and 57 deletions

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import math
from collections import defaultdict, deque, Counter
from enum import Enum
from typing import TYPE_CHECKING
@ -23,6 +22,7 @@ class Adversary:
self.senders_around_interval = Counter()
self.io_windows = [] # dict[receiver, (deque[time_received], set[sender]))]
self.io_windows.append(defaultdict(lambda: (deque(), set())))
self.final_msgs_received = defaultdict(dict) # dict[receiver, dict[window, sender]]
# self.node_states = defaultdict(dict)
self.env.process(self.update_observation_window())
@ -30,16 +30,19 @@ class Adversary:
def inspect_message_size(self, msg: SphinxPacket | bytes):
self.message_sizes.append(len(msg))
def observe_receiving_node(self, sender: "Node", receiver: "Node"):
def observe_receiving_node(self, sender: "Node", receiver: "Node", msg: SphinxPacket | bytes):
msg_queue, senders = self.io_windows[-1][receiver]
msg_queue.append(self.env.now)
senders.add(sender)
if receiver.operated_by_adversary and not isinstance(msg, SphinxPacket):
self.final_msgs_received[receiver][len(self.io_windows) - 1] = sender
# if node not in self.node_states[self.env.now]:
# self.node_states[self.env.now][node] = NodeState.RECEIVING
def observe_sending_node(self, sender: "Node"):
msg_queue, _ = self.io_windows[-1][sender]
if len(msg_queue) > 0:
# Adversary doesn't know which message in the pool is being emitted. So, pop the oldest one from the pool.
msg_queue.popleft()
if self.is_around_message_interval(self.env.now):
self.senders_around_interval.update({sender})
@ -54,6 +57,7 @@ class Adversary:
new_window = defaultdict(lambda: (deque(), set()))
for receiver, (msg_queue, _) in self.io_windows[-1].items():
for time_received in msg_queue:
# If the message is likely to be still pending and be emitted soon, pass it on to the next window.
if self.env.now - time_received < self.config.mixnet.max_mix_delay:
new_window[receiver][0].append(time_received)
self.io_windows.append(new_window)

View File

@ -3,9 +3,9 @@ from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import scipy.stats as stats
import seaborn
from matplotlib import pyplot as plt
import scipy.stats as stats
from adversary import NodeState
from config import Config
@ -37,8 +37,8 @@ class Analysis:
self.messages_emitted_around_interval()
self.messages_in_node_over_time()
# self.node_states()
self.message_hops()
# self.timing_attack(median_hops)
median_hops = self.message_hops()
self.timing_attack(median_hops)
def bandwidth(self, message_size_df: pd.DataFrame):
dataframes = []
@ -96,8 +96,9 @@ class Analysis:
def messages_emitted_around_interval(self):
# A ground truth that shows how many times each node sent a real message
truth_df = pd.DataFrame([(node.id, count) for node, count in self.sim.p2p.measurement.original_senders.items()],
columns=[COL_NODE_ID, COL_MSG_CNT])
truth_df = pd.DataFrame(
[(node.id, count) for node, count in self.sim.p2p.measurement.original_senders.items()],
columns=[COL_NODE_ID, COL_MSG_CNT])
# A result of observing nodes who have sent messages around the promised message interval
suspected_df = pd.DataFrame(
[(node.id, self.sim.p2p.adversary.senders_around_interval[node]) for node in
@ -194,22 +195,12 @@ class Analysis:
def timing_attack(self, hops_between_layers: int):
hops_to_observe = hops_between_layers * (self.config.mixnet.num_mix_layers + 1)
all_results = Counter()
window = len(self.sim.p2p.adversary.io_windows) - 1
while window >= 0:
items = self.sim.p2p.adversary.io_windows[window].items()
actual_receivers = [receiver for receiver, (_, senders) in items if len(senders) > 0]
if len(actual_receivers) == 0:
window -= 1
continue
suspected_senders = Counter()
for receiver, windows_and_senders in self.sim.p2p.adversary.final_msgs_received.items():
for window, sender in windows_and_senders.items():
suspected_senders.update(self.timing_attack_with(receiver, window, hops_to_observe, sender))
for receiver in actual_receivers:
suspected_senders = self.timing_attack_with(receiver, window, hops_to_observe)
# self.print_nodes_per_hop(suspected_senders, window)
all_results.update(suspected_senders)
window -= 1
suspected_senders = ({node.id: count for node, count in all_results.items()})
suspected_senders = ({node.id: count for node, count in suspected_senders.items()})
print(f"suspected nodes count: {len(suspected_senders)}")
# Create the bar plot for original sender counts
@ -262,40 +253,33 @@ class Analysis:
plt.legend()
plt.show()
def timing_attack_with(self, receiver: "Node", window: int, remaining_hops: int) -> Counter:
def timing_attack_with(self, receiver: "Node", window: int, remaining_hops: int, sender: "Node" = None) -> Counter:
assert remaining_hops >= 1
# Start inspecting senders who sent messages that were arrived in the receiver at the given window
_, senders = self.sim.p2p.adversary.io_windows[window][receiver]
if sender is not None:
senders = {sender}
else:
_, senders = self.sim.p2p.adversary.io_windows[window][receiver]
# If the remaining_hops is 1, return the senders as suspected senders
if remaining_hops == 1:
return Counter(senders)
# A result to be returned after inspecting all senders who sent messages to the receiver
all_suspected_senders = Counter()
suspected_senders = Counter()
# Inspect each sender who sent messages to the receiver
for sender in senders:
# A sub-result to be filled when tracking back further from the sender
suspected_senders = Counter()
# Track back to each window where that sender might have received any messages.
time_range = self.config.mixnet.max_mix_delay + self.config.p2p.max_network_latency - self.config.adversary.io_window_size
time_range = self.config.mixnet.max_mix_delay + self.config.p2p.max_network_latency
window_range = int(time_range / self.config.adversary.io_window_size)
for prev_window in range(window - 1, window - 1 - window_range, -1):
if prev_window < 0:
break
suspected_senders.update(self.timing_attack_with(sender, prev_window, remaining_hops - 1))
# If there is no suspected sender gathered, we can assume that the sender is the original sender
# because it means that nobody has sent messages to the sender within the reasonable time window
if len(suspected_senders) == 0:
all_suspected_senders.update({sender})
else:
all_suspected_senders.update(suspected_senders)
return all_suspected_senders
return suspected_senders
@staticmethod
def print_nodes_per_hop(nodes_per_hop, starting_window: int):

View File

@ -6,7 +6,7 @@ mixnet:
num_nodes: 100
# A number of mix nodes selected by a message sender through which the Sphinx message goes through
# If 0, the message is broadcast directly to all nodes without being Sphinx-encoded.
num_mix_layers: 4
num_mix_layers: 1
# A size of a message payload in bytes (e.g. the size of a block proposal)
payload_size: 320
# An interval of sending a new real/cover message
@ -17,9 +17,9 @@ mixnet:
# A weight of real message emission probability of some nodes
# Each weight is multiplied to the real_message_prob of the node being at the same position in the node list.
# The length of the list should be <= p2p.num_nodes. i.e. some nodes won't have a weight.
real_message_prob_weights: []
real_message_prob_weights: [ ]
# A probability of sending a cover message within a cycle if not sending a real message
cover_message_prob: 0.05
cover_message_prob: 0.00
# A maximum preparation time (computation time) for a message sender before sending the message
max_message_prep_time: 0
# A maximum delay of messages mixed in a mix node
@ -28,7 +28,7 @@ mixnet:
p2p:
# Broadcasting type: 1-to-all | gossip
type: "gossip"
type: "1-to-all"
# A connection density, only if the type is gossip
connection_density: 6
# A maximum network latency between nodes directly connected with each other
@ -36,8 +36,8 @@ p2p:
max_network_latency: 0.20
measurement:
# How many times in simulation represent 1 second in real time
sim_time_per_second: 1
# How many times in simulation represent 1 second in real time
sim_time_per_second: 1
adversary:
# A time window for the adversary to observe inputs and outputs of each node

View File

@ -27,7 +27,7 @@ class Measurement:
self.original_senders[node] = 0
def count_original_sender(self, sender: "Node"):
self.original_senders[sender] += 1
self.original_senders.update({sender})
def measure_egress(self, node: "Node", msg: SphinxPacket | bytes):
self.egress_bandwidth_per_time[-1][node] += len(msg)

View File

@ -10,15 +10,16 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X
from config import Config
from measurement import Measurement
from sphinx import SphinxPacket, Attachment
from p2p import P2P
from sphinx import SphinxPacket, Attachment
class Node:
INCENTIVE_TX_SIZE = 512
PADDING_SEPARATOR = b'\x01'
def __init__(self, id: int, env: simpy.Environment, p2p: P2P, config: Config, measurement: Measurement):
def __init__(self, id: int, env: simpy.Environment, p2p: P2P, config: Config, measurement: Measurement,
operated_by_adversary: bool = False):
self.id = id
self.env = env
self.p2p = p2p
@ -27,6 +28,7 @@ class Node:
self.config = config
self.payload_id = 0
self.measurement = measurement
self.operated_by_adversary = operated_by_adversary
self.action = self.env.process(self.send_message())
def send_message(self):
@ -135,4 +137,4 @@ class Node:
class MessageType(Enum):
REAL = 0
COVER = 1
COVER = 1

View File

@ -44,16 +44,17 @@ class P2P(ABC):
def send(self, msg: SphinxPacket | bytes, hops_traveled: int, sender: "Node", receiver: "Node",
is_first_of_broadcasting: bool):
if is_first_of_broadcasting:
self.adversary.inspect_message_size(msg)
self.adversary.observe_sending_node(sender)
self.measurement.measure_egress(sender, msg)
if sender != receiver:
if is_first_of_broadcasting:
self.adversary.inspect_message_size(msg)
self.adversary.observe_sending_node(sender)
self.measurement.measure_egress(sender, msg)
# simulate network latency
yield self.env.timeout(self.config.p2p.random_network_latency())
# simulate network latency
yield self.env.timeout(self.config.p2p.random_network_latency())
self.measurement.measure_ingress(receiver, msg)
self.adversary.observe_receiving_node(sender, receiver)
self.measurement.measure_ingress(receiver, msg)
self.adversary.observe_receiving_node(sender, receiver, msg)
self.receive(msg, hops_traveled + 1, sender, receiver)
@abstractmethod

View File

@ -13,7 +13,8 @@ class Simulation:
self.config = config
self.env = simpy.Environment()
self.p2p = Simulation.init_p2p(self.env, config)
nodes = [Node(i, self.env, self.p2p, config, self.p2p.measurement) for i in range(config.mixnet.num_nodes)]
nodes = [Node(i, self.env, self.p2p, config, self.p2p.measurement, i == 0) for i in
range(config.mixnet.num_nodes)]
self.p2p.set_nodes(nodes)
def run(self):
@ -27,4 +28,4 @@ class Simulation:
case P2PConfig.TYPE_GOSSIP:
return GossipP2P(env, config)
case _:
raise ValueError("Unknown P2P type")
raise ValueError("Unknown P2P type")