diff --git a/benchmarks/cli.py b/benchmarks/cli.py index 4e6cab1..26962de 100644 --- a/benchmarks/cli.py +++ b/benchmarks/cli.py @@ -26,6 +26,16 @@ def cmd_run(experiments: Dict[str, ExperimentBuilder[Experiment]], args): experiments[args.experiment].build().run() +def cmd_describe(args): + if not args.type: + print(f'Available experiment types are:') + for experiment in config_parser.experiment_types.keys(): + print(f' - {experiment}') + return + + print(config_parser.experiment_types[args.type].schema_json(indent=2)) + + def _parse_config(config: Path) -> Dict[str, ExperimentBuilder[Experiment]]: if not config.exists(): print(f'Config file {config} does not exist.') @@ -52,21 +62,31 @@ def _init_logging(): def main(): parser = argparse.ArgumentParser() - parser.add_argument('config', type=Path, help="Path to the experiment configuration file.") commands = parser.add_subparsers(required=True) - list_cmd = commands.add_parser('list', help='Lists available experiments.') - list_cmd.set_defaults(func=cmd_list) - run_cmd = commands.add_parser('run') + experiments = commands.add_parser('experiments', help='List or run experiments in config file.') + experiments.add_argument('config', type=Path, help='Path to the experiment configuration file.') + experiment_commands = experiments.add_subparsers(required=True) + + list_cmd = experiment_commands.add_parser('list', help='Lists available experiments.') + list_cmd.set_defaults(func=lambda args: cmd_list(_parse_config(args.config), args)) + + run_cmd = experiment_commands.add_parser('run', help='Runs an experiment') run_cmd.add_argument('experiment', type=str, help='Name of the experiment to run.') - run_cmd.set_defaults(func=cmd_run) + run_cmd.set_defaults(func=lambda args: cmd_run(_parse_config(args.config), args)) + + describe = commands.add_parser('describe', help='Shows the JSON schema for the various experiment types.') + describe.add_argument('type', type=str, help='Type of the experiment to describe.', + choices=config_parser.experiment_types.keys(), nargs='?') + + describe.set_defaults(func=cmd_describe) args = parser.parse_args() _init_logging() - args.func(_parse_config(args.config), args) + args.func(args) if __name__ == '__main__': diff --git a/benchmarks/core/config.py b/benchmarks/core/config.py index c0da879..d88237f 100644 --- a/benchmarks/core/config.py +++ b/benchmarks/core/config.py @@ -57,13 +57,13 @@ class ConfigParser: """ def __init__(self): - self.root_tags = {} + self.experiment_types = {} def register(self, root: Type[ExperimentBuilder[TExperiment]]): name = root.__name__ alias = cast(Callable[[str], str], root.model_config.get('alias_generator', lambda x: x))(name) - self.root_tags[alias] = root + self.experiment_types[alias] = root @overload def parse(self, data: dict) -> Dict[str, ExperimentBuilder[TExperiment]]: @@ -80,6 +80,6 @@ class ConfigParser: entries = data return { - tag: self.root_tags[tag].model_validate(config) + tag: self.experiment_types[tag].model_validate(config) for tag, config in entries.items() }