mirror of
https://github.com/logos-storage/bittorrent-benchmarks.git
synced 2026-01-07 15:33:10 +00:00
refactor: simplify download progress reporting
This commit is contained in:
parent
bd0ef9ca55
commit
849bcad6c8
@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from asyncio import Task
|
from asyncio import Task
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
|
|
||||||
from benchmarks.codex.agent.codex_client import CodexClient, Manifest
|
from benchmarks.codex.agent.codex_client import CodexClient, Manifest
|
||||||
from benchmarks.codex.logging import CodexDownloadMetric
|
from benchmarks.codex.logging import CodexDownloadMetric
|
||||||
@ -17,6 +18,12 @@ EMPTY_STREAM_BACKOFF = 0.1
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadStatus:
|
||||||
|
downloaded: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class DownloadHandle:
|
class DownloadHandle:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -37,6 +44,7 @@ class DownloadHandle:
|
|||||||
return self.download_task
|
return self.download_task
|
||||||
|
|
||||||
async def _download_loop(self):
|
async def _download_loop(self):
|
||||||
|
try:
|
||||||
step_size = int(self.manifest.datasetSize * self.read_increment)
|
step_size = int(self.manifest.datasetSize * self.read_increment)
|
||||||
|
|
||||||
while not self.download_stream.at_eof():
|
while not self.download_stream.at_eof():
|
||||||
@ -47,6 +55,7 @@ class DownloadHandle:
|
|||||||
if not bytes_read:
|
if not bytes_read:
|
||||||
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
|
await asyncio.sleep(EMPTY_STREAM_BACKOFF)
|
||||||
self.bytes_downloaded += len(bytes_read)
|
self.bytes_downloaded += len(bytes_read)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
CodexDownloadMetric(
|
CodexDownloadMetric(
|
||||||
cid=self.manifest.cid,
|
cid=self.manifest.cid,
|
||||||
@ -66,22 +75,27 @@ class DownloadHandle:
|
|||||||
f"Download size ({self.bytes_downloaded}) was greater than expected "
|
f"Download size ({self.bytes_downloaded}) was greater than expected "
|
||||||
f"({self.manifest.datasetSize})."
|
f"({self.manifest.datasetSize})."
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
self.parent._download_done(self.manifest.cid)
|
||||||
|
|
||||||
def progress(self) -> float:
|
def progress(self) -> DownloadStatus:
|
||||||
if self.download_task is None:
|
if self.download_task is None:
|
||||||
return 0
|
return DownloadStatus(downloaded=0, total=self.manifest.datasetSize)
|
||||||
|
|
||||||
if self.download_task.done():
|
if self.download_task.done():
|
||||||
# This will bubble exceptions up, if any.
|
# This will bubble exceptions up, if any.
|
||||||
self.download_task.result()
|
self.download_task.result()
|
||||||
|
|
||||||
return self.bytes_downloaded / self.manifest.datasetSize
|
return DownloadStatus(
|
||||||
|
downloaded=self.bytes_downloaded, total=self.manifest.datasetSize
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CodexAgent:
|
class CodexAgent:
|
||||||
def __init__(self, client: CodexClient, node_id: str = "unknown") -> None:
|
def __init__(self, client: CodexClient, node_id: str = "unknown") -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
|
self.ongoing_downloads: Dict[Cid, DownloadHandle] = {}
|
||||||
|
|
||||||
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
|
async def create_dataset(self, name: str, size: int, seed: Optional[int]) -> Cid:
|
||||||
with TemporaryDirectory() as td:
|
with TemporaryDirectory() as td:
|
||||||
@ -96,6 +110,9 @@ class CodexAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
|
async def download(self, cid: Cid, read_increment: float = 0.01) -> DownloadHandle:
|
||||||
|
if cid in self.ongoing_downloads:
|
||||||
|
return self.ongoing_downloads[cid]
|
||||||
|
|
||||||
handle = DownloadHandle(
|
handle = DownloadHandle(
|
||||||
self,
|
self,
|
||||||
manifest=await self.client.get_manifest(cid),
|
manifest=await self.client.get_manifest(cid),
|
||||||
@ -105,4 +122,8 @@ class CodexAgent:
|
|||||||
|
|
||||||
handle.begin_download()
|
handle.begin_download()
|
||||||
|
|
||||||
|
self.ongoing_downloads[cid] = handle
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
def _download_done(self, cid: Cid):
|
||||||
|
self.ongoing_downloads.pop(cid)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from benchmarks.codex.agent.agent import CodexAgent
|
from benchmarks.codex.agent.agent import CodexAgent, DownloadStatus
|
||||||
from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest
|
from benchmarks.codex.agent.codex_client import CodexClient, Cid, Manifest
|
||||||
from benchmarks.codex.logging import CodexDownloadMetric
|
from benchmarks.codex.logging import CodexDownloadMetric
|
||||||
from benchmarks.core.concurrency import await_predicate_async
|
from benchmarks.core.concurrency import await_predicate_async
|
||||||
@ -76,30 +76,21 @@ async def test_should_report_download_progress():
|
|||||||
|
|
||||||
handle = await codex_agent.download(cid)
|
handle = await codex_agent.download(cid)
|
||||||
|
|
||||||
assert handle.progress() == 0
|
assert handle.progress() == DownloadStatus(downloaded=0, total=1000)
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(200):
|
||||||
download_stream.feed_data(b"0" * 5)
|
download_stream.feed_data(b"0" * 5)
|
||||||
assert len(download_stream._buffer) == 5
|
|
||||||
assert await await_predicate_async(
|
assert await await_predicate_async(
|
||||||
lambda: round(handle.progress() * 100) == i, timeout=5
|
lambda: handle.progress()
|
||||||
)
|
== DownloadStatus(downloaded=5 * (i + 1), total=1000),
|
||||||
assert await await_predicate_async(
|
timeout=5,
|
||||||
lambda: len(download_stream._buffer) == 0, timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
download_stream.feed_data(b"0" * 5)
|
|
||||||
assert len(download_stream._buffer) == 5
|
|
||||||
assert await await_predicate_async(
|
|
||||||
lambda: round(handle.progress() * 100) == (i + 1), timeout=5
|
|
||||||
)
|
|
||||||
assert await await_predicate_async(
|
|
||||||
lambda: len(download_stream._buffer) == 0, timeout=5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
download_stream.feed_eof()
|
download_stream.feed_eof()
|
||||||
await handle.download_task
|
await handle.download_task
|
||||||
|
|
||||||
|
assert handle.progress() == DownloadStatus(downloaded=1000, total=1000)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_should_raise_exception_on_progress_query_if_download_fails():
|
async def test_should_raise_exception_on_progress_query_if_download_fails():
|
||||||
@ -114,16 +105,11 @@ async def test_should_raise_exception_on_progress_query_if_download_fails():
|
|||||||
download_stream.feed_eof()
|
download_stream.feed_eof()
|
||||||
|
|
||||||
with pytest.raises(EOFError):
|
with pytest.raises(EOFError):
|
||||||
|
await handle.download_task
|
||||||
def _predicate():
|
|
||||||
handle.progress()
|
|
||||||
return False
|
|
||||||
|
|
||||||
await await_predicate_async(_predicate, timeout=5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_should_log_download_progress_as_metric(mock_logger):
|
async def test_should_log_download_progress_as_metric_in_discrete_steps(mock_logger):
|
||||||
logger, output = mock_logger
|
logger, output = mock_logger
|
||||||
|
|
||||||
with patch("benchmarks.codex.agent.agent.logger", logger):
|
with patch("benchmarks.codex.agent.agent.logger", logger):
|
||||||
@ -164,3 +150,23 @@ async def test_should_log_download_progress_as_metric(mock_logger):
|
|||||||
timestamp=metrics[4].timestamp,
|
timestamp=metrics[4].timestamp,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_should_track_download_handles_and_dispose_of_them_at_the_end():
|
||||||
|
client = FakeCodexClient()
|
||||||
|
codex_agent = CodexAgent(client)
|
||||||
|
|
||||||
|
cid = await codex_agent.create_dataset(size=1000, name="dataset-1", seed=1356)
|
||||||
|
download_stream = client.create_download_stream(cid)
|
||||||
|
|
||||||
|
handle = await codex_agent.download(cid)
|
||||||
|
|
||||||
|
assert codex_agent.ongoing_downloads[cid] == handle
|
||||||
|
|
||||||
|
download_stream.feed_data(b"0" * 1000)
|
||||||
|
download_stream.feed_eof()
|
||||||
|
|
||||||
|
await handle.download_task
|
||||||
|
|
||||||
|
assert cid not in codex_agent.ongoing_downloads
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user