feat: add Codex agent REST API

This commit is contained in:
gmega 2025-02-03 18:00:43 -03:00
parent 849bcad6c8
commit 820699f001
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
4 changed files with 165 additions and 39 deletions

View File

@ -1,11 +1,12 @@
import asyncio
import logging
from asyncio import Task
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Dict
from pydantic import BaseModel
from benchmarks.codex.agent.codex_client import CodexClient, Manifest
from benchmarks.codex.logging import CodexDownloadMetric
from benchmarks.core.utils.random import random_data
@ -18,8 +19,7 @@ EMPTY_STREAM_BACKOFF = 0.1
logger = logging.getLogger(__name__)
@dataclass
class DownloadStatus:
class DownloadStatus(BaseModel):
downloaded: int
total: int
@ -44,39 +44,36 @@ class DownloadHandle:
return self.download_task
async def _download_loop(self):
try:
step_size = int(self.manifest.datasetSize * self.read_increment)
step_size = int(self.manifest.datasetSize * self.read_increment)
while not self.download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await self.download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
while not self.download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await self.download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
logger.info(
CodexDownloadMetric(
cid=self.manifest.cid,
value=self.bytes_downloaded,
node=self.parent.node_id,
)
logger.info(
CodexDownloadMetric(
cid=self.manifest.cid,
value=self.bytes_downloaded,
node=self.parent.node_id,
)
)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
finally:
self.parent._download_done(self.manifest.cid)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
def progress(self) -> DownloadStatus:
if self.download_task is None:
@ -124,6 +121,3 @@ class CodexAgent:
self.ongoing_downloads[cid] = handle
return handle
def _download_done(self, cid: Cid):
self.ongoing_downloads.pop(cid)

View File

@ -0,0 +1,53 @@
from typing import Annotated, Optional
from fastapi import APIRouter, Response, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
router = APIRouter()
def codex_agent() -> CodexAgent:
raise Exception("Dependency must be set")
@router.get("/api/v1/hello")
async def hello():
return {"message": "Server is up"}
@router.post("/api/v1/codex/dataset")
async def generate(
agent: Annotated[CodexAgent, Depends(codex_agent)],
name: str,
size: int,
seed: Optional[int],
):
return Response(
await agent.create_dataset(name=name, size=size, seed=seed),
media_type="text/plain; charset=UTF-8",
)
@router.post("/api/v1/codex/download")
async def download(
request: Request, agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str
):
await agent.download(cid)
return JSONResponse(
status_code=202,
content={"status": str(request.url_for("download_status", cid=cid))},
)
@router.get("/api/v1/codex/download/{cid}/status")
async def download_status(
agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str
) -> DownloadStatus:
if cid not in agent.ongoing_downloads:
raise HTTPException(
status_code=404, detail=f"There are no ongoing downloads for CID {cid}"
)
return agent.ongoing_downloads[cid].progress()

View File

@ -0,0 +1,77 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient, ASGITransport
from starlette.testclient import TestClient
from benchmarks.codex.agent import api
from benchmarks.codex.agent.agent import CodexAgent
from benchmarks.codex.agent.tests.test_codex_agent import FakeCodexClient
@pytest.mark.asyncio
async def test_should_create_file():
codex_client = FakeCodexClient()
codex_agent = CodexAgent(codex_client)
app = FastAPI()
app.include_router(api.router)
app.dependency_overrides[api.codex_agent] = lambda: codex_agent
client = TestClient(app)
response = client.post(
"/api/v1/codex/dataset",
params={"name": "dataset-1", "size": 1024, "seed": 12},
)
assert response.status_code == 200
assert response.charset_encoding == "utf-8"
manifest = await codex_client.get_manifest(response.text)
assert manifest.datasetSize == 1024
@pytest.mark.asyncio
async def test_should_report_when_download_is_complete():
codex_client = FakeCodexClient()
codex_agent = CodexAgent(codex_client)
app = FastAPI()
app.include_router(api.router)
app.dependency_overrides[api.codex_agent] = lambda: codex_agent
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://testserver"
) as client:
response = await client.post(
"/api/v1/codex/dataset",
params={"name": "dataset-1", "size": 1024, "seed": 12},
)
assert response.status_code == 200
assert response.charset_encoding == "utf-8"
cid = response.text
download_stream = codex_client.create_download_stream(cid)
response = await client.post(
"/api/v1/codex/download",
params={"cid": cid},
)
assert response.status_code == 202
assert response.json() == {
"status": f"http://testserver/api/v1/codex/download/{cid}/status"
}
download_stream.feed_data(b"0" * 1024)
download_stream.feed_eof()
await codex_agent.ongoing_downloads[cid].download_task
response = await client.get(f"api/v1/codex/download/{cid}/status")
assert response.status_code == 200
assert response.json() == {"downloaded": 1024, "total": 1024}

View File

@ -153,20 +153,22 @@ async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_log
@pytest.mark.asyncio
async def test_should_track_download_handles_and_dispose_of_them_at_the_end():
async def test_should_track_download_handles():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
assert cid not in codex_agent.ongoing_downloads
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert codex_agent.ongoing_downloads[cid] == handle
download_stream.feed_data(b"0" * 1000)
download_stream.feed_eof()
assert codex_agent.ongoing_downloads[cid] == handle
await handle.download_task
assert cid not in codex_agent.ongoing_downloads
assert cid in codex_agent.ongoing_downloads