fix: implement correct timeout behavior for Codex streaming downloads

This commit is contained in:
gmega 2025-02-18 15:41:29 -03:00
parent a0e4181123
commit e47f8848e4
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
7 changed files with 206 additions and 57 deletions

View File

@ -5,6 +5,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Dict
from aiohttp import ClientTimeout
from pydantic import BaseModel
from benchmarks.codex.client.async_client import AsyncCodexClient
@ -32,11 +33,13 @@ class DownloadHandle:
parent: "CodexAgent",
manifest: Manifest,
read_increment: float = 0.01,
read_timeout: Optional[float] = None,
):
self.parent = parent
self.manifest = manifest
self.bytes_downloaded = 0
self.read_increment = read_increment
self.read_timeout = read_timeout
self.download_task: Optional[Task[None]] = None
def begin_download(self) -> Task:
@ -46,7 +49,14 @@ class DownloadHandle:
async def _download_loop(self):
step_size = int(self.manifest.datasetSize * self.read_increment)
async with self.parent.client.download(self.manifest.cid) as download_stream:
async with self.parent.client.download(
self.manifest.cid,
timeout=ClientTimeout(
total=None,
sock_connect=30,
sock_read=self.read_timeout,
),
) as download_stream:
logged_step = 0
while not download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
@ -94,10 +104,16 @@ class DownloadHandle:
class CodexAgent:
def __init__(self, client: AsyncCodexClient, node_id: str = "unknown") -> None:
def __init__(
self,
client: AsyncCodexClient,
node_id: str = "unknown",
read_timeout: Optional[float] = None,
) -> None:
self.client = client
self.node_id = node_id
self.ongoing_downloads: Dict[Cid, DownloadHandle] = {}
self.read_timeout = read_timeout
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
with TemporaryDirectory() as td:
@ -119,6 +135,7 @@ class CodexAgent:
self,
manifest=await self.client.manifest(cid),
read_increment=read_increment,
read_timeout=self.read_timeout,
)
handle.begin_download()

View File

@ -1,3 +1,5 @@
"""This module contains a REST API wrapping :class:`CodexAgent`."""
from typing import Annotated, Optional
from aiohttp import ClientResponseError

View File

@ -1,3 +1,5 @@
"""A simple client for interacting with the Codex Agent API."""
import socket
import requests

View File

@ -0,0 +1,113 @@
import json
import re
from asyncio import StreamReader
from contextlib import asynccontextmanager
from io import BytesIO
from typing import Dict, Optional, AsyncIterator, Tuple, IO
from aiohttp import web, ClientTimeout
from urllib3.util import Url
from benchmarks.codex.client.async_client import AsyncCodexClient, Cid
from benchmarks.codex.client.common import Manifest
from benchmarks.core.utils.streams import BaseStreamReader
class FakeCodex(AsyncCodexClient):
def __init__(self) -> None:
self.storage: Dict[Cid, Manifest] = {}
self.streams: Dict[Cid, StreamReader] = {}
async def upload(
self,
name: str,
mime_type: str,
content: IO,
timeout: Optional[ClientTimeout] = None,
) -> Cid:
data = content.read()
cid = "Qm" + str(hash(data))
self.storage[cid] = Manifest(
cid=cid,
datasetSize=len(data),
mimetype=mime_type,
blockSize=1,
filename=name,
treeCid="",
protected=False,
)
return cid
async def manifest(self, cid: Cid) -> Manifest:
return self.storage[cid]
def create_download_stream(self, cid: Cid) -> StreamReader:
reader = StreamReader()
self.streams[cid] = reader
return reader
@asynccontextmanager
async def download(
self, cid: Cid, timeout: Optional[ClientTimeout] = None
) -> AsyncIterator[BaseStreamReader]:
yield self.streams[cid]
@asynccontextmanager
async def fake_codex_api() -> AsyncIterator[Tuple[FakeCodex, Url]]:
codex = FakeCodex()
routes = web.RouteTableDef()
@routes.get("/api/codex/v1/data/{cid}/network/manifest")
async def manifest(request):
cid = request.match_info["cid"]
assert cid in codex.storage
# Gets the manifest in a similar shape as the Codex response.
manifest = json.loads(codex.storage[cid].model_dump_json())
return web.json_response(
data={
"cid": manifest.pop("cid"),
"manifest": manifest,
}
)
@routes.post("/api/codex/v1/data")
async def upload(request):
await request.post()
filename = re.findall(
r'filename="(.+)"', request.headers["Content-Disposition"]
)[0]
cid = await codex.upload(
name=filename,
mime_type=request.headers["Content-Type"],
content=BytesIO(await request.read()),
)
return web.Response(text=cid)
@routes.get("/api/codex/v1/data/{cid}")
async def download(request):
cid = request.match_info["cid"]
assert cid in codex.streams
reader = codex.streams[cid]
# We basically copy the stream onto the response.
response = web.StreamResponse()
await response.prepare(request)
while not reader.at_eof():
await response.write(await reader.read(1024))
await response.write_eof()
return response
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", 8888)
await site.start()
try:
yield codex, Url(scheme="http", host="localhost", port=8888)
finally:
await runner.cleanup()

