2024-11-27 11:22:07 -03:00
|
|
|
"""Basic utilities for structuring experiment configurations based on Pydantic schemas."""
|
2024-12-14 06:34:11 -03:00
|
|
|
|
2024-11-28 15:15:05 -03:00
|
|
|
import os
|
2024-11-27 11:22:07 -03:00
|
|
|
from abc import abstractmethod
|
2024-11-28 15:15:05 -03:00
|
|
|
from io import TextIOBase
|
2024-12-14 06:31:20 -03:00
|
|
|
from typing import Type, Dict, TextIO
|
2024-11-27 11:22:07 -03:00
|
|
|
|
2024-11-28 15:15:05 -03:00
|
|
|
import yaml
|
2025-01-17 08:34:49 -03:00
|
|
|
from typing_extensions import Generic, overload, TypeVar
|
2024-11-27 11:22:07 -03:00
|
|
|
|
2024-12-11 13:52:55 -03:00
|
|
|
from benchmarks.core.pydantic import SnakeCaseModel
|
2024-11-27 11:22:07 -03:00
|
|
|
|
2025-01-17 08:34:49 -03:00
|
|
|
T = TypeVar("T")
|
2024-11-27 11:22:07 -03:00
|
|
|
|
2025-01-17 08:34:49 -03:00
|
|
|
|
|
|
|
|
class Builder(SnakeCaseModel, Generic[T]):
|
|
|
|
|
""":class:`Builder` is a configuration model that can build useful objects."""
|
2024-11-27 11:22:07 -03:00
|
|
|
|
|
|
|
|
@abstractmethod
|
2025-01-17 08:34:49 -03:00
|
|
|
def build(self) -> T:
|
2024-11-27 11:22:07 -03:00
|
|
|
pass
|
2024-11-28 15:15:05 -03:00
|
|
|
|
|
|
|
|
|
2025-01-17 08:34:49 -03:00
|
|
|
TBuilder = TypeVar("TBuilder", bound=Builder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigParser(Generic[TBuilder]):
|
2024-11-28 15:15:05 -03:00
|
|
|
"""
|
2025-01-17 08:34:49 -03:00
|
|
|
:class:`ConfigParser` is a utility class to parse configuration files into :class:`Builder`s.
|
|
|
|
|
Currently, each :class:`Builder` type can appear at most once in the config file.
|
2024-11-28 15:15:05 -03:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
2024-12-10 13:55:13 -03:00
|
|
|
self.experiment_types = {}
|
2024-11-28 15:15:05 -03:00
|
|
|
|
2025-01-17 08:34:49 -03:00
|
|
|
def register(self, root: Type[TBuilder]):
|
2024-12-11 13:52:55 -03:00
|
|
|
self.experiment_types[root.alias()] = root
|
2024-11-28 15:15:05 -03:00
|
|
|
|
|
|
|
|
@overload
|
2025-01-17 08:34:49 -03:00
|
|
|
def parse(self, data: dict) -> Dict[str, TBuilder]: ...
|
2024-11-28 15:15:05 -03:00
|
|
|
|
|
|
|
|
@overload
|
2025-01-17 08:34:49 -03:00
|
|
|
def parse(self, data: TextIO) -> Dict[str, TBuilder]: ...
|
2024-11-28 15:15:05 -03:00
|
|
|
|
2025-01-17 08:34:49 -03:00
|
|
|
def parse(self, data: dict | TextIO) -> Dict[str, TBuilder]:
|
2024-11-28 15:15:05 -03:00
|
|
|
if isinstance(data, TextIOBase):
|
|
|
|
|
entries = yaml.safe_load(os.path.expandvars(data.read()))
|
|
|
|
|
else:
|
|
|
|
|
entries = data
|
|
|
|
|
|
|
|
|
|
return {
|
2024-12-10 13:55:13 -03:00
|
|
|
tag: self.experiment_types[tag].model_validate(config)
|
2024-11-28 15:15:05 -03:00
|
|
|
for tag, config in entries.items()
|
|
|
|
|
}
|