2024-11-27 11:22:07 -03:00
|
|
|
"""Basic utilities for structuring experiment configurations based on Pydantic schemas."""
|
2024-11-28 15:15:05 -03:00
|
|
|
import os
|
2024-11-27 11:22:07 -03:00
|
|
|
import re
|
|
|
|
from abc import abstractmethod
|
2024-11-28 15:15:05 -03:00
|
|
|
from io import TextIOBase
|
2024-11-28 15:52:38 -03:00
|
|
|
from typing import Annotated, Type, Dict, TextIO, Callable, cast
|
2024-11-27 11:22:07 -03:00
|
|
|
|
2024-11-28 15:15:05 -03:00
|
|
|
import yaml
|
2024-11-28 15:52:38 -03:00
|
|
|
from pydantic import BaseModel, IPvAnyAddress, AfterValidator, TypeAdapter
|
2024-11-28 15:15:05 -03:00
|
|
|
from typing_extensions import Generic, overload
|
2024-11-27 11:22:07 -03:00
|
|
|
|
2024-11-28 15:15:05 -03:00
|
|
|
from benchmarks.core.experiments.experiments import TExperiment, Experiment
|
2024-11-27 11:22:07 -03:00
|
|
|
|
|
|
|
|
|
|
|
def drop_config_suffix(name: str) -> str:
|
|
|
|
return name[:-6] if name.endswith('Config') else name
|
|
|
|
|
|
|
|
|
2024-11-28 15:15:05 -03:00
|
|
|
def to_snake_case(name: str) -> str:
|
|
|
|
return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
|
|
|
|
|
|
|
|
|
2024-11-27 11:22:07 -03:00
|
|
|
class ConfigModel(BaseModel):
|
|
|
|
model_config = {
|
2024-11-28 15:15:05 -03:00
|
|
|
'alias_generator': lambda x: to_snake_case(drop_config_suffix(x))
|
2024-11-27 11:22:07 -03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# This is a simple regex which is not by any means exhaustive but should cover gross syntax errors.
|
|
|
|
VALID_DOMAIN_NAME = re.compile(r"^localhost$|^(?!-)([A-Za-z0-9-]+\.)+[A-Za-z]{2,6}$")
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_domain_name(domain_name: str):
|
|
|
|
stripped = domain_name.strip()
|
|
|
|
matches = VALID_DOMAIN_NAME.match(stripped)
|
|
|
|
assert matches is not None
|
|
|
|
return stripped
|
|
|
|
|
|
|
|
|
|
|
|
DomainName = Annotated[str, AfterValidator(is_valid_domain_name)]
|
|
|
|
|
2024-11-28 15:52:38 -03:00
|
|
|
type Host = IPvAnyAddress | DomainName
|
2024-11-27 11:22:07 -03:00
|
|
|
|
|
|
|
|
2024-11-28 15:15:05 -03:00
|
|
|
class ExperimentBuilder(ConfigModel, Generic[TExperiment]):
|
2024-11-27 11:22:07 -03:00
|
|
|
""":class:`ExperimentBuilders` can build real :class:`Experiment`s out of :class:`ConfigModel`s. """
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def build(self) -> TExperiment:
|
|
|
|
pass
|
2024-11-28 15:15:05 -03:00
|
|
|
|
|
|
|
|
|
|
|
class ConfigParser:
|
|
|
|
"""
|
|
|
|
:class:`ConfigParser` is a utility class to parse configuration files into :class:`ExperimentBuilder`s.
|
|
|
|
Currently, each :class:`ExperimentBuilder` can appear at most once in the config file.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.root_tags = {}
|
|
|
|
|
2024-11-28 15:52:38 -03:00
|
|
|
def register(self, root: Type[ExperimentBuilder[TExperiment]]):
|
2024-11-28 15:15:05 -03:00
|
|
|
name = root.__name__
|
2024-11-28 15:52:38 -03:00
|
|
|
alias = cast(Callable[[str], str],
|
|
|
|
root.model_config.get('alias_generator', lambda x: x))(name)
|
2024-11-28 15:15:05 -03:00
|
|
|
self.root_tags[alias] = root
|
|
|
|
|
|
|
|
@overload
|
2024-11-28 15:52:38 -03:00
|
|
|
def parse(self, data: dict) -> Dict[str, ExperimentBuilder[TExperiment]]:
|
2024-11-28 15:15:05 -03:00
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
2024-11-28 15:52:38 -03:00
|
|
|
def parse(self, data: TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]:
|
2024-11-28 15:15:05 -03:00
|
|
|
...
|
|
|
|
|
2024-11-28 15:52:38 -03:00
|
|
|
def parse(self, data: dict | TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]:
|
2024-11-28 15:15:05 -03:00
|
|
|
if isinstance(data, TextIOBase):
|
|
|
|
entries = yaml.safe_load(os.path.expandvars(data.read()))
|
|
|
|
else:
|
|
|
|
entries = data
|
|
|
|
|
|
|
|
return {
|
|
|
|
tag: self.root_tags[tag].model_validate(config)
|
|
|
|
for tag, config in entries.items()
|
|
|
|
}
|