86 lines
2.6 KiB
Python
Raw Normal View History

"""Basic utilities for structuring experiment configurations based on Pydantic schemas."""
2024-11-28 15:15:05 -03:00
import os
import re
from abc import abstractmethod
2024-11-28 15:15:05 -03:00
from io import TextIOBase
2024-11-28 15:52:38 -03:00
from typing import Annotated, Type, Dict, TextIO, Callable, cast
2024-11-28 15:15:05 -03:00
import yaml
2024-11-28 15:52:38 -03:00
from pydantic import BaseModel, IPvAnyAddress, AfterValidator, TypeAdapter
2024-11-28 15:15:05 -03:00
from typing_extensions import Generic, overload
2024-11-28 15:15:05 -03:00
from benchmarks.core.experiments.experiments import TExperiment, Experiment
def drop_config_suffix(name: str) -> str:
return name[:-6] if name.endswith('Config') else name
2024-11-28 15:15:05 -03:00
def to_snake_case(name: str) -> str:
return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
class ConfigModel(BaseModel):
model_config = {
2024-11-28 15:15:05 -03:00
'alias_generator': lambda x: to_snake_case(drop_config_suffix(x))
}
# This is a simple regex which is not by any means exhaustive but should cover gross syntax errors.
VALID_DOMAIN_NAME = re.compile(r"^localhost$|^(?!-)([A-Za-z0-9-]+\.)+[A-Za-z]{2,6}$")
def is_valid_domain_name(domain_name: str):
stripped = domain_name.strip()
matches = VALID_DOMAIN_NAME.match(stripped)
assert matches is not None
return stripped
DomainName = Annotated[str, AfterValidator(is_valid_domain_name)]
2024-11-28 15:52:38 -03:00
type Host = IPvAnyAddress | DomainName
2024-11-28 15:15:05 -03:00
class ExperimentBuilder(ConfigModel, Generic[TExperiment]):
""":class:`ExperimentBuilders` can build real :class:`Experiment`s out of :class:`ConfigModel`s. """
@abstractmethod
def build(self) -> TExperiment:
pass
2024-11-28 15:15:05 -03:00
class ConfigParser:
"""
:class:`ConfigParser` is a utility class to parse configuration files into :class:`ExperimentBuilder`s.
Currently, each :class:`ExperimentBuilder` can appear at most once in the config file.
"""
def __init__(self):
self.root_tags = {}
2024-11-28 15:52:38 -03:00
def register(self, root: Type[ExperimentBuilder[TExperiment]]):
2024-11-28 15:15:05 -03:00
name = root.__name__
2024-11-28 15:52:38 -03:00
alias = cast(Callable[[str], str],
root.model_config.get('alias_generator', lambda x: x))(name)
2024-11-28 15:15:05 -03:00
self.root_tags[alias] = root
@overload
2024-11-28 15:52:38 -03:00
def parse(self, data: dict) -> Dict[str, ExperimentBuilder[TExperiment]]:
2024-11-28 15:15:05 -03:00
...
@overload
2024-11-28 15:52:38 -03:00
def parse(self, data: TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]:
2024-11-28 15:15:05 -03:00
...
2024-11-28 15:52:38 -03:00
def parse(self, data: dict | TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]:
2024-11-28 15:15:05 -03:00
if isinstance(data, TextIOBase):
entries = yaml.safe_load(os.path.expandvars(data.read()))
else:
entries = data
return {
tag: self.root_tags[tag].model_validate(config)
for tag, config in entries.items()
}