167 lines
5.1 KiB
Python

from asyncio import StreamReader
from io import StringIO
from typing import IO, Dict
from unittest.mock import patch
import pytest
from benchmarks.codex.agent import CodexAgent
from benchmarks.codex.client import CodexClient, Cid, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader
from benchmarks.logging.logging import LogParser
class FakeCodexClient(CodexClient):
def __init__(self) -> None:
self.storage: Dict[Cid, Manifest] = {}
self.streams: Dict[Cid, StreamReader] = {}
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
data = content.read()
cid = "Qm" + str(hash(data))
self.storage[cid] = Manifest(
cid=cid,
datasetSize=len(data),
mimetype=mime_type,
blockSize=1,
filename=name,
treeCid="",
uploadedAt=0,
protected=False,
)
return cid
async def get_manifest(self, cid: Cid) -> Manifest:
return self.storage[cid]
def create_download_stream(self, cid: Cid) -> StreamReader:
reader = StreamReader()
self.streams[cid] = reader
return reader
async def download(self, cid: Cid) -> BaseStreamReader:
return self.streams[cid]
@pytest.mark.asyncio
async def test_should_create_dataset_of_right_size():
codex_agent = CodexAgent(FakeCodexClient())
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
manifest = await codex_agent.client.get_manifest(cid)
assert manifest.datasetSize == 1024
@pytest.mark.asyncio
async def test_same_seed_creates_same_cid():
codex_agent = CodexAgent(FakeCodexClient())
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235)
assert cid1 == cid2
assert cid1 != cid3
@pytest.mark.asyncio
async def test_should_report_download_progress():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert handle.progress() == 0
for i in range(100):
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == i, timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_eof()
await handle.download_task
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
download_stream.feed_eof()
with pytest.raises(EOFError):
def _predicate():
handle.progress()
return False
await await_predicate_async(_predicate, timeout=5)
@pytest.mark.asyncio
async def test_should_log_download_progress_as_metric(mock_logger):
logger, output = mock_logger
with patch("benchmarks.codex.agent.logger", logger):
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
handle = await codex_agent.download(cid, read_increment=0.2)
await handle.download_task
parser = LogParser()
parser.register(CodexDownloadMetric)
metrics = list(parser.parse(StringIO(output.getvalue())))
assert metrics == [
CodexDownloadMetric(
cid=cid, value=200, node=codex_agent.node_id, timestamp=metrics[0].timestamp
),
CodexDownloadMetric(
cid=cid, value=400, node=codex_agent.node_id, timestamp=metrics[1].timestamp
),
CodexDownloadMetric(
cid=cid, value=600, node=codex_agent.node_id, timestamp=metrics[2].timestamp
),
CodexDownloadMetric(
cid=cid, value=800, node=codex_agent.node_id, timestamp=metrics[3].timestamp
),
CodexDownloadMetric(
cid=cid,
value=1000,
node=codex_agent.node_id,
timestamp=metrics[4].timestamp,
),
]