-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
180 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
A summarizer-like stage that simple makes a histogram the true nz | ||
""" | ||
|
||
import numpy as np | ||
import qp | ||
|
||
from ceci.config import StageParameter as Param | ||
from rail.core.common_params import SHARED_PARAMS | ||
from rail.core.stage import RailStage | ||
from rail.core.data import QPHandle, TableHandle | ||
|
||
|
||
class TrueNZHistogrammer(RailStage): | ||
"""Summarizer-like stage which simply histograms the true redshift""" | ||
|
||
name = "TrueNZHistogrammer" | ||
config_options = RailStage.config_options.copy() | ||
config_options.update( | ||
zmin=SHARED_PARAMS, | ||
zmax=SHARED_PARAMS, | ||
nzbins=SHARED_PARAMS, | ||
redshift_col=SHARED_PARAMS, | ||
selected_bin=Param(int, -1, msg="Which tomography bin to consider"), | ||
chunk_size=10000, | ||
hdf5_groupname="", | ||
) | ||
inputs = [("input", TableHandle), ("tomography_bins", TableHandle)] | ||
outputs = [("true_NZ", QPHandle)] | ||
|
||
def __init__(self, args, comm=None): | ||
RailStage.__init__(self, args, comm=comm) | ||
self.zgrid = None | ||
self.bincents = None | ||
|
||
def _setup_iterator(self): | ||
|
||
itrs = [ | ||
self.input_iterator('input', groupname=self.config.hdf5_groupname), | ||
self.input_iterator('tomography_bins', groupname=""), | ||
] | ||
|
||
for it in zip(*itrs): | ||
first = True | ||
mask = None | ||
for s, e, d in it: | ||
if first: | ||
start = s | ||
end = e | ||
pz_data = d | ||
first = False | ||
else: | ||
if self.config.selected_bin < 0: | ||
mask = np.ones(e-s, dtype=bool) | ||
else: | ||
mask = d['class_id'] == self.config.selected_bin | ||
yield start, end, pz_data, mask | ||
|
||
def run(self): | ||
iterator = self._setup_iterator() | ||
self.zgrid = np.linspace( | ||
self.config.zmin, self.config.zmax, self.config.nzbins + 1 | ||
) | ||
self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1]) | ||
# Initiallizing the histograms | ||
single_hist = np.zeros(self.config.nzbins) | ||
|
||
first = True | ||
for s, e, data, mask in iterator: | ||
print(f"Process {self.rank} running estimator on chunk {s} - {e}") | ||
self._process_chunk( | ||
s, e, data, mask, first, single_hist | ||
) | ||
first = False | ||
if self.comm is not None: # pragma: no cover | ||
single_hist = self.comm.reduce(single_hist) | ||
|
||
if self.rank == 0: | ||
n_total = single_hist.sum() | ||
qp_d = qp.Ensemble( | ||
qp.hist, | ||
data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist)), | ||
ancil=dict(n_total=np.array([n_total], dtype=int)), | ||
) | ||
self.add_data("true_NZ", qp_d) | ||
|
||
def _process_chunk( | ||
self, _start, _end, data, mask, _first, single_hist, | ||
): | ||
squeeze_mask = np.squeeze(mask) | ||
zb = data[self.config.redshift_col][squeeze_mask] | ||
single_hist += np.histogram(zb, bins=self.zgrid)[0] | ||
|
||
def histogram(self, catalog, tomo_bins): | ||
"""The main interface method for ``TrueNZHistogrammer``. | ||
Creates histogram of N of Z_true. | ||
This will attach the sample to this `Stage` (for introspection and | ||
provenance tracking). | ||
Then it will call the run() and finalize() methods, which need to be | ||
implemented by the sub-classes. | ||
The run() method will need to register the data that it creates to this | ||
Estimator by using ``self.add_data('output', output_data)``. | ||
Finally, this will return a PqHandle providing access to that output | ||
data. | ||
Parameters | ||
---------- | ||
catalog : table-like | ||
The sample with the true NZ column | ||
Returns | ||
------- | ||
output_data : QPHandle | ||
A handle giving access to a the histogram in QP format | ||
""" | ||
self.set_data("input", catalog) | ||
self.set_data("tomography_bins", tomo_bins) | ||
self.run() | ||
self.finalize() | ||
return self.get_handle("true_NZ") |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
import numpy as np | ||
import pytest | ||
import qp | ||
|
||
from rail.utils.path_utils import RAILDIR | ||
from rail.core.stage import RailStage | ||
from rail.core.data import TableHandle | ||
from rail.estimation.algos.true_nz import TrueNZHistogrammer | ||
|
||
|
||
DS = RailStage.data_store | ||
DS.__class__.allow_overwrite = True | ||
|
||
true_nz_file = "src/rail/examples_data/testdata/validation_10gal.hdf5" | ||
tomo_file = "src/rail/examples_data/testdata/output_tomo.hdf5" | ||
|
||
|
||
def test_true_nz(): | ||
DS.clear() | ||
true_nz = DS.read_file('true_nz', path=true_nz_file, handle_class=TableHandle) | ||
tomo_bins = DS.read_file('tomo_bins', path=tomo_file, handle_class=TableHandle) | ||
|
||
nz_hist = TrueNZHistogrammer.make_stage( | ||
name='true_nz', | ||
hdf5_groupname='photometry', | ||
redshift_col='redshift', | ||
zmin=0.0, | ||
zmax=3.0, | ||
nzbins=301, | ||
) | ||
out_hist = nz_hist.histogram(true_nz, tomo_bins) | ||
|
||
check_ens = qp.read(out_hist.path) | ||
assert check_ens.ancil['n_total'][0] == 10 | ||
|
||
check_vals = [0, 5, 0, 0, 0] | ||
|
||
for i in range(5): | ||
nz_hist = TrueNZHistogrammer.make_stage( | ||
name='true_nz', | ||
hdf5_groupname='photometry', | ||
redshift_col='redshift', | ||
zmin=0.0, | ||
zmax=3.0, | ||
nzbins=301, | ||
selected_bin=i, | ||
) | ||
out_hist = nz_hist.histogram(true_nz, tomo_bins) | ||
check_ens = qp.read(out_hist.path) | ||
assert check_ens.ancil['n_total'][0] == check_vals[i] | ||
|
||
|