feat: add tracking of download progress via synchronous API at the Codex agent

This commit is contained in:
gmega 2025-01-31 15:49:58 -03:00
parent ec44588fab
commit cb941b859f
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
20 changed files with 270 additions and 50 deletions

View File

@ -1,15 +1,74 @@
import asyncio
from asyncio import Task
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
from benchmarks.codex.client import CodexClient from benchmarks.codex.client import CodexClient, Manifest
from benchmarks.core.utils import random_data from benchmarks.core.utils.random import random_data
from benchmarks.core.utils.streams import BaseStreamReader
Cid = str Cid = str
EMPTY_STREAM_BACKOFF = 0.1
class DownloadHandle:
def __init__(
self,
parent: "CodexAgent",
manifest: Manifest,
download_stream: BaseStreamReader,
read_increment: float = 0.01,
):
self.parent = parent
self.manifest = manifest
self.bytes_downloaded = 0
self.read_increment = read_increment
self.download_stream = download_stream
self.download_task: Optional[Task[None]] = None
def begin_download(self) -> Task:
self.download_task = asyncio.create_task(self._download_loop())
return self.download_task
async def _download_loop(self):
step_size = int(self.manifest.datasetSize * self.read_increment)
while not self.download_stream.at_eof():
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
bytes_read = await self.download_stream.read(step)
# We actually have no guarantees that an empty read means EOF, so we just back off
# a bit.
if not bytes_read:
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
self.bytes_downloaded += len(bytes_read)
if self.bytes_downloaded < self.manifest.datasetSize:
raise EOFError(
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
f"than expected ({self.manifest.datasetSize})."
)
if self.bytes_downloaded > self.manifest.datasetSize:
raise ValueError(
f"Download size ({self.bytes_downloaded}) was greater than expected "
f"({self.manifest.datasetSize})."
)
def progress(self) -> float:
if self.download_task is None:
return 0
if self.download_task.done():
# This will bubble exceptions up, if any.
self.download_task.result()
return self.bytes_downloaded / self.manifest.datasetSize
class CodexAgent: class CodexAgent:
def __init__(self, client: CodexClient): def __init__(self, client: CodexClient) -> None:
self.client = client self.client = client
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid: async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
@ -23,3 +82,14 @@ class CodexAgent:
return await self.client.upload( return await self.client.upload(
name=name, mime_type="application/octet-stream", content=infile name=name, mime_type="application/octet-stream", content=infile
) )
async def download(self, cid: Cid) -> DownloadHandle:
handle = DownloadHandle(
self,
manifest=await self.client.get_manifest(cid),
download_stream=await self.client.download(cid),
)
handle.begin_download()
return handle

View File

@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from typing import IO from typing import IO
import aiohttp import aiohttp
from pydantic import BaseModel from pydantic import BaseModel
from urllib3.util import Url from urllib3.util import Url
from benchmarks.core.utils.streams import BaseStreamReader
API_VERSION = "v1" API_VERSION = "v1"
Cid = str Cid = str
@ -20,7 +23,21 @@ class Manifest(BaseModel):
protected: bool protected: bool
class CodexClient: class CodexClient(ABC):
@abstractmethod
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
pass
@abstractmethod
async def get_manifest(self, cid: Cid) -> Manifest:
pass
@abstractmethod
async def download(self, cid: Cid) -> BaseStreamReader:
pass
class CodexClientImpl(CodexClient):
"""A lightweight async wrapper built around the Codex REST API.""" """A lightweight async wrapper built around the Codex REST API."""
def __init__(self, codex_api_url: Url): def __init__(self, codex_api_url: Url):
@ -55,3 +72,15 @@ class CodexClient:
cid = response_contents.pop("cid") cid = response_contents.pop("cid")
return Manifest.model_validate(dict(cid=cid, **response_contents["manifest"])) return Manifest.model_validate(dict(cid=cid, **response_contents["manifest"]))
async def download(self, cid: Cid) -> BaseStreamReader:
async with aiohttp.ClientSession() as session:
response = await session.get(
self.codex_api_url._replace(
path=f"/api/codex/v1/data/{cid}/network/download"
).url,
)
response.raise_for_status()
return response.content

View File

@ -3,12 +3,12 @@ import os
import pytest import pytest
from urllib3.util import parse_url from urllib3.util import parse_url
from benchmarks.codex.client import CodexClient from benchmarks.codex.client import CodexClientImpl
@pytest.fixture @pytest.fixture
def codex_client_1(): def codex_client_1():
# TODO wipe data between tests # TODO wipe data between tests
return CodexClient( return CodexClientImpl(
parse_url(f"http://{os.environ.get('CODEX_NODE_1', 'localhost')}:8091") parse_url(f"http://{os.environ.get('CODEX_NODE_1', 'localhost')}:8091")
) )

View File

