refactor: simplify download progress reporting

This commit is contained in:
gmega 2025-02-03 16:45:48 -03:00
parent bd0ef9ca55
commit 849bcad6c8
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
2 changed files with 80 additions and 53 deletions

View File

@ -1,9 +1,10 @@
import asyncio
import logging
from asyncio import Task
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from typing import Optional, Dict
from benchmarks.codex.agent.codex_client import CodexClient, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
@ -17,6 +18,12 @@ EMPTY_STREAM_BACKOFF = 0.1
logger = logging.getLogger(__name__)
@dataclass
class DownloadStatus:
downloaded: int
total: int
class DownloadHandle:
def __init__(
self,
@ -37,51 +44,58 @@ class DownloadHandle:
return self.download_task
async def _download_loop(self):
step_size = int(self.manifest.datasetSize * self.read_increment)
try:
step_size = int(self.manifest.datasetSize * self.read_increment)
while not self.download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await self.download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
logger.info(
CodexDownloadMetric(
cid=self.manifest.cid,
value=self.bytes_downloaded,
node=self.parent.node_id,
while not self.download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await self.download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
logger.info(
CodexDownloadMetric(
cid=self.manifest.cid,
value=self.bytes_downloaded,
node=self.parent.node_id,
)
)
)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
finally:
self.parent._download_done(self.manifest.cid)
def progress(self) -> float:
def progress(self) -> DownloadStatus:
if self.download_task is None:
return 0
return DownloadStatus(downloaded=0, total=self.manifest.datasetSize)
if self.download_task.done():
# This will bubble exceptions up, if any.
self.download_task.result()
return self.bytes_downloaded / self.manifest.datasetSize
return DownloadStatus(
downloaded=self.bytes_downloaded, total=self.manifest.datasetSize
)
class CodexAgent:
def __init__(self, client: CodexClient, node_id: str = "unknown") -> None:
self.client = client
self.node_id = node_id
self.ongoing_downloads: Dict[Cid, DownloadHandle] = {}
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
with TemporaryDirectory() as td:
@ -96,6 +110,9 @@ class CodexAgent:
)
async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
if cid in self.ongoing_downloads:
return self.ongoing_downloads[cid]
handle = DownloadHandle(
self,
manifest=await self.client.get_manifest(cid),
@ -105,4 +122,8 @@ class CodexAgent:
handle.begin_download()
self.ongoing_downloads[cid] = handle
return handle
def _download_done(self, cid: Cid):
self.ongoing_downloads.pop(cid)

View File

@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest
from benchmarks.codex.agent.agent import CodexAgent
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.concurrency import await_predicate_async
@ -76,30 +76,21 @@ async def test_should_report_download_progress():
handle = await codex_agent.download(cid)
assert handle.progress() == 0
assert handle.progress() == DownloadStatus(downloaded=0, total=1000)
for i in range(100):
for i in range(200):
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == i, timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
lambda: handle.progress()
== DownloadStatus(downloaded=5 * (i + 1), total=1000),
timeout=5,
)
download_stream.feed_eof()
await handle.download_task
assert handle.progress() == DownloadStatus(downloaded=1000, total=1000)
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
@ -114,16 +105,11 @@ async def test_should_raise_exception_on_progress_query_if_download_fails():
download_stream.feed_eof()
with pytest.raises(EOFError):
def _predicate():
handle.progress()
return False
await await_predicate_async(_predicate, timeout=5)
await handle.download_task
@pytest.mark.asyncio
async def test_should_log_download_progress_as_metric(mock_logger):
async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_logger):
logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger):
@ -164,3 +150,23 @@ async def test_should_log_download_progress_as_metric(mock_logger):
timestamp=metrics[4].timestamp,
),
]
@pytest.mark.asyncio
async def test_should_track_download_handles_and_dispose_of_them_at_the_end():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert codex_agent.ongoing_downloads[cid] == handle
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
await handle.download_task
assert cid not in codex_agent.ongoing_downloads