From 820699f0018b8a600d74c87fbe6a60682f67888f Mon Sep 17 00:00:00 2001 From: gmega Date: Mon, 3 Feb 2025 18:00:43 -0300 Subject: [PATCH] feat: add Codex agent REST API --- benchmarks/codex/agent/agent.py | 62 +++++++-------- benchmarks/codex/agent/api.py | 53 +++++++++++++ benchmarks/codex/agent/tests/test_api.py | 77 +++++++++++++++++++ .../codex/agent/tests/test_codex_agent.py | 12 +-- 4 files changed, 165 insertions(+), 39 deletions(-) create mode 100644 benchmarks/codex/agent/api.py create mode 100644 benchmarks/codex/agent/tests/test_api.py diff --git a/benchmarks/codex/agent/agent.py b/benchmarks/codex/agent/agent.py index 8de112d..fd65bdc 100644 --- a/benchmarks/codex/agent/agent.py +++ b/benchmarks/codex/agent/agent.py @@ -1,11 +1,12 @@ import asyncio import logging from asyncio import Task -from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Dict +from pydantic import BaseModel + from benchmarks.codex.agent.codex_client import CodexClient, Manifest from benchmarks.codex.logging import CodexDownloadMetric from benchmarks.core.utils.random import random_data @@ -18,8 +19,7 @@ EMPTY_STREAM_BACKOFF = 0.1 logger = logging.getLogger(__name__) -@dataclass -class DownloadStatus: +class DownloadStatus(BaseModel): downloaded: int total: int @@ -44,39 +44,36 @@ class DownloadHandle: return self.download_task async def _download_loop(self): - try: - step_size = int(self.manifest.datasetSize * self.read_increment) + step_size = int(self.manifest.datasetSize * self.read_increment) - while not self.download_stream.at_eof(): - step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded) - bytes_read = await self.download_stream.read(step) - # We actually have no guarantees that an empty read means EOF, so we just back off - # a bit. - if not bytes_read: - await asyncio.sleep(EMPTY_STREAM_BACKOFF) - self.bytes_downloaded += len(bytes_read) + while not self.download_stream.at_eof(): + step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded) + bytes_read = await self.download_stream.read(step) + # We actually have no guarantees that an empty read means EOF, so we just back off + # a bit. + if not bytes_read: + await asyncio.sleep(EMPTY_STREAM_BACKOFF) + self.bytes_downloaded += len(bytes_read) - logger.info( - CodexDownloadMetric( - cid=self.manifest.cid, - value=self.bytes_downloaded, - node=self.parent.node_id, - ) + logger.info( + CodexDownloadMetric( + cid=self.manifest.cid, + value=self.bytes_downloaded, + node=self.parent.node_id, ) + ) - if self.bytes_downloaded < self.manifest.datasetSize: - raise EOFError( - f"Got EOF too early: download size ({self.bytes_downloaded}) was less " - f"than expected ({self.manifest.datasetSize})." - ) + if self.bytes_downloaded < self.manifest.datasetSize: + raise EOFError( + f"Got EOF too early: download size ({self.bytes_downloaded}) was less " + f"than expected ({self.manifest.datasetSize})." + ) - if self.bytes_downloaded > self.manifest.datasetSize: - raise ValueError( - f"Download size ({self.bytes_downloaded}) was greater than expected " - f"({self.manifest.datasetSize})." - ) - finally: - self.parent._download_done(self.manifest.cid) + if self.bytes_downloaded > self.manifest.datasetSize: + raise ValueError( + f"Download size ({self.bytes_downloaded}) was greater than expected " + f"({self.manifest.datasetSize})." + ) def progress(self) -> DownloadStatus: if self.download_task is None: @@ -124,6 +121,3 @@ class CodexAgent: self.ongoing_downloads[cid] = handle return handle - - def _download_done(self, cid: Cid): - self.ongoing_downloads.pop(cid) diff --git a/benchmarks/codex/agent/api.py b/benchmarks/codex/agent/api.py new file mode 100644 index 0000000..ea47df3 --- /dev/null +++ b/benchmarks/codex/agent/api.py @@ -0,0 +1,53 @@ +from typing import Annotated, Optional + +from fastapi import APIRouter, Response, Depends, HTTPException, Request +from fastapi.responses import JSONResponse + +from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus + +router = APIRouter() + + +def codex_agent() -> CodexAgent: + raise Exception("Dependency must be set") + + +@router.get("/api/v1/hello") +async def hello(): + return {"message": "Server is up"} + + +@router.post("/api/v1/codex/dataset") +async def generate( + agent: Annotated[CodexAgent, Depends(codex_agent)], + name: str, + size: int, + seed: Optional[int], +): + return Response( + await agent.create_dataset(name=name, size=size, seed=seed), + media_type="text/plain; charset=UTF-8", + ) + + +@router.post("/api/v1/codex/download") +async def download( + request: Request, agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str +): + await agent.download(cid) + return JSONResponse( + status_code=202, + content={"status": str(request.url_for("download_status", cid=cid))}, + ) + + +@router.get("/api/v1/codex/download/{cid}/status") +async def download_status( + agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str +) -> DownloadStatus: + if cid not in agent.ongoing_downloads: + raise HTTPException( + status_code=404, detail=f"There are no ongoing downloads for CID {cid}" + ) + + return agent.ongoing_downloads[cid].progress() diff --git a/benchmarks/codex/agent/tests/test_api.py b/benchmarks/codex/agent/tests/test_api.py new file mode 100644 index 0000000..d0645ef --- /dev/null +++ b/benchmarks/codex/agent/tests/test_api.py @@ -0,0 +1,77 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient, ASGITransport +from starlette.testclient import TestClient + +from benchmarks.codex.agent import api +from benchmarks.codex.agent.agent import CodexAgent +from benchmarks.codex.agent.tests.test_codex_agent import FakeCodexClient + + +@pytest.mark.asyncio +async def test_should_create_file(): + codex_client = FakeCodexClient() + codex_agent = CodexAgent(codex_client) + + app = FastAPI() + app.include_router(api.router) + app.dependency_overrides[api.codex_agent] = lambda: codex_agent + + client = TestClient(app) + + response = client.post( + "/api/v1/codex/dataset", + params={"name": "dataset-1", "size": 1024, "seed": 12}, + ) + + assert response.status_code == 200 + assert response.charset_encoding == "utf-8" + + manifest = await codex_client.get_manifest(response.text) + + assert manifest.datasetSize == 1024 + + +@pytest.mark.asyncio +async def test_should_report_when_download_is_complete(): + codex_client = FakeCodexClient() + codex_agent = CodexAgent(codex_client) + + app = FastAPI() + app.include_router(api.router) + app.dependency_overrides[api.codex_agent] = lambda: codex_agent + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://testserver" + ) as client: + response = await client.post( + "/api/v1/codex/dataset", + params={"name": "dataset-1", "size": 1024, "seed": 12}, + ) + + assert response.status_code == 200 + assert response.charset_encoding == "utf-8" + + cid = response.text + + download_stream = codex_client.create_download_stream(cid) + + response = await client.post( + "/api/v1/codex/download", + params={"cid": cid}, + ) + + assert response.status_code == 202 + assert response.json() == { + "status": f"http://testserver/api/v1/codex/download/{cid}/status" + } + + download_stream.feed_data(b"0" * 1024) + download_stream.feed_eof() + + await codex_agent.ongoing_downloads[cid].download_task + + response = await client.get(f"api/v1/codex/download/{cid}/status") + + assert response.status_code == 200 + assert response.json() == {"downloaded": 1024, "total": 1024} diff --git a/benchmarks/codex/agent/tests/test_codex_agent.py b/benchmarks/codex/agent/tests/test_codex_agent.py index a20e8e7..8afb270 100644 --- a/benchmarks/codex/agent/tests/test_codex_agent.py +++ b/benchmarks/codex/agent/tests/test_codex_agent.py @@ -153,20 +153,22 @@ async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_log @pytest.mark.asyncio -async def test_should_track_download_handles_and_dispose_of_them_at_the_end(): +async def test_should_track_download_handles(): client = FakeCodexClient() codex_agent = CodexAgent(client) cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356) + + assert cid not in codex_agent.ongoing_downloads + download_stream = client.create_download_stream(cid) - handle = await codex_agent.download(cid) - assert codex_agent.ongoing_downloads[cid] == handle - download_stream.feed_data(b"0" * 1000) download_stream.feed_eof() + assert codex_agent.ongoing_downloads[cid] == handle + await handle.download_task - assert cid not in codex_agent.ongoing_downloads + assert cid in codex_agent.ongoing_downloads