refactor: simplify download progress reporting

This commit is contained in:
gmega 2025-02-03 16:45:48 -03:00
parent bd0ef9ca55
commit 849bcad6c8
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
2 changed files with 80 additions and 53 deletions

View File

@ -1,9 +1,10 @@
import asyncio import asyncio
import logging import logging
from asyncio import Task from asyncio import Task
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional, Dict
from benchmarks.codex.agent.codex_client import CodexClient, Manifest from benchmarks.codex.agent.codex_client import CodexClient, Manifest
from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.codex.logging import CodexDownloadMetric
@ -17,6 +18,12 @@ EMPTY_STREAM_BACKOFF = 0.1
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class DownloadStatus:
downloaded: int
total: int
class DownloadHandle: class DownloadHandle:
def __init__( def __init__(
self, self,
@ -37,6 +44,7 @@ class DownloadHandle:
return self.download_task return self.download_task
async def _download_loop(self): async def _download_loop(self):
try:
step_size = int(self.manifest.datasetSize * self.read_increment) step_size = int(self.manifest.datasetSize * self.read_increment)
while not self.download_stream.at_eof(): while not self.download_stream.at_eof():
@ -47,6 +55,7 @@ class DownloadHandle:
if not bytes_read: if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF) await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read) self.bytes_downloaded += len(bytes_read)
logger.info( logger.info(
CodexDownloadMetric( CodexDownloadMetric(
cid=self.manifest.cid, cid=self.manifest.cid,
@ -66,22 +75,27 @@ class DownloadHandle:
f"Download size ({self.bytes_downloaded}) was greater than expected " f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})." f"({self.manifest.datasetSize})."
) )
finally:
self.parent._download_done(self.manifest.cid)
def progress(self) -> float: def progress(self) -> DownloadStatus:
if self.download_task is None: if self.download_task is None:
return 0 return DownloadStatus(downloaded=0, total=self.manifest.datasetSize)
if self.download_task.done(): if self.download_task.done():
# This will bubble exceptions up, if any. # This will bubble exceptions up, if any.
self.download_task.result() self.download_task.result()
return self.bytes_downloaded / self.manifest.datasetSize return DownloadStatus(
downloaded=self.bytes_downloaded, total=self.manifest.datasetSize
)
class CodexAgent: class CodexAgent:
def __init__(self, client: CodexClient, node_id: str = "unknown") -> None: def __init__(self, client: CodexClient, node_id: str = "unknown") -> None:
self.client = client self.client = client
self.node_id = node_id self.node_id = node_id
self.ongoing_downloads: Dict[Cid, DownloadHandle] = {}
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
with TemporaryDirectory() as td: with TemporaryDirectory() as td:
@ -96,6 +110,9 @@ class CodexAgent:
) )
async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle: async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
if cid in self.ongoing_downloads:
return self.ongoing_downloads[cid]
handle = DownloadHandle( handle = DownloadHandle(
self, self,
manifest=await self.client.get_manifest(cid), manifest=await self.client.get_manifest(cid),
@ -105,4 +122,8 @@ class CodexAgent:
handle.begin_download() handle.begin_download()
self.ongoing_downloads[cid] = handle
return handle return handle
def _download_done(self, cid: Cid):
self.ongoing_downloads.pop(cid)

View File

@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest import pytest
from benchmarks.codex.agent.agent import CodexAgent from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest
from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.concurrency import await_predicate_async from benchmarks.core.concurrency import await_predicate_async
@ -76,30 +76,21 @@ async def test_should_report_download_progress():
handle = await codex_agent.download(cid) handle = await codex_agent.download(cid)
assert handle.progress() == 0 assert handle.progress() == DownloadStatus(downloaded=0, total=1000)
for i in range(100): for i in range(200):
download_stream.feed_data(b"0" * 5) download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async( assert await await_predicate_async(
lambda: round(handle.progress() * 100) == i, timeout=5 lambda: handle.progress()
) == DownloadStatus(downloaded=5 * (i + 1), total=1000),
assert await await_predicate_async( timeout=5,
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
) )
download_stream.feed_eof() download_stream.feed_eof()
await handle.download_task await handle.download_task
assert handle.progress() == DownloadStatus(downloaded=1000, total=1000)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails(): async def test_should_raise_exception_on_progress_query_if_download_fails():
@ -114,16 +105,11 @@ async def test_should_raise_exception_on_progress_query_if_download_fails():
download_stream.feed_eof() download_stream.feed_eof()
with pytest.raises(EOFError): with pytest.raises(EOFError):
await handle.download_task
def _predicate():
handle.progress()
return False
await await_predicate_async(_predicate, timeout=5)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_should_log_download_progress_as_metric(mock_logger): async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_logger):
logger, output = mock_logger logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger): with patch("benchmarks.codex.agent.agent.logger", logger):
@ -164,3 +150,23 @@ async def test_should_log_download_progress_as_metric(mock_logger):
timestamp=metrics[4].timestamp, timestamp=metrics[4].timestamp,
), ),
] ]
@pytest.mark.asyncio
async def test_should_track_download_handles_and_dispose_of_them_at_the_end():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert codex_agent.ongoing_downloads[cid] == handle
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
await handle.download_task
assert cid not in codex_agent.ongoing_downloads