52 lines
1.6 KiB
Python
Raw Normal View History

"""Basic utilities for structuring experiment configurations based on Pydantic schemas."""
2024-11-28 15:15:05 -03:00
import os
from abc import abstractmethod
2024-11-28 15:15:05 -03:00
from io import TextIOBase
from typing import Type, Dict, TextIO, Callable, cast
2024-11-28 15:15:05 -03:00
import yaml
from typing_extensions import Generic, overload
from benchmarks.core.experiments.experiments import TExperiment
from benchmarks.core.pydantic import SnakeCaseModel
class ExperimentBuilder(SnakeCaseModel, Generic[TExperiment]):
""":class:`ExperimentBuilders` can build real :class:`Experiment`s out of :class:`ConfigModel`s. """
@abstractmethod
def build(self) -> TExperiment:
pass
2024-11-28 15:15:05 -03:00
class ConfigParser:
"""
:class:`ConfigParser` is a utility class to parse configuration files into :class:`ExperimentBuilder`s.
Currently, each :class:`ExperimentBuilder` can appear at most once in the config file.
"""
def __init__(self):
self.experiment_types = {}
2024-11-28 15:15:05 -03:00
2024-11-28 15:52:38 -03:00
def register(self, root: Type[ExperimentBuilder[TExperiment]]):
self.experiment_types[root.alias()] = root
2024-11-28 15:15:05 -03:00
@overload
2024-11-28 15:52:38 -03:00
def parse(self, data: dict) -> Dict[str, ExperimentBuilder[TExperiment]]:
2024-11-28 15:15:05 -03:00
...
@overload
2024-11-28 15:52:38 -03:00
def parse(self, data: TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]:
2024-11-28 15:15:05 -03:00
...
2024-11-28 15:52:38 -03:00
def parse(self, data: dict | TextIO) -> Dict[str, ExperimentBuilder[TExperiment]]:
2024-11-28 15:15:05 -03:00
if isinstance(data, TextIOBase):
entries = yaml.safe_load(os.path.expandvars(data.read()))
else:
entries = data
return {
tag: self.experiment_types[tag].model_validate(config)
2024-11-28 15:15:05 -03:00
for tag, config in entries.items()
}