feat: add tracking of download progress via synchronous API at the Codex agent

This commit is contained in:
gmega 2025-01-31 15:49:58 -03:00
parent ec44588fab
commit cb941b859f
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
20 changed files with 270 additions and 50 deletions

View File

@ -1,15 +1,74 @@
import asyncio
from asyncio import Task
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from benchmarks.codex.client import CodexClient
from benchmarks.core.utils import random_data
from benchmarks.codex.client import CodexClient, Manifest
from benchmarks.core.utils.random import random_data
from benchmarks.core.utils.streams import BaseStreamReader
Cid = str
EMPTY_STREAM_BACKOFF = 0.1
class DownloadHandle:
def __init__(
self,
parent: "CodexAgent",
manifest: Manifest,
download_stream: BaseStreamReader,
read_increment: float = 0.01,
):
self.parent = parent
self.manifest = manifest
self.bytes_downloaded = 0
self.read_increment = read_increment
self.download_stream = download_stream
self.download_task: Optional[Task[None]] = None
def begin_download(self) -> Task:
self.download_task = asyncio.create_task(self._download_loop())
return self.download_task
async def _download_loop(self):
step_size = int(self.manifest.datasetSize * self.read_increment)
while not self.download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await self.download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
def progress(self) -> float:
if self.download_task is None:
return 0
if self.download_task.done():
# This will bubble exceptions up, if any.
self.download_task.result()
return self.bytes_downloaded / self.manifest.datasetSize
class CodexAgent:
def __init__(self, client: CodexClient):
def __init__(self, client: CodexClient) -> None:
self.client = client
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
@ -23,3 +82,14 @@ class CodexAgent:
return await self.client.upload(
name=name, mime_type="application/octet-stream", content=infile
)
async def download(self, cid: Cid) -> DownloadHandle:
handle = DownloadHandle(
self,
manifest=await self.client.get_manifest(cid),
download_stream=await self.client.download(cid),
)
handle.begin_download()
return handle

View File

@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from typing import IO
import aiohttp
from pydantic import BaseModel
from urllib3.util import Url
from benchmarks.core.utils.streams import BaseStreamReader
API_VERSION = "v1"
Cid = str
@ -20,7 +23,21 @@ class Manifest(BaseModel):
protected: bool
class CodexClient:
class CodexClient(ABC):
@abstractmethod
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
pass
@abstractmethod
async def get_manifest(self, cid: Cid) -> Manifest:
pass
@abstractmethod
async def download(self, cid: Cid) -> BaseStreamReader:
pass
class CodexClientImpl(CodexClient):
"""A lightweight async wrapper built around the Codex REST API."""
def __init__(self, codex_api_url: Url):
@ -55,3 +72,15 @@ class CodexClient:
cid = response_contents.pop("cid")
return Manifest.model_validate(dict(cid=cid, **response_contents["manifest"]))
async def download(self, cid: Cid) -> BaseStreamReader:
async with aiohttp.ClientSession() as session:
response = await session.get(
self.codex_api_url._replace(
path=f"/api/codex/v1/data/{cid}/network/download"
).url,
)
response.raise_for_status()
return response.content

View File

