From cb941b859fe162561e47136dcc2d76c49808e6f0 Mon Sep 17 00:00:00 2001 From: gmega Date: Fri, 31 Jan 2025 15:49:58 -0300 Subject: [PATCH] feat: add tracking of download progress via synchronous API at the Codex agent --- benchmarks/codex/agent.py | 76 +++++++++++- benchmarks/codex/client.py | 31 ++++- benchmarks/codex/tests/fixtures.py | 4 +- benchmarks/codex/tests/test_codex_agent.py | 110 ++++++++++++++++-- benchmarks/codex/tests/test_codex_client.py | 2 +- benchmarks/core/concurrency.py | 34 +++++- benchmarks/core/experiments/experiments.py | 2 +- .../core/experiments/tests/test_utils.py | 2 +- benchmarks/core/utils/__init__.py | 0 benchmarks/core/{utils.py => utils/random.py} | 23 +--- benchmarks/core/utils/streams.py | 9 ++ benchmarks/core/utils/units.py | 6 + benchmarks/deluge/agent/agent.py | 3 +- benchmarks/deluge/agent/api.py | 2 +- benchmarks/deluge/config.py | 3 +- benchmarks/deluge/deluge_node.py | 3 +- benchmarks/deluge/tests/fixtures.py | 2 +- benchmarks/deluge/tests/test_deluge_node.py | 3 +- .../tests/test_deluge_static_experiment.py | 3 +- benchmarks/logging/sources/tests/fixtures.py | 2 +- 20 files changed, 270 insertions(+), 50 deletions(-) create mode 100644 benchmarks/core/utils/__init__.py rename benchmarks/core/{utils.py => utils/random.py} (55%) create mode 100644 benchmarks/core/utils/streams.py create mode 100644 benchmarks/core/utils/units.py diff --git a/benchmarks/codex/agent.py b/benchmarks/codex/agent.py index f7cd13b..ab1287c 100644 --- a/benchmarks/codex/agent.py +++ b/benchmarks/codex/agent.py @@ -1,15 +1,74 @@ +import asyncio +from asyncio import Task from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional -from benchmarks.codex.client import CodexClient -from benchmarks.core.utils import random_data +from benchmarks.codex.client import CodexClient, Manifest +from benchmarks.core.utils.random import random_data +from benchmarks.core.utils.streams import BaseStreamReader Cid = str +EMPTY_STREAM_BACKOFF = 0.1 + + +class DownloadHandle: + def __init__( + self, + parent: "CodexAgent", + manifest: Manifest, + download_stream: BaseStreamReader, + read_increment: float = 0.01, + ): + self.parent = parent + self.manifest = manifest + self.bytes_downloaded = 0 + self.read_increment = read_increment + self.download_stream = download_stream + self.download_task: Optional[Task[None]] = None + + def begin_download(self) -> Task: + self.download_task = asyncio.create_task(self._download_loop()) + return self.download_task + + async def _download_loop(self): + step_size = int(self.manifest.datasetSize * self.read_increment) + + while not self.download_stream.at_eof(): + step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded) + bytes_read = await self.download_stream.read(step) + # We actually have no guarantees that an empty read means EOF, so we just back off + # a bit. + if not bytes_read: + await asyncio.sleep(EMPTY_STREAM_BACKOFF) + self.bytes_downloaded += len(bytes_read) + + if self.bytes_downloaded < self.manifest.datasetSize: + raise EOFError( + f"Got EOF too early: download size ({self.bytes_downloaded}) was less " + f"than expected ({self.manifest.datasetSize})." + ) + + if self.bytes_downloaded > self.manifest.datasetSize: + raise ValueError( + f"Download size ({self.bytes_downloaded}) was greater than expected " + f"({self.manifest.datasetSize})." + ) + + def progress(self) -> float: + if self.download_task is None: + return 0 + + if self.download_task.done(): + # This will bubble exceptions up, if any. + self.download_task.result() + + return self.bytes_downloaded / self.manifest.datasetSize + class CodexAgent: - def __init__(self, client: CodexClient): + def __init__(self, client: CodexClient) -> None: self.client = client async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: @@ -23,3 +82,14 @@ class CodexAgent: return await self.client.upload( name=name, mime_type="application/octet-stream", content=infile ) + + async def download(self, cid: Cid) -> DownloadHandle: + handle = DownloadHandle( + self, + manifest=await self.client.get_manifest(cid), + download_stream=await self.client.download(cid), + ) + + handle.begin_download() + + return handle diff --git a/benchmarks/codex/client.py b/benchmarks/codex/client.py index 9289844..c55efe0 100644 --- a/benchmarks/codex/client.py +++ b/benchmarks/codex/client.py @@ -1,9 +1,12 @@ +from abc import ABC, abstractmethod from typing import IO import aiohttp from pydantic import BaseModel from urllib3.util import Url +from benchmarks.core.utils.streams import BaseStreamReader + API_VERSION = "v1" Cid = str @@ -20,7 +23,21 @@ class Manifest(BaseModel): protected: bool -class CodexClient: +class CodexClient(ABC): + @abstractmethod + async def upload(self, name: str, mime_type: str, content: IO) -> Cid: + pass + + @abstractmethod + async def get_manifest(self, cid: Cid) -> Manifest: + pass + + @abstractmethod + async def download(self, cid: Cid) -> BaseStreamReader: + pass + + +class CodexClientImpl(CodexClient): """A lightweight async wrapper built around the Codex REST API.""" def __init__(self, codex_api_url: Url): @@ -55,3 +72,15 @@ class CodexClient: cid = response_contents.pop("cid") return Manifest.model_validate(dict(cid=cid, **response_contents["manifest"])) + + async def download(self, cid: Cid) -> BaseStreamReader: + async with aiohttp.ClientSession() as session: + response = await session.get( + self.codex_api_url._replace( + path=f"/api/codex/v1/data/{cid}/network/download" + ).url, + ) + + response.raise_for_status() + + return response.content diff --git a/benchmarks/codex/tests/fixtures.py b/benchmarks/codex/tests/fixtures.py index 525b9b0..ee69c77 100644 --- a/benchmarks/codex/tests/fixtures.py +++ b/benchmarks/codex/tests/fixtures.py @@ -3,12 +3,12 @@ import os import pytest from urllib3.util import parse_url -from benchmarks.codex.client import CodexClient +from benchmarks.codex.client import CodexClientImpl @pytest.fixture def codex_client_1(): # TODO wipe data between tests - return CodexClient( + return CodexClientImpl( parse_url(f"http://{os.environ.get('CODEX_NODE_1', 'localhost')}:8091") ) diff --git a/benchmarks/codex/tests/test_codex_agent.py b/benchmarks/codex/tests/test_codex_agent.py index fffd6f6..a099cd4 100644 --- a/benchmarks/codex/tests/test_codex_agent.py +++ b/benchmarks/codex/tests/test_codex_agent.py @@ -1,30 +1,118 @@ +from asyncio import StreamReader +from typing import IO, Dict + import pytest from benchmarks.codex.agent import CodexAgent +from benchmarks.codex.client import CodexClient, Cid, Manifest +from benchmarks.core.concurrency import await_predicate_async +from benchmarks.core.utils.streams import BaseStreamReader -@pytest.fixture -def codex_agent(codex_client_1): - return CodexAgent(codex_client_1) +class FakeCodexClient(CodexClient): + def __init__(self) -> None: + self.storage: Dict[Cid, Manifest] = {} + self.streams: Dict[Cid, StreamReader] = {} + + async def upload(self, name: str, mime_type: str, content: IO) -> Cid: + data = content.read() + cid = "Qm" + str(hash(data)) + self.storage[cid] = Manifest( + cid=cid, + datasetSize=len(data), + mimetype=mime_type, + blockSize=1, + filename=name, + treeCid="", + uploadedAt=0, + protected=False, + ) + return cid + + async def get_manifest(self, cid: Cid) -> Manifest: + return self.storage[cid] + + def create_download_stream(self, cid: Cid) -> StreamReader: + reader = StreamReader() + self.streams[cid] = reader + return reader + + async def download(self, cid: Cid) -> BaseStreamReader: + return self.streams[cid] -@pytest.mark.codex_integration @pytest.mark.asyncio -async def test_should_create_dataset(codex_agent: CodexAgent): +async def test_should_create_dataset_of_right_size(): + codex_agent = CodexAgent(FakeCodexClient()) cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234) + manifest = await codex_agent.client.get_manifest(cid) - dataset = await codex_agent.client.get_manifest(cid) - - assert dataset.cid == cid - assert dataset.datasetSize == 1024 + assert manifest.datasetSize == 1024 -@pytest.mark.codex_integration @pytest.mark.asyncio -async def test_same_seed_creates_same_cid(codex_agent: CodexAgent): +async def test_same_seed_creates_same_cid(): + codex_agent = CodexAgent(FakeCodexClient()) + cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235) assert cid1 == cid2 assert cid1 != cid3 + + +@pytest.mark.asyncio +async def test_should_report_download_progress(): + client = FakeCodexClient() + codex_agent = CodexAgent(client) + + cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) + download_stream = client.create_download_stream(cid) + + handle = await codex_agent.download(cid) + + assert handle.progress() == 0 + + for i in range(100): + download_stream.feed_data(b"0" * 5) + assert len(download_stream._buffer) == 5 + assert await await_predicate_async( + lambda: round(handle.progress() * 100) == i, timeout=5 + ) + assert await await_predicate_async( + lambda: len(download_stream._buffer) == 0, timeout=5 + ) + + download_stream.feed_data(b"0" * 5) + assert len(download_stream._buffer) == 5 + assert await await_predicate_async( + lambda: round(handle.progress() * 100) == (i + 1), timeout=5 + ) + assert await await_predicate_async( + lambda: len(download_stream._buffer) == 0, timeout=5 + ) + + download_stream.feed_eof() + await handle.download_task + + +@pytest.mark.asyncio +async def test_should_raise_exception_on_progress_query_if_download_fails(): + client = FakeCodexClient() + codex_agent = CodexAgent(client) + + cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) + download_stream = client.create_download_stream(cid) + + handle = await codex_agent.download(cid) + + download_stream.feed_eof() + + with pytest.raises(EOFError): + + def _predicate(): + handle.progress() + return False + + await await_predicate_async(_predicate, timeout=5) diff --git a/benchmarks/codex/tests/test_codex_client.py b/benchmarks/codex/tests/test_codex_client.py index 10f7739..a8a0f51 100644 --- a/benchmarks/codex/tests/test_codex_client.py +++ b/benchmarks/codex/tests/test_codex_client.py @@ -2,7 +2,7 @@ from io import BytesIO import pytest -from benchmarks.core.utils import random_data +from benchmarks.core.utils.random import random_data @pytest.fixture diff --git a/benchmarks/core/concurrency.py b/benchmarks/core/concurrency.py index e1d53ce..58ededf 100644 --- a/benchmarks/core/concurrency.py +++ b/benchmarks/core/concurrency.py @@ -1,11 +1,43 @@ +import asyncio from concurrent import futures from concurrent.futures.thread import ThreadPoolExecutor from queue import Queue -from typing import Iterable, Iterator, List, cast +from time import time, sleep +from typing import Iterable, Iterator, List, cast, Awaitable, Callable from typing_extensions import TypeVar +def await_predicate( + predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0 +) -> bool: + start_time = time() + while (timeout == 0) or ((time() - start_time) <= timeout): + if predicate(): + return True + sleep(polling_interval) + + return False + + +async def await_predicate_async( + predicate: Callable[[], Awaitable[bool]] | Callable[[], bool], + timeout: float = 0, + polling_interval: float = 0, +) -> bool: + start_time = time() + while (timeout == 0) or ((time() - start_time) <= timeout): + if asyncio.iscoroutinefunction(predicate): + if await predicate(): + return True + else: + if predicate(): + return True + await asyncio.sleep(polling_interval) + + return False + + class _End: pass diff --git a/benchmarks/core/experiments/experiments.py b/benchmarks/core/experiments/experiments.py index c5d8c90..05f0705 100644 --- a/benchmarks/core/experiments/experiments.py +++ b/benchmarks/core/experiments/experiments.py @@ -8,8 +8,8 @@ from typing import List, Optional from typing_extensions import Generic, TypeVar +from benchmarks.core.concurrency import await_predicate from benchmarks.core.config import Builder -from benchmarks.core.utils import await_predicate logger = logging.getLogger(__name__) diff --git a/benchmarks/core/experiments/tests/test_utils.py b/benchmarks/core/experiments/tests/test_utils.py index ceeab6c..06ef50f 100644 --- a/benchmarks/core/experiments/tests/test_utils.py +++ b/benchmarks/core/experiments/tests/test_utils.py @@ -1,6 +1,6 @@ from io import BytesIO -from benchmarks.core.utils import random_data +from benchmarks.core.utils.random import random_data def test_should_generate_the_requested_amount_of_bytes(): diff --git a/benchmarks/core/utils/__init__.py b/benchmarks/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/core/utils.py b/benchmarks/core/utils/random.py similarity index 55% rename from benchmarks/core/utils.py rename to benchmarks/core/utils/random.py index 4f56519..142e922 100644 --- a/benchmarks/core/utils.py +++ b/benchmarks/core/utils/random.py @@ -1,18 +1,7 @@ import random -from time import time, sleep -from typing import Iterator, Optional, Callable, IO +from typing import Iterator, IO, Optional - -def await_predicate( - predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0 -) -> bool: - start_time = time() - while (timeout == 0) or ((time() - start_time) <= timeout): - if predicate(): - return True - sleep(polling_interval) - - return False +from benchmarks.core.utils.units import megabytes def sample(n: int) -> Iterator[int]: @@ -26,14 +15,6 @@ def sample(n: int) -> Iterator[int]: yield p[i] -def kilobytes(n: int) -> int: - return n * 1024 - - -def megabytes(n: int) -> int: - return kilobytes(n) * 1024 - - def random_data( size: int, outfile: IO, batch_size: int = megabytes(50), seed: Optional[int] = None ): diff --git a/benchmarks/core/utils/streams.py b/benchmarks/core/utils/streams.py new file mode 100644 index 0000000..798a71a --- /dev/null +++ b/benchmarks/core/utils/streams.py @@ -0,0 +1,9 @@ +from typing import Protocol + + +class BaseStreamReader(Protocol): + async def read(self, n: int) -> bytes: ... + + def feed_data(self, data: bytes) -> None: ... + + def at_eof(self) -> bool: ... diff --git a/benchmarks/core/utils/units.py b/benchmarks/core/utils/units.py new file mode 100644 index 0000000..b95facd --- /dev/null +++ b/benchmarks/core/utils/units.py @@ -0,0 +1,6 @@ +def kilobytes(n: int) -> int: + return n * 1024 + + +def megabytes(n: int) -> int: + return kilobytes(n) * 1024 diff --git a/benchmarks/deluge/agent/agent.py b/benchmarks/deluge/agent/agent.py index d23f13a..05398a3 100644 --- a/benchmarks/deluge/agent/agent.py +++ b/benchmarks/deluge/agent/agent.py @@ -4,7 +4,8 @@ from typing import Optional from torrentool.torrent import Torrent -from benchmarks.core.utils import random_data, megabytes +from benchmarks.core.utils.random import random_data +from benchmarks.core.utils.units import megabytes logger = logging.getLogger(__name__) diff --git a/benchmarks/deluge/agent/api.py b/benchmarks/deluge/agent/api.py index 85e883d..3104703 100644 --- a/benchmarks/deluge/agent/api.py +++ b/benchmarks/deluge/agent/api.py @@ -4,8 +4,8 @@ from typing import Annotated, Optional from fastapi import FastAPI, Depends, APIRouter, Response from benchmarks.core.agent import AgentBuilder +from benchmarks.core.utils.units import megabytes -from benchmarks.core.utils import megabytes from benchmarks.deluge.agent.agent import DelugeAgent router = APIRouter() diff --git a/benchmarks/deluge/config.py b/benchmarks/deluge/config.py index a5d8586..b689dc4 100644 --- a/benchmarks/deluge/config.py +++ b/benchmarks/deluge/config.py @@ -15,7 +15,8 @@ from benchmarks.core.experiments.experiments import ( from benchmarks.core.experiments.iterated_experiment import IteratedExperiment from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment from benchmarks.core.pydantic import Host -from benchmarks.core.utils import sample +from benchmarks.core.utils.random import sample + from benchmarks.deluge.agent.client import DelugeAgentClient from benchmarks.deluge.deluge_node import DelugeMeta, DelugeNode from benchmarks.deluge.tracker import Tracker diff --git a/benchmarks/deluge/deluge_node.py b/benchmarks/deluge/deluge_node.py index 77e32eb..1f72498 100644 --- a/benchmarks/deluge/deluge_node.py +++ b/benchmarks/deluge/deluge_node.py @@ -22,9 +22,10 @@ from tenacity.wait import wait_base from torrentool.torrent import Torrent from urllib3.util import Url +from benchmarks.core.concurrency import await_predicate from benchmarks.core.experiments.experiments import ExperimentComponent from benchmarks.core.network import DownloadHandle, Node -from benchmarks.core.utils import await_predicate + from benchmarks.deluge.agent.client import DelugeAgentClient logger = logging.getLogger(__name__) diff --git a/benchmarks/deluge/tests/fixtures.py b/benchmarks/deluge/tests/fixtures.py index 5458505..cee4acc 100644 --- a/benchmarks/deluge/tests/fixtures.py +++ b/benchmarks/deluge/tests/fixtures.py @@ -5,7 +5,7 @@ from typing import Generator import pytest from urllib3.util import parse_url -from benchmarks.core.utils import await_predicate +from benchmarks.core.concurrency import await_predicate from benchmarks.deluge.agent.client import DelugeAgentClient from benchmarks.deluge.deluge_node import DelugeNode from benchmarks.deluge.tracker import Tracker diff --git a/benchmarks/deluge/tests/test_deluge_node.py b/benchmarks/deluge/tests/test_deluge_node.py index 5e3e5a6..2c82c7a 100644 --- a/benchmarks/deluge/tests/test_deluge_node.py +++ b/benchmarks/deluge/tests/test_deluge_node.py @@ -1,7 +1,8 @@ import pytest from tenacity import wait_incrementing, stop_after_attempt, RetryError -from benchmarks.core.utils import megabytes, await_predicate +from benchmarks.core.concurrency import await_predicate +from benchmarks.core.utils.units import megabytes from benchmarks.deluge.deluge_node import DelugeNode, DelugeMeta, ResilientCallWrapper from benchmarks.deluge.tracker import Tracker diff --git a/benchmarks/deluge/tests/test_deluge_static_experiment.py b/benchmarks/deluge/tests/test_deluge_static_experiment.py index 1a12ad0..2c63e64 100644 --- a/benchmarks/deluge/tests/test_deluge_static_experiment.py +++ b/benchmarks/deluge/tests/test_deluge_static_experiment.py @@ -2,7 +2,8 @@ import pytest from benchmarks.core.experiments.experiments import ExperimentEnvironment from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment -from benchmarks.core.utils import megabytes +from benchmarks.core.utils.units import megabytes + from benchmarks.deluge.deluge_node import DelugeMeta from benchmarks.deluge.tests.test_deluge_node import assert_is_seed diff --git a/benchmarks/logging/sources/tests/fixtures.py b/benchmarks/logging/sources/tests/fixtures.py index 7287e25..dd61483 100644 --- a/benchmarks/logging/sources/tests/fixtures.py +++ b/benchmarks/logging/sources/tests/fixtures.py @@ -6,7 +6,7 @@ from typing import Dict, Any import pytest from elasticsearch import Elasticsearch -from benchmarks.core.utils import await_predicate +from benchmarks.core.concurrency import await_predicate def _json_data(data: str) -> Dict[str, Any]: