Skip to content

Commit

Permalink
Issues/127/true nz (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles authored Jul 3, 2024
1 parent b0421b6 commit 0222836
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 0 deletions.
125 changes: 125 additions & 0 deletions src/rail/estimation/algos/true_nz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
A summarizer-like stage that simple makes a histogram the true nz
"""

import numpy as np
import qp

from ceci.config import StageParameter as Param
from rail.core.common_params import SHARED_PARAMS
from rail.core.stage import RailStage
from rail.core.data import QPHandle, TableHandle


class TrueNZHistogrammer(RailStage):
"""Summarizer-like stage which simply histograms the true redshift"""

name = "TrueNZHistogrammer"
config_options = RailStage.config_options.copy()
config_options.update(
zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
nzbins=SHARED_PARAMS,
redshift_col=SHARED_PARAMS,
selected_bin=Param(int, -1, msg="Which tomography bin to consider"),
chunk_size=10000,
hdf5_groupname="",
)
inputs = [("input", TableHandle), ("tomography_bins", TableHandle)]
outputs = [("true_NZ", QPHandle)]

def __init__(self, args, comm=None):
RailStage.__init__(self, args, comm=comm)
self.zgrid = None
self.bincents = None

def _setup_iterator(self):

itrs = [
self.input_iterator('input', groupname=self.config.hdf5_groupname),
self.input_iterator('tomography_bins', groupname=""),
]

for it in zip(*itrs):
first = True
mask = None
for s, e, d in it:
if first:
start = s
end = e
pz_data = d
first = False
else:
if self.config.selected_bin < 0:
mask = np.ones(e-s, dtype=bool)
else:
mask = d['class_id'] == self.config.selected_bin
yield start, end, pz_data, mask

def run(self):
iterator = self._setup_iterator()
self.zgrid = np.linspace(
self.config.zmin, self.config.zmax, self.config.nzbins + 1
)
self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1])
# Initiallizing the histograms
single_hist = np.zeros(self.config.nzbins)

first = True
for s, e, data, mask in iterator:
print(f"Process {self.rank} running estimator on chunk {s} - {e}")
self._process_chunk(
s, e, data, mask, first, single_hist
)
first = False
if self.comm is not None: # pragma: no cover
single_hist = self.comm.reduce(single_hist)

if self.rank == 0:
n_total = single_hist.sum()
qp_d = qp.Ensemble(
qp.hist,
data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist)),
ancil=dict(n_total=np.array([n_total], dtype=int)),
)
self.add_data("true_NZ", qp_d)

def _process_chunk(
self, _start, _end, data, mask, _first, single_hist,
):
squeeze_mask = np.squeeze(mask)
zb = data[self.config.redshift_col][squeeze_mask]
single_hist += np.histogram(zb, bins=self.zgrid)[0]

def histogram(self, catalog, tomo_bins):
"""The main interface method for ``TrueNZHistogrammer``.
Creates histogram of N of Z_true.
This will attach the sample to this `Stage` (for introspection and
provenance tracking).
Then it will call the run() and finalize() methods, which need to be
implemented by the sub-classes.
The run() method will need to register the data that it creates to this
Estimator by using ``self.add_data('output', output_data)``.
Finally, this will return a PqHandle providing access to that output
data.
Parameters
----------
catalog : table-like
The sample with the true NZ column
Returns
-------
output_data : QPHandle
A handle giving access to a the histogram in QP format
"""
self.set_data("input", catalog)
self.set_data("tomography_bins", tomo_bins)
self.run()
self.finalize()
return self.get_handle("true_NZ")
Binary file added src/rail/examples_data/testdata/output_tomo.hdf5
Binary file not shown.
2 changes: 2 additions & 0 deletions src/rail/stages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rail.estimation.algos.var_inf import VarInfStackInformer, VarInfStackSummarizer
from rail.estimation.algos.uniform_binning import UniformBinningClassifier
from rail.estimation.algos.equal_count import EqualCountClassifier
from rail.estimation.algos.true_nz import TrueNZHistogrammer

from rail.creation.degrader import Degrader

Expand Down Expand Up @@ -59,6 +60,7 @@ def import_and_attach_all():
"VarInfStackSummarizer",
"UniformBinningClassifier",
"EqualCountClassifier",
"TrueNZHistogrammer",
"Degrader",
"AddColumnOfRandom",
"QuantityCut",
Expand Down
53 changes: 53 additions & 0 deletions tests/estimation/test_true_nz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import numpy as np
import pytest
import qp

from rail.utils.path_utils import RAILDIR
from rail.core.stage import RailStage
from rail.core.data import TableHandle
from rail.estimation.algos.true_nz import TrueNZHistogrammer


DS = RailStage.data_store
DS.__class__.allow_overwrite = True

true_nz_file = "src/rail/examples_data/testdata/validation_10gal.hdf5"
tomo_file = "src/rail/examples_data/testdata/output_tomo.hdf5"


def test_true_nz():
DS.clear()
true_nz = DS.read_file('true_nz', path=true_nz_file, handle_class=TableHandle)
tomo_bins = DS.read_file('tomo_bins', path=tomo_file, handle_class=TableHandle)

nz_hist = TrueNZHistogrammer.make_stage(
name='true_nz',
hdf5_groupname='photometry',
redshift_col='redshift',
zmin=0.0,
zmax=3.0,
nzbins=301,
)
out_hist = nz_hist.histogram(true_nz, tomo_bins)

check_ens = qp.read(out_hist.path)
assert check_ens.ancil['n_total'][0] == 10

check_vals = [0, 5, 0, 0, 0]

for i in range(5):
nz_hist = TrueNZHistogrammer.make_stage(
name='true_nz',
hdf5_groupname='photometry',
redshift_col='redshift',
zmin=0.0,
zmax=3.0,
nzbins=301,
selected_bin=i,
)
out_hist = nz_hist.histogram(true_nz, tomo_bins)
check_ens = qp.read(out_hist.path)
assert check_ens.ancil['n_total'][0] == check_vals[i]


0 comments on commit 0222836

Please sign in to comment.