@ -3,12 +3,12 @@ import os
import pytest
from urllib3.util import parse_url
from benchmarks.codex.client import CodexClient
from benchmarks.codex.client import CodexClientImpl
@pytest.fixture
def codex_client_1():
# TODO wipe data between tests
return CodexClient(
return CodexClientImpl(
parse_url(f"http://{os.environ.get('CODEX_NODE_1', 'localhost')}:8091")
)

View File

@ -1,30 +1,118 @@
from asyncio import StreamReader
from typing import IO, Dict
import pytest
from benchmarks.codex.agent import CodexAgent
from benchmarks.codex.client import CodexClient, Cid, Manifest
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader
@pytest.fixture
def codex_agent(codex_client_1):
return CodexAgent(codex_client_1)
class FakeCodexClient(CodexClient):
def __init__(self) -> None:
self.storage: Dict[Cid, Manifest] = {}
self.streams: Dict[Cid, StreamReader] = {}
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
data = content.read()
cid = "Qm" + str(hash(data))
self.storage[cid] = Manifest(
cid=cid,
datasetSize=len(data),
mimetype=mime_type,
blockSize=1,
filename=name,
treeCid="",
uploadedAt=0,
protected=False,
)
return cid
async def get_manifest(self, cid: Cid) -> Manifest:
return self.storage[cid]
def create_download_stream(self, cid: Cid) -> StreamReader:
reader = StreamReader()
self.streams[cid] = reader
return reader
async def download(self, cid: Cid) -> BaseStreamReader:
return self.streams[cid]
@pytest.mark.codex_integration
@pytest.mark.asyncio
async def test_should_create_dataset(codex_agent: CodexAgent):
async def test_should_create_dataset_of_right_size():
codex_agent = CodexAgent(FakeCodexClient())
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
manifest = await codex_agent.client.get_manifest(cid)
dataset = await codex_agent.client.get_manifest(cid)
assert dataset.cid == cid
assert dataset.datasetSize == 1024
assert manifest.datasetSize == 1024
@pytest.mark.codex_integration
@pytest.mark.asyncio
async def test_same_seed_creates_same_cid(codex_agent: CodexAgent):
async def test_same_seed_creates_same_cid():
codex_agent = CodexAgent(FakeCodexClient())
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235)
assert cid1 == cid2
assert cid1 != cid3
@pytest.mark.asyncio
async def test_should_report_download_progress():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert handle.progress() == 0
for i in range(100):
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == i, timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_eof()
await handle.download_task
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
download_stream.feed_eof()
with pytest.raises(EOFError):
def _predicate():
handle.progress()
return False
await await_predicate_async(_predicate, timeout=5)

View File

@ -2,7 +2,7 @@ from io import BytesIO
import pytest
from benchmarks.core.utils import random_data
from benchmarks.core.utils.random import random_data
@pytest.fixture

View File

@ -1,11 +1,43 @@
import asyncio
from concurrent import futures
from concurrent.futures.thread import ThreadPoolExecutor
from queue import Queue
from typing import Iterable, Iterator, List, cast
from time import time, sleep
from typing import Iterable, Iterator, List, cast, Awaitable, Callable
from typing_extensions import TypeVar
def await_predicate(
predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0
) -> bool:
start_time = time()
while (timeout == 0) or ((time() - start_time) <= timeout):
if predicate():
return True
sleep(polling_interval)
return False
async def await_predicate_async(
predicate: Callable[[], Awaitable[bool]] | Callable[[], bool],
timeout: float = 0,
polling_interval: float = 0,
) -> bool:
start_time = time()
while (timeout == 0) or ((time() - start_time) <= timeout):
if asyncio.iscoroutinefunction(predicate):
if await predicate():
return True
else:
if predicate():
return True
await asyncio.sleep(polling_interval)
return False
class _End:
pass

View File

@ -8,8 +8,8 @@ from typing import List, Optional
from typing_extensions import Generic, TypeVar
from benchmarks.core.concurrency import await_predicate
from benchmarks.core.config import Builder
from benchmarks.core.utils import await_predicate
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
from io import BytesIO
from benchmarks.core.utils import random_data
from benchmarks.core.utils.random import random_data
def test_should_generate_the_requested_amount_of_bytes():

View File

View File

@ -1,18 +1,7 @@
import random
from time import time, sleep
from typing import Iterator, Optional, Callable, IO
from typing import Iterator, IO, Optional
def await_predicate(
predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0
) -> bool:
start_time = time()
while (timeout == 0) or ((time() - start_time) <= timeout):
if predicate():
return True
sleep(polling_interval)
return False
from benchmarks.core.utils.units import megabytes
def sample(n: int) -> Iterator[int]:
@ -26,14 +15,6 @@ def sample(n: int) -> Iterator[int]:
yield p[i]
def kilobytes(n: int) -> int:
return n * 1024
def megabytes(n: int) -> int:
return kilobytes(n) * 1024
def random_data(
size: int, outfile: IO, batch_size: int = megabytes(50), seed: Optional[int] = None
):

