From 2b75d6e9b6411e3d0e1a3a68e83df1aea08b46a2 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Tue, 13 Aug 2024 16:04:07 -0700 Subject: [PATCH] Switch to using Param and SHARED_PARAMS as much as possible (#161) * Switch to using Param and SHARED_PARAMS as much as possible * updated params in table_tools.py --- src/rail/core/common_params.py | 1 + src/rail/creation/degrader.py | 5 ++++- src/rail/creation/degraders/quantityCut.py | 5 ++++- src/rail/creation/engine.py | 14 +++++++++++--- src/rail/estimation/algos/true_nz.py | 4 ++-- src/rail/estimation/classifier.py | 8 ++++++-- src/rail/estimation/estimator.py | 2 +- src/rail/estimation/summarizer.py | 7 ++++--- src/rail/tools/table_tools.py | 15 ++++++++++++--- 9 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/rail/core/common_params.py b/src/rail/core/common_params.py index 3c9d8704..c9fb34bc 100644 --- a/src/rail/core/common_params.py +++ b/src/rail/core/common_params.py @@ -30,6 +30,7 @@ hdf5_groupname=Param( str, "photometry", msg="name of hdf5 group for data, if None, then set to ''" ), + chunk_size=Param(int, 10000, msg="Number of object per chunk for parallel processing"), zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"), zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"), nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"), diff --git a/src/rail/creation/degrader.py b/src/rail/creation/degrader.py index d184821e..a074a652 100644 --- a/src/rail/creation/degrader.py +++ b/src/rail/creation/degrader.py @@ -4,6 +4,7 @@ and returns a pandas DataFrame, and wraps the run method. """ +from ceci.config import StageParameter as Param from rail.core.stage import RailStage from rail.core.data import PqHandle @@ -19,7 +20,9 @@ class Degrader(RailStage): # pragma: no cover name = "Degrader" config_options = RailStage.config_options.copy() - config_options.update(seed=12345) + config_options.update( + seed=Param(int, default=12345, msg="Random number seed"), + ) inputs = [("input", PqHandle)] outputs = [("output", PqHandle)] diff --git a/src/rail/creation/degraders/quantityCut.py b/src/rail/creation/degraders/quantityCut.py index 171ab6c6..235776c4 100644 --- a/src/rail/creation/degraders/quantityCut.py +++ b/src/rail/creation/degraders/quantityCut.py @@ -2,6 +2,7 @@ from numbers import Number +from ceci.config import StageParameter as Param import numpy as np from rail.creation.selector import Selector @@ -15,7 +16,9 @@ class QuantityCut(Selector): name = "QuantityCut" config_options = Selector.config_options.copy() - config_options.update(cuts=dict) + config_options.update( + cuts=Param(dict, required=True, msg="Cuts to apply"), + ) def __init__(self, args, **kwargs): """Constructor. diff --git a/src/rail/creation/engine.py b/src/rail/creation/engine.py index b7a056ad..6d97c2d5 100644 --- a/src/rail/creation/engine.py +++ b/src/rail/creation/engine.py @@ -6,6 +6,7 @@ """ import qp +from ceci.config import StageParameter as Param from rail.core.data import DataHandle, ModelHandle, QPHandle, TableHandle from rail.core.stage import RailStage @@ -15,7 +16,9 @@ class Modeler(RailStage): # pragma: no cover name = "Modeler" config_options = RailStage.config_options.copy() - config_options.update(seed=12345) + config_options.update( + seed=Param(int, default=12345, msg="Random number seed"), + ) inputs = [("input", DataHandle)] outputs = [("model", ModelHandle)] @@ -52,7 +55,10 @@ class Creator(RailStage): # pragma: no cover name = "Creator" config_options = RailStage.config_options.copy() - config_options.update(n_samples=int, seed=12345) + config_options.update( + n_samples=Param(int, required=True, msg="Number of samples to create"), + seed=Param(int, default=12345, msg="Random number seed"), + ) inputs = [("model", ModelHandle)] outputs = [("output", TableHandle)] @@ -137,7 +143,9 @@ class PosteriorCalculator(RailStage): # pragma: no cover name = "PosteriorCalculator" config_options = RailStage.config_options.copy() - config_options.update(column=str) + config_options.update( + column=Param(str, required=True, msg="Column to compute posterior for"), + ) inputs = [ ("model", ModelHandle), ("input", TableHandle), diff --git a/src/rail/estimation/algos/true_nz.py b/src/rail/estimation/algos/true_nz.py index 8b148672..d4ee147f 100644 --- a/src/rail/estimation/algos/true_nz.py +++ b/src/rail/estimation/algos/true_nz.py @@ -22,8 +22,8 @@ class TrueNZHistogrammer(RailStage): nzbins=SHARED_PARAMS, redshift_col=SHARED_PARAMS, selected_bin=Param(int, -1, msg="Which tomography bin to consider"), - chunk_size=10000, - hdf5_groupname="", + chunk_size=SHARED_PARAMS, + hdf5_groupname=SHARED_PARAMS, ) inputs = [("input", TableHandle), ("tomography_bins", TableHandle)] outputs = [("true_NZ", QPHandle)] diff --git a/src/rail/estimation/classifier.py b/src/rail/estimation/classifier.py index 5014f886..f3bdd565 100644 --- a/src/rail/estimation/classifier.py +++ b/src/rail/estimation/classifier.py @@ -4,6 +4,7 @@ import gc from rail.core.data import QPHandle, TableHandle, ModelHandle, Hdf5Handle +from rail.core.common_params import SHARED_PARAMS from rail.core.stage import RailStage @@ -19,7 +20,10 @@ class CatClassifier(RailStage): # pragma: no cover name = "CatClassifier" config_options = RailStage.config_options.copy() - config_options.update(chunk_size=10000, hdf5_groupname=str) + config_options.update( + chunk_size=SHARED_PARAMS, + hdf5_groupname=SHARED_PARAMS, + ) inputs = [("model", ModelHandle), ("input", TableHandle)] outputs = [("output", TableHandle)] @@ -102,7 +106,7 @@ class PZClassifier(RailStage): name = "PZClassifier" config_options = RailStage.config_options.copy() - config_options.update(chunk_size=10000) + config_options.update(chunk_size=SHARED_PARAMS) inputs = [("input", QPHandle)] outputs = [("output", Hdf5Handle)] diff --git a/src/rail/estimation/estimator.py b/src/rail/estimation/estimator.py index 59b3d2aa..d0c3203d 100644 --- a/src/rail/estimation/estimator.py +++ b/src/rail/estimation/estimator.py @@ -28,7 +28,7 @@ class CatEstimator(RailStage, PointEstimationMixin): name = "CatEstimator" config_options = RailStage.config_options.copy() config_options.update( - chunk_size=Param(dtype=int, default=10000), + chunk_size=SHARED_PARAMS, hdf5_groupname=SHARED_PARAMS, zmin=SHARED_PARAMS, zmax=SHARED_PARAMS, diff --git a/src/rail/estimation/summarizer.py b/src/rail/estimation/summarizer.py index 4dd65178..a5d2b202 100644 --- a/src/rail/estimation/summarizer.py +++ b/src/rail/estimation/summarizer.py @@ -5,6 +5,7 @@ import numpy as np from rail.core.data import QPHandle, TableHandle, ModelHandle +from rail.core.common_params import SHARED_PARAMS from rail.core.stage import RailStage # for backwards compatibility @@ -22,7 +23,7 @@ class CatSummarizer(RailStage): name = "CatSummarizer" config_options = RailStage.config_options.copy() - config_options.update(chunk_size=10000) + config_options.update(chunk_size=SHARED_PARAMS) inputs = [("input", TableHandle)] outputs = [("output", QPHandle)] @@ -66,7 +67,7 @@ class PZSummarizer(RailStage): name = "PZtoNZSummarizer" config_options = RailStage.config_options.copy() - config_options.update(chunk_size=10000) + config_options.update(chunk_size=SHARED_PARAMS) inputs = [("model", ModelHandle), ("input", QPHandle)] outputs = [("output", QPHandle)] @@ -129,7 +130,7 @@ class SZPZSummarizer(RailStage): name = "SZPZtoNZSummarizer" config_options = RailStage.config_options.copy() - config_options.update(chunk_size=10000) + config_options.update(chunk_size=SHARED_PARAMS) inputs = [ ("input", TableHandle), ("spec_input", TableHandle), diff --git a/src/rail/tools/table_tools.py b/src/rail/tools/table_tools.py index c05702a0..d81c759d 100644 --- a/src/rail/tools/table_tools.py +++ b/src/rail/tools/table_tools.py @@ -2,6 +2,7 @@ import tables_io +from ceci.config import StageParameter as Param from rail.core.stage import RailStage from rail.core.data import PqHandle, Hdf5Handle @@ -21,7 +22,10 @@ class ColumnMapper(RailStage): name = "ColumnMapper" config_options = RailStage.config_options.copy() - config_options.update(chunk_size=100_000, columns=dict, inplace=False) + config_options.update( + columns=Param(dict, required=True, msg="Map of columns to rename"), + inplace=Param(bool, default=False, msg="Update file in place"), + ) inputs = [("input", PqHandle)] outputs = [("output", PqHandle)] @@ -70,7 +74,10 @@ class RowSelector(RailStage): name = "RowSelector" config_options = RailStage.config_options.copy() - config_options.update(start=int, stop=int) + config_options.update( + start=Param(int, required=True, msg="Starting row number"), + stop=Param(int, required=True, msg="Stoppig row number"), + ) inputs = [("input", PqHandle)] outputs = [("output", PqHandle)] @@ -112,7 +119,9 @@ class TableConverter(RailStage): name = "TableConverter" config_options = RailStage.config_options.copy() - config_options.update(output_format=str) + config_options.update( + output_format=Param(str, required=True, msg="Format of output table"), + ) inputs = [("input", PqHandle)] outputs = [("output", Hdf5Handle)]