Skip to content

Commit

Permalink
Added config params for estimate to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Aug 20, 2024
1 parent 2b75d6e commit 1d5c17e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/rail/cli/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import click
import pprint
import yaml

from rail.core import __version__
from rail.cli import options, scripts
Expand Down Expand Up @@ -85,22 +87,46 @@ def get_data(verbose, **kwargs):
@options.stage_name()
@options.stage_class()
@options.stage_module()
@options.stages_config()
@options.model_file()
@options.catalog_tag()
@options.dry_run()
@options.input_file()
def estimate(stage_name, stage_class, stage_module, model_file, catalog_tag, dry_run, input_file):
@options.params()
def estimate(stage_name, stage_class, stage_module, stages_config, model_file, catalog_tag, dry_run, input_file, params):
"""Run a pz estimation stage"""
if catalog_tag:
catalog_utils.apply_defaults(catalog_tag)

if stages_config not in [None, 'none', 'None']:
with open(stages_config) as fin:
config_data = yaml.safe_load(fin)
if stage_name in config_data:
kwargs = config_data[stage_name]
elif isinstance(config_data, dict):
kwargs = config_data
else:
raise ValueError(f"Config file {stages_config} is not properly constructed")
else:
kwargs = {}

kwargs.update(**options.args_to_dict(params))

stage = PZFactory.build_cat_estimator_stage(
stage_name=stage_name,
class_name=stage_class,
module_name=stage_module,
model_path=model_file,
data_path='dummy.in',
**kwargs,
)

if dry_run:
print(f"Running stage {stage.name} of type {type(stage)} on {input_file}")
print("Stage config: ")
pprint.pprint(stage.config.to_dict())
return 0

output = PZFactory.run_cat_estimator_stage(
stage,
data_path=input_file,
Expand Down
16 changes: 16 additions & 0 deletions src/rail/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import click

__all__ = [
"args_to_dict",
"clear_output",
"bpz_demo_data",
"catalog_tag",
Expand All @@ -14,6 +15,7 @@
"from_source",
"model_file",
"git_mode",
"params",
"pipeline_class",
"pipeline_yaml",
"print_all",
Expand Down Expand Up @@ -58,6 +60,18 @@ def convert(self, value: Any, param, ctx) -> EnumType_co: # pragma: no cover
return self._enum.__members__[converted_str]


def args_to_dict(args):
"""Convert a series of command line key=value statements
to a dict"""
out_dict = {}
for arg_ in args:
tokens = arg_.split('=')
if len(tokens) != 2:
raise ValueError(f"Poorly formed argument {arg_}. Should by key=value")
out_dict[tokens[0]] = tokens[1]
return out_dict


class PartialOption:
"""Wraps click.option with partial arguments for convenient reuse"""

Expand Down Expand Up @@ -237,6 +251,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: # pragma: no cover

inputs = PartialArgument("inputs", nargs=-1)

params = PartialArgument("params", nargs=-1)

verbose_download = PartialOption(
"-v", "--verbose", help="Verbose output when downloading", is_flag=True
)
Expand Down

0 comments on commit 1d5c17e

Please sign in to comment.