Skip to content

Commit

Permalink
Added TrueNZHistogrammer stage
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Jun 27, 2024
1 parent b0421b6 commit 05f9a17
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/rail/estimation/algos/true_nz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
A summarizer-like stage that simple makes a histogram the true nz
"""

import numpy as np
import qp

from rail.core.common_params import SHARED_PARAMS
from rail.core.stage import RailStage
from rail.core.data import QPHandle, TableHandle


class TrueNZHistogrammer(RailStage):
"""Summarizer-like stage which simply histograms the true redshift"""

name = "TrueNZHistogrammer"
config_options = RailStage.config_options.copy()
config_options.update(
zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
nzbins=SHARED_PARAMS,
redshift_col=SHARED_PARAMS,
)
inputs = [("input", TableHandle), ("tomography_bins", TableHandle)]
outputs = [("true_NZ", QPHandle)]

def __init__(self, args, comm=None):
RailStage.__init__(self, args, comm=comm)
self.zgrid = None
self.bincents = None

def _setup_iterator(self):
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

for it in zip(*itrs):
first = True
mask = None
for s, e, d in it:
if first:
start = s
end = e
pz_data = d
first = False
else:
if self.config.selected_bin < 0:
mask = np.ones(len(d))
else:
mask = d['class_id'] == self.config.selected_bin
yield start, end, pz_data, mask

def run(self):
iterator = self._setup_iterator()
self.zgrid = np.linspace(
self.config.zmin, self.config.zmax, self.config.nzbins + 1
)
self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1])
# Initiallizing the histograms
single_hist = np.zeros(self.config.nzbins)

first = True
for s, e, data, mask in iterator:
print(f"Process {self.rank} running estimator on chunk {s} - {e}")
self._process_chunk(
s, e, data, mask, first, single_hist
)
first = False
if self.comm is not None: # pragma: no cover
single_hist = self.comm.reduce(single_hist)

if self.rank == 0:
qp_d = qp.Ensemble(
qp.hist, data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist))
)
self.add_data("true_NZ", qp_d)

def _process_chunk(
self, _start, _end, data, mask, _first, single_hist,
):
zb = data[self.config.redshift_col][mask]
single_hist += np.histogram(zb, bins=self.zgrid)[0]
2 changes: 2 additions & 0 deletions src/rail/stages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rail.estimation.algos.var_inf import VarInfStackInformer, VarInfStackSummarizer
from rail.estimation.algos.uniform_binning import UniformBinningClassifier
from rail.estimation.algos.equal_count import EqualCountClassifier
from rail.estimation.algos.true_nz import TrueNZHistogrammer

from rail.creation.degrader import Degrader

Expand Down Expand Up @@ -59,6 +60,7 @@ def import_and_attach_all():
"VarInfStackSummarizer",
"UniformBinningClassifier",
"EqualCountClassifier",
"TrueNZHistogrammer",
"Degrader",
"AddColumnOfRandom",
"QuantityCut",
Expand Down

0 comments on commit 05f9a17

Please sign in to comment.