diff --git a/src/rail/estimation/algos/true_nz.py b/src/rail/estimation/algos/true_nz.py new file mode 100644 index 00000000..0b89f083 --- /dev/null +++ b/src/rail/estimation/algos/true_nz.py @@ -0,0 +1,125 @@ +""" +A summarizer-like stage that simple makes a histogram the true nz +""" + +import numpy as np +import qp + +from ceci.config import StageParameter as Param +from rail.core.common_params import SHARED_PARAMS +from rail.core.stage import RailStage +from rail.core.data import QPHandle, TableHandle + + +class TrueNZHistogrammer(RailStage): + """Summarizer-like stage which simply histograms the true redshift""" + + name = "TrueNZHistogrammer" + config_options = RailStage.config_options.copy() + config_options.update( + zmin=SHARED_PARAMS, + zmax=SHARED_PARAMS, + nzbins=SHARED_PARAMS, + redshift_col=SHARED_PARAMS, + selected_bin=Param(int, -1, msg="Which tomography bin to consider"), + chunk_size=10000, + hdf5_groupname="", + ) + inputs = [("input", TableHandle), ("tomography_bins", TableHandle)] + outputs = [("true_NZ", QPHandle)] + + def __init__(self, args, comm=None): + RailStage.__init__(self, args, comm=comm) + self.zgrid = None + self.bincents = None + + def _setup_iterator(self): + + itrs = [ + self.input_iterator('input', groupname=self.config.hdf5_groupname), + self.input_iterator('tomography_bins', groupname=""), + ] + + for it in zip(*itrs): + first = True + mask = None + for s, e, d in it: + if first: + start = s + end = e + pz_data = d + first = False + else: + if self.config.selected_bin < 0: + mask = np.ones(e-s, dtype=bool) + else: + mask = d['class_id'] == self.config.selected_bin + yield start, end, pz_data, mask + + def run(self): + iterator = self._setup_iterator() + self.zgrid = np.linspace( + self.config.zmin, self.config.zmax, self.config.nzbins + 1 + ) + self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1]) + # Initiallizing the histograms + single_hist = np.zeros(self.config.nzbins) + + first = True + for s, e, data, mask in iterator: + print(f"Process {self.rank} running estimator on chunk {s} - {e}") + self._process_chunk( + s, e, data, mask, first, single_hist + ) + first = False + if self.comm is not None: # pragma: no cover + single_hist = self.comm.reduce(single_hist) + + if self.rank == 0: + n_total = single_hist.sum() + qp_d = qp.Ensemble( + qp.hist, + data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist)), + ancil=dict(n_total=np.array([n_total], dtype=int)), + ) + self.add_data("true_NZ", qp_d) + + def _process_chunk( + self, _start, _end, data, mask, _first, single_hist, + ): + squeeze_mask = np.squeeze(mask) + zb = data[self.config.redshift_col][squeeze_mask] + single_hist += np.histogram(zb, bins=self.zgrid)[0] + + def histogram(self, catalog, tomo_bins): + """The main interface method for ``TrueNZHistogrammer``. + + Creates histogram of N of Z_true. + + This will attach the sample to this `Stage` (for introspection and + provenance tracking). + + Then it will call the run() and finalize() methods, which need to be + implemented by the sub-classes. + + The run() method will need to register the data that it creates to this + Estimator by using ``self.add_data('output', output_data)``. + + Finally, this will return a PqHandle providing access to that output + data. + + Parameters + ---------- + catalog : table-like + The sample with the true NZ column + + Returns + ------- + output_data : QPHandle + A handle giving access to a the histogram in QP format + """ + self.set_data("input", catalog) + self.set_data("tomography_bins", tomo_bins) + self.run() + self.finalize() + return self.get_handle("true_NZ") diff --git a/src/rail/examples_data/testdata/output_tomo.hdf5 b/src/rail/examples_data/testdata/output_tomo.hdf5 new file mode 100644 index 00000000..faa84853 Binary files /dev/null and b/src/rail/examples_data/testdata/output_tomo.hdf5 differ diff --git a/src/rail/stages/__init__.py b/src/rail/stages/__init__.py index 25f1adc6..93ec5f07 100644 --- a/src/rail/stages/__init__.py +++ b/src/rail/stages/__init__.py @@ -15,6 +15,7 @@ from rail.estimation.algos.var_inf import VarInfStackInformer, VarInfStackSummarizer from rail.estimation.algos.uniform_binning import UniformBinningClassifier from rail.estimation.algos.equal_count import EqualCountClassifier +from rail.estimation.algos.true_nz import TrueNZHistogrammer from rail.creation.degrader import Degrader @@ -59,6 +60,7 @@ def import_and_attach_all(): "VarInfStackSummarizer", "UniformBinningClassifier", "EqualCountClassifier", + "TrueNZHistogrammer", "Degrader", "AddColumnOfRandom", "QuantityCut", diff --git a/tests/estimation/test_true_nz.py b/tests/estimation/test_true_nz.py new file mode 100644 index 00000000..e2ba3de7 --- /dev/null +++ b/tests/estimation/test_true_nz.py @@ -0,0 +1,53 @@ +import os +import numpy as np +import pytest +import qp + +from rail.utils.path_utils import RAILDIR +from rail.core.stage import RailStage +from rail.core.data import TableHandle +from rail.estimation.algos.true_nz import TrueNZHistogrammer + + +DS = RailStage.data_store +DS.__class__.allow_overwrite = True + +true_nz_file = "src/rail/examples_data/testdata/validation_10gal.hdf5" +tomo_file = "src/rail/examples_data/testdata/output_tomo.hdf5" + + +def test_true_nz(): + DS.clear() + true_nz = DS.read_file('true_nz', path=true_nz_file, handle_class=TableHandle) + tomo_bins = DS.read_file('tomo_bins', path=tomo_file, handle_class=TableHandle) + + nz_hist = TrueNZHistogrammer.make_stage( + name='true_nz', + hdf5_groupname='photometry', + redshift_col='redshift', + zmin=0.0, + zmax=3.0, + nzbins=301, + ) + out_hist = nz_hist.histogram(true_nz, tomo_bins) + + check_ens = qp.read(out_hist.path) + assert check_ens.ancil['n_total'][0] == 10 + + check_vals = [0, 5, 0, 0, 0] + + for i in range(5): + nz_hist = TrueNZHistogrammer.make_stage( + name='true_nz', + hdf5_groupname='photometry', + redshift_col='redshift', + zmin=0.0, + zmax=3.0, + nzbins=301, + selected_bin=i, + ) + out_hist = nz_hist.histogram(true_nz, tomo_bins) + check_ens = qp.read(out_hist.path) + assert check_ens.ancil['n_total'][0] == check_vals[i] + +