@ -1,30 +1,118 @@
from asyncio import StreamReader
from typing import IO, Dict
import pytest import pytest
from benchmarks.codex.agent import CodexAgent from benchmarks.codex.agent import CodexAgent
from benchmarks.codex.client import CodexClient, Cid, Manifest
from benchmarks.core.concurrency import await_predicate_async
from benchmarks.core.utils.streams import BaseStreamReader
@pytest.fixture class FakeCodexClient(CodexClient):
def codex_agent(codex_client_1): def __init__(self) -> None:
return CodexAgent(codex_client_1) self.storage: Dict[Cid, Manifest] = {}
self.streams: Dict[Cid, StreamReader] = {}
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
data = content.read()
cid = "Qm" + str(hash(data))
self.storage[cid] = Manifest(
cid=cid,
datasetSize=len(data),
mimetype=mime_type,
blockSize=1,
filename=name,
treeCid="",
uploadedAt=0,
protected=False,
)
return cid
async def get_manifest(self, cid: Cid) -> Manifest:
return self.storage[cid]
def create_download_stream(self, cid: Cid) -> StreamReader:
reader = StreamReader()
self.streams[cid] = reader
return reader
async def download(self, cid: Cid) -> BaseStreamReader:
return self.streams[cid]
@pytest.mark.codex_integration
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_should_create_dataset(codex_agent: CodexAgent): async def test_should_create_dataset_of_right_size():
codex_agent = CodexAgent(FakeCodexClient())
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234) cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
manifest = await codex_agent.client.get_manifest(cid)
dataset = await codex_agent.client.get_manifest(cid) assert manifest.datasetSize == 1024
assert dataset.cid == cid
assert dataset.datasetSize == 1024
@pytest.mark.codex_integration
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_same_seed_creates_same_cid(codex_agent: CodexAgent): async def test_same_seed_creates_same_cid():
codex_agent = CodexAgent(FakeCodexClient())
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234) cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235) cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235)
assert cid1 == cid2 assert cid1 == cid2
assert cid1 != cid3 assert cid1 != cid3
@pytest.mark.asyncio
async def test_should_report_download_progress():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
assert handle.progress() == 0
for i in range(100):
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == i, timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_data(b"0" * 5)
assert len(download_stream._buffer) == 5
assert await await_predicate_async(
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
)
assert await await_predicate_async(
lambda: len(download_stream._buffer) == 0, timeout=5
)
download_stream.feed_eof()
await handle.download_task
@pytest.mark.asyncio
async def test_should_raise_exception_on_progress_query_if_download_fails():
client = FakeCodexClient()
codex_agent = CodexAgent(client)
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
download_stream = client.create_download_stream(cid)
handle = await codex_agent.download(cid)
download_stream.feed_eof()
with pytest.raises(EOFError):
def _predicate():
handle.progress()
return False
await await_predicate_async(_predicate, timeout=5)

View File

@ -2,7 +2,7 @@ from io import BytesIO
import pytest import pytest
from benchmarks.core.utils import random_data from benchmarks.core.utils.random import random_data
@pytest.fixture @pytest.fixture

View File

@ -1,11 +1,43 @@
import asyncio
from concurrent import futures from concurrent import futures
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from queue import Queue from queue import Queue
from typing import Iterable, Iterator, List, cast from time import time, sleep
from typing import Iterable, Iterator, List, cast, Awaitable, Callable
from typing_extensions import TypeVar from typing_extensions import TypeVar
def await_predicate(
predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0
) -> bool:
start_time = time()
while (timeout == 0) or ((time() - start_time) <= timeout):
if predicate():
return True
sleep(polling_interval)
return False
async def await_predicate_async(
predicate: Callable[[], Awaitable[bool]] | Callable[[], bool],
timeout: float = 0,
polling_interval: float = 0,
) -> bool:
start_time = time()
while (timeout == 0) or ((time() - start_time) <= timeout):
if asyncio.iscoroutinefunction(predicate):
if await predicate():
return True
else:
if predicate():
return True
await asyncio.sleep(polling_interval)
return False
class _End: class _End:
pass pass

View File

@ -8,8 +8,8 @@ from typing import List, Optional
from typing_extensions import Generic, TypeVar from typing_extensions import Generic, TypeVar
from benchmarks.core.concurrency import await_predicate
from benchmarks.core.config import Builder from benchmarks.core.config import Builder
from benchmarks.core.utils import await_predicate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from benchmarks.core.utils import random_data from benchmarks.core.utils.random import random_data
def test_should_generate_the_requested_amount_of_bytes(): def test_should_generate_the_requested_amount_of_bytes():

View File

View File

@ -1,18 +1,7 @@
import random import random
from time import time, sleep from typing import Iterator, IO, Optional
from typing import Iterator, Optional, Callable, IO
from benchmarks.core.utils.units import megabytes
def await_predicate(
predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0
) -> bool:
start_time = time()
while (timeout == 0) or ((time() - start_time) <= timeout):
if predicate():
return True
sleep(polling_interval)
return False
def sample(n: int) -> Iterator[int]: def sample(n: int) -> Iterator[int]:
@ -26,14 +15,6 @@ def sample(n: int) -> Iterator[int]:
yield p[i] yield p[i]
def kilobytes(n: int) -> int:
return n * 1024
def megabytes(n: int) -> int:
return kilobytes(n) * 1024
def random_data( def random_data(
size: int, outfile: IO, batch_size: int = megabytes(50), seed: Optional[int] = None size: int, outfile: IO, batch_size: int = megabytes(50), seed: Optional[int] = None
): ):