View File

@ -5,12 +5,12 @@ from starlette.testclient import TestClient
from benchmarks.codex.agent import api
from benchmarks.codex.agent.agent import CodexAgent
from benchmarks.codex.agent.tests.test_codex_agent import FakeCodexClient
from benchmarks.codex.agent.tests.fake_codex import FakeCodex
@pytest.mark.asyncio
async def test_should_create_file():
codex_client = FakeCodexClient()
codex_client = FakeCodex()
codex_agent = CodexAgent(codex_client)
app = FastAPI()
@ -34,7 +34,7 @@ async def test_should_create_file():
@pytest.mark.asyncio
async def test_should_report_when_download_is_complete():
codex_client = FakeCodexClient()
codex_client = FakeCodex()
codex_agent = CodexAgent(codex_client)
app = FastAPI()

View File

@ -1,54 +1,19 @@
from asyncio import StreamReader
from contextlib import asynccontextmanager
import asyncio
from io import StringIO
from typing import IO, Dict, AsyncIterator
from unittest.mock import patch
import pytest
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
from benchmarks.codex.client.async_client import AsyncCodexClient
from benchmarks.codex.client.common import Manifest, Cid
from benchmarks.codex.agent.tests.fake_codex import FakeCodex, fake_codex_api
from benchmarks.codex.client.async_client import AsyncCodexClientImpl
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader
from benchmarks.logging.logging import LogParser, DownloadMetric
class FakeCodexClient(AsyncCodexClient):
def __init__(self) -> None:
self.storage: Dict[Cid, Manifest] = {}
self.streams: Dict[Cid, StreamReader] = {}
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
data = content.read()
cid = "Qm" + str(hash(data))
self.storage[cid] = Manifest(
cid=cid,
datasetSize=len(data),
mimetype=mime_type,
blockSize=1,
filename=name,
treeCid="",
protected=False,
)
return cid
async def manifest(self, cid: Cid) -> Manifest:
return self.storage[cid]
def create_download_stream(self, cid: Cid) -> StreamReader:
reader = StreamReader()
self.streams[cid] = reader
return reader
@asynccontextmanager
async def download(self, cid: Cid) -> AsyncIterator[BaseStreamReader]:
yield self.streams[cid]
@pytest.mark.asyncio
async def test_should_create_dataset_of_right_size():
codex_agent = CodexAgent(FakeCodexClient())
codex_agent = CodexAgent(FakeCodex())
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
manifest = await codex_agent.client.manifest(cid)
@ -57,7 +22,7 @@ async def test_should_create_dataset_of_right_size():
@pytest.mark.asyncio
async def test_same_seed_creates_same_cid():
codex_agent = CodexAgent(FakeCodexClient())
codex_agent = CodexAgent(FakeCodex())
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
@ -69,7 +34,7 @@ async def test_same_seed_creates_same_cid():
@pytest.mark.asyncio
async def test_should_report_download_progress():
client = FakeCodexClient()
client = FakeCodex()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
@ -95,7 +60,7 @@ async def test_should_report_download_progress():
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
client = FakeCodexClient()
client = FakeCodex()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
@ -114,7 +79,7 @@ async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_log
logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger):
client = FakeCodexClient()
client = FakeCodex()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
@ -177,7 +142,7 @@ async def test_should_log_download_progress_as_discrete_steps_even_when_underlyi
logger, output = mock_logger
with patch("benchmarks.codex.agent.agent.logger", logger):
client = FakeCodexClient()
client = FakeCodex()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
@ -244,7 +209,7 @@ async def test_should_log_download_progress_as_discrete_steps_even_when_underlyi
@pytest.mark.asyncio
async def test_should_track_download_handles():
client = FakeCodexClient()
client = FakeCodex()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
@ -262,3 +227,34 @@ async def test_should_track_download_handles():
await handle.download_task
assert cid in codex_agent.ongoing_downloads
@pytest.mark.asyncio
async def test_should_timeout_if_download_stream_takes_too_long_to_return_content():
async with fake_codex_api() as (fake_codex, url):
client = AsyncCodexClientImpl(url)
codex_agent = CodexAgent(client, read_timeout=0.5)
fast_cid = await codex_agent.create_dataset(
size=1000, name="dataset-fast-1", seed=1356
)
slow_cid = await codex_agent.create_dataset(
size=1000, name="dataset-slow-1", seed=1353
)
fast_download = fake_codex.create_download_stream(fast_cid)
slow_download = fake_codex.create_download_stream(slow_cid)
fast_download.feed_data(b"0" * 1000)
fast_download.feed_eof()
fast_handle = await codex_agent.download(fast_cid)
await fast_handle.download_task
slow_handle = await codex_agent.download(slow_cid)
slow_download.feed_data(b"0" * 500)
await asyncio.sleep(0.6)
slow_download.feed_data(b"0" * 500)
slow_download.feed_eof()
with pytest.raises(asyncio.TimeoutError):
await slow_handle.download_task

