diff --git a/benchmarks/codex/agent/agent.py b/benchmarks/codex/agent/agent.py index 2de32eb..12261e7 100644 --- a/benchmarks/codex/agent/agent.py +++ b/benchmarks/codex/agent/agent.py @@ -5,6 +5,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Dict +from aiohttp import ClientTimeout from pydantic import BaseModel from benchmarks.codex.client.async_client import AsyncCodexClient @@ -32,11 +33,13 @@ class DownloadHandle: parent: "CodexAgent", manifest: Manifest, read_increment: float = 0.01, + read_timeout: Optional[float] = None, ): self.parent = parent self.manifest = manifest self.bytes_downloaded = 0 self.read_increment = read_increment + self.read_timeout = read_timeout self.download_task: Optional[Task[None]] = None def begin_download(self) -> Task: @@ -46,7 +49,14 @@ class DownloadHandle: async def _download_loop(self): step_size = int(self.manifest.datasetSize * self.read_increment) - async with self.parent.client.download(self.manifest.cid) as download_stream: + async with self.parent.client.download( + self.manifest.cid, + timeout=ClientTimeout( + total=None, + sock_connect=30, + sock_read=self.read_timeout, + ), + ) as download_stream: logged_step = 0 while not download_stream.at_eof(): step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded) @@ -94,10 +104,16 @@ class DownloadHandle: class CodexAgent: - def __init__(self, client: AsyncCodexClient, node_id: str = "unknown") -> None: + def __init__( + self, + client: AsyncCodexClient, + node_id: str = "unknown", + read_timeout: Optional[float] = None, + ) -> None: self.client = client self.node_id = node_id self.ongoing_downloads: Dict[Cid, DownloadHandle] = {} + self.read_timeout = read_timeout async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: with TemporaryDirectory() as td: @@ -119,6 +135,7 @@ class CodexAgent: self, manifest=await self.client.manifest(cid), read_increment=read_increment, + read_timeout=self.read_timeout, ) handle.begin_download() diff --git a/benchmarks/codex/agent/api.py b/benchmarks/codex/agent/api.py index 878127a..95b7d20 100644 --- a/benchmarks/codex/agent/api.py +++ b/benchmarks/codex/agent/api.py @@ -1,3 +1,5 @@ +"""This module contains a REST API wrapping :class:`CodexAgent`.""" + from typing import Annotated, Optional from aiohttp import ClientResponseError diff --git a/benchmarks/codex/agent/codex_agent_client.py b/benchmarks/codex/agent/codex_agent_client.py index 3ce5b29..df8d439 100644 --- a/benchmarks/codex/agent/codex_agent_client.py +++ b/benchmarks/codex/agent/codex_agent_client.py @@ -1,3 +1,5 @@ +"""A simple client for interacting with the Codex Agent API.""" + import socket import requests diff --git a/benchmarks/codex/agent/tests/fake_codex.py b/benchmarks/codex/agent/tests/fake_codex.py new file mode 100644 index 0000000..21ffeea --- /dev/null +++ b/benchmarks/codex/agent/tests/fake_codex.py @@ -0,0 +1,113 @@ +import json +import re +from asyncio import StreamReader +from contextlib import asynccontextmanager +from io import BytesIO +from typing import Dict, Optional, AsyncIterator, Tuple, IO + +from aiohttp import web, ClientTimeout +from urllib3.util import Url + +from benchmarks.codex.client.async_client import AsyncCodexClient, Cid +from benchmarks.codex.client.common import Manifest +from benchmarks.core.utils.streams import BaseStreamReader + + +class FakeCodex(AsyncCodexClient): + def __init__(self) -> None: + self.storage: Dict[Cid, Manifest] = {} + self.streams: Dict[Cid, StreamReader] = {} + + async def upload( + self, + name: str, + mime_type: str, + content: IO, + timeout: Optional[ClientTimeout] = None, + ) -> Cid: + data = content.read() + cid = "Qm" + str(hash(data)) + self.storage[cid] = Manifest( + cid=cid, + datasetSize=len(data), + mimetype=mime_type, + blockSize=1, + filename=name, + treeCid="", + protected=False, + ) + return cid + + async def manifest(self, cid: Cid) -> Manifest: + return self.storage[cid] + + def create_download_stream(self, cid: Cid) -> StreamReader: + reader = StreamReader() + self.streams[cid] = reader + return reader + + @asynccontextmanager + async def download( + self, cid: Cid, timeout: Optional[ClientTimeout] = None + ) -> AsyncIterator[BaseStreamReader]: + yield self.streams[cid] + + +@asynccontextmanager +async def fake_codex_api() -> AsyncIterator[Tuple[FakeCodex, Url]]: + codex = FakeCodex() + routes = web.RouteTableDef() + + @routes.get("/api/codex/v1/data/{cid}/network/manifest") + async def manifest(request): + cid = request.match_info["cid"] + assert cid in codex.storage + # Gets the manifest in a similar shape as the Codex response. + manifest = json.loads(codex.storage[cid].model_dump_json()) + return web.json_response( + data={ + "cid": manifest.pop("cid"), + "manifest": manifest, + } + ) + + @routes.post("/api/codex/v1/data") + async def upload(request): + await request.post() + filename = re.findall( + r'filename="(.+)"', request.headers["Content-Disposition"] + )[0] + cid = await codex.upload( + name=filename, + mime_type=request.headers["Content-Type"], + content=BytesIO(await request.read()), + ) + return web.Response(text=cid) + + @routes.get("/api/codex/v1/data/{cid}") + async def download(request): + cid = request.match_info["cid"] + assert cid in codex.streams + reader = codex.streams[cid] + + # We basically copy the stream onto the response. + response = web.StreamResponse() + await response.prepare(request) + while not reader.at_eof(): + await response.write(await reader.read(1024)) + + await response.write_eof() + return response + + app = web.Application() + app.add_routes(routes) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8888) + await site.start() + + try: + yield codex, Url(scheme="http", host="localhost", port=8888) + finally: + await runner.cleanup() diff --git a/benchmarks/codex/agent/tests/test_api.py b/benchmarks/codex/agent/tests/test_api.py index 34d5c54..c5b76ae 100644 --- a/benchmarks/codex/agent/tests/test_api.py +++ b/benchmarks/codex/agent/tests/test_api.py @@ -5,12 +5,12 @@ from starlette.testclient import TestClient from benchmarks.codex.agent import api from benchmarks.codex.agent.agent import CodexAgent -from benchmarks.codex.agent.tests.test_codex_agent import FakeCodexClient +from benchmarks.codex.agent.tests.fake_codex import FakeCodex @pytest.mark.asyncio async def test_should_create_file(): - codex_client = FakeCodexClient() + codex_client = FakeCodex() codex_agent = CodexAgent(codex_client) app = FastAPI() @@ -34,7 +34,7 @@ async def test_should_create_file(): @pytest.mark.asyncio async def test_should_report_when_download_is_complete(): - codex_client = FakeCodexClient() + codex_client = FakeCodex() codex_agent = CodexAgent(codex_client) app = FastAPI() diff --git a/benchmarks/codex/agent/tests/test_codex_agent.py b/benchmarks/codex/agent/tests/test_codex_agent.py index a3bba87..9aea254 100644 --- a/benchmarks/codex/agent/tests/test_codex_agent.py +++ b/benchmarks/codex/agent/tests/test_codex_agent.py @@ -1,54 +1,19 @@ -from asyncio import StreamReader -from contextlib import asynccontextmanager +import asyncio from io import StringIO -from typing import IO, Dict, AsyncIterator from unittest.mock import patch import pytest from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus -from benchmarks.codex.client.async_client import AsyncCodexClient -from benchmarks.codex.client.common import Manifest, Cid +from benchmarks.codex.agent.tests.fake_codex import FakeCodex, fake_codex_api +from benchmarks.codex.client.async_client import AsyncCodexClientImpl from benchmarks.core.concurrency import await_predicate_async -from benchmarks.core.utils.streams import BaseStreamReader from benchmarks.logging.logging import LogParser, DownloadMetric -class FakeCodexClient(AsyncCodexClient): - def __init__(self) -> None: - self.storage: Dict[Cid, Manifest] = {} - self.streams: Dict[Cid, StreamReader] = {} - - async def upload(self, name: str, mime_type: str, content: IO) -> Cid: - data = content.read() - cid = "Qm" + str(hash(data)) - self.storage[cid] = Manifest( - cid=cid, - datasetSize=len(data), - mimetype=mime_type, - blockSize=1, - filename=name, - treeCid="", - protected=False, - ) - return cid - - async def manifest(self, cid: Cid) -> Manifest: - return self.storage[cid] - - def create_download_stream(self, cid: Cid) -> StreamReader: - reader = StreamReader() - self.streams[cid] = reader - return reader - - @asynccontextmanager - async def download(self, cid: Cid) -> AsyncIterator[BaseStreamReader]: - yield self.streams[cid] - - @pytest.mark.asyncio async def test_should_create_dataset_of_right_size(): - codex_agent = CodexAgent(FakeCodexClient()) + codex_agent = CodexAgent(FakeCodex()) cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234) manifest = await codex_agent.client.manifest(cid) @@ -57,7 +22,7 @@ async def test_should_create_dataset_of_right_size(): @pytest.mark.asyncio async def test_same_seed_creates_same_cid(): - codex_agent = CodexAgent(FakeCodexClient()) + codex_agent = CodexAgent(FakeCodex()) cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) @@ -69,7 +34,7 @@ async def test_same_seed_creates_same_cid(): @pytest.mark.asyncio async def test_should_report_download_progress(): - client = FakeCodexClient() + client = FakeCodex() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) @@ -95,7 +60,7 @@ async def test_should_report_download_progress(): @pytest.mark.asyncio async def test_should_raise_exception_on_progress_query_if_download_fails(): - client = FakeCodexClient() + client = FakeCodex() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) @@ -114,7 +79,7 @@ async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_log logger, output = mock_logger with patch("benchmarks.codex.agent.agent.logger", logger): - client = FakeCodexClient() + client = FakeCodex() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) @@ -177,7 +142,7 @@ async def test_should_log_download_progress_as_discrete_steps_even_when_underlyi logger, output = mock_logger with patch("benchmarks.codex.agent.agent.logger", logger): - client = FakeCodexClient() + client = FakeCodex() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234) download_stream = client.create_download_stream(cid) @@ -244,7 +209,7 @@ async def test_should_log_download_progress_as_discrete_steps_even_when_underlyi @pytest.mark.asyncio async def test_should_track_download_handles(): - client = FakeCodexClient() + client = FakeCodex() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356) @@ -262,3 +227,34 @@ async def test_should_track_download_handles(): await handle.download_task assert cid in codex_agent.ongoing_downloads + + +@pytest.mark.asyncio +async def test_should_timeout_if_download_stream_takes_too_long_to_return_content(): + async with fake_codex_api() as (fake_codex, url): + client = AsyncCodexClientImpl(url) + codex_agent = CodexAgent(client, read_timeout=0.5) + + fast_cid = await codex_agent.create_dataset( + size=1000, name="dataset-fast-1", seed=1356 + ) + slow_cid = await codex_agent.create_dataset( + size=1000, name="dataset-slow-1", seed=1353 + ) + + fast_download = fake_codex.create_download_stream(fast_cid) + slow_download = fake_codex.create_download_stream(slow_cid) + + fast_download.feed_data(b"0" * 1000) + fast_download.feed_eof() + fast_handle = await codex_agent.download(fast_cid) + await fast_handle.download_task + + slow_handle = await codex_agent.download(slow_cid) + slow_download.feed_data(b"0" * 500) + await asyncio.sleep(0.6) + slow_download.feed_data(b"0" * 500) + slow_download.feed_eof() + + with pytest.raises(asyncio.TimeoutError): + await slow_handle.download_task diff --git a/benchmarks/codex/client/async_client.py b/benchmarks/codex/client/async_client.py index a725430..c3379af 100644 --- a/benchmarks/codex/client/async_client.py +++ b/benchmarks/codex/client/async_client.py @@ -3,9 +3,10 @@ from abc import ABC, abstractmethod from contextlib import asynccontextmanager -from typing import IO, AsyncIterator, AsyncGenerator +from typing import IO, AsyncIterator, AsyncGenerator, Optional import aiohttp +from aiohttp import ClientTimeout from urllib3.util import Url from benchmarks.codex.client.common import Manifest, Cid @@ -14,7 +15,13 @@ from benchmarks.core.utils.streams import BaseStreamReader class AsyncCodexClient(ABC): @abstractmethod - async def upload(self, name: str, mime_type: str, content: IO) -> Cid: + async def upload( + self, + name: str, + mime_type: str, + content: IO, + timeout: Optional[ClientTimeout] = None, + ) -> Cid: pass @abstractmethod @@ -23,7 +30,9 @@ class AsyncCodexClient(ABC): @asynccontextmanager @abstractmethod - def download(self, cid: Cid) -> AsyncGenerator[BaseStreamReader, None]: + def download( + self, cid: Cid, timeout: Optional[ClientTimeout] = None + ) -> AsyncGenerator[BaseStreamReader, None]: pass @@ -33,8 +42,14 @@ class AsyncCodexClientImpl(AsyncCodexClient): def __init__(self, codex_api_url: Url): self.codex_api_url = codex_api_url - async def upload(self, name: str, mime_type: str, content: IO) -> Cid: - async with aiohttp.ClientSession() as session: + async def upload( + self, + name: str, + mime_type: str, + content: IO, + timeout: Optional[ClientTimeout] = None, + ) -> Cid: + async with aiohttp.ClientSession(timeout=ClientTimeout()) as session: response = await session.post( self.codex_api_url._replace(path="/api/codex/v1/data").url, headers={ @@ -42,6 +57,7 @@ class AsyncCodexClientImpl(AsyncCodexClient): aiohttp.hdrs.CONTENT_DISPOSITION: f'attachment; filename="{name}"', }, data=content, + timeout=timeout, ) response.raise_for_status() @@ -62,10 +78,13 @@ class AsyncCodexClientImpl(AsyncCodexClient): return Manifest.from_codex_api_response(response_contents) @asynccontextmanager - async def download(self, cid: Cid) -> AsyncIterator[BaseStreamReader]: - async with aiohttp.ClientSession() as session: + async def download( + self, cid: Cid, timeout: Optional[ClientTimeout] = None + ) -> AsyncIterator[BaseStreamReader]: + async with aiohttp.ClientSession(timeout=ClientTimeout()) as session: response = await session.get( self.codex_api_url._replace(path=f"/api/codex/v1/data/{cid}").url, + timeout=timeout, ) response.raise_for_status()