feat: add logging of download metrics (assuming streaming) to Codex agent

This commit is contained in:
gmega 2025-02-03 15:23:53 -03:00
parent cb941b859f
commit 7844a3e338
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
3 changed files with 69 additions and 2 deletions

View File

@ -1,10 +1,12 @@
import asyncio import asyncio
import logging
from asyncio import Task from asyncio import Task
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
from benchmarks.codex.client import CodexClient, Manifest from benchmarks.codex.client import CodexClient, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.utils.random import random_data from benchmarks.core.utils.random import random_data
from benchmarks.core.utils.streams import BaseStreamReader from benchmarks.core.utils.streams import BaseStreamReader
@ -12,6 +14,8 @@ Cid = str
EMPTY_STREAM_BACKOFF = 0.1 EMPTY_STREAM_BACKOFF = 0.1
logger = logging.getLogger(__name__)
class DownloadHandle: class DownloadHandle:
def __init__( def __init__(
@ -43,6 +47,13 @@ class DownloadHandle:
if not bytes_read: if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF) await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read) self.bytes_downloaded += len(bytes_read)
logger.info(
CodexDownloadMetric(
cid=self.manifest.cid,
value=self.bytes_downloaded,
node=self.parent.node_id,
)
)
if self.bytes_downloaded < self.manifest.datasetSize: if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError( raise EOFError(
@ -68,8 +79,9 @@ class DownloadHandle:
class CodexAgent: class CodexAgent:
def __init__(self, client: CodexClient) -> None: def __init__(self, client: CodexClient, node_id: str = "unknown") -> None:
self.client = client self.client = client
self.node_id = node_id
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
with TemporaryDirectory() as td: with TemporaryDirectory() as td:
@ -83,11 +95,12 @@ class CodexAgent:
name=name, mime_type="application/octet-stream", content=infile name=name, mime_type="application/octet-stream", content=infile
) )
async def download(self, cid: Cid) -> DownloadHandle: async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
handle = DownloadHandle( handle = DownloadHandle(
self, self,
manifest=await self.client.get_manifest(cid), manifest=await self.client.get_manifest(cid),
download_stream=await self.client.download(cid), download_stream=await self.client.download(cid),
read_increment=read_increment,
) )
handle.begin_download() handle.begin_download()

View File

@ -0,0 +1,6 @@
from benchmarks.logging.logging import Metric
class CodexDownloadMetric(Metric):
name: str = "codex_download"
cid: str

View File

@ -1,12 +1,16 @@
from asyncio import StreamReader from asyncio import StreamReader
from io import StringIO
from typing import IO, Dict from typing import IO, Dict
from unittest.mock import patch
import pytest import pytest
from benchmarks.codex.agent import CodexAgent from benchmarks.codex.agent import CodexAgent
from benchmarks.codex.client import CodexClient, Cid, Manifest from benchmarks.codex.client import CodexClient, Cid, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.concurrency import await_predicate_async from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader from benchmarks.core.utils.streams import BaseStreamReader
from benchmarks.logging.logging import LogParser
class FakeCodexClient(CodexClient): class FakeCodexClient(CodexClient):
@ -116,3 +120,47 @@ async def test_should_raise_exception_on_progress_query_if_download_fails():
return False return False
await await_predicate_async(_predicate, timeout=5) await await_predicate_async(_predicate, timeout=5)
@pytest.mark.asyncio
async def test_should_log_download_progress_as_metric(mock_logger):
logger, output = mock_logger
with patch("benchmarks.codex.agent.logger", logger):
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
handle = await codex_agent.download(cid, read_increment=0.2)
await handle.download_task
parser = LogParser()
parser.register(CodexDownloadMetric)
metrics = list(parser.parse(StringIO(output.getvalue())))
assert metrics == [
CodexDownloadMetric(
cid=cid, value=200, node=codex_agent.node_id, timestamp=metrics[0].timestamp
),
CodexDownloadMetric(
cid=cid, value=400, node=codex_agent.node_id, timestamp=metrics[1].timestamp
),
CodexDownloadMetric(
cid=cid, value=600, node=codex_agent.node_id, timestamp=metrics[2].timestamp
),
CodexDownloadMetric(
cid=cid, value=800, node=codex_agent.node_id, timestamp=metrics[3].timestamp
),
CodexDownloadMetric(
cid=cid,
value=1000,
node=codex_agent.node_id,
timestamp=metrics[4].timestamp,
),
]