From 7844a3e33879787643388e3f989526d36e0562c8 Mon Sep 17 00:00:00 2001 From: gmega Date: Mon, 3 Feb 2025 15:23:53 -0300 Subject: [PATCH] feat: add logging of download metrics (assuming streaming) to Codex agent --- benchmarks/codex/agent.py | 17 +++++++- benchmarks/codex/logging.py | 6 +++ benchmarks/codex/tests/test_codex_agent.py | 48 ++++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 benchmarks/codex/logging.py diff --git a/benchmarks/codex/agent.py b/benchmarks/codex/agent.py index ab1287c..a4343e7 100644 --- a/benchmarks/codex/agent.py +++ b/benchmarks/codex/agent.py @@ -1,10 +1,12 @@ import asyncio +import logging from asyncio import Task from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional from benchmarks.codex.client import CodexClient, Manifest +from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.core.utils.random import random_data from benchmarks.core.utils.streams import BaseStreamReader @@ -12,6 +14,8 @@ Cid = str EMPTY_STREAM_BACKOFF = 0.1 +logger = logging.getLogger(__name__) + class DownloadHandle: def __init__( @@ -43,6 +47,13 @@ class DownloadHandle: if not bytes_read: await asyncio.sleep(EMPTY_STREAM_BACKOFF) self.bytes_downloaded += len(bytes_read) + logger.info( + CodexDownloadMetric( + cid=self.manifest.cid, + value=self.bytes_downloaded, + node=self.parent.node_id, + ) + ) if self.bytes_downloaded < self.manifest.datasetSize: raise EOFError( @@ -68,8 +79,9 @@ class DownloadHandle: class CodexAgent: - def __init__(self, client: CodexClient) -> None: + def __init__(self, client: CodexClient, node_id: str = "unknown") -> None: self.client = client + self.node_id = node_id async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: with TemporaryDirectory() as td: @@ -83,11 +95,12 @@ class CodexAgent: name=name, mime_type="application/octet-stream", content=infile ) - async def download(self, cid: Cid) -> DownloadHandle: + async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle: handle = DownloadHandle( self, manifest=await self.client.get_manifest(cid), download_stream=await self.client.download(cid), + read_increment=read_increment, ) handle.begin_download() diff --git a/benchmarks/codex/logging.py b/benchmarks/codex/logging.py new file mode 100644 index 0000000..0e99515 --- /dev/null +++ b/benchmarks/codex/logging.py @@ -0,0 +1,6 @@ +from benchmarks.logging.logging import Metric + + +class CodexDownloadMetric(Metric): + name: str = "codex_download" + cid: str diff --git a/benchmarks/codex/tests/test_codex_agent.py b/benchmarks/codex/tests/test_codex_agent.py index a099cd4..dc29045 100644 --- a/benchmarks/codex/tests/test_codex_agent.py +++ b/benchmarks/codex/tests/test_codex_agent.py @@ -1,12 +1,16 @@ from asyncio import StreamReader +from io import StringIO from typing import IO, Dict +from unittest.mock import patch import pytest from benchmarks.codex.agent import CodexAgent from benchmarks.codex.client import CodexClient, Cid, Manifest +from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.core.concurrency import await_predicate_async from benchmarks.core.utils.streams import BaseStreamReader +from benchmarks.logging.logging import LogParser class FakeCodexClient(CodexClient): @@ -116,3 +120,47 @@ async def test_should_raise_exception_on_progress_query_if_download_fails(): return False await await_predicate_async(_predicate, timeout=5) + + +@pytest.mark.asyncio +async def test_should_log_download_progress_as_metric(mock_logger): + logger, output = mock_logger + + with patch("benchmarks.codex.agent.logger", logger): + client = FakeCodexClient() + codex_agent = CodexAgent(client) + + cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) + + download_stream = client.create_download_stream(cid) + download_stream.feed_data(b"0" * 1000) + download_stream.feed_eof() + + handle = await codex_agent.download(cid, read_increment=0.2) + await handle.download_task + + parser = LogParser() + parser.register(CodexDownloadMetric) + + metrics = list(parser.parse(StringIO(output.getvalue()))) + + assert metrics == [ + CodexDownloadMetric( + cid=cid, value=200, node=codex_agent.node_id, timestamp=metrics[0].timestamp + ), + CodexDownloadMetric( + cid=cid, value=400, node=codex_agent.node_id, timestamp=metrics[1].timestamp + ), + CodexDownloadMetric( + cid=cid, value=600, node=codex_agent.node_id, timestamp=metrics[2].timestamp + ), + CodexDownloadMetric( + cid=cid, value=800, node=codex_agent.node_id, timestamp=metrics[3].timestamp + ), + CodexDownloadMetric( + cid=cid, + value=1000, + node=codex_agent.node_id, + timestamp=metrics[4].timestamp, + ), + ]