mirror of
https://github.com/logos-storage/bittorrent-benchmarks.git
synced 2026-01-05 22:43:11 +00:00
feat: add tracking of download progress via synchronous API at the Codex agent
This commit is contained in:
parent
ec44588fab
commit
cb941b859f
@ -1,15 +1,74 @@
|
||||
import asyncio
|
||||
from asyncio import Task
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
|
||||
from benchmarks.codex.client import CodexClient
|
||||
from benchmarks.core.utils import random_data
|
||||
from benchmarks.codex.client import CodexClient, Manifest
|
||||
from benchmarks.core.utils.random import random_data
|
||||
from benchmarks.core.utils.streams import BaseStreamReader
|
||||
|
||||
Cid = str
|
||||
|
||||
EMPTY_STREAM_BACKOFF = 0.1
|
||||
|
||||
|
||||
class DownloadHandle:
|
||||
def __init__(
|
||||
self,
|
||||
parent: "CodexAgent",
|
||||
manifest: Manifest,
|
||||
download_stream: BaseStreamReader,
|
||||
read_increment: float = 0.01,
|
||||
):
|
||||
self.parent = parent
|
||||
self.manifest = manifest
|
||||
self.bytes_downloaded = 0
|
||||
self.read_increment = read_increment
|
||||
self.download_stream = download_stream
|
||||
self.download_task: Optional[Task[None]] = None
|
||||
|
||||
def begin_download(self) -> Task:
|
||||
self.download_task = asyncio.create_task(self._download_loop())
|
||||
return self.download_task
|
||||
|
||||
async def _download_loop(self):
|
||||
step_size = int(self.manifest.datasetSize * self.read_increment)
|
||||
|
||||
while not self.download_stream.at_eof():
|
||||
step = min(step_size, self.manifest.datasetSize - self.bytes_downloaded)
|
||||
bytes_read = await self.download_stream.read(step)
|
||||
# We actually have no guarantees that an empty read means EOF, so we just back off
|
||||
# a bit.
|
||||
if not bytes_read:
|
||||
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
|
||||
self.bytes_downloaded += len(bytes_read)
|
||||
|
||||
if self.bytes_downloaded < self.manifest.datasetSize:
|
||||
raise EOFError(
|
||||
f"Got EOF too early: download size ({self.bytes_downloaded}) was less "
|
||||
f"than expected ({self.manifest.datasetSize})."
|
||||
)
|
||||
|
||||
if self.bytes_downloaded > self.manifest.datasetSize:
|
||||
raise ValueError(
|
||||
f"Download size ({self.bytes_downloaded}) was greater than expected "
|
||||
f"({self.manifest.datasetSize})."
|
||||
)
|
||||
|
||||
def progress(self) -> float:
|
||||
if self.download_task is None:
|
||||
return 0
|
||||
|
||||
if self.download_task.done():
|
||||
# This will bubble exceptions up, if any.
|
||||
self.download_task.result()
|
||||
|
||||
return self.bytes_downloaded / self.manifest.datasetSize
|
||||
|
||||
|
||||
class CodexAgent:
|
||||
def __init__(self, client: CodexClient):
|
||||
def __init__(self, client: CodexClient) -> None:
|
||||
self.client = client
|
||||
|
||||
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
|
||||
@ -23,3 +82,14 @@ class CodexAgent:
|
||||
return await self.client.upload(
|
||||
name=name, mime_type="application/octet-stream", content=infile
|
||||
)
|
||||
|
||||
async def download(self, cid: Cid) -> DownloadHandle:
|
||||
handle = DownloadHandle(
|
||||
self,
|
||||
manifest=await self.client.get_manifest(cid),
|
||||
download_stream=await self.client.download(cid),
|
||||
)
|
||||
|
||||
handle.begin_download()
|
||||
|
||||
return handle
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import IO
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel
|
||||
from urllib3.util import Url
|
||||
|
||||
from benchmarks.core.utils.streams import BaseStreamReader
|
||||
|
||||
API_VERSION = "v1"
|
||||
|
||||
Cid = str
|
||||
@ -20,7 +23,21 @@ class Manifest(BaseModel):
|
||||
protected: bool
|
||||
|
||||
|
||||
class CodexClient:
|
||||
class CodexClient(ABC):
|
||||
@abstractmethod
|
||||
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_manifest(self, cid: Cid) -> Manifest:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def download(self, cid: Cid) -> BaseStreamReader:
|
||||
pass
|
||||
|
||||
|
||||
class CodexClientImpl(CodexClient):
|
||||
"""A lightweight async wrapper built around the Codex REST API."""
|
||||
|
||||
def __init__(self, codex_api_url: Url):
|
||||
@ -55,3 +72,15 @@ class CodexClient:
|
||||
cid = response_contents.pop("cid")
|
||||
|
||||
return Manifest.model_validate(dict(cid=cid, **response_contents["manifest"]))
|
||||
|
||||
async def download(self, cid: Cid) -> BaseStreamReader:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response = await session.get(
|
||||
self.codex_api_url._replace(
|
||||
path=f"/api/codex/v1/data/{cid}/network/download"
|
||||
).url,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.content
|
||||
|
||||
@ -3,12 +3,12 @@ import os
|
||||
import pytest
|
||||
from urllib3.util import parse_url
|
||||
|
||||
from benchmarks.codex.client import CodexClient
|
||||
from benchmarks.codex.client import CodexClientImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codex_client_1():
|
||||
# TODO wipe data between tests
|
||||
return CodexClient(
|
||||
return CodexClientImpl(
|
||||
parse_url(f"http://{os.environ.get('CODEX_NODE_1', 'localhost')}:8091")
|
||||
)
|
||||
|
||||
@ -1,30 +1,118 @@
|
||||
from asyncio import StreamReader
|
||||
from typing import IO, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from benchmarks.codex.agent import CodexAgent
|
||||
from benchmarks.codex.client import CodexClient, Cid, Manifest
|
||||
from benchmarks.core.concurrency import await_predicate_async
|
||||
from benchmarks.core.utils.streams import BaseStreamReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codex_agent(codex_client_1):
|
||||
return CodexAgent(codex_client_1)
|
||||
class FakeCodexClient(CodexClient):
|
||||
def __init__(self) -> None:
|
||||
self.storage: Dict[Cid, Manifest] = {}
|
||||
self.streams: Dict[Cid, StreamReader] = {}
|
||||
|
||||
async def upload(self, name: str, mime_type: str, content: IO) -> Cid:
|
||||
data = content.read()
|
||||
cid = "Qm" + str(hash(data))
|
||||
self.storage[cid] = Manifest(
|
||||
cid=cid,
|
||||
datasetSize=len(data),
|
||||
mimetype=mime_type,
|
||||
blockSize=1,
|
||||
filename=name,
|
||||
treeCid="",
|
||||
uploadedAt=0,
|
||||
protected=False,
|
||||
)
|
||||
return cid
|
||||
|
||||
async def get_manifest(self, cid: Cid) -> Manifest:
|
||||
return self.storage[cid]
|
||||
|
||||
def create_download_stream(self, cid: Cid) -> StreamReader:
|
||||
reader = StreamReader()
|
||||
self.streams[cid] = reader
|
||||
return reader
|
||||
|
||||
async def download(self, cid: Cid) -> BaseStreamReader:
|
||||
return self.streams[cid]
|
||||
|
||||
|
||||
@pytest.mark.codex_integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_create_dataset(codex_agent: CodexAgent):
|
||||
async def test_should_create_dataset_of_right_size():
|
||||
codex_agent = CodexAgent(FakeCodexClient())
|
||||
cid = await codex_agent.create_dataset(size=1024, name="dataset-1", seed=1234)
|
||||
manifest = await codex_agent.client.get_manifest(cid)
|
||||
|
||||
dataset = await codex_agent.client.get_manifest(cid)
|
||||
|
||||
assert dataset.cid == cid
|
||||
assert dataset.datasetSize == 1024
|
||||
assert manifest.datasetSize == 1024
|
||||
|
||||
|
||||
@pytest.mark.codex_integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_seed_creates_same_cid(codex_agent: CodexAgent):
|
||||
async def test_same_seed_creates_same_cid():
|
||||
codex_agent = CodexAgent(FakeCodexClient())
|
||||
|
||||
cid1 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
|
||||
cid2 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1234)
|
||||
cid3 = await codex_agent.create_dataset(size=2048, name="dataset-1", seed=1235)
|
||||
|
||||
assert cid1 == cid2
|
||||
assert cid1 != cid3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_report_download_progress():
|
||||
client = FakeCodexClient()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
|
||||
download_stream = client.create_download_stream(cid)
|
||||
|
||||
handle = await codex_agent.download(cid)
|
||||
|
||||
assert handle.progress() == 0
|
||||
|
||||
for i in range(100):
|
||||
download_stream.feed_data(b"0" * 5)
|
||||
assert len(download_stream._buffer) == 5
|
||||
assert await await_predicate_async(
|
||||
lambda: round(handle.progress() * 100) == i, timeout=5
|
||||
)
|
||||
assert await await_predicate_async(
|
||||
lambda: len(download_stream._buffer) == 0, timeout=5
|
||||
)
|
||||
|
||||
download_stream.feed_data(b"0" * 5)
|
||||
assert len(download_stream._buffer) == 5
|
||||
assert await await_predicate_async(
|
||||
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
|
||||
)
|
||||
assert await await_predicate_async(
|
||||
lambda: len(download_stream._buffer) == 0, timeout=5
|
||||
)
|
||||
|
||||
download_stream.feed_eof()
|
||||
await handle.download_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_raise_exception_on_progress_query_if_download_fails():
|
||||
client = FakeCodexClient()
|
||||
codex_agent = CodexAgent(client)
|
||||
|
||||
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1234)
|
||||
download_stream = client.create_download_stream(cid)
|
||||
|
||||
handle = await codex_agent.download(cid)
|
||||
|
||||
download_stream.feed_eof()
|
||||
|
||||
with pytest.raises(EOFError):
|
||||
|
||||
def _predicate():
|
||||
handle.progress()
|
||||
return False
|
||||
|
||||
await await_predicate_async(_predicate, timeout=5)
|
||||
|
||||
@ -2,7 +2,7 @@ from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from benchmarks.core.utils import random_data
|
||||
from benchmarks.core.utils.random import random_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,11 +1,43 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import Iterable, Iterator, List, cast
|
||||
from time import time, sleep
|
||||
from typing import Iterable, Iterator, List, cast, Awaitable, Callable
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
|
||||
def await_predicate(
|
||||
predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0
|
||||
) -> bool:
|
||||
start_time = time()
|
||||
while (timeout == 0) or ((time() - start_time) <= timeout):
|
||||
if predicate():
|
||||
return True
|
||||
sleep(polling_interval)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def await_predicate_async(
|
||||
predicate: Callable[[], Awaitable[bool]] | Callable[[], bool],
|
||||
timeout: float = 0,
|
||||
polling_interval: float = 0,
|
||||
) -> bool:
|
||||
start_time = time()
|
||||
while (timeout == 0) or ((time() - start_time) <= timeout):
|
||||
if asyncio.iscoroutinefunction(predicate):
|
||||
if await predicate():
|
||||
return True
|
||||
else:
|
||||
if predicate():
|
||||
return True
|
||||
await asyncio.sleep(polling_interval)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class _End:
|
||||
pass
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@ from typing import List, Optional
|
||||
|
||||
from typing_extensions import Generic, TypeVar
|
||||
|
||||
from benchmarks.core.concurrency import await_predicate
|
||||
from benchmarks.core.config import Builder
|
||||
from benchmarks.core.utils import await_predicate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from io import BytesIO
|
||||
|
||||
from benchmarks.core.utils import random_data
|
||||
from benchmarks.core.utils.random import random_data
|
||||
|
||||
|
||||
def test_should_generate_the_requested_amount_of_bytes():
|
||||
|
||||
0
benchmarks/core/utils/__init__.py
Normal file
0
benchmarks/core/utils/__init__.py
Normal file
@ -1,18 +1,7 @@
|
||||
import random
|
||||
from time import time, sleep
|
||||
from typing import Iterator, Optional, Callable, IO
|
||||
from typing import Iterator, IO, Optional
|
||||
|
||||
|
||||
def await_predicate(
|
||||
predicate: Callable[[], bool], timeout: float = 0, polling_interval: float = 0
|
||||
) -> bool:
|
||||
start_time = time()
|
||||
while (timeout == 0) or ((time() - start_time) <= timeout):
|
||||
if predicate():
|
||||
return True
|
||||
sleep(polling_interval)
|
||||
|
||||
return False
|
||||
from benchmarks.core.utils.units import megabytes
|
||||
|
||||
|
||||
def sample(n: int) -> Iterator[int]:
|
||||
@ -26,14 +15,6 @@ def sample(n: int) -> Iterator[int]:
|
||||
yield p[i]
|
||||
|
||||
|
||||
def kilobytes(n: int) -> int:
|
||||
return n * 1024
|
||||
|
||||
|
||||
def megabytes(n: int) -> int:
|
||||
return kilobytes(n) * 1024
|
||||
|
||||
|
||||
def random_data(
|
||||
size: int, outfile: IO, batch_size: int = megabytes(50), seed: Optional[int] = None
|
||||
):
|
||||
9
benchmarks/core/utils/streams.py
Normal file
9
benchmarks/core/utils/streams.py
Normal file
@ -0,0 +1,9 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class BaseStreamReader(Protocol):
|
||||
async def read(self, n: int) -> bytes: ...
|
||||
|
||||
def feed_data(self, data: bytes) -> None: ...
|
||||
|
||||
def at_eof(self) -> bool: ...
|
||||
6
benchmarks/core/utils/units.py
Normal file
6
benchmarks/core/utils/units.py
Normal file
@ -0,0 +1,6 @@
|
||||
def kilobytes(n: int) -> int:
|
||||
return n * 1024
|
||||
|
||||
|
||||
def megabytes(n: int) -> int:
|
||||
return kilobytes(n) * 1024
|
||||
@ -4,7 +4,8 @@ from typing import Optional
|
||||
|
||||
from torrentool.torrent import Torrent
|
||||
|
||||
from benchmarks.core.utils import random_data, megabytes
|
||||
from benchmarks.core.utils.random import random_data
|
||||
from benchmarks.core.utils.units import megabytes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Annotated, Optional
|
||||
from fastapi import FastAPI, Depends, APIRouter, Response
|
||||
|
||||
from benchmarks.core.agent import AgentBuilder
|
||||
from benchmarks.core.utils.units import megabytes
|
||||
|
||||
from benchmarks.core.utils import megabytes
|
||||
from benchmarks.deluge.agent.agent import DelugeAgent
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -15,7 +15,8 @@ from benchmarks.core.experiments.experiments import (
|
||||
from benchmarks.core.experiments.iterated_experiment import IteratedExperiment
|
||||
from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment
|
||||
from benchmarks.core.pydantic import Host
|
||||
from benchmarks.core.utils import sample
|
||||
from benchmarks.core.utils.random import sample
|
||||
|
||||
from benchmarks.deluge.agent.client import DelugeAgentClient
|
||||
from benchmarks.deluge.deluge_node import DelugeMeta, DelugeNode
|
||||
from benchmarks.deluge.tracker import Tracker
|
||||
|
||||
@ -22,9 +22,10 @@ from tenacity.wait import wait_base
|
||||
from torrentool.torrent import Torrent
|
||||
from urllib3.util import Url
|
||||
|
||||
from benchmarks.core.concurrency import await_predicate
|
||||
from benchmarks.core.experiments.experiments import ExperimentComponent
|
||||
from benchmarks.core.network import DownloadHandle, Node
|
||||
from benchmarks.core.utils import await_predicate
|
||||
|
||||
from benchmarks.deluge.agent.client import DelugeAgentClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Generator
|
||||
import pytest
|
||||
from urllib3.util import parse_url
|
||||
|
||||
from benchmarks.core.utils import await_predicate
|
||||
from benchmarks.core.concurrency import await_predicate
|
||||
from benchmarks.deluge.agent.client import DelugeAgentClient
|
||||
from benchmarks.deluge.deluge_node import DelugeNode
|
||||
from benchmarks.deluge.tracker import Tracker
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
from tenacity import wait_incrementing, stop_after_attempt, RetryError
|
||||
|
||||
from benchmarks.core.utils import megabytes, await_predicate
|
||||
from benchmarks.core.concurrency import await_predicate
|
||||
from benchmarks.core.utils.units import megabytes
|
||||
from benchmarks.deluge.deluge_node import DelugeNode, DelugeMeta, ResilientCallWrapper
|
||||
from benchmarks.deluge.tracker import Tracker
|
||||
|
||||
|
||||
@ -2,7 +2,8 @@ import pytest
|
||||
|
||||
from benchmarks.core.experiments.experiments import ExperimentEnvironment
|
||||
from benchmarks.core.experiments.static_experiment import StaticDisseminationExperiment
|
||||
from benchmarks.core.utils import megabytes
|
||||
from benchmarks.core.utils.units import megabytes
|
||||
|
||||
from benchmarks.deluge.deluge_node import DelugeMeta
|
||||
from benchmarks.deluge.tests.test_deluge_node import assert_is_seed
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Dict, Any
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from benchmarks.core.utils import await_predicate
|
||||
from benchmarks.core.concurrency import await_predicate
|
||||
|
||||
|
||||
def _json_data(data: str) -> Dict[str, Any]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user