mirror of
https://github.com/logos-storage/bittorrent-benchmarks.git
synced 2026-01-02 13:03:13 +00:00
92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
from typing import Annotated, Optional
|
|
|
|
from aiohttp import ClientResponseError
|
|
from fastapi import APIRouter, Response, Depends, HTTPException, Request, FastAPI
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic_core import Url
|
|
from urllib3.util import parse_url
|
|
|
|
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
|
|
from benchmarks.codex.client.async_client import AsyncCodexClientImpl
|
|
from benchmarks.core.agent import AgentBuilder
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def codex_agent() -> CodexAgent:
|
|
raise Exception("Dependency must be set")
|
|
|
|
|
|
@router.get("/api/v1/hello")
|
|
async def hello():
|
|
return {"message": "Server is up"}
|
|
|
|
|
|
@router.post("/api/v1/codex/dataset")
|
|
async def generate(
|
|
agent: Annotated[CodexAgent, Depends(codex_agent)],
|
|
name: str,
|
|
size: int,
|
|
seed: Optional[int],
|
|
):
|
|
return Response(
|
|
await agent.create_dataset(name=name, size=size, seed=seed),
|
|
media_type="text/plain; charset=UTF-8",
|
|
)
|
|
|
|
|
|
@router.post("/api/v1/codex/download")
|
|
async def download(
|
|
request: Request, agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str
|
|
):
|
|
await agent.download(cid)
|
|
return JSONResponse(
|
|
status_code=202,
|
|
content={"status": str(request.url_for("download_status", cid=cid))},
|
|
)
|
|
|
|
|
|
@router.get("/api/v1/codex/download/{cid}/status")
|
|
async def download_status(
|
|
agent: Annotated[CodexAgent, Depends(codex_agent)], cid: str
|
|
) -> DownloadStatus:
|
|
if cid not in agent.ongoing_downloads:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"There are no ongoing downloads for CID {cid}"
|
|
)
|
|
|
|
return agent.ongoing_downloads[cid].progress()
|
|
|
|
|
|
@router.get("/api/v1/codex/download/node-id")
|
|
async def node_id(agent: Annotated[CodexAgent, Depends(codex_agent)]):
|
|
return agent.node_id
|
|
|
|
|
|
def client_response_error_handler(
|
|
_: Request, exception: ClientResponseError
|
|
) -> Response:
|
|
return JSONResponse(
|
|
status_code=exception.status,
|
|
content={"message": exception.message},
|
|
)
|
|
|
|
|
|
class CodexAgentConfig(AgentBuilder):
|
|
codex_api_url: Url
|
|
node_id: str
|
|
|
|
def build(self) -> FastAPI:
|
|
app = FastAPI()
|
|
app.include_router(router)
|
|
# Need to disable typing (https://github.com/encode/starlette/discussions/2391)
|
|
app.add_exception_handler(ClientResponseError, client_response_error_handler) # type: ignore
|
|
agent = CodexAgent(
|
|
client=AsyncCodexClientImpl(
|
|
codex_api_url=parse_url(str(self.codex_api_url))
|
|
),
|
|
node_id=self.node_id,
|
|
)
|
|
app.dependency_overrides[codex_agent] = lambda: agent
|
|
return app
|