mirror of
https://github.com/codex-storage/bittorrent-benchmarks.git
synced 2025-02-13 11:36:28 +00:00
fix: prevent pending tasks from racing teardown
This commit is contained in:
parent
90dda4f932
commit
63b4c51048
@ -60,6 +60,8 @@ def cmd_run_experiment(experiments: Dict[str, ExperimentBuilder[Experiment]], ar
|
||||
logger.info(DECLogEntry.adapt_instance(experiment))
|
||||
experiment.build().run()
|
||||
|
||||
print(f"Experiment {args.experiment} completed successfully.")
|
||||
|
||||
|
||||
def cmd_describe_experiment(args):
|
||||
if not args.type:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from concurrent import futures
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import Iterable, Iterator, List
|
||||
from typing import Iterable, Iterator, List, cast
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
@ -50,8 +51,27 @@ def pflatmap(
|
||||
yield item
|
||||
|
||||
# This will cause any exceptions thrown in tasks to be re-raised.
|
||||
for future in task_futures:
|
||||
future.result()
|
||||
ensure_successful(task_futures)
|
||||
|
||||
finally:
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
|
||||
def ensure_successful(futs: Iterable[futures.Future[T]]) -> List[T]:
|
||||
future_list = list(futs)
|
||||
futures.wait(future_list, return_when=futures.ALL_COMPLETED)
|
||||
|
||||
# We treat cancelled futures as if they were successful.
|
||||
exceptions = [
|
||||
fut.exception()
|
||||
for fut in future_list
|
||||
if not fut.cancelled() and fut.exception() is not None
|
||||
]
|
||||
|
||||
if exceptions:
|
||||
raise ExceptionGroup(
|
||||
"One or more computations failed to complete successfully",
|
||||
cast(List[Exception], exceptions),
|
||||
)
|
||||
|
||||
return [cast(T, fut.result()) for fut in future_list]
|
||||
|
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from typing_extensions import Generic, List, Tuple
|
||||
|
||||
from benchmarks.core.concurrency import ensure_successful
|
||||
from benchmarks.core.experiments.experiments import ExperimentWithLifecycle
|
||||
from benchmarks.core.network import (
|
||||
TInitialMetadata,
|
||||
@ -36,8 +38,8 @@ class StaticDisseminationExperiment(
|
||||
self.file_size = file_size
|
||||
self.seed = seed
|
||||
|
||||
self._pool = ThreadPool(
|
||||
processes=len(network) - len(seeders)
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=len(network) - len(seeders)
|
||||
if concurrency is None
|
||||
else concurrency
|
||||
)
|
||||
@ -71,20 +73,34 @@ class StaticDisseminationExperiment(
|
||||
_log_request(leecher, "leech", str(self.meta), RequestEventType.end)
|
||||
return download
|
||||
|
||||
downloads = list(self._pool.imap_unordered(_leech, leechers))
|
||||
|
||||
logger.info("Now waiting for downloads to complete")
|
||||
|
||||
def _await_for_download(element: Tuple[int, DownloadHandle]) -> int:
|
||||
downloads = ensure_successful(
|
||||
[self._executor.submit(_leech, leecher) for leecher in leechers]
|
||||
)
|
||||
|
||||
def _await_for_download(
|
||||
element: Tuple[int, DownloadHandle],
|
||||
) -> Tuple[int, DownloadHandle]:
|
||||
index, download = element
|
||||
if not download.await_for_completion():
|
||||
raise Exception(
|
||||
f"Download ({index}, {str(download)}) did not complete in time."
|
||||
)
|
||||
return index
|
||||
logger.info(
|
||||
"Download %d / %d completed (node: %s)",
|
||||
index + 1,
|
||||
len(downloads),
|
||||
download.node.name,
|
||||
)
|
||||
return element
|
||||
|
||||
for i in self._pool.imap_unordered(_await_for_download, enumerate(downloads)):
|
||||
logger.info("Download %d / %d completed", i + 1, len(downloads))
|
||||
ensure_successful(
|
||||
[
|
||||
self._executor.submit(_await_for_download, (i, download))
|
||||
for i, download in enumerate(downloads)
|
||||
]
|
||||
)
|
||||
|
||||
# FIXME this is a hack to ensure that nodes get a chance to log their data before we
|
||||
# run the teardown hook and remove the torrents.
|
||||
@ -96,15 +112,21 @@ class StaticDisseminationExperiment(
|
||||
index, node = element
|
||||
assert self._cid is not None # to please mypy
|
||||
node.remove(self._cid)
|
||||
return index
|
||||
logger.info("Node %d (%s) removed file", index + 1, node.name)
|
||||
return element
|
||||
|
||||
try:
|
||||
for i in self._pool.imap_unordered(_remove, enumerate(self.nodes)):
|
||||
logger.info("Node %d removed file", i + 1)
|
||||
# Since teardown might be called as the result of an exception, it's expected
|
||||
# that not all removes will succeed, so we don't check their result.
|
||||
ensure_successful(
|
||||
[
|
||||
self._executor.submit(_remove, (i, node))
|
||||
for i, node in enumerate(self.nodes)
|
||||
]
|
||||
)
|
||||
finally:
|
||||
logger.info("Shut down thread pool.")
|
||||
self._pool.close()
|
||||
self._pool.join()
|
||||
self._executor.shutdown(wait=True)
|
||||
logger.info("Done.")
|
||||
|
||||
def _split_nodes(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from io import StringIO
|
||||
from typing import Optional, List
|
||||
@ -19,12 +20,22 @@ class MockGenData:
|
||||
|
||||
|
||||
class MockNode(Node[MockGenData, str]):
|
||||
def __init__(self, name="mock_node") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name="mock_node",
|
||||
download_lag: float = 0,
|
||||
should_fail_download: bool = False,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self.seeding: Optional[MockGenData] = None
|
||||
self.leeching: Optional[MockGenData] = None
|
||||
self.download_was_awaited = False
|
||||
|
||||
self.cleanup_was_called = False
|
||||
self.download_lag = download_lag
|
||||
self.download_completed = False
|
||||
self.download_failed = False
|
||||
|
||||
self.should_fail_download = should_fail_download
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -32,13 +43,18 @@ class MockNode(Node[MockGenData, str]):
|
||||
|
||||
def genseed(self, size: int, seed: int, meta: str) -> MockGenData:
|
||||
self.seeding = MockGenData(size=size, seed=seed, name=meta)
|
||||
self.download_completed = True
|
||||
return self.seeding
|
||||
|
||||
def leech(self, handle: MockGenData):
|
||||
self.leeching = handle
|
||||
return MockDownloadHandle(self)
|
||||
return MockDownloadHandle(self, self.download_lag, self.should_fail_download)
|
||||
|
||||
def remove(self, handle: MockGenData):
|
||||
assert (
|
||||
self.download_completed or self.download_failed
|
||||
), "Removing download before completion"
|
||||
|
||||
if self.leeching is not None:
|
||||
assert self.leeching == handle
|
||||
elif self.seeding is not None:
|
||||
@ -49,19 +65,40 @@ class MockNode(Node[MockGenData, str]):
|
||||
)
|
||||
|
||||
self.remove_was_called = True
|
||||
|
||||
|
||||
class MockDownloadHandle(DownloadHandle):
|
||||
def __init__(self, parent: MockNode) -> None:
|
||||
self.parent = parent
|
||||
|
||||
def await_for_completion(self, timeout: float = 0) -> bool:
|
||||
self.parent.download_was_awaited = True
|
||||
return True
|
||||
|
||||
|
||||
def mock_network(n: int) -> List[MockNode]:
|
||||
return [MockNode(f"node-{i}") for i in range(n)]
|
||||
class MockDownloadHandle(DownloadHandle):
|
||||
def __init__(
|
||||
self, parent: MockNode, lag: float = 0, should_fail: bool = False
|
||||
) -> None:
|
||||
self.parent = parent
|
||||
self.lag = lag
|
||||
self.should_fail = should_fail
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
return self.parent
|
||||
|
||||
def await_for_completion(self, timeout: float = 0) -> bool:
|
||||
if self.should_fail:
|
||||
self.parent.download_failed = True
|
||||
raise Exception("Oooops, I failed!")
|
||||
time.sleep(self.lag)
|
||||
self.parent.download_completed = True
|
||||
return True
|
||||
|
||||
|
||||
def mock_network(
|
||||
n: int, fail: Optional[List[int]] = None, download_lag: float = 0.0
|
||||
) -> List[MockNode]:
|
||||
fail_list = fail or []
|
||||
return [
|
||||
MockNode(
|
||||
f"node-{i}", should_fail_download=i in fail_list, download_lag=download_lag
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
def test_should_generate_correct_data_and_seed():
|
||||
@ -104,7 +141,7 @@ def test_should_download_at_remaining_nodes():
|
||||
if node.leeching is not None:
|
||||
assert node.leeching == gendata
|
||||
assert node.seeding is None
|
||||
assert node.download_was_awaited
|
||||
assert node.download_completed
|
||||
actual_leechers.add(index)
|
||||
|
||||
assert actual_leechers == set(range(13)) - set(seeders)
|
||||
@ -199,3 +236,27 @@ def test_should_delete_file_from_nodes_at_the_end_of_the_experiment():
|
||||
|
||||
assert network[0].remove_was_called
|
||||
assert network[1].remove_was_called
|
||||
|
||||
|
||||
def test_should_not_have_pending_download_operations_running_at_teardown():
|
||||
network = mock_network(n=3, fail=[1], download_lag=1)
|
||||
seeders = [0]
|
||||
|
||||
experiment = StaticDisseminationExperiment(
|
||||
seeders=seeders,
|
||||
network=network,
|
||||
meta="dataset-1",
|
||||
file_size=1000,
|
||||
seed=12,
|
||||
)
|
||||
|
||||
try:
|
||||
experiment.run()
|
||||
except* Exception as e:
|
||||
assert len(e.exceptions) == 1
|
||||
assert str(e.exceptions[0]) == "Oooops, I failed!"
|
||||
|
||||
# Downloads should have been marked as completed even
|
||||
# though we had one exception.
|
||||
assert network[0].download_completed
|
||||
assert network[2].download_completed
|
||||
|
@ -9,6 +9,12 @@ TInitialMetadata = TypeVar("TInitialMetadata")
|
||||
class DownloadHandle(ABC):
|
||||
"""A :class:`DownloadHandle` is a reference to an ongoing download operation."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def node(self) -> "Node":
|
||||
"""The node that initiated the download."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def await_for_completion(self, timeout: float = 0) -> bool:
|
||||
"""Blocks the current thread until either the download completes or a timeout expires.
|
||||
|
@ -1,9 +1,21 @@
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from threading import Semaphore
|
||||
from typing import Iterable
|
||||
|
||||
import pytest
|
||||
|
||||
from benchmarks.core.concurrency import pflatmap
|
||||
from benchmarks.core.concurrency import pflatmap, ensure_successful
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor():
|
||||
executor = None
|
||||
try:
|
||||
executor = ThreadPoolExecutor(max_workers=3)
|
||||
yield executor
|
||||
finally:
|
||||
if executor is not None:
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
|
||||
def test_should_run_iterators_in_separate_threads():
|
||||
@ -44,7 +56,34 @@ def test_should_raise_exceptions_raised_by_tasks_at_the_end():
|
||||
for val in it:
|
||||
actual_vals.add(val)
|
||||
assert False, "ValueError was not raised"
|
||||
except ValueError:
|
||||
except* ValueError:
|
||||
pass
|
||||
|
||||
assert actual_vals == reference_vals
|
||||
|
||||
|
||||
def test_should_return_results_when_no_failures_occur(executor):
|
||||
def reliable_task(i: int) -> int:
|
||||
return i
|
||||
|
||||
assert set(
|
||||
ensure_successful(executor.submit(reliable_task, i) for i in range(10))
|
||||
) == set(range(10))
|
||||
|
||||
|
||||
def test_should_raise_exception_when_one_task_fails(executor):
|
||||
def reliable_task(i: int) -> int:
|
||||
return i
|
||||
|
||||
def faulty_task(i: int):
|
||||
raise ValueError("I'm very faulty")
|
||||
|
||||
try:
|
||||
ensure_successful(
|
||||
executor.submit(reliable_task if i % 2 == 0 else faulty_task, i)
|
||||
for i in range(10)
|
||||
)
|
||||
except* ValueError as e:
|
||||
assert len(e.exceptions) == 5
|
||||
for exception in e.exceptions:
|
||||
assert str(exception) == "I'm very faulty"
|
||||
|
@ -207,9 +207,13 @@ class ResilientCallWrapper:
|
||||
|
||||
class DelugeDownloadHandle(DownloadHandle):
|
||||
def __init__(self, torrent: Torrent, node: DelugeNode) -> None:
|
||||
self.node = node
|
||||
self._node = node
|
||||
self.torrent = torrent
|
||||
|
||||
@property
|
||||
def node(self) -> DelugeNode:
|
||||
return self._node
|
||||
|
||||
def await_for_completion(self, timeout: float = 0) -> bool:
|
||||
name = self.torrent.name
|
||||
|
||||
|
@ -26,11 +26,9 @@ types-pyyaml = "^6.0.12.20240917"
|
||||
types-requests = "^2.32.0.20241016"
|
||||
httpx = "^0.28.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^4.0.1"
|
||||
|
||||
|
||||
[tool.poetry.group.agent.dependencies]
|
||||
uvicorn = "^0.34.0"
|
||||
|
||||
@ -42,6 +40,9 @@ markers = [
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
Loading…
x
Reference in New Issue
Block a user