57 lines
1.6 KiB
Python
Raw Normal View History

"""Basic utilities for structuring experiment configurations based on Pydantic schemas."""
2024-12-14 06:34:11 -03:00
2024-11-28 15:15:05 -03:00
import os
from abc import abstractmethod
2024-11-28 15:15:05 -03:00
from io import TextIOBase
2024-12-14 06:31:20 -03:00
from typing import Type, Dict, TextIO
2024-11-28 15:15:05 -03:00
import yaml
2025-01-17 08:34:49 -03:00
from typing_extensions import Generic, overload, TypeVar
from benchmarks.core.pydantic import ConfigModel
2025-01-17 08:34:49 -03:00
T = TypeVar("T")
2025-01-17 08:34:49 -03:00
class Builder(ConfigModel, Generic[T]):
2025-01-17 08:34:49 -03:00
""":class:`Builder` is a configuration model that can build useful objects."""
@abstractmethod
2025-01-17 08:34:49 -03:00
def build(self) -> T:
pass
2024-11-28 15:15:05 -03:00
2025-01-17 08:34:49 -03:00
TBuilder = TypeVar("TBuilder", bound=Builder)
class ConfigParser(Generic[TBuilder]):
2024-11-28 15:15:05 -03:00
"""
2025-01-17 08:34:49 -03:00
:class:`ConfigParser` is a utility class to parse configuration files into :class:`Builder`s.
Currently, each :class:`Builder` type can appear at most once in the config file.
2024-11-28 15:15:05 -03:00
"""
def __init__(self, ignore_unknown: bool = True) -> None:
self.experiment_types: Dict[str, Type[TBuilder]] = {}
self.ignore_unknown = ignore_unknown
2024-11-28 15:15:05 -03:00
2025-01-17 08:34:49 -03:00
def register(self, root: Type[TBuilder]):
self.experiment_types[root.alias()] = root
2024-11-28 15:15:05 -03:00
@overload
2025-01-17 08:34:49 -03:00
def parse(self, data: dict) -> Dict[str, TBuilder]: ...
2024-11-28 15:15:05 -03:00
@overload
2025-01-17 08:34:49 -03:00
def parse(self, data: TextIO) -> Dict[str, TBuilder]: ...
2024-11-28 15:15:05 -03:00
2025-01-17 08:34:49 -03:00
def parse(self, data: dict | TextIO) -> Dict[str, TBuilder]:
2024-11-28 15:15:05 -03:00
if isinstance(data, TextIOBase):
entries = yaml.safe_load(os.path.expandvars(data.read()))
else:
entries = data
return {
tag: self.experiment_types[tag].model_validate(config)
2024-11-28 15:15:05 -03:00
for tag, config in entries.items()
if tag in self.experiment_types or not self.ignore_unknown
2024-11-28 15:15:05 -03:00
}