from asyncio import StreamReader from io import StringIO from typing import IO, Dict from unittest.mock import patch import pytest from benchmarks.codex.agent import CodexAgent from benchmarks.codex.client import CodexClient, Cid, Manifest from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.core.concurrency import await_predicate_async from benchmarks.core.utils.streams import BaseStreamReader from benchmarks.logging.logging import LogParser class FakeCodexClient(CodexClient): def __init__(self) -> None: self.storage: Dict[Cid, Manifest] = {} self.streams: Dict[Cid, StreamReader] = {} async def upload(self, name: str, mime_type: str, content: IO) -> Cid: data = content.read() cid = "Qm" + str(hash(data)) self.storage[cid] = Manifest( cid=cid, datasetSize=len(data), mimetype=mime_type, blockSize=1, filename=name, treeCid="", uploadedAt=0, protected=False, ) return cid async def get_manifest(self, cid: Cid) -> Manifest: return self.storage[cid] def create_download_stream(self, cid: Cid) -> StreamReader: reader = StreamReader() self.streams[cid] = reader return reader async def download(self, cid: Cid) -> BaseStreamReader: return self.streams[cid] @pytest.mark.asyncio async def test_should_create_dataset_of_right_size(): codex_agent = CodexAgent(FakeCodexClient()) cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234) manifest = await codex_agent.client.get_manifest(cid) assert manifest.datasetSize == 1024 @pytest.mark.asyncio async def test_same_seed_creates_same_cid(): codex_agent = CodexAgent(FakeCodexClient()) cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235) assert cid1 == cid2 assert cid1 != cid3 @pytest.mark.asyncio async def test_should_report_download_progress(): client = FakeCodexClient() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) download_stream = client.create_download_stream(cid) handle = await codex_agent.download(cid) assert handle.progress() == 0 for i in range(100): download_stream.feed_data(b"0" * 5) assert len(download_stream._buffer) == 5 assert await await_predicate_async( lambda: round(handle.progress() * 100) == i, timeout=5 ) assert await await_predicate_async( lambda: len(download_stream._buffer) == 0, timeout=5 ) download_stream.feed_data(b"0" * 5) assert len(download_stream._buffer) == 5 assert await await_predicate_async( lambda: round(handle.progress() * 100) == (i + 1), timeout=5 ) assert await await_predicate_async( lambda: len(download_stream._buffer) == 0, timeout=5 ) download_stream.feed_eof() await handle.download_task @pytest.mark.asyncio async def test_should_raise_exception_on_progress_query_if_download_fails(): client = FakeCodexClient() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) download_stream = client.create_download_stream(cid) handle = await codex_agent.download(cid) download_stream.feed_eof() with pytest.raises(EOFError): def _predicate(): handle.progress() return False await await_predicate_async(_predicate, timeout=5) @pytest.mark.asyncio async def test_should_log_download_progress_as_metric(mock_logger): logger, output = mock_logger with patch("benchmarks.codex.agent.logger", logger): client = FakeCodexClient() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) download_stream = client.create_download_stream(cid) download_stream.feed_data(b"0" * 1000) download_stream.feed_eof() handle = await codex_agent.download(cid, read_increment=0.2) await handle.download_task parser = LogParser() parser.register(CodexDownloadMetric) metrics = list(parser.parse(StringIO(output.getvalue()))) assert metrics == [ CodexDownloadMetric( cid=cid, value=200, node=codex_agent.node_id, timestamp=metrics[0].timestamp ), CodexDownloadMetric( cid=cid, value=400, node=codex_agent.node_id, timestamp=metrics[1].timestamp ), CodexDownloadMetric( cid=cid, value=600, node=codex_agent.node_id, timestamp=metrics[2].timestamp ), CodexDownloadMetric( cid=cid, value=800, node=codex_agent.node_id, timestamp=metrics[3].timestamp ), CodexDownloadMetric( cid=cid, value=1000, node=codex_agent.node_id, timestamp=metrics[4].timestamp, ), ]