diff --git a/benchmarks/cli.py b/benchmarks/cli.py index 9bef296..5340bfb 100644 --- a/benchmarks/cli.py +++ b/benchmarks/cli.py @@ -18,9 +18,9 @@ def _parse_config(config: Path): print(f'Config file {config} does not exist.') sys.exit(-1) - with config.open(encoding='utf-8') as config: + with config.open(encoding='utf-8') as infile: try: - return parser.parse(config) + return parser.parse(infile) except ValidationError as e: print(f'There were errors parsing the config file.') for error in e.errors(): @@ -33,9 +33,9 @@ def list(config: Path): """ Lists the experiments available in CONFIG. """ - parsed = _parse_config(config) + experiments = _parse_config(config) print(f'Available experiments in {config}:') - for experiment in parsed.keys(): + for experiment in experiments.keys(): print(f' - {experiment}') @@ -44,11 +44,11 @@ def run(config: Path, experiment: str): """ Runs the experiment with name EXPERIMENT. """ - parsed = _parse_config(config) - if experiment not in parsed: + experiments = _parse_config(config) + if experiment not in experiments: print(f'Experiment {experiment} not found in {config}.') sys.exit(-1) - parsed[experiment].run() + experiments[experiment].run() if __name__ == '__main__': diff --git a/benchmarks/core/config.py b/benchmarks/core/config.py index 7ef45a5..c0da879 100644 --- a/benchmarks/core/config.py +++ b/benchmarks/core/config.py @@ -3,10 +3,10 @@ import os import re from abc import abstractmethod from io import TextIOBase -from typing import Annotated, Type, Dict, TextIO +from typing import Annotated, Type, Dict, TextIO, Callable, cast import yaml -from pydantic import BaseModel, IPvAnyAddress, AfterValidator +from pydantic import BaseModel, IPvAnyAddress, AfterValidator, TypeAdapter from typing_extensions import Generic, overload from benchmarks.core.experiments.experiments import TExperiment, Experiment @@ -39,9 +39,7 @@ def is_valid_domain_name(domain_name: str): DomainName = Annotated[str, AfterValidator(is_valid_domain_name)] - -class Host(BaseModel): - address: IPvAnyAddress | DomainName +type Host = IPvAnyAddress | DomainName class ExperimentBuilder(ConfigModel, Generic[TExperiment]): @@ -61,20 +59,21 @@ class ConfigParser: def __init__(self): self.root_tags = {} - def register(self, root: Type[ExperimentBuilder[Experiment]]): + def register(self, root: Type[ExperimentBuilder[TExperiment]]): name = root.__name__ - alias = root.model_config.get('alias_generator', lambda x: x)(name) + alias = cast(Callable[[str], str], + root.model_config.get('alias_generator', lambda x: x))(name) self.root_tags[alias] = root @overload - def parse(self, data: dict) -> Dict[str, ExperimentBuilder[Experiment]]: + def parse(self, data: dict) -> Dict[str, ExperimentBuilder[TExperiment]]: ... @overload - def parse(self, data: TextIO) -> Dict[str, ExperimentBuilder[Experiment]]: + def parse(self, data: TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]: ... - def parse(self, data: dict | TextIO) -> Dict[str, ExperimentBuilder[Experiment]]: + def parse(self, data: dict | TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]: if isinstance(data, TextIOBase): entries = yaml.safe_load(os.path.expandvars(data.read())) else: diff --git a/benchmarks/core/experiments/static_experiment.py b/benchmarks/core/experiments/static_experiment.py index c81f8bd..4bbb46e 100644 --- a/benchmarks/core/experiments/static_experiment.py +++ b/benchmarks/core/experiments/static_experiment.py @@ -30,10 +30,12 @@ class StaticDisseminationExperiment(Generic[TNetworkHandle, TInitialMetadata], E ) with self.data as (meta, data): - meta_or_cid = meta + cid = None for node in seeders: - meta_or_cid = node.seed(data, meta_or_cid) + cid = node.seed(data, meta if cid is None else cid) - downloads = [node.leech(meta_or_cid) for node in leechers] + assert cid is not None # to please mypy + + downloads = [node.leech(cid) for node in leechers] for download in downloads: download.await_for_completion() diff --git a/benchmarks/core/tests/test_config.py b/benchmarks/core/tests/test_config.py index 6aa3d71..c1c79d8 100644 --- a/benchmarks/core/tests/test_config.py +++ b/benchmarks/core/tests/test_config.py @@ -5,37 +5,37 @@ from typing import cast import pytest import yaml -from pydantic import ValidationError, BaseModel +from pydantic import ValidationError, TypeAdapter from benchmarks.core.config import Host, DomainName, ConfigParser, ConfigModel def test_should_parse_ipv4_address(): - h = Host(address='192.168.1.1') - assert h.address == IPv4Address('192.168.1.1') + h = TypeAdapter(Host).validate_strings('192.168.1.1') + assert h == IPv4Address('192.168.1.1') def test_should_parse_ipv6_address(): - h = Host(address='2001:0000:130F:0000:0000:09C0:876A:130B') - assert h.address == IPv6Address('2001:0000:130F:0000:0000:09C0:876A:130B') + h = TypeAdapter(Host).validate_strings('2001:0000:130F:0000:0000:09C0:876A:130B') + assert h == IPv6Address('2001:0000:130F:0000:0000:09C0:876A:130B') def test_should_parse_simple_dns_names(): - h = Host(address='node-1.local.svc') - assert h.address == DomainName('node-1.local.svc') + h = TypeAdapter(Host).validate_strings('node-1.local.svc') + assert h == DomainName('node-1.local.svc') def test_should_parse_localhost(): - h = Host(address='localhost') - assert h.address == DomainName('localhost') + h = TypeAdapter(Host).validate_strings('localhost') + assert h == DomainName('localhost') def test_should_return_correct_string_representation_for_addresses(): - h = Host(address='localhost') - assert str(h.address) == 'localhost' + h = TypeAdapter(Host).validate_strings('localhost') + assert h == DomainName('localhost') - h = Host(address='192.168.1.1') - assert str(h.address) == '192.168.1.1' + h = TypeAdapter(Host).validate_strings('192.168.1.1') + assert h == IPv4Address('192.168.1.1') def test_should_fail_invalid_names(): @@ -50,7 +50,7 @@ def test_should_fail_invalid_names(): for invalid_name in invalid_names: with pytest.raises(ValidationError): - Host(address=invalid_name) + TypeAdapter(Host).validate_strings(invalid_name) class Root1(ConfigModel): diff --git a/benchmarks/deluge/config.py b/benchmarks/deluge/config.py index e74594b..6e8dae7 100644 --- a/benchmarks/deluge/config.py +++ b/benchmarks/deluge/config.py @@ -30,7 +30,7 @@ class DelugeNodeSetConfig(BaseModel): def expand_nodes(self): self.nodes = [ DelugeNodeConfig( - address=Host(address=self.address.format(node_index=str(i))), + address=self.address.format(node_index=str(i)), daemon_port=self.daemon_port, listen_ports=self.listen_ports, ) @@ -59,7 +59,7 @@ class DelugeExperimentConfig(ExperimentBuilder[DelugeDisseminationExperiment]): name=f'deluge-{i}', volume=self.shared_volume_path / f'deluge-{i}', daemon_port=node.daemon_port, - daemon_address=str(node.address.address), + daemon_address=str(node.address), ) for i, node in enumerate(nodes) ], diff --git a/benchmarks/deluge/tests/test_config.py b/benchmarks/deluge/tests/test_config.py index d8a10c9..e7c5ed2 100644 --- a/benchmarks/deluge/tests/test_config.py +++ b/benchmarks/deluge/tests/test_config.py @@ -4,7 +4,6 @@ from unittest.mock import patch import yaml -from benchmarks.core.config import Host from benchmarks.deluge.config import DelugeNodeSetConfig, DelugeNodeConfig, DelugeExperimentConfig from benchmarks.deluge.deluge_node import DelugeNode @@ -19,22 +18,22 @@ def test_should_expand_node_sets_into_simple_nodes(): assert nodeset.nodes == [ DelugeNodeConfig( - address=Host(address='deluge-1.local.svc'), + address='deluge-1.local.svc', daemon_port=6080, listen_ports=[6081, 6082], ), DelugeNodeConfig( - address=Host(address='deluge-2.local.svc'), + address='deluge-2.local.svc', daemon_port=6080, listen_ports=[6081, 6082], ), DelugeNodeConfig( - address=Host(address='deluge-3.local.svc'), + address='deluge-3.local.svc', daemon_port=6080, listen_ports=[6081, 6082], ), DelugeNodeConfig( - address=Host(address='deluge-4.local.svc'), + address='deluge-4.local.svc', daemon_port=6080, listen_ports=[6081, 6082], ),