feat: add logging of download metrics (assuming streaming) to Codex agent

This commit is contained in:
gmega 2025-02-03 15:23:53 -03:00
parent cb941b859f
commit 7844a3e338
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
3 changed files with 69 additions and 2 deletions

View File

@ -1,10 +1,12 @@
import asyncio
import logging
from asyncio import Task
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from benchmarks.codex.client import CodexClient, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.utils.random import random_data
from benchmarks.core.utils.streams import BaseStreamReader
@ -12,6 +14,8 @@ Cid = str
EMPTY_STREAM_BACKOFF = 0.1
logger = logging.getLogger(__name__)
class DownloadHandle:
def __init__(
@ -43,6 +47,13 @@ class DownloadHandle:
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
logger.info(
CodexDownloadMetric(
cid=self.manifest.cid,
value=self.bytes_downloaded,
node=self.parent.node_id,
)
)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
@ -68,8 +79,9 @@ class DownloadHandle:
class CodexAgent:
def __init__(self, client: CodexClient) -> None:
def __init__(self, client: CodexClient, node_id: str = "unknown") -> None:
self.client = client
self.node_id = node_id
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
with TemporaryDirectory() as td:
@ -83,11 +95,12 @@ class CodexAgent:
name=name, mime_type="application/octet-stream", content=infile
)
async def download(self, cid: Cid) -> DownloadHandle:
async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
handle = DownloadHandle(
self,
manifest=await self.client.get_manifest(cid),
download_stream=await self.client.download(cid),
read_increment=read_increment,
)
handle.begin_download()

View File

@ -0,0 +1,6 @@
from benchmarks.logging.logging import Metric
class CodexDownloadMetric(Metric):
name: str = "codex_download"
cid: str

View File

@ -1,12 +1,16 @@
from asyncio import StreamReader
from io import StringIO
from typing import IO, Dict
from unittest.mock import patch
import pytest
from benchmarks.codex.agent import CodexAgent
from benchmarks.codex.client import CodexClient, Cid, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader
from benchmarks.logging.logging import LogParser
class FakeCodexClient(CodexClient):
@ -116,3 +120,47 @@ async def test_should_raise_exception_on_progress_query_if_download_fails():
return False
await await_predicate_async(_predicate, timeout=5)
@pytest.mark.asyncio
async def test_should_log_download_progress_as_metric(mock_logger):
logger, output = mock_logger
with patch("benchmarks.codex.agent.logger", logger):
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
handle = await codex_agent.download(cid, read_increment=0.2)
await handle.download_task
parser = LogParser()
parser.register(CodexDownloadMetric)
metrics = list(parser.parse(StringIO(output.getvalue())))
assert metrics == [
CodexDownloadMetric(
cid=cid, value=200, node=codex_agent.node_id, timestamp=metrics[0].timestamp
),
CodexDownloadMetric(
cid=cid, value=400, node=codex_agent.node_id, timestamp=metrics[1].timestamp
),
CodexDownloadMetric(
cid=cid, value=600, node=codex_agent.node_id, timestamp=metrics[2].timestamp
),
CodexDownloadMetric(
cid=cid, value=800, node=codex_agent.node_id, timestamp=metrics[3].timestamp
),
CodexDownloadMetric(
cid=cid,
value=1000,
node=codex_agent.node_id,
timestamp=metrics[4].timestamp,
),
]