mirror of
https://github.com/logos-storage/bittorrent-benchmarks.git
synced 2026-01-04 22:13:12 +00:00
feat: add Codex agent REST API
This commit is contained in:
parent
849bcad6c8
commit
820699f001
@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import Task
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from benchmarks.codex.agent.codex_client import CodexClient, Manifest
|
||||
from benchmarks.codex.logging import CodexDownloadMetric
|
||||
from benchmarks.core.utils.random import random_data
|
||||
@ -18,8 +19,7 @@ EMPTY_STREAM_BACKOFF = 0.1
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadStatus:
|
||||
class DownloadStatus(BaseModel):
|
||||
downloaded: int
|
||||
total: int
|
||||
|
||||
@ -44,39 +44,36 @@ class DownloadHandle:
|
||||
return self.download_task
|
||||
|
||||
async def _download_loop(self):
|
||||
try:
|
||||
step_size = int(self.manifest.datasetSize * self.read_increment)
|
||||
step_size = int(self.manifest.datasetSize * self.read_increment)
|
||||
|
||||
while not self.download_stream.at_eof():
|
||||
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
|
||||
bytes_read = await self.download_stream.read(step)
|
||||
# We actually have no guarantees that an empty read means EOF, so we just back off
|
||||
# a bit.
|
||||
if not bytes_read:
|
||||
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
|
||||
self.bytes_downloaded += len(bytes_read)
|
||||
while not self.download_stream.at_eof():
|
||||
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
|
||||
bytes_read = await self.download_stream.read(step)
|
||||
# We actually have no guarantees that an empty read means EOF, so we just back off
|
||||
# a bit.
|
||||
if not bytes_read:
|
||||
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
|
||||
self.bytes_downloaded += len(bytes_read)
|
||||
|
||||
logger.info(
|
||||
CodexDownloadMetric(
|
||||
cid=self.manifest.cid,
|
||||
value=self.bytes_downloaded,
|
||||
node=self.parent.node_id,
|
||||
)
|
||||
logger.info(
|
||||
CodexDownloadMetric(
|
||||
cid=self.manifest.cid,
|
||||
value=self.bytes_downloaded,
|
||||
node=self.parent.node_id,
|
||||
)
|
||||
)
|
||||
|
||||
if self.bytes_downloaded < self.manifest.datasetSize:
|
||||
raise EOFError(
|
||||
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
|
||||
f"than expected ({self.manifest.datasetSize})."
|
||||
)
|
||||
if self.bytes_downloaded < self.manifest.datasetSize:
|
||||
raise EOFError(
|
||||
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
|
||||
f"than expected ({self.manifest.datasetSize})."
|
||||
)
|
||||
|
||||
if self.bytes_downloaded > self.manifest.datasetSize:
|
||||
raise ValueError(
|
||||
f"Download size ({self.bytes_downloaded}) was greater than expected "
|
||||
f"({self.manifest.datasetSize})."
|
||||
)
|
||||
finally:
|
||||
self.parent._download_done(self.manifest.cid)
|
||||
if self.bytes_downloaded > self.manifest.datasetSize:
|
||||
raise ValueError(
|
||||
f"Download size ({self.bytes_downloaded}) was greater than expected "
|
||||
f"({self.manifest.datasetSize})."
|
||||
)
|
||||
|
||||
def progress(self) -> DownloadStatus:
|
||||
if self.download_task is None:
|
||||
@ -124,6 +121,3 @@ class CodexAgent:
|
||||
|
||||
self.ongoing_downloads[cid] = handle
|
||||
return handle
|
||||
|
||||
def _download_done(self, cid: Cid):
|
||||
self.ongoing_downloads.pop(cid)
|
||||
|
||||
53
benchmarks/codex/agent/api.py
Normal file
53
benchmarks/codex/agent/api.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Response, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def codex_agent() -> CodexAgent:
|
||||
raise Exception("Dependency must be set")
|
||||
|
||||
|
||||
@router.get("/api/v1/hello")
|
||||
async def hello():
|
||||
return {"message": "Server is up"}
|
||||
|
||||
|
||||
@router.post("/api/v1/codex/dataset")
|
||||
async def generate(
|
||||
agent: Annotated[CodexAgent, Depends(codex_agent)],
|
||||
name: str,
|
||||
size: int,
|
||||
seed: Optional[int],
|
||||
):
|
||||
return Response(
|
||||
await agent.create_dataset(name=name, size=size, seed=seed),
|
||||
media_type="text/plain; charset=UTF-8",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/codex/download")
|
||||
async def download(
|
||||
request: Request, agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str
|
||||
):
|
||||
await agent.download(cid)
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": str(request.url_for("download_status", cid=cid))},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/codex/download/{cid}/status")
|
||||
async def download_status(
|
||||
agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str
|
||||
) -> DownloadStatus:
|
||||
if cid not in agent.ongoing_downloads:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"There are no ongoing downloads for CID {cid}"
|
||||
)
|
||||
|
||||
return agent.ongoing_downloads[cid].progress()
|
||||
77
benchmarks/codex/agent/tests/test_api.py
Normal file
77
benchmarks/codex/agent/tests/test_api.py
Normal file
@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from benchmarks.codex.agent import api
|
||||
from benchmarks.codex.agent.agent import CodexAgent
|
||||
from benchmarks.codex.agent.tests.test_codex_agent import FakeCodexClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_create_file():
|
||||
codex_client = FakeCodexClient()
|
||||
codex_agent = CodexAgent(codex_client)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(api.router)
|
||||
app.dependency_overrides[api.codex_agent] = lambda: codex_agent
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/codex/dataset",
|
||||
params={"name": "dataset-1", "size": 1024, "seed": 12},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.charset_encoding == "utf-8"
|
||||
|
||||
manifest = await codex_client.get_manifest(response.text)
|
||||
|
||||
assert manifest.datasetSize == 1024
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_report_when_download_is_complete():
|
||||
codex_client = FakeCodexClient()
|
||||
codex_agent = CodexAgent(codex_client)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(api.router)
|
||||
app.dependency_overrides[api.codex_agent] = lambda: codex_agent
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://testserver"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/codex/dataset",
|
||||
params={"name": "dataset-1", "size": 1024, "seed": 12},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.charset_encoding == "utf-8"
|
||||
|
||||
cid = response.text
|
||||
|
||||
download_stream = codex_client.create_download_stream(cid)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/codex/download",
|
||||
params={"cid": cid},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {
|
||||
"status": f"http://testserver/api/v1/codex/download/{cid}/status"
|
||||
}
|
||||
|
||||
download_stream.feed_data(b"0" * 1024)
|
||||
download_stream.feed_eof()
|
||||
|
||||
await codex_agent.ongoing_downloads[cid].download_task
|
||||
|
||||
response = await client.get(f"api/v1/codex/download/{cid}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"downloaded": 1024, "total": 1024}
|
||||
@ -153,20 +153,22 @@ async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_log
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_track_download_handles_and_dispose_of_them_at_the_end():
|
||||
async def test_should_track_download_handles():
|
||||
client = FakeCodexClient()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
|
||||
|
||||
assert cid not in codex_agent.ongoing_downloads
|
||||
|
||||
download_stream = client.create_download_stream(cid)
|
||||
|
||||
handle = await codex_agent.download(cid)
|
||||
|
||||
assert codex_agent.ongoing_downloads[cid] == handle
|
||||
|
||||
download_stream.feed_data(b"0" * 1000)
|
||||
download_stream.feed_eof()
|
||||
|
||||
assert codex_agent.ongoing_downloads[cid] == handle
|
||||
|
||||
await handle.download_task
|
||||
|
||||
assert cid not in codex_agent.ongoing_downloads
|
||||
assert cid in codex_agent.ongoing_downloads
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user