View File

@ -0,0 +1,9 @@
from typing import Protocol
class BaseStreamReader(Protocol):
async def read(self, n: int) -> bytes: ...
def feed_data(self, data: bytes) -> None: ...
def at_eof(self) -> bool: ...

View File

@ -0,0 +1,6 @@
def kilobytes(n: int) -> int:
return n * 1024
def megabytes(n: int) -> int:
return kilobytes(n) * 1024

View File

@ -4,7 +4,8 @@ from typing import Optional
from torrentool.torrent import Torrent
from benchmarks.core.utils import random_data, megabytes
from benchmarks.core.utils.random import random_data
from benchmarks.core.utils.units import megabytes
logger = logging.getLogger(__name__)

View File

@ -4,8 +4,8 @@ from typing import Annotated, Optional
from fastapi import FastAPI, Depends, APIRouter, Response
from benchmarks.core.agent import AgentBuilder
from benchmarks.core.utils.units import megabytes
from benchmarks.core.utils import megabytes
from benchmarks.deluge.agent.agent import DelugeAgent
router = APIRouter()

View File

@ -15,7 +15,8 @@ from benchmarks.core.experiments.experiments import (
from benchmarks.core.experiments.iterated_experiment import IteratedExperiment
from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment
from benchmarks.core.pydantic import Host
from benchmarks.core.utils import sample
from benchmarks.core.utils.random import sample
from benchmarks.deluge.agent.client import DelugeAgentClient
from benchmarks.deluge.deluge_node import DelugeMeta, DelugeNode
from benchmarks.deluge.tracker import Tracker

View File

@ -22,9 +22,10 @@ from tenacity.wait import wait_base
from torrentool.torrent import Torrent
from urllib3.util import Url
from benchmarks.core.concurrency import await_predicate
from benchmarks.core.experiments.experiments import ExperimentComponent
from benchmarks.core.network import DownloadHandle, Node
from benchmarks.core.utils import await_predicate
from benchmarks.deluge.agent.client import DelugeAgentClient
logger = logging.getLogger(__name__)

View File

@ -5,7 +5,7 @@ from typing import Generator
import pytest
from urllib3.util import parse_url
from benchmarks.core.utils import await_predicate
from benchmarks.core.concurrency import await_predicate
from benchmarks.deluge.agent.client import DelugeAgentClient
from benchmarks.deluge.deluge_node import DelugeNode
from benchmarks.deluge.tracker import Tracker

View File

@ -1,7 +1,8 @@
import pytest
from tenacity import wait_incrementing, stop_after_attempt, RetryError
from benchmarks.core.utils import megabytes, await_predicate
from benchmarks.core.concurrency import await_predicate
from benchmarks.core.utils.units import megabytes
from benchmarks.deluge.deluge_node import DelugeNode, DelugeMeta, ResilientCallWrapper
from benchmarks.deluge.tracker import Tracker

View File

@ -2,7 +2,8 @@ import pytest
from benchmarks.core.experiments.experiments import ExperimentEnvironment
from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment
from benchmarks.core.utils import megabytes
from benchmarks.core.utils.units import megabytes
from benchmarks.deluge.deluge_node import DelugeMeta
from benchmarks.deluge.tests.test_deluge_node import assert_is_seed

View File

@ -6,7 +6,7 @@ from typing import Dict, Any
import pytest
from elasticsearch import Elasticsearch
from benchmarks.core.utils import await_predicate
from benchmarks.core.concurrency import await_predicate
def _json_data(data: str) -> Dict[str, Any]: