Skip to content

Commit

Permalink
Switch to using Param and SHARED_PARAMS as much as possible (#161)
Browse files Browse the repository at this point in the history
* Switch to using Param and SHARED_PARAMS as much as possible

* updated params in table_tools.py
  • Loading branch information
eacharles authored Aug 13, 2024
1 parent 6a5046d commit 2b75d6e
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/rail/core/common_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
hdf5_groupname=Param(
str, "photometry", msg="name of hdf5 group for data, if None, then set to ''"
),
chunk_size=Param(int, 10000, msg="Number of object per chunk for parallel processing"),
zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"),
zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"),
nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"),
Expand Down
5 changes: 4 additions & 1 deletion src/rail/creation/degrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
and returns a pandas DataFrame, and wraps the run method.
"""

from ceci.config import StageParameter as Param
from rail.core.stage import RailStage
from rail.core.data import PqHandle

Expand All @@ -19,7 +20,9 @@ class Degrader(RailStage): # pragma: no cover

name = "Degrader"
config_options = RailStage.config_options.copy()
config_options.update(seed=12345)
config_options.update(
seed=Param(int, default=12345, msg="Random number seed"),
)
inputs = [("input", PqHandle)]
outputs = [("output", PqHandle)]

Expand Down
5 changes: 4 additions & 1 deletion src/rail/creation/degraders/quantityCut.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from numbers import Number

from ceci.config import StageParameter as Param
import numpy as np
from rail.creation.selector import Selector

Expand All @@ -15,7 +16,9 @@ class QuantityCut(Selector):

name = "QuantityCut"
config_options = Selector.config_options.copy()
config_options.update(cuts=dict)
config_options.update(
cuts=Param(dict, required=True, msg="Cuts to apply"),
)

def __init__(self, args, **kwargs):
"""Constructor.
Expand Down
14 changes: 11 additions & 3 deletions src/rail/creation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import qp
from ceci.config import StageParameter as Param
from rail.core.data import DataHandle, ModelHandle, QPHandle, TableHandle
from rail.core.stage import RailStage

Expand All @@ -15,7 +16,9 @@ class Modeler(RailStage): # pragma: no cover

name = "Modeler"
config_options = RailStage.config_options.copy()
config_options.update(seed=12345)
config_options.update(
seed=Param(int, default=12345, msg="Random number seed"),
)
inputs = [("input", DataHandle)]
outputs = [("model", ModelHandle)]

Expand Down Expand Up @@ -52,7 +55,10 @@ class Creator(RailStage): # pragma: no cover

name = "Creator"
config_options = RailStage.config_options.copy()
config_options.update(n_samples=int, seed=12345)
config_options.update(
n_samples=Param(int, required=True, msg="Number of samples to create"),
seed=Param(int, default=12345, msg="Random number seed"),
)
inputs = [("model", ModelHandle)]
outputs = [("output", TableHandle)]

Expand Down Expand Up @@ -137,7 +143,9 @@ class PosteriorCalculator(RailStage): # pragma: no cover

name = "PosteriorCalculator"
config_options = RailStage.config_options.copy()
config_options.update(column=str)
config_options.update(
column=Param(str, required=True, msg="Column to compute posterior for"),
)
inputs = [
("model", ModelHandle),
("input", TableHandle),
Expand Down
4 changes: 2 additions & 2 deletions src/rail/estimation/algos/true_nz.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class TrueNZHistogrammer(RailStage):
nzbins=SHARED_PARAMS,
redshift_col=SHARED_PARAMS,
selected_bin=Param(int, -1, msg="Which tomography bin to consider"),
chunk_size=10000,
hdf5_groupname="",
chunk_size=SHARED_PARAMS,
hdf5_groupname=SHARED_PARAMS,
)
inputs = [("input", TableHandle), ("tomography_bins", TableHandle)]
outputs = [("true_NZ", QPHandle)]
Expand Down
8 changes: 6 additions & 2 deletions src/rail/estimation/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gc
from rail.core.data import QPHandle, TableHandle, ModelHandle, Hdf5Handle
from rail.core.common_params import SHARED_PARAMS
from rail.core.stage import RailStage


Expand All @@ -19,7 +20,10 @@ class CatClassifier(RailStage): # pragma: no cover

name = "CatClassifier"
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=10000, hdf5_groupname=str)
config_options.update(
chunk_size=SHARED_PARAMS,
hdf5_groupname=SHARED_PARAMS,
)
inputs = [("model", ModelHandle), ("input", TableHandle)]
outputs = [("output", TableHandle)]

Expand Down Expand Up @@ -102,7 +106,7 @@ class PZClassifier(RailStage):

name = "PZClassifier"
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=10000)
config_options.update(chunk_size=SHARED_PARAMS)
inputs = [("input", QPHandle)]
outputs = [("output", Hdf5Handle)]

Expand Down
2 changes: 1 addition & 1 deletion src/rail/estimation/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CatEstimator(RailStage, PointEstimationMixin):
name = "CatEstimator"
config_options = RailStage.config_options.copy()
config_options.update(
chunk_size=Param(dtype=int, default=10000),
chunk_size=SHARED_PARAMS,
hdf5_groupname=SHARED_PARAMS,
zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
Expand Down
7 changes: 4 additions & 3 deletions src/rail/estimation/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from rail.core.data import QPHandle, TableHandle, ModelHandle
from rail.core.common_params import SHARED_PARAMS
from rail.core.stage import RailStage

# for backwards compatibility
Expand All @@ -22,7 +23,7 @@ class CatSummarizer(RailStage):

name = "CatSummarizer"
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=10000)
config_options.update(chunk_size=SHARED_PARAMS)
inputs = [("input", TableHandle)]
outputs = [("output", QPHandle)]

Expand Down Expand Up @@ -66,7 +67,7 @@ class PZSummarizer(RailStage):

name = "PZtoNZSummarizer"
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=10000)
config_options.update(chunk_size=SHARED_PARAMS)
inputs = [("model", ModelHandle), ("input", QPHandle)]
outputs = [("output", QPHandle)]

Expand Down Expand Up @@ -129,7 +130,7 @@ class SZPZSummarizer(RailStage):

name = "SZPZtoNZSummarizer"
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=10000)
config_options.update(chunk_size=SHARED_PARAMS)
inputs = [
("input", TableHandle),
("spec_input", TableHandle),
Expand Down
15 changes: 12 additions & 3 deletions src/rail/tools/table_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import tables_io

from ceci.config import StageParameter as Param
from rail.core.stage import RailStage

from rail.core.data import PqHandle, Hdf5Handle
Expand All @@ -21,7 +22,10 @@ class ColumnMapper(RailStage):

name = "ColumnMapper"
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=100_000, columns=dict, inplace=False)
config_options.update(
columns=Param(dict, required=True, msg="Map of columns to rename"),
inplace=Param(bool, default=False, msg="Update file in place"),
)
inputs = [("input", PqHandle)]
outputs = [("output", PqHandle)]

Expand Down Expand Up @@ -70,7 +74,10 @@ class RowSelector(RailStage):

name = "RowSelector"
config_options = RailStage.config_options.copy()
config_options.update(start=int, stop=int)
config_options.update(
start=Param(int, required=True, msg="Starting row number"),
stop=Param(int, required=True, msg="Stoppig row number"),
)
inputs = [("input", PqHandle)]
outputs = [("output", PqHandle)]

Expand Down Expand Up @@ -112,7 +119,9 @@ class TableConverter(RailStage):

name = "TableConverter"
config_options = RailStage.config_options.copy()
config_options.update(output_format=str)
config_options.update(
output_format=Param(str, required=True, msg="Format of output table"),
)
inputs = [("input", PqHandle)]
outputs = [("output", Hdf5Handle)]

Expand Down

0 comments on commit 2b75d6e

Please sign in to comment.