240 lines
7.4 KiB
Python

import asyncio
from io import StringIO
from unittest.mock import patch
import pytest
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
from benchmarks.codex.agent.tests.fake_codex import FakeCodex
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.logging.logging import LogParser, DownloadMetric
@pytest.mark.asyncio
async def test_should_create_dataset_of_right_size():
codex_agent = CodexAgent(FakeCodex())
manifest = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
assert manifest.datasetSize == 1024
@pytest.mark.asyncio
async def test_same_seed_creates_same_cid():
codex_agent = CodexAgent(FakeCodex())
manifest1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
manifest2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
manifest3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235)
assert manifest1.cid == manifest2.cid
assert manifest1.cid != manifest3.cid
@pytest.mark.asyncio
async def test_should_report_download_progress():
client = FakeCodex()
codex_agent = CodexAgent(client, status_backoff=0.01)
manifest = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
fake_download = client.new_download(manifest)
handle = await codex_agent.download(manifest)
assert handle.progress() == DownloadStatus(downloaded=0, total=1000)
for i in range(200):
fake_download.advance_download(blocks=5)
assert await await_predicate_async(
lambda: handle.progress()
== DownloadStatus(downloaded=5 * (i + 1), total=1000),
timeout=5,
)
await handle.download_task
assert handle.progress() == DownloadStatus(downloaded=1000, total=1000)
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
client = FakeCodex()
codex_agent = CodexAgent(client)
manifest = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
fake_download = client.new_download(manifest)
handle = await codex_agent.download(manifest)
class SomeError(Exception):
pass
fake_download.abort(SomeError())
with pytest.raises(SomeError):
await handle.download_task
@pytest.mark.asyncio
async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_logger):
logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger):
client = FakeCodex()
codex_agent = CodexAgent(client)
manifest = await codex_agent.create_dataset(
size=1000, name="dataset-1", seed=1234
)
fake_download = client.new_download(manifest)
fake_download.advance_download(1000)
handle = await codex_agent.download(manifest, log_increment=0.2)
await handle.download_task
parser = LogParser()
parser.register(DownloadMetric)
metrics = list(parser.parse(StringIO(output.getvalue())))
assert metrics == [
DownloadMetric(
dataset_name="dataset-1",
value=200,
node=codex_agent.node_id,
timestamp=metrics[0].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=400,
node=codex_agent.node_id,
timestamp=metrics[1].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=600,
node=codex_agent.node_id,
timestamp=metrics[2].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=800,
node=codex_agent.node_id,
timestamp=metrics[3].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=1000,
node=codex_agent.node_id,
timestamp=metrics[4].timestamp,
),
]
@pytest.mark.asyncio
async def test_should_log_download_progress_as_discrete_steps_even_when_underlying_stream_is_choppy(
mock_logger,
):
logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger):
client = FakeCodex()
codex_agent = CodexAgent(client, status_backoff=0.01)
manifest = await codex_agent.create_dataset(
size=1000, name="dataset-1", seed=1234
)
fake_download = client.new_download(manifest)
handle = await codex_agent.download(manifest, log_increment=0.2)
# Simulates a choppy download which returns a lot less than the logging step size every time.
fed = 0
step = 37
while fed < 1000:
to_feed = min(step, 1000 - fed)
fake_download.advance_download(to_feed)
fed += to_feed
assert await await_predicate_async(
lambda: handle.progress() == DownloadStatus(downloaded=fed, total=1000),
timeout=5,
)
await handle.download_task
parser = LogParser()
parser.register(DownloadMetric)
metrics = list(parser.parse(StringIO(output.getvalue())))
assert metrics == [
DownloadMetric(
dataset_name="dataset-1",
value=200,
node=codex_agent.node_id,
timestamp=metrics[0].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=400,
node=codex_agent.node_id,
timestamp=metrics[1].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=600,
node=codex_agent.node_id,
timestamp=metrics[2].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=800,
node=codex_agent.node_id,
timestamp=metrics[3].timestamp,
),
DownloadMetric(
dataset_name="dataset-1",
value=1000,
node=codex_agent.node_id,
timestamp=metrics[4].timestamp,
),
]
@pytest.mark.asyncio
async def test_should_track_download_handles():
client = FakeCodex()
codex_agent = CodexAgent(client)
manifest = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
fake_download = client.new_download(manifest)
assert manifest.treeCid not in codex_agent.ongoing_downloads
handle = await codex_agent.download(manifest)
assert codex_agent.ongoing_downloads[manifest.treeCid] == handle
fake_download.advance_download(1000)
await handle.download_task
assert manifest.treeCid in codex_agent.ongoing_downloads
@pytest.mark.asyncio
async def test_should_timeout_if_download_goes_for_too_long_without_any_progress():
fake_codex = FakeCodex()
codex_agent = CodexAgent(fake_codex, status_backoff=0.01, progress_timeout=0.5)
fast = await codex_agent.create_dataset(size=1000, name="dataset-fast-1", seed=1356)
slow = await codex_agent.create_dataset(size=1000, name="dataset-slow-1", seed=1353)
fast_download = fake_codex.new_download(fast)
slow_download = fake_codex.new_download(slow)
fast_download.advance_download(1000)
fast_handle = await codex_agent.download(fast)
await fast_handle.download_task
slow_handle = await codex_agent.download(slow)
slow_download.advance_download(500)
await asyncio.sleep(0.6)
slow_download.advance_download(500)
with pytest.raises(asyncio.TimeoutError):
await slow_handle.download_task