mirror of
https://github.com/logos-storage/bittorrent-benchmarks.git
synced 2026-01-02 13:03:13 +00:00
fix: implement correct timeout behavior for Codex streaming downloads
This commit is contained in:
parent
a0e4181123
commit
e47f8848e4
@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Dict
|
||||
|
||||
from aiohttp import ClientTimeout
|
||||
from pydantic import BaseModel
|
||||
|
||||
from benchmarks.codex.client.async_client import AsyncCodexClient
|
||||
@ -32,11 +33,13 @@ class DownloadHandle:
|
||||
parent: "CodexAgent",
|
||||
manifest: Manifest,
|
||||
read_increment: float = 0.01,
|
||||
read_timeout: Optional[float] = None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.manifest = manifest
|
||||
self.bytes_downloaded = 0
|
||||
self.read_increment = read_increment
|
||||
self.read_timeout = read_timeout
|
||||
self.download_task: Optional[Task[None]] = None
|
||||
|
||||
def begin_download(self) -> Task:
|
||||
@ -46,7 +49,14 @@ class DownloadHandle:
|
||||
async def _download_loop(self):
|
||||
step_size = int(self.manifest.datasetSize * self.read_increment)
|
||||
|
||||
async with self.parent.client.download(self.manifest.cid) as download_stream:
|
||||
async with self.parent.client.download(
|
||||
self.manifest.cid,
|
||||
timeout=ClientTimeout(
|
||||
total=None,
|
||||
sock_connect=30,
|
||||
sock_read=self.read_timeout,
|
||||
),
|
||||
) as download_stream:
|
||||
logged_step = 0
|
||||
while not download_stream.at_eof():
|
||||
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
|
||||
@ -94,10 +104,16 @@ class DownloadHandle:
|
||||
|
||||
|
||||
class CodexAgent:
|
||||
def __init__(self, client: AsyncCodexClient, node_id: str = "unknown") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncCodexClient,
|
||||
node_id: str = "unknown",
|
||||
read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.node_id = node_id
|
||||
self.ongoing_downloads: Dict[Cid, DownloadHandle] = {}
|
||||
self.read_timeout = read_timeout
|
||||
|
||||
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
|
||||
with TemporaryDirectory() as td:
|
||||
@ -119,6 +135,7 @@ class CodexAgent:
|
||||
self,
|
||||
manifest=await self.client.manifest(cid),
|
||||
read_increment=read_increment,
|
||||
read_timeout=self.read_timeout,
|
||||
)
|
||||
|
||||
handle.begin_download()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
"""This module contains a REST API wrapping :class:`CodexAgent`."""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from aiohttp import ClientResponseError
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
"""A simple client for interacting with the Codex Agent API."""
|
||||
|
||||
import socket
|
||||
|
||||
import requests
|
||||
|
||||
113
benchmarks/codex/agent/tests/fake_codex.py
Normal file
113
benchmarks/codex/agent/tests/fake_codex.py
Normal file
@ -0,0 +1,113 @@
|
||||
import json
|
||||
import re
|
||||
from asyncio import StreamReader
|
||||
from contextlib import asynccontextmanager
|
||||
from io import BytesIO
|
||||
from typing import Dict, Optional, AsyncIterator, Tuple, IO
|
||||
|
||||
from aiohttp import web, ClientTimeout
|
||||
from urllib3.util import Url
|
||||
|
||||
from benchmarks.codex.client.async_client import AsyncCodexClient, Cid
|
||||
from benchmarks.codex.client.common import Manifest
|
||||
from benchmarks.core.utils.streams import BaseStreamReader
|
||||
|
||||
|
||||
class FakeCodex(AsyncCodexClient):
|
||||
def __init__(self) -> None:
|
||||
self.storage: Dict[Cid, Manifest] = {}
|
||||
self.streams: Dict[Cid, StreamReader] = {}
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
content: IO,
|
||||
timeout: Optional[ClientTimeout] = None,
|
||||
) -> Cid:
|
||||
data = content.read()
|
||||
cid = "Qm" + str(hash(data))
|
||||
self.storage[cid] = Manifest(
|
||||
cid=cid,
|
||||
datasetSize=len(data),
|
||||
mimetype=mime_type,
|
||||
blockSize=1,
|
||||
filename=name,
|
||||
treeCid="",
|
||||
protected=False,
|
||||
)
|
||||
return cid
|
||||
|
||||
async def manifest(self, cid: Cid) -> Manifest:
|
||||
return self.storage[cid]
|
||||
|
||||
def create_download_stream(self, cid: Cid) -> StreamReader:
|
||||
reader = StreamReader()
|
||||
self.streams[cid] = reader
|
||||
return reader
|
||||
|
||||
@asynccontextmanager
|
||||
async def download(
|
||||
self, cid: Cid, timeout: Optional[ClientTimeout] = None
|
||||
) -> AsyncIterator[BaseStreamReader]:
|
||||
yield self.streams[cid]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_codex_api() -> AsyncIterator[Tuple[FakeCodex, Url]]:
|
||||
codex = FakeCodex()
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
@routes.get("/api/codex/v1/data/{cid}/network/manifest")
|
||||
async def manifest(request):
|
||||
cid = request.match_info["cid"]
|
||||
assert cid in codex.storage
|
||||
# Gets the manifest in a similar shape as the Codex response.
|
||||
manifest = json.loads(codex.storage[cid].model_dump_json())
|
||||
return web.json_response(
|
||||
data={
|
||||
"cid": manifest.pop("cid"),
|
||||
"manifest": manifest,
|
||||
}
|
||||
)
|
||||
|
||||
@routes.post("/api/codex/v1/data")
|
||||
async def upload(request):
|
||||
await request.post()
|
||||
filename = re.findall(
|
||||
r'filename="(.+)"', request.headers["Content-Disposition"]
|
||||
)[0]
|
||||
cid = await codex.upload(
|
||||
name=filename,
|
||||
mime_type=request.headers["Content-Type"],
|
||||
content=BytesIO(await request.read()),
|
||||
)
|
||||
return web.Response(text=cid)
|
||||
|
||||
@routes.get("/api/codex/v1/data/{cid}")
|
||||
async def download(request):
|
||||
cid = request.match_info["cid"]
|
||||
assert cid in codex.streams
|
||||
reader = codex.streams[cid]
|
||||
|
||||
# We basically copy the stream onto the response.
|
||||
response = web.StreamResponse()
|
||||
await response.prepare(request)
|
||||
while not reader.at_eof():
|
||||
await response.write(await reader.read(1024))
|
||||
|
||||
await response.write_eof()
|
||||
return response
|
||||
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "localhost", 8888)
|
||||
await site.start()
|
||||
|
||||
try:
|
||||
yield codex, Url(scheme="http", host="localhost", port=8888)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
@ -5,12 +5,12 @@ from starlette.testclient import TestClient
|
||||
|
||||
from benchmarks.codex.agent import api
|
||||
from benchmarks.codex.agent.agent import CodexAgent
|
||||
from benchmarks.codex.agent.tests.test_codex_agent import FakeCodexClient
|
||||
from benchmarks.codex.agent.tests.fake_codex import FakeCodex
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_create_file():
|
||||
codex_client = FakeCodexClient()
|
||||
codex_client = FakeCodex()
|
||||
codex_agent = CodexAgent(codex_client)
|
||||
|
||||
app = FastAPI()
|
||||
@ -34,7 +34,7 @@ async def test_should_create_file():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_report_when_download_is_complete():
|
||||
codex_client = FakeCodexClient()
|
||||
codex_client = FakeCodex()
|
||||
codex_agent = CodexAgent(codex_client)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -1,54 +1,19 @@
|
||||
from asyncio import StreamReader
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
from typing import IO, Dict, AsyncIterator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
|
||||
from benchmarks.codex.client.async_client import AsyncCodexClient
|
||||
from benchmarks.codex.client.common import Manifest, Cid
|
||||
from benchmarks.codex.agent.tests.fake_codex import FakeCodex, fake_codex_api
|
||||
from benchmarks.codex.client.async_client import AsyncCodexClientImpl
|
||||
from benchmarks.core.concurrency import await_predicate_async
|
||||
from benchmarks.core.utils.streams import BaseStreamReader
|
||||
from benchmarks.logging.logging import LogParser, DownloadMetric
|
||||
|
||||
|
||||
class FakeCodexClient(AsyncCodexClient):
|
||||
def __init__(self) -> None:
|
||||
self.storage: Dict[Cid, Manifest] = {}
|
||||
self.streams: Dict[Cid, StreamReader] = {}
|
||||
|
||||
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
|
||||
data = content.read()
|
||||
cid = "Qm" + str(hash(data))
|
||||
self.storage[cid] = Manifest(
|
||||
cid=cid,
|
||||
datasetSize=len(data),
|
||||
mimetype=mime_type,
|
||||
blockSize=1,
|
||||
filename=name,
|
||||
treeCid="",
|
||||
protected=False,
|
||||
)
|
||||
return cid
|
||||
|
||||
async def manifest(self, cid: Cid) -> Manifest:
|
||||
return self.storage[cid]
|
||||
|
||||
def create_download_stream(self, cid: Cid) -> StreamReader:
|
||||
reader = StreamReader()
|
||||
self.streams[cid] = reader
|
||||
return reader
|
||||
|
||||
@asynccontextmanager
|
||||
async def download(self, cid: Cid) -> AsyncIterator[BaseStreamReader]:
|
||||
yield self.streams[cid]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_create_dataset_of_right_size():
|
||||
codex_agent = CodexAgent(FakeCodexClient())
|
||||
codex_agent = CodexAgent(FakeCodex())
|
||||
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
|
||||
manifest = await codex_agent.client.manifest(cid)
|
||||
|
||||
@ -57,7 +22,7 @@ async def test_should_create_dataset_of_right_size():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_seed_creates_same_cid():
|
||||
codex_agent = CodexAgent(FakeCodexClient())
|
||||
codex_agent = CodexAgent(FakeCodex())
|
||||
|
||||
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
|
||||
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
|
||||
@ -69,7 +34,7 @@ async def test_same_seed_creates_same_cid():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_report_download_progress():
|
||||
client = FakeCodexClient()
|
||||
client = FakeCodex()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
|
||||
@ -95,7 +60,7 @@ async def test_should_report_download_progress():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_raise_exception_on_progress_query_if_download_fails():
|
||||
client = FakeCodexClient()
|
||||
client = FakeCodex()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
|
||||
@ -114,7 +79,7 @@ async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_log
|
||||
logger, output = mock_logger
|
||||
|
||||
with patch("benchmarks.codex.agent.agent.logger", logger):
|
||||
client = FakeCodexClient()
|
||||
client = FakeCodex()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
|
||||
@ -177,7 +142,7 @@ async def test_should_log_download_progress_as_discrete_steps_even_when_underlyi
|
||||
logger, output = mock_logger
|
||||
|
||||
with patch("benchmarks.codex.agent.agent.logger", logger):
|
||||
client = FakeCodexClient()
|
||||
client = FakeCodex()
|
||||
codex_agent = CodexAgent(client)
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
|
||||
download_stream = client.create_download_stream(cid)
|
||||
@ -244,7 +209,7 @@ async def test_should_log_download_progress_as_discrete_steps_even_when_underlyi
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_track_download_handles():
|
||||
client = FakeCodexClient()
|
||||
client = FakeCodex()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
|
||||
@ -262,3 +227,34 @@ async def test_should_track_download_handles():
|
||||
await handle.download_task
|
||||
|
||||
assert cid in codex_agent.ongoing_downloads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_timeout_if_download_stream_takes_too_long_to_return_content():
|
||||
async with fake_codex_api() as (fake_codex, url):
|
||||
client = AsyncCodexClientImpl(url)
|
||||
codex_agent = CodexAgent(client, read_timeout=0.5)
|
||||
|
||||
fast_cid = await codex_agent.create_dataset(
|
||||
size=1000, name="dataset-fast-1", seed=1356
|
||||
)
|
||||
slow_cid = await codex_agent.create_dataset(
|
||||
size=1000, name="dataset-slow-1", seed=1353
|
||||
)
|
||||
|
||||
fast_download = fake_codex.create_download_stream(fast_cid)
|
||||
slow_download = fake_codex.create_download_stream(slow_cid)
|
||||
|
||||
fast_download.feed_data(b"0" * 1000)
|
||||
fast_download.feed_eof()
|
||||
fast_handle = await codex_agent.download(fast_cid)
|
||||
await fast_handle.download_task
|
||||
|
||||
slow_handle = await codex_agent.download(slow_cid)
|
||||
slow_download.feed_data(b"0" * 500)
|
||||
await asyncio.sleep(0.6)
|
||||
slow_download.feed_data(b"0" * 500)
|
||||
slow_download.feed_eof()
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await slow_handle.download_task
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import IO, AsyncIterator, AsyncGenerator
|
||||
from typing import IO, AsyncIterator, AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
from urllib3.util import Url
|
||||
|
||||
from benchmarks.codex.client.common import Manifest, Cid
|
||||
@ -14,7 +15,13 @@ from benchmarks.core.utils.streams import BaseStreamReader
|
||||
|
||||
class AsyncCodexClient(ABC):
|
||||
@abstractmethod
|
||||
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
|
||||
async def upload(
|
||||
self,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
content: IO,
|
||||
timeout: Optional[ClientTimeout] = None,
|
||||
) -> Cid:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -23,7 +30,9 @@ class AsyncCodexClient(ABC):
|
||||
|
||||
@asynccontextmanager
|
||||
@abstractmethod
|
||||
def download(self, cid: Cid) -> AsyncGenerator[BaseStreamReader, None]:
|
||||
def download(
|
||||
self, cid: Cid, timeout: Optional[ClientTimeout] = None
|
||||
) -> AsyncGenerator[BaseStreamReader, None]:
|
||||
pass
|
||||
|
||||
|
||||
@ -33,8 +42,14 @@ class AsyncCodexClientImpl(AsyncCodexClient):
|
||||
def __init__(self, codex_api_url: Url):
|
||||
self.codex_api_url = codex_api_url
|
||||
|
||||
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async def upload(
|
||||
self,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
content: IO,
|
||||
timeout: Optional[ClientTimeout] = None,
|
||||
) -> Cid:
|
||||
async with aiohttp.ClientSession(timeout=ClientTimeout()) as session:
|
||||
response = await session.post(
|
||||
self.codex_api_url._replace(path="/api/codex/v1/data").url,
|
||||
headers={
|
||||
@ -42,6 +57,7 @@ class AsyncCodexClientImpl(AsyncCodexClient):
|
||||
aiohttp.hdrs.CONTENT_DISPOSITION: f'attachment; filename="{name}"',
|
||||
},
|
||||
data=content,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@ -62,10 +78,13 @@ class AsyncCodexClientImpl(AsyncCodexClient):
|
||||
return Manifest.from_codex_api_response(response_contents)
|
||||
|
||||
@asynccontextmanager
|
||||
async def download(self, cid: Cid) -> AsyncIterator[BaseStreamReader]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async def download(
|
||||
self, cid: Cid, timeout: Optional[ClientTimeout] = None
|
||||
) -> AsyncIterator[BaseStreamReader]:
|
||||
async with aiohttp.ClientSession(timeout=ClientTimeout()) as session:
|
||||
response = await session.get(
|
||||
self.codex_api_url._replace(path=f"/api/codex/v1/data/{cid}").url,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user