View File

@ -0,0 +1,9 @@
from typing import Protocol
class BaseStreamReader(Protocol):
async def read(self, n: int) -> bytes: ...
def feed_data(self, data: bytes) -> None: ...
def at_eof(self) -> bool: ...

View File

@ -0,0 +1,6 @@
def kilobytes(n: int) -> int:
return n * 1024
def megabytes(n: int) -> int:
return kilobytes(n) * 1024

View File

@ -4,7 +4,8 @@ from typing import Optional
from torrentool.torrent import Torrent from torrentool.torrent import Torrent
from benchmarks.core.utils import random_data, megabytes from benchmarks.core.utils.random import random_data
from benchmarks.core.utils.units import megabytes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,8 +4,8 @@ from typing import Annotated, Optional
from fastapi import FastAPI, Depends, APIRouter, Response from fastapi import FastAPI, Depends, APIRouter, Response
from benchmarks.core.agent import AgentBuilder from benchmarks.core.agent import AgentBuilder
from benchmarks.core.utils.units import megabytes
from benchmarks.core.utils import megabytes
from benchmarks.deluge.agent.agent import DelugeAgent from benchmarks.deluge.agent.agent import DelugeAgent
router = APIRouter() router = APIRouter()

View File

@ -15,7 +15,8 @@ from benchmarks.core.experiments.experiments import (
from benchmarks.core.experiments.iterated_experiment import IteratedExperiment from benchmarks.core.experiments.iterated_experiment import IteratedExperiment
from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment
from benchmarks.core.pydantic import Host from benchmarks.core.pydantic import Host
from benchmarks.core.utils import sample from benchmarks.core.utils.random import sample
from benchmarks.deluge.agent.client import DelugeAgentClient from benchmarks.deluge.agent.client import DelugeAgentClient
from benchmarks.deluge.deluge_node import DelugeMeta, DelugeNode from benchmarks.deluge.deluge_node import DelugeMeta, DelugeNode
from benchmarks.deluge.tracker import Tracker from benchmarks.deluge.tracker import Tracker

View File

@ -22,9 +22,10 @@ from tenacity.wait import wait_base
from torrentool.torrent import Torrent from torrentool.torrent import Torrent
from urllib3.util import Url from urllib3.util import Url
from benchmarks.core.concurrency import await_predicate
from benchmarks.core.experiments.experiments import ExperimentComponent from benchmarks.core.experiments.experiments import ExperimentComponent
from benchmarks.core.network import DownloadHandle, Node from benchmarks.core.network import DownloadHandle, Node
from benchmarks.core.utils import await_predicate
from benchmarks.deluge.agent.client import DelugeAgentClient from benchmarks.deluge.agent.client import DelugeAgentClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,7 +5,7 @@ from typing import Generator
import pytest import pytest
from urllib3.util import parse_url from urllib3.util import parse_url
from benchmarks.core.utils import await_predicate from benchmarks.core.concurrency import await_predicate
from benchmarks.deluge.agent.client import DelugeAgentClient from benchmarks.deluge.agent.client import DelugeAgentClient
from benchmarks.deluge.deluge_node import DelugeNode from benchmarks.deluge.deluge_node import DelugeNode
from benchmarks.deluge.tracker import Tracker from benchmarks.deluge.tracker import Tracker

View File

@ -1,7 +1,8 @@
import pytest import pytest
from tenacity import wait_incrementing, stop_after_attempt, RetryError from tenacity import wait_incrementing, stop_after_attempt, RetryError
from benchmarks.core.utils import megabytes, await_predicate from benchmarks.core.concurrency import await_predicate
from benchmarks.core.utils.units import megabytes
from benchmarks.deluge.deluge_node import DelugeNode, DelugeMeta, ResilientCallWrapper from benchmarks.deluge.deluge_node import DelugeNode, DelugeMeta, ResilientCallWrapper
from benchmarks.deluge.tracker import Tracker from benchmarks.deluge.tracker import Tracker

View File

@ -2,7 +2,8 @@ import pytest
from benchmarks.core.experiments.experiments import ExperimentEnvironment from benchmarks.core.experiments.experiments import ExperimentEnvironment
from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment
from benchmarks.core.utils import megabytes from benchmarks.core.utils.units import megabytes
from benchmarks.deluge.deluge_node import DelugeMeta from benchmarks.deluge.deluge_node import DelugeMeta
from benchmarks.deluge.tests.test_deluge_node import assert_is_seed from benchmarks.deluge.tests.test_deluge_node import assert_is_seed

View File

@ -6,7 +6,7 @@ from typing import Dict, Any
import pytest import pytest
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from benchmarks.core.utils import await_predicate from benchmarks.core.concurrency import await_predicate
def _json_data(data: str) -> Dict[str, Any]: def _json_data(data: str) -> Dict[str, Any]: