use timeout to join process

This commit is contained in:
Youngjoon Lee 2024-06-13 16:09:25 +09:00
parent fd771893ca
commit dea14e76b8
No known key found for this signature in database
GPG Key ID: 09B750B5BD6F08A2
4 changed files with 39 additions and 19 deletions

View File

@ -1,6 +1,7 @@
import itertools
import multiprocessing
import sys
import threading
from collections import Counter
from typing import TYPE_CHECKING
@ -244,35 +245,52 @@ class Analysis:
tasks = self.prepare_timing_attack_tasks(hops_between_layers)
print(f"{len(tasks)} TASKS")
# Spawn process for each task
processes = []
results = multiprocessing.Manager().list()
accuracy_results = multiprocessing.Manager().list()
for task in tasks:
process = multiprocessing.Process(target=self.spawn_timing_attack, args=(task, results))
process = multiprocessing.Process(target=self.spawn_timing_attack, args=(task, accuracy_results))
process.start()
processes.append(process)
# Join processes using threading to apply a timeout to all processes almost simultaneously.
threads = []
for process in processes:
process.join()
thread = threading.Thread(target=Analysis.join_process,
args=(process, self.config.adversary.timing_attack_timeout))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
return list(results)
return list(accuracy_results)
def spawn_timing_attack(self, task, results):
origin_id, receiver, time_received, remaining_hops, observed_hops, suspected_origins, senders = task
def spawn_timing_attack(self, task, accuracy_results):
origin_id, receiver, time_received, remaining_hops, observed_hops, senders = task
result = self.run_and_evaluate_timing_attack(
origin_id, receiver, time_received, remaining_hops, observed_hops, suspected_origins, senders
origin_id, receiver, time_received, remaining_hops, observed_hops, senders
)
results.append(result)
print(f"{len(results)} PROCESSES DONE")
accuracy_results.append(result)
print(f"{len(accuracy_results)} PROCESSES DONE")
@staticmethod
def join_process(process, timeout):
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
print(f"PROCESS TIMED OUT")
def prepare_timing_attack_tasks(self, hops_between_layers: int) -> list:
hops_to_observe = hops_between_layers * (self.config.mixnet.num_mix_layers + 1)
tasks = []
# Prepare a task for each real message received by the adversary
for receiver, times_and_msgs in self.sim.p2p.adversary.final_msgs_received.items():
for time_received, msgs in times_and_msgs.items():
for sender, time_sent, origin_id in msgs:
tasks.append((
origin_id, receiver, time_received, hops_to_observe, 0, Counter(), {sender: [time_sent]}
origin_id, receiver, time_received, hops_to_observe, 0, {sender: [time_sent]}
))
if len(tasks) >= self.config.adversary.timing_attack_max_targets:
return tasks
@ -280,14 +298,13 @@ class Analysis:
return tasks
def run_and_evaluate_timing_attack(self, origin_id: int, receiver: "Node", time_received: Time,
remaining_hops: int, observed_hops: int, suspected_origins: Counter,
remaining_hops: int, observed_hops: int,
senders: dict["Node", list[Time]] = None) -> float:
suspected_origins = self.timing_attack_from_receiver(
receiver, time_received, remaining_hops, observed_hops, suspected_origins, senders
receiver, time_received, remaining_hops, observed_hops, Counter(), senders
)
suspected_origin_ids = {node.id for node in suspected_origins}
if origin_id in suspected_origin_ids:
return 1 / len(suspected_origin_ids) * 100.0
if origin_id in suspected_origins:
return 1 / len(suspected_origins) * 100.0
else:
return 0.0
@ -319,7 +336,7 @@ class Analysis:
# If the sender is sent the message around the message interval, suspect the sender as the origin.
if (self.sim.p2p.adversary.is_around_message_interval(time_sent)
and observed_hops + 1 >= self.min_hops_to_observe_for_timing_attack()):
suspected_origins.update({sender})
suspected_origins.update({sender.id})
# Track back to each time when that sender might have received any messages.
for time_sender_received in range(max_time, min_time - 1, -1):

View File

@ -31,7 +31,7 @@ def bulk_attack():
args = parser.parse_args()
config = Config.load(args.config)
config.simulation.running_time = 300
config.simulation.running_time = 200
config.mixnet.num_nodes = 100
config.mixnet.payload_size = 320
config.mixnet.message_interval = 10

View File

@ -145,9 +145,11 @@ class MeasurementConfig:
@dataclass
class AdversaryConfig:
timing_attack_timeout: int
timing_attack_max_targets: int
timing_attack_max_pool_size: int
def validate(self):
assert self.timing_attack_timeout > 0
assert self.timing_attack_max_targets > 0
assert self.timing_attack_max_pool_size > 0

View File

@ -40,5 +40,6 @@ measurement:
sim_time_per_second: 10
adversary:
timing_attack_max_targets: 5
timing_attack_max_pool_size: 3
timing_attack_timeout: 300
timing_attack_max_targets: 10000000000
timing_attack_max_pool_size: 100