2025-02-03 18:00:43 -03:00

175 lines
5.3 KiB
Python

from asyncio import StreamReader
from io import StringIO
from typing import IO, Dict
from unittest.mock import patch
import pytest
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader
from benchmarks.logging.logging import LogParser
class FakeCodexClient(CodexClient):
def __init__(self) -> None:
self.storage: Dict[Cid, Manifest] = {}
self.streams: Dict[Cid, StreamReader] = {}
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
data = content.read()
cid = "Qm" + str(hash(data))
self.storage[cid] = Manifest(
cid=cid,
datasetSize=len(data),
mimetype=mime_type,
blockSize=1,
filename=name,
treeCid="",
uploadedAt=0,
protected=False,
)
return cid
async def get_manifest(self, cid: Cid) -> Manifest:
return self.storage[cid]
def create_download_stream(self, cid: Cid) -> StreamReader:
reader = StreamReader()
self.streams[cid] = reader
return reader
async def download(self, cid: Cid) -> BaseStreamReader:
return self.streams[cid]
@pytest.mark.asyncio
async def test_should_create_dataset_of_right_size():
codex_agent = CodexAgent(FakeCodexClient())
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
manifest = await codex_agent.client.get_manifest(cid)
assert manifest.datasetSize == 1024
@pytest.mark.asyncio
async def test_same_seed_creates_same_cid():
codex_agent = CodexAgent(FakeCodexClient())
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235)
assert cid1 == cid2
assert cid1 != cid3
@pytest.mark.asyncio
async def test_should_report_download_progress():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert handle.progress() == DownloadStatus(downloaded=0, total=1000)
for i in range(200):
download_stream.feed_data(b"0" * 5)
assert await await_predicate_async(
lambda: handle.progress()
== DownloadStatus(downloaded=5 * (i + 1), total=1000),
timeout=5,
)
download_stream.feed_eof()
await handle.download_task
assert handle.progress() == DownloadStatus(downloaded=1000, total=1000)
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
download_stream.feed_eof()
with pytest.raises(EOFError):
await handle.download_task
@pytest.mark.asyncio
async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_logger):
logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger):
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
handle = await codex_agent.download(cid, read_increment=0.2)
await handle.download_task
parser = LogParser()
parser.register(CodexDownloadMetric)
metrics = list(parser.parse(StringIO(output.getvalue())))
assert metrics == [
CodexDownloadMetric(
cid=cid, value=200, node=codex_agent.node_id, timestamp=metrics[0].timestamp
),
CodexDownloadMetric(
cid=cid, value=400, node=codex_agent.node_id, timestamp=metrics[1].timestamp
),
CodexDownloadMetric(
cid=cid, value=600, node=codex_agent.node_id, timestamp=metrics[2].timestamp
),
CodexDownloadMetric(
cid=cid, value=800, node=codex_agent.node_id, timestamp=metrics[3].timestamp
),
CodexDownloadMetric(
cid=cid,
value=1000,
node=codex_agent.node_id,
timestamp=metrics[4].timestamp,
),
]
@pytest.mark.asyncio
async def test_should_track_download_handles():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
assert cid not in codex_agent.ongoing_downloads
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
assert codex_agent.ongoing_downloads[cid] == handle
await handle.download_task
assert cid in codex_agent.ongoing_downloads