Skip to content

Commit

Permalink
Issue/124/mask summarizers (#130)
Browse files Browse the repository at this point in the history
* WIP, first go at doing masked summarizers

* WIP, delinting

* Fixes to masked summarizers

* Added summarize() method overrides

* added tests for masked summarizers

* Added option to set tomography_bins to 'none' in naive_stack and point_est_hist, and tests for same

* Fix up coverage
  • Loading branch information
eacharles authored Jul 29, 2024
1 parent 284a4bd commit bb064dd
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 23 deletions.
6 changes: 3 additions & 3 deletions src/rail/estimation/algos/equal_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from ceci.config import StageParameter as Param
from rail.estimation.classifier import PZClassifier
from rail.core.data import TableHandle
from rail.core.data import Hdf5Handle


class EqualCountClassifier(PZClassifier):
Expand All @@ -27,14 +27,14 @@ class EqualCountClassifier(PZClassifier):
nbins=Param(int, 5, msg="Number of tomographic bins"),
no_assign=Param(int, -99, msg="Value for no assignment flag"),
)
outputs = [("output", TableHandle)]
outputs = [("output", Hdf5Handle)]

def run(self):
test_data = self.get_data("input")
npdf = test_data.npdf

try:
zb = test_data.ancil[self.config.point_estimate]
zb = np.squeeze(test_data.ancil[self.config.point_estimate])
except KeyError as msg:
raise KeyError(
f"{self.config.point_estimate} is not contained in the data ancil, "
Expand Down
93 changes: 84 additions & 9 deletions src/rail/estimation/algos/naive_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ceci.config import StageParameter as Param
from rail.estimation.summarizer import PZSummarizer
from rail.estimation.informer import PzInformer
from rail.core.data import QPHandle
from rail.core.data import QPHandle, TableHandle


class NaiveStackInformer(PzInformer):
Expand Down Expand Up @@ -40,8 +40,16 @@ def __init__(self, args, **kwargs):
super().__init__(args, **kwargs)
self.zgrid = None

def _setup_iterator(self):
itr = self.input_iterator("input")
for s, e, d in itr:
yield s, e, d, np.ones(e-s, dtype=bool)


def run(self):
iterator = self.input_iterator("input")
handle = self.get_handle("input", allow_missing=True)
self._input_length = handle.size()
iterator = self._setup_iterator()
self.zgrid = np.linspace(
self.config.zmin, self.config.zmax, self.config.nzbins + 1
)
Expand All @@ -51,9 +59,9 @@ def run(self):
bootstrap_matrix = self._broadcast_bootstrap_matrix()

first = True
for s, e, test_data in iterator:
for s, e, test_data, mask in iterator:
print(f"Process {self.rank} running estimator on chunk {s} - {e}")
self._process_chunk(s, e, test_data, first, bootstrap_matrix, yvals, bvals)
self._process_chunk(s, e, test_data, mask, first, bootstrap_matrix, yvals, bvals)
first = False
if self.comm is not None: # pragma: no cover
bvals, yvals = self._join_histograms(bvals, yvals)
Expand All @@ -66,15 +74,82 @@ def run(self):
self.add_data("output", sample_ens)
self.add_data("single_NZ", qp_d)

def _process_chunk(self, start, end, data, _first, bootstrap_matrix, yvals, bvals):
def _process_chunk(self, start, end, data, mask, _first, bootstrap_matrix, yvals, bvals):
pdf_vals = data.pdf(self.zgrid)
squeeze_mask = np.squeeze(mask)
yvals += np.expand_dims(
np.sum(np.where(np.isfinite(pdf_vals), pdf_vals, 0.0), axis=0), 0
np.sum(np.where(np.isfinite(pdf_vals[squeeze_mask,:]), pdf_vals[squeeze_mask], 0.0), axis=0), 0
)
# qp_d is the normalized probability of the stack, we need to know how many galaxies were
for i in range(self.config.nsamples):
bootstrap_draws = bootstrap_matrix[:, i]
# Neither all of the bootstrap_draws are in this chunk nor the index starts at "start"
mask = (bootstrap_draws >= start) & (bootstrap_draws < end)
bootstrap_draws = bootstrap_draws[mask] - start
bvals[i] += np.sum(pdf_vals[bootstrap_draws], axis=0)
chunk_mask = (bootstrap_draws >= start) & (bootstrap_draws < end)
bootstrap_draws = bootstrap_draws[chunk_mask] - start
zarr = np.where(np.expand_dims(mask, -1), pdf_vals, 0.)[bootstrap_draws]
bvals[i] += np.sum(zarr, axis=0)


class NaiveStackMaskedSummarizer(NaiveStackSummarizer):
name = "NaiveStackMaskedSummarizer"
config_options = NaiveStackSummarizer.config_options.copy()
config_options.update(
selected_bin=Param(int, -1, msg="bin to use"),
)
inputs = [("input", QPHandle), ("tomography_bins", TableHandle)]
outputs = [("output", QPHandle), ("single_NZ", QPHandle)]


def _setup_iterator(self):

selected_bin = self.config.selected_bin
if self.config.tomography_bins in ['none', None]:
selected_bin = -1

if selected_bin == -1:
itrs = [self.input_iterator('input')]
else:
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

for it in zip(*itrs):
first = True
mask = None
for s, e, d in it:
if first:
start = s
end = e
pz_data = d
first = False
else:
mask = d['class_id'] == self.config.selected_bin
if mask is None:
mask = np.ones(pz_data.npdf, dtype=bool)
yield start, end, pz_data, mask

def summarize(self, input_data, tomo_bins=None):
"""Override the Summarizer.summarize() method to take tomo bins
as an additional input
Parameters
----------
input_data : `qp.Ensemble`
Per-galaxy p(z), and any ancilary data associated with it
tomo_bins : `table-like`
Tomographic bins file
Returns
-------
output: `qp.Ensemble`
Ensemble with n(z), and any ancilary data
"""
self.set_data("input", input_data)
if tomo_bins is None:
self.config.tomography_bins = None
else:
self.set_data("tomography_bins", tomo_bins)
self.run()
self.finalize()
return self.get_handle("output")


91 changes: 82 additions & 9 deletions src/rail/estimation/algos/point_est_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ceci.config import StageParameter as Param
from rail.estimation.summarizer import PZSummarizer
from rail.estimation.informer import PzInformer
from rail.core.data import QPHandle
from rail.core.data import QPHandle, TableHandle


class PointEstHistInformer(PzInformer):
Expand Down Expand Up @@ -42,8 +42,15 @@ def __init__(self, args, **kwargs):
self.zgrid = None
self.bincents = None

def _setup_iterator(self):
itr = self.input_iterator("input")
for s, e, d in itr:
yield s, e, d, np.ones(e-s, dtype=bool)

def run(self):
iterator = self.input_iterator("input")
handle = self.get_handle("input", allow_missing=True)
self._input_length = handle.size()
iterator = self._setup_iterator()
self.zgrid = np.linspace(
self.config.zmin, self.config.zmax, self.config.nzbins + 1
)
Expand All @@ -54,10 +61,10 @@ def run(self):
hist_vals = np.zeros((self.config.nsamples, self.config.nzbins))

first = True
for s, e, test_data in iterator:
for s, e, test_data, mask in iterator:
print(f"Process {self.rank} running estimator on chunk {s} - {e}")
self._process_chunk(
s, e, test_data, first, bootstrap_matrix, single_hist, hist_vals
s, e, test_data, mask, first, bootstrap_matrix, single_hist, hist_vals
)
first = False
if self.comm is not None: # pragma: no cover
Expand All @@ -74,14 +81,80 @@ def run(self):
self.add_data("single_NZ", qp_d)

def _process_chunk(
self, start, end, test_data, _first, bootstrap_matrix, single_hist, hist_vals
self, start, end, test_data, mask, _first, bootstrap_matrix, single_hist, hist_vals
):
zb = test_data.ancil[self.config.point_estimate]
single_hist += np.histogram(zb, bins=self.zgrid)[0]
single_hist += np.histogram(zb[mask], bins=self.zgrid)[0]
for i in range(self.config.nsamples):
bootstrap_indeces = bootstrap_matrix[:, i]
# Neither all of the bootstrap_draws are in this chunk nor the index starts at "start"
mask = (bootstrap_indeces >= start) & (bootstrap_indeces < end)
bootstrap_indeces = bootstrap_indeces[mask] - start
zarr = zb[bootstrap_indeces]
chunk_mask = (bootstrap_indeces >= start) & (bootstrap_indeces < end)
bootstrap_indeces = bootstrap_indeces[chunk_mask] - start
zarr = np.where(mask, zb, np.nan)[bootstrap_indeces]
hist_vals[i] += np.histogram(zarr, bins=self.zgrid)[0]


class PointEstHistMaskedSummarizer(PointEstHistSummarizer):
"""Summarizer which simply histograms a point estimate"""

name = "PointEstHistMaskedSummarizer"
config_options = PointEstHistSummarizer.config_options.copy()
config_options.update(
selected_bin=Param(int, -1, msg="bin to use"),
)
inputs = [("input", QPHandle), ("tomography_bins", TableHandle)]
outputs = [("output", QPHandle), ("single_NZ", QPHandle)]

def _setup_iterator(self):

selected_bin = self.config.selected_bin
if self.config.tomography_bins in ['none', None]:
selected_bin = -1

if selected_bin == -1:
itrs = [self.input_iterator('input')]
else:
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

for it in zip(*itrs):
first = True
mask = None
for s, e, d in it:
if first:
start = s
end = e
pz_data = d
first = False
else:
mask = d['class_id'] == self.config.selected_bin
if mask is None:
mask = np.ones(pz_data.npdf, dtype=bool)
yield start, end, pz_data, mask

def summarize(self, input_data, tomo_bins=None):
"""Override the Summarizer.summarize() method to take tomo bins
as an additional input
Parameters
----------
input_data : `qp.Ensemble`
Per-galaxy p(z), and any ancilary data associated with it
tomo_bins : `table-like`
Tomographic bins file
Returns
-------
output: `qp.Ensemble`
Ensemble with n(z), and any ancilary data
"""
self.set_data("input", input_data)
if tomo_bins is None:
self.config.tomography_bins = None
else:
self.set_data("tomography_bins", tomo_bins)
self.run()
self.finalize()
return self.get_handle("output")


2 changes: 1 addition & 1 deletion src/rail/estimation/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def summarize(self, input_data):
self.finalize()
return self.get_handle("output")

def _broadcast_bootstrap_matrix(self):
def _broadcast_bootstrap_matrix(self):
rng = np.random.default_rng(seed=self.config.seed)
# Only one of the nodes needs to produce the bootstrap indices
ngal = self._input_length
Expand Down
57 changes: 56 additions & 1 deletion tests/estimation/test_summarizers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

from rail.core.data import QPHandle
from rail.core.data import QPHandle, TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.estimation.algos import naive_stack, point_est_hist, var_inf

testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.hdf5")
tomobins = os.path.join(RAILDIR, "rail/examples_data/testdata/output_tomo.hdf5")
DS = RailStage.data_store


Expand All @@ -15,6 +16,7 @@ def one_algo(key, summarizer_class, summary_kwargs):
Run summarize
"""
DS.__class__.allow_overwrite = True
DS.clear()
test_data = DS.read_file("test_data", QPHandle, testdata)
summarizer = summarizer_class.make_stage(name=key, **summary_kwargs)
summary_ens = summarizer.summarize(test_data)
Expand All @@ -27,6 +29,37 @@ def one_algo(key, summarizer_class, summary_kwargs):
return summary_ens


def one_mask_algo(key, summarizer_class, summary_kwargs):
"""
A basic test of running an summaizer subclass
Run summarize
"""
DS.__class__.allow_overwrite = True
DS.clear()
test_data = DS.read_file("test_data", QPHandle, testdata)
tomo_bins = DS.read_file("tomo_bins", TableHandle, tomobins)

summarizer = summarizer_class.make_stage(name=key, selected_bin=1, **summary_kwargs)
summary_ens = summarizer.summarize(test_data, tomo_bins)
os.remove(
summarizer.get_output(summarizer.get_aliased_tag("output"), final_name=True)
)
os.remove(
summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True)
)

summarizer_2 = summarizer_class.make_stage(name=f"{key}_2", **summary_kwargs)
summary_2_ens = summarizer_2.summarize(test_data, None)
os.remove(
summarizer_2.get_output(summarizer_2.get_aliased_tag("output"), final_name=True)
)
os.remove(
summarizer_2.get_output(summarizer_2.get_aliased_tag("single_NZ"), final_name=True)
)

return [summary_ens, summary_2_ens]


def test_naive_stack():
"""Basic end to end test for the Naive stack informer to estimator stages"""
naive_stack_informer_stage = naive_stack.NaiveStackInformer.make_stage()
Expand Down Expand Up @@ -57,3 +90,25 @@ def test_var_inference_stack():
summary_config_dict = {}
summarizer_class = var_inf.VarInfStackSummarizer
_ = one_algo("VariationalInference", summarizer_class, summary_config_dict)


def test_naive_stack_masked():
"""Basic end to end test for the Naive stack informer to estimator stages"""
summary_config_dict = dict(
chunk_size=5,
)
summarizer_class = naive_stack.NaiveStackMaskedSummarizer
_ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict)
_ = one_algo("NaiveStack", summarizer_class, summary_config_dict)


def test_point_estimate_hist_masekd():
"""Basic end to end test for the point estimate histogram informer to estimator
stages
"""
summary_config_dict = dict(
chunk_size=5,
)
summarizer_class = point_est_hist.PointEstHistMaskedSummarizer
_ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict)
_ = one_algo("PointEstimateHist", summarizer_class, summary_config_dict)

0 comments on commit bb064dd

Please sign in to comment.