Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/134/interfaces #135

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/rail/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from rail.core import __version__
from rail.cli import options, scripts
from rail.interfaces.pz_factory import PZFactory
import ceci


@click.group()
Expand Down Expand Up @@ -74,3 +76,27 @@ def get_data(verbose, **kwargs):
"""Downloads data from NERSC (if not already found)"""
scripts.get_data(verbose, **kwargs)
return 0


@cli.command()
@options.stage_name()
@options.stage_class()
@options.stage_module()
@options.model_file()
@options.dry_run()
@options.input_file()
def estimate(stage_name, stage_class, stage_module, model_file, dry_run, input_file):
"""Run a pz estimation stage"""
stage = PZFactory.build_cat_estimator_stage(
stage_name=stage_name,
class_name=stage_class,
module_name=stage_module,
model_path=model_file,
data_path='dummy.in',
)

output = PZFactory.run_cat_estimator_stage(
stage,
data_path=input_file,
)

63 changes: 63 additions & 0 deletions src/rail/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"dry_run",
"outdir",
"from_source",
"model_file",
"git_mode",
"print_all",
"print_packages",
Expand All @@ -19,6 +20,10 @@
"print_stages",
"package_file",
"skip",
"stage_class",
"stage_module",
"stage_name",
"stages_config",
"inputs",
"verbose_download",
]
Expand Down Expand Up @@ -98,6 +103,37 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: # pragma: no cover
help="Output directory.",
)

output_yaml = PartialOption(
"--output_yaml",
type=click.Path(),
default=None,
help="Path for output yaml file",
)

pipeline_class = PartialOption(
"--pipeline_class",
type=str,
help="Full class name for pipeline, e.g., rail.pipelines.estimation.train_z.TrainZPipeline",
)

model_file = PartialOption(
"--model_file",
type=str,
help="Model for pz estimation",
)

input_file = PartialOption(
"--input_file",
type=str,
help="Input data file for pz estimation",
)

pipeline_yaml = PartialOption(
"--pipeline_yaml",
type=click.Path(),
help="Yaml for that defines pipeline",
)

git_mode = PartialOption(
"--git-mode",
type=EnumChoice(GitMode),
Expand Down Expand Up @@ -155,6 +191,33 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: # pragma: no cover
help="Skip files",
)

stage_class = PartialOption(
"--stage_class",
type=str,
help="Name of a pipeline stage python class",
)

stage_module = PartialOption(
"--stage_module",
type=str,
help="Import path for a python module",
)

stage_name = PartialOption(
"--stage_name",
type=str,
help="Name of a pipeline stage",
)

stages_config = PartialOption(
"--stages_config",
type=str,
help="Stage config file",
default=None,
)



inputs = PartialArgument("inputs", nargs=-1)

verbose_download = PartialOption(
Expand Down
7 changes: 7 additions & 0 deletions src/rail/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

from .pz_factory import PZFactory


__all__ = [
"PZFactory",
]
130 changes: 130 additions & 0 deletions src/rail/interfaces/pz_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@

from qp import Ensemble
from ceci.stage import PipelineStage
from rail.core.stage import RailStage
from rail.core.data import DataHandle
from rail.estimation.estimator import CatEstimator


class PZFactory:
""" Factory class to provide a unified interface to
rail p(z) estimation algorithms.
"""

_stage_dict = {}

@classmethod
def reset(cls):
""" Reset the dictionary of cached stage objects """
cls._stage_dict = {}

@classmethod
def build_cat_estimator_stage(
cls,
stage_name: str,
class_name: str,
module_name: str,
model_path: str,
data_path: str = 'none',
**config_params: dict,
) -> CatEstimator:
""" Build and configure an estimator that can evalute
p(z) given an input catalog

Parameters
----------
stage_name: str
Name of the stage instance, used to construct output file name

class_name: str
Name of the class, e.g., TrainZEstimator, used to find the class

module_name: str
Name of the python module that constains the class, used to import it

model_path: str
Path to the model file used by this estimator

data_path: str
Path to the input data, defaults to 'none'

config_params: dict
Configuration parameters for the stage

Returns
-------
stage_obj: CatEstimator
Newly constructed and configured Estimator instance
"""
stage_class = PipelineStage.get_stage(class_name, module_name)
stage_obj = stage_class.make_stage(name=stage_name, model=model_path, input=data_path, **config_params)
cls._stage_dict[stage_name] = stage_obj
return stage_obj

@classmethod
def get_cat_estimator_stage(
cls,
stage_name: str,
) -> CatEstimator:
""" Return a cached p(z) estimator """
try:
return cls._stage_dict[stage_name]
except KeyError as msg:
raise KeyError(
f"Could not find stage named {stage_name}, did you build it?"
f"Existing stages are: {list(cls._stage_dict.keys())}"
) from msg

@staticmethod
def run_cat_estimator_stage(
stage_obj: CatEstimator,
data_path: str,
) -> DataHandle:
""" Run a p(z) estimator on an input data file

Parameters
----------
stage_obj: CatEstimator
Object that will do the estimation

Returns
-------
data_handle: DataHandle
Object that can give access to the data
"""
RailStage.data_store.clear()
handle = stage_obj.get_handle('input', path=data_path, allow_missing=True)
return stage_obj.estimate(handle)

@staticmethod
def estimate_single_pz(
stage_obj: CatEstimator,
data_table: dict,
input_size: int=1,
) -> Ensemble:
""" Run a p(z) estimator on some objects

Parameters
----------
stage_obj: CatEstimator
Object that will do the estimation

data_table: dict
Input data presented as dict of numpy arrays objects

input_size: int
Number of objects in the input table

Returns
-------
pz : qp.Ensemble
Output pz
"""
RailStage.data_store.clear()
if stage_obj.model is None:
stage_obj.open_model(**stage_obj.config)
stage_obj._input_length = input_size
stage_obj._process_chunk(0, input_size, data_table, True)
return stage_obj._output_handle.data


Binary file added tests/interfaces/model_inform_trainz.pkl
Binary file not shown.
51 changes: 51 additions & 0 deletions tests/interfaces/test_pz_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import numpy as np
import pytest

from rail.utils.path_utils import find_rail_file
from rail.interfaces import PZFactory


def test_pz_factory():

stage = PZFactory.build_cat_estimator_stage(
'train_z',
'TrainZEstimator',
'rail.estimation.algos.train_z',
'tests/interfaces/model_inform_trainz.pkl',
'dummy.in',
)

input_file = find_rail_file('examples_data/testdata/validation_10gal.hdf5')

out_single = PZFactory.estimate_single_pz(stage, {'d':np.array([1,1])})
assert out_single.npdf == 1

out_handle = PZFactory.run_cat_estimator_stage(
stage,
input_file,
)

check_stage = PZFactory.get_cat_estimator_stage('train_z')
assert check_stage == stage

with pytest.raises(KeyError):
PZFactory.get_cat_estimator_stage('nope')

PZFactory.reset()
assert not PZFactory._stage_dict

try:
os.unlink('inprogress_output_train_z.hdf5')
except:
pass
try:
os.unlink('output_train_z.hdf5')
except:
pass






Loading