128 lines
4.4 KiB
Python
Raw Normal View History

import asyncio
import logging
from asyncio import Task
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Dict
2025-02-03 18:00:43 -03:00
from pydantic import BaseModel
from benchmarks.codex.client.async_client import AsyncCodexClient
from benchmarks.codex.client.common import Cid
from benchmarks.codex.client.common import Manifest
from benchmarks.core.utils.random import random_data
2025-02-14 18:19:45 -03:00
from benchmarks.logging.logging import DownloadMetric
2025-02-14 11:00:17 -03:00
EMPTY_STREAM_BACKOFF = 2
logger = logging.getLogger(__name__)
2025-02-03 18:00:43 -03:00
class DownloadStatus(BaseModel):
downloaded: int
total: int
def as_percent(self) -> float:
return (self.downloaded * 100) / self.total
class DownloadHandle:
def __init__(
self,
parent: "CodexAgent",
manifest: Manifest,
read_increment: float = 0.01,
):
self.parent = parent
self.manifest = manifest
self.bytes_downloaded = 0
self.read_increment = read_increment
self.download_task: Optional[Task[None]] = None
def begin_download(self) -> Task:
self.download_task = asyncio.create_task(self._download_loop())
return self.download_task
async def _download_loop(self):
2025-02-03 18:00:43 -03:00
step_size = int(self.manifest.datasetSize * self.read_increment)
async with self.parent.client.download(self.manifest.cid) as download_stream:
logged_step = 0
while not download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
if int(self.bytes_downloaded / step_size) > logged_step:
logged_step += 1
logger.info(
2025-02-14 18:19:45 -03:00
DownloadMetric(
dataset_name=self.manifest.filename,
handle=self.manifest.cid,
value=step_size * logged_step,
node=self.parent.node_id,
)
)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
def progress(self) -> DownloadStatus:
if self.download_task is None:
return DownloadStatus(downloaded=0, total=self.manifest.datasetSize)
if self.download_task.done():
# This will bubble exceptions up, if any.
self.download_task.result()
return DownloadStatus(
downloaded=self.bytes_downloaded, total=self.manifest.datasetSize
)
class CodexAgent:
def __init__(self, client: AsyncCodexClient, node_id: str = "unknown") -> None:
self.client = client
self.node_id = node_id
self.ongoing_downloads: Dict[Cid, DownloadHandle] = {}
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
with TemporaryDirectory() as td:
data = Path(td) / "datafile.bin"
with data.open(mode="wb") as outfile:
random_data(size=size, outfile=outfile, seed=seed)
with data.open(mode="rb") as infile:
return await self.client.upload(
name=name, mime_type="application/octet-stream", content=infile
)
async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
if cid in self.ongoing_downloads:
return self.ongoing_downloads[cid]
handle = DownloadHandle(
self,
manifest=await self.client.manifest(cid),
read_increment=read_increment,
)
handle.begin_download()
self.ongoing_downloads[cid] = handle
return handle