diff --git a/pyproject.toml b/pyproject.toml index 19ac9b16..a77a7e2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ disable = [ "abstract-method", "invalid-name", "too-many-statements", + "too-many-arguments", "missing-module-docstring", "missing-class-docstring", "missing-function-docstring", diff --git a/src/rail/cli/commands.py b/src/rail/cli/commands.py index 46bb9a84..3f168361 100644 --- a/src/rail/cli/commands.py +++ b/src/rail/cli/commands.py @@ -11,50 +11,50 @@ def cli() -> None: @cli.command() -@options.outdir(default='docs') +@options.outdir(default="docs") @options.clear_output() @options.dry_run() @options.inputs() @options.skip() -def render_nb(outdir, clear_output, dry_run, inputs, skip, **kwargs): +def render_nb(outdir, clear_output, dry_run, inputs, skip, **_kwargs): """Render jupyter notebooks""" return scripts.render_nb(outdir, clear_output, dry_run, inputs, skip) @cli.command() -@options.outdir(default='..') +@options.outdir(default="..") @options.git_mode() @options.dry_run() @options.package_file() -def clone_source(outdir, git_mode, dry_run, package_file, **kwargs): +def clone_source(outdir, git_mode, dry_run, package_file, **_kwargs): """Install packages from source""" scripts.clone_source(outdir, git_mode, dry_run, package_file) return 0 - - + + @cli.command() -@options.outdir(default='..') +@options.outdir(default="..") @options.dry_run() @options.package_file() -def update_source(outdir, dry_run, package_file, **kwargs): +def update_source(outdir, dry_run, package_file, **_kwargs): """Update packages from source""" scripts.update_source(outdir, dry_run, package_file) return 0 - - + + @cli.command() -@options.outdir(default='..') +@options.outdir(default="..") @options.dry_run() @options.from_source() @options.package_file() -def install(outdir, dry_run, from_source, package_file, **kwargs): +def install(outdir, dry_run, from_source, package_file, **_kwargs): """Install rail packages one by one, to be fault tolerant""" scripts.install(outdir, from_source, dry_run, package_file) return 0 - + @cli.command() -@options.outdir(default='..') +@options.outdir(default="..") @options.print_all() @options.print_packages() @options.print_namespaces() diff --git a/src/rail/cli/options.py b/src/rail/cli/options.py index 473bd416..ce736edb 100644 --- a/src/rail/cli/options.py +++ b/src/rail/cli/options.py @@ -1,6 +1,6 @@ import enum -from functools import partial, wraps -from typing import Any, Callable, Type, TypeVar, cast +from functools import partial +from typing import Any, Type, TypeVar import click @@ -20,12 +20,13 @@ "package_file", "skip", "inputs", - "verbose_download" + "verbose_download", ] class GitMode(enum.Enum): """Choose git clone mode""" + ssh = 0 https = 1 cli = 2 @@ -37,9 +38,9 @@ class GitMode(enum.Enum): class EnumChoice(click.Choice): """A version of click.Choice specialized for enum types""" - def __init__(self, enum: EnumType_co, case_sensitive: bool = True) -> None: - self._enum = enum - super().__init__(list(enum.__members__.keys()), case_sensitive=case_sensitive) + def __init__(self, the_enum: EnumType_co, case_sensitive: bool = True) -> None: + self._enum = the_enum + super().__init__(list(the_enum.__members__.keys()), case_sensitive=case_sensitive) def convert(self, value: Any, param, ctx) -> EnumType_co: converted_str = super().convert(value, param, ctx) @@ -50,7 +51,9 @@ class PartialOption: """Wraps click.option with partial arguments for convenient reuse""" def __init__(self, *param_decls: Any, **kwargs: Any) -> None: - self._partial = partial(click.option, *param_decls, cls=partial(click.Option), **kwargs) + self._partial = partial( + click.option, *param_decls, cls=partial(click.Option), **kwargs + ) def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._partial(*args, **kwargs) @@ -60,7 +63,9 @@ class PartialArgument: """Wraps click.argument with partial arguments for convenient reuse""" def __init__(self, *param_decls: Any, **kwargs: Any) -> None: - self._partial = partial(click.argument, *param_decls, cls=partial(click.Argument), **kwargs) + self._partial = partial( + click.argument, *param_decls, cls=partial(click.Argument), **kwargs + ) def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._partial(*args, **kwargs) @@ -94,7 +99,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: git_mode = PartialOption( "--git-mode", type=EnumChoice(GitMode), - default='ssh', + default="ssh", help="Git Clone mode", ) @@ -148,20 +153,15 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: help="Skip files", ) -inputs = PartialArgument( - "inputs", - nargs=-1 -) +inputs = PartialArgument("inputs", nargs=-1) verbose_download = PartialOption( - "-v", - "--verbose", - help="Verbose output when downloading", - is_flag=True + "-v", "--verbose", help="Verbose output when downloading", is_flag=True ) bpz_demo_data = PartialOption( "--bpz-demo-data", - help="Download data that is explicitly only for use in the bpz demo and nowhere else (it is dummy data that will not make sense)", - is_flag=True + help="Download data that is explicitly only for use in the bpz demo and nowhere else" + "(it is dummy data that will not make sense)", + is_flag=True, ) diff --git a/src/rail/cli/scripts.py b/src/rail/cli/scripts.py index 78e6616a..7a29852c 100644 --- a/src/rail/cli/scripts.py +++ b/src/rail/cli/scripts.py @@ -1,6 +1,4 @@ -import sys import os -import glob import yaml import rail.stages @@ -9,8 +7,7 @@ from rail.core.utils import RAILDIR -def render_nb(outdir, clear_output, dry_run, inputs, skip, **kwargs): - +def render_nb(outdir, clear_output, dry_run, inputs, skip, **_kwargs): command = "jupyter nbconvert" options = "--to html" @@ -19,10 +16,10 @@ def render_nb(outdir, clear_output, dry_run, inputs, skip, **kwargs): for nb_file in inputs: if nb_file in skip: continue - subdir = os.path.dirname(nb_file).split('/')[-1] + subdir = os.path.dirname(nb_file).split("/")[-1] basename = os.path.splitext(os.path.basename(nb_file))[0] - outfile = os.path.join('..', '..', outdir, f"{subdir}/{basename}.html") - relpath = os.path.join(outdir, f'{subdir}') + outfile = os.path.join("..", "..", outdir, f"{subdir}/{basename}.html") + relpath = os.path.join(outdir, f"{subdir}") try: print(relpath) @@ -50,18 +47,17 @@ def render_nb(outdir, clear_output, dry_run, inputs, skip, **kwargs): if failed_notebooks: raise ValueError(f"The following notebooks failed {str(failed_notebooks)}") - -def clone_source(outdir, git_mode, dry_run, package_file): - with open(package_file) as pfile: +def clone_source(outdir, git_mode, dry_run, package_file): + with open(package_file, encoding='utf-8') as pfile: package_dict = yaml.safe_load(pfile) - for key, val in package_dict.items(): + for key, _val in package_dict.items(): if os.path.exists(f"{outdir}/{key}"): print(f"Skipping existing {outdir}/{key}") continue - + if git_mode == GitMode.ssh: com_line = f"git clone https://github.com/LSSTDESC/{key}.git {outdir}/{key}" elif git_mode == GitMode.https: @@ -74,43 +70,40 @@ def clone_source(outdir, git_mode, dry_run, package_file): else: os.system(com_line) - -def update_source(outdir, dry_run, package_file): - with open(package_file) as pfile: +def update_source(outdir, dry_run, package_file): + with open(package_file, encoding='utf-8') as pfile: package_dict = yaml.safe_load(pfile) - currentpath = os.path.abspath('.') - for key, val in package_dict.items(): + currentpath = os.path.abspath(".") + for key, _val in package_dict.items(): abspath = os.path.abspath(f"{outdir}/{key}") if os.path.exists(f"{outdir}/{key}") is not True: print(f"Package {outdir}/{key} does not exist!") - continue - + continue + com_line = f"cd {abspath} && git pull && cd {currentpath}" if dry_run: print(com_line) else: os.system(com_line) - -def install(outdir, from_source, dry_run, package_file): - with open(package_file) as pfile: +def install(outdir, from_source, dry_run, package_file): + with open(package_file, encoding='utf-8') as pfile: package_dict = yaml.safe_load(pfile) for key, val in package_dict.items(): - if not from_source: com_line = f"pip install {val}" - else: + else: if not os.path.exists(f"{outdir}/{key}"): print(f"Skipping missing {outdir}/{key}") continue com_line = f"pip install -e {outdir}/{key}" - + if dry_run: print(com_line) else: @@ -118,27 +111,26 @@ def install(outdir, from_source, dry_run, package_file): def info(**kwargs): - rail.stages.import_and_attach_all() - print_all = kwargs.get('print_all', False) - if kwargs.get('print_packages') or print_all: + print_all = kwargs.get("print_all", False) + if kwargs.get("print_packages") or print_all: print("======= Printing RAIL packages ==============") RailEnv.print_rail_packages() print("\n\n") - if kwargs.get('print_namespaces') or print_all: + if kwargs.get("print_namespaces") or print_all: print("======= Printing RAIL namespaces ==============") RailEnv.print_rail_namespaces() print("\n\n") - if kwargs.get('print_modules') or print_all: + if kwargs.get("print_modules") or print_all: print("======= Printing RAIL modules ==============") RailEnv.print_rail_modules() print("\n\n") - if kwargs.get('print_tree') or print_all: + if kwargs.get("print_tree") or print_all: print("======= Printing RAIL source tree ==============") RailEnv.print_rail_namespace_tree() print("\n\n") - if kwargs.get('print_stages') or print_all: + if kwargs.get("print_stages") or print_all: print("======= Printing RAIL stages ==============") RailEnv.print_rail_stage_dict() print("\n\n") @@ -146,20 +138,25 @@ def info(**kwargs): def get_data(verbose, **kwargs): # pragma: no cover if kwargs.get("bpz_demo_data"): - # The bpz demo data is quarantined into its own flag, as it contains some + # The bpz demo data is quarantined into its own flag, as it contains some # non-physical features that would add systematics if run on any real data. # This data should NOT be used for any science with real data! bpz_local_abs_path = os.path.join( - RAILDIR, "rail/examples_data/estimation_data/data/nonphysical_dc2_templates.tar" + RAILDIR, + "rail/examples_data/estimation_data/data/nonphysical_dc2_templates.tar", + ) + bpz_remote_path = ( + "https://portal.nersc.gov/cfs/lsst/PZ/nonphysical_dc2_templates.tar" ) - bpz_remote_path = "https://portal.nersc.gov/cfs/lsst/PZ/nonphysical_dc2_templates.tar" print(f"Check for bpz demo data: {bpz_local_abs_path}") if not os.path.exists(bpz_local_abs_path): os.system(f"curl -o {bpz_local_abs_path} {bpz_remote_path} --create-dirs") print("Downloaded bpz demo data.") else: print("Already have bpz demo data.") - print("\n(Note: you can run get-data without the bpz-demo-data flag to download standard data.)") + print( + "\n(Note: you can run get-data without the bpz-demo-data flag to download standard data.)" + ) else: data_files = [ @@ -171,6 +168,10 @@ def get_data(verbose, **kwargs): # pragma: no cover for data_file in data_files: local_abs_path = os.path.join(RAILDIR, data_file["local_path"]) if verbose: - print(f"Check file exists: {local_abs_path} ({os.path.exists(local_abs_path)})") + print( + f"Check file exists: {local_abs_path} ({os.path.exists(local_abs_path)})" + ) if not os.path.exists(local_abs_path): - os.system(f'curl -o {local_abs_path} {data_file["remote_path"]} --create-dirs') + os.system( + f'curl -o {local_abs_path} {data_file["remote_path"]} --create-dirs' + ) diff --git a/src/rail/core/__init__.py b/src/rail/core/__init__.py index 7a86fd90..3da2451e 100644 --- a/src/rail/core/__init__.py +++ b/src/rail/core/__init__.py @@ -1,23 +1,27 @@ """Core code for RAIL""" import pkgutil +import os import setuptools import rail -import os + +from .stage import RailPipeline, RailStage + +# from .utilPhotometry import PhotormetryManipulator, HyperbolicSmoothing, HyperbolicMagnitudes +from .util_stages import ColumnMapper, RowSelector, TableConverter +from .introspection import RailEnv +from .point_estimation import PointEstimationMixin + def find_version(): """Find the version""" # setuptools_scm should install a # file _version alongside this one. - from . import _version + from . import _version # pylint: disable=import-outside-toplevel + return _version.version + try: __version__ = find_version() -except ImportError: # pragma: no cover +except ImportError: # pragma: no cover __version__ = "unknown" - -from .stage import RailPipeline, RailStage -#from .utilPhotometry import PhotormetryManipulator, HyperbolicSmoothing, HyperbolicMagnitudes -from .util_stages import ColumnMapper, RowSelector, TableConverter -from .introspection import RailEnv -from .point_estimation import PointEstimationMixin diff --git a/src/rail/core/algo_utils.py b/src/rail/core/algo_utils.py index 8f2c2641..fff1f002 100644 --- a/src/rail/core/algo_utils.py +++ b/src/rail/core/algo_utils.py @@ -1,19 +1,27 @@ """Utility functions to test alogrithms""" import os +import scipy.special from rail.core.stage import RailStage from rail.core.utils import RAILDIR from rail.core.data import TableHandle -import scipy.special -sci_ver_str = scipy.__version__.split('.') + +sci_ver_str = scipy.__version__.split(".") -traindata = os.path.join(RAILDIR, 'rail/examples_data/testdata/training_100gal.hdf5') -validdata = os.path.join(RAILDIR, 'rail/examples_data/testdata/validation_10gal.hdf5') +traindata = os.path.join(RAILDIR, "rail/examples_data/testdata/training_100gal.hdf5") +validdata = os.path.join(RAILDIR, "rail/examples_data/testdata/validation_10gal.hdf5") DS = RailStage.data_store DS.__class__.allow_overwrite = True -def one_algo(key, single_trainer, single_estimator, train_kwargs, estim_kwargs, is_classifier=False): +def one_algo( + key, + single_trainer, + single_estimator, + train_kwargs, + estim_kwargs, + is_classifier=False, +): """ A basic test of running an estimator subclass Run inform, write temporary trained model to @@ -22,17 +30,17 @@ def one_algo(key, single_trainer, single_estimator, train_kwargs, estim_kwargs, both datasets. """ DS.clear() - training_data = DS.read_file('training_data', TableHandle, traindata) - validation_data = DS.read_file('validation_data', TableHandle, validdata) + training_data = DS.read_file("training_data", TableHandle, traindata) + validation_data = DS.read_file("validation_data", TableHandle, validdata) if single_trainer is not None: train_pz = single_trainer.make_stage(**train_kwargs) train_pz.inform(training_data) pz = single_estimator.make_stage(name=key, **estim_kwargs) - if is_classifier==False: + if not is_classifier: estim = pz.estimate(validation_data) - elif is_classifier==True: #pragma: no cover + elif is_classifier: # pragma: no cover estim = pz.classify(validation_data) pz_2 = None estim_2 = estim @@ -40,35 +48,37 @@ def one_algo(key, single_trainer, single_estimator, train_kwargs, estim_kwargs, estim_3 = estim copy_estim_kwargs = estim_kwargs.copy() - model_file = copy_estim_kwargs.pop('model', 'None') + model_file = copy_estim_kwargs.pop("model", "None") - if model_file != 'None': - copy_estim_kwargs['model'] = model_file + if model_file != "None": + copy_estim_kwargs["model"] = model_file pz_2 = single_estimator.make_stage(name=f"{pz.name}_copy", **copy_estim_kwargs) - if is_classifier==False: + if not is_classifier: estim_2 = pz_2.estimate(validation_data) - elif is_classifier==True: #pragma: no cover + elif is_classifier: # pragma: no cover estim_2 = pz_2.classify(validation_data) - if single_trainer is not None and 'model' in single_trainer.output_tags(): + if single_trainer is not None and "model" in single_trainer.output_tags(): copy3_estim_kwargs = estim_kwargs.copy() - copy3_estim_kwargs['model'] = train_pz.get_handle('model') - pz_3 = single_estimator.make_stage(name=f"{pz.name}_copy3", **copy3_estim_kwargs) - if is_classifier==False: + copy3_estim_kwargs["model"] = train_pz.get_handle("model") + pz_3 = single_estimator.make_stage( + name=f"{pz.name}_copy3", **copy3_estim_kwargs + ) + if not is_classifier: estim_3 = pz_3.estimate(validation_data) - elif is_classifier==True: #pragma: no cover + elif is_classifier: # pragma: no cover estim_3 = pz_3.classify(validation_data) - os.remove(pz.get_output(pz.get_aliased_tag('output'), final_name=True)) + os.remove(pz.get_output(pz.get_aliased_tag("output"), final_name=True)) if pz_2 is not None: - os.remove(pz_2.get_output(pz_2.get_aliased_tag('output'), final_name=True)) + os.remove(pz_2.get_output(pz_2.get_aliased_tag("output"), final_name=True)) if pz_3 is not None: - os.remove(pz_3.get_output(pz_3.get_aliased_tag('output'), final_name=True)) - model_file = estim_kwargs.get('model', 'None') - if model_file != 'None': + os.remove(pz_3.get_output(pz_3.get_aliased_tag("output"), final_name=True)) + model_file = estim_kwargs.get("model", "None") + if model_file != "None": try: os.remove(model_file) - except FileNotFoundError: #pragma: no cover + except FileNotFoundError: # pragma: no cover pass return estim.data, estim_2.data, estim_3.data diff --git a/src/rail/core/common_params.py b/src/rail/core/common_params.py index 2113624f..6356bf7b 100644 --- a/src/rail/core/common_params.py +++ b/src/rail/core/common_params.py @@ -3,33 +3,47 @@ from ceci.config import StageParameter as Param from ceci.config import StageConfig -lsst_bands = 'ugrizy' -lsst_mag_cols = [f'mag_{band}_lsst' for band in lsst_bands] -lsst_mag_err_cols = [f'mag_err_{band}_lsst' for band in lsst_bands] +lsst_bands = "ugrizy" +lsst_mag_cols = [f"mag_{band}_lsst" for band in lsst_bands] +lsst_mag_err_cols = [f"mag_err_{band}_lsst" for band in lsst_bands] lsst_def_maglims = dict( mag_u_lsst=27.79, mag_g_lsst=29.04, mag_r_lsst=29.06, mag_i_lsst=28.62, mag_z_lsst=27.98, - mag_y_lsst=27.05 + mag_y_lsst=27.05, ) SHARED_PARAMS = StageConfig( - hdf5_groupname=Param(str, "photometry", msg="name of hdf5 group for data, if None, then set to ''"), + hdf5_groupname=Param( + str, "photometry", msg="name of hdf5 group for data, if None, then set to ''" + ), zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), dz=Param(float, 0.01, msg="delta z in grid"), - nondetect_val=Param(float, 99.0, msg="value to be replaced with magnitude limit for non detects"), - bands=Param(list, lsst_mag_cols, msg="Names of columns for magnitgude by filter band"), - err_bands=Param(list, lsst_mag_err_cols, msg="Names of columns for magnitgude errors by filter band"), + nondetect_val=Param( + float, 99.0, msg="value to be replaced with magnitude limit for non detects" + ), + bands=Param( + list, lsst_mag_cols, msg="Names of columns for magnitgude by filter band" + ), + err_bands=Param( + list, + lsst_mag_err_cols, + msg="Names of columns for magnitgude errors by filter band", + ), mag_limits=Param(dict, lsst_def_maglims, msg="Limiting magnitdues by filter"), ref_band=Param(str, "mag_i_lsst", msg="band to use in addition to colors"), - redshift_col=Param(str, 'redshift', msg="name of redshift column"), - calculated_point_estimates=Param(dtype=list, default=[], - msg="List of strings defining which point estimates to automatically calculate using `qp.Ensemble`. Options include, 'mean', 'mode', 'median'.") + redshift_col=Param(str, "redshift", msg="name of redshift column"), + calculated_point_estimates=Param( + dtype=list, + default=[], + msg="List of strings defining which point estimates to automatically calculate using `qp.Ensemble`." + "Options include, 'mean', 'mode', 'median'.", + ), ) @@ -43,10 +57,10 @@ def set_param_default(param_name, default_value): try: SHARED_PARAMS.get(param_name).set_default(default_value) except AttributeError as msg: # pragma: no cover - raise KeyError(f"No shared parameter {param_name} in SHARED_PARAMS") + raise KeyError(f"No shared parameter {param_name} in SHARED_PARAMS") from msg + def set_param_defaults(**kwargs): # pragma: no cover """Change the default value of several of the shared parameters""" for key, val in kwargs.items(): set_param_default(key, val) - diff --git a/src/rail/core/data.py b/src/rail/core/data.py index 73739ea1..a16779bc 100644 --- a/src/rail/core/data.py +++ b/src/rail/core/data.py @@ -1,12 +1,11 @@ """Rail-specific data management""" import os -import tables_io import pickle +import tables_io import qp - class DataHandle: """Class to act as a handle for a bit of data. Associating it with a file and providing tools to read & write it to that file @@ -22,10 +21,11 @@ class DataHandle: creator : str or None The name of the stage that created this data handle """ - suffix = '' + + suffix = "" def __init__(self, tag, data=None, path=None, creator=None): - """Constructor """ + """Constructor""" self.tag = tag if data is not None: self._validate_data(data) @@ -52,14 +52,14 @@ def open(self, **kwargs): @classmethod def _open(cls, path, **kwargs): - raise NotImplementedError("DataHandle._open") #pragma: no cover + raise NotImplementedError("DataHandle._open") # pragma: no cover - def close(self, **kwargs): #pylint: disable=unused-argument - """Close """ + def close(self, **kwargs): # pylint: disable=unused-argument + """Close""" self.fileObj = None def read(self, force=False, **kwargs): - """Read and return the data from the associated file """ + """Read and return the data from the associated file""" if self.data is not None and not force: return self.data self.set_data(self._read(os.path.expandvars(self.path), **kwargs)) @@ -73,59 +73,74 @@ def __call__(self, **kwargs): @classmethod def _read(cls, path, **kwargs): - raise NotImplementedError("DataHandle._read") #pragma: no cover + raise NotImplementedError("DataHandle._read") # pragma: no cover def write(self, **kwargs): - """Write the data to the associatied file """ + """Write the data to the associatied file""" if self.path is None: - raise ValueError("TableHandle.write() called but path has not been specified") + raise ValueError( + "TableHandle.write() called but path has not been specified" + ) if self.data is None: - raise ValueError(f"TableHandle.write() called for path {self.path} with no data") + raise ValueError( + f"TableHandle.write() called for path {self.path} with no data" + ) outdir = os.path.dirname(os.path.abspath(os.path.expandvars(self.path))) - if not os.path.exists(outdir): #pragma: no cover + if not os.path.exists(outdir): # pragma: no cover os.makedirs(outdir, exist_ok=True) return self._write(self.data, os.path.expandvars(self.path), **kwargs) @classmethod def _write(cls, data, path, **kwargs): - raise NotImplementedError("DataHandle._write") #pragma: no cover + raise NotImplementedError("DataHandle._write") # pragma: no cover def initialize_write(self, data_length, **kwargs): """Initialize file to be written by chunks""" - if self.path is None: #pragma: no cover - raise ValueError("TableHandle.write() called but path has not been specified") - self.groups, self.fileObj = self._initialize_write(self.data, os.path.expandvars(self.path), data_length, **kwargs) + if self.path is None: # pragma: no cover + raise ValueError( + "TableHandle.write() called but path has not been specified" + ) + self.groups, self.fileObj = self._initialize_write( + self.data, os.path.expandvars(self.path), data_length, **kwargs + ) @classmethod def _initialize_write(cls, data, path, data_length, **kwargs): - raise NotImplementedError("DataHandle._initialize_write") #pragma: no cover + raise NotImplementedError("DataHandle._initialize_write") # pragma: no cover def write_chunk(self, start, end, **kwargs): - """Write the data to the associatied file """ + """Write the data to the associatied file""" if self.data is None: - raise ValueError(f"TableHandle.write_chunk() called for path {self.path} with no data") + raise ValueError( + f"TableHandle.write_chunk() called for path {self.path} with no data" + ) if self.fileObj is None: - raise ValueError(f"TableHandle.write_chunk() called before open for {self.tag} : {self.path}") - return self._write_chunk(self.data, self.fileObj, self.groups, start, end, **kwargs) - + raise ValueError( + f"TableHandle.write_chunk() called before open for {self.tag} : {self.path}" + ) + return self._write_chunk( + self.data, self.fileObj, self.groups, start, end, **kwargs + ) @classmethod def _write_chunk(cls, data, fileObj, groups, start, end, **kwargs): - raise NotImplementedError("DataHandle._write_chunk") #pragma: no cover + raise NotImplementedError("DataHandle._write_chunk") # pragma: no cover def finalize_write(self, **kwargs): """Finalize and close file written by chunks""" - if self.fileObj is None: #pragma: no cover - raise ValueError(f"TableHandle.finalize_wite() called before open for {self.tag} : {self.path}") + if self.fileObj is None: # pragma: no cover + raise ValueError( + f"TableHandle.finalize_wite() called before open for {self.tag} : {self.path}" + ) self._finalize_write(self.data, self.fileObj, **kwargs) @classmethod def _finalize_write(cls, data, fileObj, **kwargs): - raise NotImplementedError("DataHandle._finalize_write") #pragma: no cover + raise NotImplementedError("DataHandle._finalize_write") # pragma: no cover def iterator(self, **kwargs): """Iterator over the data""" - #if self.data is not None: + # if self.data is not None: # for i in range(1): # yield i, -1, self.data return self._iterator(self.path, **kwargs) @@ -137,7 +152,7 @@ def set_data(self, data, partial=False): self.partial = partial @classmethod - def _validate_data(cls, data): + def _validate_data(cls, data): # pylint: disable=unused-argument """Make sure that the right type of data is being passed in""" return @@ -145,27 +160,26 @@ def size(self, **kwargs): """Return the size of the data associated to this handle""" return self._size(self.path, **kwargs) - @classmethod - def _size(cls, path, **kwargs): - raise NotImplementedError("DataHandle._size") #pragma: no cover + def _size(self, path, **kwargs): + raise NotImplementedError("DataHandle._size") # pragma: no cover @classmethod def _iterator(cls, path, **kwargs): - raise NotImplementedError("DataHandle._iterator") #pragma: no cover + raise NotImplementedError("DataHandle._iterator") # pragma: no cover @property def has_data(self): - """Return true if the data for this handle are loaded """ + """Return true if the data for this handle are loaded""" return self.data is not None @property def has_path(self): - """Return true if the path for the associated file is defined """ + """Return true if the path for the associated file is defined""" return self.path is not None @property def is_written(self): - """Return true if the associated file has been written """ + """Return true if the associated file has been written""" if self.path is None: return False return os.path.exists(os.path.expandvars(self.path)) @@ -185,16 +199,15 @@ def __str__(self): @classmethod def make_name(cls, tag): - """Construct and return file name for a particular data tag """ + """Construct and return file name for a particular data tag""" if cls.suffix: return f"{tag}.{cls.suffix}" - else: - return tag #pragma: no cover + return tag # pragma: no cover class TableHandle(DataHandle): - """DataHandle for single tables of data - """ + """DataHandle for single tables of data""" + suffix = None @classmethod @@ -206,20 +219,19 @@ def _open(cls, path, **kwargs): This will simply open the file and return a file-like object to the caller. It will not read or cache the data """ - return tables_io.io.io_open(path, **kwargs) #pylint: disable=no-member + return tables_io.io.io_open(path, **kwargs) # pylint: disable=no-member @classmethod def _read(cls, path, **kwargs): - """Read and return the data from the associated file """ + """Read and return the data from the associated file""" return tables_io.read(path, **kwargs) @classmethod def _write(cls, data, path, **kwargs): - """Write the data to the associatied file """ + """Write the data to the associatied file""" return tables_io.write(data, path, **kwargs) - @classmethod - def _size(cls, path, **kwargs): + def _size(self, path, **kwargs): return tables_io.io.getInputDataLengthHdf5(path, **kwargs) @classmethod @@ -230,13 +242,16 @@ def _iterator(cls, path, **kwargs): class Hdf5Handle(TableHandle): # pragma: no cover """DataHandle for a table written to HDF5""" - suffix = 'hdf5' + + suffix = "hdf5" @classmethod def _initialize_write(cls, data, path, data_length, **kwargs): initial_dict = cls._get_allocation_kwds(data, data_length) - comm = kwargs.get('communicator', None) - group, fout = tables_io.io.initializeHdf5WriteSingle(path, groupname=None, comm=comm, **initial_dict) + comm = kwargs.get("communicator", None) + group, fout = tables_io.io.initializeHdf5WriteSingle( + path, groupname=None, comm=comm, **initial_dict + ) return group, fout @classmethod @@ -256,20 +271,23 @@ def _write_chunk(cls, data, fileObj, groups, start, end, **kwargs): def _finalize_write(cls, data, fileObj, **kwargs): return tables_io.io.finalizeHdf5Write(fileObj, **kwargs) + class FitsHandle(TableHandle): """DataHandle for a table written to fits""" - suffix = 'fits' + + suffix = "fits" class PqHandle(TableHandle): """DataHandle for a parquet table""" - suffix = 'pq' + + suffix = "pq" class QPHandle(DataHandle): - """DataHandle for qp ensembles - """ - suffix = 'hdf5' + """DataHandle for qp ensembles""" + + suffix = "hdf5" @classmethod def _open(cls, path, **kwargs): @@ -280,21 +298,21 @@ def _open(cls, path, **kwargs): This will simply open the file and return a file-like object to the caller. It will not read or cache the data """ - return tables_io.io.io_open(path, **kwargs) #pylint: disable=no-member + return tables_io.io.io_open(path, **kwargs) # pylint: disable=no-member @classmethod def _read(cls, path, **kwargs): - """Read and return the data from the associated file """ + """Read and return the data from the associated file""" return qp.read(path) @classmethod def _write(cls, data, path, **kwargs): - """Write the data to the associatied file """ + """Write the data to the associatied file""" return data.write_to(path) @classmethod def _initialize_write(cls, data, path, data_length, **kwargs): - comm = kwargs.get('communicator', None) + comm = kwargs.get("communicator", None) return data.initializeHdf5Write(path, data_length, comm) @classmethod @@ -308,87 +326,95 @@ def _finalize_write(cls, data, fileObj, **kwargs): @classmethod def _validate_data(cls, data): if not isinstance(data, qp.Ensemble): - raise TypeError(f"Expected `data` to be a `qp.Ensemble`, but {type(data)} was provided. Perhaps you meant to use `TableHandle`?") + raise TypeError( + f"Expected `data` to be a `qp.Ensemble`, but {type(data)} was provided." + "Perhaps you meant to use `TableHandle`?" + ) - # @classmethod - def _size(cls, path, **kwargs): - if path == 'None': - return cls.data.npdf - return tables_io.io.getInputDataLengthHdf5(path, groupname='data') + def _size(self, path, **kwargs): + if path == "None": + return self.data.npdf + return tables_io.io.getInputDataLengthHdf5(path, groupname="data") @classmethod def _iterator(cls, path, **kwargs): """Iterate over the data""" - kwargs.pop('groupname','None') + kwargs.pop("groupname", "None") return qp.iterator(path, **kwargs) + def default_model_read(modelfile): """Default function to read model files, simply used pickle.load""" - return pickle.load(open(modelfile, 'rb')) + return pickle.load(open(modelfile, "rb")) def default_model_write(model, path): """Write the model, this default implementation uses pickle""" - with open(path, 'wb') as fout: + with open(path, "wb", encoding="utf-8") as fout: pickle.dump(obj=model, file=fout, protocol=pickle.HIGHEST_PROTOCOL) class ModelDict(dict): """ - A specialized dict to keep track of individual estimation models objects: this is just a dict these additional features + A specialized dict to keep track of individual estimation models objects: + this is just a dict these additional features 1. Keys are paths 2. There is a read(path, force=False) method that reads a model object and inserts it into the dictionary 3. There is a single static instance of this class """ - def open(self, path, mode, **kwargs): #pylint: disable=no-self-use + + def open(self, path, mode, **kwargs): """Open the file and return the file handle""" - return open(path, mode, **kwargs) + encoding = kwargs.pop('encoding', "utf-8") + return open(path, mode, encoding=encoding, **kwargs) - def read(self, path, force=False, reader=None, **kwargs): #pylint: disable=unused-argument + def read( + self, path, force=False, reader=None, **kwargs + ): # pylint: disable=unused-argument """Read a model into this dict""" if reader is None: reader = default_model_read if force or path not in self: model = reader(path) - self.__setitem__(path, model) + self[path] = model return model return self[path] - def write(self, model, path, force=False, writer=None, **kwargs): #pylint: disable=unused-argument + def write( + self, model, path, force=False, writer=None, **kwargs + ): # pylint: disable=unused-argument """Write the model, this default implementation uses pickle""" if writer is None: writer = default_model_write if force or path not in self or not os.path.exists(path): - self.__setitem__(path, model) + self[path] = model writer(model, path) - class ModelHandle(DataHandle): - """DataHandle for machine learning models - """ - suffix = 'pkl' + """DataHandle for machine learning models""" + + suffix = "pkl" model_factory = ModelDict() @classmethod def _open(cls, path, **kwargs): - """Open and return the associated file - """ + """Open and return the associated file""" kwcopy = kwargs.copy() - if kwcopy.pop('mode', 'r') == 'w': - return cls.model_factory.open(path, mode='wb', **kwcopy) + if kwcopy.pop("mode", "r") == "w": + return cls.model_factory.open(path, mode="wb", **kwcopy) return cls.model_factory.read(path, **kwargs) @classmethod def _read(cls, path, **kwargs): - """Read and return the data from the associated file """ + """Read and return the data from the associated file""" return cls.model_factory.read(path, **kwargs) @classmethod def _write(cls, data, path, **kwargs): - """Write the data to the associatied file """ + """Write the data to the associatied file""" return cls.model_factory.write(data, path, **kwargs) @@ -399,10 +425,11 @@ class DataStore(dict): 1) associates data products with keys 2) provides functions to read and write the various data produces to associated files """ + allow_overwrite = False def __init__(self, **kwargs): - """ Build from keywords + """Build from keywords Note ---- @@ -413,7 +440,7 @@ def __init__(self, **kwargs): self[key] = val def __str__(self): - """ Override __str__ casting to deal with `TableHandle` objects in the map """ + """Override __str__ casting to deal with `TableHandle` objects in the map""" s = "{" for key, val in self.items(): s += f" {key}:{val}\n" @@ -421,58 +448,63 @@ def __str__(self): return s def __repr__(self): - """ A custom representation """ + """A custom representation""" s = "DataStore\n" s += self.__str__() return s def __setitem__(self, key, value): - """ Override the __setitem__ to work with `TableHandle` """ + """Override the __setitem__ to work with `TableHandle`""" if not isinstance(value, DataHandle): - raise TypeError(f"Can only add objects of type DataHandle to DataStore, not {type(value)}") + raise TypeError( + f"Can only add objects of type DataHandle to DataStore, not {type(value)}" + ) check = self.get(key) if check is not None and not self.allow_overwrite: - raise ValueError(f"DataStore already has an item with key {key}, of type {type(check)}, created by {check.creator}") + raise ValueError( + f"DataStore already has an item with key {key}," + "of type {type(check)}, created by {check.creator}" + ) dict.__setitem__(self, key, value) return value def __getattr__(self, key): - """ Allow attribute-like parameter access """ + """Allow attribute-like parameter access""" try: return self.__getitem__(key) except KeyError as msg: # Kludge to get docstrings to work - if key in ['__objclass__']: #pragma: no cover + if key in ["__objclass__"]: # pragma: no cover return None raise KeyError from msg def __setattr__(self, key, value): - """ Allow attribute-like parameter setting """ + """Allow attribute-like parameter setting""" return self.__setitem__(key, value) - def add_data(self, key, data, handle_class, path=None, creator='DataStore'): - """ Create a handle for some data, and insert it into the DataStore """ + def add_data(self, key, data, handle_class, path=None, creator="DataStore"): + """Create a handle for some data, and insert it into the DataStore""" handle = handle_class(key, path=path, data=data, creator=creator) self[key] = handle return handle - def read_file(self, key, handle_class, path, creator='DataStore', **kwargs): - """ Create a handle, use it to read a file, and insert it into the DataStore """ + def read_file(self, key, handle_class, path, creator="DataStore", **kwargs): + """Create a handle, use it to read a file, and insert it into the DataStore""" handle = handle_class(key, path=path, data=None, creator=creator) handle.read(**kwargs) self[key] = handle return handle def read(self, key, force=False, **kwargs): - """ Read the data associated to a particular key """ + """Read the data associated to a particular key""" try: handle = self[key] except KeyError as msg: raise KeyError(f"Failed to read data {key} because {msg}") from msg return handle.read(force, **kwargs) - def open(self, key, mode='r', **kwargs): - """ Open and return the file associated to a particular key """ + def open(self, key, mode="r", **kwargs): + """Open and return the file associated to a particular key""" try: handle = self[key] except KeyError as msg: @@ -480,7 +512,7 @@ def open(self, key, mode='r', **kwargs): return handle.open(mode=mode, **kwargs) def write(self, key, **kwargs): - """ Write the data associated to a particular key """ + """Write the data associated to a particular key""" try: handle = self[key] except KeyError as msg: @@ -488,7 +520,7 @@ def write(self, key, **kwargs): return handle.write(**kwargs) def write_all(self, force=False, **kwargs): - """ Write all the data in this DataStore """ + """Write all the data in this DataStore""" for key, handle in self.items(): local_kwargs = kwargs.get(key, {}) if handle.is_written and not force: @@ -496,9 +528,9 @@ def write_all(self, force=False, **kwargs): handle.write(**local_kwargs) - _DATA_STORE = DataStore() + def DATA_STORE(): """Return the factory instance""" return _DATA_STORE diff --git a/src/rail/core/introspection.py b/src/rail/core/introspection.py index 069a906f..aae38704 100644 --- a/src/rail/core/introspection.py +++ b/src/rail/core/introspection.py @@ -1,13 +1,12 @@ import pkgutil -import setuptools import os import importlib +import setuptools import rail class RailEnv: - PACKAGES = {} NAMESPACE_PATH_DICT = {} NAMESPACE_MODULE_DICT = {} @@ -20,10 +19,12 @@ class RailEnv: @classmethod def list_rail_packages(cls): """List all the packages that are available in the RAIL ecosystem""" - cls.PACKAGES = {pkg.name:pkg for pkg in pkgutil.iter_modules(rail.__path__, rail.__name__ + '.')} + cls.PACKAGES = { + pkg.name: pkg + for pkg in pkgutil.iter_modules(rail.__path__, rail.__name__ + ".") + } return cls.PACKAGES - @classmethod def print_rail_packages(cls): """Print all the packages that are available in the RAIL ecosystem""" @@ -31,18 +32,17 @@ def print_rail_packages(cls): cls.list_rail_packages() for pkg_name, pkg in cls.PACKAGES.items(): print(f"{pkg_name} @ {pkg[0].path}") - return @classmethod def list_rail_namespaces(cls): """List all the namespaces within rail""" cls.NAMESPACE_PATH_DICT.clear() - + for path_ in rail.__path__: namespaces = setuptools.find_namespace_packages(path_) for namespace_ in namespaces: # exclude stuff that starts with 'example' - if namespace_.find('example') == 0: + if namespace_.find("example") == 0: continue if namespace_ in cls.NAMESPACE_PATH_DICT: # pragma: no cover cls.NAMESPACE_PATH_DICT[namespace_].append(path_) @@ -51,7 +51,6 @@ def list_rail_namespaces(cls): return cls.NAMESPACE_PATH_DICT - @classmethod def print_rail_namespaces(cls): """Print all the namespaces that are available in the RAIL ecosystem""" @@ -61,8 +60,6 @@ def print_rail_namespaces(cls): print(f"Namespace {key}") for vv in val: print(f" {vv}") - return - @classmethod def list_rail_modules(cls): @@ -75,8 +72,11 @@ def list_rail_modules(cls): for key, val in cls.NAMESPACE_PATH_DICT.items(): cls.NAMESPACE_MODULE_DICT[key] = [] for vv in val: - fullpath = os.path.join(vv, key.replace('.', '/')) - modules = [pkg for pkg in pkgutil.iter_modules([fullpath], rail.__name__ + '.' + key + '.')] + fullpath = os.path.join(vv, key.replace(".", "/")) + modules = list(pkgutil.iter_modules( + [fullpath], rail.__name__ + "." + key + "." + ) + ) for module_ in modules: if module_ in cls.MODULE_DICT: # pragma: no cover cls.MODULE_DICT[module_.name].append(key) @@ -87,13 +87,12 @@ def list_rail_modules(cls): return cls.MODULE_PATH_DICT - @classmethod def print_rail_modules(cls): """Print all the moduels that are available in the RAIL ecosystem""" if not cls.MODULE_DICT: cls.list_rail_modules() - + for key, val in cls.MODULE_DICT.items(): print(f"Module {key}") for vv in val: @@ -103,8 +102,6 @@ def print_rail_modules(cls): print(f"Namespace {key}") for vv in val: print(f" {vv}") - return - @classmethod def build_rail_namespace_tree(cls): @@ -117,27 +114,29 @@ def build_rail_namespace_tree(cls): cls.list_rail_packages() level_dict = {} - for key in cls.NAMESPACE_MODULE_DICT.keys(): - count = key.count('.') + for key in cls.NAMESPACE_MODULE_DICT: + count = key.count(".") if count in level_dict: level_dict[count].append(key) else: level_dict[count] = [key] depth = max(level_dict.keys()) - for current_depth in range(depth+1): + for current_depth in range(depth + 1): for key in level_dict[current_depth]: - nsname = f"rail.{key}" + _nsname = f"rail.{key}" if current_depth == 0: - nsname = f"rail.{key}" + _nsname = f"rail.{key}" cls.TREE[key] = cls.NAMESPACE_MODULE_DICT[key] else: - parent_key = '.'.join(key.split('.')[0:current_depth]) + parent_key = ".".join(key.split(".")[0:current_depth]) if parent_key in cls.TREE: - cls.TREE[parent_key].append({key:cls.NAMESPACE_MODULE_DICT[key]}) + cls.TREE[parent_key].append( + {key: cls.NAMESPACE_MODULE_DICT[key]} + ) return cls.TREE - + @classmethod def pretty_print_tree(cls, the_dict=None, indent=""): """Utility function to help print the namespace tree @@ -148,7 +147,7 @@ def pretty_print_tree(cls, the_dict=None, indent=""): ---------- the_dict: dict | None Current dictionary to print, if None it will print cls.TREE - + indent: str Indentation string prepended to each line """ @@ -164,29 +163,25 @@ def pretty_print_tree(cls, the_dict=None, indent=""): print(f"{indent}{pkg_type} {nsname}") for vv in val: if isinstance(vv, dict): - cls.pretty_print_tree(vv, indent=indent+" ") + cls.pretty_print_tree(vv, indent=indent + " ") else: print(f" {indent}{vv.name}") - return - @classmethod def print_rail_namespace_tree(cls): """Print the namespace tree in a nice way""" if not cls.TREE: cls.build_rail_namespace_tree() cls.pretty_print_tree(cls.TREE) - return - + @classmethod def do_pkg_api_rst(cls, basedir, key, val): """Build the api rst file for a rail package""" - + api_pkg_toc = f"rail.{key} package\n" - api_pkg_toc += "="*len(api_pkg_toc) + api_pkg_toc += "=" * len(api_pkg_toc) - api_pkg_toc += \ -f""" + api_pkg_toc += f""" .. automodule:: rail.{key} :members: :undoc-members: @@ -199,29 +194,26 @@ def do_pkg_api_rst(cls, basedir, key, val): :maxdepth: 4 """ - + for vv in val: if isinstance(vv, dict): # pragma: no cover - for k3, v3 in vv.items(): + for _k3, v3 in vv.items(): for v4 in v3: api_pkg_toc += f" {v4.name}.rst\n" else: - api_pkg_toc += f" {vv.name}.rst\n" + api_pkg_toc += f" {vv.name}.rst\n" - with open(os.path.join(basedir, 'api', f"rail.{key}.rst"), 'w') as apitocfile: + with open(os.path.join(basedir, "api", f"rail.{key}.rst"), "w", encoding="utf-8") as apitocfile: apitocfile.write(api_pkg_toc) - return - @classmethod def do_namespace_api_rst(cls, basedir, key, val): """Build the api rst file for a rail namespace""" api_pkg_toc = f"{key} namespace\n" - api_pkg_toc += "="*len(api_pkg_toc) + api_pkg_toc += "=" * len(api_pkg_toc) - api_pkg_toc += \ -""" + api_pkg_toc += """ .. py:module:: rail.estimation @@ -251,20 +243,19 @@ def do_namespace_api_rst(cls, basedir, key, val): sub_packages += f" rail.{k3}\n" else: sub_modules += f" {vv.name}\n" - api_pkg_toc = api_pkg_toc.format(sub_packages=sub_packages, sub_modules=sub_modules) - - with open(os.path.join(basedir, 'api', f"rail.{key}.rst"), 'w') as apitocfile: + api_pkg_toc = api_pkg_toc.format( + sub_packages=sub_packages, sub_modules=sub_modules + ) + + with open(os.path.join(basedir, "api", f"rail.{key}.rst"), "w", encoding="utf-8") as apitocfile: apitocfile.write(api_pkg_toc) - return - @classmethod - def do_api_rst(cls, basedir='.'): + def do_api_rst(cls, basedir="."): if not cls.TREE: # pragma: no cover cls.build_rail_namespace_tree() - apitoc = \ -"""API Documentation + apitoc = """API Documentation ================= Information on specific functions, classes, and methods. @@ -274,53 +265,50 @@ def do_api_rst(cls, basedir='.'): """ try: os.makedirs(basedir) - except: + except Exception: pass try: - os.makedirs(os.path.join(basedir, 'api')) - except: # pragma: no cover + os.makedirs(os.path.join(basedir, "api")) + except Exception: # pragma: no cover pass - for key, val in cls.TREE.items(): + for key, val in cls.TREE.items(): nsname = f"rail.{key}" - nsfile = os.path.join('api', f"{nsname}.rst") + nsfile = os.path.join("api", f"{nsname}.rst") apitoc += f" {nsfile}\n" if nsname in cls.PACKAGES: cls.do_pkg_api_rst(basedir, key, val) else: - cls.do_namespace_api_rst(basedir, key, val) + cls.do_namespace_api_rst(basedir, key, val) - with open(os.path.join(basedir, 'api.rst'), 'w') as apitocfile: + with open(os.path.join(basedir, "api.rst"), "w", encoding="utf-8") as apitocfile: apitocfile.write(apitoc) - return - - @classmethod def import_all_packages(cls): """Import all the packages that are available in the RAIL ecosystem""" pkgs = cls.list_rail_packages() - for pkg in pkgs.keys(): + for pkg in pkgs: try: - imported_module = importlib.import_module(pkg) + _imported_module = importlib.import_module(pkg) print(f"Imported {pkg}") except Exception as msg: print(f"Failed to import {pkg} because: {str(msg)}") - @classmethod def attach_stages(cls, to_module): """Attach all the available stages to this module - + This allow you to do 'from rail.stages import *' """ - from rail.core.stage import RailStage + from rail.core.stage import RailStage # pylint: disable=import-outside-toplevel + cls.STAGE_DICT.clear() - cls.STAGE_DICT['none'] = [] + cls.STAGE_DICT["none"] = [] cls.BASE_STAGES.clear() - + n_base_classes = 0 n_stages = 0 @@ -343,9 +331,9 @@ def attach_stages(cls, to_module): break cls.STAGE_DICT[baseclass].append(stage_name) - print(f"Attached {n_base_classes} base classes and {n_stages} fully formed stages to rail.stages") - return - + print( + f"Attached {n_base_classes} base classes and {n_stages} fully formed stages to rail.stages" + ) @classmethod def print_rail_stage_dict(cls): diff --git a/src/rail/core/point_estimation.py b/src/rail/core/point_estimation.py index b9a98d1d..23542f18 100644 --- a/src/rail/core/point_estimation.py +++ b/src/rail/core/point_estimation.py @@ -1,8 +1,8 @@ import numpy as np from numpy.typing import NDArray -class PointEstimationMixin(): +class PointEstimationMixin: def calculate_point_estimates(self, qp_dist, grid=None): """This function drives the calculation of point estimates for qp.Ensembles. It is defined here, and called from the `_process_chunk` method in the @@ -36,20 +36,20 @@ def calculate_point_estimates(self, qp_dist, grid=None): ancil_dict = dict() calculated_point_estimates = [] - if 'calculated_point_estimates' in self.config: - calculated_point_estimates = self.config['calculated_point_estimates'] + if "calculated_point_estimates" in self.config: + calculated_point_estimates = self.config["calculated_point_estimates"] - if 'mode' in calculated_point_estimates: + if "mode" in calculated_point_estimates: mode_value = self._calculate_mode_point_estimate(qp_dist, grid) - ancil_dict.update(mode = mode_value) + ancil_dict.update(mode=mode_value) - if 'mean' in calculated_point_estimates: + if "mean" in calculated_point_estimates: mean_value = self._calculate_mean_point_estimate(qp_dist) - ancil_dict.update(mean = mean_value) + ancil_dict.update(mean=mean_value) - if 'median' in calculated_point_estimates: + if "median" in calculated_point_estimates: median_value = self._calculate_median_point_estimate(qp_dist) - ancil_dict.update(median = median_value) + ancil_dict.update(median=median_value) if calculated_point_estimates: qp_dist.set_ancil(ancil_dict) @@ -82,10 +82,12 @@ def _calculate_mode_point_estimate(self, qp_dist, grid=None) -> NDArray: we'll raise a KeyError. """ if grid is None: - for key in ['zmin', 'zmax', 'nzbins']: + for key in ["zmin", "zmax", "nzbins"]: if key not in self.config: - raise KeyError(f"Expected `{key}` to be defined in stage " \ - "configuration dictionary in order to caluclate mode.") + raise KeyError( + f"Expected `{key}` to be defined in stage " + "configuration dictionary in order to caluclate mode." + ) grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins) diff --git a/src/rail/core/stage.py b/src/rail/core/stage.py index 5e6d5653..abfac239 100644 --- a/src/rail/core/stage.py +++ b/src/rail/core/stage.py @@ -1,12 +1,12 @@ """ Base class for PipelineStages in Rail """ import os +from math import ceil from ceci import PipelineStage, MiniPipeline from ceci.config import StageParameter as Param from rail.core.data import DATA_STORE, DataHandle -from math import ceil class StageIO: """A small utility class for Stage Input/ Output @@ -20,6 +20,7 @@ class StageIO: This allows users to be more concise when writing pipelines. """ + def __init__(self, parent): self._parent = parent @@ -42,6 +43,7 @@ class RailStageBuild: a_stage = StageClass.make_stage(..) a_pipe.add_stage(a_stage) """ + def __init__(self, stage_class, **kwargs): self.stage_class = stage_class self._kwargs = kwargs @@ -76,7 +78,7 @@ class RailPipeline(MiniPipeline): """ def __init__(self): - MiniPipeline.__init__(self, [], dict(name='mini')) + MiniPipeline.__init__(self, [], dict(name="mini")) def __setattr__(self, name, value): if isinstance(value, RailStageBuild): @@ -127,14 +129,15 @@ class RailStage(PipelineStage): `self.set_data(inputTag, other.get_handle(outputTag, allow_missing=True), do_read=False)` """ - config_options = dict(output_mode=Param(str, 'default', - msg="What to do with the outputs")) + config_options = dict( + output_mode=Param(str, "default", msg="What to do with the outputs") + ) data_store = DATA_STORE() def __init__(self, args, comm=None): - """ Constructor: - Do RailStage specific initialization """ + """Constructor: + Do RailStage specific initialization""" PipelineStage.__init__(self, args, comm=comm) self._input_length = None self.io = StageIO(self) @@ -156,7 +159,7 @@ def make_and_connect(cls, **kwargs): ------- A stage """ - connections = kwargs.pop('connections', {}) + connections = kwargs.pop("connections", {}) stage = cls.make_stage(**kwargs) for key, val in connections.items(): stage.set_data(key, val, do_read=False) @@ -188,7 +191,9 @@ def get_handle(self, tag, path=None, allow_missing=False): handle = self.data_store.get(aliased_tag) if handle is None: if not allow_missing: - raise KeyError(f'{self.instance_name} failed to get data by handle {aliased_tag}, associated to {tag}') + raise KeyError( + f"{self.instance_name} failed to get data by handle {aliased_tag}, associated to {tag}" + ) handle = self.add_handle(tag, path=path) return handle @@ -218,8 +223,12 @@ def add_handle(self, tag, data=None, path=None): if path is None: path = self.get_output(aliased_tag) handle_type = self.get_output_type(tag) - handle = handle_type(aliased_tag, path=path, data=data, creator=self.instance_name) - print(f"Inserting handle into data store. {aliased_tag}: {handle.path}, {handle.creator}") + handle = handle_type( + aliased_tag, path=path, data=data, creator=self.instance_name + ) + print( + f"Inserting handle into data store. {aliased_tag}: {handle.path}, {handle.creator}" + ) self.data_store[aliased_tag] = handle return handle @@ -331,39 +340,43 @@ def input_iterator(self, tag, **kwargs): try: self.config.hdf5_groupname - except: + except Exception: self.config.hdf5_groupname = None - if handle.path and handle.path!='None': + if handle.path and handle.path != "None": # pylint: disable=no-else-return self._input_length = handle.size(groupname=self.config.hdf5_groupname) - total_chunks_needed = ceil(self._input_length/self.config.chunk_size) + total_chunks_needed = ceil(self._input_length / self.config.chunk_size) # If the number of process is larger than we need, we wemove some of them - if total_chunks_needed < self.size: #pragma: no cover - color = self.rank+1 <= total_chunks_needed - newcomm = self.comm.Split(color=color,key=self.rank) + if total_chunks_needed < self.size: # pragma: no cover + color = self.rank + 1 <= total_chunks_needed + newcomm = self.comm.Split(color=color, key=self.rank) if color: self.setup_mpi(newcomm) else: quit() - kwcopy = dict(groupname=self.config.hdf5_groupname, - chunk_size=self.config.chunk_size, - rank=self.rank, - parallel_size=self.size) + kwcopy = dict( + groupname=self.config.hdf5_groupname, + chunk_size=self.config.chunk_size, + rank=self.rank, + parallel_size=self.size, + ) kwcopy.update(**kwargs) return handle.iterator(**kwcopy) + + # If data is in memory and not in a file, it means is small enough to process it # in a single chunk. - else: #pragma: no cover + else: # pragma: no cover if self.config.hdf5_groupname: - test_data = self.get_data('input')[self.config.hdf5_groupname] + test_data = self.get_data("input")[self.config.hdf5_groupname] else: - test_data = self.get_data('input') + test_data = self.get_data("input") max_l = 0 - for k, v in test_data.items(): + for _k, v in test_data.items(): max_l = max(max_l, len(v)) self._input_length = max_l s = 0 - iterator=[[s, self._input_length, test_data]] + iterator = [[s, self._input_length, test_data]] return iterator def connect_input(self, other, inputTag=None, outputTag=None): @@ -383,7 +396,7 @@ def connect_input(self, other, inputTag=None, outputTag=None): handle : The input handle for this stage """ if inputTag is None: - inputTag = self.inputs[0][0] #pylint: disable=no-member + inputTag = self.inputs[0][0] # pylint: disable=no-member if outputTag is None: outputTag = other.outputs[0][0] handle = other.get_handle(outputTag, allow_missing=True) @@ -395,7 +408,7 @@ def _finalize_tag(self, tag): This can be overridden by sub-classes for more complicated behavior """ handle = self.get_handle(tag, allow_missing=True) - if self.config.output_mode == 'default': + if self.config.output_mode == "default": if not os.path.exists(handle.path): handle.write() final_name = PipelineStage._finalize_tag(self, tag) diff --git a/src/rail/core/util_stages.py b/src/rail/core/util_stages.py index 2ee536a4..fb2a4a32 100644 --- a/src/rail/core/util_stages.py +++ b/src/rail/core/util_stages.py @@ -1,7 +1,4 @@ """ Stages that implement utility functions """ -import os -import numpy as np - import tables_io from rail.core.stage import RailStage @@ -20,21 +17,22 @@ class ColumnMapper(RailStage): `output_data = input_data.rename(columns=self.config.columns, inplace=self.config.inplace)` """ - name = 'ColumnMapper' + + name = "ColumnMapper" config_options = RailStage.config_options.copy() config_options.update(chunk_size=100_000, columns=dict, inplace=False) - inputs = [('input', PqHandle)] - outputs = [('output', PqHandle)] + inputs = [("input", PqHandle)] + outputs = [("output", PqHandle)] def __init__(self, args, comm=None): RailStage.__init__(self, args, comm=comm) def run(self): - data = self.get_data('input', allow_missing=True) + data = self.get_data("input", allow_missing=True) out_data = data.rename(columns=self.config.columns, inplace=self.config.inplace) - if self.config.inplace: #pragma: no cover + if self.config.inplace: # pragma: no cover out_data = data - self.add_data('output', out_data) + self.add_data("output", out_data) def __repr__(self): # pragma: no cover printMsg = "Stage that applies remaps the following column names in a pandas DataFrame:\n" @@ -54,10 +52,10 @@ def __call__(self, data): table : Table-like The degraded sample """ - self.set_data('input', data) + self.set_data("input", data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") class RowSelector(RailStage): @@ -71,19 +69,20 @@ class RowSelector(RailStage): `output_data = input_data[self.config.start:self.config.stop]` """ - name = 'RowSelector' + + name = "RowSelector" config_options = RailStage.config_options.copy() config_options.update(start=int, stop=int) - inputs = [('input', PqHandle)] - outputs = [('output', PqHandle)] + inputs = [("input", PqHandle)] + outputs = [("output", PqHandle)] def __init__(self, args, comm=None): RailStage.__init__(self, args, comm=comm) def run(self): - data = self.get_data('input', allow_missing=True) - out_data = data.iloc[self.config.start:self.config.stop] - self.add_data('output', out_data) + data = self.get_data("input", allow_missing=True) + out_data = data.iloc[self.config.start : self.config.stop] + self.add_data("output", out_data) def __repr__(self): # pragma: no cover printMsg = "Stage that applies remaps the following column names in a pandas DataFrame:\n" @@ -103,10 +102,10 @@ def __call__(self, data): table : table-like The degraded sample """ - self.set_data('input', data) + self.set_data("input", data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") class TableConverter(RailStage): @@ -115,20 +114,21 @@ class TableConverter(RailStage): FIXME, this is hardwired to convert parquet tables to Hdf5Tables. It would be nice to have more options here. """ - name = 'TableConverter' + + name = "TableConverter" config_options = RailStage.config_options.copy() config_options.update(output_format=str) - inputs = [('input', PqHandle)] - outputs = [('output', Hdf5Handle)] + inputs = [("input", PqHandle)] + outputs = [("output", Hdf5Handle)] def __init__(self, args, comm=None): RailStage.__init__(self, args, comm=comm) def run(self): - data = self.get_data('input', allow_missing=True) + data = self.get_data("input", allow_missing=True) out_fmt = tables_io.types.TABULAR_FORMAT_NAMES[self.config.output_format] out_data = tables_io.convert(data, out_fmt) - self.add_data('output', out_data) + self.add_data("output", out_data) def __call__(self, data): """Return a converted table @@ -143,7 +143,7 @@ def __call__(self, data): out_data : table-like The converted version of the table """ - self.set_data('input', data) + self.set_data("input", data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") diff --git a/src/rail/core/utils.py b/src/rail/core/utils.py index 421c4c7d..bfe21f04 100644 --- a/src/rail/core/utils.py +++ b/src/rail/core/utils.py @@ -4,7 +4,7 @@ import rail import rail.core -RAILDIR = os.path.abspath(os.path.join(os.path.dirname(rail.core.__file__), '..', '..')) +RAILDIR = os.path.abspath(os.path.join(os.path.dirname(rail.core.__file__), "..", "..")) def find_rail_file(relpath): diff --git a/src/rail/creation/degradation/spectroscopic_selections.py b/src/rail/creation/degradation/spectroscopic_selections.py index 83e1a149..98c15ace 100644 --- a/src/rail/creation/degradation/spectroscopic_selections.py +++ b/src/rail/creation/degradation/spectroscopic_selections.py @@ -3,9 +3,9 @@ import os import numpy as np +from scipy.interpolate import interp1d from ceci.config import StageParameter as Param from rail.creation.degrader import Degrader -from scipy.interpolate import interp1d from rail.core.utils import RAILDIR @@ -24,9 +24,11 @@ class SpecSelection(Degrader): If True, then downsample the pre-selected galaxies to N_tot galaxies. success_rate_dir: string, the path to the success rate files. - percentile_cut: If using color-based redshift cut, percentile in redshifts above which redshifts will be cut from the sample. Default is 100 (no cut) + percentile_cut: If using color-based redshift cut, percentile in redshifts above which + redshifts will be cut from the sample. Default is 100 (no cut) colnames: a dictionary that includes necessary columns - (magnitudes, colors and redshift) for selection. For magnitudes, the keys are ugrizy; for colors, the keys are, + (magnitudes, colors and redshift) for selection. For magnitudes, + the keys are ugrizy; for colors, the keys are, for example, gr standing for g-r; for redshift, the key is 'redshift'. random_seed: random seed for reproducibility. @@ -44,7 +46,8 @@ class SpecSelection(Degrader): ), success_rate_dir=Param( str, - os.path.join(RAILDIR, + os.path.join( + RAILDIR, "rail/examples_data/creation_data/data/success_rate_data", ), msg="The path to the directory containing success rate files.", @@ -56,9 +59,10 @@ class SpecSelection(Degrader): **{band: "mag_" + band + "_lsst" for band in "ugrizy"}, **{"redshift": "redshift"}, }, - msg="a dictionary that includes necessary columns\ - (magnitudes, colors and redshift) for selection. For magnitudes, the keys are ugrizy; for colors, the keys are, \ - for example, gr standing for g-r; for redshift, the key is 'redshift'", + msg="a dictionary that includes necessary columns" + "(magnitudes, colors and redshift) for selection." + "For magnitudes, the keys are ugrizy; for colors, the keys are," + "for example, gr standing for g-r; for redshift, the key is 'redshift'", ), random_seed=Param(int, 42, msg="random seed for reproducibility"), ) @@ -77,7 +81,7 @@ def _validate_settings(self): if self.config.N_tot < 0: raise ValueError( - "Total number of selected sources must be a " "positive integer." + "Total number of selected sources must be a positive integer." ) if os.path.exists(self.config.success_rate_dir) is not True: raise ValueError( @@ -97,9 +101,9 @@ def validate_colnames(self, data): if check is not True: raise ValueError( "Columns in the data are not enough for the selection." - + "The data should contain " - + str(list(colnames)) - + ". \n" + "The data should contain " + + str(list(colnames)) + + ". \n" ) def selection(self, data): @@ -129,7 +133,7 @@ def downsampling_N_tot(self): """ N_tot = self.config.N_tot N_selected = np.count_nonzero(self.mask) - if N_tot > N_selected: + if N_tot > N_selected: # pylint: disable=no-else-return print( "Warning: N_tot is greater than the size of spec-selected " + "sample (" @@ -543,7 +547,8 @@ class SpecSelection_HSC(SpecSelection): def photometryCut(self, data): """ - HSC galaxies were binned in color magnitude space with i-band mag from -2 to 6 and g-z color from 13 to 26. + HSC galaxies were binned in color magnitude space with i-band mag + from -2 to 6 and g-z color from 13 to 26. """ mask = (data[self.config.colnames["i"]] > 13.0) & ( data[self.config.colnames["i"]] < 26.0 @@ -555,10 +560,11 @@ def photometryCut(self, data): def speczSuccess(self, data): """ - HSC galaxies were binned in color magnitude space with i-band mag from -2 to 6 and g-z color from 13 to 26 - 200 bins in each direction. The ratio of of galaxies with spectroscopic redshifts (training galaxies) to - galaxies with only photometry in HSC wide field (application galaxies) was computed for each pixel. We divide - the data into the same pixels and randomly select galaxies into the training sample based on the HSC ratios + HSC galaxies were binned in color magnitude space with i-band mag from -2 to 6 and g-z color + from 13 to 26 200 bins in each direction. The ratio of of galaxies with spectroscopic redshifts + (training galaxies) to galaxies with only photometry in HSC wide field (application galaxies) + was computed for each pixel. We divide the data into the same pixels and randomly select galaxies + into the training sample based on the HSC ratios """ success_rate_dir = self.config.success_rate_dir x_edge = np.linspace(13, 26, 201, endpoint=True) diff --git a/src/rail/creation/degrader.py b/src/rail/creation/degrader.py index 70260c4d..ed8ab187 100644 --- a/src/rail/creation/degrader.py +++ b/src/rail/creation/degrader.py @@ -7,6 +7,7 @@ from rail.core.stage import RailStage from rail.core.data import PqHandle + class Degrader(RailStage): """Base class Degraders, which apply various degradations to synthetic photometric data @@ -14,11 +15,11 @@ class Degrader(RailStage): provide as "output" another pandas dataframes written to Parquet files """ - name = 'Degrader' + name = "Degrader" config_options = RailStage.config_options.copy() config_options.update(seed=12345) - inputs = [('input', PqHandle)] - outputs = [('output', PqHandle)] + inputs = [("input", PqHandle)] + outputs = [("output", PqHandle)] def __init__(self, args, comm=None): """Initialize Degrader that can degrade photometric data""" @@ -54,7 +55,7 @@ def __call__(self, sample, seed: int = None): """ if seed is not None: self.config.seed = seed - self.set_data('input', sample) + self.set_data("input", sample) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") diff --git a/src/rail/creation/engine.py b/src/rail/creation/engine.py index 7d6d7858..33cf0097 100644 --- a/src/rail/creation/engine.py +++ b/src/rail/creation/engine.py @@ -36,7 +36,8 @@ def fit_model(self): Returns ------- - [This will definitely be a file, but the filetype and format depend entirely on the modeling approach!] + [This will definitely be a file, but the filetype and format + depend entirely on the modeling approach!] """ self.run() self.finalize() diff --git a/src/rail/estimation/algos/equal_count.py b/src/rail/estimation/algos/equal_count.py index 4a9c1917..be26e78b 100644 --- a/src/rail/estimation/algos/equal_count.py +++ b/src/rail/estimation/algos/equal_count.py @@ -8,44 +8,52 @@ from rail.estimation.classifier import PZClassifier from rail.core.data import TableHandle + class EqualCountClassifier(PZClassifier): """Classifier that simply assign tomographic bins based on point estimate according to SRD""" - name = 'EqualCountClassifier' + name = "EqualCountClassifier" config_options = PZClassifier.config_options.copy() config_options.update( - id_name=Param(str, "", msg="Column name for the object ID in the input data, if empty the row index is used as the ID."), - point_estimate=Param(str, 'zmode', msg="Which point estimate to use"), + id_name=Param( + str, + "", + msg="Column name for the object ID in the input data, if empty the row index is used as the ID.", + ), + point_estimate=Param(str, "zmode", msg="Which point estimate to use"), zmin=Param(float, 0.0, msg="Minimum redshift of the sample"), zmax=Param(float, 3.0, msg="Maximum redshift of the sample"), nbins=Param(int, 5, msg="Number of tomographic bins"), no_assign=Param(int, -99, msg="Value for no assignment flag"), - ) - outputs = [('output', TableHandle)] + ) + outputs = [("output", TableHandle)] def __init__(self, args, comm=None): PZClassifier.__init__(self, args, comm=comm) def run(self): - test_data = self.get_data('input') + test_data = self.get_data("input") npdf = test_data.npdf - + try: zb = test_data.ancil[self.config.point_estimate] - except KeyError: - raise KeyError(f"{self.config.point_estimate} is not contained in the data ancil, you will need to compute it explicitly.") + except KeyError as msg: + raise KeyError( + f"{self.config.point_estimate} is not contained in the data ancil, " + "you will need to compute it explicitly." + ) from msg # tomographic bins with equal number density sortind = np.argsort(zb) - cum=np.arange(1,(len(zb)+1)) + cum = np.arange(1, (len(zb) + 1)) bin_index = np.zeros(len(zb)) for ii in range(self.config.nbins): - perc1=ii/self.config.nbins - perc2=(ii+1)/self.config.nbins - ind=(cum/cum[-1]>perc1)&(cum/cum[-1]<=perc2) - useind=sortind[ind] - bin_index[useind] = int(ii+1) + perc1 = ii / self.config.nbins + perc2 = (ii + 1) / self.config.nbins + ind = (cum / cum[-1] > perc1) & (cum / cum[-1] <= perc2) + useind = sortind[ind] + bin_index[useind] = int(ii + 1) if self.config.id_name != "": # below is commented out and replaced by a redundant line @@ -55,7 +63,7 @@ def run(self): elif self.config.id_name == "": # ID set to row index obj_id = np.arange(npdf) - self.config.id_name="row_index" - + self.config.id_name = "row_index" + class_id = {self.config.id_name: obj_id, "class_id": bin_index} - self.add_data('output', class_id) \ No newline at end of file + self.add_data("output", class_id) diff --git a/src/rail/estimation/algos/naive_stack.py b/src/rail/estimation/algos/naive_stack.py index 4d52c568..821b3dfa 100644 --- a/src/rail/estimation/algos/naive_stack.py +++ b/src/rail/estimation/algos/naive_stack.py @@ -12,44 +12,46 @@ class NaiveStackInformer(PzInformer): - """Placeholder Informer - """ + """Placeholder Informer""" - name = 'NaiveStackInformer' + name = "NaiveStackInformer" config_options = PzInformer.config_options.copy() def __init__(self, args, comm=None): PzInformer.__init__(self, args, comm=comm) def run(self): - self.add_data('model', np.array([None])) + self.add_data("model", np.array([None])) + class NaiveStackSummarizer(PZSummarizer): - """Summarizer which stacks individual P(z) - """ + """Summarizer which stacks individual P(z)""" - name = 'NaiveStackSummarizer' + name = "NaiveStackSummarizer" config_options = PZSummarizer.config_options.copy() - config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), - zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), - nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), - seed=Param(int, 87, msg="random seed"), - nsamples=Param(int, 1000, msg="Number of sample distributions to create")) - inputs = [('input', QPHandle)] - outputs = [('output', QPHandle), - ('single_NZ', QPHandle)] + config_options.update( + zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), + zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), + nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), + seed=Param(int, 87, msg="random seed"), + nsamples=Param(int, 1000, msg="Number of sample distributions to create"), + ) + inputs = [("input", QPHandle)] + outputs = [("output", QPHandle), ("single_NZ", QPHandle)] def __init__(self, args, comm=None): PZSummarizer.__init__(self, args, comm=comm) self.zgrid = None def run(self): - iterator = self.input_iterator('input') - self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1) + iterator = self.input_iterator("input") + self.zgrid = np.linspace( + self.config.zmin, self.config.zmax, self.config.nzbins + 1 + ) # Initiallizing the stacking pdf's yvals = np.zeros((1, len(self.zgrid))) bvals = np.zeros((self.config.nsamples, len(self.zgrid))) - bootstrap_matrix = self._broadcast_bootstrap_matrix() + bootstrap_matrix = self._broadcast_bootstrap_matrix() first = True for s, e, test_data in iterator: @@ -60,24 +62,22 @@ def run(self): bvals, yvals = self._join_histograms(bvals, yvals) if self.rank == 0: - sample_ens = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=bvals)) + sample_ens = qp.Ensemble( + qp.interp, data=dict(xvals=self.zgrid, yvals=bvals) + ) qp_d = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=yvals)) - self.add_data('output', sample_ens) - self.add_data('single_NZ', qp_d) - + self.add_data("output", sample_ens) + self.add_data("single_NZ", qp_d) - def _process_chunk(self, start, end, data, first, bootstrap_matrix, yvals, bvals): + def _process_chunk(self, start, end, data, _first, bootstrap_matrix, yvals, bvals): pdf_vals = data.pdf(self.zgrid) - yvals += np.expand_dims(np.sum(np.where(np.isfinite(pdf_vals), pdf_vals, 0.), axis=0), 0) + yvals += np.expand_dims( + np.sum(np.where(np.isfinite(pdf_vals), pdf_vals, 0.0), axis=0), 0 + ) # qp_d is the normalized probability of the stack, we need to know how many galaxies were for i in range(self.config.nsamples): bootstrap_draws = bootstrap_matrix[:, i] # Neither all of the bootstrap_draws are in this chunk nor the index starts at "start" - mask = (bootstrap_draws>=start) & (bootstrap_draws= start) & (bootstrap_draws < end) bootstrap_draws = bootstrap_draws[mask] - start bvals[i] += np.sum(pdf_vals[bootstrap_draws], axis=0) - - - - - diff --git a/src/rail/estimation/algos/point_est_hist.py b/src/rail/estimation/algos/point_est_hist.py index 43ecae3a..bb2e0bd5 100644 --- a/src/rail/estimation/algos/point_est_hist.py +++ b/src/rail/estimation/algos/point_est_hist.py @@ -12,46 +12,46 @@ class PointEstHistInformer(PzInformer): - """Placeholder Informer - """ + """Placeholder Informer""" - name = 'PointEstHistInformer' + name = "PointEstHistInformer" config_options = PzInformer.config_options.copy() def __init__(self, args, comm=None): PzInformer.__init__(self, args, comm=comm) def run(self): - self.add_data('model', np.array([None])) + self.add_data("model", np.array([None])) class PointEstHistSummarizer(PZSummarizer): - """Summarizer which simply histograms a point estimate - """ + """Summarizer which simply histograms a point estimate""" - name = 'PointEstHistSummarizer' + name = "PointEstHistSummarizer" config_options = PZSummarizer.config_options.copy() - config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), - zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), - nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), - seed=Param(int, 87, msg="random seed"), - point_estimate=Param(str, 'zmode', msg="Which point estimate to use"), - nsamples=Param(int, 1000, msg="Number of sample distributions to return")) - inputs = [('input', QPHandle)] - outputs = [('output', QPHandle), - ('single_NZ', QPHandle)] + config_options.update( + zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), + zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), + nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), + seed=Param(int, 87, msg="random seed"), + point_estimate=Param(str, "zmode", msg="Which point estimate to use"), + nsamples=Param(int, 1000, msg="Number of sample distributions to return"), + ) + inputs = [("input", QPHandle)] + outputs = [("output", QPHandle), ("single_NZ", QPHandle)] def __init__(self, args, comm=None): PZSummarizer.__init__(self, args, comm=comm) self.zgrid = None self.bincents = None - def run(self): - iterator = self.input_iterator('input') - self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1) + iterator = self.input_iterator("input") + self.zgrid = np.linspace( + self.config.zmin, self.config.zmax, self.config.nzbins + 1 + ) self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1]) - bootstrap_matrix = self._broadcast_bootstrap_matrix() + bootstrap_matrix = self._broadcast_bootstrap_matrix() # Initiallizing the histograms single_hist = np.zeros(self.config.nzbins) hist_vals = np.zeros((self.config.nsamples, self.config.nzbins)) @@ -59,27 +59,32 @@ def run(self): first = True for s, e, test_data in iterator: print(f"Process {self.rank} running estimator on chunk {s} - {e}") - self._process_chunk(s, e, test_data, first, bootstrap_matrix, single_hist, hist_vals) + self._process_chunk( + s, e, test_data, first, bootstrap_matrix, single_hist, hist_vals + ) first = False if self.comm is not None: # pragma: no cover hist_vals, single_hist = self._join_histograms(hist_vals, single_hist) if self.rank == 0: - sample_ens = qp.Ensemble(qp.hist, - data=dict(bins=self.zgrid, pdfs=np.atleast_2d(hist_vals))) - qp_d = qp.Ensemble(qp.hist, - data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist))) - self.add_data('output', sample_ens) - self.add_data('single_NZ', qp_d) - - def _process_chunk(self, start, end, test_data, first, bootstrap_matrix, single_hist, hist_vals): + sample_ens = qp.Ensemble( + qp.hist, data=dict(bins=self.zgrid, pdfs=np.atleast_2d(hist_vals)) + ) + qp_d = qp.Ensemble( + qp.hist, data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist)) + ) + self.add_data("output", sample_ens) + self.add_data("single_NZ", qp_d) + + def _process_chunk( + self, start, end, test_data, _first, bootstrap_matrix, single_hist, hist_vals + ): zb = test_data.ancil[self.config.point_estimate] single_hist += np.histogram(zb, bins=self.zgrid)[0] for i in range(self.config.nsamples): bootstrap_indeces = bootstrap_matrix[:, i] # Neither all of the bootstrap_draws are in this chunk nor the index starts at "start" - mask = (bootstrap_indeces>=start) & (bootstrap_indeces= start) & (bootstrap_indeces < end) bootstrap_indeces = bootstrap_indeces[mask] - start zarr = zb[bootstrap_indeces] hist_vals[i] += np.histogram(zarr, bins=self.zgrid)[0] - diff --git a/src/rail/estimation/algos/random_gauss.py b/src/rail/estimation/algos/random_gauss.py index ea60439a..0654a44f 100644 --- a/src/rail/estimation/algos/random_gauss.py +++ b/src/rail/estimation/algos/random_gauss.py @@ -14,38 +14,41 @@ class RandomGaussInformer(CatInformer): - """Placeholder Informer - """ + """Placeholder Informer""" - name = 'RandomGaussInformer' + name = "RandomGaussInformer" config_options = CatInformer.config_options.copy() def __init__(self, args, comm=None): CatInformer.__init__(self, args, comm=comm) def run(self): - self.add_data('model', np.array([None])) + self.add_data("model", np.array([None])) class RandomGaussEstimator(CatEstimator): - """Random CatEstimator - """ + """Random CatEstimator""" - name = 'RandomGaussEstimator' - inputs = [('input', TableHandle)] + name = "RandomGaussEstimator" + inputs = [("input", TableHandle)] config_options = CatEstimator.config_options.copy() - config_options.update(rand_width=Param(float, 0.025, "ad hock width of PDF"), - rand_zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), - rand_zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), - nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), - seed=Param(int, 87, msg="random seed"), - column_name=Param(str, "mag_i_lsst", - msg="name of a column that has the "\ - "correct number of galaxies to find length of")) + config_options.update( + rand_width=Param(float, 0.025, "ad hock width of PDF"), + rand_zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), + rand_zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), + nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), + seed=Param(int, 87, msg="random seed"), + column_name=Param( + str, + "mag_i_lsst", + msg="name of a column that has the " + "correct number of galaxies to find length of", + ), + ) def __init__(self, args, comm=None): - """ Constructor: - Do CatEstimator specific initialization """ + """Constructor: + Do CatEstimator specific initialization""" CatEstimator.__init__(self, args, comm=comm) self.zgrid = None @@ -56,10 +59,17 @@ def _process_chunk(self, start, end, data, first): rng = np.random.default_rng(seed=self.config.seed + start) zmode = np.round(rng.uniform(0.0, self.config.rand_zmax, numzs), 3) widths = self.config.rand_width * (1.0 + zmode) - self.zgrid = np.linspace(self.config.rand_zmin, self.config.rand_zmax, self.config.nzbins) + self.zgrid = np.linspace( + self.config.rand_zmin, self.config.rand_zmax, self.config.nzbins + ) for i in range(numzs): pdf.append(norm.pdf(self.zgrid, zmode[i], widths[i])) - qp_d = qp.Ensemble(qp.stats.norm, data=dict(loc=np.expand_dims(zmode, -1), # pylint: disable=no-member - scale=np.expand_dims(widths, -1))) + qp_d = qp.Ensemble( + qp.stats.norm, # pylint: disable=no-member + data=dict( + loc=np.expand_dims(zmode, -1), # pylint: disable=no-member + scale=np.expand_dims(widths, -1), + ), + ) qp_d.set_ancil(dict(zmode=zmode)) self._do_chunk_output(qp_d, start, end, first) diff --git a/src/rail/estimation/algos/train_z.py b/src/rail/estimation/algos/train_z.py index f9bc5c81..b93cb481 100644 --- a/src/rail/estimation/algos/train_z.py +++ b/src/rail/estimation/algos/train_z.py @@ -6,10 +6,10 @@ """ import numpy as np -from ceci.config import StageParameter as Param -from rail.estimation.estimator import CatEstimator, CatInformer -from rail.core.common_params import SHARED_PARAMS import qp +from rail.estimation.estimator import CatEstimator +from rail.estimation.informer import CatInformer +from rail.core.common_params import SHARED_PARAMS class trainZmodel: @@ -17,6 +17,7 @@ class trainZmodel: Temporary class to store the single trainZ pdf for trained model. Given how simple this is to compute, this seems like overkill. """ + def __init__(self, zgrid, pdf, zmode): self.zgrid = zgrid self.pdf = pdf @@ -24,24 +25,25 @@ def __init__(self, zgrid, pdf, zmode): class TrainZInformer(CatInformer): - """Train an Estimator which returns a global PDF for all galaxies - """ + """Train an Estimator which returns a global PDF for all galaxies""" - name = 'TrainZInformer' + name = "TrainZInformer" config_options = CatInformer.config_options.copy() - config_options.update(zmin=SHARED_PARAMS, - zmax=SHARED_PARAMS, - nzbins=SHARED_PARAMS, - redshift_col=SHARED_PARAMS) + config_options.update( + zmin=SHARED_PARAMS, + zmax=SHARED_PARAMS, + nzbins=SHARED_PARAMS, + redshift_col=SHARED_PARAMS, + ) def __init__(self, args, comm=None): CatInformer.__init__(self, args, comm=comm) def run(self): if self.config.hdf5_groupname: - training_data = self.get_data('input')[self.config.hdf5_groupname] + training_data = self.get_data("input")[self.config.hdf5_groupname] else: # pragma: no cover - training_data = self.get_data('input') + training_data = self.get_data("input") zbins = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1) speczs = np.sort(training_data[self.config.redshift_col]) train_pdf, _ = np.histogram(speczs, zbins) @@ -53,18 +55,15 @@ def run(self): train_pdf = train_pdf / norm zgrid = midpoints self.model = trainZmodel(zgrid, train_pdf, zmode) - self.add_data('model', self.model) + self.add_data("model", self.model) class TrainZEstimator(CatEstimator): - """CatEstimator which returns a global PDF for all galaxies - """ + """CatEstimator which returns a global PDF for all galaxies""" - name = 'TrainZEstimator' + name = "TrainZEstimator" config_options = CatEstimator.config_options.copy() - config_options.update(zmin=SHARED_PARAMS, - zmax=SHARED_PARAMS, - nzbins=SHARED_PARAMS) + config_options.update(zmin=SHARED_PARAMS, zmax=SHARED_PARAMS, nzbins=SHARED_PARAMS) def __init__(self, args, comm=None): self.zgrid = None @@ -83,7 +82,9 @@ def open_model(self, **kwargs): def _process_chunk(self, start, end, data, first): test_size = end - start zmode = np.repeat(self.zmode, test_size) - qp_d = qp.Ensemble(qp.interp, - data=dict(xvals=self.zgrid, yvals=np.tile(self.train_pdf, (test_size, 1)))) + qp_d = qp.Ensemble( + qp.interp, + data=dict(xvals=self.zgrid, yvals=np.tile(self.train_pdf, (test_size, 1))), + ) qp_d.set_ancil(dict(zmode=zmode)) self._do_chunk_output(qp_d, start, end, first) diff --git a/src/rail/estimation/algos/uniform_binning.py b/src/rail/estimation/algos/uniform_binning.py index c3eb7962..fde75b97 100644 --- a/src/rail/estimation/algos/uniform_binning.py +++ b/src/rail/estimation/algos/uniform_binning.py @@ -8,52 +8,67 @@ from rail.estimation.classifier import PZClassifier from rail.core.data import TableHandle + class UniformBinningClassifier(PZClassifier): """Classifier that simply assign tomographic bins based on point estimate according to SRD""" - name = 'UniformBinningClassifier' + name = "UniformBinningClassifier" config_options = PZClassifier.config_options.copy() config_options.update( - id_name=Param(str, "", msg="Column name for the object ID in the input data, if empty the row index is used as the ID."), - point_estimate=Param(str, 'zmode', msg="Which point estimate to use"), - zbin_edges=Param(list, [], msg="The tomographic redshift bin edges. If this is given (contains two or more entries), all settings below will be ignored."), + id_name=Param( + str, + "", + msg="Column name for the object ID in the input data, if empty the row index is used as the ID.", + ), + point_estimate=Param(str, "zmode", msg="Which point estimate to use"), + zbin_edges=Param( + list, + [], + msg="The tomographic redshift bin edges." + "If this is given (contains two or more entries), all settings below will be ignored.", + ), zmin=Param(float, 0.0, msg="Minimum redshift of the sample"), zmax=Param(float, 3.0, msg="Maximum redshift of the sample"), nbins=Param(int, 5, msg="Number of tomographic bins"), no_assign=Param(int, -99, msg="Value for no assignment flag"), - ) - outputs = [('output', TableHandle)] + ) + outputs = [("output", TableHandle)] def __init__(self, args, comm=None): PZClassifier.__init__(self, args, comm=comm) def run(self): - test_data = self.get_data('input') + test_data = self.get_data("input") npdf = test_data.npdf - + try: zb = test_data.ancil[self.config.point_estimate] - except KeyError: - raise KeyError(f"{self.config.point_estimate} is not contained in the data ancil, you will need to compute it explicitly.") + except KeyError as msg: + raise KeyError( + f"{self.config.point_estimate} is not contained in the data ancil, " + "you will need to compute it explicitly." + ) from msg # binning options - if len(self.config.zbin_edges)>=2: + if len(self.config.zbin_edges) >= 2: # this overwrites all other key words # linear binning defined by zmin, zmax, and nbins bin_index = np.digitize(zb, self.config.zbin_edges) # assign -99 to objects not in any bin: - bin_index[bin_index==0]=self.config.no_assign - bin_index[bin_index==len(self.config.zbin_edges)]=self.config.no_assign + bin_index[bin_index == 0] = self.config.no_assign + bin_index[bin_index == len(self.config.zbin_edges)] = self.config.no_assign else: # linear binning defined by zmin, zmax, and nbins - bin_index = np.digitize(zb, np.linspace(self.config.zmin, self.config.zmax, self.config.nbins+1)) + bin_index = np.digitize( + zb, + np.linspace(self.config.zmin, self.config.zmax, self.config.nbins + 1), + ) # assign -99 to objects not in any bin: - bin_index[bin_index==0]=self.config.no_assign - bin_index[bin_index==(self.config.nbins+1)]=self.config.no_assign - - + bin_index[bin_index == 0] = self.config.no_assign + bin_index[bin_index == (self.config.nbins + 1)] = self.config.no_assign + if self.config.id_name != "": # below is commented out and replaced by a redundant line # because the data doesn't have ID yet @@ -62,7 +77,7 @@ def run(self): elif self.config.id_name == "": # ID set to row index obj_id = np.arange(npdf) - self.config.id_name="row_index" - + self.config.id_name = "row_index" + class_id = {self.config.id_name: obj_id, "class_id": bin_index} - self.add_data('output', class_id) \ No newline at end of file + self.add_data("output", class_id) diff --git a/src/rail/estimation/algos/var_inf.py b/src/rail/estimation/algos/var_inf.py index 9e8b9ff1..b4abf2f2 100644 --- a/src/rail/estimation/algos/var_inf.py +++ b/src/rail/estimation/algos/var_inf.py @@ -11,21 +11,20 @@ from rail.estimation.informer import PzInformer from rail.core.data import QPHandle -TEENY = 1.e-15 +TEENY = 1.0e-15 class VarInfStackInformer(PzInformer): - """Placeholder Informer - """ + """Placeholder Informer""" - name = 'VarInfStackInformer' + name = "VarInfStackInformer" config_options = PzInformer.config_options.copy() def __init__(self, args, comm=None): PzInformer.__init__(self, args, comm=comm) def run(self): - self.add_data('model', np.array([None])) + self.add_data("model", np.array([None])) class VarInfStackSummarizer(PZSummarizer): @@ -47,17 +46,22 @@ class VarInfStackSummarizer(PZSummarizer): number of samples used in dirichlet to determind error bar """ - name = 'VarInfStackSummarizer' + name = "VarInfStackSummarizer" config_options = PZSummarizer.config_options.copy() - config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), - zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), - nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), - seed=Param(int, 87, msg="random seed"), - niter=Param(int, 100, msg="The number of iterations in the variational inference"), - nsamples=Param(int, 500, msg="The number of samples used in dirichlet uncertainty")) - inputs = [('input', QPHandle)] - outputs = [('output', QPHandle), - ('single_NZ', QPHandle)] + config_options.update( + zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), + zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), + nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), + seed=Param(int, 87, msg="random seed"), + niter=Param( + int, 100, msg="The number of iterations in the variational inference" + ), + nsamples=Param( + int, 500, msg="The number of samples used in dirichlet uncertainty" + ), + ) + inputs = [("input", QPHandle)] + outputs = [("output", QPHandle), ("single_NZ", QPHandle)] def __init__(self, args, comm=None): PZSummarizer.__init__(self, args, comm=comm) @@ -67,16 +71,15 @@ def run(self): # Redefining the chunk size so that all of the data is distributed at once in the # nodes. This would fill all the memory if not enough nodes are allocated - input_data = self.get_handle('input', allow_missing=True) + input_data = self.get_handle("input", allow_missing=True) try: self.config.hdf5_groupname - except: + except Exception: self.config.hdf5_groupname = None input_length = input_data.size(groupname=self.config.hdf5_groupname) - self.config.chunk_size = np.ceil(input_length/self.size) - + self.config.chunk_size = np.ceil(input_length / self.size) - iterator = self.input_iterator('input') + iterator = self.input_iterator("input") self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins) first = True for s, e, test_data in iterator: @@ -89,17 +92,24 @@ def run(self): # qp_d = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=alpha_trace)) # instead, sample and save the samples rng = np.random.default_rng(seed=self.config.seed) - sample_pz = dirichlet.rvs(alpha_trace, size=self.config.nsamples, random_state=rng) - qp_d = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=alpha_trace)) - - sample_ens = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=sample_pz)) - self.add_data('output', sample_ens) - self.add_data('single_NZ', qp_d) - - - def _process_chunk(self, start, end, test_data, first): - if not first: #pragma: no cover - raise ValueError(f"This algorithm needs all data in memory at once, increase nprocess or chunk size.") + sample_pz = dirichlet.rvs( + alpha_trace, size=self.config.nsamples, random_state=rng + ) + qp_d = qp.Ensemble( + qp.interp, data=dict(xvals=self.zgrid, yvals=alpha_trace) + ) + + sample_ens = qp.Ensemble( + qp.interp, data=dict(xvals=self.zgrid, yvals=sample_pz) + ) + self.add_data("output", sample_ens) + self.add_data("single_NZ", qp_d) + + def _process_chunk(self, _start, _end, test_data, first): + if not first: # pragma: no cover + raise ValueError( + "This algorithm needs all data in memory at once, increase nprocess or chunk size." + ) # Initiallizing arrays alpha_trace = np.ones(len(self.zgrid)) @@ -107,7 +117,9 @@ def _process_chunk(self, start, end, test_data, first): pdf_vals = test_data.pdf(self.zgrid) log_pdf_vals = np.log(np.array(pdf_vals) + TEENY) for _ in range(self.config.niter): - dig = np.array([digamma(kk) - digamma(np.sum(alpha_trace)) for kk in alpha_trace]) + dig = np.array( + [digamma(kk) - digamma(np.sum(alpha_trace)) for kk in alpha_trace] + ) matrix_grid = np.exp(dig + log_pdf_vals) gamma_matrix = np.array([kk / np.sum(kk) for kk in matrix_grid]) for kk in matrix_grid: @@ -118,6 +130,4 @@ def _process_chunk(self, start, end, test_data, first): else: nk = nk_partial alpha_trace = nk + init_trace - return(alpha_trace) - - + return alpha_trace diff --git a/src/rail/estimation/classifier.py b/src/rail/estimation/classifier.py index 3cddaa5e..8d7af24d 100644 --- a/src/rail/estimation/classifier.py +++ b/src/rail/estimation/classifier.py @@ -5,33 +5,31 @@ from rail.core.stage import RailStage -class CatClassifier(RailStage): #pragma: no cover +class CatClassifier(RailStage): # pragma: no cover """The base class for assigning classes to catalogue-like table. Classifier uses a generic "model", the details of which depends on the sub-class. CatClassifier take as "input" a catalogue-like table, assign each object into - a tomographic bin, and provide as "output" a tabular data which can be appended + a tomographic bin, and provide as "output" a tabular data which can be appended to the catalogue. """ - - name='CatClassifier' + + name = "CatClassifier" config_options = RailStage.config_options.copy() config_options.update(chunk_size=10000, hdf5_groupname=str) - inputs = [('model', ModelHandle), - ('input', TableHandle)] - outputs = [('output', TableHandle)] - + inputs = [("model", ModelHandle), ("input", TableHandle)] + outputs = [("output", TableHandle)] + def __init__(self, args, comm=None): """Initialize Classifier""" RailStage.__init__(self, args, comm=comm) self._output_handle = None self.model = None - if not isinstance(args, dict): #pragma: no cover + if not isinstance(args, dict): # pragma: no cover args = vars(args) self.open_model(**args) - - + def open_model(self, **kwargs): """Load the model and/or attach it to this Classifier @@ -47,21 +45,20 @@ def open_model(self, **kwargs): self.model : `object` The object encapsulating the trained model. """ - model = kwargs.get('model', None) - if model is None or model == 'None': + model = kwargs.get("model", None) + if model is None or model == "None": self.model = None return self.model if isinstance(model, str): - self.model = self.set_data('model', data=None, path=model) - self.config['model'] = model + self.model = self.set_data("model", data=None, path=model) + self.config["model"] = model return self.model if isinstance(model, ModelHandle): if model.has_path: - self.config['model'] = model.path - self.model = self.set_data('model', model) + self.config["model"] = model.path + self.model = self.set_data("model", model) return self.model - - + def classify(self, input_data): """The main run method for the classifier, should be implemented in the specific subclass. @@ -87,30 +84,29 @@ def classify(self, input_data): output: `dict` Class assignment for each galaxy. """ - self.set_data('input', input_data) + self.set_data("input", input_data) self.run() self.finalize() - return self.get_handle('output') - + return self.get_handle("output") + - class PZClassifier(RailStage): """The base class for assigning classes (tomographic bins) to per-galaxy PZ estimates PZClassifier take as "input" a `qp.Ensemble` with per-galaxy PDFs, and provide as "output" a tabular data which can be appended to the catalogue. """ - - name='PZClassifier' + + name = "PZClassifier" config_options = RailStage.config_options.copy() config_options.update(chunk_size=10000) - inputs = [('input', QPHandle)] - outputs = [('output', TableHandle)] - + inputs = [("input", QPHandle)] + outputs = [("output", TableHandle)] + def __init__(self, args, comm=None): """Initialize Classifier""" RailStage.__init__(self, args, comm=comm) - + def classify(self, input_data): """The main run method for the classifier, should be implemented in the specific subclass. @@ -136,7 +132,7 @@ def classify(self, input_data): output: `dict` Class assignment for each galaxy. """ - self.set_data('input', input_data) + self.set_data("input", input_data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") diff --git a/src/rail/estimation/estimator.py b/src/rail/estimation/estimator.py index a00c4523..577bf5c3 100644 --- a/src/rail/estimation/estimator.py +++ b/src/rail/estimation/estimator.py @@ -7,9 +7,10 @@ from rail.core.data import TableHandle, QPHandle, ModelHandle from rail.core.stage import RailStage -from rail.estimation.informer import CatInformer from rail.core.point_estimation import PointEstimationMixin -# for backwards compatibility + +# for backwards compatibility, to avoid break stuff that imports it from here +from rail.estimation.informer import CatInformer # pylint: disable=unused-import class CatEstimator(RailStage, PointEstimationMixin): @@ -23,15 +24,15 @@ class CatEstimator(RailStage, PointEstimationMixin): """ - name = 'CatEstimator' + name = "CatEstimator" config_options = RailStage.config_options.copy() config_options.update( chunk_size=10000, - hdf5_groupname=SHARED_PARAMS['hdf5_groupname'], - calculated_point_estimates=SHARED_PARAMS['calculated_point_estimates']) - inputs = [('model', ModelHandle), - ('input', TableHandle)] - outputs = [('output', QPHandle)] + hdf5_groupname=SHARED_PARAMS["hdf5_groupname"], + calculated_point_estimates=SHARED_PARAMS["calculated_point_estimates"], + ) + inputs = [("model", ModelHandle), ("input", TableHandle)] + outputs = [("output", QPHandle)] def __init__(self, args, comm=None): """Initialize Estimator""" @@ -54,18 +55,18 @@ def open_model(self, **kwargs): self.model : `object` The object encapsulating the trained model. """ - model = kwargs.get('model', None) - if model is None or model == 'None': + model = kwargs.get("model", None) + if model is None or model == "None": self.model = None return self.model if isinstance(model, str): - self.model = self.set_data('model', data=None, path=model) - self.config['model'] = model + self.model = self.set_data("model", data=None, path=model) + self.config["model"] = model return self.model if isinstance(model, ModelHandle): if model.has_path: - self.config['model'] = model.path - self.model = self.set_data('model', model) + self.config["model"] = model.path + self.model = self.set_data("model", model) return self.model def estimate(self, input_data): @@ -92,16 +93,15 @@ def estimate(self, input_data): output: `QPHandle` Handle providing access to QP ensemble with output data """ - self.set_data('input', input_data) + self.set_data("input", input_data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") def run(self): - self.open_model(**self.config) - iterator = self.input_iterator('input') + iterator = self.input_iterator("input") first = True self._initialize_run() self._output_handle = None @@ -121,12 +121,16 @@ def _finalize_run(self): self._output_handle.finalize_write() def _process_chunk(self, start, end, data, first): - raise NotImplementedError(f"{self.name}._process_chunk is not implemented") # pragma: no cover + raise NotImplementedError( + f"{self.name}._process_chunk is not implemented" + ) # pragma: no cover def _do_chunk_output(self, qp_dstn, start, end, first): qp_dstn = self.calculate_point_estimates(qp_dstn) if first: - self._output_handle = self.add_handle('output', data=qp_dstn) - self._output_handle.initialize_write(self._input_length, communicator=self.comm) + self._output_handle = self.add_handle("output", data=qp_dstn) + self._output_handle.initialize_write( + self._input_length, communicator=self.comm + ) self._output_handle.set_data(qp_dstn, partial=True) self._output_handle.write_chunk(start, end) diff --git a/src/rail/estimation/informer.py b/src/rail/estimation/informer.py index b47a5d27..a9075a32 100644 --- a/src/rail/estimation/informer.py +++ b/src/rail/estimation/informer.py @@ -9,8 +9,9 @@ from rail.core.data import TableHandle, QPHandle, ModelHandle from rail.core.stage import RailStage + class CatInformer(RailStage): - """The base class for informing models used to make photo-z data products + """The base class for informing models used to make photo-z data products from catalog-like inputs (i.e., tables with fluxes in photometric bands among the set of columns). @@ -26,13 +27,13 @@ class CatInformer(RailStage): They take as "input" catalog-like tabular data, which is used to "inform" the model. """ - name = 'CatInformer' + name = "CatInformer" config_options = RailStage.config_options.copy() - inputs = [('input', TableHandle)] - outputs = [('model', ModelHandle)] + inputs = [("input", TableHandle)] + outputs = [("model", ModelHandle)] def __init__(self, args, comm=None): - """Initialize Informer that can inform models for redshift estimation """ + """Initialize Informer that can inform models for redshift estimation""" RailStage.__init__(self, args, comm=comm) self.model = None @@ -60,10 +61,11 @@ def inform(self, training_data): model : ModelHandle Handle providing access to trained model """ - self.set_data('input', training_data) + self.set_data("input", training_data) self.run() self.finalize() - return self.get_handle('model') + return self.get_handle("model") + class PzInformer(RailStage): """The base class for informing models used to make photo-z data products from @@ -81,13 +83,13 @@ class PzInformer(RailStage): They take as "input" a qp.Ensemble of per-galaxy p(z) data, which is used to "inform" the model. """ - name = 'PzInformer' + name = "PzInformer" config_options = RailStage.config_options.copy() - inputs = [('input', QPHandle)] - outputs = [('model', ModelHandle)] + inputs = [("input", QPHandle)] + outputs = [("model", ModelHandle)] def __init__(self, args, comm=None): - """Initialize Informer that can inform models for redshift estimation """ + """Initialize Informer that can inform models for redshift estimation""" RailStage.__init__(self, args, comm=comm) self.model = None @@ -115,7 +117,7 @@ def inform(self, training_data): model : ModelHandle Handle providing access to trained model """ - self.set_data('input', training_data) + self.set_data("input", training_data) self.run() self.finalize() - return self.get_handle('model') + return self.get_handle("model") diff --git a/src/rail/estimation/summarizer.py b/src/rail/estimation/summarizer.py index c3e01b12..6b9f9040 100644 --- a/src/rail/estimation/summarizer.py +++ b/src/rail/estimation/summarizer.py @@ -1,10 +1,11 @@ """ Abstract base classes defining Summarizers of the redshift distribution of an ensemble of galaxies """ +import numpy as np + from rail.core.data import QPHandle, TableHandle, ModelHandle from rail.core.stage import RailStage -from rail.estimation.informer import PzInformer # for backwards compatibility @@ -18,11 +19,11 @@ class CatSummarizer(RailStage): provide as "output" a QPEnsemble, with per-ensemble n(z). """ - name = 'CatSummarizer' + name = "CatSummarizer" config_options = RailStage.config_options.copy() config_options.update(chunk_size=10000) - inputs = [('input', TableHandle)] - outputs = [('output', QPHandle)] + inputs = [("input", TableHandle)] + outputs = [("output", QPHandle)] def __init__(self, args, comm=None): """Initialize Summarizer""" @@ -53,10 +54,10 @@ def summarize(self, input_data): output: `qp.Ensemble` Ensemble with n(z), and any ancilary data """ - self.set_data('input', input_data) + self.set_data("input", input_data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") class PZSummarizer(RailStage): @@ -66,12 +67,11 @@ class PZSummarizer(RailStage): provide as "output" a QPEnsemble, with per-ensemble n(z). """ - name = 'PZtoNZSummarizer' + name = "PZtoNZSummarizer" config_options = RailStage.config_options.copy() config_options.update(chunk_size=10000) - inputs = [('model', ModelHandle), - ('input', QPHandle)] - outputs = [('output', QPHandle)] + inputs = [("model", ModelHandle), ("input", QPHandle)] + outputs = [("output", QPHandle)] def __init__(self, args, comm=None): """Initialize Estimator that can sample galaxy data.""" @@ -102,43 +102,47 @@ def summarize(self, input_data): output: `qp.Ensemble` Ensemble with n(z), and any ancilary data """ - self.set_data('input', input_data) + self.set_data("input", input_data) self.run() self.finalize() - return self.get_handle('output') - + return self.get_handle("output") def _broadcast_bootstrap_matrix(self): - import numpy as np rng = np.random.default_rng(seed=self.config.seed) # Only one of the nodes needs to produce the bootstrap indices ngal = self._input_length if self.rank == 0: - bootstrap_matrix = rng.integers(low=0, high=ngal, size=(ngal,self.config.nsamples)) + bootstrap_matrix = rng.integers( + low=0, high=ngal, size=(ngal, self.config.nsamples) + ) else: # pragma: no cover bootstrap_matrix = None if self.comm is not None: # pragma: no cover self.comm.Barrier() - bootstrap_matrix = self.comm.bcast(bootstrap_matrix, root = 0) + bootstrap_matrix = self.comm.bcast(bootstrap_matrix, root=0) return bootstrap_matrix - def _join_histograms(self, bvals, yvals):#pragma: no cover + def _join_histograms(self, bvals, yvals): # pragma: no cover bvals_r = self.comm.reduce(bvals) yvals_r = self.comm.reduce(yvals) - return(bvals_r, yvals_r) + return (bvals_r, yvals_r) + class SZPZSummarizer(RailStage): """The base class for classes that use two sets of data: a photometry sample with spec-z values, and a photometry sample with unknown redshifts, e.g. minisom_som and outputs a QP Ensemble with bootstrap realization of the N(z) distribution """ - name = 'SZPZtoNZSummarizer' + + name = "SZPZtoNZSummarizer" config_options = RailStage.config_options.copy() config_options.update(chunk_size=10000) - inputs = [('input', TableHandle), - ('spec_input', TableHandle), - ('model', ModelHandle)] - outputs = [('output', QPHandle)] + inputs = [ + ("input", TableHandle), + ("spec_input", TableHandle), + ("model", ModelHandle), + ] + outputs = [("output", QPHandle)] def __init__(self, args, comm=None): """Initialize Estimator that can sample galaxy data.""" @@ -164,18 +168,18 @@ def open_model(self, **kwargs): self.model : `object` The object encapsulating the trained model. """ - model = kwargs.get('model', None) - if model is None or model == 'None': # pragma: no cover + model = kwargs.get("model", None) + if model is None or model == "None": # pragma: no cover self.model = None return self.model if isinstance(model, str): - self.model = self.set_data('model', data=None, path=model) - self.config['model'] = model + self.model = self.set_data("model", data=None, path=model) + self.config["model"] = model return self.model if isinstance(model, ModelHandle): if model.has_path: - self.config['model'] = model.path - self.model = self.set_data('model', model) + self.config["model"] = model.path + self.model = self.set_data("model", model) return self.model def summarize(self, input_data, spec_data): @@ -203,8 +207,8 @@ def summarize(self, input_data, spec_data): output: `qp.Ensemble` Ensemble with n(z), and any ancilary data """ - self.set_data('input', input_data) - self.set_data('spec_input', spec_data) + self.set_data("input", input_data) + self.set_data("spec_input", spec_data) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") diff --git a/src/rail/evaluation/evaluator.py b/src/rail/evaluation/evaluator.py index 50679077..ffd4c72c 100644 --- a/src/rail/evaluation/evaluator.py +++ b/src/rail/evaluation/evaluator.py @@ -7,30 +7,36 @@ import numpy as np from ceci.config import StageParameter as Param +from qp.metrics.pit import PIT from rail.core.data import Hdf5Handle, QPHandle from rail.core.stage import RailStage from rail.core.common_params import SHARED_PARAMS from rail.evaluation.metrics.cdeloss import CDELoss -from qp.metrics.pit import PIT -from rail.evaluation.metrics.pointestimates import PointSigmaIQR, PointBias, PointOutlierRate, PointSigmaMAD +from rail.evaluation.metrics.pointestimates import ( + PointSigmaIQR, + PointBias, + PointOutlierRate, + PointSigmaMAD, +) class Evaluator(RailStage): - """Evaluate the performance of a photo-Z estimator """ + """Evaluate the performance of a photo-Z estimator""" - name = 'Evaluator' + name = "Evaluator" config_options = RailStage.config_options.copy() - config_options.update(zmin=Param(float, 0., msg="min z for grid"), - zmax=Param(float, 3.0, msg="max z for grid"), - nzbins=Param(int, 301, msg="# of bins in zgrid"), - pit_metrics=Param(str, 'all', msg='PIT-based metrics to include'), - point_metrics=Param(str, 'all', msg='Point-estimate metrics to include'), - do_cde=Param(bool, True, msg='Evaluate CDE Metric'), - redshift_col=SHARED_PARAMS) - inputs = [('input', QPHandle), - ('truth', Hdf5Handle)] - outputs = [('output', Hdf5Handle)] + config_options.update( + zmin=Param(float, 0.0, msg="min z for grid"), + zmax=Param(float, 3.0, msg="max z for grid"), + nzbins=Param(int, 301, msg="# of bins in zgrid"), + pit_metrics=Param(str, "all", msg="PIT-based metrics to include"), + point_metrics=Param(str, "all", msg="Point-estimate metrics to include"), + do_cde=Param(bool, True, msg="Evaluate CDE Metric"), + redshift_col=SHARED_PARAMS, + ) + inputs = [("input", QPHandle), ("truth", Hdf5Handle)] + outputs = [("output", Hdf5Handle)] def __init__(self, args, comm=None): """Initialize Evaluator""" @@ -61,14 +67,14 @@ def evaluate(self, data, truth): The evaluation metrics """ - self.set_data('input', data) - self.set_data('truth', truth) + self.set_data("input", data) + self.set_data("truth", truth) self.run() self.finalize() - return self.get_handle('output') + return self.get_handle("output") def run(self): - """ Run method + """Run method Evaluate all the metrics and put them into a table @@ -79,25 +85,25 @@ def run(self): Puts the data into the data store under this stages 'output' tag """ - pz_data = self.get_data('input') - z_true = self.get_data('truth')[self.config.redshift_col] - zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins+1) + pz_data = self.get_data("input") + z_true = self.get_data("truth")[self.config.redshift_col] + zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1) # Create an instance of the PIT class pitobj = PIT(pz_data, z_true) # Build reference dictionary of the PIT meta-metrics from this PIT instance PIT_METRICS = dict( - AD=getattr(pitobj, 'evaluate_PIT_anderson_ksamp'), - CvM=getattr(pitobj, 'evaluate_PIT_CvM'), - KS=getattr(pitobj, 'evaluate_PIT_KS'), - OutRate=getattr(pitobj, 'evaluate_PIT_outlier_rate'), + AD=getattr(pitobj, "evaluate_PIT_anderson_ksamp"), + CvM=getattr(pitobj, "evaluate_PIT_CvM"), + KS=getattr(pitobj, "evaluate_PIT_KS"), + OutRate=getattr(pitobj, "evaluate_PIT_outlier_rate"), ) # Parse the input configuration to determine which meta-metrics should be calculated - if self.config.pit_metrics == 'all': + if self.config.pit_metrics == "all": pit_metrics = list(PIT_METRICS.keys()) - else: #pragma: no cover + else: # pragma: no cover pit_metrics = self.config.pit_metrics.split() # Evaluate each of the requested meta-metrics, and store the result in `out_table` @@ -108,19 +114,25 @@ def run(self): # The result objects of some meta-metrics are bespoke scipy objects with inconsistent fields. # Here we do our best to store the relevant fields in `out_table`. if isinstance(value, list): # pragma: no cover - out_table[f'PIT_{pit_metric}'] = value + out_table[f"PIT_{pit_metric}"] = value else: - out_table[f'PIT_{pit_metric}_stat'] = [getattr(value, 'statistic', None)] - out_table[f'PIT_{pit_metric}_pval'] = [getattr(value, 'p_value', None)] - out_table[f'PIT_{pit_metric}_significance_level'] = [getattr(value, 'significance_level', None)] - - POINT_METRICS = dict(SimgaIQR=PointSigmaIQR, - Bias=PointBias, - OutlierRate=PointOutlierRate, - SigmaMAD=PointSigmaMAD) - if self.config.point_metrics == 'all': + out_table[f"PIT_{pit_metric}_stat"] = [ + getattr(value, "statistic", None) + ] + out_table[f"PIT_{pit_metric}_pval"] = [getattr(value, "p_value", None)] + out_table[f"PIT_{pit_metric}_significance_level"] = [ + getattr(value, "significance_level", None) + ] + + POINT_METRICS = dict( + SimgaIQR=PointSigmaIQR, + Bias=PointBias, + OutlierRate=PointOutlierRate, + SigmaMAD=PointSigmaMAD, + ) + if self.config.point_metrics == "all": point_metrics = list(POINT_METRICS.keys()) - else: #pragma: no cover + else: # pragma: no cover point_metrics = self.config.point_metrics.split() z_mode = None @@ -128,13 +140,15 @@ def run(self): if z_mode is None: z_mode = np.squeeze(pz_data.mode(grid=zgrid)) value = POINT_METRICS[point_metric](z_mode, z_true).evaluate() - out_table[f'POINT_{point_metric}'] = [value] + out_table[f"POINT_{point_metric}"] = [value] if self.config.do_cde: value = CDELoss(pz_data, zgrid, z_true).evaluate() - out_table['CDE_stat'] = [value.statistic] - out_table['CDE_pval'] = [value.p_value] - - # Converting any possible None to NaN to write it - out_table_to_write = {key: np.array(val).astype(float) for key, val in out_table.items()} - self.add_data('output', out_table_to_write) + out_table["CDE_stat"] = [value.statistic] + out_table["CDE_pval"] = [value.p_value] + + # Converting any possible None to NaN to write it + out_table_to_write = { + key: np.array(val).astype(float) for key, val in out_table.items() + } + self.add_data("output", out_table_to_write) diff --git a/src/rail/evaluation/metrics/base.py b/src/rail/evaluation/metrics/base.py index 510a3a27..5e67c121 100644 --- a/src/rail/evaluation/metrics/base.py +++ b/src/rail/evaluation/metrics/base.py @@ -1,5 +1,6 @@ class MetricEvaluator: - """ A superclass for metrics evaluations""" + """A superclass for metrics evaluations""" + def __init__(self, qp_ens): """Class constructor. Parameters @@ -9,7 +10,7 @@ def __init__(self, qp_ens): """ self._qp_ens = qp_ens - def evaluate(self): #pragma: no cover + def evaluate(self): # pragma: no cover """ Evaluates the metric a function of the truth and prediction diff --git a/src/rail/evaluation/metrics/cdeloss.py b/src/rail/evaluation/metrics/cdeloss.py index 6dd372ab..f51f3f8e 100644 --- a/src/rail/evaluation/metrics/cdeloss.py +++ b/src/rail/evaluation/metrics/cdeloss.py @@ -1,10 +1,11 @@ import numpy as np -from .base import MetricEvaluator from rail.evaluation.stats_groups import stat_and_pval +from .base import MetricEvaluator class CDELoss(MetricEvaluator): - """ Conditional density loss """ + """Conditional density loss""" + def __init__(self, qp_ens, zgrid, ztrue): """Class constructor""" super().__init__(qp_ens) @@ -23,7 +24,7 @@ def evaluate(self): """ # Calculate first term E[\int f*(z | X)^2 dz] - term1 = np.mean(np.trapz(self._pdfs ** 2, x=self._xvals)) + term1 = np.mean(np.trapz(self._pdfs**2, x=self._xvals)) # z bin closest to ztrue nns = [np.argmin(np.abs(self._xvals - z)) for z in self._ztrue] # Calculate second term E[f*(Z | X)] diff --git a/src/rail/evaluation/metrics/pointestimates.py b/src/rail/evaluation/metrics/pointestimates.py index 7d5c3fb0..6a5eaa86 100644 --- a/src/rail/evaluation/metrics/pointestimates.py +++ b/src/rail/evaluation/metrics/pointestimates.py @@ -2,6 +2,7 @@ from .base import MetricEvaluator + class PointStatsEz(MetricEvaluator): """Copied from PZDC1paper repo. Adapted to remove the cut based on magnitude.""" @@ -22,7 +23,7 @@ def __init__(self, pzvec, szvec): super().__init__(None) self.pzs = pzvec self.szs = szvec - ez = (pzvec - szvec) / (1. + szvec) + ez = (pzvec - szvec) / (1.0 + szvec) self.ez = ez def evaluate(self): @@ -42,7 +43,7 @@ def evaluate(self): sigma_IQR float: width of ez distribution for full sample sigma_IQR_magcut float: width of ez distribution for magcut sample """ - x75, x25 = np.percentile(self.ez, [75., 25.]) + x75, x25 = np.percentile(self.ez, [75.0, 25.0]) iqr = x75 - x25 sigma_iqr = iqr / 1.349 return sigma_iqr @@ -53,6 +54,7 @@ class PointBias(PointStatsEz): In keeping with the Science Book, this is just the median of the ez values """ + def evaluate(self): """ Returns: @@ -77,7 +79,7 @@ def evaluate(self): sig_iqr = PointSigmaIQR(self.pzs, self.szs).evaluate() threesig = 3.0 * sig_iqr cutcriterion = np.maximum(0.06, threesig) - mask = (np.fabs(self.ez) > cutcriterion) + mask = np.fabs(self.ez) > cutcriterion outlier = np.sum(mask) frac = float(outlier) / float(num) return frac diff --git a/src/rail/evaluation/stats_groups.py b/src/rail/evaluation/stats_groups.py index a2a98e85..53348c80 100644 --- a/src/rail/evaluation/stats_groups.py +++ b/src/rail/evaluation/stats_groups.py @@ -3,5 +3,7 @@ from collections import namedtuple # These generic mathematical metrics will be moved to qp at some point. -stat_and_pval = namedtuple('stat_and_pval', ['statistic', 'p_value']) -stat_crit_sig = namedtuple('stat_crit_sig', ['statistic', 'critical_values', 'significance_level']) +stat_and_pval = namedtuple("stat_and_pval", ["statistic", "p_value"]) +stat_crit_sig = namedtuple( + "stat_crit_sig", ["statistic", "critical_values", "significance_level"] +) diff --git a/src/rail/stages/__init__.py b/src/rail/stages/__init__.py index 7832f7d8..e6306567 100644 --- a/src/rail/stages/__init__.py +++ b/src/rail/stages/__init__.py @@ -1,4 +1,3 @@ - import rail from rail.core import RailEnv @@ -15,15 +14,17 @@ from rail.creation.degrader import * -#from rail.creation.degradation.spectroscopic_degraders import * + +# from rail.creation.degradation.spectroscopic_degraders import * from rail.creation.degradation.spectroscopic_selections import * from rail.creation.degradation.quantityCut import * from rail.creation.engine import * -#from rail.creation.engines.flowEngine import * -#from rail.creation.engines.galaxy_population_components import * -#from rail.creation.engines.dsps_photometry_creator import * -#from rail.creation.engines.dsps_sed_modeler import * + +# from rail.creation.engines.flowEngine import * +# from rail.creation.engines.galaxy_population_components import * +# from rail.creation.engines.dsps_photometry_creator import * +# from rail.creation.engines.dsps_sed_modeler import * from rail.evaluation.evaluator import Evaluator diff --git a/tests/cli/test_scripts.py b/tests/cli/test_scripts.py index 0e6f9fec..a2c8dc63 100644 --- a/tests/cli/test_scripts.py +++ b/tests/cli/test_scripts.py @@ -6,25 +6,29 @@ def test_render_nb(): nb_dir = "./tests/cli/" - nb_files = glob.glob(os.path.join(nb_dir,'*.ipynb')) - scripts.render_nb('docs', False, True, nb_files, skip=[]) - scripts.render_nb('docs', True, True, nb_files, skip=["./tests/cli/single_number.ipynb"]) - scripts.render_nb('docs', True, False, nb_files, skip=["./tests/cli/single_number.ipynb"]) + nb_files = glob.glob(os.path.join(nb_dir, "*.ipynb")) + scripts.render_nb("docs", False, True, nb_files, skip=[]) + scripts.render_nb( + "docs", True, True, nb_files, skip=["./tests/cli/single_number.ipynb"] + ) + scripts.render_nb( + "docs", True, False, nb_files, skip=["./tests/cli/single_number.ipynb"] + ) def test_clone_source(): - scripts.clone_source('..', GitMode.ssh, True, 'rail_packages.yml') - scripts.clone_source('..', GitMode.https, True, 'rail_packages.yml') - scripts.clone_source('..', GitMode.cli, True, 'rail_packages.yml') + scripts.clone_source("..", GitMode.ssh, True, "rail_packages.yml") + scripts.clone_source("..", GitMode.https, True, "rail_packages.yml") + scripts.clone_source("..", GitMode.cli, True, "rail_packages.yml") + - def test_update_source(): - scripts.update_source('..', True, 'rail_packages.yml') + scripts.update_source("..", True, "rail_packages.yml") def test_install(): - scripts.install('..', False, True, 'rail_packages.yml') - scripts.install('..', True, True, 'rail_packages.yml') + scripts.install("..", False, True, "rail_packages.yml") + scripts.install("..", True, True, "rail_packages.yml") def test_info(): diff --git a/tests/core/test_core.py b/tests/core/test_core.py index f30017c7..5e6c07df 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -3,12 +3,9 @@ from types import GeneratorType import numpy as np -import pandas as pd import pytest -import tempfile -import rail -from rail.core.common_params import SHARED_PARAMS, copy_param, set_param_default +from rail.core.common_params import copy_param, set_param_default from rail.core.data import ( DataHandle, DataStore, @@ -26,7 +23,7 @@ RowSelector, TableConverter, ) - + # def test_data_file(): # with pytest.raises(ValueError) as errinfo: @@ -36,39 +33,43 @@ def test_util_stages(): DS = RailStage.data_store DS.clear() - datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.pq") - - data = DS.read_file('data', TableHandle, datapath) + datapath = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.pq" + ) + + data = DS.read_file("data", TableHandle, datapath) table_conv = TableConverter.make_stage(name="conv", output_format="numpyDict") col_map = ColumnMapper.make_stage(name="col_map", columns={}) row_sel = RowSelector.make_stage(name="row_sel", start=1, stop=15) - with pytest.raises(KeyError) as errinfo: + with pytest.raises(KeyError) as _errinfo: table_conv.get_handle("nope", allow_missing=False) - conv_data = table_conv(data) + _conv_data = table_conv(data) mapped_data = col_map(data) - sel_data = row_sel(mapped_data) + _sel_data = row_sel(mapped_data) row_sel_2 = RowSelector.make_stage(name="row_sel_2", start=1, stop=15) row_sel_2.set_data("input", mapped_data.data) handle = row_sel_2.get_handle("input") - row_sel_3 = RowSelector.make_stage(name="row_sel_3", input=handle.path, start=1, stop=15) + row_sel_3 = RowSelector.make_stage( + name="row_sel_3", input=handle.path, start=1, stop=15 + ) row_sel_3.set_data("input", None, do_read=True) - + def do_data_handle(datapath, handle_class): - DS = RailStage.data_store + _DS = RailStage.data_store th = handle_class("data", path=datapath) - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: th.write() assert not th.has_data - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: th.write_chunk(0, 1) assert th.has_path assert th.is_written @@ -83,11 +84,11 @@ def do_data_handle(datapath, handle_class): assert th2.has_data assert not th2.has_path assert not th2.is_written - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: th2.open() - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: th2.write() - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: th2.write_chunk(0, 1) assert th2.make_name("data2") == f"data2.{handle_class.suffix}" @@ -97,7 +98,9 @@ def do_data_handle(datapath, handle_class): def test_pq_handle(): - datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.pq") + datapath = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.pq" + ) handle = do_data_handle(datapath, PqHandle) pqfile = handle.open() assert pqfile @@ -107,7 +110,9 @@ def test_pq_handle(): def test_qp_handle(): - datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "output_BPZ_lite.hdf5") + datapath = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "output_BPZ_lite.hdf5" + ) handle = do_data_handle(datapath, QPHandle) qpfile = handle.open() assert qpfile @@ -115,21 +120,30 @@ def test_qp_handle(): handle.close() assert handle.fileObj is None - with pytest.raises(TypeError) as errInfo: - bad_dh = QPHandle(data="this is not an Ensemble") + with pytest.raises(TypeError) as _errInfo: + _bad_dh = QPHandle(tag="bad_tag", data="this is not an Ensemble") def test_hdf5_handle(): - datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.hdf5") + datapath = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.hdf5" + ) handle = do_data_handle(datapath, Hdf5Handle) with handle.open(mode="r") as f: assert f assert handle.fileObj is not None datapath_chunked = os.path.join( - RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816_chunked.hdf5" + RAILDIR, + "rail", + "examples_data", + "testdata", + "test_dc2_training_9816_chunked.hdf5", ) handle_chunked = Hdf5Handle("chunked", handle.data, path=datapath_chunked) - from tables_io.arrayUtils import getGroupInputDataLength, getInitializationForODict, sliceDict + from tables_io.arrayUtils import ( # pylint: disable=import-outside-toplevel + getInitializationForODict, + sliceDict, + ) num_rows = len(handle.data["photometry"]["id"]) check_num_rows = len(handle()["photometry"]["id"]) @@ -143,9 +157,10 @@ def test_hdf5_handle(): for i in range(0, num_rows, chunk_size): start = i end = i + chunk_size - if end > num_rows: - end = num_rows - handle_chunked.set_data(sliceDict(handle.data["photometry"], slice(start, end)), partial=True) + end = min(end, num_rows) + handle_chunked.set_data( + sliceDict(handle.data["photometry"], slice(start, end)), partial=True + ) handle_chunked.write_chunk(start, end) write_size = handle_chunked.size() assert len(handle_chunked.data) <= 1000 @@ -159,7 +174,9 @@ def test_hdf5_handle(): def test_fits_handle(): - datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "output_BPZ_lite.fits") + datapath = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "output_BPZ_lite.fits" + ) handle = do_data_handle(datapath, FitsHandle) fitsfile = handle.open() assert fitsfile @@ -171,9 +188,21 @@ def test_fits_handle(): def test_model_handle(): DS = RailStage.data_store DS.clear() - model_path = os.path.join(RAILDIR, "rail", "examples_data", "estimation_data", "data", "CWW_HDFN_prior.pkl") + model_path = os.path.join( + RAILDIR, + "rail", + "examples_data", + "estimation_data", + "data", + "CWW_HDFN_prior.pkl", + ) model_path_copy = os.path.join( - RAILDIR, "rail", "examples_data", "estimation_data", "data", "CWW_HDFN_prior_copy.pkl" + RAILDIR, + "rail", + "examples_data", + "estimation_data", + "data", + "CWW_HDFN_prior_copy.pkl", ) mh = ModelHandle("model", path=model_path) mh2 = ModelHandle("model2", path=model_path) @@ -195,8 +224,10 @@ def test_model_handle(): def test_data_hdf5_iter(): DS = RailStage.data_store DS.clear() - - datapath = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.hdf5") + + datapath = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.hdf5" + ) # data = DS.read_file('data', TableHandle, datapath) th = Hdf5Handle("data", path=datapath) @@ -207,9 +238,12 @@ def test_data_hdf5_iter(): assert xx[0] == i * 1000 assert xx[1] - xx[0] <= 1000 - data = DS.read_file("input", TableHandle, datapath) + _data = DS.read_file("input", TableHandle, datapath) cm = ColumnMapper.make_stage( - input=datapath, chunk_size=1000, hdf5_groupname="photometry", columns=dict(id="bob") + input=datapath, + chunk_size=1000, + hdf5_groupname="photometry", + columns=dict(id="bob"), ) x = cm.input_iterator("input") @@ -224,13 +258,19 @@ def test_data_store(): DS = RailStage.data_store DS.clear() DS.__class__.allow_overwrite = False - - datapath_hdf5 = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.hdf5") - datapath_pq = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.pq") + + datapath_hdf5 = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.hdf5" + ) + datapath_pq = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816.pq" + ) datapath_hdf5_copy = os.path.join( RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816_copy.hdf5" ) - datapath_pq_copy = os.path.join(RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816_copy.pq") + datapath_pq_copy = os.path.join( + RAILDIR, "rail", "examples_data", "testdata", "test_dc2_training_9816_copy.pq" + ) DS.add_data("hdf5", None, Hdf5Handle, path=datapath_hdf5) DS.add_data("pq", None, PqHandle, path=datapath_pq) @@ -246,18 +286,18 @@ def test_data_store(): DS.write("pq_copy") DS.write("hdf5_copy") - with pytest.raises(KeyError) as errinfo: + with pytest.raises(KeyError) as _errinfo: DS.read("nope") - with pytest.raises(KeyError) as errinfo: + with pytest.raises(KeyError) as _errinfo: DS.open("nope") - with pytest.raises(KeyError) as errinfo: + with pytest.raises(KeyError) as _errinfo: DS.write("nope") - with pytest.raises(TypeError) as errinfo: + with pytest.raises(TypeError) as _errinfo: DS["nope"] = None - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: DS["pq"] = DS["pq"] - with pytest.raises(ValueError) as errinfo: + with pytest.raises(ValueError) as _errinfo: DS.pq = DS["pq"] assert repr(DS) @@ -275,7 +315,6 @@ def test_data_store(): os.remove(datapath_pq_copy) - def test_common_params(): par = copy_param("zmin") assert par.default == 0.0 @@ -300,13 +339,21 @@ def test_set_data_nonexistent_file(): col_map.set_data("model", None, path="./bad_directory/no_file.py") assert "Unable to find file" in err.context + def test_set_data_real_file(): """Create an instance of a child class of RailStage. Exercise the `set_data` method and pass in a path to model. The output of set_data should be `None`. """ DS = RailStage.data_store DS.clear() - model_path = os.path.join(RAILDIR, "rail", "examples_data", "estimation_data", "data", "CWW_HDFN_prior.pkl") + model_path = os.path.join( + RAILDIR, + "rail", + "examples_data", + "estimation_data", + "data", + "CWW_HDFN_prior.pkl", + ) DS.add_data("model", None, ModelHandle, path=model_path) col_map = ColumnMapper.make_stage(name="col_map", columns={}) diff --git a/tests/core/test_degraders.py b/tests/core/test_degraders.py index bb1aead8..d56fc1f7 100644 --- a/tests/core/test_degraders.py +++ b/tests/core/test_degraders.py @@ -1,5 +1,4 @@ import os -from typing import Type import numpy as np import pandas as pd @@ -8,11 +7,11 @@ from rail.core.data import DATA_STORE, TableHandle from rail.core.util_stages import ColumnMapper from rail.creation.degradation.quantityCut import QuantityCut -from rail.creation.degradation.spectroscopic_selections import * +from rail.creation.degradation.spectroscopic_selections import * # pylint: disable=wildcard-import,unused-wildcard-import -@pytest.fixture -def data(): +@pytest.fixture(name='data') +def data_fixture(): """Some dummy data to use below.""" DS = DATA_STORE() @@ -30,8 +29,8 @@ def data(): return DS.add_data("data", df, TableHandle, path="dummy.pd") -@pytest.fixture -def data_forspec(): +@pytest.fixture(name='data_forspec') +def data_forspec_fixture(): """Some dummy data to use below.""" DS = DATA_STORE() @@ -84,10 +83,10 @@ def test_QuantityCut_returns_correct_shape(data): def test_SpecSelection(data): bands = ["u", "g", "r", "i", "z", "y"] - band_dict = {band: f"mag_{band}_lsst" for band in bands} + _band_dict = {band: f"mag_{band}_lsst" for band in bands} rename_dict = {f"{band}_err": f"mag_err_{band}_lsst" for band in bands} rename_dict.update({f"{band}": f"mag_{band}_lsst" for band in bands}) - standard_colnames = [f"mag_{band}_lsst" for band in "ugrizy"] + _standard_colnames = [f"mag_{band}_lsst" for band in "ugrizy"] col_remapper_test = ColumnMapper.make_stage( name="col_remapper_test", hdf5_groupname="", columns=rename_dict @@ -96,51 +95,73 @@ def test_SpecSelection(data): degrader_GAMA = SpecSelection_GAMA.make_stage() degrader_GAMA(data) - degrader_GAMA.__repr__() + repr(degrader_GAMA) - os.remove(degrader_GAMA.get_output(degrader_GAMA.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_GAMA.get_output( + degrader_GAMA.get_aliased_tag("output"), final_name=True + ) + ) degrader_BOSS = SpecSelection_BOSS.make_stage() degrader_BOSS(data) - degrader_BOSS.__repr__() + repr(degrader_BOSS) - os.remove(degrader_BOSS.get_output(degrader_BOSS.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_BOSS.get_output( + degrader_BOSS.get_aliased_tag("output"), final_name=True + ) + ) degrader_DEEP2 = SpecSelection_DEEP2.make_stage() degrader_DEEP2(data) - degrader_DEEP2.__repr__() + repr(degrader_DEEP2) - os.remove(degrader_DEEP2.get_output(degrader_DEEP2.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_DEEP2.get_output( + degrader_DEEP2.get_aliased_tag("output"), final_name=True + ) + ) degrader_VVDSf02 = SpecSelection_VVDSf02.make_stage() degrader_VVDSf02(data) - degrader_VVDSf02.__repr__() + repr(degrader_VVDSf02) - degrader_zCOSMOS = SpecSelection_zCOSMOS.make_stage(colnames={"i": "mag_i_lsst", "redshift": "redshift"}) + degrader_zCOSMOS = SpecSelection_zCOSMOS.make_stage( + colnames={"i": "mag_i_lsst", "redshift": "redshift"} + ) degrader_zCOSMOS(data) - degrader_zCOSMOS.__repr__() + repr(degrader_zCOSMOS) - os.remove(degrader_zCOSMOS.get_output(degrader_zCOSMOS.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_zCOSMOS.get_output( + degrader_zCOSMOS.get_aliased_tag("output"), final_name=True + ) + ) degrader_HSC = SpecSelection_HSC.make_stage() degrader_HSC(data) - degrader_HSC.__repr__() + repr(degrader_HSC) - os.remove(degrader_HSC.get_output(degrader_HSC.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_HSC.get_output(degrader_HSC.get_aliased_tag("output"), final_name=True) + ) degrader_HSC = SpecSelection_HSC.make_stage(percentile_cut=70) degrader_HSC(data) - degrader_HSC.__repr__() + repr(degrader_HSC) - os.remove(degrader_HSC.get_output(degrader_HSC.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_HSC.get_output(degrader_HSC.get_aliased_tag("output"), final_name=True) + ) def test_SpecSelection_low_N_tot(data_forspec): bands = ["u", "g", "r", "i", "z", "y"] - band_dict = {band: f"mag_{band}_lsst" for band in bands} + _band_dict = {band: f"mag_{band}_lsst" for band in bands} rename_dict = {f"{band}_err": f"mag_err_{band}_lsst" for band in bands} rename_dict.update({f"{band}": f"mag_{band}_lsst" for band in bands}) - standard_colnames = [f"mag_{band}_lsst" for band in "ugrizy"] + _standard_colnames = [f"mag_{band}_lsst" for band in "ugrizy"] col_remapper_test = ColumnMapper.make_stage( name="col_remapper_test", hdf5_groupname="", columns=rename_dict @@ -150,7 +171,11 @@ def test_SpecSelection_low_N_tot(data_forspec): degrader_zCOSMOS = SpecSelection_zCOSMOS.make_stage(N_tot=1) degrader_zCOSMOS(data_forspec) - os.remove(degrader_zCOSMOS.get_output(degrader_zCOSMOS.get_aliased_tag("output"), final_name=True)) + os.remove( + degrader_zCOSMOS.get_output( + degrader_zCOSMOS.get_aliased_tag("output"), final_name=True + ) + ) @pytest.mark.parametrize("N_tot, errortype", [(-1, ValueError)]) @@ -168,10 +193,10 @@ def test_SpecSelection_bad_colname(data, errortype): degrader_GAMA(data) -@pytest.mark.parametrize("success_rate_dir, errortype", [("/this/path/should/not/exist", ValueError)]) +@pytest.mark.parametrize( + "success_rate_dir, errortype", [("/this/path/should/not/exist", ValueError)] +) def test_SpecSelection_bad_path(success_rate_dir, errortype): """Test bad parameters that should raise TypeError""" with pytest.raises(errortype): SpecSelection.make_stage(success_rate_dir=success_rate_dir) - - diff --git a/tests/core/test_introspection.py b/tests/core/test_introspection.py index 7310f385..6af30081 100644 --- a/tests/core/test_introspection.py +++ b/tests/core/test_introspection.py @@ -1,6 +1,4 @@ import tempfile -import pkgutil -import setuptools import rail from rail.core import RailEnv @@ -10,7 +8,7 @@ def test_print_rail_packages(): RailEnv.print_rail_packages() - + def test_print_rail_namespaces(): RailEnv.print_rail_namespaces() diff --git a/tests/core/test_pipeline.py b/tests/core/test_pipeline.py index c12bd2f9..8df171de 100644 --- a/tests/core/test_pipeline.py +++ b/tests/core/test_pipeline.py @@ -3,29 +3,31 @@ import ceci import numpy as np -import rail -from rail.core.data import TableHandle from rail.core.stage import RailPipeline, RailStage from rail.core.utils import RAILDIR from rail.core.util_stages import ColumnMapper, TableConverter -from rail.creation.degradation.quantityCut import QuantityCut + def test_pipeline(): DS = RailStage.data_store DS.__class__.allow_overwrite = True DS.clear() - input_file = os.path.join(RAILDIR, "rail/examples_data/goldenspike_data/data//test_flow_data.pq") + input_file = os.path.join( + RAILDIR, "rail/examples_data/goldenspike_data/data//test_flow_data.pq" + ) bands = ["u", "g", "r", "i", "z", "y"] - band_dict = {band: f"mag_{band}_lsst" for band in bands} + _band_dict = {band: f"mag_{band}_lsst" for band in bands} rename_dict = {f"mag_{band}_lsst": f"{band}_lsst" for band in bands} - post_grid = [float(x) for x in np.linspace(0.0, 5, 21)] + _post_grid = [float(x) for x in np.linspace(0.0, 5, 21)] col_remapper_test = ColumnMapper.make_stage( name="col_remapper_test", hdf5_groupname="", columns=rename_dict ) - table_conv_test = TableConverter.make_stage(name="table_conv_test", output_format="numpyDict", seed=12345) + table_conv_test = TableConverter.make_stage( + name="table_conv_test", output_format="numpyDict", seed=12345 + ) pipe = ceci.Pipeline.interactive() stages = [ @@ -37,7 +39,9 @@ def test_pipeline(): table_conv_test.connect_input(col_remapper_test) - pipe.initialize(dict(input=input_file), dict(output_dir=".", log_dir=".", resume=False), None) + pipe.initialize( + dict(input=input_file), dict(output_dir=".", log_dir=".", resume=False), None + ) pipe.save("stage.yaml") @@ -67,24 +71,28 @@ def test_golden_v2(): DS.clear() pipe = RailPipeline() - input_file = os.path.join(RAILDIR, "rail/examples_data/goldenspike_data/data//test_flow_data.pq") + input_file = os.path.join( + RAILDIR, "rail/examples_data/goldenspike_data/data//test_flow_data.pq" + ) bands = ["u", "g", "r", "i", "z", "y"] - band_dict = {band: f"mag_{band}_lsst" for band in bands} + _band_dict = {band: f"mag_{band}_lsst" for band in bands} rename_dict = {f"mag_{band}_lsst": f"{band}_lsst" for band in bands} - post_grid = [float(x) for x in np.linspace(0.0, 5, 21)] + _post_grid = [float(x) for x in np.linspace(0.0, 5, 21)] pipe.col_remapper_test = ColumnMapper.build( - hdf5_groupname="", - columns=rename_dict, - ) + hdf5_groupname="", + columns=rename_dict, + ) pipe.table_conv_test = TableConverter.build( - connections=dict(input=pipe.col_remapper_test.io.output), + connections=dict(input=pipe.col_remapper_test.io.output), # pylint: disable=no-member output_format="numpyDict", seed=12345, ) - pipe.initialize(dict(input=input_file), dict(output_dir=".", log_dir=".", resume=False), None) + pipe.initialize( + dict(input=input_file), dict(output_dir=".", log_dir=".", resume=False), None + ) pipe.save("stage.yaml") pr = ceci.Pipeline.read("stage.yaml") diff --git a/tests/core/test_point_estimation.py b/tests/core/test_point_estimation.py index b09e63de..56871fc8 100644 --- a/tests/core/test_point_estimation.py +++ b/tests/core/test_point_estimation.py @@ -11,8 +11,9 @@ def test_custom_point_estimate(): """ MEANING_OF_LIFE = 42.0 + class TestEstimator(CatEstimator): - name="TestEstimator" + name = "TestEstimator" def __init__(self, args, comm=None): CatEstimator.__init__(self, args, comm=comm) @@ -20,63 +21,66 @@ def __init__(self, args, comm=None): def _calculate_mode_point_estimate(self, qp_dist=None, grid=None): return np.ones(100) * MEANING_OF_LIFE - config_dict = {'calculated_point_estimates': ['mode']} + config_dict = {"calculated_point_estimates": ["mode"]} - test_estimator = TestEstimator.make_stage(name='test', **config_dict) + test_estimator = TestEstimator.make_stage(name="test", **config_dict) - locs = 2* (np.random.uniform(size=(100,1))-0.5) - scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) - test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) + locs = 2 * (np.random.uniform(size=(100, 1)) - 0.5) + scales = 1 + 0.2 * (np.random.uniform(size=(100, 1)) - 0.5) + test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) # pylint: disable=no-member result = test_estimator.calculate_point_estimates(test_ensemble) - assert np.all(result.ancil['mode'] == MEANING_OF_LIFE) + assert np.all(result.ancil["mode"] == MEANING_OF_LIFE) + def test_basic_point_estimate(): """This test checks to make sure that all the basic point estimates are executed when requested in the configuration dictionary. """ - config_dict = {'calculated_point_estimates': ['mean', 'median', 'mode'], - 'zmin': 0.0, - 'zmax': 3.0, - 'nzbins': 301} + config_dict = { + "calculated_point_estimates": ["mean", "median", "mode"], + "zmin": 0.0, + "zmax": 3.0, + "nzbins": 301, + } - test_estimator = CatEstimator.make_stage(name='test', **config_dict) + test_estimator = CatEstimator.make_stage(name="test", **config_dict) - locs = 2* (np.random.uniform(size=(100,1))-0.5) - scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) - test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) + locs = 2 * (np.random.uniform(size=(100, 1)) - 0.5) + scales = 1 + 0.2 * (np.random.uniform(size=(100, 1)) - 0.5) + test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) # pylint: disable=no-member result = test_estimator.calculate_point_estimates(test_ensemble, None) # note: we're not interested in testing the values of point estimates, # just that they were added to the ancillary data. - assert 'mode' in result.ancil - assert 'median' in result.ancil - assert 'mean' in result.ancil + assert "mode" in result.ancil + assert "median" in result.ancil + assert "mean" in result.ancil + def test_mode_no_grid(): - """This exercises the KeyError logic in `_calculate_mode_point_estimate`. - """ - config_dict = {'zmin':0.0, 'nzbins':100, 'calculated_point_estimates': ['mode']} + """This exercises the KeyError logic in `_calculate_mode_point_estimate`.""" + config_dict = {"zmin": 0.0, "nzbins": 100, "calculated_point_estimates": ["mode"]} - test_estimator = CatEstimator.make_stage(name='test', **config_dict) + test_estimator = CatEstimator.make_stage(name="test", **config_dict) with pytest.raises(KeyError) as excinfo: _ = test_estimator.calculate_point_estimates(None, None) assert "to be defined in stage configuration" in str(excinfo.value) + def test_mode_no_point_estimates(): - """This exercises the KeyError logic in `_calculate_mode_point_estimate`. - """ - config_dict = {'zmin':0.0, 'nzbins':100} + """This exercises the KeyError logic in `_calculate_mode_point_estimate`.""" + config_dict = {"zmin": 0.0, "nzbins": 100} - test_estimator = CatEstimator.make_stage(name='test', **config_dict) + test_estimator = CatEstimator.make_stage(name="test", **config_dict) - locs = 2* (np.random.uniform(size=(100,1))-0.5) - scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) - test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) + locs = 2 * (np.random.uniform(size=(100, 1)) - 0.5) + scales = 1 + 0.2 * (np.random.uniform(size=(100, 1)) - 0.5) + test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) # pylint: disable=no-member output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None) diff --git a/tests/estimation/test_algos.py b/tests/estimation/test_algos.py index e0503c0e..ceb171e6 100644 --- a/tests/estimation/test_algos.py +++ b/tests/estimation/test_algos.py @@ -23,18 +23,24 @@ def test_random_pz(): "model": "None", "seed": 42, } - zb_expected = np.array([2.322, 1.317, 2.576, 2.092, 0.283, 2.927, 2.283, 2.358, 0.384, 1.351]) + zb_expected = np.array( + [2.322, 1.317, 2.576, 2.092, 0.283, 2.927, 2.283, 2.358, 0.384, 1.351] + ) train_algo = random_gauss.RandomGaussInformer pz_algo = random_gauss.RandomGaussEstimator results, _, _ = one_algo( "RandomPZ", train_algo, pz_algo, train_config_dict, estim_config_dict ) - assert np.isclose(results.ancil['zmode'], zb_expected).all() + assert np.isclose(results.ancil["zmode"], zb_expected).all() def test_train_pz(): train_config_dict = dict( - zmin=0.0, zmax=3.0, nzbins=301, hdf5_groupname="photometry", model="model_train_z.tmp" + zmin=0.0, + zmax=3.0, + nzbins=301, + hdf5_groupname="photometry", + model="model_train_z.tmp", ) estim_config_dict = dict(hdf5_groupname="photometry", model="model_train_z.tmp") diff --git a/tests/estimation/test_classifier.py b/tests/estimation/test_classifier.py index a9a7b99d..6b795455 100644 --- a/tests/estimation/test_classifier.py +++ b/tests/estimation/test_classifier.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from rail.core.utils import RAILDIR, find_rail_file +from rail.core.utils import RAILDIR from rail.core.stage import RailStage from rail.core.data import QPHandle from rail.estimation.algos.uniform_binning import UniformBinningClassifier @@ -12,133 +12,146 @@ DS = RailStage.data_store DS.__class__.allow_overwrite = True -inputdata = os.path.join(RAILDIR, 'rail/examples_data/testdata/output_BPZ_lite.hdf5') +inputdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.hdf5") + @pytest.mark.parametrize( - "input_param", - [{"zbin_edges": [0.0, 0.3]}, - {"zmin": 0.0, "zmax": 0.3, "nbins": 1}, - {"zbin_edges": [0.0, 0.3], "id_name": "CATAID"}, - ] + "input_param", + [ + {"zbin_edges": [0.0, 0.3]}, + {"zmin": 0.0, "zmax": 0.3, "nbins": 1}, + {"zbin_edges": [0.0, 0.3], "id_name": "CATAID"}, + ], ) - def test_UniformBinningClassifier(input_param): - DS.clear() - input_data = DS.read_file('input_data', QPHandle, inputdata) + DS.clear() + input_data = DS.read_file("input_data", QPHandle, inputdata) tomo = UniformBinningClassifier.make_stage( - point_estimate='zmode', + point_estimate="zmode", no_assign=-99, **input_param, ) - - out_data = tomo.classify(input_data) - + + _out_data = tomo.classify(input_data) + def test_UniformBinningClassifier_binsize(): - DS.clear() - input_data = DS.read_file('input_data', QPHandle, inputdata) + DS.clear() + input_data = DS.read_file("input_data", QPHandle, inputdata) tomo = UniformBinningClassifier.make_stage( - point_estimate='zmode', + point_estimate="zmode", no_assign=-99, - zmin=0.0, - zmax=2.0, + zmin=0.0, + zmax=2.0, nbins=2, ) output_data = tomo.classify(input_data) - out_data=output_data.data - + out_data = output_data.data + # check length: - assert len(out_data["class_id"])==len(out_data["row_index"]) - + assert len(out_data["class_id"]) == len(out_data["row_index"]) + # check that the assignment is as expected: - assert (np.in1d(np.unique(out_data["class_id"]),[1,2,-99])).all() - - zb = input_data.data.ancil['zmode'] + assert (np.in1d(np.unique(out_data["class_id"]), [1, 2, -99])).all() + + zb = input_data.data.ancil["zmode"] if 1 in out_data["class_id"]: - assert ((zb[out_data["class_id"]==1]>=0.0)&(zb[out_data["class_id"]==1]<1.0)).all() + assert ( + (zb[out_data["class_id"] == 1] >= 0.0) + & (zb[out_data["class_id"] == 1] < 1.0) + ).all() if 2 in out_data["class_id"]: - assert ((zb[out_data["class_id"]==2]>=1.0)&(zb[out_data["class_id"]==2]<2.0)).all() + assert ( + (zb[out_data["class_id"] == 2] >= 1.0) + & (zb[out_data["class_id"] == 2] < 2.0) + ).all() if -99 in out_data["class_id"]: - assert ((zb[out_data["class_id"]==-99]<0.0)|(zb[out_data["class_id"]==-99]>=2.0)).all() - + assert ( + (zb[out_data["class_id"] == -99] < 0.0) + | (zb[out_data["class_id"] == -99] >= 2.0) + ).all() + def test_UniformBinningClassifier_ancil(): - DS.clear() - input_data = DS.read_file('input_data', QPHandle, inputdata) + DS.clear() + input_data = DS.read_file("input_data", QPHandle, inputdata) tomo = UniformBinningClassifier.make_stage( - point_estimate='zmedian', + point_estimate="zmedian", no_assign=-99, - zmin=0.0, - zmax=2.0, + zmin=0.0, + zmax=2.0, nbins=2, ) with pytest.raises(KeyError): - out_data = tomo.classify(input_data) - + _out_data = tomo.classify(input_data) + @pytest.mark.parametrize( - "input_param", - [{"zmin": 0.0, "zmax": 0.3, "nbins": 1}, - {"zmin": 0.0, "zmax": 0.3, "nbins": 1, "id_name": "CATAID"}, - ] + "input_param", + [ + {"zmin": 0.0, "zmax": 0.3, "nbins": 1}, + {"zmin": 0.0, "zmax": 0.3, "nbins": 1, "id_name": "CATAID"}, + ], ) - def test_EqualCountClassifier(input_param): - DS.clear() - input_data = DS.read_file('input_data', QPHandle, inputdata) + DS.clear() + input_data = DS.read_file("input_data", QPHandle, inputdata) tomo = EqualCountClassifier.make_stage( - point_estimate='zmode', + point_estimate="zmode", no_assign=-99, **input_param, ) - - out_data = tomo.classify(input_data) + + _out_data = tomo.classify(input_data) def test_EqualCountClassifier_nobj(): - DS.clear() - input_data = DS.read_file('input_data', QPHandle, inputdata) + DS.clear() + input_data = DS.read_file("input_data", QPHandle, inputdata) tomo = EqualCountClassifier.make_stage( - point_estimate='zmode', + point_estimate="zmode", no_assign=-99, - zmin=0.0, - zmax=2.0, + zmin=0.0, + zmax=2.0, nbins=2, ) output_data = tomo.classify(input_data) - out_data=output_data.data - + out_data = output_data.data + # check that there are equal number of object in each bin modulo Ngal%Nbins - assert (np.in1d(np.unique(out_data["class_id"]), [1,2,-99])).all() - - Ngal=sum(out_data["class_id"]!=-99) - exp_Ngal_perbin=int(Ngal/2) + assert (np.in1d(np.unique(out_data["class_id"]), [1, 2, -99])).all() + + Ngal = sum(out_data["class_id"] != -99) + exp_Ngal_perbin = int(Ngal / 2) # check that each bin does contain number of objects consistent with expected number # exp_Ngal_perbin + 1 to account for the cases where Ngal%Nbins!=0 - assert sum(out_data["class_id"]==1) in [exp_Ngal_perbin, exp_Ngal_perbin+1] - assert sum(out_data["class_id"]==2) in [exp_Ngal_perbin, exp_Ngal_perbin+1] - + assert sum(out_data["class_id"] == 1) in [exp_Ngal_perbin, exp_Ngal_perbin + 1] + assert sum(out_data["class_id"] == 2) in [exp_Ngal_perbin, exp_Ngal_perbin + 1] + # check no assignment is correct - if Ngal=2.0)).all() + if Ngal < len(out_data["class_id"]): + zb = input_data.data.ancil["zmode"] + assert ( + (zb[out_data["class_id"] == -99] < 0.0) + | (zb[out_data["class_id"] == -99] >= 2.0) + ).all() def test_EqualCountClassifier_ancil(): - DS.clear() - input_data = DS.read_file('input_data', QPHandle, inputdata) + DS.clear() + input_data = DS.read_file("input_data", QPHandle, inputdata) tomo = EqualCountClassifier.make_stage( - point_estimate='zmedian', + point_estimate="zmedian", no_assign=-99, - zmin=0.0, - zmax=2.0, + zmin=0.0, + zmax=2.0, nbins=2, ) with pytest.raises(KeyError): - out_data = tomo.classify(input_data) \ No newline at end of file + _out_data = tomo.classify(input_data) diff --git a/tests/estimation/test_summarizers.py b/tests/estimation/test_summarizers.py index 4a59c4db..a48d98ec 100644 --- a/tests/estimation/test_summarizers.py +++ b/tests/estimation/test_summarizers.py @@ -18,16 +18,19 @@ def one_algo(key, summarizer_class, summary_kwargs): test_data = DS.read_file("test_data", QPHandle, testdata) summarizer = summarizer_class.make_stage(name=key, **summary_kwargs) summary_ens = summarizer.summarize(test_data) - os.remove(summarizer.get_output(summarizer.get_aliased_tag("output"), final_name=True)) - os.remove(summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True)) + os.remove( + summarizer.get_output(summarizer.get_aliased_tag("output"), final_name=True) + ) + os.remove( + summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True) + ) return summary_ens def test_naive_stack(): - """Basic end to end test for the Naive stack informer to estimator stages - """ + """Basic end to end test for the Naive stack informer to estimator stages""" naive_stack_informer_stage = naive_stack.NaiveStackInformer.make_stage() - naive_stack_informer_stage.inform('') + naive_stack_informer_stage.inform("") summary_config_dict = {} summarizer_class = naive_stack.NaiveStackSummarizer @@ -39,7 +42,7 @@ def test_point_estimate_hist(): stages """ point_est_informer_stage = point_est_hist.PointEstHistInformer.make_stage() - point_est_informer_stage.inform('') + point_est_informer_stage.inform("") summary_config_dict = {} summarizer_class = point_est_hist.PointEstHistSummarizer @@ -47,10 +50,9 @@ def test_point_estimate_hist(): def test_var_inference_stack(): - """Basic end to end test for the var inference informer to estimator stages - """ + """Basic end to end test for the var inference informer to estimator stages""" var_inf_informer_stage = var_inf.VarInfStackInformer.make_stage() - var_inf_informer_stage.inform('') + var_inf_informer_stage.inform("") summary_config_dict = {} summarizer_class = var_inf.VarInfStackSummarizer diff --git a/tests/evaluation/test_evaluation.py b/tests/evaluation/test_evaluation.py index e13a7bf2..d9ff37a3 100644 --- a/tests/evaluation/test_evaluation.py +++ b/tests/evaluation/test_evaluation.py @@ -7,7 +7,6 @@ from rail.core.data import QPHandle, TableHandle from rail.core.stage import RailStage from rail.evaluation.evaluator import Evaluator -from rail.evaluation.metrics.cdeloss import CDELoss # values for metrics OUTRATE = 0.0 @@ -30,7 +29,7 @@ def construct_test_ensemble(): locs = np.expand_dims(true_zs + np.random.normal(0.0, 0.01, NPDF), -1) true_ez = (locs.flatten() - true_zs) / (1.0 + true_zs) scales = np.ones((NPDF, 1)) * 0.1 + np.random.uniform(size=(NPDF, 1)) * 0.05 - n_ens = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) + n_ens = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) # pylint: disable=no-member zgrid = np.linspace(0, nmax, 301) grid_ens = n_ens.convert_to(qp.interp_gen, xvals=zgrid) return zgrid, true_zs, grid_ens, true_ez @@ -59,11 +58,13 @@ def test_point_metrics(): def test_evaluation_stage(): DS = RailStage.data_store - zgrid, zspec, pdf_ens, true_ez = construct_test_ensemble() + _zgrid, zspec, pdf_ens, _true_ez = construct_test_ensemble() pdf = DS.add_data("pdf", pdf_ens, QPHandle) truth_table = dict(redshift=zspec) truth = DS.add_data("truth", truth_table, TableHandle) evaluator = Evaluator.make_stage(name="Eval") evaluator.evaluate(pdf, truth) - os.remove(evaluator.get_output(evaluator.get_aliased_tag("output"), final_name=True)) + os.remove( + evaluator.get_output(evaluator.get_aliased_tag("output"), final_name=True) + )