diff --git a/benchmarks/codex/agent/agent.py b/benchmarks/codex/agent/agent.py index 6f818c1..8de112d 100644 --- a/benchmarks/codex/agent/agent.py +++ b/benchmarks/codex/agent/agent.py @@ -1,9 +1,10 @@ import asyncio import logging from asyncio import Task +from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional +from typing import Optional, Dict from benchmarks.codex.agent.codex_client import CodexClient, Manifest from benchmarks.codex.logging import CodexDownloadMetric @@ -17,6 +18,12 @@ EMPTY_STREAM_BACKOFF = 0.1 logger = logging.getLogger(__name__) +@dataclass +class DownloadStatus: + downloaded: int + total: int + + class DownloadHandle: def __init__( self, @@ -37,51 +44,58 @@ class DownloadHandle: return self.download_task async def _download_loop(self): - step_size = int(self.manifest.datasetSize * self.read_increment) + try: + step_size = int(self.manifest.datasetSize * self.read_increment) - while not self.download_stream.at_eof(): - step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded) - bytes_read = await self.download_stream.read(step) - # We actually have no guarantees that an empty read means EOF, so we just back off - # a bit. - if not bytes_read: - await asyncio.sleep(EMPTY_STREAM_BACKOFF) - self.bytes_downloaded += len(bytes_read) - logger.info( - CodexDownloadMetric( - cid=self.manifest.cid, - value=self.bytes_downloaded, - node=self.parent.node_id, + while not self.download_stream.at_eof(): + step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded) + bytes_read = await self.download_stream.read(step) + # We actually have no guarantees that an empty read means EOF, so we just back off + # a bit. + if not bytes_read: + await asyncio.sleep(EMPTY_STREAM_BACKOFF) + self.bytes_downloaded += len(bytes_read) + + logger.info( + CodexDownloadMetric( + cid=self.manifest.cid, + value=self.bytes_downloaded, + node=self.parent.node_id, + ) ) - ) - if self.bytes_downloaded < self.manifest.datasetSize: - raise EOFError( - f"Got EOF too early: download size ({self.bytes_downloaded}) was less " - f"than expected ({self.manifest.datasetSize})." - ) + if self.bytes_downloaded < self.manifest.datasetSize: + raise EOFError( + f"Got EOF too early: download size ({self.bytes_downloaded}) was less " + f"than expected ({self.manifest.datasetSize})." + ) - if self.bytes_downloaded > self.manifest.datasetSize: - raise ValueError( - f"Download size ({self.bytes_downloaded}) was greater than expected " - f"({self.manifest.datasetSize})." - ) + if self.bytes_downloaded > self.manifest.datasetSize: + raise ValueError( + f"Download size ({self.bytes_downloaded}) was greater than expected " + f"({self.manifest.datasetSize})." + ) + finally: + self.parent._download_done(self.manifest.cid) - def progress(self) -> float: + def progress(self) -> DownloadStatus: if self.download_task is None: - return 0 + return DownloadStatus(downloaded=0, total=self.manifest.datasetSize) if self.download_task.done(): # This will bubble exceptions up, if any. self.download_task.result() - return self.bytes_downloaded / self.manifest.datasetSize + return DownloadStatus( + downloaded=self.bytes_downloaded, total=self.manifest.datasetSize + ) class CodexAgent: def __init__(self, client: CodexClient, node_id: str = "unknown") -> None: self.client = client self.node_id = node_id + self.ongoing_downloads: Dict[Cid, DownloadHandle] = {} async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: with TemporaryDirectory() as td: @@ -96,6 +110,9 @@ class CodexAgent: ) async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle: + if cid in self.ongoing_downloads: + return self.ongoing_downloads[cid] + handle = DownloadHandle( self, manifest=await self.client.get_manifest(cid), @@ -105,4 +122,8 @@ class CodexAgent: handle.begin_download() + self.ongoing_downloads[cid] = handle return handle + + def _download_done(self, cid: Cid): + self.ongoing_downloads.pop(cid) diff --git a/benchmarks/codex/agent/tests/test_codex_agent.py b/benchmarks/codex/agent/tests/test_codex_agent.py index 248bccf..a20e8e7 100644 --- a/benchmarks/codex/agent/tests/test_codex_agent.py +++ b/benchmarks/codex/agent/tests/test_codex_agent.py @@ -5,7 +5,7 @@ from unittest.mock import patch import pytest -from benchmarks.codex.agent.agent import CodexAgent +from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.core.concurrency import await_predicate_async @@ -76,30 +76,21 @@ async def test_should_report_download_progress(): handle = await codex_agent.download(cid) - assert handle.progress() == 0 + assert handle.progress() == DownloadStatus(downloaded=0, total=1000) - for i in range(100): + for i in range(200): download_stream.feed_data(b"0" * 5) - assert len(download_stream._buffer) == 5 assert await await_predicate_async( - lambda: round(handle.progress() * 100) == i, timeout=5 - ) - assert await await_predicate_async( - lambda: len(download_stream._buffer) == 0, timeout=5 - ) - - download_stream.feed_data(b"0" * 5) - assert len(download_stream._buffer) == 5 - assert await await_predicate_async( - lambda: round(handle.progress() * 100) == (i + 1), timeout=5 - ) - assert await await_predicate_async( - lambda: len(download_stream._buffer) == 0, timeout=5 + lambda: handle.progress() + == DownloadStatus(downloaded=5 * (i + 1), total=1000), + timeout=5, ) download_stream.feed_eof() await handle.download_task + assert handle.progress() == DownloadStatus(downloaded=1000, total=1000) + @pytest.mark.asyncio async def test_should_raise_exception_on_progress_query_if_download_fails(): @@ -114,16 +105,11 @@ async def test_should_raise_exception_on_progress_query_if_download_fails(): download_stream.feed_eof() with pytest.raises(EOFError): - - def _predicate(): - handle.progress() - return False - - await await_predicate_async(_predicate, timeout=5) + await handle.download_task @pytest.mark.asyncio -async def test_should_log_download_progress_as_metric(mock_logger): +async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_logger): logger, output = mock_logger with patch("benchmarks.codex.agent.agent.logger", logger): @@ -164,3 +150,23 @@ async def test_should_log_download_progress_as_metric(mock_logger): timestamp=metrics[4].timestamp, ), ] + + +@pytest.mark.asyncio +async def test_should_track_download_handles_and_dispose_of_them_at_the_end(): + client = FakeCodexClient() + codex_agent = CodexAgent(client) + + cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356) + download_stream = client.create_download_stream(cid) + + handle = await codex_agent.download(cid) + + assert codex_agent.ongoing_downloads[cid] == handle + + download_stream.feed_data(b"0" * 1000) + download_stream.feed_eof() + + await handle.download_task + + assert cid not in codex_agent.ongoing_downloads