View File

@ -3,9 +3,10 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import IO, AsyncIterator, AsyncGenerator
from typing import IO, AsyncIterator, AsyncGenerator, Optional
import aiohttp
from aiohttp import ClientTimeout
from urllib3.util import Url
from benchmarks.codex.client.common import Manifest, Cid
@ -14,7 +15,13 @@ from benchmarks.core.utils.streams import BaseStreamReader
class AsyncCodexClient(ABC):
@abstractmethod
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
async def upload(
self,
name: str,
mime_type: str,
content: IO,
timeout: Optional[ClientTimeout] = None,
) -> Cid:
pass
@abstractmethod
@ -23,7 +30,9 @@ class AsyncCodexClient(ABC):
@asynccontextmanager
@abstractmethod
def download(self, cid: Cid) -> AsyncGenerator[BaseStreamReader, None]:
def download(
self, cid: Cid, timeout: Optional[ClientTimeout] = None
) -> AsyncGenerator[BaseStreamReader, None]:
pass
@ -33,8 +42,14 @@ class AsyncCodexClientImpl(AsyncCodexClient):
def __init__(self, codex_api_url: Url):
self.codex_api_url = codex_api_url
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
async with aiohttp.ClientSession() as session:
async def upload(
self,
name: str,
mime_type: str,
content: IO,
timeout: Optional[ClientTimeout] = None,
) -> Cid:
async with aiohttp.ClientSession(timeout=ClientTimeout()) as session:
response = await session.post(
self.codex_api_url._replace(path="/api/codex/v1/data").url,
headers={
@ -42,6 +57,7 @@ class AsyncCodexClientImpl(AsyncCodexClient):
aiohttp.hdrs.CONTENT_DISPOSITION: f'attachment; filename="{name}"',
},
data=content,
timeout=timeout,
)
response.raise_for_status()
@ -62,10 +78,13 @@ class AsyncCodexClientImpl(AsyncCodexClient):
return Manifest.from_codex_api_response(response_contents)
@asynccontextmanager
async def download(self, cid: Cid) -> AsyncIterator[BaseStreamReader]:
async with aiohttp.ClientSession() as session:
async def download(
self, cid: Cid, timeout: Optional[ClientTimeout] = None
) -> AsyncIterator[BaseStreamReader]:
async with aiohttp.ClientSession(timeout=ClientTimeout()) as session:
response = await session.get(
self.codex_api_url._replace(path=f"/api/codex/v1/data/{cid}").url,
timeout=timeout,
)
response.raise_for_status()